koichi12 commited on
Commit
1105a93
·
verified ·
1 Parent(s): 4aab8df

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__init__.py +0 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/reinplace.py +537 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-311.pyc +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py +213 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py +212 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py +635 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py +256 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py +182 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_dimV_ops.h +28 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_exp.h +44 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_linalg_slogdet_meta_dispatch.h +25 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_tensor_from_mask_compositeexplicitautograd_dispatch.h +24 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_print_native.h +21 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sample_dirichlet.h +39 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_csr_sum.h +39 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_softmax_backward_data_ops.h +39 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_standard_gamma_grad_compositeexplicitautograd_dispatch.h +24 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_warn_in_autograd_native.h +22 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_upsample_nearest_exact2d_cuda_dispatch.h +28 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_upsample_nearest_exact3d_backward_meta_dispatch.h +28 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/acosh_meta_dispatch.h +26 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/all_cuda_dispatch.h +31 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atanh_native.h +29 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atleast_3d_ops.h +39 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/binary_cross_entropy.h +39 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/binary_cross_entropy_with_logits_native.h +22 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/bitwise_xor_ops.h +105 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/clip_ops.h +83 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cosine_similarity_compositeimplicitautograd_dispatch.h +23 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cumprod_backward_native.h +21 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/digamma.h +39 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/divide_native.h +30 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/embedding_dense_backward_compositeexplicitautograd_dispatch.h +26 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/expm1_compositeexplicitautogradnonfunctional_dispatch.h +24 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask_backward.h +30 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/fft_ihfftn_compositeimplicitautograd_dispatch.h +28 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/gcd_cuda_dispatch.h +26 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/grid_sampler_2d.h +39 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/hardswish_backward.h +39 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/hstack.h +39 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/huber_loss_backward.h +39 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/lift_fresh.h +30 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linalg_eigh_ops.h +39 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linalg_lu_cpu_dispatch.h +25 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linalg_pinv_compositeexplicitautogradnonfunctional_dispatch.h +23 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linear.h +39 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/log_sigmoid_forward_cpu_dispatch.h +25 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/logaddexp2_ops.h +39 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/reinplace.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Tuple
5
+
6
+ import torch
7
+ from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_functional
8
+ from torch._inductor import inductor_prims
9
+ from torch._inductor.fx_utils import get_node_storage, is_node_realized
10
+ from torch._inductor.lowering import (
11
+ inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings,
12
+ )
13
+ from torch._inductor.virtualized import V
14
+ from torch.fx.immutable_collections import immutable_dict
15
+ from torch.fx.passes.reinplace import _is_view_op
16
+ from torch.utils import _pytree as pytree
17
+
18
+ aten = torch.ops.aten
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class InplaceableOp:
23
+ inplace_op: Callable[..., Any]
24
+ mutated_arg: int
25
+ extra_check: Callable[[torch.fx.Node], bool] = lambda node: True
26
+
27
+
28
+ _SCATTER_OP_TO_VIEW = {
29
+ torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
30
+ torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
31
+ torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
32
+ torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
33
+ }
34
+ _VIEW_OP_TO_SCATTER = {v: k for k, v in _SCATTER_OP_TO_VIEW.items()}
35
+
36
+
37
+ def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs):
38
+ fake_args, fake_kwargs = pytree.tree_map(
39
+ lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
40
+ (args, kwargs),
41
+ )
42
+ with V.fake_mode:
43
+ fake_result = fn(*fake_args, **fake_kwargs)
44
+
45
+ node = graph.call_function(fn, args, kwargs)
46
+ node.meta["val"] = fake_result
47
+ return node
48
+
49
+
50
+ @dataclass
51
+ class ViewOp:
52
+ target: torch._ops.OpOverload
53
+ args: Tuple[Any, ...]
54
+ kwargs: Dict[str, Any]
55
+
56
+
57
+ def _inplace_generalized_scatter(
58
+ inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp]
59
+ ) -> torch.Tensor:
60
+ tmp = inp
61
+ for view in view_ops:
62
+ fake_args, fake_kwargs = pytree.tree_map(
63
+ lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
64
+ (view.args, view.kwargs),
65
+ )
66
+ tmp = view.target(tmp, *fake_args, **fake_kwargs)
67
+ tmp.copy_(src)
68
+ return inp
69
+
70
+
71
+ def _generalized_scatter(
72
+ inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp]
73
+ ) -> torch.Tensor:
74
+ out = inp.clone()
75
+ return _inplace_generalized_scatter(out, src, view_ops)
76
+
77
+
78
+ def _decompose_scatter_functional_helper(
79
+ graph: torch.fx.Graph,
80
+ inp: torch.Tensor,
81
+ src: torch.Tensor,
82
+ view_ops: List[ViewOp],
83
+ ) -> torch.fx.Node:
84
+ view_op, view_ops_tail = view_ops[0], view_ops[1:]
85
+
86
+ if view_ops_tail:
87
+ view = graph_call_function(
88
+ graph, view_op.target, inp, *view_op.args, **view_op.kwargs
89
+ )
90
+ src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:]) # type: ignore[assignment]
91
+
92
+ return graph_call_function(
93
+ graph,
94
+ _VIEW_OP_TO_SCATTER[view_op.target],
95
+ inp,
96
+ src,
97
+ *view_op.args,
98
+ **view_op.kwargs,
99
+ )
100
+
101
+
102
+ def _decompose_scatter_functional(
103
+ graph: torch.fx.Graph, node: torch.fx.Node
104
+ ) -> torch.fx.Node:
105
+ """Decompose _generalized_scatter to a sequence of view_scatter operations
106
+
107
+ e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
108
+
109
+ will become
110
+
111
+ view = aten.slice(inp, 0, 0, 10)
112
+ view_updated = aten.slice_scatter(view, src, 1, 10, -10)
113
+ inp_updated = aten.slice_scatter(inp, view_updated, 0, 0, 10)
114
+ """
115
+ assert node.target is _generalized_scatter
116
+ inp, src, view_ops = node.args
117
+ return _decompose_scatter_functional_helper(graph, *node.args) # type: ignore[arg-type]
118
+
119
+
120
+ def _decompose_scatter_mutating(
121
+ graph: torch.fx.Graph, node: torch.fx.Node
122
+ ) -> torch.fx.Node:
123
+ """Decompose _generalized_scatter using mutations
124
+
125
+ e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
126
+
127
+ will become
128
+
129
+ inp_updated = aten.clone(inp)
130
+ slice1 = aten.slice(inp_updated, 0, 0, 10)
131
+ slice2 = aten.slice(slice1, 1, 10, -10)
132
+ slice2.copy_(src)
133
+
134
+ """
135
+ assert node.target in (_generalized_scatter, _inplace_generalized_scatter)
136
+ inp, src, view_ops = node.args
137
+ assert not node.kwargs
138
+
139
+ if node.target is _generalized_scatter:
140
+ inp = graph_call_function(graph, aten.clone, inp)
141
+
142
+ tmp = inp
143
+ for view in view_ops: # type: ignore[union-attr]
144
+ tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr]
145
+
146
+ graph_call_function(graph, aten.copy_.default, tmp, src)
147
+ return inp # type: ignore[return-value]
148
+
149
+
150
+ # View ops whose view_scatter op is lowered into mutations anyway,
151
+ # so is never a pessimisation to decompose.
152
+ _ALWAYS_MUTATING_SCATTER_OPS = {
153
+ aten.as_strided.default,
154
+ aten.diagonal.default,
155
+ }
156
+
157
+
158
+ def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
159
+ _, _, view_ops = node.args
160
+ return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr]
161
+
162
+
163
+ def should_reinplace_scatter(node: torch.fx.Node) -> bool:
164
+ """Choose between mutating and functional scatter decompositions
165
+
166
+ Reinplacing view scatter ops can be pessimising as it blocks fusion with the
167
+ input or output tensor computations. However, it is still profitable if the
168
+ input and output would have been realized anyway.
169
+
170
+ """
171
+ inp, src, view_ops = node.args
172
+
173
+ # Mutating scatter ops unconditionally realize input and output
174
+ if scatter_always_uses_mutation(node):
175
+ return True
176
+
177
+ if is_node_realized(inp) and is_node_realized(node): # type: ignore[arg-type]
178
+ return True
179
+
180
+ # If the output is copied back into the input, this forces both to be
181
+ # realized as the output is a user of the input
182
+ if inp.op == "placeholder" and any( # type: ignore[union-attr]
183
+ user.target is aten.copy_.default and user.args[0] is inp for user in node.users
184
+ ):
185
+ return True
186
+
187
+ # Otherwise, assume fusions will make functional variants profitable
188
+ return False
189
+
190
+
191
+ def decompose_generalized_scatter(graph: torch.fx.Graph) -> None:
192
+ """Replace _generalized_scatter with normal aten ops"""
193
+ for node in graph.nodes:
194
+ if node.target not in (_generalized_scatter, _inplace_generalized_scatter):
195
+ continue
196
+
197
+ use_mutation = (
198
+ node.target is _inplace_generalized_scatter
199
+ or scatter_always_uses_mutation(node)
200
+ )
201
+
202
+ with graph.inserting_before(node):
203
+ if use_mutation:
204
+ new_node = _decompose_scatter_mutating(graph, node)
205
+ else:
206
+ new_node = _decompose_scatter_functional(graph, node)
207
+
208
+ node.replace_all_uses_with(new_node)
209
+ graph.erase_node(node)
210
+
211
+
212
+ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
213
+ """
214
+ This canonicalizes view scatter ops into a generalized form, defined as:
215
+ def scatter(inp, src, views):
216
+ tmp = inp.clone()
217
+ for view in views:
218
+ tmp = view(tmp)
219
+ tmp.copy_(src)
220
+
221
+ We also fuse consecutive view scatter ops of the form
222
+ a = scatter(view2(self), src, [view1])
223
+ b = scatter(self, a, [view2])
224
+ which can be rewritten as
225
+ b = scatter(self, src, [view2, view1])
226
+ a = view2(b)
227
+
228
+ This is both more efficient as we only do a single scatter, and also
229
+ easier to reinplace since there is only one use of `self`
230
+ """
231
+
232
+ node_to_view_base: Dict[torch.fx.Node, torch.fx.Node] = {}
233
+ node_to_view_op: Dict[torch.fx.Node, List[ViewOp]] = defaultdict(list)
234
+
235
+ def handle_views(node: torch.fx.Node):
236
+ inp = node.args[0]
237
+ node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
238
+ node_to_view_op[node] = [
239
+ *node_to_view_op[inp], # type: ignore[index]
240
+ ViewOp(
241
+ node.target, # type: ignore[arg-type]
242
+ args=node.args[1:],
243
+ kwargs=node.kwargs,
244
+ ),
245
+ ]
246
+
247
+ def handle_view_scatter(node: torch.fx.Node):
248
+ assert len(node.args) >= 2
249
+ inp, src = node.args[:2]
250
+
251
+ scatter_view_op = ViewOp(
252
+ _SCATTER_OP_TO_VIEW[node.target],
253
+ args=node.args[2:],
254
+ kwargs=node.kwargs,
255
+ )
256
+
257
+ def can_fuse():
258
+ if src.target is not _generalized_scatter: # type: ignore[union-attr]
259
+ return False
260
+ src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
261
+
262
+ inp_base = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
263
+ src_base = node_to_view_base.get(src_inp, src_inp) # type: ignore[arg-type]
264
+ return inp_base is src_base and node_to_view_op[src_inp] == [ # type: ignore[index]
265
+ *node_to_view_op[inp], # type: ignore[index]
266
+ scatter_view_op,
267
+ ]
268
+
269
+ if not can_fuse():
270
+ with graph.inserting_before(node):
271
+ new_node = graph_call_function(
272
+ graph,
273
+ _generalized_scatter,
274
+ inp,
275
+ src,
276
+ [scatter_view_op],
277
+ )
278
+ node.replace_all_uses_with(new_node)
279
+ graph.erase_node(node)
280
+ return
281
+
282
+ src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
283
+ with graph.inserting_before(src):
284
+ new_node = graph_call_function(
285
+ graph,
286
+ _generalized_scatter,
287
+ inp,
288
+ src_src,
289
+ [scatter_view_op, *src_scatter_view_op], # type: ignore[misc]
290
+ )
291
+ node.replace_all_uses_with(new_node)
292
+ graph.erase_node(node)
293
+
294
+ if src.users: # type: ignore[union-attr]
295
+ new_src = graph_call_function(
296
+ graph,
297
+ _SCATTER_OP_TO_VIEW[node.target],
298
+ new_node,
299
+ *node.args[2:],
300
+ **node.kwargs,
301
+ )
302
+
303
+ handle_views(new_src)
304
+ src.replace_all_uses_with(new_src) # type: ignore[union-attr]
305
+
306
+ graph.erase_node(src)
307
+
308
+ for node in graph.nodes:
309
+ if _is_view_op(node.target):
310
+ handle_views(node)
311
+ elif node.target in _SCATTER_OP_TO_VIEW:
312
+ handle_view_scatter(node)
313
+
314
+
315
+ inplaceable_ops = {
316
+ aten.index_put.default: InplaceableOp(aten.index_put_.default, 0),
317
+ aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0),
318
+ _generalized_scatter: InplaceableOp(
319
+ _inplace_generalized_scatter,
320
+ 0,
321
+ extra_check=should_reinplace_scatter,
322
+ ),
323
+ }
324
+
325
+ try:
326
+ c10d_functional = torch.ops._c10d_functional
327
+ inplaceable_collective_ops = {
328
+ c10d_functional.all_reduce.default: InplaceableOp(
329
+ c10d_functional.all_reduce_.default, 0
330
+ ),
331
+ c10d_functional.all_reduce_coalesced.default: InplaceableOp(
332
+ c10d_functional.all_reduce_coalesced_.default, 0
333
+ ),
334
+ }
335
+ inplaceable_ops.update(inplaceable_collective_ops)
336
+ except AttributeError:
337
+ # _c10d_functional ops are only available when torch
338
+ # is built with USE_DISTRIBUTED=1.
339
+ pass
340
+
341
+ inplaceable_foreach_ops: Dict[torch._ops.OpOverload, InplaceableOp] = {}
342
+ for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items():
343
+ inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0)
344
+
345
+
346
+ inplaceable_triton_ops = {triton_kernel_wrapper_functional}
347
+
348
+
349
+ # Operators that don't depend on the tensor data
350
+ META_ONLY_OPS = {
351
+ aten.sym_size.int,
352
+ aten.sym_stride.int,
353
+ aten.sym_numel.default,
354
+ aten.sym_storage_offset.default,
355
+ }
356
+
357
+
358
+ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
359
+ """
360
+ Reinplaces in-placeable operations.
361
+ If there are no uses of a view of the mutated arg after the current node,
362
+ it is possible to inplace the op.
363
+ This above algorithm could be justified by observing side effects. While
364
+ we traverse the graph in forwards direction, only latter nodes could view
365
+ side effects of the current node. If the current node is not used later as
366
+ well as no view of this node is used later in the graph, then it is safe to
367
+ inplace as there would be no way to observe the side effects.
368
+ This condition is slightly different for graph inputs where they can only
369
+ be inplaced if the above condition is true and there's a copy_ in the
370
+ epilogue that signals that the caller wants to observe the mutation.
371
+ """
372
+
373
+ copy_args_to_copy_nodes = {}
374
+ mutated_inputs = set()
375
+ storage_to_nodes = defaultdict(list)
376
+ node_order: Dict[Any, int] = {}
377
+ for i, node in enumerate(reversed(graph.nodes)):
378
+ node_order[node] = len(graph.nodes) - i - 1
379
+ storage_to_nodes[get_node_storage(node)].append(node)
380
+ if node.target == aten.copy_.default and node.args[0].op == "placeholder":
381
+ dst = node.args[0]
382
+ src = node.args[1]
383
+ # If the target is a getitem and it indexes a possible clone,
384
+ # then skip over it
385
+ if src.target == operator.getitem and (
386
+ (
387
+ src.args[0].target == triton_kernel_wrapper_functional
388
+ and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0]
389
+ )
390
+ or (src.args[0].target in inplaceable_foreach_ops)
391
+ or (src.args[0].target == torch.ops.higher_order.auto_functionalized)
392
+ ):
393
+ src = src.args[0]
394
+
395
+ copy_args_to_copy_nodes[(dst, src)] = node
396
+
397
+ mutated_inputs.add(node.args[0])
398
+
399
+ def any_use_of_views_after_node(node, shared_view_nodes, *, copy_node):
400
+ node_loc = node_order[node]
401
+ copy_node_loc = node_order[copy_node] if copy_node is not None else None
402
+
403
+ def is_meta_only_user(node):
404
+ if _is_view_op(node.target):
405
+ return all(is_meta_only_user(u) for u in node.users)
406
+ return node.target in META_ONLY_OPS
407
+
408
+ for view in shared_view_nodes:
409
+ for user in view.users:
410
+ user_loc = node_order[user]
411
+ # Skip all users before node
412
+ if user_loc <= node_loc:
413
+ continue
414
+ # Ignore uses after the copy_ epilogue node, where the input
415
+ # has already been mutated anyway
416
+ if copy_node_loc is not None and copy_node_loc <= user_loc:
417
+ continue
418
+ # Reinplacing does not change shape metadata
419
+ if is_meta_only_user(user):
420
+ continue
421
+ return True
422
+ return False
423
+
424
+ def can_inplace(node, mutated_arg):
425
+ if isinstance(mutated_arg, (list, tuple)):
426
+ return all(can_inplace(node, arg) for arg in mutated_arg)
427
+
428
+ if get_node_storage(mutated_arg) is None:
429
+ return False
430
+ shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)]
431
+ if mutated_arg.op == "placeholder":
432
+ if not (
433
+ copy_node := copy_args_to_copy_nodes.get((mutated_arg, node), False)
434
+ ):
435
+ return False
436
+
437
+ if any_use_of_views_after_node(
438
+ node, shared_view_nodes, copy_node=copy_node
439
+ ):
440
+ return False
441
+
442
+ return True
443
+ elif any(view.op == "placeholder" for view in shared_view_nodes):
444
+ # If mutated arg is view of any of the inputs of the graph,
445
+ # do not allow for inplacing.
446
+ # This would require more sophisticated algorithm to handle
447
+ return False
448
+ else:
449
+ return not any_use_of_views_after_node(
450
+ node, shared_view_nodes, copy_node=None
451
+ )
452
+
453
+ replace_dict: Dict[torch.fx.Node, torch.fx.Node] = {}
454
+
455
+ def reinplace_and_refine_tensors_to_clone(old_tensors_to_clone, kwargs):
456
+ tensors_to_clone: List[str] = []
457
+ for arg in old_tensors_to_clone:
458
+ assert arg in kwargs
459
+ mutated_arg = kwargs[arg]
460
+ if can_inplace(node, mutated_arg):
461
+ copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
462
+ if copy_node is not None:
463
+ replace_dict[copy_node] = copy_node.args[0]
464
+ for user in node.users:
465
+ if user.target == operator.getitem and user.args[1] == arg:
466
+ replace_dict[user] = mutated_arg
467
+ else:
468
+ tensors_to_clone.append(arg)
469
+ return tensors_to_clone
470
+
471
+ for node in graph.nodes:
472
+ if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None:
473
+ mutated_arg = node.args[inplaceable_op.mutated_arg]
474
+ if can_inplace(node, mutated_arg) and inplaceable_op.extra_check(node):
475
+ # TODO(yifu): this doesn't properly remove copy epilogues for
476
+ # ops that mutate multiple inputs. Need to revise the copy
477
+ # node tracking logic to support the case.
478
+ copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
479
+ if copy_node is not None:
480
+ replace_dict[copy_node] = copy_node.args[0]
481
+ node.target = inplaceable_op.inplace_op
482
+ elif node.target == torch.ops.higher_order.auto_functionalized:
483
+ _mutable_op = node.args[0]
484
+ from torch._higher_order_ops.auto_functionalize import get_mutable_arg_names
485
+
486
+ tensors_to_clone = get_mutable_arg_names(_mutable_op)
487
+ # Don't try to reinplace Optional[Tensor] args that are None.
488
+ tensors_to_clone = [
489
+ t for t in tensors_to_clone if node.kwargs[t] is not None
490
+ ]
491
+ tensors_to_clone = reinplace_and_refine_tensors_to_clone(
492
+ tensors_to_clone, node.kwargs
493
+ )
494
+
495
+ # Stash the metadata. There is a pass later on where we decompose
496
+ # auto_functionalized into clones + a mutable op; this metadata
497
+ # tells the decomp to only clone the following inputs
498
+ node.meta["only_clone_these_tensors"] = tensors_to_clone
499
+ elif node.target in inplaceable_triton_ops:
500
+ # inplaceable_triton_ops take an additional argument called
501
+ # tensors_to_clone which contain a list of tensors to clone
502
+ # This pass iterates over them and sees which ones are safe
503
+ # to eliminate (i.e. no longer need the clones)
504
+ tensors_to_clone = reinplace_and_refine_tensors_to_clone(
505
+ node.kwargs["tensors_to_clone"], node.kwargs["kwargs"]
506
+ )
507
+
508
+ kwargs = dict(node.kwargs)
509
+ kwargs["tensors_to_clone"] = tensors_to_clone
510
+ node.kwargs = immutable_dict(kwargs)
511
+ elif (
512
+ inplaceable_op := inplaceable_foreach_ops.get(node.target, None)
513
+ ) is not None:
514
+ mutated_args = node.args[inplaceable_op.mutated_arg]
515
+
516
+ if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args):
517
+ continue
518
+
519
+ if can_inplace(node, mutated_args):
520
+ for arg in mutated_args:
521
+ copy_node = copy_args_to_copy_nodes[(arg, node)]
522
+ replace_dict[copy_node] = copy_node.args[0]
523
+
524
+ node.target = inplaceable_op.inplace_op
525
+ for node, replacement in replace_dict.items():
526
+ while replacement in replace_dict:
527
+ replacement = replace_dict[replacement]
528
+ replace_dict[node] = replacement
529
+
530
+ node.replace_all_uses_with(replacement)
531
+ graph.erase_node(node)
532
+
533
+
534
+ def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None:
535
+ canonicalize_view_scatter_ops(graph)
536
+ reinplace_inplaceable_ops_core(graph)
537
+ decompose_generalized_scatter(graph)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (246 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-311.pyc ADDED
Binary file (21.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
35
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
36
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
37
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
38
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
39
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
40
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
41
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
42
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
43
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
44
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
45
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
46
+ amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
47
+ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
48
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
49
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
50
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
51
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
52
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
53
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
54
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
55
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
56
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
57
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
58
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
59
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
60
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
61
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
62
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
63
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
64
+ view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
65
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
66
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
67
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
68
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
69
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
70
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
71
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
72
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
73
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
74
+ view_default_8 = CallFunction(aten.view.default, sub_Tensor_1, Ignored(), _users=2)
75
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
76
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
77
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
78
+ div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
79
+ permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
80
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
81
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
82
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
83
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
84
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
85
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
86
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
87
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
88
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
89
+ _sfdp_pattern_10_training = MultiOutputPattern([view_default_5,
90
+ permute_default_6,
91
+ permute_default_9,
92
+ permute_default_11
93
+ ])
94
+
95
+
96
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
97
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
98
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
99
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
100
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
101
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
102
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
103
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
104
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
105
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
106
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
107
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
108
+ amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
109
+ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
110
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
111
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
112
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
113
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
114
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
115
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
116
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
117
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
118
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
119
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
120
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
121
+ _sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
122
+
123
+
124
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
125
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
126
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
127
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
128
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
129
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
130
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
131
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
132
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
133
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
134
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
135
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
136
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
137
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
138
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
139
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
140
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
141
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
142
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
143
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
144
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
145
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
146
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
147
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
148
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
149
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
150
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
151
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
152
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
153
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
154
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
155
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
156
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
157
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
158
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
159
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
160
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
161
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
162
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
163
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
164
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
165
+ view_default_8 = CallFunction(aten.view.default, convert_element_type_default_3, Ignored(), _users=2)
166
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
167
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
168
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
169
+ div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
170
+ permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
171
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
172
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
173
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
174
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
175
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
176
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
177
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
178
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
179
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
180
+ _sfdp_pattern_10_half_training = MultiOutputPattern([view_default_5,
181
+ permute_default_6,
182
+ permute_default_9,
183
+ permute_default_11
184
+ ])
185
+
186
+
187
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
188
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
189
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
190
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
191
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
192
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
193
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
194
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
195
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
196
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
197
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
198
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
199
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
200
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
201
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
202
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
203
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
204
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
205
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
206
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
207
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
208
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
209
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
210
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
211
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
212
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
213
+ _sfdp_pattern_10_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
35
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
36
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
37
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
38
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
39
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
40
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
41
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
42
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
43
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
44
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
45
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
46
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
47
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
48
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
49
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
50
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
51
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
52
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
53
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
54
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
55
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
56
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
57
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
58
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
59
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
60
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
61
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
62
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
63
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
64
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
65
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
66
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
67
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
68
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
69
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
70
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
71
+ div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale'))
72
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
73
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
74
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
75
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
76
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
77
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
78
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
79
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
80
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
81
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
82
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
83
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
84
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
85
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
86
+ _sfdp_pattern_11_training = MultiOutputPattern([view_default_5,
87
+ permute_default_6,
88
+ permute_default_9,
89
+ permute_default_11,
90
+ None
91
+ ])
92
+
93
+
94
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
95
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
96
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
97
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
98
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
99
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
100
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
101
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
102
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
103
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
104
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
105
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
106
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
107
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
108
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
109
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
110
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
111
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
112
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
113
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
114
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
115
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
116
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
117
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
118
+ _sfdp_pattern_11_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
119
+
120
+
121
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
122
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
123
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
124
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
125
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
126
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
127
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
128
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
129
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
130
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
131
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
132
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
133
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
134
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
135
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
136
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
137
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
138
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
139
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
140
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
141
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
142
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
143
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
144
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
145
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
146
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
147
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
148
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
149
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
150
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
151
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
152
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
153
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
154
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
155
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
156
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
157
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
158
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
159
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
160
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
161
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
162
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
163
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
164
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
165
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
166
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
167
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
168
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
169
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
170
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
171
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
172
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
173
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
174
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
175
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
176
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
177
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
178
+ _sfdp_pattern_11_half_training = MultiOutputPattern([view_default_5,
179
+ permute_default_6,
180
+ permute_default_9,
181
+ permute_default_11,
182
+ None
183
+ ])
184
+
185
+
186
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
187
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
188
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
189
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
190
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
191
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
192
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
193
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
194
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
195
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
196
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
197
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
198
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
199
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
200
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
201
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
202
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
203
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
204
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
205
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
206
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
207
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
208
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
209
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
210
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
211
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
212
+ _sfdp_pattern_11_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
35
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
36
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
37
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
38
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
39
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
40
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
41
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
42
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
43
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
44
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
45
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
46
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
47
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
48
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
49
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
50
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
51
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
52
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
53
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
54
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
55
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
56
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
57
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
58
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
59
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
60
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
61
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
62
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
63
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
64
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
65
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
66
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
67
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
68
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
69
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
70
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
71
+ clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
72
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
73
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
74
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
75
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
76
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
77
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
78
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
79
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
80
+ div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale'))
81
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
82
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
83
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
84
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
85
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
86
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
87
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
88
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
89
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
90
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
91
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
92
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
93
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
94
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
95
+ _sfdp_pattern_16_training = MultiOutputPattern([view_default_5,
96
+ permute_default_6,
97
+ permute_default_9,
98
+ permute_default_11,
99
+ None,
100
+ None,
101
+ None
102
+ ])
103
+
104
+
105
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
106
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
107
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
108
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
109
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
110
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
111
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
112
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
113
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
114
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
115
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
116
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
117
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
118
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
119
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
120
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
121
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
122
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
123
+ clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
124
+ expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
125
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
126
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
127
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
128
+ clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
129
+ view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
130
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
131
+ _sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
132
+
133
+
134
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
135
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
136
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
137
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
138
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
139
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
140
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
141
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
142
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
143
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
144
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
145
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
146
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
147
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
148
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
149
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
150
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
151
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
152
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
153
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
154
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
155
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
156
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
157
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
158
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
159
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
160
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
161
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
162
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
163
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
164
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
165
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
166
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
167
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
168
+ clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
169
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
170
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
171
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
172
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
173
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
174
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
175
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
176
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
177
+ div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale'))
178
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
179
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
180
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
181
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
182
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
183
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
184
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
185
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
186
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
187
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
188
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
189
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
190
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
191
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
192
+ _sfdp_pattern_16_bs1_training = MultiOutputPattern([view_default_5,
193
+ permute_default_6,
194
+ permute_default_9,
195
+ permute_default_11,
196
+ None,
197
+ None,
198
+ None
199
+ ])
200
+
201
+
202
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
203
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
204
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
205
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
206
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
207
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
208
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
209
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
210
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
211
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
212
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
213
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
214
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
215
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
216
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
217
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
218
+ clone_default = CallFunction(aten.clone.default, div_Tensor_1)
219
+ expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
220
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
221
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
222
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
223
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
224
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
225
+ _sfdp_pattern_16_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
226
+
227
+
228
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
229
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
230
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
231
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
232
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
233
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
234
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
235
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
236
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
237
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
238
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
239
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
240
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
241
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
242
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
243
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
244
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
245
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
246
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
247
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
248
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
249
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
250
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
251
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
252
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
253
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
254
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
255
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
256
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
257
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
258
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
259
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
260
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
261
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
262
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
263
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
264
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
265
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
266
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
267
+ clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
268
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
269
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
270
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
271
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
272
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
273
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
274
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
275
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
276
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
277
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
278
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
279
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale'))
280
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
281
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
282
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
283
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
284
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
285
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
286
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
287
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
288
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
289
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
290
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
291
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
292
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
293
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
294
+ _sfdp_pattern_16_half_training = MultiOutputPattern([view_default_5,
295
+ permute_default_6,
296
+ permute_default_9,
297
+ permute_default_11,
298
+ None,
299
+ None,
300
+ None
301
+ ])
302
+
303
+
304
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
305
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
306
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
307
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
308
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
309
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
310
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
311
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
312
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
313
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
314
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
315
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
316
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
317
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
318
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
319
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
320
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
321
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
322
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
323
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
324
+ clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
325
+ expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
326
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
327
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
328
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
329
+ clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
330
+ view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
331
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
332
+ _sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
333
+
334
+
335
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
336
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
337
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
338
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
339
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
340
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
341
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
342
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
343
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
344
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
345
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
346
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
347
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
348
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
349
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
350
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
351
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
352
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
353
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
354
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
355
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
356
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
357
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
358
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
359
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
360
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
361
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
362
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
363
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
364
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
365
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
366
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
367
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
368
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
369
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
370
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
371
+ clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
372
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
373
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
374
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
375
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
376
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
377
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
378
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
379
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
380
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
381
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
382
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
383
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale'))
384
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
385
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
386
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
387
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
388
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
389
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
390
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
391
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
392
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
393
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
394
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
395
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
396
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
397
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
398
+ _sfdp_pattern_16_half_bs1_training = MultiOutputPattern([view_default_5,
399
+ permute_default_6,
400
+ permute_default_9,
401
+ permute_default_11,
402
+ None,
403
+ None,
404
+ None
405
+ ])
406
+
407
+
408
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
409
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
410
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
411
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
412
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
413
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
414
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
415
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
416
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
417
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
418
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
419
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
420
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
421
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
422
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
423
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
424
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
425
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
426
+ clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
427
+ expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
428
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
429
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
430
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
431
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
432
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
433
+ _sfdp_pattern_16_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
434
+
435
+
436
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
437
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
438
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
439
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
440
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
441
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
442
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
443
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
444
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
445
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
446
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
447
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
448
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
449
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
450
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
451
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
452
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
453
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
454
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
455
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
456
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
457
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
458
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
459
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
460
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
461
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
462
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
463
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
464
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
465
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
466
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
467
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
468
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
469
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
470
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
471
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
472
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
473
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
474
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2)
475
+ clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
476
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
477
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
478
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
479
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
480
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
481
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
482
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
483
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
484
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
485
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale'))
486
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
487
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
488
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
489
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
490
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
491
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
492
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
493
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
494
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
495
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
496
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
497
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
498
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
499
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
500
+ _sfdp_pattern_16_half_mask_fp32_training = MultiOutputPattern([view_default_5,
501
+ permute_default_6,
502
+ permute_default_9,
503
+ permute_default_11,
504
+ None,
505
+ None,
506
+ None
507
+ ])
508
+
509
+
510
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
511
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
512
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
513
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
514
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
515
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
516
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
517
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
518
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
519
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
520
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
521
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
522
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
523
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
524
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
525
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
526
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
527
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
528
+ clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
529
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
530
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
531
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
532
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
533
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
534
+ clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
535
+ view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
536
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
537
+ _sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
538
+
539
+
540
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
541
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
542
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
543
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
544
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
545
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
546
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
547
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
548
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
549
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
550
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
551
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
552
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
553
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
554
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
555
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
556
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
557
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
558
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
559
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
560
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
561
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
562
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
563
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
564
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
565
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
566
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
567
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
568
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
569
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
570
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
571
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
572
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
573
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
574
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
575
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2)
576
+ clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
577
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
578
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
579
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
580
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
581
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
582
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
583
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
584
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
585
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
586
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale'))
587
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
588
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
589
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
590
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
591
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
592
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
593
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
594
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
595
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
596
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
597
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
598
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
599
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
600
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
601
+ _sfdp_pattern_16_half_mask_fp32_bs1_training = MultiOutputPattern([view_default_5,
602
+ permute_default_6,
603
+ permute_default_9,
604
+ permute_default_11,
605
+ None,
606
+ None,
607
+ None
608
+ ])
609
+
610
+
611
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
612
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
613
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
614
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
615
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
616
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
617
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
618
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
619
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
620
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
621
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
622
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
623
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
624
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
625
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
626
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
627
+ clone_default = CallFunction(aten.clone.default, div_Tensor_1)
628
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
629
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
630
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
631
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
632
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
633
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
634
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
635
+ _sfdp_pattern_16_half_mask_fp32_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
35
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
36
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
37
+ expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
38
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
39
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
40
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
41
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
42
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
43
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
44
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
45
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
46
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
47
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
48
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
49
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
50
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
51
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
52
+ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
53
+ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
54
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
55
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
56
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
57
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
58
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
59
+ expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
60
+ view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
61
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
62
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
63
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
64
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
65
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
66
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
67
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
68
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
69
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
70
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
71
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
72
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
73
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
74
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
75
+ clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
76
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
77
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
78
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
79
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
80
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
81
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
82
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
83
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
84
+ where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, sub_Tensor_1)
85
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
86
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
87
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
88
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
89
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
90
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
91
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
92
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
93
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
94
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
95
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
96
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
97
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
98
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
99
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
100
+ _sfdp_pattern_17_training = MultiOutputPattern([view_default_5,
101
+ permute_default_6,
102
+ permute_default_9,
103
+ permute_default_11,
104
+ None,
105
+ None,
106
+ None
107
+ ])
108
+
109
+
110
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
111
+ view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
112
+ expand_default = CallFunction(aten.expand.default, view_default, Ignored())
113
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
114
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
115
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
116
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
117
+ view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
118
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
119
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
120
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
121
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
122
+ view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
123
+ bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
124
+ view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
125
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
126
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
127
+ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
128
+ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
129
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
130
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
131
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
132
+ clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
133
+ expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored())
134
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
135
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
136
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
137
+ clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
138
+ view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored())
139
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
140
+ _sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
141
+
142
+
143
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
144
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
145
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
146
+ expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
147
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
148
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
149
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
150
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
151
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
152
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
153
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
154
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
155
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
156
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
157
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
158
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
159
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
160
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
161
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
162
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
163
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
164
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
165
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
166
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
167
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
168
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
169
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
170
+ expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
171
+ view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
172
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
173
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
174
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
175
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
176
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
177
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
178
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
179
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
180
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
181
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
182
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
183
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
184
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
185
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
186
+ clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
187
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
188
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
189
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
190
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
191
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
192
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
193
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
194
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
195
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
196
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
197
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
198
+ where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5)
199
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
200
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
201
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
202
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
203
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
204
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
205
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
206
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
207
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
208
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
209
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
210
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
211
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
212
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
213
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
214
+ _sfdp_pattern_17_half_training = MultiOutputPattern([view_default_5,
215
+ permute_default_6,
216
+ permute_default_9,
217
+ permute_default_11,
218
+ None,
219
+ None,
220
+ None
221
+ ])
222
+
223
+
224
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
225
+ view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
226
+ expand_default = CallFunction(aten.expand.default, view_default, Ignored())
227
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
228
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
229
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
230
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
231
+ view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
232
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
233
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
234
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
235
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
236
+ view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
237
+ bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
238
+ view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
239
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
240
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
241
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
242
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
243
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
244
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
245
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
246
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
247
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
248
+ clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
249
+ expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored())
250
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
251
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
252
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
253
+ clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
254
+ view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored())
255
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
256
+ _sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
35
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
36
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
37
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
38
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
39
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
40
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
41
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
42
+ amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
43
+ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
44
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
45
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
46
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
47
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
48
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
49
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
50
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
51
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
52
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
53
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
54
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
55
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
56
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
57
+ alias_default = CallFunction(aten.alias.default, div_Tensor)
58
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
59
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
60
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
61
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
62
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True)
63
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
64
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_1, mul_Tensor_2)
65
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, sub_Tensor_1, KeywordArg('scale_factor'))
66
+ view_default_8 = CallFunction(aten.view.default, mul_Tensor_3, Ignored(), _users=2)
67
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
68
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
69
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
70
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
71
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
72
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
73
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
74
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
75
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
76
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
77
+ _sfdp_pattern_2_training = MultiOutputPattern([view_default_5,
78
+ view_default_9,
79
+ permute_default_4,
80
+ view_default_11,
81
+ None
82
+ ])
83
+
84
+
85
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
86
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
87
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
88
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
89
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
90
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
91
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
92
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
93
+ amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
94
+ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
95
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
96
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
97
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
98
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
99
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
100
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
101
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
102
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
103
+ _sfdp_pattern_2_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
104
+
105
+
106
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
107
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
108
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
109
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
110
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
111
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
112
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
113
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
114
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
115
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
116
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
117
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
118
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
119
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
120
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
121
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
122
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
123
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
124
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
125
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
126
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
127
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
128
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
129
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
130
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
131
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
132
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
133
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
134
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
135
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
136
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
137
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
138
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True)
139
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
140
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_1, mul_Tensor_2)
141
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
142
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, KeywordArg('scale_factor'))
143
+ view_default_8 = CallFunction(aten.view.default, mul_Tensor_3, Ignored(), _users=2)
144
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
145
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
146
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
147
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
148
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
149
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
150
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
151
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
152
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
153
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
154
+ _sfdp_pattern_2_half_training = MultiOutputPattern([view_default_5,
155
+ view_default_9,
156
+ permute_default_4,
157
+ view_default_11,
158
+ None
159
+ ])
160
+
161
+
162
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
163
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
164
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
165
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
166
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
167
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
168
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
169
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
170
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
171
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
172
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
173
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
174
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
175
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
176
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
177
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
178
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
179
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
180
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
181
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
182
+ _sfdp_pattern_2_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_dimV_ops.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _dimV {
18
+ using schema = int64_t (const at::Tensor &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_dimV")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_dimV(Tensor self) -> int")
24
+ static int64_t call(const at::Tensor & self);
25
+ static int64_t redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self);
26
+ };
27
+
28
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_exp.h ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/_foreach_exp_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_foreach_exp(Tensor[] self) -> Tensor[]
26
+ inline ::std::vector<at::Tensor> _foreach_exp(at::TensorList self) {
27
+ return at::_ops::_foreach_exp::call(self);
28
+ }
29
+
30
+ // aten::_foreach_exp_(Tensor(a!)[] self) -> ()
31
+ inline void _foreach_exp_(at::TensorList self) {
32
+ return at::_ops::_foreach_exp_::call(self);
33
+ }
34
+
35
+ // aten::_foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
36
+ inline void _foreach_exp_out(at::TensorList out, at::TensorList self) {
37
+ return at::_ops::_foreach_exp_out::call(self, out);
38
+ }
39
+ // aten::_foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
40
+ inline void _foreach_exp_outf(at::TensorList self, at::TensorList out) {
41
+ return at::_ops::_foreach_exp_out::call(self, out);
42
+ }
43
+
44
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_linalg_slogdet_meta_dispatch.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace meta {
19
+
20
+ TORCH_API ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> _linalg_slogdet(const at::Tensor & A);
21
+ TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _linalg_slogdet_out(at::Tensor & sign, at::Tensor & logabsdet, at::Tensor & LU, at::Tensor & pivots, const at::Tensor & A);
22
+ TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _linalg_slogdet_outf(const at::Tensor & A, at::Tensor & sign, at::Tensor & logabsdet, at::Tensor & LU, at::Tensor & pivots);
23
+
24
+ } // namespace meta
25
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_tensor_from_mask_compositeexplicitautograd_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautograd {
19
+
20
+ TORCH_API at::Tensor & _nested_tensor_from_mask_out(at::Tensor & out, const at::Tensor & t, const at::Tensor & mask, bool mask_check=true);
21
+ TORCH_API at::Tensor & _nested_tensor_from_mask_outf(const at::Tensor & t, const at::Tensor & mask, bool mask_check, at::Tensor & out);
22
+
23
+ } // namespace compositeexplicitautograd
24
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_print_native.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API void _print(c10::string_view s);
20
+ } // namespace native
21
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sample_dirichlet.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/_sample_dirichlet_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor
26
+ inline at::Tensor _sample_dirichlet(const at::Tensor & self, c10::optional<at::Generator> generator=c10::nullopt) {
27
+ return at::_ops::_sample_dirichlet::call(self, generator);
28
+ }
29
+
30
+ // aten::_sample_dirichlet.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
31
+ inline at::Tensor & _sample_dirichlet_out(at::Tensor & out, const at::Tensor & self, c10::optional<at::Generator> generator=c10::nullopt) {
32
+ return at::_ops::_sample_dirichlet_out::call(self, generator, out);
33
+ }
34
+ // aten::_sample_dirichlet.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
35
+ inline at::Tensor & _sample_dirichlet_outf(const at::Tensor & self, c10::optional<at::Generator> generator, at::Tensor & out) {
36
+ return at::_ops::_sample_dirichlet_out::call(self, generator, out);
37
+ }
38
+
39
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_csr_sum.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/_sparse_csr_sum_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_sparse_csr_sum.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
26
+ inline at::Tensor _sparse_csr_sum(const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, c10::optional<at::ScalarType> dtype=c10::nullopt) {
27
+ return at::_ops::_sparse_csr_sum_dim_dtype::call(self, dim, keepdim, dtype);
28
+ }
29
+
30
+ // aten::_sparse_csr_sum.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
31
+ inline at::Tensor & _sparse_csr_sum_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, c10::optional<at::ScalarType> dtype=c10::nullopt) {
32
+ return at::_ops::_sparse_csr_sum_dim_dtype_out::call(self, dim, keepdim, dtype, out);
33
+ }
34
+ // aten::_sparse_csr_sum.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
35
+ inline at::Tensor & _sparse_csr_sum_outf(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, c10::optional<at::ScalarType> dtype, at::Tensor & out) {
36
+ return at::_ops::_sparse_csr_sum_dim_dtype_out::call(self, dim, keepdim, dtype, out);
37
+ }
38
+
39
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_softmax_backward_data_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _sparse_softmax_backward_data {
18
+ using schema = at::Tensor (const at::Tensor &, const at::Tensor &, int64_t, const at::Tensor &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_softmax_backward_data")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self);
26
+ };
27
+
28
+ struct TORCH_API _sparse_softmax_backward_data_out {
29
+ using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, int64_t, const at::Tensor &, at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_softmax_backward_data")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
35
+ static at::Tensor & call(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out);
36
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out);
37
+ };
38
+
39
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_standard_gamma_grad_compositeexplicitautograd_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautograd {
19
+
20
+ TORCH_API at::Tensor & _standard_gamma_grad_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & output);
21
+ TORCH_API at::Tensor & _standard_gamma_grad_outf(const at::Tensor & self, const at::Tensor & output, at::Tensor & out);
22
+
23
+ } // namespace compositeexplicitautograd
24
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_warn_in_autograd_native.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor _test_warn_in_autograd(const at::Tensor & self);
20
+ TORCH_API at::Tensor & _test_warn_in_autograd_out(const at::Tensor & self, at::Tensor & out);
21
+ } // namespace native
22
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_upsample_nearest_exact2d_cuda_dispatch.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cuda {
19
+
20
+ TORCH_API at::Tensor _upsample_nearest_exact2d(const at::Tensor & self, at::IntArrayRef output_size, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
21
+ TORCH_API at::Tensor _upsample_nearest_exact2d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
22
+ TORCH_API at::Tensor & _upsample_nearest_exact2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
23
+ TORCH_API at::Tensor & _upsample_nearest_exact2d_outf(const at::Tensor & self, at::IntArrayRef output_size, c10::optional<double> scales_h, c10::optional<double> scales_w, at::Tensor & out);
24
+ TORCH_API at::Tensor & _upsample_nearest_exact2d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
25
+ TORCH_API at::Tensor & _upsample_nearest_exact2d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional<double> scales_h, c10::optional<double> scales_w, at::Tensor & out);
26
+
27
+ } // namespace cuda
28
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_upsample_nearest_exact3d_backward_meta_dispatch.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace meta {
19
+
20
+ TORCH_API at::Tensor _upsample_nearest_exact3d_backward(const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional<double> scales_d=c10::nullopt, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
21
+ TORCH_API at::Tensor _upsample_nearest_exact3d_backward_symint(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional<double> scales_d=c10::nullopt, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
22
+ TORCH_API at::Tensor & _upsample_nearest_exact3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional<double> scales_d=c10::nullopt, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
23
+ TORCH_API at::Tensor & _upsample_nearest_exact3d_backward_outf(const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional<double> scales_d, c10::optional<double> scales_h, c10::optional<double> scales_w, at::Tensor & grad_input);
24
+ TORCH_API at::Tensor & _upsample_nearest_exact3d_backward_symint_out(at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional<double> scales_d=c10::nullopt, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
25
+ TORCH_API at::Tensor & _upsample_nearest_exact3d_backward_symint_outf(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional<double> scales_d, c10::optional<double> scales_h, c10::optional<double> scales_w, at::Tensor & grad_input);
26
+
27
+ } // namespace meta
28
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/acosh_meta_dispatch.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace meta {
19
+
20
+ TORCH_API at::Tensor acosh(const at::Tensor & self);
21
+ TORCH_API at::Tensor & acosh_out(at::Tensor & out, const at::Tensor & self);
22
+ TORCH_API at::Tensor & acosh_outf(const at::Tensor & self, at::Tensor & out);
23
+ TORCH_API at::Tensor & acosh_(at::Tensor & self);
24
+
25
+ } // namespace meta
26
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/all_cuda_dispatch.h ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cuda {
19
+
20
+ TORCH_API at::Tensor all(const at::Tensor & self, int64_t dim, bool keepdim=false);
21
+ TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false);
22
+ TORCH_API at::Tensor & all_outf(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out);
23
+ TORCH_API at::Tensor all(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false);
24
+ TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false);
25
+ TORCH_API at::Tensor & all_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out);
26
+ TORCH_API at::Tensor all(const at::Tensor & self);
27
+ TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self);
28
+ TORCH_API at::Tensor & all_outf(const at::Tensor & self, at::Tensor & out);
29
+
30
+ } // namespace cuda
31
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atanh_native.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+ #include <ATen/ops/atanh_meta.h>
16
+
17
+ namespace at {
18
+ namespace native {
19
+ struct TORCH_API structured_atanh_out : public at::meta::structured_atanh {
20
+ void impl(const at::Tensor & self, const at::Tensor & out);
21
+ };
22
+ TORCH_API at::Tensor atanh_sparse(const at::Tensor & self);
23
+ TORCH_API at::Tensor & atanh_sparse_out(const at::Tensor & self, at::Tensor & out);
24
+ TORCH_API at::Tensor & atanh_sparse_(at::Tensor & self);
25
+ TORCH_API at::Tensor atanh_sparse_csr(const at::Tensor & self);
26
+ TORCH_API at::Tensor & atanh_sparse_csr_out(const at::Tensor & self, at::Tensor & out);
27
+ TORCH_API at::Tensor & atanh_sparse_csr_(at::Tensor & self);
28
+ } // namespace native
29
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atleast_3d_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API atleast_3d {
18
+ using schema = at::Tensor (const at::Tensor &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::atleast_3d")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "atleast_3d(Tensor self) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & self);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self);
26
+ };
27
+
28
+ struct TORCH_API atleast_3d_Sequence {
29
+ using schema = ::std::vector<at::Tensor> (at::TensorList);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::atleast_3d")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Sequence")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "atleast_3d.Sequence(Tensor[] tensors) -> Tensor[]")
35
+ static ::std::vector<at::Tensor> call(at::TensorList tensors);
36
+ static ::std::vector<at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors);
37
+ };
38
+
39
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/binary_cross_entropy.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/binary_cross_entropy_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor
26
+ inline at::Tensor binary_cross_entropy(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight={}, int64_t reduction=at::Reduction::Mean) {
27
+ return at::_ops::binary_cross_entropy::call(self, target, weight, reduction);
28
+ }
29
+
30
+ // aten::binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
31
+ inline at::Tensor & binary_cross_entropy_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight={}, int64_t reduction=at::Reduction::Mean) {
32
+ return at::_ops::binary_cross_entropy_out::call(self, target, weight, reduction, out);
33
+ }
34
+ // aten::binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
35
+ inline at::Tensor & binary_cross_entropy_outf(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, at::Tensor & out) {
36
+ return at::_ops::binary_cross_entropy_out::call(self, target, weight, reduction, out);
37
+ }
38
+
39
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/binary_cross_entropy_with_logits_native.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor binary_cross_entropy_with_logits(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight={}, const c10::optional<at::Tensor> & pos_weight={}, int64_t reduction=at::Reduction::Mean);
20
+ TORCH_API at::Tensor & binary_cross_entropy_with_logits_out(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & pos_weight, int64_t reduction, at::Tensor & out);
21
+ } // namespace native
22
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/bitwise_xor_ops.h ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API bitwise_xor_Tensor_out {
18
+ using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, at::Tensor &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor_out")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)")
24
+ static at::Tensor & call(const at::Tensor & self, const at::Tensor & other, at::Tensor & out);
25
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out);
26
+ };
27
+
28
+ struct TORCH_API bitwise_xor_Scalar_out {
29
+ using schema = at::Tensor & (const at::Tensor &, const at::Scalar &, at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Scalar_out")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)")
35
+ static at::Tensor & call(const at::Tensor & self, const at::Scalar & other, at::Tensor & out);
36
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out);
37
+ };
38
+
39
+ struct TORCH_API bitwise_xor_Scalar {
40
+ using schema = at::Tensor (const at::Tensor &, const at::Scalar &);
41
+ using ptr_schema = schema*;
42
+ // See Note [static constexpr char* members for windows NVCC]
43
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor")
44
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Scalar")
45
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor")
46
+ static at::Tensor call(const at::Tensor & self, const at::Scalar & other);
47
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other);
48
+ };
49
+
50
+ struct TORCH_API bitwise_xor_Scalar_Tensor {
51
+ using schema = at::Tensor (const at::Scalar &, const at::Tensor &);
52
+ using ptr_schema = schema*;
53
+ // See Note [static constexpr char* members for windows NVCC]
54
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor")
55
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Scalar_Tensor")
56
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor")
57
+ static at::Tensor call(const at::Scalar & self, const at::Tensor & other);
58
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other);
59
+ };
60
+
61
+ struct TORCH_API bitwise_xor_Tensor {
62
+ using schema = at::Tensor (const at::Tensor &, const at::Tensor &);
63
+ using ptr_schema = schema*;
64
+ // See Note [static constexpr char* members for windows NVCC]
65
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor")
66
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor")
67
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor")
68
+ static at::Tensor call(const at::Tensor & self, const at::Tensor & other);
69
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other);
70
+ };
71
+
72
+ struct TORCH_API bitwise_xor__Scalar {
73
+ using schema = at::Tensor & (at::Tensor &, const at::Scalar &);
74
+ using ptr_schema = schema*;
75
+ // See Note [static constexpr char* members for windows NVCC]
76
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor_")
77
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Scalar")
78
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)")
79
+ static at::Tensor & call(at::Tensor & self, const at::Scalar & other);
80
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other);
81
+ };
82
+
83
+ struct TORCH_API bitwise_xor__Tensor {
84
+ using schema = at::Tensor & (at::Tensor &, const at::Tensor &);
85
+ using ptr_schema = schema*;
86
+ // See Note [static constexpr char* members for windows NVCC]
87
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor_")
88
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor")
89
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)")
90
+ static at::Tensor & call(at::Tensor & self, const at::Tensor & other);
91
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other);
92
+ };
93
+
94
+ struct TORCH_API bitwise_xor_Scalar_Tensor_out {
95
+ using schema = at::Tensor & (const at::Scalar &, const at::Tensor &, at::Tensor &);
96
+ using ptr_schema = schema*;
97
+ // See Note [static constexpr char* members for windows NVCC]
98
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor")
99
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Scalar_Tensor_out")
100
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)")
101
+ static at::Tensor & call(const at::Scalar & self, const at::Tensor & other, at::Tensor & out);
102
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out);
103
+ };
104
+
105
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/clip_ops.h ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API clip {
18
+ using schema = at::Tensor (const at::Tensor &, const c10::optional<at::Scalar> &, const c10::optional<at::Scalar> &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::clip")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & self, const c10::optional<at::Scalar> & min, const c10::optional<at::Scalar> & max);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional<at::Scalar> & min, const c10::optional<at::Scalar> & max);
26
+ };
27
+
28
+ struct TORCH_API clip_Tensor {
29
+ using schema = at::Tensor (const at::Tensor &, const c10::optional<at::Tensor> &, const c10::optional<at::Tensor> &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::clip")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor")
35
+ static at::Tensor call(const at::Tensor & self, const c10::optional<at::Tensor> & min, const c10::optional<at::Tensor> & max);
36
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional<at::Tensor> & min, const c10::optional<at::Tensor> & max);
37
+ };
38
+
39
+ struct TORCH_API clip_ {
40
+ using schema = at::Tensor & (at::Tensor &, const c10::optional<at::Scalar> &, const c10::optional<at::Scalar> &);
41
+ using ptr_schema = schema*;
42
+ // See Note [static constexpr char* members for windows NVCC]
43
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::clip_")
44
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
45
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)")
46
+ static at::Tensor & call(at::Tensor & self, const c10::optional<at::Scalar> & min, const c10::optional<at::Scalar> & max);
47
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::optional<at::Scalar> & min, const c10::optional<at::Scalar> & max);
48
+ };
49
+
50
+ struct TORCH_API clip__Tensor {
51
+ using schema = at::Tensor & (at::Tensor &, const c10::optional<at::Tensor> &, const c10::optional<at::Tensor> &);
52
+ using ptr_schema = schema*;
53
+ // See Note [static constexpr char* members for windows NVCC]
54
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::clip_")
55
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor")
56
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)")
57
+ static at::Tensor & call(at::Tensor & self, const c10::optional<at::Tensor> & min, const c10::optional<at::Tensor> & max);
58
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::optional<at::Tensor> & min, const c10::optional<at::Tensor> & max);
59
+ };
60
+
61
+ struct TORCH_API clip_out {
62
+ using schema = at::Tensor & (const at::Tensor &, const c10::optional<at::Scalar> &, const c10::optional<at::Scalar> &, at::Tensor &);
63
+ using ptr_schema = schema*;
64
+ // See Note [static constexpr char* members for windows NVCC]
65
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::clip")
66
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
67
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)")
68
+ static at::Tensor & call(const at::Tensor & self, const c10::optional<at::Scalar> & min, const c10::optional<at::Scalar> & max, at::Tensor & out);
69
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional<at::Scalar> & min, const c10::optional<at::Scalar> & max, at::Tensor & out);
70
+ };
71
+
72
+ struct TORCH_API clip_Tensor_out {
73
+ using schema = at::Tensor & (const at::Tensor &, const c10::optional<at::Tensor> &, const c10::optional<at::Tensor> &, at::Tensor &);
74
+ using ptr_schema = schema*;
75
+ // See Note [static constexpr char* members for windows NVCC]
76
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::clip")
77
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor_out")
78
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)")
79
+ static at::Tensor & call(const at::Tensor & self, const c10::optional<at::Tensor> & min, const c10::optional<at::Tensor> & max, at::Tensor & out);
80
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional<at::Tensor> & min, const c10::optional<at::Tensor> & max, at::Tensor & out);
81
+ };
82
+
83
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cosine_similarity_compositeimplicitautograd_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeimplicitautograd {
19
+
20
+ TORCH_API at::Tensor cosine_similarity(const at::Tensor & x1, const at::Tensor & x2, int64_t dim=1, double eps=1e-08);
21
+
22
+ } // namespace compositeimplicitautograd
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cumprod_backward_native.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor cumprod_backward(const at::Tensor & grad, const at::Tensor & input, int64_t dim, const at::Tensor & output);
20
+ } // namespace native
21
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/digamma.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/digamma_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
26
+ inline at::Tensor & digamma_out(at::Tensor & out, const at::Tensor & self) {
27
+ return at::_ops::digamma_out::call(self, out);
28
+ }
29
+ // aten::digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
30
+ inline at::Tensor & digamma_outf(const at::Tensor & self, at::Tensor & out) {
31
+ return at::_ops::digamma_out::call(self, out);
32
+ }
33
+
34
+ // aten::digamma(Tensor self) -> Tensor
35
+ inline at::Tensor digamma(const at::Tensor & self) {
36
+ return at::_ops::digamma::call(self);
37
+ }
38
+
39
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/divide_native.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor divide(const at::Tensor & self, const at::Tensor & other);
20
+ TORCH_API at::Tensor & divide_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out);
21
+ TORCH_API at::Tensor & divide_(at::Tensor & self, const at::Tensor & other);
22
+ TORCH_API at::Tensor divide(const at::Tensor & self, const at::Scalar & other);
23
+ TORCH_API at::Tensor & divide_(at::Tensor & self, const at::Scalar & other);
24
+ TORCH_API at::Tensor divide(const at::Tensor & self, const at::Tensor & other, c10::optional<c10::string_view> rounding_mode);
25
+ TORCH_API at::Tensor & divide_out(const at::Tensor & self, const at::Tensor & other, c10::optional<c10::string_view> rounding_mode, at::Tensor & out);
26
+ TORCH_API at::Tensor & divide_(at::Tensor & self, const at::Tensor & other, c10::optional<c10::string_view> rounding_mode);
27
+ TORCH_API at::Tensor divide(const at::Tensor & self, const at::Scalar & other, c10::optional<c10::string_view> rounding_mode);
28
+ TORCH_API at::Tensor & divide_(at::Tensor & self, const at::Scalar & other, c10::optional<c10::string_view> rounding_mode);
29
+ } // namespace native
30
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/embedding_dense_backward_compositeexplicitautograd_dispatch.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautograd {
19
+
20
+ TORCH_API at::Tensor & embedding_dense_backward_out(at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq);
21
+ TORCH_API at::Tensor & embedding_dense_backward_outf(const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, at::Tensor & out);
22
+ TORCH_API at::Tensor & embedding_dense_backward_symint_out(at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq);
23
+ TORCH_API at::Tensor & embedding_dense_backward_symint_outf(const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, at::Tensor & out);
24
+
25
+ } // namespace compositeexplicitautograd
26
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/expm1_compositeexplicitautogradnonfunctional_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautogradnonfunctional {
19
+
20
+ TORCH_API at::Tensor expm1(const at::Tensor & self);
21
+ TORCH_API at::Tensor & expm1_(at::Tensor & self);
22
+
23
+ } // namespace compositeexplicitautogradnonfunctional
24
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask_backward.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::fake_quantize_per_channel_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor
26
+ inline at::Tensor fake_quantize_per_channel_affine_cachemask_backward(const at::Tensor & grad, const at::Tensor & mask) {
27
+ return at::_ops::fake_quantize_per_channel_affine_cachemask_backward::call(grad, mask);
28
+ }
29
+
30
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/fft_ihfftn_compositeimplicitautograd_dispatch.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeimplicitautograd {
19
+
20
+ TORCH_API at::Tensor fft_ihfftn(const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional<c10::string_view> norm=c10::nullopt);
21
+ TORCH_API at::Tensor fft_ihfftn_symint(const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional<c10::string_view> norm=c10::nullopt);
22
+ TORCH_API const at::Tensor & fft_ihfftn_out(const at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional<c10::string_view> norm=c10::nullopt);
23
+ TORCH_API const at::Tensor & fft_ihfftn_outf(const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional<c10::string_view> norm, const at::Tensor & out);
24
+ TORCH_API const at::Tensor & fft_ihfftn_symint_out(const at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional<c10::string_view> norm=c10::nullopt);
25
+ TORCH_API const at::Tensor & fft_ihfftn_symint_outf(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional<c10::string_view> norm, const at::Tensor & out);
26
+
27
+ } // namespace compositeimplicitautograd
28
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/gcd_cuda_dispatch.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cuda {
19
+
20
+ TORCH_API at::Tensor gcd(const at::Tensor & self, const at::Tensor & other);
21
+ TORCH_API at::Tensor & gcd_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other);
22
+ TORCH_API at::Tensor & gcd_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out);
23
+ TORCH_API at::Tensor & gcd_(at::Tensor & self, const at::Tensor & other);
24
+
25
+ } // namespace cuda
26
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/grid_sampler_2d.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/grid_sampler_2d_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
26
+ inline at::Tensor grid_sampler_2d(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
27
+ return at::_ops::grid_sampler_2d::call(input, grid, interpolation_mode, padding_mode, align_corners);
28
+ }
29
+
30
+ // aten::grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
31
+ inline at::Tensor & grid_sampler_2d_out(at::Tensor & out, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
32
+ return at::_ops::grid_sampler_2d_out::call(input, grid, interpolation_mode, padding_mode, align_corners, out);
33
+ }
34
+ // aten::grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
35
+ inline at::Tensor & grid_sampler_2d_outf(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out) {
36
+ return at::_ops::grid_sampler_2d_out::call(input, grid, interpolation_mode, padding_mode, align_corners, out);
37
+ }
38
+
39
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/hardswish_backward.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/hardswish_backward_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::hardswish_backward(Tensor grad_output, Tensor self) -> Tensor
26
+ inline at::Tensor hardswish_backward(const at::Tensor & grad_output, const at::Tensor & self) {
27
+ return at::_ops::hardswish_backward::call(grad_output, self);
28
+ }
29
+
30
+ // aten::hardswish_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
31
+ inline at::Tensor & hardswish_backward_out(at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) {
32
+ return at::_ops::hardswish_backward_out::call(grad_output, self, out);
33
+ }
34
+ // aten::hardswish_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
35
+ inline at::Tensor & hardswish_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) {
36
+ return at::_ops::hardswish_backward_out::call(grad_output, self, out);
37
+ }
38
+
39
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/hstack.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/hstack_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::hstack(Tensor[] tensors) -> Tensor
26
+ inline at::Tensor hstack(at::TensorList tensors) {
27
+ return at::_ops::hstack::call(tensors);
28
+ }
29
+
30
+ // aten::hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
31
+ inline at::Tensor & hstack_out(at::Tensor & out, at::TensorList tensors) {
32
+ return at::_ops::hstack_out::call(tensors, out);
33
+ }
34
+ // aten::hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
35
+ inline at::Tensor & hstack_outf(at::TensorList tensors, at::Tensor & out) {
36
+ return at::_ops::hstack_out::call(tensors, out);
37
+ }
38
+
39
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/huber_loss_backward.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/huber_loss_backward_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!)
26
+ inline at::Tensor & huber_loss_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) {
27
+ return at::_ops::huber_loss_backward_out::call(grad_output, self, target, reduction, delta, grad_input);
28
+ }
29
+ // aten::huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!)
30
+ inline at::Tensor & huber_loss_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta, at::Tensor & grad_input) {
31
+ return at::_ops::huber_loss_backward_out::call(grad_output, self, target, reduction, delta, grad_input);
32
+ }
33
+
34
+ // aten::huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor
35
+ inline at::Tensor huber_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) {
36
+ return at::_ops::huber_loss_backward::call(grad_output, self, target, reduction, delta);
37
+ }
38
+
39
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/lift_fresh.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/lift_fresh_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::lift_fresh(Tensor(a) self) -> Tensor(a)
26
+ inline at::Tensor lift_fresh(const at::Tensor & self) {
27
+ return at::_ops::lift_fresh::call(self);
28
+ }
29
+
30
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linalg_eigh_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API linalg_eigh {
18
+ using schema = ::std::tuple<at::Tensor,at::Tensor> (const at::Tensor &, c10::string_view);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::linalg_eigh")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "linalg_eigh(Tensor self, str UPLO=\"L\") -> (Tensor eigenvalues, Tensor eigenvectors)")
24
+ static ::std::tuple<at::Tensor,at::Tensor> call(const at::Tensor & self, c10::string_view UPLO);
25
+ static ::std::tuple<at::Tensor,at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO);
26
+ };
27
+
28
+ struct TORCH_API linalg_eigh_eigvals {
29
+ using schema = ::std::tuple<at::Tensor &,at::Tensor &> (const at::Tensor &, c10::string_view, at::Tensor &, at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::linalg_eigh")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "eigvals")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "linalg_eigh.eigvals(Tensor self, str UPLO=\"L\", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)")
35
+ static ::std::tuple<at::Tensor &,at::Tensor &> call(const at::Tensor & self, c10::string_view UPLO, at::Tensor & eigvals, at::Tensor & eigvecs);
36
+ static ::std::tuple<at::Tensor &,at::Tensor &> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO, at::Tensor & eigvals, at::Tensor & eigvecs);
37
+ };
38
+
39
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linalg_lu_cpu_dispatch.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cpu {
19
+
20
+ TORCH_API ::std::tuple<at::Tensor,at::Tensor,at::Tensor> linalg_lu(const at::Tensor & A, bool pivot=true);
21
+ TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> linalg_lu_out(at::Tensor & P, at::Tensor & L, at::Tensor & U, const at::Tensor & A, bool pivot=true);
22
+ TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> linalg_lu_outf(const at::Tensor & A, bool pivot, at::Tensor & P, at::Tensor & L, at::Tensor & U);
23
+
24
+ } // namespace cpu
25
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linalg_pinv_compositeexplicitautogradnonfunctional_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautogradnonfunctional {
19
+
20
+ TORCH_API at::Tensor linalg_pinv(const at::Tensor & self, const c10::optional<at::Tensor> & atol={}, const c10::optional<at::Tensor> & rtol={}, bool hermitian=false);
21
+
22
+ } // namespace compositeexplicitautogradnonfunctional
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linear.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/linear_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
26
+ inline at::Tensor linear(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias={}) {
27
+ return at::_ops::linear::call(input, weight, bias);
28
+ }
29
+
30
+ // aten::linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)
31
+ inline at::Tensor & linear_out(at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias={}) {
32
+ return at::_ops::linear_out::call(input, weight, bias, out);
33
+ }
34
+ // aten::linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)
35
+ inline at::Tensor & linear_outf(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::Tensor & out) {
36
+ return at::_ops::linear_out::call(input, weight, bias, out);
37
+ }
38
+
39
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/log_sigmoid_forward_cpu_dispatch.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cpu {
19
+
20
+ TORCH_API ::std::tuple<at::Tensor,at::Tensor> log_sigmoid_forward(const at::Tensor & self);
21
+ TORCH_API ::std::tuple<at::Tensor &,at::Tensor &> log_sigmoid_forward_out(at::Tensor & output, at::Tensor & buffer, const at::Tensor & self);
22
+ TORCH_API ::std::tuple<at::Tensor &,at::Tensor &> log_sigmoid_forward_outf(const at::Tensor & self, at::Tensor & output, at::Tensor & buffer);
23
+
24
+ } // namespace cpu
25
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/logaddexp2_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API logaddexp2_out {
18
+ using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, at::Tensor &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::logaddexp2")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)")
24
+ static at::Tensor & call(const at::Tensor & self, const at::Tensor & other, at::Tensor & out);
25
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out);
26
+ };
27
+
28
+ struct TORCH_API logaddexp2 {
29
+ using schema = at::Tensor (const at::Tensor &, const at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::logaddexp2")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "logaddexp2(Tensor self, Tensor other) -> Tensor")
35
+ static at::Tensor call(const at::Tensor & self, const at::Tensor & other);
36
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other);
37
+ };
38
+
39
+ }} // namespace at::_ops