koichi12 commited on
Commit
a8eed2c
·
verified ·
1 Parent(s): 49fc886

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 +3 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comms.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-311.pyc +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/exc.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/graph.cpython-311.pyc +0 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comms.py +363 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/config.py +752 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py +264 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/graph.py +1324 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ir.py +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py +1524 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/scheduler.py +2445 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py +1156 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/sizevars.py +643 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/utils.py +1428 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/wrapper_benchmark.py +299 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h +2 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h +1 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DynamicLibrary.h +34 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h +1 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h +324 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h +86 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h +9 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContext.h +9 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h +115 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Exceptions.h +174 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSDevice.h +85 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AmpKernels.h +28 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUBlas.h +189 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h +263 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvolutionMM3d.h +14 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Copy.h +20 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h +229 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ForeachUtils.h +371 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebra.h +18 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SortingUtils.h +88 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UnaryOps.h +130 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h +37 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CatKernel.h +12 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h +14 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h +21 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h +33 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Loops.h +394 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h +14 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h +238 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h +22 -0
.gitattributes CHANGED
@@ -76,3 +76,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/
76
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
77
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text
78
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
 
 
76
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
77
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text
78
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
79
+ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 filter=lfs diff=lfs merge=lfs -text
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26a7288b7315d658acab1073f02c4f18cd1d27eeadde102958f0317dad6656e0
3
+ size 150200
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-311.pyc ADDED
Binary file (8.74 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comms.cpython-311.pyc ADDED
Binary file (18.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-311.pyc ADDED
Binary file (6.13 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/exc.cpython-311.pyc ADDED
Binary file (7.37 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/graph.cpython-311.pyc ADDED
Binary file (67.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-311.pyc ADDED
Binary file (18.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-311.pyc ADDED
Binary file (4.85 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comms.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-strict
2
+
3
+ from typing import List
4
+
5
+ import torch
6
+
7
+ from . import config, ir, scheduler
8
+ from .dependencies import WeakDep
9
+ from .utils import tuple_sorted
10
+
11
+ overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
12
+
13
+
14
+ def sink_waits(
15
+ snodes: List["scheduler.BaseSchedulerNode"],
16
+ ) -> List["scheduler.BaseSchedulerNode"]:
17
+ """
18
+ Greedily moves waits as late as possible (i.e. until we reach a use). Optimal in terms of
19
+ communication overlap.
20
+ """
21
+ new_order = []
22
+ cur_waits = set()
23
+ for snode in snodes:
24
+ if isinstance(snode.node, ir.Wait):
25
+ cur_waits.add(snode)
26
+ else:
27
+ for wait in tuple_sorted(cur_waits):
28
+ if snode in wait.node_users:
29
+ new_order.append(wait)
30
+ cur_waits.remove(wait)
31
+ new_order.append(snode)
32
+ new_order.extend(tuple_sorted(cur_waits))
33
+ return new_order
34
+
35
+
36
+ def raise_comms(
37
+ snodes: List["scheduler.BaseSchedulerNode"],
38
+ ) -> List["scheduler.BaseSchedulerNode"]:
39
+ """
40
+ Greedily moves comms as early as possible (i.e. until we reach an input).
41
+ Optimal in terms of communication overlap.
42
+
43
+ TODO: We might want to adjust this in the future to account for memory limitations.
44
+ e.g. when we are compiling FSDP, this heuristics will cause the all-gathers to be prefetched as soon as possible,
45
+ which is the beginning of the forwards pass. We'll have to either do a special pass for FSDP,
46
+ or we'll want to redo this pass with memory considerations so we handle the FSDP case in a general way.
47
+ """
48
+ new_order_reversed: List["scheduler.BaseSchedulerNode"] = []
49
+ cur_comms: List["scheduler.BaseSchedulerNode"] = []
50
+ for snode in reversed(snodes):
51
+ if isinstance(snode.node, ir.CollectiveKernel):
52
+ cur_comms.append(snode)
53
+ else:
54
+ for comm in cur_comms:
55
+ assert len(comm.inverse_users) > 0
56
+ while len(cur_comms) > 0 and any(
57
+ snode in comm.inverse_users for comm in cur_comms
58
+ ):
59
+ comm = cur_comms.pop(0)
60
+ new_order_reversed.append(comm)
61
+ new_order_reversed.append(snode)
62
+ assert len(cur_comms) <= 1
63
+ new_order_reversed.extend(tuple_sorted(cur_comms))
64
+ return new_order_reversed[::-1]
65
+
66
+
67
+ def get_ancestors(node):
68
+ ancestors = set()
69
+ cur_nodes = [node]
70
+ while len(cur_nodes) > 0:
71
+ new_nodes = []
72
+ for node in cur_nodes:
73
+ for inp in node.inverse_users:
74
+ if inp not in ancestors:
75
+ ancestors.add(inp)
76
+ new_nodes.append(inp)
77
+ cur_nodes = new_nodes
78
+ return ancestors
79
+
80
+
81
+ def get_descendants(node):
82
+ descendants = set()
83
+ cur_nodes = [node]
84
+ while len(cur_nodes) > 0:
85
+ new_nodes = []
86
+ for node in cur_nodes:
87
+ for inp in node.node_users:
88
+ if inp not in descendants:
89
+ descendants.add(inp)
90
+ new_nodes.append(inp)
91
+ cur_nodes = new_nodes
92
+ return descendants
93
+
94
+
95
+ def decide_global_ordering_of_comms(nodes: List["scheduler.BaseSchedulerNode"]):
96
+ """
97
+ Decide global ordering of comms, by just enforcing the ordering that's in the input graph
98
+ (might not be the same ordering as the eager mode program).
99
+ TODO: Come up with a better approach
100
+ """
101
+ comm_nodes = [n for n in nodes if isinstance(n.node, ir.CollectiveKernel)]
102
+ for i in range(1, len(comm_nodes)):
103
+ # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
104
+ comm_nodes[i].add_fake_dep(WeakDep(comm_nodes[i - 1].get_name()))
105
+
106
+
107
+ def assert_no_comm_nodes(snodes: List["scheduler.BaseSchedulerNode"]) -> None:
108
+ assert not any(isinstance(snode.node, ir.CollectiveKernel) for snode in snodes)
109
+
110
+
111
+ def estimate_op_runtime(snode: "scheduler.BaseSchedulerNode") -> float:
112
+ """
113
+ Returns estimated op runtime in nanoseconds (ns)
114
+ """
115
+ if config.estimate_op_runtime == "default":
116
+ runtime = snode.get_estimated_runtime()
117
+ else:
118
+ assert callable(config.estimate_op_runtime)
119
+ runtime = config.estimate_op_runtime(snode)
120
+ return runtime
121
+
122
+
123
+ def reorder_compute_for_overlap(
124
+ snodes: List["scheduler.BaseSchedulerNode"],
125
+ ) -> List["scheduler.BaseSchedulerNode"]:
126
+ """
127
+ Decides a global ordering of all compute and communication nodes,
128
+ assuming that we already have a global ordering of communication nodes.
129
+
130
+ Overall scheduling procedure is:
131
+ Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
132
+ that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N.
133
+ Step 2: If all those compute nodes are sufficient to overlap comm N, we're done.
134
+ Otherwise, we now need to look elsewhere to find compute that overlaps with comm N.
135
+ We prioritize compute nodes that are needed sooner.
136
+ Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1.
137
+ Step 4: We schedule comm N + 1.
138
+ Repeat this for subsequent comm nodes.
139
+ """
140
+ final_order = []
141
+
142
+ comm_nodes = []
143
+ for snode in snodes:
144
+ if isinstance(snode.node, ir.CollectiveKernel):
145
+ comm_nodes.append(snode)
146
+ if len(comm_nodes) == 0:
147
+ # if there is no comm nodes, return the current order
148
+ return snodes
149
+
150
+ comm_ancestors = {node: get_ancestors(node) for node in comm_nodes}
151
+ comm_descendants = {node: get_descendants(node) for node in comm_nodes}
152
+
153
+ indeg = dict.fromkeys(snodes, 0)
154
+ for snode in snodes:
155
+ for user in snode.node_users:
156
+ if user in indeg:
157
+ indeg[user] += 1
158
+ ready_to_schedule_nodes = {node for node in snodes if indeg[node] == 0}
159
+
160
+ unscheduled_nodes = set()
161
+ unscheduled_nodes = set(snodes)
162
+
163
+ def schedule_node(snode):
164
+ """
165
+ Schedule a single node.
166
+ """
167
+ assert snode in unscheduled_nodes
168
+ assert snode in ready_to_schedule_nodes
169
+ ready_to_schedule_nodes.remove(snode)
170
+ unscheduled_nodes.remove(snode)
171
+ final_order.append(snode)
172
+ for user in tuple_sorted(snode.node_users):
173
+ if user in indeg:
174
+ indeg[user] -= 1
175
+ if indeg[user] == 0:
176
+ ready_to_schedule_nodes.add(user)
177
+
178
+ def schedule_nodes(snodes):
179
+ """
180
+ Schedules all nodes in `snodes` in an arbitrary topologically valid order.
181
+ """
182
+ all_nodes = set(snodes)
183
+ assert all(node in unscheduled_nodes for node in all_nodes)
184
+ while len(all_nodes) > 0:
185
+ # NOTE: since model graph is always a DAG and does not have circular dependency inside,
186
+ # there should be at least one node that is a "free node" (i.e. indeg == 0),
187
+ # hence infinite loop is not possible. But we check here just to be safe.
188
+ progress = False
189
+ for node in tuple_sorted(all_nodes):
190
+ if node in ready_to_schedule_nodes:
191
+ schedule_node(node)
192
+ all_nodes.remove(node)
193
+ progress = True
194
+ if not progress:
195
+ raise Exception(
196
+ "Unable to find a free node (indeg == 0). This is an impossible state to reach. "
197
+ "Please report a bug to PyTorch."
198
+ )
199
+
200
+ # First, schedule all compute nodes that are required by first comm node,
201
+ # as well as the first comm node itself.
202
+ assert len(comm_nodes) > 0
203
+ schedule_nodes(
204
+ list(comm_ancestors[comm_nodes[0]]) + [comm_nodes[0]],
205
+ )
206
+
207
+ rolled_over_compute_cost = 0
208
+ for idx in range(1, len(comm_ancestors)):
209
+ # Step 1: Given that we've currently scheduled comm `idx-1`, we now schedule
210
+ # all compute nodes that are required for comm `idx` but do not depend on comm `idx-1`,
211
+ # to run at the same time with comm `idx-1`.
212
+ needed_by_next_comm_and_ready_compute_nodes = unscheduled_nodes & (
213
+ comm_ancestors[comm_nodes[idx]] - comm_descendants[comm_nodes[idx - 1]]
214
+ )
215
+ assert_no_comm_nodes(needed_by_next_comm_and_ready_compute_nodes)
216
+
217
+ total_compute_runtime_cost = rolled_over_compute_cost + sum(
218
+ [
219
+ estimate_op_runtime(node)
220
+ for node in needed_by_next_comm_and_ready_compute_nodes
221
+ ]
222
+ )
223
+ prev_comm_runtime_cost = estimate_op_runtime(comm_nodes[idx - 1])
224
+ schedule_nodes(tuple_sorted(needed_by_next_comm_and_ready_compute_nodes))
225
+
226
+ # Step 2: If all those compute nodes are sufficient to overlap comm `idx-1`, we're done.
227
+ # Otherwise, we now need to look elsewhere to find compute that overlaps with comm `idx`.
228
+ # We prioritize compute nodes that are needed sooner.
229
+ step1_runtime_cost = total_compute_runtime_cost
230
+ if step1_runtime_cost >= prev_comm_runtime_cost:
231
+ pass
232
+ else:
233
+ # Find all ready to schedule compute nodes that do not depend on comm `idx-1`.
234
+ ready_to_schedule_compute_nodes = tuple_sorted(
235
+ ready_to_schedule_nodes - comm_descendants[comm_nodes[idx - 1]]
236
+ )
237
+ assert_no_comm_nodes(ready_to_schedule_compute_nodes)
238
+
239
+ def earliest_comm_descendant(node):
240
+ for idx in range(len(comm_nodes)):
241
+ if node in comm_ancestors[comm_nodes[idx]]:
242
+ return idx
243
+ return len(comm_nodes)
244
+
245
+ # Prioritize compute nodes that are needed sooner.
246
+ ready_to_schedule_compute_nodes = sorted(
247
+ ready_to_schedule_compute_nodes, key=earliest_comm_descendant
248
+ )
249
+
250
+ for snode in ready_to_schedule_compute_nodes:
251
+ if total_compute_runtime_cost >= prev_comm_runtime_cost:
252
+ # If accumulated compute runtime cost is greater than comm `idx-1` runtime cost,
253
+ # it means we have maximized overlap for comm `idx-1`, and hence we stop looking
254
+ # for more compute to schedule.
255
+ break
256
+ compute_runtime_cost = estimate_op_runtime(snode)
257
+ # If we're not able to leverage more than half of this
258
+ # node's compute to overlap, we skip it.
259
+ # TODO: Smarter heuristics here
260
+ if (
261
+ prev_comm_runtime_cost - total_compute_runtime_cost
262
+ ) <= compute_runtime_cost / 2:
263
+ continue
264
+ schedule_node(snode)
265
+ total_compute_runtime_cost += compute_runtime_cost
266
+ rollable_compute_cost = total_compute_runtime_cost - step1_runtime_cost
267
+
268
+ # Step 3: We schedule the compute nodes dependent on comm `idx-1` and required for comm `idx`.
269
+ needed_by_next_comm_nodes = unscheduled_nodes & comm_ancestors[comm_nodes[idx]]
270
+ schedule_nodes(list(needed_by_next_comm_nodes))
271
+
272
+ # Step 4: We schedule comm `idx`.
273
+ schedule_nodes([comm_nodes[idx]])
274
+
275
+ is_prev_comm_blocking_next_comm = len(needed_by_next_comm_nodes) > 0
276
+ # The idea here is that if there are no compute nodes from Step 3
277
+ # (i.e. if prev comm is not blocking next comm), we can roll over the compute nodes
278
+ # in Step 2 to overlap with the next comm, since they're not required to finish
279
+ # before the next comm starts.
280
+ if is_prev_comm_blocking_next_comm:
281
+ rolled_over_compute_cost = 0
282
+ else:
283
+ rolled_over_compute_cost = rollable_compute_cost # type: ignore[assignment]
284
+
285
+ schedule_nodes(unscheduled_nodes)
286
+ return final_order
287
+
288
+
289
+ def node_summary(snode):
290
+ detail = ""
291
+ if isinstance(snode.node, ir.ExternKernelOut):
292
+ detail = f" ({snode.node.python_kernel_name})"
293
+ out_tensor_info = ""
294
+ if (
295
+ hasattr(snode.node, "layout")
296
+ and hasattr(snode.node.layout, "size")
297
+ and hasattr(snode.node.layout, "stride")
298
+ ):
299
+ out_tensor_info = (
300
+ f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})"
301
+ )
302
+ node_name = ""
303
+ if hasattr(snode.node, "name"):
304
+ node_name = snode.node.name
305
+ return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})"
306
+
307
+
308
+ def visualize_overlap(order):
309
+ total_est_runtime: float = 0.0
310
+ cur_comm_node = None
311
+ for snode in order:
312
+ if cur_comm_node is None:
313
+ if isinstance(snode.node, ir.CollectiveKernel):
314
+ total_est_runtime += estimate_op_runtime(snode)
315
+ cur_comm_node = snode.node
316
+ elif isinstance(snode.node, ir.Wait):
317
+ raise Exception(
318
+ "Wait is not expected when there is no collective running"
319
+ )
320
+ else: # exposed compute op
321
+ total_est_runtime += estimate_op_runtime(snode)
322
+ overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
323
+ else: # cur_comm_node is not None
324
+ if isinstance(snode.node, ir.CollectiveKernel):
325
+ raise Exception(
326
+ "Found two collectives running at the same time. "
327
+ "`visualize_overlap` needs to be updated to handle this case"
328
+ )
329
+ elif isinstance(snode.node, ir.Wait): # end of this comm op
330
+ overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
331
+ cur_comm_node = None
332
+ else: # overlapped compute op
333
+ overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004
334
+ overlap_log.debug(
335
+ f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004
336
+ )
337
+
338
+
339
+ def reorder_compute_and_comm_for_overlap(
340
+ snodes: List["scheduler.BaseSchedulerNode"],
341
+ ) -> List["scheduler.BaseSchedulerNode"]:
342
+ order = snodes
343
+ for p in config.reorder_for_compute_comm_overlap_passes:
344
+ if isinstance(p, str) and p in globals():
345
+ p = globals()[p] # it is a builtin pass
346
+ if torch.distributed.get_rank() == 0:
347
+ overlap_log.debug(
348
+ f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004
349
+ )
350
+ try:
351
+ visualize_overlap(order)
352
+ except Exception as e:
353
+ overlap_log.debug(str(e))
354
+ order = p(order) # type: ignore[operator]
355
+ if torch.distributed.get_rank() == 0:
356
+ overlap_log.debug(
357
+ f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004
358
+ )
359
+ try:
360
+ visualize_overlap(order)
361
+ except Exception as e:
362
+ overlap_log.debug(str(e))
363
+ return order
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/config.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os # noqa: C101
2
+ import sys
3
+ from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
4
+
5
+ import torch
6
+
7
+
8
+ def is_fbcode():
9
+ return not hasattr(torch.version, "git_version")
10
+
11
+
12
+ # add some debug printouts
13
+ debug = False
14
+
15
+ # add inf and NaN checkers
16
+ debug_check_inf_and_nan = False
17
+
18
+ # Whether to disable a progress bar for autotuning
19
+ disable_progress = True
20
+
21
+ # Whether to enable printing the source code for each future
22
+ verbose_progress = False
23
+
24
+ # use fx aot graph codegen cache
25
+ fx_graph_cache = os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE") == "1"
26
+
27
+ # use cpp wrapper instead of python wrapper
28
+ cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1"
29
+
30
+ # codegen cpp wrapper code in an ABI compatible mode
31
+ abi_compatible = (
32
+ os.environ.get("TORCHINDUCTOR_ABI_COMPATIBLE", "1" if is_fbcode() else "0") == "1"
33
+ )
34
+
35
+ c_shim_version = os.environ.get(
36
+ "TORCHINDUCTOR_C_SHIM_VERSION", "1" if is_fbcode() else "2"
37
+ )
38
+
39
+ # dead code elimination
40
+ dce = False
41
+
42
+ # assume weight tensors are fixed size
43
+ static_weight_shapes = True
44
+
45
+ # put correctness assertions in generated code
46
+ size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
47
+ nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
48
+
49
+ # enable loop reordering based on input orders
50
+ pick_loop_orders = True
51
+
52
+ # reuse a kernel input as the output
53
+ inplace_buffers = True
54
+
55
+ # reuse a buffer for an unrelated purpose
56
+ allow_buffer_reuse = True
57
+
58
+ # Enable pooled allocations for non-output tensors
59
+ memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1"
60
+
61
+ # How to organize memory under memory_planning=True:
62
+ # - "none": do not try to pool storage, just reuse
63
+ # - "intermediates": all non-outputs share storage, outputs each get unique storage
64
+ # - "outputs": two pools, one for intermediates (freed on return) and one for outputs
65
+ # - "combined": a single pool for both intermediates and outputs
66
+ memory_pool = os.environ.get("TORCHINDUCTOR_MEMORY_POOL", "intermediates")
67
+
68
+ # codegen benchmark harness
69
+ benchmark_harness = True
70
+
71
+ # fuse pointwise into templates
72
+ epilogue_fusion = True
73
+
74
+ # do epilogue fusions before other fusions
75
+ epilogue_fusion_first = False
76
+
77
+ # enable pattern match+replace optimizations
78
+ pattern_matcher = True
79
+
80
+ # register custom graph optimization pass hook. so far, pre/post passes are
81
+ # only applied before/after pattern_matcher in post_grad_passes.
82
+ #
83
+ # def my_custom_pre_pass(graph: torch.fx.graph.Graph):
84
+ # # my custom graph optimization pass
85
+ # ...
86
+ #
87
+ # def my_custom_post_pass(graph: torch.fx.graph.Graph):
88
+ # # my custom graph optimization pass
89
+ # ...
90
+ #
91
+ # torch._inductor.config.post_grad_custom_pre_pass = my_custom_pre_pass
92
+ # torch._inductor.config.post_grad_custom_post_pass = my_custom_post_pass
93
+ post_grad_custom_pre_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
94
+ post_grad_custom_post_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
95
+
96
+ # Registers a custom pregrad pass. Note that the pre-grad IR is 1.
97
+ # non-functional, 2. non-normalized, and 3. prone to change. Ideally we should
98
+ # use post-grad passes.
99
+ pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
100
+
101
+ # Optimize away split cat patterns (Experimental)
102
+ split_cat_fx_passes = True
103
+
104
+ # Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability.
105
+ efficient_conv_bn_eval_fx_passes = False
106
+
107
+ # Enable predispatch aten IR for export
108
+ is_predispatch = False
109
+
110
+ # Deprecated
111
+ group_fusion = False
112
+
113
+ # Deprecated
114
+ batch_fusion = True
115
+
116
+ # Pre grad group/batch fusion and options in order, set to empty dict to disable fusion.
117
+ # Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions.
118
+ pre_grad_fusion_options: Dict[str, Dict[str, Any]] = {
119
+ "batch_linear": {},
120
+ "batch_linear_lhs": {},
121
+ "batch_layernorm": {},
122
+ "batch_tanh": {},
123
+ "batch_relu": {},
124
+ "batch_sigmoid": {},
125
+ }
126
+
127
+ # Post grad group/batch fusion and options, set to empty dict to disable fusion.
128
+ # Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
129
+ post_grad_fusion_options: Dict[str, Dict[str, Any]] = {}
130
+
131
+ # enable reordering pass for improving memory locality
132
+ reorder_for_locality = True
133
+
134
+ # Scale down RBLOCK for better occupancy
135
+ dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1"
136
+
137
+ # this forces fusion for int_mm with mul. Needed when you want to avoid realizing the int32
138
+ # but the mul gets fused with other pointwise ops instead.
139
+ force_fuse_int_mm_with_mul = False
140
+
141
+ # for pattern torch.mm(a, b.to(dtype)) with cuda tensors,
142
+ # enable torch._inductor.kernel.mm.tuned_mixed_mm fused kernel.
143
+ # Autotune will compare perf with normal cast->then->mm option
144
+ use_mixed_mm = False
145
+
146
+ # enable runtime numeric check for pre/post grad fx passes
147
+ # floating point provides limited accuracy (about 7 decimal digits for single precision
148
+ # floating point numbers,about 16 decimal digits for double precision floating point numbers)
149
+ # according to PyTorch documentation.
150
+ # https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations
151
+ fx_passes_numeric_check: Dict[str, Any] = {
152
+ "pre_grad": False,
153
+ "precision": 1e-4,
154
+ "num_iterations": 1,
155
+ "requires_optimizer": True,
156
+ }
157
+
158
+ # for pattern torch.mm(a, b.to(dtype)) with cuda tensors, always use
159
+ # torch._inductor.kernel.mm.tuned_mixed_mm's fused kernel.
160
+ # Autotune will not compare with normal cast->then->mm option.
161
+ # (if force_mixed_mm is true, the use_mixed_mm flag will be ignored)
162
+ force_mixed_mm = False
163
+
164
+ # enable reordering pass for increasing overlap between compute and communication
165
+ reorder_for_compute_comm_overlap = False
166
+
167
+ # passes (in execution order) for increasing overlap between compute and communication
168
+ # for built-in passes, use string name; for user-defined passes, pass in the function handle
169
+ reorder_for_compute_comm_overlap_passes = [
170
+ "reorder_compute_for_overlap",
171
+ "sink_waits",
172
+ "raise_comms",
173
+ ]
174
+
175
+ # runtime estimation function for ops
176
+ # for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
177
+ estimate_op_runtime = "default"
178
+
179
+ # unit: GB/s, uni-directional P2P bandwidth per card
180
+ # default value is NVLink
181
+ intra_node_bw = 300
182
+
183
+ # unit: GB/s, uni-directional P2P bandwidth per node
184
+ # default value is InfiniBand
185
+ inter_node_bw = 25
186
+
187
+ # enable slow autotuning passes to select algorithms
188
+ max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
189
+
190
+ # enable slow autotuning passes to select pointwise/reductions algorithms
191
+ max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1"
192
+
193
+ # enable slow autotuning passes to select gemm algorithms
194
+ max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1"
195
+
196
+ # enable autotune local cache
197
+ use_autotune_local_cache = True
198
+
199
+ # enable autotune remote cache
200
+ use_autotune_remote_cache = (
201
+ os.environ.get("TORCH_INDUCTOR_AUTOTUNE_REMOTE_CACHE") == "1"
202
+ )
203
+
204
+ # force cublas and triton to use the same precision; cublas supports TF32 for matmul operations
205
+ # when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations
206
+ # for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure
207
+ # that triton does not use TF32 wherever cublas would not use TF32
208
+ force_same_precision = (
209
+ True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1"
210
+ )
211
+ # Specify candidate backends for gemm autotune.
212
+ # Possible choices are combinations of: ATen, Triton, CUTLASS.
213
+ # ATen: default Pytorch ATen kernels.
214
+ # Triton: Triton templates defined in torch inductor.
215
+ # CUTLASS: Cutlass templates and kernels.
216
+ max_autotune_gemm_backends = os.environ.get(
217
+ "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON"
218
+ ).upper()
219
+
220
+ # the value used as a fallback for the unbacked SymInts
221
+ # that can appear in the input shapes (e.g., in autotuning)
222
+ unbacked_symint_fallback = 8192
223
+
224
+ # enable searching global and local cache regardless of `max_autotune`
225
+ search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1"
226
+
227
+ save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1"
228
+
229
+ # We will disable creating subprocess for autotuning if this is False
230
+ autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
231
+
232
+ # If autotuning in subprocess, whether to use multiple devices
233
+ autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1"
234
+
235
+ coordinate_descent_tuning = (
236
+ os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1"
237
+ )
238
+ coordinate_descent_check_all_directions = (
239
+ os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1"
240
+ )
241
+ coordinate_descent_search_radius = int(
242
+ os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1")
243
+ )
244
+
245
+ # Disabled by default on ROCm, opt-in if model utilises NHWC convolutions
246
+ layout_opt_default = "1" if not torch.version.hip else "0"
247
+ layout_optimization = (
248
+ os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", layout_opt_default) == "1"
249
+ )
250
+
251
+ force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1"
252
+
253
+
254
+ # Whether to keep the output strides the same as eager after layout optimization.
255
+ keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1"
256
+
257
+ # Enabling this will let compiler print warning messages if a generated triton
258
+ # kernel has inputs with mixed layouts. This is helpful for perf debugging
259
+ # since kernel with mixed layout inputs may run much slower then one whose inputs
260
+ # have uniform layouts.
261
+ warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1"
262
+
263
+ # control store vs recompute heuristic
264
+ # For fanouts, rematerialization can lead to exponential blowup. So, have
265
+ # smaller threshold
266
+ realize_reads_threshold = 4
267
+ realize_opcount_threshold = 30
268
+
269
+ # Threshold to prevent excessive accumulation of ops in one buffer during lowering
270
+ realize_acc_reads_threshold = 8
271
+
272
+ # fallback to eager for random/dropout, this is slow but useful for debugging
273
+ fallback_random = False
274
+
275
+ # automatically create fallbacks when encountering an unhandled op
276
+ implicit_fallbacks = True
277
+
278
+ # fuse even in cases without common reads
279
+ aggressive_fusion = False
280
+
281
+ # For each fused kernel in the wrapper, comment with the nodes that get fused.
282
+ # Useful for debugging fusion.
283
+ debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1"
284
+ benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
285
+ enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "")
286
+
287
+ # how many nodes to allow into a single fusion
288
+ max_fusion_size = 64
289
+
290
+ # max number of inputs to generate cat as a pointwise op with masked laods
291
+ max_pointwise_cat_inputs = 8
292
+
293
+ # replace small reductions with pointwise, disable with `= 1`
294
+ unroll_reductions_threshold = 8
295
+
296
+ # Add extra comments to output code (causes compile cache misses)
297
+ comment_origin = False
298
+
299
+ # Convert 1x1 convs into matmuls
300
+ conv_1x1_as_mm = False
301
+
302
+ # Enable split reductions for better utilization when the dimension
303
+ # being reduced over is large (by splitting it)
304
+ split_reductions = True
305
+
306
+ benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1"
307
+
308
+ # Enable constant and index_expr folding
309
+ constant_and_index_propagation = True
310
+
311
+ # we always add constants into graph.constants without
312
+ # performing any constant-inlining optimization
313
+ always_keep_tensor_constants = False
314
+
315
+ # assert that indirect indexing does not read / write out of bounds
316
+ assert_indirect_indexing = True
317
+
318
+ # constant folding on the joint graph
319
+ joint_graph_constant_folding = True
320
+
321
+ # Enable indirect_indexing asserts for decompositions and lowerings
322
+ debug_index_asserts = False
323
+
324
+ # warnings intended for PyTorch developers, disable for point releases
325
+ is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__
326
+ developer_warnings = is_fbcode() or is_nightly_or_source
327
+
328
+ # The multiprocessing start method to use for inductor workers in the codecache.
329
+ # TODO: fork is not safe in a multithreaded environment, we should evaluate changing
330
+ # the default to spawn.
331
+ worker_start_method = "fork"
332
+
333
+
334
+ def decide_compile_threads():
335
+ """
336
+ Here are the precedence to decide compile_threads
337
+ 1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by
338
+ setting this to 1 to make pdb happy.
339
+ 2. Set to 1 if it's win32 platform or it's a fbcode build
340
+ 3. decide by the number of CPU cores
341
+ """
342
+ if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
343
+ return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
344
+ elif sys.platform == "win32" or is_fbcode():
345
+ return 1
346
+ else:
347
+ cpu_count = (
348
+ len(os.sched_getaffinity(0))
349
+ if hasattr(os, "sched_getaffinity")
350
+ else os.cpu_count()
351
+ )
352
+ assert cpu_count
353
+ return min(32, cpu_count)
354
+
355
+
356
+ compile_threads = decide_compile_threads()
357
+
358
+ # gemm autotuning global cache dir
359
+ if is_fbcode():
360
+ from libfb.py import parutil
361
+
362
+ try:
363
+ if __package__:
364
+ global_cache_dir = parutil.get_dir_path(
365
+ os.path.join(__package__.replace(".", os.sep), "fb/cache")
366
+ )
367
+ else:
368
+ global_cache_dir = parutil.get_dir_path("fb/cache")
369
+ except ValueError:
370
+ global_cache_dir = None
371
+ else:
372
+ global_cache_dir = None
373
+
374
+ # If kernel is fused, the name is generated from the origin node op names
375
+ # for larger kernels limit this
376
+ kernel_name_max_ops = 10
377
+
378
+ # Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs
379
+ shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1"
380
+
381
+ # Fx-based linear/matmul/bmm + permute/transpose vertical fusion
382
+ permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"
383
+
384
+ # Mark the wrapper call in PyTorch profiler
385
+ profiler_mark_wrapper_call = False
386
+
387
+ # Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for
388
+ # every intermediate for which we can correlate it with an intermediate
389
+ # from the original FX graph
390
+ generate_intermediate_hooks = False
391
+
392
+ # Populate traceback field on IRNode; good for debugging why origin_node is
393
+ # not populated, or finding out where an IRNode was constructed
394
+ debug_ir_traceback = False
395
+
396
+ # used for debugging to make sure config is properly set
397
+ _raise_error_for_testing = False
398
+
399
+ _profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "")
400
+ profile_bandwidth = _profile_var != ""
401
+ profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var
402
+ # Specify a file where we print out the profiling results.
403
+ # None means we do not dump results to a file.
404
+ profile_bandwidth_output = os.environ.get("TORCHINDUCTOR_PROFILE_OUTPUT", None)
405
+
406
+ # TODO: remove later
407
+ disable_cpp_codegen = False
408
+
409
+
410
+ # Freezing will attempt to inline weights as constants in optimization
411
+ # and run constant folding and other optimizations on them. After freezing, weights
412
+ # can no longer be updated.
413
+ freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1"
414
+
415
+ # Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead
416
+ # of potentially keeping multiple copies of weights.
417
+ freezing_discard_parameters: bool = False
418
+
419
+ # Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests
420
+ # should be run with this flag both on and off to make sure we have coverage.
421
+ allow_stack_allocation: bool = (
422
+ os.environ.get("TORCHINDUCTOR_STACK_ALLOCATION", "1") == "1"
423
+ )
424
+
425
+ # Enables an alternate DSO interface (the "minimal ArrayRef interface") intended
426
+ # to maximize performance for use cases that it can accommodate at the expense of
427
+ # generality. In brief:
428
+ # - inputs and outputs are ArrayRefTensor<T> (note that strides are required, but the
429
+ # tensor must be contiguous)
430
+ # - constant handling is unchanged because it is not a per-inference-iteration bottleneck
431
+ #
432
+ # When the DSO is generated in this mode, the usual interface will also be supported,
433
+ # but performance for that interface may be degraded.
434
+ use_minimal_arrayref_interface: bool = False
435
+
436
+ # decompose some memory bound matmul/bmm to mul
437
+ decompose_mem_bound_mm: bool = False
438
+
439
+
440
+ # config specific to codegen/cpp.py
441
+ class cpp:
442
+ # set to torch.get_num_threads()
443
+ threads = -1
444
+
445
+ # Do not generate loops when the condition doesn't hold, like:
446
+ # for(long i0=4096; i0<4096; i0+=1)
447
+ no_redundant_loops = True
448
+
449
+ # Assume number of threads is dynamic, don't specialize thread number.
450
+ # Kernels don't recompile on thread number changes with this flag on.
451
+ # For single-threaded workload, turning it on would incur a slight
452
+ # performance degradation.
453
+ dynamic_threads = False
454
+
455
+ simdlen: Optional[int] = None
456
+ min_chunk_size = 4096
457
+ cxx = (
458
+ None, # download gcc12 from conda-forge if conda is installed
459
+ # "g++-12",
460
+ # "g++-11",
461
+ # "g++-10",
462
+ # "clang++",
463
+ os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"),
464
+ # "g++.par",
465
+ )
466
+ # Allow kernel performance profiling via PyTorch profiler
467
+ enable_kernel_profile = False
468
+
469
+ # enable weight prepacking to get a better performance; may lead to large memory footprint
470
+ weight_prepack = True
471
+
472
+ # Inject a bug into our relu implementation; useful for testing our repro
473
+ # extraction and minification functionality.
474
+ # Valid values: "compile_error", "runtime_error", "accuracy"
475
+ inject_relu_bug_TESTING_ONLY: Optional[str] = None
476
+ inject_log1p_bug_TESTING_ONLY: Optional[str] = None
477
+
478
+ # If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise,
479
+ # force usage as specified, without testing.
480
+ vec_isa_ok: Optional[bool] = None
481
+
482
+ # similar to config.triton.descriptive_names
483
+ descriptive_names = "original_aten"
484
+
485
+ # how many nodes to allow into a single horizontal fusion
486
+ max_horizontal_fusion_size = 16
487
+
488
+ # Make scatter_reduce fallback when reduce is sum to avoid performance regression
489
+ # using atomic_add.
490
+ fallback_scatter_reduce_sum = True
491
+
492
+ # Use funsafe-math-optimizations when compiling
493
+ enable_unsafe_math_opt_flag = False
494
+
495
+ # Use ffp-contract when compiling
496
+ enable_floating_point_contract_flag = False
497
+
498
+
499
+ # config specific to codegen/triton.py
500
+ class triton:
501
+ # Use cudagraphs on output code
502
+ cudagraphs = False
503
+
504
+ # Use cudagraph trees for memory pooling if `cudagraphs` is True
505
+ cudagraph_trees = True
506
+
507
+ # assertions not on the fast path, steady state
508
+ slow_path_cudagraph_asserts = True
509
+
510
+ # TODO - need to debug why this prevents cleanup
511
+ cudagraph_trees_history_recording = False
512
+
513
+ # assertions on the fast path
514
+ fast_path_cudagraph_asserts = False
515
+
516
+ # skip warmup for cudagraph trees
517
+ skip_cudagraph_warmup = False
518
+
519
+ # Synchronize before and after every compiled graph.
520
+ debug_sync_graph = False
521
+
522
+ # Synchronize after every kernel launch, to help pinpoint bugs
523
+ debug_sync_kernel = False
524
+
525
+ # Always load full blocks (rather than broadcasting inside the block)
526
+ dense_indexing = False
527
+
528
+ # limit tiling dimensions
529
+ max_tiles = 2
530
+
531
+ # use triton.autotune for pointwise ops with complex layouts
532
+ # this should only be disabled for debugging/testing
533
+ autotune_pointwise = True
534
+
535
+ # max autotune gemm with cublasLt
536
+ autotune_cublasLt = True
537
+
538
+ # should we stop a fusion to allow better tiling?
539
+ tiling_prevents_pointwise_fusion = True
540
+ tiling_prevents_reduction_fusion = True
541
+
542
+ # should we give different names to kernels
543
+ # Note: This is orthogonal to descriptive_names - this is deciding whether
544
+ # our triton kernel names should all be `triton_` (to maximize caching) or
545
+ # whether they should be unique.
546
+ unique_kernel_names = os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1"
547
+
548
+ # should we put op names in kernel names
549
+ # False: No special names (just triton__1, triton__2, etc.)
550
+ # "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.)
551
+ # "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions)
552
+ # "inductor_node": Maps to the node name in the FX graph passed to Inductor
553
+ descriptive_names = "original_aten"
554
+
555
+ # use alternate codegen for smaller reductions
556
+ persistent_reductions = (
557
+ os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1"
558
+ )
559
+
560
+ # 0/False: disable
561
+ # 1/True: enable, use tuning to pick between different subkernels
562
+ # 2: enable, force using persistent reduction (for debugging)
563
+ # 3: enable, force using non-persistent reduction (for debugging)
564
+ multi_kernel = int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0"))
565
+
566
+ # hint to Triton when arguments are divisible by 16
567
+ divisible_by_16 = True
568
+
569
+ # theses are not enforced, but they are used by asserts in triton_heuristics.py
570
+ # NOTE: mobilevit_s in timm_models required X to be set to the higher value 2048
571
+
572
+ # Max RBLOCK will be large for multi-kernel since we do more aggressive
573
+ # persistent reduction.
574
+ max_block = {
575
+ "X": 2048,
576
+ "Y": 1024,
577
+ "Z": 1024,
578
+ "R": 4096 * (16 if multi_kernel else 1),
579
+ }
580
+
581
+ # Minimum RBLOCK to be used for a TritonSplitScanKernel
582
+ # NOTE: This also indirectly controls the size of workspace buffer required
583
+ min_split_scan_rblock = 256
584
+
585
+ # Store the generated cubin files for cpp wrapper code to load
586
+ store_cubin = False
587
+
588
+ # the max number of spills we allow for the configs we benchmark.
589
+ # Setting this to 0 means we skip a config if it spills even a single
590
+ # register.
591
+ # Setting it to a larger value allows a config spilling a small amount
592
+ # of registers being benchmarked.
593
+ #
594
+ # NOTE: triton will always report >0 register spills for kernels using sin/cos.
595
+ # (check this issue https://github.com/openai/triton/issues/1756 )
596
+ # So far we see a fixed 8 spilled registers for kernels using sin/cos.
597
+ # Raise the threshold to 16 to be safe.
598
+ # We should revisit this once we understand more of the source of register spills.
599
+ spill_threshold: int = 16
600
+
601
+ # Generate code containing the newer tl.make_block_ptr() API for loads/store
602
+ use_block_ptr = False
603
+
604
+ # Inject a bug into our relu implementation; useful for testing our repro
605
+ # extraction and minification functionality.
606
+ # Valid values: "compile_error", "runtime_error", "accuracy"
607
+ inject_relu_bug_TESTING_ONLY: Optional[str] = None
608
+
609
+
610
+ class aot_inductor:
611
+ # AOTInductor output path
612
+ # If an absolute path is specified, the generated lib files will be stored under the directory;
613
+ # If a relative path is specified, it will be used as a subdirectory under the default caching path;
614
+ # If not specified, a temp directory will be created under the default caching path.
615
+ # If the specified path contains something like "model.so", the sub-string will be used
616
+ # to name the generated library.
617
+ output_path = ""
618
+
619
+ debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1"
620
+
621
+ # Serialized tree spec for flattening inputs
622
+ serialized_in_spec = ""
623
+
624
+ # Serialized tree spec for flattening outputs
625
+ serialized_out_spec = ""
626
+
627
+ # flag to decide whether to create a submodule for constant graph.
628
+ use_runtime_constant_folding: bool = False
629
+
630
+
631
+ class cuda:
632
+ # CUDA arch to use for CUDA template kernel compilation.
633
+ # e.g. "70", "75", "80", "90", etc.
634
+ # When arch is None, Inductor uses torch.cuda.get_device_capability(0).
635
+ arch: Optional[str] = None
636
+
637
+ # CUDA version to use for CUDA template kernel compilation.
638
+ # e.g. "11.4", "12.1", etc.
639
+ # When version is None, Inductor uses torch.version.cuda.
640
+ version: Optional[str] = None
641
+
642
+ # Optimization level for the host compiler.
643
+ compile_opt_level = "-O1"
644
+
645
+ # Whether to enable device LTO (link-time-optimization).
646
+ enable_cuda_lto = False
647
+
648
+ # Whether to keep intermediate files dring compilation.
649
+ enable_ptxas_info = False
650
+
651
+ # Whether to enable debug info, e.g. line number, cutlass debug info.
652
+ enable_debug_info = False
653
+
654
+ # Whether to use fast math.
655
+ use_fast_math = False
656
+
657
+ # Path to the CUTLASS repo root directory.
658
+ # The default path only works under PyTorch local development environment.
659
+ cutlass_dir = os.environ.get(
660
+ "TORCHINDUCTOR_CUTLASS_DIR",
661
+ os.path.abspath(
662
+ os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/")
663
+ ),
664
+ )
665
+
666
+ # Configures the maximum number of CUTLASS configs to profile in max_autotune.
667
+ # By default it's None, so that all CUTLASS configs are tuned.
668
+ # This is mainly used to reduce test time in CI.
669
+ cutlass_max_profiling_configs: Optional[int] = None
670
+
671
+ # Path to CUDA NVCC.
672
+ # NVCC search order:
673
+ # 1) cuda_cxx set in this config
674
+ # 2)CUDACXX environment variable
675
+ # 3)CUDA_HOME environment variable
676
+ # 4) default system search PATH.
677
+ cuda_cxx: Optional[str] = None
678
+
679
+ # If set to True, it will ensure that only GEMM ops capable of
680
+ # epilogue fusion via CUTLASS Epilogue Visitor Trees ( EVT )
681
+ # are enabled for the CUTLASS backend.
682
+ cutlass_only_evt_capable_ops: bool = False
683
+
684
+
685
+ # create a directory containing lots of debug information
686
+ class trace:
687
+ # master switch for all debugging flags below
688
+ enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
689
+
690
+ # Save debug information to a temporary directory
691
+ # If not specified, a temp directory will be created by system
692
+ debug_dir: Optional[str] = None
693
+
694
+ # Save python logger call >=logging.DEBUG
695
+ debug_log = False
696
+
697
+ # Save python logger call >=logging.INFO
698
+ info_log = False
699
+
700
+ # Save input FX graph (post decomps, pre optimization)
701
+ fx_graph = True
702
+
703
+ # Save FX graph after transformations
704
+ fx_graph_transformed = True
705
+
706
+ # Save TorchInductor IR before fusion pass
707
+ ir_pre_fusion = True
708
+
709
+ # Save TorchInductor IR after fusion pass
710
+ ir_post_fusion = True
711
+
712
+ # Copy generated code to trace dir
713
+ output_code = True
714
+
715
+ # SVG figure showing post-fusion graph
716
+ graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1"
717
+
718
+ # SVG figure showing fx with fusion
719
+ draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1"
720
+
721
+ # We draw our fx graphs with the "record" shape attribute by default.
722
+ # Sometimes, when the graph is very complex, we may hit dot errors like below:
723
+ # "flat edge between adjacent nodes one of which has a record shape -
724
+ # replace records with HTML-like labels"
725
+ # and thus fail to generate a graph. So, let's give the user an option
726
+ # to specify the shape attribute for the dot graph. For example, passing
727
+ # INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like lables
728
+ # to workaround the above failure.
729
+ dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None)
730
+
731
+ # Store cProfile (see snakeviz to view)
732
+ compile_profile = False
733
+
734
+ # Upload the .tar.gz file
735
+ # Needs to be overriden based on specific environment needs
736
+ upload_tar: Optional[Callable[[str], None]] = None
737
+
738
+ log_autotuning_results: bool = False
739
+
740
+
741
+ _save_config_ignore = {
742
+ # workaround: "Can't pickle <function ...>"
743
+ "trace.upload_tar",
744
+ }
745
+
746
+ if TYPE_CHECKING:
747
+ from torch.utils._config_typing import * # noqa: F401, F403
748
+
749
+ from torch.utils._config_module import install_config_module
750
+
751
+ # adds patch, save_config, etc
752
+ install_config_module(sys.modules[__name__])
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ from typing import Any, Callable, Dict, Optional
3
+
4
+ import torch
5
+ import torch.utils._pytree as pytree
6
+
7
+ aten = torch.ops.aten
8
+
9
+ # We would like to split modules into two subgraphs for runtime weight updates to work correctly.
10
+ # The use case and more information could be found at:
11
+ # https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing
12
+ META_TAG = "MODULE_TYPE"
13
+ MODULE_TAG = "_MAIN_MODULE"
14
+ CONST_MODULE_TAG = "_CONST_MODULE"
15
+
16
+
17
+ def replace_node_with_constant(gm, node, constant, name=None):
18
+ g = gm.graph
19
+
20
+ if name:
21
+ qualname = name
22
+ else:
23
+ if not hasattr(gm, "_frozen_param_count"):
24
+ gm._frozen_param_count = 0
25
+ i = gm._frozen_param_count
26
+
27
+ while True:
28
+ qualname = f"_frozen_param{i}"
29
+ if not hasattr(gm, qualname):
30
+ break
31
+ i += 1
32
+
33
+ gm._frozen_param_count = i + 1
34
+
35
+ with g.inserting_before(node):
36
+ new_input_node = g.create_node("get_attr", qualname, (), {})
37
+ node.replace_all_uses_with(new_input_node)
38
+ new_input_node.meta.update(node.meta)
39
+ g.erase_node(node)
40
+
41
+ # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
42
+ gm.register_buffer(qualname, constant)
43
+ setattr(gm, qualname, constant)
44
+
45
+
46
+ class ConstantFolder(torch.fx.Interpreter):
47
+ def __init__(
48
+ self,
49
+ gm,
50
+ skip_constructors=False,
51
+ ):
52
+ super().__init__(gm)
53
+ self.node_replacements: Dict[torch.fx.Node, Any] = {}
54
+ self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
55
+ self.unknown_value = object()
56
+ self.skip_constructors: bool = skip_constructors
57
+
58
+ # overwrite this to deallocate env values if their only remaining use
59
+ # is the output
60
+ self.user_to_last_uses = self.node_to_last_non_output_use()
61
+
62
+ def is_impure(self, node: torch.fx.node.Node):
63
+ if node.target in [
64
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
65
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
66
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
67
+ ]:
68
+ # For the pattern fp32_weight -> q -> dq
69
+ # We only folding fp32_weight -> q
70
+ # int8_weight and leave dq in graph to be fused
71
+ return True
72
+ return False
73
+
74
+ def node_to_last_non_output_use(self):
75
+ last_non_output_use = collections.defaultdict(list)
76
+ seen_uses = set()
77
+ output_node = next(iter(reversed(self.module.graph.nodes)))
78
+
79
+ for node in reversed(self.module.graph.nodes):
80
+ if node.target == "output":
81
+ continue
82
+
83
+ def add_use(inp):
84
+ if inp in seen_uses:
85
+ return
86
+
87
+ seen_uses.add(inp)
88
+ last_non_output_use[node].append(inp)
89
+
90
+ pytree.tree_map_only(torch.fx.Node, add_use, (node.args, node.kwargs))
91
+
92
+ # if this node is only used in output, we want to gc it right away
93
+ if len(node.users) == 1 and output_node in node.users:
94
+ last_non_output_use[node].append(node)
95
+
96
+ return last_non_output_use
97
+
98
+ def run_node(self, node):
99
+ if node.target == "output":
100
+ # because we remove nodes from env on last non output use,
101
+ # re-define them now or we'll get error in interpreter
102
+ def set_env(arg):
103
+ self.env[arg] = self.unknown_value
104
+
105
+ pytree.tree_map_only(torch.fx.Node, set_env, node.args)
106
+ return super().run_node(node)
107
+
108
+ args, kwargs = self.fetch_args_kwargs_from_env(node)
109
+ flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
110
+
111
+ if self.unknown_value in flattened_inputs:
112
+ return self.unknown_value
113
+
114
+ # TODO - fix errors with this
115
+ if (
116
+ node.op == "call_function"
117
+ and node.target == aten._efficientzerotensor.default
118
+ ):
119
+ return self.unknown_value
120
+
121
+ # TODO - constant folding triton kernel returns the inputs -- fix this
122
+ if (
123
+ node.op == "call_function"
124
+ and node.name == "triton_kernel_wrapper_functional_proxy"
125
+ ):
126
+ return self.unknown_value
127
+
128
+ # skip constructors, since inductor generates optimal code for them already
129
+ # and turning into tensor would result in an additional global memory read
130
+ # TODO - more complicated strategy
131
+ if (
132
+ self.skip_constructors
133
+ and node.op != "get_attr"
134
+ and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
135
+ ):
136
+ return self.unknown_value
137
+
138
+ # All mutations should either be removed or on inputs which we did not make constant
139
+ if (
140
+ isinstance(node.target, torch._ops.OpOverload)
141
+ and torch.Tag.nondeterministic_seeded in node.target.tags
142
+ ):
143
+ return self.unknown_value
144
+
145
+ out = super().run_node(node)
146
+
147
+ if node.op != "get_attr" and isinstance(out, torch.Tensor):
148
+ if not self.insertable_tensor_check(out):
149
+ return out
150
+
151
+ if self.is_impure(node):
152
+ return self.unknown_value
153
+
154
+ self.add_node_replacement(node, out)
155
+
156
+ flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
157
+
158
+ for n in flattened_node_inps:
159
+ if not isinstance(n, torch.fx.Node):
160
+ continue
161
+
162
+ self.replaced_uses[n] += 1
163
+
164
+ for to_delete in self.user_to_last_uses.get(node, []):
165
+ if self.replaced_uses[to_delete] == len(to_delete.users):
166
+ self.node_replacements.pop(to_delete, None)
167
+
168
+ return out
169
+
170
+ def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
171
+ return True
172
+
173
+ def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
174
+ self.node_replacements[node] = tensor
175
+
176
+ def run(self):
177
+ env = {}
178
+ for n in self.module.graph.nodes:
179
+ if n.op == "placeholder":
180
+ env[n] = self.unknown_value
181
+ return super().run(initial_env=env)
182
+
183
+
184
+ @torch.utils._python_dispatch._disable_current_modes()
185
+ def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None):
186
+ cf = ConstantFolder(gm, skip_constructors=True)
187
+ cf.run()
188
+
189
+ for node, constant in cf.node_replacements.items():
190
+ if constraint_fn is not None and not constraint_fn(node):
191
+ continue
192
+ replace_node_with_constant(gm, node, constant)
193
+
194
+ erased_params = []
195
+ for node in gm.graph.nodes:
196
+ if node.op == "get_attr" and len(node.users) == 0:
197
+ if hasattr(gm, node.target):
198
+ delattr(gm, node.target)
199
+ erased_params.append(node)
200
+
201
+ for node in erased_params:
202
+ gm.graph.erase_node(node)
203
+
204
+ gm.graph.eliminate_dead_code()
205
+ gm.graph.lint()
206
+ gm.recompile()
207
+
208
+
209
+ @torch.utils._python_dispatch._disable_current_modes()
210
+ def constant_graph_tag(gm: torch.fx.GraphModule):
211
+ cf = ConstantFolder(gm, skip_constructors=True)
212
+ cf.run()
213
+
214
+ for node in gm.graph.nodes:
215
+ if (
216
+ node.op == "get_attr"
217
+ or node in cf.node_replacements
218
+ or node in cf.replaced_uses
219
+ ):
220
+ node.meta[META_TAG] = CONST_MODULE_TAG
221
+ else:
222
+ node.meta[META_TAG] = MODULE_TAG
223
+
224
+
225
+ def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
226
+ """
227
+ Construct a GraphModule which corresponds to the part which could be
228
+ constant folded in provided gm.
229
+ """
230
+
231
+ constant_graph_tag(gm)
232
+ # We rewrite the tags, if it's a constant being directly consumed, without
233
+ # any folding opportunity, we keep it in main gm.
234
+ for node in gm.graph.nodes:
235
+ if node.op == "get_attr":
236
+ used_to_fold = False
237
+ for u in node.users:
238
+ if u.meta[META_TAG] == CONST_MODULE_TAG:
239
+ used_to_fold = True
240
+ break
241
+ if not used_to_fold:
242
+ node.meta[META_TAG] = MODULE_TAG
243
+
244
+ new_graph = torch.fx.Graph()
245
+
246
+ node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
247
+ output_nodes = []
248
+ for node in gm.graph.nodes:
249
+ if node.meta[META_TAG] == MODULE_TAG:
250
+ continue
251
+
252
+ new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
253
+ node_remapping[node] = new_node
254
+
255
+ for user in node.users:
256
+ if user.meta[META_TAG] == MODULE_TAG:
257
+ output_nodes.append(new_node)
258
+ break
259
+
260
+ new_graph.output(tuple(output_nodes))
261
+ new_graph.lint()
262
+ new_gm = torch.fx.GraphModule(gm, new_graph)
263
+
264
+ return new_gm
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/graph.py ADDED
@@ -0,0 +1,1324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import logging
3
+ import operator
4
+ import os
5
+ import re
6
+ import sys
7
+ import time
8
+ from collections import defaultdict
9
+ from contextlib import contextmanager
10
+ from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple
11
+
12
+ import sympy
13
+
14
+ import torch
15
+ import torch._logging
16
+ import torch.fx
17
+ from torch._decomp import get_decompositions
18
+ from torch._dynamo.utils import defake, dynamo_timed
19
+ from torch._logging import LazyString, trace_structured
20
+ from torch._subclasses.fake_tensor import FakeTensor
21
+ from torch.fx.experimental._backward_state import BackwardState
22
+ from torch.fx.experimental.sym_node import magic_methods, method_to_operator
23
+ from torch.fx.experimental.symbolic_shapes import has_free_symbols, ShapeEnv, SymTypes
24
+ from torch.utils._mode_utils import no_dispatch
25
+
26
+ from . import config, ir
27
+ from .codegen.common import (
28
+ DeviceOpOverrides,
29
+ get_device_op_overrides,
30
+ get_scheduling_for_device,
31
+ get_wrapper_codegen_for_device,
32
+ register_backend_for_device,
33
+ )
34
+ from .codegen.cpp_wrapper_cpu import CppWrapperCpu
35
+ from .codegen.cpp_wrapper_cuda import CppWrapperCuda
36
+ from .codegen.wrapper import WrapperCodeGen
37
+ from .exc import (
38
+ CppWrapperCodeGenError,
39
+ LoweringException,
40
+ MissingOperatorWithDecomp,
41
+ MissingOperatorWithoutDecomp,
42
+ )
43
+ from .ir import (
44
+ Constant,
45
+ FixedLayout,
46
+ InputBuffer,
47
+ Pointwise,
48
+ Reduction,
49
+ StorageBox,
50
+ TensorBox,
51
+ )
52
+ from .lowering import (
53
+ constrain_to_fx_strides,
54
+ FALLBACK_ALLOW_LIST,
55
+ fallback_handler,
56
+ fallback_node_due_to_unsupported_type,
57
+ layout_constraints,
58
+ lowerings,
59
+ make_fallback,
60
+ needs_realized_inputs,
61
+ unsupported_output_tensor,
62
+ )
63
+ from .sizevars import SizeVarAllocator
64
+ from .utils import convert_shape_to_inductor, gather_origins, get_sympy_Expr_dtype
65
+ from .virtualized import V
66
+
67
+ log = logging.getLogger(__name__)
68
+ perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
69
+ output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
70
+
71
+
72
+ if config.is_fbcode():
73
+ from torch._inductor.fb.utils import log_module_code
74
+ else:
75
+
76
+ def log_module_code(*args, **kwargs):
77
+ pass
78
+
79
+
80
+ def supported_dtype_of_cpp_wrapper(dtype, cuda):
81
+ supported_dtype = {
82
+ torch.float32,
83
+ torch.float64,
84
+ torch.int64,
85
+ torch.int32,
86
+ torch.int16,
87
+ torch.int8,
88
+ torch.uint8,
89
+ torch.bool,
90
+ torch.bfloat16,
91
+ torch.complex32,
92
+ torch.complex64,
93
+ torch.complex128,
94
+ torch.float16,
95
+ }
96
+ if cuda:
97
+ supported_dtype.add(torch.float8_e4m3fn)
98
+ supported_dtype.add(torch.float8_e5m2)
99
+ supported_dtype.add(torch.float8_e4m3fnuz)
100
+ supported_dtype.add(torch.float8_e5m2fnuz)
101
+
102
+ return dtype in supported_dtype
103
+
104
+
105
+ def may_get_constant_buffer_dtype(constant_buffer):
106
+ assert isinstance(
107
+ constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
108
+ ), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
109
+ if isinstance(constant_buffer, sympy.core.numbers.Integer):
110
+ return torch.int64
111
+
112
+ if isinstance(constant_buffer, sympy.Expr):
113
+ return get_sympy_Expr_dtype(constant_buffer)
114
+
115
+ if constant_buffer.is_integer:
116
+ return torch.int64
117
+ elif constant_buffer.is_float:
118
+ return torch.float32
119
+ else:
120
+ return None
121
+
122
+
123
+ def is_magic_method(op):
124
+ magic_ops = {method_to_operator(m) for m in magic_methods}
125
+ return op in magic_ops
126
+
127
+
128
+ def getattr_recursive(obj, target):
129
+ target_atoms = target.split(".")
130
+ attr_itr = obj
131
+ for i, atom in enumerate(target_atoms):
132
+ if not hasattr(attr_itr, atom):
133
+ raise RuntimeError(
134
+ f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
135
+ )
136
+ attr_itr = getattr(attr_itr, atom)
137
+ return attr_itr
138
+
139
+
140
+ class GraphLowering(torch.fx.Interpreter):
141
+ graph_outputs: List[ir.IRNode]
142
+
143
+ def symbolic_sizes_strides(self, ex: torch.Tensor):
144
+ """
145
+ Support dynamic shapes and dynamic strides by assigning variables
146
+ to each dimension. We duck-shape tensors, so if two tensors
147
+ have the same size they get assigned the same symbolic variable.
148
+ """
149
+ if self.reuse_shape_env:
150
+ return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor(
151
+ ex.stride()
152
+ )
153
+ else:
154
+ from torch._dynamo.source import ConstantSource
155
+
156
+ # TODO: this should not be needed once #93059 lands
157
+ # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816
158
+ # TODO: make a dedicated UnknownSource for this?
159
+ # NB: This is using the legacy default behavior from
160
+ # create_symbolic_sizes_strides_storage_offset but we hope we can
161
+ # just delete this entirely
162
+ source = ConstantSource(
163
+ f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}"
164
+ )
165
+ (
166
+ size,
167
+ stride,
168
+ _,
169
+ ) = self._shape_env.create_symbolic_sizes_strides_storage_offset(
170
+ ex,
171
+ source,
172
+ )
173
+
174
+ size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size]
175
+ stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride]
176
+ return size, stride
177
+
178
+ def static_sizes_strides(self, ex: torch.Tensor):
179
+ """
180
+ Primarily used to weights
181
+ """
182
+ size = [sympy.Integer(i) for i in ex.size()]
183
+ stride = [sympy.Integer(i) for i in ex.stride()]
184
+ return size, stride
185
+
186
+ def init_backend_registration(self):
187
+ if get_scheduling_for_device("cpu") is None:
188
+ from .codegen.cpp import CppScheduling
189
+
190
+ register_backend_for_device("cpu", CppScheduling, WrapperCodeGen)
191
+
192
+ if get_scheduling_for_device("cuda") is None:
193
+ from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
194
+
195
+ # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
196
+ register_backend_for_device("cuda", CUDACombinedScheduling, WrapperCodeGen)
197
+
198
+ def __init__(
199
+ self,
200
+ gm: torch.fx.GraphModule,
201
+ example_inputs: Optional[List[torch.Tensor]] = None,
202
+ shape_env=None,
203
+ num_static_inputs=None,
204
+ graph_id=None,
205
+ cpp_wrapper=False,
206
+ aot_mode=False,
207
+ user_visible_outputs=frozenset(),
208
+ layout_opt=None,
209
+ extern_node_serializer=None,
210
+ is_inference=False,
211
+ is_const_graph=False,
212
+ const_output_index=None,
213
+ const_code=None,
214
+ const_module=None,
215
+ name=None,
216
+ ):
217
+ super().__init__(gm)
218
+
219
+ self.example_inputs = example_inputs
220
+ self.layout_opt = (
221
+ layout_opt
222
+ if layout_opt is not None
223
+ else self.decide_layout_opt(gm, is_inference=is_inference)
224
+ )
225
+ self.num_channels_last_conv = 0
226
+ self.is_inference = is_inference
227
+ self.is_const_graph = is_const_graph
228
+ self.const_code = const_code
229
+ self.const_module = const_module
230
+
231
+ self.extra_traceback = False # we do our own error wrapping
232
+ if shape_env is None:
233
+ shape_env = ShapeEnv()
234
+ self.reuse_shape_env = False
235
+ else:
236
+ self._shape_env = shape_env
237
+ self.reuse_shape_env = True
238
+ self._shape_env = shape_env
239
+ self.sizevars = SizeVarAllocator(shape_env)
240
+ self.graph_input_names: List[str] = []
241
+ self.graph_inputs: Dict[str, TensorBox] = {}
242
+ self.graph_inputs_original: Dict[str, InputBuffer] = {}
243
+ self.device_types: Set[str] = (
244
+ const_module.device_types if const_module else set()
245
+ )
246
+ self.device_idxs: Set[int] = const_module.device_idxs if const_module else set()
247
+ self.cuda = False
248
+ self.buffers: List[ir.Buffer] = []
249
+ self.const_output_index: Dict[str, int] = (
250
+ const_output_index if const_output_index else {}
251
+ )
252
+ self.folded_constants: Set[str] = (
253
+ set(const_output_index.keys()) if const_output_index else set()
254
+ )
255
+ self.constants: Dict[str, torch.Tensor] = (
256
+ const_module.constants if const_module else {}
257
+ )
258
+ self.constant_reprs: Dict[str, str] = {}
259
+ self.removed_buffers: Set[str] = set()
260
+ self.removed_inplace_buffers: Set[str] = set()
261
+ self.mutated_buffers: Set[str] = set()
262
+ self.never_reuse_buffers: Set[str] = set()
263
+ self.inplaced_to_remove: Set[str] = set()
264
+ self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
265
+ self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment]
266
+ # See `ProxyExecutor Design Note` in ir.py for more details
267
+ self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
268
+ self.extern_node_serializer: Optional[
269
+ Callable[[List[ir.ExternKernelNode]], Any]
270
+ ] = extern_node_serializer
271
+ self.current_node: torch.fx.Node = None # type: ignore[assignment]
272
+ self.num_static_inputs = num_static_inputs
273
+ self.lists: Dict[str, List[str]] = {}
274
+ self.mutated_inputs: Set[str] = set()
275
+ self.mutated_input_idxs: List[int] = []
276
+ self.name_to_buffer: Dict[str, ir.Buffer] = {}
277
+ self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
278
+ self.creation_time = time.time()
279
+ self.name = name
280
+ self.cpp_wrapper = cpp_wrapper
281
+
282
+ # record multi_kernel choice for cpp_wrapper so the second pass knows
283
+ # which sub-kernel is picked. Copy cpp_wrapper to another variable
284
+ # since cpp_wrapper flag is set to false for the first pass of codegen.
285
+ self.record_multi_kernel_choice = cpp_wrapper
286
+ self.multi_kernel_to_choice: Dict[str, int] = {}
287
+
288
+ self.aot_mode = aot_mode
289
+ self.graph_id = graph_id
290
+ self.scheduler: "torch._inductor.scheduler.Scheduler" = None # type: ignore[assignment]
291
+ self.nodes_prefer_channels_last = (
292
+ self.find_nodes_prefer_channels_last() if self.layout_opt else set()
293
+ )
294
+ self._warned_fallback = {"aten.convolution_backward"}
295
+ self.user_visible_outputs = user_visible_outputs
296
+ self.cache_key: str = "" # This is the cache key for the compiled artifact
297
+ self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
298
+ self.cache_linemap: List[
299
+ Tuple[int, str]
300
+ ] = (
301
+ []
302
+ ) # This is the linemap used by the profiler to mark custom compiled kernels getting run
303
+ # Used if lowering encounters cases where cudagraphs are not supported
304
+ self.disable_cudagraphs_reason: Optional[str] = None
305
+
306
+ # only keeping one node per device for stack trace purposes
307
+ self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
308
+ self.orig_gm: torch.fx.GraphModule = gm.__copy__()
309
+ self.dynamo_flat_name_to_original_fqn = self.module.meta.get(
310
+ "dynamo_flat_name_to_original_fqn", {}
311
+ )
312
+ self.allocated_constant_name = (
313
+ const_module.allocated_constant_name if const_module is not None else {}
314
+ )
315
+ self.init_backend_registration()
316
+
317
+ @staticmethod
318
+ def decide_layout_opt(gm, *, is_inference) -> bool:
319
+ """
320
+ Decide if we should enable layout optimization for this graph based on
321
+ heuristics.
322
+ """
323
+ if not config.layout_optimization:
324
+ return False
325
+
326
+ if config.force_layout_optimization:
327
+ return True
328
+
329
+ conv_nodes = [
330
+ n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default
331
+ ]
332
+ nconv = len(conv_nodes)
333
+
334
+ if nconv == 0:
335
+ return False
336
+
337
+ # For cpu backend and mkldnn enabled, we always use channels_last for better performance.
338
+ if (
339
+ torch.backends.mkldnn.enabled
340
+ and torch.backends.mkldnn.is_available()
341
+ and all(
342
+ n.args[idx].meta["val"].device == torch.device("cpu")
343
+ for n in conv_nodes
344
+ for idx in [0, 1]
345
+ )
346
+ ):
347
+ return True
348
+
349
+ # Following models are skipped due to this:
350
+ # jx_nest_base
351
+ # volo_d1_224
352
+ if len(list(gm.graph.nodes)) >= 300 * nconv:
353
+ log.debug("Skipped layout opt because only a few conv")
354
+ return False
355
+
356
+ if any(
357
+ has_free_symbols(n.args[idx].meta["val"])
358
+ for n in conv_nodes
359
+ for idx in [0, 1]
360
+ ):
361
+ log.debug(
362
+ "See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670"
363
+ )
364
+ return False
365
+
366
+ def is_grouped(n):
367
+ return n.args[-1] > 1 and n.args[1].meta["val"].size(1) > 1
368
+
369
+ def is_in_out_channel(n):
370
+ return (
371
+ n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1)
372
+ and n.args[1].meta["val"].size(2) > 1
373
+ )
374
+
375
+ def is_small_channel(n):
376
+ return (
377
+ n.args[1].meta["val"].size(0) <= 64
378
+ and n.args[1].meta["val"].size(1) <= 64
379
+ )
380
+
381
+ # only grouped convolutions benchmarked as slower in conv samples for inference only
382
+ if is_inference:
383
+ from torch.utils.flop_counter import FlopCounterMode
384
+
385
+ flop_counts: Dict[str, float] = defaultdict(float)
386
+ for node in conv_nodes:
387
+ success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(
388
+ node
389
+ )
390
+
391
+ if success:
392
+ with FlopCounterMode(display=False) as flop_counter_mode:
393
+ with V.fake_mode:
394
+ node.target(*args, **kwargs)
395
+
396
+ counted_flops = flop_counter_mode.get_total_flops()
397
+ if is_grouped(node):
398
+ node_type = "grouped"
399
+ elif is_small_channel(node):
400
+ node_type = "small"
401
+ elif is_in_out_channel(node):
402
+ node_type = "in_out"
403
+ else:
404
+ node_type = "default"
405
+
406
+ flop_counts[node_type] += counted_flops
407
+ else:
408
+ log.debug("Conv inputs meta not found")
409
+
410
+ # average benchmarked channels last speedup / slowdown, < 1 is speedup.
411
+ # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/
412
+ # To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb
413
+ GROUPED_MULTIPLIER = 1.358
414
+ DEFAULT_MULTIPLIER = 0.823
415
+ IN_OUT_MULTIPLIER = 0.725
416
+ SMALL_MULTIPLIER = 0.783
417
+
418
+ total_flops = sum(flop_counts.values())
419
+ # TODO - get different values per hardware
420
+ weighted_flops = (
421
+ flop_counts["grouped"] * GROUPED_MULTIPLIER
422
+ + flop_counts["small"] * SMALL_MULTIPLIER
423
+ + flop_counts["in_out"] * IN_OUT_MULTIPLIER
424
+ + flop_counts["default"] * DEFAULT_MULTIPLIER
425
+ )
426
+ do_layout_opt = weighted_flops <= total_flops
427
+ if not do_layout_opt:
428
+ log.debug(
429
+ "Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d",
430
+ total_flops,
431
+ weighted_flops,
432
+ )
433
+ return do_layout_opt
434
+
435
+ # Channels last layout can dramatically hurt grouped conv perf. E.g.
436
+ # Conv with arguments like
437
+ # {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3],
438
+ # "stride": [2, 2], "padding": [1, 1], "groups": 2}
439
+ # slows down 31x using channels last..
440
+
441
+ # But a lot of timm models use depthwise separable convolution which will
442
+ # result in grouped convolution with in-channel size == 1.
443
+ # For those grouped convolution, channels last still helps a lot.
444
+ # E.g.
445
+ # Conv with arguments
446
+ # {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3],
447
+ # "stride": [2, 2], "padding": [1, 1], "groups": 58}
448
+ # get 1.86x speedup with channels last layout.
449
+ #
450
+ # The following heuristics skip using channels-last if the model contains
451
+ # grouped convolution with in-channels > 1.
452
+ if any(map(is_grouped, conv_nodes)):
453
+ log.debug(
454
+ "Skip layout opt because found grouped convolution with >1 in_channels!"
455
+ )
456
+ return False
457
+
458
+ # For some models that contain convolution with larger in-channel than out-channel, applying
459
+ # channels last hurts performance.
460
+ # Following models are skipped due to this:
461
+ # - pytorch_unet
462
+ # - phlippe_densenet (slightly worse)
463
+ # - Background_Matting (1.22x -> 0.821x)
464
+ # - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x)
465
+ if any(map(is_in_out_channel, conv_nodes)):
466
+ log.debug(
467
+ "Skip layout opt because some convolutions have smaller out_channel"
468
+ )
469
+ return False
470
+
471
+ # Following models are skipped due to this:
472
+ # - functorch_maml_omniglot
473
+ if all(map(is_small_channel, conv_nodes)):
474
+ log.debug("Skip layout opt because all convolution channels are too small")
475
+ return False
476
+
477
+ return True
478
+
479
+ def qualify_name(self, name: str) -> str:
480
+ """Prepend the given name with the graph name if any."""
481
+ if self.name is not None:
482
+ return f"{self.name}_{name}"
483
+ return name
484
+
485
+ def make_subgraph(
486
+ self,
487
+ gm: torch.fx.GraphModule,
488
+ example_inputs: List[torch.Tensor],
489
+ subgraph_name: str,
490
+ ) -> "GraphLowering":
491
+ """
492
+ Make a subgraph of the current graph with all inherited
493
+ parts, except the graph module (`gm`) and `example_inputs`.
494
+ The subgraphs are lowered separately, but intended to be
495
+ inlined in the parent graph's codegening. Hence the need
496
+ for maintaining the same `shape_env` and other properties.
497
+ The subgraph name is qualified by the parent graph's name.
498
+ """
499
+ return GraphLowering(
500
+ gm=gm,
501
+ example_inputs=example_inputs,
502
+ shape_env=self._shape_env,
503
+ cpp_wrapper=self.cpp_wrapper,
504
+ aot_mode=self.aot_mode,
505
+ extern_node_serializer=self.extern_node_serializer,
506
+ is_inference=self.is_inference,
507
+ name=self.qualify_name(subgraph_name),
508
+ )
509
+
510
+ def find_nodes_prefer_channels_last(self):
511
+ """
512
+ The rule to decide if an node prefer channels last is simple.
513
+ 1. if it's input/output of a convolution
514
+ 2. if one of its user prefers channels last
515
+
516
+ We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs;
517
+ Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers
518
+ channels last.
519
+
520
+ Consider the scenario: conv -> batch-norm -> relu -> conv
521
+ Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies:
522
+ 1. the output of batch-norm should be channels last initially since its input is a conv's output.
523
+ Forcing the batch-norm's output to be contiguous results in the first copy
524
+ 2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output.
525
+ We need convert it to channels last layout which results in the second copy.
526
+ With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies
527
+ can be saved.
528
+ """
529
+ output_set = set()
530
+ for n in reversed(self.module.graph.nodes):
531
+ if n.target == torch.ops.aten.convolution.default:
532
+ output_set.add(n)
533
+ continue
534
+
535
+ for user in n.users:
536
+ if user in output_set:
537
+ output_set.add(n)
538
+ break
539
+
540
+ # need a second pass to add downstream nodes of those channel last nodes to the sets.
541
+ # This pass is especially needed to avoid mix-layout kernel inputs in backward pass.
542
+ #
543
+ # Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned
544
+ # from the fwd graph. Without this second pass, we will force relu's output to be contiguous.
545
+ # Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last
546
+ # tensors and passed to a kernel.
547
+ #
548
+ # This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x.
549
+ # It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x .
550
+ # This also helps the following models:
551
+ # - res2net101_26w_4s
552
+ # - res2net50_14w_8s
553
+ # - sebotnet33ts_256
554
+ for n in self.module.graph.nodes:
555
+ if n in output_set:
556
+ for child in n.users:
557
+ output_set.add(child)
558
+
559
+ return output_set
560
+
561
+ def warn_fallback(self, name):
562
+ if name not in self._warned_fallback:
563
+ self._warned_fallback.add(name)
564
+ perf_hint_log.info("Using FallbackKernel: %s", name)
565
+
566
+ def add_device_info(self, device: torch.device):
567
+ self.device_types.add(device.type)
568
+ if device.index is not None:
569
+ self.device_idxs.add(device.index)
570
+ if V.graph.current_node and device not in self.device_node_mapping:
571
+ self.device_node_mapping[device] = V.graph.current_node
572
+
573
+ @property
574
+ def fake_mode(self):
575
+ return V.fake_mode
576
+
577
+ def get_buffer(self, buffer_name: str):
578
+ if buffer_name in self.name_to_buffer:
579
+ return self.name_to_buffer[buffer_name]
580
+ if buffer_name in self.graph_inputs:
581
+ return self.graph_inputs[buffer_name]
582
+ return None
583
+
584
+ def get_dtype(self, buffer_name: str):
585
+ if buffer_name in self.constants:
586
+ return self.constants[buffer_name].dtype
587
+ if buffer_name in self.name_to_buffer:
588
+ return self.name_to_buffer[buffer_name].get_dtype()
589
+ if buffer_name in self.graph_inputs:
590
+ return self.graph_inputs[buffer_name].get_dtype()
591
+ m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name)
592
+ if m:
593
+ return self.get_dtype(m.group(1))
594
+ raise KeyError(f"could not find {buffer_name}")
595
+
596
+ def get_numel(self, buffer_name: str):
597
+ from .ir import MultiOutputLayout
598
+
599
+ if buffer_name in self.constants:
600
+ return self.constants[buffer_name].numel()
601
+ if buffer_name in self.name_to_buffer:
602
+ buf = self.name_to_buffer[buffer_name]
603
+ if isinstance(getattr(buf, "layout", None), MultiOutputLayout):
604
+ return 1
605
+ return buf.get_numel()
606
+ if buffer_name in self.graph_inputs:
607
+ return self.graph_inputs[buffer_name].get_numel()
608
+ raise KeyError(f"could not find {buffer_name}")
609
+
610
+ @dynamo_timed
611
+ def run(self, *args):
612
+ return super().run(*args)
613
+
614
+ def register_buffer(self, buffer: ir.Buffer):
615
+ name = self.qualify_name(f"buf{len(self.buffers)}")
616
+ self.buffers.append(buffer)
617
+ self.name_to_buffer[name] = buffer
618
+ # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
619
+ if not isinstance(buffer, ir.ComputedBuffer) or not buffer.is_zero_elements():
620
+ self.add_device_info(buffer.get_device())
621
+ return name
622
+
623
+ def register_list(self, buffer_names: List[str]):
624
+ name = self.qualify_name("list_" + "_".join(buffer_names))
625
+ self.lists[name] = buffer_names
626
+ return name
627
+
628
+ def register_users_of(self, node_output):
629
+ def register(value):
630
+ if isinstance(value, (list, tuple)):
631
+ for x in value:
632
+ register(x)
633
+ if isinstance(value, ir.IRNode):
634
+ if (
635
+ not hasattr(value, "data")
636
+ or not isinstance(value.data, ir.IRNode)
637
+ or not (
638
+ hasattr(value.data, "data")
639
+ and isinstance(value.data.data, ir.IRNode)
640
+ )
641
+ ):
642
+ return
643
+
644
+ for read_name in value.get_read_names():
645
+ self.name_to_users[read_name].append(value)
646
+
647
+ register(node_output)
648
+
649
+ def mark_buffer_mutated(self, name: str):
650
+ """
651
+ When a buffer is mutated we need to make sure all the reads to
652
+ the old version are realized before the mutation happens.
653
+ """
654
+ assert isinstance(name, str)
655
+ self.mutated_buffers.add(name)
656
+
657
+ if name not in self.name_to_users:
658
+ return
659
+
660
+ for user in self.name_to_users[name]:
661
+ user.realize()
662
+
663
+ def add_tensor_constant(self, data, name=None):
664
+ def allocate(name):
665
+ if not config.aot_inductor.use_runtime_constant_folding:
666
+ for constant_name, value in self.constants.items():
667
+ if (
668
+ not data.is_mkldnn
669
+ and data.size() == value.size()
670
+ and data.stride() == value.stride()
671
+ and data.dtype == value.dtype
672
+ and data.device == value.device
673
+ and torch.eq(data, value).all()
674
+ ):
675
+ return constant_name
676
+
677
+ if name is None:
678
+ name = f"constant{len(self.constants)}"
679
+ if name[0].isdigit():
680
+ name = f"constant_{name}"
681
+ name = self.qualify_name(name)
682
+ # We may generate a var name for each constant in the codegen.
683
+ # Let's only keep sane characters.
684
+ prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name)
685
+ name = prefix
686
+ cnt = 0
687
+ while name in self.constants:
688
+ name = f"{prefix}_{cnt}"
689
+ cnt += 1
690
+ self.constants[name] = data
691
+ self.constant_reprs[name] = (
692
+ f"{data.device!r} {data.dtype!r} "
693
+ f"{tuple(data.size())!r} {tuple(data.stride())!r} "
694
+ f"{hash(data):x}"
695
+ )
696
+ return name
697
+
698
+ new_name = allocate(name)
699
+ self.allocated_constant_name[new_name] = name
700
+
701
+ return TensorBox.create(
702
+ ir.ConstantBuffer(
703
+ new_name,
704
+ FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)),
705
+ )
706
+ )
707
+
708
+ def constant_name(self, name: str, device_override: Optional[torch.device]):
709
+ """
710
+ We AOT copy constants to the devices they are needed on.
711
+ If device_override doesn't match the constant's device, then
712
+ copy it and return a different name.
713
+ """
714
+ if self.constants[name].device == device_override or device_override is None:
715
+ return name
716
+ alt_name = f"{name}_{device_override.type}{device_override.index or 0}"
717
+ if alt_name not in self.constants:
718
+ self.constants[alt_name] = self.constants[name].to(device_override)
719
+ return alt_name
720
+
721
+ def placeholder(self, target: str, args, kwargs):
722
+ example = super().placeholder(target, args, kwargs)
723
+ self.graph_input_names.append(target)
724
+ if isinstance(example, SymTypes):
725
+ expr = example.node.expr
726
+ self.graph_inputs[target] = expr
727
+ return expr
728
+ elif isinstance(example, (int, bool, float)):
729
+ expr = sympy.sympify(example)
730
+ self.graph_inputs[target] = expr
731
+ return expr
732
+ if isinstance(example, BackwardState):
733
+ # Ignored arg, must be unused
734
+ # Alternately we could filter this out in AotAutograd
735
+ return None
736
+ assert isinstance(example, torch.Tensor), example
737
+ # todo(chilli): We can remove the last check once we turn buffers into
738
+ # static shape tensors. That's a hack to workaround Inductor believing
739
+ # the buffer should be static but us passing in a fake tensor with
740
+ # symbolic shapes.
741
+ if not example._has_symbolic_sizes_strides:
742
+ # the first N inputs are weights
743
+ sizes, strides = self.static_sizes_strides(example)
744
+ else:
745
+ sizes, strides = self.symbolic_sizes_strides(example)
746
+ # TODO(jansel): handle input aliasing
747
+ target = self.qualify_name(target)
748
+ tensor = TensorBox.create(
749
+ InputBuffer(
750
+ target,
751
+ FixedLayout(example.device, example.dtype, sizes, strides),
752
+ )
753
+ )
754
+ self.graph_inputs[target] = tensor
755
+ self.graph_inputs_original[target] = tensor.data.data
756
+ self.add_device_info(example.device)
757
+ return tensor
758
+
759
+ def call_function(self, target, args, kwargs):
760
+ if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
761
+ return super().call_function(target, args, kwargs)
762
+
763
+ if hasattr(target, "_inductor_lowering_function"):
764
+ # passthrough lowerings from .pattern_matcher
765
+ return target(*args, **kwargs)
766
+
767
+ def get_custom_op_layout_constraints(target, args, kwargs):
768
+ # Custom operations that require preserving stride order
769
+ # which run through implicit fallback must constrain their
770
+ # arguments' fx strides
771
+ layout_constraint = None
772
+ if torch._C.Tag.needs_fixed_stride_order in target.tags:
773
+ # We have to set the current args because call_function will immediately
774
+ # evaluate this lowering after creating the fallback, without evaluating
775
+ # the layout constraint
776
+ args, kwargs = constrain_to_fx_strides(
777
+ self.current_node, *args, **kwargs
778
+ )
779
+ # Also register the layout constraint so when the fallback
780
+ # is used again, we can constrain the args to the same layout
781
+ layout_constraint = constrain_to_fx_strides
782
+ return layout_constraint, args, kwargs
783
+
784
+ if target not in lowerings:
785
+ assert isinstance(
786
+ target, torch._ops.OpOverload
787
+ ), f"{target} is not an OpOverload"
788
+ base_name = target.name().split(".")[0]
789
+ if base_name in FALLBACK_ALLOW_LIST:
790
+ make_fallback(target)
791
+ elif config.implicit_fallbacks:
792
+ layout_constraint, args, kwargs = get_custom_op_layout_constraints(
793
+ target, args, kwargs
794
+ )
795
+ error = (
796
+ MissingOperatorWithDecomp
797
+ if get_decompositions([target])
798
+ else MissingOperatorWithoutDecomp
799
+ )
800
+ log.info(
801
+ "Creating implicit fallback for:\n%s",
802
+ error.operator_str(target, args, kwargs),
803
+ )
804
+ make_fallback(target, layout_constraint)
805
+
806
+ elif get_decompositions([target]):
807
+ # There isn't a good way to dynamically patch this in
808
+ # since AOT Autograd already ran. The error message tells
809
+ # the user how to fix it.
810
+ raise MissingOperatorWithDecomp(target, args, kwargs)
811
+ else:
812
+ raise MissingOperatorWithoutDecomp(target, args, kwargs)
813
+
814
+ try:
815
+ log.debug(" via %s", lowerings[target])
816
+ out = lowerings[target](*args, **kwargs)
817
+ return out
818
+ except Exception as e:
819
+ raise LoweringException(e, target, args, kwargs).with_traceback(
820
+ e.__traceback__
821
+ ) from None
822
+
823
+ @staticmethod
824
+ def can_inline_constant(t: torch.Tensor) -> bool:
825
+ """
826
+ True if this is a small constant attr that will be inlined.
827
+ """
828
+ return len(t.shape) == 1 and t.shape[0] <= 8
829
+
830
+ def get_attr(self, target, args, kwargs):
831
+ # this is a constant
832
+ value = getattr_recursive(self.module, target)
833
+
834
+ if isinstance(value, torch.fx.GraphModule):
835
+ return ir.Subgraph(name=target, graph_module=value)
836
+
837
+ if (
838
+ config.aot_inductor.use_runtime_constant_folding
839
+ or config.always_keep_tensor_constants
840
+ or unsupported_output_tensor(value)
841
+ ):
842
+ return self.add_tensor_constant(value, target)
843
+
844
+ with no_dispatch():
845
+ if value.shape == ():
846
+ return Constant(value.item(), value.dtype, value.device)
847
+ if self.can_inline_constant(value):
848
+ # tensor lowering has constant inlining logic
849
+ from .lowering import tensor
850
+
851
+ return tensor(value.tolist(), dtype=value.dtype, device=value.device)
852
+
853
+ return self.add_tensor_constant(value, target)
854
+
855
+ def call_module(self, target, args, kwargs):
856
+ raise AssertionError()
857
+
858
+ def call_method(self, target, args, kwargs):
859
+ raise AssertionError()
860
+
861
+ def output(self, target, args, kwargs):
862
+ result = super().output(target, args, kwargs)
863
+ assert isinstance(result, (tuple, list)), type(result)
864
+ assert all(
865
+ isinstance(
866
+ x,
867
+ (
868
+ TensorBox,
869
+ ir.Constant,
870
+ type(None),
871
+ ir.ConstantBuffer,
872
+ sympy.Expr,
873
+ sympy.logic.boolalg.Boolean,
874
+ int,
875
+ ),
876
+ )
877
+ for x in result
878
+ ), result
879
+ self.graph_outputs = [ir.ExternKernel.realize_input(x) for x in result]
880
+ value: ir.IRNode
881
+ for name, value in self.graph_inputs.items():
882
+ assert isinstance(
883
+ value, (TensorBox, sympy.Expr)
884
+ ), f"Unsupported inductor graph input type: {type(value)}"
885
+ if not isinstance(value, TensorBox):
886
+ continue
887
+ value.realize()
888
+ assert isinstance(value, TensorBox)
889
+ value = value.data
890
+ assert isinstance(value, ir.StorageBox)
891
+ value_storage_box = value
892
+ value = value.data
893
+ if not isinstance(value, InputBuffer) or value.get_name() != name:
894
+ # one of our inputs was mutated, need to turn that into a copy
895
+ ir.MutationLayout.realize_into(value, self.graph_inputs_original[name])
896
+ # replace output with mutated input
897
+ try:
898
+ ind = self.graph_outputs.index(value_storage_box)
899
+ self.graph_outputs[ind] = self.graph_inputs_original[name]
900
+ except ValueError:
901
+ pass
902
+
903
+ self.finalize()
904
+ log.debug(
905
+ "Force channels last inputs for %d conv for the current graph with id %d",
906
+ self.num_channels_last_conv,
907
+ self.graph_id if self.graph_id is not None else -1,
908
+ )
909
+
910
+ def finalize(self):
911
+ for buf in self.buffers:
912
+ buf.decide_layout()
913
+
914
+ @contextmanager
915
+ def set_current_node(self, node: torch.fx.Node):
916
+ old = self.current_node
917
+ try:
918
+ self.current_node = node
919
+ yield
920
+ finally:
921
+ self.current_node = old
922
+
923
+ def run_node(self, n: torch.fx.Node):
924
+ def debug(msg):
925
+ log.debug("lowering %s %s", LazyString(n.format_node), msg)
926
+
927
+ origins = {n}
928
+ if n.op == "call_function":
929
+ args, kwargs = self.fetch_args_kwargs_from_env(n)
930
+ origins |= gather_origins(args, kwargs)
931
+ with ir.IRNode.current_origins(origins), self.set_current_node(
932
+ n
933
+ ), V.set_current_node(n):
934
+ if (
935
+ n.op == "call_function"
936
+ and n.target is not operator.getitem
937
+ and fallback_node_due_to_unsupported_type(n)
938
+ ):
939
+ debug("fallback_handler")
940
+ result = fallback_handler(n.target, add_to_fallback_set=False)(
941
+ *args, **kwargs # type: ignore[possibly-undefined]
942
+ )
943
+ elif n.op == "call_function" and n.target in layout_constraints:
944
+ debug("layout_constraints")
945
+ args, kwargs = layout_constraints[n.target](n, *args, **kwargs) # type: ignore[index]
946
+ result = self.call_function(n.target, args, kwargs)
947
+ elif is_magic_method(n.target):
948
+ # TODO: this is sus, it probably should be handled in the
949
+ # lowerings themselves similarly to sym_size/sym-stride
950
+ debug("is_magic_method")
951
+ if isinstance(n.meta["val"], torch.SymInt):
952
+ result = n.meta["val"].node.expr
953
+ else:
954
+ result = super().run_node(n)
955
+ else:
956
+ debug("")
957
+ result = super().run_node(n)
958
+
959
+ # require the same stride order for dense outputs,
960
+ # 1. user-land view() will not throw because inductor
961
+ # output different strides than eager
962
+ # long term the solution is to make view() always succeed
963
+ # with infallible strides.
964
+ # 2: as_strided ops, we need make sure its input has same size/stride with
965
+ # eager model to align with eager behavior.
966
+ as_strided_ops = [
967
+ torch.ops.aten.as_strided.default,
968
+ torch.ops.aten.as_strided_.default,
969
+ torch.ops.aten.as_strided_scatter.default,
970
+ ]
971
+ is_output = any(user.op == "output" for user in n.users)
972
+ is_input_for_as_strided = any(
973
+ user.target in as_strided_ops for user in n.users
974
+ )
975
+ if (
976
+ is_output
977
+ and isinstance(result, TensorBox)
978
+ and isinstance(result.data, ir.BaseView)
979
+ ):
980
+ # Realize so that outputs are correctly aliased
981
+ result.realize()
982
+
983
+ if (is_output or is_input_for_as_strided) and isinstance(
984
+ n.meta["val"], torch.Tensor
985
+ ):
986
+ strides = n.meta["val"].stride()
987
+ dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"])
988
+ # requiring a stride order for a non-dense output wouldn't
989
+ # recreate the same strides, and would fail with view, defer for now.
990
+ if dense and len(strides):
991
+ stride_order = ir.get_stride_order(strides)
992
+ if (
993
+ len(result.get_size()) == 4
994
+ and n in self.nodes_prefer_channels_last
995
+ and n.name not in self.user_visible_outputs
996
+ and not is_input_for_as_strided
997
+ ):
998
+ stride_order = ir.NHWC_STRIDE_ORDER
999
+ result = ir.ExternKernel.require_stride_order(result, stride_order)
1000
+
1001
+ # Realize if (1) any user need inputs realized, or (2) there is
1002
+ # already too many reads and rematerializing can be bad.
1003
+ num_users = len(set(n.users))
1004
+ if num_users > 1 and isinstance(result, TensorBox):
1005
+ for user in n.users:
1006
+ if user.target in needs_realized_inputs:
1007
+ result.realize_hint()
1008
+ # This inclusion is somewhat controversial (from
1009
+ # discussion between Horace, Natalia, and Elias).
1010
+ # Currently, it's not very clear why this is helpful.
1011
+ # The general idea here is that even though a node may
1012
+ # have FlexibleLayout, we still often *treat* it as if
1013
+ # it was contiguous. This appears to sometimes result in
1014
+ # suboptimal behavior.
1015
+ #
1016
+ # When we do a better job selecting layout, we should
1017
+ # revisit this.
1018
+ need_fixed_layout = [
1019
+ torch.ops.aten.convolution_backward.default,
1020
+ torch.ops.aten.mm.default,
1021
+ torch.ops.aten._int_mm.default,
1022
+ ]
1023
+ if not self.layout_opt:
1024
+ need_fixed_layout.append(torch.ops.aten.convolution.default)
1025
+ if torch._C._has_mkldnn:
1026
+ need_fixed_layout += [
1027
+ torch.ops.mkldnn._convolution_pointwise.default,
1028
+ torch.ops.mkldnn._convolution_pointwise.binary,
1029
+ torch.ops.mkldnn._convolution_pointwise_.binary,
1030
+ torch.ops.mkldnn._convolution_transpose_pointwise.default,
1031
+ torch.ops.mkldnn._linear_pointwise.default,
1032
+ torch.ops.mkldnn._linear_pointwise.binary,
1033
+ torch.ops.aten.mkldnn_rnn_layer.default,
1034
+ torch.ops.onednn.qconv2d_pointwise.default,
1035
+ torch.ops.onednn.qconv2d_pointwise.binary,
1036
+ torch.ops.onednn.qlinear_pointwise.default,
1037
+ torch.ops.onednn.qlinear_pointwise.tensor,
1038
+ ]
1039
+ if torch._C.has_mkl:
1040
+ need_fixed_layout += [torch.ops.mkl._mkl_linear.default]
1041
+ if user.target in need_fixed_layout:
1042
+ result = ir.ExternKernel.require_stride_order(
1043
+ result, ir.get_stride_order(n.meta["val"].stride())
1044
+ )
1045
+ if user.op == "output":
1046
+ if isinstance(result.data.data, (Pointwise, Reduction)):
1047
+ result.realize()
1048
+
1049
+ # TODO(jansel): introduce a store vs inline choice
1050
+ result.mark_reuse(len(n.users))
1051
+
1052
+ # Realize if the IRNode already has accumulated lots of reads
1053
+ if isinstance(result, TensorBox) and result.has_exceeded_max_reads():
1054
+ # Prevent excessive accumulation in a computed buffer, when
1055
+ # there are multiple branches each with small number of memory
1056
+ # reads, but they converge to a user.
1057
+ result.realize_hint()
1058
+
1059
+ # Realize if a Pointwise has too much stuff to be inlined.
1060
+ # As this may cause RecursionError during Inductor's evaluation.
1061
+ if isinstance(result, TensorBox) and isinstance(result.data, StorageBox):
1062
+ curr = result.data.data
1063
+ if isinstance(curr, Pointwise):
1064
+ # Use inner fn as a rough proxy. Good enough.
1065
+ if curr.has_large_inner_fn():
1066
+ result.realize()
1067
+
1068
+ # This is not complete, but it doesn't have to be: origin_node
1069
+ # tracking is best effort. The logic here critically relies on direct
1070
+ # TensorBox -> StorageBox denoting a non-view; we don't bother trying
1071
+ # to get views to work. Feel free to add any extra cases as needed.
1072
+ #
1073
+ # Note: we can't YOLO tree_map over this result, because if there are
1074
+ # buffers or a view involved, we might not be able to validly assign
1075
+ # the origin_node here.
1076
+ if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox):
1077
+ if isinstance(result.data.data, ir.Loops):
1078
+ result.data.data.origin_node = n
1079
+ elif isinstance(result.data.data, ir.Buffer):
1080
+ result.data.data.origin_node = n
1081
+ if isinstance(result.data.data, ir.ComputedBuffer) and isinstance(
1082
+ result.data.data.data, ir.Loops
1083
+ ):
1084
+ result.data.data.data.origin_node = n
1085
+ # Not really multi-output, can straightforwardly recurse in
1086
+ elif (
1087
+ isinstance(result.data.data, ir.MultiOutput)
1088
+ and not result.data.data.indices
1089
+ ):
1090
+ if isinstance(result.data.data.inputs[0], ir.Buffer):
1091
+ result.data.data.inputs[0].origin_node = n
1092
+
1093
+ self.register_users_of(result)
1094
+
1095
+ return result
1096
+
1097
+ def validate_can_generate_cpp_wrapper(self):
1098
+ if config.disable_cpp_codegen:
1099
+ raise CppWrapperCodeGenError("C++ codegen is disabled")
1100
+
1101
+ if sys.platform not in ["linux", "darwin"]:
1102
+ raise CppWrapperCodeGenError(f"Unsupported platform {sys.platform}")
1103
+
1104
+ for value in self.graph_inputs.values():
1105
+ dtype = None
1106
+ if isinstance(value, TensorBox):
1107
+ dtype = value.get_dtype()
1108
+ elif isinstance(
1109
+ value, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
1110
+ ):
1111
+ dtype = may_get_constant_buffer_dtype(value)
1112
+
1113
+ if not supported_dtype_of_cpp_wrapper(dtype, self.cuda):
1114
+ raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}")
1115
+
1116
+ def init_wrapper_code(self):
1117
+ self.cuda = "cuda" in self.device_types
1118
+ if self.cpp_wrapper:
1119
+ self.validate_can_generate_cpp_wrapper()
1120
+ self.wrapper_code = CppWrapperCuda() if self.cuda else CppWrapperCpu()
1121
+ else:
1122
+ device_types = self.device_types.copy()
1123
+ device_types.discard("cpu")
1124
+ # TODO(Eikan): Only support mixing cpu and other device now.
1125
+ assert len(device_types) <= 1, "Does not support mixing {}".format(
1126
+ "+".join(device_types)
1127
+ )
1128
+ only_cpu = len(device_types) == 0
1129
+ device_type = "cpu" if only_cpu else device_types.pop()
1130
+
1131
+ self.device_ops = get_device_op_overrides(device_type)
1132
+ wrapper_code_gen_cls = get_wrapper_codegen_for_device(device_type)
1133
+ assert (
1134
+ wrapper_code_gen_cls is not None
1135
+ ), f"Device {device_type} not supported"
1136
+ self.wrapper_code = wrapper_code_gen_cls()
1137
+
1138
+ if self.const_module:
1139
+ # If we have const module, we could reuse the kernels
1140
+ # This could avoid duplication and save time on doing recompilation (if Triton.)
1141
+ self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter
1142
+ self.wrapper_code.src_to_kernel = (
1143
+ self.const_module.wrapper_code.src_to_kernel
1144
+ )
1145
+
1146
+ def codegen_with_cpp_wrapper(self):
1147
+ """
1148
+ For CPU, the cpp wrapper codegen is done in one pass.
1149
+ For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python
1150
+ wrapper code and run it to generate autotuned kernel binaries in the first pass; and then
1151
+ generate cpp wrapper code and compile it to a dynamic library in the second pass.
1152
+ """
1153
+ if "cuda" in self.device_types:
1154
+ # first pass
1155
+ self.cpp_wrapper = False
1156
+ compiled = self.compile_to_module().call
1157
+
1158
+ def materialize(x):
1159
+ if isinstance(x, (torch.SymInt, torch.SymFloat)):
1160
+ # Need concrete value to run dynamic shapes and tune the result
1161
+ return x.node.hint
1162
+ elif isinstance(x, FakeTensor):
1163
+ return defake(x)
1164
+ else:
1165
+ assert isinstance(
1166
+ x, torch.Tensor
1167
+ ), "Unknown type when creating real inputs" + str(type(x))
1168
+ return x
1169
+
1170
+ if tracing_context := torch._guards.TracingContext.try_get():
1171
+ if tracing_context.output_strides:
1172
+ tracing_context.output_strides.clear()
1173
+
1174
+ params_flat = [
1175
+ param
1176
+ for param in tracing_context.params_flat # type: ignore[union-attr]
1177
+ if param is not None
1178
+ ]
1179
+ real_inputs = [
1180
+ materialize(x) for x in itertools.chain(params_flat, V.real_inputs)
1181
+ ]
1182
+ else:
1183
+ real_inputs = [materialize(x) for x in V.real_inputs]
1184
+
1185
+ with torch.utils._python_dispatch._disable_current_modes():
1186
+ assert self.example_inputs is not None
1187
+ compiled(real_inputs)
1188
+ del real_inputs
1189
+
1190
+ # second pass
1191
+ # TODO: reuse self.scheduler from the first pass to speed up the second pass
1192
+ self.cpp_wrapper = True
1193
+ self.removed_buffers.clear()
1194
+ self.inplaced_to_remove.clear()
1195
+ return self.codegen()
1196
+ else:
1197
+ # cpu
1198
+ return self.codegen()
1199
+
1200
+ def codegen(self):
1201
+ from .scheduler import Scheduler
1202
+
1203
+ self.init_wrapper_code()
1204
+
1205
+ self.scheduler = Scheduler(self.buffers)
1206
+ V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
1207
+
1208
+ self.scheduler.codegen()
1209
+ return self.wrapper_code.generate(self.is_inference)
1210
+
1211
+ def codegen_subgraph(self, parent_graph):
1212
+ """
1213
+ This is a more compact version of the `codegen()` above
1214
+ where we codegen this graph as a subgraph of some parent
1215
+ graph. The parent graph is passed as an argument: the
1216
+ intention is to inline codegening of the subgraph in
1217
+ the parent graph's wrapper code (including the generated
1218
+ kerenls). The wrapper code is not finalized (via `.generate()`
1219
+ call), as this will be done in the parent graph's `codegen()`.
1220
+ """
1221
+ from .scheduler import Scheduler
1222
+
1223
+ self.wrapper_code = parent_graph.wrapper_code
1224
+ self.device_ops = parent_graph.device_ops
1225
+ self.cpp_wrapper = parent_graph.cpp_wrapper
1226
+
1227
+ self.scheduler = Scheduler(self.buffers)
1228
+ self.scheduler.codegen()
1229
+
1230
+ def count_bytes(self):
1231
+ from .scheduler import Scheduler
1232
+
1233
+ scheduler = Scheduler(self.buffers)
1234
+
1235
+ total_bytes = 0
1236
+ node_counts = []
1237
+ node_runtimes = []
1238
+ for node in scheduler.nodes:
1239
+ num_bytes = node.get_read_write_buffers_sizes()
1240
+ total_bytes += num_bytes
1241
+ node_counts.append((node, num_bytes // 4))
1242
+ node_runtimes.append((node, node.get_estimated_runtime()))
1243
+ return total_bytes, node_counts, node_runtimes
1244
+
1245
+ @dynamo_timed(phase_name="code_gen")
1246
+ def compile_to_module(self):
1247
+ from .codecache import PyCodeCache
1248
+
1249
+ code, linemap = (
1250
+ self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
1251
+ )
1252
+ linemap = [(line_no, node.stack_trace) for line_no, node in linemap]
1253
+ key, path = PyCodeCache.write(code)
1254
+ mod = PyCodeCache.load_by_key_path(
1255
+ key, path, linemap=linemap, attrs=self.constants
1256
+ )
1257
+ self.cache_key = key
1258
+ self.cache_path = path
1259
+ self.cache_linemap = linemap
1260
+
1261
+ # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
1262
+ # TODO. Revisit this once the logging API is more mature
1263
+ assert mod.__file__ is not None
1264
+
1265
+ log_module_code(mod.__file__)
1266
+ log.debug("Output code written to: %s", mod.__file__)
1267
+ output_code_log.debug("Output code: \n%s", code)
1268
+ trace_structured(
1269
+ "inductor_output_code",
1270
+ lambda: {"filename": mod.__file__},
1271
+ payload_fn=lambda: code,
1272
+ )
1273
+ output_code_log.info("Output code written to: %s", mod.__file__)
1274
+ if config.benchmark_kernel:
1275
+ print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
1276
+ V.debug.output_code(mod.__file__)
1277
+ V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
1278
+ return mod
1279
+
1280
+ def compile_to_fn(self):
1281
+ if self.aot_mode:
1282
+ from .codecache import AotCodeCompiler
1283
+
1284
+ assert self.cpp_wrapper, "AOT mode only supports C++ wrapper"
1285
+ code, linemap = self.codegen_with_cpp_wrapper()
1286
+ output_code_log.debug("Output code: \n%s", code)
1287
+
1288
+ serialized_extern_kernel_nodes = None
1289
+ if (
1290
+ config.is_fbcode()
1291
+ and self.extern_kernel_nodes
1292
+ and self.extern_node_serializer
1293
+ ):
1294
+ serialized_extern_kernel_nodes = self.extern_node_serializer(
1295
+ self.extern_kernel_nodes
1296
+ )
1297
+ output_code_log.debug(
1298
+ "Serialized Extern Kernel Nodes: \n%s",
1299
+ serialized_extern_kernel_nodes,
1300
+ )
1301
+
1302
+ # Directly return the file path with the compiled code
1303
+ return AotCodeCompiler.compile(
1304
+ self, code, serialized_extern_kernel_nodes, cuda=self.cuda
1305
+ )
1306
+ else:
1307
+ return self.compile_to_module().call
1308
+
1309
+ def get_output_names(self):
1310
+ return [
1311
+ node.get_name()
1312
+ for node in self.graph_outputs
1313
+ if not isinstance(node, ir.NoneAsConstantBuffer)
1314
+ and not isinstance(node, ir.ShapeAsConstantBuffer)
1315
+ ]
1316
+
1317
+ def is_unspec_arg(self, name: str):
1318
+ # dynamo wraps unspec variable as 0d CPU tensor,
1319
+ # need to convert to scalar during codegen (triton only)
1320
+ return (
1321
+ name in self.graph_inputs.keys()
1322
+ and self.graph_inputs[name].get_numel() == 1
1323
+ and self.graph_inputs[name].get_device().type == "cpu"
1324
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ir.py ADDED
The diff for this file is too large to render. See raw diff
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py ADDED
@@ -0,0 +1,1524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import functools
5
+ import inspect
6
+ import itertools
7
+ import logging
8
+ import operator
9
+ import os
10
+ import re
11
+ from collections import defaultdict
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ DefaultDict,
16
+ Dict,
17
+ Iterable,
18
+ List,
19
+ NoReturn,
20
+ Optional,
21
+ Set,
22
+ Union,
23
+ )
24
+
25
+ from typing_extensions import TypeGuard
26
+
27
+ import torch
28
+ import torch._guards
29
+ import torch.fx
30
+ import torch.utils._pytree as pytree
31
+ from torch._dispatch.python import enable_python_dispatcher
32
+ from torch._dynamo.utils import counters
33
+ from torch._prims_common import is_integer_dtype
34
+ from torch.fx import Node
35
+ from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
36
+ from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
37
+ from torch.fx.immutable_collections import immutable_dict, immutable_list
38
+
39
+ from .._functorch import config as functorch_config
40
+ from .._functorch.aot_autograd import aot_function, make_boxed_func
41
+ from .._functorch.partitioners import default_partition
42
+ from .._subclasses import FakeTensorMode
43
+ from ..fx import Transformer
44
+ from . import config
45
+ from .decomposition import select_decomp_table
46
+ from .lowering import fallback_node_due_to_unsupported_type
47
+
48
+ log = logging.getLogger(__name__)
49
+ aten = torch.ops.aten
50
+ prims = torch.ops.prims
51
+
52
+ Constant = Any
53
+ NodeOrConstant = Union[Constant, torch.fx.Node]
54
+
55
+
56
+ class Multiple:
57
+ pass
58
+
59
+
60
+ # Sentinel indicating multiple quantities can be matched
61
+ MULTIPLE = Multiple()
62
+
63
+
64
+ class Match:
65
+ """
66
+ Represents a successfully matched pattern.
67
+ """
68
+
69
+ def __init__(self, pattern: PatternExpr, args=None, kwargs=None):
70
+ super().__init__()
71
+ self.pattern = pattern
72
+ # The input nodes that must be passed in to the result
73
+ self.args = args or []
74
+ self.kwargs = kwargs or {}
75
+ # The nodes matched in this expression
76
+ self.nodes: List[torch.fx.Node] = []
77
+ # Mapping CallFunction to the node.target
78
+ self.targets: Dict[_TargetExpr, torch.fx.node.Target] = {}
79
+ self.ctx: Optional[MatchContext] = None
80
+ self.replacement_graph: Optional[torch.fx.Graph] = None
81
+
82
+ @property
83
+ def graph(self) -> torch.fx.Graph:
84
+ assert self.ctx
85
+ return self.ctx.graph
86
+
87
+ def extend(self, other: Match):
88
+ if self.kwargs:
89
+ for key in set(self.kwargs.keys()) & set(other.kwargs.keys()):
90
+ if self.kwargs[key] != other.kwargs[key]:
91
+ raise FailedMatch("kwarg mismatch: {}", key)
92
+ self.args.extend(other.args)
93
+ self.nodes.extend(other.nodes)
94
+ self.kwargs.update(other.kwargs)
95
+ self.targets.update(other.targets)
96
+
97
+ def bundle(self) -> Match:
98
+ # Wrap args in an extra list
99
+ self.args = [tuple(self.args)] if self.args else []
100
+ return self
101
+
102
+ def __repr__(self):
103
+ return f"Match(..., {self.args}, {self.kwargs})"
104
+
105
+ def erase_nodes(self, graph: torch.fx.Graph):
106
+ for n in reversed(self.nodes):
107
+ if not n._erased:
108
+ graph.erase_node(n)
109
+
110
+ def output_nodes(self) -> List[Optional[torch.fx.Node]]:
111
+ assert self.ctx
112
+ return [
113
+ (self.ctx.pattern_to_node[p] if p is not None else None)
114
+ for p in self.ctx.outputs
115
+ ]
116
+
117
+ def output_node(self) -> torch.fx.Node:
118
+ return next(p for p in self.output_nodes() if p)
119
+
120
+ def replace_with_graph(self, replacement_graph, args):
121
+ assert self.ctx
122
+ ReplacementPatternEntry.replace_with_graph(
123
+ self, self.ctx.graph, replacement_graph, args
124
+ )
125
+
126
+ def replace_by_example(self, replacement_fn, args, trace_fn=None, run_dce=True):
127
+ assert self.ctx
128
+ if trace_fn is None:
129
+ trace_fn = functools.partial(fwd_only, run_dce=run_dce)
130
+ replacement = trace_fn(
131
+ replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
132
+ )
133
+ ReplacementPatternEntry.replace_with_graph(
134
+ self,
135
+ self.ctx.graph,
136
+ replacement,
137
+ args,
138
+ )
139
+
140
+
141
+ class FailedMatch(RuntimeError):
142
+ def __init__(self, format_string, *args, **kwargs):
143
+ self.format_string = format_string
144
+ # We want to construct error messages lazily instead of eagerly, as
145
+ # constructing them eagerly can significantly worsen compile times.
146
+ if len(format_string) > 200:
147
+ raise RuntimeError(
148
+ f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}"
149
+ )
150
+ self.args = args
151
+ self.kwargs = kwargs
152
+
153
+ def __str__(self):
154
+ return self.format_string.format(*self.args, **self.kwargs)
155
+
156
+ def __bool__(self):
157
+ return False
158
+
159
+
160
+ def is_match(m: Union[Match, FailedMatch]) -> TypeGuard[Match]:
161
+ """
162
+ TypeGuards cannot act on `self`. Thus this function exists to let mypy
163
+ recognize FailedMatch.__bool__ as a TypeGuard.
164
+ """
165
+ return bool(m)
166
+
167
+
168
+ class MatchContext:
169
+ """
170
+ State needed while running PatternExpr._match().
171
+ """
172
+
173
+ def __init__(
174
+ self,
175
+ outputs: List[Optional[PatternExpr]],
176
+ pattern_to_node: Optional[Dict[PatternExpr, Node]] = None,
177
+ *,
178
+ graph: torch.fx.Graph,
179
+ ):
180
+ self.outputs = outputs
181
+ self.pattern_to_node = {} if pattern_to_node is None else pattern_to_node
182
+ self.graph = graph
183
+ self.exclusive_node_set: List[NodeOrConstant] = []
184
+
185
+ def match(self, pattern, node):
186
+ """wrapper to check reused nodes in patterns"""
187
+ if pattern in self.pattern_to_node:
188
+ if self.pattern_to_node[pattern] == node:
189
+ return Match(pattern) # already checked this node
190
+ else:
191
+ return FailedMatch("repeated pattern differs")
192
+ m = pattern._match(node, self)
193
+ assert pattern not in self.pattern_to_node
194
+ self.pattern_to_node[pattern] = node if m else None
195
+ m.ctx = self
196
+ return m
197
+
198
+ def filter_multi_user_patterns(self):
199
+ return {
200
+ pattern: node
201
+ for pattern, node in self.pattern_to_node.items()
202
+ if pattern.has_multiple_users() and node is not None
203
+ }
204
+
205
+
206
+ class PatternExpr:
207
+ """
208
+ Base class for types of patterns
209
+ """
210
+
211
+ def _match(
212
+ self, node: torch.fx.Node, ctx: MatchContext
213
+ ) -> Union[Match, FailedMatch]:
214
+ raise NotImplementedError()
215
+
216
+ def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]:
217
+ try:
218
+ return MatchContext([self], graph=node.graph).match(self, node)
219
+ except FailedMatch as e:
220
+ return e
221
+
222
+ def has_multiple_users(self) -> bool:
223
+ return False
224
+
225
+ def __repr__(self):
226
+ return self.__class__.__name__ + "()"
227
+
228
+ def find_anchor_nodes(self, ctx: MatchContext, searched):
229
+ if self in ctx.pattern_to_node:
230
+ yield ctx.pattern_to_node[self]
231
+
232
+
233
+ class Arg(PatternExpr):
234
+ """
235
+ Capture an arg which will become an input to the handler. Args are
236
+ passed in depth first order.
237
+ """
238
+
239
+ def _match(self, node: NodeOrConstant, ctx: MatchContext):
240
+ return Match(self, args=[node]) # matches anything
241
+
242
+
243
+ class Ignored(PatternExpr):
244
+ """
245
+ Match an arg, but don't pass it to handler
246
+ """
247
+
248
+ def _match(self, node: NodeOrConstant, ctx: MatchContext):
249
+ return Match(self) # matches anything
250
+
251
+ def __repr__(self):
252
+ return "*"
253
+
254
+ def pretty_print(self, pp: PatternPrettyPrinter):
255
+ return "Ignored()"
256
+
257
+
258
+ class KeywordArg(PatternExpr):
259
+ """
260
+ Capture a kwarg which will become an input to the handler.
261
+ """
262
+
263
+ def __init__(self, name: str):
264
+ super().__init__()
265
+ self.name = name
266
+
267
+ def __repr__(self):
268
+ return f"KeywordArg({self.name!r})"
269
+
270
+ def _match(self, node: NodeOrConstant, ctx: MatchContext):
271
+ return Match(self, kwargs={self.name: node}) # matches anything
272
+
273
+
274
+ class ExclusiveKeywordArg(PatternExpr):
275
+ """
276
+ Capture a kwarg which will become an input to the handler.
277
+ """
278
+
279
+ def __init__(self, name):
280
+ super().__init__()
281
+ self.name = name
282
+
283
+ def __repr__(self):
284
+ return f"ExclusiveKeywordArg({self.name!r})"
285
+
286
+ def _match(self, node: NodeOrConstant, ctx: MatchContext):
287
+ if node in ctx.exclusive_node_set:
288
+ return FailedMatch("exclusive arg appears twice")
289
+
290
+ ctx.exclusive_node_set.append(node)
291
+ return Match(self, kwargs={self.name: node}) # matches anything
292
+
293
+
294
+ class _TargetExpr(PatternExpr):
295
+ """
296
+ Base class for filtering match by node.target
297
+ """
298
+
299
+ op: Optional[str] = None
300
+
301
+ def __init__(self, fns, users=1):
302
+ if not self.op:
303
+ raise NotImplementedError("Shouldn't directly use _BaseNodeMatch")
304
+ super().__init__()
305
+ fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns)
306
+ for fn in list(fns):
307
+ if isinstance(fn, torch._ops.OpOverloadPacket):
308
+ fns.extend([getattr(fn, overload) for overload in fn.overloads()])
309
+
310
+ self.fns: List[Union[Callable[..., Any], str]] = fns
311
+ self.fns_set: Set[Union[Callable[..., Any], str]] = set(fns)
312
+ self.users: Union[int, Multiple] = users
313
+
314
+ def fns_repr(self) -> str:
315
+ first_repr = self.fns[0]
316
+ if not isinstance(first_repr, str):
317
+ first_repr = first_repr.__name__
318
+
319
+ if len(self.fns) > 1:
320
+ return f"[{first_repr}, ...]"
321
+ elif self.fns[0] is getattr(torch, first_repr, None):
322
+ return f"torch.{first_repr}"
323
+ elif isinstance(self.fns[0], torch._ops.OpOverload):
324
+ return str(self.fns[0])
325
+ else:
326
+ return first_repr
327
+
328
+ def __repr__(self):
329
+ return f"{self.__class__.__name__}({self.fns_repr()})"
330
+
331
+ def has_multiple_users(self) -> bool:
332
+ return isinstance(self.users, Multiple) or self.users > 1
333
+
334
+ def find_anchor_nodes(self, ctx: MatchContext, searched):
335
+ raise NotImplementedError()
336
+
337
+ def _match_fns(self, node: torch.fx.Node):
338
+ return (
339
+ isinstance(node, torch.fx.Node)
340
+ and node.op == self.op
341
+ and extract_target(node) in self.fns_set
342
+ )
343
+
344
+ def _match_users(self, node: torch.fx.Node, ctx: MatchContext):
345
+ return (
346
+ self in ctx.outputs
347
+ or self.users is MULTIPLE
348
+ or len(node.users) == self.users
349
+ )
350
+
351
+
352
+ class _TargetArgsExpr(_TargetExpr):
353
+ """
354
+ Base class for filtering match by node.{target,args,kwargs}
355
+ """
356
+
357
+ def __init__(self, fns, *args, _users=1, **kwargs):
358
+ super().__init__(fns, _users)
359
+ self.args = tuple(args)
360
+ self.kwargs = dict(kwargs)
361
+ if any(
362
+ isinstance(x, (dict, list, tuple))
363
+ for x in itertools.chain(args, kwargs.values())
364
+ ):
365
+ self.flatten = self.pytree_flatten
366
+ else:
367
+ self.flatten = self.simple_flatten
368
+ self.flat_args_kwargs = self.flatten(self.args, self.kwargs)
369
+
370
+ @staticmethod
371
+ def simple_flatten(args, kwargs: Dict[Any, Any]):
372
+ return (*args, *kwargs.values()), (len(args), *kwargs.keys())
373
+
374
+ @staticmethod
375
+ def pytree_flatten(args, kwargs: Dict[Any, Any]):
376
+ def norm_spec(s: pytree.TreeSpec):
377
+ if s.type is None:
378
+ return s
379
+ mapping = {immutable_list: list, tuple: list, immutable_dict: dict}
380
+ return pytree.TreeSpec(
381
+ mapping.get(s.type, s.type),
382
+ s.context,
383
+ list(map(norm_spec, s.children_specs)),
384
+ )
385
+
386
+ flat, spec = pytree.tree_flatten([args, kwargs])
387
+ spec = norm_spec(spec)
388
+ return flat, spec
389
+
390
+ def __repr__(self):
391
+ args = [
392
+ self.fns_repr(),
393
+ *map(repr, self.args),
394
+ *[f"{k}={v}" for k, v in self.kwargs.items()],
395
+ ]
396
+ return f"{self.__class__.__name__}({', '.join(args)})"
397
+
398
+ def pretty_print(self, pp: PatternPrettyPrinter):
399
+ args = [
400
+ self.fns_repr(),
401
+ *(pp.pretty_print(x) for x in self.args),
402
+ *[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()],
403
+ ]
404
+ if isinstance(self.users, Multiple):
405
+ args.append("_users=MULTIPLE")
406
+ elif self.users > 1:
407
+ args.append(f"_users={self.users}")
408
+
409
+ joiner_str = ", "
410
+ return f"{self.__class__.__name__}({joiner_str.join(args)})"
411
+
412
+ def _match(self, node: torch.fx.Node, ctx: MatchContext):
413
+ if not self._match_fns(node) or len(node.args) != len(self.args):
414
+ return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
415
+
416
+ if not self._match_users(node, ctx):
417
+ return FailedMatch("multiple_users {}", self)
418
+
419
+ _args = node.args
420
+ _kwargs = node.kwargs
421
+ if len(_kwargs) < len(self.kwargs):
422
+ from torch.fx.operator_schemas import normalize_function
423
+
424
+ normalized_args_and_kwargs = normalize_function(
425
+ node.target, node.args, node.kwargs
426
+ )
427
+
428
+ if normalized_args_and_kwargs is None:
429
+ return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
430
+ else:
431
+ _args, _kwargs = normalized_args_and_kwargs
432
+ if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs):
433
+ _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
434
+ else:
435
+ return FailedMatch(
436
+ "function_mismatch: node={}, pattern={}", node, self
437
+ )
438
+ else:
439
+ _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
440
+
441
+ node_items, node_spec = self.flatten(_args, _kwargs)
442
+ self_items, self_spec = self.flat_args_kwargs
443
+ if node_spec != self_spec:
444
+ return FailedMatch("args_structure {} {}", node_spec, self_spec)
445
+ assert len(node_items) == len(self_items)
446
+
447
+ m = Match(self)
448
+ for i, pattern, child_node in zip(itertools.count(), self_items, node_items):
449
+ if isinstance(pattern, PatternExpr):
450
+ child_match = ctx.match(pattern, child_node)
451
+ if not child_match:
452
+ return child_match
453
+ m.extend(child_match)
454
+ elif isinstance(child_node, torch.fx.Node) or child_node != pattern:
455
+ return FailedMatch(
456
+ "constant_args: {} {!r}!={pattern!r}", node, child_node
457
+ )
458
+ m.nodes.append(node)
459
+ m.targets[self] = node.target
460
+ return m
461
+
462
+ def find_anchor_nodes(self, ctx: MatchContext, searched):
463
+ """
464
+ This is used when we are matching a pattern with multiple outputs.
465
+ There is a partial match (stored in ctx) and we want to walk
466
+ this pattern to find a connection to an already-matched node.
467
+
468
+ Yields candidate nodes that `self._match` might like.
469
+ """
470
+ if self in ctx.pattern_to_node:
471
+ yield ctx.pattern_to_node[self]
472
+ return
473
+
474
+ for pattern in self.flat_args_kwargs[0]:
475
+ if isinstance(pattern, PatternExpr):
476
+ for other_node in pattern.find_anchor_nodes(ctx, searched):
477
+ if not isinstance(other_node, torch.fx.Node):
478
+ continue
479
+ for node in other_node.users:
480
+ if node not in searched:
481
+ if self._match_fns(node):
482
+ yield node
483
+ searched.add(node)
484
+
485
+
486
+ class CallFunction(_TargetArgsExpr):
487
+ """
488
+ Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)`
489
+ """
490
+
491
+ op = "call_function"
492
+
493
+
494
+ class CallMethod(_TargetArgsExpr):
495
+ """
496
+ Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)`
497
+ """
498
+
499
+ op = "call_method"
500
+
501
+
502
+ class CallModule(_TargetArgsExpr):
503
+ """
504
+ Matches a call_module node in the FX graphs: `module(*args, **kwargs)`
505
+ """
506
+
507
+ op = "call_module"
508
+
509
+
510
+ class _TargetExprVarArgs(_TargetExpr):
511
+ """
512
+ Matches a call_function node with any arguments which are passed into the pattern
513
+ """
514
+
515
+ def _match(self, node: torch.fx.Node, ctx: MatchContext):
516
+ if not self._match_fns(node):
517
+ return FailedMatch("function_mismatch")
518
+
519
+ if not self._match_users(node, ctx):
520
+ return FailedMatch("multiple_users")
521
+
522
+ m = Match(self)
523
+ m.nodes.append(node)
524
+ m.targets[self] = node.target
525
+ m.args.extend(node.args)
526
+ m.kwargs.update(node.kwargs)
527
+ return m
528
+
529
+
530
+ class CallFunctionVarArgs(_TargetExprVarArgs):
531
+ op = "call_function"
532
+
533
+
534
+ class CallMethodVarArgs(_TargetExprVarArgs):
535
+ op = "call_method"
536
+
537
+
538
+ class CallModuleVarArgs(_TargetExprVarArgs):
539
+ op = "call_module"
540
+
541
+
542
+ class ListOf(PatternExpr):
543
+ """
544
+ Matches a repeated pattern
545
+ """
546
+
547
+ def __init__(self, pattern: PatternExpr, partial=False):
548
+ super().__init__()
549
+ assert isinstance(pattern, PatternExpr)
550
+ self.pattern = pattern
551
+ self.partial = partial
552
+
553
+ def __repr__(self):
554
+ return f"{self.__class__.__name__}({self.pattern})"
555
+
556
+ def _match(self, node: List[torch.fx.Node], ctx: MatchContext): # type: ignore[override]
557
+ if not isinstance(node, (list, tuple)) or len(node) == 0:
558
+ return FailedMatch("non_list")
559
+ m = Match(self)
560
+ # Propagating patterns with multiple users will ensure we don't revisit
561
+ # the same nodes
562
+ pattern_to_node = ctx.filter_multi_user_patterns()
563
+ matched = False
564
+ for i, child_node in enumerate(node):
565
+ child_ctx = MatchContext(
566
+ ctx.outputs, pattern_to_node, graph=child_node.graph
567
+ )
568
+ child_match = child_ctx.match(self.pattern, child_node)
569
+ pattern_to_node = child_ctx.filter_multi_user_patterns()
570
+ if not child_match:
571
+ if not self.partial:
572
+ return FailedMatch("list[{}]: {}", i, child_match)
573
+ continue
574
+ matched = True
575
+ m.extend(child_match.bundle())
576
+ if not matched:
577
+ return FailedMatch("list: no_match")
578
+ return m.bundle()
579
+
580
+
581
+ class MultiOutputPattern(PatternExpr):
582
+ def __init__(self, outputs):
583
+ super().__init__()
584
+ assert all(isinstance(x, (PatternExpr, type(None))) for x in outputs), outputs
585
+ self.outputs: List[Optional[PatternExpr]] = outputs
586
+
587
+ @property
588
+ def fns(self):
589
+ assert self.outputs[0] and hasattr(self.outputs[0], "fns")
590
+ return self.outputs[0].fns
591
+
592
+ def __repr__(self):
593
+ return f"{self.__class__.__name__}({self.outputs})"
594
+
595
+ def pretty_print(self, pp: PatternPrettyPrinter):
596
+ args = [pp.pretty_print(x) for x in self.outputs]
597
+ joiner_str = f",\n{' '}"
598
+ str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}"
599
+ str_out = f"{str_out}\n])"
600
+ return str_out
601
+
602
+ def _match(self, node: torch.fx.Node, ctx: MatchContext):
603
+ m = ctx.match(self.outputs[0], node)
604
+ if not m:
605
+ return m
606
+
607
+ for pattern in self.outputs[1:]:
608
+ if pattern is None:
609
+ continue
610
+ child_match = self._match_from_anchors(pattern, ctx)
611
+ if not child_match:
612
+ return child_match
613
+ m.extend(child_match)
614
+
615
+ return m
616
+
617
+ def _match_from_anchors(self, pattern, ctx):
618
+ prior = dict(ctx.pattern_to_node)
619
+ m = FailedMatch("no anchor found")
620
+ for node in pattern.find_anchor_nodes(ctx, set()):
621
+ m = ctx.match(pattern, node)
622
+ if m:
623
+ return m
624
+ # revert any partial matches
625
+ ctx.pattern_to_node = dict(prior)
626
+ return m
627
+
628
+ def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]:
629
+ try:
630
+ return MatchContext(self.outputs, graph=node.graph).match(self, node)
631
+ except FailedMatch as e:
632
+ return e
633
+
634
+
635
+ class RepeatedExpr(PatternExpr):
636
+ """
637
+ Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind`
638
+ """
639
+
640
+ def __init__(self, inner_pattern: PatternExpr):
641
+ super().__init__()
642
+ assert hasattr(inner_pattern, "fns")
643
+ self.inner_pattern = inner_pattern
644
+
645
+ @property
646
+ def fns(self):
647
+ return self.inner_pattern.fns
648
+
649
+ def _match(self, node: torch.fx.Node, ctx: MatchContext):
650
+ m = ctx.match(self.inner_pattern, node)
651
+ if not m:
652
+ return m
653
+ ctx.pattern_to_node.pop(
654
+ self.inner_pattern,
655
+ )
656
+ # Check all anchor nodes match the pattern
657
+ for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, set()):
658
+ anchor_m = MatchContext([self], graph=node.graph).match(
659
+ self.inner_pattern, anchor_node
660
+ )
661
+ if not anchor_m:
662
+ return anchor_m
663
+ m.extend(anchor_m)
664
+ return m
665
+
666
+
667
+ class PatternPrettyPrinter:
668
+ """
669
+ Serializes Patterns to executable python.
670
+ XXX: currently only used and tested for fuse attention patterns. May not cover
671
+ all patterns.
672
+ """
673
+
674
+ def __init__(self):
675
+ self.namespace = torch.fx.graph._Namespace()
676
+ self.memoized_objs_names: Dict[PatternExpr, str] = {}
677
+ self.memoized_objs_pp: Dict[PatternExpr, str] = {}
678
+
679
+ @staticmethod
680
+ def run(obj: PatternExpr, output_name="output"):
681
+ """
682
+ Serializes obj to python code with obj written out to `output_name`
683
+ """
684
+
685
+ pp = PatternPrettyPrinter()
686
+ assert hasattr(obj, "pretty_print")
687
+ out_str = obj.pretty_print(pp=pp)
688
+
689
+ output = []
690
+ for key in pp.memoized_objs_names:
691
+ output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}")
692
+
693
+ output.append(f"{output_name} = {out_str}")
694
+
695
+ return "\n".join(output)
696
+
697
+ def pretty_print(self, obj):
698
+ if isinstance(obj, _TargetArgsExpr):
699
+ if memoized_name := self.memoized_objs_names.get(obj):
700
+ return memoized_name
701
+ else:
702
+ return self.memoize(obj)
703
+ if hasattr(obj, "pretty_print"):
704
+ return obj.pretty_print(self)
705
+
706
+ return repr(obj)
707
+
708
+ def memoize(self, obj):
709
+ obj_str = obj.pretty_print(self)
710
+ obj_name = obj.fns_repr()
711
+ for prefix in ("aten.", "torch.", "prims."):
712
+ obj_name = obj_name.replace(prefix, "")
713
+
714
+ tmp_name = self.namespace.create_name(obj_name, None)
715
+ self.memoized_objs_names[obj] = tmp_name
716
+ self.memoized_objs_pp[obj] = obj_str
717
+ return tmp_name
718
+
719
+
720
+ @dataclasses.dataclass
721
+ class PatternEntry:
722
+ pattern: PatternExpr
723
+ extra_check: Callable[[Match], bool]
724
+
725
+ def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node):
726
+ raise NotImplementedError()
727
+
728
+ def register(self, pass_dicts, target=None, prepend=False):
729
+ if target is None:
730
+ assert hasattr(self.pattern, "fns")
731
+ for fn in self.pattern.fns:
732
+ self.register(pass_dicts, fn, prepend=prepend)
733
+ elif isinstance(pass_dicts, (dict, PatternMatcherPass)):
734
+ if prepend:
735
+ pass_dicts[target].insert(0, self)
736
+ else:
737
+ pass_dicts[target].append(self)
738
+ else:
739
+ for x in pass_dicts:
740
+ self.register(x, target, prepend=prepend)
741
+
742
+
743
+ @dataclasses.dataclass
744
+ class LoweringPatternEntry(PatternEntry):
745
+ handler: Callable[..., Any]
746
+
747
+ def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node):
748
+ handler = functools.wraps(self.handler)(functools.partial(self.handler, match))
749
+ with graph.inserting_before(node):
750
+ replacement = graph.call_function(handler, tuple(match.args), match.kwargs)
751
+ replacement.meta.update(node.meta)
752
+ node.replace_all_uses_with(replacement)
753
+ assert match.nodes[-1] is node
754
+ match.erase_nodes(graph)
755
+
756
+
757
+ @dataclasses.dataclass
758
+ class GraphPatternEntry(PatternEntry):
759
+ """
760
+ A pattern that runs a function on the FX graph
761
+ """
762
+
763
+ handler: Callable[..., Any]
764
+
765
+ def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node):
766
+ with graph.inserting_before(node):
767
+ self.handler(match, *match.args, **match.kwargs)
768
+
769
+
770
+ @dataclasses.dataclass
771
+ class ReplacementPatternEntry(PatternEntry):
772
+ normalize_args: Callable[..., List[Any]]
773
+
774
+ @staticmethod
775
+ def replace_with_graph(
776
+ match: Match,
777
+ graph: torch.fx.Graph,
778
+ replacement_graph: torch.fx.Graph,
779
+ args: List[Any],
780
+ ):
781
+ output_nodes = match.output_nodes()
782
+ first_node = output_nodes[0]
783
+
784
+ class Replacer(torch.fx.Interpreter):
785
+ call_method = None # type: ignore[assignment]
786
+ call_module = None # type: ignore[assignment]
787
+ get_attr = None # type: ignore[assignment]
788
+
789
+ def run_node(self, node) -> Any:
790
+ if node.op in ("placeholder", "output"):
791
+ return super().run_node(node)
792
+ if node.op == "call_function":
793
+ target = node.target
794
+ args, kwargs = self.fetch_args_kwargs_from_env(node)
795
+ result = graph.call_function(target, args, kwargs)
796
+ if "val" in node.meta and "val" not in result.meta:
797
+ result.meta["val"] = node.meta["val"]
798
+ if isinstance(node.meta["val"], torch.Tensor):
799
+ assert "tensor_meta" in node.meta
800
+ result.meta["tensor_meta"] = node.meta["tensor_meta"]
801
+ return result
802
+ raise NotImplementedError(f"unhandled {node}")
803
+
804
+ output_nodes = match.output_nodes()
805
+
806
+ if len(output_nodes) == 1:
807
+ last_node = output_nodes[0]
808
+ else:
809
+ assert output_nodes[0]
810
+ nodes = list(output_nodes[0].graph.nodes)
811
+ indices = [
812
+ (nodes.index(n), n)
813
+ for n in output_nodes
814
+ if isinstance(n, torch.fx.Node)
815
+ ]
816
+ last_node = min(indices, key=lambda tup: tup[0])[1]
817
+
818
+ def percolate_tags(node, recompute_tag, input_stops):
819
+ queue = [node]
820
+ visited = set()
821
+
822
+ while queue:
823
+ arg = queue.pop()
824
+ if (
825
+ arg not in visited
826
+ and arg not in input_stops
827
+ and hasattr(arg, "meta")
828
+ ):
829
+ visited.add(arg)
830
+ arg.meta["recompute"] = recompute_tag
831
+ queue.extend(arg.all_input_nodes)
832
+
833
+ with graph.inserting_before(last_node):
834
+ replacement = Replacer(replacement_graph).run(*args)
835
+ if isinstance(replacement, torch.fx.Node):
836
+ replacement = [replacement]
837
+
838
+ def maybe_getitem(node):
839
+ if node.op != "call_function":
840
+ return None
841
+ if node.target != operator.getitem:
842
+ return None
843
+ assert len(node.args) == 2
844
+ return node.args[1]
845
+
846
+ def replace(old, new):
847
+ if old is None:
848
+ assert new is None
849
+ return
850
+ assert isinstance(old, torch.fx.Node)
851
+ if new is None:
852
+ old.replace_all_uses_with(None)
853
+ graph.erase_node(old)
854
+ return
855
+ if isinstance(new, torch.fx.Node):
856
+ if "val" not in new.meta:
857
+ new.meta.update(old.meta)
858
+
859
+ # Preserve the recompute tags in the replacement graph. We
860
+ # look at the recompute tags of the original output node to
861
+ # propagate the tag from the output all the way to the input
862
+ # args (named as args in the replace_with_graph).
863
+ # Note that this is best effort. Since patterns are from
864
+ # many to many, there is no easy way to correctly map the
865
+ # recomputable tags. It is possible in some scenarios that we
866
+ # incorrectly tag some nodes as recomputables.
867
+ if "recompute" in old.meta:
868
+ percolate_tags(new, old.meta["recompute"], args)
869
+
870
+ old.replace_all_uses_with(new)
871
+ graph.erase_node(old)
872
+ return
873
+
874
+ # `new` is not a node: it's a list of nodes.
875
+ #
876
+ # This happens when we want to replace a node that has a single
877
+ # packed return with multiple unpacked returns. We need to do
878
+ # some graph surgery here.
879
+ #
880
+ # Example:
881
+ # def original_graph(x):
882
+ # a = op(x)
883
+ # b = a[0]
884
+ # c = a[1]
885
+ # ...
886
+ #
887
+ # Assume that we want to replace op(x) with the graph
888
+ # def new_op(x):
889
+ # w = x + 1
890
+ # z = x + 2
891
+ # return (w, z)
892
+ #
893
+ # We need to replace `op` with the contents of `new_op`,
894
+ # and then rewrite a[0] to be w and a[1] to be z, as so:
895
+ # def new_graph(x):
896
+ # w = x + 1
897
+ # z = x + 2
898
+ # b = w
899
+ # c = z
900
+ # ...
901
+ old_uses = list(old.users.keys())
902
+ for user in old_uses:
903
+ idx = maybe_getitem(user)
904
+ if idx is None:
905
+ raise AssertionError("can't handle")
906
+ replace(user, new[idx])
907
+ graph.erase_node(old)
908
+
909
+ if len(output_nodes) == len(replacement):
910
+ for old, new in zip(output_nodes, replacement):
911
+ replace(old, new)
912
+ else:
913
+ assert len(output_nodes) == 1
914
+ replace(output_nodes[0], replacement)
915
+
916
+ match.erase_nodes(graph)
917
+
918
+ def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node):
919
+ self.replace_with_graph(
920
+ match,
921
+ graph,
922
+ match.replacement_graph, # type: ignore[arg-type]
923
+ self.normalize_args(*match.args, **match.kwargs),
924
+ )
925
+
926
+
927
+ def _return_true(match):
928
+ return True
929
+
930
+
931
+ def log_trace_failure(search_fn, e):
932
+ log.info(
933
+ "Replacement pattern %s failed to apply due to shape mismatch: %s",
934
+ search_fn.__name__,
935
+ e,
936
+ )
937
+
938
+
939
+ def register_replacement(
940
+ search_fn,
941
+ replace_fn,
942
+ example_inputs: Iterable[Any],
943
+ trace_fn: Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule],
944
+ pass_dicts,
945
+ extra_check=_return_true,
946
+ scalar_workaround=(),
947
+ exclusive_arg_names=(),
948
+ search_fn_pattern=None,
949
+ ):
950
+ """
951
+ Create a replacement rule based on example functions that get traced
952
+ to create patterns. This supports both training and inference when
953
+ run on a joint forward+backward graph.
954
+
955
+ Args:
956
+ search_fn: traced to give original pattern
957
+ replace_fn: traced to give replacement graph
958
+ example_inputs: example inputs for initial trace
959
+ trace_fn: fwd_only or joint_fwd_bwd
960
+ pass_dict: dict of passes to register to
961
+ extra_check: additional check to run on match(using real shapes)
962
+ """
963
+ argnames_static = [*inspect.signature(search_fn).parameters.keys()]
964
+
965
+ def check_fn(match: Match):
966
+ """
967
+ Often shapes get burned into the pattern, so our initial match ran with
968
+ `ignore_types=(int, ...)`.
969
+
970
+ Recheck the match with the correct shapes.
971
+ """
972
+ argnames = list(argnames_static)
973
+ for name in argnames:
974
+ if name not in match.kwargs:
975
+ raise RuntimeError(
976
+ f"Not all inputs to pattern found in match.kwargs. Perhaps one "
977
+ f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}"
978
+ )
979
+
980
+ args = list(
981
+ torch.fx.map_arg(
982
+ [match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
983
+ )
984
+ )
985
+ sym_args: List[torch.SymInt] = []
986
+ with torch._dynamo.utils.detect_fake_mode(args):
987
+ for i, grad in enumerate(requires_grad):
988
+ if isinstance(args[i], torch.Tensor):
989
+ if grad and is_integer_dtype(args[i].dtype):
990
+ return False
991
+
992
+ args[i] = torch.empty_strided(
993
+ args[i].size(),
994
+ args[i].stride(),
995
+ dtype=args[i].dtype,
996
+ device=args[i].device,
997
+ requires_grad=grad,
998
+ )
999
+ for v in itertools.chain(args[i].shape, args[i].stride()):
1000
+ if isinstance(v, torch.SymInt) and all(
1001
+ guard_size_oblivious(v != a) for a in sym_args
1002
+ ):
1003
+ sym_args.append(v)
1004
+
1005
+ if sym_args:
1006
+ # AOT Autograd and make fx will dedupe symbolic shape size
1007
+ # accesses of sym ints that appear as inputs
1008
+ # We don't want the sym_size uses to interfere with pattern matching
1009
+ # so we provide them as inputs.
1010
+ # Later, when we actually do the replacement, the symbolic shape
1011
+ # sizes will get re-traced and added to the graph.
1012
+
1013
+ def search_fn_new(*args_new):
1014
+ return search_fn(*args_new[len(args_new) - len(args) :])
1015
+
1016
+ try:
1017
+ specific_graph = trace_fn(search_fn_new, sym_args + args)
1018
+ except RuntimeError as e:
1019
+ log_trace_failure(search_fn, e)
1020
+ return False
1021
+
1022
+ # correct argnames in the graph
1023
+ sym_arg_names = []
1024
+ for i, placeholder in zip(
1025
+ range(len(sym_args) + len(args)),
1026
+ specific_graph.graph.nodes,
1027
+ ):
1028
+ if i < len(sym_args):
1029
+ sym_arg_names.append(placeholder.target)
1030
+ continue
1031
+
1032
+ with specific_graph.graph.inserting_after(placeholder):
1033
+ new_node = specific_graph.graph.placeholder(
1034
+ argnames[i - len(sym_args)]
1035
+ )
1036
+ new_node.target = new_node.name
1037
+ placeholder.replace_all_uses_with(new_node)
1038
+ specific_graph.graph.erase_node(placeholder)
1039
+
1040
+ argnames = sym_arg_names + argnames
1041
+ else:
1042
+ try:
1043
+ specific_graph = trace_fn(search_fn, args)
1044
+ except RuntimeError as e:
1045
+ log_trace_failure(search_fn, e)
1046
+ return False
1047
+
1048
+ specific_pattern = fx_to_pattern(
1049
+ specific_graph,
1050
+ argnames=argnames,
1051
+ exclusive_arg_names=exclusive_arg_names,
1052
+ scalar_workaround=scalar_workaround,
1053
+ )
1054
+ specific_pattern_match = specific_pattern.match(match.output_nodes()[0]) # type: ignore[arg-type]
1055
+ if specific_pattern_match and extra_check(specific_pattern_match):
1056
+ # trace the pattern using the shapes from the user program
1057
+ match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment]
1058
+ return True
1059
+ return False
1060
+
1061
+ def normalize_args(**kwargs):
1062
+ args = []
1063
+ for name in argnames_static:
1064
+ args.append(kwargs.pop(name))
1065
+ for i in range(1, len(kwargs) + 1):
1066
+ if f"tangents_{i}" not in kwargs:
1067
+ break
1068
+ args.append(kwargs.pop(f"tangents_{i}"))
1069
+ assert not kwargs, f"leftover kwargs: {kwargs!r}"
1070
+ return args
1071
+
1072
+ if trace_fn is joint_fwd_bwd:
1073
+ # If inference mode is enabled during compilation, assume that we don't
1074
+ # want to match on any training graph patterns
1075
+ if torch.is_inference_mode_enabled():
1076
+ return False
1077
+
1078
+ # TODO: Revisit the functionalize_rng_ops for lowmem dropout
1079
+ with functorch_config.patch(functionalize_rng_ops=False):
1080
+ requires_grad: List[bool] = [
1081
+ isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs
1082
+ ]
1083
+ if search_fn_pattern is None:
1084
+ pattern = gen_pattern(
1085
+ search_fn,
1086
+ example_inputs,
1087
+ trace_fn,
1088
+ scalar_workaround,
1089
+ exclusive_arg_names,
1090
+ )
1091
+ else:
1092
+ pattern = search_fn_pattern
1093
+
1094
+ pattern_repr = PatternPrettyPrinter.run(pattern)
1095
+ assert pattern_repr not in _seen_patterns
1096
+ _seen_patterns.add(pattern_repr)
1097
+ pattern = ReplacementPatternEntry(
1098
+ pattern=pattern,
1099
+ extra_check=check_fn,
1100
+ normalize_args=normalize_args,
1101
+ )
1102
+ pattern.register(pass_dicts)
1103
+ return pattern.pattern
1104
+
1105
+
1106
+ @functorch_config.patch(functionalize_rng_ops=False)
1107
+ def gen_pattern(
1108
+ search_fn, example_inputs, trace_fn, scalar_workaround=(), exclusive_arg_names=()
1109
+ ) -> PatternExpr:
1110
+ argnames = [*inspect.signature(search_fn).parameters.keys()]
1111
+
1112
+ if scalar_workaround == ():
1113
+ scalar_workaround = {}
1114
+ flat_inputs = []
1115
+ input_idx = 0 # Positional arguments index
1116
+
1117
+ for argname in argnames:
1118
+ if argname in scalar_workaround:
1119
+ flat_inputs.append(scalar_workaround[argname])
1120
+ else:
1121
+ flat_inputs.append(example_inputs[input_idx])
1122
+ input_idx += 1
1123
+
1124
+ search_gm = trace_fn(search_fn, flat_inputs)
1125
+ return fx_to_pattern(
1126
+ search_gm,
1127
+ ignore_types=(int, float, list, torch.device, torch.dtype),
1128
+ argnames=argnames,
1129
+ scalar_workaround=scalar_workaround,
1130
+ exclusive_arg_names=exclusive_arg_names,
1131
+ )
1132
+
1133
+
1134
+ def register_lowering_pattern(
1135
+ pattern: PatternExpr, extra_check=_return_true, *, pass_dict, prepend=False
1136
+ ):
1137
+ """
1138
+ Register an aten to inductor IR replacement pattern. The decorated
1139
+ function is saved and then called a lowering time allowing direct
1140
+ pattern to inductor IR conversion.
1141
+ """
1142
+
1143
+ def decorator(handler):
1144
+ assert callable(handler)
1145
+ LoweringPatternEntry(
1146
+ pattern=pattern, extra_check=extra_check, handler=handler
1147
+ ).register(pass_dict, prepend=prepend)
1148
+ handler._inductor_lowering_function = True
1149
+ return handler
1150
+
1151
+ return decorator
1152
+
1153
+
1154
+ def register_graph_pattern(
1155
+ pattern: PatternExpr, extra_check=_return_true, *, pass_dict, prepend=False
1156
+ ):
1157
+ """
1158
+ Register a pattern that runs a function on the FX graph, allowing
1159
+ custom transformation code.
1160
+ """
1161
+
1162
+ def decorator(handler):
1163
+ assert callable(handler)
1164
+ GraphPatternEntry(
1165
+ pattern=pattern, extra_check=extra_check, handler=handler
1166
+ ).register(pass_dict, prepend=prepend)
1167
+ return handler
1168
+
1169
+ return decorator
1170
+
1171
+
1172
+ def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
1173
+ # first node in the graph
1174
+ return node is next(iter(graph.nodes))
1175
+
1176
+
1177
+ # match: copy_, relu_, _set_grad_enabled, manual_seed, enter_functional_autocast, etc
1178
+ _mutation_op_re = re.compile(r"_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_)")
1179
+
1180
+
1181
+ def is_mutation_op(node: torch.fx.Node) -> bool:
1182
+ if node.op == "call_function":
1183
+ if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
1184
+ return True
1185
+ elif node.op == "call_method":
1186
+ if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type]
1187
+ return True
1188
+ return node.kwargs.get("out") is not None
1189
+
1190
+
1191
+ def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
1192
+ n = node
1193
+ while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):
1194
+ n = n.prev
1195
+ mutation_region_id = n.meta.get("mutation_region_id", 0)
1196
+ while n is not node:
1197
+ n = n.next
1198
+ if is_mutation_op(n):
1199
+ mutation_region_id += 1
1200
+ n.meta["mutation_region_id"] = mutation_region_id
1201
+ return mutation_region_id
1202
+
1203
+
1204
+ def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool:
1205
+ return "mutation_region_id" not in next(iter(graph.nodes)).meta
1206
+
1207
+
1208
+ def compute_mutation_region_ids(graph: torch.fx.GraphModule):
1209
+ mutation_region_id = 0
1210
+ for nd in graph.nodes:
1211
+ if is_mutation_op(nd):
1212
+ mutation_region_id += 1
1213
+ nd.meta["mutation_region_id"] = mutation_region_id
1214
+
1215
+
1216
+ class PatternMatcherPass:
1217
+ def __init__(
1218
+ self, prevent_match_across_mutations=False, pass_name: Optional[str] = None
1219
+ ):
1220
+ super().__init__()
1221
+ self.patterns: DefaultDict[
1222
+ torch.fx.node.Target, List[PatternEntry]
1223
+ ] = defaultdict(list)
1224
+ self.prevent_match_across_mutations = prevent_match_across_mutations
1225
+ self.pass_name = pass_name
1226
+
1227
+ def __getitem__(self, item: torch.fx.node.Target) -> List[PatternEntry]:
1228
+ return self.patterns[item]
1229
+
1230
+ def apply(self, graph: torch.fx.GraphModule) -> int:
1231
+ if not self.patterns:
1232
+ return 0
1233
+ if isinstance(graph, torch.fx.GraphModule):
1234
+ graph = graph.graph
1235
+ if self.prevent_match_across_mutations:
1236
+ if should_compute_mutation_region_ids(graph):
1237
+ compute_mutation_region_ids(graph)
1238
+ get_mutation_region_id_partial = functools.partial(
1239
+ get_mutation_region_id, graph
1240
+ )
1241
+ count = 0
1242
+ for node in reversed(graph.nodes):
1243
+ target = extract_target(node)
1244
+ if (
1245
+ node.op in ["call_function", "call_method", "call_module"]
1246
+ and target in self.patterns
1247
+ ):
1248
+ # conservatively not applying pattern for cpu input,
1249
+ # since some of the patterns induce codegen and split nodes.
1250
+ # Note: we will only skip cpu compute if disable_cpp_codegen=True
1251
+ if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False):
1252
+ continue
1253
+
1254
+ for entry in self.patterns[target]:
1255
+ if node._erased:
1256
+ break
1257
+ m = entry.pattern.match(node)
1258
+ # pattern match crosses mutation barrier - discard
1259
+ if (
1260
+ self.prevent_match_across_mutations
1261
+ and is_match(m)
1262
+ and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined]
1263
+ ):
1264
+ continue
1265
+ if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
1266
+ log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
1267
+ if is_match(m) and entry.extra_check(m):
1268
+ count += 1
1269
+ entry.apply(m, graph, node) # type: ignore[arg-type]
1270
+ counters["inductor"]["pattern_matcher_count"] += 1
1271
+ counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes)
1272
+ return count
1273
+
1274
+ def clear(self):
1275
+ self.patterns.clear()
1276
+
1277
+
1278
+ def _not_implemented(*args, **kwargs) -> NoReturn:
1279
+ raise NotImplementedError()
1280
+
1281
+
1282
+ def fx_to_pattern(
1283
+ gm,
1284
+ ignore_types=(),
1285
+ argnames=(),
1286
+ scalar_workaround=(),
1287
+ exclusive_arg_names=(),
1288
+ ) -> PatternExpr:
1289
+ """
1290
+ Convert an FX graph into a PatternExpr. This is useful for simple
1291
+ patterns that can only match single functions and fixed-length lists.
1292
+ """
1293
+ # scalar_workaround is a hack to capture dropout_p
1294
+ # see https://github.com/pytorch/pytorch/issues/97894
1295
+ scalar_workaround = scalar_workaround or {}
1296
+ inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()}
1297
+ assert len(inv_scalar_workaround) == len(scalar_workaround)
1298
+
1299
+ def process_arg(x):
1300
+ if isinstance(x, (float, int)) and x in inv_scalar_workaround:
1301
+ return KeywordArg(inv_scalar_workaround[x])
1302
+ if type(x) in ignore_types:
1303
+ return Ignored()
1304
+ if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x:
1305
+ return Ignored()
1306
+ return x
1307
+
1308
+ argnum = itertools.count()
1309
+
1310
+ class Converter(torch.fx.Interpreter):
1311
+ call_method = _not_implemented
1312
+ call_module = _not_implemented
1313
+ get_attr = _not_implemented
1314
+
1315
+ def placeholder(self, target, args, kwargs):
1316
+ n = next(argnum)
1317
+ if n < len(argnames):
1318
+ name = argnames[n]
1319
+ elif argnames:
1320
+ assert target.startswith("tangent")
1321
+ name = target
1322
+ else:
1323
+ target = re.sub(r"_\d+$", "", target) # de-mangle arg name
1324
+ name = target
1325
+ if name in exclusive_arg_names:
1326
+ return ExclusiveKeywordArg(name)
1327
+ else:
1328
+ return KeywordArg(name)
1329
+
1330
+ def call_function(self, target, args, kwargs):
1331
+ args, kwargs = pytree.tree_map(process_arg, (args, kwargs))
1332
+ if list in ignore_types:
1333
+ # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...]
1334
+ args = [process_arg(a) for a in args]
1335
+ kwargs = {k: process_arg(a) for k, a in kwargs.items()}
1336
+ return CallFunction(target, *args, **kwargs)
1337
+
1338
+ def run_node(self, n):
1339
+ rv = super().run_node(n)
1340
+ if n.op == "output" and isinstance(rv, tuple):
1341
+ assert len(rv) == len(n.args[0])
1342
+ for r, arg in zip(rv, n.args[0]):
1343
+ r.users = len(arg.users)
1344
+ else:
1345
+ rv.users = len(n.users)
1346
+ return rv
1347
+
1348
+ pattern = Converter(gm).run()
1349
+ if not isinstance(pattern, PatternExpr):
1350
+ return MultiOutputPattern(pytree.tree_leaves(pattern))
1351
+ return pattern
1352
+
1353
+
1354
+ @torch.no_grad()
1355
+ def fwd_only(fn, args, *, run_dce=True) -> torch.fx.GraphModule:
1356
+ """Build a normalized inference graph, for use with fx_to_pattern"""
1357
+ # TODO - look into using aot autograd, asserting no mutating ops here
1358
+ with enable_python_dispatcher():
1359
+ mode = (
1360
+ "real" if not torch._inductor.utils.any_is_symbolic(*args) else "symbolic"
1361
+ )
1362
+ gm = make_fx(fn, select_decomp_table(), tracing_mode=mode)(*args)
1363
+ if run_dce:
1364
+ gm.graph.eliminate_dead_code()
1365
+ gm.recompile()
1366
+ return gm
1367
+
1368
+
1369
+ @torch.enable_grad()
1370
+ def joint_fwd_bwd(fn, args) -> torch.fx.GraphModule:
1371
+ """Build a normalized training graph, for use with fx_to_pattern"""
1372
+ gm: Optional[torch.fx.GraphModule] = None
1373
+
1374
+ def record_joint_graph(joint_graph, inputs, **kwargs):
1375
+ nonlocal gm
1376
+ assert not gm
1377
+ gm = clone_graph(joint_graph)
1378
+ return default_partition(joint_graph, inputs, **kwargs)
1379
+
1380
+ with torch._guards.tracing(None):
1381
+ aot_function(
1382
+ fn,
1383
+ lambda g, i: make_boxed_func(g),
1384
+ partition_fn=record_joint_graph,
1385
+ decompositions=select_decomp_table(),
1386
+ keep_inference_input_mutations=True,
1387
+ enable_log=False,
1388
+ )(*args)
1389
+ assert gm
1390
+
1391
+ from .fx_passes.joint_graph import pointless_view
1392
+
1393
+ matcher_pass = PatternMatcherPass()
1394
+
1395
+ pattern = CallFunction(
1396
+ torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")
1397
+ )
1398
+ GraphPatternEntry(
1399
+ pattern=pattern, handler=pointless_view, extra_check=_return_true
1400
+ ).register(matcher_pass.patterns)
1401
+ matcher_pass.apply(gm.graph) # type: ignore[arg-type]
1402
+
1403
+ # remove in/out specs
1404
+ gm.graph._codegen = torch.fx.graph.CodeGen()
1405
+ gm.graph.eliminate_dead_code()
1406
+ gm.recompile()
1407
+ return gm
1408
+
1409
+
1410
+ def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]:
1411
+ args: List[torch.fx.node.Argument] = list()
1412
+ torch.fx.map_arg((n.args, n.kwargs), args.append)
1413
+ return args
1414
+
1415
+
1416
+ def stable_topological_sort(graph: torch.fx.Graph):
1417
+ # Nodes are in exactly one of these three collections:
1418
+
1419
+ # - Nodes in `pending` are waiting to be processed (in reverse order):
1420
+ pending = list(reversed(graph.nodes))
1421
+
1422
+ # - Nodes in `ready` have been processed and are already in the correct
1423
+ # order.
1424
+ ready = set()
1425
+
1426
+ # - `waiting` is a mapping from a dependency to nodes which depend on that
1427
+ # dependency.
1428
+ waiting = defaultdict(list)
1429
+
1430
+ # The cursor indicates the last processed node so we can add new nodes
1431
+ # after it.
1432
+ cursor = None
1433
+ while pending:
1434
+ node = pending.pop()
1435
+ waiting_for = [x for x in _args(node) if x not in ready]
1436
+ if waiting_for:
1437
+ # We have unprocessed input nodes. Might as well wait for the last
1438
+ # arg so an already sorted list will only recheck this node once.
1439
+ waiting[waiting_for[-1]].append(node)
1440
+ else:
1441
+ ready.add(node)
1442
+ if cursor and cursor.next is not node:
1443
+ cursor.append(node)
1444
+ cursor = node
1445
+ # Mark the nodes that have been waiting for this node to finish as
1446
+ # ready to check again.
1447
+ pending.extend(reversed(waiting.pop(node, ())))
1448
+
1449
+ assert not waiting and len(ready) == len(graph.nodes)
1450
+
1451
+
1452
+ def init_once_fakemode(fn: Callable[..., Any]):
1453
+ """Wrapper around lazy init functions in fx_passes/"""
1454
+
1455
+ @functools.lru_cache(None)
1456
+ @functools.wraps(fn)
1457
+ def lazy_init():
1458
+ counters_ref = counters["inductor"].copy()
1459
+
1460
+ with torch._guards.tracing(
1461
+ None
1462
+ ), maybe_disable_fake_tensor_mode(), FakeTensorMode():
1463
+ result = fn()
1464
+
1465
+ # clear view matches encountered during tracing
1466
+ counters["inductor"] = counters_ref
1467
+
1468
+ return result
1469
+
1470
+ return lazy_init
1471
+
1472
+
1473
+ def config_flag(name):
1474
+ """Function for extra_check to put pass behind a flag"""
1475
+
1476
+ def flag_check(match):
1477
+ return getattr(config, name)
1478
+
1479
+ return flag_check
1480
+
1481
+
1482
+ def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
1483
+ class CopyGraph(Transformer):
1484
+ def run_node(self, old_node):
1485
+ new_node = super().run_node(old_node)
1486
+ if isinstance(new_node, torch.fx.Proxy):
1487
+ new_node.node.meta.update(old_node.meta)
1488
+ new_node.node.name = self.new_graph._graph_namespace.create_name(
1489
+ old_node.name, None
1490
+ )
1491
+ return new_node
1492
+
1493
+ return CopyGraph(input_graph).transform()
1494
+
1495
+
1496
+ _seen_patterns: Set[str] = set()
1497
+
1498
+
1499
+ def get_arg_value(
1500
+ node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None
1501
+ ):
1502
+ return (
1503
+ node.args[arg_number]
1504
+ if len(node.args) > arg_number
1505
+ else node.kwargs.get(kwarg_name) # type: ignore[arg-type]
1506
+ )
1507
+
1508
+
1509
+ def filter_nodes(nodes: Iterable[torch.fx.Node], fn) -> List[torch.fx.Node]:
1510
+ fns = [fn]
1511
+ if isinstance(fn, torch._ops.OpOverloadPacket):
1512
+ fns.extend([getattr(fn, overload) for overload in fn.overloads()])
1513
+
1514
+ return [node for node in nodes if node.target in fns]
1515
+
1516
+
1517
+ def extract_target(node: Node):
1518
+ """For call_function and call_method, we directly use the target function;
1519
+ For call_module, the target is string, and we treat the module class
1520
+ as a function.
1521
+ """
1522
+ if node.op == "call_module":
1523
+ return getattr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type]
1524
+ return node.target
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/scheduler.py ADDED
@@ -0,0 +1,2445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import dataclasses
3
+ import functools
4
+ import itertools
5
+ import logging
6
+ import math
7
+ import operator
8
+ import os
9
+ import pprint
10
+ import textwrap
11
+ from typing import (
12
+ Any,
13
+ Counter,
14
+ DefaultDict,
15
+ Dict,
16
+ Generic,
17
+ List,
18
+ Optional,
19
+ Sequence,
20
+ Set,
21
+ Tuple,
22
+ TypeVar,
23
+ Union,
24
+ )
25
+
26
+ import sympy
27
+
28
+ import torch
29
+ from torch._dynamo.utils import dynamo_timed
30
+ from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
31
+ from torch.utils._triton import has_triton
32
+
33
+ from . import comms, config, dependencies, ir, metrics
34
+ from .codegen.common import get_scheduling_for_device, Kernel
35
+ from .comm_analysis import estimate_nccl_collective_runtime
36
+ from .dependencies import Dep, MemoryDep, StarDep, WeakDep
37
+ from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout
38
+ from .sizevars import SimplifyIndexing
39
+ from .utils import (
40
+ cache_on_self,
41
+ cmp,
42
+ free_symbol_has,
43
+ get_device_tflops,
44
+ get_dtype_size,
45
+ get_gpu_dram_gbps,
46
+ green_text,
47
+ is_collective,
48
+ is_wait,
49
+ red_text,
50
+ sympy_product,
51
+ )
52
+ from .virtualized import V
53
+
54
+
55
+ log = logging.getLogger(__name__)
56
+ fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
57
+
58
+
59
+ class WhyNoFuse:
60
+ # TODO when we drop support for Python < 3.10, we can use
61
+ # @dataclass(slots=True) instead of manually specifying __slots__.
62
+ __slots__ = ["node1", "node2", "reason", "args"]
63
+ reason: str
64
+ args: Tuple[Any, ...]
65
+
66
+ def __init__(self, node1: "BaseSchedulerNode", node2: "BaseSchedulerNode"):
67
+ self.node1 = node1
68
+ self.node2 = node2
69
+
70
+ def __call__(self, reason, *args):
71
+ self.reason = reason
72
+ self.args = args
73
+ fusion_log.debug(self)
74
+
75
+ def __str__(self):
76
+ return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + (
77
+ self.reason % self.args
78
+ )
79
+
80
+
81
+ def pformat(obj):
82
+ if isinstance(obj, set):
83
+ # pformat has trouble with sets of sympy exprs
84
+ obj = sorted(obj, key=str)
85
+ result = pprint.pformat(obj, indent=4)
86
+ if "\n" in result:
87
+ return f"\n{textwrap.indent(result, ' '*4)}"
88
+ return result
89
+
90
+
91
+ class OutputNode:
92
+ def __init__(self, dep):
93
+ self.unmet_dependencies = {dep}
94
+ self.inverse_users = []
95
+
96
+ def is_reduction(self):
97
+ return False
98
+
99
+ def get_alias_names(self):
100
+ return ()
101
+
102
+ def get_name(self):
103
+ return "OUTPUT"
104
+
105
+ __repr__ = get_name
106
+
107
+
108
+ def _prune_redundant_deps(node, name_to_fused_node):
109
+ """
110
+ Prunes weakdeps intended for mutation ordering
111
+ on an upstream fused node if after fusion there is another dependency
112
+ on the fused upstream node, making the weakdep redundant
113
+
114
+ In essence this enforces an ordering on fusions. As fusions occur, weakdeps will
115
+ be incrementally removed, enabling other fusions, ensuring they are fused in order.
116
+ """
117
+ name_to_dep_count: Counter[str] = collections.Counter()
118
+
119
+ for dep in node.unmet_dependencies:
120
+ if not isinstance(dep, WeakDep):
121
+ name_to_dep_count[name_to_fused_node[dep.name].get_name()] += 1
122
+
123
+ def should_prune(dep):
124
+ if isinstance(dep, WeakDep):
125
+ is_redundant = (
126
+ name_to_dep_count[name_to_fused_node[dep.name].get_name()] > 0
127
+ )
128
+ # These can occur because fused nodes always gather deps from their snodes
129
+ # If B has a weakdep on A
130
+ # B gets fused with C, then any time BC is fused, the weakdep will reappear
131
+ is_self_dep = name_to_fused_node[dep.name] == node
132
+ return is_redundant or is_self_dep
133
+ else:
134
+ return False
135
+
136
+ deps_to_prune = {dep for dep in node.unmet_dependencies if should_prune(dep)}
137
+
138
+ if deps_to_prune:
139
+ node.unmet_dependencies = node.unmet_dependencies - deps_to_prune
140
+ node.set_read_writes(node.read_writes.remove_reads(deps_to_prune))
141
+
142
+
143
+ # TODO(xmfan): reuse an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel
144
+ kernel_name_to_op = {
145
+ "extern_kernels.convolution": torch.ops.aten.convolution,
146
+ "extern_kernels.mm": torch.ops.aten.mm,
147
+ "extern_kernels.bmm": torch.ops.aten.bmm,
148
+ "extern_kernels.addmm": torch.ops.aten.addmm,
149
+ }
150
+
151
+
152
+ class BaseSchedulerNode:
153
+ def __init__(self, scheduler: "Scheduler", node: ir.Buffer):
154
+ self.scheduler: Scheduler = scheduler
155
+ self.node: ir.Buffer = node
156
+ self.users: List[NodeUser] = []
157
+ self.inverse_users: List[BaseSchedulerNode] = []
158
+ self.node_users: List[BaseSchedulerNode] = []
159
+ self.set_read_writes(node.get_read_writes())
160
+ self.ancestors: Set[str] = set()
161
+ self.min_order: int
162
+ self.max_order: int
163
+ self.last_usage: Set[
164
+ str
165
+ ] = set() # buffers that won't be used after this kernel
166
+ self.written = False
167
+
168
+ def __repr__(self):
169
+ return f"{type(self).__name__}(name={self.get_name()!r})"
170
+
171
+ def debug_str(self) -> str:
172
+ """Longer form printout for trace logs"""
173
+ name = self.get_name()
174
+ lines = [
175
+ f"{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__})",
176
+ f"{name}.writes = {pformat(self.read_writes.writes)}",
177
+ f"{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}",
178
+ f"{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}",
179
+ f"{name}.users = {self.users}",
180
+ ]
181
+ try:
182
+ lines += [
183
+ self.debug_str_extra(),
184
+ ]
185
+ except Exception:
186
+ log.warning("Ignoring error in debug_str()", exc_info=True)
187
+
188
+ return "\n".join(lines).rstrip()
189
+
190
+ def debug_str_extra(self) -> str:
191
+ return ""
192
+
193
+ def log_details(self):
194
+ log.info(
195
+ "%s: unmet_dependencies = %s, writes = %s",
196
+ self,
197
+ self.unmet_dependencies,
198
+ self.read_writes.writes,
199
+ )
200
+
201
+ def update_mutated_names(self, renames: Dict[str, str]):
202
+ self.set_read_writes(self.read_writes.rename(renames))
203
+
204
+ def add_mutation_dep(self, dep):
205
+ self.set_read_writes(self.read_writes.with_read(dep))
206
+
207
+ def add_fake_dep(self, dep):
208
+ self.set_read_writes(self.read_writes.with_read(dep))
209
+
210
+ def set_users(self, users: List["NodeUser"]):
211
+ # deduplicate
212
+ result: Dict[int, NodeUser] = {}
213
+ for use in users:
214
+ if id(use.node) in result:
215
+ result[id(use.node)] = use.merge(result[id(use.node)])
216
+ else:
217
+ result[id(use.node)] = use
218
+ self.users = list(result.values())
219
+
220
+ def set_last_usage(
221
+ self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str]
222
+ ):
223
+ used_buffers = self.used_or_aliased_buffer_names()
224
+ used_buffers = {mutation_real_name.get(k, k) for k in used_buffers}
225
+ self.last_usage = used_buffers - future_used_buffers
226
+
227
+ def get_aliases(self):
228
+ return self.node.get_alias_names()
229
+
230
+ def get_mutations(self):
231
+ return self.node.get_mutation_names()
232
+
233
+ def has_aliasing_or_mutation(self):
234
+ return bool(self.get_aliases() or self.get_mutations())
235
+
236
+ def set_read_writes(self, rw: dependencies.ReadWrites):
237
+ self.read_writes: dependencies.ReadWrites = rw
238
+ self.unmet_dependencies = self.read_writes.reads
239
+ self.prune_deps()
240
+
241
+ def op_counts(self):
242
+ return self.read_writes.op_counts
243
+
244
+ def used_buffer_names(self) -> Set[str]:
245
+ return {
246
+ dep.name
247
+ for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes)
248
+ }
249
+
250
+ def used_or_aliased_buffer_names(self) -> Set[str]:
251
+ used_names = set()
252
+
253
+ for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes):
254
+ used_names.add(dep.name)
255
+ if V.graph.name_to_buffer.get(dep.name):
256
+ layout = V.graph.name_to_buffer[dep.name].get_layout()
257
+ # needed to avoid deallocating aliased buffer
258
+ # if there are still uses of aliases ahead
259
+ if isinstance(layout, ir.AliasedLayout):
260
+ used_names.add(layout.view.data.get_name())
261
+ return used_names
262
+
263
+ def prune_deps(self):
264
+ self.unmet_dependencies = {
265
+ dep
266
+ for dep in self.unmet_dependencies
267
+ if dep.name not in self.scheduler.available_buffer_names
268
+ }
269
+
270
+ def prune_weak_deps(self):
271
+ # Prune weak dependencies on buffers that have been removed
272
+ def should_prune(dep):
273
+ return isinstance(dep, WeakDep) and dep.name in V.graph.removed_buffers
274
+
275
+ to_remove = {dep for dep in self.read_writes.reads if should_prune(dep)}
276
+ self.set_read_writes(self.read_writes.remove_reads(to_remove))
277
+
278
+ def prune_redundant_deps(self, name_to_fused_node):
279
+ _prune_redundant_deps(self, name_to_fused_node)
280
+
281
+ def get_name(self) -> str:
282
+ return self.node.get_name()
283
+
284
+ def get_first_name(self) -> str:
285
+ return self.get_name()
286
+
287
+ def get_names(self) -> Set[str]:
288
+ return {self.get_name()}
289
+
290
+ def get_nodes(self) -> Sequence["BaseSchedulerNode"]:
291
+ return [self]
292
+
293
+ def get_device(self):
294
+ return self.node.get_device()
295
+
296
+ def is_reduction(self):
297
+ return False
298
+
299
+ def is_split_scan(self):
300
+ return False
301
+
302
+ def is_template(self):
303
+ return False
304
+
305
+ def is_extern(self):
306
+ return False
307
+
308
+ def is_foreach(self):
309
+ return False
310
+
311
+ def can_inplace(self, read_dep: dependencies.MemoryDep):
312
+ return False
313
+
314
+ def has_side_effects(self):
315
+ return False
316
+
317
+ def decide_inplace_update(self):
318
+ """
319
+ Decide if there should be inplace updates for the node
320
+ and record the decision in the active kernel.
321
+ """
322
+ if not self.node.should_allocate():
323
+ return
324
+
325
+ if isinstance(self, (SchedulerNode,)) and (
326
+ self.node.get_alias_names() or self.node.get_mutation_names()
327
+ ):
328
+ return
329
+
330
+ if (
331
+ (
332
+ isinstance(self, (SchedulerNode,))
333
+ # o what have i done. lets make this an api
334
+ or (
335
+ isinstance(self, ExternKernelSchedulerNode)
336
+ and isinstance(self.node, (ir.AllReduce, ir.InPlaceHint))
337
+ )
338
+ )
339
+ and config.inplace_buffers
340
+ and (
341
+ not isinstance(V.kernel, torch._inductor.codegen.triton.TritonKernel)
342
+ or getattr(V.kernel, "mutations", None) is not None
343
+ )
344
+ ):
345
+ from .codegen.wrapper import buffer_reuse_key
346
+
347
+ ordered_reads = sorted(self.read_writes.reads, key=lambda x: x.name)
348
+
349
+ for read in ordered_reads:
350
+ input_node: Optional[
351
+ BaseSchedulerNode
352
+ ] = self.scheduler.name_to_node.get(read.name)
353
+ if input_node and V.graph.wrapper_code.can_reuse(input_node, self):
354
+ assert input_node.users is not None
355
+ remaining_uses = [
356
+ x
357
+ for x in input_node.users
358
+ if x.node.get_name()
359
+ not in self.scheduler.available_buffer_names
360
+ ]
361
+ if (
362
+ len(remaining_uses) == 1
363
+ and remaining_uses[0].can_inplace
364
+ and remaining_uses[0].node is self
365
+ and not isinstance(
366
+ input_node.node.get_layout(),
367
+ (
368
+ ir.MultiOutputLayout,
369
+ ir.MutationLayout,
370
+ ir.AliasedLayout,
371
+ ),
372
+ )
373
+ and not (
374
+ isinstance(
375
+ input_node.node, (ir.FallbackKernel, ir.MultiOutput)
376
+ )
377
+ and len(input_node.node.get_alias_names()) > 0
378
+ )
379
+ and buffer_reuse_key(input_node.node)
380
+ == buffer_reuse_key(self.node)
381
+ ):
382
+ # hacky check for if V.kernel is a real kernel or NullHandler
383
+ if hasattr(V.kernel, "args"):
384
+ # if there isn't a triton kernel, then we don't need to call triton-specific things.
385
+ # but TODO this might be a convenient place to signal to the Collective kernels to inplace
386
+ # (and, can we make "kernel" less generic of a name?)
387
+ V.kernel.args.make_inplace(
388
+ input_node.get_name(), self.get_name()
389
+ )
390
+ # mutations not tracked in cpp kernels
391
+ if isinstance(
392
+ V.kernel, torch._inductor.codegen.triton.TritonKernel
393
+ ):
394
+ V.kernel.mutations.add(input_node.get_name())
395
+ V.kernel.mutations.add(self.get_name())
396
+
397
+ # update last usage of reused node
398
+ self.last_usage.discard(input_node.get_name())
399
+
400
+ V.kernel.inplace_update_buffers[
401
+ self.get_name()
402
+ ] = input_node.get_name()
403
+ break
404
+
405
+ def allocate(self):
406
+ if not self.node.should_allocate():
407
+ return
408
+
409
+ if isinstance(self, (SchedulerNode,)) and (
410
+ self.node.get_alias_names() or self.node.get_mutation_names()
411
+ ):
412
+ V.graph.wrapper_code.codegen_allocation(self.node)
413
+ return
414
+
415
+ # hacky check for if V.kernel is a real kernel or NullHandler
416
+ if (
417
+ hasattr(V.kernel, "args")
418
+ and self.get_name() in V.kernel.inplace_update_buffers
419
+ ):
420
+ V.graph.wrapper_code.codegen_inplace_reuse(
421
+ self.scheduler.name_to_node[
422
+ V.kernel.inplace_update_buffers[self.get_name()]
423
+ ].node,
424
+ self.node,
425
+ )
426
+ else:
427
+ V.graph.wrapper_code.codegen_allocation(self.node)
428
+
429
+ def can_free(self):
430
+ # There's no real allocated buffer, no need to free it
431
+ if isinstance(self.node.layout, ir.NoneLayout):
432
+ return False
433
+ for use in self.users:
434
+ if isinstance(use.node, OutputNode):
435
+ return False
436
+ return True
437
+
438
+ def codegen_originating_info(self, buffer, only_once=True):
439
+ if not config.comment_origin:
440
+ return
441
+
442
+ if only_once and self.written:
443
+ return
444
+ origins = self.node.origins
445
+ out_lines = []
446
+
447
+ for o in origins:
448
+ if o.op == "output":
449
+ # These are boring and samey
450
+ continue
451
+
452
+ out_lines.append("")
453
+ # TODO(voz): Should the pragma be constant somewhere?
454
+ out_lines.append("#pragma CMT ORIGIN:")
455
+ op_info_str = f"#pragma CMT {o.op} {o.target}"
456
+ if "seq_nr" in o.meta:
457
+ op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}"
458
+ out_lines.append(op_info_str)
459
+ if "stack_trace" in o.meta:
460
+ stack_trace = f"{o.meta['stack_trace']}"
461
+ stack_trace_last_line = stack_trace.split("|")[-1]
462
+ out_lines.append(
463
+ "#pragma CMT "
464
+ + stack_trace_last_line.replace("{", "{{")
465
+ .replace("}", "}}")
466
+ .replace("\n", "\\")
467
+ )
468
+ out_lines.append("#pragma CMT END ORIGIN")
469
+ out_lines.append("")
470
+
471
+ if len(out_lines) == 0:
472
+ return
473
+
474
+ # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
475
+ # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
476
+ buffer.writelines(out_lines)
477
+ self.written = True
478
+
479
+ def get_read_write_buffers_sizes(self) -> int:
480
+ """
481
+ Counting the number of bytes accessed for a kernel is
482
+ surprisingly tricky. In particular, there is a differentiation
483
+ between 'theoretical' memory accesses and practical memory
484
+ accesses. For example, a layernorm kernel may actually access an
485
+ input 3 times, but in theory, it only needs to access its input
486
+ once (and may be optimized to do so through say, persistent
487
+ reductions)
488
+
489
+ Another example is that even though a buffer is passed in, we may
490
+ not access the entire buffer. This may occur if we are accessing
491
+ a slice of the buffer. Another tricky case is for indirect
492
+ indexing, where the amount of bytes accessed depends on the
493
+ values of the input.
494
+
495
+ What this function aims to compute is the memory accesses for
496
+ worst-case inputs, best-case optimization. What this means is
497
+ that for each buffer we compute the amount of potential accesses in two ways and take the minimum.
498
+
499
+ 1. Numel in ranges multiplied by number of deps the buffer has
500
+ 2. The buffer size
501
+ """
502
+ if isinstance(self, NopKernelSchedulerNode):
503
+ return 0
504
+ if isinstance(self, ExternKernelSchedulerNode) and isinstance(
505
+ self.node, MultiOutput
506
+ ):
507
+ return 0
508
+
509
+ if isinstance(self, SchedulerNode):
510
+ node_numel = V.graph.sizevars.size_hint(
511
+ sympy_product(self.get_ranges()[0])
512
+ * sympy_product(self.get_ranges()[1])
513
+ )
514
+ else:
515
+ node_numel = int(1e9)
516
+ buf_accesses = collections.defaultdict(list)
517
+ for dep in self.read_writes.reads | self.read_writes.writes:
518
+ buf_accesses[dep.name].append(dep)
519
+
520
+ reads = {dep.name for dep in self.read_writes.reads}
521
+ writes = {dep.name for dep in self.read_writes.writes}
522
+
523
+ def is_materialized(buf, snodes):
524
+ users = self.scheduler.name_to_node[buf].users
525
+ buf_uses = {user.node for user in users}
526
+ return len(buf_uses - set(snodes)) > 0
527
+
528
+ if isinstance(self, FusedSchedulerNode):
529
+ removed_buffers = {
530
+ dep for dep in writes if not is_materialized(dep, self.snodes)
531
+ }
532
+ writes = writes - removed_buffers
533
+ reads = reads - removed_buffers
534
+ node_bytes = 0
535
+
536
+ for buf_name in reads | writes:
537
+ buf_accessed_elems = sum([node_numel for dep in buf_accesses[buf_name]])
538
+ buf: Union[ir.Buffer, ir.TensorBox]
539
+ if buf_name in V.graph.name_to_buffer:
540
+ buf = V.graph.name_to_buffer[buf_name]
541
+ elif buf_name in V.graph.graph_inputs:
542
+ buf = V.graph.graph_inputs[buf_name]
543
+ else:
544
+ continue
545
+
546
+ def get_buf_elems(buf):
547
+ return V.graph.sizevars.size_hint(sympy_product(buf.get_size()))
548
+
549
+ # Kind of a lazy way to get the MultiOutput nodes corresponding to
550
+ # a MultiOutputLayout
551
+ if isinstance(buf.layout, MultiOutputLayout):
552
+ users = self.scheduler.name_to_node[buf.get_name()].users
553
+ buf_elems = sum(get_buf_elems(user.node.node) for user in users)
554
+ else:
555
+ buf_elems = get_buf_elems(buf)
556
+
557
+ node_bytes += min(buf_elems, buf_accessed_elems) * get_dtype_size(
558
+ buf.get_dtype()
559
+ )
560
+
561
+ return node_bytes
562
+
563
+ def get_estimated_runtime(self) -> float:
564
+ """
565
+ Returns estimated op runtime in nanoseconds (ns)
566
+ """
567
+ layout = None
568
+ dtype = None
569
+ if not hasattr(self, "node") or not self.node:
570
+ assert isinstance(
571
+ self, (FusedSchedulerNode, ForeachKernelSchedulerNode)
572
+ ), f"{type(self)=}"
573
+ assert self.snodes
574
+ if not self.snodes[0].node:
575
+ return 0
576
+ layout = self.snodes[0].node.get_layout()
577
+ dtype = self.snodes[0].node.get_dtype()
578
+ else:
579
+ layout = self.node.get_layout()
580
+ dtype = self.node.get_dtype()
581
+
582
+ if "cuda" != layout.device.type:
583
+ # default to no reordering based on runtime
584
+ return 0
585
+
586
+ # Collective kernels
587
+ if is_collective(self.node):
588
+ return estimate_nccl_collective_runtime(self.node)
589
+ elif is_wait(self.node):
590
+ # ir.Wait is only used for collective ops.
591
+ # The time needed for the collective op is already estimated and considered
592
+ # when we are processing the collective op IR node, so ir.Wait takes 0 time
593
+ # since it doesn't take extra time to get the result after the collective is completed.
594
+ return 0
595
+
596
+ try:
597
+ gpu_memory_bandwidth = get_gpu_dram_gbps()
598
+ gpu_flops = get_device_tflops(dtype) * 10**12
599
+ except Exception:
600
+ return 0
601
+
602
+ if isinstance(self, ExternKernelSchedulerNode):
603
+ assert isinstance(self.node, ir.ExternKernel), f"{type(self.node)=}"
604
+ op = kernel_name_to_op.get(
605
+ getattr(self.node, "python_kernel_name", ""), None
606
+ )
607
+
608
+ # if there is a resolved op, dry-run using fake mode and record flop count
609
+ if op is not None:
610
+ from torch._subclasses.fake_tensor import FakeTensorMode
611
+ from torch.utils.flop_counter import FlopCounterMode
612
+
613
+ with FakeTensorMode(), FlopCounterMode(
614
+ display=False
615
+ ) as flop_counter_mode:
616
+ from .ir import ir_node_to_tensor
617
+
618
+ fake_inputs = [
619
+ ir_node_to_tensor(input, guard_shape=False)
620
+ for input in self.node.inputs
621
+ ]
622
+ cls = self.node.__class__
623
+ cls.process_kernel(op, *fake_inputs, **self.node.kwargs)
624
+
625
+ # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship
626
+ factor = 1.0
627
+ counted_flops = flop_counter_mode.get_total_flops()
628
+ counted_bytes = self.get_read_write_buffers_sizes()
629
+ compute_time = (factor * counted_flops / gpu_flops) * 1e9
630
+ transfer_time = counted_bytes / gpu_memory_bandwidth
631
+
632
+ # Return estimated runtime in nanoseconds
633
+ return max(compute_time, transfer_time)
634
+
635
+ elif isinstance(self, FusedSchedulerNode) or isinstance(
636
+ self.node, ComputedBuffer
637
+ ):
638
+ # Return estimated runtime in nanoseconds (bytes / gbps)
639
+ return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
640
+
641
+ return 0
642
+
643
+
644
+ class ExternKernelSchedulerNode(BaseSchedulerNode):
645
+ def debug_str_extra(self) -> str:
646
+ return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}"
647
+
648
+ def is_extern(self):
649
+ return True
650
+
651
+ def has_side_effects(self):
652
+ return hasattr(self.node, "has_side_effects") and self.node.has_side_effects()
653
+
654
+ def can_inplace(self, read_dep: dependencies.MemoryDep):
655
+ if self.get_aliases() or self.is_template():
656
+ return False
657
+
658
+ if read_dep.name not in self.scheduler.name_to_node:
659
+ # don't allow reuse of an 'input' buffer, we don't own it
660
+ # (would this have been fixed if I tracked mutations properly above?)
661
+ return False
662
+ if not isinstance(
663
+ self.node, (torch._inductor.ir.AllReduce, torch._inductor.ir.InPlaceHint)
664
+ ):
665
+ # TODO make this a property of the IR
666
+ return False
667
+
668
+ if len(self.read_writes.writes) == 1:
669
+ write_dep = next(iter(self.read_writes.writes))
670
+ numel_diff = read_dep.get_numel() - write_dep.get_numel()
671
+ return V.graph.sizevars.simplify(numel_diff) == 0
672
+
673
+ return False
674
+
675
+
676
+ class NopKernelSchedulerNode(BaseSchedulerNode):
677
+ pass
678
+
679
+
680
+ class SchedulerNode(BaseSchedulerNode):
681
+ def __init__(
682
+ self,
683
+ scheduler: "Scheduler",
684
+ node: Union[ir.ComputedBuffer, ir.TemplateBuffer],
685
+ ):
686
+ super().__init__(scheduler, node)
687
+ self._compute_attrs()
688
+
689
+ def _compute_attrs(
690
+ self,
691
+ extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None,
692
+ ):
693
+ assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer))
694
+ self._sizes, self._body = self.node.simplify_and_reorder(
695
+ extra_indexing_constraints=extra_indexing_constraints
696
+ )
697
+
698
+ group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn
699
+ self.group = (self.node.get_device(), group_fn(self._sizes))
700
+
701
+ if isinstance(self.node, ir.TemplateBuffer):
702
+ self.set_read_writes(self.node.normalized_read_writes())
703
+ else:
704
+ self.set_read_writes(
705
+ dependencies.extract_read_writes(
706
+ self._body, *self._sizes, normalize=True
707
+ )
708
+ )
709
+
710
+ def recompute_size_and_body(
711
+ self, extra_indexing_constraints: Tuple[Dict[Any, Any], List[Any]]
712
+ ):
713
+ self._compute_attrs(extra_indexing_constraints=extra_indexing_constraints)
714
+
715
+ def debug_str_extra(self) -> str:
716
+ name = self.get_name()
717
+ lines = [
718
+ f"{name}.group.device = {self.group[0]}",
719
+ f"{name}.group.iteration = {self.group[1]}",
720
+ f"{name}.sizes = {self._sizes}",
721
+ ]
722
+ if self.get_aliases():
723
+ lines.append(f"{name}.aliases = {pformat(self.get_aliases())}")
724
+ if self.get_mutations():
725
+ lines.append(f"{name}.mutations = {pformat(self.get_mutations())}")
726
+ if isinstance(self._body, ir.LoopBody):
727
+ lines.append(f"class {name}_loop_body:")
728
+ lines.append(textwrap.indent(self._body.debug_str(), " "))
729
+ return "\n".join(lines)
730
+
731
+ def get_ranges(self):
732
+ return self._sizes
733
+
734
+ def is_reduction(self):
735
+ assert isinstance(
736
+ self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
737
+ ), f"{type(self.node)=}"
738
+ return bool(self.node.get_reduction_type())
739
+
740
+ def is_split_scan(self):
741
+ assert isinstance(
742
+ self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
743
+ ), f"{type(self.node)=}"
744
+ return isinstance(self.node, ir.ComputedBuffer) and isinstance(
745
+ self.node.data, ir.SplitScan
746
+ )
747
+
748
+ def is_template(self):
749
+ return isinstance(self.node, ir.TemplateBuffer)
750
+
751
+ def get_template_node(self):
752
+ return self.node if self.is_template() else None
753
+
754
+ def run(self, *index_vars):
755
+ self.decide_inplace_update()
756
+ self.mark_run()
757
+ self.codegen(index_vars)
758
+
759
+ def mark_run(self):
760
+ self.allocate()
761
+
762
+ def ranges_from_index_vars(self, index_vars):
763
+ sizes = self._sizes
764
+ assert sum(map(len, sizes)) == sum(map(len, index_vars))
765
+ var_ranges = dict(
766
+ zip(
767
+ itertools.chain.from_iterable(index_vars),
768
+ itertools.chain.from_iterable(sizes),
769
+ )
770
+ )
771
+ return var_ranges
772
+
773
+ def codegen(self, index_vars):
774
+ var_ranges = self.ranges_from_index_vars(index_vars)
775
+ try:
776
+ with V.set_ops_handler(
777
+ SimplifyIndexing(V.get_ops_handler(), var_ranges)
778
+ ), V.kernel.set_current_node(self):
779
+ self._body(*index_vars)
780
+ except Exception:
781
+ log.fatal("Error in codegen for %s", self.node)
782
+ raise
783
+
784
+ def pointwise_read_writes(self):
785
+ """
786
+ Get the memory dependencies in the non-reduction axis.
787
+ """
788
+ sizes, reduction_sizes = self._sizes
789
+
790
+ def fn(index):
791
+ return self._body(index, [sympy.Integer(0) for _ in reduction_sizes])
792
+
793
+ return dependencies.extract_read_writes(fn, sizes)
794
+
795
+ def can_inplace(self, read_dep: dependencies.MemoryDep):
796
+ if self.get_aliases() or self.is_template():
797
+ return False
798
+ if len(self.read_writes.writes) == 1 and isinstance(
799
+ read_dep, dependencies.MemoryDep
800
+ ):
801
+ write_dep = next(iter(self.read_writes.writes))
802
+ assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}"
803
+ return read_dep.index == write_dep.index and read_dep.size == write_dep.size
804
+ return False
805
+
806
+ @cache_on_self
807
+ def _get_atomic_add_buffers(self) -> Set[str]:
808
+ buffers_store_as_atomic_add = set()
809
+ if isinstance(self._body, ir.LoopBody):
810
+ for node in self._body.get_nodes():
811
+ if (
812
+ node.op == "call_method"
813
+ and node.target == "store"
814
+ and (
815
+ ("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add")
816
+ or (len(node.args) == 5 and node.args[4] == "atomic_add")
817
+ )
818
+ ):
819
+ buffers_store_as_atomic_add.add(
820
+ node.kwargs["name"]
821
+ if "name" in node.kwargs
822
+ else (node.args[1] if len(node.args) >= 2 else "")
823
+ )
824
+ return buffers_store_as_atomic_add
825
+
826
+ def has_atomic_add(self, check_buf):
827
+ return check_buf in self._get_atomic_add_buffers()
828
+
829
+
830
+ class FusedSchedulerNode(BaseSchedulerNode):
831
+ """
832
+ This is a "fake" scheduler node that represents a group of scheduler nodes
833
+ that are meant to be fused together. The way it does this is by maintaining
834
+ its unmet dependencies as the union of its constituent nodes.
835
+ """
836
+
837
+ @classmethod
838
+ def fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
839
+ assert node1.scheduler is node2.scheduler
840
+ assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance(
841
+ node2, (SchedulerNode, FusedSchedulerNode)
842
+ )
843
+ return cls(node1.scheduler, list(node1.get_nodes()) + list(node2.get_nodes())) # type: ignore[arg-type]
844
+
845
+ def __init__(self, scheduler: "Scheduler", snodes: List[SchedulerNode]):
846
+ # NB: No need to call super().__init__() because we don't need to re-use any of its logic.
847
+ self.snodes = snodes
848
+ self.scheduler = scheduler
849
+ self.node: ir.Buffer = None # type: ignore[assignment]
850
+ self.users: List[NodeUser] = []
851
+ self.inverse_users = []
852
+ self.node_users = []
853
+ self.group = max(snodes, key=lambda x: int(x.is_reduction())).group
854
+ self.ancestors = set.union(
855
+ *[x.ancestors for x in snodes if x.ancestors is not None]
856
+ )
857
+
858
+ self.set_read_writes(
859
+ dependencies.ReadWrites.merge_list([x.read_writes for x in snodes])
860
+ )
861
+
862
+ self.unmet_dependencies = {
863
+ dep
864
+ for dep in set.union(*[x.unmet_dependencies for x in snodes])
865
+ if dep.name not in self.get_names()
866
+ } - self.read_writes.writes
867
+ self.min_order = min([x.min_order for x in self.snodes])
868
+ self.max_order = max([x.max_order for x in self.snodes])
869
+
870
+ @cache_on_self
871
+ def get_name(self) -> str:
872
+ return "_".join([x.get_name() for x in self.snodes])
873
+
874
+ def get_first_name(self) -> str:
875
+ return self.snodes[0].get_name()
876
+
877
+ @cache_on_self
878
+ def get_names(self) -> Set[str]:
879
+ return set.union(*[x.get_names() for x in self.snodes])
880
+
881
+ def debug_str_extra(self) -> str:
882
+ lines = [
883
+ f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}"
884
+ for i, node in enumerate(self.snodes)
885
+ ]
886
+ return textwrap.indent("\n".join(lines).rstrip(), " ")
887
+
888
+ def set_last_usage(
889
+ self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str]
890
+ ):
891
+ # Set self.last_usage using the global information
892
+ # This will be used for inter-kernel optimisations
893
+ super().set_last_usage(future_used_buffers, mutation_real_name)
894
+ # Set self.last_usage on the snodes
895
+ # This will be used for optimisations within the kernel
896
+ future_used_buffers: Set[str] = set()
897
+ for node in reversed(self.snodes):
898
+ node.set_last_usage(future_used_buffers, mutation_real_name)
899
+ future_used_buffers.update(node.last_usage) # type: ignore[arg-type]
900
+
901
+ @cache_on_self
902
+ def used_buffer_names(self) -> Set[str]:
903
+ return set.union(*[x.used_buffer_names() for x in self.snodes])
904
+
905
+ @cache_on_self
906
+ def used_or_aliased_buffer_names(self) -> Set[str]:
907
+ return set.union(*[x.used_or_aliased_buffer_names() for x in self.snodes])
908
+
909
+ def get_nodes(self) -> List[SchedulerNode]:
910
+ return self.snodes
911
+
912
+ def __repr__(self):
913
+ return f"{type(self).__name__}(nodes={self.get_name()})"
914
+
915
+ @cache_on_self
916
+ def is_reduction(self):
917
+ return any(x.is_reduction() for x in self.snodes)
918
+
919
+ @cache_on_self
920
+ def is_split_scan(self):
921
+ return any(x.is_split_scan() for x in self.snodes)
922
+
923
+ @cache_on_self
924
+ def is_template(self):
925
+ return any(x.is_template() for x in self.snodes)
926
+
927
+ @cache_on_self
928
+ def get_template_node(self):
929
+ for node in self.snodes:
930
+ if node.is_template():
931
+ return node
932
+ return None
933
+
934
+ def get_device(self):
935
+ return self.group[0]
936
+
937
+ @cache_on_self
938
+ def has_aliasing_or_mutation(self):
939
+ return any(x.has_aliasing_or_mutation() for x in self.snodes)
940
+
941
+ @cache_on_self
942
+ def op_counts(self):
943
+ op_counts: Counter[str] = collections.Counter()
944
+ for node in self.snodes:
945
+ op_counts.update(node.op_counts())
946
+ return op_counts
947
+
948
+ def has_atomic_add(self, check_buf):
949
+ return any(
950
+ (
951
+ isinstance(sub_schedule_node1, SchedulerNode)
952
+ and sub_schedule_node1.has_atomic_add(check_buf)
953
+ )
954
+ for sub_schedule_node1 in self.get_nodes()
955
+ )
956
+
957
+ # None of these need to be implemented, as a FusedSchedulerNode is just an
958
+ # abstraction for scheduling purposes
959
+ def update_mutated_names(self, renames: Dict[str, str]):
960
+ raise NotImplementedError
961
+
962
+ def add_mutation_dep(self, name):
963
+ raise NotImplementedError
964
+
965
+ def set_users(self, users: List["NodeUser"]):
966
+ raise NotImplementedError
967
+
968
+ def get_aliases(self):
969
+ raise NotImplementedError
970
+
971
+ def get_mutations(self):
972
+ raise NotImplementedError
973
+
974
+ def can_inplace(self, read_dep: dependencies.MemoryDep):
975
+ raise NotImplementedError
976
+
977
+ def allocate(self):
978
+ raise NotImplementedError
979
+
980
+ def can_free(self):
981
+ raise NotImplementedError
982
+
983
+ def debug_str(self) -> str:
984
+ """Longer form printout for trace logs"""
985
+ name = self.get_name()
986
+ node_typestr = ",".join(type(n).__name__ for n in self.snodes)
987
+ lines = [
988
+ f"{name}: {type(self).__name__}({node_typestr})",
989
+ f"{name}.writes = {pformat(self.read_writes.writes)}",
990
+ f"{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}",
991
+ f"{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}",
992
+ f"{name}.users = {self.users}",
993
+ ]
994
+ try:
995
+ lines += [
996
+ self.debug_str_extra(),
997
+ ]
998
+ except Exception:
999
+ log.warning("Ignoring error in debug_str()", exc_info=True)
1000
+
1001
+ return "\n".join(lines).rstrip()
1002
+
1003
+
1004
+ class ForeachKernelSchedulerNode(FusedSchedulerNode):
1005
+ """Scheduler node which consists of a list of scheduler nodes that each operate on a
1006
+ distinct tensor in a list of tensors."""
1007
+
1008
+ def get_consumer_subnode_for(self, producer):
1009
+ if producer.get_name() in self.read_to_node:
1010
+ return self.read_to_node[producer.get_name()]
1011
+
1012
+ return None
1013
+
1014
+ def get_producer_subnode_for(self, consumer):
1015
+ for rd in consumer.read_writes.reads:
1016
+ if rd.name in self.name_to_node:
1017
+ return self.name_to_node[rd.name]
1018
+
1019
+ return None
1020
+
1021
+ @classmethod
1022
+ def can_fuse(cls, producer, consumer):
1023
+ why = WhyNoFuse(producer, consumer)
1024
+ if producer.is_foreach() and consumer.is_foreach():
1025
+ foreach_match = len(producer.snodes) == len(consumer.snodes)
1026
+ if not foreach_match:
1027
+ why("foreach do not have same length")
1028
+ return foreach_match and all(
1029
+ producer.scheduler.can_fuse(l, r)
1030
+ for l, r in zip(producer.snodes, consumer.snodes)
1031
+ )
1032
+ elif consumer.is_foreach():
1033
+ consumer_subnode = consumer.get_consumer_subnode_for(producer)
1034
+ if consumer_subnode is not None:
1035
+ return consumer.scheduler.can_fuse(producer, consumer_subnode)
1036
+
1037
+ why("candidate producer is not dep of any foreach consumer")
1038
+ return False
1039
+
1040
+ elif producer.is_foreach():
1041
+ producer_subnode = producer.get_producer_subnode_for(consumer)
1042
+ if producer_subnode is not None:
1043
+ return producer.scheduler.can_fuse(producer_subnode, consumer)
1044
+
1045
+ why("candidate consumer has no dep in any foreach producer")
1046
+ return False
1047
+
1048
+ raise AssertionError(
1049
+ "At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node"
1050
+ )
1051
+
1052
+ @classmethod
1053
+ def fuse(cls, producer, consumer):
1054
+ assert producer.is_foreach() or consumer.is_foreach()
1055
+ prev_node_1 = None
1056
+ prev_node_2 = None
1057
+ if producer.is_foreach() and consumer.is_foreach():
1058
+ fused_nodes = [
1059
+ FusedSchedulerNode.fuse(l, r)
1060
+ for l, r in zip(producer.snodes, consumer.snodes)
1061
+ ]
1062
+ elif producer.is_foreach():
1063
+ producer_subnode = producer.get_producer_subnode_for(consumer)
1064
+ fused_nodes = []
1065
+ prev_node_1 = producer
1066
+ prev_node_2 = None
1067
+ for node in producer.snodes:
1068
+ if node is producer_subnode:
1069
+ new_node = FusedSchedulerNode.fuse(node, consumer)
1070
+ prev_node_2 = new_node
1071
+ fused_nodes.append(new_node)
1072
+ else:
1073
+ fused_nodes.append(node)
1074
+
1075
+ elif consumer.is_foreach():
1076
+ consumer_subnode = consumer.get_consumer_subnode_for(producer)
1077
+ fused_nodes = []
1078
+ prev_node_1 = consumer
1079
+ prev_node_2 = None
1080
+
1081
+ for node in consumer.snodes:
1082
+ if node is consumer_subnode:
1083
+ new_node = FusedSchedulerNode.fuse(producer, node)
1084
+ prev_node_2 = new_node
1085
+ fused_nodes.append(new_node)
1086
+ else:
1087
+ fused_nodes.append(node)
1088
+
1089
+ return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2) # type: ignore[possibly-undefined]
1090
+
1091
+ def __init__(
1092
+ self,
1093
+ scheduler: "Scheduler",
1094
+ nodes: List[SchedulerNode],
1095
+ prev_node_1=None,
1096
+ prev_node_2=None,
1097
+ ):
1098
+ self.read_to_node = {}
1099
+ self.name_to_node = {}
1100
+
1101
+ if prev_node_1 is None or prev_node_2 is None:
1102
+ super().__init__(scheduler, nodes)
1103
+
1104
+ for node in nodes:
1105
+ for read in node.read_writes.reads:
1106
+ self.read_to_node[read.name] = node
1107
+
1108
+ for name in node.get_names():
1109
+ self.name_to_node[name] = node
1110
+ else:
1111
+ self.scheduler = scheduler
1112
+ self.snodes = nodes
1113
+ self.node: ir.Buffer = None # type: ignore[assignment]
1114
+ self.users: List[NodeUser] = []
1115
+
1116
+ self.set_read_writes(
1117
+ dependencies.ReadWrites.merge_list(
1118
+ [prev_node_1.read_writes, prev_node_2.read_writes]
1119
+ )
1120
+ )
1121
+
1122
+ self.unmet_dependencies = {
1123
+ dep
1124
+ for dep in set.union(
1125
+ prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies
1126
+ )
1127
+ if dep.name not in self.get_names()
1128
+ } - self.read_writes.writes
1129
+
1130
+ self.min_order = min([prev_node_1.min_order, prev_node_2.min_order])
1131
+ self.max_order = max([prev_node_1.max_order, prev_node_2.max_order])
1132
+
1133
+ foreach_node = prev_node_1 if prev_node_1.is_foreach() else prev_node_2
1134
+ other_node = prev_node_2 if prev_node_1.is_foreach() else prev_node_1
1135
+
1136
+ self.ancestors = foreach_node.ancestors
1137
+ self.ancestors.update(other_node.ancestors)
1138
+
1139
+ self.name_to_node = foreach_node.name_to_node
1140
+ for name in other_node.get_names():
1141
+ self.name_to_node[name] = other_node
1142
+
1143
+ self.group = (nodes[0].get_device(), "foreach")
1144
+
1145
+ self.origins: Set[torch.fx.Node] = set()
1146
+
1147
+ def mark_run(self):
1148
+ raise NotImplementedError
1149
+
1150
+ def codegen(self):
1151
+ assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}"
1152
+ self.node.get_store_function()(self.node.make_loader()())
1153
+
1154
+ def can_free(self):
1155
+ return NotImplementedError
1156
+
1157
+ def is_foreach(self):
1158
+ return True
1159
+
1160
+ def get_subkernel_nodes(self):
1161
+ """Returns a list of nodes which comprise the foreach kernel, operating on corresponding elements of our input lists.
1162
+ These nodes may be vertically fused."""
1163
+ return list(self.snodes)
1164
+
1165
+ def get_nodes(self):
1166
+ """Returns all nodes contained in this kernel, unpacking fused nodes into their constituent scheduler nodes."""
1167
+ return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes))
1168
+
1169
+ def get_first_name(self):
1170
+ return self.snodes[0].get_first_name()
1171
+
1172
+ def prune_redundant_deps(self, name_to_fused_node):
1173
+ _prune_redundant_deps(self, name_to_fused_node)
1174
+
1175
+ for node in self.snodes:
1176
+ node.prune_redundant_deps(name_to_fused_node)
1177
+
1178
+
1179
+ def pick_loop_order(stride_lengths, sizes, priority_idx=()):
1180
+ """
1181
+ A heuristic to decide loop iteration orders. This has not been well
1182
+ tuned and may be something we should autotune.
1183
+ """
1184
+
1185
+ @functools.cmp_to_key
1186
+ def index_cmp(a, b):
1187
+ if sizes[a] == 1 or sizes[b] == 1:
1188
+ # 1-sizes don't matter, just move them to the end
1189
+ return cmp(sizes[a] == 1, sizes[b] == 1)
1190
+
1191
+ stride_len_a = [sl[a] for sl in stride_lengths]
1192
+ stride_len_b = [sl[b] for sl in stride_lengths]
1193
+
1194
+ # equivalent to
1195
+ # np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all()
1196
+ a_first = sum(
1197
+ sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b)
1198
+ )
1199
+ b_first = sum(
1200
+ sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b)
1201
+ )
1202
+ if a_first > b_first:
1203
+ return -1
1204
+ if b_first > a_first:
1205
+ return 1
1206
+
1207
+ # otherwise contiguous
1208
+ return cmp(b, a)
1209
+
1210
+ order = list(reversed(range(len(stride_lengths[0]))))
1211
+ if len(priority_idx) > 0:
1212
+ # if we have priority node, only use that node's order
1213
+ stride_lengths = [stride_lengths[pi] for pi in priority_idx]
1214
+ if config.pick_loop_orders:
1215
+ order.sort(key=index_cmp)
1216
+ return order
1217
+
1218
+
1219
+ @dataclasses.dataclass
1220
+ class NodeUser:
1221
+ node: BaseSchedulerNode
1222
+ can_inplace: bool = False
1223
+
1224
+ # A weak user must be scheduled after a given node, but doesn't actually
1225
+ # use the result
1226
+ is_weak: bool = False
1227
+
1228
+ def __hash__(self):
1229
+ return hash((self.node.get_name(), self.can_inplace, self.is_weak))
1230
+
1231
+ def __eq__(self, other):
1232
+ return (
1233
+ self.get_name() == other.get_name()
1234
+ and self.can_inplace == other.can_inplace
1235
+ and self.is_weak == other.is_weak
1236
+ )
1237
+
1238
+ def get_name(self):
1239
+ return self.node.get_name()
1240
+
1241
+ def merge(self, other: "NodeUser") -> "NodeUser":
1242
+ assert self.node is other.node
1243
+ return NodeUser(
1244
+ self.node,
1245
+ self.can_inplace and other.can_inplace,
1246
+ self.is_weak and other.is_weak,
1247
+ )
1248
+
1249
+
1250
+ _post_grad_graph_counter = itertools.count()
1251
+
1252
+
1253
+ class Scheduler:
1254
+ @dynamo_timed
1255
+ def __init__(self, nodes):
1256
+ super().__init__()
1257
+ self.backends = {}
1258
+ self.fuse_cache = {}
1259
+ self.post_grad_graph_id = next(_post_grad_graph_counter)
1260
+
1261
+ self.nodes = []
1262
+ self.available_buffer_names = {
1263
+ *V.graph.graph_inputs.keys(),
1264
+ *V.graph.constants.keys(),
1265
+ }
1266
+
1267
+ self.nodes = [self.create_scheduler_node(n) for n in nodes]
1268
+
1269
+ # some new constants could have been created above
1270
+ self.available_buffer_names.update(V.graph.constants.keys())
1271
+ for node in self.nodes:
1272
+ node.prune_deps()
1273
+
1274
+ self.name_to_node: Dict[str, BaseSchedulerNode] = {
1275
+ n.get_name(): n for n in self.nodes
1276
+ }
1277
+ self.name_to_fused_node: Dict[
1278
+ str, BaseSchedulerNode
1279
+ ] = dict() # set in fuse_nodes()
1280
+
1281
+ # mutation_real_name: Maps back to the original name for codegen
1282
+ # Example:
1283
+ # If you mutate buf0 inside of buf1's kernel, then:
1284
+ # mutation_real_name = {"buf0" : "buf1"}
1285
+ # all subsequent uses of buf0 become buf1's usage in dependency graph
1286
+ self.mutation_real_name = {}
1287
+
1288
+ # We handle mutation by renaming modified versions of the same
1289
+ # buffer in the dependency graph to prevent cycles.
1290
+ # mutation_renames: tracks the current name for a given buffer
1291
+ # (changed once per mutation)
1292
+ # Example:
1293
+ # If you mutate buf0 inside of buf1's kernel, then:
1294
+ # mutation_renames = {"buf1" : "buf0"}
1295
+ # in codegen we only use buf0, never buf1
1296
+ self.mutation_renames = {}
1297
+
1298
+ self.compute_dependencies()
1299
+ self.topological_sort_schedule()
1300
+ self.dead_node_elimination()
1301
+ if config.reorder_for_compute_comm_overlap:
1302
+ comms.decide_global_ordering_of_comms(self.nodes)
1303
+ self.compute_ancestors()
1304
+
1305
+ metrics.ir_nodes_pre_fusion += len(self.nodes)
1306
+ V.debug.ir_pre_fusion(self.nodes)
1307
+ self.num_orig_nodes = len(self.nodes)
1308
+ self.name_to_fused_node = {n.get_name(): n for n in self.nodes}
1309
+ self.create_foreach_nodes()
1310
+ self.topological_sort_schedule()
1311
+ self.logged_slow_fusion = set()
1312
+ self.fuse_nodes()
1313
+ if config.reorder_for_compute_comm_overlap:
1314
+ # Refresh node_users and inverse_users to reflect fused nodes
1315
+ self.compute_node_users()
1316
+ self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
1317
+ self.compute_last_usage()
1318
+ V.debug.ir_post_fusion(self.nodes)
1319
+ V.debug.graph_diagram(self.nodes)
1320
+ self.debug_draw_graph()
1321
+
1322
+ # used during codegen:
1323
+ self.current_device: torch.device = None # type: ignore[assignment]
1324
+ self.buffer_names_to_free = set()
1325
+
1326
+ # fx graph node to the position it appears in the graph
1327
+ # for debug attribution
1328
+ self.origin_to_index = {}
1329
+
1330
+ get_metric_table("graph_stats").add_row(
1331
+ lambda: {
1332
+ "graph_id": self.post_grad_graph_id,
1333
+ "num_nodes_before_fusion": self.num_orig_nodes,
1334
+ "num_nodes_after_fusion": len(self.nodes),
1335
+ }
1336
+ )
1337
+
1338
+ def debug_draw_graph(self):
1339
+ """Generate an image of the graph for debugging"""
1340
+ if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1":
1341
+ from .debug import draw_buffers
1342
+
1343
+ draw_buffers(self.nodes, print_graph=True)
1344
+
1345
+ def debug_print_nodes(self, label):
1346
+ if log.isEnabledFor(logging.INFO):
1347
+ log.info("%s:", label)
1348
+ for node in self.nodes:
1349
+ node.log_details()
1350
+
1351
+ def create_scheduler_node(self, node):
1352
+ assert (
1353
+ node.origins is not None
1354
+ ), "All nodes passed to scheduling must have an origin"
1355
+ if node.is_no_op():
1356
+ return NopKernelSchedulerNode(self, node)
1357
+ elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
1358
+ return SchedulerNode(self, node)
1359
+ elif isinstance(node, ir.ExternKernel):
1360
+ return ExternKernelSchedulerNode(self, node)
1361
+ else:
1362
+ raise NotImplementedError(node)
1363
+
1364
+ def create_foreach_nodes(self):
1365
+ removed_node_names = set()
1366
+ fe_nodes = []
1367
+ kept_node_names = self.name_to_fused_node.keys()
1368
+
1369
+ for names in V.graph.lists.values():
1370
+ names = [
1371
+ name
1372
+ for name in names
1373
+ if name in kept_node_names
1374
+ and not isinstance(self.name_to_node[name], NopKernelSchedulerNode)
1375
+ ]
1376
+ if not names:
1377
+ # All nodes eliminated
1378
+ continue
1379
+
1380
+ removed_node_names.update(names)
1381
+ snodes = [self.name_to_node[name] for name in names]
1382
+
1383
+ fe_node = ForeachKernelSchedulerNode(self, snodes) # type: ignore[arg-type]
1384
+
1385
+ fe_nodes.append(fe_node)
1386
+
1387
+ for name in names:
1388
+ self.name_to_fused_node[name] = fe_node
1389
+
1390
+ self.nodes = [
1391
+ node for node in self.nodes if node.get_name() not in removed_node_names
1392
+ ] + fe_nodes
1393
+
1394
+ def compute_dependencies(self):
1395
+ """
1396
+ Create dependency edges between nodes, handling aliasing and
1397
+ mutation properly.
1398
+ """
1399
+
1400
+ T = TypeVar("T")
1401
+
1402
+ class DedupList(Generic[T]):
1403
+ """
1404
+ This data structure behaves like a list except it makes sure the
1405
+ elements remain unique.
1406
+ Normally one could use a set/dict for this purpose however
1407
+ the list in question gets elements appended as it is being
1408
+ iterated over which means that we need to keep the list
1409
+ semantics.
1410
+ """
1411
+
1412
+ def __init__(self, items=None, membership=None):
1413
+ self.items = items or list()
1414
+ self.membership = membership or set()
1415
+
1416
+ def append(self, node_user: T) -> None:
1417
+ if node_user in self.membership:
1418
+ return
1419
+ self.items.append(node_user)
1420
+ self.membership.add(node_user)
1421
+
1422
+ def __add__(self, other: "DedupList[T]") -> "DedupList[T]":
1423
+ new_membership = set.union(self.membership, other.membership)
1424
+ new_items = self.items + [
1425
+ x for x in other.items if x not in self.membership
1426
+ ]
1427
+ return DedupList(new_items, new_membership)
1428
+
1429
+ name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict(
1430
+ DedupList
1431
+ )
1432
+
1433
+ # handle aliasing by using python aliasing in name_to_users
1434
+ # if foo aliases bar then we will make name_to_users["foo"] point
1435
+ # to the same python list as name_to_users["bar"]
1436
+ for node1 in self.nodes:
1437
+ node1_name = node1.get_name()
1438
+ for node2_name in node1.get_aliases():
1439
+ if node1_name in name_to_users and node2_name in name_to_users:
1440
+ # merge the two
1441
+ list1 = name_to_users[node1_name]
1442
+ list2 = name_to_users[node2_name]
1443
+ combined = list1 + list2
1444
+ for key in name_to_users.keys():
1445
+ if name_to_users[key] is list1 or name_to_users[key] is list2:
1446
+ name_to_users[key] = combined
1447
+ elif node1_name in name_to_users:
1448
+ name_to_users[node2_name] = name_to_users[node1_name]
1449
+ else:
1450
+ name_to_users[node1_name] = name_to_users[node2_name]
1451
+
1452
+ def rename(n):
1453
+ if n in self.mutation_renames:
1454
+ return rename(self.mutation_renames[n])
1455
+ return n
1456
+
1457
+ def dep_closure(node_name):
1458
+ reachable_names = {node_name}
1459
+ node = self.name_to_node[node_name]
1460
+ write_dep = next(iter(node.read_writes.writes))
1461
+ for read_dep in node.read_writes.reads:
1462
+ if (
1463
+ read_dep.name in self.name_to_node
1464
+ and isinstance(read_dep, dependencies.MemoryDep)
1465
+ and isinstance(write_dep, dependencies.MemoryDep)
1466
+ and read_dep.index == write_dep.index
1467
+ and read_dep.size == write_dep.size
1468
+ ):
1469
+ reachable_names.update(dep_closure(read_dep.name))
1470
+ return reachable_names
1471
+
1472
+ def add_user(used_by_name, user_node, can_inplace=False, is_weak=False):
1473
+ name_to_users[rename(used_by_name)].append(
1474
+ NodeUser(user_node, can_inplace, is_weak)
1475
+ )
1476
+
1477
+ unbacked_symbol_to_origin_node = {}
1478
+
1479
+ for node in self.nodes:
1480
+ log.debug("scheduling %s", node.node)
1481
+
1482
+ # unbacked symbols don't follow ordinary buffer dependencies, so
1483
+ # we track their def/uses separately
1484
+ unbacked_symbol_defs = sorted(
1485
+ node.node.get_unbacked_symbol_defs(), key=lambda x: x.name
1486
+ )
1487
+ for s in unbacked_symbol_defs:
1488
+ assert isinstance(s, sympy.Symbol)
1489
+ # Pick the first definer as canonical. There may be multiple
1490
+ # because if a MultiOutputLayout buffer propagates an unbacked
1491
+ # symint to multiple outputs, they will all claim to def it.
1492
+ if s not in unbacked_symbol_to_origin_node:
1493
+ unbacked_symbol_to_origin_node[s] = node
1494
+
1495
+ unbacked_symbol_uses = sorted(
1496
+ node.node.get_unbacked_symbol_uses(), key=lambda x: x.name
1497
+ )
1498
+ # if a kernel takes unbacked symints, register dependencies
1499
+ for s in unbacked_symbol_uses:
1500
+ assert (
1501
+ s in unbacked_symbol_to_origin_node
1502
+ ), f"{s} not in {unbacked_symbol_to_origin_node}"
1503
+ node.add_fake_dep(StarDep(unbacked_symbol_to_origin_node[s].get_name()))
1504
+
1505
+ # a node will mutate either 0 or 1 buffers
1506
+ assert len(node.get_mutations()) <= 1
1507
+ for alt_name in node.get_mutations():
1508
+ alt_name = rename(alt_name)
1509
+ # this node must run after the prior writer
1510
+ add_user(alt_name, node)
1511
+ node.add_mutation_dep(StarDep(alt_name))
1512
+ for other_node in name_to_users[alt_name].items:
1513
+ # this node must run after all prior readers
1514
+ other_name = rename(other_node.get_name())
1515
+ known_dep_node_names = dep_closure(node.get_name())
1516
+ if other_name not in known_dep_node_names:
1517
+ # If this node already directly or indirectly depends on other_node,
1518
+ # we don't need to insert an extra dep.
1519
+ node.add_mutation_dep(WeakDep(other_name))
1520
+ add_user(other_name, node, is_weak=True)
1521
+
1522
+ # add normal non-mutation dependencies
1523
+ for read in node.read_writes.reads:
1524
+ is_weak = isinstance(read, WeakDep)
1525
+ add_user(read.name, node, node.can_inplace(read), is_weak)
1526
+
1527
+ node.update_mutated_names(self.mutation_renames)
1528
+
1529
+ # update our renaming scheme for the next iteration
1530
+ for alt_name in node.get_mutations():
1531
+ self.mutation_renames[rename(alt_name)] = node.get_name()
1532
+ self.mutation_renames[alt_name] = node.get_name()
1533
+ self.mutation_real_name[node.get_name()] = self.mutation_real_name.get(
1534
+ alt_name, alt_name
1535
+ )
1536
+
1537
+ # make sure outputs aren't dead-code-eliminated
1538
+ for node_name in V.graph.get_output_names():
1539
+ log.debug("scheduling output %s", node_name)
1540
+ add_user(node_name, OutputNode(StarDep(node_name)))
1541
+
1542
+ # make sure unbacked symints aren't dead-code-eliminated
1543
+ for node in V.graph.graph_outputs:
1544
+ for s in node.get_unbacked_symbol_uses():
1545
+ assert (
1546
+ s in unbacked_symbol_to_origin_node
1547
+ ), f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
1548
+ node_name = unbacked_symbol_to_origin_node[s].node.name
1549
+ log.debug("scheduling output %s for unbacked symint %s", node_name, s)
1550
+ add_user(node_name, OutputNode(StarDep(node_name)))
1551
+
1552
+ # make sure input mutation isn't dead-code-eliminated
1553
+ for name in self.mutation_renames:
1554
+ if name in V.graph.graph_inputs:
1555
+ add_user(name, OutputNode(StarDep(name)))
1556
+ V.graph.mutated_inputs.add(name)
1557
+
1558
+ inp_names = {
1559
+ name: index for index, name in enumerate(V.graph.graph_inputs.keys())
1560
+ }
1561
+ V.graph.mutated_input_idxs = [
1562
+ inp_names[name] for name in V.graph.mutated_inputs
1563
+ ]
1564
+
1565
+ # copy users information onto the nodes
1566
+ for node in self.nodes:
1567
+ node.set_users(name_to_users[node.get_name()].items)
1568
+
1569
+ # populate inverse_users
1570
+ for node in self.nodes:
1571
+ for user in node.users:
1572
+ user.node.inverse_users.append(node)
1573
+
1574
+ def compute_node_users(self):
1575
+ # set up buffer name to (fused)snode mapping
1576
+ buf_to_snode = {}
1577
+ for node in self.nodes:
1578
+ if isinstance(node, FusedSchedulerNode):
1579
+ for x in node.snodes:
1580
+ buf_to_snode[x.get_name()] = node
1581
+ buf_to_snode[node.get_name()] = node
1582
+
1583
+ for node in self.nodes:
1584
+ node.node_users = []
1585
+ node.inverse_users = []
1586
+
1587
+ # compute inverse_users
1588
+ for node in self.nodes:
1589
+ inverse_users = []
1590
+ for dep in node.unmet_dependencies:
1591
+ assert dep.name in buf_to_snode
1592
+ dep_node = buf_to_snode[dep.name]
1593
+ inverse_users.append(dep_node)
1594
+ node.inverse_users = inverse_users
1595
+
1596
+ # compute node_users
1597
+ # TODO: ideally, we should deduplicate .users and .node_users,
1598
+ # but currently .users contains extra information that's difficult to
1599
+ # extract into a standalone container.
1600
+ node_to_users: Dict[BaseSchedulerNode, List[BaseSchedulerNode]] = {}
1601
+ for node in self.nodes:
1602
+ for inverse_user in node.inverse_users:
1603
+ node_to_users.setdefault(inverse_user, []).append(node)
1604
+ for node, users in node_to_users.items():
1605
+ node.node_users = users
1606
+
1607
+ def dead_node_elimination(self):
1608
+ """
1609
+ Remove any nodes without users
1610
+ """
1611
+ again = True # repeat until a fixed point
1612
+ while again:
1613
+ updated_nodes = []
1614
+ for node in self.nodes:
1615
+
1616
+ def can_eliminate_user(user: NodeUser):
1617
+ return user.is_weak or user.get_name() in V.graph.removed_buffers
1618
+
1619
+ can_eliminate = not node.has_side_effects() and all(
1620
+ can_eliminate_user(u) for u in node.users
1621
+ )
1622
+
1623
+ if not can_eliminate:
1624
+ updated_nodes.append(node)
1625
+ else:
1626
+ # dead code
1627
+ log.debug("removed dead node: %s", node.get_name())
1628
+ V.graph.removed_buffers.add(node.get_name())
1629
+
1630
+ again = len(self.nodes) > len(updated_nodes)
1631
+ self.nodes = updated_nodes
1632
+
1633
+ # Prune any WeakDeps no longer needed
1634
+ for node in self.nodes:
1635
+ node.prune_weak_deps()
1636
+
1637
+ def topological_sort_schedule(self):
1638
+ """
1639
+ Ensure self.nodes is in topologically sorted order
1640
+ """
1641
+ seen: Set[ir.Buffer] = set()
1642
+ name_to_node: Dict[str, ir.Buffer] = dict()
1643
+ result: List[ir.Buffer] = []
1644
+
1645
+ def visit(n):
1646
+ if n not in seen:
1647
+ seen.add(n)
1648
+ for dep in sorted(n.unmet_dependencies, key=lambda d: d.name):
1649
+ visit(name_to_node[dep.name])
1650
+ result.append(n)
1651
+
1652
+ for node in self.nodes:
1653
+ for name in node.get_names():
1654
+ name_to_node[name] = node
1655
+ for node in self.nodes:
1656
+ visit(node)
1657
+ self.nodes = result
1658
+
1659
+ def compute_ancestors(self):
1660
+ """
1661
+ Populate each node.ancestors
1662
+ """
1663
+ # note self.nodes is topologically sorted
1664
+ name_to_ancestors: Dict[str, Set[str]] = {}
1665
+ for node in self.nodes:
1666
+ ancestors = set()
1667
+ for dep in node.unmet_dependencies:
1668
+ ancestors.add(dep.name)
1669
+ ancestors |= name_to_ancestors[dep.name]
1670
+ name_to_ancestors[node.get_name()] = ancestors
1671
+ node.ancestors = ancestors
1672
+
1673
+ for order, node in enumerate(self.nodes):
1674
+ node.min_order = order
1675
+ node.max_order = order
1676
+
1677
+ def fuse_nodes(self):
1678
+ """
1679
+ Mutates self.nodes to combine nodes into FusedSchedulerNodes.
1680
+ """
1681
+ for i in range(10):
1682
+ old_len = len(self.nodes)
1683
+ fusion_log.debug(
1684
+ "===== attempting fusion (%d/10): %d nodes =====", i + 1, old_len
1685
+ )
1686
+ self.fuse_nodes_once()
1687
+ new_len = len(self.nodes)
1688
+ fusion_log.debug(
1689
+ "completed fusion round (%d/10): fused %d nodes into %d nodes\n",
1690
+ i + 1,
1691
+ old_len,
1692
+ new_len,
1693
+ )
1694
+ if new_len == old_len or new_len == 1:
1695
+ fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1)
1696
+ break
1697
+
1698
+ def benchmark_fused_nodes(self, nodes):
1699
+ """
1700
+ Benchmark fused list of nodes and return the execution time
1701
+ in milliseconds on randomly generated inputs.
1702
+ """
1703
+ assert len(nodes) > 0
1704
+ device = nodes[0].get_device()
1705
+ V.graph.scheduler = self
1706
+ self.current_device = device
1707
+ backend = self.get_backend(device)
1708
+ return backend.benchmark_fused_nodes(nodes)
1709
+
1710
+ def speedup_by_fusion(self, node1, node2):
1711
+ """
1712
+ If config.benchmark_fusion is False, always return True.
1713
+ Otherwise, return True if fusion can brings speedup.
1714
+ """
1715
+ if not config.benchmark_fusion:
1716
+ return True
1717
+
1718
+ if (
1719
+ node1.is_template()
1720
+ and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer)
1721
+ or node1.is_foreach()
1722
+ or node2.is_foreach()
1723
+ ):
1724
+ # TODO support benchmarking epilogue fusion
1725
+ return True
1726
+
1727
+ node_list_1 = node1.get_nodes()
1728
+ device = node_list_1[0].get_device()
1729
+
1730
+ # don't support benchmark fusion for CPU right now.
1731
+ if device.type == "cpu":
1732
+ return True
1733
+
1734
+ node_list_2 = node2.get_nodes()
1735
+ node_list_fused = node_list_1 + node_list_2
1736
+
1737
+ # We can not accurately benchmark kernel using atomic_add
1738
+ # due to how we generate random integer inputs.
1739
+ # Skip benchmarking them by allowing fusion.
1740
+ if any(
1741
+ hasattr(n.node, "data")
1742
+ and hasattr(n.node.data, "scatter_mode")
1743
+ and n.node.data.scatter_mode == "atomic_add"
1744
+ for n in node_list_fused
1745
+ ):
1746
+ return True
1747
+
1748
+ from triton.compiler.errors import CompilationError
1749
+
1750
+ why = WhyNoFuse(node1, node2)
1751
+
1752
+ try:
1753
+ ms1, path1 = self.benchmark_fused_nodes(node_list_1)
1754
+ if math.isinf(ms1):
1755
+ why("register spilling of the first kernel")
1756
+ return False
1757
+ ms2, path2 = self.benchmark_fused_nodes(node_list_2)
1758
+ if math.isinf(ms2):
1759
+ why("register spilling of the second kernel")
1760
+ return False
1761
+ ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused)
1762
+ if math.isinf(ms_fused):
1763
+ why("register spilling of the fused kernel")
1764
+ return False
1765
+ except CompilationError as e:
1766
+ # workaround triton issue: https://github.com/openai/triton/issues/2151
1767
+ if "Loop-carried variable" in str(e):
1768
+ return True # allow fusion
1769
+ else:
1770
+ raise
1771
+
1772
+ if fusion_log.isEnabledFor(logging.DEBUG):
1773
+ if ms_fused < ms1 + ms2:
1774
+ fusion_log.debug(
1775
+ "can fuse (benchmark): fusing %s with %s cause %sx speedup",
1776
+ node1.get_names(),
1777
+ node2.get_names(),
1778
+ green_text(f"{(ms1 + ms2) / ms_fused:.3f}"),
1779
+ )
1780
+ else:
1781
+ fusion_log.debug(
1782
+ "cannot fuse (benchmark): fusing %s with %s cause %sx slowdown",
1783
+ node1.get_names(),
1784
+ node2.get_names(),
1785
+ red_text(f"{ms_fused / (ms1 + ms2):.3f}"),
1786
+ )
1787
+
1788
+ if (
1789
+ is_metric_table_enabled("slow_fusion")
1790
+ and ms_fused >= ms1 + ms2
1791
+ and (path1, path2) not in self.logged_slow_fusion
1792
+ ):
1793
+ self.logged_slow_fusion.add((path1, path2))
1794
+ get_metric_table("slow_fusion").add_row(
1795
+ lambda: {
1796
+ "kernel1_path": path1,
1797
+ "kernel1_latency": ms1,
1798
+ "kernel2_path": path2,
1799
+ "kernel2_latency": ms2,
1800
+ "fused_kernel_path": path_fused,
1801
+ "fused_kernel_latency": ms_fused,
1802
+ "slow_down_ratio": ms_fused / (ms1 + ms2),
1803
+ }
1804
+ )
1805
+ return ms_fused < ms1 + ms2
1806
+
1807
+ def fuse_nodes_once(self):
1808
+ """
1809
+ Mutates self.nodes to combine nodes into FusedSchedulerNodes.
1810
+
1811
+ This relies on two key functions to control the logic:
1812
+ - self.can_fuse(): checks if a fusion is legal
1813
+ - self.score_fusion(): assigns priority to a given fusion
1814
+ """
1815
+ fused_nodes = set(self.nodes)
1816
+ for node1, node2 in self.get_possible_fusions():
1817
+ node1 = self.name_to_fused_node[node1.get_first_name()]
1818
+ node2 = self.name_to_fused_node[node2.get_first_name()]
1819
+ if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle(
1820
+ node1, node2
1821
+ ):
1822
+ if not self.speedup_by_fusion(node1, node2):
1823
+ continue
1824
+ fusion_log.debug(
1825
+ "fusing %s with %s", node1.get_name(), node2.get_name()
1826
+ )
1827
+
1828
+ # above can_fuse asserts that node2 has the same device
1829
+ device = node1.get_device()
1830
+ node3 = self.get_backend(device).fuse(node1, node2)
1831
+ fused_nodes.remove(node1)
1832
+ fused_nodes.remove(node2)
1833
+ fused_nodes.add(node3)
1834
+ self.name_to_fused_node.update(
1835
+ {n.get_name(): node3 for n in node3.get_nodes()}
1836
+ )
1837
+ self.nodes = sorted(fused_nodes, key=lambda x: x.min_order)
1838
+ self.topological_sort_schedule()
1839
+ self.prune_redundant_deps()
1840
+
1841
+ def prune_redundant_deps(self):
1842
+ for node in self.nodes:
1843
+ node.prune_redundant_deps(self.name_to_fused_node)
1844
+
1845
+ def get_possible_fusions(self):
1846
+ """
1847
+ Helper to find all legal fusion opportunities, sorted by self.score_fusion()
1848
+ """
1849
+ possible_fusions = []
1850
+ seen = set()
1851
+
1852
+ def check_all_pairs(nodes):
1853
+ for node1_index, node1 in enumerate(nodes):
1854
+ for node2 in nodes[node1_index + 1 :]:
1855
+ key = (node1, node2)
1856
+ if key in seen:
1857
+ continue
1858
+ seen.add(key)
1859
+
1860
+ if self.can_fuse(node1, node2):
1861
+ possible_fusions.append(key)
1862
+ elif (node2.is_template() or node2.is_foreach()) and self.can_fuse(
1863
+ node2, node1
1864
+ ):
1865
+ # foreach fusions and epilogue fusions are order dependent
1866
+ possible_fusions.append((node2, node1))
1867
+
1868
+ buffer_names_grouping = collections.defaultdict(list)
1869
+ for node in self.nodes:
1870
+ for buf in node.used_buffer_names():
1871
+ buffer_names_grouping[buf].append(node)
1872
+ for node_grouping in buffer_names_grouping.values():
1873
+ check_all_pairs(node_grouping)
1874
+
1875
+ if config.aggressive_fusion:
1876
+ group_grouping = collections.defaultdict(list)
1877
+ for node in self.nodes:
1878
+ group = getattr(node, "group", None)
1879
+ if group:
1880
+ group_grouping[group].append(node)
1881
+ for node_grouping in group_grouping.values():
1882
+ check_all_pairs(node_grouping)
1883
+
1884
+ possible_fusions.sort(key=self.score_fusion_key, reverse=True)
1885
+ fusion_log.debug("found %d possible fusions", len(possible_fusions))
1886
+ return possible_fusions
1887
+
1888
+ def will_fusion_create_cycle(self, node1, node2):
1889
+ """
1890
+ Finds whether there's a path from node1 to node2 (or vice-versa)
1891
+ caused indirectly by other fusions.
1892
+ """
1893
+
1894
+ def found_path(node):
1895
+ # only fused nodes can introduce new ancestors.
1896
+ if isinstance(node, FusedSchedulerNode) and node not in visited:
1897
+ visited.add(node)
1898
+ if node.get_names().issubset(combined_ancestors):
1899
+ # All fusion outputs are in ancestors of node1 and node2, thus
1900
+ # cannot introduce new path:
1901
+ #
1902
+ # 1. if output is neither descendent of node1 or node2, the
1903
+ # output cannot introduce a path
1904
+ # 2. due to [can_fuse]: if WLOG output is descendent of node1, it cannot be
1905
+ # on path(node1->node2), hence it cannot be ancestor of node2
1906
+ # 3. due to [acyclic]: if WLOG output is descendent of node1, it cannot be
1907
+ # ancestor of node1
1908
+ return False
1909
+ else:
1910
+ # continue DFS of new ancestors introduced by the fusion
1911
+ return bool(combined_names & node.ancestors) or any(
1912
+ found_path(self.name_to_fused_node[n])
1913
+ for n in node.ancestors - combined_ancestors
1914
+ )
1915
+ return False
1916
+
1917
+ visited = set()
1918
+ combined_names = node1.get_names() | node2.get_names()
1919
+ combined_ancestors = (node1.ancestors | node2.ancestors) - combined_names
1920
+ cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors)
1921
+ if cycle:
1922
+ WhyNoFuse(node1, node2)("will create cycle")
1923
+ return cycle
1924
+
1925
+ def can_fusion_increase_peak_memory(
1926
+ self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
1927
+ ):
1928
+ """
1929
+ This function prevents fusion for nodes that can increase memory
1930
+ footprint. This problem is more common in horizontal fusion, where nodes
1931
+ that are far apart in the original order get fused, lengthening the live
1932
+ intervals of tensors. This is very evident in models with activation
1933
+ checkpointing, where the recomputed nodes from different checkpointed
1934
+ regions get fused and significantly increase the memory footprint.
1935
+
1936
+ The current attempt is a quick, possibly hacky, heuristic to prevent the
1937
+ fusion of nodes that are far away in the original order.
1938
+
1939
+ A better but difficult to implement heurisitic would be to use live
1940
+ intervals of the buffers, find region of peak pressure in the original
1941
+ program and prevent fusion that crosses that peak region. We might need
1942
+ special care or good approximation in this implementation, as fusion of
1943
+ node changes live intervals, and re-computing live intervals and peak
1944
+ memory after each fusion can introduce large compilation overhead.
1945
+ """
1946
+ proximity_score = max(
1947
+ abs(node1.min_order - node2.max_order),
1948
+ abs(node2.min_order - node1.max_order),
1949
+ )
1950
+ return proximity_score > 64
1951
+
1952
+ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
1953
+ """
1954
+ Determine if it is possible to combine node1 and node2 into a
1955
+ single fused node.
1956
+ """
1957
+
1958
+ if node1 is node2:
1959
+ return False
1960
+
1961
+ why = WhyNoFuse(node1, node2)
1962
+
1963
+ if (
1964
+ isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
1965
+ and not node1.is_template()
1966
+ ):
1967
+ why("node1 is extern or nop")
1968
+ return False
1969
+ if (
1970
+ isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
1971
+ and not node2.is_template()
1972
+ ):
1973
+ why("node2 is extern or nop")
1974
+ return False
1975
+
1976
+ if node2.get_names() & node1.ancestors:
1977
+ why("node1 must go before node2")
1978
+ return False
1979
+
1980
+ if (
1981
+ isinstance(node1, (FusedSchedulerNode, SchedulerNode))
1982
+ and isinstance(node2, SchedulerNode)
1983
+ and isinstance(node2._body, ir.LoopBody)
1984
+ ):
1985
+ # Fix issue: https://github.com/pytorch/pytorch/issues/108963
1986
+ # Check:
1987
+ # If node2 reads a buf which is a mutation buf of node1(SchedulerNode) or among nodes in node1(FusedSchedulerNode),
1988
+ # we will get the corresponding mutation buf and check if this mutation buf is stored by atomic_add mode.
1989
+ # If True, we will disable the fusion of node1 and node2.
1990
+ if any(
1991
+ (
1992
+ node2_used_buf in self.mutation_renames
1993
+ and node1.has_atomic_add(self.mutation_renames[node2_used_buf])
1994
+ )
1995
+ for node2_used_buf in node2._body.reads_name2expr.keys()
1996
+ ):
1997
+ return False
1998
+
1999
+ if node2.is_template():
2000
+ why("templates can only fuse epilogues")
2001
+ return False
2002
+ if node1.is_template() and (
2003
+ node2.has_aliasing_or_mutation()
2004
+ or node2.is_reduction()
2005
+ or not config.epilogue_fusion
2006
+ ):
2007
+ why("template epilogue not satisfied")
2008
+ return False
2009
+
2010
+ device = node1.get_device()
2011
+ device2 = node2.get_device()
2012
+ if device != device2:
2013
+ why("device mismatch (%s vs %s)", device, device2)
2014
+ return False
2015
+ del device2
2016
+
2017
+ no_shared_data = self.score_fusion_memory(node1, node2) == 0
2018
+ if no_shared_data and (
2019
+ not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction()
2020
+ ):
2021
+ why("no shared data")
2022
+ return False # heuristic not needed for correctness
2023
+
2024
+ if (
2025
+ not node1.is_foreach()
2026
+ and not node2.is_foreach()
2027
+ and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size
2028
+ ):
2029
+ why("exceeds max fusion")
2030
+ return False # heuristic not needed for correctness
2031
+
2032
+ if node1.get_names() & node2.ancestors:
2033
+ # node2 depends on node1 outputs
2034
+ if not self.can_fuse_vertical(node1, node2):
2035
+ return False
2036
+ return self.get_backend(device).can_fuse_vertical(node1, node2)
2037
+ else: # nodes don't depend on each other, but may have common reads
2038
+ if self.can_fusion_increase_peak_memory(node1, node2):
2039
+ why("will increase peak memory")
2040
+ return False
2041
+ return self.get_backend(device).can_fuse_horizontal(node1, node2)
2042
+
2043
+ def can_fuse_vertical(self, node1, node2):
2044
+ """
2045
+ Check if it is legal to fuse a consumer (node2) into a producer (node1).
2046
+
2047
+ We can fuse them if all the reads of node2 either match
2048
+ corresponding writes in node1, or are written by nodes that can
2049
+ be scheduled before the fusion of node1 and node2.
2050
+
2051
+ We also disable fusion of a write subsequent to a read if the reads
2052
+ and writes do not align.
2053
+ """
2054
+ node1_names = node1.get_names()
2055
+ computed_deps = set()
2056
+ why = WhyNoFuse(node1, node2)
2057
+
2058
+ # StarDep doesn't match MemoryDep, different indices don't match
2059
+ # However, broadcasting sometimes strips dimensions, and if that's the case
2060
+ # we still can match unmet dep
2061
+ # if there's indirect indexing, don't match it
2062
+ def fusable_read_and_write(read: Dep, write: Dep):
2063
+ return (
2064
+ self.mutation_renames.get(read.name, read.name) == write.name
2065
+ and (isinstance(read, MemoryDep) and isinstance(write, MemoryDep))
2066
+ and not free_symbol_has(read.index, "tmp")
2067
+ and not free_symbol_has(write.index, "tmp")
2068
+ and read.index == write.index
2069
+ and len(read.size) >= len(write.size)
2070
+ and read.size[: len(write.size)] == write.size
2071
+ )
2072
+
2073
+ for rd in node2.unmet_dependencies:
2074
+ for cd in node1.read_writes.writes:
2075
+ if fusable_read_and_write(rd, cd):
2076
+ computed_deps.add(rd)
2077
+
2078
+ remaining_deps = {dep.name for dep in node2.unmet_dependencies - computed_deps}
2079
+ if remaining_deps & node1_names:
2080
+ # MemoryDeps didn't match and read different locations of the same buffer.
2081
+ # Examples here include:
2082
+ # - MemoryDep("foo", x) != MemoryDep("foo", x + 1)
2083
+ # - MemoryDep("foo", x) != StarDep("foo")
2084
+ why("memory deps did not match")
2085
+ return False
2086
+ for name in remaining_deps:
2087
+ if node1_names & self.name_to_fused_node[name].ancestors:
2088
+ why("intermediate nodes between node1 & node2")
2089
+ return False
2090
+
2091
+ # similar to can_inplace, if we are going to fuse a write subsequent to a read
2092
+ # require that the indexing and size is the same
2093
+ for write in node2.read_writes.writes:
2094
+ for read in node1.read_writes.reads:
2095
+ if write.name != self.mutation_renames.get(read.name, read.name):
2096
+ continue
2097
+
2098
+ # bail on StarDep
2099
+ if not fusable_read_and_write(read=read, write=write):
2100
+ why("fusing a write into a read with different indexing formula")
2101
+ return False
2102
+
2103
+ return True
2104
+
2105
+ def score_fusion(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
2106
+ """
2107
+ Assign a score (higher comes first) to the fusion of node1
2108
+ and node2. When different fusions conflict with each other,
2109
+ this is the way we decide what order to run them in.
2110
+
2111
+ Our current score is based on:
2112
+ - Estimate of the saved memory operations
2113
+ - Fusions closer together in original order
2114
+ """
2115
+ memory_score = self.score_fusion_memory(node1, node2)
2116
+ proximity_score = -max(
2117
+ abs(node1.min_order - node2.max_order),
2118
+ abs(node2.min_order - node1.max_order),
2119
+ )
2120
+ return (
2121
+ node1.is_template() == config.epilogue_fusion_first and memory_score > 0,
2122
+ node1.is_reduction() == node2.is_reduction() and memory_score > 0,
2123
+ memory_score,
2124
+ proximity_score,
2125
+ )
2126
+
2127
+ def score_fusion_memory(self, node1, node2):
2128
+ """
2129
+ The first term in our fusion score that estimates number of saved memory operations.
2130
+ """
2131
+ common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & (
2132
+ node2.read_writes.reads | node2.read_writes.writes
2133
+ )
2134
+ common_memory_deps = {
2135
+ dep for dep in common_memory_deps if not dep.has_unbacked_symbols()
2136
+ }
2137
+ return sum(dep.numbytes_hint() for dep in common_memory_deps)
2138
+
2139
+ def score_fusion_key(self, nodes):
2140
+ """
2141
+ Shim for list.sort(key=...)
2142
+ """
2143
+ node1, node2 = nodes
2144
+ return self.score_fusion(node1, node2)
2145
+
2146
+ def compute_last_usage(self):
2147
+ """
2148
+ Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode)
2149
+ """
2150
+
2151
+ future_used_buffers = set()
2152
+ for node_name in V.graph.get_output_names():
2153
+ future_used_buffers.add(node_name)
2154
+
2155
+ for node in reversed(self.nodes):
2156
+ node.set_last_usage(future_used_buffers, self.mutation_real_name)
2157
+ future_used_buffers.update(node.last_usage)
2158
+
2159
+ def free_buffers(self):
2160
+ """Free any buffers that are no longer needed"""
2161
+ for name in sorted(
2162
+ self.buffer_names_to_free
2163
+ - V.graph.removed_buffers
2164
+ - V.graph.wrapper_code.freed
2165
+ ):
2166
+ if name in self.name_to_node:
2167
+ node = self.name_to_node[name]
2168
+ if node.can_free():
2169
+ V.graph.wrapper_code.codegen_free(node.node)
2170
+ elif name in V.graph.graph_inputs:
2171
+ storage = V.graph.graph_inputs[name].data
2172
+ assert isinstance(storage, ir.StorageBox) and storage.is_input_buffer()
2173
+ V.graph.wrapper_code.codegen_free(storage.data)
2174
+
2175
+ self.buffer_names_to_free.clear()
2176
+
2177
+ def remove_kernel_local_buffers(self):
2178
+ """
2179
+ Any buffers that are both created and have a last use in the
2180
+ same kernel can be removed.
2181
+ """
2182
+
2183
+ # V.kernel.store_buffer_names should represent the set of nodes
2184
+ # get fused
2185
+ fused_node_names = V.kernel.store_buffer_names
2186
+ names_to_remove = []
2187
+ for out_buf in V.kernel.store_buffer_names:
2188
+ users = self.name_to_node[out_buf].users
2189
+ assert users is not None
2190
+ users = {user.get_name() for user in users if not user.is_weak}
2191
+ if users.issubset(fused_node_names):
2192
+ names_to_remove.append(out_buf)
2193
+
2194
+ def remove_filter(n):
2195
+ return (
2196
+ n not in V.kernel.must_keep_buffers
2197
+ and n not in V.kernel.args.input_buffers
2198
+ and n not in self.mutation_renames
2199
+ and n not in self.mutation_real_name
2200
+ )
2201
+
2202
+ names_to_remove = list(filter(remove_filter, names_to_remove))
2203
+
2204
+ for name in names_to_remove:
2205
+ if name in V.kernel.args.inplace_buffers:
2206
+ buf = V.kernel.args.inplace_buffers[name]
2207
+ if isinstance(buf, str) and buf.startswith("REMOVED"):
2208
+ continue
2209
+ remove = all(n in names_to_remove for n in buf.other_names)
2210
+ if remove:
2211
+ self.remove_inplace_buffer(name)
2212
+ V.kernel.inplaced_to_remove.add(name)
2213
+ else:
2214
+ self.remove_buffer(name)
2215
+
2216
+ def remove_buffer(self, name):
2217
+ # Assign a special value instead of deleting the entry
2218
+ # because we still rely on output_buffers's length to
2219
+ # generate unique arg name.
2220
+ log.debug("remove_buffer(%r)", name)
2221
+ V.kernel.args.output_buffers[name] = "REMOVED"
2222
+ V.kernel.removed_buffers.add(name)
2223
+
2224
+ def remove_inplace_buffer(self, name):
2225
+ log.debug("removing_inplace_buffer(%r)", name)
2226
+ inner_name = V.kernel.args.inplace_buffers[name].inner_name
2227
+ V.kernel.args.inplace_buffers[name] = inner_name.replace(
2228
+ "in_out_ptr", "REMOVED"
2229
+ )
2230
+ V.kernel.removed_buffers.add(name)
2231
+
2232
+ def flush(self):
2233
+ for backend in self.backends.values():
2234
+ backend.flush()
2235
+ self.free_buffers()
2236
+
2237
+ def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode):
2238
+ assert isinstance(scheduler_node, ExternKernelSchedulerNode)
2239
+ # 'decide_inplace_update' stores the inplace update decisions in
2240
+ # the current kernel from where 'allocate' retrieve those decisions.
2241
+ # We have to make sure there is a non-NULL kernel handler to store
2242
+ # those inplace update decisions.
2243
+ with V.set_kernel_handler(Kernel(increase_kernel_count=False)):
2244
+ scheduler_node.decide_inplace_update()
2245
+ scheduler_node.allocate()
2246
+ node = scheduler_node.node
2247
+ assert isinstance(node, ir.ExternKernel), f"{type(node)=}"
2248
+ node.codegen(V.graph.wrapper_code)
2249
+ self.free_buffers()
2250
+
2251
+ def create_backend(self, device: torch.device):
2252
+ assert (
2253
+ device.type != "cuda" or device.index is not None
2254
+ ), f"{device} should have been normalized in lowering"
2255
+ V.graph.add_device_info(device)
2256
+
2257
+ device_scheduling = get_scheduling_for_device(device.type)
2258
+ if device_scheduling is None:
2259
+ raise RuntimeError(f"Unsupported device type: {device.type}")
2260
+
2261
+ if device.type == "cuda" and not has_triton():
2262
+ device_props = torch.cuda.get_device_properties(device)
2263
+ if device_props.major < 7:
2264
+ raise RuntimeError(
2265
+ f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" # noqa: B950
2266
+ )
2267
+ else:
2268
+ raise RuntimeError(
2269
+ "Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950
2270
+ )
2271
+
2272
+ return device_scheduling(self)
2273
+
2274
+ def get_backend(self, device: torch.device):
2275
+ if device not in self.backends:
2276
+ self.backends[device] = self.create_backend(device)
2277
+ return self.backends[device]
2278
+
2279
+ def enter_context(self, node):
2280
+ def get_order(n):
2281
+ if n not in self.origin_to_index:
2282
+ self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)})
2283
+ return self.origin_to_index[n]
2284
+
2285
+ # Use a dict to have ordering
2286
+ origins = {
2287
+ (get_order(e), e): None for n in node.get_nodes() for e in n.node.origins
2288
+ }
2289
+ origins = list(origins.keys())
2290
+ if origins:
2291
+ _, last = max(origins, key=operator.itemgetter(0))
2292
+ V.graph.wrapper_code.enter_context(last)
2293
+
2294
+ @dynamo_timed
2295
+ def codegen(self):
2296
+ for node in self.nodes:
2297
+ try:
2298
+ log.debug(
2299
+ "Generating code for node %s with estimated runtime %f",
2300
+ node.get_name(),
2301
+ node.get_estimated_runtime(),
2302
+ )
2303
+ except Exception as e:
2304
+ log.debug(
2305
+ "Generating code for node %s with estimated runtime 0.0",
2306
+ node.get_name(),
2307
+ )
2308
+
2309
+ self.enter_context(node)
2310
+
2311
+ if not isinstance(node, NopKernelSchedulerNode):
2312
+ device = node.get_device()
2313
+ if (
2314
+ device != self.current_device
2315
+ or node.is_extern()
2316
+ or node.is_template()
2317
+ ):
2318
+ self.flush()
2319
+ if device != self.current_device:
2320
+ if device.type == "cuda":
2321
+ if self.current_device and self.current_device.type == "cuda":
2322
+ V.graph.wrapper_code.codegen_device_guard_exit()
2323
+ assert device.index is not None, "device should have an index"
2324
+ V.graph.wrapper_code.codegen_device_guard_enter(device.index)
2325
+ elif self.current_device and self.current_device.type == "cuda":
2326
+ V.graph.wrapper_code.codegen_device_guard_exit()
2327
+ self.current_device = device
2328
+
2329
+ self.buffer_names_to_free.update(node.last_usage)
2330
+
2331
+ if node.is_template():
2332
+ node, *epilogue = node.get_nodes()
2333
+ self.get_backend(device).codegen_template(node, epilogue) # type: ignore[possibly-undefined]
2334
+ elif node.is_extern():
2335
+ self.codegen_extern_call(node)
2336
+ elif node.is_foreach():
2337
+ self.get_backend(device).codegen_foreach(node) # type: ignore[possibly-undefined]
2338
+ elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
2339
+ self.get_backend(device).codegen_nodes(node.get_nodes()) # type: ignore[possibly-undefined]
2340
+ else:
2341
+ assert isinstance(node, NopKernelSchedulerNode)
2342
+ node.allocate()
2343
+
2344
+ if config.debug_check_inf_and_nan:
2345
+ V.graph.wrapper_code.generate_inf_and_nan_checker(node)
2346
+
2347
+ if config.triton.debug_sync_kernel:
2348
+ self.get_backend(device).codegen_sync() # type: ignore[possibly-undefined]
2349
+
2350
+ self.available_buffer_names.update(node.get_names())
2351
+
2352
+ if not isinstance(node, NopKernelSchedulerNode):
2353
+ device = node.get_device()
2354
+ if self.get_backend(device).ready_to_flush():
2355
+ self.flush()
2356
+
2357
+ if self.current_device and self.current_device.type == "cuda":
2358
+ # exit the outermost CUDA device guard. this is
2359
+ # important for nested indentation codegen-ing.
2360
+ V.graph.wrapper_code.codegen_device_guard_exit()
2361
+
2362
+ self.flush()
2363
+
2364
+ def is_unaligned_buffer(self, buf_name):
2365
+ if buf_name in V.graph.graph_inputs or buf_name in V.graph.constants:
2366
+ # all graph inputs or constants are assumed to be aligned
2367
+ return False
2368
+ node = self.name_to_node[buf_name]
2369
+ layout = node.node.get_layout()
2370
+ if isinstance(layout, ir.AliasedLayout):
2371
+ return not layout.maybe_guard_aligned()
2372
+ else:
2373
+ return False
2374
+
2375
+
2376
+ class BaseScheduling:
2377
+ def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
2378
+ """
2379
+ Check whether node1 and node2 can be vertically fused or not.
2380
+ """
2381
+ raise NotImplementedError()
2382
+
2383
+ def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
2384
+ """
2385
+ Check whether node1 and node2 can be horizontally fused or not.
2386
+ """
2387
+ raise NotImplementedError()
2388
+
2389
+ def fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
2390
+ """
2391
+ Fuse two nodes
2392
+ """
2393
+ if node1.is_foreach() or node2.is_foreach():
2394
+ return ForeachKernelSchedulerNode.fuse(node1, node2)
2395
+ else:
2396
+ return FusedSchedulerNode.fuse(node1, node2)
2397
+
2398
+ def group_fn(self, sizes):
2399
+ """
2400
+ Process the iteration sizes in case a transformation needs to be applied.
2401
+ """
2402
+ raise NotImplementedError()
2403
+
2404
+ def codegen_template(
2405
+ self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode]
2406
+ ):
2407
+ """
2408
+ Given a template node, generate a kernel.
2409
+
2410
+ This function is only available for triton now. If the third-party backend behaves as a sub-class
2411
+ of TritonScheduling, it can override it or reuse it.
2412
+ """
2413
+ raise NotImplementedError()
2414
+
2415
+ def codegen_nodes(self, nodes: List[SchedulerNode]):
2416
+ """
2417
+ Generate a kernel given a list of pre-fused nodes.
2418
+ """
2419
+ raise NotImplementedError()
2420
+
2421
+ def codegen_sync(self):
2422
+ """
2423
+ Generate synchronization code for the kernel. This method depends on the hardware characteristics.
2424
+ """
2425
+ raise NotImplementedError()
2426
+
2427
+ def ready_to_flush(self) -> bool:
2428
+ """
2429
+ Check whether the backend is requesting the scheduler to flush the generated kernel.
2430
+ If not supported, please return False.
2431
+ """
2432
+ return False
2433
+
2434
+ def flush(self):
2435
+ """
2436
+ Flush the generated kernel and python wrapper code to the source code file.
2437
+ """
2438
+ raise NotImplementedError()
2439
+
2440
+ def benchmark_fused_nodes(self, nodes):
2441
+ """
2442
+ Benchmark fused list of nodes and return the execution time
2443
+ in milliseconds on randomly generated inputs.
2444
+ """
2445
+ raise NotImplementedError()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py ADDED
@@ -0,0 +1,1156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import builtins
2
+ import functools
3
+ import inspect
4
+ import itertools
5
+ import logging
6
+ import operator
7
+ import sys
8
+ import textwrap
9
+ import time
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from io import StringIO
12
+
13
+ from typing import Any, Callable, Dict, List, Optional, Union
14
+ from unittest.mock import patch
15
+
16
+ import sympy
17
+
18
+ import torch
19
+ from torch._dynamo.testing import rand_strided
20
+ from torch._dynamo.utils import counters, identity, preserve_rng_state
21
+
22
+ from . import config, ir
23
+ from .autotune_process import TensorMeta, TritonBenchmarkRequest
24
+ from .codecache import code_hash, PersistentCache, PyCodeCache
25
+ from .codegen.common import (
26
+ ChoiceCaller,
27
+ IndentedBuffer,
28
+ KernelTemplate,
29
+ PrimitiveInfoType,
30
+ )
31
+ from .codegen.triton import (
32
+ gen_common_triton_imports,
33
+ texpr,
34
+ TritonKernel,
35
+ TritonPrinter,
36
+ TritonScheduling,
37
+ )
38
+ from .codegen.triton_utils import config_of, signature_to_meta
39
+ from .exc import CUDACompileError
40
+ from .utils import (
41
+ do_bench,
42
+ get_dtype_size,
43
+ Placeholder,
44
+ sympy_dot,
45
+ sympy_product,
46
+ unique,
47
+ )
48
+ from .virtualized import V
49
+
50
+ log = logging.getLogger(__name__)
51
+
52
+ # correctness checks struggle with fp16/tf32
53
+ VERIFY: Dict[str, Any] = dict()
54
+ PRINT_AUTOTUNE = True
55
+ DEBUG = False
56
+
57
+
58
+ class KernelNamespace:
59
+ pass
60
+
61
+
62
+ # these objects are imported from the generated wrapper code
63
+ extern_kernels = KernelNamespace()
64
+
65
+
66
+ class PartialRender:
67
+ """
68
+ Some parts of a template need to be generated at the end, but
69
+ inserted into the template at the start. This allows doing a bunch
70
+ of replacements after the initial render.
71
+ """
72
+
73
+ def __init__(self, code, replacement_hooks):
74
+ super().__init__()
75
+ self.code = code
76
+ self.replacement_hooks = replacement_hooks
77
+
78
+ def finalize(self):
79
+ code = self.code
80
+ assert code is not None, "can only be called once"
81
+ self.code = None
82
+ for key, fn in self.replacement_hooks.items():
83
+ code = code.replace(key, fn())
84
+ return code
85
+
86
+
87
+ class TritonTemplateKernel(TritonKernel):
88
+ def __init__(
89
+ self,
90
+ kernel_name,
91
+ input_nodes,
92
+ output_node,
93
+ defines,
94
+ num_stages,
95
+ num_warps,
96
+ grid_fn,
97
+ meta,
98
+ call_sizes,
99
+ use_jit=True,
100
+ prefix_args=0,
101
+ suffix_args=0,
102
+ epilogue_fn=identity,
103
+ *,
104
+ index_dtype,
105
+ ):
106
+ super().__init__(
107
+ sympy_product(output_node.get_size()),
108
+ sympy.Integer(1),
109
+ index_dtype=index_dtype,
110
+ )
111
+ self.input_nodes = input_nodes
112
+ self.output_node = output_node
113
+ self.named_input_nodes = {}
114
+ self.defines = defines
115
+ self.kernel_name = kernel_name
116
+ self.template_mask = None
117
+ self.use_jit = use_jit
118
+ self.num_stages = num_stages
119
+ self.num_warps = num_warps
120
+ self.grid_fn = grid_fn
121
+ self.meta = meta
122
+ self.call_sizes = call_sizes
123
+ # for templates with fixed epilogues
124
+ self.prefix_args = prefix_args
125
+ self.suffix_args = suffix_args
126
+ self.epilogue_fn = epilogue_fn
127
+ self.render_hooks = dict()
128
+ self.triton_meta: Optional[Dict[str, object]] = None
129
+
130
+ def need_numel_args(self):
131
+ return False
132
+
133
+ def estimate_kernel_num_bytes(self):
134
+ """
135
+ Estimate the total number of bytes this kernel takes.
136
+ For in/out nodes, sizes are counted twice: once for reading and
137
+ once for writing.
138
+ """
139
+ ninplace_args = len(unique(self.args.inplace_buffers.values()))
140
+ num_bytes = []
141
+ for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))):
142
+ size = V.graph.sizevars.size_hints(inp.get_size())
143
+ numel = functools.reduce(operator.mul, size)
144
+ dtype_size = get_dtype_size(inp.get_dtype())
145
+ num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
146
+ return sum(num_bytes)
147
+
148
+ def jit_lines(self):
149
+ if self.use_jit:
150
+ return "@triton.jit"
151
+
152
+ argdefs, _, signature = self.args.python_argdefs()
153
+ triton_meta = {
154
+ "signature": signature_to_meta(signature, size_dtype=self.index_dtype),
155
+ "device": V.graph.scheduler.current_device.index,
156
+ "device_type": V.graph.scheduler.current_device.type,
157
+ "constants": {},
158
+ }
159
+ triton_meta["configs"] = [config_of(signature)]
160
+ for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index]
161
+ triton_meta["constants"][arg_num] = 1 # type: ignore[index]
162
+ self.triton_meta = triton_meta
163
+
164
+ inductor_meta = {
165
+ "kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
166
+ "backend_hash": torch.utils._triton.triton_hash_with_backend(),
167
+ }
168
+ if config.profile_bandwidth or config.benchmark_kernel:
169
+ num_gb = self.estimate_kernel_num_bytes() / 1e9
170
+ inductor_meta["kernel_num_gb"] = num_gb
171
+ return f"""
172
+ @triton_heuristics.template(
173
+ num_stages={self.num_stages},
174
+ num_warps={self.num_warps},
175
+ triton_meta={triton_meta!r},
176
+ inductor_meta={inductor_meta!r},
177
+ )
178
+ @triton.jit
179
+ """
180
+
181
+ def def_kernel(self, *argnames):
182
+ """
183
+ Hook called from template code to generate function def and
184
+ needed args.
185
+ """
186
+ assert all(isinstance(x, str) for x in argnames)
187
+ renames = IndentedBuffer(initial_indent=1)
188
+
189
+ named_args = self.input_nodes[
190
+ self.prefix_args : len(self.input_nodes) - self.suffix_args
191
+ ]
192
+
193
+ assert len(argnames) == len(named_args), (
194
+ len(argnames),
195
+ len(named_args),
196
+ self.prefix_args,
197
+ len(self.input_nodes),
198
+ )
199
+
200
+ for input_node in self.input_nodes[: self.prefix_args]:
201
+ # get args in correct order
202
+ self.args.input(input_node.get_name())
203
+
204
+ for name, input_node in zip(argnames, named_args):
205
+ arg_name = f"arg_{name}"
206
+ self.named_input_nodes[name] = input_node
207
+ self.args.input_buffers[input_node.get_name()] = arg_name
208
+
209
+ # The args may be duplicated, so renaming must be after args are de-duplicated.
210
+ for name in argnames:
211
+ input_node = self.named_input_nodes[name]
212
+ arg_name = self.args.input_buffers[input_node.get_name()]
213
+ if input_node.get_layout().offset == 0:
214
+ renames.writeline(f"{name} = {arg_name}")
215
+ else:
216
+ offset = texpr(self.rename_indexing(input_node.get_layout().offset))
217
+ renames.writeline(f"{name} = {arg_name} + {offset}")
218
+
219
+ for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]:
220
+ # get args in correct order
221
+ self.args.input(input_node.get_name())
222
+
223
+ def hook():
224
+ # python_argdefs() cannot be run until after the rest of the template lazily adds more args
225
+ arg_defs, *_ = self.args.python_argdefs()
226
+ code = IndentedBuffer()
227
+ code.splice(gen_common_triton_imports())
228
+ code.splice(self.jit_lines())
229
+ code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):")
230
+ with code.indent():
231
+ code.splice(self.defines)
232
+ code.splice(renames.getvalue())
233
+ return code.getvalue()
234
+
235
+ assert "<DEF_KERNEL>" not in self.render_hooks
236
+ self.render_hooks["<DEF_KERNEL>"] = hook
237
+ return "<DEF_KERNEL>"
238
+
239
+ def size(self, name: str, index: int):
240
+ """
241
+ Hook called from template code to get the size of an arg.
242
+ Will add needed args to pass it in if it is dynamic.
243
+ """
244
+ assert isinstance(index, int)
245
+ if name is None:
246
+ val = self.output_node.get_size()[index]
247
+ else:
248
+ assert isinstance(name, str)
249
+ val = self.named_input_nodes[name].get_size()[index]
250
+ return texpr(self.rename_indexing(val))
251
+
252
+ def stride(self, name, index):
253
+ """
254
+ Hook called from template code to get the stride of an arg.
255
+ Will add needed args to pass it in if it is dynamic.
256
+ """
257
+ assert isinstance(index, int)
258
+ if name is None:
259
+ val = self.output_node.get_stride()[index]
260
+ else:
261
+ assert isinstance(name, str)
262
+ val = self.named_input_nodes[name].get_stride()[index]
263
+ return texpr(self.rename_indexing(val))
264
+
265
+ def store_output(self, indices, val, mask):
266
+ """
267
+ Hook called from template code to store the final output
268
+ (if the buffer hasn't been optimized away), then append any
269
+ epilogue fusions.
270
+ """
271
+ assert isinstance(indices, (list, tuple))
272
+ assert isinstance(val, str)
273
+ assert isinstance(mask, str)
274
+ assert self.template_mask is None
275
+ indices = list(map(TritonPrinter.paren, indices))
276
+ index_symbols = [sympy.Symbol(x) for x in indices]
277
+ lengths = [V.graph.sizevars.simplify(s) for s in self.output_node.get_size()]
278
+ assert len(indices) == len(lengths)
279
+
280
+ # glue to make generated code use same indexing from template
281
+ for name, range_tree_entry in zip(
282
+ indices, self.range_trees[0].construct_entries(lengths)
283
+ ):
284
+ range_tree_entry.set_name(name)
285
+ contiguous_index = sympy_dot(
286
+ ir.FlexibleLayout.contiguous_strides(lengths), index_symbols
287
+ )
288
+ contiguous_index = self.rename_indexing(contiguous_index)
289
+ self.body.writeline("xindex = " + texpr(contiguous_index))
290
+ self.range_trees[0].lookup(sympy.Integer(1), sympy_product(lengths)).set_name(
291
+ "xindex"
292
+ )
293
+ self.template_mask = mask
294
+ self.template_indices = indices
295
+ output_index = self.output_node.get_layout().make_indexer()(index_symbols)
296
+ output_index = self.rename_indexing(output_index)
297
+ if output_index == contiguous_index:
298
+ output_index = sympy.Symbol("xindex")
299
+
300
+ epilogue_args = [val]
301
+ for input_node in itertools.chain(
302
+ self.input_nodes[: self.prefix_args],
303
+ self.input_nodes[len(self.input_nodes) - self.suffix_args :],
304
+ ):
305
+ input_node.freeze_layout()
306
+ epilogue_args.append(input_node.make_loader()(index_symbols))
307
+
308
+ V.ops.store(
309
+ self.output_node.get_name(),
310
+ output_index,
311
+ self.epilogue_fn(*epilogue_args),
312
+ )
313
+ self.codegen_body()
314
+
315
+ def hook():
316
+ # more stuff might have been added since the codegen_body above
317
+ self.codegen_body()
318
+ return textwrap.indent(self.body.getvalue(), " ").strip()
319
+
320
+ assert "<STORE_OUTPUT>" not in self.render_hooks
321
+ self.render_hooks["<STORE_OUTPUT>"] = hook
322
+ return "<STORE_OUTPUT>"
323
+
324
+ def render(self, template, kwargs):
325
+ return PartialRender(
326
+ template.render(**self.template_env(), **kwargs),
327
+ self.render_hooks,
328
+ )
329
+
330
+ def make_load(self, name, indices, mask):
331
+ """
332
+ Optional helper called from template code to generate the code
333
+ needed to load from an tensor.
334
+ """
335
+ assert isinstance(indices, (list, tuple))
336
+ assert isinstance(name, str)
337
+ assert isinstance(mask, str)
338
+ stride = self.named_input_nodes[name].get_stride()
339
+ indices = list(map(TritonPrinter.paren, indices))
340
+ assert len(indices) == len(stride)
341
+ index = " + ".join(
342
+ f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)
343
+ )
344
+ return f"tl.load({name} + ({index}), {mask})"
345
+
346
+ def template_env(self):
347
+ """
348
+ Generate the namespace visible in the template.
349
+ """
350
+ return {
351
+ fn.__name__: fn
352
+ for fn in [
353
+ self.def_kernel,
354
+ self.size,
355
+ self.stride,
356
+ self.store_output,
357
+ self.make_load,
358
+ ]
359
+ }
360
+
361
+ def indexing(
362
+ self,
363
+ index: sympy.Expr,
364
+ *,
365
+ dense_indexing=False,
366
+ copy_shape=None,
367
+ override_mask=None,
368
+ block_ptr=False,
369
+ ):
370
+ """
371
+ Override the default indexing to use our custom mask and force
372
+ dense indexing.
373
+ """
374
+ return super().indexing(
375
+ index,
376
+ dense_indexing=False,
377
+ copy_shape=self.template_mask,
378
+ override_mask=self.template_mask,
379
+ block_ptr=block_ptr,
380
+ )
381
+
382
+ def initialize_range_tree(self, pid_cache):
383
+ super().initialize_range_tree(pid_cache)
384
+ # ignore default codegen
385
+ self.body.clear()
386
+ self.indexing_code.clear()
387
+
388
+ def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
389
+ wrapper = V.graph.wrapper_code
390
+ _, call_args, _ = self.args.python_argdefs()
391
+ call_args = [str(a) for a in call_args]
392
+
393
+ for i in range(len(call_args)):
394
+ if V.graph.is_unspec_arg(call_args[i]):
395
+ call_args[i] = call_args[i] + ".item()"
396
+ if isinstance(call_args[i], sympy.Symbol):
397
+ call_args[i] = texpr(call_args[i])
398
+
399
+ if V.graph.cpp_wrapper:
400
+ # In the cpp_wrapper case, we have to compute CUDA launch grid at runtime
401
+ # if any dynamic dimension is involved. We rely on the Python version
402
+ # of the grid function to generate those grid configs, which may contain
403
+ # symbolic values. The wrapper will use cexpr to print out C++ code
404
+ # appropriately for the grid configs.
405
+ grid_args = [V.graph.sizevars.simplify(s) for s in self.call_sizes] + [
406
+ self.meta
407
+ ]
408
+ grid = self.grid_fn(*grid_args)
409
+
410
+ wrapper.generate_kernel_call(
411
+ name,
412
+ call_args,
413
+ device_index=V.graph.scheduler.current_device.index,
414
+ grid=grid,
415
+ triton_meta=self.triton_meta,
416
+ )
417
+ else:
418
+ stream_name = wrapper.write_get_raw_stream(
419
+ V.graph.scheduler.current_device.index
420
+ )
421
+
422
+ wrapper.add_import_once(f"import {self.grid_fn.__module__}")
423
+ meta = wrapper.add_meta_once(self.meta)
424
+
425
+ grid_call = [
426
+ texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes
427
+ ] + [meta]
428
+ grid_call = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})"
429
+ wrapper.writeline(
430
+ f"{name}.run({', '.join(call_args)}, grid={grid_call}, stream={stream_name})"
431
+ )
432
+
433
+
434
+ @functools.lru_cache(None)
435
+ def _jinja2_env():
436
+ try:
437
+ import jinja2
438
+
439
+ return jinja2.Environment(
440
+ undefined=jinja2.StrictUndefined,
441
+ )
442
+ except ImportError:
443
+ return None
444
+
445
+
446
+ class TritonTemplate(KernelTemplate):
447
+ index_counter = itertools.count()
448
+ all_templates: Dict[str, "TritonTemplate"] = dict()
449
+
450
+ def __init__(self, name: str, grid: Any, source: str, debug=False):
451
+ super().__init__(name)
452
+ self.grid = grid
453
+ self.template = self._template_from_string(source)
454
+ assert name not in self.all_templates, "duplicate template name"
455
+ self.all_templates[name] = self
456
+ self.debug = debug
457
+
458
+ def generate(
459
+ self,
460
+ input_nodes,
461
+ layout,
462
+ num_stages,
463
+ num_warps,
464
+ prefix_args=0,
465
+ suffix_args=0,
466
+ epilogue_fn=identity,
467
+ **kwargs,
468
+ ):
469
+ assert self.template, "requires jinja2"
470
+ defines = StringIO()
471
+ for name, val in kwargs.items():
472
+ defines.write(f" {name} : tl.constexpr = {val}\n")
473
+ defines = defines.getvalue()
474
+
475
+ fake_out = ir.Buffer("buf_out", layout)
476
+ kernel_name = f"triton_{self.name}"
477
+
478
+ numel = sympy_product(layout.size)
479
+ buffers = itertools.chain(input_nodes, (fake_out,))
480
+ if not TritonScheduling.can_use_32bit_indexing(numel, buffers):
481
+ raise NotImplementedError(
482
+ "64-bit indexing is not yet implemented for triton templates"
483
+ )
484
+
485
+ kernel_options = dict(
486
+ input_nodes=input_nodes,
487
+ defines=defines,
488
+ num_stages=num_stages,
489
+ num_warps=num_warps,
490
+ grid_fn=self.grid,
491
+ meta=kwargs,
492
+ call_sizes=layout.size,
493
+ prefix_args=prefix_args,
494
+ suffix_args=suffix_args,
495
+ epilogue_fn=epilogue_fn,
496
+ index_dtype="tl.int32",
497
+ )
498
+ with patch.object(
499
+ V.graph, "get_dtype", self._fake_get_dtype(fake_out)
500
+ ), TritonTemplateKernel(
501
+ kernel_name=kernel_name,
502
+ output_node=fake_out,
503
+ use_jit=True,
504
+ **kernel_options,
505
+ ) as kernel:
506
+ try:
507
+ code = kernel.render(self.template, kwargs).finalize()
508
+ except ZeroDivisionError:
509
+ # TODO(nmacchioni): fix sympy division by zero
510
+ return None
511
+ if self.debug:
512
+ print("Generated Code:\n", code)
513
+ extra = (
514
+ "-".join(
515
+ [
516
+ *[
517
+ f"{kwarg}={repr(kwargs[kwarg])}"
518
+ for kwarg in sorted(kwargs.keys())
519
+ ],
520
+ f"num_stages={num_stages}",
521
+ f"num_warps={num_warps}",
522
+ ]
523
+ )
524
+ + "-"
525
+ )
526
+ mod = PyCodeCache.load(code, extra)
527
+ _, call_args, _ = kernel.args.python_argdefs()
528
+
529
+ expected_args = list(unique(x.get_name() for x in input_nodes))
530
+ expected_args.extend([fake_out.get_name()])
531
+ assert list(call_args)[: len(expected_args)] == expected_args, (
532
+ call_args,
533
+ expected_args,
534
+ )
535
+ extra_args = V.graph.sizevars.size_hints(
536
+ map(sympy.expand, call_args[len(expected_args) :]),
537
+ fallback=config.unbacked_symint_fallback,
538
+ )
539
+
540
+ kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
541
+
542
+ def make_kernel_render(out_node):
543
+ kernel = TritonTemplateKernel(
544
+ kernel_name=str(Placeholder.KERNEL_NAME),
545
+ output_node=out_node,
546
+ use_jit=False,
547
+ **kernel_options,
548
+ )
549
+ render = functools.partial(
550
+ kernel.render,
551
+ self.template,
552
+ kwargs,
553
+ )
554
+ return kernel, render
555
+
556
+ # create the BenchmarkRequest
557
+ assert mod.__file__ is not None
558
+ grid = self.grid(
559
+ *V.graph.sizevars.size_hints(
560
+ layout.size,
561
+ fallback=config.unbacked_symint_fallback,
562
+ ),
563
+ kwargs,
564
+ )
565
+ bmreq = TritonBenchmarkRequest(
566
+ module_path=mod.__file__,
567
+ module_cache_key=mod.key,
568
+ kernel_name=kernel_name,
569
+ grid=grid,
570
+ extra_args=extra_args,
571
+ num_stages=num_stages,
572
+ num_warps=num_warps,
573
+ matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
574
+ input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
575
+ output_tensor_meta=TensorMeta.from_irnodes(layout),
576
+ )
577
+
578
+ return TritonTemplateCaller(
579
+ kernel_hash_name,
580
+ input_nodes,
581
+ layout,
582
+ make_kernel_render,
583
+ extra.strip("-").replace("-", ", "),
584
+ bmreq,
585
+ log_info={
586
+ "tile_shape": str(
587
+ (
588
+ kwargs.get("BLOCK_M", -1),
589
+ kwargs.get("BLOCK_K", -1),
590
+ kwargs.get("BLOCK_N", -1),
591
+ )
592
+ ),
593
+ "num_stages": num_stages,
594
+ "num_warps": num_warps,
595
+ "allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
596
+ "acc_type": str(kwargs.get("ACC_TYPE", None)),
597
+ },
598
+ )
599
+
600
+
601
+ class ExternKernelChoice:
602
+ def __init__(
603
+ self,
604
+ kernel,
605
+ cpp_kernel=None,
606
+ *,
607
+ name=None,
608
+ has_out_variant=True,
609
+ op_overload=None,
610
+ use_fallback_kernel=False,
611
+ ):
612
+ super().__init__()
613
+ name = name or kernel.__name__
614
+ assert callable(kernel)
615
+ assert not hasattr(extern_kernels, name), "duplicate extern kernel"
616
+ self.name = name
617
+ self.cpp_kernel_name = cpp_kernel
618
+ self.has_out_variant = has_out_variant
619
+ setattr(extern_kernels, name, kernel)
620
+ self.op_overload = op_overload
621
+ self.use_fallback_kernel = use_fallback_kernel
622
+
623
+ def to_callable(self):
624
+ return getattr(extern_kernels, self.name)
625
+
626
+ def call_name(self):
627
+ return f"extern_kernels.{self.name}"
628
+
629
+ @functools.lru_cache(None)
630
+ def hash_key(self):
631
+ fn = self.to_callable()
632
+ parts = [
633
+ self.name,
634
+ getattr(fn, "__name__", ""),
635
+ getattr(fn, "__module__", ""),
636
+ ]
637
+ try:
638
+ parts.append(inspect.getsource(fn))
639
+ except Exception:
640
+ pass
641
+ return code_hash("-".join(parts))
642
+
643
+ def bind(
644
+ self,
645
+ input_nodes,
646
+ layout,
647
+ ordered_kwargs_for_cpp_kernel=(),
648
+ **kwargs,
649
+ ):
650
+ self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
651
+ return ExternKernelCaller(
652
+ self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant
653
+ )
654
+
655
+
656
+ class TritonTemplateCaller(ChoiceCaller):
657
+ def __init__(
658
+ self,
659
+ name,
660
+ input_nodes,
661
+ layout,
662
+ make_kernel_render,
663
+ debug_extra,
664
+ bmreq,
665
+ log_info: Optional[
666
+ Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]
667
+ ] = None,
668
+ ):
669
+ super().__init__(name, input_nodes, layout)
670
+ self.make_kernel_render = make_kernel_render
671
+ self.debug_extra = debug_extra
672
+ self.bmreq: TritonBenchmarkRequest = bmreq
673
+ if log_info is None:
674
+ log_info = {}
675
+ self.log_info: Dict[str, Any] = log_info
676
+ self.log_info.update(
677
+ {
678
+ "backend": "Triton",
679
+ "grid": str(self.bmreq.grid),
680
+ "num_stages": self.bmreq.num_stages,
681
+ "num_warps": self.bmreq.num_warps,
682
+ }
683
+ )
684
+
685
+ def benchmark(self, *args, out):
686
+ assert self.bmreq is not None
687
+ return self.bmreq.benchmark(*args, output_tensor=out)
688
+
689
+ def __str__(self):
690
+ return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})"
691
+
692
+ def call_name(self):
693
+ return f"template_kernels.{self.name}"
694
+
695
+ def hash_key(self):
696
+ return "-".join(
697
+ [
698
+ self.name.rsplit("_", 1)[0],
699
+ self.bmreq.module_cache_key,
700
+ ]
701
+ )
702
+
703
+ def output_node(self):
704
+ return ir.TensorBox.create(
705
+ ir.TritonTemplateBuffer(
706
+ layout=self.layout,
707
+ inputs=self.input_nodes,
708
+ make_kernel_render=self.make_kernel_render,
709
+ )
710
+ )
711
+
712
+ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
713
+ """Information returned here is logged to the autotune log file when that is enabled."""
714
+ return self.log_info
715
+
716
+
717
+ class ExternKernelCaller(ChoiceCaller):
718
+ def __init__(
719
+ self,
720
+ choice: ExternKernelChoice,
721
+ input_nodes,
722
+ layout,
723
+ kwargs=None,
724
+ *,
725
+ has_out_variant=True,
726
+ ):
727
+ super().__init__(choice.name, input_nodes, layout)
728
+ self.choice = choice
729
+ self.kwargs = kwargs or {}
730
+ self.has_out_variant = has_out_variant
731
+
732
+ def __str__(self):
733
+ return f"ExternKernelCaller({self.choice.call_name()})"
734
+
735
+ def benchmark(self, *args, out):
736
+ if self.has_out_variant:
737
+ return super().benchmark(*args, out=out)
738
+ else:
739
+ algo = self.to_callable()
740
+ out_new = algo(*args)
741
+ torch._C._dynamo.guards.assert_size_stride(
742
+ out_new, tuple(out.size()), tuple(out.stride())
743
+ )
744
+ out.copy_(out_new) # for correctness checking
745
+ return do_bench(lambda: algo(*args))
746
+
747
+ def to_callable(self):
748
+ fn = self.choice.to_callable()
749
+ if self.kwargs:
750
+ return functools.partial(fn, **self.kwargs)
751
+ else:
752
+ return fn
753
+
754
+ def hash_key(self):
755
+ return "-".join(
756
+ [
757
+ self.choice.name,
758
+ *[
759
+ f"{kwarg}={repr(self.kwargs[kwarg])}"
760
+ for kwarg in sorted(self.kwargs.keys())
761
+ ],
762
+ self.choice.hash_key(),
763
+ ]
764
+ )
765
+
766
+ def output_node(self):
767
+ if config.abi_compatible and self.choice.use_fallback_kernel:
768
+ assert (
769
+ self.choice.op_overload is not None
770
+ ), "Please provide an op_overload to use ir.FallbackKernel"
771
+ inner = ir.FallbackKernel.create(
772
+ self.choice.op_overload, *self.input_nodes, **self.kwargs
773
+ )
774
+ else:
775
+ cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc
776
+ inner = cls(
777
+ layout=self.layout,
778
+ inputs=self.input_nodes,
779
+ python_kernel_name=self.choice.call_name(),
780
+ cpp_kernel_name=self.choice.cpp_kernel_name,
781
+ ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel,
782
+ op_overload=self.choice.op_overload,
783
+ kwargs=self.kwargs,
784
+ )
785
+
786
+ return ir.TensorBox.create(inner)
787
+
788
+ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
789
+ """Information returned here is logged to the autotune log file when that is enabled."""
790
+ return {
791
+ "backend": "extern",
792
+ "kernel_call_name": self.choice.call_name(),
793
+ }
794
+
795
+
796
+ class ErrorFromChoice(RuntimeError):
797
+ def __init__(self, msg, choice: ChoiceCaller, inputs_str):
798
+ msg += f"\nFrom choice {choice}\n{inputs_str}"
799
+ super().__init__(msg)
800
+ self.choice = choice
801
+
802
+
803
+ class AlgorithmSelectorCache(PersistentCache):
804
+ def __call__(
805
+ self,
806
+ name,
807
+ choices: List[ChoiceCaller],
808
+ input_nodes,
809
+ layout,
810
+ # optional dict mapping arg indices to the functions
811
+ # generating a torch.Tensor for that input from the
812
+ # corresponding ir.Buffer. if passed for a given
813
+ # arg, the function will be called instead of
814
+ # generating a random torch.Tensor for benchmarking.
815
+ input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
816
+ precompilation_timeout_seconds: int = 60 * 60,
817
+ ):
818
+ from .codegen.cuda.cuda_kernel import CUDATemplateCaller
819
+
820
+ # TODO(nmacchioni): remove once CI tests are fixed
821
+ choices = [choice for choice in choices if choice is not None]
822
+ if len(choices) == 0:
823
+ raise RuntimeError(
824
+ "No choices to select, please consider adding ATEN into max_autotune_gemm_backends "
825
+ "config (defined in torch/_inductor/config.py) to allow at least one choice. "
826
+ )
827
+ log.debug("Max autotune selects from %s choices.", str(len(choices)))
828
+
829
+ if len(choices) == 1:
830
+ if not isinstance(choices[0], CUDATemplateCaller):
831
+ # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size.
832
+ return choices[0].output_node()
833
+
834
+ @functools.lru_cache(None)
835
+ def make_benchmark_fn():
836
+ return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns)
837
+
838
+ def precompile(choices):
839
+ if (
840
+ precompilation_timeout_seconds is None
841
+ or precompilation_timeout_seconds <= 0
842
+ ):
843
+ return
844
+ num_workers = min(
845
+ config.compile_threads,
846
+ torch.get_num_threads(),
847
+ len(choices),
848
+ )
849
+ if num_workers <= 0:
850
+ return
851
+ log.info(
852
+ "Multithreaded precompilation for %d choices using %d worker threads",
853
+ len(choices),
854
+ num_workers,
855
+ )
856
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
857
+ futures = executor.map(
858
+ lambda c: c.precompile(),
859
+ [c for c in choices if hasattr(c, "precompile")],
860
+ timeout=precompilation_timeout_seconds,
861
+ )
862
+ try:
863
+ iterator = iter(futures)
864
+ while True:
865
+ try:
866
+ next(iterator)
867
+ except CUDACompileError:
868
+ log.error( # noqa: G201
869
+ "CUDA Compilation error", exc_info=True
870
+ )
871
+ except TimeoutError:
872
+ log.warning(
873
+ f"Precompilation timed out after {precompilation_timeout_seconds} seconds." # noqa: G004
874
+ )
875
+ except StopIteration:
876
+ pass
877
+ executor.shutdown(wait=True)
878
+
879
+ def autotune(choices):
880
+ try:
881
+ precompile(choices)
882
+ except TimeoutError:
883
+ log.warning(
884
+ "Precompilation phase took longer than timeout allowed. Continuing"
885
+ )
886
+ pass
887
+ return make_benchmark_fn()(choices)
888
+
889
+ if config.autotune_in_subproc:
890
+ from .autotune_process import tuning_pool
891
+
892
+ # do the optional warmup
893
+ tuning_pool.initialize()
894
+
895
+ autotune_start_ts = time.time()
896
+ timings = self.lookup(
897
+ choices,
898
+ name,
899
+ repr([self.key_of(x) for x in input_nodes]),
900
+ autotune,
901
+ )
902
+ autotune_elapse = time.time() - autotune_start_ts
903
+ if timings == {} or choices[0] not in timings:
904
+ return choices[0].output_node()
905
+
906
+ if make_benchmark_fn.cache_info().currsize:
907
+ counters["inductor"]["select_algorithm_autotune"] += 1
908
+ if (
909
+ make_benchmark_fn.cache_info().currsize
910
+ or log.getEffectiveLevel() == logging.DEBUG
911
+ or config.trace.log_autotuning_results
912
+ ):
913
+ self.log_results(name, input_nodes, timings, autotune_elapse)
914
+ selected_choice = builtins.min(timings, key=timings.__getitem__).output_node()
915
+ log.debug("selected choice: %s", str(selected_choice))
916
+ return selected_choice
917
+
918
+ @classmethod
919
+ def make_benchmark_fn(
920
+ cls,
921
+ choices,
922
+ input_nodes,
923
+ layout,
924
+ input_gen_fns=None,
925
+ ):
926
+ if input_gen_fns is None:
927
+ input_gen_fns = {}
928
+
929
+ # de-duplicate args
930
+ unique_example_inputs = {
931
+ x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x)
932
+ for i, x in enumerate(input_nodes)
933
+ }
934
+ example_inputs = list(unique_example_inputs.values())
935
+ example_inputs_extern = [
936
+ torch.as_strided(
937
+ unique_example_inputs[input_node.get_name()],
938
+ V.graph.sizevars.size_hints(
939
+ input_node.get_size(),
940
+ fallback=config.unbacked_symint_fallback,
941
+ ),
942
+ V.graph.sizevars.size_hints(
943
+ input_node.get_stride(),
944
+ fallback=config.unbacked_symint_fallback,
945
+ ),
946
+ V.graph.sizevars.size_hint(
947
+ input_node.get_layout().offset,
948
+ fallback=config.unbacked_symint_fallback,
949
+ ),
950
+ )
951
+ for input_node in input_nodes
952
+ ]
953
+
954
+ out = cls.benchmark_example_value(layout)
955
+ out_extern = torch.as_strided(
956
+ out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
957
+ )
958
+ if VERIFY:
959
+ choices[0].benchmark(*example_inputs_extern, out=out_extern)
960
+ expected = out_extern.clone()
961
+
962
+ if DEBUG:
963
+ print(f"{len(choices)} tuning requests:")
964
+
965
+ def debug_str():
966
+ def tensor_repr(x):
967
+ return (
968
+ f"torch.empty_strided({tuple(x.size())!r}, {tuple(x.stride())!r}, "
969
+ f"dtype={x.dtype!r}, device={x.device.type!r})"
970
+ )
971
+
972
+ lines = [
973
+ "inputs = [",
974
+ ]
975
+ for x in example_inputs:
976
+ lines.append(f" {tensor_repr(x)},")
977
+ lines += ["]", f"out = {tensor_repr(out)}", ""]
978
+ return "\n".join(lines)
979
+
980
+ def benchmark_choice_in_current_process(choice):
981
+ out.zero_()
982
+ if isinstance(choice, ExternKernelCaller):
983
+ # aten kernels want the offset baked in for sliced tensors
984
+ result = choice.benchmark(*example_inputs_extern, out=out_extern)
985
+ else:
986
+ # triton templates want the base pointer for sliced tensors
987
+ result = choice.benchmark(*example_inputs, out=out)
988
+ if VERIFY:
989
+ torch.testing.assert_close(out_extern, expected, **VERIFY)
990
+ torch.cuda.synchronize() # shake out any CUDA errors
991
+ return result
992
+
993
+ def benchmark_in_current_process(choices):
994
+ timings = {}
995
+ for choice in choices:
996
+ try:
997
+ timing = benchmark_choice_in_current_process(choice)
998
+ except CUDACompileError as e:
999
+ log.warning(
1000
+ "CUDA compilation error: \n%s. \nIgnore this choice.", str(e)
1001
+ )
1002
+ timing = float("inf")
1003
+ except RuntimeError as e:
1004
+ msg = str(e)
1005
+ if "invalid argument" in msg:
1006
+ msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n"
1007
+ log.warning(msg)
1008
+ timing = float("inf")
1009
+ else:
1010
+ if "illegal memory access" in msg:
1011
+ msg += "\n\nEither error in template or triton bug.\n"
1012
+ raise ErrorFromChoice(msg, choice, debug_str()) # noqa: TRY200
1013
+ except AssertionError as e:
1014
+ raise AssertionError( # noqa: TRY200
1015
+ f"Incorrect result from choice {choice}\n\n{e}"
1016
+ )
1017
+
1018
+ timings[choice] = timing
1019
+
1020
+ return timings
1021
+
1022
+ def benchmark_in_sub_process(choices):
1023
+ from . import autotune_process
1024
+
1025
+ # only benchmark triton kernel in sub process for now.
1026
+ # ATen/Extern kernel are still benchmarked in the current process.
1027
+ extern = [c for c in choices if isinstance(c, ExternKernelCaller)]
1028
+ triton = [c for c in choices if not isinstance(c, ExternKernelCaller)]
1029
+
1030
+ timings = benchmark_in_current_process(extern)
1031
+ timings.update(autotune_process.benchmark_in_sub_process(triton))
1032
+ return timings
1033
+
1034
+ benchmark = (
1035
+ benchmark_in_sub_process
1036
+ if config.autotune_in_subproc
1037
+ else benchmark_in_current_process
1038
+ )
1039
+
1040
+ return benchmark
1041
+
1042
+ @staticmethod
1043
+ def log_results(
1044
+ name: str,
1045
+ input_nodes: List[ir.IRNode],
1046
+ timings: Dict[ChoiceCaller, float],
1047
+ elapse: float,
1048
+ ):
1049
+ V.debug.log_autotuning_results(name, input_nodes, timings, elapse)
1050
+ if not (config.max_autotune or config.max_autotune_gemm) or not PRINT_AUTOTUNE:
1051
+ return
1052
+ sizes = ", ".join(
1053
+ [
1054
+ "x".join(
1055
+ map(
1056
+ str,
1057
+ V.graph.sizevars.size_hints(
1058
+ n.get_size(), fallback=config.unbacked_symint_fallback
1059
+ ),
1060
+ )
1061
+ )
1062
+ for n in input_nodes
1063
+ ]
1064
+ )
1065
+ n = None if log.getEffectiveLevel() == logging.DEBUG else 10
1066
+ top_k = sorted(timings, key=timings.__getitem__)[:n]
1067
+ best = top_k[0]
1068
+ best_time = timings[best]
1069
+ sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
1070
+ for choice in top_k:
1071
+ result = timings[choice]
1072
+ if result:
1073
+ sys.stderr.write(
1074
+ f" {choice.name} {result:.4f} ms {best_time/result:.1%}\n"
1075
+ )
1076
+ else:
1077
+ sys.stderr.write(
1078
+ f" {choice.name} {result:.4f} ms <DIVIDED BY ZERO ERROR>\n"
1079
+ )
1080
+
1081
+ autotune_type_str = (
1082
+ "SubProcess" if config.autotune_in_subproc else "SingleProcess"
1083
+ )
1084
+ sys.stderr.write(f"{autotune_type_str} AUTOTUNE takes {elapse:.4f} seconds\n")
1085
+
1086
+ @staticmethod
1087
+ def benchmark_example_value(node):
1088
+ """
1089
+ Convert an ir.Buffer into a concrete torch.Tensor we can use for
1090
+ benchmarking.
1091
+ """
1092
+ if isinstance(node, ir.Layout):
1093
+ node = ir.Buffer("fake", node)
1094
+ # triton templates want the base tensor.
1095
+ if isinstance(node, ir.BaseView):
1096
+ node = node.unwrap_view()
1097
+ # preserve rng states to avoid the rand_strided call below changes
1098
+ # the rng states for the real model code.
1099
+ with preserve_rng_state():
1100
+ return rand_strided(
1101
+ V.graph.sizevars.size_hints(
1102
+ node.get_size(),
1103
+ fallback=config.unbacked_symint_fallback,
1104
+ ),
1105
+ V.graph.sizevars.size_hints(
1106
+ node.get_stride(),
1107
+ fallback=config.unbacked_symint_fallback,
1108
+ ),
1109
+ device=node.get_device(),
1110
+ dtype=node.get_dtype(),
1111
+ extra_size=node.layout.offset,
1112
+ )
1113
+
1114
+ @staticmethod
1115
+ def key_of(node):
1116
+ """
1117
+ Extract the pieces of an ir.Buffer that we should invalidate cached
1118
+ autotuning results on.
1119
+ """
1120
+ sizevars = V.graph.sizevars
1121
+ return (
1122
+ node.get_device().type,
1123
+ str(node.get_dtype()),
1124
+ *sizevars.size_hints(
1125
+ node.get_size(),
1126
+ fallback=config.unbacked_symint_fallback,
1127
+ ),
1128
+ *sizevars.size_hints(
1129
+ node.get_stride(),
1130
+ fallback=config.unbacked_symint_fallback,
1131
+ ),
1132
+ sizevars.size_hint(
1133
+ node.get_layout().offset,
1134
+ fallback=config.unbacked_symint_fallback,
1135
+ ),
1136
+ )
1137
+
1138
+
1139
+ _ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None
1140
+
1141
+
1142
+ def autotune_select_algorithm(*args, **kwargs):
1143
+ global _ALGORITHM_SELECTOR_CACHE
1144
+ if _ALGORITHM_SELECTOR_CACHE is None:
1145
+ _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
1146
+ return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
1147
+
1148
+
1149
+ def realize_inputs(*args):
1150
+ if len(args) == 1:
1151
+ return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0]))
1152
+ return [realize_inputs(x) for x in args]
1153
+
1154
+
1155
+ # ensure lowering is imported so that `extern_kernels.*` is populated
1156
+ from . import lowering # noqa: F401
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/sizevars.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import itertools
3
+ import logging
4
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
5
+
6
+ import sympy
7
+ from sympy import Expr
8
+
9
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
10
+ from torch.utils._sympy.functions import FloorDiv, ModularIndexing
11
+ from torch.utils._sympy.value_ranges import bound_sympy
12
+
13
+ from .utils import sympy_index_symbol, sympy_subs, VarRanges
14
+ from .virtualized import V
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ # This class is a little awkward, because ShapeEnv is doing most of the heavy
20
+ # lifting and in some cases we should be directly passing through to ShapeEnv,
21
+ # but there is some extra inductor logic that needs to be handled here
22
+ class SizeVarAllocator:
23
+ def __init__(self, shape_env=None):
24
+ super().__init__()
25
+ if shape_env is None:
26
+ shape_env = ShapeEnv()
27
+ self.shape_env = shape_env
28
+ self.var_to_val = self.shape_env.var_to_val
29
+ self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements
30
+ # Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
31
+ # The basic idea is if we have some complicated sympy expression
32
+ # f(s0), we may choose to precompute it on the host and then replace
33
+ # all occurrences of that sympy expression with ps0, so that when we
34
+ # codegen we simply reference ps0 directly without repeating
35
+ # f(s0). Unlike regular size variables, ps variables cannot be
36
+ # guarded upon; so if we are asked to guard on a Sympy expression
37
+ # which potentially could have already had a precomputed replacement
38
+ # on it, we are obligated to invert the precomputed replacements
39
+ # (inv_precomputed_replacements).
40
+ self.precomputed_replacements: Dict[Expr, sympy.Symbol] = dict()
41
+ self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = dict()
42
+ self.stride_vars = self.make_stride_vars_cache()
43
+ self.simplify_with_ranges = self.make_simplify_with_ranges_cache()
44
+ self._simplify_loops = self.make_simplify_loops_cache()
45
+
46
+ def simplify(self, expr: Expr):
47
+ return sympy.expand(expr).xreplace(self.replacements)
48
+
49
+ def make_simplify_with_ranges_cache(self) -> Callable[[Expr, VarRanges], Expr]:
50
+ """
51
+ self._simplify_with_ranges() can be expensive, cache its results
52
+ """
53
+ cache: Dict[Tuple[Any, ...], Expr] = dict()
54
+ replacement_count = len(self.replacements)
55
+
56
+ def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr:
57
+ nonlocal replacement_count
58
+ if replacement_count != len(self.replacements):
59
+ # new replacements invalidates cached results
60
+ cache.clear()
61
+ replacement_count = len(self.replacements)
62
+ key = (expr, *var_ranges.items())
63
+ result = cache.get(key, None)
64
+ if result is None:
65
+ result = self._simplify_with_ranges(expr, var_ranges)
66
+ cache[key] = result
67
+ return result
68
+
69
+ return simplify_with_ranges
70
+
71
+ def make_simplify_loops_cache(self):
72
+ """
73
+ self._simplify_with_ranges() can be expensive, cache its results
74
+ """
75
+ cache: Dict[Tuple[Any, ...], Any] = dict()
76
+ replacement_count = len(self.replacements)
77
+
78
+ def simplify_loops(index_vars, sizes, index_formulas):
79
+ nonlocal replacement_count
80
+ if replacement_count != len(self.replacements):
81
+ # new replacements invalidates cached results
82
+ cache.clear()
83
+ replacement_count = len(self.replacements)
84
+ key = (*index_vars, *sizes, *index_formulas)
85
+ result = cache.get(key, None)
86
+ if result is None:
87
+ result = self._simplify_loops_impl(index_vars, sizes, index_formulas)
88
+ cache[key] = result
89
+ return result
90
+
91
+ return simplify_loops
92
+
93
+ def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr:
94
+ """
95
+ Simplify indexing expression with knowledge of the ranges of
96
+ iteration variables.
97
+ """
98
+
99
+ expr = join_dimensions(self.simplify(expr))
100
+ original_expr = expr
101
+
102
+ def remove_zero_terms(base, divisor):
103
+ """Symbols smaller than the divisor are zero"""
104
+ for v in base.free_symbols:
105
+ if v in var_ranges:
106
+ # var smaller than divisor can be removed
107
+ # if the rest is guaranteed to be multiple of divisor
108
+ rest = sympy.Wild("_rest", exclude=[v])
109
+ m = base.match(v + rest)
110
+ if m and v not in m[rest].free_symbols:
111
+ gcd = sympy.gcd(m[rest], divisor)
112
+ if gcd == divisor:
113
+ if self.statically_known_leq(var_ranges[v], divisor):
114
+ base = m[rest]
115
+ return base
116
+
117
+ def visit_indexing_div(base, divisor):
118
+ return FloorDiv(remove_zero_terms(base, divisor), divisor)
119
+
120
+ def visit_modular_indexing(base, divisor, modulus):
121
+ base = remove_zero_terms(base, divisor)
122
+ base_pos = True
123
+ if isinstance(base, ModularIndexing):
124
+ # for modular indexing, biggest values from the ranges don't necessarily result in
125
+ # the biggest result, the biggest result is modulus - 1
126
+ base_s = base.args[2] - 1
127
+ elif not base.has(ModularIndexing):
128
+ # actual iteration range is to size-1
129
+ iter_ranges_zero = {k: 0 for k, v in var_ranges.items()}
130
+ base_lowest = sympy_subs(base, iter_ranges_zero)
131
+ if self.statically_known_leq(0, base_lowest): # type: ignore[arg-type]
132
+ # can't replace with indexing div if base can be negative
133
+ base_pos = True
134
+ else:
135
+ base_pos = False
136
+ iter_ranges = {k: v - 1 for k, v in var_ranges.items()}
137
+ base_s = sympy_subs(base, iter_ranges)
138
+ else:
139
+ base_s = base
140
+ if self.statically_known_lt(base_s, modulus * divisor) and base_pos:
141
+ return FloorDiv(base, divisor)
142
+ return ModularIndexing(base, divisor, modulus)
143
+
144
+ if expr.has(ModularIndexing):
145
+ expr = expr.replace(
146
+ ModularIndexing(
147
+ sympy.Wild("base"),
148
+ sympy.Wild("divisor"),
149
+ sympy.Wild("modulus"),
150
+ ),
151
+ visit_modular_indexing,
152
+ )
153
+
154
+ if expr.has(FloorDiv):
155
+ expr = expr.replace(
156
+ FloorDiv(
157
+ sympy.Wild("base"),
158
+ sympy.Wild("divisor"),
159
+ ),
160
+ visit_indexing_div,
161
+ )
162
+
163
+ if expr != original_expr:
164
+ return self._simplify_with_ranges(expr, var_ranges)
165
+ return expr
166
+
167
+ def _simplify_loops_impl(
168
+ self, index_vars: List[sympy.Symbol], sizes, index_formulas
169
+ ):
170
+ """
171
+ Try to remove as many axis from loop iterations as possible, by:
172
+ 1) removing size==1 dimensions
173
+ 2) fuse contiguous dimensions into a single loop
174
+ If channel_last = True, we will prevent the last dim fused with other dims
175
+ """
176
+ sizes = list(map(self.simplify, sizes))
177
+
178
+ strides = [self.stride_vars(x, index_vars) for x in index_formulas]
179
+ assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))
180
+
181
+ for i in range(len(sizes)):
182
+ if sizes[i] == 1:
183
+ # remove dim
184
+ sizes[i] = None
185
+
186
+ def can_merge_dims(a, b):
187
+ for k in range(len(strides)):
188
+ if self.simplify(strides[k][a] * sizes[a]) == self.simplify(
189
+ strides[k][b]
190
+ ):
191
+ # approximate test passed, try sound version
192
+ va = index_vars[a]
193
+ vb = index_vars[b]
194
+ v = sympy_index_symbol("_merge_tester")
195
+ expr1 = sympy_subs(index_formulas[k], {va: v * sizes[a], vb: 0})
196
+ expr2 = sympy_subs(index_formulas[k], {va: 0, vb: v})
197
+ if self.simplify(expr1) == self.simplify(expr2):
198
+ continue
199
+ return False
200
+ return True
201
+
202
+ changed = True
203
+ while changed:
204
+ changed = False
205
+ for i, j in itertools.product(
206
+ reversed(range(len(sizes))), reversed(range(len(sizes)))
207
+ ):
208
+ if i == j or sizes[i] is None or sizes[j] is None:
209
+ continue
210
+ if can_merge_dims(i, j):
211
+ changed = True
212
+ sizes[i] = sizes[i] * sizes[j]
213
+ sizes[j] = None
214
+
215
+ def reindex(index):
216
+ it = list(reversed(index))
217
+ new_index = []
218
+ for size in sizes:
219
+ if size is None:
220
+ new_index.append(sympy.Integer(0))
221
+ else:
222
+ new_index.append(it.pop())
223
+ assert not it
224
+ return new_index
225
+
226
+ def prune(index):
227
+ assert len(index) == len(sizes)
228
+ return [i for i, s in zip(index, sizes) if s is not None]
229
+
230
+ return [x for x in sizes if x is not None], reindex, prune
231
+
232
+ # Note - [On Statically Known]
233
+ #
234
+ # The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system
235
+ # operated by providing essentially a question, where the size hinted values were evaluated. If the condition was
236
+ # true, we add a guard and return True, otherwise, False.
237
+ #
238
+ # def maybe_guard_foo(args):
239
+ # if size_hinted_check(args):
240
+ # return False # No guard, no optim
241
+ # guard(args) # Make a guard
242
+ # return True # Safe to apply optimization
243
+ #
244
+ # The prior system incurred a guard, and green lit an optimization.
245
+ #
246
+ # The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the
247
+ # condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we
248
+ # return False.
249
+ #
250
+ # def maybe_guard_foo(args):
251
+ # if all_static(args):
252
+ # return True # Safe to apply optimization
253
+ # else:
254
+ # return False # No guard, no optim
255
+
256
+ # See Note - [On Statically Known]
257
+
258
+ def is_expr_static_and_true(self, expr: Union[Expr, int]) -> bool:
259
+ if expr in (True, False):
260
+ return bool(expr)
261
+
262
+ try:
263
+ simplified = self.shape_env._maybe_evaluate_static(expr)
264
+ if simplified is not None:
265
+ return bool(simplified)
266
+ except Exception:
267
+ log.debug("Could not simplify %s", expr)
268
+
269
+ return False
270
+
271
+ def statically_known_equals(self, left: Expr, right: Expr) -> bool:
272
+ """
273
+ Returns a bool indicating if it is sound to optimize as if left and right are equal.
274
+ """
275
+ return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type]
276
+
277
+ # See Note - [On Statically Known]
278
+ def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
279
+ """
280
+ Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
281
+ """
282
+ if len(left) != len(right):
283
+ return False
284
+ if all(self.statically_known_equals(l, r) for l, r in zip(left, right)):
285
+ return True
286
+ return False
287
+
288
+ # See Note - [On Statically Known]
289
+ def statically_known_leq(self, left: Expr, right: Expr) -> bool:
290
+ """
291
+ Returns a bool indicating if it is sound to optimize as if left is less than or equal to right.
292
+ """
293
+ expr = left <= right
294
+ return self.is_expr_static_and_true(expr)
295
+
296
+ # See Note - [On Statically Known]
297
+ def statically_known_lt(self, left: Expr, right: Expr) -> bool:
298
+ """
299
+ Returns a bool indicating if it is sound to optimize as if left is less than right.
300
+ """
301
+ expr = left < right
302
+ return self.is_expr_static_and_true(expr)
303
+
304
+ # See Note - [On Statically Known]
305
+ def statically_known_multiple_of(self, numerator: Expr, denominator: Expr) -> bool:
306
+ """
307
+ Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
308
+ """
309
+ expr = sympy.Eq(numerator % denominator, 0)
310
+ return self.is_expr_static_and_true(expr) # type: ignore[arg-type]
311
+
312
+ # The guard functions require you to ALREADY KNOW that a particular
313
+ # condition holds. If you don't know (you want to guard on an expression
314
+ # being a particular value, and then get access to that value), use
315
+ # the evaluate functions.
316
+
317
+ def guard_equals(self, left: Expr, right: Expr) -> Expr:
318
+ if isinstance(left, Expr):
319
+ left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type]
320
+ if isinstance(right, Expr):
321
+ right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type]
322
+ assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
323
+ return left
324
+
325
+ def guard_leq(self, left: Expr, right: Expr) -> None:
326
+ return self.guard_lt(left, right + 1)
327
+
328
+ def guard_lt(self, left: Expr, right: Expr) -> None:
329
+ assert self.shape_env.evaluate_expr(sympy.Lt(left, right))
330
+
331
+ def expect_true(self, expr: Expr, *, msg: str) -> None:
332
+ expr = sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
333
+ self.shape_env.defer_runtime_assert(expr, msg, fx_node=None)
334
+
335
+ def expect_equals(self, left: Expr, right: Expr, *, msg: str) -> Expr:
336
+ # Prefer returning the expression without unbacked symints
337
+ if self.shape_env.is_unbacked_symint(left):
338
+ self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type]
339
+ return right
340
+ elif self.shape_env.is_unbacked_symint(right):
341
+ self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type]
342
+ return left
343
+ else:
344
+ return self.guard_equals(left, right)
345
+
346
+ def guarded_order(self, seq):
347
+ """
348
+ Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing.
349
+ Used for generating block_ptrs.
350
+ """
351
+ seq = [*map(self.remove_precomputed_replacements, seq)]
352
+ seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)]
353
+ seq.sort()
354
+ order = [-1] * len(seq)
355
+ last_var = None
356
+ for new_index, (_, orig_index, var) in enumerate(seq):
357
+ order[orig_index] = new_index
358
+ if last_var is not None:
359
+ self.guard_leq(last_var, var)
360
+ last_var = var
361
+ return order
362
+
363
+ # The evaluate functions evaluate some symbolic sympy expression
364
+ # (NB: not necessarily an Expr) and return what the concrete result
365
+ # is, guarding on the expression being that result
366
+
367
+ # NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b)
368
+ # as this will ensure that you actually have a sympy'ified expression,
369
+ # and will prevent you from incorrectly writing evaluate_expr(a == b)
370
+ # which does the wrong thing if a or b is a sympy expression
371
+ def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool:
372
+ assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left)
373
+ return self.shape_env.evaluate_expr(sympy.sympify(left))
374
+
375
+ def evaluate_min(self, left: Expr, right: Expr) -> Expr:
376
+ """return the smaller of left and right, and guard on that choice"""
377
+ lv = self.size_hint(left)
378
+ rv = self.size_hint(right)
379
+ if lv <= rv:
380
+ self.guard_leq(left, right)
381
+ return left
382
+ else:
383
+ self.guard_leq(right, left)
384
+ return right
385
+
386
+ def evaluate_max(self, left: Expr, right: Expr) -> Expr:
387
+ """return the larger of left and right, and guard on that choice"""
388
+ # Always choose the opposite of eval min for consistency
389
+ # This means min(a, b) and max(a, b) produce the same guards
390
+ min_val = self.evaluate_min(left, right)
391
+ return right if min_val is left else left
392
+
393
+ def evaluate_static_shape(self, left: Expr) -> int:
394
+ right = self.size_hint(left)
395
+ self.guard_equals(left, sympy.Integer(right))
396
+ return int(right)
397
+
398
+ def evaluate_static_shapes(self, left: List[Expr]) -> List[int]:
399
+ return [self.evaluate_static_shape(x) for x in left]
400
+
401
+ def remove_precomputed_replacements(self, expr: Expr) -> Expr:
402
+ if any(s.name.startswith("ps") for s in expr.free_symbols): # type: ignore[attr-defined]
403
+ return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
404
+ return expr
405
+
406
+ def symbolic_hint(self, expr: Expr) -> Expr:
407
+ # Substitute all hints into expr, but leave unbacked symints alone
408
+ if not isinstance(expr, Expr):
409
+ assert isinstance(expr, int)
410
+ return expr
411
+ free_symbols = expr.free_symbols
412
+ if not free_symbols:
413
+ return int(expr) # type: ignore[return-value]
414
+ expr = self.remove_precomputed_replacements(expr)
415
+ return sympy_subs(expr, self.var_to_val)
416
+
417
+ def size_hint(self, expr: Expr, *, fallback: Optional[int] = None) -> int:
418
+ out = self.symbolic_hint(expr)
419
+ if not isinstance(out, (int, sympy.Integer)) and fallback is not None:
420
+ # Use the provided heuristic fallback hint
421
+ sym_vrs = {
422
+ s: self.shape_env.var_to_range.get(s, None) for s in expr.free_symbols
423
+ }
424
+ if all(vr is not None for vr in sym_vrs.values()):
425
+ expr_vr = bound_sympy(expr, sym_vrs) # type: ignore[arg-type]
426
+ lower = self.size_hint(expr_vr.lower) # type: ignore[arg-type]
427
+ upper = self.size_hint(expr_vr.upper) # type: ignore[arg-type]
428
+ fallback = min(max(fallback, lower), upper)
429
+ return fallback
430
+ try:
431
+ return int(out)
432
+ except Exception:
433
+ log.debug("failed on: %s", out)
434
+ raise
435
+
436
+ def size_hints(
437
+ self,
438
+ exprs: Iterable[Expr],
439
+ *,
440
+ fallback: Optional[int] = None,
441
+ ) -> Tuple[int, ...]:
442
+ return tuple(self.size_hint(x, fallback=fallback) for x in exprs)
443
+
444
+ def _lru_cache(self, fn, maxsize=None):
445
+ """
446
+ Wrapper around functools.lru_cache that clears when replacements
447
+ has been invalidated.
448
+ """
449
+ fn_cache = functools.lru_cache(maxsize)(fn)
450
+ prior_len = len(self.replacements)
451
+
452
+ @functools.wraps(fn)
453
+ def wrapper(*args, **kwargs):
454
+ nonlocal prior_len
455
+ if prior_len != len(self.replacements):
456
+ prior_len = len(self.replacements)
457
+ fn_cache.cache_clear()
458
+ return fn_cache(*args, **kwargs)
459
+
460
+ return wrapper
461
+
462
+ def make_stride_vars_cache(self):
463
+ cache = self._lru_cache(self._stride_vars)
464
+
465
+ def stride_vars(
466
+ index: Expr,
467
+ vars: List[sympy.Symbol],
468
+ support_vars: Optional[List[sympy.Symbol]] = None,
469
+ ) -> List[Expr]:
470
+ if not support_vars:
471
+ support_vars = vars
472
+ return cache(index, tuple(vars), tuple(support_vars))
473
+
474
+ return stride_vars
475
+
476
+ def _stride_vars(
477
+ self, index: Expr, vars: List[sympy.Symbol], support_vars: List[sympy.Symbol]
478
+ ) -> List[Expr]:
479
+ """Convert an indexing expression back into strides
480
+
481
+ NOTE: This is only valid if the index is a standard strided offset
482
+ calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a
483
+ stride of -10 because the index wraps around after the first element
484
+
485
+ """
486
+ strides = []
487
+ index = self.simplify(index)
488
+ # remove any offset
489
+ index = index - sympy_subs(
490
+ index, {v: sympy.Integer(0) for v in support_vars if v != 0}
491
+ )
492
+ for i in range(len(vars)):
493
+ # drop all the other dims
494
+ index_dim = sympy_subs(
495
+ index,
496
+ {
497
+ support_vars[j]: sympy.Integer(0)
498
+ for j in range(len(support_vars))
499
+ if vars[i] != support_vars[j] and support_vars[j] != 0
500
+ },
501
+ )
502
+ v = vars[i]
503
+ if v == 0:
504
+ strides.append(sympy.Integer(0))
505
+ else:
506
+ # TODO(jansel): should we use sympy.diff here?
507
+ strides.append(
508
+ sympy_subs(index_dim, {v: sympy.Integer(1)})
509
+ - sympy_subs(index_dim, {v: sympy.Integer(0)})
510
+ )
511
+ return strides
512
+
513
+ def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr:
514
+ """Extract offset part of an indexing expression"""
515
+ index = self.simplify(index)
516
+ return sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0})
517
+
518
+ def stride_hints(
519
+ self,
520
+ index: Expr,
521
+ vars: List[sympy.Symbol],
522
+ support_vars: Optional[List[sympy.Symbol]] = None,
523
+ ) -> List[int]:
524
+ for v in index.free_symbols:
525
+ if v.name.startswith("indirect"): # type: ignore[attr-defined]
526
+ index = sympy_subs(index, {v: 0}) # type: ignore[dict-item]
527
+ result = []
528
+ for s in self.stride_vars(index, vars, support_vars):
529
+ try:
530
+ result.append(self.size_hint(s))
531
+ except TypeError:
532
+ result.append(0)
533
+ return result
534
+
535
+ def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]:
536
+ strides = tuple(map(abs, self.stride_hints(index, vars)))
537
+ order = list(range(len(strides)))
538
+ order.sort(key=lambda x: (strides[x] == 0, strides[x]))
539
+ return order
540
+
541
+ def lookup_precomputed_size(self, expr: Expr) -> Expr:
542
+ if (
543
+ isinstance(expr, (int, sympy.Symbol, sympy.Number))
544
+ or expr.is_number
545
+ or expr.is_symbol
546
+ ):
547
+ return expr
548
+ expr = self.remove_precomputed_replacements(expr)
549
+ if expr not in self.precomputed_replacements:
550
+ sym = sympy_index_symbol(f"ps{len(self.precomputed_replacements)}")
551
+ self.precomputed_replacements[expr] = sym
552
+ self.inv_precomputed_replacements[sym] = expr
553
+ return self.precomputed_replacements[expr]
554
+
555
+ def free_symbols(self) -> Set[sympy.Symbol]:
556
+ return set(self.var_to_val.keys()) - set(self.replacements.keys())
557
+
558
+
559
+ def join_dimensions(expr: Expr) -> Expr:
560
+ if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing):
561
+ return expr # fast exit path
562
+ return _join_dimensions_cached(expr)
563
+
564
+
565
+ @functools.lru_cache(256)
566
+ def _join_dimensions_cached(expr: Expr) -> Expr:
567
+ """
568
+ ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4)
569
+ becomes
570
+ ModularIndexing(i0, 1, 128)
571
+ ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32)
572
+ becomes i0
573
+
574
+
575
+ This type of pattern can come from view operations
576
+ """
577
+ assert isinstance(expr, sympy.Add)
578
+
579
+ scale = sympy.Wild("scale", exclude=[0])
580
+ base = sympy.Wild("base")
581
+ divisor = sympy.Wild("divisor")
582
+ mod1 = sympy.Wild("modulus")
583
+ mod2 = sympy.Wild("modulus2")
584
+ for term1 in expr.args:
585
+ m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
586
+ if m1:
587
+ for term2 in expr.args:
588
+ m2 = term2.match(
589
+ m1[scale]
590
+ * m1[mod1]
591
+ * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2)
592
+ )
593
+ if m2 and term1 != term2:
594
+ expr = join_dimensions(
595
+ expr
596
+ - term1
597
+ - term2
598
+ + m1[scale]
599
+ * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2])
600
+ )
601
+ return expr
602
+ for term1 in expr.args:
603
+ m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
604
+ if m1:
605
+ for term2 in expr.args:
606
+ m2 = term2.match(
607
+ m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1])
608
+ )
609
+ if m2 is not None: # in case of success we get an empty dict here
610
+ expr = join_dimensions(
611
+ expr
612
+ - term1
613
+ - term2
614
+ + m1[scale] * FloorDiv(m1[base], m1[divisor])
615
+ )
616
+ return expr
617
+ return expr
618
+
619
+
620
+ class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined]
621
+ """
622
+ A wrapper around .virtualize.ops that uses var range information to
623
+ simplify ModularIndexing/FloorDiv.
624
+ """
625
+
626
+ def __init__(self, inner, var_ranges: VarRanges):
627
+ super().__init__(inner)
628
+ self.name = "SimplifyIndexing"
629
+ self._simplify: Callable[
630
+ [Expr], Expr
631
+ ] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
632
+
633
+ def load(self, name: str, index: sympy.Expr):
634
+ return self._inner.load(name, self._simplify(index))
635
+
636
+ def store(self, name, index, value, mode=None):
637
+ return self._inner.store(name, self._simplify(index), value, mode=mode)
638
+
639
+ def store_reduction(self, name, index, value):
640
+ return self._inner.store_reduction(name, self._simplify(index), value)
641
+
642
+ def index_expr(self, index, dtype):
643
+ return self._inner.index_expr(self._simplify(index), dtype)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/utils.py ADDED
@@ -0,0 +1,1428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import collections
4
+ import contextlib
5
+ import dataclasses
6
+ import enum
7
+ import functools
8
+ import getpass
9
+ import inspect
10
+ import io
11
+ import itertools
12
+ import logging
13
+ import math
14
+ import operator
15
+ import os
16
+ import platform
17
+ import re
18
+ import shutil
19
+ import sys
20
+ import tempfile
21
+ import textwrap
22
+ import time
23
+ import unittest
24
+ from dataclasses import fields
25
+ from datetime import datetime
26
+ from io import StringIO
27
+ from typing import (
28
+ Any,
29
+ Callable,
30
+ Dict,
31
+ Generic,
32
+ Iterable,
33
+ List,
34
+ NamedTuple,
35
+ Optional,
36
+ Protocol,
37
+ Set,
38
+ TypeVar,
39
+ Union,
40
+ ValuesView,
41
+ )
42
+ from unittest import mock
43
+
44
+ import sympy
45
+ from typing_extensions import Concatenate, ParamSpec
46
+
47
+ import torch
48
+ from torch._dynamo.device_interface import get_interface_for_device
49
+ from torch.autograd import DeviceType
50
+ from torch.autograd.profiler_util import EventList
51
+ from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
52
+ from . import config
53
+
54
+ log = logging.getLogger(__name__)
55
+
56
+ _T = TypeVar("_T")
57
+ VarRanges = Dict[sympy.Expr, sympy.Expr]
58
+
59
+
60
+ def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float:
61
+ """
62
+ Returns benchmark results by examining torch profiler events.
63
+ This could be more accurate as it doesn't count CPU side overhead.
64
+ However, this also requires manually excluding irrelevant event, e.g.
65
+ vectorized_elementwise_kernel which is used to fill L2 cache,
66
+ various CUDA events, etc, so could also be fragile.
67
+ """
68
+
69
+ fn()
70
+ torch.cuda.synchronize()
71
+ cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
72
+
73
+ # Estimate the runtime of the function
74
+ start_event = torch.cuda.Event(enable_timing=True)
75
+ end_event = torch.cuda.Event(enable_timing=True)
76
+ start_event.record()
77
+ for _ in range(5):
78
+ cache.zero_()
79
+ fn()
80
+ end_event.record()
81
+ torch.cuda.synchronize()
82
+ estimate_ms = start_event.elapsed_time(end_event) / 5
83
+
84
+ # compute number of warmup and repeat
85
+ n_warmup = max(1, int(warmup / estimate_ms))
86
+ n_repeat = max(1, int(rep / estimate_ms))
87
+
88
+ # Warm-up
89
+ for _ in range(n_warmup):
90
+ fn()
91
+
92
+ with torch.profiler.profile(
93
+ activities=[
94
+ torch.profiler.ProfilerActivity.CUDA,
95
+ ]
96
+ ) as p:
97
+ # Benchmark
98
+ for i in range(n_repeat):
99
+ # we clear the L2 cache before each run
100
+ cache.zero_()
101
+ # record time of `fn`
102
+ fn()
103
+ # Record clocks
104
+ torch.cuda.synchronize()
105
+
106
+ log.debug("raw events")
107
+ log.debug(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
108
+
109
+ filtered_events = EventList(
110
+ [
111
+ event
112
+ for event in p.events()
113
+ if event.device_type == DeviceType.CUDA and event.name != "Context Sync"
114
+ ]
115
+ )
116
+ if len(filtered_events) % n_repeat != 0:
117
+ raise RuntimeError(
118
+ "Failed to divide all profiling events into #repeat groups. "
119
+ "#CUDA events: %d, #repeats: %s",
120
+ len(filtered_events),
121
+ n_repeat,
122
+ )
123
+ num_event_per_group = len(filtered_events) / n_repeat
124
+ actual_events = EventList(
125
+ [
126
+ event
127
+ for i, event in enumerate(filtered_events)
128
+ if i % num_event_per_group != 0
129
+ ]
130
+ )
131
+ actual_events._build_tree()
132
+ actual_events = actual_events.key_averages()
133
+
134
+ log.debug("profiling time breakdown")
135
+ log.debug(actual_events.table(row_limit=-1))
136
+
137
+ res = sum(event.cuda_time_total for event in actual_events) / 1000.0 / n_repeat
138
+ log.debug("profiling results: %s ms", res)
139
+ return res
140
+
141
+
142
+ def do_bench(*args, **kwargs):
143
+ @functools.lru_cache(None)
144
+ def load_triton():
145
+ try:
146
+ # NB: Lazily load triton, as importing triton is slow
147
+ # see https://github.com/openai/triton/issues/1599
148
+ from triton.testing import do_bench as triton_do_bench
149
+ except ImportError as exc:
150
+ raise NotImplementedError("requires Triton") from exc
151
+
152
+ # triton PR https://github.com/openai/triton/pull/1513 change the
153
+ # quantile fields name from 'percentiles' to 'quantiles'
154
+ # and change the default value from (0.5, 0.2, 0.8) to None.
155
+ # This may break inductor since a caller expects a tuple may get a item.
156
+ #
157
+ # Add a wrapper to maintain the same behavior for inductor.
158
+ # Maybe we should have own implementation of this function?
159
+ return triton_do_bench, (
160
+ "quantiles"
161
+ if inspect.signature(triton_do_bench).parameters.get("quantiles")
162
+ is not None
163
+ else "percentiles"
164
+ )
165
+
166
+ triton_do_bench, quantile_field_name = load_triton()
167
+
168
+ if quantile_field_name not in kwargs:
169
+ kwargs[quantile_field_name] = (0.5, 0.2, 0.8)
170
+ return triton_do_bench(*args, **kwargs)[0]
171
+
172
+
173
+ @functools.lru_cache(None)
174
+ def has_torchvision_roi_align() -> bool:
175
+ try:
176
+ from torchvision.ops import roi_align # noqa: F401
177
+
178
+ return roi_align is not None and hasattr(
179
+ getattr(torch.ops, "torchvision", None), "roi_align"
180
+ )
181
+ except ImportError:
182
+ return False
183
+
184
+
185
+ def conditional_product(*args):
186
+ return functools.reduce(operator.mul, [x for x in args if x])
187
+
188
+
189
+ def decode_device(device: Union[Optional[torch.device], str]) -> torch.device:
190
+ if device is None:
191
+ return torch.tensor(0.0).device # default device
192
+ if isinstance(device, str):
193
+ device = torch.device(device)
194
+ if device.type != "cpu" and device.index is None:
195
+ device_interface = get_interface_for_device(device.type)
196
+ return torch.device(device.type, index=device_interface.Worker.current_device())
197
+ return device
198
+
199
+
200
+ def sympy_product(it):
201
+ return functools.reduce(operator.mul, it, sympy.Integer(1))
202
+
203
+
204
+ def sympy_dot(seq1, seq2):
205
+ assert len(seq1) == len(seq2)
206
+ return sympy.expand(sum(a * b for a, b in zip(seq1, seq2)))
207
+
208
+
209
+ def unique(it: Iterable[_T]) -> ValuesView[_T]:
210
+ return {id(x): x for x in it}.values()
211
+
212
+
213
+ def ceildiv(
214
+ numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr]
215
+ ) -> Union[int, sympy.Expr]:
216
+ if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr):
217
+ return CeilDiv(numer, denom)
218
+ # TODO: There is a bug in a call to this function, to repro:
219
+ # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
220
+ # --amp --only YituTechConvBert --dynamic-shapes
221
+ assert isinstance(numer, int) and isinstance(
222
+ denom, int
223
+ ), f"{numer}: {type(numer)}, {denom}: {type(denom)}"
224
+ return -(numer // -denom)
225
+
226
+
227
+ def next_power_of_2(n: int) -> int:
228
+ """Return the smallest power of 2 greater than or equal to n"""
229
+ n -= 1
230
+ n |= n >> 1
231
+ n |= n >> 2
232
+ n |= n >> 4
233
+ n |= n >> 8
234
+ n |= n >> 16
235
+ n |= n >> 32
236
+ n += 1
237
+ return n
238
+
239
+
240
+ def _type_of(key):
241
+ # Use the function here to get rid of dependencies on the Triton during the codegen.
242
+ # Refer to Triton implementation here:
243
+ # https://github.com/openai/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238
244
+ # `None` is nullptr. Implicitly convert to *i8.
245
+ if key is None:
246
+ return "*i8"
247
+ dtype_str = str(key).split(".")[-1]
248
+ tys = {
249
+ "bool": "i1",
250
+ "float8e4nv": "fp8e4nv",
251
+ "float8e5": "fp8e5",
252
+ "float8e4b15": "fp8e4b15",
253
+ "float8e4b15x4": "fp8e4b15x4",
254
+ "float8_e4m3fn": "fp8e4nv",
255
+ "float8_e5m2": "fp8e5",
256
+ "float16": "fp16",
257
+ "bfloat16": "bf16",
258
+ "float32": "fp32",
259
+ "float64": "fp64",
260
+ "int8": "i8",
261
+ "int16": "i16",
262
+ "int32": "i32",
263
+ "int64": "i64",
264
+ "uint8": "u8",
265
+ "uint16": "u16",
266
+ "uint32": "u32",
267
+ "uint64": "u64",
268
+ }
269
+ # reinterpret can create triton type
270
+ for v in list(tys.values()):
271
+ tys[v] = v
272
+ return key if isinstance(key, str) else f"*{tys[dtype_str]}"
273
+
274
+
275
+ def convert_shape_to_inductor(
276
+ lst: Iterable[Union[int, torch.SymInt]]
277
+ ) -> List[sympy.Expr]:
278
+ """
279
+ Gets the shape and stride of a tensor. For non-symbolic tensors, this is
280
+ trivial. But for symbolic tensors, we need to map from SymIntNode into
281
+ sympy.Expr.
282
+ """
283
+ return [
284
+ i.node.expr if isinstance(i, torch.SymInt) else sympy.Integer(i) for i in lst
285
+ ]
286
+
287
+
288
+ def convert_shape_to_symint(
289
+ lst: Iterable[Union[int, sympy.Expr]]
290
+ ) -> List[Union[int, torch.SymInt]]:
291
+ """
292
+ Takes a list of shapes from Inductor and converts them into symints (or just
293
+ ints if all shapes are static).
294
+ """
295
+ from .virtualized import V
296
+
297
+ return [
298
+ i
299
+ if isinstance(i, int)
300
+ else int(i)
301
+ if isinstance(i, sympy.Integer)
302
+ else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
303
+ for i in lst
304
+ ]
305
+
306
+
307
+ def is_view(op: torch._ops.OpOverload):
308
+ """
309
+ Does this op overload have aliasing
310
+ """
311
+ assert isinstance(op, torch._ops.OpOverload)
312
+ return any(a.alias_info is not None for a in op._schema.arguments)
313
+
314
+
315
+ def is_pointwise_use(use):
316
+ if not use.op == "call_function":
317
+ return False
318
+
319
+ if not (
320
+ isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
321
+ ):
322
+ return False
323
+
324
+ if use.target is operator.getitem or is_view(use.target):
325
+ return all(is_pointwise_use(u) for u in use.users)
326
+
327
+ return torch.Tag.pointwise in use.target.tags
328
+
329
+
330
+ def gen_gm_and_inputs(target, args, kwargs):
331
+ g = torch.fx.Graph()
332
+ g_args = []
333
+ a_args = []
334
+ for n, arg in enumerate(args):
335
+ if isinstance(arg, torch.Tensor):
336
+ g_args.append(g.placeholder(f"arg{n}"))
337
+ a_args.append(arg)
338
+ else:
339
+ g_args.append(arg)
340
+ assert all(not isinstance(x, torch.Tensor) for x in kwargs.values())
341
+ node = g.call_function(target, tuple(g_args), kwargs)
342
+ if (
343
+ len(target._schema.returns) == 1
344
+ and str(target._schema.returns[0].type) == "Tensor"
345
+ ):
346
+ node = (node,)
347
+ g.output(node)
348
+
349
+ gm = torch.fx.GraphModule({}, g)
350
+ return gm, a_args
351
+
352
+
353
+ def synchronize(device: str = "cuda"):
354
+ if device == "cpu":
355
+ return
356
+ device_interface = get_interface_for_device(device)
357
+ if device_interface.is_available():
358
+ device_interface.synchronize()
359
+
360
+
361
+ def timed(
362
+ model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda"
363
+ ) -> float:
364
+ synchronize(device)
365
+ torch.manual_seed(1337)
366
+ t0 = time.perf_counter()
367
+ for _ in range(times):
368
+ result = model(*example_inputs)
369
+ synchronize(device)
370
+ t1 = time.perf_counter()
371
+ # GC the result after timing
372
+ assert result is not None # type: ignore[possibly-undefined]
373
+ return t1 - t0
374
+
375
+
376
+ def print_performance(
377
+ fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda"
378
+ ):
379
+ timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)])
380
+ took = torch.median(timings) / times
381
+ print(f"{took/baseline:.6f}")
382
+ return took
383
+
384
+
385
+ def precompute_method(obj: Any, method: str):
386
+ """Replace obj.method() with a new method that returns a precomputed constant."""
387
+ result = getattr(obj, method)()
388
+ setattr(obj, method, lambda: result)
389
+
390
+
391
+ def precompute_methods(obj: Any, methods: List[str]):
392
+ """Replace methods with new methods that returns a precomputed constants."""
393
+ for method in methods:
394
+ precompute_method(obj, method)
395
+
396
+
397
+ def cmp(a, b) -> int:
398
+ return int(a > b) - int(a < b)
399
+
400
+
401
+ def pad_listlike(x, size):
402
+ if len(x) == 1:
403
+ return type(x)([x[0]]) * size
404
+ else:
405
+ return x
406
+
407
+
408
+ # Used to ensure that iterating over a set is deterministic
409
+ def tuple_sorted(x):
410
+ if len(x) == 0:
411
+ return []
412
+
413
+ def sort_func(elem):
414
+ if isinstance(elem, str):
415
+ return elem
416
+ else:
417
+ # We expect `elem` to be `scheduler.BaseSchedulerNode` type here,
418
+ # but we are not able to do isinstance assert because of circular dependency
419
+ return elem.get_name()
420
+
421
+ return sorted(x, key=sort_func)
422
+
423
+
424
+ P = ParamSpec("P")
425
+ RV = TypeVar("RV", covariant=True)
426
+
427
+
428
+ class CachedMethod(Generic[P, RV], Protocol):
429
+ @staticmethod
430
+ def clear_cache(self) -> None:
431
+ ...
432
+
433
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV:
434
+ ...
435
+
436
+
437
+ # See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
438
+ def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
439
+ key = f"__{fn.__name__}_cache"
440
+
441
+ @functools.wraps(fn)
442
+ def wrapper(self):
443
+ if not hasattr(self, key):
444
+ setattr(self, key, fn(self))
445
+ return getattr(self, key)
446
+
447
+ def clear_cache(self):
448
+ if hasattr(self, key):
449
+ delattr(self, key)
450
+
451
+ wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
452
+ return wrapper # type: ignore[return-value]
453
+
454
+
455
+ def aggregate_origins(node_schedule):
456
+ from . import ir
457
+
458
+ if isinstance(node_schedule, list):
459
+ return functools.reduce(
460
+ operator.or_,
461
+ [
462
+ node.node.origins
463
+ for node in node_schedule
464
+ if hasattr(node, "node") and node.node
465
+ ],
466
+ set(),
467
+ )
468
+ elif isinstance(node_schedule, ir.ExternKernel):
469
+ return node_schedule.origins
470
+ else:
471
+ return set()
472
+
473
+
474
+ def get_fused_kernel_name(node_schedule, descriptive_names):
475
+ all_origins = aggregate_origins(node_schedule)
476
+ if descriptive_names == "original_aten":
477
+ # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
478
+ sources = [
479
+ origin.meta["original_aten"]._overloadpacket.__name__
480
+ for origin in all_origins
481
+ if origin.op == "call_function"
482
+ and "original_aten" in origin.meta
483
+ and origin.meta["original_aten"] is not None
484
+ ]
485
+ sources = sorted(set(sources))
486
+ elif descriptive_names == "torch":
487
+ # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
488
+ sources = []
489
+ for origin in all_origins:
490
+ if origin.op == "call_function" and "source_fn_stack" in origin.meta:
491
+ source_fn = origin.meta["source_fn_stack"][-1]
492
+ if isinstance(source_fn[1], str):
493
+ sources.append(source_fn[1])
494
+ else:
495
+ sources.append(source_fn[1].__name__)
496
+ sources = sorted(set(sources))
497
+ elif descriptive_names == "inductor_node":
498
+ sources = [
499
+ origin.name for origin in all_origins if origin.op == "call_function"
500
+ ]
501
+ else:
502
+ raise NotImplementedError
503
+ sources = sources
504
+ return "_".join(["fused"] + sources)
505
+
506
+
507
+ def get_kernel_metadata(node_schedule, wrapper):
508
+ all_origins = aggregate_origins(node_schedule)
509
+ inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"]
510
+
511
+ from_node_dict = collections.defaultdict(list)
512
+ original_aten_dict = collections.defaultdict(list)
513
+ for node in inductor_nodes:
514
+ if "original_aten" in node.meta and node.meta["original_aten"] is not None:
515
+ key = str(node.meta["original_aten"]._overloadpacket)
516
+ original_aten_dict[key].append(node.name)
517
+ if "from_node" in node.meta:
518
+ key = node.meta["from_node"][0][0]
519
+ from_node_dict[key].append(node.name)
520
+ metadata = (
521
+ f"{wrapper.comment} Source Nodes: [{', '.join(sorted(from_node_dict.keys()))}], "
522
+ f"Original ATen: [{', '.join(sorted(original_aten_dict.keys()))}]"
523
+ )
524
+ # trace back to original node here
525
+ detailed_metadata = []
526
+ for original_node, nodes in sorted(from_node_dict.items()):
527
+ detailed_metadata.append(
528
+ f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}"
529
+ )
530
+ return metadata, "\n".join(detailed_metadata)
531
+
532
+
533
+ def dominated_nodes(
534
+ initial_queue: Iterable[torch.fx.Node], skip_filter=None
535
+ ) -> Set[torch.fx.Node]:
536
+ """Returns the set of nodes whose values depend on those within initial_queue"""
537
+ initial_queue = list(initial_queue)
538
+ dominated_set = set(initial_queue)
539
+
540
+ while initial_queue:
541
+ node = initial_queue.pop()
542
+ for user in node.users:
543
+ if skip_filter and skip_filter(user):
544
+ continue
545
+ if user not in dominated_set:
546
+ dominated_set.add(user)
547
+ initial_queue.append(user)
548
+
549
+ return dominated_set
550
+
551
+
552
+ def gather_origins(args, kwargs):
553
+ import itertools
554
+
555
+ from . import ir
556
+
557
+ def is_unrealized_node(n):
558
+ if isinstance(n, ir.TensorBox):
559
+ return is_unrealized_node(n.data)
560
+ if isinstance(n, ir.StorageBox):
561
+ return is_unrealized_node(n.data)
562
+ return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise)
563
+
564
+ kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)]
565
+ arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)]
566
+ return set(itertools.chain(*arg_origins, *kwarg_origins))
567
+
568
+
569
+ def sympy_str(expr: sympy.Expr) -> str:
570
+ """
571
+ Normal sympy str is very slow, this is a lot faster. The result are
572
+ somewhat worse, as it doesn't do as much simplification. So don't
573
+ use this for final codegen.
574
+ """
575
+ if isinstance(expr, sympy.Symbol):
576
+ return expr.name
577
+ if isinstance(expr, sympy.Add):
578
+ return " + ".join(map(sympy_str, expr.args))
579
+ if isinstance(expr, sympy.Mul):
580
+ return " * ".join(map(sympy_str, expr.args))
581
+
582
+ if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)):
583
+ return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
584
+ return str(expr)
585
+
586
+
587
+ def sympy_index_symbol(name: str) -> sympy.Symbol:
588
+ """
589
+ Used to generate an integer-nonnegative symbol.
590
+ """
591
+ # This should never be used for creating shape/stride symbols, as those
592
+ # should all be allocated before Inductor.
593
+ assert name[0] != "s"
594
+ # NOTE: shape symbols are positive (> 0), but index variables are only
595
+ # non-negative (>= 0).
596
+ return sympy.Symbol(name, integer=True, nonnegative=True)
597
+
598
+
599
+ def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr:
600
+ """
601
+ When the passed replacement symbol v is a string, it is converted to a symbol with name v that
602
+ have the same replaced expression integer and nonnegative properties.
603
+ """
604
+
605
+ def to_symbol(replaced, replacement):
606
+ assert isinstance(replaced, sympy.Expr)
607
+ if isinstance(replacement, str):
608
+ return sympy.Symbol(
609
+ replacement,
610
+ integer=replaced.is_integer, # type: ignore[attr-defined]
611
+ nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
612
+ )
613
+ else:
614
+ return replacement
615
+
616
+ # xreplace is faster than subs, but is way more picky
617
+ return sympy.sympify(expr).xreplace(
618
+ {k: to_symbol(k, v) for k, v in replacements.items()}
619
+ )
620
+
621
+
622
+ def free_symbol_startswith(index: sympy.Expr, prefix: str):
623
+ return any(v.name.startswith(prefix) for v in index.free_symbols) # type: ignore[attr-defined]
624
+
625
+
626
+ def free_symbol_has(index: sympy.Expr, pattern: str):
627
+ return any(pattern in v.name for v in index.free_symbols) # type: ignore[attr-defined]
628
+
629
+
630
+ def is_symbolic(a: Any) -> bool:
631
+ return isinstance(a, torch.SymInt) or (
632
+ isinstance(a, torch.Tensor)
633
+ and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride()))
634
+ )
635
+
636
+
637
+ def any_is_symbolic(*args: Any) -> bool:
638
+ return any(is_symbolic(a) for a in args)
639
+
640
+
641
+ def has_incompatible_cudagraph_ops(gm):
642
+ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
643
+
644
+ forbidden_set = {
645
+ "aten._fused_moving_avg_obs_fq_helper.default",
646
+ "aten._fused_moving_avg_obs_fq_helper_functional.default",
647
+ "aten.multinomial.default",
648
+ "fbgemm.dense_to_jagged.default",
649
+ "fbgemm.jagged_to_padded_dense.default",
650
+ "run_and_save_rng_state",
651
+ "run_with_rng_state",
652
+ "aten._local_scalar_dense",
653
+ # Technically, it's not necessary to ban this, because an
654
+ # assert_scalar with constant arguments can be validly run
655
+ # with CUDA graphs, but the operator is also pointless with
656
+ # constant arguments, so might as well ban
657
+ "aten._assert_scalar",
658
+ }
659
+ if torch.are_deterministic_algorithms_enabled():
660
+ forbidden_set.update(
661
+ {
662
+ "aten._unsafe_index_put.default",
663
+ "aten.index_put.default",
664
+ "aten.index_put_.default",
665
+ "aten.scatter.src",
666
+ "aten.scatter.reduce",
667
+ "aten.scatter.value_reduce",
668
+ "aten.scatter_add_",
669
+ "aten.scatter_add.default",
670
+ "aten.scatter_reduce.two",
671
+ "aten.scatter_reduce_.two",
672
+ "aten.scatter_reduce.two_out",
673
+ }
674
+ )
675
+ for node in gm.graph.nodes:
676
+ if str(node.target) in forbidden_set:
677
+ return True
678
+ if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val):
679
+ return True
680
+ return False
681
+
682
+
683
+ def output_node(gm: torch.fx.GraphModule):
684
+ """Get the output node from an FX graph"""
685
+ last_node = next(iter(reversed(gm.graph.nodes)))
686
+ assert last_node.op == "output"
687
+ return last_node
688
+
689
+
690
+ # Attempt to import AttrsDescriptor from Triton
691
+ try:
692
+ from triton.compiler.compiler import AttrsDescriptor
693
+
694
+ attrs_descriptor_available = True
695
+ # Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor
696
+ attr_desc_fields = {f.name for f in fields(AttrsDescriptor)}
697
+ ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields
698
+ divisible_by_8_available = "divisible_by_8" in attr_desc_fields
699
+ except ImportError:
700
+ attrs_descriptor_available = False
701
+
702
+ # Define `instance_descriptor` function with clear conditional handling
703
+ if attrs_descriptor_available:
704
+
705
+ def instance_descriptor(
706
+ divisible_by_16=None,
707
+ equal_to_1=None,
708
+ ids_of_folded_args=None,
709
+ divisible_by_8=None,
710
+ ):
711
+ # Prepare the arguments for AttrsDescriptor
712
+ kwargs = {
713
+ "divisible_by_16": divisible_by_16,
714
+ "equal_to_1": equal_to_1,
715
+ }
716
+
717
+ # Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor
718
+ if ids_of_folded_args_available:
719
+ kwargs["ids_of_folded_args"] = ids_of_folded_args
720
+ if divisible_by_8_available:
721
+ kwargs["divisible_by_8"] = divisible_by_8
722
+
723
+ # Instantiate AttrsDescriptor with the prepared arguments
724
+ return AttrsDescriptor(**kwargs)
725
+
726
+ else:
727
+ # Define a namedtuple as a fallback when AttrsDescriptor is not available
728
+ instance_descriptor = collections.namedtuple( # type: ignore[no-redef]
729
+ "instance_descriptor",
730
+ ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
731
+ defaults=[tuple(), tuple(), tuple(), tuple()],
732
+ )
733
+
734
+
735
+ @functools.lru_cache(None)
736
+ def cache_dir() -> str:
737
+ cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
738
+ if cache_dir is None:
739
+ sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
740
+ cache_dir = os.path.join(
741
+ tempfile.gettempdir(),
742
+ "torchinductor_" + sanitized_username,
743
+ )
744
+ os.makedirs(cache_dir, exist_ok=True)
745
+ return cache_dir
746
+
747
+
748
+ @contextlib.contextmanager
749
+ def fresh_inductor_cache(cache_entries=None):
750
+ """
751
+ Contextmanager that provides a clean tmp cachedir for inductor.
752
+
753
+ Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
754
+ generated with this cache instance.
755
+ """
756
+ with tempfile.TemporaryDirectory() as inductor_cache_dir:
757
+ with mock.patch.dict(
758
+ os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
759
+ ):
760
+ triton_cache_dir = os.path.join(inductor_cache_dir, "triton")
761
+ with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}):
762
+ yield
763
+ if isinstance(cache_entries, dict):
764
+ assert len(cache_entries) == 0, "expected empty cache_entries dict"
765
+ if os.path.exists(triton_cache_dir):
766
+ files = os.listdir(triton_cache_dir)
767
+ cache_entries.update(
768
+ {
769
+ f: os.path.getsize(os.path.join(triton_cache_dir, f))
770
+ for f in files
771
+ if ".lock" not in f
772
+ }
773
+ )
774
+
775
+
776
+ def argsort(seq) -> List[int]:
777
+ # preserve original order for equal strides
778
+ getter = seq.__getitem__
779
+ a_r = range(len(seq))
780
+ return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
781
+
782
+
783
+ @functools.lru_cache(8)
784
+ def get_dtype_size(dtype):
785
+ return torch.empty((), dtype=dtype).element_size()
786
+
787
+
788
+ class LineContext(NamedTuple):
789
+ context: Any
790
+
791
+
792
+ class IndentedBuffer:
793
+ tabwidth = 4
794
+
795
+ def __init__(self, initial_indent=0):
796
+ self._lines = []
797
+ self._indent = initial_indent
798
+
799
+ def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]:
800
+ buf = StringIO()
801
+ p = 1
802
+ linemap = []
803
+ for line in self._lines:
804
+ if isinstance(line, DeferredLineBase):
805
+ line = line()
806
+ if line is None:
807
+ continue
808
+ elif isinstance(line, LineContext):
809
+ linemap.append((p, line.context))
810
+ continue
811
+ assert isinstance(line, str)
812
+ buf.write(line)
813
+ buf.write("\n")
814
+ p += 1 + line.count("\n")
815
+ return buf.getvalue(), linemap
816
+
817
+ def getvalue(self) -> str:
818
+ v, _ = self.getvaluewithlinemap()
819
+ return v
820
+
821
+ def getrawvalue(self) -> str:
822
+ buf = StringIO()
823
+ for line in self._lines:
824
+ if isinstance(line, DeferredLineBase):
825
+ line = line()
826
+ if line is None:
827
+ continue
828
+ elif isinstance(line, LineContext):
829
+ continue
830
+ assert isinstance(line, str)
831
+ # backslash implies line continuation
832
+ if line.endswith("\\"):
833
+ buf.write(line[:-1])
834
+ else:
835
+ buf.write(line)
836
+ buf.write("\n")
837
+ return buf.getvalue()
838
+
839
+ def clear(self):
840
+ self._lines.clear()
841
+
842
+ def __bool__(self):
843
+ return bool(self._lines)
844
+
845
+ def prefix(self):
846
+ return " " * (self._indent * self.tabwidth)
847
+
848
+ def newline(self):
849
+ self.writeline("\n")
850
+
851
+ def writeline(self, line):
852
+ if isinstance(line, LineContext):
853
+ self._lines.append(line)
854
+ elif isinstance(line, DeferredLineBase):
855
+ self._lines.append(line.with_prefix(self.prefix()))
856
+ elif line.strip():
857
+ self._lines.append(f"{self.prefix()}{line}")
858
+ else:
859
+ self._lines.append("")
860
+
861
+ def writelines(self, lines):
862
+ for line in lines:
863
+ self.writeline(line)
864
+
865
+ def indent(self, offset=1):
866
+ @contextlib.contextmanager
867
+ def ctx():
868
+ self._indent += offset
869
+ try:
870
+ yield
871
+ finally:
872
+ self._indent -= offset
873
+
874
+ return ctx()
875
+
876
+ def do_indent(self, offset=1):
877
+ self._indent += offset
878
+
879
+ def do_unindent(self, offset=1):
880
+ self._indent -= offset
881
+
882
+ def splice(self, other_code, strip=False):
883
+ if isinstance(other_code, IndentedBuffer):
884
+ dedent = float("inf")
885
+ for line in other_code._lines:
886
+ if not isinstance(line, LineContext) and line:
887
+ dedent = min(dedent, len(line) - len(line.lstrip()))
888
+ if math.isinf(dedent):
889
+ dedent = 0
890
+ for line in other_code._lines:
891
+ if isinstance(line, LineContext):
892
+ self._lines.append(line)
893
+ else:
894
+ IndentedBuffer.writeline(self, line[int(dedent) :])
895
+ else:
896
+ other_code = textwrap.dedent(other_code)
897
+ if strip:
898
+ other_code = other_code.lstrip()
899
+ if not other_code:
900
+ return
901
+ other_code = other_code.rstrip()
902
+ for line in other_code.split("\n"):
903
+ self.writeline(line)
904
+
905
+ def __repr__(self):
906
+ return f"{type(self)}({self.getvalue()})"
907
+
908
+
909
+ class DeferredLineBase:
910
+ """A line that can be 'unwritten' at a later time"""
911
+
912
+ def __init__(self, line):
913
+ if not line.strip():
914
+ line = ""
915
+ self.line = line
916
+
917
+ def __call__(self) -> Optional[str]:
918
+ """Returns either self.line or None to indicate the line has been 'unwritten'"""
919
+ raise NotImplementedError()
920
+
921
+ def _new_line(self, line: str) -> DeferredLineBase:
922
+ """Returns a new deferred line with the same condition"""
923
+ raise NotImplementedError()
924
+
925
+ def with_prefix(self, prefix):
926
+ return self._new_line(f"{prefix}{self.line}")
927
+
928
+ def lstrip(self):
929
+ return self._new_line(self.line.lstrip())
930
+
931
+ def __getitem__(self, index):
932
+ return self._new_line(self.line[index])
933
+
934
+ def __bool__(self):
935
+ return bool(self.line)
936
+
937
+ def __len__(self):
938
+ return len(self.line)
939
+
940
+
941
+ @functools.lru_cache(None)
942
+ def is_big_gpu(index):
943
+ sms = torch.cuda.get_device_properties(index).multi_processor_count
944
+ if sms < 80: # V100
945
+ log.warning("not enough SMs to use max_autotune_gemm mode")
946
+ return False
947
+ return True
948
+
949
+
950
+ def use_max_autotune() -> bool:
951
+ return (
952
+ config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache
953
+ )
954
+
955
+
956
+ def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool:
957
+ return (
958
+ use_max_autotune()
959
+ and layout.device.type == "cuda"
960
+ and layout.dtype in allowed_layout_dtypes
961
+ and is_big_gpu(layout.device.index or 0)
962
+ )
963
+
964
+
965
+ def _use_autotune_backend(backend: str) -> bool:
966
+ return backend.upper() in [
967
+ x.strip() for x in config.max_autotune_gemm_backends.upper().split(",")
968
+ ]
969
+
970
+
971
+ def use_triton_template(layout, *, enable_int32=False):
972
+ layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
973
+ if enable_int32:
974
+ layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
975
+ return _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
976
+ "TRITON"
977
+ )
978
+
979
+
980
+ def use_cutlass_template(layout):
981
+ from .codegen.cuda.cutlass_utils import try_import_cutlass
982
+
983
+ # Do not use cutlass template on ROCm
984
+ if torch.version.hip:
985
+ return False
986
+
987
+ layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
988
+ res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
989
+ "CUTLASS"
990
+ )
991
+
992
+ if res:
993
+ if not try_import_cutlass():
994
+ log.warning(
995
+ "Failed to import CUTLASS lib. Please check whether "
996
+ "_inductor.config.cuda.cutlass_dir is set correctly. "
997
+ "Skipping CUTLASS backend for now."
998
+ )
999
+ return False
1000
+ return res
1001
+
1002
+
1003
+ def use_aten_gemm_kernels():
1004
+ return not use_max_autotune() or _use_autotune_backend("ATEN")
1005
+
1006
+
1007
+ class DebugDirManager:
1008
+ counter = itertools.count(0)
1009
+ prev_debug_name: str
1010
+
1011
+ def __init__(self):
1012
+ self.id = next(DebugDirManager.counter)
1013
+
1014
+ def __enter__(self):
1015
+ self.prev_debug_name = torch._dynamo.config.debug_dir_root
1016
+ self.new_name = f"{self.prev_debug_name}_tmp_{self.id}"
1017
+ torch._dynamo.config.debug_dir_root = self.new_name
1018
+
1019
+ def __exit__(self, *args):
1020
+ shutil.rmtree(self.new_name)
1021
+ torch._dynamo.config.debug_dir_root = self.prev_debug_name
1022
+
1023
+
1024
+ def run_and_get_code(fn, *args, **kwargs):
1025
+ from .graph import GraphLowering
1026
+
1027
+ compile_to_module = GraphLowering.compile_to_module
1028
+ source_codes = []
1029
+
1030
+ def patched_compile_to_module(self):
1031
+ mod = compile_to_module(self)
1032
+ with open(mod.__file__) as f:
1033
+ source_codes.append(f.read())
1034
+ return mod
1035
+
1036
+ # If FX code caching is enabled, a hit prevents getting the code.
1037
+ with config.patch({"fx_graph_cache": False}):
1038
+ with mock.patch.object(
1039
+ GraphLowering, "compile_to_module", patched_compile_to_module
1040
+ ):
1041
+ torch._dynamo.reset()
1042
+ result = fn(*args, **kwargs)
1043
+ return result, source_codes
1044
+
1045
+
1046
+ def run_and_get_triton_code(fn, *args, **kwargs):
1047
+ _, source_codes = run_and_get_code(fn, *args, **kwargs)
1048
+ # Can have two outputs if backwards was eagerly compiled
1049
+ assert (
1050
+ 1 <= len(source_codes) <= 2
1051
+ ), f"expected one or two code outputs got {len(source_codes)}"
1052
+ return source_codes[0]
1053
+
1054
+
1055
+ @contextlib.contextmanager
1056
+ def override_lowering(aten_op, override_fn):
1057
+ """
1058
+ Override the lowering of aten_op with override_fn.
1059
+ The first argument of override_fn is the original lowering fn.
1060
+ """
1061
+ from torch._inductor import lowering
1062
+
1063
+ orig_fn = lowering.lowerings[aten_op]
1064
+ try:
1065
+ lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn)
1066
+ yield
1067
+ finally:
1068
+ lowering.lowerings[aten_op] = orig_fn
1069
+
1070
+
1071
+ def add_scheduler_init_hook(pre_fn, post_fn=None):
1072
+ """
1073
+ Add hook functions to be called at the beginning and end of Scheduler.__init__.
1074
+ Used for unit tests.
1075
+ """
1076
+ from torch._inductor.scheduler import Scheduler
1077
+
1078
+ orig_fn = Scheduler.__init__
1079
+
1080
+ def wrapper(scheduler, nodes):
1081
+ pre_fn(scheduler, nodes)
1082
+ out = orig_fn(scheduler, nodes)
1083
+ if post_fn:
1084
+ post_fn(scheduler, nodes)
1085
+ return out
1086
+
1087
+ return unittest.mock.patch.object(Scheduler, "__init__", wrapper)
1088
+
1089
+
1090
+ def developer_warning(msg):
1091
+ """
1092
+ Warnings that will be actionable for PyTorch developers, but not
1093
+ end users. Allows us to easily disable them in stable releases but
1094
+ keep them on for nightly builds.
1095
+ """
1096
+ if config.developer_warnings:
1097
+ log.warning(msg)
1098
+ else:
1099
+ log.info(msg)
1100
+
1101
+
1102
+ def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
1103
+ """
1104
+ Return the total number of bytes the arguments of tensor type takes.
1105
+
1106
+ For in/out args, tensor sizes are counted twice: once for reading and
1107
+ once for writing.
1108
+
1109
+ The first num_in_out_args arguments are in out tensors.
1110
+ """
1111
+ return sum(
1112
+ arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args))
1113
+ for i, arg in enumerate(args)
1114
+ if isinstance(arg, torch.Tensor)
1115
+ )
1116
+
1117
+
1118
+ def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True):
1119
+ info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
1120
+ slow = ms > 0.012 and gb_per_s < 650
1121
+ return red_text(info_str) if color and slow else info_str
1122
+
1123
+
1124
+ def get_benchmark_name():
1125
+ """
1126
+ An experimental API used only when config.benchmark_kernel is true.
1127
+
1128
+ The benchmark name is only available at codegen time. So we can not
1129
+ directly call it in benchmark_all_kernels which is run after codegen.
1130
+
1131
+ The function assumes the argument after --only is the benchmark name.
1132
+ It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc
1133
+ scripts, this function may return None.
1134
+
1135
+ There are 2 flavors of --only argument we need handle:
1136
+ 1. --only model_name
1137
+ 2. --only=model_name
1138
+ """
1139
+ try:
1140
+ idx = sys.argv.index("--only")
1141
+ if (
1142
+ idx + 1 < len(sys.argv)
1143
+ and len(sys.argv[idx + 1]) > 0
1144
+ and sys.argv[idx + 1][0] != "-"
1145
+ ):
1146
+ return sys.argv[idx + 1]
1147
+ except ValueError:
1148
+ pass
1149
+
1150
+ for arg in sys.argv:
1151
+ if arg.startswith("--only="):
1152
+ return arg[len("--only=") :]
1153
+
1154
+
1155
+ def is_ones(items):
1156
+ return all(x == 1 for x in items)
1157
+
1158
+
1159
+ def is_zeros(items):
1160
+ return all(x == 0 for x in items)
1161
+
1162
+
1163
+ def is_cpu_device(inputs):
1164
+ return all(
1165
+ item.device == torch.device("cpu")
1166
+ for item in inputs
1167
+ if isinstance(item, torch.Tensor)
1168
+ )
1169
+
1170
+
1171
+ def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
1172
+ assert isinstance(
1173
+ val, sympy.Expr
1174
+ ), "only support sympy.Expr as input to get_sympy_Expr_dtype"
1175
+ if val.is_integer: # type: ignore[attr-defined]
1176
+ return torch.int64
1177
+ else:
1178
+ return torch.float64
1179
+
1180
+
1181
+ @contextlib.contextmanager
1182
+ def maybe_profile(should_profile, *args, **kwargs):
1183
+ if should_profile:
1184
+ with torch.profiler.profile(*args, **kwargs) as p:
1185
+ yield p
1186
+ else:
1187
+ yield
1188
+
1189
+
1190
+ def triton_config_to_hashable(cfg):
1191
+ """
1192
+ Convert triton config to a tuple that can uniquely identify it. We can use
1193
+ the return value as a dictionary key.
1194
+ """
1195
+ items = sorted(cfg.kwargs.items())
1196
+ items.append(("num_warps", cfg.num_warps))
1197
+ items.append(("num_stages", cfg.num_stages))
1198
+ return tuple(items)
1199
+
1200
+
1201
+ def parallel_num_threads():
1202
+ threads = config.cpp.threads
1203
+ if threads < 1:
1204
+ threads = torch.get_num_threads()
1205
+ return threads
1206
+
1207
+
1208
+ HAS_COLORAMA = True
1209
+ try:
1210
+ import colorama
1211
+ except ImportError:
1212
+ HAS_COLORAMA = False
1213
+
1214
+
1215
+ def _color_text(msg, color):
1216
+ if not HAS_COLORAMA:
1217
+ return msg
1218
+
1219
+ return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
1220
+
1221
+
1222
+ def green_text(msg):
1223
+ return _color_text(msg, "green")
1224
+
1225
+
1226
+ def yellow_text(msg):
1227
+ return _color_text(msg, "yellow")
1228
+
1229
+
1230
+ def red_text(msg):
1231
+ return _color_text(msg, "red")
1232
+
1233
+
1234
+ def blue_text(msg):
1235
+ return _color_text(msg, "blue")
1236
+
1237
+
1238
+ @functools.lru_cache(None)
1239
+ def get_device_tflops(dtype):
1240
+ from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
1241
+
1242
+ assert dtype in (torch.float16, torch.bfloat16, torch.float32)
1243
+
1244
+ if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"):
1245
+ # Triton API change in https://github.com/openai/triton/pull/2293
1246
+ from torch._utils_internal import max_clock_rate
1247
+
1248
+ sm_clock = max_clock_rate()
1249
+ if dtype in (torch.float16, torch.bfloat16):
1250
+ return get_max_tensorcore_tflops(dtype, sm_clock)
1251
+
1252
+ if torch.backends.cuda.matmul.allow_tf32:
1253
+ return get_max_tensorcore_tflops(torch.float32, sm_clock)
1254
+ else:
1255
+ return get_max_simd_tflops(torch.float32, sm_clock)
1256
+ else:
1257
+ if dtype in (torch.float16, torch.bfloat16):
1258
+ return get_max_tensorcore_tflops(dtype)
1259
+
1260
+ if torch.backends.cuda.matmul.allow_tf32:
1261
+ return get_max_tensorcore_tflops(torch.float32)
1262
+ else:
1263
+ return get_max_simd_tflops(torch.float32)
1264
+
1265
+
1266
+ @functools.lru_cache(None)
1267
+ def get_gpu_dram_gbps():
1268
+ from triton.testing import get_dram_gbps
1269
+
1270
+ return get_dram_gbps()
1271
+
1272
+
1273
+ def is_welford_reduction(reduction_type):
1274
+ return reduction_type.startswith("welford")
1275
+
1276
+
1277
+ def reduction_num_outputs(reduction_type):
1278
+ return 3 if is_welford_reduction(reduction_type) else 1
1279
+
1280
+
1281
+ def get_max_y_grid():
1282
+ return 65535
1283
+
1284
+
1285
+ def is_linux() -> bool:
1286
+ return platform.system() == "Linux"
1287
+
1288
+
1289
+ def has_free_symbols(itr: Iterable[Any]):
1290
+ return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr)
1291
+
1292
+
1293
+ def is_dynamic(*args):
1294
+ from . import ir
1295
+
1296
+ for t in args:
1297
+ if isinstance(t, ir.TensorBox):
1298
+ if has_free_symbols(t.data.get_size()) or (
1299
+ hasattr(t.data, "get_stride") and has_free_symbols(t.data.get_stride())
1300
+ ):
1301
+ return True
1302
+ elif isinstance(t, (ir.StorageBox, ir.BaseView, ir.ComputedBuffer)):
1303
+ assert hasattr(t, "get_size") and hasattr(t, "get_stride")
1304
+ if has_free_symbols(t.get_size()) or has_free_symbols(t.get_stride()):
1305
+ return True
1306
+ elif not isinstance(t, ir.IRNode):
1307
+ continue
1308
+ else:
1309
+ raise TypeError(f"unexpected type for is_dynamic {type(t)}")
1310
+
1311
+ return False
1312
+
1313
+
1314
+ # Placeholder strings used in triton codegen.
1315
+ class Placeholder(enum.Enum):
1316
+ # The placeholder for the actual name of a triton kernel.
1317
+ # e.g. for "def triton_" it would be "triton_"
1318
+ KERNEL_NAME = "KERNEL_NAME"
1319
+
1320
+ # The descriptive name of the triton kernel; when unique_kernel_names = False, this
1321
+ # placeholder will be replaced with a string with more information.
1322
+ DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME"
1323
+
1324
+
1325
+ def pass_execution_and_save(func, gm, msg):
1326
+ from .pattern_matcher import stable_topological_sort
1327
+
1328
+ with tempfile.NamedTemporaryFile(
1329
+ mode="w",
1330
+ encoding="utf-8",
1331
+ delete=False,
1332
+ ) as f:
1333
+ before_io = io.StringIO()
1334
+ after_io = io.StringIO()
1335
+ print(f"Before:\n{gm.graph}", file=f)
1336
+ print(gm.graph, file=before_io)
1337
+ start_time = datetime.now()
1338
+ func(gm.graph)
1339
+ time_elapsed = datetime.now() - start_time
1340
+ # recompile graph
1341
+ stable_topological_sort(gm.graph)
1342
+ gm.graph.lint()
1343
+ gm.recompile()
1344
+
1345
+ print(f"After:\n{gm.graph}", file=f)
1346
+ print(gm.graph, file=after_io)
1347
+ t = before_io.getvalue() == after_io.getvalue()
1348
+ log.info(
1349
+ "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s",
1350
+ msg,
1351
+ f.name,
1352
+ t,
1353
+ time_elapsed,
1354
+ )
1355
+
1356
+
1357
+ def is_collective(node):
1358
+ from . import ir
1359
+
1360
+ return isinstance(node, ir.CollectiveKernel) or type(node) == ir._CollectiveKernel
1361
+
1362
+
1363
+ def is_wait(node):
1364
+ from . import ir
1365
+
1366
+ return isinstance(node, ir.Wait) or type(node) == ir._WaitKernel
1367
+
1368
+
1369
+ def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int):
1370
+ "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)"
1371
+ num_rng_seed_offset_inputs = (
1372
+ 2 if torch._functorch.config.functionalize_rng_ops else 0
1373
+ )
1374
+ return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs
1375
+
1376
+
1377
+ def count_tangents(fx_g: torch.fx.GraphModule):
1378
+ """
1379
+ Infers which inputs are static for a backwards graph
1380
+ """
1381
+
1382
+ def is_saved_tensor(x):
1383
+ return (
1384
+ "tangents" not in x.name
1385
+ and "bwd_seed" not in x.name
1386
+ and "bwd_base_offset" not in x.name
1387
+ )
1388
+
1389
+ arg_count = 0
1390
+ static_arg_idxs = []
1391
+ for n in fx_g.graph.nodes:
1392
+ if n.op == "placeholder":
1393
+ if is_saved_tensor(n):
1394
+ static_arg_idxs.append(arg_count)
1395
+ arg_count += 1
1396
+
1397
+ assert static_arg_idxs == list(range(len(static_arg_idxs)))
1398
+ return len(static_arg_idxs)
1399
+
1400
+
1401
+ @dataclasses.dataclass
1402
+ class BoxedBool:
1403
+ value: bool
1404
+
1405
+ def __bool__(self):
1406
+ return self.value
1407
+
1408
+ @staticmethod
1409
+ def disable(obj):
1410
+ if isinstance(obj, BoxedBool):
1411
+ obj.value = False
1412
+ return obj
1413
+ return False
1414
+
1415
+
1416
+ @contextlib.contextmanager
1417
+ def collect_defined_kernels(kernel_list):
1418
+ from .codegen.wrapper import WrapperCodeGen
1419
+
1420
+ orig_define_kernel = WrapperCodeGen.define_kernel
1421
+
1422
+ def new_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs):
1423
+ nonlocal kernel_list
1424
+ kernel_list.append(kernel_code)
1425
+ return orig_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs)
1426
+
1427
+ with unittest.mock.patch.object(WrapperCodeGen, "define_kernel", new_define_kernel):
1428
+ yield
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/wrapper_benchmark.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import tempfile
3
+ from collections import defaultdict
4
+
5
+ import torch
6
+ from torch.autograd import DeviceType
7
+ from .utils import create_bandwidth_info_str, do_bench, get_num_bytes
8
+
9
+ _kernel_category_choices = [
10
+ "foreach",
11
+ "persistent_reduction",
12
+ "pointwise",
13
+ "reduction",
14
+ "split_scan",
15
+ "template",
16
+ ]
17
+
18
+
19
+ def get_kernel_category_by_source_code(src_code):
20
+ """
21
+ Similar to get_kernel_category but use the source code. Call this API
22
+ if we have not compile the src_code to module yet.
23
+ """
24
+ choices = [
25
+ ch for ch in _kernel_category_choices if f"@triton_heuristics.{ch}" in src_code
26
+ ]
27
+ if len(choices) == 1:
28
+ return choices[0]
29
+ else:
30
+ return "unknown"
31
+
32
+
33
+ def get_kernel_category(kernel_mod):
34
+ """
35
+ Given the module defining a triton kernel, return the category of the kernel.
36
+ Category can be one of:
37
+ - pointwise
38
+ - reduction
39
+ - persistent_reduction
40
+
41
+ Currently we simply decide the category depending on what decorator is imported
42
+ by the kernel.
43
+ """
44
+ choices = [ch for ch in _kernel_category_choices if ch in kernel_mod.__dict__]
45
+ if len(choices) == 1:
46
+ return choices[0]
47
+ else:
48
+ return "unknown"
49
+
50
+
51
+ def get_triton_kernel(mod):
52
+ from torch._inductor.triton_heuristics import CachingAutotuner
53
+
54
+ cand_list = [
55
+ v
56
+ for k, v in mod.__dict__.items()
57
+ if k.startswith("triton_") and isinstance(v, CachingAutotuner)
58
+ ]
59
+ assert len(cand_list) == 1
60
+ return cand_list[0]
61
+
62
+
63
+ def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
64
+ """
65
+ An experimental API used only when config.benchmark_kernel is true.
66
+
67
+ Run the kernel benchmarks for all the kernels cached in PyCodeCache.
68
+ Used in the compiled modules.
69
+
70
+ Put this method here rather than codegen it for convenience since its implementation
71
+ does not change based on different graph modules being compiled.
72
+ """
73
+ from torch._inductor.codecache import PyCodeCache
74
+
75
+ nfound = 0
76
+ for kernel_key, kernel_mod in PyCodeCache.cache.items():
77
+ if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
78
+ continue
79
+
80
+ triton_kernel = get_triton_kernel(kernel_mod)
81
+ kernel_category = get_kernel_category(kernel_mod)
82
+ args = kernel_mod.get_args()
83
+ num_in_out_ptrs = len(
84
+ [
85
+ arg_name
86
+ for arg_name in triton_kernel.fn.arg_names
87
+ if arg_name.startswith("in_out_ptr")
88
+ ]
89
+ )
90
+ num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None)
91
+ if num_gb is None:
92
+ num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
93
+
94
+ def get_info_str(ms, n_regs, n_spills, shared, prefix=""):
95
+ if not any(x is None for x in [n_regs, n_spills, shared]):
96
+ kernel_detail_str = (
97
+ f" {n_regs:3} regs {n_spills:3} spills {shared:8} shared mem"
98
+ )
99
+ else:
100
+ kernel_detail_str = ""
101
+
102
+ gb_per_s = num_gb / (ms / 1e3)
103
+ return create_bandwidth_info_str(
104
+ ms, num_gb, gb_per_s, prefix=prefix, suffix=kernel_detail_str
105
+ )
106
+
107
+ kernel_desc = (
108
+ f"{benchmark_name:20} {kernel_category[:3].upper()} {kernel_key[:10]}"
109
+ )
110
+ if benchmark_all_configs:
111
+ assert hasattr(kernel_mod, "benchmark_all_configs")
112
+ bench_result = kernel_mod.benchmark_all_configs(args)
113
+ print(kernel_desc)
114
+ for launcher, ms in bench_result.items():
115
+ print(
116
+ f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}"
117
+ )
118
+ else:
119
+ ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)
120
+ assert (
121
+ len(triton_kernel.launchers) == 1
122
+ ), "Autotuner should have selected the best config"
123
+ launcher = triton_kernel.launchers[0]
124
+ print(
125
+ get_info_str(
126
+ ms,
127
+ launcher.n_regs,
128
+ launcher.n_spills,
129
+ launcher.shared,
130
+ prefix=f"{kernel_desc} ",
131
+ )
132
+ )
133
+
134
+ nfound += 1
135
+ if nfound == 0:
136
+ print(
137
+ "No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True"
138
+ )
139
+
140
+
141
+ @dataclasses.dataclass
142
+ class ProfileEvent:
143
+ category: str
144
+ key: str
145
+ self_cuda_time_ms: float
146
+ # the benchmark is run multiple times and we average the count across all the
147
+ # runs. It should be an integer but define a float just in case.
148
+ count: float
149
+
150
+
151
+ def parse_profile_event_list(benchmark_name, event_list, wall_time_ms, nruns):
152
+ def get_self_cuda_time(ev):
153
+ """
154
+ ev.self_cuda_time_total is in microsecond. Convert to millisecond.
155
+ """
156
+ return ev.self_cuda_time_total / 1000 / nruns
157
+
158
+ all_events = defaultdict(list)
159
+
160
+ def add_event(ev, category):
161
+ profile_ev = ProfileEvent(
162
+ category=category,
163
+ key=ev.key,
164
+ self_cuda_time_ms=get_self_cuda_time(ev),
165
+ count=ev.count / nruns, # average across all runs
166
+ )
167
+ all_events[category].append(profile_ev)
168
+
169
+ for ev in event_list:
170
+ assert not ev.is_legacy, "Don't support the legacy profiler"
171
+ if ev.device_type == DeviceType.CPU:
172
+ # ignore the event on CPU side
173
+ continue
174
+
175
+ category = "unknown"
176
+ if ev.key.startswith("triton_"):
177
+ if ev.key.startswith("triton_poi"):
178
+ category = "triton_pointwise"
179
+ elif ev.key.startswith("triton_red"):
180
+ category = "triton_reduction"
181
+ elif ev.key.startswith("triton_per"):
182
+ category = "triton_persistent_reduction"
183
+ else:
184
+ category = "triton_unknown"
185
+
186
+ add_event(ev, category)
187
+
188
+ def report_category(category, profile_events):
189
+ from tabulate import tabulate
190
+
191
+ profile_events.sort(key=lambda ev: ev.self_cuda_time_ms, reverse=True)
192
+
193
+ rows = []
194
+ total_time = 0.0
195
+ print(f"\n == {category} category kernels == ")
196
+ for ev in profile_events:
197
+ total_time += ev.self_cuda_time_ms
198
+ percent = f"{ev.self_cuda_time_ms / wall_time_ms * 100:.2f}%"
199
+ rows.append([ev.key[:120], ev.self_cuda_time_ms, ev.count, percent])
200
+ rows.append(
201
+ ["Total", total_time, "", f"{total_time / wall_time_ms * 100:.2f}%"]
202
+ )
203
+ print(
204
+ tabulate(
205
+ rows, headers=["Kernel", "Self CUDA TIME (ms)", "Count", "Percent"]
206
+ )
207
+ )
208
+ return total_time
209
+
210
+ def report():
211
+ category_list = [
212
+ "triton_pointwise",
213
+ "triton_reduction",
214
+ "triton_persistent_reduction",
215
+ "triton_unknown",
216
+ "unknown",
217
+ ]
218
+ assert set(all_events.keys()).issubset(
219
+ set(category_list)
220
+ ), f"{list(all_events.keys())}"
221
+
222
+ per_category_wall_time = {}
223
+ total_cuda_ms = 0.0
224
+ for category in category_list:
225
+ if category in all_events:
226
+ _time = report_category(category, all_events[category])
227
+ per_category_wall_time[category] = _time
228
+ total_cuda_ms += _time
229
+
230
+ gpu_busy_percent = f"{total_cuda_ms / wall_time_ms * 100:.2f}%"
231
+ print(f"\nPercent of time when GPU is busy: {gpu_busy_percent}")
232
+ print(f"Total wall time {wall_time_ms:.3f} ms")
233
+
234
+ # output such a line so we can gather such line from all compiled modules from all
235
+ # benchmarks and tabulate it!
236
+ # Columns: benchmark_name, pointwise_percent, reduction_percent, persistent_reduction_percent,
237
+ # unknown_category_percent, GPU_busy_percent, wall_time_ms
238
+ tabulate_line = f"Output for tabulate: {benchmark_name}"
239
+ for category in category_list:
240
+ percent = (
241
+ f"{per_category_wall_time.get(category, 0.0) / wall_time_ms * 100:.2f}%"
242
+ )
243
+ tabulate_line += f", {percent}"
244
+ tabulate_line += f", {gpu_busy_percent}, {wall_time_ms:.3f}ms"
245
+
246
+ print(tabulate_line)
247
+
248
+ report()
249
+
250
+
251
+ def compiled_module_main(benchmark_name, benchmark_compiled_module_fn):
252
+ """
253
+ This is the function called in __main__ block of a compiled module.
254
+ """
255
+ import argparse
256
+
257
+ parser = argparse.ArgumentParser()
258
+ parser.add_argument(
259
+ "--benchmark-kernels",
260
+ "-k",
261
+ action="store_true",
262
+ help="Whether to benchmark each individual kernels",
263
+ )
264
+ parser.add_argument(
265
+ "--benchmark-all-configs",
266
+ "-c",
267
+ action="store_true",
268
+ help="Whether to benchmark each individual config for a kernel",
269
+ )
270
+ parser.add_argument(
271
+ "--profile",
272
+ "-p",
273
+ action="store_true",
274
+ help="Whether to profile the compiled module",
275
+ )
276
+ args = parser.parse_args()
277
+
278
+ if args.benchmark_kernels:
279
+ benchmark_all_kernels(benchmark_name, args.benchmark_all_configs)
280
+ else:
281
+ times = 10
282
+ repeat = 10
283
+ wall_time_ms = benchmark_compiled_module_fn(times=times, repeat=repeat) * 1000
284
+
285
+ if not args.profile:
286
+ return
287
+
288
+ with torch.profiler.profile(record_shapes=True) as p:
289
+ benchmark_compiled_module_fn(times=times, repeat=repeat)
290
+
291
+ path = f"{tempfile.gettempdir()}/compiled_module_profile.json"
292
+ p.export_chrome_trace(path)
293
+ print(f"Profiling result for a compiled module of benchmark {benchmark_name}:")
294
+ print(f"Chrome trace for the profile is written to {path}")
295
+ event_list = p.key_averages(group_by_input_shape=True)
296
+ print(event_list.table(sort_by="self_cuda_time_total", row_limit=10))
297
+ parse_profile_event_list(
298
+ benchmark_name, event_list, wall_time_ms, times * repeat
299
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/DimVector.h>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #include <ATen/core/Dimname.h>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DynamicLibrary.h ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Utils.h>
4
+ #include <c10/macros/Export.h>
5
+ #include <c10/util/Exception.h>
6
+
7
+ namespace c10 {
8
+
9
+ class DynamicLibraryError : public Error {
10
+ using Error::Error;
11
+ };
12
+
13
+ } // namespace c10
14
+
15
+ namespace at {
16
+
17
+ struct DynamicLibrary {
18
+ AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
19
+
20
+ TORCH_API DynamicLibrary(
21
+ const char* name,
22
+ const char* alt_name = nullptr,
23
+ bool leak_handle = false);
24
+
25
+ TORCH_API void* sym(const char* name);
26
+
27
+ TORCH_API ~DynamicLibrary();
28
+
29
+ private:
30
+ bool leak_handle;
31
+ void* handle = nullptr;
32
+ };
33
+
34
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #include <ATen/core/Formatting.h>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ #if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
12
+ #error This change adds a dependency on all pytorch operators, meaning the \
13
+ file will need to be re-compiled every time an operator is changed or added. \
14
+ Consider including a specific operator from \
15
+ <ATen/ops/{my_operator}_meta_dispatch.h>. \
16
+ See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
17
+ #endif
18
+
19
+ #include <ATen/ops/_add_relu_meta_dispatch.h>
20
+ #include <ATen/ops/_addmm_activation_meta_dispatch.h>
21
+ #include <ATen/ops/_amp_update_scale_meta_dispatch.h>
22
+ #include <ATen/ops/_coalesced_meta_dispatch.h>
23
+ #include <ATen/ops/_convert_indices_from_coo_to_csr_meta_dispatch.h>
24
+ #include <ATen/ops/_convert_indices_from_csr_to_coo_meta_dispatch.h>
25
+ #include <ATen/ops/_ctc_loss_meta_dispatch.h>
26
+ #include <ATen/ops/_efficientzerotensor_meta_dispatch.h>
27
+ #include <ATen/ops/_fill_mem_eff_dropout_mask_meta_dispatch.h>
28
+ #include <ATen/ops/_fused_sdp_choice_meta_dispatch.h>
29
+ #include <ATen/ops/_index_put_impl_meta_dispatch.h>
30
+ #include <ATen/ops/_linalg_det_meta_dispatch.h>
31
+ #include <ATen/ops/_linalg_eigh_meta_dispatch.h>
32
+ #include <ATen/ops/_linalg_slogdet_meta_dispatch.h>
33
+ #include <ATen/ops/_linalg_solve_ex_meta_dispatch.h>
34
+ #include <ATen/ops/_linalg_svd_meta_dispatch.h>
35
+ #include <ATen/ops/_log_softmax_meta_dispatch.h>
36
+ #include <ATen/ops/_log_softmax_backward_data_meta_dispatch.h>
37
+ #include <ATen/ops/_mkldnn_transpose_meta_dispatch.h>
38
+ #include <ATen/ops/_reshape_alias_meta_dispatch.h>
39
+ #include <ATen/ops/_resize_output_meta_dispatch.h>
40
+ #include <ATen/ops/_softmax_meta_dispatch.h>
41
+ #include <ATen/ops/_softmax_backward_data_meta_dispatch.h>
42
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_meta_dispatch.h>
43
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_meta_dispatch.h>
44
+ #include <ATen/ops/_upsample_bicubic2d_aa_meta_dispatch.h>
45
+ #include <ATen/ops/_upsample_bicubic2d_aa_backward_meta_dispatch.h>
46
+ #include <ATen/ops/_upsample_bilinear2d_aa_meta_dispatch.h>
47
+ #include <ATen/ops/_upsample_bilinear2d_aa_backward_meta_dispatch.h>
48
+ #include <ATen/ops/_upsample_nearest_exact1d_meta_dispatch.h>
49
+ #include <ATen/ops/_upsample_nearest_exact1d_backward_meta_dispatch.h>
50
+ #include <ATen/ops/_upsample_nearest_exact2d_meta_dispatch.h>
51
+ #include <ATen/ops/_upsample_nearest_exact2d_backward_meta_dispatch.h>
52
+ #include <ATen/ops/_upsample_nearest_exact3d_meta_dispatch.h>
53
+ #include <ATen/ops/_upsample_nearest_exact3d_backward_meta_dispatch.h>
54
+ #include <ATen/ops/acos_meta_dispatch.h>
55
+ #include <ATen/ops/acosh_meta_dispatch.h>
56
+ #include <ATen/ops/adaptive_max_pool2d_meta_dispatch.h>
57
+ #include <ATen/ops/adaptive_max_pool2d_backward_meta_dispatch.h>
58
+ #include <ATen/ops/adaptive_max_pool3d_meta_dispatch.h>
59
+ #include <ATen/ops/adaptive_max_pool3d_backward_meta_dispatch.h>
60
+ #include <ATen/ops/add_meta_dispatch.h>
61
+ #include <ATen/ops/addbmm_meta_dispatch.h>
62
+ #include <ATen/ops/addcdiv_meta_dispatch.h>
63
+ #include <ATen/ops/addcmul_meta_dispatch.h>
64
+ #include <ATen/ops/addmm_meta_dispatch.h>
65
+ #include <ATen/ops/addmv_meta_dispatch.h>
66
+ #include <ATen/ops/all_meta_dispatch.h>
67
+ #include <ATen/ops/amax_meta_dispatch.h>
68
+ #include <ATen/ops/amin_meta_dispatch.h>
69
+ #include <ATen/ops/aminmax_meta_dispatch.h>
70
+ #include <ATen/ops/any_meta_dispatch.h>
71
+ #include <ATen/ops/arange_meta_dispatch.h>
72
+ #include <ATen/ops/argmax_meta_dispatch.h>
73
+ #include <ATen/ops/argmin_meta_dispatch.h>
74
+ #include <ATen/ops/as_strided_meta_dispatch.h>
75
+ #include <ATen/ops/asin_meta_dispatch.h>
76
+ #include <ATen/ops/asinh_meta_dispatch.h>
77
+ #include <ATen/ops/atan_meta_dispatch.h>
78
+ #include <ATen/ops/atan2_meta_dispatch.h>
79
+ #include <ATen/ops/atanh_meta_dispatch.h>
80
+ #include <ATen/ops/avg_pool2d_meta_dispatch.h>
81
+ #include <ATen/ops/avg_pool2d_backward_meta_dispatch.h>
82
+ #include <ATen/ops/avg_pool3d_meta_dispatch.h>
83
+ #include <ATen/ops/avg_pool3d_backward_meta_dispatch.h>
84
+ #include <ATen/ops/baddbmm_meta_dispatch.h>
85
+ #include <ATen/ops/bernoulli_meta_dispatch.h>
86
+ #include <ATen/ops/bitwise_and_meta_dispatch.h>
87
+ #include <ATen/ops/bitwise_left_shift_meta_dispatch.h>
88
+ #include <ATen/ops/bitwise_not_meta_dispatch.h>
89
+ #include <ATen/ops/bitwise_or_meta_dispatch.h>
90
+ #include <ATen/ops/bitwise_right_shift_meta_dispatch.h>
91
+ #include <ATen/ops/bitwise_xor_meta_dispatch.h>
92
+ #include <ATen/ops/bmm_meta_dispatch.h>
93
+ #include <ATen/ops/cat_meta_dispatch.h>
94
+ #include <ATen/ops/cauchy_meta_dispatch.h>
95
+ #include <ATen/ops/ceil_meta_dispatch.h>
96
+ #include <ATen/ops/clamp_meta_dispatch.h>
97
+ #include <ATen/ops/clamp_max_meta_dispatch.h>
98
+ #include <ATen/ops/clamp_min_meta_dispatch.h>
99
+ #include <ATen/ops/copy_sparse_to_sparse_meta_dispatch.h>
100
+ #include <ATen/ops/copysign_meta_dispatch.h>
101
+ #include <ATen/ops/cos_meta_dispatch.h>
102
+ #include <ATen/ops/cosh_meta_dispatch.h>
103
+ #include <ATen/ops/cumprod_meta_dispatch.h>
104
+ #include <ATen/ops/cumsum_meta_dispatch.h>
105
+ #include <ATen/ops/digamma_meta_dispatch.h>
106
+ #include <ATen/ops/div_meta_dispatch.h>
107
+ #include <ATen/ops/elu_meta_dispatch.h>
108
+ #include <ATen/ops/elu_backward_meta_dispatch.h>
109
+ #include <ATen/ops/embedding_renorm_meta_dispatch.h>
110
+ #include <ATen/ops/empty_meta_dispatch.h>
111
+ #include <ATen/ops/empty_strided_meta_dispatch.h>
112
+ #include <ATen/ops/eq_meta_dispatch.h>
113
+ #include <ATen/ops/erf_meta_dispatch.h>
114
+ #include <ATen/ops/erfc_meta_dispatch.h>
115
+ #include <ATen/ops/erfinv_meta_dispatch.h>
116
+ #include <ATen/ops/exp_meta_dispatch.h>
117
+ #include <ATen/ops/exp2_meta_dispatch.h>
118
+ #include <ATen/ops/expm1_meta_dispatch.h>
119
+ #include <ATen/ops/exponential_meta_dispatch.h>
120
+ #include <ATen/ops/eye_meta_dispatch.h>
121
+ #include <ATen/ops/fill_meta_dispatch.h>
122
+ #include <ATen/ops/floor_meta_dispatch.h>
123
+ #include <ATen/ops/floor_divide_meta_dispatch.h>
124
+ #include <ATen/ops/fmax_meta_dispatch.h>
125
+ #include <ATen/ops/fmin_meta_dispatch.h>
126
+ #include <ATen/ops/fmod_meta_dispatch.h>
127
+ #include <ATen/ops/frac_meta_dispatch.h>
128
+ #include <ATen/ops/fractional_max_pool2d_meta_dispatch.h>
129
+ #include <ATen/ops/fractional_max_pool2d_backward_meta_dispatch.h>
130
+ #include <ATen/ops/fractional_max_pool3d_meta_dispatch.h>
131
+ #include <ATen/ops/gather_meta_dispatch.h>
132
+ #include <ATen/ops/gcd_meta_dispatch.h>
133
+ #include <ATen/ops/ge_meta_dispatch.h>
134
+ #include <ATen/ops/gelu_meta_dispatch.h>
135
+ #include <ATen/ops/gelu_backward_meta_dispatch.h>
136
+ #include <ATen/ops/geometric_meta_dispatch.h>
137
+ #include <ATen/ops/glu_meta_dispatch.h>
138
+ #include <ATen/ops/gt_meta_dispatch.h>
139
+ #include <ATen/ops/hardshrink_meta_dispatch.h>
140
+ #include <ATen/ops/hardshrink_backward_meta_dispatch.h>
141
+ #include <ATen/ops/hardsigmoid_meta_dispatch.h>
142
+ #include <ATen/ops/hardsigmoid_backward_meta_dispatch.h>
143
+ #include <ATen/ops/hardswish_meta_dispatch.h>
144
+ #include <ATen/ops/hardtanh_meta_dispatch.h>
145
+ #include <ATen/ops/heaviside_meta_dispatch.h>
146
+ #include <ATen/ops/hypot_meta_dispatch.h>
147
+ #include <ATen/ops/i0_meta_dispatch.h>
148
+ #include <ATen/ops/igamma_meta_dispatch.h>
149
+ #include <ATen/ops/igammac_meta_dispatch.h>
150
+ #include <ATen/ops/index_meta_dispatch.h>
151
+ #include <ATen/ops/index_add_meta_dispatch.h>
152
+ #include <ATen/ops/index_copy_meta_dispatch.h>
153
+ #include <ATen/ops/index_fill_meta_dispatch.h>
154
+ #include <ATen/ops/index_reduce_meta_dispatch.h>
155
+ #include <ATen/ops/isin_meta_dispatch.h>
156
+ #include <ATen/ops/isneginf_meta_dispatch.h>
157
+ #include <ATen/ops/isposinf_meta_dispatch.h>
158
+ #include <ATen/ops/lcm_meta_dispatch.h>
159
+ #include <ATen/ops/le_meta_dispatch.h>
160
+ #include <ATen/ops/leaky_relu_meta_dispatch.h>
161
+ #include <ATen/ops/leaky_relu_backward_meta_dispatch.h>
162
+ #include <ATen/ops/lerp_meta_dispatch.h>
163
+ #include <ATen/ops/lgamma_meta_dispatch.h>
164
+ #include <ATen/ops/linalg_cholesky_ex_meta_dispatch.h>
165
+ #include <ATen/ops/linalg_cross_meta_dispatch.h>
166
+ #include <ATen/ops/linalg_inv_ex_meta_dispatch.h>
167
+ #include <ATen/ops/linalg_ldl_factor_ex_meta_dispatch.h>
168
+ #include <ATen/ops/linalg_ldl_solve_meta_dispatch.h>
169
+ #include <ATen/ops/linalg_lu_meta_dispatch.h>
170
+ #include <ATen/ops/linalg_lu_factor_ex_meta_dispatch.h>
171
+ #include <ATen/ops/linalg_lu_solve_meta_dispatch.h>
172
+ #include <ATen/ops/linalg_qr_meta_dispatch.h>
173
+ #include <ATen/ops/linalg_vector_norm_meta_dispatch.h>
174
+ #include <ATen/ops/linspace_meta_dispatch.h>
175
+ #include <ATen/ops/log_meta_dispatch.h>
176
+ #include <ATen/ops/log10_meta_dispatch.h>
177
+ #include <ATen/ops/log1p_meta_dispatch.h>
178
+ #include <ATen/ops/log2_meta_dispatch.h>
179
+ #include <ATen/ops/log_normal_meta_dispatch.h>
180
+ #include <ATen/ops/logaddexp_meta_dispatch.h>
181
+ #include <ATen/ops/logaddexp2_meta_dispatch.h>
182
+ #include <ATen/ops/logit_meta_dispatch.h>
183
+ #include <ATen/ops/logit_backward_meta_dispatch.h>
184
+ #include <ATen/ops/logspace_meta_dispatch.h>
185
+ #include <ATen/ops/lshift_meta_dispatch.h>
186
+ #include <ATen/ops/lt_meta_dispatch.h>
187
+ #include <ATen/ops/lu_unpack_meta_dispatch.h>
188
+ #include <ATen/ops/masked_fill_meta_dispatch.h>
189
+ #include <ATen/ops/masked_scatter_meta_dispatch.h>
190
+ #include <ATen/ops/max_meta_dispatch.h>
191
+ #include <ATen/ops/max_pool2d_with_indices_meta_dispatch.h>
192
+ #include <ATen/ops/max_pool2d_with_indices_backward_meta_dispatch.h>
193
+ #include <ATen/ops/maximum_meta_dispatch.h>
194
+ #include <ATen/ops/mean_meta_dispatch.h>
195
+ #include <ATen/ops/min_meta_dispatch.h>
196
+ #include <ATen/ops/minimum_meta_dispatch.h>
197
+ #include <ATen/ops/mish_meta_dispatch.h>
198
+ #include <ATen/ops/mm_meta_dispatch.h>
199
+ #include <ATen/ops/mse_loss_meta_dispatch.h>
200
+ #include <ATen/ops/mul_meta_dispatch.h>
201
+ #include <ATen/ops/ne_meta_dispatch.h>
202
+ #include <ATen/ops/neg_meta_dispatch.h>
203
+ #include <ATen/ops/nextafter_meta_dispatch.h>
204
+ #include <ATen/ops/nll_loss_backward_meta_dispatch.h>
205
+ #include <ATen/ops/nll_loss_forward_meta_dispatch.h>
206
+ #include <ATen/ops/norm_meta_dispatch.h>
207
+ #include <ATen/ops/normal_meta_dispatch.h>
208
+ #include <ATen/ops/polygamma_meta_dispatch.h>
209
+ #include <ATen/ops/pow_meta_dispatch.h>
210
+ #include <ATen/ops/prod_meta_dispatch.h>
211
+ #include <ATen/ops/put_meta_dispatch.h>
212
+ #include <ATen/ops/random_meta_dispatch.h>
213
+ #include <ATen/ops/range_meta_dispatch.h>
214
+ #include <ATen/ops/reciprocal_meta_dispatch.h>
215
+ #include <ATen/ops/reflection_pad1d_meta_dispatch.h>
216
+ #include <ATen/ops/reflection_pad1d_backward_meta_dispatch.h>
217
+ #include <ATen/ops/reflection_pad3d_meta_dispatch.h>
218
+ #include <ATen/ops/reflection_pad3d_backward_meta_dispatch.h>
219
+ #include <ATen/ops/relu_meta_dispatch.h>
220
+ #include <ATen/ops/remainder_meta_dispatch.h>
221
+ #include <ATen/ops/renorm_meta_dispatch.h>
222
+ #include <ATen/ops/replication_pad1d_meta_dispatch.h>
223
+ #include <ATen/ops/replication_pad1d_backward_meta_dispatch.h>
224
+ #include <ATen/ops/replication_pad2d_meta_dispatch.h>
225
+ #include <ATen/ops/replication_pad3d_meta_dispatch.h>
226
+ #include <ATen/ops/resize_meta_dispatch.h>
227
+ #include <ATen/ops/resize_as_sparse_meta_dispatch.h>
228
+ #include <ATen/ops/round_meta_dispatch.h>
229
+ #include <ATen/ops/rrelu_with_noise_meta_dispatch.h>
230
+ #include <ATen/ops/rshift_meta_dispatch.h>
231
+ #include <ATen/ops/rsqrt_meta_dispatch.h>
232
+ #include <ATen/ops/scatter_meta_dispatch.h>
233
+ #include <ATen/ops/scatter_add_meta_dispatch.h>
234
+ #include <ATen/ops/scatter_reduce_meta_dispatch.h>
235
+ #include <ATen/ops/set_meta_dispatch.h>
236
+ #include <ATen/ops/sgn_meta_dispatch.h>
237
+ #include <ATen/ops/sigmoid_meta_dispatch.h>
238
+ #include <ATen/ops/sigmoid_backward_meta_dispatch.h>
239
+ #include <ATen/ops/sign_meta_dispatch.h>
240
+ #include <ATen/ops/signbit_meta_dispatch.h>
241
+ #include <ATen/ops/silu_meta_dispatch.h>
242
+ #include <ATen/ops/silu_backward_meta_dispatch.h>
243
+ #include <ATen/ops/sin_meta_dispatch.h>
244
+ #include <ATen/ops/sinc_meta_dispatch.h>
245
+ #include <ATen/ops/sinh_meta_dispatch.h>
246
+ #include <ATen/ops/slow_conv_transpose2d_meta_dispatch.h>
247
+ #include <ATen/ops/smooth_l1_loss_meta_dispatch.h>
248
+ #include <ATen/ops/softplus_meta_dispatch.h>
249
+ #include <ATen/ops/softplus_backward_meta_dispatch.h>
250
+ #include <ATen/ops/softshrink_meta_dispatch.h>
251
+ #include <ATen/ops/softshrink_backward_meta_dispatch.h>
252
+ #include <ATen/ops/sort_meta_dispatch.h>
253
+ #include <ATen/ops/sparse_resize_meta_dispatch.h>
254
+ #include <ATen/ops/sparse_resize_and_clear_meta_dispatch.h>
255
+ #include <ATen/ops/special_airy_ai_meta_dispatch.h>
256
+ #include <ATen/ops/special_bessel_j0_meta_dispatch.h>
257
+ #include <ATen/ops/special_bessel_j1_meta_dispatch.h>
258
+ #include <ATen/ops/special_bessel_y0_meta_dispatch.h>
259
+ #include <ATen/ops/special_bessel_y1_meta_dispatch.h>
260
+ #include <ATen/ops/special_chebyshev_polynomial_t_meta_dispatch.h>
261
+ #include <ATen/ops/special_chebyshev_polynomial_u_meta_dispatch.h>
262
+ #include <ATen/ops/special_chebyshev_polynomial_v_meta_dispatch.h>
263
+ #include <ATen/ops/special_chebyshev_polynomial_w_meta_dispatch.h>
264
+ #include <ATen/ops/special_entr_meta_dispatch.h>
265
+ #include <ATen/ops/special_erfcx_meta_dispatch.h>
266
+ #include <ATen/ops/special_hermite_polynomial_h_meta_dispatch.h>
267
+ #include <ATen/ops/special_hermite_polynomial_he_meta_dispatch.h>
268
+ #include <ATen/ops/special_i0e_meta_dispatch.h>
269
+ #include <ATen/ops/special_i1_meta_dispatch.h>
270
+ #include <ATen/ops/special_i1e_meta_dispatch.h>
271
+ #include <ATen/ops/special_laguerre_polynomial_l_meta_dispatch.h>
272
+ #include <ATen/ops/special_legendre_polynomial_p_meta_dispatch.h>
273
+ #include <ATen/ops/special_log_ndtr_meta_dispatch.h>
274
+ #include <ATen/ops/special_modified_bessel_i0_meta_dispatch.h>
275
+ #include <ATen/ops/special_modified_bessel_i1_meta_dispatch.h>
276
+ #include <ATen/ops/special_modified_bessel_k0_meta_dispatch.h>
277
+ #include <ATen/ops/special_modified_bessel_k1_meta_dispatch.h>
278
+ #include <ATen/ops/special_ndtri_meta_dispatch.h>
279
+ #include <ATen/ops/special_scaled_modified_bessel_k0_meta_dispatch.h>
280
+ #include <ATen/ops/special_scaled_modified_bessel_k1_meta_dispatch.h>
281
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_t_meta_dispatch.h>
282
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_u_meta_dispatch.h>
283
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_v_meta_dispatch.h>
284
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_w_meta_dispatch.h>
285
+ #include <ATen/ops/special_spherical_bessel_j0_meta_dispatch.h>
286
+ #include <ATen/ops/special_xlog1py_meta_dispatch.h>
287
+ #include <ATen/ops/special_zeta_meta_dispatch.h>
288
+ #include <ATen/ops/sqrt_meta_dispatch.h>
289
+ #include <ATen/ops/sub_meta_dispatch.h>
290
+ #include <ATen/ops/sum_meta_dispatch.h>
291
+ #include <ATen/ops/tan_meta_dispatch.h>
292
+ #include <ATen/ops/tanh_meta_dispatch.h>
293
+ #include <ATen/ops/tanh_backward_meta_dispatch.h>
294
+ #include <ATen/ops/threshold_meta_dispatch.h>
295
+ #include <ATen/ops/threshold_backward_meta_dispatch.h>
296
+ #include <ATen/ops/topk_meta_dispatch.h>
297
+ #include <ATen/ops/triangular_solve_meta_dispatch.h>
298
+ #include <ATen/ops/tril_meta_dispatch.h>
299
+ #include <ATen/ops/triu_meta_dispatch.h>
300
+ #include <ATen/ops/trunc_meta_dispatch.h>
301
+ #include <ATen/ops/unfold_meta_dispatch.h>
302
+ #include <ATen/ops/uniform_meta_dispatch.h>
303
+ #include <ATen/ops/upsample_bicubic2d_meta_dispatch.h>
304
+ #include <ATen/ops/upsample_bicubic2d_backward_meta_dispatch.h>
305
+ #include <ATen/ops/upsample_bilinear2d_meta_dispatch.h>
306
+ #include <ATen/ops/upsample_bilinear2d_backward_meta_dispatch.h>
307
+ #include <ATen/ops/upsample_linear1d_meta_dispatch.h>
308
+ #include <ATen/ops/upsample_linear1d_backward_meta_dispatch.h>
309
+ #include <ATen/ops/upsample_nearest1d_meta_dispatch.h>
310
+ #include <ATen/ops/upsample_nearest1d_backward_meta_dispatch.h>
311
+ #include <ATen/ops/upsample_nearest2d_meta_dispatch.h>
312
+ #include <ATen/ops/upsample_nearest2d_backward_meta_dispatch.h>
313
+ #include <ATen/ops/upsample_nearest3d_meta_dispatch.h>
314
+ #include <ATen/ops/upsample_nearest3d_backward_meta_dispatch.h>
315
+ #include <ATen/ops/upsample_trilinear3d_meta_dispatch.h>
316
+ #include <ATen/ops/upsample_trilinear3d_backward_meta_dispatch.h>
317
+ #include <ATen/ops/view_meta_dispatch.h>
318
+ #include <ATen/ops/view_as_complex_meta_dispatch.h>
319
+ #include <ATen/ops/view_as_real_meta_dispatch.h>
320
+ #include <ATen/ops/xlogy_meta_dispatch.h>
321
+ #include <ATen/ops/zero_meta_dispatch.h>
322
+
323
+
324
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/List.h>
3
+ #include <ATen/core/Tensor.h>
4
+ #include <c10/core/impl/TorchDispatchModeTLS.h>
5
+
6
+ #ifndef AT_PER_OPERATOR_HEADERS
7
+ #include <ATen/Functions.h>
8
+ #else
9
+ #include <ATen/ops/equal.h>
10
+ #endif
11
+
12
+ namespace at {
13
+
14
+ // Note [Tensor-subclass-like Tensors]
15
+ // Tensor-subclass-like is defined as:
16
+ // - a Tensor subclass (via __torch_dispatch__ in Python or extending
17
+ // TensorImpl in C++)
18
+ // - anything else that shares the same perils as Tensor subclasses.
19
+ // For example, many Tensor subclasses do not have storage and meta Tensors
20
+ // do not have storage either, so meta Tensors belong here.
21
+ //
22
+ // We should ensure that PyTorch internals supports Tensor-subclass-like
23
+ // objects. In particular, Tensor-subclass-like objects struggle with two
24
+ // classes of operations that are problematic for Tensor subclasses:
25
+ // 1. Because some Tensor subclasses do not have storage, .item() or
26
+ // .data_ptr() calls are not good.
27
+ // 2. Certain in-place operations can eliminate the typing of the Tensor
28
+ // subclass. For example:
29
+ // >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input)
30
+ // If input is a Tensor subclass, then the above ends up either erroring out
31
+ // or returning a regular non-Tensor-subclass Tensor!
32
+
33
+ constexpr auto kFunctorchWrappedTensors = DispatchKeySet(
34
+ {DispatchKey::FuncTorchGradWrapper,
35
+ DispatchKey::FuncTorchBatched,
36
+ DispatchKey::Functionalize});
37
+
38
+ constexpr auto kTensorSubclassLike =
39
+ kFunctorchWrappedTensors |
40
+ DispatchKeySet(
41
+ {// WARNING: DO NOT put combined backend component + functionality keys
42
+ // here, you will incorrectly always match on the functionality key
43
+ // no matter the backend component
44
+ DispatchKey::Batched,
45
+ DispatchKey::Sparse,
46
+ DispatchKey::SparseCsr,
47
+ DispatchKey::Python}) |
48
+ DispatchKeySet(BackendComponent::MetaBit);
49
+
50
+ inline bool isTensorSubclassLike(const Tensor& tensor) {
51
+ if (c10::impl::dispatch_mode_enabled())
52
+ return true;
53
+ auto key_set = tensor.unsafeGetTensorImpl()->key_set();
54
+ return !(key_set & kTensorSubclassLike).empty();
55
+ }
56
+
57
+ inline bool areAnyTensorSubclassLike(TensorList tensors) {
58
+ if (c10::impl::dispatch_mode_enabled())
59
+ return true;
60
+ return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
61
+ }
62
+
63
+ inline bool areAnyOptionalTensorSubclassLike(
64
+ const c10::List<c10::optional<Tensor>>& tensors) {
65
+ if (c10::impl::dispatch_mode_enabled())
66
+ return true;
67
+ return std::any_of(
68
+ tensors.begin(), tensors.end(), [](const optional<Tensor>& opt_tensor) {
69
+ return (
70
+ opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value()));
71
+ });
72
+ }
73
+
74
+ // Helper function to deal testing truthfulness of a scalar tensor
75
+ // in a Composite Compliant manner.
76
+ // NOTE: This function expects a scalar tensor of boolean dtype.
77
+ // Eg.
78
+ // Non-Composite Compliant Pattern : (t == 0).all().item<bool>()
79
+ // Composite Compliant Patter : is_salar_tensor_true((t == 0).all())
80
+ inline bool is_scalar_tensor_true(const Tensor& t) {
81
+ TORCH_INTERNAL_ASSERT(t.dim() == 0)
82
+ TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool)
83
+ return at::equal(t, t.new_ones({}, t.options()));
84
+ }
85
+
86
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda.h>
4
+ #include <cuda_runtime.h>
5
+ #include <cuda_fp16.h>
6
+
7
+ #include <c10/macros/Export.h>
8
+
9
+ // Use TORCH_CUDA_CPP_API or TORCH_CUDA_CU_API for exports from this folder
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContext.h ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/cuda/CUDAContextLight.h>
4
+
5
+ // Preserved for BC, as many files depend on these includes
6
+ #include <ATen/Context.h>
7
+ #include <c10/cuda/CUDAStream.h>
8
+ #include <c10/util/Logging.h>
9
+ #include <ATen/cuda/Exceptions.h>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/ScalarType.h>
4
+
5
+ #include <cuda.h>
6
+ #include <library_types.h>
7
+
8
+ namespace at::cuda {
9
+
10
+ template <typename scalar_t>
11
+ cudaDataType getCudaDataType() {
12
+ TORCH_INTERNAL_ASSERT(false, "Cannot convert type ", typeid(scalar_t).name(), " to cudaDataType.")
13
+ }
14
+
15
+ template<> inline cudaDataType getCudaDataType<at::Half>() {
16
+ return CUDA_R_16F;
17
+ }
18
+ template<> inline cudaDataType getCudaDataType<float>() {
19
+ return CUDA_R_32F;
20
+ }
21
+ template<> inline cudaDataType getCudaDataType<double>() {
22
+ return CUDA_R_64F;
23
+ }
24
+ template<> inline cudaDataType getCudaDataType<c10::complex<c10::Half>>() {
25
+ return CUDA_C_16F;
26
+ }
27
+ template<> inline cudaDataType getCudaDataType<c10::complex<float>>() {
28
+ return CUDA_C_32F;
29
+ }
30
+ template<> inline cudaDataType getCudaDataType<c10::complex<double>>() {
31
+ return CUDA_C_64F;
32
+ }
33
+
34
+ // HIP doesn't define integral types
35
+ #ifndef USE_ROCM
36
+ template<> inline cudaDataType getCudaDataType<uint8_t>() {
37
+ return CUDA_R_8U;
38
+ }
39
+ template<> inline cudaDataType getCudaDataType<int8_t>() {
40
+ return CUDA_R_8I;
41
+ }
42
+ template<> inline cudaDataType getCudaDataType<int>() {
43
+ return CUDA_R_32I;
44
+ }
45
+ #endif
46
+
47
+ #if !defined(USE_ROCM)
48
+ template<> inline cudaDataType getCudaDataType<int16_t>() {
49
+ return CUDA_R_16I;
50
+ }
51
+ template<> inline cudaDataType getCudaDataType<int64_t>() {
52
+ return CUDA_R_64I;
53
+ }
54
+ template<> inline cudaDataType getCudaDataType<at::BFloat16>() {
55
+ return CUDA_R_16BF;
56
+ }
57
+ #endif
58
+
59
+ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) {
60
+ switch (scalar_type) {
61
+ // HIP doesn't define integral types
62
+ #ifndef USE_ROCM
63
+ case c10::ScalarType::Byte:
64
+ return CUDA_R_8U;
65
+ case c10::ScalarType::Char:
66
+ return CUDA_R_8I;
67
+ case c10::ScalarType::Int:
68
+ return CUDA_R_32I;
69
+ #endif
70
+ case c10::ScalarType::Half:
71
+ return CUDA_R_16F;
72
+ case c10::ScalarType::Float:
73
+ return CUDA_R_32F;
74
+ case c10::ScalarType::Double:
75
+ return CUDA_R_64F;
76
+ case c10::ScalarType::ComplexHalf:
77
+ return CUDA_C_16F;
78
+ case c10::ScalarType::ComplexFloat:
79
+ return CUDA_C_32F;
80
+ case c10::ScalarType::ComplexDouble:
81
+ return CUDA_C_64F;
82
+ #if !defined(USE_ROCM)
83
+ case c10::ScalarType::Short:
84
+ return CUDA_R_16I;
85
+ case c10::ScalarType::Long:
86
+ return CUDA_R_64I;
87
+ case c10::ScalarType::BFloat16:
88
+ return CUDA_R_16BF;
89
+ #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
90
+ case c10::ScalarType::Float8_e4m3fn:
91
+ return CUDA_R_8F_E4M3;
92
+ case c10::ScalarType::Float8_e5m2:
93
+ return CUDA_R_8F_E5M2;
94
+ #endif
95
+ #else // USE_ROCM
96
+ case c10::ScalarType::BFloat16:
97
+ return CUDA_R_16BF;
98
+ #if defined(HIP_NEW_TYPE_ENUMS)
99
+ case c10::ScalarType::Float8_e4m3fnuz:
100
+ return HIP_R_8F_E4M3_FNUZ;
101
+ case c10::ScalarType::Float8_e5m2fnuz:
102
+ return HIP_R_8F_E5M2_FNUZ;
103
+ #else
104
+ case c10::ScalarType::Float8_e4m3fnuz:
105
+ return static_cast<hipDataType>(1000);
106
+ case c10::ScalarType::Float8_e5m2fnuz:
107
+ return static_cast<hipDataType>(1001);
108
+ #endif
109
+ #endif
110
+ default:
111
+ TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.")
112
+ }
113
+ }
114
+
115
+ } // namespace at::cuda
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Exceptions.h ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cublas_v2.h>
4
+ #include <cusparse.h>
5
+ #include <c10/macros/Export.h>
6
+
7
+ #ifdef CUDART_VERSION
8
+ #include <cusolver_common.h>
9
+ #endif
10
+
11
+ #include <ATen/Context.h>
12
+ #include <c10/util/Exception.h>
13
+ #include <c10/cuda/CUDAException.h>
14
+
15
+
16
+ namespace c10 {
17
+
18
+ class CuDNNError : public c10::Error {
19
+ using Error::Error;
20
+ };
21
+
22
+ } // namespace c10
23
+
24
+ #define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \
25
+ do { \
26
+ auto error_object = EXPR; \
27
+ if (!error_object.is_good()) { \
28
+ TORCH_CHECK_WITH(CuDNNError, false, \
29
+ "cuDNN Frontend error: ", error_object.get_message()); \
30
+ } \
31
+ } while (0) \
32
+
33
+ #define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__)
34
+
35
+ // See Note [CHECK macro]
36
+ #define AT_CUDNN_CHECK(EXPR, ...) \
37
+ do { \
38
+ cudnnStatus_t status = EXPR; \
39
+ if (status != CUDNN_STATUS_SUCCESS) { \
40
+ if (status == CUDNN_STATUS_NOT_SUPPORTED) { \
41
+ TORCH_CHECK_WITH(CuDNNError, false, \
42
+ "cuDNN error: ", \
43
+ cudnnGetErrorString(status), \
44
+ ". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \
45
+ } else { \
46
+ TORCH_CHECK_WITH(CuDNNError, false, \
47
+ "cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \
48
+ } \
49
+ } \
50
+ } while (0)
51
+
52
+ namespace at::cuda::blas {
53
+ C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
54
+ } // namespace at::cuda::blas
55
+
56
+ #define TORCH_CUDABLAS_CHECK(EXPR) \
57
+ do { \
58
+ cublasStatus_t __err = EXPR; \
59
+ TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \
60
+ "CUDA error: ", \
61
+ at::cuda::blas::_cublasGetErrorEnum(__err), \
62
+ " when calling `" #EXPR "`"); \
63
+ } while (0)
64
+
65
+ const char *cusparseGetErrorString(cusparseStatus_t status);
66
+
67
+ #define TORCH_CUDASPARSE_CHECK(EXPR) \
68
+ do { \
69
+ cusparseStatus_t __err = EXPR; \
70
+ TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \
71
+ "CUDA error: ", \
72
+ cusparseGetErrorString(__err), \
73
+ " when calling `" #EXPR "`"); \
74
+ } while (0)
75
+
76
+ // cusolver related headers are only supported on cuda now
77
+ #ifdef CUDART_VERSION
78
+
79
+ namespace at::cuda::solver {
80
+ C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
81
+
82
+ constexpr const char* _cusolver_backend_suggestion = \
83
+ "If you keep seeing this error, you may use " \
84
+ "`torch.backends.cuda.preferred_linalg_library()` to try " \
85
+ "linear algebra operators with other supported backends. " \
86
+ "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";
87
+
88
+ } // namespace at::cuda::solver
89
+
90
+ // When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
91
+ // When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
92
+ #define TORCH_CUSOLVER_CHECK(EXPR) \
93
+ do { \
94
+ cusolverStatus_t __err = EXPR; \
95
+ if ((CUDA_VERSION < 11500 && \
96
+ __err == CUSOLVER_STATUS_EXECUTION_FAILED) || \
97
+ (CUDA_VERSION >= 11500 && \
98
+ __err == CUSOLVER_STATUS_INVALID_VALUE)) { \
99
+ TORCH_CHECK_LINALG( \
100
+ false, \
101
+ "cusolver error: ", \
102
+ at::cuda::solver::cusolverGetErrorMessage(__err), \
103
+ ", when calling `" #EXPR "`", \
104
+ ". This error may appear if the input matrix contains NaN. ", \
105
+ at::cuda::solver::_cusolver_backend_suggestion); \
106
+ } else { \
107
+ TORCH_CHECK( \
108
+ __err == CUSOLVER_STATUS_SUCCESS, \
109
+ "cusolver error: ", \
110
+ at::cuda::solver::cusolverGetErrorMessage(__err), \
111
+ ", when calling `" #EXPR "`. ", \
112
+ at::cuda::solver::_cusolver_backend_suggestion); \
113
+ } \
114
+ } while (0)
115
+
116
+ #else
117
+ #define TORCH_CUSOLVER_CHECK(EXPR) EXPR
118
+ #endif
119
+
120
+ #define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR)
121
+
122
+ // For CUDA Driver API
123
+ //
124
+ // This is here instead of in c10 because NVRTC is loaded dynamically via a stub
125
+ // in ATen, and we need to use its nvrtcGetErrorString.
126
+ // See NOTE [ USE OF NVRTC AND DRIVER API ].
127
+ #if !defined(USE_ROCM)
128
+
129
+ #define AT_CUDA_DRIVER_CHECK(EXPR) \
130
+ do { \
131
+ CUresult __err = EXPR; \
132
+ if (__err != CUDA_SUCCESS) { \
133
+ const char* err_str; \
134
+ CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \
135
+ if (get_error_str_err != CUDA_SUCCESS) { \
136
+ AT_ERROR("CUDA driver error: unknown error"); \
137
+ } else { \
138
+ AT_ERROR("CUDA driver error: ", err_str); \
139
+ } \
140
+ } \
141
+ } while (0)
142
+
143
+ #else
144
+
145
+ #define AT_CUDA_DRIVER_CHECK(EXPR) \
146
+ do { \
147
+ CUresult __err = EXPR; \
148
+ if (__err != CUDA_SUCCESS) { \
149
+ AT_ERROR("CUDA driver error: ", static_cast<int>(__err)); \
150
+ } \
151
+ } while (0)
152
+
153
+ #endif
154
+
155
+ // For CUDA NVRTC
156
+ //
157
+ // Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE,
158
+ // incorrectly produces the error string "NVRTC unknown error."
159
+ // The following maps it correctly.
160
+ //
161
+ // This is here instead of in c10 because NVRTC is loaded dynamically via a stub
162
+ // in ATen, and we need to use its nvrtcGetErrorString.
163
+ // See NOTE [ USE OF NVRTC AND DRIVER API ].
164
+ #define AT_CUDA_NVRTC_CHECK(EXPR) \
165
+ do { \
166
+ nvrtcResult __err = EXPR; \
167
+ if (__err != NVRTC_SUCCESS) { \
168
+ if (static_cast<int>(__err) != 7) { \
169
+ AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \
170
+ } else { \
171
+ AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \
172
+ } \
173
+ } \
174
+ } while (0)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSDevice.h ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+ #include <c10/core/Allocator.h>
5
+ #include <c10/macros/Macros.h>
6
+ #include <c10/util/Exception.h>
7
+
8
+
9
+ #ifdef __OBJC__
10
+ #include <Foundation/Foundation.h>
11
+ #include <Metal/Metal.h>
12
+ #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
13
+ typedef id<MTLDevice> MTLDevice_t;
14
+ typedef id<MTLLibrary> MTLLibrary_t;
15
+ typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
16
+ typedef id<MTLLibrary> MTLLibrary_t;
17
+ #else
18
+ typedef void* MTLDevice;
19
+ typedef void* MTLDevice_t;
20
+ typedef void* MTLLibrary_t;
21
+ typedef void* MTLComputePipelineState_t;
22
+ typedef void* MTLLibrary_t;
23
+ #endif
24
+
25
+ using namespace std;
26
+
27
+ namespace at::mps {
28
+
29
+ // Helper enum to check if a MPSGraph op is supported in a given macOS version
30
+ enum class MacOSVersion : uint32_t {
31
+ MACOS_VER_13_0_PLUS = 0,
32
+ MACOS_VER_13_1_PLUS,
33
+ MACOS_VER_13_2_PLUS,
34
+ MACOS_VER_13_3_PLUS,
35
+ MACOS_VER_14_0_PLUS,
36
+ };
37
+
38
+ //-----------------------------------------------------------------
39
+ // MPSDevice
40
+ //
41
+ // MPSDevice is a singleton class that returns the default device
42
+ //-----------------------------------------------------------------
43
+
44
+ class TORCH_API MPSDevice {
45
+ public:
46
+ /**
47
+ * MPSDevice should not be cloneable.
48
+ */
49
+ MPSDevice(MPSDevice& other) = delete;
50
+ /**
51
+ * MPSDevice should not be assignable.
52
+ */
53
+ void operator=(const MPSDevice&) = delete;
54
+ /**
55
+ * Gets single instance of the Device.
56
+ */
57
+ static MPSDevice* getInstance();
58
+ /**
59
+ * Returns the single device.
60
+ */
61
+ MTLDevice_t device() {
62
+ return _mtl_device;
63
+ }
64
+ /**
65
+ * Returns whether running on Ventura or newer
66
+ */
67
+ bool isMacOS13Plus(MacOSVersion version) const;
68
+
69
+ MTLComputePipelineState_t metalIndexingPSO(const std::string &kernel);
70
+ MTLLibrary_t getMetalIndexingLibrary();
71
+
72
+ ~MPSDevice();
73
+
74
+ private:
75
+ static MPSDevice* _device;
76
+ MTLDevice_t _mtl_device;
77
+ MTLLibrary_t _mtl_indexing_library;
78
+ MPSDevice();
79
+ };
80
+
81
+ TORCH_API bool is_available();
82
+ TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
83
+ TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
84
+
85
+ } // namespace at::mps
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AmpKernels.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <ATen/core/ATen_fwd.h>
5
+
6
+ namespace at {
7
+ class Tensor;
8
+
9
+ namespace native {
10
+
11
+ using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)(
12
+ TensorList,
13
+ Tensor&,
14
+ const Tensor&);
15
+
16
+ using _amp_update_scale_cpu__fn = Tensor& (*)(
17
+ Tensor&,
18
+ Tensor&,
19
+ const Tensor&,
20
+ double,
21
+ double,
22
+ int64_t);
23
+
24
+ DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub);
25
+ DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub);
26
+
27
+ } // namespace native
28
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUBlas.h ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/OpMathType.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <ATen/native/TransposeType.h>
6
+ #include <c10/util/complex.h>
7
+ #include <c10/core/ScalarType.h>
8
+ #include <c10/core/Scalar.h>
9
+
10
+ namespace at::native::cpublas {
11
+
12
+ namespace internal {
13
+ void normalize_last_dims(
14
+ TransposeType transa, TransposeType transb,
15
+ int64_t m, int64_t n, int64_t k,
16
+ int64_t *lda, int64_t *ldb, int64_t *ldc);
17
+ } // namespace internal
18
+
19
+ using gemm_fn = void(*)(
20
+ at::ScalarType type,
21
+ TransposeType transa, TransposeType transb,
22
+ int64_t m, int64_t n, int64_t k,
23
+ const Scalar& alpha,
24
+ const void *a, int64_t lda,
25
+ const void *b, int64_t ldb,
26
+ const Scalar& beta,
27
+ void *c, int64_t ldc);
28
+
29
+ DECLARE_DISPATCH(gemm_fn, gemm_stub);
30
+
31
+ template <typename scalar_t>
32
+ void gemm(
33
+ TransposeType transa, TransposeType transb,
34
+ int64_t m, int64_t n, int64_t k,
35
+ at::opmath_type<scalar_t> alpha,
36
+ const scalar_t *a, int64_t lda,
37
+ const scalar_t *b, int64_t ldb,
38
+ at::opmath_type<scalar_t> beta,
39
+ scalar_t *c, int64_t ldc) {
40
+ internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
41
+ gemm_stub(
42
+ kCPU, c10::CppTypeToScalarType<scalar_t>::value,
43
+ transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
44
+ }
45
+
46
+ void gemm(
47
+ TransposeType transa, TransposeType transb,
48
+ int64_t m, int64_t n, int64_t k,
49
+ double alpha,
50
+ const double *a, int64_t lda,
51
+ const double *b, int64_t ldb,
52
+ double beta,
53
+ double *c, int64_t ldc);
54
+
55
+ void gemm(
56
+ TransposeType transa, TransposeType transb,
57
+ int64_t m, int64_t n, int64_t k,
58
+ float alpha,
59
+ const float *a, int64_t lda,
60
+ const float *b, int64_t ldb,
61
+ float beta,
62
+ float *c, int64_t ldc);
63
+
64
+ void gemm(
65
+ TransposeType transa, TransposeType transb,
66
+ int64_t m, int64_t n, int64_t k,
67
+ float alpha,
68
+ const at::BFloat16 *a, int64_t lda,
69
+ const at::BFloat16 *b, int64_t ldb,
70
+ float beta,
71
+ at::BFloat16 *c, int64_t ldc);
72
+
73
+ void gemm(
74
+ TransposeType transa, TransposeType transb,
75
+ int64_t m, int64_t n, int64_t k,
76
+ const float alpha,
77
+ const at::BFloat16 *a, int64_t lda,
78
+ const at::BFloat16 *b, int64_t ldb,
79
+ const float beta,
80
+ float *c, int64_t ldc);
81
+
82
+ void gemm(
83
+ TransposeType transa, TransposeType transb,
84
+ int64_t m, int64_t n, int64_t k,
85
+ float alpha,
86
+ const at::Half *a, int64_t lda,
87
+ const at::Half *b, int64_t ldb,
88
+ float beta,
89
+ at::Half *c, int64_t ldc);
90
+
91
+ void gemm(
92
+ TransposeType transa, TransposeType transb,
93
+ int64_t m, int64_t n, int64_t k,
94
+ const float alpha,
95
+ const at::Half *a, int64_t lda,
96
+ const at::Half *b, int64_t ldb,
97
+ const float beta,
98
+ float *c, int64_t ldc);
99
+
100
+ void gemm(
101
+ TransposeType transa, TransposeType transb,
102
+ int64_t m, int64_t n, int64_t k,
103
+ c10::complex<double> alpha,
104
+ const c10::complex<double> *a, int64_t lda,
105
+ const c10::complex<double> *b, int64_t ldb,
106
+ c10::complex<double> beta,
107
+ c10::complex<double> *c, int64_t ldc);
108
+
109
+ void gemm(
110
+ TransposeType transa, TransposeType transb,
111
+ int64_t m, int64_t n, int64_t k,
112
+ c10::complex<float> alpha,
113
+ const c10::complex<float> *a, int64_t lda,
114
+ const c10::complex<float> *b, int64_t ldb,
115
+ c10::complex<float> beta,
116
+ c10::complex<float> *c, int64_t ldc);
117
+
118
+ void gemm(
119
+ TransposeType transa, TransposeType transb,
120
+ int64_t m, int64_t n, int64_t k,
121
+ int64_t alpha,
122
+ const int64_t *a, int64_t lda,
123
+ const int64_t *b, int64_t ldb,
124
+ int64_t beta,
125
+ int64_t *c, int64_t ldc);
126
+
127
+ template <typename scalar_t>
128
+ void gemm_batched(
129
+ TransposeType transa, TransposeType transb,
130
+ int64_t batch_size, int64_t m, int64_t n, int64_t k,
131
+ scalar_t alpha,
132
+ const scalar_t * const *a, int64_t lda,
133
+ const scalar_t * const *b, int64_t ldb,
134
+ const scalar_t beta,
135
+ scalar_t * const *c, int64_t ldc);
136
+
137
+ template <typename scalar_t>
138
+ void gemm_batched_with_stride(
139
+ TransposeType transa, TransposeType transb,
140
+ int64_t batch_size, int64_t m, int64_t n, int64_t k,
141
+ scalar_t alpha,
142
+ const scalar_t *a, int64_t lda, int64_t batch_stride_a,
143
+ const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
144
+ scalar_t beta,
145
+ scalar_t *c, int64_t ldc, int64_t batch_stride_c);
146
+
147
+ using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
148
+
149
+ DECLARE_DISPATCH(axpy_fn, axpy_stub);
150
+
151
+ template<typename scalar_t>
152
+ void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
153
+ if(n == 1)
154
+ {
155
+ incx = 1;
156
+ incy = 1;
157
+ }
158
+ axpy_stub(
159
+ kCPU, c10::CppTypeToScalarType<scalar_t>::value,
160
+ n, a, x, incx, y, incy);
161
+ }
162
+
163
+ void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
164
+ void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
165
+ void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
166
+ void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
167
+
168
+ using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
169
+
170
+ DECLARE_DISPATCH(copy_fn, copy_stub);
171
+
172
+ template<typename scalar_t>
173
+ void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
174
+ if(n == 1)
175
+ {
176
+ incx = 1;
177
+ incy = 1;
178
+ }
179
+ copy_stub(
180
+ kCPU, c10::CppTypeToScalarType<scalar_t>::value,
181
+ n, x, incx, y, incy);
182
+ }
183
+
184
+ void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
185
+ void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
186
+ void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
187
+ void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
188
+
189
+ } // namespace at::native::cpublas
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <utility>
2
+
3
+ #pragma once
4
+
5
+ namespace at::native {
6
+
7
+ namespace {
8
+
9
+ // operator_brackets_proxy is used in
10
+ // CompositeRandomAccessor in place of operator[].
11
+ // For some iterators, references returned by operator[]
12
+ // could become invalid, operator_brackets_proxy tries to
13
+ // resolve that by making accessor[n] to be equivalent to
14
+ // *(accessor + n).
15
+ template <typename Accessor>
16
+ class operator_brackets_proxy {
17
+ using reference = typename std::iterator_traits<Accessor>::reference;
18
+ using value_type = typename std::iterator_traits<Accessor>::value_type;
19
+
20
+ public:
21
+ C10_HOST_DEVICE
22
+ operator_brackets_proxy(Accessor const& accessor)
23
+ : accessor(accessor)
24
+ {}
25
+
26
+ C10_HOST_DEVICE
27
+ operator reference() {
28
+ return *accessor;
29
+ }
30
+
31
+ C10_HOST_DEVICE
32
+ reference operator*() {
33
+ return *accessor;
34
+ }
35
+
36
+ C10_HOST_DEVICE
37
+ operator_brackets_proxy& operator=(value_type const& val) {
38
+ *accessor = val;
39
+ return *this;
40
+ }
41
+
42
+ private:
43
+ Accessor accessor;
44
+ };
45
+
46
+ }
47
+
48
+ // references_holder is used as a surrogate for the
49
+ // references type from std::iterator_traits in CompositeRandomAccessor.
50
+ // It is assumed in CompositeRandomAccessor that
51
+ // References = tuple<Types&...>,
52
+ // Values = tuple<Types...> by default,
53
+ // but they could be anything as long as References could be
54
+ // cast to Values.
55
+ // If you plan to use it with STL, for example, you will need to
56
+ // define 'swap` and `get`(aka std::get) methods.
57
+ template <typename Values, typename References>
58
+ class references_holder {
59
+ public:
60
+ using values = Values;
61
+ using references = References;
62
+
63
+ C10_HOST_DEVICE
64
+ references_holder(references refs)
65
+ : refs{std::move(refs)}
66
+ {}
67
+
68
+ C10_HOST_DEVICE
69
+ operator references() {
70
+ return refs;
71
+ }
72
+
73
+ C10_HOST_DEVICE
74
+ operator values() {
75
+ return refs;
76
+ }
77
+
78
+ C10_HOST_DEVICE
79
+ references_holder& operator=(values vals) {
80
+ refs = vals;
81
+ return *this;
82
+ }
83
+
84
+ C10_HOST_DEVICE
85
+ references& data() {
86
+ return refs;
87
+ }
88
+
89
+ protected:
90
+ references refs;
91
+ };
92
+
93
+ // CompositeRandomAccessor is essentially a simplified version of
94
+ // a random access iterator over two random access iterators.
95
+ // TupleInfo should contain a variadic type `tuple`, and a method `tie`,
96
+ // which constructs a tuple of references from a variadic list of arguments.
97
+ template <typename KeyAccessor, typename ValueAccessor, typename TupleInfo>
98
+ class CompositeRandomAccessor {
99
+ using self_type = CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfo>;
100
+
101
+ using key_accessor_value_type =
102
+ typename std::iterator_traits<KeyAccessor>::value_type;
103
+ using value_accessor_value_type =
104
+ typename std::iterator_traits<ValueAccessor>::value_type;
105
+ using key_accessor_reference_type =
106
+ typename std::iterator_traits<KeyAccessor>::reference;
107
+ using value_accessor_reference_type =
108
+ typename std::iterator_traits<ValueAccessor>::reference;
109
+
110
+ using composite_value_type = typename TupleInfo::template tuple<
111
+ key_accessor_value_type,
112
+ value_accessor_value_type>;
113
+ using composite_reference = typename TupleInfo::template tuple<
114
+ key_accessor_reference_type,
115
+ value_accessor_reference_type>;
116
+
117
+ public:
118
+ using value_type = composite_value_type;
119
+ using reference = references_holder<composite_value_type, composite_reference>;
120
+ // Note that CompositeRandomAccessor does not hold key and values
121
+ // in a specific datastructure, which means that a pointer to a (key, value)
122
+ // is not defined. Hence we just use a pointer type of the KeyAccessor.
123
+ using pointer = typename std::iterator_traits<KeyAccessor>::pointer;
124
+ using difference_type = typename std::iterator_traits<KeyAccessor>::difference_type;
125
+ using iterator_category = std::random_access_iterator_tag;
126
+
127
+ C10_HOST_DEVICE
128
+ CompositeRandomAccessor() = default;
129
+
130
+ C10_HOST_DEVICE
131
+ CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values)
132
+ : keys(keys), values(values)
133
+ {}
134
+
135
+ // Pointer-like operations {
136
+ C10_HOST_DEVICE
137
+ reference operator*() const {
138
+ return TupleInfo::tie(*keys, *values);
139
+ }
140
+
141
+ // operator->() is supposed to return a pointer type.
142
+ // Since CompositeRandomAccessor does not hold pointers to pairs,
143
+ // we just return a pointer to a key.
144
+ C10_HOST_DEVICE
145
+ auto* operator->() const {
146
+ return keys.operator->();
147
+ }
148
+
149
+ C10_HOST_DEVICE
150
+ reference operator[](difference_type idx) {
151
+ return operator_brackets_proxy<self_type>(
152
+ CompositeRandomAccessor(keys + idx, values + idx)
153
+ );
154
+ }
155
+ // }
156
+
157
+ // Prefix/postfix increment/decrement {
158
+ C10_HOST_DEVICE
159
+ CompositeRandomAccessor& operator++() {
160
+ ++keys;
161
+ ++values;
162
+ return *this;
163
+ }
164
+
165
+ C10_HOST_DEVICE
166
+ CompositeRandomAccessor operator++(int) {
167
+ CompositeRandomAccessor copy(*this);
168
+ ++*this;
169
+ return copy;
170
+ }
171
+
172
+ C10_HOST_DEVICE
173
+ CompositeRandomAccessor& operator--() {
174
+ --keys;
175
+ --values;
176
+ return *this;
177
+ }
178
+
179
+ C10_HOST_DEVICE
180
+ CompositeRandomAccessor operator--(int) {
181
+ CompositeRandomAccessor copy(*this);
182
+ --*this;
183
+ return copy;
184
+ }
185
+ // }
186
+
187
+ // Arithmetic operations {
188
+ C10_HOST_DEVICE
189
+ CompositeRandomAccessor& operator+=(difference_type offset) {
190
+ keys += offset;
191
+ values += offset;
192
+ return *this;
193
+ }
194
+
195
+ C10_HOST_DEVICE
196
+ CompositeRandomAccessor operator+(difference_type offset) const {
197
+ return CompositeRandomAccessor(keys + offset, values + offset);
198
+ }
199
+
200
+ C10_HOST_DEVICE
201
+ friend CompositeRandomAccessor operator+(
202
+ difference_type offset,
203
+ const CompositeRandomAccessor& accessor
204
+ ) {
205
+ return accessor + offset;
206
+ }
207
+
208
+ C10_HOST_DEVICE
209
+ CompositeRandomAccessor& operator-=(difference_type offset) {
210
+ keys -= offset;
211
+ values -= offset;
212
+ return *this;
213
+ }
214
+
215
+ C10_HOST_DEVICE
216
+ CompositeRandomAccessor operator-(difference_type offset) const {
217
+ return CompositeRandomAccessor(keys - offset, values - offset);
218
+ }
219
+
220
+ C10_HOST_DEVICE
221
+ difference_type operator-(const CompositeRandomAccessor& other) const {
222
+ return keys - other.keys;
223
+ }
224
+ // }
225
+
226
+ // Comparison operators {
227
+ C10_HOST_DEVICE
228
+ bool operator==(const CompositeRandomAccessor& other) const {
229
+ return keys == other.keys;
230
+ }
231
+
232
+ C10_HOST_DEVICE
233
+ bool operator!=(const CompositeRandomAccessor& other) const {
234
+ return keys != other.keys;
235
+ }
236
+
237
+ C10_HOST_DEVICE
238
+ bool operator<(const CompositeRandomAccessor& other) const {
239
+ return keys < other.keys;
240
+ }
241
+
242
+ C10_HOST_DEVICE
243
+ bool operator<=(const CompositeRandomAccessor& other) const {
244
+ return keys <= other.keys;
245
+ }
246
+
247
+ C10_HOST_DEVICE
248
+ bool operator>(const CompositeRandomAccessor& other) const {
249
+ return keys > other.keys;
250
+ }
251
+
252
+ C10_HOST_DEVICE
253
+ bool operator>=(const CompositeRandomAccessor& other) const {
254
+ return keys >= other.keys;
255
+ }
256
+ // }
257
+
258
+ protected:
259
+ KeyAccessor keys;
260
+ ValueAccessor values;
261
+ };
262
+
263
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvolutionMM3d.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+
3
+ namespace at::native {
4
+
5
+ std::tuple<Tensor, Tensor, Tensor> slow_conv3d_backward_cpu(
6
+ const Tensor& grad_output,
7
+ const Tensor& self,
8
+ const Tensor& weight,
9
+ IntArrayRef kernel_size,
10
+ IntArrayRef stride,
11
+ IntArrayRef padding,
12
+ std::array<bool, 3> output_mask);
13
+
14
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Copy.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+
5
+ namespace at {
6
+
7
+ class Tensor;
8
+ struct TensorIterator;
9
+ class TensorBase;
10
+
11
+ namespace native {
12
+
13
+ using copy_fn = void (*)(TensorIterator&, bool non_blocking);
14
+
15
+ DECLARE_DISPATCH(copy_fn, copy_stub);
16
+
17
+ TORCH_API void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src);
18
+
19
+ } // namespace native
20
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <algorithm>
4
+ #include <vector>
5
+
6
+ #include <ATen/div_rtn.h>
7
+ #include <ATen/core/Tensor.h>
8
+ #include <c10/util/irange.h>
9
+
10
+ #define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
11
+ TORCH_CHECK( \
12
+ T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \
13
+ "Need " #T " of dimension ", \
14
+ DIM, \
15
+ " and " #T ".size[", \
16
+ DIM_SIZE, \
17
+ "] == ", \
18
+ SIZE, \
19
+ " but got input to be of shape ", \
20
+ T.sizes())
21
+
22
+ namespace at::native::internal {
23
+ namespace {
24
+ inline bool all_positive(IntArrayRef& arr) {
25
+ return std::all_of(
26
+ arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
27
+ }
28
+
29
+ inline bool all_nonnegative(std::vector<int64_t>& arr) {
30
+ return std::all_of(
31
+ arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
32
+ }
33
+
34
+ } // namespace
35
+
36
+ // calculate the rear part of output tensor sizes
37
+ template <int64_t dim>
38
+ std::vector<int64_t> get_output_size(
39
+ const Tensor& input,
40
+ IntArrayRef kernel_size,
41
+ IntArrayRef stride_size,
42
+ IntArrayRef pad_size,
43
+ IntArrayRef dilation_size) {
44
+ std::vector<int64_t> sizes;
45
+ for (const auto index : c10::irange(dim)) {
46
+ sizes.push_back(
47
+ div_rtn<int64_t>(
48
+ input.size(index + input.dim() - dim) + 2 * pad_size[index] -
49
+ (dilation_size[index] * (kernel_size[index] - 1) + 1),
50
+ stride_size[index]) +
51
+ 1);
52
+ }
53
+ return sizes;
54
+ }
55
+
56
+ // calculate the sizes of output tensor
57
+ template <int64_t dim>
58
+ std::vector<int64_t> get_output_size(
59
+ const Tensor& input,
60
+ const Tensor& weight,
61
+ IntArrayRef kernel_size,
62
+ IntArrayRef stride_size,
63
+ IntArrayRef pad_size,
64
+ IntArrayRef dilation_size) {
65
+ auto output_size = get_output_size<dim>(
66
+ input, kernel_size, stride_size, pad_size, dilation_size);
67
+ output_size.insert(output_size.begin(), weight.size(0));
68
+ if (input.dim() == dim + 2) {
69
+ output_size.insert(output_size.begin(), input.size(0));
70
+ }
71
+ return output_size;
72
+ }
73
+ /*
74
+ slow_conv_dilated_shape_check - check user-input to dilated convolution
75
+ forward and backward functions.
76
+ */
77
+ template <int64_t dim>
78
+ void slow_conv_dilated_shape_check(
79
+ const Tensor& input,
80
+ const Tensor& weight,
81
+ const Tensor& bias,
82
+ const Tensor& grad_output,
83
+ IntArrayRef kernel_size,
84
+ IntArrayRef stride_size,
85
+ IntArrayRef pad_size,
86
+ IntArrayRef dilation_size) {
87
+ /*
88
+ When the following tensors are defined:
89
+
90
+ bias, grad_weight, grad_output
91
+
92
+ then these are assumed to be contiguous without checking
93
+ because of these tensors are made contiguous by calling
94
+ .contiguous() method or by resizing of zero-sized tensors in
95
+ forward/backward functions.
96
+
97
+ When grad_weight is defined then it is assumed without
98
+ checking to have the same shape as weight, see backward
99
+ functions.
100
+ */
101
+ // Check size arguments
102
+ TORCH_CHECK(
103
+ kernel_size.size() == dim,
104
+ "kernel sizes length should be ",
105
+ dim,
106
+ ", but got ",
107
+ kernel_size.size());
108
+ TORCH_CHECK(
109
+ stride_size.size() == dim,
110
+ "strides length should be ",
111
+ dim,
112
+ ", but got ",
113
+ stride_size.size());
114
+ TORCH_CHECK(
115
+ dilation_size.size() == dim,
116
+ "dilations length should be ",
117
+ dim,
118
+ ", but got ",
119
+ dilation_size.size());
120
+ TORCH_CHECK(
121
+ pad_size.size() == dim,
122
+ "pads length should be ",
123
+ dim,
124
+ ", but got ",
125
+ pad_size.size());
126
+
127
+ TORCH_CHECK(
128
+ all_positive(kernel_size),
129
+ "kernel size should be greater than zero, but got ",
130
+ kernel_size);
131
+ TORCH_CHECK(
132
+ all_positive(stride_size),
133
+ "stride should be greater than zero, but got ",
134
+ stride_size);
135
+ TORCH_CHECK(
136
+ all_positive(dilation_size),
137
+ "dilation should be greater than zero, but got ",
138
+ dilation_size);
139
+
140
+ // check input
141
+ TORCH_CHECK(input.defined(), "input must be defined");
142
+ bool is_batch = input.dim() == dim + 2;
143
+ int64_t n = (is_batch ? 2 : 1);
144
+ int64_t ndim = n + dim;
145
+ if (!is_batch) {
146
+ // input dim has to be dim + 1 if not batched
147
+ TORCH_CHECK(
148
+ input.dim() == dim + 1,
149
+ "input must be 4D or 5D tensor but got ",
150
+ input.dim(),
151
+ "D tensor");
152
+ }
153
+
154
+ // check output sizes
155
+ auto output_size = get_output_size<dim>(
156
+ input, kernel_size, stride_size, pad_size, dilation_size);
157
+
158
+ TORCH_CHECK(
159
+ all_nonnegative(output_size),
160
+ "calculated output size ",
161
+ output_size,
162
+ " is too small (all sizes must be non-negative)");
163
+
164
+ // check weight
165
+ TORCH_CHECK(weight.defined(), "weight must be defined");
166
+ TORCH_CHECK(
167
+ weight.dim() == dim + 2,
168
+ "weight must be ",
169
+ dim + 2,
170
+ "D tensor but got ",
171
+ weight.dim(),
172
+ "D tensor dim=",
173
+ dim);
174
+ TORCH_CHECK(
175
+ weight.sizes().slice(2) == kernel_size,
176
+ "weight[2:] shape ",
177
+ weight.sizes().slice(2),
178
+ " must be equal to kernel_size ",
179
+ kernel_size);
180
+
181
+ TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));
182
+
183
+ // check bias when present
184
+ if (bias.defined()) {
185
+ TORCH_CHECK(
186
+ bias.dim() == 1,
187
+ "bias must be 1D tensor but got ",
188
+ bias.dim(),
189
+ "D tensor");
190
+ TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
191
+ }
192
+
193
+ // check grad_output when present
194
+ if (grad_output.defined()) {
195
+ TORCH_CHECK(
196
+ grad_output.dim() == ndim,
197
+ "grad_output must be ",
198
+ ndim,
199
+ "D tensor but got ",
200
+ grad_output.dim(),
201
+ "D tensor");
202
+ if (is_batch) {
203
+ TORCH_CHECK(
204
+ grad_output.size(0) == input.size(0),
205
+ "grad_output.size(0)=",
206
+ grad_output.size(0),
207
+ " must be input.size(0)=",
208
+ input.size(0));
209
+ }
210
+ TORCH_CHECK(
211
+ grad_output.size(n - 1) == weight.size(0),
212
+ "grad_output.size(",
213
+ n - 1,
214
+ ")=",
215
+ grad_output.size(n - 1),
216
+ " must be weight.size(0)=",
217
+ weight.size(0));
218
+ TORCH_CHECK(
219
+ grad_output.sizes().slice(n) == output_size,
220
+ "grad_output[",
221
+ n,
222
+ ":] shape",
223
+ grad_output.sizes().slice(n),
224
+ " must be equal to output size ",
225
+ output_size);
226
+ }
227
+ }
228
+
229
+ } // namespace at::native::internal
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ForeachUtils.h ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Device.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/ScalarType.h>
6
+ #include <ATen/core/Tensor.h>
7
+ #include <ATen/native/utils/ParamsHash.h>
8
+ #include <c10/util/Exception.h>
9
+ #include <c10/util/irange.h>
10
+
11
+ #ifndef AT_PER_OPERATOR_HEADERS
12
+ #include <ATen/NativeFunctions.h>
13
+ #else
14
+ #include <ATen/ops/result_type_native.h>
15
+ #endif
16
+
17
+ #include <unordered_map>
18
+ #include <vector>
19
+
20
+ namespace at::native {
21
+ namespace {
22
+ // Check if tensor list has either a boolean tensor or a integer tensor
23
+ inline bool has_integral_tensor(TensorList tensors, const bool includeBool) {
24
+ return std::any_of(
25
+ tensors.begin(), tensors.end(), [&includeBool](const auto& t) {
26
+ return at::isIntegralType(t.scalar_type(), includeBool);
27
+ });
28
+ }
29
+ // check if tensor list has bool tensors
30
+ inline bool has_bool_tensor(TensorList tensors) {
31
+ return std::any_of(tensors.begin(), tensors.end(), [](const auto& t) -> bool {
32
+ return t.scalar_type() == ScalarType::Bool;
33
+ });
34
+ }
35
+
36
+ // Check foreach API restrictions
37
+ // - Tensor lists must be non-empty.
38
+ // - All TensorLists and ScalarLists must have the same number of elements.
39
+ // - Corresponding tensors must have the same size.
40
+ inline void check_foreach_api_restrictions(TensorList tensors) {
41
+ TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor.");
42
+ }
43
+
44
+ inline void check_foreach_api_restrictions(
45
+ TensorList tensors,
46
+ ArrayRef<Scalar> scalars) {
47
+ check_foreach_api_restrictions(tensors);
48
+ TORCH_CHECK(
49
+ tensors.size() == scalars.size(),
50
+ "Tensor list must have same number of elements as scalar list.");
51
+ }
52
+
53
+ inline void check_foreach_api_restrictions(
54
+ TensorList tensors1,
55
+ TensorList tensors2) {
56
+ TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
57
+ TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
58
+ TORCH_CHECK(
59
+ tensors1.size() == tensors2.size(),
60
+ "Tensor lists must have the same number of tensors, got ",
61
+ tensors1.size(),
62
+ " and ",
63
+ tensors2.size());
64
+ }
65
+
66
+ inline void check_foreach_api_restrictions(
67
+ TensorList tensors1,
68
+ TensorList tensors2,
69
+ TensorList tensors3) {
70
+ TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
71
+ TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
72
+ TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor.");
73
+ TORCH_CHECK(
74
+ tensors1.size() == tensors2.size(),
75
+ "Tensor lists must have the same number of tensors, got ",
76
+ tensors1.size(),
77
+ " and ",
78
+ tensors2.size());
79
+ TORCH_CHECK(
80
+ tensors1.size() == tensors3.size(),
81
+ "Tensor lists must have the same number of tensors, got ",
82
+ tensors1.size(),
83
+ " and ",
84
+ tensors3.size());
85
+ }
86
+
87
+ inline void check_foreach_api_restrictions(
88
+ TensorList tensors1,
89
+ TensorList tensors2,
90
+ TensorList tensors3,
91
+ ArrayRef<Scalar> scalars) {
92
+ check_foreach_api_restrictions(tensors1, tensors2, tensors3);
93
+ TORCH_CHECK(
94
+ tensors1.size() == scalars.size(),
95
+ "Tensor list must have same number of elements as scalar list, got ",
96
+ tensors1.size(),
97
+ " and ",
98
+ scalars.size());
99
+ }
100
+
101
+ // Helper function called in check_fast_path_restrictions to check whether all
102
+ // corresponding tensors (aligning in index across the tensorLists) share the
103
+ // same device and dtype.
104
+ inline bool _check_tensors_share_device_and_dtype(
105
+ ArrayRef<TensorList> tensorLists) {
106
+ const auto expected_dtype = tensorLists[0][0].dtype();
107
+ const auto expected_device = tensorLists[0][0].device();
108
+
109
+ auto is_tensor_okay = [&](const Tensor& tensor) {
110
+ return tensor.dtype() == expected_dtype &&
111
+ tensor.device() == expected_device && tensor.layout() == at::kStrided &&
112
+ tensor.is_non_overlapping_and_dense();
113
+ };
114
+
115
+ for (const auto& tensorList : tensorLists) {
116
+ for (const auto& tensor : tensorList) {
117
+ if (!is_tensor_okay(tensor)) {
118
+ return false;
119
+ }
120
+ }
121
+ }
122
+
123
+ return true;
124
+ }
125
+
126
+ // Helper function called in check_fast_path_restrictions to check if
127
+ // corresponding tensors in tensor lists have the same sizes and strides.
128
+ inline bool _check_tensors_share_sizes_and_strides(
129
+ ArrayRef<TensorList> tensorLists) {
130
+ for (const auto i : c10::irange(1, tensorLists.size())) {
131
+ for (const auto j : c10::irange(tensorLists[0].size())) {
132
+ if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() ||
133
+ tensorLists[0][j].strides() != tensorLists[i][j].strides()) {
134
+ return false;
135
+ }
136
+ }
137
+ }
138
+
139
+ return true;
140
+ }
141
+
142
+ // Helper function called in check_fast_path_restrictions to check whether
143
+ // all tensors type promote properly with the scalars in scalarList. This
144
+ // function assumes that _check_tensors_share_device_and_dtype has already been
145
+ // called so that all corresponding tensors in tensorLists have the same dtype.
146
+ // Then, it is sufficient to check the type promotion with just one tensorList.
147
+ inline bool _check_tensors_do_type_promotion_with_scalars(
148
+ TensorList tensorList,
149
+ ArrayRef<Scalar> scalarList = {},
150
+ bool does_op_promote_integer_inputs_to_float = false) {
151
+ for (const auto i : c10::irange(tensorList.size())) {
152
+ // For division, integer inputs will result in float.
153
+ if (does_op_promote_integer_inputs_to_float) {
154
+ if (at::isIntegralType(
155
+ tensorList[i].scalar_type(), /*includeBool*/ true)) {
156
+ return false;
157
+ }
158
+ }
159
+ if (!scalarList.empty()) {
160
+ const auto& scalar =
161
+ scalarList.size() == 1 ? scalarList[0] : scalarList[i];
162
+ const auto& tensor = tensorList[i];
163
+ // note(mkozuki): This check might be responsible for
164
+ // `_foreach_add(bool_tensors, bool_tensors)` being pushed to slow path.
165
+ if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) {
166
+ return false;
167
+ }
168
+ }
169
+ }
170
+
171
+ return true;
172
+ }
173
+
174
+ // To go via 'fast' path, several conditions must be satisfied
175
+ // - All tensors in all lists must have the same dtype.
176
+ // - All tensors must be on the same device
177
+ // - All tensors must have strided layout
178
+ // - All tensors must be non-overlapping and dense
179
+ // - Resulting tensor must have the same dtype as the input one
180
+
181
+ // Please, make sure to call check_foreach_api_restrictions before calling this
182
+ // method. There is a set of preconditions that have to be satisfied.
183
+ inline bool check_fast_path_restrictions(
184
+ ArrayRef<TensorList> tensorLists,
185
+ ArrayRef<Scalar> scalarList = {},
186
+ bool does_op_promote_integer_inputs_to_float = false) {
187
+ return _check_tensors_share_device_and_dtype(tensorLists) &&
188
+ _check_tensors_share_sizes_and_strides(tensorLists) &&
189
+ _check_tensors_do_type_promotion_with_scalars(
190
+ tensorLists[0],
191
+ scalarList,
192
+ does_op_promote_integer_inputs_to_float);
193
+ }
194
+
195
+ inline std::vector<c10::Scalar> convert_tensor_to_scalar_list(
196
+ const Tensor& scalarList_,
197
+ int64_t expect_length) {
198
+ std::vector<c10::Scalar> scalarList;
199
+ TORCH_CHECK(
200
+ scalarList_.device() == c10::kCPU,
201
+ "Expected scalars to be on CPU, got ",
202
+ scalarList_.device(),
203
+ " instead.");
204
+ TORCH_CHECK(
205
+ scalarList_.is_contiguous(), "Expected scalars to be contiguous.");
206
+ TORCH_CHECK(
207
+ scalarList_.dim() == 1,
208
+ "Expected packed scalar Tensor to be of dimension 1. Got ",
209
+ scalarList_.dim(),
210
+ " instead.");
211
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
212
+ kComplexHalf,
213
+ kHalf,
214
+ kBool,
215
+ kBFloat16,
216
+ scalarList_.scalar_type(),
217
+ "convert_tensor_to_scalar_list",
218
+ [&]() {
219
+ const scalar_t* scalar_data = scalarList_.data_ptr<scalar_t>();
220
+ TORCH_CHECK(
221
+ (expect_length == scalarList_.size(0)),
222
+ "Expected length of scalars to match input of length ",
223
+ expect_length,
224
+ " but got ",
225
+ scalarList_.size(0),
226
+ " instead.");
227
+ for (int64_t i = 0; i < scalarList_.size(0); i++) {
228
+ scalarList.emplace_back(scalar_data[i]);
229
+ }
230
+ });
231
+ return scalarList;
232
+ }
233
+
234
+ inline bool can_use_fast_route(
235
+ ArrayRef<TensorList> tensorLists,
236
+ ArrayRef<Scalar> scalarList = {},
237
+ bool does_op_promote_integer_inputs_to_float = false) {
238
+ return check_fast_path_restrictions(
239
+ tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
240
+ }
241
+
242
+ inline bool can_use_fast_route(
243
+ TensorList tensors1,
244
+ TensorList tensors2,
245
+ bool does_op_promote_integer_inputs_to_float = false) {
246
+ return can_use_fast_route(
247
+ {tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);
248
+ }
249
+
250
+ using DeviceDtypeKey = std::pair<at::Device, at::ScalarType>;
251
+ using IndicesT = std::vector<size_t>;
252
+ using nested_optional_tensorvec_t =
253
+ std::vector<std::vector<c10::optional<at::Tensor>>>;
254
+ using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>;
255
+ using FlatMap = std::unordered_map<
256
+ DeviceDtypeKey,
257
+ TensorsAndIndicesT,
258
+ ParamsHash<DeviceDtypeKey>>;
259
+
260
+ inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
261
+ const nested_optional_tensorvec_t& nested_tensorlist,
262
+ const bool with_indices) {
263
+ FlatMap grouped_tensors_with_indices;
264
+
265
+ TORCH_CHECK(!nested_tensorlist.empty());
266
+ TORCH_CHECK(!nested_tensorlist[0].empty());
267
+ const auto num_lists = nested_tensorlist.size();
268
+ const auto num_tensors = nested_tensorlist[0].size();
269
+
270
+ TORCH_CHECK(std::all_of(
271
+ nested_tensorlist.cbegin(),
272
+ nested_tensorlist.cend(),
273
+ [&](const auto& tensorlist) -> bool {
274
+ // note(crcrpar): Allow empty tensorlists following
275
+ // ref:
276
+ // https://github.com/pytorch/pytorch/blob/85885301fd3c6adb8b9dc3cf7afadf6945566684/torch/utils/_foreach_utils.py#L21-L24
277
+ return tensorlist.size() == num_tensors || tensorlist.size() == 0;
278
+ }));
279
+
280
+ for (const auto& tensor_index : c10::irange(num_tensors)) {
281
+ const auto key = [&]() -> DeviceDtypeKey {
282
+ const auto t = nested_tensorlist[0][tensor_index];
283
+ TORCH_CHECK(
284
+ t.has_value(),
285
+ "Tensors of the first list of nested Tensor lists are supposed to be defined but ",
286
+ "the ",
287
+ tensor_index,
288
+ "-th Tensor is not.");
289
+ return {t->device(), t->scalar_type()};
290
+ }();
291
+ TORCH_CHECK(
292
+ std::all_of(
293
+ nested_tensorlist.cbegin(),
294
+ nested_tensorlist.cend(),
295
+ [&](const auto& tensorlist) -> bool {
296
+ if (tensorlist.size() == 0) {
297
+ return true;
298
+ }
299
+ const auto& tensor = tensorlist[tensor_index];
300
+ // note(crcrpar): Currently the scope of this function is
301
+ // optimizers so there could be `state_steps` and other scalars
302
+ // whose elements are float tensors no matter what the parameter's
303
+ // dtype is.
304
+ if (!tensor.has_value()) {
305
+ return true;
306
+ } else {
307
+ const auto s = tensor->scalar_type();
308
+ const auto d = tensor->device();
309
+ // Note: `step` or `state_step` is float32 by default.
310
+ if (key.first == d) {
311
+ return key.second == s || s == at::ScalarType::Float ||
312
+ s == at::ScalarType::Double;
313
+ } else if (d.is_cpu()) {
314
+ // note(crcrpar): There are some test cases (e.g.
315
+ // TestOptim::test_adam) where state_steps are on CPU and the
316
+ // others are on CUDA. Currently a state_step Tensor has the
317
+ // dtype of float.
318
+ return s == at::ScalarType::Float ||
319
+ s == at::ScalarType::Double;
320
+ } else {
321
+ return false;
322
+ }
323
+ }
324
+ }),
325
+ "Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding");
326
+ if (!grouped_tensors_with_indices.count(key)) {
327
+ grouped_tensors_with_indices.insert(
328
+ {key,
329
+ TensorsAndIndicesT{
330
+ [&]() -> nested_optional_tensorvec_t {
331
+ nested_optional_tensorvec_t nested_tensorvec;
332
+ nested_tensorvec.reserve(num_lists);
333
+ for (const auto& i : c10::irange(num_lists)) {
334
+ std::vector<c10::optional<at::Tensor>> tensors;
335
+ if (!nested_tensorlist[i].empty()) {
336
+ // NB: num_tensors is the max possible length for any of
337
+ // the inner lists of tensor references. Reserving the max
338
+ // trades memory for perf. This should not have significant
339
+ // impact.
340
+ tensors.reserve(num_tensors);
341
+ }
342
+ nested_tensorvec.emplace_back(tensors);
343
+ }
344
+ return nested_tensorvec;
345
+ }(),
346
+ [&]() -> IndicesT {
347
+ if (!with_indices) {
348
+ return {};
349
+ } else {
350
+ IndicesT indices;
351
+ indices.reserve(num_tensors);
352
+ return indices;
353
+ }
354
+ }()}});
355
+ }
356
+ for (const auto& list_index : c10::irange(num_lists)) {
357
+ if (!nested_tensorlist[list_index].empty()) {
358
+ grouped_tensors_with_indices[key].first[list_index].emplace_back(
359
+ nested_tensorlist[list_index][tensor_index]);
360
+ }
361
+ }
362
+ if (with_indices) {
363
+ grouped_tensors_with_indices[key].second.emplace_back(tensor_index);
364
+ }
365
+ }
366
+
367
+ return grouped_tensors_with_indices;
368
+ }
369
+
370
+ } // namespace
371
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebra.h ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <c10/util/Optional.h>
5
+
6
+ namespace c10 {
7
+ class Scalar;
8
+ }
9
+
10
+ namespace at {
11
+ struct TensorIterator;
12
+ }
13
+
14
+ namespace at::native {
15
+
16
+ using addr_fn = void (*)(TensorIterator &, const Scalar& beta, const Scalar& alpha);
17
+ DECLARE_DISPATCH(addr_fn, addr_stub);
18
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SortingUtils.h ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/NumericUtils.h>
4
+ #include <ATen/native/Resize.h>
5
+ #include <c10/util/irange.h>
6
+
7
+ #ifndef AT_PER_OPERATOR_HEADERS
8
+ #include <ATen/Functions.h>
9
+ #else
10
+ #include <ATen/ops/empty.h>
11
+ #endif
12
+
13
+ namespace at::native {
14
+
15
+ // ensure we get good values and indices for kthvalue, mode
16
+ // this will always be with the reducing dim as 1-d
17
+ inline void _reduction_with_indices_allocate_or_resize_output(
18
+ Tensor& values,
19
+ Tensor& indices,
20
+ const Tensor& self,
21
+ int64_t dim_,
22
+ bool keepdim) {
23
+ int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
24
+ auto result_sizes = self.sizes().vec();
25
+ if (!result_sizes.empty()) {
26
+ result_sizes[dim] = 1;
27
+ }
28
+ if (values.defined()) {
29
+ TORCH_CHECK(
30
+ self.options().type_equal(values.options()),
31
+ "output values must be of same type as input");
32
+ if (!keepdim && values.dim() == self.dim() - 1) {
33
+ // unsqueeze to preserve passed in noncontiguous tensor in resize
34
+ values.unsqueeze_(dim);
35
+ }
36
+ resize_output(values, result_sizes);
37
+ } else {
38
+ values = at::empty(result_sizes, self.options());
39
+ }
40
+ if (indices.defined()) {
41
+ TORCH_CHECK(
42
+ indices.dtype() == kLong, "output indices must be of scalar type Long");
43
+ TORCH_CHECK(
44
+ indices.device() == self.device(),
45
+ "output indices must be on same device as input");
46
+ if (!keepdim && indices.dim() == self.dim() - 1) {
47
+ // unsqueeze to preserve passed in noncontiguous tensor in resize
48
+ indices.unsqueeze_(dim);
49
+ }
50
+ resize_output(indices, result_sizes);
51
+ } else {
52
+ indices = at::empty(result_sizes, self.options().dtype(kLong));
53
+ }
54
+ }
55
+
56
+ // ensure we get good values and indices for topk
57
+ inline void _allocate_or_resize_output_with_indices(
58
+ Tensor& values,
59
+ Tensor& indices,
60
+ const Tensor& self,
61
+ int64_t dim_,
62
+ int64_t k) {
63
+ int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
64
+ auto result_sizes = self.sizes().vec();
65
+ if (!result_sizes.empty()) {
66
+ result_sizes[dim] = k;
67
+ }
68
+ if (values.defined()) {
69
+ TORCH_CHECK(
70
+ self.options().type_equal(values.options()),
71
+ "output values must be of same type as input");
72
+ values.resize_(result_sizes);
73
+ } else {
74
+ values = at::empty(result_sizes, self.options());
75
+ }
76
+ if (indices.defined()) {
77
+ TORCH_CHECK(
78
+ indices.dtype() == kLong, "output indices must be of scalar type Long");
79
+ TORCH_CHECK(
80
+ indices.device() == self.device(),
81
+ "output indices must be on same device as input");
82
+ indices.resize_(result_sizes);
83
+ } else {
84
+ indices = at::empty(result_sizes, self.options().dtype(kLong));
85
+ }
86
+ }
87
+
88
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UnaryOps.h ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <ATen/Generator.h>
5
+ #include <c10/core/Scalar.h>
6
+ #include <stdexcept>
7
+
8
+ namespace at {
9
+ class Tensor;
10
+ class TensorBase;
11
+ struct TensorIteratorBase;
12
+ }
13
+
14
+ namespace at::native {
15
+
16
+ using unary_fn = void(*)(TensorIteratorBase&);
17
+ using unary_fn_with_scalar = void(*)(TensorIteratorBase&, const Scalar& a);
18
+
19
+ inline namespace CPU_CAPABILITY {
20
+ void conj_kernel(TensorIteratorBase &iter);
21
+ void neg_kernel(TensorIteratorBase &iter);
22
+ void reciprocal_kernel(TensorIteratorBase &iter);
23
+ void rsqrt_kernel(TensorIteratorBase& iter);
24
+ void sqrt_kernel(TensorIteratorBase& iter);
25
+ } // namespace CPU_CAPABILITY
26
+
27
+ DECLARE_DISPATCH(unary_fn, abs_stub);
28
+ DECLARE_DISPATCH(unary_fn, angle_stub);
29
+ DECLARE_DISPATCH(unary_fn, conj_physical_stub);
30
+ DECLARE_DISPATCH(unary_fn, acos_stub);
31
+ DECLARE_DISPATCH(unary_fn, acosh_stub);
32
+ DECLARE_DISPATCH(unary_fn, asinh_stub);
33
+ DECLARE_DISPATCH(unary_fn, atanh_stub);
34
+ DECLARE_DISPATCH(unary_fn, asin_stub);
35
+ DECLARE_DISPATCH(unary_fn, atan_stub);
36
+ DECLARE_DISPATCH(unary_fn, bitwise_not_stub);
37
+ DECLARE_DISPATCH(unary_fn, logical_not_stub);
38
+ DECLARE_DISPATCH(unary_fn, ceil_stub);
39
+ DECLARE_DISPATCH(unary_fn, cos_stub);
40
+ DECLARE_DISPATCH(unary_fn, cosh_stub);
41
+ DECLARE_DISPATCH(unary_fn, digamma_stub);
42
+ DECLARE_DISPATCH(unary_fn, special_entr_stub);
43
+ DECLARE_DISPATCH(unary_fn, special_erfcx_stub);
44
+ DECLARE_DISPATCH(unary_fn, erf_stub);
45
+ DECLARE_DISPATCH(unary_fn, erfc_stub);
46
+ DECLARE_DISPATCH(unary_fn, erfinv_stub);
47
+ DECLARE_DISPATCH(unary_fn, exp_stub);
48
+ DECLARE_DISPATCH(unary_fn, exp2_stub);
49
+ DECLARE_DISPATCH(unary_fn, expm1_stub);
50
+ DECLARE_DISPATCH(unary_fn, floor_stub);
51
+ DECLARE_DISPATCH(unary_fn, frac_stub);
52
+ DECLARE_DISPATCH(unary_fn, frexp_stub);
53
+ DECLARE_DISPATCH(unary_fn, i0_stub);
54
+ DECLARE_DISPATCH(unary_fn, special_i0e_stub);
55
+ DECLARE_DISPATCH(unary_fn, special_i1_stub);
56
+ DECLARE_DISPATCH(unary_fn, special_i1e_stub);
57
+ DECLARE_DISPATCH(unary_fn, log_stub);
58
+ DECLARE_DISPATCH(unary_fn, log10_stub);
59
+ DECLARE_DISPATCH(unary_fn, log1p_stub);
60
+ DECLARE_DISPATCH(unary_fn, log2_stub);
61
+ DECLARE_DISPATCH(unary_fn, special_ndtri_stub);
62
+ DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub);
63
+ DECLARE_DISPATCH(unary_fn, neg_stub);
64
+
65
+ DECLARE_DISPATCH(unary_fn, reciprocal_stub);
66
+ DECLARE_DISPATCH(unary_fn, round_stub);
67
+ DECLARE_DISPATCH(unary_fn, rsqrt_stub);
68
+ DECLARE_DISPATCH(unary_fn, sigmoid_stub);
69
+ DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub);
70
+ DECLARE_DISPATCH(unary_fn, sign_stub);
71
+ DECLARE_DISPATCH(unary_fn, signbit_stub);
72
+ DECLARE_DISPATCH(unary_fn, sgn_stub);
73
+ DECLARE_DISPATCH(unary_fn, sin_stub);
74
+ DECLARE_DISPATCH(unary_fn, sinc_stub);
75
+ DECLARE_DISPATCH(unary_fn, sinh_stub);
76
+ DECLARE_DISPATCH(unary_fn, sqrt_stub);
77
+ DECLARE_DISPATCH(unary_fn, tan_stub);
78
+ DECLARE_DISPATCH(unary_fn, tanh_stub);
79
+ DECLARE_DISPATCH(unary_fn, trigamma_stub);
80
+ DECLARE_DISPATCH(unary_fn, trunc_stub);
81
+ DECLARE_DISPATCH(unary_fn, lgamma_stub);
82
+ DECLARE_DISPATCH(unary_fn, special_airy_ai_stub);
83
+ DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub);
84
+ DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub);
85
+ DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub);
86
+ DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub);
87
+ DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub);
88
+ DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub);
89
+ DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub);
90
+ DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub);
91
+ DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub);
92
+ DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub);
93
+ DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub);
94
+
95
+ // NB: these are actually defined in Distribution
96
+ DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, c10::optional<Generator>), bernoulli_tensor_stub);
97
+ DECLARE_DISPATCH(void(*)(const TensorBase&, const double, c10::optional<Generator>), bernoulli_scalar_stub);
98
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, c10::optional<Generator>), cauchy_stub);
99
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, c10::optional<Generator>), exponential_stub);
100
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, c10::optional<Generator>), geometric_stub);
101
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, c10::optional<Generator>), log_normal_stub);
102
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, c10::optional<Generator>), uniform_stub);
103
+ DECLARE_DISPATCH(void(*)(const TensorBase&, const double, const double, c10::optional<Generator>), normal_stub);
104
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const uint64_t, const int64_t, c10::optional<Generator>), random_from_to_stub);
105
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, c10::optional<Generator>), random_full_64_bits_range_stub);
106
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, c10::optional<Generator>), random_stub);
107
+
108
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub);
109
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub);
110
+ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const Scalar& a, const Scalar& b), clamp_stub);
111
+ DECLARE_DISPATCH(
112
+ void (*)(Tensor&, const Tensor&, int64_t, c10::optional<Generator>),
113
+ multinomial_with_replacement_stub);
114
+ DECLARE_DISPATCH(
115
+ void (*)(
116
+ TensorIteratorBase&,
117
+ c10::optional<double>,
118
+ c10::optional<double>,
119
+ c10::optional<double>),
120
+ nan_to_num_stub);
121
+ DECLARE_DISPATCH(void (*)(TensorIteratorBase&, int64_t), round_decimals_stub);
122
+
123
+ // Missing unary functions
124
+ // digamma
125
+ // lgamma
126
+ // erfinv
127
+ // clone
128
+ // contiguous
129
+ // zero
130
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef ATOMIC_ADD_FLOAT
2
+ #define ATOMIC_ADD_FLOAT
3
+
4
+ #if (defined(__x86_64__) || defined(__i386__) || defined(__aarch64__))
5
+ #include <ATen/native/cpu/Intrinsics.h>
6
+ #else
7
+ #define _mm_pause()
8
+ #endif
9
+
10
+ #include <atomic>
11
+
12
+ static inline void cpu_atomic_add_float(float* dst, float fvalue)
13
+ {
14
+ typedef union {
15
+ unsigned intV;
16
+ float floatV;
17
+ } uf32_t;
18
+
19
+ uf32_t new_value, old_value;
20
+ std::atomic<unsigned>* dst_intV = (std::atomic<unsigned>*)(dst);
21
+
22
+ old_value.floatV = *dst;
23
+ new_value.floatV = old_value.floatV + fvalue;
24
+
25
+ unsigned* old_intV = (unsigned*)(&old_value.intV);
26
+ while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) {
27
+ #ifdef __aarch64__
28
+ __asm__ __volatile__("yield;" : : : "memory");
29
+ #else
30
+ _mm_pause();
31
+ #endif
32
+ old_value.floatV = *dst;
33
+ new_value.floatV = old_value.floatV + fvalue;
34
+ }
35
+ }
36
+
37
+ #endif
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CatKernel.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <ATen/core/IListRef.h>
6
+
7
+ namespace at { namespace native {
8
+
9
+ using cat_serial_fn = void(*)(const Tensor &, const MaterializedITensorListRef&, int64_t);
10
+ DECLARE_DISPATCH(cat_serial_fn, cat_serial_stub);
11
+
12
+ }} // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/DispatchStub.h>
3
+ #include <cstdint>
4
+
5
+ namespace at {
6
+ class TensorBase;
7
+ }
8
+
9
+ namespace at { namespace native {
10
+
11
+ using channel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
12
+ DECLARE_DISPATCH(channel_shuffle_fn, channel_shuffle_kernel);
13
+
14
+ }} // at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <c10/util/ArrayRef.h>
5
+
6
+ /*
7
+ Depthwise 3x3 Winograd convolution operator
8
+ */
9
+
10
+ namespace at {
11
+ class Tensor;
12
+
13
+ namespace native {
14
+
15
+ using convolution_depthwise3x3_winograd_fn =
16
+ Tensor (*)(const Tensor &, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t);
17
+
18
+ DECLARE_DISPATCH(convolution_depthwise3x3_winograd_fn, convolution_depthwise3x3_winograd_stub);
19
+
20
+ } // namespace native
21
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #if defined(__clang__) && (defined(__x86_64__) || defined(__i386__))
4
+ /* Clang-compatible compiler, targeting x86/x86-64 */
5
+ #include <x86intrin.h>
6
+ #elif defined(_MSC_VER)
7
+ /* Microsoft C/C++-compatible compiler */
8
+ #include <intrin.h>
9
+ #if _MSC_VER <= 1900
10
+ #define _mm256_extract_epi64(X, Y) (((uint64_t*)&X)[Y])
11
+ #endif
12
+ #elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
13
+ /* GCC-compatible compiler, targeting x86/x86-64 */
14
+ #include <x86intrin.h>
15
+ #elif defined(__GNUC__) && defined(__ARM_NEON__)
16
+ /* GCC-compatible compiler, targeting ARM with NEON */
17
+ #include <arm_neon.h>
18
+ #elif defined(__GNUC__) && defined(__IWMMXT__)
19
+ /* GCC-compatible compiler, targeting ARM with WMMX */
20
+ #include <mmintrin.h>
21
+ #elif (defined(__GNUC__) || defined(__xlC__)) && \
22
+ (defined(__VEC__) || defined(__ALTIVEC__))
23
+ /* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
24
+ #include <altivec.h>
25
+ /* We need to undef those tokens defined by <altivec.h> to avoid conflicts
26
+ with the C++ types. => Can still use __bool/__vector */
27
+ #undef bool
28
+ #undef vector
29
+ #undef pixel
30
+ #elif defined(__GNUC__) && defined(__SPE__)
31
+ /* GCC-compatible compiler, targeting PowerPC with SPE */
32
+ #include <spe.h>
33
+ #endif
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Loops.h ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // This file provides two functions to help write elementwise kernels:
4
+ //
5
+ // cpu_kernel(TensorIterator iter, <lambda>)
6
+ // cpu_kernel_vec(TensorIterator iter, <lambda>, <vec_lambda>)
7
+ //
8
+ // Both functions may generate vectorized code. The cpu_kernel implementation
9
+ // relies on the compiler's auto-vectorization. The cpu_kernel_vec
10
+ // implementation uses x86 SIMD intrinsics when available. These functions
11
+ // are only intended to be used in the ATen/native/cpu subdirectory, since files
12
+ // in other directories are not compiled with AVX/AVX2 enabled. See README.md
13
+ // for more details.
14
+ //
15
+ // For example, to write a multiplication kernel for float:
16
+ //
17
+ // cpu_kernel(iter, [](float a, float b) { return a * b; });
18
+ //
19
+ // Or you may write:
20
+ //
21
+ // cpu_kernel_vec(iter,
22
+ // [](float a, float b) { return a * b; },
23
+ // [](Vectorized<float> a, Vectorized<float> b) { return a * b; });
24
+ //
25
+ // See BinaryOpsKernel.cpp for the complete implementation
26
+ //
27
+ //
28
+
29
+ #include <stdint.h>
30
+ #include <c10/util/C++17.h>
31
+ #include <c10/util/Load.h>
32
+ #include <c10/util/irange.h>
33
+ #include <ATen/detail/FunctionTraits.h>
34
+ #include <ATen/native/cpu/IsContiguous.h>
35
+ #include <ATen/native/TensorIterator.h>
36
+ #include <ATen/native/TensorIteratorDynamicCasting.h>
37
+ #include <ATen/cpu/vec/vec.h>
38
+
39
+ #include <utility>
40
+
41
+ namespace at { namespace native { inline namespace CPU_CAPABILITY {
42
+
43
+ using namespace vec;
44
+
45
+ template <typename traits, std::size_t... INDEX>
46
+ typename traits::ArgsTuple
47
+ dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i,
48
+ std::index_sequence<INDEX...>) {
49
+ return std::make_tuple(
50
+ c10::load<typename traits::template arg<INDEX>::type>(
51
+ data[INDEX] + i * strides[INDEX])...);
52
+ }
53
+
54
+ template <typename traits>
55
+ typename traits::ArgsTuple
56
+ dereference(char* C10_RESTRICT data[], const int64_t* strides, int64_t i) {
57
+ using Indices = std::make_index_sequence<traits::arity>;
58
+ return dereference_impl<traits>(data, strides, i, Indices{});
59
+ }
60
+
61
+ template <typename traits, std::size_t... INDEX>
62
+ typename traits::ArgsTuple
63
+ dereference_vec_impl(char* C10_RESTRICT data[],
64
+ const typename traits::result_type& opt_scalar,
65
+ size_t S,
66
+ int64_t i,
67
+ std::index_sequence<INDEX...>) {
68
+ using Vec = typename traits::result_type;
69
+ using scalar_t = typename Vec::value_type;
70
+ return std::make_tuple(
71
+ S == INDEX + 1 ?
72
+ opt_scalar :
73
+ Vec::loadu(data[INDEX] + i * sizeof(scalar_t))...);
74
+ }
75
+
76
+ template <typename traits>
77
+ typename traits::ArgsTuple
78
+ dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& opt_scalar, size_t S, int64_t i) {
79
+ using Indices = std::make_index_sequence<traits::arity>;
80
+ return dereference_vec_impl<traits>(data, opt_scalar, S, i, Indices{});
81
+ }
82
+
83
+ template <typename func_t,
84
+ typename std::enable_if<!std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
85
+ static inline void
86
+ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
87
+ using traits = function_traits<func_t>;
88
+ using result_type = typename traits::result_type;
89
+ for (; i < n; i++) {
90
+ result_type* out_ptr = (result_type*)(data[0] + i * strides[0]);
91
+ *out_ptr = c10::guts::apply(std::forward<func_t>(op), dereference<traits>(
92
+ &data[1],
93
+ &strides[1],
94
+ i));
95
+ }
96
+ }
97
+
98
+ template <typename func_t,
99
+ typename std::enable_if<std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
100
+ static inline void
101
+ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
102
+ using traits = function_traits<func_t>;
103
+ for (; i < n; i++) {
104
+ c10::guts::apply(std::forward<func_t>(op), dereference<traits>(
105
+ &data[0],
106
+ &strides[0],
107
+ i));
108
+ }
109
+ }
110
+
111
+ // Basic loop operation (one output, N inputs). May be auto-vectorized
112
+ // by the compiler. Supports inputs and outputs of different types.
113
+ template <typename func_t>
114
+ static inline void
115
+ basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
116
+ using traits = function_traits<func_t>;
117
+ constexpr int ntensors = traits::arity + 1;
118
+
119
+ // Copying strides to temporary array helps auto vectorization in older GCC
120
+ // versions.
121
+ int64_t strides[ntensors];
122
+ for (const auto arg : c10::irange(ntensors)) {
123
+ strides[arg] = strides_[arg];
124
+ }
125
+
126
+ execute_op(data, strides, i, n, std::forward<func_t>(op));
127
+ }
128
+
129
+ // the recursive variadic template for iterating over the returned tuple
130
+ template<class T, size_t N>
131
+ struct TupleOutput {
132
+ static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
133
+ const T &tuple) {
134
+ TupleOutput<T, N - 1>::handle(data, strides, i, tuple);
135
+
136
+ auto output = std::get<N - 1>(tuple);
137
+ using output_type = decltype(output);
138
+ output_type * out_ptr = (output_type *)(data[N - 1] + i * strides[N - 1]);
139
+ *out_ptr = output;
140
+ }
141
+ };
142
+
143
+ // Base case for the above recursive template
144
+ template<class T>
145
+ struct TupleOutput<T, 1> {
146
+ static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
147
+ const T &tuple) {
148
+ auto output = std::get<0>(tuple);
149
+ using output_type = decltype(output);
150
+ output_type* out_ptr = (output_type *)(data[0] + i * strides[0]);
151
+ *out_ptr = output;
152
+ }
153
+ };
154
+
155
+ template<class... Args>
156
+ void handle_tuple_outputs(char* C10_RESTRICT data[],
157
+ const int64_t* strides,
158
+ int64_t i,
159
+ const std::tuple<Args...> &tuple) {
160
+ TupleOutput<decltype(tuple), sizeof...(Args)>::handle(data, strides, i, tuple);
161
+ }
162
+
163
+ // Loop operation for `cpu_kernel_multiple_outputs`.
164
+ // 1. Use `c10::guts::apply` to make dynamic method invocation
165
+ // for the lambda passed in `cpu_kernel_multiple_outputs`.
166
+ // 2. Iterate over the members of the returned tuple, set the corresponding
167
+ // output tensor by the tuple member in `handle_tuple_outputs` function.
168
+ template <typename func_t>
169
+ static inline void
170
+ multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
171
+ using traits = function_traits<func_t>;
172
+
173
+ using result_type = typename traits::result_type;
174
+ constexpr int num_outputs = std::tuple_size<result_type>::value;
175
+ constexpr int ntensors = traits::arity + num_outputs;
176
+
177
+ // Copying strides to temporary array helps auto vectorization in older GCC
178
+ // versions.
179
+ int64_t strides[ntensors];
180
+ for (const auto arg : c10::irange(ntensors)) {
181
+ strides[arg] = strides_[arg];
182
+ }
183
+
184
+ for (; i < n; i++) {
185
+ auto output = c10::guts::apply(op, dereference<traits>(
186
+ &data[num_outputs],
187
+ &strides[num_outputs],
188
+ i));
189
+ handle_tuple_outputs(data, strides, i, output);
190
+ }
191
+ }
192
+
193
+ // Explicitly vectorized loop implementation. All inputs and outputs must be
194
+ // the same type and contiguous with one exception: a single input may be
195
+ // a scalar (stride 0). It's position is indicated by the argument `S`. If `S`
196
+ // is 0, then there are no scalar inputs.
197
+ template <typename func_t, typename vec_func_t>
198
+ static inline void
199
+ vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) {
200
+ using traits = function_traits<vec_func_t>;
201
+ using scalar_t = typename function_traits<func_t>::result_type;
202
+ using Vec = Vectorized<scalar_t>;
203
+ constexpr int ntensors = traits::arity + 1;
204
+
205
+ char* C10_RESTRICT data[ntensors];
206
+ for (const auto arg : c10::irange(ntensors)) {
207
+ data[arg] = data_[arg];
208
+ }
209
+
210
+ Vec opt_scalar = Vec(S > 0 ? *(scalar_t*)data[S] : scalar_t(0));
211
+ int64_t i = 0;
212
+ for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
213
+ auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);
214
+ auto args2 = dereference_vec<traits>(&data[1], opt_scalar, S, i + Vec::size());
215
+ auto out1 = c10::guts::apply(std::forward<vec_func_t>(vop), std::move(args1));
216
+ auto out2 = c10::guts::apply(std::forward<vec_func_t>(vop), std::move(args2));
217
+ out1.store(data[0] + i * sizeof(scalar_t));
218
+ out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t));
219
+ }
220
+ if (i < n) {
221
+ int64_t strides[ntensors];
222
+ for (const auto arg : c10::irange(ntensors)) {
223
+ strides[arg] = (S > 0 && arg == S) ? 0 : sizeof(scalar_t);
224
+ }
225
+ basic_loop(data, strides, i, n, std::forward<func_t>(op));
226
+ }
227
+ }
228
+
229
+
230
+ template <typename traits, typename cb_t>
231
+ static inline void unroll_contiguous_scalar_checks(
232
+ const int64_t* /*strides*/,
233
+ std::index_sequence<>,
234
+ cb_t&& cb) {
235
+ cb(0);
236
+ }
237
+
238
+ template <typename traits, typename cb_t, size_t INDEX0, size_t ...INDEX>
239
+ static inline void unroll_contiguous_scalar_checks(
240
+ const int64_t* strides,
241
+ std::index_sequence<INDEX0, INDEX...>,
242
+ cb_t&& cb) {
243
+ if (is_contiguous_scalar<traits, INDEX0 + 1>(strides)) {
244
+ cb(INDEX0 + 1);
245
+ } else {
246
+ unroll_contiguous_scalar_checks<traits>(strides, std::index_sequence<INDEX...>{}, std::forward<cb_t>(cb));
247
+ }
248
+ }
249
+
250
+ template <typename op_t, typename vop_t>
251
+ struct VectorizedLoop2d {
252
+ op_t op;
253
+ vop_t vop;
254
+
255
+ using traits = function_traits<op_t>;
256
+ static constexpr int ntensors = traits::arity + 1;
257
+ using data_t = std::array<char*, ntensors>;
258
+
259
+ VectorizedLoop2d(const op_t &op, vop_t vop):
260
+ op(op), vop(std::move(vop)) {}
261
+
262
+ static void advance(data_t &data, const int64_t *outer_strides) {
263
+ for (const auto arg : c10::irange(data.size())) {
264
+ data[arg] += outer_strides[arg];
265
+ }
266
+ }
267
+
268
+ void operator()(char** base, const int64_t *strides, int64_t size0, int64_t size1) {
269
+ data_t data;
270
+ std::copy_n(base, ntensors, data.data());
271
+ const int64_t *outer_strides = &strides[ntensors];
272
+
273
+ if (is_contiguous<traits>(strides)) {
274
+ for (const auto i C10_UNUSED : c10::irange(size1)) {
275
+ vectorized_loop(data.data(), size0, 0, op, vop);
276
+ advance(data, outer_strides);
277
+ }
278
+ } else {
279
+ using Indices = std::make_index_sequence<traits::arity>;
280
+ unroll_contiguous_scalar_checks<traits>(strides, Indices{}, [&](size_t idx) {
281
+ if (idx) {
282
+ for (const auto i C10_UNUSED : c10::irange(size1)) {
283
+ vectorized_loop(data.data(), size0, idx, op, vop);
284
+ advance(data, outer_strides);
285
+ }
286
+ } else {
287
+ for (const auto i C10_UNUSED : c10::irange(size1)) {
288
+ basic_loop(data.data(), strides, 0, size0, op);
289
+ advance(data, outer_strides);
290
+ }
291
+ }
292
+ });
293
+ }
294
+ }
295
+ };
296
+
297
+ template <typename op_t, typename vop_t>
298
+ VectorizedLoop2d<op_t, vop_t> make_vectorized_loop2d(
299
+ const op_t &op, const vop_t &vop) {
300
+ return VectorizedLoop2d<op_t, vop_t>(op, vop);
301
+ }
302
+
303
+ template <typename func_t>
304
+ void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
305
+ using traits = function_traits<func_t>;
306
+ // this could be extended to work with void return types
307
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
308
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
309
+ // dynamic casting not currently supported on CPU
310
+ TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
311
+
312
+ iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
313
+ // basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that
314
+ // iter.for_each is ever sending to the loop lambda
315
+ basic_loop(data, strides, 0, n, std::forward<func_t>(op));
316
+ }, grain_size);
317
+ iter.cast_outputs();
318
+ }
319
+
320
+ // This function helps write elementwise kernels that requires multiple outputs.
321
+ // It follows the similar structure of cpu_kernel.
322
+ // Instead of `basic_loop` function, a new `multiple_outputs_loop` function is
323
+ // manipulated to handle multiple return values.
324
+ // For now `needs_dynamic_casting` check is not added as the passed lambda (`func_t`)
325
+ // of `multiple_outputs_loop` returns `std::tuple` instead of `scalar_t`.
326
+ // The `gpu_kernel_multiple_outputs` is also implemented without this check,
327
+ // We could extend `needs_dynamic_casting` to support both `std::tuple` and
328
+ // `thrust::tuple` in the future.
329
+ template <typename func_t>
330
+ void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
331
+ using traits = function_traits<func_t>;
332
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
333
+
334
+ iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
335
+ multiple_outputs_loop(data, strides, 0, n, std::forward<func_t>(op));
336
+ }, grain_size);
337
+ iter.cast_outputs();
338
+ }
339
+
340
+ template <bool check_dynamic_cast=true, typename func_t, typename vec_func_t>
341
+ void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, int64_t grain_size = at::internal::GRAIN_SIZE) {
342
+ using traits = function_traits<func_t>;
343
+ // this could be extended to work with void return types
344
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
345
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
346
+ // dynamic casting not currently supported on CPU, but some kernels (like Fill)
347
+ // explicitly dynamic_cast, so we give the opt-out of checking.
348
+ if constexpr (check_dynamic_cast) {
349
+ TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
350
+ }
351
+
352
+ iter.for_each(make_vectorized_loop2d(op, vop), grain_size);
353
+ iter.cast_outputs();
354
+ }
355
+
356
+ template <typename func_t>
357
+ void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range) {
358
+ using traits = function_traits<func_t>;
359
+ constexpr bool result_void = std::is_void<typename traits::result_type>::value;
360
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity &&
361
+ ((result_void && iter.noutputs() == 0) || (!result_void && iter.noutputs() == 1)));
362
+ // dynamic casting not currently supported on CPU
363
+ TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
364
+
365
+ iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) {
366
+ basic_loop(data, strides, 0, n, std::forward<func_t>(op));
367
+ }, range);
368
+ iter.cast_outputs();
369
+ }
370
+
371
+ template <typename func_t>
372
+ void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op) {
373
+ cpu_serial_kernel(iter, op, {0, iter.numel()});
374
+ }
375
+
376
+ template <typename func_t, typename vec_func_t>
377
+ void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, const Range& range) {
378
+ using traits = function_traits<func_t>;
379
+ // this could be extended to work with void return types
380
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
381
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
382
+ // dynamic casting not currently supported on CPU
383
+ TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
384
+
385
+ iter.serial_for_each(make_vectorized_loop2d(op, vop), range);
386
+ iter.cast_outputs();
387
+ }
388
+
389
+ template <typename func_t, typename vec_func_t>
390
+ void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) {
391
+ cpu_serial_kernel_vec(iter, op, vop, {0, iter.numel()});
392
+ }
393
+
394
+ }}} // namespace at::native::<anonymous>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/DispatchStub.h>
3
+
4
+ namespace at {
5
+ class Tensor;
6
+
7
+ namespace native {
8
+
9
+ using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&);
10
+
11
+ DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel);
12
+ DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel);
13
+
14
+ }} // at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Parallel.h>
4
+ #include <ATen/NumericUtils.h>
5
+ #include <ATen/cpu/vec/vec.h>
6
+ #include <ATen/cpu/vec/functional.h>
7
+ #include <ATen/native/ReductionType.h>
8
+ #include <c10/util/irange.h>
9
+ #include <ATen/OpMathType.h>
10
+ #include <ATen/native/cpu/utils.h>
11
+ #include <ATen/OpMathType.h>
12
+
13
+ namespace at::native {
14
+ inline namespace CPU_CAPABILITY {
15
+
16
+ using namespace vec;
17
+
18
+ #define AT_DISPATCH_REDUCTION_TYPES(op, ...) \
19
+ [&] { \
20
+ switch (op) { \
21
+ case ReductionType::SUM: { \
22
+ static constexpr auto reduce = ReductionType::SUM; \
23
+ return __VA_ARGS__(); \
24
+ } \
25
+ case ReductionType::MEAN: { \
26
+ static constexpr auto reduce = ReductionType::MEAN; \
27
+ return __VA_ARGS__(); \
28
+ } \
29
+ case ReductionType::MIN: { \
30
+ static constexpr auto reduce = ReductionType::MIN; \
31
+ return __VA_ARGS__(); \
32
+ } \
33
+ case ReductionType::MAX: { \
34
+ static constexpr auto reduce = ReductionType::MAX; \
35
+ return __VA_ARGS__(); \
36
+ } \
37
+ case ReductionType::PROD: { \
38
+ static constexpr auto reduce = ReductionType::PROD; \
39
+ return __VA_ARGS__(); \
40
+ } \
41
+ } \
42
+ }()
43
+
44
+ template <typename scalar_t, ReductionType reduce>
45
+ inline vec_scalar_t<scalar_t> init_value() {
46
+ using acc_t = vec_scalar_t<scalar_t>;
47
+ acc_t val;
48
+ if (reduce == ReductionType::SUM ||
49
+ reduce == ReductionType::MEAN) {
50
+ val = static_cast<acc_t>(0);
51
+ } else if (reduce == ReductionType::PROD) {
52
+ val = static_cast<acc_t>(1);
53
+ } else if (reduce == ReductionType::MAX) {
54
+ val = -std::numeric_limits<acc_t>::infinity();
55
+ } else {
56
+ TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
57
+ val = std::numeric_limits<acc_t>::infinity();
58
+ }
59
+ return val;
60
+ }
61
+
62
+ template <typename scalar_t, ReductionType reduce>
63
+ inline vec_scalar_t<scalar_t> init_value(const c10::optional<Scalar>& initial) {
64
+ using acc_t = vec_scalar_t<scalar_t>;
65
+ if (initial.has_value()) {
66
+ return initial.value().to<acc_t>();
67
+ } else {
68
+ return init_value<scalar_t, reduce>();
69
+ }
70
+ }
71
+
72
+ template <typename scalar_t>
73
+ inline void init(scalar_t* out, int64_t size, const vec_scalar_t<scalar_t>& val) {
74
+ using Vec = Vectorized<vec_scalar_t<scalar_t>>;
75
+ map<scalar_t>(
76
+ [val](Vec x) { return Vec(val); },
77
+ out,
78
+ out,
79
+ size);
80
+ }
81
+
82
+ template <typename scalar_t, ReductionType reduce>
83
+ inline void init(scalar_t* out, int64_t size, const c10::optional<Scalar>& initial) {
84
+ using acc_t = vec_scalar_t<scalar_t>;
85
+ acc_t val = init_value<scalar_t, reduce>(initial);
86
+ init(out, size, val);
87
+ }
88
+
89
+ // overload with `include_self`, used by scatter_reduce
90
+ template <typename scalar_t, ReductionType reduce>
91
+ inline void init(scalar_t* out, int64_t size, bool include_self = false) {
92
+ using acc_t = vec_scalar_t<scalar_t>;
93
+ if (!include_self) {
94
+ acc_t val = init_value<scalar_t, reduce>();
95
+ init(out, size, val);
96
+ }
97
+ }
98
+
99
+ template <typename scalar_t, ReductionType reduce>
100
+ inline void _init(scalar_t* self_ptr, at::opmath_type<scalar_t>* buffer_ptr, int64_t size, bool include_self) {
101
+ if (!include_self) {
102
+ init<at::opmath_type<scalar_t>, reduce>(buffer_ptr, size, include_self);
103
+ } else {
104
+ vec::convert(self_ptr, buffer_ptr, size);
105
+ }
106
+ }
107
+
108
+ template <typename scalar_t>
109
+ inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
110
+ _max(const scalar_t& x, const scalar_t& y) {
111
+ return at::_isnan(y) ? y : std::max(x, y);
112
+ }
113
+
114
+ template <typename scalar_t>
115
+ inline Vectorized<scalar_t> _max(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
116
+ // vec::maximum propagates NaN
117
+ return vec::maximum(x, y);
118
+ }
119
+
120
+ template <typename vec_t>
121
+ inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
122
+ _max(const vec_t& x, const vec_t& y) {
123
+ // vec::maximum propagates NaN
124
+ return maximum(x, y);
125
+ }
126
+
127
+ template <typename scalar_t>
128
+ inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
129
+ _min(const scalar_t& x, const scalar_t& y) {
130
+ return at::_isnan(y) ? y : std::min(x, y);
131
+ }
132
+
133
+ template <typename scalar_t>
134
+ inline Vectorized<scalar_t> _min(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
135
+ // vec::minimum propagates NaN
136
+ return vec::minimum(x, y);
137
+ }
138
+
139
+ template <typename vec_t>
140
+ inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
141
+ _min(const vec_t& x, const vec_t& y) {
142
+ // vec::minimum propagates NaN
143
+ return minimum(x, y);
144
+ }
145
+
146
+ template <typename scalar_t, typename accumut, typename Op,
147
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
148
+ inline void map_acc(
149
+ const Op& vec_fun,
150
+ accumut* output_data,
151
+ const accumut* input_data,
152
+ const scalar_t* input_data2,
153
+ int64_t size) {
154
+ using Vec = vec::Vectorized<scalar_t>;
155
+ using aVec = vec::Vectorized<accumut>;
156
+ int64_t d = 0;
157
+ constexpr int64_t kVecSize = Vec::size();
158
+ constexpr int64_t kaVecSize = aVec::size();
159
+ for (d = 0; d < size - (size % kVecSize); d += kVecSize) {
160
+ Vec data2_vec = Vec::loadu(input_data2 + d);
161
+ auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
162
+ aVec input_vec0 = aVec::loadu(input_data + d);
163
+ aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize);
164
+ vec_fun(input_vec0, data2_avec0).store(output_data + d);
165
+ vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize);
166
+ }
167
+ if (size - d > 0) {
168
+ int64_t tail_size = size - d;
169
+ Vec data2_vec = Vec::loadu(input_data2 + d, tail_size);
170
+ auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
171
+ if (tail_size > kaVecSize) {
172
+ aVec input_vec0 = aVec::loadu(input_data + d);
173
+ aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize);
174
+ vec_fun(input_vec0, data2_avec0).store(output_data + d);
175
+ vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize);
176
+ } else {
177
+ aVec input_vec0 = aVec::loadu(input_data + d, tail_size);
178
+ vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size);
179
+ }
180
+ }
181
+ }
182
+
183
+ // for Max and Min, propagate NaN:
184
+ template <typename T, ReductionType reduce>
185
+ inline T update(const T& x, const T& y) {
186
+ if (reduce == ReductionType::SUM ||
187
+ reduce == ReductionType::MEAN) {
188
+ return x + y;
189
+ } else if (reduce == ReductionType::PROD) {
190
+ return x * y;
191
+ } else if (reduce == ReductionType::MAX) {
192
+ return _max(x, y);
193
+ } else {
194
+ TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
195
+ return _min(x, y);
196
+ }
197
+ }
198
+
199
+ template <typename scalar_t, ReductionType reduce>
200
+ inline void update(scalar_t* out, const scalar_t* data, int64_t K) {
201
+ using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
202
+ map2<scalar_t>(
203
+ [](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
204
+ out,
205
+ out,
206
+ data,
207
+ K);
208
+ }
209
+
210
+ template <typename scalar_t, ReductionType reduce,
211
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
212
+ inline void update(at::opmath_type<scalar_t>* out, const scalar_t* data, int64_t K) {
213
+ using opmath_t = at::opmath_type<scalar_t>;
214
+ using Vec = vec::Vectorized<opmath_t>;
215
+ map_acc<scalar_t, opmath_t>(
216
+ [](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
217
+ out,
218
+ out,
219
+ data,
220
+ K);
221
+ }
222
+
223
+ template <typename scalar_t, ReductionType reduce>
224
+ inline void write(scalar_t* out, int64_t count, int64_t K) {
225
+ using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
226
+ if (reduce == ReductionType::MEAN) {
227
+ if (count > 0) {
228
+ vec::map<scalar_t>(
229
+ [count](Vec x) { return x / Vec(count); },
230
+ out,
231
+ out,
232
+ K);
233
+ }
234
+ }
235
+ }
236
+
237
+ } // namespace CPU_CAPABILITY
238
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <ATen/native/ReductionType.h>
6
+
7
+ namespace at::native {
8
+
9
+ using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
10
+ using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
11
+ using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
12
+ using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
13
+ using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
14
+
15
+ DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub);
16
+ DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub);
17
+ DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub);
18
+ DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub);
19
+ DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub);
20
+ DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub);
21
+
22
+ } // at::native