koichi12 commited on
Commit
4aab8df
·
verified ·
1 Parent(s): c169562

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc +0 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-311.pyc +0 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-311.pyc +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py +277 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py +221 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py +157 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py +212 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py +210 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py +567 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py +1980 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py +139 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc +0 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/central_index.cpython-311.pyc +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py +232 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py +142 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py +236 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py +202 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py +206 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py +213 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-311.pyc +0 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-311.pyc +0 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/bmm.py +128 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm.py +312 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_common.py +262 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_plus_mm.py +235 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/EmptyTensor.h +29 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/IndexKernels.h +630 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGuardImpl.h +174 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h +369 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h +445 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h +335 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/Factory.h +24 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamUtils.h +42 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamsHash.h +104 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_cpu_dispatch.h +23 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_bin_edges_cpu_dispatch.h +23 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_from_bin_tensors_ops.h +39 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_differentiable_gru_cell_backward.h +30 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_fused_gru_cell_backward.h +39 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/asin_meta_dispatch.h +26 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_ops.h +39 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/bitwise_or.h +67 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h +23 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc ADDED
Binary file (6.95 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (226 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-311.pyc ADDED
Binary file (59.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-311.pyc ADDED
Binary file (61.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-311.pyc ADDED
Binary file (7.12 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import itertools
3
+
4
+ import torch
5
+ from ..._dynamo.utils import counters
6
+
7
+ from ..pattern_matcher import Arg, CallFunction, KeywordArg
8
+ from .freezing_patterns import register_binary_folding_pattern
9
+
10
+ aten = torch.ops.aten
11
+ prims = torch.ops.prims
12
+
13
+
14
+ def mark_mixed_dtype_conv(conv):
15
+ conv_dtype = conv.meta["val"].dtype
16
+ if conv_dtype not in (torch.float16, torch.bfloat16):
17
+ return
18
+
19
+ if not len(conv.users) == 1:
20
+ return
21
+
22
+ conv_user = next(iter(conv.users.keys()))
23
+ if not isinstance(conv_user.meta["val"], torch.Tensor):
24
+ return
25
+
26
+ if not conv_user.meta["val"].dtype == torch.float32:
27
+ return
28
+
29
+ while conv_user.target in _binary_ops:
30
+ if not len(conv_user.users) == 1:
31
+ return
32
+
33
+ conv_user = next(iter(conv_user.users.keys()))
34
+
35
+ if not (
36
+ conv_user.target == prims.convert_element_type.default
37
+ and conv_user.args[1] == conv_dtype
38
+ ):
39
+ return
40
+
41
+ conv.meta["_allow_conv_mixed_dtype_folding"] = conv_dtype
42
+
43
+
44
+ def mark_mixed_dtype_allowed_convs(gm):
45
+ """
46
+ Mark convolutions which we will binary fold even with mixed precision constants. We constant fold in the higher precision
47
+ for better accuracy and then recover the original precision after.
48
+ """
49
+ for node in gm.graph.nodes:
50
+ if node.target is aten.convolution.default:
51
+ mark_mixed_dtype_conv(node)
52
+
53
+
54
+ def recover_original_precision_folded_convs(gm):
55
+ """
56
+ After binary folding conv weights and biases to a higher dtype, recover the original precision they were in.
57
+ """
58
+ graph = gm.graph
59
+ convs = [node for node in graph.nodes if node.target is aten.convolution.default]
60
+ for node in convs:
61
+ orig_dtype = node.meta.get("_allow_conv_mixed_dtype_folding", None)
62
+ if orig_dtype is None:
63
+ continue
64
+
65
+ with graph.inserting_before(node):
66
+ for idx in [1, 2]:
67
+ old_input = node.args[idx]
68
+ if old_input is None:
69
+ continue
70
+
71
+ new_input = graph.create_node(
72
+ "call_function",
73
+ prims.convert_element_type.default,
74
+ (old_input, orig_dtype),
75
+ )
76
+ node.replace_input_with(old_input, new_input)
77
+
78
+
79
+ _binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor]
80
+
81
+
82
+ @functools.lru_cache(None)
83
+ def binary_folding_init():
84
+ _conv_args = [Arg() for _ in range(9)]
85
+ _computation_ops = [aten.convolution.default]
86
+ _computation_calls = [CallFunction(aten.convolution.default, *_conv_args, _users=1)]
87
+
88
+ """
89
+ In order to fuse add/sub/mul/div with conv, the dimensions of its
90
+ constant tensor must satisfy the following:
91
+ - with resizing, broadcast to w/ weight/bias tensor shape
92
+ - broadcast to the conv output shape
93
+ It needs to have a shape that can resize to weight/bias
94
+ tensor shape because we need to run the op with the conv
95
+ weights/bias without changing their sizes.
96
+ It needs to broadcast to the conv output shape so that we do
97
+ accidentally change the shape of op output by pre-fusing it
98
+ compared to eager.
99
+ The only dimension value shared by weight/bias/conv output
100
+ is they all contain a dim with value = channels-out. In the
101
+ conv output tensor, this is in the second dimension,
102
+ so the pointwise op tensor may have a second dimension of
103
+ value == channels-out, but all the other dimensions have to be 1
104
+ """
105
+
106
+ def _op_not_broadcasting_with_conv(weight_tensor, other_tensor):
107
+ # According to opDoesNotBroadCastWithConv of frozen_conv_folding.cpp
108
+ weight_shape = weight_tensor.shape
109
+ other_shape = other_tensor.shape
110
+ if len(weight_shape) < len(other_shape):
111
+ return False
112
+ if len(weight_shape) == len(other_shape) + 1:
113
+ # weight shape is [o, i, *], other_shape is [o, 1...].
114
+ for i in reversed(range(len(other_shape))):
115
+ if i == 0 and weight_shape[0] == other_shape[i]:
116
+ continue
117
+ if other_shape[i] != 1:
118
+ return False
119
+ else:
120
+ # weight shape is [o, i, *], other_shape is [1, i, *]
121
+ for i in reversed(range(len(other_shape))):
122
+ if i == 1 and weight_shape[0] == other_shape[i]:
123
+ continue
124
+ if other_shape[i] != 1:
125
+ return False
126
+ return True
127
+
128
+ def _check_conv_and_broadcast_op(conv_node, other):
129
+ # According to checkConvAndBroadcastingOpPreConditions of frozen_conv_folding.cpp.
130
+ # conv.weight
131
+ if conv_node.args[1].op != "get_attr":
132
+ return False
133
+ # conv.bias
134
+ if conv_node.args[1] is not None and conv_node.args[1].op != "get_attr":
135
+ return False
136
+ if (
137
+ not isinstance(other, int)
138
+ and not isinstance(other, float)
139
+ and other.op != "get_attr"
140
+ ):
141
+ return False
142
+
143
+ if not len(conv_node.args[1].users) == 1:
144
+ return False
145
+
146
+ weight_meta_value = conv_node.args[1].meta.get("val")
147
+ if weight_meta_value is None:
148
+ return False
149
+ # Avoid fusing op that causes type promotion
150
+ # restricting to float avoids int/float difficulties with scalar overload
151
+ if not weight_meta_value.is_floating_point():
152
+ return False
153
+ if isinstance(other, torch.fx.Node) and other.op == "get_attr":
154
+ other_meta_value = other.meta.get("val")
155
+ if not other_meta_value.is_floating_point():
156
+ return False
157
+ if (
158
+ torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype)
159
+ != weight_meta_value.dtype
160
+ ):
161
+ if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False):
162
+ return False
163
+
164
+ if (
165
+ other_meta_value.dtype != torch.float
166
+ and weight_meta_value.dtype not in (torch.float16, torch.bfloat16)
167
+ ):
168
+ return False
169
+
170
+ if not _op_not_broadcasting_with_conv(weight_meta_value, other_meta_value):
171
+ return False
172
+ else:
173
+ # TODO: support scalar case
174
+ return False
175
+
176
+ return True
177
+
178
+ def _is_foldable_pattern(match):
179
+ binary_node = match.output_node()
180
+ computation_node = binary_node.args[0]
181
+ other = binary_node.args[1]
182
+ if binary_node.args[0].target not in _computation_ops:
183
+ computation_node = binary_node.args[1]
184
+ other = binary_node.args[0]
185
+ if binary_node.args[0].target == aten.convolution.default:
186
+ return _check_conv_and_broadcast_op(computation_node, other)
187
+
188
+ return False
189
+
190
+ def resize_scalar_or_tensor_to_shape(graph, other, shape):
191
+ # TODO: support scalar case
192
+ if other.meta.get("val").numel() == 1:
193
+ # expand errors if the shape input has less # dims than the tensor input
194
+ res = graph.create_node(
195
+ "call_function",
196
+ aten.reshape.default,
197
+ (other, (1,)),
198
+ )
199
+ res = graph.create_node(
200
+ "call_function",
201
+ aten.expand.default,
202
+ (res, shape),
203
+ )
204
+ else:
205
+ res = graph.create_node(
206
+ "call_function",
207
+ aten.reshape.default,
208
+ (other, shape),
209
+ )
210
+ return res
211
+
212
+ def _create_new_conv_node(graph, conv_node, binary_node, other):
213
+ assert conv_node.target == aten.convolution.default
214
+ conv_args = list(conv_node.args)
215
+ weight_meta_value = conv_node.args[1].meta.get("val")
216
+ bias = conv_args[2]
217
+ if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]:
218
+ other_reshape = resize_scalar_or_tensor_to_shape(
219
+ graph, other, (weight_meta_value.size(0),)
220
+ )
221
+ new_bias = graph.create_node(
222
+ "call_function",
223
+ binary_node.target,
224
+ (0 if bias is None else bias, other_reshape),
225
+ )
226
+ conv_args[2] = new_bias
227
+ else:
228
+ assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor]
229
+ weight_broadcast_shape = [1 for _ in range(len(weight_meta_value.shape))]
230
+ weight_broadcast_shape[0] = weight_meta_value.size(0)
231
+ other_reshape1 = resize_scalar_or_tensor_to_shape(
232
+ graph, other, tuple(weight_broadcast_shape)
233
+ )
234
+ new_weight = graph.create_node(
235
+ "call_function", binary_node.target, (conv_args[1], other_reshape1)
236
+ )
237
+ new_weight.meta.update(conv_args[1].meta)
238
+ conv_args[1] = new_weight
239
+ if bias is not None:
240
+ other_reshape = resize_scalar_or_tensor_to_shape(
241
+ graph, other, (weight_meta_value.size(0),)
242
+ )
243
+ new_bias = graph.create_node(
244
+ "call_function", binary_node.target, (bias, other_reshape)
245
+ )
246
+ new_bias.meta.update(bias.meta)
247
+ conv_args[2] = new_bias
248
+ return graph.create_node("call_function", conv_node.target, tuple(conv_args))
249
+
250
+ for _computation_call, binary_op in itertools.product(
251
+ _computation_calls, _binary_ops
252
+ ):
253
+
254
+ @register_binary_folding_pattern(
255
+ CallFunction(binary_op, _computation_call, KeywordArg("other")),
256
+ extra_check=_is_foldable_pattern,
257
+ )
258
+ def folded_op(match, *args, **kwargs):
259
+ counters["inductor"]["binary_folding"] += 1
260
+ other = kwargs.get("other")
261
+ binary_node = match.output_node()
262
+ computation_node = (
263
+ binary_node.args[0]
264
+ if binary_node.args[0].target in _computation_ops
265
+ else binary_node.args[1]
266
+ )
267
+ graph = match.graph
268
+ with graph.inserting_before(binary_node):
269
+ # TODO: support linear?
270
+ assert computation_node.target == aten.convolution.default
271
+ new_computation_node = _create_new_conv_node(
272
+ graph, computation_node, binary_node, other
273
+ )
274
+ binary_node.replace_all_uses_with(new_computation_node)
275
+ new_computation_node.meta.update(computation_node.meta)
276
+ graph.erase_node(binary_node)
277
+ graph.erase_node(computation_node)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch._dynamo.utils import counters
7
+ from torch._inductor import utils
8
+
9
+ from ..pattern_matcher import (
10
+ Arg,
11
+ CallFunction,
12
+ config_flag,
13
+ Ignored,
14
+ Match,
15
+ register_graph_pattern,
16
+ )
17
+ from .post_grad import decompose_mm_pass
18
+
19
+ aten = torch.ops.aten
20
+ log = logging.getLogger(__name__)
21
+
22
+ # TODO: need a better strategy for decomposing mm
23
+ MIN_FIRST_DIMENSION_DECOMPOSITION = 10240
24
+ MAX_OTHER_DIMENSION_DECOMPOSITION = 32
25
+
26
+
27
+ def check_device(a: Tensor, b: Tensor) -> bool:
28
+ return a.is_cuda and b.is_cuda
29
+
30
+
31
+ def should_decompose_common(
32
+ mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None
33
+ ) -> bool:
34
+ return (
35
+ torch._inductor.config.decompose_mem_bound_mm
36
+ and check_device(mat1, mat2)
37
+ and not utils.any_is_symbolic(mat1, mat2, input)
38
+ )
39
+
40
+
41
+ def should_decompose_bmm(mat1, mat2) -> bool:
42
+ if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
43
+ mat1 = mat1.meta["val"]
44
+ mat2 = mat2.meta["val"]
45
+ else:
46
+ return False
47
+ if not should_decompose_common(mat1, mat2):
48
+ return False
49
+ else:
50
+ if len(mat1.shape) != 3 or len(mat2.shape) != 3:
51
+ return False
52
+ if mat1.shape[0] < MIN_FIRST_DIMENSION_DECOMPOSITION:
53
+ return False
54
+ # 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION
55
+ if (mat1.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION) + (
56
+ mat1.shape[2] < MAX_OTHER_DIMENSION_DECOMPOSITION
57
+ ) + (mat2.shape[2] < MAX_OTHER_DIMENSION_DECOMPOSITION) < 2:
58
+ return False
59
+ return True
60
+
61
+
62
+ def should_decompose_mm(mat1, mat2) -> bool:
63
+ if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
64
+ mat1 = mat1.meta["val"]
65
+ mat2 = mat2.meta["val"]
66
+ else:
67
+ return False
68
+ return (
69
+ should_decompose_common(mat1, mat2)
70
+ and len(mat1.shape) == 2
71
+ and len(mat2.shape) == 2
72
+ and mat1.shape[0] >= MIN_FIRST_DIMENSION_DECOMPOSITION
73
+ and mat2.shape[0] < MAX_OTHER_DIMENSION_DECOMPOSITION
74
+ and mat2.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION
75
+ )
76
+
77
+
78
+ def should_decompose_mmt(mat1, mat2) -> bool:
79
+ if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
80
+ mat1 = mat1.meta["val"]
81
+ mat2 = mat2.meta["val"]
82
+ else:
83
+ return False
84
+ return (
85
+ should_decompose_common(mat1, mat2)
86
+ and len(mat1.shape) == 2
87
+ and len(mat2.shape) == 2
88
+ and mat1.shape[0] >= MIN_FIRST_DIMENSION_DECOMPOSITION
89
+ and mat1.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION
90
+ and mat2.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION
91
+ )
92
+
93
+
94
+ def should_decompose_mm_largek(mat1, mat2) -> bool:
95
+ if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
96
+ mat1 = mat1.meta["val"]
97
+ mat2 = mat2.meta["val"]
98
+ else:
99
+ return False
100
+ return (
101
+ should_decompose_common(mat1, mat2)
102
+ and len(mat1.shape) == 2
103
+ and len(mat2.shape) == 2
104
+ and mat1.shape[1] >= MIN_FIRST_DIMENSION_DECOMPOSITION
105
+ and mat1.shape[0] < MAX_OTHER_DIMENSION_DECOMPOSITION
106
+ and mat2.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION
107
+ )
108
+
109
+
110
+ def is_node_meta_valid(node: torch.fx.Node):
111
+ return "val" in node.meta
112
+
113
+
114
+ def print_decompose_pattern(match: Match, inputs: List[torch.fx.Node]):
115
+ node = match.nodes[-1]
116
+ log.debug(
117
+ "Decompose %s with input shape: %s",
118
+ node.target,
119
+ ", ".join(
120
+ str(input.meta["val"].shape) if "val" in input.meta else "None"
121
+ for input in inputs
122
+ ),
123
+ )
124
+
125
+
126
+ @register_graph_pattern(
127
+ CallFunction(aten.bmm, Arg(), Arg()),
128
+ pass_dict=decompose_mm_pass,
129
+ extra_check=config_flag("decompose_mem_bound_mm"),
130
+ )
131
+ def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node):
132
+ def repl(mat1, mat2):
133
+ return torch.sum(mat1[:, :, :, None] * mat2[:, None, :, :], dim=-2)
134
+
135
+ if should_decompose_bmm(mat1, mat2):
136
+ counters["inductor"]["decompose_bmm"] += 1
137
+ match.replace_by_example(repl, [mat1, mat2])
138
+ print_decompose_pattern(match, [mat1, mat2])
139
+ return
140
+
141
+
142
+ @register_graph_pattern(
143
+ CallFunction(aten.addmm, Arg(), Arg(), Arg()),
144
+ pass_dict=decompose_mm_pass,
145
+ extra_check=config_flag("decompose_mem_bound_mm"),
146
+ )
147
+ def decompose_addmm(
148
+ match: Match,
149
+ mat1: torch.fx.Node,
150
+ mat2: torch.fx.Node,
151
+ mat3: torch.fx.Node,
152
+ ):
153
+ def repl(mat1, mat2, mat3):
154
+ return torch.sum(mat2[:, :, None] * mat3[None, :, :], dim=-2) + mat1
155
+
156
+ if should_decompose_mm(mat2, mat3):
157
+ counters["inductor"]["decompose_addmm"] += 1
158
+ match.replace_by_example(repl, [mat1, mat2, mat3])
159
+ print_decompose_pattern(match, [mat1, mat2, mat3])
160
+ return
161
+
162
+
163
+ @register_graph_pattern(
164
+ CallFunction(aten.mm, CallFunction(aten.permute, Arg(), Ignored()), Arg()),
165
+ pass_dict=decompose_mm_pass,
166
+ extra_check=config_flag("decompose_mem_bound_mm"),
167
+ )
168
+ def decompose_mmt(
169
+ match: Match,
170
+ mat1: torch.fx.Node,
171
+ mat2: torch.fx.Node,
172
+ ):
173
+ def repl(mat1, mat2):
174
+ return torch.sum(mat1[:, :, None] * mat2[:, None, :], dim=0)
175
+
176
+ if should_decompose_mmt(mat1, mat2):
177
+ counters["inductor"]["decompose_mmt"] += 1
178
+ match.replace_by_example(repl, [mat1, mat2])
179
+ print_decompose_pattern(match, [mat1, mat2])
180
+ return
181
+
182
+
183
+ @register_graph_pattern(
184
+ CallFunction(aten.mm, Arg(), Arg()),
185
+ pass_dict=decompose_mm_pass,
186
+ extra_check=config_flag("decompose_mem_bound_mm"),
187
+ )
188
+ def decompose_mm(
189
+ match: Match,
190
+ mat1: torch.fx.Node,
191
+ mat2: torch.fx.Node,
192
+ ):
193
+ def repl(mat1, mat2):
194
+ return torch.sum(mat1[:, :, None] * mat2[None, :, :], dim=-2)
195
+
196
+ if should_decompose_mm(mat1, mat2):
197
+ counters["inductor"]["decompose_mm"] += 1
198
+ match.replace_by_example(repl, [mat1, mat2])
199
+ print_decompose_pattern(match, [mat1, mat2])
200
+ return
201
+
202
+
203
+ @register_graph_pattern(
204
+ CallFunction(aten.mm, Arg(), Arg()),
205
+ pass_dict=decompose_mm_pass,
206
+ extra_check=config_flag("decompose_mem_bound_mm"),
207
+ )
208
+ def decompose_mm_large_k(
209
+ match: Match,
210
+ mat1: torch.fx.Node,
211
+ mat2: torch.fx.Node,
212
+ ):
213
+ def repl(mat1, mat2):
214
+ mat1 = mat1.permute(1, 0)
215
+ return torch.sum(mat1[:, :, None] * mat2[:, None, :], dim=0)
216
+
217
+ if should_decompose_mm_largek(mat1, mat2):
218
+ counters["inductor"]["decompose_mm_large_k"] += 1
219
+ match.replace_by_example(repl, [mat1, mat2])
220
+ print_decompose_pattern(match, [mat1, mat2])
221
+ return
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch._dynamo.utils import counters
5
+ from torch._inductor import config as inductor_config
6
+ from torch.func import functional_call
7
+
8
+ from ..pattern_matcher import CallModuleVarArgs, Match, register_graph_pattern
9
+
10
+ from .pre_grad import efficient_conv_bn_eval_pass
11
+
12
+
13
+ def efficient_conv_bn_eval(
14
+ bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor
15
+ ):
16
+ """
17
+ Implementation based on https://arxiv.org/abs/2305.11624
18
+ "Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
19
+ It leverages the associative law between convolution and affine transform,
20
+ i.e., normalize (weight conv feature) = (normalize weight) conv feature.
21
+ It works for Eval mode of ConvBN blocks during validation, and can be used
22
+ for **training** as well, but only if one sets `bn.training=False`. It
23
+ reduces memory footprint and computation cost, at the cost of slightly
24
+ reduced numerical stability.
25
+ Args:
26
+ bn (nn.modules.batchnorm._BatchNorm): a BatchNorm module.
27
+ conv (nn.modules.conv._ConvNd): a conv module
28
+ x (torch.Tensor): Input feature map.
29
+ """
30
+
31
+ assert bn.running_var is not None
32
+
33
+ # These lines of code are designed to deal with various cases
34
+ # like bn without affine transform, and conv without bias
35
+ weight_on_the_fly = conv.weight
36
+ if conv.bias is not None:
37
+ bias_on_the_fly = conv.bias
38
+ else:
39
+ bias_on_the_fly = torch.zeros_like(bn.running_var)
40
+
41
+ if bn.weight is not None:
42
+ bn_weight = bn.weight
43
+ else:
44
+ bn_weight = torch.ones_like(bn.running_var)
45
+
46
+ if bn.bias is not None:
47
+ bn_bias = bn.bias
48
+ else:
49
+ bn_bias = torch.zeros_like(bn.running_var)
50
+
51
+ # shape of [C_out, 1, 1, 1] in Conv2d
52
+ target_shape = [-1] + [1] * (conv.weight.ndim - 1)
53
+ if isinstance(conv, nn.modules.conv._ConvTransposeNd):
54
+ # for transposed conv, the C_out dimension should at index 1.
55
+ target_shape[:2] = [target_shape[1], target_shape[0]]
56
+ weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape(target_shape)
57
+ # shape of [C_out, 1, 1, 1] in Conv2d
58
+ coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
59
+
60
+ # shape of [C_out, C_in, k, k] in Conv2d
61
+ weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
62
+ # shape of [C_out] in Conv2d
63
+ bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (
64
+ bias_on_the_fly - bn.running_mean
65
+ )
66
+
67
+ input = x
68
+ params = {"weight": weight_on_the_fly, "bias": bias_on_the_fly}
69
+ output = functional_call(conv, params, input)
70
+ return output
71
+
72
+
73
+ @register_graph_pattern(
74
+ CallModuleVarArgs(
75
+ [
76
+ nn.modules.batchnorm._BatchNorm,
77
+ nn.BatchNorm1d,
78
+ nn.BatchNorm2d,
79
+ nn.BatchNorm3d,
80
+ nn.SyncBatchNorm,
81
+ ],
82
+ ),
83
+ pass_dict=efficient_conv_bn_eval_pass,
84
+ extra_check=lambda match: not inductor_config.freezing
85
+ and inductor_config.efficient_conv_bn_eval_fx_passes,
86
+ )
87
+ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
88
+ # We matched a BN node
89
+ bn_node = match.nodes[0]
90
+ graph = match.graph
91
+ gm = graph.owning_module
92
+ bn_mod = getattr(gm, bn_node.target) # type: ignore[arg-type]
93
+
94
+ # We can only use efficient conv-bn for eval mode with track_running_stats
95
+ if not bn_mod.track_running_stats or bn_mod.training:
96
+ return
97
+
98
+ # Check if the input is Conv
99
+ if bn_node.args:
100
+ input_node = bn_node.args[0]
101
+ else:
102
+ input_node = bn_node.kwargs["input"]
103
+ if input_node.op != "call_module": # type: ignore[union-attr]
104
+ return
105
+ if not hasattr(gm, input_node.target): # type: ignore[arg-type, union-attr]
106
+ return
107
+ input_mod = getattr(gm, input_node.target) # type: ignore[arg-type, union-attr]
108
+ supported_convs = [
109
+ nn.Linear,
110
+ nn.Conv1d,
111
+ nn.Conv2d,
112
+ nn.Conv3d,
113
+ nn.ConvTranspose1d,
114
+ nn.ConvTranspose2d,
115
+ nn.ConvTranspose3d,
116
+ ]
117
+ if not any(isinstance(input_mod, cls) for cls in supported_convs):
118
+ return
119
+ conv_node = input_node
120
+ # Output of conv is used by other nodes, cannot optimize
121
+ if len(conv_node.users) > 1: # type: ignore[union-attr]
122
+ return
123
+
124
+ # Find a pair of conv and bn computation nodes to optimize.
125
+ counters["inductor"]["efficient_conv_bn_eval"] += 1
126
+
127
+ with graph.inserting_before(conv_node):
128
+ # create `get_attr` node to access modules
129
+ # note that we directly call `create_node` to fill the `name`
130
+ # argument. `graph.get_attr` and
131
+ # `graph.call_function` does not allow the `name` argument.
132
+ conv_get_node = graph.create_node(
133
+ op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr]
134
+ )
135
+ bn_get_node = graph.create_node(
136
+ op="get_attr", target=bn_node.target, name="get_bn"
137
+ )
138
+ if conv_node.args: # type: ignore[union-attr]
139
+ conv_input = conv_node.args[0] # type: ignore[union-attr]
140
+ else:
141
+ conv_input = conv_node.kwargs["input"] # type: ignore[union-attr]
142
+ # prepare args for the fused function
143
+ args = (bn_get_node, conv_get_node, conv_input)
144
+ # create a new node
145
+ new_node = graph.create_node(
146
+ op="call_function",
147
+ target=efficient_conv_bn_eval,
148
+ args=args,
149
+ name="efficient_conv_bn_eval",
150
+ )
151
+ # this node replaces the original conv + bn, and therefore
152
+ # should replace the uses of bn_node
153
+ bn_node.replace_all_uses_with(new_node)
154
+ # take care of the deletion order:
155
+ # delete bn_node first, and then conv_node
156
+ graph.erase_node(bn_node)
157
+ graph.erase_node(conv_node)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ from torch._inductor.compile_fx import fake_tensor_prop
5
+ from ..._dynamo.utils import counters
6
+
7
+ from .. import config
8
+ from ..pattern_matcher import (
9
+ _return_true,
10
+ CallFunction,
11
+ fwd_only,
12
+ Ignored,
13
+ init_once_fakemode,
14
+ KeywordArg,
15
+ Match,
16
+ PatternMatcherPass,
17
+ register_graph_pattern,
18
+ register_replacement,
19
+ stable_topological_sort,
20
+ )
21
+
22
+ aten = torch.ops.aten
23
+
24
+ # First pass_patterns[0] are applied, then [1], then [2]
25
+ pass_patterns = [
26
+ PatternMatcherPass(),
27
+ PatternMatcherPass(),
28
+ PatternMatcherPass(),
29
+ ]
30
+
31
+ binary_folding_pass = PatternMatcherPass()
32
+
33
+
34
+ def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs):
35
+ """
36
+ Passes that are applied to the graph to freeze pass.
37
+ """
38
+
39
+ from ..freezing import constant_fold
40
+
41
+ lazy_init()
42
+ # We need a few rounds of binary folding to get rid of all the
43
+ # unnecessary nodes, but may need a good method to chose the rounds number.
44
+ # works like: conv+binary+binary.
45
+ binary_folding = counters["inductor"]["binary_folding"]
46
+ fake_tensor_prop(gm, aot_example_inputs, True)
47
+
48
+ torch._inductor.fx_passes.binary_folding.mark_mixed_dtype_allowed_convs(gm)
49
+ for _ in range(4):
50
+ constant_fold(gm)
51
+ # Make sure meta['val'] is properly set for all nodes
52
+ fake_tensor_prop(gm, aot_example_inputs, True)
53
+ binary_folding_pass.apply(gm.graph) # type: ignore[arg-type]
54
+ # If we don't have binary folding, we don't need to run the pass again.
55
+ # TODO: remove the need to run fake_tensor_prop on the whole model.
56
+ if counters["inductor"]["binary_folding"] == binary_folding:
57
+ break
58
+ binary_folding = counters["inductor"]["binary_folding"]
59
+
60
+ torch._inductor.fx_passes.binary_folding.recover_original_precision_folded_convs(gm)
61
+
62
+ constant_fold(gm)
63
+ fake_tensor_prop(gm, aot_example_inputs, True)
64
+
65
+ for pattern in pass_patterns:
66
+ pattern.apply(gm.graph) # type: ignore[arg-type]
67
+
68
+ # The CPU weight packing always assume the conv's weight is channels last,
69
+ # So make sure the layout_optimization is on when doing it.
70
+ if (
71
+ torch._C._has_mkldnn
72
+ and config.cpp.weight_prepack
73
+ and config.layout_optimization
74
+ ):
75
+ from .mkldnn_fusion import _eliminate_duplicate_packed_nodes
76
+
77
+ _eliminate_duplicate_packed_nodes(gm)
78
+
79
+ stable_topological_sort(gm.graph)
80
+ gm.recompile()
81
+ gm.graph.lint()
82
+
83
+
84
+ @init_once_fakemode
85
+ def lazy_init():
86
+ if torch._C._has_mkldnn and config.cpp.weight_prepack:
87
+ from .mkldnn_fusion import _mkldnn_weight_pack_init
88
+
89
+ _mkldnn_weight_pack_init()
90
+
91
+ from .binary_folding import binary_folding_init
92
+
93
+ addmm_patterns_init()
94
+ binary_folding_init()
95
+
96
+
97
+ def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0):
98
+ return register_graph_pattern(
99
+ pattern,
100
+ extra_check=extra_check,
101
+ pass_dict=pass_patterns[pass_number],
102
+ )
103
+
104
+
105
+ def register_binary_folding_pattern(pattern, extra_check=_return_true):
106
+ return register_graph_pattern(
107
+ pattern,
108
+ extra_check=extra_check,
109
+ pass_dict=binary_folding_pass,
110
+ )
111
+
112
+
113
+ @functools.lru_cache(None)
114
+ def addmm_patterns_init():
115
+ if torch.cuda.is_available():
116
+ # workaround https://github.com/pytorch/pytorch/issues/97894
117
+ device = "cuda"
118
+ else:
119
+ device = "cpu"
120
+ val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False)
121
+
122
+ def check_concat_weights(match):
123
+ weights = [
124
+ match.kwargs["w1"],
125
+ match.kwargs["w2"],
126
+ ]
127
+ if "w3" in match.kwargs:
128
+ weights.append(match.kwargs["w3"])
129
+
130
+ return all(
131
+ w.op == "get_attr" and w.meta["val"].shape == weights[0].meta["val"].shape
132
+ for w in weights
133
+ )
134
+
135
+ def matmul_fuse_pattern(inp, w1, w2, w3):
136
+ return (inp @ w1, inp @ w2, inp @ w3)
137
+
138
+ def matmul_replacement(inp, w1, w2, w3):
139
+ cat_t = torch.cat((w1, w2, w3), dim=1)
140
+ mm = inp @ cat_t
141
+ return mm.chunk(3, dim=1)
142
+
143
+ register_replacement(
144
+ matmul_fuse_pattern,
145
+ matmul_replacement,
146
+ [val(), val(), val(), val()],
147
+ fwd_only,
148
+ pass_patterns[0],
149
+ extra_check=check_concat_weights,
150
+ exclusive_arg_names=("w1", "w2", "w3"),
151
+ )
152
+
153
+ def matmul_fuse_pattern_two(inp, w1, w2):
154
+ return (inp @ w1, inp @ w2)
155
+
156
+ def matmul_replacement_two(inp, w1, w2):
157
+ cat_t = torch.cat((w1, w2), dim=1)
158
+ mm = inp @ cat_t
159
+ return mm.chunk(2, dim=1)
160
+
161
+ register_replacement(
162
+ matmul_fuse_pattern_two,
163
+ matmul_replacement_two,
164
+ [val(), val(), val()],
165
+ fwd_only,
166
+ pass_patterns[0],
167
+ extra_check=check_concat_weights,
168
+ exclusive_arg_names=("w1", "w2"),
169
+ )
170
+
171
+ def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3):
172
+ return (
173
+ aten.addmm(b1, inp, w1),
174
+ aten.addmm(b2, inp, w2),
175
+ aten.addmm(b3, inp, w3),
176
+ )
177
+
178
+ def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3):
179
+ cat_w = torch.cat((w1, w2, w3), dim=1)
180
+ cat_b = torch.cat((b1, b2, b3))
181
+ return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1)
182
+
183
+ register_replacement(
184
+ addmm_fuse_pattern_second,
185
+ addmm_fuse_replacement_second,
186
+ [val() for _ in range(7)],
187
+ fwd_only,
188
+ pass_patterns[0],
189
+ extra_check=check_concat_weights,
190
+ exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"),
191
+ )
192
+
193
+
194
+ def same_dtype(match):
195
+ return match.output_node().args[0].meta["val"].dtype == match.kwargs["dtype"]
196
+
197
+
198
+ @register_graph_pattern(
199
+ CallFunction(
200
+ torch.ops.prims.convert_element_type.default,
201
+ Ignored(),
202
+ KeywordArg("dtype"),
203
+ ),
204
+ pass_dict=pass_patterns[0],
205
+ extra_check=same_dtype,
206
+ )
207
+ def unnecessary_dtype_convert(match: Match, **kwargs):
208
+ """Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding"""
209
+ graph = match.graph
210
+ node = match.output_node()
211
+ node.replace_all_uses_with(node.args[0])
212
+ graph.erase_node(node)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ import os
4
+ import random
5
+ import traceback
6
+
7
+ import numpy
8
+
9
+ import torch
10
+ import torch.optim as optim
11
+
12
+ from .. import config
13
+
14
+ logger: logging.Logger = logging.getLogger(__name__)
15
+
16
+ MAIN_RANDOM_SEED = 1337
17
+
18
+ # Set the CUBLAS_WORKSPACE_CONFIG environment variable
19
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
20
+
21
+
22
+ # If the two forward functions involve any non-deterministic operations,
23
+ # such as certain types of parallelism or asynchronous execution,
24
+ # this can also lead to different outputs.
25
+ def set_deterministic() -> None:
26
+ """Make torch manual seed deterministic."""
27
+
28
+ torch.manual_seed(MAIN_RANDOM_SEED)
29
+ random.seed(MAIN_RANDOM_SEED)
30
+ numpy.random.seed(MAIN_RANDOM_SEED)
31
+ torch.use_deterministic_algorithms(True)
32
+
33
+
34
+ def clean_memory() -> None:
35
+ """Clean memory to avoid OOM."""
36
+ gc.collect()
37
+ torch.cuda.empty_cache()
38
+
39
+
40
+ # We compare the numerical results before and after pre/post grad fx passes
41
+ # transformation to make sure the numerical results are the same.
42
+ def compare_dict_tensors(dict_base, dict_control, precision):
43
+ if len(set(dict_base.keys())) != len(set(dict_control.keys())):
44
+ logger.warning("Mismatch keys found before and after pre/post grad fx passes.")
45
+ logger.debug("keys before pre/post grad fx passes %s", dict_base.keys())
46
+ logger.debug("keys after pre/post grad fx passes %s", dict_control.keys())
47
+ return False
48
+ is_allclose = True
49
+ for key in dict_base.keys():
50
+ if key not in dict_control:
51
+ logger.warning(
52
+ "Mismatch parameter name %s does not exist after pre/post grad fx passes",
53
+ key,
54
+ )
55
+ # Some parameters have `None`, and not every param has a valid .grad field, we skip them
56
+ if dict_base[key] is None or dict_control[key] is None:
57
+ continue
58
+ if not torch.allclose(
59
+ dict_base[key],
60
+ dict_control[key],
61
+ rtol=precision,
62
+ atol=precision,
63
+ equal_nan=True,
64
+ ):
65
+ logger.warning(
66
+ "Mismatch parameter values found before and after pre/post grad fx passes."
67
+ )
68
+ logger.debug("value before pre/post grad fx passes %s", dict_base[key])
69
+ logger.debug("value after pre/post grad fx passes %s", dict_control[key])
70
+ is_allclose = False
71
+ return is_allclose
72
+
73
+
74
+ def compare_tuple_tensors(tuple_base, tuple_control, precision):
75
+ if len(tuple_base) != len(tuple_control):
76
+ logger.warning(
77
+ "Mismatch fw output length. before transformation: %s, after transformation: %s",
78
+ len(tuple_base),
79
+ len(tuple_control),
80
+ )
81
+ return False
82
+ is_allclose = True
83
+ for i in range(len(tuple_base)):
84
+ # Some parameters have `None`, we skip them
85
+ if tuple_base[i] is None or tuple_control[i] is None:
86
+ continue
87
+ if not torch.allclose(
88
+ tuple_base[i],
89
+ tuple_control[i],
90
+ rtol=precision,
91
+ atol=precision,
92
+ equal_nan=True,
93
+ ):
94
+ logger.debug(
95
+ "forward output before pre/post grad fx passes %s", tuple_base[i]
96
+ )
97
+ logger.debug(
98
+ "forward output after pre/post grad fx passes %s", tuple_control[i]
99
+ )
100
+ is_allclose = False
101
+ return is_allclose
102
+
103
+
104
+ def compare_parameters(model_base, model_control, precision):
105
+ return compare_dict_tensors(
106
+ dict(model_base.named_parameters()),
107
+ dict(model_control.named_parameters()),
108
+ precision,
109
+ )
110
+
111
+
112
+ def compare_forward_output(pred_base, pred_control, precision):
113
+ return compare_tuple_tensors(
114
+ pred_base,
115
+ pred_control,
116
+ precision,
117
+ )
118
+
119
+
120
+ def compare_gradients(model_base, model_control, precision):
121
+ grad_base = {key: param.grad for key, param in model_base.named_parameters()}
122
+ grad_pt2 = {key: param.grad for key, param in model_control.named_parameters()}
123
+ return compare_dict_tensors(
124
+ grad_base,
125
+ grad_pt2,
126
+ precision,
127
+ )
128
+
129
+
130
+ def run_model(
131
+ model_base, model_control, model_input, num_iterations=10, precision=1e-4
132
+ ):
133
+ clean_memory()
134
+ for i in range(num_iterations):
135
+ logger.info("start %s iteration", i)
136
+ set_deterministic()
137
+ pred_base = model_base(*model_input)
138
+ set_deterministic()
139
+ pred_control = model_control(*model_input)
140
+
141
+ res = compare_parameters(model_base, model_control, precision)
142
+ logger.info("compare parameters. Numerical result : %s", res)
143
+
144
+ res = compare_forward_output(pred_base, pred_control, precision)
145
+ logger.info("compare loss/predict. Numerical result : %s", res)
146
+ # tensor may not have a grad_fn
147
+ try:
148
+ _ = pred_base[0].sum().backward(retain_graph=True)
149
+ _ = pred_control[0].sum().backward(retain_graph=True)
150
+ res = compare_gradients(model_base, model_control, precision)
151
+ logger.info("compare param grad. Numerical result : %s", res)
152
+ except Exception as e:
153
+ logger.exception("Exception %s when compare gradients", e)
154
+ traceback.print_exc()
155
+
156
+ if config.fx_passes_numeric_check["requires_optimizer"]:
157
+ try:
158
+ optimizer_base = optim.SGD(
159
+ [param for name, param in model_base.named_parameters()], lr=0.01
160
+ )
161
+ optimizer_base.step()
162
+
163
+ optimizer_control = optim.SGD(
164
+ [param for name, param in model_control.named_parameters()], lr=0.01
165
+ )
166
+ optimizer_control.step()
167
+
168
+ res = compare_parameters(model_base, model_control, precision)
169
+ logger.info(
170
+ "compare parameters with optimizer added. Numerical result : %s",
171
+ res,
172
+ )
173
+ except Exception as e:
174
+ logger.exception(
175
+ "Exception %s when optimizer is added to check parameter names", e
176
+ )
177
+ traceback.print_exc()
178
+ else:
179
+ logger.warning(
180
+ "no parameter with optimizer to compare with length %s before transformation"
181
+ " and the length %s after transformation",
182
+ len(dict(model_base.named_parameters())),
183
+ len(dict(model_control.named_parameters())),
184
+ )
185
+
186
+
187
+ def numeric_check_if_enabled(
188
+ gm_before_fx_passes,
189
+ gm_after_fx_passes,
190
+ example_inputs,
191
+ num_iterations,
192
+ precision,
193
+ ):
194
+ # need to topo-sort graphmodule before we run the model,
195
+ # otherwise it may fail as refer before def
196
+ # fail silently in order not to block the model run
197
+ try:
198
+ with torch.autograd.set_detect_anomaly(True):
199
+ run_model(
200
+ gm_before_fx_passes,
201
+ gm_after_fx_passes,
202
+ example_inputs,
203
+ num_iterations=num_iterations,
204
+ precision=precision,
205
+ )
206
+ except Exception as e:
207
+ logger.warning(
208
+ "Runtime numeric check failed in pre grad fx passes with error: %s", e
209
+ )
210
+ traceback.print_exc()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from typing import List, Optional, Set, Union
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch._inductor import utils
7
+ from torch._subclasses.fake_tensor import FakeTensor
8
+ from torch.utils._mode_utils import no_dispatch
9
+ from torch.utils._triton import has_triton
10
+
11
+ from ..pattern_matcher import (
12
+ fwd_only,
13
+ joint_fwd_bwd,
14
+ Match,
15
+ MatchContext,
16
+ register_replacement,
17
+ )
18
+ from ..utils import is_view
19
+
20
+ aten = torch.ops.aten
21
+
22
+
23
+ # This flag is only used for testing purpose.
24
+ # Changing it to True will ignore comparing do_bench times
25
+ # between original pattern and padded one.
26
+ _skip_do_bench_times = False
27
+
28
+
29
+ def fetch_fake_tensors(match, kwarg_names) -> List[Tensor]:
30
+ kwargs = match.kwargs
31
+ return [kwargs[name].meta["val"] for name in kwarg_names]
32
+
33
+
34
+ def unwrap_fake_args(*arg_names):
35
+ def decorator(func):
36
+ def wrapper(match):
37
+ fake_tensors = fetch_fake_tensors(match, arg_names)
38
+ return func(*fake_tensors)
39
+
40
+ return wrapper
41
+
42
+ return decorator
43
+
44
+
45
+ def get_alignment_size(x: Tensor) -> int:
46
+ if x.dtype == torch.float16 or x.dtype == torch.half or x.dtype == torch.bfloat16:
47
+ return 8
48
+ elif x.dtype == torch.float32 or x.dtype == torch.float:
49
+ return 4
50
+ else:
51
+ return 0
52
+
53
+
54
+ def check_device(a: Tensor, b: Tensor) -> bool:
55
+ return a.is_cuda and b.is_cuda
56
+
57
+
58
+ def check_dtype(a: Tensor, b: Tensor) -> bool:
59
+ return a.is_floating_point() and b.is_floating_point()
60
+
61
+
62
+ def _result_layout_affects_graph_output(match: Match) -> bool:
63
+ """
64
+ Check if the matched GEMM operation potentially affects the graph output strides.
65
+ returns True if the matched op's output buffer does not pass through functions which certainly
66
+ redefine the memory layout before being part of the graph output.
67
+ """
68
+
69
+ if match.ctx is not None:
70
+ assert isinstance(match.ctx, MatchContext)
71
+ search_node: torch.fx.Node = match.output_node()
72
+ else:
73
+ return True
74
+
75
+ assert search_node is not None
76
+ seen: Set[torch.fx.Node] = set()
77
+
78
+ def find_output(node: torch.fx.Node, is_start_node=False):
79
+ if not isinstance(node, torch.fx.Node):
80
+ return False
81
+ if node in seen:
82
+ return False
83
+ seen.add(node)
84
+ if node.op == "output":
85
+ return True
86
+ if node.op != "call_function":
87
+ return False
88
+ if not is_start_node and (
89
+ (not isinstance(node.target, torch._ops.OpOverload))
90
+ or (not is_view(node.target))
91
+ ):
92
+ return False
93
+ if node.users is not None and len(node.users) > 0:
94
+ for n in node.users:
95
+ if find_output(n):
96
+ return True
97
+ return False
98
+
99
+ return find_output(search_node, True)
100
+
101
+
102
+ def should_pad_common(
103
+ mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None
104
+ ) -> bool:
105
+ # It's fine we have symbolic shapes or strides as long as they
106
+ # have hints. Later, we will make sure we only pad non-symbolic dimensions.
107
+ def valid_shape_and_stride(t: Optional[Tensor]) -> bool:
108
+ if t is None:
109
+ return True
110
+
111
+ symbolic_cnt = 0
112
+ for x in t.size():
113
+ if isinstance(x, int):
114
+ continue
115
+ elif utils.is_symbolic(x):
116
+ if not x.node.has_hint():
117
+ return False
118
+ symbolic_cnt += 1
119
+ else:
120
+ return False
121
+ # filter out cases where all dimentions are symbolic
122
+ if symbolic_cnt == len(t.size()):
123
+ return False
124
+ return all(
125
+ isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint())
126
+ for x in t.stride()
127
+ )
128
+
129
+ return (
130
+ torch._inductor.config.shape_padding
131
+ and check_device(mat1, mat2)
132
+ and check_dtype(mat1, mat2)
133
+ and all(valid_shape_and_stride(t) for t in (mat1, mat2, input))
134
+ )
135
+
136
+
137
+ def get_padded_length(x: Union[int, torch.SymInt], alignment_size) -> int:
138
+ # we don't pad x if it is symbolic
139
+ if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0:
140
+ return 0
141
+ return int((x // alignment_size + 1) * alignment_size) - x
142
+
143
+
144
+ def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor:
145
+ if padded_length == 0:
146
+ return x
147
+ pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :])
148
+ return torch.cat([x, pad], dim=dim)
149
+
150
+
151
+ def addmm_pattern(
152
+ input: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float
153
+ ) -> Tensor:
154
+ return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
155
+
156
+
157
+ def should_pad_addmm(match: Match) -> bool:
158
+ if (
159
+ torch._inductor.config.keep_output_stride
160
+ and _result_layout_affects_graph_output(match)
161
+ ):
162
+ return False
163
+ mat1, mat2, input = fetch_fake_tensors(match, ("mat1", "mat2", "input"))
164
+ return should_pad_common(mat1, mat2, input) and should_pad_bench(
165
+ mat1, mat2, torch.ops.aten.addmm, input=input
166
+ )
167
+
168
+
169
+ def addmm_replace(
170
+ input: Optional[Tensor], mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0
171
+ ) -> Tensor:
172
+ m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
173
+ k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
174
+ n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
175
+
176
+ if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0:
177
+ return pad_addmm(
178
+ input,
179
+ mat1,
180
+ mat2,
181
+ m_padded_length,
182
+ k_padded_length,
183
+ n_padded_length,
184
+ beta,
185
+ alpha,
186
+ )
187
+
188
+ return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
189
+
190
+
191
+ def pad_addmm(
192
+ input: Optional[Tensor],
193
+ mat1: Tensor,
194
+ mat2: Tensor,
195
+ m_padded_length: int,
196
+ k_padded_length: int,
197
+ n_padded_length: int,
198
+ beta=1.0,
199
+ alpha=1.0,
200
+ ):
201
+ # addmm decomp with padding will go through pad_addmm multiple times if multiple dimensions are needed to be padded
202
+ if k_padded_length != 0:
203
+ mat1 = pad_dim(mat1, k_padded_length, 1)
204
+ mat2 = pad_dim(mat2, k_padded_length, 0)
205
+ elif n_padded_length != 0:
206
+ mat2 = pad_dim(mat2, n_padded_length, 1)
207
+ elif m_padded_length != 0:
208
+ mat1 = pad_dim(mat1, m_padded_length, 0)
209
+
210
+ # the add broadcasts, so we only pad if the dimension != 1
211
+ if input is not None and k_padded_length == 0:
212
+ if n_padded_length != 0:
213
+ if input.dim() == 2 and input.shape[1] != 1:
214
+ input = pad_dim(input, n_padded_length, 1)
215
+ elif input.dim() == 1 and input.shape[0] != 1:
216
+ input = pad_dim(input, n_padded_length, 0)
217
+ elif m_padded_length != 0 and input.dim() == 2 and input.shape[0] != 1:
218
+ input = pad_dim(input, m_padded_length, 0)
219
+
220
+ if k_padded_length != 0:
221
+ return addmm_replace(input, mat1, mat2, beta=beta, alpha=alpha)
222
+ elif n_padded_length != 0:
223
+ return addmm_replace(input, mat1, mat2, beta=beta, alpha=alpha)[
224
+ :, :-n_padded_length
225
+ ]
226
+ else:
227
+ return addmm_replace(input, mat1, mat2, beta=beta, alpha=alpha)[
228
+ :-m_padded_length, :
229
+ ]
230
+
231
+
232
+ def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool:
233
+ denominator = M * K + N * K + M * N
234
+ if denominator == 0:
235
+ return False
236
+ arithmetic_intensity = (M * N * K) / denominator
237
+
238
+ # Fails with AMD
239
+ try:
240
+ machine_balance = (
241
+ 1000 * utils.get_device_tflops(dtype)
242
+ ) / utils.get_gpu_dram_gbps()
243
+ except Exception:
244
+ return True
245
+
246
+ # dram_gbps might be underestimating bandwidth because of cache.
247
+ # if we estimate machine balance too low we might miss some speedups,
248
+ # if we extimate too high there will be unnecessary compilation time increase.
249
+ # TODO - finetune coefficient here. As a reference point, Triton mm model assumes
250
+ # 80% of reads are in cache and cache is 4x faster than dram_gbps
251
+ machine_balance = machine_balance * 0.5
252
+
253
+ return arithmetic_intensity > machine_balance
254
+
255
+
256
+ @functools.lru_cache(None)
257
+ def get_pad_cache():
258
+ return torch._inductor.codecache.LocalCache()
259
+
260
+
261
+ def get_cached_should_pad(key):
262
+ return get_pad_cache().lookup(key)
263
+
264
+
265
+ def set_cached_should_pad(key, value):
266
+ return get_pad_cache().set_value(key, value=value)
267
+
268
+
269
+ def should_pad_bench_key(
270
+ mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
271
+ ) -> str:
272
+ def tensor_key(t):
273
+ return (t.shape, t.stride(), t.dtype)
274
+
275
+ tf32_key = (
276
+ None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32
277
+ )
278
+ key = (
279
+ tensor_key(mat1),
280
+ tensor_key(mat2),
281
+ op,
282
+ input if input is None else tensor_key(input),
283
+ tf32_key,
284
+ )
285
+
286
+ return str(key)
287
+
288
+
289
+ def should_pad_bench(
290
+ mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
291
+ ) -> bool:
292
+ if not has_triton():
293
+ return False
294
+
295
+ do_bench = functools.partial(
296
+ utils.do_bench,
297
+ warmup=5,
298
+ )
299
+
300
+ with no_dispatch():
301
+ if op is torch.ops.aten.mm or op is torch.ops.aten.addmm:
302
+ m = mat1.shape[0]
303
+ k = mat1.shape[1]
304
+ n = mat2.shape[1]
305
+
306
+ m_padded_length = get_padded_length(m, get_alignment_size(mat1))
307
+ k_padded_length = get_padded_length(k, get_alignment_size(mat1))
308
+ n_padded_length = get_padded_length(n, get_alignment_size(mat2))
309
+ elif op is torch.ops.aten.bmm:
310
+ m = mat1.shape[1]
311
+ k = mat1.shape[2]
312
+ n = mat2.shape[2]
313
+
314
+ m_padded_length = get_padded_length(m, get_alignment_size(mat1))
315
+ k_padded_length = get_padded_length(k, get_alignment_size(mat1))
316
+ n_padded_length = get_padded_length(n, get_alignment_size(mat2))
317
+ else:
318
+ return False
319
+
320
+ if m_padded_length == k_padded_length == n_padded_length == 0:
321
+ return False
322
+
323
+ if not is_mm_compute_bound(m, k, n, mat1.dtype):
324
+ return False
325
+
326
+ # We don't want to look up the cache for cases that are trivially false
327
+ # since it does file io
328
+ key = should_pad_bench_key(mat1, mat2, op, input)
329
+
330
+ cached_pad = get_cached_should_pad(key)
331
+ if cached_pad is not None:
332
+ return cached_pad
333
+
334
+ def realize_symbols(ds):
335
+ return [d if isinstance(d, int) else d.node.hint for d in ds]
336
+
337
+ def realize_tensor(t):
338
+ if isinstance(t, FakeTensor):
339
+ size_hints = realize_symbols(t.size())
340
+ stride_hint = realize_symbols(t.stride())
341
+ real_size = (
342
+ sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1
343
+ )
344
+ real_t = torch.randn(real_size, dtype=t.dtype, device=t.device)
345
+ return torch.as_strided(real_t, size_hints, stride_hint)
346
+ else:
347
+ return torch.randn_like(t)
348
+
349
+ mat1 = realize_tensor(mat1)
350
+ mat2 = realize_tensor(mat2)
351
+ if op is torch.ops.aten.bmm or op is torch.ops.aten.mm:
352
+ ori_time = do_bench(
353
+ lambda: op(mat1, mat2),
354
+ )
355
+ else:
356
+ if input is not None:
357
+ input = realize_tensor(input)
358
+ ori_time = do_bench(
359
+ lambda: op(input, mat1, mat2),
360
+ )
361
+
362
+ mat1_pad = torch.randn_like(mat1)
363
+ mat2_pad = torch.randn_like(mat2)
364
+
365
+ if op is torch.ops.aten.addmm:
366
+ input_pad = None
367
+ if input is not None and input.is_cuda:
368
+ input_pad = torch.randn_like(input)
369
+ pad_time = do_bench(
370
+ lambda: pad_addmm(
371
+ input_pad,
372
+ mat1_pad,
373
+ mat2_pad,
374
+ m_padded_length,
375
+ k_padded_length,
376
+ n_padded_length,
377
+ ),
378
+ )
379
+ elif op is torch.ops.aten.mm:
380
+ pad_time = do_bench(
381
+ lambda: pad_mm(
382
+ mat1_pad,
383
+ mat2_pad,
384
+ m_padded_length,
385
+ k_padded_length,
386
+ n_padded_length,
387
+ ),
388
+ )
389
+ else:
390
+ pad_time = do_bench(
391
+ lambda: pad_bmm(
392
+ mat1_pad,
393
+ mat2_pad,
394
+ m_padded_length,
395
+ k_padded_length,
396
+ n_padded_length,
397
+ ),
398
+ )
399
+
400
+ # Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable
401
+ # tradeoff between performance improvement from shape padding and overhead from additional memory ops
402
+ # TODO: Build a learned model which would be better than this heuristic
403
+ should_pad = _skip_do_bench_times or ori_time > pad_time * 1.1
404
+ set_cached_should_pad(key, should_pad)
405
+
406
+ return should_pad
407
+
408
+
409
+ def mm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor:
410
+ return aten.mm(mat1, mat2)
411
+
412
+
413
+ def should_pad_mm(match: Match) -> bool:
414
+ if (
415
+ torch._inductor.config.keep_output_stride
416
+ and _result_layout_affects_graph_output(match)
417
+ ):
418
+ return False
419
+ mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2"))
420
+ return should_pad_common(mat1, mat2) and should_pad_bench(
421
+ mat1, mat2, torch.ops.aten.mm
422
+ )
423
+
424
+
425
+ def mm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
426
+ m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
427
+ k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
428
+ n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
429
+
430
+ return pad_mm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length)
431
+
432
+
433
+ def pad_mm(
434
+ mat1: Tensor,
435
+ mat2: Tensor,
436
+ m_padded_length: int,
437
+ k_padded_length: int,
438
+ n_padded_length: int,
439
+ ) -> Tensor:
440
+ # mm_replace will go through pad_mm multiple times if multiple dimensions are needed to be padded
441
+ if k_padded_length != 0:
442
+ mat1 = pad_dim(mat1, k_padded_length, 1)
443
+ mat2 = pad_dim(mat2, k_padded_length, 0)
444
+ return torch.ops.aten.mm(mat1, mat2)
445
+ elif n_padded_length != 0:
446
+ mat2 = pad_dim(mat2, n_padded_length, 1)
447
+ return torch.ops.aten.mm(mat1, mat2)[:, :-n_padded_length]
448
+ else:
449
+ mat1 = pad_dim(mat1, m_padded_length, 0)
450
+ return torch.ops.aten.mm(mat1, mat2)[:-m_padded_length, :]
451
+
452
+
453
+ def bmm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor:
454
+ return aten.bmm(mat1, mat2)
455
+
456
+
457
+ def should_pad_bmm(match: Match) -> bool:
458
+ if (
459
+ torch._inductor.config.keep_output_stride
460
+ and _result_layout_affects_graph_output(match)
461
+ ):
462
+ return False
463
+ mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2"))
464
+ return should_pad_common(mat1, mat2) and should_pad_bench(
465
+ mat1, mat2, torch.ops.aten.bmm
466
+ )
467
+
468
+
469
+ def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
470
+ m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
471
+ k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
472
+ n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
473
+
474
+ if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0:
475
+ return pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length)
476
+
477
+ return aten.bmm(mat1, mat2)
478
+
479
+
480
+ def pad_bmm(
481
+ mat1: Tensor,
482
+ mat2: Tensor,
483
+ m_padded_length: int,
484
+ k_padded_length: int,
485
+ n_padded_length: int,
486
+ ) -> Tensor:
487
+ # bmm_replace will go through pad_bmm multiple times if multiple dimensions are needed to be padded
488
+ if k_padded_length != 0:
489
+ mat1 = pad_dim(mat1, k_padded_length, 2)
490
+ mat2 = pad_dim(mat2, k_padded_length, 1)
491
+
492
+ return aten.bmm(mat1, mat2)
493
+ elif n_padded_length != 0:
494
+ mat2 = pad_dim(mat2, n_padded_length, 2)
495
+ return aten.bmm(mat1, mat2)[:, :, :-n_padded_length].contiguous()
496
+ else:
497
+ mat1 = pad_dim(mat1, m_padded_length, 1)
498
+ return aten.bmm(mat1, mat2)[:, :-m_padded_length, :].contiguous()
499
+
500
+
501
+ @functools.lru_cache(None)
502
+ def _pad_mm_init():
503
+ from .joint_graph import patterns
504
+
505
+ if torch.cuda.is_available():
506
+ # workaround https://github.com/pytorch/pytorch/issues/97894
507
+ device = "cuda"
508
+ else:
509
+ device = "cpu"
510
+
511
+ # sizes/values dont actually matter for initial trace
512
+ # once we get a possible match we re-trace with the actual values and verify the match still holds
513
+
514
+ dim2a = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True)
515
+ dim2b = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True)
516
+
517
+ dim3a = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True)
518
+ dim3b = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True)
519
+
520
+ dim1a = functools.partial(torch.empty, (4), device=device, requires_grad=True)
521
+
522
+ # workaround https://github.com/pytorch/pytorch/issues/97894
523
+ # 0.113377 is a "magic" value that lets us recover the lost input arg relationship
524
+ rep = {"beta": 0.213377, "alpha": 0.113377}
525
+
526
+ for pattern, replacement, args, workaround, extra_check in [
527
+ (
528
+ mm_pattern,
529
+ mm_replace,
530
+ [dim2a(), dim2b()],
531
+ {},
532
+ should_pad_mm,
533
+ ),
534
+ (
535
+ bmm_pattern,
536
+ bmm_replace,
537
+ [dim3a(), dim3b()],
538
+ {},
539
+ should_pad_bmm,
540
+ ),
541
+ (
542
+ addmm_pattern,
543
+ addmm_replace,
544
+ [dim1a(), dim2a(), dim2b()],
545
+ rep,
546
+ should_pad_addmm,
547
+ ),
548
+ ]:
549
+ assert isinstance(workaround, dict) # mypy is unable to infer the type properly
550
+ register_replacement(
551
+ pattern,
552
+ replacement,
553
+ args,
554
+ joint_fwd_bwd,
555
+ patterns,
556
+ extra_check=extra_check,
557
+ scalar_workaround=workaround,
558
+ )
559
+ register_replacement(
560
+ pattern,
561
+ replacement,
562
+ args,
563
+ fwd_only,
564
+ patterns,
565
+ extra_check=extra_check,
566
+ scalar_workaround=workaround,
567
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py ADDED
@@ -0,0 +1,1980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import functools
3
+ import itertools
4
+ import math
5
+ import operator
6
+ from typing import Any, Tuple
7
+
8
+ import torch
9
+ from torch._dynamo.utils import counters
10
+ from torch.fx.experimental.symbolic_shapes import has_free_symbols
11
+ from ..lowering import lowerings as L, require_channels_last
12
+ from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match
13
+ from ..utils import pad_listlike
14
+ from .freezing_patterns import register_freezing_graph_pattern
15
+ from .post_grad import register_lowering_pattern
16
+
17
+ aten = torch.ops.aten
18
+ prims = torch.ops.prims
19
+ quantized_decomposed = torch.ops.quantized_decomposed
20
+ quantized = torch.ops.quantized
21
+
22
+ """
23
+ The quantization.py file primarily incorporates passes related to quantization fusion
24
+ in inductor, includes:
25
+ 1. Dequant Promotion;
26
+ 2. Conv/GEMM weight prepack with oneDNN Library;
27
+ 3. Conv/GEMM quantization fusion with output quant node (if have);
28
+ 4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more;
29
+
30
+ It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference
31
+ of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is
32
+ 1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM.
33
+ 2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node.
34
+ Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16
35
+ quantization.
36
+ """
37
+
38
+
39
+ def _may_generate_pattern_with_dtype_convert(pattern, dtype=Arg(), dtype_convert=True):
40
+ if dtype_convert:
41
+ return CallFunction(
42
+ prims.convert_element_type.default,
43
+ pattern,
44
+ dtype,
45
+ )
46
+ else:
47
+ return pattern
48
+
49
+
50
+ def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True):
51
+ if with_reshape:
52
+ return CallFunction(
53
+ torch.ops.aten.reshape.default,
54
+ pattern,
55
+ reshape_size,
56
+ )
57
+ else:
58
+ return pattern
59
+
60
+
61
+ def _generate_linear_t_pattern(
62
+ _dequant_per_channel_pattern,
63
+ dtype,
64
+ ):
65
+ assert dtype in [torch.float32, torch.bfloat16]
66
+ t_pattern = CallFunction(
67
+ aten.permute.default,
68
+ _may_generate_pattern_with_dtype_convert(
69
+ _dequant_per_channel_pattern,
70
+ KeywordArg("autocast_wgt_dtype"),
71
+ dtype == torch.bfloat16,
72
+ ),
73
+ KeywordArg("permute_axes"),
74
+ )
75
+ return t_pattern
76
+
77
+
78
+ """
79
+ dequantize activation:
80
+ x = x.to(fp32)
81
+ x = x - zero_point
82
+ x = x * scale
83
+ """
84
+ dequantize_per_tensor_activation_pattern = CallFunction(
85
+ aten.mul.Tensor,
86
+ CallFunction(
87
+ aten.sub.Tensor,
88
+ CallFunction(
89
+ prims.convert_element_type.default,
90
+ KeywordArg("x"),
91
+ KeywordArg("x_dq_dtype"),
92
+ ),
93
+ KeywordArg("x_zp"),
94
+ ),
95
+ KeywordArg("x_scale"),
96
+ )
97
+
98
+ dequantize_per_channel_weight_pattern = CallFunction(
99
+ quantized_decomposed.dequantize_per_channel.default,
100
+ KeywordArg("q_weight"),
101
+ KeywordArg("w_scale"),
102
+ KeywordArg("w_zp"),
103
+ KeywordArg("w_axis"),
104
+ KeywordArg("w_quant_min"),
105
+ KeywordArg("w_quant_max"),
106
+ KeywordArg("w_dtype"),
107
+ )
108
+
109
+ dequantize_per_channel_to_bf16_weight_pattern = (
110
+ _may_generate_pattern_with_dtype_convert(
111
+ dequantize_per_channel_weight_pattern,
112
+ KeywordArg("autocast_wgt_dtype"),
113
+ )
114
+ )
115
+
116
+ dequantize_per_channel_clone_weight_pattern = CallFunction(
117
+ aten.clone.default,
118
+ dequantize_per_channel_weight_pattern,
119
+ memory_format=KeywordArg("memory_format"),
120
+ )
121
+
122
+ dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction(
123
+ aten.clone.default,
124
+ dequantize_per_channel_to_bf16_weight_pattern,
125
+ memory_format=KeywordArg("memory_format"),
126
+ )
127
+
128
+
129
+ def get_dequantize_qconv_pt2e_pattern(users=1):
130
+ return CallFunction(
131
+ torch.ops.onednn.qconv2d_pointwise.default,
132
+ KeywordArg("x"),
133
+ KeywordArg("x_scale"), # x_scale
134
+ KeywordArg("x_zp"), # x_zp
135
+ KeywordArg("packed_weight"), # packed_weight
136
+ KeywordArg("w_scale"), # w_scale
137
+ KeywordArg("w_zp"), # w_zp
138
+ KeywordArg("b"), # bias
139
+ KeywordArg("stride"),
140
+ KeywordArg("padding"),
141
+ KeywordArg("dilation"),
142
+ KeywordArg("groups"),
143
+ KeywordArg("inv_output_scale"), # inv_output_scale = 1.0
144
+ KeywordArg("output_zero_point"), # output_zero_point = 0
145
+ KeywordArg("output_dtype"), # output_dtype = None
146
+ KeywordArg("attr"), # attr = "none"
147
+ Arg(), # scalars
148
+ Arg(), # algorithm
149
+ _users=users,
150
+ )
151
+
152
+
153
+ def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors):
154
+ qlinear_op = (
155
+ torch.ops.onednn.qlinear_pointwise.tensor
156
+ if x_scale_zp_are_tensors
157
+ else torch.ops.onednn.qlinear_pointwise.default
158
+ )
159
+ return CallFunction(
160
+ qlinear_op,
161
+ KeywordArg("x"),
162
+ KeywordArg("x_scale"),
163
+ KeywordArg("x_zp"),
164
+ KeywordArg("packed_weight"),
165
+ KeywordArg("w_scale"),
166
+ KeywordArg("w_zp"),
167
+ KeywordArg("b"),
168
+ KeywordArg("output_scale"),
169
+ KeywordArg("output_zero_point"),
170
+ KeywordArg("output_dtype"),
171
+ KeywordArg("postop_name"),
172
+ KeywordArg("postop_args"),
173
+ KeywordArg("postop_algorithm"),
174
+ )
175
+
176
+
177
+ dequantize_accum_pattern = CallFunction(
178
+ aten.mul.Tensor,
179
+ CallFunction(
180
+ aten.sub.Tensor,
181
+ CallFunction(
182
+ prims.convert_element_type.default,
183
+ KeywordArg("accum"),
184
+ KeywordArg("accum_dq_dtype"),
185
+ ),
186
+ KeywordArg("accum_zp"),
187
+ ),
188
+ KeywordArg("accum_scale"),
189
+ )
190
+
191
+
192
+ def generate_pattern_with_binary(
193
+ binary_post_op,
194
+ computation_call,
195
+ extra_input_pattern,
196
+ int8_mixed_bf16_with_inplace_add=False,
197
+ ):
198
+ binary_pattern = CallFunction(
199
+ binary_post_op,
200
+ computation_call,
201
+ extra_input_pattern,
202
+ )
203
+ return _may_generate_pattern_with_dtype_convert(
204
+ binary_pattern,
205
+ KeywordArg("convert_dtype_after_inplace_add"),
206
+ int8_mixed_bf16_with_inplace_add,
207
+ )
208
+
209
+
210
+ def generate_pattern_with_unary(computation_call, unary_post_op):
211
+ if unary_post_op is not None:
212
+ if unary_post_op == aten.hardtanh.default:
213
+ return CallFunction(
214
+ aten.clamp_max,
215
+ CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")),
216
+ KeywordArg("max_value"),
217
+ )
218
+ if unary_post_op == aten.hardswish.default:
219
+ return CallFunction(
220
+ aten.div,
221
+ CallFunction(
222
+ aten.mul,
223
+ computation_call,
224
+ CallFunction(
225
+ aten.clamp_max,
226
+ CallFunction(
227
+ aten.clamp_min,
228
+ CallFunction(aten.add, computation_call, 3),
229
+ 0,
230
+ ),
231
+ 6,
232
+ ),
233
+ ),
234
+ 6,
235
+ )
236
+ else:
237
+ return CallFunction(
238
+ unary_post_op,
239
+ computation_call,
240
+ )
241
+ return computation_call
242
+
243
+
244
+ def generate_pattern_with_output_quant(computation_call, dtype=torch.float32):
245
+ """
246
+ quantize output:
247
+ output = round(output * o_inv_scale)
248
+ output = output + zero_point
249
+ output = clamp_min(output, 0)
250
+ output = clamp_max(output, 127)
251
+ output = output.to(uint8)
252
+ """
253
+ assert dtype in [torch.float32, torch.bfloat16]
254
+ quantized_op_output_pattern_pt2e = CallFunction(
255
+ prims.convert_element_type.default,
256
+ CallFunction(
257
+ aten.clamp_max.default,
258
+ CallFunction(
259
+ aten.clamp_min.default,
260
+ CallFunction(
261
+ aten.add.Tensor,
262
+ CallFunction(
263
+ aten.round.default,
264
+ CallFunction(
265
+ aten.mul.Tensor,
266
+ _may_generate_pattern_with_dtype_convert(
267
+ computation_call,
268
+ KeywordArg("autocast_output_quant_dtype"),
269
+ dtype == torch.bfloat16,
270
+ ),
271
+ KeywordArg("o_inv_scale"),
272
+ ),
273
+ ),
274
+ KeywordArg("o_zp"),
275
+ ),
276
+ KeywordArg("o_qmin"),
277
+ ),
278
+ KeywordArg("o_qmax"),
279
+ ),
280
+ KeywordArg("o_dtype"),
281
+ )
282
+ return quantized_op_output_pattern_pt2e
283
+
284
+
285
+ def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value):
286
+ if kwarg_name in check_node.kwargs:
287
+ actual_value = check_node.kwargs[kwarg_name]
288
+ return actual_value == expected_value
289
+ else:
290
+ assert len(check_node.args) >= (args_index + 1)
291
+ actual_value = check_node.args[args_index]
292
+ return actual_value == expected_value
293
+
294
+
295
+ def _is_valid_quantized_conv2d_optimization_pattern(output_dtype):
296
+ def fn(match):
297
+ if output_dtype is not None:
298
+ # Only keep matched pattern with same output_dtype
299
+ qconv_node_after_weight_prepack = filter_nodes(
300
+ match.nodes, torch.ops.onednn.qconv2d_pointwise
301
+ )[0]
302
+ return _check_node_kwarg_arg_value(
303
+ qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype
304
+ )
305
+ return True
306
+
307
+ return fn
308
+
309
+
310
+ def _register_quantized_conv_lowering(
311
+ pattern,
312
+ pass_number,
313
+ computation_op,
314
+ output_dtype,
315
+ unary_attr,
316
+ original_pattern_output_dtype=torch.float32,
317
+ ):
318
+ @register_lowering_pattern(
319
+ pattern,
320
+ extra_check=_is_valid_quantized_conv2d_optimization_pattern(output_dtype),
321
+ pass_number=pass_number,
322
+ )
323
+ def qconv(match: Match, *args, **kwargs):
324
+ # Activation QParams
325
+ x, x_scale, x_zp = (
326
+ kwargs["x"],
327
+ kwargs["x_scale"],
328
+ kwargs["x_zp"],
329
+ )
330
+ # Weight QParams
331
+ packed_weight, w_scale, w_zp = (
332
+ kwargs["packed_weight"],
333
+ kwargs["w_scale"],
334
+ kwargs["w_zp"],
335
+ )
336
+ # Conv Params
337
+ b, stride, padding, dilation, groups = (
338
+ kwargs["b"],
339
+ kwargs["stride"],
340
+ kwargs["padding"],
341
+ kwargs["dilation"],
342
+ kwargs["groups"],
343
+ )
344
+ assert output_dtype in [None, torch.float32, torch.bfloat16]
345
+ # Output QParams
346
+ o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0
347
+ o_zero_point = kwargs["o_zp"] if output_dtype is None else 0
348
+ assert (
349
+ kwargs["output_dtype"] is original_pattern_output_dtype
350
+ ) # Expected int8-in fp32-out qconv in weight prepack phase
351
+ assert (
352
+ kwargs["attr"] == "none"
353
+ ) # Expected no post op fused in weight prepack phase
354
+ if unary_attr.op_name == "hardtanh":
355
+ min_value = kwargs.get("min_value")
356
+ max_value = kwargs.get("max_value")
357
+ unary_attr.scalars_attr = [min_value, max_value]
358
+
359
+ computation_args = (
360
+ x,
361
+ x_scale,
362
+ x_zp,
363
+ packed_weight,
364
+ w_scale,
365
+ w_zp,
366
+ b,
367
+ stride,
368
+ padding,
369
+ dilation,
370
+ groups,
371
+ o_inv_scale,
372
+ o_zero_point,
373
+ output_dtype,
374
+ unary_attr.op_name,
375
+ unary_attr.scalars_attr,
376
+ unary_attr.algorithm_attr,
377
+ )
378
+ counters["inductor"]["qconv2d_unary_matcher_count"] += 1
379
+ counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes)
380
+ return L[computation_op](*computation_args)
381
+
382
+ return qconv
383
+
384
+
385
+ def _is_valid_quantized_linear_optimization_pattern(output_dtype):
386
+ def fn(match):
387
+ if output_dtype is not None:
388
+ # Only keep matched pattern with same output_dtype
389
+ qlinear_node_after_weight_prepack = filter_nodes(
390
+ match.nodes, torch.ops.onednn.qlinear_pointwise
391
+ )[0]
392
+ return _check_node_kwarg_arg_value(
393
+ qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype
394
+ )
395
+ return True
396
+
397
+ return fn
398
+
399
+
400
+ def _register_quantized_linear_lowering(
401
+ pattern,
402
+ pass_number,
403
+ computation_op,
404
+ output_dtype,
405
+ unary_attr,
406
+ original_pattern_output_dtype=torch.float32,
407
+ ):
408
+ @register_lowering_pattern(
409
+ pattern,
410
+ extra_check=_is_valid_quantized_linear_optimization_pattern(output_dtype),
411
+ pass_number=pass_number,
412
+ )
413
+ def qlinear(match: Match, *args, **kwargs):
414
+ # Activation QParams
415
+ x, x_scale, x_zp = (
416
+ kwargs["x"],
417
+ kwargs["x_scale"],
418
+ kwargs["x_zp"],
419
+ )
420
+ # Weight QParams
421
+ packed_weight, w_scale, w_zp = (
422
+ kwargs["packed_weight"],
423
+ kwargs["w_scale"],
424
+ kwargs["w_zp"],
425
+ )
426
+
427
+ # bias
428
+ b = kwargs["b"] if "b" in kwargs else None
429
+
430
+ # Output QParams
431
+ o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0
432
+ o_zero_point = kwargs["o_zp"] if output_dtype is None else 0
433
+ assert (
434
+ kwargs["output_dtype"] is original_pattern_output_dtype
435
+ ) # Expected int8-in fp32/bf16-out qlinear in weight prepack phase
436
+ assert (
437
+ kwargs["postop_name"] == "none"
438
+ ) # Expected no post op fused in weight prepack phase
439
+
440
+ computation_args = (
441
+ x,
442
+ x_scale,
443
+ x_zp,
444
+ packed_weight,
445
+ w_scale,
446
+ w_zp,
447
+ b,
448
+ o_inv_scale,
449
+ o_zero_point,
450
+ output_dtype,
451
+ unary_attr.op_name,
452
+ unary_attr.scalars_attr,
453
+ unary_attr.algorithm_attr,
454
+ )
455
+ counters["inductor"]["qlinear_unary_matcher_count"] += 1
456
+ counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes)
457
+ return L[computation_op](*computation_args)
458
+
459
+ return qlinear
460
+
461
+
462
+ def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype):
463
+ # Check if it's a valid Conv Binary Pattern:
464
+ # * qconv2d_pointwise should only has one users
465
+ # * Extra input of binary node comes from dequant pattern
466
+ # * the two inputs of binary node should have attribute "meta" and should be tensors
467
+ # * the two inputs of binary node should have the same shape
468
+ # * All users of the extra input in this pattern should be
469
+ # ancestor nodes of the compute node, except for the binary node
470
+ # connected to the compute node.
471
+ def fn(match):
472
+ compute_node = filter_nodes(match.nodes, torch.ops.onednn.qconv2d_pointwise)[0]
473
+ # qconv2d_pointwise should only have one user
474
+ if len(compute_node.users) != 1:
475
+ return False
476
+ binary_node_inputs = next(iter(compute_node.users)).args
477
+ assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs"
478
+ if output_dtype is not None:
479
+ extra_input_of_binary_node = None
480
+ for arg in binary_node_inputs:
481
+ if arg != compute_node:
482
+ extra_input_of_binary_node = arg
483
+ break
484
+ assert extra_input_of_binary_node is not None
485
+ # Extra input of binary node comes from dequant pattern
486
+ if (not isinstance(extra_input_of_binary_node, torch.fx.Node)) or (
487
+ extra_input_of_binary_node.target != aten.mul.Tensor
488
+ ):
489
+ return False
490
+
491
+ # the two inputs of binary node should have attribute "meta" and should be tensors
492
+ if not (
493
+ hasattr(binary_node_inputs[0], "meta")
494
+ and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
495
+ ) or not (
496
+ hasattr(binary_node_inputs[1], "meta")
497
+ and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
498
+ ):
499
+ return False
500
+ # the two inputs of binary node should have the same shape
501
+ if (
502
+ binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr]
503
+ != binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr]
504
+ ):
505
+ return False
506
+
507
+ # All users of the extra input in this pattern should be
508
+ # ancestor nodes of the compute node, except for the binary node
509
+ # connected to the compute node.
510
+
511
+ from .mkldnn_fusion import _get_remaining_users
512
+
513
+ extra_input_of_pattern = (
514
+ match.kwargs["accum"]
515
+ if output_dtype is None
516
+ else match.kwargs["accum_after_dequant"]
517
+ )
518
+ if (
519
+ len(
520
+ _get_remaining_users(
521
+ extra_input_of_pattern,
522
+ compute_node,
523
+ )
524
+ )
525
+ > 1
526
+ or extra_input_of_pattern == compute_node.args[0]
527
+ ):
528
+ return False
529
+ return True
530
+
531
+ return fn
532
+
533
+
534
+ def _register_quantized_conv_binary_lowering(
535
+ pattern,
536
+ pass_number,
537
+ computation_op,
538
+ output_dtype,
539
+ binary_unary_attr,
540
+ ):
541
+ @register_lowering_pattern(
542
+ pattern,
543
+ extra_check=_is_valid_quantized_conv_binary_optimization_pattern(output_dtype),
544
+ pass_number=pass_number,
545
+ )
546
+ def qconv_binary(match: Match, *args, **kwargs):
547
+ x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"]
548
+ accum = (
549
+ kwargs["accum"] if output_dtype is None else kwargs["accum_after_dequant"]
550
+ )
551
+ accum_scale = kwargs["accum_scale"] if output_dtype is None else 1.0
552
+ accum_zp = kwargs["accum_zp"] if output_dtype is None else 0
553
+ packed_weight, w_scale, w_zp = (
554
+ kwargs["packed_weight"],
555
+ kwargs["w_scale"],
556
+ kwargs["w_zp"],
557
+ )
558
+ b, stride, padding, dilation, groups = (
559
+ kwargs["b"],
560
+ kwargs["stride"],
561
+ kwargs["padding"],
562
+ kwargs["dilation"],
563
+ kwargs["groups"],
564
+ )
565
+ # Output QParams
566
+ o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0
567
+ o_zero_point = kwargs["o_zp"] if output_dtype is None else 0
568
+
569
+ accum.realize()
570
+ from .mkldnn_fusion import _can_be_inplace
571
+
572
+ assert _can_be_inplace(
573
+ accum
574
+ ), "QConv Binary Inplace Fusion requires accum is not an alias or mutation."
575
+
576
+ computation_args = (
577
+ x,
578
+ x_scale,
579
+ x_zp,
580
+ accum,
581
+ accum_scale,
582
+ accum_zp,
583
+ packed_weight,
584
+ w_scale,
585
+ w_zp,
586
+ b,
587
+ stride,
588
+ padding,
589
+ dilation,
590
+ groups,
591
+ o_inv_scale,
592
+ o_zero_point,
593
+ output_dtype,
594
+ binary_unary_attr.binary_op_name,
595
+ binary_unary_attr.alpha,
596
+ binary_unary_attr.unary_op_name,
597
+ binary_unary_attr.scalars_attr,
598
+ binary_unary_attr.algorithm_attr,
599
+ )
600
+ counters["inductor"]["qconv2d_binary_matcher_count"] += 1
601
+ counters["inductor"]["qconv2d_binary_matcher_nodes"] += len(match.nodes)
602
+ return L[computation_op](*computation_args)
603
+
604
+ return qconv_binary
605
+
606
+
607
+ def _register_quantization_unary_fusion():
608
+ class UnaryAttr:
609
+ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
610
+ self.op_name = op_name
611
+ self.scalars_attr = scalars_attr if scalars_attr else []
612
+ self.algorithm_attr = algorithm_attr if algorithm_attr else ""
613
+
614
+ for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
615
+ # QConv2d
616
+ # Priority 1 to match: QConv2d Unary pattern with int8 output
617
+ # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
618
+ # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
619
+ conv_unary_replace_patterns = {
620
+ UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
621
+ get_dequantize_qconv_pt2e_pattern(1),
622
+ dtype=original_pattern_output_dtype,
623
+ ),
624
+ UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
625
+ generate_pattern_with_unary(
626
+ get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
627
+ ),
628
+ dtype=original_pattern_output_dtype,
629
+ ),
630
+ UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant(
631
+ generate_pattern_with_unary(
632
+ get_dequantize_qconv_pt2e_pattern(1), aten.hardtanh.default
633
+ ),
634
+ dtype=original_pattern_output_dtype,
635
+ ),
636
+ UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant(
637
+ generate_pattern_with_unary(
638
+ get_dequantize_qconv_pt2e_pattern(2), aten.hardswish.default
639
+ ),
640
+ dtype=original_pattern_output_dtype,
641
+ ),
642
+ }
643
+
644
+ for unary_attr, patterns in conv_unary_replace_patterns.items():
645
+ # Register qconv2d pattern for ExternKernel Lowering
646
+ _register_quantized_conv_lowering(
647
+ patterns,
648
+ 1, # pass_number
649
+ torch.ops.onednn.qconv2d_pointwise, # computation_op
650
+ None, # output_dtype, None is the default value for int8 output
651
+ unary_attr, # unary_attr
652
+ original_pattern_output_dtype=original_pattern_output_dtype,
653
+ )
654
+
655
+ # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
656
+ conv_unary_replace_float_out_patterns = {
657
+ UnaryAttr("relu", [], ""): generate_pattern_with_unary(
658
+ get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
659
+ ),
660
+ UnaryAttr("hardtanh", [], ""): generate_pattern_with_unary(
661
+ get_dequantize_qconv_pt2e_pattern(1), aten.hardtanh.default
662
+ ),
663
+ UnaryAttr("hardswish", [], ""): generate_pattern_with_unary(
664
+ get_dequantize_qconv_pt2e_pattern(2), aten.hardswish.default
665
+ ),
666
+ }
667
+
668
+ for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():
669
+ # Register qconv2d pattern for ExternKernel Lowering
670
+ _register_quantized_conv_lowering(
671
+ patterns,
672
+ 2, # pass_number
673
+ torch.ops.onednn.qconv2d_pointwise, # computation_op
674
+ original_pattern_output_dtype, # output_dtype
675
+ unary_attr, # unary_attr
676
+ original_pattern_output_dtype=original_pattern_output_dtype,
677
+ )
678
+
679
+ # QLinear
680
+ for x_scale_zp_are_tensors in (False, True):
681
+ qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
682
+ # Priority 1 to match: QLinear Unary pattern with int8 output
683
+ linear_unary_replace_patterns = {
684
+ UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
685
+ qlinear_pattern,
686
+ dtype=original_pattern_output_dtype,
687
+ ),
688
+ UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
689
+ generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
690
+ dtype=original_pattern_output_dtype,
691
+ ),
692
+ }
693
+
694
+ for unary_attr, patterns in linear_unary_replace_patterns.items():
695
+ _register_quantized_linear_lowering(
696
+ patterns,
697
+ 1, # pass_number
698
+ torch.ops.onednn.qlinear_pointwise, # computation_op
699
+ None, # output_dtype
700
+ unary_attr, # unary_attr
701
+ original_pattern_output_dtype=original_pattern_output_dtype,
702
+ )
703
+
704
+ # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
705
+ linear_unary_replace_float_out_patterns = {
706
+ UnaryAttr("relu", [], ""): generate_pattern_with_unary(
707
+ qlinear_pattern, aten.relu.default
708
+ ),
709
+ }
710
+
711
+ for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
712
+ _register_quantized_linear_lowering(
713
+ patterns,
714
+ 2, # pass_number
715
+ torch.ops.onednn.qlinear_pointwise, # computation_op
716
+ original_pattern_output_dtype, # output_dtype
717
+ unary_attr, # unary_attr
718
+ original_pattern_output_dtype=original_pattern_output_dtype,
719
+ )
720
+
721
+
722
+ def _register_quantization_binary_fusion():
723
+ class BinaryUnaryAttr:
724
+ def __init__(
725
+ self,
726
+ binary_op_name: str,
727
+ alpha=None,
728
+ unary_op_name: str = "none",
729
+ scalars_attr=None,
730
+ algorithm_attr=None,
731
+ ):
732
+ self.binary_op_name = binary_op_name
733
+ self.alpha = alpha if alpha else 1.0
734
+ self.unary_op_name = unary_op_name
735
+ self.scalars_attr = scalars_attr if scalars_attr else []
736
+ self.algorithm_attr = algorithm_attr if algorithm_attr else ""
737
+
738
+ for int8_mixed_bf16_with_inplace_add in [False, True]:
739
+ # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
740
+ binary_replace_patterns = {
741
+ BinaryUnaryAttr(
742
+ "sum", 1.0, "none", [], ""
743
+ ): generate_pattern_with_output_quant(
744
+ generate_pattern_with_binary(
745
+ aten.add.Tensor,
746
+ get_dequantize_qconv_pt2e_pattern(1),
747
+ dequantize_accum_pattern,
748
+ int8_mixed_bf16_with_inplace_add,
749
+ ),
750
+ dtype=torch.bfloat16
751
+ if int8_mixed_bf16_with_inplace_add
752
+ else torch.float32,
753
+ ),
754
+ BinaryUnaryAttr(
755
+ "sum", 1.0, "relu", [], ""
756
+ ): generate_pattern_with_output_quant(
757
+ generate_pattern_with_unary(
758
+ generate_pattern_with_binary(
759
+ aten.add.Tensor,
760
+ get_dequantize_qconv_pt2e_pattern(1),
761
+ dequantize_accum_pattern,
762
+ int8_mixed_bf16_with_inplace_add,
763
+ ),
764
+ aten.relu.default,
765
+ ),
766
+ dtype=torch.bfloat16
767
+ if int8_mixed_bf16_with_inplace_add
768
+ else torch.float32,
769
+ ),
770
+ }
771
+
772
+ for binary_unary_attr, patterns in binary_replace_patterns.items():
773
+ _register_quantized_conv_binary_lowering(
774
+ patterns,
775
+ 0, # pass_number
776
+ torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
777
+ None, # output_dtype
778
+ binary_unary_attr, # binary_unary_attr
779
+ )
780
+
781
+ # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output
782
+ binary_replace_float_out_patterns = {
783
+ BinaryUnaryAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary(
784
+ generate_pattern_with_binary(
785
+ aten.add.Tensor,
786
+ get_dequantize_qconv_pt2e_pattern(1),
787
+ KeywordArg("accum_after_dequant"),
788
+ int8_mixed_bf16_with_inplace_add,
789
+ ),
790
+ aten.relu.default,
791
+ ),
792
+ }
793
+
794
+ for (
795
+ binary_unary_attr,
796
+ patterns,
797
+ ) in binary_replace_float_out_patterns.items():
798
+ if int8_mixed_bf16_with_inplace_add:
799
+ _register_quantized_conv_binary_lowering(
800
+ patterns,
801
+ 0, # pass_number
802
+ torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
803
+ # Note that for int8-mixed-bf16 and non-inplace add, because we have
804
+ # q-dq inserted at extra input of add, so the non-inplace add has bf16 and fp32 inputs,
805
+ # the output dtype will be float32.
806
+ # For inplace add, there is a extra to_bf16 node at add output, so the fusion pattern has bfloat16 output.
807
+ torch.bfloat16,
808
+ binary_unary_attr, # binary_unary_attr
809
+ )
810
+ else:
811
+ _register_quantized_conv_binary_lowering(
812
+ patterns,
813
+ 1, # pass_number
814
+ torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
815
+ torch.float32,
816
+ binary_unary_attr, # binary_unary_attr
817
+ )
818
+
819
+ # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output
820
+ binary_replace_float_out_patterns = {
821
+ BinaryUnaryAttr("sum", 1.0, "none", [], ""): generate_pattern_with_binary(
822
+ aten.add.Tensor,
823
+ get_dequantize_qconv_pt2e_pattern(1),
824
+ KeywordArg("accum_after_dequant"),
825
+ int8_mixed_bf16_with_inplace_add,
826
+ ),
827
+ }
828
+
829
+ for (
830
+ binary_unary_attr,
831
+ patterns,
832
+ ) in binary_replace_float_out_patterns.items():
833
+ _register_quantized_conv_binary_lowering(
834
+ patterns,
835
+ 1 if int8_mixed_bf16_with_inplace_add else 2, # pass_number
836
+ torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
837
+ # Same output dtype setting as conv-add-relu pattern
838
+ torch.bfloat16 if int8_mixed_bf16_with_inplace_add else torch.float32,
839
+ binary_unary_attr, # binary_unary_attr
840
+ )
841
+
842
+
843
+ def _is_valid_quantized_maxpool2d_optimization_pattern():
844
+ def fn(match):
845
+ # Only match the pattern which max_pool2d_with_indices returns value
846
+ # instead of indices.
847
+ get_item_node = filter_nodes(match.nodes, operator.getitem)[0]
848
+ return get_item_node.args[1] == 0
849
+
850
+ return fn
851
+
852
+
853
+ def _register_quantized_maxpool2d_lowering(
854
+ pattern,
855
+ computation_op,
856
+ ):
857
+ @register_lowering_pattern(
858
+ pattern,
859
+ extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(),
860
+ )
861
+ def qmaxpool2d(match: Match, *args, **kwargs):
862
+ x = kwargs["x"]
863
+ kernel_size = kwargs["kernel_size"]
864
+ stride = kwargs["stride"] if ("stride" in kwargs) else None
865
+ padding = kwargs["padding"] if ("padding" in kwargs) else 0
866
+ dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1
867
+ ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False
868
+
869
+ if padding == 0:
870
+ padding = [0, 0]
871
+ if dilation == 1:
872
+ dilation = [1, 1]
873
+ if not stride:
874
+ stride = kernel_size
875
+ kernel_size = pad_listlike(kernel_size, 2)
876
+ stride = pad_listlike(stride, 2)
877
+ padding = pad_listlike(padding, 2)
878
+ dilation = pad_listlike(dilation, 2)
879
+
880
+ assert len(kernel_size) == 2
881
+ assert len(stride) == 2
882
+ assert len(padding) == 2
883
+ assert len(dilation) == 2
884
+
885
+ computation_args = (
886
+ x,
887
+ kernel_size,
888
+ stride,
889
+ padding,
890
+ dilation,
891
+ ceil_mode,
892
+ )
893
+ computation_args, _ = require_channels_last(computation_op, *computation_args)
894
+ return L[computation_op](*computation_args)
895
+
896
+ return qmaxpool2d
897
+
898
+
899
+ def _register_quantization_maxpool2d():
900
+ # Currently, the default parameters are not in FX Graph generated by Dynamo export.
901
+ # So, if user defines nn.MaxPool2d with different assignment of default parameter,
902
+ # it will generate graph with different number of input nodes and hence
903
+ # different pattern to be matched.
904
+ # Refer to the issue: https://github.com/pytorch/pytorch/issues/105901
905
+ max_pool2d_args_list = [
906
+ [
907
+ KeywordArg("stride"),
908
+ ],
909
+ [
910
+ KeywordArg("stride"),
911
+ KeywordArg("padding"),
912
+ ],
913
+ [
914
+ KeywordArg("stride"),
915
+ KeywordArg("padding"),
916
+ KeywordArg("dilation"),
917
+ ],
918
+ [
919
+ KeywordArg("stride"),
920
+ KeywordArg("padding"),
921
+ KeywordArg("dilation"),
922
+ KeywordArg("ceil_mode"),
923
+ ],
924
+ ]
925
+
926
+ for max_pool2d_args in max_pool2d_args_list:
927
+ dequantize_maxpool2d_pattern = CallFunction(
928
+ aten.max_pool2d_with_indices.default,
929
+ dequantize_per_tensor_activation_pattern,
930
+ KeywordArg("kernel_size"),
931
+ *max_pool2d_args,
932
+ )
933
+ dequantize_maxpool2d_get_item_pattern = CallFunction(
934
+ operator.getitem,
935
+ dequantize_maxpool2d_pattern,
936
+ Arg(),
937
+ )
938
+ _register_quantized_maxpool2d_lowering(
939
+ generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern),
940
+ quantized.max_pool2d.default,
941
+ )
942
+
943
+
944
+ def _is_input_output_same_scale_zp(check_node):
945
+ def fn(match):
946
+ # Ensure all the inputs and output has same scale and zero point
947
+ # Step 1: Check inputs/output zero point
948
+ sub_nodes = filter_nodes(match.nodes, aten.sub.Tensor)
949
+ zero_points = [node.args[1] for node in sub_nodes]
950
+ add_nodes = filter_nodes(match.nodes, aten.add.Tensor)
951
+ assert len(add_nodes) == 1, "expect only 1 add node at output quant pattern"
952
+ zero_points.append(add_nodes[0].args[1])
953
+ if not all(zero_point == zero_points[0] for zero_point in zero_points):
954
+ return False
955
+
956
+ # Step 2: Check inputs/output scale
957
+ mul_nodes = filter_nodes(match.nodes, aten.mul.Tensor)
958
+ # We need to find mul node at output since the scale value is reciprocal to input scale.
959
+ # Mul node at output should connect to cat node directly.
960
+ scales = [
961
+ (
962
+ mul_node.args[1]
963
+ if mul_node.args[0].target is check_node # type: ignore[union-attr]
964
+ else 1.0 / mul_node.args[1] # type: ignore[operator]
965
+ )
966
+ for mul_node in mul_nodes
967
+ ]
968
+ if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type]
969
+ return False
970
+
971
+ return True
972
+
973
+ return fn
974
+
975
+
976
+ def _register_quantized_cat_lowering(
977
+ pattern,
978
+ computation_op,
979
+ ):
980
+ @register_lowering_pattern(
981
+ pattern,
982
+ extra_check=_is_input_output_same_scale_zp(aten.cat.default),
983
+ )
984
+ def qcat(match: Match, inputs, dim, **kwargs):
985
+ # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...]
986
+ uint8_inputs = [input[0] for input in inputs]
987
+ return L[computation_op](uint8_inputs, dim)
988
+
989
+ return qcat
990
+
991
+
992
+ _raw_dequantize_per_tensor_activation_pattern = CallFunction(
993
+ aten.mul.Tensor,
994
+ CallFunction(
995
+ aten.sub.Tensor,
996
+ CallFunction(
997
+ prims.convert_element_type.default,
998
+ Arg(),
999
+ Arg(),
1000
+ ),
1001
+ Arg(),
1002
+ ),
1003
+ Arg(),
1004
+ )
1005
+
1006
+
1007
+ def _register_quantization_cat():
1008
+ dequantize_cat_pattern = CallFunction(
1009
+ aten.cat.default,
1010
+ ListOf(_raw_dequantize_per_tensor_activation_pattern),
1011
+ KeywordArg("dim"),
1012
+ )
1013
+ _register_quantized_cat_lowering(
1014
+ generate_pattern_with_output_quant(dequantize_cat_pattern),
1015
+ aten.cat,
1016
+ )
1017
+
1018
+
1019
+ def _register_quantized_reshape_lowering(
1020
+ pattern,
1021
+ computation_op,
1022
+ ):
1023
+ @register_lowering_pattern(
1024
+ pattern,
1025
+ extra_check=_is_input_output_same_scale_zp(aten.reshape.default),
1026
+ )
1027
+ def qreshape(match: Match, *args, **kwargs):
1028
+ qx = kwargs["x"]
1029
+ shape = kwargs["shape"]
1030
+ counters["inductor"]["qreshape_matcher_count"] += 1
1031
+ counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes)
1032
+ return L[computation_op](qx, shape)
1033
+
1034
+ return qreshape
1035
+
1036
+
1037
+ def _register_quantization_reshape():
1038
+ dequantize_reshape_pattern = CallFunction(
1039
+ torch.ops.aten.reshape.default,
1040
+ dequantize_per_tensor_activation_pattern,
1041
+ KeywordArg("shape"),
1042
+ )
1043
+ _register_quantized_reshape_lowering(
1044
+ generate_pattern_with_output_quant(dequantize_reshape_pattern),
1045
+ aten.reshape,
1046
+ )
1047
+
1048
+
1049
+ def _register_quantization_lowerings():
1050
+ _register_quantization_unary_fusion()
1051
+ _register_quantization_binary_fusion()
1052
+ _register_quantization_maxpool2d()
1053
+ _register_quantization_cat()
1054
+ _register_quantization_reshape()
1055
+
1056
+
1057
+ def _is_valid_dequant_promotion_pattern(dtype=torch.float32):
1058
+ def _inner(match):
1059
+ assert dtype in [torch.float32, torch.bfloat16]
1060
+ dequant_pattern_end_node = match.output_node()
1061
+ if dequant_pattern_end_node.target not in [
1062
+ aten.mul.Tensor,
1063
+ prims.convert_element_type.default,
1064
+ aten.reshape.default,
1065
+ ]:
1066
+ return False
1067
+
1068
+ if dequant_pattern_end_node.target is aten.reshape.default:
1069
+ mul_node = (
1070
+ dequant_pattern_end_node.args[0] # pattern: linear <- reshape <- mul
1071
+ if dtype == torch.float32
1072
+ else dequant_pattern_end_node.args[0].args[
1073
+ 0
1074
+ ] # pattern: linear <- reshape <- to_bf16 <- mul
1075
+ )
1076
+ else:
1077
+ mul_node = (
1078
+ dequant_pattern_end_node # pattern: linear <- mul
1079
+ if dtype == torch.float32
1080
+ else dequant_pattern_end_node.args[
1081
+ 0
1082
+ ] # pattern: linear <- to_bf16 <- mul
1083
+ )
1084
+
1085
+ sub_node = mul_node.args[0]
1086
+ to_fp32_node = sub_node.args[0]
1087
+ if (
1088
+ mul_node.target is aten.mul.Tensor
1089
+ and sub_node.target is aten.sub.Tensor
1090
+ and to_fp32_node.target is prims.convert_element_type.default
1091
+ and len(list(dequant_pattern_end_node.users)) > 1
1092
+ ):
1093
+ # If dequant pattern has more than 1 users, then do dequant promoted
1094
+ return True
1095
+ return False
1096
+
1097
+ return _inner
1098
+
1099
+
1100
+ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
1101
+ @register_freezing_graph_pattern(
1102
+ pattern,
1103
+ extra_check=_is_valid_dequant_promotion_pattern(dtype),
1104
+ pass_number=pass_number,
1105
+ )
1106
+ def dequant_promotion(match: Match, *args, **kwargs):
1107
+ # Dequant_promotion will transform
1108
+ # graph 1:
1109
+ # quant
1110
+ # + - - - | - - - +
1111
+ # | dequant |
1112
+ # | / \ |
1113
+ # | node1 node2 |
1114
+ # + - | - - - | - +
1115
+ # quant quant
1116
+ # into:
1117
+ # graph 2:
1118
+ # quant
1119
+ # + - - / - \ - - +
1120
+ # |dequant dequant|
1121
+ # | | | |
1122
+ # | node1 node2 |
1123
+ # + - | - - - | - +
1124
+ # quant quant
1125
+ # In graph 1, the dequant node is shared by node1 and node2,
1126
+ # as a result, neither node1 nor node2 could form an int8
1127
+ # fusion pattern.
1128
+ # After this transformation, the graph 2 could hit the int8
1129
+ # fusion pattern: dequant-node-quant, respectively for
1130
+ # node1 and node2.
1131
+ assert dtype in [torch.float32, torch.bfloat16]
1132
+
1133
+ def clone_to_new_node(graph, source_node, user_node):
1134
+ # Clone the source_node to a new node
1135
+ # Replace user_node's input from source_node to new_node
1136
+ assert (
1137
+ source_node.op == "call_function"
1138
+ ), "clone_to_new_node only support node.op call_function"
1139
+ with graph.inserting_before(user_node):
1140
+ new_node = graph.call_function(
1141
+ source_node.target,
1142
+ args=source_node.args,
1143
+ kwargs=source_node.kwargs,
1144
+ )
1145
+ new_node.meta = copy.copy(source_node.meta)
1146
+ user_node.replace_input_with(source_node, new_node)
1147
+ return new_node
1148
+
1149
+ # Find the start node and end node of a dequant pattern
1150
+ # * End node should be the match.output_node()
1151
+ # * Start node should be the node of dtype convert to float32
1152
+ dequant_pattern_end_node = match.output_node()
1153
+ assert dequant_pattern_end_node.target in [
1154
+ aten.mul.Tensor,
1155
+ prims.convert_element_type.default,
1156
+ aten.reshape.default,
1157
+ ]
1158
+
1159
+ # For a dequant pattern, we should expect see the node list as:
1160
+ # * OPT(aten.reshape.default)
1161
+ # * OPT(prims.convert_element_type.default) (to_bf16)
1162
+ # * aten.mul
1163
+ # * aten.sub
1164
+ # * prims.convert_element_type.default (to_fp32)
1165
+ def _find_first_node_in_dequant_pattern(_node):
1166
+ if (
1167
+ _node.target is prims.convert_element_type.default
1168
+ and _node.args[1] == torch.float32
1169
+ ):
1170
+ # For a dequant pattern, we expect the start node is a to_fp32 node
1171
+ return _node
1172
+ else:
1173
+ assert (
1174
+ len(_node.args) >= 1
1175
+ ), "In in dequant pattern, each node should have more than 1 arg."
1176
+ return _find_first_node_in_dequant_pattern(_node.args[0])
1177
+
1178
+ dequant_pattern_start_node = _find_first_node_in_dequant_pattern(
1179
+ dequant_pattern_end_node
1180
+ )
1181
+
1182
+ # Clone the dequant pattern for each user node
1183
+ graph = match.graph
1184
+ user_node_list = list(dequant_pattern_end_node.users)
1185
+ for user_node in user_node_list[1:]:
1186
+ _source_node = dequant_pattern_end_node
1187
+ _user_node = user_node
1188
+ while _source_node != dequant_pattern_start_node.args[0]:
1189
+ _user_node = clone_to_new_node(graph, _source_node, _user_node)
1190
+ _source_node = _source_node.args[0] # type: ignore[assignment]
1191
+
1192
+ counters["inductor"]["dequant_promotion_matcher_count"] += 1
1193
+ counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes)
1194
+
1195
+
1196
+ def _is_valid_dequant_conv2d_pattern(dtype):
1197
+ def _inner(match):
1198
+ # Here we do some further check to ensure:
1199
+ # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now.
1200
+ # 2. The dequant pattern has only 1 user of conv2d node.
1201
+ # If these conditions don't meet, we will not
1202
+ # insert weight prepack node into the matched pattern.
1203
+ conv_node = match.output_node()
1204
+ assert conv_node.target is aten.convolution.default
1205
+ input_meta_value = conv_node.args[0].meta.get("val")
1206
+ weight_meta_value = conv_node.args[1].meta.get("val")
1207
+ for meta_value in [input_meta_value, weight_meta_value]:
1208
+ if (
1209
+ meta_value is None
1210
+ or meta_value.device.type != "cpu"
1211
+ or meta_value.dim() != 4
1212
+ ):
1213
+ # Only support conv2d now
1214
+ return False
1215
+
1216
+ assert dtype in [torch.float32, torch.bfloat16]
1217
+ if dtype == torch.float32:
1218
+ mul_node = conv_node.args[0]
1219
+ else:
1220
+ convert_to_bf16 = conv_node.args[0]
1221
+ mul_node = convert_to_bf16.args[0]
1222
+ sub_node = mul_node.args[0]
1223
+ to_fp32_node = sub_node.args[0]
1224
+
1225
+ assert to_fp32_node.target is prims.convert_element_type.default
1226
+ assert sub_node.target is aten.sub.Tensor
1227
+ assert mul_node.target is aten.mul.Tensor
1228
+ if (
1229
+ len(list(to_fp32_node.users)) != 1
1230
+ or len(list(sub_node.users)) != 1
1231
+ or len(list(mul_node.users)) != 1
1232
+ ):
1233
+ # Ensure the dequant pattern only has 1 user
1234
+ # since we will delete the dequant pattern here
1235
+ return False
1236
+ return True
1237
+
1238
+ return _inner
1239
+
1240
+
1241
+ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32):
1242
+ @register_freezing_graph_pattern(
1243
+ pattern,
1244
+ extra_check=_is_valid_dequant_conv2d_pattern(dtype),
1245
+ pass_number=pass_number,
1246
+ )
1247
+ def qconv_weight_prepack(match: Match, *args, **kwargs):
1248
+ """
1249
+ Match the pattern:
1250
+ int8 activation
1251
+ |
1252
+ dequant_per_tensor
1253
+ |
1254
+ Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight
1255
+
1256
+ Insert weight prepack node and change the pattern to:
1257
+ int8 activation
1258
+ |
1259
+ onednn.qconv2d_pointwise <- onednn.qconv_prepack <- int8_weight
1260
+ """
1261
+ assert dtype in [torch.float32, torch.bfloat16]
1262
+ conv_node = match.output_node()
1263
+ assert conv_node.target is aten.convolution.default
1264
+ if dtype == torch.float32:
1265
+ mul_node = conv_node.args[0]
1266
+ else:
1267
+ convert_to_bf16 = conv_node.args[0]
1268
+ mul_node = convert_to_bf16.args[0] # type: ignore[union-attr]
1269
+ sub_node = mul_node.args[0] # type: ignore[union-attr]
1270
+ to_fp32_node = sub_node.args[0] # type: ignore[union-attr]
1271
+ has_clone_to_channel_last_node_in_pattern = (
1272
+ conv_node.args[1].target is aten.clone.default # type: ignore[union-attr]
1273
+ )
1274
+ clone_node = (
1275
+ conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None
1276
+ )
1277
+
1278
+ if dtype == torch.float32:
1279
+ dequant_per_channel = (
1280
+ clone_node.args[0] # type: ignore[union-attr]
1281
+ if has_clone_to_channel_last_node_in_pattern
1282
+ else conv_node.args[1]
1283
+ )
1284
+ else:
1285
+ weight_to_bf16_node = (
1286
+ clone_node.args[0] # type: ignore[union-attr]
1287
+ if has_clone_to_channel_last_node_in_pattern
1288
+ else conv_node.args[1]
1289
+ )
1290
+ dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr]
1291
+
1292
+ assert (
1293
+ dequant_per_channel.target # type: ignore[union-attr]
1294
+ is quantized_decomposed.dequantize_per_channel.default
1295
+ )
1296
+
1297
+ # Activation QParams
1298
+ qx, x_zp, x_scale = (
1299
+ kwargs["x"],
1300
+ kwargs["x_zp"],
1301
+ kwargs["x_scale"],
1302
+ )
1303
+
1304
+ # Weight QParams
1305
+ qw, w_scale, w_zp = (
1306
+ kwargs["q_weight"],
1307
+ kwargs["w_scale"],
1308
+ kwargs["w_zp"],
1309
+ )
1310
+
1311
+ # Conv Params
1312
+ bias, stride, padding, dilation, groups = (
1313
+ kwargs["b"],
1314
+ kwargs["stride"],
1315
+ kwargs["padding"],
1316
+ kwargs["dilation"],
1317
+ kwargs["groups"],
1318
+ )
1319
+
1320
+ x_shape = qx.meta.get("tensor_meta").shape
1321
+ if has_free_symbols(x_shape):
1322
+ # For dynamic shape case, we can't get activation shape ahead of runtime.
1323
+ x_shape = None
1324
+ graph = match.graph
1325
+ with graph.inserting_before(conv_node):
1326
+ # Insert weight prepack node and the QConv node
1327
+ packed_weight_inputs = (
1328
+ qw,
1329
+ w_scale,
1330
+ x_scale,
1331
+ x_zp,
1332
+ stride,
1333
+ padding,
1334
+ dilation,
1335
+ groups,
1336
+ x_shape,
1337
+ )
1338
+ packed_weight_op = torch.ops.onednn.qconv_prepack
1339
+ prepack_weight_node = graph.call_function(
1340
+ packed_weight_op, args=packed_weight_inputs
1341
+ )
1342
+
1343
+ new_args: Tuple[Any, ...] = (
1344
+ qx,
1345
+ x_scale,
1346
+ x_zp,
1347
+ prepack_weight_node,
1348
+ w_scale,
1349
+ w_zp,
1350
+ bias,
1351
+ stride,
1352
+ padding,
1353
+ dilation,
1354
+ groups,
1355
+ 1.0, # inv_output_scale
1356
+ 0, # output_zero_point
1357
+ dtype, # output_dtype
1358
+ "none", # attr
1359
+ [], # scalars
1360
+ "", # algorithm
1361
+ )
1362
+ new_conv_node = graph.call_function(
1363
+ torch.ops.onednn.qconv2d_pointwise.default, args=new_args
1364
+ )
1365
+ conv_node.replace_all_uses_with(new_conv_node)
1366
+ new_conv_node.meta.update(conv_node.meta)
1367
+
1368
+ # Erase the original conv node
1369
+ graph.erase_node(conv_node)
1370
+ # Erase the dequant pattern
1371
+ if dtype == torch.bfloat16:
1372
+ graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined]
1373
+ # Erase the dequant pattern
1374
+ graph.erase_node(mul_node)
1375
+ graph.erase_node(sub_node)
1376
+ graph.erase_node(to_fp32_node)
1377
+ # Erase the dequant per channel pattern
1378
+ if clone_node is not None:
1379
+ graph.erase_node(clone_node)
1380
+ if dtype == torch.bfloat16:
1381
+ graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
1382
+ graph.erase_node(dequant_per_channel)
1383
+ counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
1384
+ counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
1385
+ match.nodes
1386
+ )
1387
+
1388
+
1389
+ def _generate_dequant_convolution_node_pattern(
1390
+ _dequant_per_channel_pattern, dtype=torch.float32
1391
+ ):
1392
+ assert dtype in [torch.float32, torch.bfloat16]
1393
+ dequant_convolution_node_pattern = CallFunction(
1394
+ aten.convolution.default,
1395
+ _may_generate_pattern_with_dtype_convert(
1396
+ dequantize_per_tensor_activation_pattern,
1397
+ KeywordArg("autocast_act_dtype"),
1398
+ dtype == torch.bfloat16,
1399
+ ),
1400
+ _dequant_per_channel_pattern,
1401
+ KeywordArg("b"),
1402
+ KeywordArg("stride"),
1403
+ KeywordArg("padding"),
1404
+ KeywordArg("dilation"),
1405
+ KeywordArg("is_transposed"),
1406
+ KeywordArg("out_padding"),
1407
+ KeywordArg("groups"),
1408
+ )
1409
+ return dequant_convolution_node_pattern
1410
+
1411
+
1412
+ def _generate_qconv_weight_prepack_patterns(dtype=torch.float32):
1413
+ assert dtype in [torch.float32, torch.bfloat16]
1414
+ return (
1415
+ _generate_dequant_convolution_node_pattern(
1416
+ dequantize_per_channel_weight_pattern
1417
+ if dtype == torch.float32
1418
+ else dequantize_per_channel_to_bf16_weight_pattern,
1419
+ dtype,
1420
+ ),
1421
+ # There is another pattern due to the pass of convert_conv_weights_to_channels_last
1422
+ # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
1423
+ # Depend on some heuristics, it may or may not insert to(channel_last) node
1424
+ # between convolution and dequant_per_channel node
1425
+ _generate_dequant_convolution_node_pattern(
1426
+ dequantize_per_channel_clone_weight_pattern
1427
+ if dtype == torch.float32
1428
+ else dequantize_per_channel_to_bf16_clone_weight_pattern,
1429
+ dtype,
1430
+ ),
1431
+ )
1432
+
1433
+
1434
+ def _get_linear_node(match, input_dim_exceeds_two, input_contiguous):
1435
+ output_reshape_node = None
1436
+ if input_dim_exceeds_two:
1437
+ if input_contiguous:
1438
+ output_reshape_node = match.output_node()
1439
+ assert output_reshape_node.target is aten.reshape.default
1440
+ linear_node = output_reshape_node.args[0]
1441
+ else:
1442
+ linear_nodes = filter_nodes(match.nodes, aten.bmm.default)
1443
+ assert len(linear_nodes) == 1
1444
+ linear_node = linear_nodes[0]
1445
+ else:
1446
+ linear_node = match.output_node()
1447
+
1448
+ assert linear_node.target in (
1449
+ aten.addmm.default,
1450
+ aten.mm.default,
1451
+ aten.bmm.default,
1452
+ )
1453
+ return linear_node, output_reshape_node
1454
+
1455
+
1456
+ def _get_linear_dq_mul_node(
1457
+ linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
1458
+ ):
1459
+ act_reshape_node = None
1460
+ activation_to_bf16_node = None
1461
+ act_expand_node = None
1462
+ if input_dim_exceeds_two:
1463
+ if input_contiguous:
1464
+ act_reshape_node = linear_node.args[input_index]
1465
+ assert act_reshape_node.target is aten.reshape.default
1466
+ if dtype == torch.float32:
1467
+ # pattern: linear -> reshape -> mul
1468
+ mul_node = act_reshape_node.args[0]
1469
+ else:
1470
+ # pattern: linear -> reshape -> to_bf16 -> mul
1471
+ activation_to_bf16_node = act_reshape_node.args[0]
1472
+ mul_node = activation_to_bf16_node.args[0]
1473
+ else:
1474
+ # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous
1475
+ act_expand_node = linear_node.args[input_index]
1476
+ assert act_expand_node.target is aten.expand.default
1477
+ if dtype == torch.float32:
1478
+ mul_node = act_expand_node.args[0]
1479
+ else:
1480
+ activation_to_bf16_node = act_expand_node.args[0]
1481
+ mul_node = activation_to_bf16_node.args[0]
1482
+ else:
1483
+ if dtype == torch.float32:
1484
+ # pattern: linear -> mul
1485
+ mul_node = linear_node.args[input_index]
1486
+ else:
1487
+ # pattern: linear -> to_bf16 -> mul
1488
+ activation_to_bf16_node = linear_node.args[input_index]
1489
+ mul_node = activation_to_bf16_node.args[0]
1490
+ return mul_node, act_reshape_node, activation_to_bf16_node, act_expand_node
1491
+
1492
+
1493
+ def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous):
1494
+ def _inner(match):
1495
+ # Check dequant pattern has only 1 user.
1496
+ (
1497
+ linear_node,
1498
+ _,
1499
+ ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
1500
+
1501
+ input_index = 1 if linear_node.target is aten.addmm.default else 0
1502
+ assert dtype in [torch.float32, torch.bfloat16]
1503
+
1504
+ (
1505
+ mul_node,
1506
+ _,
1507
+ _,
1508
+ _,
1509
+ ) = _get_linear_dq_mul_node(
1510
+ linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
1511
+ )
1512
+
1513
+ sub_node = mul_node.args[0]
1514
+ to_fp32_node = sub_node.args[0]
1515
+
1516
+ assert to_fp32_node.target is prims.convert_element_type.default
1517
+ assert sub_node.target is aten.sub.Tensor
1518
+ assert mul_node.target is aten.mul.Tensor
1519
+ if (
1520
+ len(list(to_fp32_node.users)) != 1
1521
+ or len(list(sub_node.users)) != 1
1522
+ or len(list(mul_node.users)) != 1
1523
+ ):
1524
+ # Ensure the dequant pattern only has 1 user
1525
+ # since we will delete the dequant pattern here
1526
+ return False
1527
+
1528
+ # Extra check for bmm pattern
1529
+ if input_dim_exceeds_two and not input_contiguous:
1530
+ # Check for act
1531
+ # Act expand size should be exactly same as act size
1532
+ act_expand_size = match.kwargs["act_expand_size"]
1533
+ act_node = match.kwargs["x"]
1534
+ if not (
1535
+ hasattr(act_node, "meta")
1536
+ and isinstance(act_node.meta.get("val", None), torch.Tensor)
1537
+ and (act_node.meta["val"].size() == torch.Size(act_expand_size))
1538
+ ):
1539
+ return False
1540
+
1541
+ # Check for wgt
1542
+ # wgt permute dims should be [1, 0]
1543
+ wgt_permute_dims = match.kwargs["permute_axes"]
1544
+ if wgt_permute_dims != [1, 0]:
1545
+ return False
1546
+
1547
+ # Check below wgt size items:
1548
+ # wgt before expand should with dim 2
1549
+ # Expand size should with dim 3
1550
+ # Expand size[0] should same as act size[0]
1551
+ # Expand size[1] should same as wgt size[1]
1552
+ # Expand size[2] should same as wgt size[0]
1553
+ qweight_node = match.kwargs["q_weight"]
1554
+ wgt_expand_size = match.kwargs["wgt_expand_size"]
1555
+ if not (
1556
+ hasattr(qweight_node, "meta")
1557
+ and isinstance(qweight_node.meta.get("val", None), torch.Tensor)
1558
+ and len(qweight_node.meta["val"].size()) == 2
1559
+ and len(wgt_expand_size) == 3
1560
+ and wgt_expand_size[0] == act_node.meta["val"].size()[0]
1561
+ and wgt_expand_size[1] == qweight_node.meta["val"].size()[1]
1562
+ and wgt_expand_size[2] == qweight_node.meta["val"].size()[0]
1563
+ ):
1564
+ return False
1565
+
1566
+ return True
1567
+
1568
+ return _inner
1569
+
1570
+
1571
+ def _register_qlinear_weight_prepack_pass(
1572
+ pattern,
1573
+ pass_number,
1574
+ dtype=torch.float32,
1575
+ input_dim_exceeds_two=False,
1576
+ input_contiguous=True,
1577
+ ):
1578
+ @register_freezing_graph_pattern(
1579
+ pattern,
1580
+ extra_check=_is_valid_dequant_linear_pattern(
1581
+ dtype, input_dim_exceeds_two, input_contiguous
1582
+ ),
1583
+ pass_number=pass_number,
1584
+ )
1585
+ def qlinear_weight_prepack(match: Match, *args, **kwargs):
1586
+ """
1587
+ Match the pattern:
1588
+ int8 activation
1589
+ |
1590
+ dequant_per_tensor
1591
+ |
1592
+ mm/addmm <- t <- dequant_per_channel <- int8_weight
1593
+
1594
+ Insert weight prepack node and change the pattern to:
1595
+ int8 activation
1596
+ |
1597
+ onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight
1598
+ """
1599
+ assert dtype in [torch.float32, torch.bfloat16]
1600
+ (
1601
+ linear_node,
1602
+ output_reshape_node,
1603
+ ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
1604
+ input_index = 1 if linear_node.target is aten.addmm.default else 0
1605
+ weight_index = input_index + 1
1606
+
1607
+ (
1608
+ mul_node,
1609
+ act_reshape_node,
1610
+ activation_to_bf16_node,
1611
+ act_expand_node,
1612
+ ) = _get_linear_dq_mul_node(
1613
+ linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
1614
+ )
1615
+
1616
+ sub_node = mul_node.args[0]
1617
+ to_fp32_node = sub_node.args[0]
1618
+
1619
+ if input_dim_exceeds_two and not input_contiguous:
1620
+ wgt_expand_node = linear_node.args[weight_index]
1621
+ assert wgt_expand_node.target is aten.expand.default
1622
+ t_node = wgt_expand_node.args[0]
1623
+ else:
1624
+ t_node = linear_node.args[weight_index]
1625
+
1626
+ if dtype == torch.float32:
1627
+ dequant_per_channel = t_node.args[0]
1628
+ else:
1629
+ weight_to_bf16_node = t_node.args[0]
1630
+ dequant_per_channel = weight_to_bf16_node.args[0]
1631
+ assert (
1632
+ dequant_per_channel.target
1633
+ is quantized_decomposed.dequantize_per_channel.default
1634
+ )
1635
+
1636
+ # Activation QParams
1637
+ qx, x_zp, x_scale = (
1638
+ kwargs["x"],
1639
+ kwargs["x_zp"],
1640
+ kwargs["x_scale"],
1641
+ )
1642
+
1643
+ # Weight QParams
1644
+ qw, w_scale, w_zp = (
1645
+ kwargs["q_weight"],
1646
+ kwargs["w_scale"],
1647
+ kwargs["w_zp"],
1648
+ )
1649
+
1650
+ # Params
1651
+ bias = kwargs["b"] if "b" in kwargs else None
1652
+
1653
+ x_shape = qx.meta.get("tensor_meta").shape
1654
+ if has_free_symbols(x_shape):
1655
+ # For dynamic shape case, we can't get activation shape ahead of runtime.
1656
+ x_shape = None
1657
+ graph = match.graph
1658
+ with graph.inserting_before(linear_node):
1659
+ # Insert weight prepack node and the qlinear node
1660
+ packed_weight_inputs = (
1661
+ qw,
1662
+ x_shape,
1663
+ )
1664
+ packed_weight_op = torch.ops.onednn.qlinear_prepack
1665
+ prepack_weight_node = graph.call_function(
1666
+ packed_weight_op, args=packed_weight_inputs
1667
+ )
1668
+
1669
+ new_args: Tuple[Any, ...] = (
1670
+ qx,
1671
+ x_scale,
1672
+ x_zp,
1673
+ prepack_weight_node,
1674
+ w_scale,
1675
+ w_zp,
1676
+ bias,
1677
+ 1.0, # output_scale
1678
+ 0, # output_zero_point
1679
+ dtype, # output_dtype
1680
+ "none", # post op name
1681
+ [], # post op args
1682
+ "", # post op algorithm
1683
+ )
1684
+ Node = torch.fx.node.Node
1685
+ if isinstance(x_scale, Node) and isinstance(x_zp, Node):
1686
+ new_linear_node = graph.call_function(
1687
+ torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
1688
+ )
1689
+ else:
1690
+ new_linear_node = graph.call_function(
1691
+ torch.ops.onednn.qlinear_pointwise.default, args=new_args
1692
+ )
1693
+ if input_dim_exceeds_two:
1694
+ if input_contiguous:
1695
+ output_reshape_node.replace_all_uses_with(new_linear_node)
1696
+ new_linear_node.meta.update(output_reshape_node.meta)
1697
+ else:
1698
+ if bias:
1699
+ output_add_node_for_bias = match.output_node()
1700
+ assert output_add_node_for_bias.target is aten.add.Tensor
1701
+ output_add_node_for_bias.replace_all_uses_with(new_linear_node)
1702
+ new_linear_node.meta.update(output_add_node_for_bias.meta)
1703
+ else:
1704
+ linear_node.replace_all_uses_with(new_linear_node)
1705
+ new_linear_node.meta.update(linear_node.meta)
1706
+ else:
1707
+ linear_node.replace_all_uses_with(new_linear_node)
1708
+ new_linear_node.meta.update(linear_node.meta)
1709
+
1710
+ # Erase the original linear node
1711
+ if input_dim_exceeds_two:
1712
+ if input_contiguous:
1713
+ graph.erase_node(output_reshape_node)
1714
+ elif not input_contiguous and bias:
1715
+ graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined]
1716
+ graph.erase_node(linear_node)
1717
+ if input_dim_exceeds_two:
1718
+ if input_contiguous:
1719
+ graph.erase_node(act_reshape_node)
1720
+ else:
1721
+ graph.erase_node(act_expand_node)
1722
+ graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined]
1723
+ if dtype == torch.bfloat16:
1724
+ graph.erase_node(activation_to_bf16_node)
1725
+ # Erase the dequant pattern
1726
+ graph.erase_node(mul_node)
1727
+ graph.erase_node(sub_node)
1728
+ graph.erase_node(to_fp32_node)
1729
+ # Erase the dequant per channel pattern
1730
+ graph.erase_node(t_node)
1731
+ if dtype == torch.bfloat16:
1732
+ graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
1733
+ graph.erase_node(dequant_per_channel)
1734
+
1735
+ counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
1736
+ counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
1737
+ match.nodes
1738
+ )
1739
+
1740
+
1741
+ def _generate_dequant_linear_node_pattern(
1742
+ _dequant_per_channel_pattern, dtype=torch.float32, input_dim_exceeds_two=False
1743
+ ):
1744
+ assert dtype in [torch.float32, torch.bfloat16]
1745
+ t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
1746
+ dequant_linear_bias_pattern = _may_generate_pattern_with_reshape(
1747
+ CallFunction(
1748
+ aten.addmm.default,
1749
+ KeywordArg("b"),
1750
+ _may_generate_pattern_with_reshape(
1751
+ _may_generate_pattern_with_dtype_convert(
1752
+ dequantize_per_tensor_activation_pattern,
1753
+ KeywordArg("autocast_act_dtype"),
1754
+ dtype == torch.bfloat16,
1755
+ ),
1756
+ KeywordArg("act_reshape_size"),
1757
+ input_dim_exceeds_two,
1758
+ ),
1759
+ t_pattern,
1760
+ ),
1761
+ KeywordArg("output_reshape_size"),
1762
+ input_dim_exceeds_two,
1763
+ )
1764
+ dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape(
1765
+ CallFunction(
1766
+ aten.mm.default,
1767
+ _may_generate_pattern_with_reshape(
1768
+ _may_generate_pattern_with_dtype_convert(
1769
+ dequantize_per_tensor_activation_pattern,
1770
+ KeywordArg("autocast_act_dtype"),
1771
+ dtype == torch.bfloat16,
1772
+ ),
1773
+ KeywordArg("act_reshape_size"),
1774
+ input_dim_exceeds_two,
1775
+ ),
1776
+ t_pattern,
1777
+ ),
1778
+ KeywordArg("output_reshape_size"),
1779
+ input_dim_exceeds_two,
1780
+ )
1781
+ return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern
1782
+
1783
+
1784
+ def _generate_dequant_bmm_node_pattern(
1785
+ _dequant_per_channel_pattern,
1786
+ dtype=torch.float32,
1787
+ with_bias=False,
1788
+ ):
1789
+ # When activation of linear dim exceed 2 and not contiguous
1790
+ t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
1791
+
1792
+ assert dtype in [torch.float32, torch.bfloat16]
1793
+ dequant_bmm_pattern = CallFunction(
1794
+ aten.bmm.default,
1795
+ CallFunction(
1796
+ aten.expand.default,
1797
+ _may_generate_pattern_with_dtype_convert(
1798
+ dequantize_per_tensor_activation_pattern,
1799
+ KeywordArg("autocast_act_dtype"),
1800
+ dtype == torch.bfloat16,
1801
+ ),
1802
+ KeywordArg("act_expand_size"),
1803
+ ),
1804
+ CallFunction(
1805
+ aten.expand.default,
1806
+ t_pattern,
1807
+ KeywordArg("wgt_expand_size"),
1808
+ ),
1809
+ )
1810
+
1811
+ def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias):
1812
+ if _with_bias:
1813
+ return CallFunction(
1814
+ aten.add.Tensor,
1815
+ _dequant_bmm_pattern,
1816
+ KeywordArg("b"),
1817
+ )
1818
+ else:
1819
+ return _dequant_bmm_pattern
1820
+
1821
+ return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias)
1822
+
1823
+
1824
+ def _generate_qlinear_weight_prepack_patterns(
1825
+ dtype=torch.float32,
1826
+ input_dim_exceeds_two=False,
1827
+ input_contiguous=True,
1828
+ with_bias=False,
1829
+ ):
1830
+ if input_dim_exceeds_two and not input_contiguous:
1831
+ return _generate_dequant_bmm_node_pattern(
1832
+ dequantize_per_channel_weight_pattern,
1833
+ dtype,
1834
+ with_bias,
1835
+ )
1836
+ else:
1837
+ return _generate_dequant_linear_node_pattern(
1838
+ dequantize_per_channel_weight_pattern, dtype, input_dim_exceeds_two
1839
+ )
1840
+
1841
+
1842
+ def _register_dequant_promotion():
1843
+ dequant_pattern_cases = itertools.product(
1844
+ [torch.float32, torch.bfloat16], [True, False]
1845
+ )
1846
+ for dtype, input_dim_exceeds_two in dequant_pattern_cases:
1847
+ # 4 dequantization patterns will be matched based on the dtype and input dimension size.
1848
+ # Case 1: int8-mixed-fp32, input dim size is 2
1849
+ # Case 2: int8-mixed-fp32, input dim size exceeds 2
1850
+ # Case 3: int8-mixed-bf16, input dim size is 2
1851
+ # Case 4: int8-mixed-bf16, input dim size exceeds 2
1852
+ # quant
1853
+ # + - - - - | - - - - +
1854
+ # | dequant |
1855
+ # | | |
1856
+ # | OPT(to_bf16) |
1857
+ # | | |
1858
+ # | OPT(reshape) |
1859
+ # | / \ |
1860
+ # | node1 node2 |
1861
+ # + - - | - - - | - - +
1862
+ # OPT(reshape) OPT(reshape)
1863
+ # + - - | - - - | - - +
1864
+ # OPT(to_fp32) OPT(to_fp32)
1865
+ # + - - | - - - | - - +
1866
+ # quant quant
1867
+ _register_dequant_promotion_pass(
1868
+ _may_generate_pattern_with_reshape(
1869
+ _may_generate_pattern_with_dtype_convert(
1870
+ dequantize_per_tensor_activation_pattern,
1871
+ KeywordArg("autocast_act_dtype"),
1872
+ dtype == torch.bfloat16,
1873
+ ),
1874
+ KeywordArg("act_reshape_size"),
1875
+ with_reshape=input_dim_exceeds_two,
1876
+ ),
1877
+ pass_number=0,
1878
+ dtype=dtype,
1879
+ ) # pass_number=0 to run before weight prepack
1880
+
1881
+
1882
+ def _register_qconv_weight_prepack():
1883
+ for dtype in [torch.float32, torch.bfloat16]:
1884
+ weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype)
1885
+ for weight_prepack_pattern in weight_prepack_patterns:
1886
+ # Register to pass_number 1, so we can do dequant promotion in pass_number 0.
1887
+ _register_qconv_weight_prepack_pass(
1888
+ weight_prepack_pattern, pass_number=1, dtype=dtype
1889
+ )
1890
+
1891
+
1892
+ def _register_qlinear_weight_prepack():
1893
+ # 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous.
1894
+ # Then convert the pattern into a QLinear node with int8_fp32/bf16.
1895
+ # Case 1: int8-mixed-fp32, input dim size is 2
1896
+ # Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous
1897
+ # Case 3: int8-mixed-bf16, input dim size is 2
1898
+ # Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous
1899
+
1900
+ # + - - - - | - - - - - - | - - - - - +
1901
+ # | dq_per_tensor dq_per_channel |
1902
+ # | | | |
1903
+ # | OPT(to_bf16) OPT(to_bf16) |
1904
+ # | | | |
1905
+ # | OPT(reshape) permute |
1906
+ # | \ / |
1907
+ # | addmm/mm |
1908
+ # | | |
1909
+ # | OPT(reshape) |
1910
+
1911
+ # Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous
1912
+ # Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous
1913
+
1914
+ # + - - - - | - - - - - - | - - - - - +
1915
+ # | dq_per_tensor dq_per_channel |
1916
+ # | | | |
1917
+ # | OPT(to_bf16) OPT(to_bf16) |
1918
+ # | | | |
1919
+ # | expand permute |
1920
+ # | \ | |
1921
+ # | expand |
1922
+ # | / |
1923
+ # | bmm |
1924
+ # | | |
1925
+ # | OPT(add) |
1926
+
1927
+ linear_weight_prepack_cases = itertools.product(
1928
+ [torch.float32, torch.bfloat16], [True, False]
1929
+ )
1930
+
1931
+ # Step 1: register patterns from mm and addmm
1932
+ for dtype, input_dim_exceeds_two in linear_weight_prepack_cases:
1933
+ weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns(
1934
+ dtype, input_dim_exceeds_two
1935
+ )
1936
+ for weight_prepack_pattern in weight_prepack_patterns:
1937
+ # Register to pass_number 1, so we can do dequant promotion in pass_number 0.
1938
+ _register_qlinear_weight_prepack_pass(
1939
+ weight_prepack_pattern,
1940
+ pass_number=1,
1941
+ dtype=dtype,
1942
+ input_dim_exceeds_two=input_dim_exceeds_two,
1943
+ )
1944
+
1945
+ # Step 2: register patterns from bmm
1946
+ # Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous
1947
+ # refer to:
1948
+ # https://github.com/pytorch/pytorch/blob/
1949
+ # 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968
1950
+ # in this case, we can convert it back to qlinear
1951
+ for dtype, with_bias in itertools.product(
1952
+ [torch.float32, torch.bfloat16], [True, False]
1953
+ ):
1954
+ bmm_pattern = _generate_qlinear_weight_prepack_patterns(
1955
+ dtype=dtype,
1956
+ input_dim_exceeds_two=True,
1957
+ input_contiguous=False,
1958
+ with_bias=with_bias,
1959
+ )
1960
+ _register_qlinear_weight_prepack_pass(
1961
+ bmm_pattern,
1962
+ pass_number=1
1963
+ if with_bias
1964
+ else 2, # if with_bias, there is an output add, so we should try to match it firstly
1965
+ dtype=dtype,
1966
+ input_dim_exceeds_two=True,
1967
+ input_contiguous=False,
1968
+ )
1969
+
1970
+
1971
+ @functools.lru_cache(None)
1972
+ def _register_quantization_weight_pack_pass():
1973
+ # Step 1: Dequant promotion for int8-mixed-fp32/bf16
1974
+ _register_dequant_promotion()
1975
+
1976
+ # Step 2: QConv weight prepack
1977
+ _register_qconv_weight_prepack()
1978
+
1979
+ # Step 3: QLinear weight prepack
1980
+ _register_qlinear_weight_prepack()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+
4
+ import torch
5
+
6
+ from torch.fx.passes.shape_prop import _extract_tensor_metadata
7
+ from .. import config, inductor_prims
8
+ from ..pattern_matcher import (
9
+ CallFunctionVarArgs,
10
+ Match,
11
+ PatternMatcherPass,
12
+ register_graph_pattern,
13
+ )
14
+ from ..virtualized import V
15
+
16
+ log = logging.getLogger(__name__)
17
+ patterns = PatternMatcherPass()
18
+ aten = torch.ops.aten
19
+
20
+
21
+ def replace_random_passes(gm: torch.fx.GraphModule):
22
+ """Modify the given FX graph to use backend-native random ops"""
23
+ if config.fallback_random:
24
+ return 0
25
+
26
+ count = patterns.apply(gm)
27
+ count += fuse_seed_creation_pass(gm.graph)
28
+
29
+ return count
30
+
31
+
32
+ def fuse_seed_creation_pass(graph: torch.fx.Graph):
33
+ """
34
+ Horizontally fuse all the seed generation on each device
35
+
36
+ a = inductor_seed(dev)
37
+ b = inductor_seed(dev)
38
+
39
+ Becomes:
40
+ seeds = inductor_seeds(2, dev)
41
+ a = inductor_lookup_seed(seeds, 0)
42
+ b = inductor_lookup_seed(seeds, 1)
43
+
44
+ We do this because seed creation is entirely launch overhead bound.
45
+ """
46
+ device_seeds = collections.defaultdict(list)
47
+ for node in graph.nodes:
48
+ if CallFunctionVarArgs(inductor_prims.seed).match(node):
49
+ device_seeds[node.args[0]].append(node)
50
+
51
+ if not device_seeds:
52
+ return 0
53
+
54
+ for device, seeds in device_seeds.items():
55
+ with graph.inserting_before(seeds[0]):
56
+ combined = graph.call_function(inductor_prims.seeds, (len(seeds), device))
57
+ with V.fake_mode:
58
+ combined.meta["val"] = torch.empty(
59
+ [len(seeds)], device=device, dtype=torch.int64
60
+ )
61
+ combined.meta["tensor_meta"] = _extract_tensor_metadata(
62
+ combined.meta["val"]
63
+ )
64
+
65
+ for idx, seed in enumerate(seeds):
66
+ with graph.inserting_before(seed):
67
+ new_seed = graph.call_function(
68
+ inductor_prims.lookup_seed, (combined, idx)
69
+ )
70
+ seed.replace_all_uses_with(new_seed)
71
+ new_seed.meta.update(seed.meta)
72
+ graph.erase_node(seed)
73
+
74
+ return len(device_seeds)
75
+
76
+
77
+ def default_kwargs(device):
78
+ return {}
79
+
80
+
81
+ def get_device(device):
82
+ if device is not None:
83
+ return device
84
+ return torch.empty([]).device # default device
85
+
86
+
87
+ @register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns)
88
+ @register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns)
89
+ @register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns)
90
+ @register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns)
91
+ def replace_random(
92
+ match: Match,
93
+ size,
94
+ *,
95
+ generator=None,
96
+ dtype=None,
97
+ device=None,
98
+ layout=None,
99
+ pin_memory=None,
100
+ ):
101
+ if generator is not None:
102
+ return
103
+
104
+ def replacement(size):
105
+ result = inductor_prims.random(
106
+ size, inductor_prims.seed(device), mode, **default_kwargs(device)
107
+ )
108
+ if dtype is not None:
109
+ result = result.to(dtype)
110
+ return result
111
+
112
+ mode = {
113
+ aten.rand: "rand",
114
+ aten.randn: "randn",
115
+ }[
116
+ match.output_node().target.overloadpacket # type: ignore[union-attr]
117
+ ] # type: ignore[union-attr]
118
+ device = get_device(device)
119
+ match.replace_by_example(replacement, [size])
120
+
121
+
122
+ @register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns)
123
+ def replace_randint(
124
+ match: Match,
125
+ low,
126
+ high,
127
+ size,
128
+ *,
129
+ dtype=torch.int64,
130
+ device=None,
131
+ layout=None,
132
+ pin_memory=None,
133
+ ):
134
+ def replacement(size):
135
+ result = inductor_prims.randint(low, high, size, inductor_prims.seed(device))
136
+ return result.to(dtype)
137
+
138
+ device = get_device(device)
139
+ match.replace_by_example(replacement, [size])
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc ADDED
Binary file (17.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc ADDED
Binary file (19.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc ADDED
Binary file (16.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-311.pyc ADDED
Binary file (16.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/central_index.cpython-311.pyc ADDED
Binary file (7.36 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
35
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
36
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
37
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
38
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
39
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
40
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
41
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
42
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
43
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
44
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
45
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
46
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
47
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
48
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
49
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
50
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
51
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
52
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
53
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
54
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
55
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
56
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
57
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
58
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
59
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
60
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
61
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
62
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
63
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
64
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
65
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
66
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
67
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
68
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
69
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
70
+ clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
71
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
72
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
73
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
74
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
75
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
76
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
77
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
78
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
79
+ div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale_factor'))
80
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
81
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
82
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
83
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
84
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
85
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
86
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
87
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
88
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
89
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
90
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
91
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
92
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
93
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
94
+ _sfdp_pattern_12_training = MultiOutputPattern([view_default_5,
95
+ permute_default_6,
96
+ permute_default_9,
97
+ permute_default_11,
98
+ None,
99
+ None
100
+ ])
101
+
102
+
103
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
104
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
105
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
106
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
107
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
108
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
109
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
110
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
111
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
112
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
113
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
114
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
115
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
116
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
117
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
118
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
119
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
120
+ clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
121
+ expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
122
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
123
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
124
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
125
+ clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
126
+ view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
127
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
128
+ _sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
129
+
130
+
131
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
132
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
133
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
134
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
135
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
136
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
137
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
138
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
139
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
140
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
141
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
142
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
143
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
144
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
145
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
146
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
147
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
148
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
149
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
150
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
151
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
152
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
153
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
154
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
155
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
156
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
157
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
158
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
159
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
160
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
161
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
162
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
163
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
164
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
165
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
166
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
167
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
168
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
169
+ clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
170
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
171
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
172
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
173
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
174
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
175
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
176
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
177
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
178
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
179
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
180
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
181
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor'))
182
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
183
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
184
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
185
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
186
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
187
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
188
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
189
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
190
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
191
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
192
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
193
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
194
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
195
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
196
+ _sfdp_pattern_12_half_training = MultiOutputPattern([view_default_5,
197
+ permute_default_6,
198
+ permute_default_9,
199
+ permute_default_11,
200
+ None,
201
+ None
202
+ ])
203
+
204
+
205
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
206
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
207
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
208
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
209
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
210
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
211
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
212
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
213
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
214
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
215
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
216
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
217
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
218
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
219
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
220
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
221
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
222
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
223
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
224
+ clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
225
+ expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
226
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
227
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
228
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
229
+ clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
230
+ view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
231
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
232
+ _sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
35
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
36
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
37
+ bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2)
38
+ amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True)
39
+ sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default)
40
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
41
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
42
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
43
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor)
44
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2)
45
+ bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value'))
46
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
47
+ bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1)
48
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
49
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
50
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2)
51
+ clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
52
+ alias_default = CallFunction(aten.alias.default, div_Tensor)
53
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
54
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
55
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
56
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
57
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
58
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
59
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5, _users=2)
60
+ permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored())
61
+ bmm_default_3 = CallFunction(aten.bmm.default, sub_Tensor_1, permute_default_2)
62
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
63
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, sub_Tensor_1)
64
+ permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored())
65
+ permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored())
66
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1'))
67
+ _sfdp_pattern_13_training = MultiOutputPattern([bmm_default_1,
68
+ bmm_default_3,
69
+ permute_default_4,
70
+ bmm_default_5,
71
+ None
72
+ ])
73
+
74
+
75
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
76
+ bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2)
77
+ amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True)
78
+ sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default)
79
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
80
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
81
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
82
+ clone_default = CallFunction(aten.clone.default, div_Tensor)
83
+ _sfdp_pattern_13_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value'))
84
+
85
+
86
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
87
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
88
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
89
+ bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default)
90
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2)
91
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
92
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
93
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
94
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
95
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
96
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
97
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
98
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2)
99
+ bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value'))
100
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
101
+ bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1)
102
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
103
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
104
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2)
105
+ clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
106
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
107
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
108
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
109
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
110
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
111
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
112
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
113
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
114
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
115
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
116
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored(), _users=2)
117
+ permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored())
118
+ bmm_default_3 = CallFunction(aten.bmm.default, convert_element_type_default_5, permute_default_2)
119
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
120
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, convert_element_type_default_5)
121
+ permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored())
122
+ permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored())
123
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1'))
124
+ _sfdp_pattern_13_half_training = MultiOutputPattern([bmm_default_1,
125
+ bmm_default_3,
126
+ permute_default_4,
127
+ bmm_default_5,
128
+ None
129
+ ])
130
+
131
+
132
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
133
+ bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default)
134
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2)
135
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
136
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
137
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
138
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
139
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
140
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
141
+ clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
142
+ _sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value'))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
35
+ expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
36
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
37
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
38
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
39
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
40
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
41
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
42
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
43
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
44
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
45
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
46
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
47
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
48
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
49
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
50
+ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
51
+ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
52
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
53
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
54
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
55
+ expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
56
+ view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
57
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
58
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
59
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
60
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
61
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
62
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
63
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
64
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
65
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
66
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
67
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
68
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
69
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
70
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
71
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
72
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
73
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
74
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
75
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
76
+ where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, sub_Tensor_1)
77
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
78
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
79
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
80
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
81
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
82
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
83
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
84
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
85
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
86
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
87
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
88
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
89
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
90
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
91
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
92
+ _sfdp_pattern_15_training = MultiOutputPattern([view_default_5,
93
+ permute_default_6,
94
+ permute_default_9,
95
+ permute_default_11,
96
+ None,
97
+ None
98
+ ])
99
+
100
+
101
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
102
+ view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
103
+ expand_default = CallFunction(aten.expand.default, view_default, Ignored())
104
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
105
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
106
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
107
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
108
+ view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
109
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
110
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
111
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
112
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
113
+ view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
114
+ bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
115
+ view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
116
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
117
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
118
+ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
119
+ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
120
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
121
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
122
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
123
+ expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
124
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
125
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
126
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
127
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
128
+ view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
129
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
130
+ _sfdp_pattern_15_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
131
+
132
+
133
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
134
+ expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
135
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
136
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
137
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
138
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
139
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
140
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
141
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
142
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
143
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
144
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
145
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
146
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
147
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
148
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
149
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
150
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
151
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
152
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
153
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
154
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
155
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
156
+ expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
157
+ view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
158
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
159
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
160
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
161
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
162
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
163
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
164
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
165
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
166
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
167
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
168
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
169
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
170
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
171
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
172
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
173
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
174
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
175
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
176
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
177
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
178
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
179
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
180
+ where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_4)
181
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
182
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
183
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
184
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
185
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
186
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
187
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
188
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
189
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
190
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
191
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
192
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
193
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
194
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
195
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
196
+ _sfdp_pattern_15_half_training = MultiOutputPattern([view_default_5,
197
+ permute_default_6,
198
+ permute_default_9,
199
+ permute_default_11,
200
+ None,
201
+ None
202
+ ])
203
+
204
+
205
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
206
+ view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
207
+ expand_default = CallFunction(aten.expand.default, view_default, Ignored())
208
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
209
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
210
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
211
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
212
+ view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
213
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
214
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
215
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
216
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
217
+ view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
218
+ bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
219
+ view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
220
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
221
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
222
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
223
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
224
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
225
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
226
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
227
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
228
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
229
+ expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
230
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
231
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
232
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
233
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
234
+ view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
235
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
236
+ _sfdp_pattern_15_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
35
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
36
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
37
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
38
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
39
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
40
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
41
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
42
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
43
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
44
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
45
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
46
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
47
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
48
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
49
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
50
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
51
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
52
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
53
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
54
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
55
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
56
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
57
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
58
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
59
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
60
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
61
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
62
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
63
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
64
+ clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
65
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
66
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
67
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
68
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
69
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
70
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
71
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
72
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
73
+ div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale_factor'))
74
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
75
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
76
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
77
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
78
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
79
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
80
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
81
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
82
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
83
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
84
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
85
+ _sfdp_pattern_3_training = MultiOutputPattern([view_default_5,
86
+ view_default_9,
87
+ permute_default_4,
88
+ view_default_11,
89
+ None,
90
+ None
91
+ ])
92
+
93
+
94
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
95
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
96
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
97
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
98
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
99
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
100
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
101
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
102
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
103
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
104
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
105
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
106
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
107
+ clone_default = CallFunction(aten.clone.default, div_Tensor_1)
108
+ expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
109
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
110
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
111
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
112
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
113
+ _sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
114
+
115
+
116
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
117
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
118
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
119
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
120
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
121
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
122
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
123
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
124
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
125
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
126
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
127
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
128
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
129
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
130
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
131
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
132
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
133
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
134
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
135
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
136
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
137
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
138
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
139
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
140
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
141
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
142
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
143
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
144
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
145
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
146
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
147
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
148
+ clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
149
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
150
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
151
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
152
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
153
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
154
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
155
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
156
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
157
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
158
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
159
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
160
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor'))
161
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
162
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
163
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
164
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
165
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
166
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
167
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
168
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
169
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
170
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
171
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
172
+ _sfdp_pattern_3_half_training = MultiOutputPattern([view_default_5,
173
+ view_default_9,
174
+ permute_default_4,
175
+ view_default_11,
176
+ None,
177
+ None
178
+ ])
179
+
180
+
181
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
182
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
183
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
184
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
185
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
186
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
187
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
188
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
189
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
190
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
191
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
192
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
193
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
194
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
195
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
196
+ clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
197
+ expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
198
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
199
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
200
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
201
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
202
+ _sfdp_pattern_3_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
35
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
36
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
37
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
38
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
39
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
40
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
41
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
42
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
43
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
44
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
45
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
46
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
47
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
48
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
49
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
50
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
51
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
52
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
53
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
54
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
55
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
56
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
57
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
58
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
59
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
60
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
61
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
62
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
63
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
64
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
65
+ clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
66
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
67
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
68
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
69
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
70
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
71
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
72
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
73
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
74
+ div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored())
75
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
76
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
77
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
78
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
79
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
80
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
81
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
82
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
83
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
84
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
85
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
86
+ _sfdp_pattern_6_training = MultiOutputPattern([view_default_5,
87
+ view_default_9,
88
+ permute_default_4,
89
+ view_default_11,
90
+ None,
91
+ None
92
+ ])
93
+
94
+
95
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
96
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
97
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
98
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
99
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
100
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
101
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
102
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
103
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
104
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
105
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
106
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
107
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
108
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
109
+ clone_default = CallFunction(aten.clone.default, div_Tensor_1)
110
+ expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
111
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
112
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
113
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
114
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
115
+ _sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
116
+
117
+
118
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
119
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
120
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
121
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
122
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
123
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
124
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
125
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
126
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
127
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
128
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
129
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
130
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
131
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
132
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
133
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
134
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
135
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
136
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
137
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
138
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
139
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
140
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
141
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
142
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
143
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
144
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
145
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
146
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
147
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
148
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
149
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
150
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
151
+ clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
152
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
153
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
154
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
155
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
156
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
157
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
158
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
159
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
160
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
161
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
162
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
163
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored())
164
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
165
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
166
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
167
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
168
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
169
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
170
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
171
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
172
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
173
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
174
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
175
+ _sfdp_pattern_6_half_training = MultiOutputPattern([view_default_5,
176
+ view_default_9,
177
+ permute_default_4,
178
+ view_default_11,
179
+ None,
180
+ None
181
+ ])
182
+
183
+
184
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
185
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
186
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
187
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
188
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
189
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
190
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
191
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
192
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
193
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
194
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
195
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
196
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
197
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
198
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
199
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
200
+ clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
201
+ expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
202
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
203
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
204
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
205
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
206
+ _sfdp_pattern_6_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
35
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
36
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
37
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
38
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
39
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
40
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
41
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
42
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
43
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
44
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
45
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
46
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
47
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
48
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
49
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
50
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
51
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
52
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
53
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
54
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
55
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
56
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
57
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
58
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
59
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
60
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
61
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
62
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
63
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
64
+ view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
65
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
66
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
67
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
68
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
69
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
70
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
71
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
72
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
73
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
74
+ div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored())
75
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
76
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
77
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
78
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
79
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
80
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
81
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
82
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
83
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
84
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
85
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
86
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
87
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
88
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
89
+ _sfdp_pattern_8_training = MultiOutputPattern([view_default_5,
90
+ permute_default_6,
91
+ permute_default_9,
92
+ permute_default_11
93
+ ])
94
+
95
+
96
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
97
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
98
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
99
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
100
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
101
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
102
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
103
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
104
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
105
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
106
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
107
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
108
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
109
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
110
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
111
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
112
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
113
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
114
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
115
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
116
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
117
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
118
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
119
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
120
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
121
+ _sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
122
+
123
+
124
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
125
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
126
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
127
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
128
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
129
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
130
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
131
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
132
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
133
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
134
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
135
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
136
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
137
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
138
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
139
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
140
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
141
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
142
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
143
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
144
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
145
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
146
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
147
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
148
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
149
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
150
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
151
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
152
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
153
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
154
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
155
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
156
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
157
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
158
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
159
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
160
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
161
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
162
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
163
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
164
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
165
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, Ignored())
166
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
167
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
168
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
169
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
170
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
171
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
172
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
173
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
174
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
175
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
176
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
177
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
178
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
179
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
180
+ _sfdp_pattern_8_half_training = MultiOutputPattern([view_default_5,
181
+ permute_default_6,
182
+ permute_default_9,
183
+ permute_default_11
184
+ ])
185
+
186
+
187
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
188
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
189
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
190
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
191
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
192
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
193
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
194
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
195
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
196
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
197
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
198
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
199
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
200
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
201
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
202
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
203
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
204
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
205
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
206
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
207
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
208
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
209
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
210
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
211
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
212
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
213
+ _sfdp_pattern_8_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-311.pyc ADDED
Binary file (20.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-311.pyc ADDED
Binary file (10.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/bmm.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..lowering import register_lowering
4
+ from ..select_algorithm import (
5
+ autotune_select_algorithm,
6
+ ExternKernelChoice,
7
+ TritonTemplate,
8
+ )
9
+ from ..utils import ceildiv as cdiv, use_aten_gemm_kernels, use_triton_template
10
+
11
+ from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
12
+
13
+ aten = torch.ops.aten
14
+
15
+
16
+ def bmm_grid(b, m, n, meta):
17
+ return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
18
+
19
+
20
+ bmm_template = TritonTemplate(
21
+ name="bmm",
22
+ grid=bmm_grid,
23
+ source=r"""
24
+ {{def_kernel("A", "B")}}
25
+ M = {{size("A", -2)}}
26
+ N = {{size("B", -1)}}
27
+ K = {{size("A", -1)}}
28
+
29
+ stride_aq = {{stride("A", 0)}}
30
+ stride_am = {{stride("A", 1)}}
31
+ stride_ak = {{stride("A", 2)}}
32
+
33
+ stride_bq = {{stride("B", 0)}}
34
+ stride_bk = {{stride("B", 1)}}
35
+ stride_bn = {{stride("B", 2)}}
36
+
37
+ # based on triton.ops.matmul
38
+ pid = tl.program_id(0)
39
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
40
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
41
+
42
+ # re-order program ID for better L2 performance
43
+ width = GROUP_M * grid_n
44
+ group_id = pid // width
45
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
46
+ pid_m = group_id * GROUP_M + (pid % group_size)
47
+ pid_n = (pid % width) // (group_size)
48
+
49
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
50
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
51
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
52
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
53
+ rk = tl.arange(0, BLOCK_K)
54
+
55
+ idx_q = tl.program_id(1) # batch dimension for BMM
56
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
57
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
58
+
59
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
60
+ for k in range(K, 0, -BLOCK_K):
61
+ if EVEN_K:
62
+ a = tl.load(A)
63
+ b = tl.load(B)
64
+ else:
65
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
66
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
67
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
68
+ A += BLOCK_K * stride_ak
69
+ B += BLOCK_K * stride_bk
70
+
71
+ # rematerialize rm and rn to save registers
72
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
73
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
74
+ idx_q = tl.program_id(1) # batch dimension for BMM
75
+ idx_m = rm[:, None]
76
+ idx_n = rn[None, :]
77
+ mask = (idx_m < M) & (idx_n < N)
78
+
79
+ # inductor generates a suffix
80
+ {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}}
81
+ """,
82
+ )
83
+
84
+ aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
85
+ aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out")
86
+
87
+
88
+ @register_lowering(aten.bmm)
89
+ def tuned_bmm(mat1, mat2, *, layout=None):
90
+ m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
91
+
92
+ # options to tune from
93
+ choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
94
+ if use_triton_template(layout):
95
+ for config in mm_configs(m, n, k):
96
+ bmm_template.maybe_append_choice(
97
+ choices,
98
+ input_nodes=(mat1, mat2),
99
+ layout=layout,
100
+ **mm_options(config, m, n, k, layout),
101
+ )
102
+
103
+ return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
104
+
105
+
106
+ # Don't register this since it is slower than decomposing it
107
+ # @register_lowering(aten.baddbmm)
108
+ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
109
+ m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
110
+
111
+ # options to tune from
112
+ choices = (
113
+ [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
114
+ if use_aten_gemm_kernels()
115
+ else []
116
+ )
117
+ if use_triton_template(layout):
118
+ for config in mm_configs(m, n, k):
119
+ bmm_template.maybe_append_choice(
120
+ choices,
121
+ input_nodes=(inp, mat1, mat2),
122
+ layout=layout,
123
+ **mm_options(config, m, n, k, layout),
124
+ prefix_args=1,
125
+ epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
126
+ )
127
+
128
+ return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import torch
6
+ from torch._inductor.virtualized import V
7
+ from .. import config as inductor_config
8
+ from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate
9
+ from ..lowering import register_lowering
10
+ from ..select_algorithm import (
11
+ autotune_select_algorithm,
12
+ ExternKernelChoice,
13
+ TritonTemplate,
14
+ )
15
+ from ..utils import (
16
+ use_aten_gemm_kernels,
17
+ use_cutlass_template,
18
+ use_max_autotune,
19
+ use_triton_template,
20
+ )
21
+ from .mm_common import (
22
+ addmm_epilogue,
23
+ int8_mm_configs,
24
+ mm_args,
25
+ mm_configs,
26
+ mm_grid,
27
+ mm_options,
28
+ )
29
+
30
+ log = logging.getLogger(__name__)
31
+ aten = torch.ops.aten
32
+
33
+ mm_template = TritonTemplate(
34
+ name="mm",
35
+ grid=mm_grid,
36
+ source=r"""
37
+ {{def_kernel("A", "B")}}
38
+ M = {{size("A", 0)}}
39
+ N = {{size("B", 1)}}
40
+ K = {{size("A", 1)}}
41
+ if M * N == 0:
42
+ # early exit due to zero-size input(s)
43
+ return
44
+ stride_am = {{stride("A", 0)}}
45
+ stride_ak = {{stride("A", 1)}}
46
+ stride_bk = {{stride("B", 0)}}
47
+ stride_bn = {{stride("B", 1)}}
48
+
49
+ # based on triton.ops.matmul
50
+ pid = tl.program_id(0)
51
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
52
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
53
+
54
+ # re-order program ID for better L2 performance
55
+ width = GROUP_M * grid_n
56
+ group_id = pid // width
57
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
58
+ pid_m = group_id * GROUP_M + (pid % group_size)
59
+ pid_n = (pid % width) // (group_size)
60
+
61
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
62
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
63
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
64
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
65
+ rk = tl.arange(0, BLOCK_K)
66
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
67
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
68
+
69
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
70
+ for k in range(K, 0, -BLOCK_K):
71
+ if EVEN_K:
72
+ a = tl.load(A)
73
+ b = tl.load(B)
74
+ else:
75
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
76
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
77
+ if B_PROLOGUE_CAST_TYPE is not None:
78
+ b = b.to(B_PROLOGUE_CAST_TYPE)
79
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
80
+ A += BLOCK_K * stride_ak
81
+ B += BLOCK_K * stride_bk
82
+
83
+ # rematerialize rm and rn to save registers
84
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
85
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
86
+ idx_m = rm[:, None]
87
+ idx_n = rn[None, :]
88
+ mask = (idx_m < M) & (idx_n < N)
89
+
90
+ # inductor generates a suffix
91
+ {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
92
+ """,
93
+ )
94
+
95
+ aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
96
+
97
+
98
+ aten_addmm = ExternKernelChoice(
99
+ torch.addmm, "at::addmm_out", op_overload=aten.addmm.default
100
+ )
101
+
102
+ aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm")
103
+
104
+
105
+ def _is_int8_mat(mat):
106
+ return mat.get_dtype() in (torch.int8, torch.uint8)
107
+
108
+
109
+ def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
110
+ """
111
+ Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
112
+ kernel under the hood. There are a few shapes where this is slower,
113
+ but they are rare.
114
+ """
115
+ if inp.stride(0) == 0 or inp.size(0) == 1:
116
+ return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta)
117
+ return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta)
118
+
119
+
120
+ aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
121
+
122
+
123
+ @register_lowering(aten.mm, type_promotion_kind=None)
124
+ def tuned_mm(mat1, mat2, *, layout=None):
125
+ m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
126
+
127
+ # options to tune from
128
+ choices = [aten_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
129
+
130
+ if m * n != 0 and use_triton_template(layout):
131
+ for config in mm_configs(m, n, k):
132
+ mm_template.maybe_append_choice(
133
+ choices,
134
+ input_nodes=(mat1, mat2),
135
+ layout=layout,
136
+ **mm_options(config, m, n, k, layout),
137
+ )
138
+
139
+ if m * n != 0 and use_cutlass_template(layout):
140
+ CUTLASSGemmTemplate.add_cutlass_gemm_choices(
141
+ choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
142
+ )
143
+
144
+ from torch._inductor.ir import FixedLayout, FlexibleLayout
145
+
146
+ if (
147
+ len(choices) == 1
148
+ and use_aten_gemm_kernels()
149
+ and isinstance(layout, FixedLayout)
150
+ ):
151
+ # If we are not autotuning, we can swap to a FlexibleLayout
152
+ # in order to get fusion optimizations to kick in, e.g. ConcatFusion
153
+ layout = FlexibleLayout(
154
+ device=layout.device, dtype=layout.dtype, size=layout.size
155
+ )
156
+ choices = [aten_mm.bind((mat1, mat2), layout)]
157
+
158
+ return autotune_select_algorithm("mm", choices, [mat1, mat2], layout)
159
+
160
+
161
+ @register_lowering(aten._int_mm, type_promotion_kind=None)
162
+ def tuned_int_mm(mat1, mat2, *, layout=None):
163
+ m, n, k, layout, mat1, mat2 = mm_args(
164
+ mat1, mat2, layout=layout, out_dtype=torch.int32
165
+ )
166
+ choices = (
167
+ [aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
168
+ )
169
+ if m * n != 0 and use_triton_template(layout, enable_int32=True):
170
+ # TODO: Re-enable eager mode implementation once cuBLAS is fixed
171
+ choices = []
172
+ for config in int8_mm_configs(m, n, k):
173
+ mm_template.maybe_append_choice(
174
+ choices,
175
+ input_nodes=(mat1, mat2),
176
+ layout=layout,
177
+ **mm_options(config, m, n, k, layout),
178
+ )
179
+ return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
180
+
181
+
182
+ @register_lowering(aten.addmm, type_promotion_kind=None)
183
+ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
184
+ m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
185
+ if m * n == 0 or not use_max_autotune():
186
+ choices = (
187
+ [
188
+ aten_addmm.bind(
189
+ (inp, mat1, mat2),
190
+ layout,
191
+ alpha=alpha,
192
+ beta=beta,
193
+ )
194
+ ]
195
+ if use_aten_gemm_kernels()
196
+ else []
197
+ )
198
+ return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout)
199
+
200
+ choices = (
201
+ [
202
+ aten_addmm.bind(
203
+ (inp_expanded, mat1, mat2),
204
+ layout,
205
+ alpha=alpha,
206
+ beta=beta,
207
+ )
208
+ ]
209
+ if use_aten_gemm_kernels()
210
+ else []
211
+ )
212
+
213
+ if (
214
+ use_aten_gemm_kernels()
215
+ and inp_expanded.get_stride()[0] == 0
216
+ and inp_expanded.get_device().type == "cuda"
217
+ and inductor_config.triton.autotune_cublasLt
218
+ ):
219
+ # unexpand inp to make sure fused addmm from cublasLt is used
220
+ choices.insert(
221
+ 0,
222
+ aten_bias_addmm.bind(
223
+ (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
224
+ ),
225
+ )
226
+
227
+ if use_triton_template(layout):
228
+ for config in mm_configs(m, n, k):
229
+ mm_template.maybe_append_choice(
230
+ choices,
231
+ input_nodes=(inp_expanded, mat1, mat2),
232
+ layout=layout,
233
+ **mm_options(config, m, n, k, layout),
234
+ prefix_args=1,
235
+ epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
236
+ )
237
+
238
+ if use_cutlass_template(layout):
239
+ CUTLASSGemmTemplate.add_cutlass_gemm_choices(
240
+ choices,
241
+ layout,
242
+ [mat1, mat2, inp_expanded],
243
+ alpha=alpha,
244
+ beta=beta,
245
+ input_reorder=[2, 0, 1],
246
+ fuseable=False,
247
+ )
248
+
249
+ return autotune_select_algorithm(
250
+ "addmm", choices, [inp_expanded, mat1, mat2], layout
251
+ )
252
+
253
+
254
+ def fallback_mixed_mm(mat1, mat2, *, out):
255
+ return torch.mm(mat1, mat2.to(mat1.dtype), out=out)
256
+
257
+
258
+ aten_fallback_mixed_mm = ExternKernelChoice(fallback_mixed_mm, None)
259
+
260
+
261
+ @functools.lru_cache(None)
262
+ def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool:
263
+ props = torch.cuda.get_device_properties(index or 0)
264
+ return props.major <= 7
265
+
266
+
267
+ def tuned_mixed_mm(mat1, mat2, mat2_dtype):
268
+ m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None)
269
+ choices = [aten_fallback_mixed_mm.bind((mat1, mat2), layout)]
270
+ if (
271
+ mat1.layout.dtype != torch.float32 and not mat2.layout.is_contiguous()
272
+ ) or _is_sm7x_or_older_gpu(layout.device.index):
273
+ # can't use triton kernel unless one of these is true or if running on v100 (numerical issues)
274
+ return autotune_select_algorithm("mixed_mm", choices, [mat1, mat2], layout)
275
+ if inductor_config.force_mixed_mm:
276
+ choices = []
277
+ b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
278
+ has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2)
279
+ for config in mm_configs(m, n, k, has_int8_tensor=has_int8_tensor):
280
+ mm_template.maybe_append_choice(
281
+ choices,
282
+ input_nodes=(mat1, mat2),
283
+ layout=layout,
284
+ **mm_options(config, m, n, k, layout, b_prologue_cast_type),
285
+ )
286
+ return autotune_select_algorithm("mixed_mm", choices, [mat1, mat2], layout)
287
+
288
+
289
+ # This op is a special case of the int_mm op which we use based on the pattern
290
+ # _int_mm -> mul (defined in ../fx_passes/post_grad.py) in order to prevent
291
+ # realization of the int32 _int_mm output by forcing fusion with the mul op.
292
+ # This is only used when config.force_fuse_int_mm_with_mul = True
293
+ def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
294
+ out_dtype = (
295
+ torch.promote_types(mat3.get_dtype(), torch.int32)
296
+ if out_dtype is None
297
+ else out_dtype
298
+ )
299
+ m, n, k, layout, mat1, mat2, mat3 = mm_args(
300
+ mat1, mat2, mat3, layout=layout, out_dtype=out_dtype
301
+ )
302
+ choices: List[Dict[Any, Any]] = []
303
+ for config in int8_mm_configs(m, n, k):
304
+ mm_template.maybe_append_choice(
305
+ choices,
306
+ input_nodes=(mat1, mat2, mat3),
307
+ layout=layout,
308
+ **dict(mm_options(config, m, n, k, layout), ACC_TYPE="tl.int32"),
309
+ suffix_args=1,
310
+ epilogue_fn=V.ops.mul,
311
+ )
312
+ return autotune_select_algorithm("int_mm", choices, [mat1, mat2, mat3], layout)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_common.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ from typing import cast, List, Tuple
4
+
5
+ import sympy
6
+
7
+ import torch
8
+ from torch._inductor.select_algorithm import realize_inputs
9
+ from torch._inductor.virtualized import V
10
+
11
+ from .. import config as inductor_config
12
+ from ..utils import ceildiv as cdiv, next_power_of_2
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ def triton_config(num_stages, num_warps, **kwargs):
18
+ from triton import Config
19
+
20
+ return Config(kwargs, num_stages=num_stages, num_warps=num_warps)
21
+
22
+
23
+ def filtered_configs(
24
+ m: int,
25
+ n: int,
26
+ k: int,
27
+ configs: List[Tuple[int, int, int, int, int]],
28
+ has_int8_tensor=False,
29
+ ):
30
+ """Heuristic to shrink configs when they are bigger than the input size"""
31
+
32
+ # According to https://github.com/openai/triton/issues/2156#issuecomment-1695897424
33
+ # it's safer to use at least [32, 32] block size for int8/uint8
34
+ # tensors
35
+ min_block_size = 32 if has_int8_tensor else 16
36
+ m = max(
37
+ next_power_of_2(
38
+ V.graph.sizevars.size_hint(
39
+ m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
40
+ )
41
+ ),
42
+ min_block_size,
43
+ )
44
+ n = max(
45
+ next_power_of_2(
46
+ V.graph.sizevars.size_hint(
47
+ n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
48
+ )
49
+ ),
50
+ min_block_size,
51
+ )
52
+ k = max(
53
+ next_power_of_2(
54
+ V.graph.sizevars.size_hint(
55
+ k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
56
+ )
57
+ ),
58
+ min_block_size,
59
+ )
60
+ used = set()
61
+ for block_m, block_n, block_k, num_stages, num_warps in configs:
62
+ # shrink configs for small sizes
63
+ block_m = max(min(block_m, m), min_block_size)
64
+ block_n = max(min(block_n, n), min_block_size)
65
+ block_k = max(min(block_k, k), min_block_size)
66
+ # each warp computes 16x16 tile = 256
67
+ num_warps = min(num_warps, block_m * block_n // 256)
68
+ if torch.version.hip:
69
+ for matrix_instr_nonkdim in [0, 16]:
70
+ if matrix_instr_nonkdim != 0 and (
71
+ block_m % matrix_instr_nonkdim != 0
72
+ or block_n % matrix_instr_nonkdim != 0
73
+ ):
74
+ # block_m and block_n must be a multiple of matrix_instr_nonkdim
75
+ continue
76
+ if (
77
+ block_m,
78
+ block_n,
79
+ block_k,
80
+ num_stages,
81
+ num_warps,
82
+ matrix_instr_nonkdim,
83
+ ) not in used:
84
+ used.add(
85
+ (
86
+ block_m,
87
+ block_n,
88
+ block_k,
89
+ num_stages,
90
+ num_warps,
91
+ matrix_instr_nonkdim,
92
+ )
93
+ )
94
+ yield triton_config(
95
+ BLOCK_M=block_m,
96
+ BLOCK_N=block_n,
97
+ BLOCK_K=block_k,
98
+ num_stages=num_stages,
99
+ num_warps=num_warps,
100
+ matrix_instr_nonkdim=matrix_instr_nonkdim,
101
+ )
102
+ else:
103
+ if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used:
104
+ used.add((block_m, block_n, block_k, num_stages, num_warps, 0))
105
+ yield triton_config(
106
+ BLOCK_M=block_m,
107
+ BLOCK_N=block_n,
108
+ BLOCK_K=block_k,
109
+ num_stages=num_stages,
110
+ num_warps=num_warps,
111
+ )
112
+
113
+
114
+ # List of dictionaries to store the kernel configs. Configs that evaluate to true
115
+ # will be utilised on the target platform
116
+ mm_kernel_configs = [
117
+ # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
118
+ {"config": (64, 64, 32, 2, 4), "cond": True},
119
+ {"config": (64, 128, 32, 3, 4), "cond": True},
120
+ {"config": (128, 64, 32, 3, 4), "cond": True},
121
+ {"config": (64, 128, 32, 4, 8), "cond": True},
122
+ {"config": (128, 64, 32, 4, 8), "cond": True},
123
+ {"config": (64, 32, 32, 5, 8), "cond": True},
124
+ {"config": (32, 64, 32, 5, 8), "cond": True},
125
+ {"config": (128, 128, 32, 2, 8), "cond": True},
126
+ {"config": (64, 64, 64, 3, 8), "cond": True},
127
+ {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None},
128
+ {"config": (64, 64, 16, 2, 4), "cond": True},
129
+ {"config": (32, 32, 16, 1, 2), "cond": True},
130
+ ]
131
+
132
+ int8_mm_kernel_configs = [
133
+ {"config": (64, 64, 32, 2, 4), "cond": True},
134
+ {"config": (64, 128, 32, 3, 4), "cond": True},
135
+ {"config": (128, 64, 32, 3, 4), "cond": True},
136
+ {"config": (64, 128, 32, 4, 8), "cond": True},
137
+ {"config": (128, 64, 32, 4, 8), "cond": True},
138
+ {"config": (64, 32, 32, 5, 8), "cond": True},
139
+ {"config": (32, 64, 32, 5, 8), "cond": True},
140
+ {"config": (128, 128, 32, 2, 8), "cond": True},
141
+ {"config": (64, 64, 64, 3, 8), "cond": True},
142
+ # {"config": (32, 32, 128, 2, 4), "cond": True},
143
+ # {"config": (64, 64, 16, 2, 4), "cond": True},
144
+ # {"config": (32, 32, 16, 1, 2), "cond": True},
145
+ {"config": (128, 256, 128, 3, 8), "cond": torch.version.hip is None},
146
+ {"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None},
147
+ ]
148
+
149
+ # Create filtered list of configs based on cond evaluation
150
+
151
+
152
+ mm_platform_configs = tuple(
153
+ cast(Tuple[int, int, int, int, int], config["config"])
154
+ for config in mm_kernel_configs
155
+ if config["cond"]
156
+ )
157
+ int8_platform_configs = tuple(
158
+ cast(Tuple[int, int, int, int, int], config["config"])
159
+ for config in int8_mm_kernel_configs
160
+ if config["cond"]
161
+ )
162
+
163
+ # On ROCm convert num_stages to 1 as pipelining provides no benefit
164
+ if torch.version.hip:
165
+ mm_platform_configs = tuple(
166
+ (config[0], config[1], config[2], 1, config[4])
167
+ for config in mm_platform_configs
168
+ )
169
+ int8_platform_configs = tuple(
170
+ (config[0], config[1], config[2], 1, config[4])
171
+ for config in mm_platform_configs
172
+ )
173
+
174
+ mm_configs = functools.partial(
175
+ filtered_configs,
176
+ configs=mm_platform_configs,
177
+ )
178
+
179
+ int8_mm_configs = functools.partial(
180
+ filtered_configs,
181
+ configs=int8_platform_configs,
182
+ )
183
+
184
+
185
+ def mm_grid(m, n, meta):
186
+ """
187
+ The CUDA grid size for matmul triton templates.
188
+ """
189
+ return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
190
+
191
+
192
+ def acc_type(dtype):
193
+ if dtype in (torch.float16, torch.bfloat16):
194
+ return "tl.float32"
195
+ return f"tl.{dtype}".replace("torch.", "")
196
+
197
+
198
+ def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None):
199
+ """
200
+ Common options to matmul triton templates.
201
+ """
202
+ even_k_symbolic = (
203
+ # it isn't worth guarding on this
204
+ sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
205
+ == config.kwargs["BLOCK_K"]
206
+ )
207
+ allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
208
+ not inductor_config.force_same_precision
209
+ or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0)
210
+ )
211
+ return dict(
212
+ GROUP_M=8,
213
+ EVEN_K=even_k_symbolic,
214
+ ALLOW_TF32=allow_tf32,
215
+ ACC_TYPE=acc_type(layout.dtype),
216
+ B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
217
+ num_stages=config.num_stages,
218
+ num_warps=config.num_warps,
219
+ **config.kwargs,
220
+ )
221
+
222
+
223
+ def mm_args(mat1, mat2, *others, layout=None, out_dtype=None, use_4x2_dim=False):
224
+ """
225
+ Common arg processing for mm,bmm,addmm,etc
226
+ """
227
+ mat1, mat2 = realize_inputs(mat1, mat2)
228
+ *b1, m, k1 = mat1.get_size()
229
+ *b2, k2, n = mat2.get_size()
230
+ b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
231
+ if use_4x2_dim:
232
+ k2 = k2 * 2
233
+ k = V.graph.sizevars.guard_equals(k1, k2)
234
+ if layout is None:
235
+ from torch._inductor.ir import FixedLayout
236
+
237
+ if out_dtype is None:
238
+ out_dtype = mat1.get_dtype()
239
+ layout = FixedLayout(
240
+ mat1.get_device(),
241
+ out_dtype,
242
+ [*b, m, n],
243
+ )
244
+ else:
245
+ assert out_dtype is None, "out_dtype is ignored if layout is specified."
246
+
247
+ from ..lowering import expand
248
+
249
+ others = [realize_inputs(expand(x, layout.size)) for x in others]
250
+
251
+ return [m, n, k, layout, mat1, mat2, *others]
252
+
253
+
254
+ def addmm_epilogue(dtype, alpha, beta):
255
+ def epilogue(acc, bias):
256
+ if alpha != 1:
257
+ acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
258
+ if beta != 1:
259
+ bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
260
+ return V.ops.add(acc, bias)
261
+
262
+ return epilogue
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_plus_mm.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+
5
+ from ..lowering import lowerings
6
+ from ..select_algorithm import (
7
+ autotune_select_algorithm,
8
+ ExternKernelChoice,
9
+ TritonTemplate,
10
+ )
11
+ from ..utils import use_aten_gemm_kernels, use_triton_template
12
+ from ..virtualized import V
13
+ from .mm_common import mm_args, mm_grid, mm_options
14
+
15
+ aten = torch.ops.aten
16
+
17
+ aten_mm_plus_mm = ExternKernelChoice(
18
+ torch.ops.inductor._mm_plus_mm, "torch::inductor::_mm_plus_mm"
19
+ )
20
+
21
+ mm_plus_mm_template = TritonTemplate(
22
+ name="mm_plus_mm",
23
+ grid=mm_grid,
24
+ debug=False,
25
+ source=r"""
26
+ {{def_kernel("A", "B", "C", "D")}}
27
+ M = {{size("A", 0)}}
28
+ N = {{size("B", 1)}}
29
+ K1 = {{size("A", 1)}}
30
+ if M * N == 0:
31
+ # early exit due to zero-size input(s)
32
+ return
33
+ # K2 = {{size("C", 1)}}
34
+ stride_am = {{stride("A", 0)}}
35
+ stride_ak = {{stride("A", 1)}}
36
+ stride_bk = {{stride("B", 0)}}
37
+ stride_bn = {{stride("B", 1)}}
38
+ stride_cm = {{stride("C", 0)}}
39
+ stride_ck = {{stride("C", 1)}}
40
+ stride_dk = {{stride("D", 0)}}
41
+ stride_dn = {{stride("D", 1)}}
42
+
43
+ # based on triton.ops.matmul
44
+ pid = tl.program_id(0)
45
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
46
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
47
+
48
+ # re-order program ID for better L2 performance
49
+ width = GROUP_M * grid_n
50
+ group_id = pid // width
51
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
52
+ pid_m = group_id * GROUP_M + (pid % group_size)
53
+ pid_n = (pid % width) // (group_size)
54
+
55
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
56
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
57
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
58
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
59
+ rk = tl.arange(0, BLOCK_K)
60
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
61
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
62
+ C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck)
63
+ D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn)
64
+
65
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
66
+ for k1 in range(K1, 0, -BLOCK_K):
67
+ # First matmul with A @ B
68
+ if EVEN_K:
69
+ a = tl.load(A)
70
+ b = tl.load(B)
71
+ else:
72
+ a = tl.load(A, mask=rk[None, :] < k1, other=0.)
73
+ b = tl.load(B, mask=rk[:, None] < k1, other=0.)
74
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
75
+ A += BLOCK_K * stride_ak
76
+ B += BLOCK_K * stride_bk
77
+
78
+ for k2 in range(K1, 0, -BLOCK_K):
79
+
80
+ # Second matmul with C @ D
81
+ if EVEN_K:
82
+ c = tl.load(C)
83
+ d = tl.load(D)
84
+ else:
85
+ c = tl.load(C, mask=rk[None, :] < k2, other=0.)
86
+ d = tl.load(D, mask=rk[:, None] < k2, other=0.)
87
+ acc += tl.dot(c, d, allow_tf32=ALLOW_TF32)
88
+ C += BLOCK_K * stride_ck
89
+ D += BLOCK_K * stride_dk
90
+
91
+
92
+ idx_m = rm[:, None]
93
+ idx_n = rn[None, :]
94
+ mask = (idx_m < M) & (idx_n < N)
95
+
96
+ # inductor generates a suffix
97
+ {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
98
+ """,
99
+ )
100
+
101
+
102
+ @functools.lru_cache(None)
103
+ def mm_configs():
104
+ import triton
105
+
106
+ # List of dictionaries to store the kernel configs. Configs that evaluate to true
107
+ # will be utilised on the target platform
108
+ mm_triton_configs = [
109
+ {
110
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
111
+ "num_stages": 2,
112
+ "num_warps": 4,
113
+ "cond": True,
114
+ },
115
+ {
116
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
117
+ "num_stages": 3,
118
+ "num_warps": 8,
119
+ "cond": True,
120
+ },
121
+ {
122
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
123
+ "num_stages": 4,
124
+ "num_warps": 16,
125
+ "cond": True,
126
+ },
127
+ {
128
+ "config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32},
129
+ "num_stages": 4,
130
+ "num_warps": 8,
131
+ "cond": True,
132
+ },
133
+ {
134
+ "config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32},
135
+ "num_stages": 4,
136
+ "num_warps": 8,
137
+ "cond": True,
138
+ },
139
+ {
140
+ "config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32},
141
+ "num_stages": 1,
142
+ "num_warps": 8,
143
+ "cond": True,
144
+ },
145
+ {
146
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64},
147
+ "num_stages": 1,
148
+ "num_warps": 8,
149
+ "cond": True,
150
+ },
151
+ {
152
+ "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128},
153
+ "num_stages": 1,
154
+ "num_warps": 8,
155
+ "cond": torch.version.hip is None,
156
+ },
157
+ {
158
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16},
159
+ "num_stages": 2,
160
+ "num_warps": 4,
161
+ "cond": True,
162
+ },
163
+ {
164
+ "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16},
165
+ "num_stages": 1,
166
+ "num_warps": 2,
167
+ "cond": True,
168
+ },
169
+ ]
170
+
171
+ # Filter out configs in which cond evaluates to true
172
+ # On ROCm convert num_stages to 1 as pipelining provides no benefit
173
+ if torch.version.hip:
174
+ filtered_configs = [
175
+ triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"])
176
+ for c in mm_triton_configs
177
+ if c["cond"]
178
+ ]
179
+ else:
180
+ filtered_configs = [
181
+ triton.Config(
182
+ c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"]
183
+ )
184
+ for c in mm_triton_configs
185
+ if c["cond"]
186
+ ]
187
+
188
+ return filtered_configs
189
+
190
+
191
+ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
192
+ """
193
+ Computes mm(mat1, mat2) + mm(mat3, mat4)
194
+ """
195
+ m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
196
+ m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
197
+ # Optimization is optional, because we can always just not do the fusion
198
+ if (
199
+ m1 * n1 == 0
200
+ or m2 * n2 == 0
201
+ or not V.graph.sizevars.statically_known_list_equals(
202
+ mat1.get_size(), mat3.get_size()
203
+ )
204
+ or not V.graph.sizevars.statically_known_list_equals(
205
+ mat2.get_size(), mat4.get_size()
206
+ )
207
+ ):
208
+ # TODO(jansel): support different K values when this is fixed:
209
+ # https://github.com/openai/triton/issues/967
210
+ return lowerings[aten.add](
211
+ lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4)
212
+ )
213
+
214
+ assert layout1 == layout2
215
+ # options to tune from
216
+ choices = (
217
+ [aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)]
218
+ if use_aten_gemm_kernels()
219
+ else []
220
+ )
221
+ if use_triton_template(layout1):
222
+ for config in mm_configs():
223
+ # see https://github.com/openai/triton/issues/1298
224
+ # BLOCK_K = K causes llvm error
225
+ if config.kwargs["BLOCK_K"] < k1:
226
+ mm_plus_mm_template.maybe_append_choice(
227
+ choices,
228
+ input_nodes=(mat1, mat2, mat3, mat4),
229
+ layout=layout1,
230
+ **mm_options(config, m1, n1, k1, layout1),
231
+ )
232
+
233
+ return autotune_select_algorithm(
234
+ "mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1
235
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/EmptyTensor.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+ #include <ATen/core/TensorBase.h>
5
+
6
+ namespace at::detail {
7
+
8
+ C10_EXPORT TensorBase empty_mps(
9
+ IntArrayRef size,
10
+ c10::optional<ScalarType> dtype_opt,
11
+ c10::optional<Layout> layout_opt,
12
+ c10::optional<Device> device_opt,
13
+ c10::optional<bool> pin_memory_opt,
14
+ c10::optional<c10::MemoryFormat> memory_format_opt);
15
+ C10_EXPORT TensorBase empty_mps(
16
+ IntArrayRef size, const TensorOptions &options);
17
+
18
+ C10_EXPORT TensorBase empty_strided_mps(
19
+ IntArrayRef size,
20
+ IntArrayRef stride,
21
+ ScalarType dtype,
22
+ c10::optional<Device> device_opt);
23
+
24
+ C10_EXPORT TensorBase empty_strided_mps(
25
+ IntArrayRef size,
26
+ IntArrayRef stride,
27
+ const TensorOptions &options);
28
+
29
+ } // namespace at::detail
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/IndexKernels.h ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at::mps {
4
+
5
+ static const char * indexing_metal_shaders = R"INDEX_METAL(
6
+ #include <metal_stdlib>
7
+ #include <metal_atomic>
8
+
9
+ using namespace metal;
10
+
11
+ #if __METAL_VERSION__ < 300
12
+ struct IndexAB {
13
+ // Allow up to 16 indices
14
+ metal::array<constant void *, 16> indexArray [[ id(0) ]];
15
+ };
16
+ #else
17
+ struct IndexAB {
18
+ constant int64_t* indexArray;
19
+ };
20
+
21
+ #endif
22
+
23
+ template<typename T, typename OffsetsT>
24
+ kernel void index_select(
25
+ #if __METAL_VERSION__ >= 300
26
+ constant IndexAB * indexAB [[buffer(0)]],
27
+ #else
28
+ constant IndexAB & indexAB [[buffer(0)]],
29
+ #endif
30
+ constant void * indexSizes [[buffer(1)]],
31
+ constant void * indexStrides [[buffer(2)]],
32
+ constant OffsetsT * offsets [[buffer(3)]],
33
+ constant void * inputData [[buffer(4)]],
34
+ device void * outputData [[buffer(5)]],
35
+ constant uint32_t & num_indices [[buffer(6)]],
36
+ uint thread_index [[thread_position_in_grid]]) {
37
+ constant int64_t * index_sizes = (constant int64_t *)indexSizes;
38
+ constant int64_t * index_strides = (constant int64_t *)indexStrides;
39
+ int64_t offset = 0;
40
+ for (uint32_t i = 0; i < num_indices; i++) {
41
+ #if __METAL_VERSION__ >= 300
42
+ constant int64_t* indexArray = indexAB[i].indexArray;
43
+ #else
44
+ constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
45
+ #endif
46
+ int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
47
+ if (index < 0) {
48
+ index += index_sizes[i];
49
+ }
50
+ offset += index * index_strides[i];
51
+ }
52
+ device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
53
+ constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y + offset);
54
+ *out = *in;
55
+ }
56
+
57
+ template<typename T, typename OffsetsT>
58
+ void index_put_impl(
59
+ #if __METAL_VERSION__ >= 300
60
+ constant IndexAB * indexAB,
61
+ #else
62
+ constant IndexAB & indexAB,
63
+ #endif
64
+ constant int64_t * index_sizes,
65
+ constant int64_t * index_strides,
66
+ constant OffsetsT * offsets,
67
+ constant void * inputData,
68
+ device void * outputData,
69
+ constant uint32_t & num_indices,
70
+ uint thread_index) {
71
+ int64_t offset = 0;
72
+ for (uint32_t i = 0; i < num_indices; i++) {
73
+ #if __METAL_VERSION__ >= 300
74
+ constant int64_t* indexArray = indexAB[i].indexArray;
75
+ #else
76
+ constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
77
+ #endif
78
+ int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
79
+
80
+ if (index < 0) {
81
+ index += index_sizes[i];
82
+ }
83
+ offset += index * index_strides[i];
84
+ }
85
+ device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
86
+ constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
87
+ *out = *in;
88
+ }
89
+
90
+ template<typename T, typename OffsetsT>
91
+ kernel void index_put_serial(
92
+ #if __METAL_VERSION__ >= 300
93
+ constant IndexAB * indexAB [[buffer(0)]],
94
+ #else
95
+ constant IndexAB & indexAB [[buffer(0)]],
96
+ #endif
97
+ constant void * indexSizes [[buffer(1)]],
98
+ constant void * indexStrides [[buffer(2)]],
99
+ constant OffsetsT * offsets [[buffer(3)]],
100
+ constant void * inputData [[buffer(4)]],
101
+ device void * outputData [[buffer(5)]],
102
+ constant uint32_t & num_indices [[buffer(6)]],
103
+ constant uint * numIters [[buffer(7)]],
104
+ uint thread_index [[thread_position_in_grid]]) {
105
+
106
+ constant int64_t * index_sizes = (constant int64_t *)indexSizes;
107
+ constant int64_t * index_strides = (constant int64_t *)indexStrides;
108
+
109
+ for (uint iter_i = 0; iter_i < *numIters; iter_i++) {
110
+ index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i);
111
+ }
112
+ }
113
+
114
+ template<typename T, typename OffsetsT>
115
+ kernel void index_put(
116
+ #if __METAL_VERSION__ >= 300
117
+ constant IndexAB * indexAB [[buffer(0)]],
118
+ #else
119
+ constant IndexAB & indexAB [[buffer(0)]],
120
+ #endif
121
+ constant void * indexSizes [[buffer(1)]],
122
+ constant void * indexStrides [[buffer(2)]],
123
+ constant OffsetsT * offsets [[buffer(3)]],
124
+ constant void * inputData [[buffer(4)]],
125
+ device void * outputData [[buffer(5)]],
126
+ constant uint32_t & num_indices [[buffer(6)]],
127
+ uint thread_index [[thread_position_in_grid]]) {
128
+
129
+ constant int64_t * index_sizes = (constant int64_t *)indexSizes;
130
+ constant int64_t * index_strides = (constant int64_t *)indexStrides;
131
+ index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index);
132
+ }
133
+
134
+ #if __METAL_VERSION__ < 300
135
+ #define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
136
+ template \
137
+ [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
138
+ kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
139
+ constant IndexAB & indexAB [[buffer(0)]], \
140
+ constant void * indexSizes [[buffer(1)]], \
141
+ constant void * indexStrides [[buffer(2)]], \
142
+ constant IDX_DTYPE * offsets [[buffer(3)]], \
143
+ constant void * inputData [[buffer(4)]], \
144
+ device void * outputData [[buffer(5)]], \
145
+ constant uint32_t & num_indices [[buffer(6)]], \
146
+ uint thread_index [[thread_position_in_grid]]);
147
+ #else
148
+ #define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
149
+ template \
150
+ [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
151
+ kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
152
+ constant IndexAB * indexAB [[buffer(0)]], \
153
+ constant void * indexSizes [[buffer(1)]], \
154
+ constant void * indexStrides [[buffer(2)]], \
155
+ constant IDX_DTYPE * offsets [[buffer(3)]], \
156
+ constant void * inputData [[buffer(4)]], \
157
+ device void * outputData [[buffer(5)]], \
158
+ constant uint32_t & num_indices [[buffer(6)]], \
159
+ uint thread_index [[thread_position_in_grid]]);
160
+ #endif
161
+
162
+ #define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
163
+ REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
164
+ REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
165
+ REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
166
+ REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
167
+ REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
168
+ REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
169
+ REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
170
+ REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
171
+
172
+ REGISTER_INDEX_OP_ALL_DTYPES(select);
173
+ REGISTER_INDEX_OP_ALL_DTYPES(put);
174
+
175
+ #if __METAL_VERSION__ < 300
176
+ #define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
177
+ template \
178
+ [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
179
+ kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
180
+ constant IndexAB & indexAB [[buffer(0)]], \
181
+ constant void * indexSizes [[buffer(1)]], \
182
+ constant void * indexStrides [[buffer(2)]], \
183
+ constant IDX_DTYPE * offsets [[buffer(3)]], \
184
+ constant void * inputData [[buffer(4)]], \
185
+ device void * outputData [[buffer(5)]], \
186
+ constant uint32_t & num_indices [[buffer(6)]], \
187
+ constant uint * numIters [[buffer(7)]], \
188
+ uint thread_index [[thread_position_in_grid]]);
189
+ #else
190
+ #define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
191
+ template \
192
+ [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
193
+ kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
194
+ constant IndexAB * indexAB [[buffer(0)]], \
195
+ constant void * indexSizes [[buffer(1)]], \
196
+ constant void * indexStrides [[buffer(2)]], \
197
+ constant IDX_DTYPE * offsets [[buffer(3)]], \
198
+ constant void * inputData [[buffer(4)]], \
199
+ device void * outputData [[buffer(5)]], \
200
+ constant uint32_t & num_indices [[buffer(6)]], \
201
+ constant uint * numIters [[buffer(7)]], \
202
+ uint thread_index [[thread_position_in_grid]]);
203
+ #endif
204
+
205
+ #define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
206
+ REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
207
+ REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
208
+ REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
209
+ REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
210
+ REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
211
+ REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
212
+ REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
213
+ REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
214
+
215
+ REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial);
216
+
217
+ template<typename StridesT, typename DataT>
218
+ kernel void kernel_index_offsets(constant StridesT * strides [[buffer(0)]],
219
+ device DataT * data_offsets [[buffer(1)]],
220
+ constant uint * iter_shape [[buffer(2)]],
221
+ constant uint & num_dimensions [[buffer(3)]],
222
+ uint thread_index [[thread_position_in_grid]]) {
223
+ data_offsets[thread_index] = 0;
224
+ uint32_t idx = thread_index;
225
+ for (uint32_t dim = 0; dim < num_dimensions; dim++) {
226
+ uint32_t remainder = idx % iter_shape[dim];
227
+ idx /= iter_shape[dim];
228
+
229
+ data_offsets[thread_index] += remainder * DataT(strides[dim]);
230
+ }
231
+ }
232
+
233
+ template
234
+ [[host_name("kernel_index_offsets_32")]]
235
+ kernel void kernel_index_offsets<packed_uint3, uint3>(
236
+ constant packed_uint3 * strides [[buffer(0)]],
237
+ device uint3 * data_offsets [[buffer(1)]],
238
+ constant uint * iter_shape [[buffer(2)]],
239
+ constant uint & num_dimensions [[buffer(3)]],
240
+ uint thread_index [[thread_position_in_grid]]);
241
+
242
+ template
243
+ [[host_name("kernel_index_offsets_64")]]
244
+ kernel void kernel_index_offsets<packed_uint3, ulong3>(
245
+ constant packed_uint3 * strides [[buffer(0)]],
246
+ device ulong3 * data_offsets [[buffer(1)]],
247
+ constant uint * iter_shape [[buffer(2)]],
248
+ constant uint & num_dimensions [[buffer(3)]],
249
+ uint thread_index [[thread_position_in_grid]]);
250
+
251
+ template<typename T, typename E, typename OffsetsT>
252
+ kernel void index_put_accumulate_native_dtypes(
253
+ #if __METAL_VERSION__ >= 300
254
+ constant IndexAB * indexAB [[buffer(0)]],
255
+ #else
256
+ constant IndexAB & indexAB [[buffer(0)]],
257
+ #endif
258
+ constant void * indexSizes [[buffer(1)]],
259
+ constant void * indexStrides [[buffer(2)]],
260
+ constant OffsetsT * offsets [[buffer(3)]],
261
+ constant void * inputData [[buffer(4)]],
262
+ device void * outputData [[buffer(5)]],
263
+ constant uint32_t & num_indices [[buffer(6)]],
264
+ uint thread_index [[thread_position_in_grid]]) {
265
+ constant int64_t * index_sizes = (constant int64_t *)indexSizes;
266
+ constant int64_t * index_strides = (constant int64_t *)indexStrides;
267
+ int64_t offset = 0;
268
+ for (uint32_t i = 0; i < num_indices; i++) {
269
+ #if __METAL_VERSION__ >= 300
270
+ constant int64_t* indexArray = indexAB[i].indexArray;
271
+ #else
272
+ constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
273
+ #endif
274
+ int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
275
+ if (index < 0) {
276
+ index += index_sizes[i];
277
+ }
278
+ offset += index * index_strides[i];
279
+ }
280
+ device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
281
+ constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index].y);
282
+ atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
283
+ }
284
+
285
+ template<typename T>
286
+ __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) {
287
+ device atomic_uint* uintAddr = (device atomic_uint*)addr;
288
+ uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed);
289
+ T updated = as_type<T>(expected) + value;
290
+ while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type<uint>(updated), memory_order_relaxed, memory_order_relaxed)) {
291
+ updated = as_type<T>(expected) + value;
292
+ }
293
+ }
294
+
295
+ template<typename T, typename OffsetsT>
296
+ kernel void atomic_index_put_accumulate(
297
+ #if __METAL_VERSION__ >= 300
298
+ constant IndexAB * indexAB [[buffer(0)]],
299
+ #else
300
+ constant IndexAB & indexAB [[buffer(0)]],
301
+ #endif
302
+ constant void * indexSizes [[buffer(1)]],
303
+ constant void * indexStrides [[buffer(2)]],
304
+ constant OffsetsT * offsets [[buffer(3)]],
305
+ constant void * inputData [[buffer(4)]],
306
+ device void * outputData [[buffer(5)]],
307
+ constant uint32_t & num_indices [[buffer(6)]],
308
+ uint thread_index [[thread_position_in_grid]]) {
309
+ constant int64_t * index_sizes = (constant int64_t *)indexSizes;
310
+ constant int64_t * index_strides = (constant int64_t *)indexStrides;
311
+ int64_t offset = 0;
312
+ for (uint32_t i = 0; i < num_indices; i++) {
313
+ #if __METAL_VERSION__ >= 300
314
+ constant int64_t* indexArray = indexAB[i].indexArray;
315
+ #else
316
+ constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
317
+ #endif
318
+ int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
319
+ if (index < 0) {
320
+ index += index_sizes[i];
321
+ }
322
+ offset += index * index_strides[i];
323
+ }
324
+ device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset);
325
+ constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
326
+ atomic_fetch_add_relaxed<T>(out, *in);
327
+ }
328
+
329
+ template
330
+ [[host_name("index_put_accumulate_32bit_float_idx32")]]
331
+ kernel void atomic_index_put_accumulate<float, uint3>(
332
+ #if __METAL_VERSION__ >= 300
333
+ constant IndexAB * indexAB [[buffer(0)]],
334
+ #else
335
+ constant IndexAB & indexAB [[buffer(0)]],
336
+ #endif
337
+ constant void * indexSizes [[buffer(1)]],
338
+ constant void * indexStrides [[buffer(2)]],
339
+ constant uint3 * offsets [[buffer(3)]],
340
+ constant void * inputData [[buffer(4)]],
341
+ device void * outputData [[buffer(5)]],
342
+ constant uint32_t & num_indices [[buffer(6)]],
343
+ uint thread_index [[thread_position_in_grid]]);
344
+
345
+ template
346
+ [[host_name("index_put_accumulate_32bit_float_idx64")]]
347
+ kernel void atomic_index_put_accumulate<float, ulong3>(
348
+ #if __METAL_VERSION__ >= 300
349
+ constant IndexAB * indexAB [[buffer(0)]],
350
+ #else
351
+ constant IndexAB & indexAB [[buffer(0)]],
352
+ #endif
353
+ constant void * indexSizes [[buffer(1)]],
354
+ constant void * indexStrides [[buffer(2)]],
355
+ constant ulong3 * offsets [[buffer(3)]],
356
+ constant void * inputData [[buffer(4)]],
357
+ device void * outputData [[buffer(5)]],
358
+ constant uint32_t & num_indices [[buffer(6)]],
359
+ uint thread_index [[thread_position_in_grid]]);
360
+
361
+ template
362
+ [[host_name("index_put_accumulate_32bit_int_idx32")]]
363
+ kernel void index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
364
+ #if __METAL_VERSION__ >= 300
365
+ constant IndexAB * indexAB [[buffer(0)]],
366
+ #else
367
+ constant IndexAB & indexAB [[buffer(0)]],
368
+ #endif
369
+ constant void * indexSizes [[buffer(1)]],
370
+ constant void * indexStrides [[buffer(2)]],
371
+ constant uint3 * offsets [[buffer(3)]],
372
+ constant void * inputData [[buffer(4)]],
373
+ device void * outputData [[buffer(5)]],
374
+ constant uint32_t & num_indices [[buffer(6)]],
375
+ uint thread_index [[thread_position_in_grid]]);
376
+
377
+ template
378
+ [[host_name("index_put_accumulate_32bit_int_idx64")]]
379
+ kernel void index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
380
+ #if __METAL_VERSION__ >= 300
381
+ constant IndexAB * indexAB [[buffer(0)]],
382
+ #else
383
+ constant IndexAB & indexAB [[buffer(0)]],
384
+ #endif
385
+ constant void * indexSizes [[buffer(1)]],
386
+ constant void * indexStrides [[buffer(2)]],
387
+ constant ulong3 * offsets [[buffer(3)]],
388
+ constant void * inputData [[buffer(4)]],
389
+ device void * outputData [[buffer(5)]],
390
+ constant uint32_t & num_indices [[buffer(6)]],
391
+ uint thread_index [[thread_position_in_grid]]);
392
+ )INDEX_METAL";
393
+
394
+ static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
395
+ struct __attribute__ ((packed)) packed_uint5{{
396
+ uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
397
+ }};
398
+
399
+ template<typename Y, typename X>
400
+ Y cast(const X x);
401
+
402
+ template<>
403
+ {1} cast<{1}, {0}>(const {0} x) {{
404
+ return {2};
405
+ }}
406
+
407
+ kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]],
408
+ constant void * src_ [[buffer(0)]],
409
+ device void * dst_ [[buffer(1)]],
410
+ constant packed_uint5 & size [[buffer(2)]],
411
+ constant packed_uint5 & stride [[buffer(3)]],
412
+ constant uint32_t & numel [[buffer(4)]]) {{
413
+ if (linear_index >= numel) return;
414
+
415
+ constant {0} * src = (constant {0} *)src_;
416
+ device {1} * dst = (device {1} *)dst_;
417
+
418
+ packed_uint5 local_index;
419
+ local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
420
+ local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
421
+ local_index.z = linear_index / (size.u * size.w) % size.z;
422
+ local_index.w = linear_index / size.u % size.w;
423
+ local_index.u = linear_index % size.u;
424
+
425
+ packed_uint5 strided_index;
426
+ strided_index.x = local_index.x * stride.x;
427
+ strided_index.y = local_index.y * stride.y;
428
+ strided_index.z = local_index.z * stride.z;
429
+ strided_index.w = local_index.w * stride.w;
430
+ strided_index.u = local_index.u * stride.u;
431
+
432
+ dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]);
433
+ }}
434
+
435
+ kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
436
+ constant void * src_ [[buffer(0)]],
437
+ device void * dst_ [[buffer(1)]],
438
+ constant packed_uint4 & size [[buffer(2)]],
439
+ constant packed_uint4 & stride [[buffer(3)]],
440
+ constant uint32_t & numel [[buffer(4)]]) {{
441
+ if (linear_index >= numel) return;
442
+
443
+ constant {0} * src = (constant {0} *)src_;
444
+ device {1} * dst = (device {1} *)dst_;
445
+
446
+ packed_uint4 local_index;
447
+ local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
448
+ local_index.y = linear_index / (size[3] * size[2]) % size[1];
449
+ local_index.z = linear_index / size[3] % size[2];
450
+ local_index.w = linear_index % size[3];
451
+
452
+ const packed_uint4 strided_index = local_index * stride;
453
+ dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
454
+ }}
455
+
456
+ kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]],
457
+ constant void * src_ [[buffer(0)]],
458
+ device void * dst_ [[buffer(1)]],
459
+ constant packed_uint3 & size [[buffer(2)]],
460
+ constant packed_uint3 & stride [[buffer(3)]],
461
+ constant uint32_t & numel [[buffer(4)]]) {{
462
+ if (linear_index >= numel) return;
463
+
464
+ constant {0} * src = (constant {0} *)src_;
465
+ device {1} * dst = (device {1} *)dst_;
466
+
467
+ packed_uint3 local_index;
468
+ local_index.x = linear_index / (size[2] * size[1]) % size[0];
469
+ local_index.y = linear_index / size[2] % size[1];
470
+ local_index.z = linear_index % size[2];
471
+
472
+ const packed_uint3 strided_index = local_index * stride;
473
+ dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
474
+ }}
475
+
476
+ kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]],
477
+ constant void * src_ [[buffer(0)]],
478
+ device void * dst_ [[buffer(1)]],
479
+ constant packed_uint2 & size [[buffer(2)]],
480
+ constant packed_uint2 & stride [[buffer(3)]],
481
+ constant uint32_t & numel [[buffer(4)]]) {{
482
+ if (linear_index >= numel) return;
483
+
484
+ constant {0} * src = (constant {0} *)src_;
485
+ device {1} * dst = (device {1} *)dst_;
486
+
487
+ packed_uint2 local_index;
488
+ local_index.x = linear_index / size[1] % size[0];
489
+ local_index.y = linear_index % size[1];
490
+
491
+ const packed_uint2 strided_index = local_index * stride;
492
+ dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
493
+ }}
494
+
495
+ kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]],
496
+ constant void * src_ [[buffer(0)]],
497
+ device void * dst_ [[buffer(1)]],
498
+ constant int & size [[buffer(2)]],
499
+ constant int & stride [[buffer(3)]],
500
+ constant uint32_t & numel [[buffer(4)]]) {{
501
+ if (linear_index >= numel) return;
502
+
503
+ constant {0} * src = (constant {0} *)src_;
504
+ device {1} * dst = (device {1} *)dst_;
505
+
506
+ const int local_index = linear_index % size;
507
+ const int strided_index = local_index * stride;
508
+ dst[strided_index] = cast<{1}>(src[linear_index]);
509
+ }}
510
+ )METAL_SCATTER";
511
+
512
+ static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER(
513
+ struct __attribute__ ((packed)) packed_uint5{{
514
+ uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
515
+ }};
516
+
517
+ template<typename Y, typename X>
518
+ Y cast(const X x);
519
+
520
+ template<>
521
+ {1} cast<{1}, {0}>(const {0} x) {{
522
+ return {2};
523
+ }}
524
+
525
+ kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]],
526
+ constant void * src_ [[buffer(0)]],
527
+ device void * dst_ [[buffer(1)]],
528
+ constant packed_uint5 & size [[buffer(2)]],
529
+ constant packed_uint5 & stride [[buffer(3)]],
530
+ constant uint32_t & numel [[buffer(4)]]) {{
531
+ if (linear_index >= numel) return;
532
+
533
+ constant {0} * src = (constant {0} *)src_;
534
+ device {1} * dst = (device {1} *)dst_;
535
+
536
+
537
+ packed_uint5 local_index;
538
+ local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
539
+ local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
540
+ local_index.z = linear_index / (size.u * size.w) % size.z;
541
+ local_index.w = linear_index / size.u % size.w;
542
+ local_index.u = linear_index % size.u;
543
+
544
+ packed_uint5 strided_index;
545
+ strided_index.x = local_index.x * stride.x;
546
+ strided_index.y = local_index.y * stride.y;
547
+ strided_index.z = local_index.z * stride.z;
548
+ strided_index.w = local_index.w * stride.w;
549
+ strided_index.u = local_index.u * stride.u;
550
+
551
+ dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]);
552
+ }}
553
+
554
+ kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
555
+ constant void * src_ [[buffer(0)]],
556
+ device void * dst_ [[buffer(1)]],
557
+ constant packed_uint4 & size [[buffer(2)]],
558
+ constant packed_uint4 & stride [[buffer(3)]],
559
+ constant uint32_t & numel [[buffer(4)]]) {{
560
+ if (linear_index >= numel) return;
561
+
562
+ constant {0} * src = (constant {0} *)src_;
563
+ device {1} * dst = (device {1} *)dst_;
564
+
565
+ packed_uint4 local_index;
566
+ local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
567
+ local_index.y = linear_index / (size[3] * size[2]) % size[1];
568
+ local_index.z = linear_index / size[3] % size[2];
569
+ local_index.w = linear_index % size[3];
570
+
571
+ const packed_uint4 strided_index = local_index * stride;
572
+ dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
573
+ }}
574
+
575
+ kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]],
576
+ constant void * src_ [[buffer(0)]],
577
+ device void * dst_ [[buffer(1)]],
578
+ constant packed_uint3 & size [[buffer(2)]],
579
+ constant packed_uint3 & stride [[buffer(3)]],
580
+ constant uint32_t & numel [[buffer(4)]]) {{
581
+ if (linear_index >= numel) return;
582
+
583
+ constant {0} * src = (constant {0} *)src_;
584
+ device {1} * dst = (device {1} *)dst_;
585
+
586
+ packed_uint3 local_index;
587
+ local_index.x = linear_index / (size[2] * size[1]) % size[0];
588
+ local_index.y = linear_index / size[2] % size[1];
589
+ local_index.z = linear_index % size[2];
590
+
591
+ const packed_uint3 strided_index = local_index * stride;
592
+ dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
593
+ }}
594
+
595
+ kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]],
596
+ constant void * src_ [[buffer(0)]],
597
+ device void * dst_ [[buffer(1)]],
598
+ constant packed_uint2 & size [[buffer(2)]],
599
+ constant packed_uint2 & stride [[buffer(3)]],
600
+ constant uint32_t & numel [[buffer(4)]]) {{
601
+ if (linear_index >= numel) return;
602
+
603
+ constant {0} * src = (constant {0} *)src_;
604
+ device {1} * dst = (device {1} *)dst_;
605
+
606
+ packed_uint2 local_index;
607
+ local_index.x = linear_index / size[1] % size[0];
608
+ local_index.y = linear_index % size[1];
609
+
610
+ const packed_uint2 strided_index = local_index * stride;
611
+ dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
612
+ }}
613
+
614
+ kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]],
615
+ constant void * src_ [[buffer(0)]],
616
+ device void * dst_ [[buffer(1)]],
617
+ constant int & size [[buffer(2)]],
618
+ constant int & stride [[buffer(3)]],
619
+ constant uint32_t & numel [[buffer(4)]]) {{
620
+ if (linear_index >= numel) return;
621
+
622
+ constant {0} * src = (constant {0} *)src_;
623
+ device {1} * dst = (device {1} *)dst_;
624
+
625
+ const int local_index = linear_index % size;
626
+ const int strided_index = local_index * stride;
627
+ dst[linear_index] = cast<{1}>(src[strided_index]);
628
+ }}
629
+ )METAL_GATHER";
630
+ } // namespace at::mps
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGuardImpl.h ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+ #include <c10/core/impl/DeviceGuardImplInterface.h>
5
+ #include <c10/macros/Macros.h>
6
+ #include <c10/util/Exception.h>
7
+ #include <ATen/Context.h>
8
+ #include <ATen/mps/MPSStream.h>
9
+ #include <ATen/mps/MPSEvent.h>
10
+
11
+ #ifdef __OBJC__
12
+ #include <Foundation/Foundation.h>
13
+ #include <Metal/Metal.h>
14
+ #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
15
+ #endif
16
+
17
+ #include <ATen/Tensor.h>
18
+ #include <c10/core/MemoryFormat.h>
19
+ #include <c10/core/Storage.h>
20
+ #include <c10/core/TensorImpl.h>
21
+ #include <sys/_types/_size_t.h>
22
+ #include <memory>
23
+ #include <c10/core/UndefinedTensorImpl.h>
24
+ #include <c10/util/intrusive_ptr.h>
25
+
26
+
27
+ namespace at::mps {
28
+
29
+ typedef MPSEvent* mpsEvent_t;
30
+
31
+ // TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
32
+ // https://github.com/pytorch/pytorch/issues/77170
33
+ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
34
+ static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
35
+
36
+ // constructor
37
+ MPSGuardImpl() {}
38
+ explicit MPSGuardImpl(c10::DeviceType t) {
39
+ TORCH_INTERNAL_ASSERT(t == c10::DeviceType::MPS);
40
+ }
41
+
42
+ // returns the type
43
+ c10::DeviceType type() const override {
44
+ return c10::DeviceType::MPS;
45
+ }
46
+
47
+ Device exchangeDevice(Device d) const override {
48
+ return Device(c10::DeviceType::MPS, 0);
49
+ }
50
+
51
+ Device getDevice() const override {
52
+ return Device(c10::DeviceType::MPS, 0);
53
+ }
54
+
55
+ c10::optional<Device> uncheckedGetDevice() const noexcept {
56
+ return Device(c10::DeviceType::MPS, 0);
57
+ }
58
+
59
+ void setDevice(Device d) const override {
60
+ TORCH_INTERNAL_ASSERT(d.is_mps());
61
+ }
62
+
63
+ void uncheckedSetDevice(Device d) const noexcept override {
64
+ // TODO: Currently setting only device 0
65
+ }
66
+
67
+ Stream getStream(Device d) const noexcept override {
68
+ return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
69
+ }
70
+
71
+ Stream getDefaultStream(Device d) const override {
72
+ return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
73
+ }
74
+
75
+ // NB: These do NOT set the current device
76
+ Stream exchangeStream(Stream s) const noexcept override {
77
+ return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
78
+ }
79
+ DeviceIndex deviceCount() const noexcept override {
80
+ if (at::hasMPS()) {
81
+ //TODO: extend it for multi-device case
82
+ return 1;
83
+ } else {
84
+ return 0;
85
+ }
86
+ }
87
+
88
+ // Event-related functions
89
+ void createEvent(
90
+ mpsEvent_t* event,
91
+ const EventFlag flag) const;
92
+
93
+ void destroyEvent(
94
+ void* event,
95
+ const DeviceIndex device_index) const noexcept override;
96
+
97
+ void record(
98
+ void** event,
99
+ const Stream& stream,
100
+ const DeviceIndex device_index,
101
+ const EventFlag flag) const override;
102
+
103
+ void block(
104
+ void* event,
105
+ const Stream& stream) const override;
106
+
107
+ bool queryEvent(void* event) const override;
108
+
109
+ };
110
+
111
+ /// A variant of OptionalDeviceGuard that is specialized for MPS.
112
+ struct OptionalMPSGuard {
113
+ explicit OptionalMPSGuard() : guard_() {}
114
+
115
+ explicit OptionalMPSGuard(c10::optional<Device> device_opt)
116
+ : guard_(device_opt) {}
117
+
118
+ /// Set the current MPS device to the passed device index, if it is not
119
+ /// nullopt
120
+ explicit OptionalMPSGuard(c10::optional<DeviceIndex> device_index_opt)
121
+ : guard_(device_index_opt) {}
122
+
123
+ // Copy is not allowed
124
+ OptionalMPSGuard(const OptionalMPSGuard&) = delete;
125
+ OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete;
126
+ OptionalMPSGuard(OptionalMPSGuard&& other) = delete;
127
+ OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete;
128
+
129
+ /// Sets the MPS device to the given device, initializing the guard if it
130
+ /// is not already initialized. Errors if the given device is not a MPS
131
+ /// device.
132
+ void set_device(Device device) {
133
+ guard_.set_device(device);
134
+ }
135
+
136
+ /// Sets the MPS device to the given device, initializing the guard if it is
137
+ /// not already initialized. Errors if the given device is not a MPS device.
138
+ void reset_device(Device device) {
139
+ guard_.reset_device(device);
140
+ }
141
+
142
+ /// Sets the MPS device to the given device index, initializing the guard if
143
+ /// it is not already initialized.
144
+ void set_index(DeviceIndex device_index) {
145
+ guard_.set_index(device_index);
146
+ }
147
+
148
+ /// Returns the device that was set immediately prior to initialization of the
149
+ /// guard, or nullopt if the guard is uninitialized.
150
+ c10::optional<Device> original_device() const {
151
+ return guard_.original_device();
152
+ }
153
+
154
+ /// Returns the most recent device that was set using this device guard,
155
+ /// either from construction, or via set_device, if the guard is initialized,
156
+ /// or nullopt if the guard is uninitialized.
157
+ c10::optional<Device> current_device() const {
158
+ return guard_.current_device();
159
+ }
160
+
161
+ /// Restore the original MPS device, resetting this guard to uninitialized
162
+ /// state.
163
+ void reset() {
164
+ guard_.reset();
165
+ }
166
+
167
+ private:
168
+ c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
169
+ };
170
+
171
+
172
+ C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl);
173
+
174
+ } // namespace at::mps
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/CPUApplyUtils.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/Dispatch_v2.h>
6
+ #include <ATen/ExpandBase.h>
7
+ #include <ATen/core/DistributionsHelper.h>
8
+ #include <ATen/native/TensorIterator.h>
9
+ #include <ATen/native/cpu/Loops.h>
10
+ #include <limits>
11
+ #include <mutex>
12
+
13
+ #ifdef CPU_CAPABILITY_AVX2
14
+ #include <ATen/native/cpu/avx_mathfun.h>
15
+ #include <c10/util/irange.h>
16
+ #endif
17
+
18
+
19
+ namespace at {
20
+ namespace native {
21
+ namespace templates {
22
+ namespace cpu {
23
+ namespace {
24
+
25
+ // ==================================================== Random ========================================================
26
+
27
+ template<typename RNG>
28
+ void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) {
29
+ AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] {
30
+ std::lock_guard<std::mutex> lock(generator->mutex_);
31
+ cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
32
+ uniform_int_from_to_distribution<scalar_t> random(range, base);
33
+ return random(generator);
34
+ });
35
+ }), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
36
+ }
37
+
38
+ // This is the special kernel to handle single specific case:
39
+ // from(inclusive) = std::numeric_limits<int64_t>::lowest()
40
+ // to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
41
+ template<typename RNG>
42
+ void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) {
43
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] {
44
+ if constexpr (std::is_same<scalar_t, int64_t>::value ||
45
+ std::is_same<scalar_t, double>::value ||
46
+ std::is_same<scalar_t, float>::value ||
47
+ std::is_same<scalar_t, at::BFloat16>::value) {
48
+ std::lock_guard<std::mutex> lock(generator->mutex_);
49
+ cpu_serial_kernel(iter, [generator]() -> scalar_t {
50
+ uniform_int_full_range_distribution<scalar_t> random;
51
+ return random(generator);
52
+ });
53
+ } else {
54
+ TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16");
55
+ }
56
+ });
57
+ }
58
+
59
+ template<typename RNG>
60
+ struct RandomFromToKernel {
61
+ void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, c10::optional<Generator> gen) {
62
+ random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
63
+ }
64
+ void operator()(TensorIteratorBase& iter, c10::optional<Generator> gen) {
65
+ random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
66
+ }
67
+ };
68
+
69
+ template<typename RNG>
70
+ void random_kernel(TensorIteratorBase& iter, RNG generator) {
71
+ std::lock_guard<std::mutex> lock(generator->mutex_);
72
+ AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] {
73
+ cpu_serial_kernel(iter, [generator]() -> scalar_t {
74
+ uniform_int_distribution<scalar_t> random;
75
+ return random(generator);
76
+ });
77
+ });
78
+ }
79
+
80
+ template<typename RNG>
81
+ struct RandomKernel {
82
+ void operator()(TensorIteratorBase& iter, c10::optional<Generator> gen) {
83
+ random_kernel(iter, check_generator<RNG>(gen));
84
+ }
85
+ };
86
+
87
+ // ==================================================== Normal ========================================================
88
+
89
+ #ifdef CPU_CAPABILITY_AVX2
90
+ static void normal_fill_16_AVX2(float *data,
91
+ const __m256* two_pi,
92
+ const __m256* one,
93
+ const __m256* minus_two,
94
+ const __m256* mean,
95
+ const __m256* std_v) {
96
+ const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data));
97
+ const __m256 u2 = _mm256_loadu_ps(data + 8);
98
+ // sincos256_ps and log256_ps are from avx_mathfun.h
99
+ const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1)));
100
+ const __m256 theta = _mm256_mul_ps(*two_pi, u2);
101
+ __m256 sintheta, costheta;
102
+ sincos256_ps(theta, &sintheta, &costheta);
103
+ const __m256 n1 = _mm256_mul_ps(radius, costheta);
104
+ const __m256 n2 = _mm256_mul_ps(radius, sintheta);
105
+ _mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean));
106
+ _mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean));
107
+ }
108
+
109
+ template<typename RNG>
110
+ void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) {
111
+ float *data = self.data_ptr<float>();
112
+ auto size = self.numel();
113
+ std::lock_guard<std::mutex> lock(generator->mutex_);
114
+ for (const auto i : c10::irange(size)) {
115
+ at::uniform_real_distribution<float> uniform(0, 1);
116
+ data[i] = uniform(generator);
117
+ }
118
+ const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi<double>);
119
+ const __m256 one = _mm256_set1_ps(1.0f);
120
+ const __m256 minus_two = _mm256_set1_ps(-2.0f);
121
+ const __m256 mean_v = _mm256_set1_ps(mean);
122
+ const __m256 std_v = _mm256_set1_ps(std);
123
+
124
+ for (int64_t i = 0; i < size - 15; i += 16) {
125
+ normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v);
126
+ }
127
+
128
+ if (size % 16 != 0) {
129
+ // Recompute the last 16 values.
130
+ data = data + size - 16;
131
+ for (const auto i : c10::irange(16)) {
132
+ at::uniform_real_distribution<float> uniform(0, 1);
133
+ data[i] = uniform(generator);
134
+ }
135
+ normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v);
136
+ }
137
+ }
138
+ #endif
139
+
140
+ template <typename scalar_t>
141
+ static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) {
142
+ for (const auto j : c10::irange(8)) {
143
+ const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log.
144
+ const scalar_t u2 = data[j + 8];
145
+ const scalar_t radius = std::sqrt(-2 * std::log(u1));
146
+ const scalar_t theta = 2.0f * c10::pi<double> * u2;
147
+ data[j] = radius * std::cos(theta) * std + mean;
148
+ data[j + 8] = radius * std::sin(theta) * std + mean;
149
+ }
150
+ }
151
+
152
+ template <typename scalar_t, typename RNG>
153
+ void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
154
+ scalar_t *data = self.data_ptr<scalar_t>();
155
+ auto size = self.numel();
156
+ std::lock_guard<std::mutex> lock(generator->mutex_);
157
+ for (const auto i : c10::irange(size)) {
158
+ at::uniform_real_distribution<scalar_t> uniform(0, 1);
159
+ data[i] = uniform(generator);
160
+ }
161
+
162
+ for (int64_t i = 0; i < size - 15; i += 16) {
163
+ normal_fill_16<scalar_t>(data + i, mean, std);
164
+ }
165
+ if (size % 16 != 0) {
166
+ // Recompute the last 16 values.
167
+ data = data + size - 16;
168
+ for (const auto i : c10::irange(16)) {
169
+ at::uniform_real_distribution<scalar_t> uniform(0, 1);
170
+ data[i] = uniform(generator);
171
+ }
172
+ normal_fill_16<scalar_t>(data, mean, std);
173
+ }
174
+ }
175
+
176
+ template<typename RNG>
177
+ void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) {
178
+ auto size = self.numel();
179
+ if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) {
180
+ #ifdef CPU_CAPABILITY_AVX2
181
+ normal_fill_AVX2(self, static_cast<float>(mean), static_cast<float>(std), generator);
182
+ #else
183
+ normal_fill(self, static_cast<float>(mean), static_cast<float>(std), generator);
184
+ #endif
185
+ } else {
186
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] {
187
+ if (size >= 16 && self.is_contiguous()) {
188
+ normal_fill<scalar_t>(self, static_cast<scalar_t>(mean), static_cast<scalar_t>(std), generator);
189
+ } else {
190
+ auto iter = TensorIterator::borrowing_nullary_op(self);
191
+ std::lock_guard<std::mutex> lock(generator->mutex_);
192
+ cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t {
193
+ at::normal_distribution<double> normal(mean, std);
194
+ return static_cast<scalar_t>(normal(generator));
195
+ });
196
+ }
197
+ });
198
+ }
199
+ }
200
+
201
+ template<typename RNG>
202
+ struct NormalKernel {
203
+ void operator()(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
204
+ normal_kernel(self, mean, std, check_generator<RNG>(gen));
205
+ }
206
+ };
207
+
208
+ // ==================================================== Uniform =======================================================
209
+
210
+ template<typename RNG>
211
+ void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) {
212
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() {
213
+ std::lock_guard<std::mutex> lock(generator->mutex_);
214
+ auto from = static_cast<scalar_t>(from_);
215
+ auto to = static_cast<scalar_t>(to_);
216
+ at::uniform_real_distribution<scalar_t> uniform(from, to);
217
+ cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t {
218
+ return static_cast<scalar_t>(uniform(generator));
219
+ });
220
+ });
221
+ }
222
+
223
+ template<typename RNG>
224
+ struct UniformKernel {
225
+ void operator()(TensorIteratorBase& iter, double from, double to, c10::optional<Generator> gen) {
226
+ uniform_kernel(iter, from, to, check_generator<RNG>(gen));
227
+ }
228
+ };
229
+
230
+ // ==================================================== Cauchy ========================================================
231
+
232
+ template<typename RNG>
233
+ void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) {
234
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() {
235
+ std::lock_guard<std::mutex> lock(generator->mutex_);
236
+ at::cauchy_distribution<double> cauchy(median, sigma);
237
+ cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t {
238
+ return static_cast<scalar_t>(cauchy(generator));
239
+ });
240
+ });
241
+ }
242
+
243
+ template<typename RNG>
244
+ struct CauchyKernel {
245
+ void operator()(TensorIteratorBase& iter, double median, double sigma, c10::optional<Generator> gen) {
246
+ cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
247
+ }
248
+ };
249
+
250
+ // ================================================== LogNormal =======================================================
251
+
252
+ template<typename RNG>
253
+ void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) {
254
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() {
255
+ std::lock_guard<std::mutex> lock(generator->mutex_);
256
+ at::lognormal_distribution<double> logNormal(mean, std);
257
+ cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t {
258
+ return static_cast<scalar_t>(logNormal(generator));
259
+ });
260
+ });
261
+ }
262
+
263
+ template<typename RNG>
264
+ struct LogNormalKernel {
265
+ void operator()(TensorIteratorBase& iter, double mean, double std, c10::optional<Generator> gen) {
266
+ log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
267
+ }
268
+ };
269
+
270
+ // =================================================== Geometric ======================================================
271
+
272
+ template<typename RNG>
273
+ void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) {
274
+ AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() {
275
+ std::lock_guard<std::mutex> lock(generator->mutex_);
276
+ at::geometric_distribution<double> geometric(p);
277
+ cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t {
278
+ return static_cast<scalar_t>(geometric(generator));
279
+ });
280
+ });
281
+ }
282
+
283
+ template<typename RNG>
284
+ struct GeometricKernel {
285
+ void operator()(TensorIteratorBase& iter, double p, c10::optional<Generator> gen) {
286
+ geometric_kernel(iter, p, check_generator<RNG>(gen));
287
+ }
288
+ };
289
+
290
+ // ================================================== Exponential =====================================================
291
+
292
+ template<typename RNG>
293
+ void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) {
294
+ TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
295
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() {
296
+ std::lock_guard<std::mutex> lock(generator->mutex_);
297
+ at::exponential_distribution<double> exponential(lambda);
298
+ cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t {
299
+ return static_cast<scalar_t>(exponential(generator));
300
+ });
301
+ });
302
+ }
303
+
304
+ template<typename RNG>
305
+ struct ExponentialKernel {
306
+ void operator()(TensorIteratorBase& iter, double lambda, c10::optional<Generator> gen) {
307
+ exponential_kernel(iter, lambda, check_generator<RNG>(gen));
308
+ }
309
+ };
310
+
311
+ // ================================================== Bernoulli =======================================================
312
+
313
+ template<typename RNG>
314
+ void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) {
315
+ AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
316
+ self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
317
+ // See Note [Acquire lock when using random generators]
318
+ std::lock_guard<std::mutex> lock(generator->mutex_);
319
+ using self_t = scalar_t;
320
+ auto p_cpu = p_.to(kCPU);
321
+ auto p = expand_inplace(self, p_cpu);
322
+ auto iter = TensorIteratorConfig()
323
+ .add_output(self)
324
+ .add_input(*p)
325
+ .check_all_same_dtype(false)
326
+ .build();
327
+ if (p->scalar_type() == kDouble) {
328
+ cpu_serial_kernel(iter, [&](const double p_val) -> self_t {
329
+ at::bernoulli_distribution<double> bernoulli(p_val);
330
+ return static_cast<self_t>(bernoulli(generator));
331
+ });
332
+ } else {
333
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
334
+ p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
335
+ using p_t = scalar_t;
336
+ cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
337
+ at::bernoulli_distribution<float> bernoulli(p_val);
338
+ return static_cast<self_t>(bernoulli(generator));
339
+ });
340
+ });
341
+ }
342
+ });
343
+ }
344
+
345
+ template<typename RNG>
346
+ void bernoulli_kernel(const TensorBase &self, double p, RNG generator) {
347
+ AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
348
+ self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
349
+ // See Note [Acquire lock when using random generators]
350
+ std::lock_guard<std::mutex> lock(generator->mutex_);
351
+ auto iter = TensorIterator::borrowing_nullary_op(self);
352
+ cpu_serial_kernel(iter, [p, generator]() -> scalar_t {
353
+ at::bernoulli_distribution<double> bernoulli(p);
354
+ return static_cast<scalar_t>(bernoulli(generator));
355
+ });
356
+ });
357
+ }
358
+
359
+ template<typename RNG>
360
+ struct BernoulliKernel {
361
+ void operator()(const TensorBase &self, double p, c10::optional<Generator> gen) {
362
+ bernoulli_kernel(self, p, check_generator<RNG>(gen));
363
+ }
364
+ void operator()(const TensorBase &self, const TensorBase &p_, c10::optional<Generator> gen) {
365
+ bernoulli_kernel(self, p_, check_generator<RNG>(gen));
366
+ }
367
+ };
368
+
369
+ }}}}}
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Config.h>
4
+ #if AT_MKLDNN_ENABLED()
5
+ #include <ATen/Tensor.h>
6
+ #include <ATen/native/quantized/PackedParams.h>
7
+ #include <ideep.hpp>
8
+ #include <cpuinfo.h>
9
+
10
+ #include <c10/util/CallOnce.h>
11
+
12
+ using PrimitiveCacheKey = std::tuple<
13
+ double, // input_scale
14
+ int64_t, // input_zero_point
15
+ std::vector<int64_t>, // input_shape
16
+ double, // output_scale
17
+ int64_t, // output_zero_point
18
+ int64_t, // OMP_number_of_threads
19
+ double, // accum_scale
20
+ int64_t>; // accum_zero_point
21
+
22
+ enum CacheKeyIndex {
23
+ InputScale,
24
+ InputZeroPoint,
25
+ InputShape,
26
+ OutputScale,
27
+ OutputZeroPoint,
28
+ NumOfThreads,
29
+ };
30
+
31
+ // Base class of primitive cache
32
+ struct PrimitiveCache {
33
+ PrimitiveCacheKey key;
34
+
35
+ bool hit(const PrimitiveCacheKey& key) {
36
+ return this->key == key;
37
+ }
38
+ };
39
+
40
+ using LinearParams = ideep::matmul_forward_params;
41
+ using Conv = dnnl::convolution_forward;
42
+ using ConvDesc = dnnl::convolution_forward::primitive_desc;
43
+ using ConvParams = ideep::convolution_forward_params;
44
+ using Deconv = dnnl::deconvolution_forward;
45
+ using DeconvDesc = dnnl::deconvolution_forward::primitive_desc;
46
+ using DeconvParams = ideep::deconv_forward_params;
47
+
48
+ struct LinearPrimitiveCache : PrimitiveCache {
49
+ LinearPrimitiveCache() {}
50
+
51
+ LinearPrimitiveCache(
52
+ const PrimitiveCacheKey& key,
53
+ const LinearParams& param) {
54
+ this->key = key;
55
+ this->param = param;
56
+ }
57
+
58
+ LinearParams param;
59
+
60
+ // For dynamic qlinear, scale and zero point
61
+ // are set at execution time. So we only need to compare
62
+ // the rest part of key.
63
+ bool hit_dynamic(const PrimitiveCacheKey& new_key) {
64
+ auto cached_input_shape = std::get<InputShape>(this->key);
65
+ auto new_input_shape = std::get<InputShape>(new_key);
66
+ return (
67
+ cached_input_shape == new_input_shape &&
68
+ std::get<NumOfThreads>(this->key) == std::get<NumOfThreads>(new_key));
69
+ }
70
+
71
+ LinearParams& get_param() {
72
+ return param;
73
+ }
74
+ };
75
+
76
+ struct ConvPrimitiveCache : PrimitiveCache {
77
+ ConvPrimitiveCache() {}
78
+
79
+ ConvPrimitiveCache(
80
+ const PrimitiveCacheKey& key,
81
+ const ConvParams& params) {
82
+ this->key = key;
83
+ this->params = params;
84
+ }
85
+
86
+ ConvParams params;
87
+
88
+ ConvParams& get_params() {
89
+ return params;
90
+ }
91
+ };
92
+
93
+ struct DeconvPrimitiveCache : PrimitiveCache {
94
+ DeconvPrimitiveCache() {}
95
+
96
+ DeconvPrimitiveCache(
97
+ const PrimitiveCacheKey& key,
98
+ const DeconvParams& params) {
99
+ this->key = key;
100
+ this->params = params;
101
+ }
102
+
103
+ DeconvParams params;
104
+
105
+ DeconvParams& get_params() {
106
+ return params;
107
+ }
108
+ };
109
+
110
+ enum PostOps {
111
+ NoPostOp,
112
+ Relu,
113
+ LeakyRelu,
114
+ Tanh,
115
+ Gelu
116
+ };
117
+
118
+ static std::unordered_map<std::string, PostOps> POST_OP_TABLE = {
119
+ {"none", NoPostOp},
120
+ {"relu", Relu},
121
+ {"leaky_relu", LeakyRelu},
122
+ {"tanh", Tanh},
123
+ {"gelu", Gelu}
124
+ };
125
+
126
+ struct PackedLinearWeightsOnednn : public LinearPackedParamsBase {
127
+ PackedLinearWeightsOnednn(
128
+ std::unique_ptr<ideep::tensor> weight,
129
+ c10::optional<ideep::tensor> bias,
130
+ at::Tensor orig_weight,
131
+ c10::optional<at::Tensor> orig_bias)
132
+ : weight_(std::move(weight)),
133
+ bias_(std::move(bias)),
134
+ orig_weight_(std::move(orig_weight)),
135
+ orig_bias_(std::move(orig_bias)) {
136
+ cache_initialized_flag = std::make_unique<c10::once_flag>();
137
+ }
138
+ std::unique_ptr<ideep::tensor> weight_;
139
+ c10::optional<ideep::tensor> bias_;
140
+ at::Tensor orig_weight_;
141
+ c10::optional<at::Tensor> orig_bias_;
142
+
143
+ at::Tensor apply(
144
+ at::Tensor input,
145
+ double output_scale,
146
+ int64_t output_zero_point) override;
147
+ at::Tensor apply_relu(
148
+ at::Tensor input,
149
+ double output_scale,
150
+ int64_t output_zero_point) override;
151
+
152
+ at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
153
+ at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;
154
+
155
+ at::Tensor apply_leaky_relu(
156
+ at::Tensor input,
157
+ double output_scale,
158
+ int64_t output_zero_point,
159
+ double negative_slope);
160
+
161
+ at::Tensor apply_tanh(
162
+ at::Tensor input,
163
+ double output_scale,
164
+ int64_t output_zero_point);
165
+
166
+ std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
167
+
168
+ c10::optional<at::Tensor> bias() override {
169
+ return orig_bias_;
170
+ }
171
+
172
+ static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
173
+ at::Tensor weight,
174
+ c10::optional<at::Tensor> bias);
175
+
176
+ private:
177
+ LinearPrimitiveCache prim_cache;
178
+ std::unique_ptr<c10::once_flag> cache_initialized_flag;
179
+
180
+ template <PostOps post_op>
181
+ at::Tensor apply_impl(
182
+ at::Tensor input,
183
+ double output_scale,
184
+ int64_t output_zero_point,
185
+ torch::List<at::Scalar> post_op_args = torch::List<at::Scalar>());
186
+
187
+ template <bool ReluFused>
188
+ at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false);
189
+
190
+ LinearPrimitiveCache& get_cache() {
191
+ return prim_cache;
192
+ }
193
+ };
194
+
195
+ template <int kSpatialDim = 2>
196
+ struct PackedConvWeightsOnednn : public ConvPackedParamsBase<kSpatialDim> {
197
+ PackedConvWeightsOnednn(
198
+ std::unique_ptr<ideep::tensor> weight,
199
+ c10::optional<ideep::tensor> bias,
200
+ at::Tensor orig_weight,
201
+ c10::optional<at::Tensor> orig_bias,
202
+ torch::List<int64_t> stride,
203
+ torch::List<int64_t> padding,
204
+ torch::List<int64_t> output_padding,
205
+ torch::List<int64_t> dilation,
206
+ int64_t groups,
207
+ uint8_t transpose)
208
+ : weight_(std::move(weight)),
209
+ bias_(std::move(bias)),
210
+ orig_weight_(std::move(orig_weight)),
211
+ orig_bias_(std::move(orig_bias)),
212
+ stride_(std::move(stride)),
213
+ padding_(std::move(padding)),
214
+ output_padding_(std::move(output_padding)),
215
+ dilation_(std::move(dilation)),
216
+ groups_(groups),
217
+ transpose_(transpose) {
218
+ cache_initialized_flag = std::make_unique<c10::once_flag>();
219
+ }
220
+
221
+ std::unique_ptr<ideep::tensor> weight_;
222
+ c10::optional<ideep::tensor> bias_;
223
+ at::Tensor orig_weight_;
224
+ c10::optional<at::Tensor> orig_bias_;
225
+ torch::List<int64_t> stride_;
226
+ torch::List<int64_t> padding_;
227
+ torch::List<int64_t> output_padding_;
228
+ torch::List<int64_t> dilation_;
229
+ int64_t groups_;
230
+ uint8_t transpose_;
231
+
232
+ at::Tensor apply(
233
+ const at::Tensor& input,
234
+ double output_scale,
235
+ int64_t output_zero_point) override;
236
+
237
+ at::Tensor apply_relu(
238
+ const at::Tensor& input,
239
+ double output_scale,
240
+ int64_t output_zero_point) override;
241
+
242
+ at::Tensor apply_dynamic(
243
+ const at::Tensor& input,
244
+ bool reduce_range) override;
245
+
246
+ at::Tensor apply_add(
247
+ const at::Tensor& input,
248
+ const at::Tensor& accum,
249
+ double output_scale,
250
+ int64_t output_zero_point);
251
+
252
+ at::Tensor apply_add_relu(
253
+ const at::Tensor& input,
254
+ const at::Tensor& accum,
255
+ double output_scale,
256
+ int64_t output_zero_point);
257
+
258
+ std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
259
+
260
+ static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
261
+ at::Tensor weight,
262
+ c10::optional<at::Tensor> bias,
263
+ torch::List<int64_t> stride,
264
+ torch::List<int64_t> padding,
265
+ torch::List<int64_t> output_padding,
266
+ torch::List<int64_t> dilation,
267
+ int64_t groups,
268
+ bool transpose);
269
+
270
+ torch::List<int64_t> stride() const override {
271
+ return stride_;
272
+ }
273
+
274
+ torch::List<int64_t> padding() const override {
275
+ return padding_;
276
+ }
277
+
278
+ torch::List<int64_t> output_padding() const override {
279
+ return output_padding_;
280
+ }
281
+
282
+ torch::List<int64_t> dilation() const override {
283
+ return dilation_;
284
+ }
285
+
286
+ int64_t groups() const override {
287
+ return groups_;
288
+ }
289
+
290
+ bool transpose() const override {
291
+ return (bool)transpose_;
292
+ }
293
+
294
+ private:
295
+ ConvPrimitiveCache conv_prim_cache;
296
+ DeconvPrimitiveCache deconv_prim_cache;
297
+ std::unique_ptr<c10::once_flag> cache_initialized_flag;
298
+
299
+ template <bool ReluFused>
300
+ at::Tensor apply_impl(
301
+ const at::Tensor& input,
302
+ const c10::optional<at::Tensor>& accum,
303
+ double output_scale,
304
+ int64_t output_zero_point);
305
+
306
+ ConvPrimitiveCache& get_conv_cache() {
307
+ assert(!transpose());
308
+ return conv_prim_cache;
309
+ }
310
+
311
+ DeconvPrimitiveCache& get_deconv_cache() {
312
+ assert(transpose());
313
+ return deconv_prim_cache;
314
+ }
315
+ };
316
+
317
+ namespace onednn_utils {
318
+
319
+ static ideep::attr_t create_attr_by_post_op(
320
+ const std::string& post_op_name,
321
+ const torch::List<c10::optional<at::Scalar>>& post_op_args,
322
+ const dnnl::algorithm post_algorithm) {
323
+ using ideep::tensor;
324
+ PostOps post_op = POST_OP_TABLE[post_op_name];
325
+ if (post_op == Relu) {
326
+ return ideep::attr_t::fuse_relu();
327
+ } else if (post_op == LeakyRelu) {
328
+ return ideep::attr_t::fuse_relu_v2(/*alpha=*/post_op_args[0].value().to<float>());
329
+ } else if (post_op == Tanh) {
330
+ return ideep::attr_t::fuse_tanh();
331
+ } else if (post_op == Gelu) {
332
+ return ideep::attr_t::fuse_gelu_v2(0.f, 0.f, post_algorithm);
333
+ }
334
+ return ideep::attr_t();
335
+ }
336
+
337
+ // Try to reorder tensor to expected desc at runtime
338
+ // Do it in a `try...catch...` manner to avoid oneDNN's errors
339
+ // TODO: Move it to third_party/ideep
340
+ static void try_reorder(
341
+ ideep::tensor& t,
342
+ const ideep::tensor::desc&& desc,
343
+ ideep::scale_t scales) {
344
+ if (t.get_desc() != desc) {
345
+ try {
346
+ t = t.reorder_if_differ_in(desc);
347
+ } catch (...) {
348
+ ideep::tensor&& plain = t.to_public(nullptr, t.get_data_type());
349
+ t = plain.reorder_if_differ_in(desc);
350
+ }
351
+ t.set_scale(scales);
352
+ }
353
+ }
354
+
355
+ // ONEDNN requires symmetric quantization of weight
356
+ // Use this util function to check.
357
+ static bool is_weight_symmetric_quant(
358
+ const at::Tensor& weight,
359
+ bool is_transposed_conv) {
360
+ bool is_symmetric = true;
361
+ const auto qtype = weight.qscheme();
362
+ if (qtype == c10::kPerTensorAffine) {
363
+ is_symmetric &= (weight.q_zero_point() == 0);
364
+ } else if (qtype == c10::kPerChannelAffine) {
365
+ if (is_transposed_conv) {
366
+ // This case is currently not supported in PyTorch
367
+ // but we do not want to raise an error in this util function.
368
+ is_symmetric = false;
369
+ } else {
370
+ auto output_channels = weight.size(0);
371
+ for (int i = 0; i < output_channels; ++i) {
372
+ auto zp = weight.q_per_channel_zero_points()[i].item<int32_t>();
373
+ is_symmetric &= (zp == 0);
374
+ }
375
+ }
376
+ } else {
377
+ // This case is currently not supported in PyTorch
378
+ // but we do not want to raise an error in this util function.
379
+ is_symmetric = false;
380
+ }
381
+ return is_symmetric;
382
+ }
383
+
384
+ // When qengine is x86, use this util func to check if onednn kernel
385
+ // is preferred than fbgemm's to get better performance.
386
+ static bool should_use_onednn_quant(
387
+ const at::Tensor& weight,
388
+ bool is_transposed_conv,
389
+ int groups,
390
+ torch::List<int64_t> output_padding) {
391
+ // Performance of onednn is only validated on Linux right now.
392
+ // Also, the heuristics for dispatching are based on perf data on Linux.
393
+ // So, for x86 qengine, we always use fbgemm kernels if OS is not Linux.
394
+ // TODO Support more OSs.
395
+ #if !defined(__linux__)
396
+ return false;
397
+ #else
398
+ bool vnni_available = cpuinfo_has_x86_avx512vnni();
399
+ bool w_sym_quant =
400
+ is_weight_symmetric_quant(weight, is_transposed_conv);
401
+ bool opad_all_zero =
402
+ std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; });
403
+ return vnni_available && (groups <= 100) && w_sym_quant && opad_all_zero;
404
+ #endif
405
+ }
406
+
407
+ } // onednn_utils
408
+
409
+ at::Tensor _qconv_prepack_onednn(
410
+ at::Tensor weight, // from CPU backend instead of QuantizedCPU
411
+ at::Tensor weight_scales, // Weight zero points must be 0 for onednn
412
+ double input_scale,
413
+ int64_t input_zero_point,
414
+ torch::List<int64_t> stride,
415
+ torch::List<int64_t> padding,
416
+ torch::List<int64_t> dilation,
417
+ int64_t groups,
418
+ c10::optional<torch::List<int64_t>> input_shape=c10::nullopt);
419
+
420
+ static at::Tensor _quantized_convolution_onednn(
421
+ at::Tensor act, // contains quantized values but not QTensor
422
+ double act_scale,
423
+ int64_t act_zero_point,
424
+ at::Tensor weight, // MKLDNN tensor with quantized values
425
+ at::Tensor weight_scales,
426
+ at::Tensor weight_zero_points,
427
+ c10::optional<at::Tensor> bias, // Bias is packed if not None
428
+ torch::List<int64_t> stride,
429
+ torch::List<int64_t> padding,
430
+ torch::List<int64_t> dilation,
431
+ bool transposed,
432
+ int64_t groups,
433
+ double inv_output_scale,
434
+ int64_t output_zero_point,
435
+ c10::optional<at::Tensor> accum=c10::nullopt, // accum to fused with conv add
436
+ double accum_scale=1.0,
437
+ int64_t accum_zero_point=0,
438
+ bool fp32_output=false,
439
+ c10::optional<c10::string_view> binary_attr=c10::nullopt,
440
+ c10::optional<at::Scalar> binary_alpha=c10::nullopt,
441
+ c10::optional<c10::string_view> unary_attr=c10::nullopt,
442
+ torch::List<c10::optional<at::Scalar>> unary_scalars=torch::List<c10::optional<at::Scalar>>(),
443
+ c10::optional<c10::string_view> unary_algorithm=c10::nullopt);
444
+
445
+ #endif // #if AT_MKLDNN_ENABLED()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef USE_XNNPACK
4
+ #include <cstdint>
5
+
6
+ #include <ATen/core/Tensor.h>
7
+ #include <ATen/native/xnnpack/Common.h>
8
+
9
+ using xnnpack_operator = at::native::xnnpack::Operator;
10
+
11
+ namespace at {
12
+ namespace native {
13
+ namespace xnnp_utils {
14
+
15
+ /*
16
+ * Return shape in the same order as the memory format
17
+ * e.g. channels_last will return NHWC instead of NCHW
18
+ */
19
+ std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in);
20
+
21
+ /*
22
+ * Input is always int8_t, output can be [int8_t, uint8_t].
23
+ * input + offset = output
24
+ * int8_t + 128 = uint8_t
25
+ * int8_t + 0 = int8_t
26
+ */
27
+ template <typename PT>
28
+ void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out);
29
+
30
+ template <int kSpatialDim>
31
+ Tensor convert_conv_weights_to_channel_last_tensor(
32
+ const at::Tensor& src,
33
+ int groups,
34
+ bool transpose);
35
+
36
+ /*
37
+ * Series of create wrapper functions to call xnn_create_[de]conv* functions.
38
+ */
39
+ C10_ALWAYS_INLINE
40
+ enum xnn_status xnnp_create_convolution2d_nhwc(
41
+ uint32_t pad_top,
42
+ uint32_t pad_right,
43
+ uint32_t pad_bottom,
44
+ uint32_t pad_left,
45
+ uint32_t kernel_h,
46
+ uint32_t kernel_w,
47
+ uint32_t stride_h,
48
+ uint32_t stride_w,
49
+ uint32_t dilation_h,
50
+ uint32_t dilation_w,
51
+ uint32_t groups,
52
+ size_t group_input_channels,
53
+ size_t group_output_channels,
54
+ size_t ip_chan_stride,
55
+ size_t op_chan_stride,
56
+ int8_t izp,
57
+ float ip_scale,
58
+ int8_t kzp,
59
+ const float* k_scales,
60
+ const int8_t* kernel,
61
+ const int32_t* bias,
62
+ int8_t ozp,
63
+ float op_scale,
64
+ int8_t op_min,
65
+ int8_t op_max,
66
+ uint32_t flags,
67
+ xnn_operator_t* op,
68
+ bool per_channel,
69
+ bool transpose) {
70
+ /* Symmetric quantization forces kzp = 0 */
71
+ TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero."
72
+ "But got: ", kzp);
73
+
74
+ if (transpose) {
75
+ TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
76
+ return xnn_create_deconvolution2d_nhwc_qs8(
77
+ pad_top, /* uint32_t output_padding_top */
78
+ pad_right, /* uint32_t output_padding_right */
79
+ pad_bottom, /* uint32_t output_padding_bottom */
80
+ pad_left, /* uint32_t output_padding_left */
81
+ kernel_h, /* uint32_t kernel_height */
82
+ kernel_w, /* uint32_t kernel_width */
83
+ stride_h, /* uint32_t stride_height */
84
+ stride_w, /* uint32_t stride_width */
85
+ dilation_h, /* uint32_t dilation_height */
86
+ dilation_w, /* uint32_t dilation_width */
87
+ groups, /* uint32_t groups */
88
+ group_input_channels, /* size_t group_input_channels */
89
+ group_output_channels, /* size_t group_output_channels */
90
+ ip_chan_stride, /* size_t input_pixel_stride */
91
+ op_chan_stride, /* size_t output_pixel_stride */
92
+ izp, /* int8_t input_zero_point */
93
+ ip_scale, /* float input_scale */
94
+ k_scales[0], /* float kernel_scale */
95
+ kernel, /* const int8_t* kernel */
96
+ bias, /* const int32_t* bias */
97
+ ozp, /* int8_t output_zero_point */
98
+ op_scale, /* float output_scale */
99
+ op_min, /* int8_t output_min */
100
+ op_max, /* int8_t output_max */
101
+ flags, /* uint32_t flags */
102
+ nullptr, /* xnn_caches_t caches */
103
+ nullptr, /* xnn_weights_cache_t weights_cache */
104
+ op); /* xnn_operator_t* deconvolution_op_out */
105
+
106
+ }
107
+
108
+ if (!per_channel) {
109
+ return xnn_create_convolution2d_nhwc_qs8(
110
+ pad_top, /* uint32_t input_padding_top */
111
+ pad_right, /* uint32_t input_padding_right */
112
+ pad_bottom, /* uint32_t input_padding_bottom */
113
+ pad_left, /* uint32_t input_padding_left */
114
+ kernel_h, /* uint32_t kernel_height */
115
+ kernel_w, /* uint32_t kernel_width */
116
+ stride_h, /* uint32_t subsampling_height */
117
+ stride_w, /* uint32_t subsampling_width */
118
+ dilation_h, /* uint32_t dilation_height */
119
+ dilation_w, /* uint32_t dilation_width */
120
+ groups, /* uint32_t groups */
121
+ group_input_channels, /* size_t group_input_channels */
122
+ group_output_channels, /* size_t group_output_channels*/
123
+ ip_chan_stride, /* size_t input_channel_stride */
124
+ op_chan_stride, /* size_t output_channel_stride */
125
+ izp, /* int8_t input_zero_point */
126
+ ip_scale, /* float input_scale */
127
+ k_scales[0], /* float kernel_scale */
128
+ kernel, /* const int8_t* kernel */
129
+ bias, /* const int32_t* bias */
130
+ ozp, /* int8_t output_zero_point */
131
+ op_scale, /* float output_scale */
132
+ op_min, /* int8_t output_min */
133
+ op_max, /* int8_t output_max */
134
+ flags, /* uint32_t flags */
135
+ nullptr, /* xnn_caches_t caches */
136
+ nullptr, /* xnn_weights_cache_t weights_cache */
137
+ op); /* xnn_operator_t* convolution_op_out */
138
+ } else { /* per_channel */
139
+ return xnn_create_convolution2d_nhwc_qs8_qc8w(
140
+ pad_top, /* uint32_t input_padding_top */
141
+ pad_right, /* uint32_t input_padding_right */
142
+ pad_bottom, /* uint32_t input_padding_bottom */
143
+ pad_left, /* uint32_t input_padding_left */
144
+ kernel_h, /* uint32_t kernel_height */
145
+ kernel_w, /* uint32_t kernel_width */
146
+ stride_h, /* uint32_t subsampling_height */
147
+ stride_w, /* uint32_t subsampling_width */
148
+ dilation_h, /* uint32_t dilation_height */
149
+ dilation_w, /* uint32_t dilation_width */
150
+ groups, /* uint32_t groups */
151
+ group_input_channels, /* size_t group_input_channels */
152
+ group_output_channels, /* size_t group_output_channels*/
153
+ ip_chan_stride, /* size_t input_channel_stride */
154
+ op_chan_stride, /* size_t output_channel_stride */
155
+ izp, /* int8_t input_zero_point */
156
+ ip_scale, /* float input_scale */
157
+ k_scales, /* const float* kernel_scale */
158
+ kernel, /* const int8_t* kernel */
159
+ bias, /* const int32_t* bias */
160
+ ozp, /* int8_t output_zero_point */
161
+ op_scale, /* float output_scale */
162
+ op_min, /* int8_t output_min */
163
+ op_max, /* int8_t output_max */
164
+ flags, /* uint32_t flags */
165
+ nullptr, /* xnn_caches_t caches */
166
+ nullptr, /* xnn_weights_cache_t weights_cache */
167
+ op); /* xnn_operator_t* convolution_op_out */
168
+ }
169
+ }
170
+
171
+ /*
172
+ * Series of reshape wrapper functions to call xnn_reshape_[de]conv* functions.
173
+ */
174
+ C10_ALWAYS_INLINE
175
+ enum xnn_status xnnp_reshape_convolution2d_nhwc(
176
+ xnn_operator_t op,
177
+ size_t batch,
178
+ size_t in_h,
179
+ size_t in_w,
180
+ pthreadpool_t pt_pool,
181
+ bool per_channel = false,
182
+ bool transpose = false,
183
+ uint32_t adj_h = 0,
184
+ uint32_t adj_w = 0) {
185
+ if(transpose) {
186
+ TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
187
+ return xnn_reshape_deconvolution2d_nhwc_qs8(
188
+ op, /* xnn_operator_t deconvolution_op */
189
+ batch, /* size_t batch_size */
190
+ in_h, /* size_t input_height */
191
+ in_w, /* size_t input_width */
192
+ adj_h, /* uint32_t adjustment_height */
193
+ adj_w, /* uint32_t adjustment_width */
194
+ nullptr, /* size_t* output_height_out */
195
+ nullptr, /* size_t* output_width_out */
196
+ pt_pool); /* pthreadpool_t threadpool */
197
+ }
198
+
199
+ size_t workspace_size = SIZE_MAX;
200
+ size_t workspace_alignment = SIZE_MAX;
201
+
202
+ if (!per_channel) {
203
+ return xnn_reshape_convolution2d_nhwc_qs8(
204
+ op, /* xnn_operator_t convolution_op */
205
+ batch, /* size_t batch_size */
206
+ in_h, /* size_t input_height */
207
+ in_w, /* size_t input_width */
208
+ &workspace_size, /* size_t* workspace_size */
209
+ &workspace_alignment, /* size_t* workspace_alignment */
210
+ nullptr, /* size_t* output_height_out */
211
+ nullptr, /* size_t* output_width_out */
212
+ pt_pool); /* pthreadpool_t threadpool */
213
+ } else { /* per_channel */
214
+ return xnn_reshape_convolution2d_nhwc_qs8_qc8w(
215
+ op, /* xnn_operator_t convolution_op */
216
+ batch, /* size_t batch_size */
217
+ in_h, /* size_t input_height */
218
+ in_w, /* size_t input_width */
219
+ &workspace_size, /* size_t* workspace_size */
220
+ &workspace_alignment, /* size_t* workspace_alignment */
221
+ nullptr, /* size_t* output_height_out */
222
+ nullptr, /* size_t* output_width_out */
223
+ pt_pool); /* pthreadpool_t threadpool */
224
+ }
225
+ }
226
+
227
+
228
+ /*
229
+ * Series of setup wrapper functions to call xnn_setup_[de]conv* functions.
230
+ */
231
+ C10_ALWAYS_INLINE
232
+ enum xnn_status xnnp_setup_convolution2d_nhwc(
233
+ xnn_operator_t op,
234
+ const int8_t* inp,
235
+ int8_t* outp,
236
+ bool per_channel = false,
237
+ bool transpose = false) {
238
+ if(transpose) {
239
+ TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
240
+
241
+ return xnn_setup_deconvolution2d_nhwc_qs8(
242
+ op, /* xnn_operator_t deconvolution_op */
243
+ inp, /* const int8_t* input */
244
+ outp); /* int8_t* output */
245
+ }
246
+
247
+ if (!per_channel) {
248
+ return xnn_setup_convolution2d_nhwc_qs8(
249
+ op, /* xnn_operator_t deconvolution_op */
250
+ nullptr, /* void workspace */
251
+ inp, /* const int8_t* input */
252
+ outp); /* int8_t* output */
253
+ } else { /* per_channel */
254
+ return xnn_setup_convolution2d_nhwc_qs8_qc8w(
255
+ op, /* xnn_operator_t deconvolution_op */
256
+ nullptr, /* void workspace */
257
+ inp, /* const int8_t* input */
258
+ outp); /* int8_t* output */
259
+ }
260
+ }
261
+
262
+
263
+ /*
264
+ * Series of wrapper functions to call xnn_create* and xnn_setup*
265
+ * functions for linear
266
+ */
267
+ C10_ALWAYS_INLINE
268
+ enum xnn_status xnnp_create_fully_connected_nc(
269
+ size_t input_channels,
270
+ size_t output_channels,
271
+ size_t input_stride,
272
+ size_t output_stride,
273
+ int8_t input_zero_point,
274
+ float input_scale,
275
+ int8_t kernel_zero_point,
276
+ float kernel_scale,
277
+ const int8_t* kernel,
278
+ const int32_t* bias,
279
+ int8_t output_zero_point,
280
+ float output_scale,
281
+ int8_t output_min,
282
+ int8_t output_max,
283
+ uint32_t flags,
284
+ xnn_operator_t* fully_connected_op_out) {
285
+ /* Symmetric quantization forces kzp = 0 */
286
+ TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero."
287
+ "But got: ", kernel_zero_point);
288
+ return xnn_create_fully_connected_nc_qs8(
289
+ input_channels, /* size_t input_channels */
290
+ output_channels, /* size_t output_channels */
291
+ input_stride, /* size_t input_stride */
292
+ output_stride, /* size_t output_stride */
293
+ input_zero_point, /* int8_t input_zero_point */
294
+ input_scale, /* float input_scale */
295
+ kernel_scale, /* float kernel_scale */
296
+ kernel, /* const int8_t* kernel */
297
+ bias, /* const int32_t* bias */
298
+ output_zero_point, /* int8_t output_zero_point */
299
+ output_scale, /* float output_scale */
300
+ output_min, /* int8_t output_min */
301
+ output_max, /* int8_t output_max */
302
+ flags, /* uint32_t flags */
303
+ nullptr, /* xnn_caches_t caches */
304
+ nullptr, /* xnn_weights_cache_t */
305
+ fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */
306
+ }
307
+
308
+ C10_ALWAYS_INLINE
309
+ enum xnn_status xnnp_reshape_fully_connected_nc(
310
+ xnn_operator_t fully_connected_op,
311
+ size_t batch_size,
312
+ pthreadpool_t threadpool) {
313
+ return xnn_reshape_fully_connected_nc_qs8(
314
+ fully_connected_op, /* xnn_operator_t fully_connected_op */
315
+ batch_size, /* size_t batch_size */
316
+ threadpool); /* pthreadpool_t threadpool */
317
+ }
318
+
319
+ C10_ALWAYS_INLINE
320
+ enum xnn_status xnnp_setup_fully_connected_nc(
321
+ xnn_operator_t fully_connected_op,
322
+ const int8_t* input,
323
+ int8_t* output) {
324
+ return xnn_setup_fully_connected_nc_qs8(
325
+ fully_connected_op, /* xnn_operator_t fully_connected_op */
326
+ input, /* const int8_t* input */
327
+ output /* int8_t* output */
328
+ );
329
+ }
330
+
331
+ } // namespace xnnp_utils
332
+ } // namespace native
333
+ } // namespace at
334
+
335
+ #endif // USE_XNNPACK
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/Factory.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+
5
+ namespace at {
6
+ namespace native {
7
+ namespace mobile {
8
+
9
+ Tensor allocate_padded_contiguous_if_needed(
10
+ const Tensor& input,
11
+ c10::MemoryFormat memory_format);
12
+
13
+ // TODO: Remove this function when at::native::empty() is modified to accept a
14
+ // custom memory allocator.
15
+
16
+ at::Tensor empty_with_tail_padding(
17
+ IntArrayRef size,
18
+ const caffe2::TypeMeta dtype,
19
+ c10::MemoryFormat memory_format,
20
+ c10::optional<DimnameList> maybe_names);
21
+
22
+ } // namespace mobile
23
+ } // namespace native
24
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamUtils.h ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/ArrayRef.h>
4
+ #include <vector>
5
+
6
+ namespace at {
7
+ namespace native {
8
+
9
+ template <typename T>
10
+ inline std::vector<T> _expand_param_if_needed(
11
+ ArrayRef<T> list_param,
12
+ const char* param_name,
13
+ int64_t expected_dim) {
14
+ if (list_param.size() == 1) {
15
+ return std::vector<T>(expected_dim, list_param[0]);
16
+ } else if ((int64_t)list_param.size() != expected_dim) {
17
+ std::ostringstream ss;
18
+ ss << "expected " << param_name << " to be a single integer value or a "
19
+ << "list of " << expected_dim << " values to match the convolution "
20
+ << "dimensions, but got " << param_name << "=" << list_param;
21
+ AT_ERROR(ss.str());
22
+ } else {
23
+ return list_param.vec();
24
+ }
25
+ }
26
+
27
+ inline std::vector<int64_t> expand_param_if_needed(
28
+ IntArrayRef list_param,
29
+ const char* param_name,
30
+ int64_t expected_dim) {
31
+ return _expand_param_if_needed(list_param, param_name, expected_dim);
32
+ }
33
+
34
+ inline std::vector<c10::SymInt> expand_param_if_needed(
35
+ SymIntArrayRef list_param,
36
+ const char* param_name,
37
+ int64_t expected_dim) {
38
+ return _expand_param_if_needed(list_param, param_name, expected_dim);
39
+ }
40
+
41
+ } // namespace native
42
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamsHash.h ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/irange.h>
4
+ #include <memory>
5
+ #include <mutex>
6
+
7
+ namespace at::native {
8
+
9
+ // Hashing machinery for Params
10
+ // Fowler–Noll–Vo hash function
11
+ // see
12
+ // https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
13
+ template <typename Params>
14
+ struct ParamsHash {
15
+ // Params must be a POD because we read out its memory
16
+ // contents as char* when hashing
17
+ static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
18
+
19
+ size_t operator()(const Params& params) const {
20
+ auto ptr = reinterpret_cast<const uint8_t*>(&params);
21
+ uint32_t value = 0x811C9DC5;
22
+ for (const auto i : c10::irange(sizeof(Params))) {
23
+ value ^= ptr[i];
24
+ value *= 0x01000193;
25
+ }
26
+ return (size_t)value;
27
+ }
28
+ };
29
+
30
+ template <typename Params>
31
+ struct ParamsEqual {
32
+ // Params must be a POD because we read out its memory
33
+ // contents as char* when comparing
34
+ static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
35
+
36
+ bool operator()(const Params& a, const Params& b) const {
37
+ auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
38
+ auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
39
+ return memcmp(ptr1, ptr2, sizeof(Params)) == 0;
40
+ }
41
+ };
42
+
43
+ // Provide explicit byte-for-byte constructors to avoid uwittingly leaving
44
+ // padding bytes unitialized (e.g., when passing Params by value)
45
+ template <typename T>
46
+ struct ParamsWrapper {
47
+ T pod;
48
+ static_assert(
49
+ std::is_standard_layout_v<T>,
50
+ "ParamsWrapper cannot wrap non-POD data");
51
+
52
+ ParamsWrapper() {
53
+ memset(&(this->pod), 0, sizeof(this->pod));
54
+ }
55
+
56
+ ParamsWrapper(const ParamsWrapper& other) {
57
+ memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
58
+ }
59
+
60
+ ParamsWrapper(ParamsWrapper&& other) noexcept {
61
+ memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
62
+ }
63
+
64
+ ParamsWrapper& operator=(const ParamsWrapper& other) {
65
+ memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
66
+ return *this;
67
+ }
68
+
69
+ ParamsWrapper& operator=(ParamsWrapper&& other) noexcept {
70
+ memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
71
+ return *this;
72
+ }
73
+
74
+ inline friend bool operator==(
75
+ const ParamsWrapper& lhs,
76
+ const ParamsWrapper& rhs) noexcept {
77
+ auto ptr1 = reinterpret_cast<const uint8_t*>(&(lhs.pod));
78
+ auto ptr2 = reinterpret_cast<const uint8_t*>(&(rhs.pod));
79
+ return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0;
80
+ }
81
+ };
82
+
83
+ // Wrapped version: this allows the outer struct to have custom copy and move
84
+ // constructors for additional safety
85
+ template <typename ParamsWrapper>
86
+ struct ParamsWrapperHash {
87
+ // Params must be a POD because we read out its memory
88
+ // contents as char* when hashing
89
+ static_assert(
90
+ std::is_standard_layout_v<decltype(ParamsWrapper::pod)>,
91
+ "ParamsWrapper cannot wrap non-POD data");
92
+
93
+ size_t operator()(const ParamsWrapper& params_wrapper) const {
94
+ auto ptr = reinterpret_cast<const uint8_t*>(&(params_wrapper.pod));
95
+ uint32_t value = 0x811C9DC5;
96
+ for (const auto i : c10::irange(sizeof(params_wrapper.pod))) {
97
+ value ^= ptr[i];
98
+ value *= 0x01000193;
99
+ }
100
+ return (size_t)value;
101
+ }
102
+ };
103
+
104
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_cpu_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cpu {
19
+
20
+ TORCH_API void _amp_foreach_non_finite_check_and_unscale_(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale);
21
+
22
+ } // namespace cpu
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_bin_edges_cpu_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cpu {
19
+
20
+ TORCH_API ::std::vector<at::Tensor> _histogramdd_bin_edges(const at::Tensor & self, at::IntArrayRef bins, c10::optional<at::ArrayRef<double>> range=c10::nullopt, const c10::optional<at::Tensor> & weight={}, bool density=false);
21
+
22
+ } // namespace cpu
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_from_bin_tensors_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _histogramdd_from_bin_tensors {
18
+ using schema = at::Tensor (const at::Tensor &, at::TensorList, const c10::optional<at::Tensor> &, bool);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_histogramdd_from_bin_tensors")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & self, at::TensorList bins, const c10::optional<at::Tensor> & weight, bool density);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, const c10::optional<at::Tensor> & weight, bool density);
26
+ };
27
+
28
+ struct TORCH_API _histogramdd_from_bin_tensors_out {
29
+ using schema = at::Tensor & (const at::Tensor &, at::TensorList, const c10::optional<at::Tensor> &, bool, at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_histogramdd_from_bin_tensors")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_histogramdd_from_bin_tensors.out(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!)")
35
+ static at::Tensor & call(const at::Tensor & self, at::TensorList bins, const c10::optional<at::Tensor> & weight, bool density, at::Tensor & out);
36
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, const c10::optional<at::Tensor> & weight, bool density, at::Tensor & out);
37
+ };
38
+
39
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_differentiable_gru_cell_backward.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/_thnn_differentiable_gru_cell_backward_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
26
+ inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _thnn_differentiable_gru_cell_backward(const at::Tensor & grad_hy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const c10::optional<at::Tensor> & input_bias, const c10::optional<at::Tensor> & hidden_bias) {
27
+ return at::_ops::_thnn_differentiable_gru_cell_backward::call(grad_hy, input_gates, hidden_gates, hx, input_bias, hidden_bias);
28
+ }
29
+
30
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_fused_gru_cell_backward.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/_thnn_fused_gru_cell_backward_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
26
+ inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _thnn_fused_gru_cell_backward(const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) {
27
+ return at::_ops::_thnn_fused_gru_cell_backward::call(grad_hy, workspace, has_bias);
28
+ }
29
+
30
+ // aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))
31
+ inline ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _thnn_fused_gru_cell_backward_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) {
32
+ return at::_ops::_thnn_fused_gru_cell_backward_out::call(grad_hy, workspace, has_bias, out0, out1, out2, out3, out4);
33
+ }
34
+ // aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))
35
+ inline ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _thnn_fused_gru_cell_backward_outf(const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) {
36
+ return at::_ops::_thnn_fused_gru_cell_backward_out::call(grad_hy, workspace, has_bias, out0, out1, out2, out3, out4);
37
+ }
38
+
39
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/asin_meta_dispatch.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace meta {
19
+
20
+ TORCH_API at::Tensor asin(const at::Tensor & self);
21
+ TORCH_API at::Tensor & asin_out(at::Tensor & out, const at::Tensor & self);
22
+ TORCH_API at::Tensor & asin_outf(const at::Tensor & self, at::Tensor & out);
23
+ TORCH_API at::Tensor & asin_(at::Tensor & self);
24
+
25
+ } // namespace meta
26
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API batch_norm_backward_reduce {
18
+ using schema = ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &, bool, bool, bool);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::batch_norm_backward_reduce")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor)")
24
+ static ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> call(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional<at::Tensor> & weight, bool input_g, bool weight_g, bool bias_g);
25
+ static ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional<at::Tensor> & weight, bool input_g, bool weight_g, bool bias_g);
26
+ };
27
+
28
+ struct TORCH_API batch_norm_backward_reduce_out {
29
+ using schema = ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &, bool, bool, bool, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::batch_norm_backward_reduce")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))")
35
+ static ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> call(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional<at::Tensor> & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3);
36
+ static ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional<at::Tensor> & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3);
37
+ };
38
+
39
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/bitwise_or.h ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <c10/util/Optional.h>
17
+
18
+
19
+
20
+ #include <ATen/ops/bitwise_or_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
26
+ inline at::Tensor & bitwise_or_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
27
+ return at::_ops::bitwise_or_Tensor_out::call(self, other, out);
28
+ }
29
+ // aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
30
+ inline at::Tensor & bitwise_or_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
31
+ return at::_ops::bitwise_or_Tensor_out::call(self, other, out);
32
+ }
33
+
34
+ // aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
35
+ inline at::Tensor & bitwise_or_out(at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
36
+ return at::_ops::bitwise_or_Scalar_out::call(self, other, out);
37
+ }
38
+ // aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
39
+ inline at::Tensor & bitwise_or_outf(const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
40
+ return at::_ops::bitwise_or_Scalar_out::call(self, other, out);
41
+ }
42
+
43
+ // aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor
44
+ inline at::Tensor bitwise_or(const at::Tensor & self, const at::Scalar & other) {
45
+ return at::_ops::bitwise_or_Scalar::call(self, other);
46
+ }
47
+
48
+ // aten::bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
49
+ inline at::Tensor bitwise_or(const at::Scalar & self, const at::Tensor & other) {
50
+ return at::_ops::bitwise_or_Scalar_Tensor::call(self, other);
51
+ }
52
+
53
+ // aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor
54
+ inline at::Tensor bitwise_or(const at::Tensor & self, const at::Tensor & other) {
55
+ return at::_ops::bitwise_or_Tensor::call(self, other);
56
+ }
57
+
58
+ // aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
59
+ inline at::Tensor & bitwise_or_out(at::Tensor & out, const at::Scalar & self, const at::Tensor & other) {
60
+ return at::_ops::bitwise_or_Scalar_Tensor_out::call(self, other, out);
61
+ }
62
+ // aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
63
+ inline at::Tensor & bitwise_or_outf(const at::Scalar & self, const at::Tensor & other, at::Tensor & out) {
64
+ return at::_ops::bitwise_or_Scalar_Tensor_out::call(self, other, out);
65
+ }
66
+
67
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautogradnonfunctional {
19
+
20
+ TORCH_API at::Tensor cat(const at::ITensorListRef & tensors, int64_t dim=0);
21
+
22
+ } // namespace compositeexplicitautogradnonfunctional
23
+ } // namespace at