Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py +277 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py +221 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py +157 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py +212 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py +210 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py +567 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py +1980 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py +139 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/central_index.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py +232 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py +142 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py +236 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py +202 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py +206 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py +213 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/bmm.py +128 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm.py +312 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_common.py +262 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_plus_mm.py +235 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/EmptyTensor.h +29 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/IndexKernels.h +630 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGuardImpl.h +174 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h +369 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h +445 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h +335 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/Factory.h +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamUtils.h +42 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamsHash.h +104 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_cpu_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_bin_edges_cpu_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_from_bin_tensors_ops.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_differentiable_gru_cell_backward.h +30 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_fused_gru_cell_backward.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/asin_meta_dispatch.h +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_ops.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/bitwise_or.h +67 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h +23 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc
ADDED
|
Binary file (6.95 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (226 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-311.pyc
ADDED
|
Binary file (59.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-311.pyc
ADDED
|
Binary file (61.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-311.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-311.pyc
ADDED
|
Binary file (7.12 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import itertools
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from ..._dynamo.utils import counters
|
| 6 |
+
|
| 7 |
+
from ..pattern_matcher import Arg, CallFunction, KeywordArg
|
| 8 |
+
from .freezing_patterns import register_binary_folding_pattern
|
| 9 |
+
|
| 10 |
+
aten = torch.ops.aten
|
| 11 |
+
prims = torch.ops.prims
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def mark_mixed_dtype_conv(conv):
|
| 15 |
+
conv_dtype = conv.meta["val"].dtype
|
| 16 |
+
if conv_dtype not in (torch.float16, torch.bfloat16):
|
| 17 |
+
return
|
| 18 |
+
|
| 19 |
+
if not len(conv.users) == 1:
|
| 20 |
+
return
|
| 21 |
+
|
| 22 |
+
conv_user = next(iter(conv.users.keys()))
|
| 23 |
+
if not isinstance(conv_user.meta["val"], torch.Tensor):
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
if not conv_user.meta["val"].dtype == torch.float32:
|
| 27 |
+
return
|
| 28 |
+
|
| 29 |
+
while conv_user.target in _binary_ops:
|
| 30 |
+
if not len(conv_user.users) == 1:
|
| 31 |
+
return
|
| 32 |
+
|
| 33 |
+
conv_user = next(iter(conv_user.users.keys()))
|
| 34 |
+
|
| 35 |
+
if not (
|
| 36 |
+
conv_user.target == prims.convert_element_type.default
|
| 37 |
+
and conv_user.args[1] == conv_dtype
|
| 38 |
+
):
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
conv.meta["_allow_conv_mixed_dtype_folding"] = conv_dtype
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def mark_mixed_dtype_allowed_convs(gm):
|
| 45 |
+
"""
|
| 46 |
+
Mark convolutions which we will binary fold even with mixed precision constants. We constant fold in the higher precision
|
| 47 |
+
for better accuracy and then recover the original precision after.
|
| 48 |
+
"""
|
| 49 |
+
for node in gm.graph.nodes:
|
| 50 |
+
if node.target is aten.convolution.default:
|
| 51 |
+
mark_mixed_dtype_conv(node)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def recover_original_precision_folded_convs(gm):
|
| 55 |
+
"""
|
| 56 |
+
After binary folding conv weights and biases to a higher dtype, recover the original precision they were in.
|
| 57 |
+
"""
|
| 58 |
+
graph = gm.graph
|
| 59 |
+
convs = [node for node in graph.nodes if node.target is aten.convolution.default]
|
| 60 |
+
for node in convs:
|
| 61 |
+
orig_dtype = node.meta.get("_allow_conv_mixed_dtype_folding", None)
|
| 62 |
+
if orig_dtype is None:
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
with graph.inserting_before(node):
|
| 66 |
+
for idx in [1, 2]:
|
| 67 |
+
old_input = node.args[idx]
|
| 68 |
+
if old_input is None:
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
new_input = graph.create_node(
|
| 72 |
+
"call_function",
|
| 73 |
+
prims.convert_element_type.default,
|
| 74 |
+
(old_input, orig_dtype),
|
| 75 |
+
)
|
| 76 |
+
node.replace_input_with(old_input, new_input)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
_binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@functools.lru_cache(None)
|
| 83 |
+
def binary_folding_init():
|
| 84 |
+
_conv_args = [Arg() for _ in range(9)]
|
| 85 |
+
_computation_ops = [aten.convolution.default]
|
| 86 |
+
_computation_calls = [CallFunction(aten.convolution.default, *_conv_args, _users=1)]
|
| 87 |
+
|
| 88 |
+
"""
|
| 89 |
+
In order to fuse add/sub/mul/div with conv, the dimensions of its
|
| 90 |
+
constant tensor must satisfy the following:
|
| 91 |
+
- with resizing, broadcast to w/ weight/bias tensor shape
|
| 92 |
+
- broadcast to the conv output shape
|
| 93 |
+
It needs to have a shape that can resize to weight/bias
|
| 94 |
+
tensor shape because we need to run the op with the conv
|
| 95 |
+
weights/bias without changing their sizes.
|
| 96 |
+
It needs to broadcast to the conv output shape so that we do
|
| 97 |
+
accidentally change the shape of op output by pre-fusing it
|
| 98 |
+
compared to eager.
|
| 99 |
+
The only dimension value shared by weight/bias/conv output
|
| 100 |
+
is they all contain a dim with value = channels-out. In the
|
| 101 |
+
conv output tensor, this is in the second dimension,
|
| 102 |
+
so the pointwise op tensor may have a second dimension of
|
| 103 |
+
value == channels-out, but all the other dimensions have to be 1
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def _op_not_broadcasting_with_conv(weight_tensor, other_tensor):
|
| 107 |
+
# According to opDoesNotBroadCastWithConv of frozen_conv_folding.cpp
|
| 108 |
+
weight_shape = weight_tensor.shape
|
| 109 |
+
other_shape = other_tensor.shape
|
| 110 |
+
if len(weight_shape) < len(other_shape):
|
| 111 |
+
return False
|
| 112 |
+
if len(weight_shape) == len(other_shape) + 1:
|
| 113 |
+
# weight shape is [o, i, *], other_shape is [o, 1...].
|
| 114 |
+
for i in reversed(range(len(other_shape))):
|
| 115 |
+
if i == 0 and weight_shape[0] == other_shape[i]:
|
| 116 |
+
continue
|
| 117 |
+
if other_shape[i] != 1:
|
| 118 |
+
return False
|
| 119 |
+
else:
|
| 120 |
+
# weight shape is [o, i, *], other_shape is [1, i, *]
|
| 121 |
+
for i in reversed(range(len(other_shape))):
|
| 122 |
+
if i == 1 and weight_shape[0] == other_shape[i]:
|
| 123 |
+
continue
|
| 124 |
+
if other_shape[i] != 1:
|
| 125 |
+
return False
|
| 126 |
+
return True
|
| 127 |
+
|
| 128 |
+
def _check_conv_and_broadcast_op(conv_node, other):
|
| 129 |
+
# According to checkConvAndBroadcastingOpPreConditions of frozen_conv_folding.cpp.
|
| 130 |
+
# conv.weight
|
| 131 |
+
if conv_node.args[1].op != "get_attr":
|
| 132 |
+
return False
|
| 133 |
+
# conv.bias
|
| 134 |
+
if conv_node.args[1] is not None and conv_node.args[1].op != "get_attr":
|
| 135 |
+
return False
|
| 136 |
+
if (
|
| 137 |
+
not isinstance(other, int)
|
| 138 |
+
and not isinstance(other, float)
|
| 139 |
+
and other.op != "get_attr"
|
| 140 |
+
):
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
if not len(conv_node.args[1].users) == 1:
|
| 144 |
+
return False
|
| 145 |
+
|
| 146 |
+
weight_meta_value = conv_node.args[1].meta.get("val")
|
| 147 |
+
if weight_meta_value is None:
|
| 148 |
+
return False
|
| 149 |
+
# Avoid fusing op that causes type promotion
|
| 150 |
+
# restricting to float avoids int/float difficulties with scalar overload
|
| 151 |
+
if not weight_meta_value.is_floating_point():
|
| 152 |
+
return False
|
| 153 |
+
if isinstance(other, torch.fx.Node) and other.op == "get_attr":
|
| 154 |
+
other_meta_value = other.meta.get("val")
|
| 155 |
+
if not other_meta_value.is_floating_point():
|
| 156 |
+
return False
|
| 157 |
+
if (
|
| 158 |
+
torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype)
|
| 159 |
+
!= weight_meta_value.dtype
|
| 160 |
+
):
|
| 161 |
+
if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False):
|
| 162 |
+
return False
|
| 163 |
+
|
| 164 |
+
if (
|
| 165 |
+
other_meta_value.dtype != torch.float
|
| 166 |
+
and weight_meta_value.dtype not in (torch.float16, torch.bfloat16)
|
| 167 |
+
):
|
| 168 |
+
return False
|
| 169 |
+
|
| 170 |
+
if not _op_not_broadcasting_with_conv(weight_meta_value, other_meta_value):
|
| 171 |
+
return False
|
| 172 |
+
else:
|
| 173 |
+
# TODO: support scalar case
|
| 174 |
+
return False
|
| 175 |
+
|
| 176 |
+
return True
|
| 177 |
+
|
| 178 |
+
def _is_foldable_pattern(match):
|
| 179 |
+
binary_node = match.output_node()
|
| 180 |
+
computation_node = binary_node.args[0]
|
| 181 |
+
other = binary_node.args[1]
|
| 182 |
+
if binary_node.args[0].target not in _computation_ops:
|
| 183 |
+
computation_node = binary_node.args[1]
|
| 184 |
+
other = binary_node.args[0]
|
| 185 |
+
if binary_node.args[0].target == aten.convolution.default:
|
| 186 |
+
return _check_conv_and_broadcast_op(computation_node, other)
|
| 187 |
+
|
| 188 |
+
return False
|
| 189 |
+
|
| 190 |
+
def resize_scalar_or_tensor_to_shape(graph, other, shape):
|
| 191 |
+
# TODO: support scalar case
|
| 192 |
+
if other.meta.get("val").numel() == 1:
|
| 193 |
+
# expand errors if the shape input has less # dims than the tensor input
|
| 194 |
+
res = graph.create_node(
|
| 195 |
+
"call_function",
|
| 196 |
+
aten.reshape.default,
|
| 197 |
+
(other, (1,)),
|
| 198 |
+
)
|
| 199 |
+
res = graph.create_node(
|
| 200 |
+
"call_function",
|
| 201 |
+
aten.expand.default,
|
| 202 |
+
(res, shape),
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
res = graph.create_node(
|
| 206 |
+
"call_function",
|
| 207 |
+
aten.reshape.default,
|
| 208 |
+
(other, shape),
|
| 209 |
+
)
|
| 210 |
+
return res
|
| 211 |
+
|
| 212 |
+
def _create_new_conv_node(graph, conv_node, binary_node, other):
|
| 213 |
+
assert conv_node.target == aten.convolution.default
|
| 214 |
+
conv_args = list(conv_node.args)
|
| 215 |
+
weight_meta_value = conv_node.args[1].meta.get("val")
|
| 216 |
+
bias = conv_args[2]
|
| 217 |
+
if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]:
|
| 218 |
+
other_reshape = resize_scalar_or_tensor_to_shape(
|
| 219 |
+
graph, other, (weight_meta_value.size(0),)
|
| 220 |
+
)
|
| 221 |
+
new_bias = graph.create_node(
|
| 222 |
+
"call_function",
|
| 223 |
+
binary_node.target,
|
| 224 |
+
(0 if bias is None else bias, other_reshape),
|
| 225 |
+
)
|
| 226 |
+
conv_args[2] = new_bias
|
| 227 |
+
else:
|
| 228 |
+
assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor]
|
| 229 |
+
weight_broadcast_shape = [1 for _ in range(len(weight_meta_value.shape))]
|
| 230 |
+
weight_broadcast_shape[0] = weight_meta_value.size(0)
|
| 231 |
+
other_reshape1 = resize_scalar_or_tensor_to_shape(
|
| 232 |
+
graph, other, tuple(weight_broadcast_shape)
|
| 233 |
+
)
|
| 234 |
+
new_weight = graph.create_node(
|
| 235 |
+
"call_function", binary_node.target, (conv_args[1], other_reshape1)
|
| 236 |
+
)
|
| 237 |
+
new_weight.meta.update(conv_args[1].meta)
|
| 238 |
+
conv_args[1] = new_weight
|
| 239 |
+
if bias is not None:
|
| 240 |
+
other_reshape = resize_scalar_or_tensor_to_shape(
|
| 241 |
+
graph, other, (weight_meta_value.size(0),)
|
| 242 |
+
)
|
| 243 |
+
new_bias = graph.create_node(
|
| 244 |
+
"call_function", binary_node.target, (bias, other_reshape)
|
| 245 |
+
)
|
| 246 |
+
new_bias.meta.update(bias.meta)
|
| 247 |
+
conv_args[2] = new_bias
|
| 248 |
+
return graph.create_node("call_function", conv_node.target, tuple(conv_args))
|
| 249 |
+
|
| 250 |
+
for _computation_call, binary_op in itertools.product(
|
| 251 |
+
_computation_calls, _binary_ops
|
| 252 |
+
):
|
| 253 |
+
|
| 254 |
+
@register_binary_folding_pattern(
|
| 255 |
+
CallFunction(binary_op, _computation_call, KeywordArg("other")),
|
| 256 |
+
extra_check=_is_foldable_pattern,
|
| 257 |
+
)
|
| 258 |
+
def folded_op(match, *args, **kwargs):
|
| 259 |
+
counters["inductor"]["binary_folding"] += 1
|
| 260 |
+
other = kwargs.get("other")
|
| 261 |
+
binary_node = match.output_node()
|
| 262 |
+
computation_node = (
|
| 263 |
+
binary_node.args[0]
|
| 264 |
+
if binary_node.args[0].target in _computation_ops
|
| 265 |
+
else binary_node.args[1]
|
| 266 |
+
)
|
| 267 |
+
graph = match.graph
|
| 268 |
+
with graph.inserting_before(binary_node):
|
| 269 |
+
# TODO: support linear?
|
| 270 |
+
assert computation_node.target == aten.convolution.default
|
| 271 |
+
new_computation_node = _create_new_conv_node(
|
| 272 |
+
graph, computation_node, binary_node, other
|
| 273 |
+
)
|
| 274 |
+
binary_node.replace_all_uses_with(new_computation_node)
|
| 275 |
+
new_computation_node.meta.update(computation_node.meta)
|
| 276 |
+
graph.erase_node(binary_node)
|
| 277 |
+
graph.erase_node(computation_node)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torch._dynamo.utils import counters
|
| 7 |
+
from torch._inductor import utils
|
| 8 |
+
|
| 9 |
+
from ..pattern_matcher import (
|
| 10 |
+
Arg,
|
| 11 |
+
CallFunction,
|
| 12 |
+
config_flag,
|
| 13 |
+
Ignored,
|
| 14 |
+
Match,
|
| 15 |
+
register_graph_pattern,
|
| 16 |
+
)
|
| 17 |
+
from .post_grad import decompose_mm_pass
|
| 18 |
+
|
| 19 |
+
aten = torch.ops.aten
|
| 20 |
+
log = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
# TODO: need a better strategy for decomposing mm
|
| 23 |
+
MIN_FIRST_DIMENSION_DECOMPOSITION = 10240
|
| 24 |
+
MAX_OTHER_DIMENSION_DECOMPOSITION = 32
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def check_device(a: Tensor, b: Tensor) -> bool:
|
| 28 |
+
return a.is_cuda and b.is_cuda
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def should_decompose_common(
|
| 32 |
+
mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None
|
| 33 |
+
) -> bool:
|
| 34 |
+
return (
|
| 35 |
+
torch._inductor.config.decompose_mem_bound_mm
|
| 36 |
+
and check_device(mat1, mat2)
|
| 37 |
+
and not utils.any_is_symbolic(mat1, mat2, input)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def should_decompose_bmm(mat1, mat2) -> bool:
|
| 42 |
+
if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
|
| 43 |
+
mat1 = mat1.meta["val"]
|
| 44 |
+
mat2 = mat2.meta["val"]
|
| 45 |
+
else:
|
| 46 |
+
return False
|
| 47 |
+
if not should_decompose_common(mat1, mat2):
|
| 48 |
+
return False
|
| 49 |
+
else:
|
| 50 |
+
if len(mat1.shape) != 3 or len(mat2.shape) != 3:
|
| 51 |
+
return False
|
| 52 |
+
if mat1.shape[0] < MIN_FIRST_DIMENSION_DECOMPOSITION:
|
| 53 |
+
return False
|
| 54 |
+
# 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION
|
| 55 |
+
if (mat1.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION) + (
|
| 56 |
+
mat1.shape[2] < MAX_OTHER_DIMENSION_DECOMPOSITION
|
| 57 |
+
) + (mat2.shape[2] < MAX_OTHER_DIMENSION_DECOMPOSITION) < 2:
|
| 58 |
+
return False
|
| 59 |
+
return True
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def should_decompose_mm(mat1, mat2) -> bool:
|
| 63 |
+
if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
|
| 64 |
+
mat1 = mat1.meta["val"]
|
| 65 |
+
mat2 = mat2.meta["val"]
|
| 66 |
+
else:
|
| 67 |
+
return False
|
| 68 |
+
return (
|
| 69 |
+
should_decompose_common(mat1, mat2)
|
| 70 |
+
and len(mat1.shape) == 2
|
| 71 |
+
and len(mat2.shape) == 2
|
| 72 |
+
and mat1.shape[0] >= MIN_FIRST_DIMENSION_DECOMPOSITION
|
| 73 |
+
and mat2.shape[0] < MAX_OTHER_DIMENSION_DECOMPOSITION
|
| 74 |
+
and mat2.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def should_decompose_mmt(mat1, mat2) -> bool:
|
| 79 |
+
if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
|
| 80 |
+
mat1 = mat1.meta["val"]
|
| 81 |
+
mat2 = mat2.meta["val"]
|
| 82 |
+
else:
|
| 83 |
+
return False
|
| 84 |
+
return (
|
| 85 |
+
should_decompose_common(mat1, mat2)
|
| 86 |
+
and len(mat1.shape) == 2
|
| 87 |
+
and len(mat2.shape) == 2
|
| 88 |
+
and mat1.shape[0] >= MIN_FIRST_DIMENSION_DECOMPOSITION
|
| 89 |
+
and mat1.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION
|
| 90 |
+
and mat2.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def should_decompose_mm_largek(mat1, mat2) -> bool:
|
| 95 |
+
if is_node_meta_valid(mat1) and is_node_meta_valid(mat2):
|
| 96 |
+
mat1 = mat1.meta["val"]
|
| 97 |
+
mat2 = mat2.meta["val"]
|
| 98 |
+
else:
|
| 99 |
+
return False
|
| 100 |
+
return (
|
| 101 |
+
should_decompose_common(mat1, mat2)
|
| 102 |
+
and len(mat1.shape) == 2
|
| 103 |
+
and len(mat2.shape) == 2
|
| 104 |
+
and mat1.shape[1] >= MIN_FIRST_DIMENSION_DECOMPOSITION
|
| 105 |
+
and mat1.shape[0] < MAX_OTHER_DIMENSION_DECOMPOSITION
|
| 106 |
+
and mat2.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def is_node_meta_valid(node: torch.fx.Node):
|
| 111 |
+
return "val" in node.meta
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def print_decompose_pattern(match: Match, inputs: List[torch.fx.Node]):
|
| 115 |
+
node = match.nodes[-1]
|
| 116 |
+
log.debug(
|
| 117 |
+
"Decompose %s with input shape: %s",
|
| 118 |
+
node.target,
|
| 119 |
+
", ".join(
|
| 120 |
+
str(input.meta["val"].shape) if "val" in input.meta else "None"
|
| 121 |
+
for input in inputs
|
| 122 |
+
),
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@register_graph_pattern(
|
| 127 |
+
CallFunction(aten.bmm, Arg(), Arg()),
|
| 128 |
+
pass_dict=decompose_mm_pass,
|
| 129 |
+
extra_check=config_flag("decompose_mem_bound_mm"),
|
| 130 |
+
)
|
| 131 |
+
def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node):
|
| 132 |
+
def repl(mat1, mat2):
|
| 133 |
+
return torch.sum(mat1[:, :, :, None] * mat2[:, None, :, :], dim=-2)
|
| 134 |
+
|
| 135 |
+
if should_decompose_bmm(mat1, mat2):
|
| 136 |
+
counters["inductor"]["decompose_bmm"] += 1
|
| 137 |
+
match.replace_by_example(repl, [mat1, mat2])
|
| 138 |
+
print_decompose_pattern(match, [mat1, mat2])
|
| 139 |
+
return
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@register_graph_pattern(
|
| 143 |
+
CallFunction(aten.addmm, Arg(), Arg(), Arg()),
|
| 144 |
+
pass_dict=decompose_mm_pass,
|
| 145 |
+
extra_check=config_flag("decompose_mem_bound_mm"),
|
| 146 |
+
)
|
| 147 |
+
def decompose_addmm(
|
| 148 |
+
match: Match,
|
| 149 |
+
mat1: torch.fx.Node,
|
| 150 |
+
mat2: torch.fx.Node,
|
| 151 |
+
mat3: torch.fx.Node,
|
| 152 |
+
):
|
| 153 |
+
def repl(mat1, mat2, mat3):
|
| 154 |
+
return torch.sum(mat2[:, :, None] * mat3[None, :, :], dim=-2) + mat1
|
| 155 |
+
|
| 156 |
+
if should_decompose_mm(mat2, mat3):
|
| 157 |
+
counters["inductor"]["decompose_addmm"] += 1
|
| 158 |
+
match.replace_by_example(repl, [mat1, mat2, mat3])
|
| 159 |
+
print_decompose_pattern(match, [mat1, mat2, mat3])
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@register_graph_pattern(
|
| 164 |
+
CallFunction(aten.mm, CallFunction(aten.permute, Arg(), Ignored()), Arg()),
|
| 165 |
+
pass_dict=decompose_mm_pass,
|
| 166 |
+
extra_check=config_flag("decompose_mem_bound_mm"),
|
| 167 |
+
)
|
| 168 |
+
def decompose_mmt(
|
| 169 |
+
match: Match,
|
| 170 |
+
mat1: torch.fx.Node,
|
| 171 |
+
mat2: torch.fx.Node,
|
| 172 |
+
):
|
| 173 |
+
def repl(mat1, mat2):
|
| 174 |
+
return torch.sum(mat1[:, :, None] * mat2[:, None, :], dim=0)
|
| 175 |
+
|
| 176 |
+
if should_decompose_mmt(mat1, mat2):
|
| 177 |
+
counters["inductor"]["decompose_mmt"] += 1
|
| 178 |
+
match.replace_by_example(repl, [mat1, mat2])
|
| 179 |
+
print_decompose_pattern(match, [mat1, mat2])
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@register_graph_pattern(
|
| 184 |
+
CallFunction(aten.mm, Arg(), Arg()),
|
| 185 |
+
pass_dict=decompose_mm_pass,
|
| 186 |
+
extra_check=config_flag("decompose_mem_bound_mm"),
|
| 187 |
+
)
|
| 188 |
+
def decompose_mm(
|
| 189 |
+
match: Match,
|
| 190 |
+
mat1: torch.fx.Node,
|
| 191 |
+
mat2: torch.fx.Node,
|
| 192 |
+
):
|
| 193 |
+
def repl(mat1, mat2):
|
| 194 |
+
return torch.sum(mat1[:, :, None] * mat2[None, :, :], dim=-2)
|
| 195 |
+
|
| 196 |
+
if should_decompose_mm(mat1, mat2):
|
| 197 |
+
counters["inductor"]["decompose_mm"] += 1
|
| 198 |
+
match.replace_by_example(repl, [mat1, mat2])
|
| 199 |
+
print_decompose_pattern(match, [mat1, mat2])
|
| 200 |
+
return
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@register_graph_pattern(
|
| 204 |
+
CallFunction(aten.mm, Arg(), Arg()),
|
| 205 |
+
pass_dict=decompose_mm_pass,
|
| 206 |
+
extra_check=config_flag("decompose_mem_bound_mm"),
|
| 207 |
+
)
|
| 208 |
+
def decompose_mm_large_k(
|
| 209 |
+
match: Match,
|
| 210 |
+
mat1: torch.fx.Node,
|
| 211 |
+
mat2: torch.fx.Node,
|
| 212 |
+
):
|
| 213 |
+
def repl(mat1, mat2):
|
| 214 |
+
mat1 = mat1.permute(1, 0)
|
| 215 |
+
return torch.sum(mat1[:, :, None] * mat2[:, None, :], dim=0)
|
| 216 |
+
|
| 217 |
+
if should_decompose_mm_largek(mat1, mat2):
|
| 218 |
+
counters["inductor"]["decompose_mm_large_k"] += 1
|
| 219 |
+
match.replace_by_example(repl, [mat1, mat2])
|
| 220 |
+
print_decompose_pattern(match, [mat1, mat2])
|
| 221 |
+
return
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from torch._dynamo.utils import counters
|
| 5 |
+
from torch._inductor import config as inductor_config
|
| 6 |
+
from torch.func import functional_call
|
| 7 |
+
|
| 8 |
+
from ..pattern_matcher import CallModuleVarArgs, Match, register_graph_pattern
|
| 9 |
+
|
| 10 |
+
from .pre_grad import efficient_conv_bn_eval_pass
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def efficient_conv_bn_eval(
|
| 14 |
+
bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor
|
| 15 |
+
):
|
| 16 |
+
"""
|
| 17 |
+
Implementation based on https://arxiv.org/abs/2305.11624
|
| 18 |
+
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
|
| 19 |
+
It leverages the associative law between convolution and affine transform,
|
| 20 |
+
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
|
| 21 |
+
It works for Eval mode of ConvBN blocks during validation, and can be used
|
| 22 |
+
for **training** as well, but only if one sets `bn.training=False`. It
|
| 23 |
+
reduces memory footprint and computation cost, at the cost of slightly
|
| 24 |
+
reduced numerical stability.
|
| 25 |
+
Args:
|
| 26 |
+
bn (nn.modules.batchnorm._BatchNorm): a BatchNorm module.
|
| 27 |
+
conv (nn.modules.conv._ConvNd): a conv module
|
| 28 |
+
x (torch.Tensor): Input feature map.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
assert bn.running_var is not None
|
| 32 |
+
|
| 33 |
+
# These lines of code are designed to deal with various cases
|
| 34 |
+
# like bn without affine transform, and conv without bias
|
| 35 |
+
weight_on_the_fly = conv.weight
|
| 36 |
+
if conv.bias is not None:
|
| 37 |
+
bias_on_the_fly = conv.bias
|
| 38 |
+
else:
|
| 39 |
+
bias_on_the_fly = torch.zeros_like(bn.running_var)
|
| 40 |
+
|
| 41 |
+
if bn.weight is not None:
|
| 42 |
+
bn_weight = bn.weight
|
| 43 |
+
else:
|
| 44 |
+
bn_weight = torch.ones_like(bn.running_var)
|
| 45 |
+
|
| 46 |
+
if bn.bias is not None:
|
| 47 |
+
bn_bias = bn.bias
|
| 48 |
+
else:
|
| 49 |
+
bn_bias = torch.zeros_like(bn.running_var)
|
| 50 |
+
|
| 51 |
+
# shape of [C_out, 1, 1, 1] in Conv2d
|
| 52 |
+
target_shape = [-1] + [1] * (conv.weight.ndim - 1)
|
| 53 |
+
if isinstance(conv, nn.modules.conv._ConvTransposeNd):
|
| 54 |
+
# for transposed conv, the C_out dimension should at index 1.
|
| 55 |
+
target_shape[:2] = [target_shape[1], target_shape[0]]
|
| 56 |
+
weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape(target_shape)
|
| 57 |
+
# shape of [C_out, 1, 1, 1] in Conv2d
|
| 58 |
+
coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
|
| 59 |
+
|
| 60 |
+
# shape of [C_out, C_in, k, k] in Conv2d
|
| 61 |
+
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
|
| 62 |
+
# shape of [C_out] in Conv2d
|
| 63 |
+
bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (
|
| 64 |
+
bias_on_the_fly - bn.running_mean
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
input = x
|
| 68 |
+
params = {"weight": weight_on_the_fly, "bias": bias_on_the_fly}
|
| 69 |
+
output = functional_call(conv, params, input)
|
| 70 |
+
return output
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@register_graph_pattern(
|
| 74 |
+
CallModuleVarArgs(
|
| 75 |
+
[
|
| 76 |
+
nn.modules.batchnorm._BatchNorm,
|
| 77 |
+
nn.BatchNorm1d,
|
| 78 |
+
nn.BatchNorm2d,
|
| 79 |
+
nn.BatchNorm3d,
|
| 80 |
+
nn.SyncBatchNorm,
|
| 81 |
+
],
|
| 82 |
+
),
|
| 83 |
+
pass_dict=efficient_conv_bn_eval_pass,
|
| 84 |
+
extra_check=lambda match: not inductor_config.freezing
|
| 85 |
+
and inductor_config.efficient_conv_bn_eval_fx_passes,
|
| 86 |
+
)
|
| 87 |
+
def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
|
| 88 |
+
# We matched a BN node
|
| 89 |
+
bn_node = match.nodes[0]
|
| 90 |
+
graph = match.graph
|
| 91 |
+
gm = graph.owning_module
|
| 92 |
+
bn_mod = getattr(gm, bn_node.target) # type: ignore[arg-type]
|
| 93 |
+
|
| 94 |
+
# We can only use efficient conv-bn for eval mode with track_running_stats
|
| 95 |
+
if not bn_mod.track_running_stats or bn_mod.training:
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
# Check if the input is Conv
|
| 99 |
+
if bn_node.args:
|
| 100 |
+
input_node = bn_node.args[0]
|
| 101 |
+
else:
|
| 102 |
+
input_node = bn_node.kwargs["input"]
|
| 103 |
+
if input_node.op != "call_module": # type: ignore[union-attr]
|
| 104 |
+
return
|
| 105 |
+
if not hasattr(gm, input_node.target): # type: ignore[arg-type, union-attr]
|
| 106 |
+
return
|
| 107 |
+
input_mod = getattr(gm, input_node.target) # type: ignore[arg-type, union-attr]
|
| 108 |
+
supported_convs = [
|
| 109 |
+
nn.Linear,
|
| 110 |
+
nn.Conv1d,
|
| 111 |
+
nn.Conv2d,
|
| 112 |
+
nn.Conv3d,
|
| 113 |
+
nn.ConvTranspose1d,
|
| 114 |
+
nn.ConvTranspose2d,
|
| 115 |
+
nn.ConvTranspose3d,
|
| 116 |
+
]
|
| 117 |
+
if not any(isinstance(input_mod, cls) for cls in supported_convs):
|
| 118 |
+
return
|
| 119 |
+
conv_node = input_node
|
| 120 |
+
# Output of conv is used by other nodes, cannot optimize
|
| 121 |
+
if len(conv_node.users) > 1: # type: ignore[union-attr]
|
| 122 |
+
return
|
| 123 |
+
|
| 124 |
+
# Find a pair of conv and bn computation nodes to optimize.
|
| 125 |
+
counters["inductor"]["efficient_conv_bn_eval"] += 1
|
| 126 |
+
|
| 127 |
+
with graph.inserting_before(conv_node):
|
| 128 |
+
# create `get_attr` node to access modules
|
| 129 |
+
# note that we directly call `create_node` to fill the `name`
|
| 130 |
+
# argument. `graph.get_attr` and
|
| 131 |
+
# `graph.call_function` does not allow the `name` argument.
|
| 132 |
+
conv_get_node = graph.create_node(
|
| 133 |
+
op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr]
|
| 134 |
+
)
|
| 135 |
+
bn_get_node = graph.create_node(
|
| 136 |
+
op="get_attr", target=bn_node.target, name="get_bn"
|
| 137 |
+
)
|
| 138 |
+
if conv_node.args: # type: ignore[union-attr]
|
| 139 |
+
conv_input = conv_node.args[0] # type: ignore[union-attr]
|
| 140 |
+
else:
|
| 141 |
+
conv_input = conv_node.kwargs["input"] # type: ignore[union-attr]
|
| 142 |
+
# prepare args for the fused function
|
| 143 |
+
args = (bn_get_node, conv_get_node, conv_input)
|
| 144 |
+
# create a new node
|
| 145 |
+
new_node = graph.create_node(
|
| 146 |
+
op="call_function",
|
| 147 |
+
target=efficient_conv_bn_eval,
|
| 148 |
+
args=args,
|
| 149 |
+
name="efficient_conv_bn_eval",
|
| 150 |
+
)
|
| 151 |
+
# this node replaces the original conv + bn, and therefore
|
| 152 |
+
# should replace the uses of bn_node
|
| 153 |
+
bn_node.replace_all_uses_with(new_node)
|
| 154 |
+
# take care of the deletion order:
|
| 155 |
+
# delete bn_node first, and then conv_node
|
| 156 |
+
graph.erase_node(bn_node)
|
| 157 |
+
graph.erase_node(conv_node)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch._inductor.compile_fx import fake_tensor_prop
|
| 5 |
+
from ..._dynamo.utils import counters
|
| 6 |
+
|
| 7 |
+
from .. import config
|
| 8 |
+
from ..pattern_matcher import (
|
| 9 |
+
_return_true,
|
| 10 |
+
CallFunction,
|
| 11 |
+
fwd_only,
|
| 12 |
+
Ignored,
|
| 13 |
+
init_once_fakemode,
|
| 14 |
+
KeywordArg,
|
| 15 |
+
Match,
|
| 16 |
+
PatternMatcherPass,
|
| 17 |
+
register_graph_pattern,
|
| 18 |
+
register_replacement,
|
| 19 |
+
stable_topological_sort,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
aten = torch.ops.aten
|
| 23 |
+
|
| 24 |
+
# First pass_patterns[0] are applied, then [1], then [2]
|
| 25 |
+
pass_patterns = [
|
| 26 |
+
PatternMatcherPass(),
|
| 27 |
+
PatternMatcherPass(),
|
| 28 |
+
PatternMatcherPass(),
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
binary_folding_pass = PatternMatcherPass()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs):
|
| 35 |
+
"""
|
| 36 |
+
Passes that are applied to the graph to freeze pass.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
from ..freezing import constant_fold
|
| 40 |
+
|
| 41 |
+
lazy_init()
|
| 42 |
+
# We need a few rounds of binary folding to get rid of all the
|
| 43 |
+
# unnecessary nodes, but may need a good method to chose the rounds number.
|
| 44 |
+
# works like: conv+binary+binary.
|
| 45 |
+
binary_folding = counters["inductor"]["binary_folding"]
|
| 46 |
+
fake_tensor_prop(gm, aot_example_inputs, True)
|
| 47 |
+
|
| 48 |
+
torch._inductor.fx_passes.binary_folding.mark_mixed_dtype_allowed_convs(gm)
|
| 49 |
+
for _ in range(4):
|
| 50 |
+
constant_fold(gm)
|
| 51 |
+
# Make sure meta['val'] is properly set for all nodes
|
| 52 |
+
fake_tensor_prop(gm, aot_example_inputs, True)
|
| 53 |
+
binary_folding_pass.apply(gm.graph) # type: ignore[arg-type]
|
| 54 |
+
# If we don't have binary folding, we don't need to run the pass again.
|
| 55 |
+
# TODO: remove the need to run fake_tensor_prop on the whole model.
|
| 56 |
+
if counters["inductor"]["binary_folding"] == binary_folding:
|
| 57 |
+
break
|
| 58 |
+
binary_folding = counters["inductor"]["binary_folding"]
|
| 59 |
+
|
| 60 |
+
torch._inductor.fx_passes.binary_folding.recover_original_precision_folded_convs(gm)
|
| 61 |
+
|
| 62 |
+
constant_fold(gm)
|
| 63 |
+
fake_tensor_prop(gm, aot_example_inputs, True)
|
| 64 |
+
|
| 65 |
+
for pattern in pass_patterns:
|
| 66 |
+
pattern.apply(gm.graph) # type: ignore[arg-type]
|
| 67 |
+
|
| 68 |
+
# The CPU weight packing always assume the conv's weight is channels last,
|
| 69 |
+
# So make sure the layout_optimization is on when doing it.
|
| 70 |
+
if (
|
| 71 |
+
torch._C._has_mkldnn
|
| 72 |
+
and config.cpp.weight_prepack
|
| 73 |
+
and config.layout_optimization
|
| 74 |
+
):
|
| 75 |
+
from .mkldnn_fusion import _eliminate_duplicate_packed_nodes
|
| 76 |
+
|
| 77 |
+
_eliminate_duplicate_packed_nodes(gm)
|
| 78 |
+
|
| 79 |
+
stable_topological_sort(gm.graph)
|
| 80 |
+
gm.recompile()
|
| 81 |
+
gm.graph.lint()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@init_once_fakemode
|
| 85 |
+
def lazy_init():
|
| 86 |
+
if torch._C._has_mkldnn and config.cpp.weight_prepack:
|
| 87 |
+
from .mkldnn_fusion import _mkldnn_weight_pack_init
|
| 88 |
+
|
| 89 |
+
_mkldnn_weight_pack_init()
|
| 90 |
+
|
| 91 |
+
from .binary_folding import binary_folding_init
|
| 92 |
+
|
| 93 |
+
addmm_patterns_init()
|
| 94 |
+
binary_folding_init()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0):
|
| 98 |
+
return register_graph_pattern(
|
| 99 |
+
pattern,
|
| 100 |
+
extra_check=extra_check,
|
| 101 |
+
pass_dict=pass_patterns[pass_number],
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def register_binary_folding_pattern(pattern, extra_check=_return_true):
|
| 106 |
+
return register_graph_pattern(
|
| 107 |
+
pattern,
|
| 108 |
+
extra_check=extra_check,
|
| 109 |
+
pass_dict=binary_folding_pass,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@functools.lru_cache(None)
|
| 114 |
+
def addmm_patterns_init():
|
| 115 |
+
if torch.cuda.is_available():
|
| 116 |
+
# workaround https://github.com/pytorch/pytorch/issues/97894
|
| 117 |
+
device = "cuda"
|
| 118 |
+
else:
|
| 119 |
+
device = "cpu"
|
| 120 |
+
val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False)
|
| 121 |
+
|
| 122 |
+
def check_concat_weights(match):
|
| 123 |
+
weights = [
|
| 124 |
+
match.kwargs["w1"],
|
| 125 |
+
match.kwargs["w2"],
|
| 126 |
+
]
|
| 127 |
+
if "w3" in match.kwargs:
|
| 128 |
+
weights.append(match.kwargs["w3"])
|
| 129 |
+
|
| 130 |
+
return all(
|
| 131 |
+
w.op == "get_attr" and w.meta["val"].shape == weights[0].meta["val"].shape
|
| 132 |
+
for w in weights
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def matmul_fuse_pattern(inp, w1, w2, w3):
|
| 136 |
+
return (inp @ w1, inp @ w2, inp @ w3)
|
| 137 |
+
|
| 138 |
+
def matmul_replacement(inp, w1, w2, w3):
|
| 139 |
+
cat_t = torch.cat((w1, w2, w3), dim=1)
|
| 140 |
+
mm = inp @ cat_t
|
| 141 |
+
return mm.chunk(3, dim=1)
|
| 142 |
+
|
| 143 |
+
register_replacement(
|
| 144 |
+
matmul_fuse_pattern,
|
| 145 |
+
matmul_replacement,
|
| 146 |
+
[val(), val(), val(), val()],
|
| 147 |
+
fwd_only,
|
| 148 |
+
pass_patterns[0],
|
| 149 |
+
extra_check=check_concat_weights,
|
| 150 |
+
exclusive_arg_names=("w1", "w2", "w3"),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def matmul_fuse_pattern_two(inp, w1, w2):
|
| 154 |
+
return (inp @ w1, inp @ w2)
|
| 155 |
+
|
| 156 |
+
def matmul_replacement_two(inp, w1, w2):
|
| 157 |
+
cat_t = torch.cat((w1, w2), dim=1)
|
| 158 |
+
mm = inp @ cat_t
|
| 159 |
+
return mm.chunk(2, dim=1)
|
| 160 |
+
|
| 161 |
+
register_replacement(
|
| 162 |
+
matmul_fuse_pattern_two,
|
| 163 |
+
matmul_replacement_two,
|
| 164 |
+
[val(), val(), val()],
|
| 165 |
+
fwd_only,
|
| 166 |
+
pass_patterns[0],
|
| 167 |
+
extra_check=check_concat_weights,
|
| 168 |
+
exclusive_arg_names=("w1", "w2"),
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3):
|
| 172 |
+
return (
|
| 173 |
+
aten.addmm(b1, inp, w1),
|
| 174 |
+
aten.addmm(b2, inp, w2),
|
| 175 |
+
aten.addmm(b3, inp, w3),
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3):
|
| 179 |
+
cat_w = torch.cat((w1, w2, w3), dim=1)
|
| 180 |
+
cat_b = torch.cat((b1, b2, b3))
|
| 181 |
+
return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1)
|
| 182 |
+
|
| 183 |
+
register_replacement(
|
| 184 |
+
addmm_fuse_pattern_second,
|
| 185 |
+
addmm_fuse_replacement_second,
|
| 186 |
+
[val() for _ in range(7)],
|
| 187 |
+
fwd_only,
|
| 188 |
+
pass_patterns[0],
|
| 189 |
+
extra_check=check_concat_weights,
|
| 190 |
+
exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"),
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def same_dtype(match):
|
| 195 |
+
return match.output_node().args[0].meta["val"].dtype == match.kwargs["dtype"]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@register_graph_pattern(
|
| 199 |
+
CallFunction(
|
| 200 |
+
torch.ops.prims.convert_element_type.default,
|
| 201 |
+
Ignored(),
|
| 202 |
+
KeywordArg("dtype"),
|
| 203 |
+
),
|
| 204 |
+
pass_dict=pass_patterns[0],
|
| 205 |
+
extra_check=same_dtype,
|
| 206 |
+
)
|
| 207 |
+
def unnecessary_dtype_convert(match: Match, **kwargs):
|
| 208 |
+
"""Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding"""
|
| 209 |
+
graph = match.graph
|
| 210 |
+
node = match.output_node()
|
| 211 |
+
node.replace_all_uses_with(node.args[0])
|
| 212 |
+
graph.erase_node(node)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import traceback
|
| 6 |
+
|
| 7 |
+
import numpy
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.optim as optim
|
| 11 |
+
|
| 12 |
+
from .. import config
|
| 13 |
+
|
| 14 |
+
logger: logging.Logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
MAIN_RANDOM_SEED = 1337
|
| 17 |
+
|
| 18 |
+
# Set the CUBLAS_WORKSPACE_CONFIG environment variable
|
| 19 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# If the two forward functions involve any non-deterministic operations,
|
| 23 |
+
# such as certain types of parallelism or asynchronous execution,
|
| 24 |
+
# this can also lead to different outputs.
|
| 25 |
+
def set_deterministic() -> None:
|
| 26 |
+
"""Make torch manual seed deterministic."""
|
| 27 |
+
|
| 28 |
+
torch.manual_seed(MAIN_RANDOM_SEED)
|
| 29 |
+
random.seed(MAIN_RANDOM_SEED)
|
| 30 |
+
numpy.random.seed(MAIN_RANDOM_SEED)
|
| 31 |
+
torch.use_deterministic_algorithms(True)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def clean_memory() -> None:
|
| 35 |
+
"""Clean memory to avoid OOM."""
|
| 36 |
+
gc.collect()
|
| 37 |
+
torch.cuda.empty_cache()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# We compare the numerical results before and after pre/post grad fx passes
|
| 41 |
+
# transformation to make sure the numerical results are the same.
|
| 42 |
+
def compare_dict_tensors(dict_base, dict_control, precision):
|
| 43 |
+
if len(set(dict_base.keys())) != len(set(dict_control.keys())):
|
| 44 |
+
logger.warning("Mismatch keys found before and after pre/post grad fx passes.")
|
| 45 |
+
logger.debug("keys before pre/post grad fx passes %s", dict_base.keys())
|
| 46 |
+
logger.debug("keys after pre/post grad fx passes %s", dict_control.keys())
|
| 47 |
+
return False
|
| 48 |
+
is_allclose = True
|
| 49 |
+
for key in dict_base.keys():
|
| 50 |
+
if key not in dict_control:
|
| 51 |
+
logger.warning(
|
| 52 |
+
"Mismatch parameter name %s does not exist after pre/post grad fx passes",
|
| 53 |
+
key,
|
| 54 |
+
)
|
| 55 |
+
# Some parameters have `None`, and not every param has a valid .grad field, we skip them
|
| 56 |
+
if dict_base[key] is None or dict_control[key] is None:
|
| 57 |
+
continue
|
| 58 |
+
if not torch.allclose(
|
| 59 |
+
dict_base[key],
|
| 60 |
+
dict_control[key],
|
| 61 |
+
rtol=precision,
|
| 62 |
+
atol=precision,
|
| 63 |
+
equal_nan=True,
|
| 64 |
+
):
|
| 65 |
+
logger.warning(
|
| 66 |
+
"Mismatch parameter values found before and after pre/post grad fx passes."
|
| 67 |
+
)
|
| 68 |
+
logger.debug("value before pre/post grad fx passes %s", dict_base[key])
|
| 69 |
+
logger.debug("value after pre/post grad fx passes %s", dict_control[key])
|
| 70 |
+
is_allclose = False
|
| 71 |
+
return is_allclose
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def compare_tuple_tensors(tuple_base, tuple_control, precision):
|
| 75 |
+
if len(tuple_base) != len(tuple_control):
|
| 76 |
+
logger.warning(
|
| 77 |
+
"Mismatch fw output length. before transformation: %s, after transformation: %s",
|
| 78 |
+
len(tuple_base),
|
| 79 |
+
len(tuple_control),
|
| 80 |
+
)
|
| 81 |
+
return False
|
| 82 |
+
is_allclose = True
|
| 83 |
+
for i in range(len(tuple_base)):
|
| 84 |
+
# Some parameters have `None`, we skip them
|
| 85 |
+
if tuple_base[i] is None or tuple_control[i] is None:
|
| 86 |
+
continue
|
| 87 |
+
if not torch.allclose(
|
| 88 |
+
tuple_base[i],
|
| 89 |
+
tuple_control[i],
|
| 90 |
+
rtol=precision,
|
| 91 |
+
atol=precision,
|
| 92 |
+
equal_nan=True,
|
| 93 |
+
):
|
| 94 |
+
logger.debug(
|
| 95 |
+
"forward output before pre/post grad fx passes %s", tuple_base[i]
|
| 96 |
+
)
|
| 97 |
+
logger.debug(
|
| 98 |
+
"forward output after pre/post grad fx passes %s", tuple_control[i]
|
| 99 |
+
)
|
| 100 |
+
is_allclose = False
|
| 101 |
+
return is_allclose
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def compare_parameters(model_base, model_control, precision):
|
| 105 |
+
return compare_dict_tensors(
|
| 106 |
+
dict(model_base.named_parameters()),
|
| 107 |
+
dict(model_control.named_parameters()),
|
| 108 |
+
precision,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def compare_forward_output(pred_base, pred_control, precision):
|
| 113 |
+
return compare_tuple_tensors(
|
| 114 |
+
pred_base,
|
| 115 |
+
pred_control,
|
| 116 |
+
precision,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def compare_gradients(model_base, model_control, precision):
|
| 121 |
+
grad_base = {key: param.grad for key, param in model_base.named_parameters()}
|
| 122 |
+
grad_pt2 = {key: param.grad for key, param in model_control.named_parameters()}
|
| 123 |
+
return compare_dict_tensors(
|
| 124 |
+
grad_base,
|
| 125 |
+
grad_pt2,
|
| 126 |
+
precision,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def run_model(
|
| 131 |
+
model_base, model_control, model_input, num_iterations=10, precision=1e-4
|
| 132 |
+
):
|
| 133 |
+
clean_memory()
|
| 134 |
+
for i in range(num_iterations):
|
| 135 |
+
logger.info("start %s iteration", i)
|
| 136 |
+
set_deterministic()
|
| 137 |
+
pred_base = model_base(*model_input)
|
| 138 |
+
set_deterministic()
|
| 139 |
+
pred_control = model_control(*model_input)
|
| 140 |
+
|
| 141 |
+
res = compare_parameters(model_base, model_control, precision)
|
| 142 |
+
logger.info("compare parameters. Numerical result : %s", res)
|
| 143 |
+
|
| 144 |
+
res = compare_forward_output(pred_base, pred_control, precision)
|
| 145 |
+
logger.info("compare loss/predict. Numerical result : %s", res)
|
| 146 |
+
# tensor may not have a grad_fn
|
| 147 |
+
try:
|
| 148 |
+
_ = pred_base[0].sum().backward(retain_graph=True)
|
| 149 |
+
_ = pred_control[0].sum().backward(retain_graph=True)
|
| 150 |
+
res = compare_gradients(model_base, model_control, precision)
|
| 151 |
+
logger.info("compare param grad. Numerical result : %s", res)
|
| 152 |
+
except Exception as e:
|
| 153 |
+
logger.exception("Exception %s when compare gradients", e)
|
| 154 |
+
traceback.print_exc()
|
| 155 |
+
|
| 156 |
+
if config.fx_passes_numeric_check["requires_optimizer"]:
|
| 157 |
+
try:
|
| 158 |
+
optimizer_base = optim.SGD(
|
| 159 |
+
[param for name, param in model_base.named_parameters()], lr=0.01
|
| 160 |
+
)
|
| 161 |
+
optimizer_base.step()
|
| 162 |
+
|
| 163 |
+
optimizer_control = optim.SGD(
|
| 164 |
+
[param for name, param in model_control.named_parameters()], lr=0.01
|
| 165 |
+
)
|
| 166 |
+
optimizer_control.step()
|
| 167 |
+
|
| 168 |
+
res = compare_parameters(model_base, model_control, precision)
|
| 169 |
+
logger.info(
|
| 170 |
+
"compare parameters with optimizer added. Numerical result : %s",
|
| 171 |
+
res,
|
| 172 |
+
)
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.exception(
|
| 175 |
+
"Exception %s when optimizer is added to check parameter names", e
|
| 176 |
+
)
|
| 177 |
+
traceback.print_exc()
|
| 178 |
+
else:
|
| 179 |
+
logger.warning(
|
| 180 |
+
"no parameter with optimizer to compare with length %s before transformation"
|
| 181 |
+
" and the length %s after transformation",
|
| 182 |
+
len(dict(model_base.named_parameters())),
|
| 183 |
+
len(dict(model_control.named_parameters())),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def numeric_check_if_enabled(
|
| 188 |
+
gm_before_fx_passes,
|
| 189 |
+
gm_after_fx_passes,
|
| 190 |
+
example_inputs,
|
| 191 |
+
num_iterations,
|
| 192 |
+
precision,
|
| 193 |
+
):
|
| 194 |
+
# need to topo-sort graphmodule before we run the model,
|
| 195 |
+
# otherwise it may fail as refer before def
|
| 196 |
+
# fail silently in order not to block the model run
|
| 197 |
+
try:
|
| 198 |
+
with torch.autograd.set_detect_anomaly(True):
|
| 199 |
+
run_model(
|
| 200 |
+
gm_before_fx_passes,
|
| 201 |
+
gm_after_fx_passes,
|
| 202 |
+
example_inputs,
|
| 203 |
+
num_iterations=num_iterations,
|
| 204 |
+
precision=precision,
|
| 205 |
+
)
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.warning(
|
| 208 |
+
"Runtime numeric check failed in pre grad fx passes with error: %s", e
|
| 209 |
+
)
|
| 210 |
+
traceback.print_exc()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
from typing import List, Optional, Set, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torch._inductor import utils
|
| 7 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 8 |
+
from torch.utils._mode_utils import no_dispatch
|
| 9 |
+
from torch.utils._triton import has_triton
|
| 10 |
+
|
| 11 |
+
from ..pattern_matcher import (
|
| 12 |
+
fwd_only,
|
| 13 |
+
joint_fwd_bwd,
|
| 14 |
+
Match,
|
| 15 |
+
MatchContext,
|
| 16 |
+
register_replacement,
|
| 17 |
+
)
|
| 18 |
+
from ..utils import is_view
|
| 19 |
+
|
| 20 |
+
aten = torch.ops.aten
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# This flag is only used for testing purpose.
|
| 24 |
+
# Changing it to True will ignore comparing do_bench times
|
| 25 |
+
# between original pattern and padded one.
|
| 26 |
+
_skip_do_bench_times = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def fetch_fake_tensors(match, kwarg_names) -> List[Tensor]:
|
| 30 |
+
kwargs = match.kwargs
|
| 31 |
+
return [kwargs[name].meta["val"] for name in kwarg_names]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def unwrap_fake_args(*arg_names):
|
| 35 |
+
def decorator(func):
|
| 36 |
+
def wrapper(match):
|
| 37 |
+
fake_tensors = fetch_fake_tensors(match, arg_names)
|
| 38 |
+
return func(*fake_tensors)
|
| 39 |
+
|
| 40 |
+
return wrapper
|
| 41 |
+
|
| 42 |
+
return decorator
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_alignment_size(x: Tensor) -> int:
|
| 46 |
+
if x.dtype == torch.float16 or x.dtype == torch.half or x.dtype == torch.bfloat16:
|
| 47 |
+
return 8
|
| 48 |
+
elif x.dtype == torch.float32 or x.dtype == torch.float:
|
| 49 |
+
return 4
|
| 50 |
+
else:
|
| 51 |
+
return 0
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def check_device(a: Tensor, b: Tensor) -> bool:
|
| 55 |
+
return a.is_cuda and b.is_cuda
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def check_dtype(a: Tensor, b: Tensor) -> bool:
|
| 59 |
+
return a.is_floating_point() and b.is_floating_point()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _result_layout_affects_graph_output(match: Match) -> bool:
|
| 63 |
+
"""
|
| 64 |
+
Check if the matched GEMM operation potentially affects the graph output strides.
|
| 65 |
+
returns True if the matched op's output buffer does not pass through functions which certainly
|
| 66 |
+
redefine the memory layout before being part of the graph output.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
if match.ctx is not None:
|
| 70 |
+
assert isinstance(match.ctx, MatchContext)
|
| 71 |
+
search_node: torch.fx.Node = match.output_node()
|
| 72 |
+
else:
|
| 73 |
+
return True
|
| 74 |
+
|
| 75 |
+
assert search_node is not None
|
| 76 |
+
seen: Set[torch.fx.Node] = set()
|
| 77 |
+
|
| 78 |
+
def find_output(node: torch.fx.Node, is_start_node=False):
|
| 79 |
+
if not isinstance(node, torch.fx.Node):
|
| 80 |
+
return False
|
| 81 |
+
if node in seen:
|
| 82 |
+
return False
|
| 83 |
+
seen.add(node)
|
| 84 |
+
if node.op == "output":
|
| 85 |
+
return True
|
| 86 |
+
if node.op != "call_function":
|
| 87 |
+
return False
|
| 88 |
+
if not is_start_node and (
|
| 89 |
+
(not isinstance(node.target, torch._ops.OpOverload))
|
| 90 |
+
or (not is_view(node.target))
|
| 91 |
+
):
|
| 92 |
+
return False
|
| 93 |
+
if node.users is not None and len(node.users) > 0:
|
| 94 |
+
for n in node.users:
|
| 95 |
+
if find_output(n):
|
| 96 |
+
return True
|
| 97 |
+
return False
|
| 98 |
+
|
| 99 |
+
return find_output(search_node, True)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def should_pad_common(
|
| 103 |
+
mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None
|
| 104 |
+
) -> bool:
|
| 105 |
+
# It's fine we have symbolic shapes or strides as long as they
|
| 106 |
+
# have hints. Later, we will make sure we only pad non-symbolic dimensions.
|
| 107 |
+
def valid_shape_and_stride(t: Optional[Tensor]) -> bool:
|
| 108 |
+
if t is None:
|
| 109 |
+
return True
|
| 110 |
+
|
| 111 |
+
symbolic_cnt = 0
|
| 112 |
+
for x in t.size():
|
| 113 |
+
if isinstance(x, int):
|
| 114 |
+
continue
|
| 115 |
+
elif utils.is_symbolic(x):
|
| 116 |
+
if not x.node.has_hint():
|
| 117 |
+
return False
|
| 118 |
+
symbolic_cnt += 1
|
| 119 |
+
else:
|
| 120 |
+
return False
|
| 121 |
+
# filter out cases where all dimentions are symbolic
|
| 122 |
+
if symbolic_cnt == len(t.size()):
|
| 123 |
+
return False
|
| 124 |
+
return all(
|
| 125 |
+
isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint())
|
| 126 |
+
for x in t.stride()
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return (
|
| 130 |
+
torch._inductor.config.shape_padding
|
| 131 |
+
and check_device(mat1, mat2)
|
| 132 |
+
and check_dtype(mat1, mat2)
|
| 133 |
+
and all(valid_shape_and_stride(t) for t in (mat1, mat2, input))
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_padded_length(x: Union[int, torch.SymInt], alignment_size) -> int:
|
| 138 |
+
# we don't pad x if it is symbolic
|
| 139 |
+
if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0:
|
| 140 |
+
return 0
|
| 141 |
+
return int((x // alignment_size + 1) * alignment_size) - x
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor:
|
| 145 |
+
if padded_length == 0:
|
| 146 |
+
return x
|
| 147 |
+
pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :])
|
| 148 |
+
return torch.cat([x, pad], dim=dim)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def addmm_pattern(
|
| 152 |
+
input: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float
|
| 153 |
+
) -> Tensor:
|
| 154 |
+
return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def should_pad_addmm(match: Match) -> bool:
|
| 158 |
+
if (
|
| 159 |
+
torch._inductor.config.keep_output_stride
|
| 160 |
+
and _result_layout_affects_graph_output(match)
|
| 161 |
+
):
|
| 162 |
+
return False
|
| 163 |
+
mat1, mat2, input = fetch_fake_tensors(match, ("mat1", "mat2", "input"))
|
| 164 |
+
return should_pad_common(mat1, mat2, input) and should_pad_bench(
|
| 165 |
+
mat1, mat2, torch.ops.aten.addmm, input=input
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def addmm_replace(
|
| 170 |
+
input: Optional[Tensor], mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0
|
| 171 |
+
) -> Tensor:
|
| 172 |
+
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
|
| 173 |
+
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
|
| 174 |
+
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
|
| 175 |
+
|
| 176 |
+
if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0:
|
| 177 |
+
return pad_addmm(
|
| 178 |
+
input,
|
| 179 |
+
mat1,
|
| 180 |
+
mat2,
|
| 181 |
+
m_padded_length,
|
| 182 |
+
k_padded_length,
|
| 183 |
+
n_padded_length,
|
| 184 |
+
beta,
|
| 185 |
+
alpha,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def pad_addmm(
|
| 192 |
+
input: Optional[Tensor],
|
| 193 |
+
mat1: Tensor,
|
| 194 |
+
mat2: Tensor,
|
| 195 |
+
m_padded_length: int,
|
| 196 |
+
k_padded_length: int,
|
| 197 |
+
n_padded_length: int,
|
| 198 |
+
beta=1.0,
|
| 199 |
+
alpha=1.0,
|
| 200 |
+
):
|
| 201 |
+
# addmm decomp with padding will go through pad_addmm multiple times if multiple dimensions are needed to be padded
|
| 202 |
+
if k_padded_length != 0:
|
| 203 |
+
mat1 = pad_dim(mat1, k_padded_length, 1)
|
| 204 |
+
mat2 = pad_dim(mat2, k_padded_length, 0)
|
| 205 |
+
elif n_padded_length != 0:
|
| 206 |
+
mat2 = pad_dim(mat2, n_padded_length, 1)
|
| 207 |
+
elif m_padded_length != 0:
|
| 208 |
+
mat1 = pad_dim(mat1, m_padded_length, 0)
|
| 209 |
+
|
| 210 |
+
# the add broadcasts, so we only pad if the dimension != 1
|
| 211 |
+
if input is not None and k_padded_length == 0:
|
| 212 |
+
if n_padded_length != 0:
|
| 213 |
+
if input.dim() == 2 and input.shape[1] != 1:
|
| 214 |
+
input = pad_dim(input, n_padded_length, 1)
|
| 215 |
+
elif input.dim() == 1 and input.shape[0] != 1:
|
| 216 |
+
input = pad_dim(input, n_padded_length, 0)
|
| 217 |
+
elif m_padded_length != 0 and input.dim() == 2 and input.shape[0] != 1:
|
| 218 |
+
input = pad_dim(input, m_padded_length, 0)
|
| 219 |
+
|
| 220 |
+
if k_padded_length != 0:
|
| 221 |
+
return addmm_replace(input, mat1, mat2, beta=beta, alpha=alpha)
|
| 222 |
+
elif n_padded_length != 0:
|
| 223 |
+
return addmm_replace(input, mat1, mat2, beta=beta, alpha=alpha)[
|
| 224 |
+
:, :-n_padded_length
|
| 225 |
+
]
|
| 226 |
+
else:
|
| 227 |
+
return addmm_replace(input, mat1, mat2, beta=beta, alpha=alpha)[
|
| 228 |
+
:-m_padded_length, :
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool:
|
| 233 |
+
denominator = M * K + N * K + M * N
|
| 234 |
+
if denominator == 0:
|
| 235 |
+
return False
|
| 236 |
+
arithmetic_intensity = (M * N * K) / denominator
|
| 237 |
+
|
| 238 |
+
# Fails with AMD
|
| 239 |
+
try:
|
| 240 |
+
machine_balance = (
|
| 241 |
+
1000 * utils.get_device_tflops(dtype)
|
| 242 |
+
) / utils.get_gpu_dram_gbps()
|
| 243 |
+
except Exception:
|
| 244 |
+
return True
|
| 245 |
+
|
| 246 |
+
# dram_gbps might be underestimating bandwidth because of cache.
|
| 247 |
+
# if we estimate machine balance too low we might miss some speedups,
|
| 248 |
+
# if we extimate too high there will be unnecessary compilation time increase.
|
| 249 |
+
# TODO - finetune coefficient here. As a reference point, Triton mm model assumes
|
| 250 |
+
# 80% of reads are in cache and cache is 4x faster than dram_gbps
|
| 251 |
+
machine_balance = machine_balance * 0.5
|
| 252 |
+
|
| 253 |
+
return arithmetic_intensity > machine_balance
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@functools.lru_cache(None)
|
| 257 |
+
def get_pad_cache():
|
| 258 |
+
return torch._inductor.codecache.LocalCache()
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def get_cached_should_pad(key):
|
| 262 |
+
return get_pad_cache().lookup(key)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def set_cached_should_pad(key, value):
|
| 266 |
+
return get_pad_cache().set_value(key, value=value)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def should_pad_bench_key(
|
| 270 |
+
mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
|
| 271 |
+
) -> str:
|
| 272 |
+
def tensor_key(t):
|
| 273 |
+
return (t.shape, t.stride(), t.dtype)
|
| 274 |
+
|
| 275 |
+
tf32_key = (
|
| 276 |
+
None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32
|
| 277 |
+
)
|
| 278 |
+
key = (
|
| 279 |
+
tensor_key(mat1),
|
| 280 |
+
tensor_key(mat2),
|
| 281 |
+
op,
|
| 282 |
+
input if input is None else tensor_key(input),
|
| 283 |
+
tf32_key,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
return str(key)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def should_pad_bench(
|
| 290 |
+
mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
|
| 291 |
+
) -> bool:
|
| 292 |
+
if not has_triton():
|
| 293 |
+
return False
|
| 294 |
+
|
| 295 |
+
do_bench = functools.partial(
|
| 296 |
+
utils.do_bench,
|
| 297 |
+
warmup=5,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
with no_dispatch():
|
| 301 |
+
if op is torch.ops.aten.mm or op is torch.ops.aten.addmm:
|
| 302 |
+
m = mat1.shape[0]
|
| 303 |
+
k = mat1.shape[1]
|
| 304 |
+
n = mat2.shape[1]
|
| 305 |
+
|
| 306 |
+
m_padded_length = get_padded_length(m, get_alignment_size(mat1))
|
| 307 |
+
k_padded_length = get_padded_length(k, get_alignment_size(mat1))
|
| 308 |
+
n_padded_length = get_padded_length(n, get_alignment_size(mat2))
|
| 309 |
+
elif op is torch.ops.aten.bmm:
|
| 310 |
+
m = mat1.shape[1]
|
| 311 |
+
k = mat1.shape[2]
|
| 312 |
+
n = mat2.shape[2]
|
| 313 |
+
|
| 314 |
+
m_padded_length = get_padded_length(m, get_alignment_size(mat1))
|
| 315 |
+
k_padded_length = get_padded_length(k, get_alignment_size(mat1))
|
| 316 |
+
n_padded_length = get_padded_length(n, get_alignment_size(mat2))
|
| 317 |
+
else:
|
| 318 |
+
return False
|
| 319 |
+
|
| 320 |
+
if m_padded_length == k_padded_length == n_padded_length == 0:
|
| 321 |
+
return False
|
| 322 |
+
|
| 323 |
+
if not is_mm_compute_bound(m, k, n, mat1.dtype):
|
| 324 |
+
return False
|
| 325 |
+
|
| 326 |
+
# We don't want to look up the cache for cases that are trivially false
|
| 327 |
+
# since it does file io
|
| 328 |
+
key = should_pad_bench_key(mat1, mat2, op, input)
|
| 329 |
+
|
| 330 |
+
cached_pad = get_cached_should_pad(key)
|
| 331 |
+
if cached_pad is not None:
|
| 332 |
+
return cached_pad
|
| 333 |
+
|
| 334 |
+
def realize_symbols(ds):
|
| 335 |
+
return [d if isinstance(d, int) else d.node.hint for d in ds]
|
| 336 |
+
|
| 337 |
+
def realize_tensor(t):
|
| 338 |
+
if isinstance(t, FakeTensor):
|
| 339 |
+
size_hints = realize_symbols(t.size())
|
| 340 |
+
stride_hint = realize_symbols(t.stride())
|
| 341 |
+
real_size = (
|
| 342 |
+
sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1
|
| 343 |
+
)
|
| 344 |
+
real_t = torch.randn(real_size, dtype=t.dtype, device=t.device)
|
| 345 |
+
return torch.as_strided(real_t, size_hints, stride_hint)
|
| 346 |
+
else:
|
| 347 |
+
return torch.randn_like(t)
|
| 348 |
+
|
| 349 |
+
mat1 = realize_tensor(mat1)
|
| 350 |
+
mat2 = realize_tensor(mat2)
|
| 351 |
+
if op is torch.ops.aten.bmm or op is torch.ops.aten.mm:
|
| 352 |
+
ori_time = do_bench(
|
| 353 |
+
lambda: op(mat1, mat2),
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
if input is not None:
|
| 357 |
+
input = realize_tensor(input)
|
| 358 |
+
ori_time = do_bench(
|
| 359 |
+
lambda: op(input, mat1, mat2),
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
mat1_pad = torch.randn_like(mat1)
|
| 363 |
+
mat2_pad = torch.randn_like(mat2)
|
| 364 |
+
|
| 365 |
+
if op is torch.ops.aten.addmm:
|
| 366 |
+
input_pad = None
|
| 367 |
+
if input is not None and input.is_cuda:
|
| 368 |
+
input_pad = torch.randn_like(input)
|
| 369 |
+
pad_time = do_bench(
|
| 370 |
+
lambda: pad_addmm(
|
| 371 |
+
input_pad,
|
| 372 |
+
mat1_pad,
|
| 373 |
+
mat2_pad,
|
| 374 |
+
m_padded_length,
|
| 375 |
+
k_padded_length,
|
| 376 |
+
n_padded_length,
|
| 377 |
+
),
|
| 378 |
+
)
|
| 379 |
+
elif op is torch.ops.aten.mm:
|
| 380 |
+
pad_time = do_bench(
|
| 381 |
+
lambda: pad_mm(
|
| 382 |
+
mat1_pad,
|
| 383 |
+
mat2_pad,
|
| 384 |
+
m_padded_length,
|
| 385 |
+
k_padded_length,
|
| 386 |
+
n_padded_length,
|
| 387 |
+
),
|
| 388 |
+
)
|
| 389 |
+
else:
|
| 390 |
+
pad_time = do_bench(
|
| 391 |
+
lambda: pad_bmm(
|
| 392 |
+
mat1_pad,
|
| 393 |
+
mat2_pad,
|
| 394 |
+
m_padded_length,
|
| 395 |
+
k_padded_length,
|
| 396 |
+
n_padded_length,
|
| 397 |
+
),
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable
|
| 401 |
+
# tradeoff between performance improvement from shape padding and overhead from additional memory ops
|
| 402 |
+
# TODO: Build a learned model which would be better than this heuristic
|
| 403 |
+
should_pad = _skip_do_bench_times or ori_time > pad_time * 1.1
|
| 404 |
+
set_cached_should_pad(key, should_pad)
|
| 405 |
+
|
| 406 |
+
return should_pad
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def mm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor:
|
| 410 |
+
return aten.mm(mat1, mat2)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def should_pad_mm(match: Match) -> bool:
|
| 414 |
+
if (
|
| 415 |
+
torch._inductor.config.keep_output_stride
|
| 416 |
+
and _result_layout_affects_graph_output(match)
|
| 417 |
+
):
|
| 418 |
+
return False
|
| 419 |
+
mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2"))
|
| 420 |
+
return should_pad_common(mat1, mat2) and should_pad_bench(
|
| 421 |
+
mat1, mat2, torch.ops.aten.mm
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def mm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
|
| 426 |
+
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
|
| 427 |
+
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
|
| 428 |
+
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
|
| 429 |
+
|
| 430 |
+
return pad_mm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def pad_mm(
|
| 434 |
+
mat1: Tensor,
|
| 435 |
+
mat2: Tensor,
|
| 436 |
+
m_padded_length: int,
|
| 437 |
+
k_padded_length: int,
|
| 438 |
+
n_padded_length: int,
|
| 439 |
+
) -> Tensor:
|
| 440 |
+
# mm_replace will go through pad_mm multiple times if multiple dimensions are needed to be padded
|
| 441 |
+
if k_padded_length != 0:
|
| 442 |
+
mat1 = pad_dim(mat1, k_padded_length, 1)
|
| 443 |
+
mat2 = pad_dim(mat2, k_padded_length, 0)
|
| 444 |
+
return torch.ops.aten.mm(mat1, mat2)
|
| 445 |
+
elif n_padded_length != 0:
|
| 446 |
+
mat2 = pad_dim(mat2, n_padded_length, 1)
|
| 447 |
+
return torch.ops.aten.mm(mat1, mat2)[:, :-n_padded_length]
|
| 448 |
+
else:
|
| 449 |
+
mat1 = pad_dim(mat1, m_padded_length, 0)
|
| 450 |
+
return torch.ops.aten.mm(mat1, mat2)[:-m_padded_length, :]
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def bmm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor:
|
| 454 |
+
return aten.bmm(mat1, mat2)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def should_pad_bmm(match: Match) -> bool:
|
| 458 |
+
if (
|
| 459 |
+
torch._inductor.config.keep_output_stride
|
| 460 |
+
and _result_layout_affects_graph_output(match)
|
| 461 |
+
):
|
| 462 |
+
return False
|
| 463 |
+
mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2"))
|
| 464 |
+
return should_pad_common(mat1, mat2) and should_pad_bench(
|
| 465 |
+
mat1, mat2, torch.ops.aten.bmm
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
|
| 470 |
+
m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
|
| 471 |
+
k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
|
| 472 |
+
n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
|
| 473 |
+
|
| 474 |
+
if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0:
|
| 475 |
+
return pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length)
|
| 476 |
+
|
| 477 |
+
return aten.bmm(mat1, mat2)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def pad_bmm(
|
| 481 |
+
mat1: Tensor,
|
| 482 |
+
mat2: Tensor,
|
| 483 |
+
m_padded_length: int,
|
| 484 |
+
k_padded_length: int,
|
| 485 |
+
n_padded_length: int,
|
| 486 |
+
) -> Tensor:
|
| 487 |
+
# bmm_replace will go through pad_bmm multiple times if multiple dimensions are needed to be padded
|
| 488 |
+
if k_padded_length != 0:
|
| 489 |
+
mat1 = pad_dim(mat1, k_padded_length, 2)
|
| 490 |
+
mat2 = pad_dim(mat2, k_padded_length, 1)
|
| 491 |
+
|
| 492 |
+
return aten.bmm(mat1, mat2)
|
| 493 |
+
elif n_padded_length != 0:
|
| 494 |
+
mat2 = pad_dim(mat2, n_padded_length, 2)
|
| 495 |
+
return aten.bmm(mat1, mat2)[:, :, :-n_padded_length].contiguous()
|
| 496 |
+
else:
|
| 497 |
+
mat1 = pad_dim(mat1, m_padded_length, 1)
|
| 498 |
+
return aten.bmm(mat1, mat2)[:, :-m_padded_length, :].contiguous()
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
@functools.lru_cache(None)
|
| 502 |
+
def _pad_mm_init():
|
| 503 |
+
from .joint_graph import patterns
|
| 504 |
+
|
| 505 |
+
if torch.cuda.is_available():
|
| 506 |
+
# workaround https://github.com/pytorch/pytorch/issues/97894
|
| 507 |
+
device = "cuda"
|
| 508 |
+
else:
|
| 509 |
+
device = "cpu"
|
| 510 |
+
|
| 511 |
+
# sizes/values dont actually matter for initial trace
|
| 512 |
+
# once we get a possible match we re-trace with the actual values and verify the match still holds
|
| 513 |
+
|
| 514 |
+
dim2a = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True)
|
| 515 |
+
dim2b = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True)
|
| 516 |
+
|
| 517 |
+
dim3a = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True)
|
| 518 |
+
dim3b = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True)
|
| 519 |
+
|
| 520 |
+
dim1a = functools.partial(torch.empty, (4), device=device, requires_grad=True)
|
| 521 |
+
|
| 522 |
+
# workaround https://github.com/pytorch/pytorch/issues/97894
|
| 523 |
+
# 0.113377 is a "magic" value that lets us recover the lost input arg relationship
|
| 524 |
+
rep = {"beta": 0.213377, "alpha": 0.113377}
|
| 525 |
+
|
| 526 |
+
for pattern, replacement, args, workaround, extra_check in [
|
| 527 |
+
(
|
| 528 |
+
mm_pattern,
|
| 529 |
+
mm_replace,
|
| 530 |
+
[dim2a(), dim2b()],
|
| 531 |
+
{},
|
| 532 |
+
should_pad_mm,
|
| 533 |
+
),
|
| 534 |
+
(
|
| 535 |
+
bmm_pattern,
|
| 536 |
+
bmm_replace,
|
| 537 |
+
[dim3a(), dim3b()],
|
| 538 |
+
{},
|
| 539 |
+
should_pad_bmm,
|
| 540 |
+
),
|
| 541 |
+
(
|
| 542 |
+
addmm_pattern,
|
| 543 |
+
addmm_replace,
|
| 544 |
+
[dim1a(), dim2a(), dim2b()],
|
| 545 |
+
rep,
|
| 546 |
+
should_pad_addmm,
|
| 547 |
+
),
|
| 548 |
+
]:
|
| 549 |
+
assert isinstance(workaround, dict) # mypy is unable to infer the type properly
|
| 550 |
+
register_replacement(
|
| 551 |
+
pattern,
|
| 552 |
+
replacement,
|
| 553 |
+
args,
|
| 554 |
+
joint_fwd_bwd,
|
| 555 |
+
patterns,
|
| 556 |
+
extra_check=extra_check,
|
| 557 |
+
scalar_workaround=workaround,
|
| 558 |
+
)
|
| 559 |
+
register_replacement(
|
| 560 |
+
pattern,
|
| 561 |
+
replacement,
|
| 562 |
+
args,
|
| 563 |
+
fwd_only,
|
| 564 |
+
patterns,
|
| 565 |
+
extra_check=extra_check,
|
| 566 |
+
scalar_workaround=workaround,
|
| 567 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py
ADDED
|
@@ -0,0 +1,1980 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import functools
|
| 3 |
+
import itertools
|
| 4 |
+
import math
|
| 5 |
+
import operator
|
| 6 |
+
from typing import Any, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch._dynamo.utils import counters
|
| 10 |
+
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
| 11 |
+
from ..lowering import lowerings as L, require_channels_last
|
| 12 |
+
from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match
|
| 13 |
+
from ..utils import pad_listlike
|
| 14 |
+
from .freezing_patterns import register_freezing_graph_pattern
|
| 15 |
+
from .post_grad import register_lowering_pattern
|
| 16 |
+
|
| 17 |
+
aten = torch.ops.aten
|
| 18 |
+
prims = torch.ops.prims
|
| 19 |
+
quantized_decomposed = torch.ops.quantized_decomposed
|
| 20 |
+
quantized = torch.ops.quantized
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
The quantization.py file primarily incorporates passes related to quantization fusion
|
| 24 |
+
in inductor, includes:
|
| 25 |
+
1. Dequant Promotion;
|
| 26 |
+
2. Conv/GEMM weight prepack with oneDNN Library;
|
| 27 |
+
3. Conv/GEMM quantization fusion with output quant node (if have);
|
| 28 |
+
4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more;
|
| 29 |
+
|
| 30 |
+
It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference
|
| 31 |
+
of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is
|
| 32 |
+
1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM.
|
| 33 |
+
2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node.
|
| 34 |
+
Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16
|
| 35 |
+
quantization.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _may_generate_pattern_with_dtype_convert(pattern, dtype=Arg(), dtype_convert=True):
|
| 40 |
+
if dtype_convert:
|
| 41 |
+
return CallFunction(
|
| 42 |
+
prims.convert_element_type.default,
|
| 43 |
+
pattern,
|
| 44 |
+
dtype,
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
return pattern
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True):
|
| 51 |
+
if with_reshape:
|
| 52 |
+
return CallFunction(
|
| 53 |
+
torch.ops.aten.reshape.default,
|
| 54 |
+
pattern,
|
| 55 |
+
reshape_size,
|
| 56 |
+
)
|
| 57 |
+
else:
|
| 58 |
+
return pattern
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _generate_linear_t_pattern(
|
| 62 |
+
_dequant_per_channel_pattern,
|
| 63 |
+
dtype,
|
| 64 |
+
):
|
| 65 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 66 |
+
t_pattern = CallFunction(
|
| 67 |
+
aten.permute.default,
|
| 68 |
+
_may_generate_pattern_with_dtype_convert(
|
| 69 |
+
_dequant_per_channel_pattern,
|
| 70 |
+
KeywordArg("autocast_wgt_dtype"),
|
| 71 |
+
dtype == torch.bfloat16,
|
| 72 |
+
),
|
| 73 |
+
KeywordArg("permute_axes"),
|
| 74 |
+
)
|
| 75 |
+
return t_pattern
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
"""
|
| 79 |
+
dequantize activation:
|
| 80 |
+
x = x.to(fp32)
|
| 81 |
+
x = x - zero_point
|
| 82 |
+
x = x * scale
|
| 83 |
+
"""
|
| 84 |
+
dequantize_per_tensor_activation_pattern = CallFunction(
|
| 85 |
+
aten.mul.Tensor,
|
| 86 |
+
CallFunction(
|
| 87 |
+
aten.sub.Tensor,
|
| 88 |
+
CallFunction(
|
| 89 |
+
prims.convert_element_type.default,
|
| 90 |
+
KeywordArg("x"),
|
| 91 |
+
KeywordArg("x_dq_dtype"),
|
| 92 |
+
),
|
| 93 |
+
KeywordArg("x_zp"),
|
| 94 |
+
),
|
| 95 |
+
KeywordArg("x_scale"),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
dequantize_per_channel_weight_pattern = CallFunction(
|
| 99 |
+
quantized_decomposed.dequantize_per_channel.default,
|
| 100 |
+
KeywordArg("q_weight"),
|
| 101 |
+
KeywordArg("w_scale"),
|
| 102 |
+
KeywordArg("w_zp"),
|
| 103 |
+
KeywordArg("w_axis"),
|
| 104 |
+
KeywordArg("w_quant_min"),
|
| 105 |
+
KeywordArg("w_quant_max"),
|
| 106 |
+
KeywordArg("w_dtype"),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
dequantize_per_channel_to_bf16_weight_pattern = (
|
| 110 |
+
_may_generate_pattern_with_dtype_convert(
|
| 111 |
+
dequantize_per_channel_weight_pattern,
|
| 112 |
+
KeywordArg("autocast_wgt_dtype"),
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
dequantize_per_channel_clone_weight_pattern = CallFunction(
|
| 117 |
+
aten.clone.default,
|
| 118 |
+
dequantize_per_channel_weight_pattern,
|
| 119 |
+
memory_format=KeywordArg("memory_format"),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction(
|
| 123 |
+
aten.clone.default,
|
| 124 |
+
dequantize_per_channel_to_bf16_weight_pattern,
|
| 125 |
+
memory_format=KeywordArg("memory_format"),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_dequantize_qconv_pt2e_pattern(users=1):
|
| 130 |
+
return CallFunction(
|
| 131 |
+
torch.ops.onednn.qconv2d_pointwise.default,
|
| 132 |
+
KeywordArg("x"),
|
| 133 |
+
KeywordArg("x_scale"), # x_scale
|
| 134 |
+
KeywordArg("x_zp"), # x_zp
|
| 135 |
+
KeywordArg("packed_weight"), # packed_weight
|
| 136 |
+
KeywordArg("w_scale"), # w_scale
|
| 137 |
+
KeywordArg("w_zp"), # w_zp
|
| 138 |
+
KeywordArg("b"), # bias
|
| 139 |
+
KeywordArg("stride"),
|
| 140 |
+
KeywordArg("padding"),
|
| 141 |
+
KeywordArg("dilation"),
|
| 142 |
+
KeywordArg("groups"),
|
| 143 |
+
KeywordArg("inv_output_scale"), # inv_output_scale = 1.0
|
| 144 |
+
KeywordArg("output_zero_point"), # output_zero_point = 0
|
| 145 |
+
KeywordArg("output_dtype"), # output_dtype = None
|
| 146 |
+
KeywordArg("attr"), # attr = "none"
|
| 147 |
+
Arg(), # scalars
|
| 148 |
+
Arg(), # algorithm
|
| 149 |
+
_users=users,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors):
|
| 154 |
+
qlinear_op = (
|
| 155 |
+
torch.ops.onednn.qlinear_pointwise.tensor
|
| 156 |
+
if x_scale_zp_are_tensors
|
| 157 |
+
else torch.ops.onednn.qlinear_pointwise.default
|
| 158 |
+
)
|
| 159 |
+
return CallFunction(
|
| 160 |
+
qlinear_op,
|
| 161 |
+
KeywordArg("x"),
|
| 162 |
+
KeywordArg("x_scale"),
|
| 163 |
+
KeywordArg("x_zp"),
|
| 164 |
+
KeywordArg("packed_weight"),
|
| 165 |
+
KeywordArg("w_scale"),
|
| 166 |
+
KeywordArg("w_zp"),
|
| 167 |
+
KeywordArg("b"),
|
| 168 |
+
KeywordArg("output_scale"),
|
| 169 |
+
KeywordArg("output_zero_point"),
|
| 170 |
+
KeywordArg("output_dtype"),
|
| 171 |
+
KeywordArg("postop_name"),
|
| 172 |
+
KeywordArg("postop_args"),
|
| 173 |
+
KeywordArg("postop_algorithm"),
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
dequantize_accum_pattern = CallFunction(
|
| 178 |
+
aten.mul.Tensor,
|
| 179 |
+
CallFunction(
|
| 180 |
+
aten.sub.Tensor,
|
| 181 |
+
CallFunction(
|
| 182 |
+
prims.convert_element_type.default,
|
| 183 |
+
KeywordArg("accum"),
|
| 184 |
+
KeywordArg("accum_dq_dtype"),
|
| 185 |
+
),
|
| 186 |
+
KeywordArg("accum_zp"),
|
| 187 |
+
),
|
| 188 |
+
KeywordArg("accum_scale"),
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def generate_pattern_with_binary(
|
| 193 |
+
binary_post_op,
|
| 194 |
+
computation_call,
|
| 195 |
+
extra_input_pattern,
|
| 196 |
+
int8_mixed_bf16_with_inplace_add=False,
|
| 197 |
+
):
|
| 198 |
+
binary_pattern = CallFunction(
|
| 199 |
+
binary_post_op,
|
| 200 |
+
computation_call,
|
| 201 |
+
extra_input_pattern,
|
| 202 |
+
)
|
| 203 |
+
return _may_generate_pattern_with_dtype_convert(
|
| 204 |
+
binary_pattern,
|
| 205 |
+
KeywordArg("convert_dtype_after_inplace_add"),
|
| 206 |
+
int8_mixed_bf16_with_inplace_add,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def generate_pattern_with_unary(computation_call, unary_post_op):
|
| 211 |
+
if unary_post_op is not None:
|
| 212 |
+
if unary_post_op == aten.hardtanh.default:
|
| 213 |
+
return CallFunction(
|
| 214 |
+
aten.clamp_max,
|
| 215 |
+
CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")),
|
| 216 |
+
KeywordArg("max_value"),
|
| 217 |
+
)
|
| 218 |
+
if unary_post_op == aten.hardswish.default:
|
| 219 |
+
return CallFunction(
|
| 220 |
+
aten.div,
|
| 221 |
+
CallFunction(
|
| 222 |
+
aten.mul,
|
| 223 |
+
computation_call,
|
| 224 |
+
CallFunction(
|
| 225 |
+
aten.clamp_max,
|
| 226 |
+
CallFunction(
|
| 227 |
+
aten.clamp_min,
|
| 228 |
+
CallFunction(aten.add, computation_call, 3),
|
| 229 |
+
0,
|
| 230 |
+
),
|
| 231 |
+
6,
|
| 232 |
+
),
|
| 233 |
+
),
|
| 234 |
+
6,
|
| 235 |
+
)
|
| 236 |
+
else:
|
| 237 |
+
return CallFunction(
|
| 238 |
+
unary_post_op,
|
| 239 |
+
computation_call,
|
| 240 |
+
)
|
| 241 |
+
return computation_call
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def generate_pattern_with_output_quant(computation_call, dtype=torch.float32):
|
| 245 |
+
"""
|
| 246 |
+
quantize output:
|
| 247 |
+
output = round(output * o_inv_scale)
|
| 248 |
+
output = output + zero_point
|
| 249 |
+
output = clamp_min(output, 0)
|
| 250 |
+
output = clamp_max(output, 127)
|
| 251 |
+
output = output.to(uint8)
|
| 252 |
+
"""
|
| 253 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 254 |
+
quantized_op_output_pattern_pt2e = CallFunction(
|
| 255 |
+
prims.convert_element_type.default,
|
| 256 |
+
CallFunction(
|
| 257 |
+
aten.clamp_max.default,
|
| 258 |
+
CallFunction(
|
| 259 |
+
aten.clamp_min.default,
|
| 260 |
+
CallFunction(
|
| 261 |
+
aten.add.Tensor,
|
| 262 |
+
CallFunction(
|
| 263 |
+
aten.round.default,
|
| 264 |
+
CallFunction(
|
| 265 |
+
aten.mul.Tensor,
|
| 266 |
+
_may_generate_pattern_with_dtype_convert(
|
| 267 |
+
computation_call,
|
| 268 |
+
KeywordArg("autocast_output_quant_dtype"),
|
| 269 |
+
dtype == torch.bfloat16,
|
| 270 |
+
),
|
| 271 |
+
KeywordArg("o_inv_scale"),
|
| 272 |
+
),
|
| 273 |
+
),
|
| 274 |
+
KeywordArg("o_zp"),
|
| 275 |
+
),
|
| 276 |
+
KeywordArg("o_qmin"),
|
| 277 |
+
),
|
| 278 |
+
KeywordArg("o_qmax"),
|
| 279 |
+
),
|
| 280 |
+
KeywordArg("o_dtype"),
|
| 281 |
+
)
|
| 282 |
+
return quantized_op_output_pattern_pt2e
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value):
|
| 286 |
+
if kwarg_name in check_node.kwargs:
|
| 287 |
+
actual_value = check_node.kwargs[kwarg_name]
|
| 288 |
+
return actual_value == expected_value
|
| 289 |
+
else:
|
| 290 |
+
assert len(check_node.args) >= (args_index + 1)
|
| 291 |
+
actual_value = check_node.args[args_index]
|
| 292 |
+
return actual_value == expected_value
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _is_valid_quantized_conv2d_optimization_pattern(output_dtype):
|
| 296 |
+
def fn(match):
|
| 297 |
+
if output_dtype is not None:
|
| 298 |
+
# Only keep matched pattern with same output_dtype
|
| 299 |
+
qconv_node_after_weight_prepack = filter_nodes(
|
| 300 |
+
match.nodes, torch.ops.onednn.qconv2d_pointwise
|
| 301 |
+
)[0]
|
| 302 |
+
return _check_node_kwarg_arg_value(
|
| 303 |
+
qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype
|
| 304 |
+
)
|
| 305 |
+
return True
|
| 306 |
+
|
| 307 |
+
return fn
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _register_quantized_conv_lowering(
|
| 311 |
+
pattern,
|
| 312 |
+
pass_number,
|
| 313 |
+
computation_op,
|
| 314 |
+
output_dtype,
|
| 315 |
+
unary_attr,
|
| 316 |
+
original_pattern_output_dtype=torch.float32,
|
| 317 |
+
):
|
| 318 |
+
@register_lowering_pattern(
|
| 319 |
+
pattern,
|
| 320 |
+
extra_check=_is_valid_quantized_conv2d_optimization_pattern(output_dtype),
|
| 321 |
+
pass_number=pass_number,
|
| 322 |
+
)
|
| 323 |
+
def qconv(match: Match, *args, **kwargs):
|
| 324 |
+
# Activation QParams
|
| 325 |
+
x, x_scale, x_zp = (
|
| 326 |
+
kwargs["x"],
|
| 327 |
+
kwargs["x_scale"],
|
| 328 |
+
kwargs["x_zp"],
|
| 329 |
+
)
|
| 330 |
+
# Weight QParams
|
| 331 |
+
packed_weight, w_scale, w_zp = (
|
| 332 |
+
kwargs["packed_weight"],
|
| 333 |
+
kwargs["w_scale"],
|
| 334 |
+
kwargs["w_zp"],
|
| 335 |
+
)
|
| 336 |
+
# Conv Params
|
| 337 |
+
b, stride, padding, dilation, groups = (
|
| 338 |
+
kwargs["b"],
|
| 339 |
+
kwargs["stride"],
|
| 340 |
+
kwargs["padding"],
|
| 341 |
+
kwargs["dilation"],
|
| 342 |
+
kwargs["groups"],
|
| 343 |
+
)
|
| 344 |
+
assert output_dtype in [None, torch.float32, torch.bfloat16]
|
| 345 |
+
# Output QParams
|
| 346 |
+
o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0
|
| 347 |
+
o_zero_point = kwargs["o_zp"] if output_dtype is None else 0
|
| 348 |
+
assert (
|
| 349 |
+
kwargs["output_dtype"] is original_pattern_output_dtype
|
| 350 |
+
) # Expected int8-in fp32-out qconv in weight prepack phase
|
| 351 |
+
assert (
|
| 352 |
+
kwargs["attr"] == "none"
|
| 353 |
+
) # Expected no post op fused in weight prepack phase
|
| 354 |
+
if unary_attr.op_name == "hardtanh":
|
| 355 |
+
min_value = kwargs.get("min_value")
|
| 356 |
+
max_value = kwargs.get("max_value")
|
| 357 |
+
unary_attr.scalars_attr = [min_value, max_value]
|
| 358 |
+
|
| 359 |
+
computation_args = (
|
| 360 |
+
x,
|
| 361 |
+
x_scale,
|
| 362 |
+
x_zp,
|
| 363 |
+
packed_weight,
|
| 364 |
+
w_scale,
|
| 365 |
+
w_zp,
|
| 366 |
+
b,
|
| 367 |
+
stride,
|
| 368 |
+
padding,
|
| 369 |
+
dilation,
|
| 370 |
+
groups,
|
| 371 |
+
o_inv_scale,
|
| 372 |
+
o_zero_point,
|
| 373 |
+
output_dtype,
|
| 374 |
+
unary_attr.op_name,
|
| 375 |
+
unary_attr.scalars_attr,
|
| 376 |
+
unary_attr.algorithm_attr,
|
| 377 |
+
)
|
| 378 |
+
counters["inductor"]["qconv2d_unary_matcher_count"] += 1
|
| 379 |
+
counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes)
|
| 380 |
+
return L[computation_op](*computation_args)
|
| 381 |
+
|
| 382 |
+
return qconv
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def _is_valid_quantized_linear_optimization_pattern(output_dtype):
|
| 386 |
+
def fn(match):
|
| 387 |
+
if output_dtype is not None:
|
| 388 |
+
# Only keep matched pattern with same output_dtype
|
| 389 |
+
qlinear_node_after_weight_prepack = filter_nodes(
|
| 390 |
+
match.nodes, torch.ops.onednn.qlinear_pointwise
|
| 391 |
+
)[0]
|
| 392 |
+
return _check_node_kwarg_arg_value(
|
| 393 |
+
qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype
|
| 394 |
+
)
|
| 395 |
+
return True
|
| 396 |
+
|
| 397 |
+
return fn
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def _register_quantized_linear_lowering(
|
| 401 |
+
pattern,
|
| 402 |
+
pass_number,
|
| 403 |
+
computation_op,
|
| 404 |
+
output_dtype,
|
| 405 |
+
unary_attr,
|
| 406 |
+
original_pattern_output_dtype=torch.float32,
|
| 407 |
+
):
|
| 408 |
+
@register_lowering_pattern(
|
| 409 |
+
pattern,
|
| 410 |
+
extra_check=_is_valid_quantized_linear_optimization_pattern(output_dtype),
|
| 411 |
+
pass_number=pass_number,
|
| 412 |
+
)
|
| 413 |
+
def qlinear(match: Match, *args, **kwargs):
|
| 414 |
+
# Activation QParams
|
| 415 |
+
x, x_scale, x_zp = (
|
| 416 |
+
kwargs["x"],
|
| 417 |
+
kwargs["x_scale"],
|
| 418 |
+
kwargs["x_zp"],
|
| 419 |
+
)
|
| 420 |
+
# Weight QParams
|
| 421 |
+
packed_weight, w_scale, w_zp = (
|
| 422 |
+
kwargs["packed_weight"],
|
| 423 |
+
kwargs["w_scale"],
|
| 424 |
+
kwargs["w_zp"],
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# bias
|
| 428 |
+
b = kwargs["b"] if "b" in kwargs else None
|
| 429 |
+
|
| 430 |
+
# Output QParams
|
| 431 |
+
o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0
|
| 432 |
+
o_zero_point = kwargs["o_zp"] if output_dtype is None else 0
|
| 433 |
+
assert (
|
| 434 |
+
kwargs["output_dtype"] is original_pattern_output_dtype
|
| 435 |
+
) # Expected int8-in fp32/bf16-out qlinear in weight prepack phase
|
| 436 |
+
assert (
|
| 437 |
+
kwargs["postop_name"] == "none"
|
| 438 |
+
) # Expected no post op fused in weight prepack phase
|
| 439 |
+
|
| 440 |
+
computation_args = (
|
| 441 |
+
x,
|
| 442 |
+
x_scale,
|
| 443 |
+
x_zp,
|
| 444 |
+
packed_weight,
|
| 445 |
+
w_scale,
|
| 446 |
+
w_zp,
|
| 447 |
+
b,
|
| 448 |
+
o_inv_scale,
|
| 449 |
+
o_zero_point,
|
| 450 |
+
output_dtype,
|
| 451 |
+
unary_attr.op_name,
|
| 452 |
+
unary_attr.scalars_attr,
|
| 453 |
+
unary_attr.algorithm_attr,
|
| 454 |
+
)
|
| 455 |
+
counters["inductor"]["qlinear_unary_matcher_count"] += 1
|
| 456 |
+
counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes)
|
| 457 |
+
return L[computation_op](*computation_args)
|
| 458 |
+
|
| 459 |
+
return qlinear
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype):
|
| 463 |
+
# Check if it's a valid Conv Binary Pattern:
|
| 464 |
+
# * qconv2d_pointwise should only has one users
|
| 465 |
+
# * Extra input of binary node comes from dequant pattern
|
| 466 |
+
# * the two inputs of binary node should have attribute "meta" and should be tensors
|
| 467 |
+
# * the two inputs of binary node should have the same shape
|
| 468 |
+
# * All users of the extra input in this pattern should be
|
| 469 |
+
# ancestor nodes of the compute node, except for the binary node
|
| 470 |
+
# connected to the compute node.
|
| 471 |
+
def fn(match):
|
| 472 |
+
compute_node = filter_nodes(match.nodes, torch.ops.onednn.qconv2d_pointwise)[0]
|
| 473 |
+
# qconv2d_pointwise should only have one user
|
| 474 |
+
if len(compute_node.users) != 1:
|
| 475 |
+
return False
|
| 476 |
+
binary_node_inputs = next(iter(compute_node.users)).args
|
| 477 |
+
assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs"
|
| 478 |
+
if output_dtype is not None:
|
| 479 |
+
extra_input_of_binary_node = None
|
| 480 |
+
for arg in binary_node_inputs:
|
| 481 |
+
if arg != compute_node:
|
| 482 |
+
extra_input_of_binary_node = arg
|
| 483 |
+
break
|
| 484 |
+
assert extra_input_of_binary_node is not None
|
| 485 |
+
# Extra input of binary node comes from dequant pattern
|
| 486 |
+
if (not isinstance(extra_input_of_binary_node, torch.fx.Node)) or (
|
| 487 |
+
extra_input_of_binary_node.target != aten.mul.Tensor
|
| 488 |
+
):
|
| 489 |
+
return False
|
| 490 |
+
|
| 491 |
+
# the two inputs of binary node should have attribute "meta" and should be tensors
|
| 492 |
+
if not (
|
| 493 |
+
hasattr(binary_node_inputs[0], "meta")
|
| 494 |
+
and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
|
| 495 |
+
) or not (
|
| 496 |
+
hasattr(binary_node_inputs[1], "meta")
|
| 497 |
+
and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
|
| 498 |
+
):
|
| 499 |
+
return False
|
| 500 |
+
# the two inputs of binary node should have the same shape
|
| 501 |
+
if (
|
| 502 |
+
binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr]
|
| 503 |
+
!= binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr]
|
| 504 |
+
):
|
| 505 |
+
return False
|
| 506 |
+
|
| 507 |
+
# All users of the extra input in this pattern should be
|
| 508 |
+
# ancestor nodes of the compute node, except for the binary node
|
| 509 |
+
# connected to the compute node.
|
| 510 |
+
|
| 511 |
+
from .mkldnn_fusion import _get_remaining_users
|
| 512 |
+
|
| 513 |
+
extra_input_of_pattern = (
|
| 514 |
+
match.kwargs["accum"]
|
| 515 |
+
if output_dtype is None
|
| 516 |
+
else match.kwargs["accum_after_dequant"]
|
| 517 |
+
)
|
| 518 |
+
if (
|
| 519 |
+
len(
|
| 520 |
+
_get_remaining_users(
|
| 521 |
+
extra_input_of_pattern,
|
| 522 |
+
compute_node,
|
| 523 |
+
)
|
| 524 |
+
)
|
| 525 |
+
> 1
|
| 526 |
+
or extra_input_of_pattern == compute_node.args[0]
|
| 527 |
+
):
|
| 528 |
+
return False
|
| 529 |
+
return True
|
| 530 |
+
|
| 531 |
+
return fn
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def _register_quantized_conv_binary_lowering(
|
| 535 |
+
pattern,
|
| 536 |
+
pass_number,
|
| 537 |
+
computation_op,
|
| 538 |
+
output_dtype,
|
| 539 |
+
binary_unary_attr,
|
| 540 |
+
):
|
| 541 |
+
@register_lowering_pattern(
|
| 542 |
+
pattern,
|
| 543 |
+
extra_check=_is_valid_quantized_conv_binary_optimization_pattern(output_dtype),
|
| 544 |
+
pass_number=pass_number,
|
| 545 |
+
)
|
| 546 |
+
def qconv_binary(match: Match, *args, **kwargs):
|
| 547 |
+
x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"]
|
| 548 |
+
accum = (
|
| 549 |
+
kwargs["accum"] if output_dtype is None else kwargs["accum_after_dequant"]
|
| 550 |
+
)
|
| 551 |
+
accum_scale = kwargs["accum_scale"] if output_dtype is None else 1.0
|
| 552 |
+
accum_zp = kwargs["accum_zp"] if output_dtype is None else 0
|
| 553 |
+
packed_weight, w_scale, w_zp = (
|
| 554 |
+
kwargs["packed_weight"],
|
| 555 |
+
kwargs["w_scale"],
|
| 556 |
+
kwargs["w_zp"],
|
| 557 |
+
)
|
| 558 |
+
b, stride, padding, dilation, groups = (
|
| 559 |
+
kwargs["b"],
|
| 560 |
+
kwargs["stride"],
|
| 561 |
+
kwargs["padding"],
|
| 562 |
+
kwargs["dilation"],
|
| 563 |
+
kwargs["groups"],
|
| 564 |
+
)
|
| 565 |
+
# Output QParams
|
| 566 |
+
o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0
|
| 567 |
+
o_zero_point = kwargs["o_zp"] if output_dtype is None else 0
|
| 568 |
+
|
| 569 |
+
accum.realize()
|
| 570 |
+
from .mkldnn_fusion import _can_be_inplace
|
| 571 |
+
|
| 572 |
+
assert _can_be_inplace(
|
| 573 |
+
accum
|
| 574 |
+
), "QConv Binary Inplace Fusion requires accum is not an alias or mutation."
|
| 575 |
+
|
| 576 |
+
computation_args = (
|
| 577 |
+
x,
|
| 578 |
+
x_scale,
|
| 579 |
+
x_zp,
|
| 580 |
+
accum,
|
| 581 |
+
accum_scale,
|
| 582 |
+
accum_zp,
|
| 583 |
+
packed_weight,
|
| 584 |
+
w_scale,
|
| 585 |
+
w_zp,
|
| 586 |
+
b,
|
| 587 |
+
stride,
|
| 588 |
+
padding,
|
| 589 |
+
dilation,
|
| 590 |
+
groups,
|
| 591 |
+
o_inv_scale,
|
| 592 |
+
o_zero_point,
|
| 593 |
+
output_dtype,
|
| 594 |
+
binary_unary_attr.binary_op_name,
|
| 595 |
+
binary_unary_attr.alpha,
|
| 596 |
+
binary_unary_attr.unary_op_name,
|
| 597 |
+
binary_unary_attr.scalars_attr,
|
| 598 |
+
binary_unary_attr.algorithm_attr,
|
| 599 |
+
)
|
| 600 |
+
counters["inductor"]["qconv2d_binary_matcher_count"] += 1
|
| 601 |
+
counters["inductor"]["qconv2d_binary_matcher_nodes"] += len(match.nodes)
|
| 602 |
+
return L[computation_op](*computation_args)
|
| 603 |
+
|
| 604 |
+
return qconv_binary
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
def _register_quantization_unary_fusion():
|
| 608 |
+
class UnaryAttr:
|
| 609 |
+
def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
|
| 610 |
+
self.op_name = op_name
|
| 611 |
+
self.scalars_attr = scalars_attr if scalars_attr else []
|
| 612 |
+
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
|
| 613 |
+
|
| 614 |
+
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
|
| 615 |
+
# QConv2d
|
| 616 |
+
# Priority 1 to match: QConv2d Unary pattern with int8 output
|
| 617 |
+
# If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
|
| 618 |
+
# For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
|
| 619 |
+
conv_unary_replace_patterns = {
|
| 620 |
+
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
|
| 621 |
+
get_dequantize_qconv_pt2e_pattern(1),
|
| 622 |
+
dtype=original_pattern_output_dtype,
|
| 623 |
+
),
|
| 624 |
+
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
|
| 625 |
+
generate_pattern_with_unary(
|
| 626 |
+
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
|
| 627 |
+
),
|
| 628 |
+
dtype=original_pattern_output_dtype,
|
| 629 |
+
),
|
| 630 |
+
UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant(
|
| 631 |
+
generate_pattern_with_unary(
|
| 632 |
+
get_dequantize_qconv_pt2e_pattern(1), aten.hardtanh.default
|
| 633 |
+
),
|
| 634 |
+
dtype=original_pattern_output_dtype,
|
| 635 |
+
),
|
| 636 |
+
UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant(
|
| 637 |
+
generate_pattern_with_unary(
|
| 638 |
+
get_dequantize_qconv_pt2e_pattern(2), aten.hardswish.default
|
| 639 |
+
),
|
| 640 |
+
dtype=original_pattern_output_dtype,
|
| 641 |
+
),
|
| 642 |
+
}
|
| 643 |
+
|
| 644 |
+
for unary_attr, patterns in conv_unary_replace_patterns.items():
|
| 645 |
+
# Register qconv2d pattern for ExternKernel Lowering
|
| 646 |
+
_register_quantized_conv_lowering(
|
| 647 |
+
patterns,
|
| 648 |
+
1, # pass_number
|
| 649 |
+
torch.ops.onednn.qconv2d_pointwise, # computation_op
|
| 650 |
+
None, # output_dtype, None is the default value for int8 output
|
| 651 |
+
unary_attr, # unary_attr
|
| 652 |
+
original_pattern_output_dtype=original_pattern_output_dtype,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
# Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
|
| 656 |
+
conv_unary_replace_float_out_patterns = {
|
| 657 |
+
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
|
| 658 |
+
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
|
| 659 |
+
),
|
| 660 |
+
UnaryAttr("hardtanh", [], ""): generate_pattern_with_unary(
|
| 661 |
+
get_dequantize_qconv_pt2e_pattern(1), aten.hardtanh.default
|
| 662 |
+
),
|
| 663 |
+
UnaryAttr("hardswish", [], ""): generate_pattern_with_unary(
|
| 664 |
+
get_dequantize_qconv_pt2e_pattern(2), aten.hardswish.default
|
| 665 |
+
),
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():
|
| 669 |
+
# Register qconv2d pattern for ExternKernel Lowering
|
| 670 |
+
_register_quantized_conv_lowering(
|
| 671 |
+
patterns,
|
| 672 |
+
2, # pass_number
|
| 673 |
+
torch.ops.onednn.qconv2d_pointwise, # computation_op
|
| 674 |
+
original_pattern_output_dtype, # output_dtype
|
| 675 |
+
unary_attr, # unary_attr
|
| 676 |
+
original_pattern_output_dtype=original_pattern_output_dtype,
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
# QLinear
|
| 680 |
+
for x_scale_zp_are_tensors in (False, True):
|
| 681 |
+
qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
|
| 682 |
+
# Priority 1 to match: QLinear Unary pattern with int8 output
|
| 683 |
+
linear_unary_replace_patterns = {
|
| 684 |
+
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
|
| 685 |
+
qlinear_pattern,
|
| 686 |
+
dtype=original_pattern_output_dtype,
|
| 687 |
+
),
|
| 688 |
+
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
|
| 689 |
+
generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
|
| 690 |
+
dtype=original_pattern_output_dtype,
|
| 691 |
+
),
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
for unary_attr, patterns in linear_unary_replace_patterns.items():
|
| 695 |
+
_register_quantized_linear_lowering(
|
| 696 |
+
patterns,
|
| 697 |
+
1, # pass_number
|
| 698 |
+
torch.ops.onednn.qlinear_pointwise, # computation_op
|
| 699 |
+
None, # output_dtype
|
| 700 |
+
unary_attr, # unary_attr
|
| 701 |
+
original_pattern_output_dtype=original_pattern_output_dtype,
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
|
| 705 |
+
linear_unary_replace_float_out_patterns = {
|
| 706 |
+
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
|
| 707 |
+
qlinear_pattern, aten.relu.default
|
| 708 |
+
),
|
| 709 |
+
}
|
| 710 |
+
|
| 711 |
+
for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
|
| 712 |
+
_register_quantized_linear_lowering(
|
| 713 |
+
patterns,
|
| 714 |
+
2, # pass_number
|
| 715 |
+
torch.ops.onednn.qlinear_pointwise, # computation_op
|
| 716 |
+
original_pattern_output_dtype, # output_dtype
|
| 717 |
+
unary_attr, # unary_attr
|
| 718 |
+
original_pattern_output_dtype=original_pattern_output_dtype,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def _register_quantization_binary_fusion():
|
| 723 |
+
class BinaryUnaryAttr:
|
| 724 |
+
def __init__(
|
| 725 |
+
self,
|
| 726 |
+
binary_op_name: str,
|
| 727 |
+
alpha=None,
|
| 728 |
+
unary_op_name: str = "none",
|
| 729 |
+
scalars_attr=None,
|
| 730 |
+
algorithm_attr=None,
|
| 731 |
+
):
|
| 732 |
+
self.binary_op_name = binary_op_name
|
| 733 |
+
self.alpha = alpha if alpha else 1.0
|
| 734 |
+
self.unary_op_name = unary_op_name
|
| 735 |
+
self.scalars_attr = scalars_attr if scalars_attr else []
|
| 736 |
+
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
|
| 737 |
+
|
| 738 |
+
for int8_mixed_bf16_with_inplace_add in [False, True]:
|
| 739 |
+
# Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
|
| 740 |
+
binary_replace_patterns = {
|
| 741 |
+
BinaryUnaryAttr(
|
| 742 |
+
"sum", 1.0, "none", [], ""
|
| 743 |
+
): generate_pattern_with_output_quant(
|
| 744 |
+
generate_pattern_with_binary(
|
| 745 |
+
aten.add.Tensor,
|
| 746 |
+
get_dequantize_qconv_pt2e_pattern(1),
|
| 747 |
+
dequantize_accum_pattern,
|
| 748 |
+
int8_mixed_bf16_with_inplace_add,
|
| 749 |
+
),
|
| 750 |
+
dtype=torch.bfloat16
|
| 751 |
+
if int8_mixed_bf16_with_inplace_add
|
| 752 |
+
else torch.float32,
|
| 753 |
+
),
|
| 754 |
+
BinaryUnaryAttr(
|
| 755 |
+
"sum", 1.0, "relu", [], ""
|
| 756 |
+
): generate_pattern_with_output_quant(
|
| 757 |
+
generate_pattern_with_unary(
|
| 758 |
+
generate_pattern_with_binary(
|
| 759 |
+
aten.add.Tensor,
|
| 760 |
+
get_dequantize_qconv_pt2e_pattern(1),
|
| 761 |
+
dequantize_accum_pattern,
|
| 762 |
+
int8_mixed_bf16_with_inplace_add,
|
| 763 |
+
),
|
| 764 |
+
aten.relu.default,
|
| 765 |
+
),
|
| 766 |
+
dtype=torch.bfloat16
|
| 767 |
+
if int8_mixed_bf16_with_inplace_add
|
| 768 |
+
else torch.float32,
|
| 769 |
+
),
|
| 770 |
+
}
|
| 771 |
+
|
| 772 |
+
for binary_unary_attr, patterns in binary_replace_patterns.items():
|
| 773 |
+
_register_quantized_conv_binary_lowering(
|
| 774 |
+
patterns,
|
| 775 |
+
0, # pass_number
|
| 776 |
+
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
| 777 |
+
None, # output_dtype
|
| 778 |
+
binary_unary_attr, # binary_unary_attr
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
# Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output
|
| 782 |
+
binary_replace_float_out_patterns = {
|
| 783 |
+
BinaryUnaryAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary(
|
| 784 |
+
generate_pattern_with_binary(
|
| 785 |
+
aten.add.Tensor,
|
| 786 |
+
get_dequantize_qconv_pt2e_pattern(1),
|
| 787 |
+
KeywordArg("accum_after_dequant"),
|
| 788 |
+
int8_mixed_bf16_with_inplace_add,
|
| 789 |
+
),
|
| 790 |
+
aten.relu.default,
|
| 791 |
+
),
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
for (
|
| 795 |
+
binary_unary_attr,
|
| 796 |
+
patterns,
|
| 797 |
+
) in binary_replace_float_out_patterns.items():
|
| 798 |
+
if int8_mixed_bf16_with_inplace_add:
|
| 799 |
+
_register_quantized_conv_binary_lowering(
|
| 800 |
+
patterns,
|
| 801 |
+
0, # pass_number
|
| 802 |
+
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
| 803 |
+
# Note that for int8-mixed-bf16 and non-inplace add, because we have
|
| 804 |
+
# q-dq inserted at extra input of add, so the non-inplace add has bf16 and fp32 inputs,
|
| 805 |
+
# the output dtype will be float32.
|
| 806 |
+
# For inplace add, there is a extra to_bf16 node at add output, so the fusion pattern has bfloat16 output.
|
| 807 |
+
torch.bfloat16,
|
| 808 |
+
binary_unary_attr, # binary_unary_attr
|
| 809 |
+
)
|
| 810 |
+
else:
|
| 811 |
+
_register_quantized_conv_binary_lowering(
|
| 812 |
+
patterns,
|
| 813 |
+
1, # pass_number
|
| 814 |
+
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
| 815 |
+
torch.float32,
|
| 816 |
+
binary_unary_attr, # binary_unary_attr
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
# Priority 3: QConv2d Binary pattern with fp32/bfloat16 output
|
| 820 |
+
binary_replace_float_out_patterns = {
|
| 821 |
+
BinaryUnaryAttr("sum", 1.0, "none", [], ""): generate_pattern_with_binary(
|
| 822 |
+
aten.add.Tensor,
|
| 823 |
+
get_dequantize_qconv_pt2e_pattern(1),
|
| 824 |
+
KeywordArg("accum_after_dequant"),
|
| 825 |
+
int8_mixed_bf16_with_inplace_add,
|
| 826 |
+
),
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
for (
|
| 830 |
+
binary_unary_attr,
|
| 831 |
+
patterns,
|
| 832 |
+
) in binary_replace_float_out_patterns.items():
|
| 833 |
+
_register_quantized_conv_binary_lowering(
|
| 834 |
+
patterns,
|
| 835 |
+
1 if int8_mixed_bf16_with_inplace_add else 2, # pass_number
|
| 836 |
+
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
| 837 |
+
# Same output dtype setting as conv-add-relu pattern
|
| 838 |
+
torch.bfloat16 if int8_mixed_bf16_with_inplace_add else torch.float32,
|
| 839 |
+
binary_unary_attr, # binary_unary_attr
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
def _is_valid_quantized_maxpool2d_optimization_pattern():
|
| 844 |
+
def fn(match):
|
| 845 |
+
# Only match the pattern which max_pool2d_with_indices returns value
|
| 846 |
+
# instead of indices.
|
| 847 |
+
get_item_node = filter_nodes(match.nodes, operator.getitem)[0]
|
| 848 |
+
return get_item_node.args[1] == 0
|
| 849 |
+
|
| 850 |
+
return fn
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
def _register_quantized_maxpool2d_lowering(
|
| 854 |
+
pattern,
|
| 855 |
+
computation_op,
|
| 856 |
+
):
|
| 857 |
+
@register_lowering_pattern(
|
| 858 |
+
pattern,
|
| 859 |
+
extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(),
|
| 860 |
+
)
|
| 861 |
+
def qmaxpool2d(match: Match, *args, **kwargs):
|
| 862 |
+
x = kwargs["x"]
|
| 863 |
+
kernel_size = kwargs["kernel_size"]
|
| 864 |
+
stride = kwargs["stride"] if ("stride" in kwargs) else None
|
| 865 |
+
padding = kwargs["padding"] if ("padding" in kwargs) else 0
|
| 866 |
+
dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1
|
| 867 |
+
ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False
|
| 868 |
+
|
| 869 |
+
if padding == 0:
|
| 870 |
+
padding = [0, 0]
|
| 871 |
+
if dilation == 1:
|
| 872 |
+
dilation = [1, 1]
|
| 873 |
+
if not stride:
|
| 874 |
+
stride = kernel_size
|
| 875 |
+
kernel_size = pad_listlike(kernel_size, 2)
|
| 876 |
+
stride = pad_listlike(stride, 2)
|
| 877 |
+
padding = pad_listlike(padding, 2)
|
| 878 |
+
dilation = pad_listlike(dilation, 2)
|
| 879 |
+
|
| 880 |
+
assert len(kernel_size) == 2
|
| 881 |
+
assert len(stride) == 2
|
| 882 |
+
assert len(padding) == 2
|
| 883 |
+
assert len(dilation) == 2
|
| 884 |
+
|
| 885 |
+
computation_args = (
|
| 886 |
+
x,
|
| 887 |
+
kernel_size,
|
| 888 |
+
stride,
|
| 889 |
+
padding,
|
| 890 |
+
dilation,
|
| 891 |
+
ceil_mode,
|
| 892 |
+
)
|
| 893 |
+
computation_args, _ = require_channels_last(computation_op, *computation_args)
|
| 894 |
+
return L[computation_op](*computation_args)
|
| 895 |
+
|
| 896 |
+
return qmaxpool2d
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
def _register_quantization_maxpool2d():
|
| 900 |
+
# Currently, the default parameters are not in FX Graph generated by Dynamo export.
|
| 901 |
+
# So, if user defines nn.MaxPool2d with different assignment of default parameter,
|
| 902 |
+
# it will generate graph with different number of input nodes and hence
|
| 903 |
+
# different pattern to be matched.
|
| 904 |
+
# Refer to the issue: https://github.com/pytorch/pytorch/issues/105901
|
| 905 |
+
max_pool2d_args_list = [
|
| 906 |
+
[
|
| 907 |
+
KeywordArg("stride"),
|
| 908 |
+
],
|
| 909 |
+
[
|
| 910 |
+
KeywordArg("stride"),
|
| 911 |
+
KeywordArg("padding"),
|
| 912 |
+
],
|
| 913 |
+
[
|
| 914 |
+
KeywordArg("stride"),
|
| 915 |
+
KeywordArg("padding"),
|
| 916 |
+
KeywordArg("dilation"),
|
| 917 |
+
],
|
| 918 |
+
[
|
| 919 |
+
KeywordArg("stride"),
|
| 920 |
+
KeywordArg("padding"),
|
| 921 |
+
KeywordArg("dilation"),
|
| 922 |
+
KeywordArg("ceil_mode"),
|
| 923 |
+
],
|
| 924 |
+
]
|
| 925 |
+
|
| 926 |
+
for max_pool2d_args in max_pool2d_args_list:
|
| 927 |
+
dequantize_maxpool2d_pattern = CallFunction(
|
| 928 |
+
aten.max_pool2d_with_indices.default,
|
| 929 |
+
dequantize_per_tensor_activation_pattern,
|
| 930 |
+
KeywordArg("kernel_size"),
|
| 931 |
+
*max_pool2d_args,
|
| 932 |
+
)
|
| 933 |
+
dequantize_maxpool2d_get_item_pattern = CallFunction(
|
| 934 |
+
operator.getitem,
|
| 935 |
+
dequantize_maxpool2d_pattern,
|
| 936 |
+
Arg(),
|
| 937 |
+
)
|
| 938 |
+
_register_quantized_maxpool2d_lowering(
|
| 939 |
+
generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern),
|
| 940 |
+
quantized.max_pool2d.default,
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
def _is_input_output_same_scale_zp(check_node):
|
| 945 |
+
def fn(match):
|
| 946 |
+
# Ensure all the inputs and output has same scale and zero point
|
| 947 |
+
# Step 1: Check inputs/output zero point
|
| 948 |
+
sub_nodes = filter_nodes(match.nodes, aten.sub.Tensor)
|
| 949 |
+
zero_points = [node.args[1] for node in sub_nodes]
|
| 950 |
+
add_nodes = filter_nodes(match.nodes, aten.add.Tensor)
|
| 951 |
+
assert len(add_nodes) == 1, "expect only 1 add node at output quant pattern"
|
| 952 |
+
zero_points.append(add_nodes[0].args[1])
|
| 953 |
+
if not all(zero_point == zero_points[0] for zero_point in zero_points):
|
| 954 |
+
return False
|
| 955 |
+
|
| 956 |
+
# Step 2: Check inputs/output scale
|
| 957 |
+
mul_nodes = filter_nodes(match.nodes, aten.mul.Tensor)
|
| 958 |
+
# We need to find mul node at output since the scale value is reciprocal to input scale.
|
| 959 |
+
# Mul node at output should connect to cat node directly.
|
| 960 |
+
scales = [
|
| 961 |
+
(
|
| 962 |
+
mul_node.args[1]
|
| 963 |
+
if mul_node.args[0].target is check_node # type: ignore[union-attr]
|
| 964 |
+
else 1.0 / mul_node.args[1] # type: ignore[operator]
|
| 965 |
+
)
|
| 966 |
+
for mul_node in mul_nodes
|
| 967 |
+
]
|
| 968 |
+
if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type]
|
| 969 |
+
return False
|
| 970 |
+
|
| 971 |
+
return True
|
| 972 |
+
|
| 973 |
+
return fn
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
def _register_quantized_cat_lowering(
|
| 977 |
+
pattern,
|
| 978 |
+
computation_op,
|
| 979 |
+
):
|
| 980 |
+
@register_lowering_pattern(
|
| 981 |
+
pattern,
|
| 982 |
+
extra_check=_is_input_output_same_scale_zp(aten.cat.default),
|
| 983 |
+
)
|
| 984 |
+
def qcat(match: Match, inputs, dim, **kwargs):
|
| 985 |
+
# inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...]
|
| 986 |
+
uint8_inputs = [input[0] for input in inputs]
|
| 987 |
+
return L[computation_op](uint8_inputs, dim)
|
| 988 |
+
|
| 989 |
+
return qcat
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
_raw_dequantize_per_tensor_activation_pattern = CallFunction(
|
| 993 |
+
aten.mul.Tensor,
|
| 994 |
+
CallFunction(
|
| 995 |
+
aten.sub.Tensor,
|
| 996 |
+
CallFunction(
|
| 997 |
+
prims.convert_element_type.default,
|
| 998 |
+
Arg(),
|
| 999 |
+
Arg(),
|
| 1000 |
+
),
|
| 1001 |
+
Arg(),
|
| 1002 |
+
),
|
| 1003 |
+
Arg(),
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
def _register_quantization_cat():
|
| 1008 |
+
dequantize_cat_pattern = CallFunction(
|
| 1009 |
+
aten.cat.default,
|
| 1010 |
+
ListOf(_raw_dequantize_per_tensor_activation_pattern),
|
| 1011 |
+
KeywordArg("dim"),
|
| 1012 |
+
)
|
| 1013 |
+
_register_quantized_cat_lowering(
|
| 1014 |
+
generate_pattern_with_output_quant(dequantize_cat_pattern),
|
| 1015 |
+
aten.cat,
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
def _register_quantized_reshape_lowering(
|
| 1020 |
+
pattern,
|
| 1021 |
+
computation_op,
|
| 1022 |
+
):
|
| 1023 |
+
@register_lowering_pattern(
|
| 1024 |
+
pattern,
|
| 1025 |
+
extra_check=_is_input_output_same_scale_zp(aten.reshape.default),
|
| 1026 |
+
)
|
| 1027 |
+
def qreshape(match: Match, *args, **kwargs):
|
| 1028 |
+
qx = kwargs["x"]
|
| 1029 |
+
shape = kwargs["shape"]
|
| 1030 |
+
counters["inductor"]["qreshape_matcher_count"] += 1
|
| 1031 |
+
counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes)
|
| 1032 |
+
return L[computation_op](qx, shape)
|
| 1033 |
+
|
| 1034 |
+
return qreshape
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
+
def _register_quantization_reshape():
|
| 1038 |
+
dequantize_reshape_pattern = CallFunction(
|
| 1039 |
+
torch.ops.aten.reshape.default,
|
| 1040 |
+
dequantize_per_tensor_activation_pattern,
|
| 1041 |
+
KeywordArg("shape"),
|
| 1042 |
+
)
|
| 1043 |
+
_register_quantized_reshape_lowering(
|
| 1044 |
+
generate_pattern_with_output_quant(dequantize_reshape_pattern),
|
| 1045 |
+
aten.reshape,
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
|
| 1049 |
+
def _register_quantization_lowerings():
|
| 1050 |
+
_register_quantization_unary_fusion()
|
| 1051 |
+
_register_quantization_binary_fusion()
|
| 1052 |
+
_register_quantization_maxpool2d()
|
| 1053 |
+
_register_quantization_cat()
|
| 1054 |
+
_register_quantization_reshape()
|
| 1055 |
+
|
| 1056 |
+
|
| 1057 |
+
def _is_valid_dequant_promotion_pattern(dtype=torch.float32):
|
| 1058 |
+
def _inner(match):
|
| 1059 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1060 |
+
dequant_pattern_end_node = match.output_node()
|
| 1061 |
+
if dequant_pattern_end_node.target not in [
|
| 1062 |
+
aten.mul.Tensor,
|
| 1063 |
+
prims.convert_element_type.default,
|
| 1064 |
+
aten.reshape.default,
|
| 1065 |
+
]:
|
| 1066 |
+
return False
|
| 1067 |
+
|
| 1068 |
+
if dequant_pattern_end_node.target is aten.reshape.default:
|
| 1069 |
+
mul_node = (
|
| 1070 |
+
dequant_pattern_end_node.args[0] # pattern: linear <- reshape <- mul
|
| 1071 |
+
if dtype == torch.float32
|
| 1072 |
+
else dequant_pattern_end_node.args[0].args[
|
| 1073 |
+
0
|
| 1074 |
+
] # pattern: linear <- reshape <- to_bf16 <- mul
|
| 1075 |
+
)
|
| 1076 |
+
else:
|
| 1077 |
+
mul_node = (
|
| 1078 |
+
dequant_pattern_end_node # pattern: linear <- mul
|
| 1079 |
+
if dtype == torch.float32
|
| 1080 |
+
else dequant_pattern_end_node.args[
|
| 1081 |
+
0
|
| 1082 |
+
] # pattern: linear <- to_bf16 <- mul
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
sub_node = mul_node.args[0]
|
| 1086 |
+
to_fp32_node = sub_node.args[0]
|
| 1087 |
+
if (
|
| 1088 |
+
mul_node.target is aten.mul.Tensor
|
| 1089 |
+
and sub_node.target is aten.sub.Tensor
|
| 1090 |
+
and to_fp32_node.target is prims.convert_element_type.default
|
| 1091 |
+
and len(list(dequant_pattern_end_node.users)) > 1
|
| 1092 |
+
):
|
| 1093 |
+
# If dequant pattern has more than 1 users, then do dequant promoted
|
| 1094 |
+
return True
|
| 1095 |
+
return False
|
| 1096 |
+
|
| 1097 |
+
return _inner
|
| 1098 |
+
|
| 1099 |
+
|
| 1100 |
+
def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
|
| 1101 |
+
@register_freezing_graph_pattern(
|
| 1102 |
+
pattern,
|
| 1103 |
+
extra_check=_is_valid_dequant_promotion_pattern(dtype),
|
| 1104 |
+
pass_number=pass_number,
|
| 1105 |
+
)
|
| 1106 |
+
def dequant_promotion(match: Match, *args, **kwargs):
|
| 1107 |
+
# Dequant_promotion will transform
|
| 1108 |
+
# graph 1:
|
| 1109 |
+
# quant
|
| 1110 |
+
# + - - - | - - - +
|
| 1111 |
+
# | dequant |
|
| 1112 |
+
# | / \ |
|
| 1113 |
+
# | node1 node2 |
|
| 1114 |
+
# + - | - - - | - +
|
| 1115 |
+
# quant quant
|
| 1116 |
+
# into:
|
| 1117 |
+
# graph 2:
|
| 1118 |
+
# quant
|
| 1119 |
+
# + - - / - \ - - +
|
| 1120 |
+
# |dequant dequant|
|
| 1121 |
+
# | | | |
|
| 1122 |
+
# | node1 node2 |
|
| 1123 |
+
# + - | - - - | - +
|
| 1124 |
+
# quant quant
|
| 1125 |
+
# In graph 1, the dequant node is shared by node1 and node2,
|
| 1126 |
+
# as a result, neither node1 nor node2 could form an int8
|
| 1127 |
+
# fusion pattern.
|
| 1128 |
+
# After this transformation, the graph 2 could hit the int8
|
| 1129 |
+
# fusion pattern: dequant-node-quant, respectively for
|
| 1130 |
+
# node1 and node2.
|
| 1131 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1132 |
+
|
| 1133 |
+
def clone_to_new_node(graph, source_node, user_node):
|
| 1134 |
+
# Clone the source_node to a new node
|
| 1135 |
+
# Replace user_node's input from source_node to new_node
|
| 1136 |
+
assert (
|
| 1137 |
+
source_node.op == "call_function"
|
| 1138 |
+
), "clone_to_new_node only support node.op call_function"
|
| 1139 |
+
with graph.inserting_before(user_node):
|
| 1140 |
+
new_node = graph.call_function(
|
| 1141 |
+
source_node.target,
|
| 1142 |
+
args=source_node.args,
|
| 1143 |
+
kwargs=source_node.kwargs,
|
| 1144 |
+
)
|
| 1145 |
+
new_node.meta = copy.copy(source_node.meta)
|
| 1146 |
+
user_node.replace_input_with(source_node, new_node)
|
| 1147 |
+
return new_node
|
| 1148 |
+
|
| 1149 |
+
# Find the start node and end node of a dequant pattern
|
| 1150 |
+
# * End node should be the match.output_node()
|
| 1151 |
+
# * Start node should be the node of dtype convert to float32
|
| 1152 |
+
dequant_pattern_end_node = match.output_node()
|
| 1153 |
+
assert dequant_pattern_end_node.target in [
|
| 1154 |
+
aten.mul.Tensor,
|
| 1155 |
+
prims.convert_element_type.default,
|
| 1156 |
+
aten.reshape.default,
|
| 1157 |
+
]
|
| 1158 |
+
|
| 1159 |
+
# For a dequant pattern, we should expect see the node list as:
|
| 1160 |
+
# * OPT(aten.reshape.default)
|
| 1161 |
+
# * OPT(prims.convert_element_type.default) (to_bf16)
|
| 1162 |
+
# * aten.mul
|
| 1163 |
+
# * aten.sub
|
| 1164 |
+
# * prims.convert_element_type.default (to_fp32)
|
| 1165 |
+
def _find_first_node_in_dequant_pattern(_node):
|
| 1166 |
+
if (
|
| 1167 |
+
_node.target is prims.convert_element_type.default
|
| 1168 |
+
and _node.args[1] == torch.float32
|
| 1169 |
+
):
|
| 1170 |
+
# For a dequant pattern, we expect the start node is a to_fp32 node
|
| 1171 |
+
return _node
|
| 1172 |
+
else:
|
| 1173 |
+
assert (
|
| 1174 |
+
len(_node.args) >= 1
|
| 1175 |
+
), "In in dequant pattern, each node should have more than 1 arg."
|
| 1176 |
+
return _find_first_node_in_dequant_pattern(_node.args[0])
|
| 1177 |
+
|
| 1178 |
+
dequant_pattern_start_node = _find_first_node_in_dequant_pattern(
|
| 1179 |
+
dequant_pattern_end_node
|
| 1180 |
+
)
|
| 1181 |
+
|
| 1182 |
+
# Clone the dequant pattern for each user node
|
| 1183 |
+
graph = match.graph
|
| 1184 |
+
user_node_list = list(dequant_pattern_end_node.users)
|
| 1185 |
+
for user_node in user_node_list[1:]:
|
| 1186 |
+
_source_node = dequant_pattern_end_node
|
| 1187 |
+
_user_node = user_node
|
| 1188 |
+
while _source_node != dequant_pattern_start_node.args[0]:
|
| 1189 |
+
_user_node = clone_to_new_node(graph, _source_node, _user_node)
|
| 1190 |
+
_source_node = _source_node.args[0] # type: ignore[assignment]
|
| 1191 |
+
|
| 1192 |
+
counters["inductor"]["dequant_promotion_matcher_count"] += 1
|
| 1193 |
+
counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes)
|
| 1194 |
+
|
| 1195 |
+
|
| 1196 |
+
def _is_valid_dequant_conv2d_pattern(dtype):
|
| 1197 |
+
def _inner(match):
|
| 1198 |
+
# Here we do some further check to ensure:
|
| 1199 |
+
# 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now.
|
| 1200 |
+
# 2. The dequant pattern has only 1 user of conv2d node.
|
| 1201 |
+
# If these conditions don't meet, we will not
|
| 1202 |
+
# insert weight prepack node into the matched pattern.
|
| 1203 |
+
conv_node = match.output_node()
|
| 1204 |
+
assert conv_node.target is aten.convolution.default
|
| 1205 |
+
input_meta_value = conv_node.args[0].meta.get("val")
|
| 1206 |
+
weight_meta_value = conv_node.args[1].meta.get("val")
|
| 1207 |
+
for meta_value in [input_meta_value, weight_meta_value]:
|
| 1208 |
+
if (
|
| 1209 |
+
meta_value is None
|
| 1210 |
+
or meta_value.device.type != "cpu"
|
| 1211 |
+
or meta_value.dim() != 4
|
| 1212 |
+
):
|
| 1213 |
+
# Only support conv2d now
|
| 1214 |
+
return False
|
| 1215 |
+
|
| 1216 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1217 |
+
if dtype == torch.float32:
|
| 1218 |
+
mul_node = conv_node.args[0]
|
| 1219 |
+
else:
|
| 1220 |
+
convert_to_bf16 = conv_node.args[0]
|
| 1221 |
+
mul_node = convert_to_bf16.args[0]
|
| 1222 |
+
sub_node = mul_node.args[0]
|
| 1223 |
+
to_fp32_node = sub_node.args[0]
|
| 1224 |
+
|
| 1225 |
+
assert to_fp32_node.target is prims.convert_element_type.default
|
| 1226 |
+
assert sub_node.target is aten.sub.Tensor
|
| 1227 |
+
assert mul_node.target is aten.mul.Tensor
|
| 1228 |
+
if (
|
| 1229 |
+
len(list(to_fp32_node.users)) != 1
|
| 1230 |
+
or len(list(sub_node.users)) != 1
|
| 1231 |
+
or len(list(mul_node.users)) != 1
|
| 1232 |
+
):
|
| 1233 |
+
# Ensure the dequant pattern only has 1 user
|
| 1234 |
+
# since we will delete the dequant pattern here
|
| 1235 |
+
return False
|
| 1236 |
+
return True
|
| 1237 |
+
|
| 1238 |
+
return _inner
|
| 1239 |
+
|
| 1240 |
+
|
| 1241 |
+
def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32):
|
| 1242 |
+
@register_freezing_graph_pattern(
|
| 1243 |
+
pattern,
|
| 1244 |
+
extra_check=_is_valid_dequant_conv2d_pattern(dtype),
|
| 1245 |
+
pass_number=pass_number,
|
| 1246 |
+
)
|
| 1247 |
+
def qconv_weight_prepack(match: Match, *args, **kwargs):
|
| 1248 |
+
"""
|
| 1249 |
+
Match the pattern:
|
| 1250 |
+
int8 activation
|
| 1251 |
+
|
|
| 1252 |
+
dequant_per_tensor
|
| 1253 |
+
|
|
| 1254 |
+
Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight
|
| 1255 |
+
|
| 1256 |
+
Insert weight prepack node and change the pattern to:
|
| 1257 |
+
int8 activation
|
| 1258 |
+
|
|
| 1259 |
+
onednn.qconv2d_pointwise <- onednn.qconv_prepack <- int8_weight
|
| 1260 |
+
"""
|
| 1261 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1262 |
+
conv_node = match.output_node()
|
| 1263 |
+
assert conv_node.target is aten.convolution.default
|
| 1264 |
+
if dtype == torch.float32:
|
| 1265 |
+
mul_node = conv_node.args[0]
|
| 1266 |
+
else:
|
| 1267 |
+
convert_to_bf16 = conv_node.args[0]
|
| 1268 |
+
mul_node = convert_to_bf16.args[0] # type: ignore[union-attr]
|
| 1269 |
+
sub_node = mul_node.args[0] # type: ignore[union-attr]
|
| 1270 |
+
to_fp32_node = sub_node.args[0] # type: ignore[union-attr]
|
| 1271 |
+
has_clone_to_channel_last_node_in_pattern = (
|
| 1272 |
+
conv_node.args[1].target is aten.clone.default # type: ignore[union-attr]
|
| 1273 |
+
)
|
| 1274 |
+
clone_node = (
|
| 1275 |
+
conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None
|
| 1276 |
+
)
|
| 1277 |
+
|
| 1278 |
+
if dtype == torch.float32:
|
| 1279 |
+
dequant_per_channel = (
|
| 1280 |
+
clone_node.args[0] # type: ignore[union-attr]
|
| 1281 |
+
if has_clone_to_channel_last_node_in_pattern
|
| 1282 |
+
else conv_node.args[1]
|
| 1283 |
+
)
|
| 1284 |
+
else:
|
| 1285 |
+
weight_to_bf16_node = (
|
| 1286 |
+
clone_node.args[0] # type: ignore[union-attr]
|
| 1287 |
+
if has_clone_to_channel_last_node_in_pattern
|
| 1288 |
+
else conv_node.args[1]
|
| 1289 |
+
)
|
| 1290 |
+
dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr]
|
| 1291 |
+
|
| 1292 |
+
assert (
|
| 1293 |
+
dequant_per_channel.target # type: ignore[union-attr]
|
| 1294 |
+
is quantized_decomposed.dequantize_per_channel.default
|
| 1295 |
+
)
|
| 1296 |
+
|
| 1297 |
+
# Activation QParams
|
| 1298 |
+
qx, x_zp, x_scale = (
|
| 1299 |
+
kwargs["x"],
|
| 1300 |
+
kwargs["x_zp"],
|
| 1301 |
+
kwargs["x_scale"],
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
# Weight QParams
|
| 1305 |
+
qw, w_scale, w_zp = (
|
| 1306 |
+
kwargs["q_weight"],
|
| 1307 |
+
kwargs["w_scale"],
|
| 1308 |
+
kwargs["w_zp"],
|
| 1309 |
+
)
|
| 1310 |
+
|
| 1311 |
+
# Conv Params
|
| 1312 |
+
bias, stride, padding, dilation, groups = (
|
| 1313 |
+
kwargs["b"],
|
| 1314 |
+
kwargs["stride"],
|
| 1315 |
+
kwargs["padding"],
|
| 1316 |
+
kwargs["dilation"],
|
| 1317 |
+
kwargs["groups"],
|
| 1318 |
+
)
|
| 1319 |
+
|
| 1320 |
+
x_shape = qx.meta.get("tensor_meta").shape
|
| 1321 |
+
if has_free_symbols(x_shape):
|
| 1322 |
+
# For dynamic shape case, we can't get activation shape ahead of runtime.
|
| 1323 |
+
x_shape = None
|
| 1324 |
+
graph = match.graph
|
| 1325 |
+
with graph.inserting_before(conv_node):
|
| 1326 |
+
# Insert weight prepack node and the QConv node
|
| 1327 |
+
packed_weight_inputs = (
|
| 1328 |
+
qw,
|
| 1329 |
+
w_scale,
|
| 1330 |
+
x_scale,
|
| 1331 |
+
x_zp,
|
| 1332 |
+
stride,
|
| 1333 |
+
padding,
|
| 1334 |
+
dilation,
|
| 1335 |
+
groups,
|
| 1336 |
+
x_shape,
|
| 1337 |
+
)
|
| 1338 |
+
packed_weight_op = torch.ops.onednn.qconv_prepack
|
| 1339 |
+
prepack_weight_node = graph.call_function(
|
| 1340 |
+
packed_weight_op, args=packed_weight_inputs
|
| 1341 |
+
)
|
| 1342 |
+
|
| 1343 |
+
new_args: Tuple[Any, ...] = (
|
| 1344 |
+
qx,
|
| 1345 |
+
x_scale,
|
| 1346 |
+
x_zp,
|
| 1347 |
+
prepack_weight_node,
|
| 1348 |
+
w_scale,
|
| 1349 |
+
w_zp,
|
| 1350 |
+
bias,
|
| 1351 |
+
stride,
|
| 1352 |
+
padding,
|
| 1353 |
+
dilation,
|
| 1354 |
+
groups,
|
| 1355 |
+
1.0, # inv_output_scale
|
| 1356 |
+
0, # output_zero_point
|
| 1357 |
+
dtype, # output_dtype
|
| 1358 |
+
"none", # attr
|
| 1359 |
+
[], # scalars
|
| 1360 |
+
"", # algorithm
|
| 1361 |
+
)
|
| 1362 |
+
new_conv_node = graph.call_function(
|
| 1363 |
+
torch.ops.onednn.qconv2d_pointwise.default, args=new_args
|
| 1364 |
+
)
|
| 1365 |
+
conv_node.replace_all_uses_with(new_conv_node)
|
| 1366 |
+
new_conv_node.meta.update(conv_node.meta)
|
| 1367 |
+
|
| 1368 |
+
# Erase the original conv node
|
| 1369 |
+
graph.erase_node(conv_node)
|
| 1370 |
+
# Erase the dequant pattern
|
| 1371 |
+
if dtype == torch.bfloat16:
|
| 1372 |
+
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined]
|
| 1373 |
+
# Erase the dequant pattern
|
| 1374 |
+
graph.erase_node(mul_node)
|
| 1375 |
+
graph.erase_node(sub_node)
|
| 1376 |
+
graph.erase_node(to_fp32_node)
|
| 1377 |
+
# Erase the dequant per channel pattern
|
| 1378 |
+
if clone_node is not None:
|
| 1379 |
+
graph.erase_node(clone_node)
|
| 1380 |
+
if dtype == torch.bfloat16:
|
| 1381 |
+
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
|
| 1382 |
+
graph.erase_node(dequant_per_channel)
|
| 1383 |
+
counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
|
| 1384 |
+
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
|
| 1385 |
+
match.nodes
|
| 1386 |
+
)
|
| 1387 |
+
|
| 1388 |
+
|
| 1389 |
+
def _generate_dequant_convolution_node_pattern(
|
| 1390 |
+
_dequant_per_channel_pattern, dtype=torch.float32
|
| 1391 |
+
):
|
| 1392 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1393 |
+
dequant_convolution_node_pattern = CallFunction(
|
| 1394 |
+
aten.convolution.default,
|
| 1395 |
+
_may_generate_pattern_with_dtype_convert(
|
| 1396 |
+
dequantize_per_tensor_activation_pattern,
|
| 1397 |
+
KeywordArg("autocast_act_dtype"),
|
| 1398 |
+
dtype == torch.bfloat16,
|
| 1399 |
+
),
|
| 1400 |
+
_dequant_per_channel_pattern,
|
| 1401 |
+
KeywordArg("b"),
|
| 1402 |
+
KeywordArg("stride"),
|
| 1403 |
+
KeywordArg("padding"),
|
| 1404 |
+
KeywordArg("dilation"),
|
| 1405 |
+
KeywordArg("is_transposed"),
|
| 1406 |
+
KeywordArg("out_padding"),
|
| 1407 |
+
KeywordArg("groups"),
|
| 1408 |
+
)
|
| 1409 |
+
return dequant_convolution_node_pattern
|
| 1410 |
+
|
| 1411 |
+
|
| 1412 |
+
def _generate_qconv_weight_prepack_patterns(dtype=torch.float32):
|
| 1413 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1414 |
+
return (
|
| 1415 |
+
_generate_dequant_convolution_node_pattern(
|
| 1416 |
+
dequantize_per_channel_weight_pattern
|
| 1417 |
+
if dtype == torch.float32
|
| 1418 |
+
else dequantize_per_channel_to_bf16_weight_pattern,
|
| 1419 |
+
dtype,
|
| 1420 |
+
),
|
| 1421 |
+
# There is another pattern due to the pass of convert_conv_weights_to_channels_last
|
| 1422 |
+
# https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
|
| 1423 |
+
# Depend on some heuristics, it may or may not insert to(channel_last) node
|
| 1424 |
+
# between convolution and dequant_per_channel node
|
| 1425 |
+
_generate_dequant_convolution_node_pattern(
|
| 1426 |
+
dequantize_per_channel_clone_weight_pattern
|
| 1427 |
+
if dtype == torch.float32
|
| 1428 |
+
else dequantize_per_channel_to_bf16_clone_weight_pattern,
|
| 1429 |
+
dtype,
|
| 1430 |
+
),
|
| 1431 |
+
)
|
| 1432 |
+
|
| 1433 |
+
|
| 1434 |
+
def _get_linear_node(match, input_dim_exceeds_two, input_contiguous):
|
| 1435 |
+
output_reshape_node = None
|
| 1436 |
+
if input_dim_exceeds_two:
|
| 1437 |
+
if input_contiguous:
|
| 1438 |
+
output_reshape_node = match.output_node()
|
| 1439 |
+
assert output_reshape_node.target is aten.reshape.default
|
| 1440 |
+
linear_node = output_reshape_node.args[0]
|
| 1441 |
+
else:
|
| 1442 |
+
linear_nodes = filter_nodes(match.nodes, aten.bmm.default)
|
| 1443 |
+
assert len(linear_nodes) == 1
|
| 1444 |
+
linear_node = linear_nodes[0]
|
| 1445 |
+
else:
|
| 1446 |
+
linear_node = match.output_node()
|
| 1447 |
+
|
| 1448 |
+
assert linear_node.target in (
|
| 1449 |
+
aten.addmm.default,
|
| 1450 |
+
aten.mm.default,
|
| 1451 |
+
aten.bmm.default,
|
| 1452 |
+
)
|
| 1453 |
+
return linear_node, output_reshape_node
|
| 1454 |
+
|
| 1455 |
+
|
| 1456 |
+
def _get_linear_dq_mul_node(
|
| 1457 |
+
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
|
| 1458 |
+
):
|
| 1459 |
+
act_reshape_node = None
|
| 1460 |
+
activation_to_bf16_node = None
|
| 1461 |
+
act_expand_node = None
|
| 1462 |
+
if input_dim_exceeds_two:
|
| 1463 |
+
if input_contiguous:
|
| 1464 |
+
act_reshape_node = linear_node.args[input_index]
|
| 1465 |
+
assert act_reshape_node.target is aten.reshape.default
|
| 1466 |
+
if dtype == torch.float32:
|
| 1467 |
+
# pattern: linear -> reshape -> mul
|
| 1468 |
+
mul_node = act_reshape_node.args[0]
|
| 1469 |
+
else:
|
| 1470 |
+
# pattern: linear -> reshape -> to_bf16 -> mul
|
| 1471 |
+
activation_to_bf16_node = act_reshape_node.args[0]
|
| 1472 |
+
mul_node = activation_to_bf16_node.args[0]
|
| 1473 |
+
else:
|
| 1474 |
+
# bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous
|
| 1475 |
+
act_expand_node = linear_node.args[input_index]
|
| 1476 |
+
assert act_expand_node.target is aten.expand.default
|
| 1477 |
+
if dtype == torch.float32:
|
| 1478 |
+
mul_node = act_expand_node.args[0]
|
| 1479 |
+
else:
|
| 1480 |
+
activation_to_bf16_node = act_expand_node.args[0]
|
| 1481 |
+
mul_node = activation_to_bf16_node.args[0]
|
| 1482 |
+
else:
|
| 1483 |
+
if dtype == torch.float32:
|
| 1484 |
+
# pattern: linear -> mul
|
| 1485 |
+
mul_node = linear_node.args[input_index]
|
| 1486 |
+
else:
|
| 1487 |
+
# pattern: linear -> to_bf16 -> mul
|
| 1488 |
+
activation_to_bf16_node = linear_node.args[input_index]
|
| 1489 |
+
mul_node = activation_to_bf16_node.args[0]
|
| 1490 |
+
return mul_node, act_reshape_node, activation_to_bf16_node, act_expand_node
|
| 1491 |
+
|
| 1492 |
+
|
| 1493 |
+
def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous):
|
| 1494 |
+
def _inner(match):
|
| 1495 |
+
# Check dequant pattern has only 1 user.
|
| 1496 |
+
(
|
| 1497 |
+
linear_node,
|
| 1498 |
+
_,
|
| 1499 |
+
) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
|
| 1500 |
+
|
| 1501 |
+
input_index = 1 if linear_node.target is aten.addmm.default else 0
|
| 1502 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1503 |
+
|
| 1504 |
+
(
|
| 1505 |
+
mul_node,
|
| 1506 |
+
_,
|
| 1507 |
+
_,
|
| 1508 |
+
_,
|
| 1509 |
+
) = _get_linear_dq_mul_node(
|
| 1510 |
+
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
|
| 1511 |
+
)
|
| 1512 |
+
|
| 1513 |
+
sub_node = mul_node.args[0]
|
| 1514 |
+
to_fp32_node = sub_node.args[0]
|
| 1515 |
+
|
| 1516 |
+
assert to_fp32_node.target is prims.convert_element_type.default
|
| 1517 |
+
assert sub_node.target is aten.sub.Tensor
|
| 1518 |
+
assert mul_node.target is aten.mul.Tensor
|
| 1519 |
+
if (
|
| 1520 |
+
len(list(to_fp32_node.users)) != 1
|
| 1521 |
+
or len(list(sub_node.users)) != 1
|
| 1522 |
+
or len(list(mul_node.users)) != 1
|
| 1523 |
+
):
|
| 1524 |
+
# Ensure the dequant pattern only has 1 user
|
| 1525 |
+
# since we will delete the dequant pattern here
|
| 1526 |
+
return False
|
| 1527 |
+
|
| 1528 |
+
# Extra check for bmm pattern
|
| 1529 |
+
if input_dim_exceeds_two and not input_contiguous:
|
| 1530 |
+
# Check for act
|
| 1531 |
+
# Act expand size should be exactly same as act size
|
| 1532 |
+
act_expand_size = match.kwargs["act_expand_size"]
|
| 1533 |
+
act_node = match.kwargs["x"]
|
| 1534 |
+
if not (
|
| 1535 |
+
hasattr(act_node, "meta")
|
| 1536 |
+
and isinstance(act_node.meta.get("val", None), torch.Tensor)
|
| 1537 |
+
and (act_node.meta["val"].size() == torch.Size(act_expand_size))
|
| 1538 |
+
):
|
| 1539 |
+
return False
|
| 1540 |
+
|
| 1541 |
+
# Check for wgt
|
| 1542 |
+
# wgt permute dims should be [1, 0]
|
| 1543 |
+
wgt_permute_dims = match.kwargs["permute_axes"]
|
| 1544 |
+
if wgt_permute_dims != [1, 0]:
|
| 1545 |
+
return False
|
| 1546 |
+
|
| 1547 |
+
# Check below wgt size items:
|
| 1548 |
+
# wgt before expand should with dim 2
|
| 1549 |
+
# Expand size should with dim 3
|
| 1550 |
+
# Expand size[0] should same as act size[0]
|
| 1551 |
+
# Expand size[1] should same as wgt size[1]
|
| 1552 |
+
# Expand size[2] should same as wgt size[0]
|
| 1553 |
+
qweight_node = match.kwargs["q_weight"]
|
| 1554 |
+
wgt_expand_size = match.kwargs["wgt_expand_size"]
|
| 1555 |
+
if not (
|
| 1556 |
+
hasattr(qweight_node, "meta")
|
| 1557 |
+
and isinstance(qweight_node.meta.get("val", None), torch.Tensor)
|
| 1558 |
+
and len(qweight_node.meta["val"].size()) == 2
|
| 1559 |
+
and len(wgt_expand_size) == 3
|
| 1560 |
+
and wgt_expand_size[0] == act_node.meta["val"].size()[0]
|
| 1561 |
+
and wgt_expand_size[1] == qweight_node.meta["val"].size()[1]
|
| 1562 |
+
and wgt_expand_size[2] == qweight_node.meta["val"].size()[0]
|
| 1563 |
+
):
|
| 1564 |
+
return False
|
| 1565 |
+
|
| 1566 |
+
return True
|
| 1567 |
+
|
| 1568 |
+
return _inner
|
| 1569 |
+
|
| 1570 |
+
|
| 1571 |
+
def _register_qlinear_weight_prepack_pass(
|
| 1572 |
+
pattern,
|
| 1573 |
+
pass_number,
|
| 1574 |
+
dtype=torch.float32,
|
| 1575 |
+
input_dim_exceeds_two=False,
|
| 1576 |
+
input_contiguous=True,
|
| 1577 |
+
):
|
| 1578 |
+
@register_freezing_graph_pattern(
|
| 1579 |
+
pattern,
|
| 1580 |
+
extra_check=_is_valid_dequant_linear_pattern(
|
| 1581 |
+
dtype, input_dim_exceeds_two, input_contiguous
|
| 1582 |
+
),
|
| 1583 |
+
pass_number=pass_number,
|
| 1584 |
+
)
|
| 1585 |
+
def qlinear_weight_prepack(match: Match, *args, **kwargs):
|
| 1586 |
+
"""
|
| 1587 |
+
Match the pattern:
|
| 1588 |
+
int8 activation
|
| 1589 |
+
|
|
| 1590 |
+
dequant_per_tensor
|
| 1591 |
+
|
|
| 1592 |
+
mm/addmm <- t <- dequant_per_channel <- int8_weight
|
| 1593 |
+
|
| 1594 |
+
Insert weight prepack node and change the pattern to:
|
| 1595 |
+
int8 activation
|
| 1596 |
+
|
|
| 1597 |
+
onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight
|
| 1598 |
+
"""
|
| 1599 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1600 |
+
(
|
| 1601 |
+
linear_node,
|
| 1602 |
+
output_reshape_node,
|
| 1603 |
+
) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
|
| 1604 |
+
input_index = 1 if linear_node.target is aten.addmm.default else 0
|
| 1605 |
+
weight_index = input_index + 1
|
| 1606 |
+
|
| 1607 |
+
(
|
| 1608 |
+
mul_node,
|
| 1609 |
+
act_reshape_node,
|
| 1610 |
+
activation_to_bf16_node,
|
| 1611 |
+
act_expand_node,
|
| 1612 |
+
) = _get_linear_dq_mul_node(
|
| 1613 |
+
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
|
| 1614 |
+
)
|
| 1615 |
+
|
| 1616 |
+
sub_node = mul_node.args[0]
|
| 1617 |
+
to_fp32_node = sub_node.args[0]
|
| 1618 |
+
|
| 1619 |
+
if input_dim_exceeds_two and not input_contiguous:
|
| 1620 |
+
wgt_expand_node = linear_node.args[weight_index]
|
| 1621 |
+
assert wgt_expand_node.target is aten.expand.default
|
| 1622 |
+
t_node = wgt_expand_node.args[0]
|
| 1623 |
+
else:
|
| 1624 |
+
t_node = linear_node.args[weight_index]
|
| 1625 |
+
|
| 1626 |
+
if dtype == torch.float32:
|
| 1627 |
+
dequant_per_channel = t_node.args[0]
|
| 1628 |
+
else:
|
| 1629 |
+
weight_to_bf16_node = t_node.args[0]
|
| 1630 |
+
dequant_per_channel = weight_to_bf16_node.args[0]
|
| 1631 |
+
assert (
|
| 1632 |
+
dequant_per_channel.target
|
| 1633 |
+
is quantized_decomposed.dequantize_per_channel.default
|
| 1634 |
+
)
|
| 1635 |
+
|
| 1636 |
+
# Activation QParams
|
| 1637 |
+
qx, x_zp, x_scale = (
|
| 1638 |
+
kwargs["x"],
|
| 1639 |
+
kwargs["x_zp"],
|
| 1640 |
+
kwargs["x_scale"],
|
| 1641 |
+
)
|
| 1642 |
+
|
| 1643 |
+
# Weight QParams
|
| 1644 |
+
qw, w_scale, w_zp = (
|
| 1645 |
+
kwargs["q_weight"],
|
| 1646 |
+
kwargs["w_scale"],
|
| 1647 |
+
kwargs["w_zp"],
|
| 1648 |
+
)
|
| 1649 |
+
|
| 1650 |
+
# Params
|
| 1651 |
+
bias = kwargs["b"] if "b" in kwargs else None
|
| 1652 |
+
|
| 1653 |
+
x_shape = qx.meta.get("tensor_meta").shape
|
| 1654 |
+
if has_free_symbols(x_shape):
|
| 1655 |
+
# For dynamic shape case, we can't get activation shape ahead of runtime.
|
| 1656 |
+
x_shape = None
|
| 1657 |
+
graph = match.graph
|
| 1658 |
+
with graph.inserting_before(linear_node):
|
| 1659 |
+
# Insert weight prepack node and the qlinear node
|
| 1660 |
+
packed_weight_inputs = (
|
| 1661 |
+
qw,
|
| 1662 |
+
x_shape,
|
| 1663 |
+
)
|
| 1664 |
+
packed_weight_op = torch.ops.onednn.qlinear_prepack
|
| 1665 |
+
prepack_weight_node = graph.call_function(
|
| 1666 |
+
packed_weight_op, args=packed_weight_inputs
|
| 1667 |
+
)
|
| 1668 |
+
|
| 1669 |
+
new_args: Tuple[Any, ...] = (
|
| 1670 |
+
qx,
|
| 1671 |
+
x_scale,
|
| 1672 |
+
x_zp,
|
| 1673 |
+
prepack_weight_node,
|
| 1674 |
+
w_scale,
|
| 1675 |
+
w_zp,
|
| 1676 |
+
bias,
|
| 1677 |
+
1.0, # output_scale
|
| 1678 |
+
0, # output_zero_point
|
| 1679 |
+
dtype, # output_dtype
|
| 1680 |
+
"none", # post op name
|
| 1681 |
+
[], # post op args
|
| 1682 |
+
"", # post op algorithm
|
| 1683 |
+
)
|
| 1684 |
+
Node = torch.fx.node.Node
|
| 1685 |
+
if isinstance(x_scale, Node) and isinstance(x_zp, Node):
|
| 1686 |
+
new_linear_node = graph.call_function(
|
| 1687 |
+
torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
|
| 1688 |
+
)
|
| 1689 |
+
else:
|
| 1690 |
+
new_linear_node = graph.call_function(
|
| 1691 |
+
torch.ops.onednn.qlinear_pointwise.default, args=new_args
|
| 1692 |
+
)
|
| 1693 |
+
if input_dim_exceeds_two:
|
| 1694 |
+
if input_contiguous:
|
| 1695 |
+
output_reshape_node.replace_all_uses_with(new_linear_node)
|
| 1696 |
+
new_linear_node.meta.update(output_reshape_node.meta)
|
| 1697 |
+
else:
|
| 1698 |
+
if bias:
|
| 1699 |
+
output_add_node_for_bias = match.output_node()
|
| 1700 |
+
assert output_add_node_for_bias.target is aten.add.Tensor
|
| 1701 |
+
output_add_node_for_bias.replace_all_uses_with(new_linear_node)
|
| 1702 |
+
new_linear_node.meta.update(output_add_node_for_bias.meta)
|
| 1703 |
+
else:
|
| 1704 |
+
linear_node.replace_all_uses_with(new_linear_node)
|
| 1705 |
+
new_linear_node.meta.update(linear_node.meta)
|
| 1706 |
+
else:
|
| 1707 |
+
linear_node.replace_all_uses_with(new_linear_node)
|
| 1708 |
+
new_linear_node.meta.update(linear_node.meta)
|
| 1709 |
+
|
| 1710 |
+
# Erase the original linear node
|
| 1711 |
+
if input_dim_exceeds_two:
|
| 1712 |
+
if input_contiguous:
|
| 1713 |
+
graph.erase_node(output_reshape_node)
|
| 1714 |
+
elif not input_contiguous and bias:
|
| 1715 |
+
graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined]
|
| 1716 |
+
graph.erase_node(linear_node)
|
| 1717 |
+
if input_dim_exceeds_two:
|
| 1718 |
+
if input_contiguous:
|
| 1719 |
+
graph.erase_node(act_reshape_node)
|
| 1720 |
+
else:
|
| 1721 |
+
graph.erase_node(act_expand_node)
|
| 1722 |
+
graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined]
|
| 1723 |
+
if dtype == torch.bfloat16:
|
| 1724 |
+
graph.erase_node(activation_to_bf16_node)
|
| 1725 |
+
# Erase the dequant pattern
|
| 1726 |
+
graph.erase_node(mul_node)
|
| 1727 |
+
graph.erase_node(sub_node)
|
| 1728 |
+
graph.erase_node(to_fp32_node)
|
| 1729 |
+
# Erase the dequant per channel pattern
|
| 1730 |
+
graph.erase_node(t_node)
|
| 1731 |
+
if dtype == torch.bfloat16:
|
| 1732 |
+
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
|
| 1733 |
+
graph.erase_node(dequant_per_channel)
|
| 1734 |
+
|
| 1735 |
+
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
|
| 1736 |
+
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
|
| 1737 |
+
match.nodes
|
| 1738 |
+
)
|
| 1739 |
+
|
| 1740 |
+
|
| 1741 |
+
def _generate_dequant_linear_node_pattern(
|
| 1742 |
+
_dequant_per_channel_pattern, dtype=torch.float32, input_dim_exceeds_two=False
|
| 1743 |
+
):
|
| 1744 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1745 |
+
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
|
| 1746 |
+
dequant_linear_bias_pattern = _may_generate_pattern_with_reshape(
|
| 1747 |
+
CallFunction(
|
| 1748 |
+
aten.addmm.default,
|
| 1749 |
+
KeywordArg("b"),
|
| 1750 |
+
_may_generate_pattern_with_reshape(
|
| 1751 |
+
_may_generate_pattern_with_dtype_convert(
|
| 1752 |
+
dequantize_per_tensor_activation_pattern,
|
| 1753 |
+
KeywordArg("autocast_act_dtype"),
|
| 1754 |
+
dtype == torch.bfloat16,
|
| 1755 |
+
),
|
| 1756 |
+
KeywordArg("act_reshape_size"),
|
| 1757 |
+
input_dim_exceeds_two,
|
| 1758 |
+
),
|
| 1759 |
+
t_pattern,
|
| 1760 |
+
),
|
| 1761 |
+
KeywordArg("output_reshape_size"),
|
| 1762 |
+
input_dim_exceeds_two,
|
| 1763 |
+
)
|
| 1764 |
+
dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape(
|
| 1765 |
+
CallFunction(
|
| 1766 |
+
aten.mm.default,
|
| 1767 |
+
_may_generate_pattern_with_reshape(
|
| 1768 |
+
_may_generate_pattern_with_dtype_convert(
|
| 1769 |
+
dequantize_per_tensor_activation_pattern,
|
| 1770 |
+
KeywordArg("autocast_act_dtype"),
|
| 1771 |
+
dtype == torch.bfloat16,
|
| 1772 |
+
),
|
| 1773 |
+
KeywordArg("act_reshape_size"),
|
| 1774 |
+
input_dim_exceeds_two,
|
| 1775 |
+
),
|
| 1776 |
+
t_pattern,
|
| 1777 |
+
),
|
| 1778 |
+
KeywordArg("output_reshape_size"),
|
| 1779 |
+
input_dim_exceeds_two,
|
| 1780 |
+
)
|
| 1781 |
+
return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern
|
| 1782 |
+
|
| 1783 |
+
|
| 1784 |
+
def _generate_dequant_bmm_node_pattern(
|
| 1785 |
+
_dequant_per_channel_pattern,
|
| 1786 |
+
dtype=torch.float32,
|
| 1787 |
+
with_bias=False,
|
| 1788 |
+
):
|
| 1789 |
+
# When activation of linear dim exceed 2 and not contiguous
|
| 1790 |
+
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
|
| 1791 |
+
|
| 1792 |
+
assert dtype in [torch.float32, torch.bfloat16]
|
| 1793 |
+
dequant_bmm_pattern = CallFunction(
|
| 1794 |
+
aten.bmm.default,
|
| 1795 |
+
CallFunction(
|
| 1796 |
+
aten.expand.default,
|
| 1797 |
+
_may_generate_pattern_with_dtype_convert(
|
| 1798 |
+
dequantize_per_tensor_activation_pattern,
|
| 1799 |
+
KeywordArg("autocast_act_dtype"),
|
| 1800 |
+
dtype == torch.bfloat16,
|
| 1801 |
+
),
|
| 1802 |
+
KeywordArg("act_expand_size"),
|
| 1803 |
+
),
|
| 1804 |
+
CallFunction(
|
| 1805 |
+
aten.expand.default,
|
| 1806 |
+
t_pattern,
|
| 1807 |
+
KeywordArg("wgt_expand_size"),
|
| 1808 |
+
),
|
| 1809 |
+
)
|
| 1810 |
+
|
| 1811 |
+
def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias):
|
| 1812 |
+
if _with_bias:
|
| 1813 |
+
return CallFunction(
|
| 1814 |
+
aten.add.Tensor,
|
| 1815 |
+
_dequant_bmm_pattern,
|
| 1816 |
+
KeywordArg("b"),
|
| 1817 |
+
)
|
| 1818 |
+
else:
|
| 1819 |
+
return _dequant_bmm_pattern
|
| 1820 |
+
|
| 1821 |
+
return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias)
|
| 1822 |
+
|
| 1823 |
+
|
| 1824 |
+
def _generate_qlinear_weight_prepack_patterns(
|
| 1825 |
+
dtype=torch.float32,
|
| 1826 |
+
input_dim_exceeds_two=False,
|
| 1827 |
+
input_contiguous=True,
|
| 1828 |
+
with_bias=False,
|
| 1829 |
+
):
|
| 1830 |
+
if input_dim_exceeds_two and not input_contiguous:
|
| 1831 |
+
return _generate_dequant_bmm_node_pattern(
|
| 1832 |
+
dequantize_per_channel_weight_pattern,
|
| 1833 |
+
dtype,
|
| 1834 |
+
with_bias,
|
| 1835 |
+
)
|
| 1836 |
+
else:
|
| 1837 |
+
return _generate_dequant_linear_node_pattern(
|
| 1838 |
+
dequantize_per_channel_weight_pattern, dtype, input_dim_exceeds_two
|
| 1839 |
+
)
|
| 1840 |
+
|
| 1841 |
+
|
| 1842 |
+
def _register_dequant_promotion():
|
| 1843 |
+
dequant_pattern_cases = itertools.product(
|
| 1844 |
+
[torch.float32, torch.bfloat16], [True, False]
|
| 1845 |
+
)
|
| 1846 |
+
for dtype, input_dim_exceeds_two in dequant_pattern_cases:
|
| 1847 |
+
# 4 dequantization patterns will be matched based on the dtype and input dimension size.
|
| 1848 |
+
# Case 1: int8-mixed-fp32, input dim size is 2
|
| 1849 |
+
# Case 2: int8-mixed-fp32, input dim size exceeds 2
|
| 1850 |
+
# Case 3: int8-mixed-bf16, input dim size is 2
|
| 1851 |
+
# Case 4: int8-mixed-bf16, input dim size exceeds 2
|
| 1852 |
+
# quant
|
| 1853 |
+
# + - - - - | - - - - +
|
| 1854 |
+
# | dequant |
|
| 1855 |
+
# | | |
|
| 1856 |
+
# | OPT(to_bf16) |
|
| 1857 |
+
# | | |
|
| 1858 |
+
# | OPT(reshape) |
|
| 1859 |
+
# | / \ |
|
| 1860 |
+
# | node1 node2 |
|
| 1861 |
+
# + - - | - - - | - - +
|
| 1862 |
+
# OPT(reshape) OPT(reshape)
|
| 1863 |
+
# + - - | - - - | - - +
|
| 1864 |
+
# OPT(to_fp32) OPT(to_fp32)
|
| 1865 |
+
# + - - | - - - | - - +
|
| 1866 |
+
# quant quant
|
| 1867 |
+
_register_dequant_promotion_pass(
|
| 1868 |
+
_may_generate_pattern_with_reshape(
|
| 1869 |
+
_may_generate_pattern_with_dtype_convert(
|
| 1870 |
+
dequantize_per_tensor_activation_pattern,
|
| 1871 |
+
KeywordArg("autocast_act_dtype"),
|
| 1872 |
+
dtype == torch.bfloat16,
|
| 1873 |
+
),
|
| 1874 |
+
KeywordArg("act_reshape_size"),
|
| 1875 |
+
with_reshape=input_dim_exceeds_two,
|
| 1876 |
+
),
|
| 1877 |
+
pass_number=0,
|
| 1878 |
+
dtype=dtype,
|
| 1879 |
+
) # pass_number=0 to run before weight prepack
|
| 1880 |
+
|
| 1881 |
+
|
| 1882 |
+
def _register_qconv_weight_prepack():
|
| 1883 |
+
for dtype in [torch.float32, torch.bfloat16]:
|
| 1884 |
+
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype)
|
| 1885 |
+
for weight_prepack_pattern in weight_prepack_patterns:
|
| 1886 |
+
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
|
| 1887 |
+
_register_qconv_weight_prepack_pass(
|
| 1888 |
+
weight_prepack_pattern, pass_number=1, dtype=dtype
|
| 1889 |
+
)
|
| 1890 |
+
|
| 1891 |
+
|
| 1892 |
+
def _register_qlinear_weight_prepack():
|
| 1893 |
+
# 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous.
|
| 1894 |
+
# Then convert the pattern into a QLinear node with int8_fp32/bf16.
|
| 1895 |
+
# Case 1: int8-mixed-fp32, input dim size is 2
|
| 1896 |
+
# Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous
|
| 1897 |
+
# Case 3: int8-mixed-bf16, input dim size is 2
|
| 1898 |
+
# Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous
|
| 1899 |
+
|
| 1900 |
+
# + - - - - | - - - - - - | - - - - - +
|
| 1901 |
+
# | dq_per_tensor dq_per_channel |
|
| 1902 |
+
# | | | |
|
| 1903 |
+
# | OPT(to_bf16) OPT(to_bf16) |
|
| 1904 |
+
# | | | |
|
| 1905 |
+
# | OPT(reshape) permute |
|
| 1906 |
+
# | \ / |
|
| 1907 |
+
# | addmm/mm |
|
| 1908 |
+
# | | |
|
| 1909 |
+
# | OPT(reshape) |
|
| 1910 |
+
|
| 1911 |
+
# Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous
|
| 1912 |
+
# Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous
|
| 1913 |
+
|
| 1914 |
+
# + - - - - | - - - - - - | - - - - - +
|
| 1915 |
+
# | dq_per_tensor dq_per_channel |
|
| 1916 |
+
# | | | |
|
| 1917 |
+
# | OPT(to_bf16) OPT(to_bf16) |
|
| 1918 |
+
# | | | |
|
| 1919 |
+
# | expand permute |
|
| 1920 |
+
# | \ | |
|
| 1921 |
+
# | expand |
|
| 1922 |
+
# | / |
|
| 1923 |
+
# | bmm |
|
| 1924 |
+
# | | |
|
| 1925 |
+
# | OPT(add) |
|
| 1926 |
+
|
| 1927 |
+
linear_weight_prepack_cases = itertools.product(
|
| 1928 |
+
[torch.float32, torch.bfloat16], [True, False]
|
| 1929 |
+
)
|
| 1930 |
+
|
| 1931 |
+
# Step 1: register patterns from mm and addmm
|
| 1932 |
+
for dtype, input_dim_exceeds_two in linear_weight_prepack_cases:
|
| 1933 |
+
weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns(
|
| 1934 |
+
dtype, input_dim_exceeds_two
|
| 1935 |
+
)
|
| 1936 |
+
for weight_prepack_pattern in weight_prepack_patterns:
|
| 1937 |
+
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
|
| 1938 |
+
_register_qlinear_weight_prepack_pass(
|
| 1939 |
+
weight_prepack_pattern,
|
| 1940 |
+
pass_number=1,
|
| 1941 |
+
dtype=dtype,
|
| 1942 |
+
input_dim_exceeds_two=input_dim_exceeds_two,
|
| 1943 |
+
)
|
| 1944 |
+
|
| 1945 |
+
# Step 2: register patterns from bmm
|
| 1946 |
+
# Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous
|
| 1947 |
+
# refer to:
|
| 1948 |
+
# https://github.com/pytorch/pytorch/blob/
|
| 1949 |
+
# 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968
|
| 1950 |
+
# in this case, we can convert it back to qlinear
|
| 1951 |
+
for dtype, with_bias in itertools.product(
|
| 1952 |
+
[torch.float32, torch.bfloat16], [True, False]
|
| 1953 |
+
):
|
| 1954 |
+
bmm_pattern = _generate_qlinear_weight_prepack_patterns(
|
| 1955 |
+
dtype=dtype,
|
| 1956 |
+
input_dim_exceeds_two=True,
|
| 1957 |
+
input_contiguous=False,
|
| 1958 |
+
with_bias=with_bias,
|
| 1959 |
+
)
|
| 1960 |
+
_register_qlinear_weight_prepack_pass(
|
| 1961 |
+
bmm_pattern,
|
| 1962 |
+
pass_number=1
|
| 1963 |
+
if with_bias
|
| 1964 |
+
else 2, # if with_bias, there is an output add, so we should try to match it firstly
|
| 1965 |
+
dtype=dtype,
|
| 1966 |
+
input_dim_exceeds_two=True,
|
| 1967 |
+
input_contiguous=False,
|
| 1968 |
+
)
|
| 1969 |
+
|
| 1970 |
+
|
| 1971 |
+
@functools.lru_cache(None)
|
| 1972 |
+
def _register_quantization_weight_pack_pass():
|
| 1973 |
+
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
|
| 1974 |
+
_register_dequant_promotion()
|
| 1975 |
+
|
| 1976 |
+
# Step 2: QConv weight prepack
|
| 1977 |
+
_register_qconv_weight_prepack()
|
| 1978 |
+
|
| 1979 |
+
# Step 3: QLinear weight prepack
|
| 1980 |
+
_register_qlinear_weight_prepack()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
| 7 |
+
from .. import config, inductor_prims
|
| 8 |
+
from ..pattern_matcher import (
|
| 9 |
+
CallFunctionVarArgs,
|
| 10 |
+
Match,
|
| 11 |
+
PatternMatcherPass,
|
| 12 |
+
register_graph_pattern,
|
| 13 |
+
)
|
| 14 |
+
from ..virtualized import V
|
| 15 |
+
|
| 16 |
+
log = logging.getLogger(__name__)
|
| 17 |
+
patterns = PatternMatcherPass()
|
| 18 |
+
aten = torch.ops.aten
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def replace_random_passes(gm: torch.fx.GraphModule):
|
| 22 |
+
"""Modify the given FX graph to use backend-native random ops"""
|
| 23 |
+
if config.fallback_random:
|
| 24 |
+
return 0
|
| 25 |
+
|
| 26 |
+
count = patterns.apply(gm)
|
| 27 |
+
count += fuse_seed_creation_pass(gm.graph)
|
| 28 |
+
|
| 29 |
+
return count
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def fuse_seed_creation_pass(graph: torch.fx.Graph):
|
| 33 |
+
"""
|
| 34 |
+
Horizontally fuse all the seed generation on each device
|
| 35 |
+
|
| 36 |
+
a = inductor_seed(dev)
|
| 37 |
+
b = inductor_seed(dev)
|
| 38 |
+
|
| 39 |
+
Becomes:
|
| 40 |
+
seeds = inductor_seeds(2, dev)
|
| 41 |
+
a = inductor_lookup_seed(seeds, 0)
|
| 42 |
+
b = inductor_lookup_seed(seeds, 1)
|
| 43 |
+
|
| 44 |
+
We do this because seed creation is entirely launch overhead bound.
|
| 45 |
+
"""
|
| 46 |
+
device_seeds = collections.defaultdict(list)
|
| 47 |
+
for node in graph.nodes:
|
| 48 |
+
if CallFunctionVarArgs(inductor_prims.seed).match(node):
|
| 49 |
+
device_seeds[node.args[0]].append(node)
|
| 50 |
+
|
| 51 |
+
if not device_seeds:
|
| 52 |
+
return 0
|
| 53 |
+
|
| 54 |
+
for device, seeds in device_seeds.items():
|
| 55 |
+
with graph.inserting_before(seeds[0]):
|
| 56 |
+
combined = graph.call_function(inductor_prims.seeds, (len(seeds), device))
|
| 57 |
+
with V.fake_mode:
|
| 58 |
+
combined.meta["val"] = torch.empty(
|
| 59 |
+
[len(seeds)], device=device, dtype=torch.int64
|
| 60 |
+
)
|
| 61 |
+
combined.meta["tensor_meta"] = _extract_tensor_metadata(
|
| 62 |
+
combined.meta["val"]
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
for idx, seed in enumerate(seeds):
|
| 66 |
+
with graph.inserting_before(seed):
|
| 67 |
+
new_seed = graph.call_function(
|
| 68 |
+
inductor_prims.lookup_seed, (combined, idx)
|
| 69 |
+
)
|
| 70 |
+
seed.replace_all_uses_with(new_seed)
|
| 71 |
+
new_seed.meta.update(seed.meta)
|
| 72 |
+
graph.erase_node(seed)
|
| 73 |
+
|
| 74 |
+
return len(device_seeds)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def default_kwargs(device):
|
| 78 |
+
return {}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_device(device):
|
| 82 |
+
if device is not None:
|
| 83 |
+
return device
|
| 84 |
+
return torch.empty([]).device # default device
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns)
|
| 88 |
+
@register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns)
|
| 89 |
+
@register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns)
|
| 90 |
+
@register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns)
|
| 91 |
+
def replace_random(
|
| 92 |
+
match: Match,
|
| 93 |
+
size,
|
| 94 |
+
*,
|
| 95 |
+
generator=None,
|
| 96 |
+
dtype=None,
|
| 97 |
+
device=None,
|
| 98 |
+
layout=None,
|
| 99 |
+
pin_memory=None,
|
| 100 |
+
):
|
| 101 |
+
if generator is not None:
|
| 102 |
+
return
|
| 103 |
+
|
| 104 |
+
def replacement(size):
|
| 105 |
+
result = inductor_prims.random(
|
| 106 |
+
size, inductor_prims.seed(device), mode, **default_kwargs(device)
|
| 107 |
+
)
|
| 108 |
+
if dtype is not None:
|
| 109 |
+
result = result.to(dtype)
|
| 110 |
+
return result
|
| 111 |
+
|
| 112 |
+
mode = {
|
| 113 |
+
aten.rand: "rand",
|
| 114 |
+
aten.randn: "randn",
|
| 115 |
+
}[
|
| 116 |
+
match.output_node().target.overloadpacket # type: ignore[union-attr]
|
| 117 |
+
] # type: ignore[union-attr]
|
| 118 |
+
device = get_device(device)
|
| 119 |
+
match.replace_by_example(replacement, [size])
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns)
|
| 123 |
+
def replace_randint(
|
| 124 |
+
match: Match,
|
| 125 |
+
low,
|
| 126 |
+
high,
|
| 127 |
+
size,
|
| 128 |
+
*,
|
| 129 |
+
dtype=torch.int64,
|
| 130 |
+
device=None,
|
| 131 |
+
layout=None,
|
| 132 |
+
pin_memory=None,
|
| 133 |
+
):
|
| 134 |
+
def replacement(size):
|
| 135 |
+
result = inductor_prims.randint(low, high, size, inductor_prims.seed(device))
|
| 136 |
+
return result.to(dtype)
|
| 137 |
+
|
| 138 |
+
device = get_device(device)
|
| 139 |
+
match.replace_by_example(replacement, [size])
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc
ADDED
|
Binary file (17.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc
ADDED
|
Binary file (19.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-311.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/central_index.cpython-311.pyc
ADDED
|
Binary file (7.36 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 35 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 36 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 37 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 38 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 39 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 40 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 41 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 42 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 43 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 44 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 45 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 46 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 47 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
|
| 48 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 49 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 50 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 51 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 52 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 53 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 54 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 55 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 56 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 57 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 58 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 59 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 60 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 61 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 62 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 63 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 64 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 65 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 66 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 67 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 68 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 69 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 70 |
+
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 71 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 72 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 73 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 74 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 75 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
| 76 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 77 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 78 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 79 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale_factor'))
|
| 80 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 81 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 82 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 83 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 84 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 85 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 86 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 87 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 88 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 89 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 90 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 91 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 92 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 93 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 94 |
+
_sfdp_pattern_12_training = MultiOutputPattern([view_default_5,
|
| 95 |
+
permute_default_6,
|
| 96 |
+
permute_default_9,
|
| 97 |
+
permute_default_11,
|
| 98 |
+
None,
|
| 99 |
+
None
|
| 100 |
+
])
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 104 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 105 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 106 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 107 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 108 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 109 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 110 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 111 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 112 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 113 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 114 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
|
| 115 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 116 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 117 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 118 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 119 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 120 |
+
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
| 121 |
+
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
| 122 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 123 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 124 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 125 |
+
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 126 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
| 127 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 128 |
+
_sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 132 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 133 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 134 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 135 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 136 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 137 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 138 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 139 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 140 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 141 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 142 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 143 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 144 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
|
| 145 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 146 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 147 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 148 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 149 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 150 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 151 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 152 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 153 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 154 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 155 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 156 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 157 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 158 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 159 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 160 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 161 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 162 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 163 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 164 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 165 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 166 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 167 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
| 168 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 169 |
+
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 170 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
|
| 171 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 172 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 173 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 174 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 175 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 176 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
|
| 177 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 178 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
|
| 179 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 180 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 181 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor'))
|
| 182 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 183 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 184 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 185 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 186 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 187 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 188 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 189 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 190 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 191 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 192 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 193 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 194 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 195 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 196 |
+
_sfdp_pattern_12_half_training = MultiOutputPattern([view_default_5,
|
| 197 |
+
permute_default_6,
|
| 198 |
+
permute_default_9,
|
| 199 |
+
permute_default_11,
|
| 200 |
+
None,
|
| 201 |
+
None
|
| 202 |
+
])
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 206 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 207 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 208 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 209 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 210 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 211 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 212 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 213 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 214 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 215 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 216 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
|
| 217 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 218 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 219 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 220 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 221 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 222 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 223 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 224 |
+
clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
|
| 225 |
+
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
| 226 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 227 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 228 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 229 |
+
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 230 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
| 231 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 232 |
+
_sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 35 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 36 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
| 37 |
+
bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2)
|
| 38 |
+
amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True)
|
| 39 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default)
|
| 40 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 41 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 42 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 43 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor)
|
| 44 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2)
|
| 45 |
+
bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value'))
|
| 46 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 47 |
+
bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1)
|
| 48 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 49 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 50 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2)
|
| 51 |
+
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 52 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor)
|
| 53 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 54 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 55 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 56 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
| 57 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 58 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 59 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5, _users=2)
|
| 60 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored())
|
| 61 |
+
bmm_default_3 = CallFunction(aten.bmm.default, sub_Tensor_1, permute_default_2)
|
| 62 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 63 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, sub_Tensor_1)
|
| 64 |
+
permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored())
|
| 65 |
+
permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored())
|
| 66 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1'))
|
| 67 |
+
_sfdp_pattern_13_training = MultiOutputPattern([bmm_default_1,
|
| 68 |
+
bmm_default_3,
|
| 69 |
+
permute_default_4,
|
| 70 |
+
bmm_default_5,
|
| 71 |
+
None
|
| 72 |
+
])
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 76 |
+
bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2)
|
| 77 |
+
amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True)
|
| 78 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default)
|
| 79 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 80 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 81 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 82 |
+
clone_default = CallFunction(aten.clone.default, div_Tensor)
|
| 83 |
+
_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value'))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 87 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 88 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
| 89 |
+
bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default)
|
| 90 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2)
|
| 91 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 92 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 93 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 94 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 95 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 96 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 97 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 98 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2)
|
| 99 |
+
bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value'))
|
| 100 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 101 |
+
bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1)
|
| 102 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 103 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
| 104 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2)
|
| 105 |
+
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 106 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
| 107 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 108 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 109 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 110 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 111 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 112 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
|
| 113 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 114 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
|
| 115 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 116 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored(), _users=2)
|
| 117 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored())
|
| 118 |
+
bmm_default_3 = CallFunction(aten.bmm.default, convert_element_type_default_5, permute_default_2)
|
| 119 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 120 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, convert_element_type_default_5)
|
| 121 |
+
permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored())
|
| 122 |
+
permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored())
|
| 123 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1'))
|
| 124 |
+
_sfdp_pattern_13_half_training = MultiOutputPattern([bmm_default_1,
|
| 125 |
+
bmm_default_3,
|
| 126 |
+
permute_default_4,
|
| 127 |
+
bmm_default_5,
|
| 128 |
+
None
|
| 129 |
+
])
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 133 |
+
bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default)
|
| 134 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2)
|
| 135 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 136 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 137 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 138 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 139 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 140 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
| 141 |
+
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
|
| 142 |
+
_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value'))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 35 |
+
expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
|
| 36 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 37 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 38 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 39 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 40 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 41 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 42 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 43 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 44 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 45 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 46 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 47 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 48 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 49 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
|
| 50 |
+
amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
| 51 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
| 52 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 53 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 54 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 55 |
+
expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 56 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 57 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 58 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 59 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 60 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 61 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 62 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 63 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 64 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 65 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 66 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 67 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 68 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 69 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 70 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 71 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 72 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
|
| 73 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 74 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 75 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 76 |
+
where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, sub_Tensor_1)
|
| 77 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
|
| 78 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 79 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 80 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 81 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 82 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 83 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 84 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 85 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 86 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 87 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 88 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 89 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 90 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 91 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 92 |
+
_sfdp_pattern_15_training = MultiOutputPattern([view_default_5,
|
| 93 |
+
permute_default_6,
|
| 94 |
+
permute_default_9,
|
| 95 |
+
permute_default_11,
|
| 96 |
+
None,
|
| 97 |
+
None
|
| 98 |
+
])
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 102 |
+
view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
|
| 103 |
+
expand_default = CallFunction(aten.expand.default, view_default, Ignored())
|
| 104 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 105 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 106 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 107 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 108 |
+
view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
|
| 109 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 110 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 111 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 112 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 113 |
+
view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 114 |
+
bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
|
| 115 |
+
view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 116 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
|
| 117 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
|
| 118 |
+
amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
| 119 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
| 120 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 121 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 122 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 123 |
+
expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 124 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 125 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 126 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 127 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 128 |
+
view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 129 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
|
| 130 |
+
_sfdp_pattern_15_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 134 |
+
expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
|
| 135 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 136 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 137 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 138 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 139 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 140 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 141 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 142 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 143 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 144 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 145 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 146 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 147 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 148 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
|
| 149 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
|
| 150 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 151 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 152 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 153 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 154 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 155 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 156 |
+
expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 157 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 158 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 159 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 160 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 161 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 162 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 163 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 164 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 165 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 166 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 167 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 168 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 169 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 170 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 171 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 172 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 173 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 174 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 175 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
|
| 176 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 177 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
|
| 178 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 179 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 180 |
+
where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_4)
|
| 181 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
|
| 182 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 183 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 184 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 185 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 186 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 187 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 188 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 189 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 190 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 191 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 192 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 193 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 194 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 195 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 196 |
+
_sfdp_pattern_15_half_training = MultiOutputPattern([view_default_5,
|
| 197 |
+
permute_default_6,
|
| 198 |
+
permute_default_9,
|
| 199 |
+
permute_default_11,
|
| 200 |
+
None,
|
| 201 |
+
None
|
| 202 |
+
])
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 206 |
+
view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
|
| 207 |
+
expand_default = CallFunction(aten.expand.default, view_default, Ignored())
|
| 208 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 209 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 210 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 211 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 212 |
+
view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
|
| 213 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 214 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 215 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 216 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 217 |
+
view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 218 |
+
bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
|
| 219 |
+
view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 220 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
|
| 221 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
|
| 222 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
|
| 223 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 224 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 225 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 226 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 227 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 228 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 229 |
+
expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 230 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 231 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 232 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 233 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 234 |
+
view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 235 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
|
| 236 |
+
_sfdp_pattern_15_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 35 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 36 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 37 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 38 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 39 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 40 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 41 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 42 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 43 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
|
| 44 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 45 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 46 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 47 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 48 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 49 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 50 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 51 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 52 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 53 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 54 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 55 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 56 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 57 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 58 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 59 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 60 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 61 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 62 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 63 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 64 |
+
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 65 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 66 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 67 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 68 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 69 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
| 70 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 71 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 72 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 73 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale_factor'))
|
| 74 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 75 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 76 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 77 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 78 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 79 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 80 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 81 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 82 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 83 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 84 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 85 |
+
_sfdp_pattern_3_training = MultiOutputPattern([view_default_5,
|
| 86 |
+
view_default_9,
|
| 87 |
+
permute_default_4,
|
| 88 |
+
view_default_11,
|
| 89 |
+
None,
|
| 90 |
+
None
|
| 91 |
+
])
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 95 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 96 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 97 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 98 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 99 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 100 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 101 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
|
| 102 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 103 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 104 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 105 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 106 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 107 |
+
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
|
| 108 |
+
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
| 109 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 110 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 111 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 112 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 113 |
+
_sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 117 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 118 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 119 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 120 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 121 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 122 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 123 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 124 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 125 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
|
| 126 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 127 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 128 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 129 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 130 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 131 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 132 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 133 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 134 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 135 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 136 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 137 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 138 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 139 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 140 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 141 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 142 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 143 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 144 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 145 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 146 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
| 147 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 148 |
+
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 149 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
| 150 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 151 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 152 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 153 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 154 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 155 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
|
| 156 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 157 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
|
| 158 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 159 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 160 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor'))
|
| 161 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 162 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 163 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 164 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 165 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 166 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 167 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 168 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 169 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 170 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 171 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 172 |
+
_sfdp_pattern_3_half_training = MultiOutputPattern([view_default_5,
|
| 173 |
+
view_default_9,
|
| 174 |
+
permute_default_4,
|
| 175 |
+
view_default_11,
|
| 176 |
+
None,
|
| 177 |
+
None
|
| 178 |
+
])
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 182 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 183 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 184 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 185 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 186 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 187 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 188 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
|
| 189 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 190 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 191 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 192 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 193 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 194 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 195 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 196 |
+
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
|
| 197 |
+
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
| 198 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 199 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 200 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 201 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 202 |
+
_sfdp_pattern_3_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 35 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 36 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 37 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 38 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 39 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 40 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 41 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 42 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 43 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 44 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 45 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 46 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 47 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 48 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 49 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 50 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 51 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 52 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 53 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 54 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 55 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 56 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 57 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 58 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 59 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 60 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 61 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 62 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 63 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 64 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 65 |
+
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 66 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 67 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 68 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 69 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 70 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
| 71 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 72 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 73 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 74 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored())
|
| 75 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 76 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 77 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 78 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 79 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 80 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 81 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 82 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 83 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 84 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 85 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 86 |
+
_sfdp_pattern_6_training = MultiOutputPattern([view_default_5,
|
| 87 |
+
view_default_9,
|
| 88 |
+
permute_default_4,
|
| 89 |
+
view_default_11,
|
| 90 |
+
None,
|
| 91 |
+
None
|
| 92 |
+
])
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 96 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 97 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 98 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 99 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 100 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 101 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 102 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 103 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 104 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 105 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 106 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 107 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 108 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 109 |
+
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
|
| 110 |
+
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
| 111 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 112 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 113 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 114 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 115 |
+
_sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 119 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 120 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 121 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 122 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 123 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 124 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 125 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 126 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 127 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 128 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 129 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 130 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 131 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 132 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 133 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 134 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 135 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 136 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 137 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 138 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 139 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 140 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 141 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 142 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 143 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 144 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 145 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 146 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 147 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 148 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 149 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
| 150 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 151 |
+
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 152 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
| 153 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 154 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 155 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 156 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 157 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 158 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
|
| 159 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 160 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
|
| 161 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 162 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 163 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored())
|
| 164 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 165 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 166 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 167 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 168 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 169 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 170 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 171 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 172 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 173 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 174 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 175 |
+
_sfdp_pattern_6_half_training = MultiOutputPattern([view_default_5,
|
| 176 |
+
view_default_9,
|
| 177 |
+
permute_default_4,
|
| 178 |
+
view_default_11,
|
| 179 |
+
None,
|
| 180 |
+
None
|
| 181 |
+
])
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 185 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 186 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 187 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 188 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 189 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 190 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 191 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 192 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 193 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 194 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 195 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 196 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 197 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 198 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 199 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 200 |
+
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
|
| 201 |
+
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
| 202 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 203 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 204 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 205 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 206 |
+
_sfdp_pattern_6_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 35 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 36 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 37 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 38 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 39 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 40 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 41 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 42 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 43 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 44 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 45 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
|
| 46 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 47 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 48 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 49 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 50 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 51 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 52 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 53 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 54 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 55 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 56 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 57 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 58 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 59 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 60 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 61 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 62 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 63 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
|
| 64 |
+
view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
|
| 65 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 66 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 67 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 68 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 69 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 70 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
|
| 71 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 72 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 73 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 74 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored())
|
| 75 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 76 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 77 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 78 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 79 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 80 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 81 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 82 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 83 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 84 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 85 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 86 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 87 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 88 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 89 |
+
_sfdp_pattern_8_training = MultiOutputPattern([view_default_5,
|
| 90 |
+
permute_default_6,
|
| 91 |
+
permute_default_9,
|
| 92 |
+
permute_default_11
|
| 93 |
+
])
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 97 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 98 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 99 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 100 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 101 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 102 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 103 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 104 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 105 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 106 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 107 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
|
| 108 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 109 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 110 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 111 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 112 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 113 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 114 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 115 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 116 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 117 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 118 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 119 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 120 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 121 |
+
_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 125 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 126 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 127 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 128 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 129 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 130 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 131 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 132 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 133 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 134 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 135 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 136 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 137 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 138 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 139 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 140 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 141 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 142 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 143 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 144 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 145 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 146 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 147 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 148 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 149 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 150 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 151 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 152 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 153 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 154 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 155 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 156 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 157 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 158 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 159 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 160 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
|
| 161 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 162 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 163 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 164 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 165 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, Ignored())
|
| 166 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 167 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 168 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 169 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 170 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 171 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 172 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 173 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 174 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 175 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 176 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 177 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 178 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 179 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 180 |
+
_sfdp_pattern_8_half_training = MultiOutputPattern([view_default_5,
|
| 181 |
+
permute_default_6,
|
| 182 |
+
permute_default_9,
|
| 183 |
+
permute_default_11
|
| 184 |
+
])
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 188 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 189 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 190 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 191 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 192 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 193 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 194 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 195 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 196 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 197 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 198 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 199 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 200 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 201 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 202 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 203 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 204 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 205 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 206 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 207 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 208 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 209 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 210 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 211 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 212 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 213 |
+
_sfdp_pattern_8_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-311.pyc
ADDED
|
Binary file (20.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-311.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/bmm.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ..lowering import register_lowering
|
| 4 |
+
from ..select_algorithm import (
|
| 5 |
+
autotune_select_algorithm,
|
| 6 |
+
ExternKernelChoice,
|
| 7 |
+
TritonTemplate,
|
| 8 |
+
)
|
| 9 |
+
from ..utils import ceildiv as cdiv, use_aten_gemm_kernels, use_triton_template
|
| 10 |
+
|
| 11 |
+
from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
|
| 12 |
+
|
| 13 |
+
aten = torch.ops.aten
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def bmm_grid(b, m, n, meta):
|
| 17 |
+
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
bmm_template = TritonTemplate(
|
| 21 |
+
name="bmm",
|
| 22 |
+
grid=bmm_grid,
|
| 23 |
+
source=r"""
|
| 24 |
+
{{def_kernel("A", "B")}}
|
| 25 |
+
M = {{size("A", -2)}}
|
| 26 |
+
N = {{size("B", -1)}}
|
| 27 |
+
K = {{size("A", -1)}}
|
| 28 |
+
|
| 29 |
+
stride_aq = {{stride("A", 0)}}
|
| 30 |
+
stride_am = {{stride("A", 1)}}
|
| 31 |
+
stride_ak = {{stride("A", 2)}}
|
| 32 |
+
|
| 33 |
+
stride_bq = {{stride("B", 0)}}
|
| 34 |
+
stride_bk = {{stride("B", 1)}}
|
| 35 |
+
stride_bn = {{stride("B", 2)}}
|
| 36 |
+
|
| 37 |
+
# based on triton.ops.matmul
|
| 38 |
+
pid = tl.program_id(0)
|
| 39 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 40 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 41 |
+
|
| 42 |
+
# re-order program ID for better L2 performance
|
| 43 |
+
width = GROUP_M * grid_n
|
| 44 |
+
group_id = pid // width
|
| 45 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 46 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 47 |
+
pid_n = (pid % width) // (group_size)
|
| 48 |
+
|
| 49 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 50 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 51 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 52 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 53 |
+
rk = tl.arange(0, BLOCK_K)
|
| 54 |
+
|
| 55 |
+
idx_q = tl.program_id(1) # batch dimension for BMM
|
| 56 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
|
| 57 |
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
|
| 58 |
+
|
| 59 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 60 |
+
for k in range(K, 0, -BLOCK_K):
|
| 61 |
+
if EVEN_K:
|
| 62 |
+
a = tl.load(A)
|
| 63 |
+
b = tl.load(B)
|
| 64 |
+
else:
|
| 65 |
+
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
| 66 |
+
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
| 67 |
+
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
| 68 |
+
A += BLOCK_K * stride_ak
|
| 69 |
+
B += BLOCK_K * stride_bk
|
| 70 |
+
|
| 71 |
+
# rematerialize rm and rn to save registers
|
| 72 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 73 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 74 |
+
idx_q = tl.program_id(1) # batch dimension for BMM
|
| 75 |
+
idx_m = rm[:, None]
|
| 76 |
+
idx_n = rn[None, :]
|
| 77 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 78 |
+
|
| 79 |
+
# inductor generates a suffix
|
| 80 |
+
{{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}}
|
| 81 |
+
""",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
|
| 85 |
+
aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@register_lowering(aten.bmm)
|
| 89 |
+
def tuned_bmm(mat1, mat2, *, layout=None):
|
| 90 |
+
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
| 91 |
+
|
| 92 |
+
# options to tune from
|
| 93 |
+
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
|
| 94 |
+
if use_triton_template(layout):
|
| 95 |
+
for config in mm_configs(m, n, k):
|
| 96 |
+
bmm_template.maybe_append_choice(
|
| 97 |
+
choices,
|
| 98 |
+
input_nodes=(mat1, mat2),
|
| 99 |
+
layout=layout,
|
| 100 |
+
**mm_options(config, m, n, k, layout),
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# Don't register this since it is slower than decomposing it
|
| 107 |
+
# @register_lowering(aten.baddbmm)
|
| 108 |
+
def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
| 109 |
+
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
|
| 110 |
+
|
| 111 |
+
# options to tune from
|
| 112 |
+
choices = (
|
| 113 |
+
[aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
|
| 114 |
+
if use_aten_gemm_kernels()
|
| 115 |
+
else []
|
| 116 |
+
)
|
| 117 |
+
if use_triton_template(layout):
|
| 118 |
+
for config in mm_configs(m, n, k):
|
| 119 |
+
bmm_template.maybe_append_choice(
|
| 120 |
+
choices,
|
| 121 |
+
input_nodes=(inp, mat1, mat2),
|
| 122 |
+
layout=layout,
|
| 123 |
+
**mm_options(config, m, n, k, layout),
|
| 124 |
+
prefix_args=1,
|
| 125 |
+
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch._inductor.virtualized import V
|
| 7 |
+
from .. import config as inductor_config
|
| 8 |
+
from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate
|
| 9 |
+
from ..lowering import register_lowering
|
| 10 |
+
from ..select_algorithm import (
|
| 11 |
+
autotune_select_algorithm,
|
| 12 |
+
ExternKernelChoice,
|
| 13 |
+
TritonTemplate,
|
| 14 |
+
)
|
| 15 |
+
from ..utils import (
|
| 16 |
+
use_aten_gemm_kernels,
|
| 17 |
+
use_cutlass_template,
|
| 18 |
+
use_max_autotune,
|
| 19 |
+
use_triton_template,
|
| 20 |
+
)
|
| 21 |
+
from .mm_common import (
|
| 22 |
+
addmm_epilogue,
|
| 23 |
+
int8_mm_configs,
|
| 24 |
+
mm_args,
|
| 25 |
+
mm_configs,
|
| 26 |
+
mm_grid,
|
| 27 |
+
mm_options,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
log = logging.getLogger(__name__)
|
| 31 |
+
aten = torch.ops.aten
|
| 32 |
+
|
| 33 |
+
mm_template = TritonTemplate(
|
| 34 |
+
name="mm",
|
| 35 |
+
grid=mm_grid,
|
| 36 |
+
source=r"""
|
| 37 |
+
{{def_kernel("A", "B")}}
|
| 38 |
+
M = {{size("A", 0)}}
|
| 39 |
+
N = {{size("B", 1)}}
|
| 40 |
+
K = {{size("A", 1)}}
|
| 41 |
+
if M * N == 0:
|
| 42 |
+
# early exit due to zero-size input(s)
|
| 43 |
+
return
|
| 44 |
+
stride_am = {{stride("A", 0)}}
|
| 45 |
+
stride_ak = {{stride("A", 1)}}
|
| 46 |
+
stride_bk = {{stride("B", 0)}}
|
| 47 |
+
stride_bn = {{stride("B", 1)}}
|
| 48 |
+
|
| 49 |
+
# based on triton.ops.matmul
|
| 50 |
+
pid = tl.program_id(0)
|
| 51 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 52 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 53 |
+
|
| 54 |
+
# re-order program ID for better L2 performance
|
| 55 |
+
width = GROUP_M * grid_n
|
| 56 |
+
group_id = pid // width
|
| 57 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 58 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 59 |
+
pid_n = (pid % width) // (group_size)
|
| 60 |
+
|
| 61 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 62 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 63 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 64 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 65 |
+
rk = tl.arange(0, BLOCK_K)
|
| 66 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 67 |
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
| 68 |
+
|
| 69 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 70 |
+
for k in range(K, 0, -BLOCK_K):
|
| 71 |
+
if EVEN_K:
|
| 72 |
+
a = tl.load(A)
|
| 73 |
+
b = tl.load(B)
|
| 74 |
+
else:
|
| 75 |
+
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
| 76 |
+
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
| 77 |
+
if B_PROLOGUE_CAST_TYPE is not None:
|
| 78 |
+
b = b.to(B_PROLOGUE_CAST_TYPE)
|
| 79 |
+
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
| 80 |
+
A += BLOCK_K * stride_ak
|
| 81 |
+
B += BLOCK_K * stride_bk
|
| 82 |
+
|
| 83 |
+
# rematerialize rm and rn to save registers
|
| 84 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 85 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 86 |
+
idx_m = rm[:, None]
|
| 87 |
+
idx_n = rn[None, :]
|
| 88 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 89 |
+
|
| 90 |
+
# inductor generates a suffix
|
| 91 |
+
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
| 92 |
+
""",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
aten_addmm = ExternKernelChoice(
|
| 99 |
+
torch.addmm, "at::addmm_out", op_overload=aten.addmm.default
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm")
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _is_int8_mat(mat):
|
| 106 |
+
return mat.get_dtype() in (torch.int8, torch.uint8)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
|
| 110 |
+
"""
|
| 111 |
+
Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
|
| 112 |
+
kernel under the hood. There are a few shapes where this is slower,
|
| 113 |
+
but they are rare.
|
| 114 |
+
"""
|
| 115 |
+
if inp.stride(0) == 0 or inp.size(0) == 1:
|
| 116 |
+
return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta)
|
| 117 |
+
return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@register_lowering(aten.mm, type_promotion_kind=None)
|
| 124 |
+
def tuned_mm(mat1, mat2, *, layout=None):
|
| 125 |
+
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
| 126 |
+
|
| 127 |
+
# options to tune from
|
| 128 |
+
choices = [aten_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
|
| 129 |
+
|
| 130 |
+
if m * n != 0 and use_triton_template(layout):
|
| 131 |
+
for config in mm_configs(m, n, k):
|
| 132 |
+
mm_template.maybe_append_choice(
|
| 133 |
+
choices,
|
| 134 |
+
input_nodes=(mat1, mat2),
|
| 135 |
+
layout=layout,
|
| 136 |
+
**mm_options(config, m, n, k, layout),
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
if m * n != 0 and use_cutlass_template(layout):
|
| 140 |
+
CUTLASSGemmTemplate.add_cutlass_gemm_choices(
|
| 141 |
+
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
from torch._inductor.ir import FixedLayout, FlexibleLayout
|
| 145 |
+
|
| 146 |
+
if (
|
| 147 |
+
len(choices) == 1
|
| 148 |
+
and use_aten_gemm_kernels()
|
| 149 |
+
and isinstance(layout, FixedLayout)
|
| 150 |
+
):
|
| 151 |
+
# If we are not autotuning, we can swap to a FlexibleLayout
|
| 152 |
+
# in order to get fusion optimizations to kick in, e.g. ConcatFusion
|
| 153 |
+
layout = FlexibleLayout(
|
| 154 |
+
device=layout.device, dtype=layout.dtype, size=layout.size
|
| 155 |
+
)
|
| 156 |
+
choices = [aten_mm.bind((mat1, mat2), layout)]
|
| 157 |
+
|
| 158 |
+
return autotune_select_algorithm("mm", choices, [mat1, mat2], layout)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@register_lowering(aten._int_mm, type_promotion_kind=None)
|
| 162 |
+
def tuned_int_mm(mat1, mat2, *, layout=None):
|
| 163 |
+
m, n, k, layout, mat1, mat2 = mm_args(
|
| 164 |
+
mat1, mat2, layout=layout, out_dtype=torch.int32
|
| 165 |
+
)
|
| 166 |
+
choices = (
|
| 167 |
+
[aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
|
| 168 |
+
)
|
| 169 |
+
if m * n != 0 and use_triton_template(layout, enable_int32=True):
|
| 170 |
+
# TODO: Re-enable eager mode implementation once cuBLAS is fixed
|
| 171 |
+
choices = []
|
| 172 |
+
for config in int8_mm_configs(m, n, k):
|
| 173 |
+
mm_template.maybe_append_choice(
|
| 174 |
+
choices,
|
| 175 |
+
input_nodes=(mat1, mat2),
|
| 176 |
+
layout=layout,
|
| 177 |
+
**mm_options(config, m, n, k, layout),
|
| 178 |
+
)
|
| 179 |
+
return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@register_lowering(aten.addmm, type_promotion_kind=None)
|
| 183 |
+
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
| 184 |
+
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
|
| 185 |
+
if m * n == 0 or not use_max_autotune():
|
| 186 |
+
choices = (
|
| 187 |
+
[
|
| 188 |
+
aten_addmm.bind(
|
| 189 |
+
(inp, mat1, mat2),
|
| 190 |
+
layout,
|
| 191 |
+
alpha=alpha,
|
| 192 |
+
beta=beta,
|
| 193 |
+
)
|
| 194 |
+
]
|
| 195 |
+
if use_aten_gemm_kernels()
|
| 196 |
+
else []
|
| 197 |
+
)
|
| 198 |
+
return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout)
|
| 199 |
+
|
| 200 |
+
choices = (
|
| 201 |
+
[
|
| 202 |
+
aten_addmm.bind(
|
| 203 |
+
(inp_expanded, mat1, mat2),
|
| 204 |
+
layout,
|
| 205 |
+
alpha=alpha,
|
| 206 |
+
beta=beta,
|
| 207 |
+
)
|
| 208 |
+
]
|
| 209 |
+
if use_aten_gemm_kernels()
|
| 210 |
+
else []
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
if (
|
| 214 |
+
use_aten_gemm_kernels()
|
| 215 |
+
and inp_expanded.get_stride()[0] == 0
|
| 216 |
+
and inp_expanded.get_device().type == "cuda"
|
| 217 |
+
and inductor_config.triton.autotune_cublasLt
|
| 218 |
+
):
|
| 219 |
+
# unexpand inp to make sure fused addmm from cublasLt is used
|
| 220 |
+
choices.insert(
|
| 221 |
+
0,
|
| 222 |
+
aten_bias_addmm.bind(
|
| 223 |
+
(inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
|
| 224 |
+
),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if use_triton_template(layout):
|
| 228 |
+
for config in mm_configs(m, n, k):
|
| 229 |
+
mm_template.maybe_append_choice(
|
| 230 |
+
choices,
|
| 231 |
+
input_nodes=(inp_expanded, mat1, mat2),
|
| 232 |
+
layout=layout,
|
| 233 |
+
**mm_options(config, m, n, k, layout),
|
| 234 |
+
prefix_args=1,
|
| 235 |
+
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if use_cutlass_template(layout):
|
| 239 |
+
CUTLASSGemmTemplate.add_cutlass_gemm_choices(
|
| 240 |
+
choices,
|
| 241 |
+
layout,
|
| 242 |
+
[mat1, mat2, inp_expanded],
|
| 243 |
+
alpha=alpha,
|
| 244 |
+
beta=beta,
|
| 245 |
+
input_reorder=[2, 0, 1],
|
| 246 |
+
fuseable=False,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
return autotune_select_algorithm(
|
| 250 |
+
"addmm", choices, [inp_expanded, mat1, mat2], layout
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def fallback_mixed_mm(mat1, mat2, *, out):
|
| 255 |
+
return torch.mm(mat1, mat2.to(mat1.dtype), out=out)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
aten_fallback_mixed_mm = ExternKernelChoice(fallback_mixed_mm, None)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@functools.lru_cache(None)
|
| 262 |
+
def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool:
|
| 263 |
+
props = torch.cuda.get_device_properties(index or 0)
|
| 264 |
+
return props.major <= 7
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def tuned_mixed_mm(mat1, mat2, mat2_dtype):
|
| 268 |
+
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None)
|
| 269 |
+
choices = [aten_fallback_mixed_mm.bind((mat1, mat2), layout)]
|
| 270 |
+
if (
|
| 271 |
+
mat1.layout.dtype != torch.float32 and not mat2.layout.is_contiguous()
|
| 272 |
+
) or _is_sm7x_or_older_gpu(layout.device.index):
|
| 273 |
+
# can't use triton kernel unless one of these is true or if running on v100 (numerical issues)
|
| 274 |
+
return autotune_select_algorithm("mixed_mm", choices, [mat1, mat2], layout)
|
| 275 |
+
if inductor_config.force_mixed_mm:
|
| 276 |
+
choices = []
|
| 277 |
+
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
|
| 278 |
+
has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2)
|
| 279 |
+
for config in mm_configs(m, n, k, has_int8_tensor=has_int8_tensor):
|
| 280 |
+
mm_template.maybe_append_choice(
|
| 281 |
+
choices,
|
| 282 |
+
input_nodes=(mat1, mat2),
|
| 283 |
+
layout=layout,
|
| 284 |
+
**mm_options(config, m, n, k, layout, b_prologue_cast_type),
|
| 285 |
+
)
|
| 286 |
+
return autotune_select_algorithm("mixed_mm", choices, [mat1, mat2], layout)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# This op is a special case of the int_mm op which we use based on the pattern
|
| 290 |
+
# _int_mm -> mul (defined in ../fx_passes/post_grad.py) in order to prevent
|
| 291 |
+
# realization of the int32 _int_mm output by forcing fusion with the mul op.
|
| 292 |
+
# This is only used when config.force_fuse_int_mm_with_mul = True
|
| 293 |
+
def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
|
| 294 |
+
out_dtype = (
|
| 295 |
+
torch.promote_types(mat3.get_dtype(), torch.int32)
|
| 296 |
+
if out_dtype is None
|
| 297 |
+
else out_dtype
|
| 298 |
+
)
|
| 299 |
+
m, n, k, layout, mat1, mat2, mat3 = mm_args(
|
| 300 |
+
mat1, mat2, mat3, layout=layout, out_dtype=out_dtype
|
| 301 |
+
)
|
| 302 |
+
choices: List[Dict[Any, Any]] = []
|
| 303 |
+
for config in int8_mm_configs(m, n, k):
|
| 304 |
+
mm_template.maybe_append_choice(
|
| 305 |
+
choices,
|
| 306 |
+
input_nodes=(mat1, mat2, mat3),
|
| 307 |
+
layout=layout,
|
| 308 |
+
**dict(mm_options(config, m, n, k, layout), ACC_TYPE="tl.int32"),
|
| 309 |
+
suffix_args=1,
|
| 310 |
+
epilogue_fn=V.ops.mul,
|
| 311 |
+
)
|
| 312 |
+
return autotune_select_algorithm("int_mm", choices, [mat1, mat2, mat3], layout)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_common.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import logging
|
| 3 |
+
from typing import cast, List, Tuple
|
| 4 |
+
|
| 5 |
+
import sympy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch._inductor.select_algorithm import realize_inputs
|
| 9 |
+
from torch._inductor.virtualized import V
|
| 10 |
+
|
| 11 |
+
from .. import config as inductor_config
|
| 12 |
+
from ..utils import ceildiv as cdiv, next_power_of_2
|
| 13 |
+
|
| 14 |
+
log = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def triton_config(num_stages, num_warps, **kwargs):
|
| 18 |
+
from triton import Config
|
| 19 |
+
|
| 20 |
+
return Config(kwargs, num_stages=num_stages, num_warps=num_warps)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def filtered_configs(
|
| 24 |
+
m: int,
|
| 25 |
+
n: int,
|
| 26 |
+
k: int,
|
| 27 |
+
configs: List[Tuple[int, int, int, int, int]],
|
| 28 |
+
has_int8_tensor=False,
|
| 29 |
+
):
|
| 30 |
+
"""Heuristic to shrink configs when they are bigger than the input size"""
|
| 31 |
+
|
| 32 |
+
# According to https://github.com/openai/triton/issues/2156#issuecomment-1695897424
|
| 33 |
+
# it's safer to use at least [32, 32] block size for int8/uint8
|
| 34 |
+
# tensors
|
| 35 |
+
min_block_size = 32 if has_int8_tensor else 16
|
| 36 |
+
m = max(
|
| 37 |
+
next_power_of_2(
|
| 38 |
+
V.graph.sizevars.size_hint(
|
| 39 |
+
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
| 40 |
+
)
|
| 41 |
+
),
|
| 42 |
+
min_block_size,
|
| 43 |
+
)
|
| 44 |
+
n = max(
|
| 45 |
+
next_power_of_2(
|
| 46 |
+
V.graph.sizevars.size_hint(
|
| 47 |
+
n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
| 48 |
+
)
|
| 49 |
+
),
|
| 50 |
+
min_block_size,
|
| 51 |
+
)
|
| 52 |
+
k = max(
|
| 53 |
+
next_power_of_2(
|
| 54 |
+
V.graph.sizevars.size_hint(
|
| 55 |
+
k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
| 56 |
+
)
|
| 57 |
+
),
|
| 58 |
+
min_block_size,
|
| 59 |
+
)
|
| 60 |
+
used = set()
|
| 61 |
+
for block_m, block_n, block_k, num_stages, num_warps in configs:
|
| 62 |
+
# shrink configs for small sizes
|
| 63 |
+
block_m = max(min(block_m, m), min_block_size)
|
| 64 |
+
block_n = max(min(block_n, n), min_block_size)
|
| 65 |
+
block_k = max(min(block_k, k), min_block_size)
|
| 66 |
+
# each warp computes 16x16 tile = 256
|
| 67 |
+
num_warps = min(num_warps, block_m * block_n // 256)
|
| 68 |
+
if torch.version.hip:
|
| 69 |
+
for matrix_instr_nonkdim in [0, 16]:
|
| 70 |
+
if matrix_instr_nonkdim != 0 and (
|
| 71 |
+
block_m % matrix_instr_nonkdim != 0
|
| 72 |
+
or block_n % matrix_instr_nonkdim != 0
|
| 73 |
+
):
|
| 74 |
+
# block_m and block_n must be a multiple of matrix_instr_nonkdim
|
| 75 |
+
continue
|
| 76 |
+
if (
|
| 77 |
+
block_m,
|
| 78 |
+
block_n,
|
| 79 |
+
block_k,
|
| 80 |
+
num_stages,
|
| 81 |
+
num_warps,
|
| 82 |
+
matrix_instr_nonkdim,
|
| 83 |
+
) not in used:
|
| 84 |
+
used.add(
|
| 85 |
+
(
|
| 86 |
+
block_m,
|
| 87 |
+
block_n,
|
| 88 |
+
block_k,
|
| 89 |
+
num_stages,
|
| 90 |
+
num_warps,
|
| 91 |
+
matrix_instr_nonkdim,
|
| 92 |
+
)
|
| 93 |
+
)
|
| 94 |
+
yield triton_config(
|
| 95 |
+
BLOCK_M=block_m,
|
| 96 |
+
BLOCK_N=block_n,
|
| 97 |
+
BLOCK_K=block_k,
|
| 98 |
+
num_stages=num_stages,
|
| 99 |
+
num_warps=num_warps,
|
| 100 |
+
matrix_instr_nonkdim=matrix_instr_nonkdim,
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used:
|
| 104 |
+
used.add((block_m, block_n, block_k, num_stages, num_warps, 0))
|
| 105 |
+
yield triton_config(
|
| 106 |
+
BLOCK_M=block_m,
|
| 107 |
+
BLOCK_N=block_n,
|
| 108 |
+
BLOCK_K=block_k,
|
| 109 |
+
num_stages=num_stages,
|
| 110 |
+
num_warps=num_warps,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
| 115 |
+
# will be utilised on the target platform
|
| 116 |
+
mm_kernel_configs = [
|
| 117 |
+
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
|
| 118 |
+
{"config": (64, 64, 32, 2, 4), "cond": True},
|
| 119 |
+
{"config": (64, 128, 32, 3, 4), "cond": True},
|
| 120 |
+
{"config": (128, 64, 32, 3, 4), "cond": True},
|
| 121 |
+
{"config": (64, 128, 32, 4, 8), "cond": True},
|
| 122 |
+
{"config": (128, 64, 32, 4, 8), "cond": True},
|
| 123 |
+
{"config": (64, 32, 32, 5, 8), "cond": True},
|
| 124 |
+
{"config": (32, 64, 32, 5, 8), "cond": True},
|
| 125 |
+
{"config": (128, 128, 32, 2, 8), "cond": True},
|
| 126 |
+
{"config": (64, 64, 64, 3, 8), "cond": True},
|
| 127 |
+
{"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None},
|
| 128 |
+
{"config": (64, 64, 16, 2, 4), "cond": True},
|
| 129 |
+
{"config": (32, 32, 16, 1, 2), "cond": True},
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
int8_mm_kernel_configs = [
|
| 133 |
+
{"config": (64, 64, 32, 2, 4), "cond": True},
|
| 134 |
+
{"config": (64, 128, 32, 3, 4), "cond": True},
|
| 135 |
+
{"config": (128, 64, 32, 3, 4), "cond": True},
|
| 136 |
+
{"config": (64, 128, 32, 4, 8), "cond": True},
|
| 137 |
+
{"config": (128, 64, 32, 4, 8), "cond": True},
|
| 138 |
+
{"config": (64, 32, 32, 5, 8), "cond": True},
|
| 139 |
+
{"config": (32, 64, 32, 5, 8), "cond": True},
|
| 140 |
+
{"config": (128, 128, 32, 2, 8), "cond": True},
|
| 141 |
+
{"config": (64, 64, 64, 3, 8), "cond": True},
|
| 142 |
+
# {"config": (32, 32, 128, 2, 4), "cond": True},
|
| 143 |
+
# {"config": (64, 64, 16, 2, 4), "cond": True},
|
| 144 |
+
# {"config": (32, 32, 16, 1, 2), "cond": True},
|
| 145 |
+
{"config": (128, 256, 128, 3, 8), "cond": torch.version.hip is None},
|
| 146 |
+
{"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None},
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
# Create filtered list of configs based on cond evaluation
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
mm_platform_configs = tuple(
|
| 153 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 154 |
+
for config in mm_kernel_configs
|
| 155 |
+
if config["cond"]
|
| 156 |
+
)
|
| 157 |
+
int8_platform_configs = tuple(
|
| 158 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 159 |
+
for config in int8_mm_kernel_configs
|
| 160 |
+
if config["cond"]
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# On ROCm convert num_stages to 1 as pipelining provides no benefit
|
| 164 |
+
if torch.version.hip:
|
| 165 |
+
mm_platform_configs = tuple(
|
| 166 |
+
(config[0], config[1], config[2], 1, config[4])
|
| 167 |
+
for config in mm_platform_configs
|
| 168 |
+
)
|
| 169 |
+
int8_platform_configs = tuple(
|
| 170 |
+
(config[0], config[1], config[2], 1, config[4])
|
| 171 |
+
for config in mm_platform_configs
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
mm_configs = functools.partial(
|
| 175 |
+
filtered_configs,
|
| 176 |
+
configs=mm_platform_configs,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
int8_mm_configs = functools.partial(
|
| 180 |
+
filtered_configs,
|
| 181 |
+
configs=int8_platform_configs,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def mm_grid(m, n, meta):
|
| 186 |
+
"""
|
| 187 |
+
The CUDA grid size for matmul triton templates.
|
| 188 |
+
"""
|
| 189 |
+
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def acc_type(dtype):
|
| 193 |
+
if dtype in (torch.float16, torch.bfloat16):
|
| 194 |
+
return "tl.float32"
|
| 195 |
+
return f"tl.{dtype}".replace("torch.", "")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None):
|
| 199 |
+
"""
|
| 200 |
+
Common options to matmul triton templates.
|
| 201 |
+
"""
|
| 202 |
+
even_k_symbolic = (
|
| 203 |
+
# it isn't worth guarding on this
|
| 204 |
+
sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
|
| 205 |
+
== config.kwargs["BLOCK_K"]
|
| 206 |
+
)
|
| 207 |
+
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
|
| 208 |
+
not inductor_config.force_same_precision
|
| 209 |
+
or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0)
|
| 210 |
+
)
|
| 211 |
+
return dict(
|
| 212 |
+
GROUP_M=8,
|
| 213 |
+
EVEN_K=even_k_symbolic,
|
| 214 |
+
ALLOW_TF32=allow_tf32,
|
| 215 |
+
ACC_TYPE=acc_type(layout.dtype),
|
| 216 |
+
B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
|
| 217 |
+
num_stages=config.num_stages,
|
| 218 |
+
num_warps=config.num_warps,
|
| 219 |
+
**config.kwargs,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def mm_args(mat1, mat2, *others, layout=None, out_dtype=None, use_4x2_dim=False):
|
| 224 |
+
"""
|
| 225 |
+
Common arg processing for mm,bmm,addmm,etc
|
| 226 |
+
"""
|
| 227 |
+
mat1, mat2 = realize_inputs(mat1, mat2)
|
| 228 |
+
*b1, m, k1 = mat1.get_size()
|
| 229 |
+
*b2, k2, n = mat2.get_size()
|
| 230 |
+
b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
|
| 231 |
+
if use_4x2_dim:
|
| 232 |
+
k2 = k2 * 2
|
| 233 |
+
k = V.graph.sizevars.guard_equals(k1, k2)
|
| 234 |
+
if layout is None:
|
| 235 |
+
from torch._inductor.ir import FixedLayout
|
| 236 |
+
|
| 237 |
+
if out_dtype is None:
|
| 238 |
+
out_dtype = mat1.get_dtype()
|
| 239 |
+
layout = FixedLayout(
|
| 240 |
+
mat1.get_device(),
|
| 241 |
+
out_dtype,
|
| 242 |
+
[*b, m, n],
|
| 243 |
+
)
|
| 244 |
+
else:
|
| 245 |
+
assert out_dtype is None, "out_dtype is ignored if layout is specified."
|
| 246 |
+
|
| 247 |
+
from ..lowering import expand
|
| 248 |
+
|
| 249 |
+
others = [realize_inputs(expand(x, layout.size)) for x in others]
|
| 250 |
+
|
| 251 |
+
return [m, n, k, layout, mat1, mat2, *others]
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def addmm_epilogue(dtype, alpha, beta):
|
| 255 |
+
def epilogue(acc, bias):
|
| 256 |
+
if alpha != 1:
|
| 257 |
+
acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
|
| 258 |
+
if beta != 1:
|
| 259 |
+
bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
|
| 260 |
+
return V.ops.add(acc, bias)
|
| 261 |
+
|
| 262 |
+
return epilogue
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_plus_mm.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ..lowering import lowerings
|
| 6 |
+
from ..select_algorithm import (
|
| 7 |
+
autotune_select_algorithm,
|
| 8 |
+
ExternKernelChoice,
|
| 9 |
+
TritonTemplate,
|
| 10 |
+
)
|
| 11 |
+
from ..utils import use_aten_gemm_kernels, use_triton_template
|
| 12 |
+
from ..virtualized import V
|
| 13 |
+
from .mm_common import mm_args, mm_grid, mm_options
|
| 14 |
+
|
| 15 |
+
aten = torch.ops.aten
|
| 16 |
+
|
| 17 |
+
aten_mm_plus_mm = ExternKernelChoice(
|
| 18 |
+
torch.ops.inductor._mm_plus_mm, "torch::inductor::_mm_plus_mm"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
mm_plus_mm_template = TritonTemplate(
|
| 22 |
+
name="mm_plus_mm",
|
| 23 |
+
grid=mm_grid,
|
| 24 |
+
debug=False,
|
| 25 |
+
source=r"""
|
| 26 |
+
{{def_kernel("A", "B", "C", "D")}}
|
| 27 |
+
M = {{size("A", 0)}}
|
| 28 |
+
N = {{size("B", 1)}}
|
| 29 |
+
K1 = {{size("A", 1)}}
|
| 30 |
+
if M * N == 0:
|
| 31 |
+
# early exit due to zero-size input(s)
|
| 32 |
+
return
|
| 33 |
+
# K2 = {{size("C", 1)}}
|
| 34 |
+
stride_am = {{stride("A", 0)}}
|
| 35 |
+
stride_ak = {{stride("A", 1)}}
|
| 36 |
+
stride_bk = {{stride("B", 0)}}
|
| 37 |
+
stride_bn = {{stride("B", 1)}}
|
| 38 |
+
stride_cm = {{stride("C", 0)}}
|
| 39 |
+
stride_ck = {{stride("C", 1)}}
|
| 40 |
+
stride_dk = {{stride("D", 0)}}
|
| 41 |
+
stride_dn = {{stride("D", 1)}}
|
| 42 |
+
|
| 43 |
+
# based on triton.ops.matmul
|
| 44 |
+
pid = tl.program_id(0)
|
| 45 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 46 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 47 |
+
|
| 48 |
+
# re-order program ID for better L2 performance
|
| 49 |
+
width = GROUP_M * grid_n
|
| 50 |
+
group_id = pid // width
|
| 51 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 52 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 53 |
+
pid_n = (pid % width) // (group_size)
|
| 54 |
+
|
| 55 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 56 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 57 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 58 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 59 |
+
rk = tl.arange(0, BLOCK_K)
|
| 60 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 61 |
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
| 62 |
+
C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck)
|
| 63 |
+
D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn)
|
| 64 |
+
|
| 65 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 66 |
+
for k1 in range(K1, 0, -BLOCK_K):
|
| 67 |
+
# First matmul with A @ B
|
| 68 |
+
if EVEN_K:
|
| 69 |
+
a = tl.load(A)
|
| 70 |
+
b = tl.load(B)
|
| 71 |
+
else:
|
| 72 |
+
a = tl.load(A, mask=rk[None, :] < k1, other=0.)
|
| 73 |
+
b = tl.load(B, mask=rk[:, None] < k1, other=0.)
|
| 74 |
+
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
| 75 |
+
A += BLOCK_K * stride_ak
|
| 76 |
+
B += BLOCK_K * stride_bk
|
| 77 |
+
|
| 78 |
+
for k2 in range(K1, 0, -BLOCK_K):
|
| 79 |
+
|
| 80 |
+
# Second matmul with C @ D
|
| 81 |
+
if EVEN_K:
|
| 82 |
+
c = tl.load(C)
|
| 83 |
+
d = tl.load(D)
|
| 84 |
+
else:
|
| 85 |
+
c = tl.load(C, mask=rk[None, :] < k2, other=0.)
|
| 86 |
+
d = tl.load(D, mask=rk[:, None] < k2, other=0.)
|
| 87 |
+
acc += tl.dot(c, d, allow_tf32=ALLOW_TF32)
|
| 88 |
+
C += BLOCK_K * stride_ck
|
| 89 |
+
D += BLOCK_K * stride_dk
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
idx_m = rm[:, None]
|
| 93 |
+
idx_n = rn[None, :]
|
| 94 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 95 |
+
|
| 96 |
+
# inductor generates a suffix
|
| 97 |
+
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
| 98 |
+
""",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@functools.lru_cache(None)
|
| 103 |
+
def mm_configs():
|
| 104 |
+
import triton
|
| 105 |
+
|
| 106 |
+
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
| 107 |
+
# will be utilised on the target platform
|
| 108 |
+
mm_triton_configs = [
|
| 109 |
+
{
|
| 110 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
| 111 |
+
"num_stages": 2,
|
| 112 |
+
"num_warps": 4,
|
| 113 |
+
"cond": True,
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
| 117 |
+
"num_stages": 3,
|
| 118 |
+
"num_warps": 8,
|
| 119 |
+
"cond": True,
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
| 123 |
+
"num_stages": 4,
|
| 124 |
+
"num_warps": 16,
|
| 125 |
+
"cond": True,
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32},
|
| 129 |
+
"num_stages": 4,
|
| 130 |
+
"num_warps": 8,
|
| 131 |
+
"cond": True,
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32},
|
| 135 |
+
"num_stages": 4,
|
| 136 |
+
"num_warps": 8,
|
| 137 |
+
"cond": True,
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32},
|
| 141 |
+
"num_stages": 1,
|
| 142 |
+
"num_warps": 8,
|
| 143 |
+
"cond": True,
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64},
|
| 147 |
+
"num_stages": 1,
|
| 148 |
+
"num_warps": 8,
|
| 149 |
+
"cond": True,
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128},
|
| 153 |
+
"num_stages": 1,
|
| 154 |
+
"num_warps": 8,
|
| 155 |
+
"cond": torch.version.hip is None,
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16},
|
| 159 |
+
"num_stages": 2,
|
| 160 |
+
"num_warps": 4,
|
| 161 |
+
"cond": True,
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16},
|
| 165 |
+
"num_stages": 1,
|
| 166 |
+
"num_warps": 2,
|
| 167 |
+
"cond": True,
|
| 168 |
+
},
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
# Filter out configs in which cond evaluates to true
|
| 172 |
+
# On ROCm convert num_stages to 1 as pipelining provides no benefit
|
| 173 |
+
if torch.version.hip:
|
| 174 |
+
filtered_configs = [
|
| 175 |
+
triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"])
|
| 176 |
+
for c in mm_triton_configs
|
| 177 |
+
if c["cond"]
|
| 178 |
+
]
|
| 179 |
+
else:
|
| 180 |
+
filtered_configs = [
|
| 181 |
+
triton.Config(
|
| 182 |
+
c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"]
|
| 183 |
+
)
|
| 184 |
+
for c in mm_triton_configs
|
| 185 |
+
if c["cond"]
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
return filtered_configs
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
| 192 |
+
"""
|
| 193 |
+
Computes mm(mat1, mat2) + mm(mat3, mat4)
|
| 194 |
+
"""
|
| 195 |
+
m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
| 196 |
+
m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
|
| 197 |
+
# Optimization is optional, because we can always just not do the fusion
|
| 198 |
+
if (
|
| 199 |
+
m1 * n1 == 0
|
| 200 |
+
or m2 * n2 == 0
|
| 201 |
+
or not V.graph.sizevars.statically_known_list_equals(
|
| 202 |
+
mat1.get_size(), mat3.get_size()
|
| 203 |
+
)
|
| 204 |
+
or not V.graph.sizevars.statically_known_list_equals(
|
| 205 |
+
mat2.get_size(), mat4.get_size()
|
| 206 |
+
)
|
| 207 |
+
):
|
| 208 |
+
# TODO(jansel): support different K values when this is fixed:
|
| 209 |
+
# https://github.com/openai/triton/issues/967
|
| 210 |
+
return lowerings[aten.add](
|
| 211 |
+
lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4)
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
assert layout1 == layout2
|
| 215 |
+
# options to tune from
|
| 216 |
+
choices = (
|
| 217 |
+
[aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)]
|
| 218 |
+
if use_aten_gemm_kernels()
|
| 219 |
+
else []
|
| 220 |
+
)
|
| 221 |
+
if use_triton_template(layout1):
|
| 222 |
+
for config in mm_configs():
|
| 223 |
+
# see https://github.com/openai/triton/issues/1298
|
| 224 |
+
# BLOCK_K = K causes llvm error
|
| 225 |
+
if config.kwargs["BLOCK_K"] < k1:
|
| 226 |
+
mm_plus_mm_template.maybe_append_choice(
|
| 227 |
+
choices,
|
| 228 |
+
input_nodes=(mat1, mat2, mat3, mat4),
|
| 229 |
+
layout=layout1,
|
| 230 |
+
**mm_options(config, m1, n1, k1, layout1),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
return autotune_select_algorithm(
|
| 234 |
+
"mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1
|
| 235 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/EmptyTensor.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
#include <ATen/core/TensorBase.h>
|
| 5 |
+
|
| 6 |
+
namespace at::detail {
|
| 7 |
+
|
| 8 |
+
C10_EXPORT TensorBase empty_mps(
|
| 9 |
+
IntArrayRef size,
|
| 10 |
+
c10::optional<ScalarType> dtype_opt,
|
| 11 |
+
c10::optional<Layout> layout_opt,
|
| 12 |
+
c10::optional<Device> device_opt,
|
| 13 |
+
c10::optional<bool> pin_memory_opt,
|
| 14 |
+
c10::optional<c10::MemoryFormat> memory_format_opt);
|
| 15 |
+
C10_EXPORT TensorBase empty_mps(
|
| 16 |
+
IntArrayRef size, const TensorOptions &options);
|
| 17 |
+
|
| 18 |
+
C10_EXPORT TensorBase empty_strided_mps(
|
| 19 |
+
IntArrayRef size,
|
| 20 |
+
IntArrayRef stride,
|
| 21 |
+
ScalarType dtype,
|
| 22 |
+
c10::optional<Device> device_opt);
|
| 23 |
+
|
| 24 |
+
C10_EXPORT TensorBase empty_strided_mps(
|
| 25 |
+
IntArrayRef size,
|
| 26 |
+
IntArrayRef stride,
|
| 27 |
+
const TensorOptions &options);
|
| 28 |
+
|
| 29 |
+
} // namespace at::detail
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/IndexKernels.h
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at::mps {
|
| 4 |
+
|
| 5 |
+
static const char * indexing_metal_shaders = R"INDEX_METAL(
|
| 6 |
+
#include <metal_stdlib>
|
| 7 |
+
#include <metal_atomic>
|
| 8 |
+
|
| 9 |
+
using namespace metal;
|
| 10 |
+
|
| 11 |
+
#if __METAL_VERSION__ < 300
|
| 12 |
+
struct IndexAB {
|
| 13 |
+
// Allow up to 16 indices
|
| 14 |
+
metal::array<constant void *, 16> indexArray [[ id(0) ]];
|
| 15 |
+
};
|
| 16 |
+
#else
|
| 17 |
+
struct IndexAB {
|
| 18 |
+
constant int64_t* indexArray;
|
| 19 |
+
};
|
| 20 |
+
|
| 21 |
+
#endif
|
| 22 |
+
|
| 23 |
+
template<typename T, typename OffsetsT>
|
| 24 |
+
kernel void index_select(
|
| 25 |
+
#if __METAL_VERSION__ >= 300
|
| 26 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 27 |
+
#else
|
| 28 |
+
constant IndexAB & indexAB [[buffer(0)]],
|
| 29 |
+
#endif
|
| 30 |
+
constant void * indexSizes [[buffer(1)]],
|
| 31 |
+
constant void * indexStrides [[buffer(2)]],
|
| 32 |
+
constant OffsetsT * offsets [[buffer(3)]],
|
| 33 |
+
constant void * inputData [[buffer(4)]],
|
| 34 |
+
device void * outputData [[buffer(5)]],
|
| 35 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 36 |
+
uint thread_index [[thread_position_in_grid]]) {
|
| 37 |
+
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
| 38 |
+
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
| 39 |
+
int64_t offset = 0;
|
| 40 |
+
for (uint32_t i = 0; i < num_indices; i++) {
|
| 41 |
+
#if __METAL_VERSION__ >= 300
|
| 42 |
+
constant int64_t* indexArray = indexAB[i].indexArray;
|
| 43 |
+
#else
|
| 44 |
+
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
|
| 45 |
+
#endif
|
| 46 |
+
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
| 47 |
+
if (index < 0) {
|
| 48 |
+
index += index_sizes[i];
|
| 49 |
+
}
|
| 50 |
+
offset += index * index_strides[i];
|
| 51 |
+
}
|
| 52 |
+
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
|
| 53 |
+
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y + offset);
|
| 54 |
+
*out = *in;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
template<typename T, typename OffsetsT>
|
| 58 |
+
void index_put_impl(
|
| 59 |
+
#if __METAL_VERSION__ >= 300
|
| 60 |
+
constant IndexAB * indexAB,
|
| 61 |
+
#else
|
| 62 |
+
constant IndexAB & indexAB,
|
| 63 |
+
#endif
|
| 64 |
+
constant int64_t * index_sizes,
|
| 65 |
+
constant int64_t * index_strides,
|
| 66 |
+
constant OffsetsT * offsets,
|
| 67 |
+
constant void * inputData,
|
| 68 |
+
device void * outputData,
|
| 69 |
+
constant uint32_t & num_indices,
|
| 70 |
+
uint thread_index) {
|
| 71 |
+
int64_t offset = 0;
|
| 72 |
+
for (uint32_t i = 0; i < num_indices; i++) {
|
| 73 |
+
#if __METAL_VERSION__ >= 300
|
| 74 |
+
constant int64_t* indexArray = indexAB[i].indexArray;
|
| 75 |
+
#else
|
| 76 |
+
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
|
| 77 |
+
#endif
|
| 78 |
+
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
| 79 |
+
|
| 80 |
+
if (index < 0) {
|
| 81 |
+
index += index_sizes[i];
|
| 82 |
+
}
|
| 83 |
+
offset += index * index_strides[i];
|
| 84 |
+
}
|
| 85 |
+
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
|
| 86 |
+
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
|
| 87 |
+
*out = *in;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
template<typename T, typename OffsetsT>
|
| 91 |
+
kernel void index_put_serial(
|
| 92 |
+
#if __METAL_VERSION__ >= 300
|
| 93 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 94 |
+
#else
|
| 95 |
+
constant IndexAB & indexAB [[buffer(0)]],
|
| 96 |
+
#endif
|
| 97 |
+
constant void * indexSizes [[buffer(1)]],
|
| 98 |
+
constant void * indexStrides [[buffer(2)]],
|
| 99 |
+
constant OffsetsT * offsets [[buffer(3)]],
|
| 100 |
+
constant void * inputData [[buffer(4)]],
|
| 101 |
+
device void * outputData [[buffer(5)]],
|
| 102 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 103 |
+
constant uint * numIters [[buffer(7)]],
|
| 104 |
+
uint thread_index [[thread_position_in_grid]]) {
|
| 105 |
+
|
| 106 |
+
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
| 107 |
+
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
| 108 |
+
|
| 109 |
+
for (uint iter_i = 0; iter_i < *numIters; iter_i++) {
|
| 110 |
+
index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i);
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
template<typename T, typename OffsetsT>
|
| 115 |
+
kernel void index_put(
|
| 116 |
+
#if __METAL_VERSION__ >= 300
|
| 117 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 118 |
+
#else
|
| 119 |
+
constant IndexAB & indexAB [[buffer(0)]],
|
| 120 |
+
#endif
|
| 121 |
+
constant void * indexSizes [[buffer(1)]],
|
| 122 |
+
constant void * indexStrides [[buffer(2)]],
|
| 123 |
+
constant OffsetsT * offsets [[buffer(3)]],
|
| 124 |
+
constant void * inputData [[buffer(4)]],
|
| 125 |
+
device void * outputData [[buffer(5)]],
|
| 126 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 127 |
+
uint thread_index [[thread_position_in_grid]]) {
|
| 128 |
+
|
| 129 |
+
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
| 130 |
+
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
| 131 |
+
index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
#if __METAL_VERSION__ < 300
|
| 135 |
+
#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
|
| 136 |
+
template \
|
| 137 |
+
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
|
| 138 |
+
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
|
| 139 |
+
constant IndexAB & indexAB [[buffer(0)]], \
|
| 140 |
+
constant void * indexSizes [[buffer(1)]], \
|
| 141 |
+
constant void * indexStrides [[buffer(2)]], \
|
| 142 |
+
constant IDX_DTYPE * offsets [[buffer(3)]], \
|
| 143 |
+
constant void * inputData [[buffer(4)]], \
|
| 144 |
+
device void * outputData [[buffer(5)]], \
|
| 145 |
+
constant uint32_t & num_indices [[buffer(6)]], \
|
| 146 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 147 |
+
#else
|
| 148 |
+
#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
|
| 149 |
+
template \
|
| 150 |
+
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
|
| 151 |
+
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
|
| 152 |
+
constant IndexAB * indexAB [[buffer(0)]], \
|
| 153 |
+
constant void * indexSizes [[buffer(1)]], \
|
| 154 |
+
constant void * indexStrides [[buffer(2)]], \
|
| 155 |
+
constant IDX_DTYPE * offsets [[buffer(3)]], \
|
| 156 |
+
constant void * inputData [[buffer(4)]], \
|
| 157 |
+
device void * outputData [[buffer(5)]], \
|
| 158 |
+
constant uint32_t & num_indices [[buffer(6)]], \
|
| 159 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 160 |
+
#endif
|
| 161 |
+
|
| 162 |
+
#define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
|
| 163 |
+
REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
|
| 164 |
+
REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
|
| 165 |
+
REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
|
| 166 |
+
REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
|
| 167 |
+
REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
|
| 168 |
+
REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
|
| 169 |
+
REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
|
| 170 |
+
REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
|
| 171 |
+
|
| 172 |
+
REGISTER_INDEX_OP_ALL_DTYPES(select);
|
| 173 |
+
REGISTER_INDEX_OP_ALL_DTYPES(put);
|
| 174 |
+
|
| 175 |
+
#if __METAL_VERSION__ < 300
|
| 176 |
+
#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
|
| 177 |
+
template \
|
| 178 |
+
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
|
| 179 |
+
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
|
| 180 |
+
constant IndexAB & indexAB [[buffer(0)]], \
|
| 181 |
+
constant void * indexSizes [[buffer(1)]], \
|
| 182 |
+
constant void * indexStrides [[buffer(2)]], \
|
| 183 |
+
constant IDX_DTYPE * offsets [[buffer(3)]], \
|
| 184 |
+
constant void * inputData [[buffer(4)]], \
|
| 185 |
+
device void * outputData [[buffer(5)]], \
|
| 186 |
+
constant uint32_t & num_indices [[buffer(6)]], \
|
| 187 |
+
constant uint * numIters [[buffer(7)]], \
|
| 188 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 189 |
+
#else
|
| 190 |
+
#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
|
| 191 |
+
template \
|
| 192 |
+
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
|
| 193 |
+
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
|
| 194 |
+
constant IndexAB * indexAB [[buffer(0)]], \
|
| 195 |
+
constant void * indexSizes [[buffer(1)]], \
|
| 196 |
+
constant void * indexStrides [[buffer(2)]], \
|
| 197 |
+
constant IDX_DTYPE * offsets [[buffer(3)]], \
|
| 198 |
+
constant void * inputData [[buffer(4)]], \
|
| 199 |
+
device void * outputData [[buffer(5)]], \
|
| 200 |
+
constant uint32_t & num_indices [[buffer(6)]], \
|
| 201 |
+
constant uint * numIters [[buffer(7)]], \
|
| 202 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 203 |
+
#endif
|
| 204 |
+
|
| 205 |
+
#define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
|
| 206 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
|
| 207 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
|
| 208 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
|
| 209 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
|
| 210 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
|
| 211 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
|
| 212 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
|
| 213 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
|
| 214 |
+
|
| 215 |
+
REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial);
|
| 216 |
+
|
| 217 |
+
template<typename StridesT, typename DataT>
|
| 218 |
+
kernel void kernel_index_offsets(constant StridesT * strides [[buffer(0)]],
|
| 219 |
+
device DataT * data_offsets [[buffer(1)]],
|
| 220 |
+
constant uint * iter_shape [[buffer(2)]],
|
| 221 |
+
constant uint & num_dimensions [[buffer(3)]],
|
| 222 |
+
uint thread_index [[thread_position_in_grid]]) {
|
| 223 |
+
data_offsets[thread_index] = 0;
|
| 224 |
+
uint32_t idx = thread_index;
|
| 225 |
+
for (uint32_t dim = 0; dim < num_dimensions; dim++) {
|
| 226 |
+
uint32_t remainder = idx % iter_shape[dim];
|
| 227 |
+
idx /= iter_shape[dim];
|
| 228 |
+
|
| 229 |
+
data_offsets[thread_index] += remainder * DataT(strides[dim]);
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
template
|
| 234 |
+
[[host_name("kernel_index_offsets_32")]]
|
| 235 |
+
kernel void kernel_index_offsets<packed_uint3, uint3>(
|
| 236 |
+
constant packed_uint3 * strides [[buffer(0)]],
|
| 237 |
+
device uint3 * data_offsets [[buffer(1)]],
|
| 238 |
+
constant uint * iter_shape [[buffer(2)]],
|
| 239 |
+
constant uint & num_dimensions [[buffer(3)]],
|
| 240 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 241 |
+
|
| 242 |
+
template
|
| 243 |
+
[[host_name("kernel_index_offsets_64")]]
|
| 244 |
+
kernel void kernel_index_offsets<packed_uint3, ulong3>(
|
| 245 |
+
constant packed_uint3 * strides [[buffer(0)]],
|
| 246 |
+
device ulong3 * data_offsets [[buffer(1)]],
|
| 247 |
+
constant uint * iter_shape [[buffer(2)]],
|
| 248 |
+
constant uint & num_dimensions [[buffer(3)]],
|
| 249 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 250 |
+
|
| 251 |
+
template<typename T, typename E, typename OffsetsT>
|
| 252 |
+
kernel void index_put_accumulate_native_dtypes(
|
| 253 |
+
#if __METAL_VERSION__ >= 300
|
| 254 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 255 |
+
#else
|
| 256 |
+
constant IndexAB & indexAB [[buffer(0)]],
|
| 257 |
+
#endif
|
| 258 |
+
constant void * indexSizes [[buffer(1)]],
|
| 259 |
+
constant void * indexStrides [[buffer(2)]],
|
| 260 |
+
constant OffsetsT * offsets [[buffer(3)]],
|
| 261 |
+
constant void * inputData [[buffer(4)]],
|
| 262 |
+
device void * outputData [[buffer(5)]],
|
| 263 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 264 |
+
uint thread_index [[thread_position_in_grid]]) {
|
| 265 |
+
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
| 266 |
+
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
| 267 |
+
int64_t offset = 0;
|
| 268 |
+
for (uint32_t i = 0; i < num_indices; i++) {
|
| 269 |
+
#if __METAL_VERSION__ >= 300
|
| 270 |
+
constant int64_t* indexArray = indexAB[i].indexArray;
|
| 271 |
+
#else
|
| 272 |
+
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
|
| 273 |
+
#endif
|
| 274 |
+
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
| 275 |
+
if (index < 0) {
|
| 276 |
+
index += index_sizes[i];
|
| 277 |
+
}
|
| 278 |
+
offset += index * index_strides[i];
|
| 279 |
+
}
|
| 280 |
+
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
|
| 281 |
+
constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index].y);
|
| 282 |
+
atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
template<typename T>
|
| 286 |
+
__attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) {
|
| 287 |
+
device atomic_uint* uintAddr = (device atomic_uint*)addr;
|
| 288 |
+
uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed);
|
| 289 |
+
T updated = as_type<T>(expected) + value;
|
| 290 |
+
while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type<uint>(updated), memory_order_relaxed, memory_order_relaxed)) {
|
| 291 |
+
updated = as_type<T>(expected) + value;
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
template<typename T, typename OffsetsT>
|
| 296 |
+
kernel void atomic_index_put_accumulate(
|
| 297 |
+
#if __METAL_VERSION__ >= 300
|
| 298 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 299 |
+
#else
|
| 300 |
+
constant IndexAB & indexAB [[buffer(0)]],
|
| 301 |
+
#endif
|
| 302 |
+
constant void * indexSizes [[buffer(1)]],
|
| 303 |
+
constant void * indexStrides [[buffer(2)]],
|
| 304 |
+
constant OffsetsT * offsets [[buffer(3)]],
|
| 305 |
+
constant void * inputData [[buffer(4)]],
|
| 306 |
+
device void * outputData [[buffer(5)]],
|
| 307 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 308 |
+
uint thread_index [[thread_position_in_grid]]) {
|
| 309 |
+
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
| 310 |
+
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
| 311 |
+
int64_t offset = 0;
|
| 312 |
+
for (uint32_t i = 0; i < num_indices; i++) {
|
| 313 |
+
#if __METAL_VERSION__ >= 300
|
| 314 |
+
constant int64_t* indexArray = indexAB[i].indexArray;
|
| 315 |
+
#else
|
| 316 |
+
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
|
| 317 |
+
#endif
|
| 318 |
+
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
| 319 |
+
if (index < 0) {
|
| 320 |
+
index += index_sizes[i];
|
| 321 |
+
}
|
| 322 |
+
offset += index * index_strides[i];
|
| 323 |
+
}
|
| 324 |
+
device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset);
|
| 325 |
+
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
|
| 326 |
+
atomic_fetch_add_relaxed<T>(out, *in);
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
template
|
| 330 |
+
[[host_name("index_put_accumulate_32bit_float_idx32")]]
|
| 331 |
+
kernel void atomic_index_put_accumulate<float, uint3>(
|
| 332 |
+
#if __METAL_VERSION__ >= 300
|
| 333 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 334 |
+
#else
|
| 335 |
+
constant IndexAB & indexAB [[buffer(0)]],
|
| 336 |
+
#endif
|
| 337 |
+
constant void * indexSizes [[buffer(1)]],
|
| 338 |
+
constant void * indexStrides [[buffer(2)]],
|
| 339 |
+
constant uint3 * offsets [[buffer(3)]],
|
| 340 |
+
constant void * inputData [[buffer(4)]],
|
| 341 |
+
device void * outputData [[buffer(5)]],
|
| 342 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 343 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 344 |
+
|
| 345 |
+
template
|
| 346 |
+
[[host_name("index_put_accumulate_32bit_float_idx64")]]
|
| 347 |
+
kernel void atomic_index_put_accumulate<float, ulong3>(
|
| 348 |
+
#if __METAL_VERSION__ >= 300
|
| 349 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 350 |
+
#else
|
| 351 |
+
constant IndexAB & indexAB [[buffer(0)]],
|
| 352 |
+
#endif
|
| 353 |
+
constant void * indexSizes [[buffer(1)]],
|
| 354 |
+
constant void * indexStrides [[buffer(2)]],
|
| 355 |
+
constant ulong3 * offsets [[buffer(3)]],
|
| 356 |
+
constant void * inputData [[buffer(4)]],
|
| 357 |
+
device void * outputData [[buffer(5)]],
|
| 358 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 359 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 360 |
+
|
| 361 |
+
template
|
| 362 |
+
[[host_name("index_put_accumulate_32bit_int_idx32")]]
|
| 363 |
+
kernel void index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
|
| 364 |
+
#if __METAL_VERSION__ >= 300
|
| 365 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 366 |
+
#else
|
| 367 |
+
constant IndexAB & indexAB [[buffer(0)]],
|
| 368 |
+
#endif
|
| 369 |
+
constant void * indexSizes [[buffer(1)]],
|
| 370 |
+
constant void * indexStrides [[buffer(2)]],
|
| 371 |
+
constant uint3 * offsets [[buffer(3)]],
|
| 372 |
+
constant void * inputData [[buffer(4)]],
|
| 373 |
+
device void * outputData [[buffer(5)]],
|
| 374 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 375 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 376 |
+
|
| 377 |
+
template
|
| 378 |
+
[[host_name("index_put_accumulate_32bit_int_idx64")]]
|
| 379 |
+
kernel void index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
|
| 380 |
+
#if __METAL_VERSION__ >= 300
|
| 381 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 382 |
+
#else
|
| 383 |
+
constant IndexAB & indexAB [[buffer(0)]],
|
| 384 |
+
#endif
|
| 385 |
+
constant void * indexSizes [[buffer(1)]],
|
| 386 |
+
constant void * indexStrides [[buffer(2)]],
|
| 387 |
+
constant ulong3 * offsets [[buffer(3)]],
|
| 388 |
+
constant void * inputData [[buffer(4)]],
|
| 389 |
+
device void * outputData [[buffer(5)]],
|
| 390 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 391 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 392 |
+
)INDEX_METAL";
|
| 393 |
+
|
| 394 |
+
static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
|
| 395 |
+
struct __attribute__ ((packed)) packed_uint5{{
|
| 396 |
+
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
|
| 397 |
+
}};
|
| 398 |
+
|
| 399 |
+
template<typename Y, typename X>
|
| 400 |
+
Y cast(const X x);
|
| 401 |
+
|
| 402 |
+
template<>
|
| 403 |
+
{1} cast<{1}, {0}>(const {0} x) {{
|
| 404 |
+
return {2};
|
| 405 |
+
}}
|
| 406 |
+
|
| 407 |
+
kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]],
|
| 408 |
+
constant void * src_ [[buffer(0)]],
|
| 409 |
+
device void * dst_ [[buffer(1)]],
|
| 410 |
+
constant packed_uint5 & size [[buffer(2)]],
|
| 411 |
+
constant packed_uint5 & stride [[buffer(3)]],
|
| 412 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 413 |
+
if (linear_index >= numel) return;
|
| 414 |
+
|
| 415 |
+
constant {0} * src = (constant {0} *)src_;
|
| 416 |
+
device {1} * dst = (device {1} *)dst_;
|
| 417 |
+
|
| 418 |
+
packed_uint5 local_index;
|
| 419 |
+
local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
|
| 420 |
+
local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
|
| 421 |
+
local_index.z = linear_index / (size.u * size.w) % size.z;
|
| 422 |
+
local_index.w = linear_index / size.u % size.w;
|
| 423 |
+
local_index.u = linear_index % size.u;
|
| 424 |
+
|
| 425 |
+
packed_uint5 strided_index;
|
| 426 |
+
strided_index.x = local_index.x * stride.x;
|
| 427 |
+
strided_index.y = local_index.y * stride.y;
|
| 428 |
+
strided_index.z = local_index.z * stride.z;
|
| 429 |
+
strided_index.w = local_index.w * stride.w;
|
| 430 |
+
strided_index.u = local_index.u * stride.u;
|
| 431 |
+
|
| 432 |
+
dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]);
|
| 433 |
+
}}
|
| 434 |
+
|
| 435 |
+
kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
|
| 436 |
+
constant void * src_ [[buffer(0)]],
|
| 437 |
+
device void * dst_ [[buffer(1)]],
|
| 438 |
+
constant packed_uint4 & size [[buffer(2)]],
|
| 439 |
+
constant packed_uint4 & stride [[buffer(3)]],
|
| 440 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 441 |
+
if (linear_index >= numel) return;
|
| 442 |
+
|
| 443 |
+
constant {0} * src = (constant {0} *)src_;
|
| 444 |
+
device {1} * dst = (device {1} *)dst_;
|
| 445 |
+
|
| 446 |
+
packed_uint4 local_index;
|
| 447 |
+
local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
|
| 448 |
+
local_index.y = linear_index / (size[3] * size[2]) % size[1];
|
| 449 |
+
local_index.z = linear_index / size[3] % size[2];
|
| 450 |
+
local_index.w = linear_index % size[3];
|
| 451 |
+
|
| 452 |
+
const packed_uint4 strided_index = local_index * stride;
|
| 453 |
+
dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
|
| 454 |
+
}}
|
| 455 |
+
|
| 456 |
+
kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]],
|
| 457 |
+
constant void * src_ [[buffer(0)]],
|
| 458 |
+
device void * dst_ [[buffer(1)]],
|
| 459 |
+
constant packed_uint3 & size [[buffer(2)]],
|
| 460 |
+
constant packed_uint3 & stride [[buffer(3)]],
|
| 461 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 462 |
+
if (linear_index >= numel) return;
|
| 463 |
+
|
| 464 |
+
constant {0} * src = (constant {0} *)src_;
|
| 465 |
+
device {1} * dst = (device {1} *)dst_;
|
| 466 |
+
|
| 467 |
+
packed_uint3 local_index;
|
| 468 |
+
local_index.x = linear_index / (size[2] * size[1]) % size[0];
|
| 469 |
+
local_index.y = linear_index / size[2] % size[1];
|
| 470 |
+
local_index.z = linear_index % size[2];
|
| 471 |
+
|
| 472 |
+
const packed_uint3 strided_index = local_index * stride;
|
| 473 |
+
dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
|
| 474 |
+
}}
|
| 475 |
+
|
| 476 |
+
kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]],
|
| 477 |
+
constant void * src_ [[buffer(0)]],
|
| 478 |
+
device void * dst_ [[buffer(1)]],
|
| 479 |
+
constant packed_uint2 & size [[buffer(2)]],
|
| 480 |
+
constant packed_uint2 & stride [[buffer(3)]],
|
| 481 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 482 |
+
if (linear_index >= numel) return;
|
| 483 |
+
|
| 484 |
+
constant {0} * src = (constant {0} *)src_;
|
| 485 |
+
device {1} * dst = (device {1} *)dst_;
|
| 486 |
+
|
| 487 |
+
packed_uint2 local_index;
|
| 488 |
+
local_index.x = linear_index / size[1] % size[0];
|
| 489 |
+
local_index.y = linear_index % size[1];
|
| 490 |
+
|
| 491 |
+
const packed_uint2 strided_index = local_index * stride;
|
| 492 |
+
dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
|
| 493 |
+
}}
|
| 494 |
+
|
| 495 |
+
kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]],
|
| 496 |
+
constant void * src_ [[buffer(0)]],
|
| 497 |
+
device void * dst_ [[buffer(1)]],
|
| 498 |
+
constant int & size [[buffer(2)]],
|
| 499 |
+
constant int & stride [[buffer(3)]],
|
| 500 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 501 |
+
if (linear_index >= numel) return;
|
| 502 |
+
|
| 503 |
+
constant {0} * src = (constant {0} *)src_;
|
| 504 |
+
device {1} * dst = (device {1} *)dst_;
|
| 505 |
+
|
| 506 |
+
const int local_index = linear_index % size;
|
| 507 |
+
const int strided_index = local_index * stride;
|
| 508 |
+
dst[strided_index] = cast<{1}>(src[linear_index]);
|
| 509 |
+
}}
|
| 510 |
+
)METAL_SCATTER";
|
| 511 |
+
|
| 512 |
+
static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER(
|
| 513 |
+
struct __attribute__ ((packed)) packed_uint5{{
|
| 514 |
+
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
|
| 515 |
+
}};
|
| 516 |
+
|
| 517 |
+
template<typename Y, typename X>
|
| 518 |
+
Y cast(const X x);
|
| 519 |
+
|
| 520 |
+
template<>
|
| 521 |
+
{1} cast<{1}, {0}>(const {0} x) {{
|
| 522 |
+
return {2};
|
| 523 |
+
}}
|
| 524 |
+
|
| 525 |
+
kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]],
|
| 526 |
+
constant void * src_ [[buffer(0)]],
|
| 527 |
+
device void * dst_ [[buffer(1)]],
|
| 528 |
+
constant packed_uint5 & size [[buffer(2)]],
|
| 529 |
+
constant packed_uint5 & stride [[buffer(3)]],
|
| 530 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 531 |
+
if (linear_index >= numel) return;
|
| 532 |
+
|
| 533 |
+
constant {0} * src = (constant {0} *)src_;
|
| 534 |
+
device {1} * dst = (device {1} *)dst_;
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
packed_uint5 local_index;
|
| 538 |
+
local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
|
| 539 |
+
local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
|
| 540 |
+
local_index.z = linear_index / (size.u * size.w) % size.z;
|
| 541 |
+
local_index.w = linear_index / size.u % size.w;
|
| 542 |
+
local_index.u = linear_index % size.u;
|
| 543 |
+
|
| 544 |
+
packed_uint5 strided_index;
|
| 545 |
+
strided_index.x = local_index.x * stride.x;
|
| 546 |
+
strided_index.y = local_index.y * stride.y;
|
| 547 |
+
strided_index.z = local_index.z * stride.z;
|
| 548 |
+
strided_index.w = local_index.w * stride.w;
|
| 549 |
+
strided_index.u = local_index.u * stride.u;
|
| 550 |
+
|
| 551 |
+
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]);
|
| 552 |
+
}}
|
| 553 |
+
|
| 554 |
+
kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
|
| 555 |
+
constant void * src_ [[buffer(0)]],
|
| 556 |
+
device void * dst_ [[buffer(1)]],
|
| 557 |
+
constant packed_uint4 & size [[buffer(2)]],
|
| 558 |
+
constant packed_uint4 & stride [[buffer(3)]],
|
| 559 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 560 |
+
if (linear_index >= numel) return;
|
| 561 |
+
|
| 562 |
+
constant {0} * src = (constant {0} *)src_;
|
| 563 |
+
device {1} * dst = (device {1} *)dst_;
|
| 564 |
+
|
| 565 |
+
packed_uint4 local_index;
|
| 566 |
+
local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
|
| 567 |
+
local_index.y = linear_index / (size[3] * size[2]) % size[1];
|
| 568 |
+
local_index.z = linear_index / size[3] % size[2];
|
| 569 |
+
local_index.w = linear_index % size[3];
|
| 570 |
+
|
| 571 |
+
const packed_uint4 strided_index = local_index * stride;
|
| 572 |
+
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
|
| 573 |
+
}}
|
| 574 |
+
|
| 575 |
+
kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]],
|
| 576 |
+
constant void * src_ [[buffer(0)]],
|
| 577 |
+
device void * dst_ [[buffer(1)]],
|
| 578 |
+
constant packed_uint3 & size [[buffer(2)]],
|
| 579 |
+
constant packed_uint3 & stride [[buffer(3)]],
|
| 580 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 581 |
+
if (linear_index >= numel) return;
|
| 582 |
+
|
| 583 |
+
constant {0} * src = (constant {0} *)src_;
|
| 584 |
+
device {1} * dst = (device {1} *)dst_;
|
| 585 |
+
|
| 586 |
+
packed_uint3 local_index;
|
| 587 |
+
local_index.x = linear_index / (size[2] * size[1]) % size[0];
|
| 588 |
+
local_index.y = linear_index / size[2] % size[1];
|
| 589 |
+
local_index.z = linear_index % size[2];
|
| 590 |
+
|
| 591 |
+
const packed_uint3 strided_index = local_index * stride;
|
| 592 |
+
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
|
| 593 |
+
}}
|
| 594 |
+
|
| 595 |
+
kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]],
|
| 596 |
+
constant void * src_ [[buffer(0)]],
|
| 597 |
+
device void * dst_ [[buffer(1)]],
|
| 598 |
+
constant packed_uint2 & size [[buffer(2)]],
|
| 599 |
+
constant packed_uint2 & stride [[buffer(3)]],
|
| 600 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 601 |
+
if (linear_index >= numel) return;
|
| 602 |
+
|
| 603 |
+
constant {0} * src = (constant {0} *)src_;
|
| 604 |
+
device {1} * dst = (device {1} *)dst_;
|
| 605 |
+
|
| 606 |
+
packed_uint2 local_index;
|
| 607 |
+
local_index.x = linear_index / size[1] % size[0];
|
| 608 |
+
local_index.y = linear_index % size[1];
|
| 609 |
+
|
| 610 |
+
const packed_uint2 strided_index = local_index * stride;
|
| 611 |
+
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
|
| 612 |
+
}}
|
| 613 |
+
|
| 614 |
+
kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]],
|
| 615 |
+
constant void * src_ [[buffer(0)]],
|
| 616 |
+
device void * dst_ [[buffer(1)]],
|
| 617 |
+
constant int & size [[buffer(2)]],
|
| 618 |
+
constant int & stride [[buffer(3)]],
|
| 619 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 620 |
+
if (linear_index >= numel) return;
|
| 621 |
+
|
| 622 |
+
constant {0} * src = (constant {0} *)src_;
|
| 623 |
+
device {1} * dst = (device {1} *)dst_;
|
| 624 |
+
|
| 625 |
+
const int local_index = linear_index % size;
|
| 626 |
+
const int strided_index = local_index * stride;
|
| 627 |
+
dst[linear_index] = cast<{1}>(src[strided_index]);
|
| 628 |
+
}}
|
| 629 |
+
)METAL_GATHER";
|
| 630 |
+
} // namespace at::mps
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGuardImpl.h
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
| 5 |
+
#include <c10/macros/Macros.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
#include <ATen/Context.h>
|
| 8 |
+
#include <ATen/mps/MPSStream.h>
|
| 9 |
+
#include <ATen/mps/MPSEvent.h>
|
| 10 |
+
|
| 11 |
+
#ifdef __OBJC__
|
| 12 |
+
#include <Foundation/Foundation.h>
|
| 13 |
+
#include <Metal/Metal.h>
|
| 14 |
+
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
#include <ATen/Tensor.h>
|
| 18 |
+
#include <c10/core/MemoryFormat.h>
|
| 19 |
+
#include <c10/core/Storage.h>
|
| 20 |
+
#include <c10/core/TensorImpl.h>
|
| 21 |
+
#include <sys/_types/_size_t.h>
|
| 22 |
+
#include <memory>
|
| 23 |
+
#include <c10/core/UndefinedTensorImpl.h>
|
| 24 |
+
#include <c10/util/intrusive_ptr.h>
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
namespace at::mps {
|
| 28 |
+
|
| 29 |
+
typedef MPSEvent* mpsEvent_t;
|
| 30 |
+
|
| 31 |
+
// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
|
| 32 |
+
// https://github.com/pytorch/pytorch/issues/77170
|
| 33 |
+
struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
| 34 |
+
static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
|
| 35 |
+
|
| 36 |
+
// constructor
|
| 37 |
+
MPSGuardImpl() {}
|
| 38 |
+
explicit MPSGuardImpl(c10::DeviceType t) {
|
| 39 |
+
TORCH_INTERNAL_ASSERT(t == c10::DeviceType::MPS);
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
// returns the type
|
| 43 |
+
c10::DeviceType type() const override {
|
| 44 |
+
return c10::DeviceType::MPS;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
Device exchangeDevice(Device d) const override {
|
| 48 |
+
return Device(c10::DeviceType::MPS, 0);
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
Device getDevice() const override {
|
| 52 |
+
return Device(c10::DeviceType::MPS, 0);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
c10::optional<Device> uncheckedGetDevice() const noexcept {
|
| 56 |
+
return Device(c10::DeviceType::MPS, 0);
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
void setDevice(Device d) const override {
|
| 60 |
+
TORCH_INTERNAL_ASSERT(d.is_mps());
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
void uncheckedSetDevice(Device d) const noexcept override {
|
| 64 |
+
// TODO: Currently setting only device 0
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
Stream getStream(Device d) const noexcept override {
|
| 68 |
+
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
Stream getDefaultStream(Device d) const override {
|
| 72 |
+
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
// NB: These do NOT set the current device
|
| 76 |
+
Stream exchangeStream(Stream s) const noexcept override {
|
| 77 |
+
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
|
| 78 |
+
}
|
| 79 |
+
DeviceIndex deviceCount() const noexcept override {
|
| 80 |
+
if (at::hasMPS()) {
|
| 81 |
+
//TODO: extend it for multi-device case
|
| 82 |
+
return 1;
|
| 83 |
+
} else {
|
| 84 |
+
return 0;
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
// Event-related functions
|
| 89 |
+
void createEvent(
|
| 90 |
+
mpsEvent_t* event,
|
| 91 |
+
const EventFlag flag) const;
|
| 92 |
+
|
| 93 |
+
void destroyEvent(
|
| 94 |
+
void* event,
|
| 95 |
+
const DeviceIndex device_index) const noexcept override;
|
| 96 |
+
|
| 97 |
+
void record(
|
| 98 |
+
void** event,
|
| 99 |
+
const Stream& stream,
|
| 100 |
+
const DeviceIndex device_index,
|
| 101 |
+
const EventFlag flag) const override;
|
| 102 |
+
|
| 103 |
+
void block(
|
| 104 |
+
void* event,
|
| 105 |
+
const Stream& stream) const override;
|
| 106 |
+
|
| 107 |
+
bool queryEvent(void* event) const override;
|
| 108 |
+
|
| 109 |
+
};
|
| 110 |
+
|
| 111 |
+
/// A variant of OptionalDeviceGuard that is specialized for MPS.
|
| 112 |
+
struct OptionalMPSGuard {
|
| 113 |
+
explicit OptionalMPSGuard() : guard_() {}
|
| 114 |
+
|
| 115 |
+
explicit OptionalMPSGuard(c10::optional<Device> device_opt)
|
| 116 |
+
: guard_(device_opt) {}
|
| 117 |
+
|
| 118 |
+
/// Set the current MPS device to the passed device index, if it is not
|
| 119 |
+
/// nullopt
|
| 120 |
+
explicit OptionalMPSGuard(c10::optional<DeviceIndex> device_index_opt)
|
| 121 |
+
: guard_(device_index_opt) {}
|
| 122 |
+
|
| 123 |
+
// Copy is not allowed
|
| 124 |
+
OptionalMPSGuard(const OptionalMPSGuard&) = delete;
|
| 125 |
+
OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete;
|
| 126 |
+
OptionalMPSGuard(OptionalMPSGuard&& other) = delete;
|
| 127 |
+
OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete;
|
| 128 |
+
|
| 129 |
+
/// Sets the MPS device to the given device, initializing the guard if it
|
| 130 |
+
/// is not already initialized. Errors if the given device is not a MPS
|
| 131 |
+
/// device.
|
| 132 |
+
void set_device(Device device) {
|
| 133 |
+
guard_.set_device(device);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
/// Sets the MPS device to the given device, initializing the guard if it is
|
| 137 |
+
/// not already initialized. Errors if the given device is not a MPS device.
|
| 138 |
+
void reset_device(Device device) {
|
| 139 |
+
guard_.reset_device(device);
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
/// Sets the MPS device to the given device index, initializing the guard if
|
| 143 |
+
/// it is not already initialized.
|
| 144 |
+
void set_index(DeviceIndex device_index) {
|
| 145 |
+
guard_.set_index(device_index);
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
/// Returns the device that was set immediately prior to initialization of the
|
| 149 |
+
/// guard, or nullopt if the guard is uninitialized.
|
| 150 |
+
c10::optional<Device> original_device() const {
|
| 151 |
+
return guard_.original_device();
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
/// Returns the most recent device that was set using this device guard,
|
| 155 |
+
/// either from construction, or via set_device, if the guard is initialized,
|
| 156 |
+
/// or nullopt if the guard is uninitialized.
|
| 157 |
+
c10::optional<Device> current_device() const {
|
| 158 |
+
return guard_.current_device();
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
/// Restore the original MPS device, resetting this guard to uninitialized
|
| 162 |
+
/// state.
|
| 163 |
+
void reset() {
|
| 164 |
+
guard_.reset();
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
private:
|
| 168 |
+
c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
|
| 169 |
+
};
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl);
|
| 173 |
+
|
| 174 |
+
} // namespace at::mps
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/CPUApplyUtils.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/Dispatch_v2.h>
|
| 6 |
+
#include <ATen/ExpandBase.h>
|
| 7 |
+
#include <ATen/core/DistributionsHelper.h>
|
| 8 |
+
#include <ATen/native/TensorIterator.h>
|
| 9 |
+
#include <ATen/native/cpu/Loops.h>
|
| 10 |
+
#include <limits>
|
| 11 |
+
#include <mutex>
|
| 12 |
+
|
| 13 |
+
#ifdef CPU_CAPABILITY_AVX2
|
| 14 |
+
#include <ATen/native/cpu/avx_mathfun.h>
|
| 15 |
+
#include <c10/util/irange.h>
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
namespace at {
|
| 20 |
+
namespace native {
|
| 21 |
+
namespace templates {
|
| 22 |
+
namespace cpu {
|
| 23 |
+
namespace {
|
| 24 |
+
|
| 25 |
+
// ==================================================== Random ========================================================
|
| 26 |
+
|
| 27 |
+
template<typename RNG>
|
| 28 |
+
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) {
|
| 29 |
+
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] {
|
| 30 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 31 |
+
cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
|
| 32 |
+
uniform_int_from_to_distribution<scalar_t> random(range, base);
|
| 33 |
+
return random(generator);
|
| 34 |
+
});
|
| 35 |
+
}), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// This is the special kernel to handle single specific case:
|
| 39 |
+
// from(inclusive) = std::numeric_limits<int64_t>::lowest()
|
| 40 |
+
// to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
|
| 41 |
+
template<typename RNG>
|
| 42 |
+
void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) {
|
| 43 |
+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] {
|
| 44 |
+
if constexpr (std::is_same<scalar_t, int64_t>::value ||
|
| 45 |
+
std::is_same<scalar_t, double>::value ||
|
| 46 |
+
std::is_same<scalar_t, float>::value ||
|
| 47 |
+
std::is_same<scalar_t, at::BFloat16>::value) {
|
| 48 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 49 |
+
cpu_serial_kernel(iter, [generator]() -> scalar_t {
|
| 50 |
+
uniform_int_full_range_distribution<scalar_t> random;
|
| 51 |
+
return random(generator);
|
| 52 |
+
});
|
| 53 |
+
} else {
|
| 54 |
+
TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16");
|
| 55 |
+
}
|
| 56 |
+
});
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
template<typename RNG>
|
| 60 |
+
struct RandomFromToKernel {
|
| 61 |
+
void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, c10::optional<Generator> gen) {
|
| 62 |
+
random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
|
| 63 |
+
}
|
| 64 |
+
void operator()(TensorIteratorBase& iter, c10::optional<Generator> gen) {
|
| 65 |
+
random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
|
| 66 |
+
}
|
| 67 |
+
};
|
| 68 |
+
|
| 69 |
+
template<typename RNG>
|
| 70 |
+
void random_kernel(TensorIteratorBase& iter, RNG generator) {
|
| 71 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 72 |
+
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] {
|
| 73 |
+
cpu_serial_kernel(iter, [generator]() -> scalar_t {
|
| 74 |
+
uniform_int_distribution<scalar_t> random;
|
| 75 |
+
return random(generator);
|
| 76 |
+
});
|
| 77 |
+
});
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
template<typename RNG>
|
| 81 |
+
struct RandomKernel {
|
| 82 |
+
void operator()(TensorIteratorBase& iter, c10::optional<Generator> gen) {
|
| 83 |
+
random_kernel(iter, check_generator<RNG>(gen));
|
| 84 |
+
}
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
// ==================================================== Normal ========================================================
|
| 88 |
+
|
| 89 |
+
#ifdef CPU_CAPABILITY_AVX2
|
| 90 |
+
static void normal_fill_16_AVX2(float *data,
|
| 91 |
+
const __m256* two_pi,
|
| 92 |
+
const __m256* one,
|
| 93 |
+
const __m256* minus_two,
|
| 94 |
+
const __m256* mean,
|
| 95 |
+
const __m256* std_v) {
|
| 96 |
+
const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data));
|
| 97 |
+
const __m256 u2 = _mm256_loadu_ps(data + 8);
|
| 98 |
+
// sincos256_ps and log256_ps are from avx_mathfun.h
|
| 99 |
+
const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1)));
|
| 100 |
+
const __m256 theta = _mm256_mul_ps(*two_pi, u2);
|
| 101 |
+
__m256 sintheta, costheta;
|
| 102 |
+
sincos256_ps(theta, &sintheta, &costheta);
|
| 103 |
+
const __m256 n1 = _mm256_mul_ps(radius, costheta);
|
| 104 |
+
const __m256 n2 = _mm256_mul_ps(radius, sintheta);
|
| 105 |
+
_mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean));
|
| 106 |
+
_mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean));
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
template<typename RNG>
|
| 110 |
+
void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) {
|
| 111 |
+
float *data = self.data_ptr<float>();
|
| 112 |
+
auto size = self.numel();
|
| 113 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 114 |
+
for (const auto i : c10::irange(size)) {
|
| 115 |
+
at::uniform_real_distribution<float> uniform(0, 1);
|
| 116 |
+
data[i] = uniform(generator);
|
| 117 |
+
}
|
| 118 |
+
const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi<double>);
|
| 119 |
+
const __m256 one = _mm256_set1_ps(1.0f);
|
| 120 |
+
const __m256 minus_two = _mm256_set1_ps(-2.0f);
|
| 121 |
+
const __m256 mean_v = _mm256_set1_ps(mean);
|
| 122 |
+
const __m256 std_v = _mm256_set1_ps(std);
|
| 123 |
+
|
| 124 |
+
for (int64_t i = 0; i < size - 15; i += 16) {
|
| 125 |
+
normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
if (size % 16 != 0) {
|
| 129 |
+
// Recompute the last 16 values.
|
| 130 |
+
data = data + size - 16;
|
| 131 |
+
for (const auto i : c10::irange(16)) {
|
| 132 |
+
at::uniform_real_distribution<float> uniform(0, 1);
|
| 133 |
+
data[i] = uniform(generator);
|
| 134 |
+
}
|
| 135 |
+
normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v);
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
#endif
|
| 139 |
+
|
| 140 |
+
template <typename scalar_t>
|
| 141 |
+
static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) {
|
| 142 |
+
for (const auto j : c10::irange(8)) {
|
| 143 |
+
const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log.
|
| 144 |
+
const scalar_t u2 = data[j + 8];
|
| 145 |
+
const scalar_t radius = std::sqrt(-2 * std::log(u1));
|
| 146 |
+
const scalar_t theta = 2.0f * c10::pi<double> * u2;
|
| 147 |
+
data[j] = radius * std::cos(theta) * std + mean;
|
| 148 |
+
data[j + 8] = radius * std::sin(theta) * std + mean;
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
template <typename scalar_t, typename RNG>
|
| 153 |
+
void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
|
| 154 |
+
scalar_t *data = self.data_ptr<scalar_t>();
|
| 155 |
+
auto size = self.numel();
|
| 156 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 157 |
+
for (const auto i : c10::irange(size)) {
|
| 158 |
+
at::uniform_real_distribution<scalar_t> uniform(0, 1);
|
| 159 |
+
data[i] = uniform(generator);
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
for (int64_t i = 0; i < size - 15; i += 16) {
|
| 163 |
+
normal_fill_16<scalar_t>(data + i, mean, std);
|
| 164 |
+
}
|
| 165 |
+
if (size % 16 != 0) {
|
| 166 |
+
// Recompute the last 16 values.
|
| 167 |
+
data = data + size - 16;
|
| 168 |
+
for (const auto i : c10::irange(16)) {
|
| 169 |
+
at::uniform_real_distribution<scalar_t> uniform(0, 1);
|
| 170 |
+
data[i] = uniform(generator);
|
| 171 |
+
}
|
| 172 |
+
normal_fill_16<scalar_t>(data, mean, std);
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
template<typename RNG>
|
| 177 |
+
void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) {
|
| 178 |
+
auto size = self.numel();
|
| 179 |
+
if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) {
|
| 180 |
+
#ifdef CPU_CAPABILITY_AVX2
|
| 181 |
+
normal_fill_AVX2(self, static_cast<float>(mean), static_cast<float>(std), generator);
|
| 182 |
+
#else
|
| 183 |
+
normal_fill(self, static_cast<float>(mean), static_cast<float>(std), generator);
|
| 184 |
+
#endif
|
| 185 |
+
} else {
|
| 186 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] {
|
| 187 |
+
if (size >= 16 && self.is_contiguous()) {
|
| 188 |
+
normal_fill<scalar_t>(self, static_cast<scalar_t>(mean), static_cast<scalar_t>(std), generator);
|
| 189 |
+
} else {
|
| 190 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 191 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 192 |
+
cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t {
|
| 193 |
+
at::normal_distribution<double> normal(mean, std);
|
| 194 |
+
return static_cast<scalar_t>(normal(generator));
|
| 195 |
+
});
|
| 196 |
+
}
|
| 197 |
+
});
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
template<typename RNG>
|
| 202 |
+
struct NormalKernel {
|
| 203 |
+
void operator()(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
|
| 204 |
+
normal_kernel(self, mean, std, check_generator<RNG>(gen));
|
| 205 |
+
}
|
| 206 |
+
};
|
| 207 |
+
|
| 208 |
+
// ==================================================== Uniform =======================================================
|
| 209 |
+
|
| 210 |
+
template<typename RNG>
|
| 211 |
+
void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) {
|
| 212 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() {
|
| 213 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 214 |
+
auto from = static_cast<scalar_t>(from_);
|
| 215 |
+
auto to = static_cast<scalar_t>(to_);
|
| 216 |
+
at::uniform_real_distribution<scalar_t> uniform(from, to);
|
| 217 |
+
cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t {
|
| 218 |
+
return static_cast<scalar_t>(uniform(generator));
|
| 219 |
+
});
|
| 220 |
+
});
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
template<typename RNG>
|
| 224 |
+
struct UniformKernel {
|
| 225 |
+
void operator()(TensorIteratorBase& iter, double from, double to, c10::optional<Generator> gen) {
|
| 226 |
+
uniform_kernel(iter, from, to, check_generator<RNG>(gen));
|
| 227 |
+
}
|
| 228 |
+
};
|
| 229 |
+
|
| 230 |
+
// ==================================================== Cauchy ========================================================
|
| 231 |
+
|
| 232 |
+
template<typename RNG>
|
| 233 |
+
void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) {
|
| 234 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() {
|
| 235 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 236 |
+
at::cauchy_distribution<double> cauchy(median, sigma);
|
| 237 |
+
cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t {
|
| 238 |
+
return static_cast<scalar_t>(cauchy(generator));
|
| 239 |
+
});
|
| 240 |
+
});
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
template<typename RNG>
|
| 244 |
+
struct CauchyKernel {
|
| 245 |
+
void operator()(TensorIteratorBase& iter, double median, double sigma, c10::optional<Generator> gen) {
|
| 246 |
+
cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
|
| 247 |
+
}
|
| 248 |
+
};
|
| 249 |
+
|
| 250 |
+
// ================================================== LogNormal =======================================================
|
| 251 |
+
|
| 252 |
+
template<typename RNG>
|
| 253 |
+
void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) {
|
| 254 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() {
|
| 255 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 256 |
+
at::lognormal_distribution<double> logNormal(mean, std);
|
| 257 |
+
cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t {
|
| 258 |
+
return static_cast<scalar_t>(logNormal(generator));
|
| 259 |
+
});
|
| 260 |
+
});
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
template<typename RNG>
|
| 264 |
+
struct LogNormalKernel {
|
| 265 |
+
void operator()(TensorIteratorBase& iter, double mean, double std, c10::optional<Generator> gen) {
|
| 266 |
+
log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
|
| 267 |
+
}
|
| 268 |
+
};
|
| 269 |
+
|
| 270 |
+
// =================================================== Geometric ======================================================
|
| 271 |
+
|
| 272 |
+
template<typename RNG>
|
| 273 |
+
void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) {
|
| 274 |
+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() {
|
| 275 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 276 |
+
at::geometric_distribution<double> geometric(p);
|
| 277 |
+
cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t {
|
| 278 |
+
return static_cast<scalar_t>(geometric(generator));
|
| 279 |
+
});
|
| 280 |
+
});
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
template<typename RNG>
|
| 284 |
+
struct GeometricKernel {
|
| 285 |
+
void operator()(TensorIteratorBase& iter, double p, c10::optional<Generator> gen) {
|
| 286 |
+
geometric_kernel(iter, p, check_generator<RNG>(gen));
|
| 287 |
+
}
|
| 288 |
+
};
|
| 289 |
+
|
| 290 |
+
// ================================================== Exponential =====================================================
|
| 291 |
+
|
| 292 |
+
template<typename RNG>
|
| 293 |
+
void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) {
|
| 294 |
+
TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
|
| 295 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() {
|
| 296 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 297 |
+
at::exponential_distribution<double> exponential(lambda);
|
| 298 |
+
cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t {
|
| 299 |
+
return static_cast<scalar_t>(exponential(generator));
|
| 300 |
+
});
|
| 301 |
+
});
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
template<typename RNG>
|
| 305 |
+
struct ExponentialKernel {
|
| 306 |
+
void operator()(TensorIteratorBase& iter, double lambda, c10::optional<Generator> gen) {
|
| 307 |
+
exponential_kernel(iter, lambda, check_generator<RNG>(gen));
|
| 308 |
+
}
|
| 309 |
+
};
|
| 310 |
+
|
| 311 |
+
// ================================================== Bernoulli =======================================================
|
| 312 |
+
|
| 313 |
+
template<typename RNG>
|
| 314 |
+
void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) {
|
| 315 |
+
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
|
| 316 |
+
self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
|
| 317 |
+
// See Note [Acquire lock when using random generators]
|
| 318 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 319 |
+
using self_t = scalar_t;
|
| 320 |
+
auto p_cpu = p_.to(kCPU);
|
| 321 |
+
auto p = expand_inplace(self, p_cpu);
|
| 322 |
+
auto iter = TensorIteratorConfig()
|
| 323 |
+
.add_output(self)
|
| 324 |
+
.add_input(*p)
|
| 325 |
+
.check_all_same_dtype(false)
|
| 326 |
+
.build();
|
| 327 |
+
if (p->scalar_type() == kDouble) {
|
| 328 |
+
cpu_serial_kernel(iter, [&](const double p_val) -> self_t {
|
| 329 |
+
at::bernoulli_distribution<double> bernoulli(p_val);
|
| 330 |
+
return static_cast<self_t>(bernoulli(generator));
|
| 331 |
+
});
|
| 332 |
+
} else {
|
| 333 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
|
| 334 |
+
p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
|
| 335 |
+
using p_t = scalar_t;
|
| 336 |
+
cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
|
| 337 |
+
at::bernoulli_distribution<float> bernoulli(p_val);
|
| 338 |
+
return static_cast<self_t>(bernoulli(generator));
|
| 339 |
+
});
|
| 340 |
+
});
|
| 341 |
+
}
|
| 342 |
+
});
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
template<typename RNG>
|
| 346 |
+
void bernoulli_kernel(const TensorBase &self, double p, RNG generator) {
|
| 347 |
+
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
|
| 348 |
+
self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
|
| 349 |
+
// See Note [Acquire lock when using random generators]
|
| 350 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 351 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 352 |
+
cpu_serial_kernel(iter, [p, generator]() -> scalar_t {
|
| 353 |
+
at::bernoulli_distribution<double> bernoulli(p);
|
| 354 |
+
return static_cast<scalar_t>(bernoulli(generator));
|
| 355 |
+
});
|
| 356 |
+
});
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
template<typename RNG>
|
| 360 |
+
struct BernoulliKernel {
|
| 361 |
+
void operator()(const TensorBase &self, double p, c10::optional<Generator> gen) {
|
| 362 |
+
bernoulli_kernel(self, p, check_generator<RNG>(gen));
|
| 363 |
+
}
|
| 364 |
+
void operator()(const TensorBase &self, const TensorBase &p_, c10::optional<Generator> gen) {
|
| 365 |
+
bernoulli_kernel(self, p_, check_generator<RNG>(gen));
|
| 366 |
+
}
|
| 367 |
+
};
|
| 368 |
+
|
| 369 |
+
}}}}}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Config.h>
|
| 4 |
+
#if AT_MKLDNN_ENABLED()
|
| 5 |
+
#include <ATen/Tensor.h>
|
| 6 |
+
#include <ATen/native/quantized/PackedParams.h>
|
| 7 |
+
#include <ideep.hpp>
|
| 8 |
+
#include <cpuinfo.h>
|
| 9 |
+
|
| 10 |
+
#include <c10/util/CallOnce.h>
|
| 11 |
+
|
| 12 |
+
using PrimitiveCacheKey = std::tuple<
|
| 13 |
+
double, // input_scale
|
| 14 |
+
int64_t, // input_zero_point
|
| 15 |
+
std::vector<int64_t>, // input_shape
|
| 16 |
+
double, // output_scale
|
| 17 |
+
int64_t, // output_zero_point
|
| 18 |
+
int64_t, // OMP_number_of_threads
|
| 19 |
+
double, // accum_scale
|
| 20 |
+
int64_t>; // accum_zero_point
|
| 21 |
+
|
| 22 |
+
enum CacheKeyIndex {
|
| 23 |
+
InputScale,
|
| 24 |
+
InputZeroPoint,
|
| 25 |
+
InputShape,
|
| 26 |
+
OutputScale,
|
| 27 |
+
OutputZeroPoint,
|
| 28 |
+
NumOfThreads,
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
// Base class of primitive cache
|
| 32 |
+
struct PrimitiveCache {
|
| 33 |
+
PrimitiveCacheKey key;
|
| 34 |
+
|
| 35 |
+
bool hit(const PrimitiveCacheKey& key) {
|
| 36 |
+
return this->key == key;
|
| 37 |
+
}
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
using LinearParams = ideep::matmul_forward_params;
|
| 41 |
+
using Conv = dnnl::convolution_forward;
|
| 42 |
+
using ConvDesc = dnnl::convolution_forward::primitive_desc;
|
| 43 |
+
using ConvParams = ideep::convolution_forward_params;
|
| 44 |
+
using Deconv = dnnl::deconvolution_forward;
|
| 45 |
+
using DeconvDesc = dnnl::deconvolution_forward::primitive_desc;
|
| 46 |
+
using DeconvParams = ideep::deconv_forward_params;
|
| 47 |
+
|
| 48 |
+
struct LinearPrimitiveCache : PrimitiveCache {
|
| 49 |
+
LinearPrimitiveCache() {}
|
| 50 |
+
|
| 51 |
+
LinearPrimitiveCache(
|
| 52 |
+
const PrimitiveCacheKey& key,
|
| 53 |
+
const LinearParams& param) {
|
| 54 |
+
this->key = key;
|
| 55 |
+
this->param = param;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
LinearParams param;
|
| 59 |
+
|
| 60 |
+
// For dynamic qlinear, scale and zero point
|
| 61 |
+
// are set at execution time. So we only need to compare
|
| 62 |
+
// the rest part of key.
|
| 63 |
+
bool hit_dynamic(const PrimitiveCacheKey& new_key) {
|
| 64 |
+
auto cached_input_shape = std::get<InputShape>(this->key);
|
| 65 |
+
auto new_input_shape = std::get<InputShape>(new_key);
|
| 66 |
+
return (
|
| 67 |
+
cached_input_shape == new_input_shape &&
|
| 68 |
+
std::get<NumOfThreads>(this->key) == std::get<NumOfThreads>(new_key));
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
LinearParams& get_param() {
|
| 72 |
+
return param;
|
| 73 |
+
}
|
| 74 |
+
};
|
| 75 |
+
|
| 76 |
+
struct ConvPrimitiveCache : PrimitiveCache {
|
| 77 |
+
ConvPrimitiveCache() {}
|
| 78 |
+
|
| 79 |
+
ConvPrimitiveCache(
|
| 80 |
+
const PrimitiveCacheKey& key,
|
| 81 |
+
const ConvParams& params) {
|
| 82 |
+
this->key = key;
|
| 83 |
+
this->params = params;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
ConvParams params;
|
| 87 |
+
|
| 88 |
+
ConvParams& get_params() {
|
| 89 |
+
return params;
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
struct DeconvPrimitiveCache : PrimitiveCache {
|
| 94 |
+
DeconvPrimitiveCache() {}
|
| 95 |
+
|
| 96 |
+
DeconvPrimitiveCache(
|
| 97 |
+
const PrimitiveCacheKey& key,
|
| 98 |
+
const DeconvParams& params) {
|
| 99 |
+
this->key = key;
|
| 100 |
+
this->params = params;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
DeconvParams params;
|
| 104 |
+
|
| 105 |
+
DeconvParams& get_params() {
|
| 106 |
+
return params;
|
| 107 |
+
}
|
| 108 |
+
};
|
| 109 |
+
|
| 110 |
+
enum PostOps {
|
| 111 |
+
NoPostOp,
|
| 112 |
+
Relu,
|
| 113 |
+
LeakyRelu,
|
| 114 |
+
Tanh,
|
| 115 |
+
Gelu
|
| 116 |
+
};
|
| 117 |
+
|
| 118 |
+
static std::unordered_map<std::string, PostOps> POST_OP_TABLE = {
|
| 119 |
+
{"none", NoPostOp},
|
| 120 |
+
{"relu", Relu},
|
| 121 |
+
{"leaky_relu", LeakyRelu},
|
| 122 |
+
{"tanh", Tanh},
|
| 123 |
+
{"gelu", Gelu}
|
| 124 |
+
};
|
| 125 |
+
|
| 126 |
+
struct PackedLinearWeightsOnednn : public LinearPackedParamsBase {
|
| 127 |
+
PackedLinearWeightsOnednn(
|
| 128 |
+
std::unique_ptr<ideep::tensor> weight,
|
| 129 |
+
c10::optional<ideep::tensor> bias,
|
| 130 |
+
at::Tensor orig_weight,
|
| 131 |
+
c10::optional<at::Tensor> orig_bias)
|
| 132 |
+
: weight_(std::move(weight)),
|
| 133 |
+
bias_(std::move(bias)),
|
| 134 |
+
orig_weight_(std::move(orig_weight)),
|
| 135 |
+
orig_bias_(std::move(orig_bias)) {
|
| 136 |
+
cache_initialized_flag = std::make_unique<c10::once_flag>();
|
| 137 |
+
}
|
| 138 |
+
std::unique_ptr<ideep::tensor> weight_;
|
| 139 |
+
c10::optional<ideep::tensor> bias_;
|
| 140 |
+
at::Tensor orig_weight_;
|
| 141 |
+
c10::optional<at::Tensor> orig_bias_;
|
| 142 |
+
|
| 143 |
+
at::Tensor apply(
|
| 144 |
+
at::Tensor input,
|
| 145 |
+
double output_scale,
|
| 146 |
+
int64_t output_zero_point) override;
|
| 147 |
+
at::Tensor apply_relu(
|
| 148 |
+
at::Tensor input,
|
| 149 |
+
double output_scale,
|
| 150 |
+
int64_t output_zero_point) override;
|
| 151 |
+
|
| 152 |
+
at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
|
| 153 |
+
at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;
|
| 154 |
+
|
| 155 |
+
at::Tensor apply_leaky_relu(
|
| 156 |
+
at::Tensor input,
|
| 157 |
+
double output_scale,
|
| 158 |
+
int64_t output_zero_point,
|
| 159 |
+
double negative_slope);
|
| 160 |
+
|
| 161 |
+
at::Tensor apply_tanh(
|
| 162 |
+
at::Tensor input,
|
| 163 |
+
double output_scale,
|
| 164 |
+
int64_t output_zero_point);
|
| 165 |
+
|
| 166 |
+
std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
|
| 167 |
+
|
| 168 |
+
c10::optional<at::Tensor> bias() override {
|
| 169 |
+
return orig_bias_;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
|
| 173 |
+
at::Tensor weight,
|
| 174 |
+
c10::optional<at::Tensor> bias);
|
| 175 |
+
|
| 176 |
+
private:
|
| 177 |
+
LinearPrimitiveCache prim_cache;
|
| 178 |
+
std::unique_ptr<c10::once_flag> cache_initialized_flag;
|
| 179 |
+
|
| 180 |
+
template <PostOps post_op>
|
| 181 |
+
at::Tensor apply_impl(
|
| 182 |
+
at::Tensor input,
|
| 183 |
+
double output_scale,
|
| 184 |
+
int64_t output_zero_point,
|
| 185 |
+
torch::List<at::Scalar> post_op_args = torch::List<at::Scalar>());
|
| 186 |
+
|
| 187 |
+
template <bool ReluFused>
|
| 188 |
+
at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false);
|
| 189 |
+
|
| 190 |
+
LinearPrimitiveCache& get_cache() {
|
| 191 |
+
return prim_cache;
|
| 192 |
+
}
|
| 193 |
+
};
|
| 194 |
+
|
| 195 |
+
template <int kSpatialDim = 2>
|
| 196 |
+
struct PackedConvWeightsOnednn : public ConvPackedParamsBase<kSpatialDim> {
|
| 197 |
+
PackedConvWeightsOnednn(
|
| 198 |
+
std::unique_ptr<ideep::tensor> weight,
|
| 199 |
+
c10::optional<ideep::tensor> bias,
|
| 200 |
+
at::Tensor orig_weight,
|
| 201 |
+
c10::optional<at::Tensor> orig_bias,
|
| 202 |
+
torch::List<int64_t> stride,
|
| 203 |
+
torch::List<int64_t> padding,
|
| 204 |
+
torch::List<int64_t> output_padding,
|
| 205 |
+
torch::List<int64_t> dilation,
|
| 206 |
+
int64_t groups,
|
| 207 |
+
uint8_t transpose)
|
| 208 |
+
: weight_(std::move(weight)),
|
| 209 |
+
bias_(std::move(bias)),
|
| 210 |
+
orig_weight_(std::move(orig_weight)),
|
| 211 |
+
orig_bias_(std::move(orig_bias)),
|
| 212 |
+
stride_(std::move(stride)),
|
| 213 |
+
padding_(std::move(padding)),
|
| 214 |
+
output_padding_(std::move(output_padding)),
|
| 215 |
+
dilation_(std::move(dilation)),
|
| 216 |
+
groups_(groups),
|
| 217 |
+
transpose_(transpose) {
|
| 218 |
+
cache_initialized_flag = std::make_unique<c10::once_flag>();
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
std::unique_ptr<ideep::tensor> weight_;
|
| 222 |
+
c10::optional<ideep::tensor> bias_;
|
| 223 |
+
at::Tensor orig_weight_;
|
| 224 |
+
c10::optional<at::Tensor> orig_bias_;
|
| 225 |
+
torch::List<int64_t> stride_;
|
| 226 |
+
torch::List<int64_t> padding_;
|
| 227 |
+
torch::List<int64_t> output_padding_;
|
| 228 |
+
torch::List<int64_t> dilation_;
|
| 229 |
+
int64_t groups_;
|
| 230 |
+
uint8_t transpose_;
|
| 231 |
+
|
| 232 |
+
at::Tensor apply(
|
| 233 |
+
const at::Tensor& input,
|
| 234 |
+
double output_scale,
|
| 235 |
+
int64_t output_zero_point) override;
|
| 236 |
+
|
| 237 |
+
at::Tensor apply_relu(
|
| 238 |
+
const at::Tensor& input,
|
| 239 |
+
double output_scale,
|
| 240 |
+
int64_t output_zero_point) override;
|
| 241 |
+
|
| 242 |
+
at::Tensor apply_dynamic(
|
| 243 |
+
const at::Tensor& input,
|
| 244 |
+
bool reduce_range) override;
|
| 245 |
+
|
| 246 |
+
at::Tensor apply_add(
|
| 247 |
+
const at::Tensor& input,
|
| 248 |
+
const at::Tensor& accum,
|
| 249 |
+
double output_scale,
|
| 250 |
+
int64_t output_zero_point);
|
| 251 |
+
|
| 252 |
+
at::Tensor apply_add_relu(
|
| 253 |
+
const at::Tensor& input,
|
| 254 |
+
const at::Tensor& accum,
|
| 255 |
+
double output_scale,
|
| 256 |
+
int64_t output_zero_point);
|
| 257 |
+
|
| 258 |
+
std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
|
| 259 |
+
|
| 260 |
+
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
|
| 261 |
+
at::Tensor weight,
|
| 262 |
+
c10::optional<at::Tensor> bias,
|
| 263 |
+
torch::List<int64_t> stride,
|
| 264 |
+
torch::List<int64_t> padding,
|
| 265 |
+
torch::List<int64_t> output_padding,
|
| 266 |
+
torch::List<int64_t> dilation,
|
| 267 |
+
int64_t groups,
|
| 268 |
+
bool transpose);
|
| 269 |
+
|
| 270 |
+
torch::List<int64_t> stride() const override {
|
| 271 |
+
return stride_;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
torch::List<int64_t> padding() const override {
|
| 275 |
+
return padding_;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
torch::List<int64_t> output_padding() const override {
|
| 279 |
+
return output_padding_;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
torch::List<int64_t> dilation() const override {
|
| 283 |
+
return dilation_;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
int64_t groups() const override {
|
| 287 |
+
return groups_;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
bool transpose() const override {
|
| 291 |
+
return (bool)transpose_;
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
private:
|
| 295 |
+
ConvPrimitiveCache conv_prim_cache;
|
| 296 |
+
DeconvPrimitiveCache deconv_prim_cache;
|
| 297 |
+
std::unique_ptr<c10::once_flag> cache_initialized_flag;
|
| 298 |
+
|
| 299 |
+
template <bool ReluFused>
|
| 300 |
+
at::Tensor apply_impl(
|
| 301 |
+
const at::Tensor& input,
|
| 302 |
+
const c10::optional<at::Tensor>& accum,
|
| 303 |
+
double output_scale,
|
| 304 |
+
int64_t output_zero_point);
|
| 305 |
+
|
| 306 |
+
ConvPrimitiveCache& get_conv_cache() {
|
| 307 |
+
assert(!transpose());
|
| 308 |
+
return conv_prim_cache;
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
DeconvPrimitiveCache& get_deconv_cache() {
|
| 312 |
+
assert(transpose());
|
| 313 |
+
return deconv_prim_cache;
|
| 314 |
+
}
|
| 315 |
+
};
|
| 316 |
+
|
| 317 |
+
namespace onednn_utils {
|
| 318 |
+
|
| 319 |
+
static ideep::attr_t create_attr_by_post_op(
|
| 320 |
+
const std::string& post_op_name,
|
| 321 |
+
const torch::List<c10::optional<at::Scalar>>& post_op_args,
|
| 322 |
+
const dnnl::algorithm post_algorithm) {
|
| 323 |
+
using ideep::tensor;
|
| 324 |
+
PostOps post_op = POST_OP_TABLE[post_op_name];
|
| 325 |
+
if (post_op == Relu) {
|
| 326 |
+
return ideep::attr_t::fuse_relu();
|
| 327 |
+
} else if (post_op == LeakyRelu) {
|
| 328 |
+
return ideep::attr_t::fuse_relu_v2(/*alpha=*/post_op_args[0].value().to<float>());
|
| 329 |
+
} else if (post_op == Tanh) {
|
| 330 |
+
return ideep::attr_t::fuse_tanh();
|
| 331 |
+
} else if (post_op == Gelu) {
|
| 332 |
+
return ideep::attr_t::fuse_gelu_v2(0.f, 0.f, post_algorithm);
|
| 333 |
+
}
|
| 334 |
+
return ideep::attr_t();
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
// Try to reorder tensor to expected desc at runtime
|
| 338 |
+
// Do it in a `try...catch...` manner to avoid oneDNN's errors
|
| 339 |
+
// TODO: Move it to third_party/ideep
|
| 340 |
+
static void try_reorder(
|
| 341 |
+
ideep::tensor& t,
|
| 342 |
+
const ideep::tensor::desc&& desc,
|
| 343 |
+
ideep::scale_t scales) {
|
| 344 |
+
if (t.get_desc() != desc) {
|
| 345 |
+
try {
|
| 346 |
+
t = t.reorder_if_differ_in(desc);
|
| 347 |
+
} catch (...) {
|
| 348 |
+
ideep::tensor&& plain = t.to_public(nullptr, t.get_data_type());
|
| 349 |
+
t = plain.reorder_if_differ_in(desc);
|
| 350 |
+
}
|
| 351 |
+
t.set_scale(scales);
|
| 352 |
+
}
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
// ONEDNN requires symmetric quantization of weight
|
| 356 |
+
// Use this util function to check.
|
| 357 |
+
static bool is_weight_symmetric_quant(
|
| 358 |
+
const at::Tensor& weight,
|
| 359 |
+
bool is_transposed_conv) {
|
| 360 |
+
bool is_symmetric = true;
|
| 361 |
+
const auto qtype = weight.qscheme();
|
| 362 |
+
if (qtype == c10::kPerTensorAffine) {
|
| 363 |
+
is_symmetric &= (weight.q_zero_point() == 0);
|
| 364 |
+
} else if (qtype == c10::kPerChannelAffine) {
|
| 365 |
+
if (is_transposed_conv) {
|
| 366 |
+
// This case is currently not supported in PyTorch
|
| 367 |
+
// but we do not want to raise an error in this util function.
|
| 368 |
+
is_symmetric = false;
|
| 369 |
+
} else {
|
| 370 |
+
auto output_channels = weight.size(0);
|
| 371 |
+
for (int i = 0; i < output_channels; ++i) {
|
| 372 |
+
auto zp = weight.q_per_channel_zero_points()[i].item<int32_t>();
|
| 373 |
+
is_symmetric &= (zp == 0);
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
} else {
|
| 377 |
+
// This case is currently not supported in PyTorch
|
| 378 |
+
// but we do not want to raise an error in this util function.
|
| 379 |
+
is_symmetric = false;
|
| 380 |
+
}
|
| 381 |
+
return is_symmetric;
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
// When qengine is x86, use this util func to check if onednn kernel
|
| 385 |
+
// is preferred than fbgemm's to get better performance.
|
| 386 |
+
static bool should_use_onednn_quant(
|
| 387 |
+
const at::Tensor& weight,
|
| 388 |
+
bool is_transposed_conv,
|
| 389 |
+
int groups,
|
| 390 |
+
torch::List<int64_t> output_padding) {
|
| 391 |
+
// Performance of onednn is only validated on Linux right now.
|
| 392 |
+
// Also, the heuristics for dispatching are based on perf data on Linux.
|
| 393 |
+
// So, for x86 qengine, we always use fbgemm kernels if OS is not Linux.
|
| 394 |
+
// TODO Support more OSs.
|
| 395 |
+
#if !defined(__linux__)
|
| 396 |
+
return false;
|
| 397 |
+
#else
|
| 398 |
+
bool vnni_available = cpuinfo_has_x86_avx512vnni();
|
| 399 |
+
bool w_sym_quant =
|
| 400 |
+
is_weight_symmetric_quant(weight, is_transposed_conv);
|
| 401 |
+
bool opad_all_zero =
|
| 402 |
+
std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; });
|
| 403 |
+
return vnni_available && (groups <= 100) && w_sym_quant && opad_all_zero;
|
| 404 |
+
#endif
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
} // onednn_utils
|
| 408 |
+
|
| 409 |
+
at::Tensor _qconv_prepack_onednn(
|
| 410 |
+
at::Tensor weight, // from CPU backend instead of QuantizedCPU
|
| 411 |
+
at::Tensor weight_scales, // Weight zero points must be 0 for onednn
|
| 412 |
+
double input_scale,
|
| 413 |
+
int64_t input_zero_point,
|
| 414 |
+
torch::List<int64_t> stride,
|
| 415 |
+
torch::List<int64_t> padding,
|
| 416 |
+
torch::List<int64_t> dilation,
|
| 417 |
+
int64_t groups,
|
| 418 |
+
c10::optional<torch::List<int64_t>> input_shape=c10::nullopt);
|
| 419 |
+
|
| 420 |
+
static at::Tensor _quantized_convolution_onednn(
|
| 421 |
+
at::Tensor act, // contains quantized values but not QTensor
|
| 422 |
+
double act_scale,
|
| 423 |
+
int64_t act_zero_point,
|
| 424 |
+
at::Tensor weight, // MKLDNN tensor with quantized values
|
| 425 |
+
at::Tensor weight_scales,
|
| 426 |
+
at::Tensor weight_zero_points,
|
| 427 |
+
c10::optional<at::Tensor> bias, // Bias is packed if not None
|
| 428 |
+
torch::List<int64_t> stride,
|
| 429 |
+
torch::List<int64_t> padding,
|
| 430 |
+
torch::List<int64_t> dilation,
|
| 431 |
+
bool transposed,
|
| 432 |
+
int64_t groups,
|
| 433 |
+
double inv_output_scale,
|
| 434 |
+
int64_t output_zero_point,
|
| 435 |
+
c10::optional<at::Tensor> accum=c10::nullopt, // accum to fused with conv add
|
| 436 |
+
double accum_scale=1.0,
|
| 437 |
+
int64_t accum_zero_point=0,
|
| 438 |
+
bool fp32_output=false,
|
| 439 |
+
c10::optional<c10::string_view> binary_attr=c10::nullopt,
|
| 440 |
+
c10::optional<at::Scalar> binary_alpha=c10::nullopt,
|
| 441 |
+
c10::optional<c10::string_view> unary_attr=c10::nullopt,
|
| 442 |
+
torch::List<c10::optional<at::Scalar>> unary_scalars=torch::List<c10::optional<at::Scalar>>(),
|
| 443 |
+
c10::optional<c10::string_view> unary_algorithm=c10::nullopt);
|
| 444 |
+
|
| 445 |
+
#endif // #if AT_MKLDNN_ENABLED()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifdef USE_XNNPACK
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
#include <ATen/core/Tensor.h>
|
| 7 |
+
#include <ATen/native/xnnpack/Common.h>
|
| 8 |
+
|
| 9 |
+
using xnnpack_operator = at::native::xnnpack::Operator;
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
namespace native {
|
| 13 |
+
namespace xnnp_utils {
|
| 14 |
+
|
| 15 |
+
/*
|
| 16 |
+
* Return shape in the same order as the memory format
|
| 17 |
+
* e.g. channels_last will return NHWC instead of NCHW
|
| 18 |
+
*/
|
| 19 |
+
std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in);
|
| 20 |
+
|
| 21 |
+
/*
|
| 22 |
+
* Input is always int8_t, output can be [int8_t, uint8_t].
|
| 23 |
+
* input + offset = output
|
| 24 |
+
* int8_t + 128 = uint8_t
|
| 25 |
+
* int8_t + 0 = int8_t
|
| 26 |
+
*/
|
| 27 |
+
template <typename PT>
|
| 28 |
+
void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out);
|
| 29 |
+
|
| 30 |
+
template <int kSpatialDim>
|
| 31 |
+
Tensor convert_conv_weights_to_channel_last_tensor(
|
| 32 |
+
const at::Tensor& src,
|
| 33 |
+
int groups,
|
| 34 |
+
bool transpose);
|
| 35 |
+
|
| 36 |
+
/*
|
| 37 |
+
* Series of create wrapper functions to call xnn_create_[de]conv* functions.
|
| 38 |
+
*/
|
| 39 |
+
C10_ALWAYS_INLINE
|
| 40 |
+
enum xnn_status xnnp_create_convolution2d_nhwc(
|
| 41 |
+
uint32_t pad_top,
|
| 42 |
+
uint32_t pad_right,
|
| 43 |
+
uint32_t pad_bottom,
|
| 44 |
+
uint32_t pad_left,
|
| 45 |
+
uint32_t kernel_h,
|
| 46 |
+
uint32_t kernel_w,
|
| 47 |
+
uint32_t stride_h,
|
| 48 |
+
uint32_t stride_w,
|
| 49 |
+
uint32_t dilation_h,
|
| 50 |
+
uint32_t dilation_w,
|
| 51 |
+
uint32_t groups,
|
| 52 |
+
size_t group_input_channels,
|
| 53 |
+
size_t group_output_channels,
|
| 54 |
+
size_t ip_chan_stride,
|
| 55 |
+
size_t op_chan_stride,
|
| 56 |
+
int8_t izp,
|
| 57 |
+
float ip_scale,
|
| 58 |
+
int8_t kzp,
|
| 59 |
+
const float* k_scales,
|
| 60 |
+
const int8_t* kernel,
|
| 61 |
+
const int32_t* bias,
|
| 62 |
+
int8_t ozp,
|
| 63 |
+
float op_scale,
|
| 64 |
+
int8_t op_min,
|
| 65 |
+
int8_t op_max,
|
| 66 |
+
uint32_t flags,
|
| 67 |
+
xnn_operator_t* op,
|
| 68 |
+
bool per_channel,
|
| 69 |
+
bool transpose) {
|
| 70 |
+
/* Symmetric quantization forces kzp = 0 */
|
| 71 |
+
TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero."
|
| 72 |
+
"But got: ", kzp);
|
| 73 |
+
|
| 74 |
+
if (transpose) {
|
| 75 |
+
TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
|
| 76 |
+
return xnn_create_deconvolution2d_nhwc_qs8(
|
| 77 |
+
pad_top, /* uint32_t output_padding_top */
|
| 78 |
+
pad_right, /* uint32_t output_padding_right */
|
| 79 |
+
pad_bottom, /* uint32_t output_padding_bottom */
|
| 80 |
+
pad_left, /* uint32_t output_padding_left */
|
| 81 |
+
kernel_h, /* uint32_t kernel_height */
|
| 82 |
+
kernel_w, /* uint32_t kernel_width */
|
| 83 |
+
stride_h, /* uint32_t stride_height */
|
| 84 |
+
stride_w, /* uint32_t stride_width */
|
| 85 |
+
dilation_h, /* uint32_t dilation_height */
|
| 86 |
+
dilation_w, /* uint32_t dilation_width */
|
| 87 |
+
groups, /* uint32_t groups */
|
| 88 |
+
group_input_channels, /* size_t group_input_channels */
|
| 89 |
+
group_output_channels, /* size_t group_output_channels */
|
| 90 |
+
ip_chan_stride, /* size_t input_pixel_stride */
|
| 91 |
+
op_chan_stride, /* size_t output_pixel_stride */
|
| 92 |
+
izp, /* int8_t input_zero_point */
|
| 93 |
+
ip_scale, /* float input_scale */
|
| 94 |
+
k_scales[0], /* float kernel_scale */
|
| 95 |
+
kernel, /* const int8_t* kernel */
|
| 96 |
+
bias, /* const int32_t* bias */
|
| 97 |
+
ozp, /* int8_t output_zero_point */
|
| 98 |
+
op_scale, /* float output_scale */
|
| 99 |
+
op_min, /* int8_t output_min */
|
| 100 |
+
op_max, /* int8_t output_max */
|
| 101 |
+
flags, /* uint32_t flags */
|
| 102 |
+
nullptr, /* xnn_caches_t caches */
|
| 103 |
+
nullptr, /* xnn_weights_cache_t weights_cache */
|
| 104 |
+
op); /* xnn_operator_t* deconvolution_op_out */
|
| 105 |
+
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
if (!per_channel) {
|
| 109 |
+
return xnn_create_convolution2d_nhwc_qs8(
|
| 110 |
+
pad_top, /* uint32_t input_padding_top */
|
| 111 |
+
pad_right, /* uint32_t input_padding_right */
|
| 112 |
+
pad_bottom, /* uint32_t input_padding_bottom */
|
| 113 |
+
pad_left, /* uint32_t input_padding_left */
|
| 114 |
+
kernel_h, /* uint32_t kernel_height */
|
| 115 |
+
kernel_w, /* uint32_t kernel_width */
|
| 116 |
+
stride_h, /* uint32_t subsampling_height */
|
| 117 |
+
stride_w, /* uint32_t subsampling_width */
|
| 118 |
+
dilation_h, /* uint32_t dilation_height */
|
| 119 |
+
dilation_w, /* uint32_t dilation_width */
|
| 120 |
+
groups, /* uint32_t groups */
|
| 121 |
+
group_input_channels, /* size_t group_input_channels */
|
| 122 |
+
group_output_channels, /* size_t group_output_channels*/
|
| 123 |
+
ip_chan_stride, /* size_t input_channel_stride */
|
| 124 |
+
op_chan_stride, /* size_t output_channel_stride */
|
| 125 |
+
izp, /* int8_t input_zero_point */
|
| 126 |
+
ip_scale, /* float input_scale */
|
| 127 |
+
k_scales[0], /* float kernel_scale */
|
| 128 |
+
kernel, /* const int8_t* kernel */
|
| 129 |
+
bias, /* const int32_t* bias */
|
| 130 |
+
ozp, /* int8_t output_zero_point */
|
| 131 |
+
op_scale, /* float output_scale */
|
| 132 |
+
op_min, /* int8_t output_min */
|
| 133 |
+
op_max, /* int8_t output_max */
|
| 134 |
+
flags, /* uint32_t flags */
|
| 135 |
+
nullptr, /* xnn_caches_t caches */
|
| 136 |
+
nullptr, /* xnn_weights_cache_t weights_cache */
|
| 137 |
+
op); /* xnn_operator_t* convolution_op_out */
|
| 138 |
+
} else { /* per_channel */
|
| 139 |
+
return xnn_create_convolution2d_nhwc_qs8_qc8w(
|
| 140 |
+
pad_top, /* uint32_t input_padding_top */
|
| 141 |
+
pad_right, /* uint32_t input_padding_right */
|
| 142 |
+
pad_bottom, /* uint32_t input_padding_bottom */
|
| 143 |
+
pad_left, /* uint32_t input_padding_left */
|
| 144 |
+
kernel_h, /* uint32_t kernel_height */
|
| 145 |
+
kernel_w, /* uint32_t kernel_width */
|
| 146 |
+
stride_h, /* uint32_t subsampling_height */
|
| 147 |
+
stride_w, /* uint32_t subsampling_width */
|
| 148 |
+
dilation_h, /* uint32_t dilation_height */
|
| 149 |
+
dilation_w, /* uint32_t dilation_width */
|
| 150 |
+
groups, /* uint32_t groups */
|
| 151 |
+
group_input_channels, /* size_t group_input_channels */
|
| 152 |
+
group_output_channels, /* size_t group_output_channels*/
|
| 153 |
+
ip_chan_stride, /* size_t input_channel_stride */
|
| 154 |
+
op_chan_stride, /* size_t output_channel_stride */
|
| 155 |
+
izp, /* int8_t input_zero_point */
|
| 156 |
+
ip_scale, /* float input_scale */
|
| 157 |
+
k_scales, /* const float* kernel_scale */
|
| 158 |
+
kernel, /* const int8_t* kernel */
|
| 159 |
+
bias, /* const int32_t* bias */
|
| 160 |
+
ozp, /* int8_t output_zero_point */
|
| 161 |
+
op_scale, /* float output_scale */
|
| 162 |
+
op_min, /* int8_t output_min */
|
| 163 |
+
op_max, /* int8_t output_max */
|
| 164 |
+
flags, /* uint32_t flags */
|
| 165 |
+
nullptr, /* xnn_caches_t caches */
|
| 166 |
+
nullptr, /* xnn_weights_cache_t weights_cache */
|
| 167 |
+
op); /* xnn_operator_t* convolution_op_out */
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
/*
|
| 172 |
+
* Series of reshape wrapper functions to call xnn_reshape_[de]conv* functions.
|
| 173 |
+
*/
|
| 174 |
+
C10_ALWAYS_INLINE
|
| 175 |
+
enum xnn_status xnnp_reshape_convolution2d_nhwc(
|
| 176 |
+
xnn_operator_t op,
|
| 177 |
+
size_t batch,
|
| 178 |
+
size_t in_h,
|
| 179 |
+
size_t in_w,
|
| 180 |
+
pthreadpool_t pt_pool,
|
| 181 |
+
bool per_channel = false,
|
| 182 |
+
bool transpose = false,
|
| 183 |
+
uint32_t adj_h = 0,
|
| 184 |
+
uint32_t adj_w = 0) {
|
| 185 |
+
if(transpose) {
|
| 186 |
+
TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
|
| 187 |
+
return xnn_reshape_deconvolution2d_nhwc_qs8(
|
| 188 |
+
op, /* xnn_operator_t deconvolution_op */
|
| 189 |
+
batch, /* size_t batch_size */
|
| 190 |
+
in_h, /* size_t input_height */
|
| 191 |
+
in_w, /* size_t input_width */
|
| 192 |
+
adj_h, /* uint32_t adjustment_height */
|
| 193 |
+
adj_w, /* uint32_t adjustment_width */
|
| 194 |
+
nullptr, /* size_t* output_height_out */
|
| 195 |
+
nullptr, /* size_t* output_width_out */
|
| 196 |
+
pt_pool); /* pthreadpool_t threadpool */
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
size_t workspace_size = SIZE_MAX;
|
| 200 |
+
size_t workspace_alignment = SIZE_MAX;
|
| 201 |
+
|
| 202 |
+
if (!per_channel) {
|
| 203 |
+
return xnn_reshape_convolution2d_nhwc_qs8(
|
| 204 |
+
op, /* xnn_operator_t convolution_op */
|
| 205 |
+
batch, /* size_t batch_size */
|
| 206 |
+
in_h, /* size_t input_height */
|
| 207 |
+
in_w, /* size_t input_width */
|
| 208 |
+
&workspace_size, /* size_t* workspace_size */
|
| 209 |
+
&workspace_alignment, /* size_t* workspace_alignment */
|
| 210 |
+
nullptr, /* size_t* output_height_out */
|
| 211 |
+
nullptr, /* size_t* output_width_out */
|
| 212 |
+
pt_pool); /* pthreadpool_t threadpool */
|
| 213 |
+
} else { /* per_channel */
|
| 214 |
+
return xnn_reshape_convolution2d_nhwc_qs8_qc8w(
|
| 215 |
+
op, /* xnn_operator_t convolution_op */
|
| 216 |
+
batch, /* size_t batch_size */
|
| 217 |
+
in_h, /* size_t input_height */
|
| 218 |
+
in_w, /* size_t input_width */
|
| 219 |
+
&workspace_size, /* size_t* workspace_size */
|
| 220 |
+
&workspace_alignment, /* size_t* workspace_alignment */
|
| 221 |
+
nullptr, /* size_t* output_height_out */
|
| 222 |
+
nullptr, /* size_t* output_width_out */
|
| 223 |
+
pt_pool); /* pthreadpool_t threadpool */
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
/*
|
| 229 |
+
* Series of setup wrapper functions to call xnn_setup_[de]conv* functions.
|
| 230 |
+
*/
|
| 231 |
+
C10_ALWAYS_INLINE
|
| 232 |
+
enum xnn_status xnnp_setup_convolution2d_nhwc(
|
| 233 |
+
xnn_operator_t op,
|
| 234 |
+
const int8_t* inp,
|
| 235 |
+
int8_t* outp,
|
| 236 |
+
bool per_channel = false,
|
| 237 |
+
bool transpose = false) {
|
| 238 |
+
if(transpose) {
|
| 239 |
+
TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
|
| 240 |
+
|
| 241 |
+
return xnn_setup_deconvolution2d_nhwc_qs8(
|
| 242 |
+
op, /* xnn_operator_t deconvolution_op */
|
| 243 |
+
inp, /* const int8_t* input */
|
| 244 |
+
outp); /* int8_t* output */
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
if (!per_channel) {
|
| 248 |
+
return xnn_setup_convolution2d_nhwc_qs8(
|
| 249 |
+
op, /* xnn_operator_t deconvolution_op */
|
| 250 |
+
nullptr, /* void workspace */
|
| 251 |
+
inp, /* const int8_t* input */
|
| 252 |
+
outp); /* int8_t* output */
|
| 253 |
+
} else { /* per_channel */
|
| 254 |
+
return xnn_setup_convolution2d_nhwc_qs8_qc8w(
|
| 255 |
+
op, /* xnn_operator_t deconvolution_op */
|
| 256 |
+
nullptr, /* void workspace */
|
| 257 |
+
inp, /* const int8_t* input */
|
| 258 |
+
outp); /* int8_t* output */
|
| 259 |
+
}
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
/*
|
| 264 |
+
* Series of wrapper functions to call xnn_create* and xnn_setup*
|
| 265 |
+
* functions for linear
|
| 266 |
+
*/
|
| 267 |
+
C10_ALWAYS_INLINE
|
| 268 |
+
enum xnn_status xnnp_create_fully_connected_nc(
|
| 269 |
+
size_t input_channels,
|
| 270 |
+
size_t output_channels,
|
| 271 |
+
size_t input_stride,
|
| 272 |
+
size_t output_stride,
|
| 273 |
+
int8_t input_zero_point,
|
| 274 |
+
float input_scale,
|
| 275 |
+
int8_t kernel_zero_point,
|
| 276 |
+
float kernel_scale,
|
| 277 |
+
const int8_t* kernel,
|
| 278 |
+
const int32_t* bias,
|
| 279 |
+
int8_t output_zero_point,
|
| 280 |
+
float output_scale,
|
| 281 |
+
int8_t output_min,
|
| 282 |
+
int8_t output_max,
|
| 283 |
+
uint32_t flags,
|
| 284 |
+
xnn_operator_t* fully_connected_op_out) {
|
| 285 |
+
/* Symmetric quantization forces kzp = 0 */
|
| 286 |
+
TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero."
|
| 287 |
+
"But got: ", kernel_zero_point);
|
| 288 |
+
return xnn_create_fully_connected_nc_qs8(
|
| 289 |
+
input_channels, /* size_t input_channels */
|
| 290 |
+
output_channels, /* size_t output_channels */
|
| 291 |
+
input_stride, /* size_t input_stride */
|
| 292 |
+
output_stride, /* size_t output_stride */
|
| 293 |
+
input_zero_point, /* int8_t input_zero_point */
|
| 294 |
+
input_scale, /* float input_scale */
|
| 295 |
+
kernel_scale, /* float kernel_scale */
|
| 296 |
+
kernel, /* const int8_t* kernel */
|
| 297 |
+
bias, /* const int32_t* bias */
|
| 298 |
+
output_zero_point, /* int8_t output_zero_point */
|
| 299 |
+
output_scale, /* float output_scale */
|
| 300 |
+
output_min, /* int8_t output_min */
|
| 301 |
+
output_max, /* int8_t output_max */
|
| 302 |
+
flags, /* uint32_t flags */
|
| 303 |
+
nullptr, /* xnn_caches_t caches */
|
| 304 |
+
nullptr, /* xnn_weights_cache_t */
|
| 305 |
+
fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
C10_ALWAYS_INLINE
|
| 309 |
+
enum xnn_status xnnp_reshape_fully_connected_nc(
|
| 310 |
+
xnn_operator_t fully_connected_op,
|
| 311 |
+
size_t batch_size,
|
| 312 |
+
pthreadpool_t threadpool) {
|
| 313 |
+
return xnn_reshape_fully_connected_nc_qs8(
|
| 314 |
+
fully_connected_op, /* xnn_operator_t fully_connected_op */
|
| 315 |
+
batch_size, /* size_t batch_size */
|
| 316 |
+
threadpool); /* pthreadpool_t threadpool */
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
C10_ALWAYS_INLINE
|
| 320 |
+
enum xnn_status xnnp_setup_fully_connected_nc(
|
| 321 |
+
xnn_operator_t fully_connected_op,
|
| 322 |
+
const int8_t* input,
|
| 323 |
+
int8_t* output) {
|
| 324 |
+
return xnn_setup_fully_connected_nc_qs8(
|
| 325 |
+
fully_connected_op, /* xnn_operator_t fully_connected_op */
|
| 326 |
+
input, /* const int8_t* input */
|
| 327 |
+
output /* int8_t* output */
|
| 328 |
+
);
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
} // namespace xnnp_utils
|
| 332 |
+
} // namespace native
|
| 333 |
+
} // namespace at
|
| 334 |
+
|
| 335 |
+
#endif // USE_XNNPACK
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/Factory.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
namespace native {
|
| 7 |
+
namespace mobile {
|
| 8 |
+
|
| 9 |
+
Tensor allocate_padded_contiguous_if_needed(
|
| 10 |
+
const Tensor& input,
|
| 11 |
+
c10::MemoryFormat memory_format);
|
| 12 |
+
|
| 13 |
+
// TODO: Remove this function when at::native::empty() is modified to accept a
|
| 14 |
+
// custom memory allocator.
|
| 15 |
+
|
| 16 |
+
at::Tensor empty_with_tail_padding(
|
| 17 |
+
IntArrayRef size,
|
| 18 |
+
const caffe2::TypeMeta dtype,
|
| 19 |
+
c10::MemoryFormat memory_format,
|
| 20 |
+
c10::optional<DimnameList> maybe_names);
|
| 21 |
+
|
| 22 |
+
} // namespace mobile
|
| 23 |
+
} // namespace native
|
| 24 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamUtils.h
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/ArrayRef.h>
|
| 4 |
+
#include <vector>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
namespace native {
|
| 8 |
+
|
| 9 |
+
template <typename T>
|
| 10 |
+
inline std::vector<T> _expand_param_if_needed(
|
| 11 |
+
ArrayRef<T> list_param,
|
| 12 |
+
const char* param_name,
|
| 13 |
+
int64_t expected_dim) {
|
| 14 |
+
if (list_param.size() == 1) {
|
| 15 |
+
return std::vector<T>(expected_dim, list_param[0]);
|
| 16 |
+
} else if ((int64_t)list_param.size() != expected_dim) {
|
| 17 |
+
std::ostringstream ss;
|
| 18 |
+
ss << "expected " << param_name << " to be a single integer value or a "
|
| 19 |
+
<< "list of " << expected_dim << " values to match the convolution "
|
| 20 |
+
<< "dimensions, but got " << param_name << "=" << list_param;
|
| 21 |
+
AT_ERROR(ss.str());
|
| 22 |
+
} else {
|
| 23 |
+
return list_param.vec();
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
inline std::vector<int64_t> expand_param_if_needed(
|
| 28 |
+
IntArrayRef list_param,
|
| 29 |
+
const char* param_name,
|
| 30 |
+
int64_t expected_dim) {
|
| 31 |
+
return _expand_param_if_needed(list_param, param_name, expected_dim);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
inline std::vector<c10::SymInt> expand_param_if_needed(
|
| 35 |
+
SymIntArrayRef list_param,
|
| 36 |
+
const char* param_name,
|
| 37 |
+
int64_t expected_dim) {
|
| 38 |
+
return _expand_param_if_needed(list_param, param_name, expected_dim);
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
} // namespace native
|
| 42 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamsHash.h
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
#include <memory>
|
| 5 |
+
#include <mutex>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
|
| 9 |
+
// Hashing machinery for Params
|
| 10 |
+
// Fowler–Noll–Vo hash function
|
| 11 |
+
// see
|
| 12 |
+
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
|
| 13 |
+
template <typename Params>
|
| 14 |
+
struct ParamsHash {
|
| 15 |
+
// Params must be a POD because we read out its memory
|
| 16 |
+
// contents as char* when hashing
|
| 17 |
+
static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
|
| 18 |
+
|
| 19 |
+
size_t operator()(const Params& params) const {
|
| 20 |
+
auto ptr = reinterpret_cast<const uint8_t*>(¶ms);
|
| 21 |
+
uint32_t value = 0x811C9DC5;
|
| 22 |
+
for (const auto i : c10::irange(sizeof(Params))) {
|
| 23 |
+
value ^= ptr[i];
|
| 24 |
+
value *= 0x01000193;
|
| 25 |
+
}
|
| 26 |
+
return (size_t)value;
|
| 27 |
+
}
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
template <typename Params>
|
| 31 |
+
struct ParamsEqual {
|
| 32 |
+
// Params must be a POD because we read out its memory
|
| 33 |
+
// contents as char* when comparing
|
| 34 |
+
static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
|
| 35 |
+
|
| 36 |
+
bool operator()(const Params& a, const Params& b) const {
|
| 37 |
+
auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
|
| 38 |
+
auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
|
| 39 |
+
return memcmp(ptr1, ptr2, sizeof(Params)) == 0;
|
| 40 |
+
}
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
// Provide explicit byte-for-byte constructors to avoid uwittingly leaving
|
| 44 |
+
// padding bytes unitialized (e.g., when passing Params by value)
|
| 45 |
+
template <typename T>
|
| 46 |
+
struct ParamsWrapper {
|
| 47 |
+
T pod;
|
| 48 |
+
static_assert(
|
| 49 |
+
std::is_standard_layout_v<T>,
|
| 50 |
+
"ParamsWrapper cannot wrap non-POD data");
|
| 51 |
+
|
| 52 |
+
ParamsWrapper() {
|
| 53 |
+
memset(&(this->pod), 0, sizeof(this->pod));
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
ParamsWrapper(const ParamsWrapper& other) {
|
| 57 |
+
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
ParamsWrapper(ParamsWrapper&& other) noexcept {
|
| 61 |
+
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
ParamsWrapper& operator=(const ParamsWrapper& other) {
|
| 65 |
+
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
|
| 66 |
+
return *this;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
ParamsWrapper& operator=(ParamsWrapper&& other) noexcept {
|
| 70 |
+
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
|
| 71 |
+
return *this;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
inline friend bool operator==(
|
| 75 |
+
const ParamsWrapper& lhs,
|
| 76 |
+
const ParamsWrapper& rhs) noexcept {
|
| 77 |
+
auto ptr1 = reinterpret_cast<const uint8_t*>(&(lhs.pod));
|
| 78 |
+
auto ptr2 = reinterpret_cast<const uint8_t*>(&(rhs.pod));
|
| 79 |
+
return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0;
|
| 80 |
+
}
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
// Wrapped version: this allows the outer struct to have custom copy and move
|
| 84 |
+
// constructors for additional safety
|
| 85 |
+
template <typename ParamsWrapper>
|
| 86 |
+
struct ParamsWrapperHash {
|
| 87 |
+
// Params must be a POD because we read out its memory
|
| 88 |
+
// contents as char* when hashing
|
| 89 |
+
static_assert(
|
| 90 |
+
std::is_standard_layout_v<decltype(ParamsWrapper::pod)>,
|
| 91 |
+
"ParamsWrapper cannot wrap non-POD data");
|
| 92 |
+
|
| 93 |
+
size_t operator()(const ParamsWrapper& params_wrapper) const {
|
| 94 |
+
auto ptr = reinterpret_cast<const uint8_t*>(&(params_wrapper.pod));
|
| 95 |
+
uint32_t value = 0x811C9DC5;
|
| 96 |
+
for (const auto i : c10::irange(sizeof(params_wrapper.pod))) {
|
| 97 |
+
value ^= ptr[i];
|
| 98 |
+
value *= 0x01000193;
|
| 99 |
+
}
|
| 100 |
+
return (size_t)value;
|
| 101 |
+
}
|
| 102 |
+
};
|
| 103 |
+
|
| 104 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_cpu_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cpu {
|
| 19 |
+
|
| 20 |
+
TORCH_API void _amp_foreach_non_finite_check_and_unscale_(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale);
|
| 21 |
+
|
| 22 |
+
} // namespace cpu
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_bin_edges_cpu_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cpu {
|
| 19 |
+
|
| 20 |
+
TORCH_API ::std::vector<at::Tensor> _histogramdd_bin_edges(const at::Tensor & self, at::IntArrayRef bins, c10::optional<at::ArrayRef<double>> range=c10::nullopt, const c10::optional<at::Tensor> & weight={}, bool density=false);
|
| 21 |
+
|
| 22 |
+
} // namespace cpu
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_from_bin_tensors_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _histogramdd_from_bin_tensors {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &, at::TensorList, const c10::optional<at::Tensor> &, bool);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_histogramdd_from_bin_tensors")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & self, at::TensorList bins, const c10::optional<at::Tensor> & weight, bool density);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, const c10::optional<at::Tensor> & weight, bool density);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _histogramdd_from_bin_tensors_out {
|
| 29 |
+
using schema = at::Tensor & (const at::Tensor &, at::TensorList, const c10::optional<at::Tensor> &, bool, at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_histogramdd_from_bin_tensors")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_histogramdd_from_bin_tensors.out(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!)")
|
| 35 |
+
static at::Tensor & call(const at::Tensor & self, at::TensorList bins, const c10::optional<at::Tensor> & weight, bool density, at::Tensor & out);
|
| 36 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, const c10::optional<at::Tensor> & weight, bool density, at::Tensor & out);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_differentiable_gru_cell_backward.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
|
| 26 |
+
inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _thnn_differentiable_gru_cell_backward(const at::Tensor & grad_hy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const c10::optional<at::Tensor> & input_bias, const c10::optional<at::Tensor> & hidden_bias) {
|
| 27 |
+
return at::_ops::_thnn_differentiable_gru_cell_backward::call(grad_hy, input_gates, hidden_gates, hx, input_bias, hidden_bias);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_fused_gru_cell_backward.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
|
| 26 |
+
inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _thnn_fused_gru_cell_backward(const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) {
|
| 27 |
+
return at::_ops::_thnn_fused_gru_cell_backward::call(grad_hy, workspace, has_bias);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))
|
| 31 |
+
inline ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _thnn_fused_gru_cell_backward_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) {
|
| 32 |
+
return at::_ops::_thnn_fused_gru_cell_backward_out::call(grad_hy, workspace, has_bias, out0, out1, out2, out3, out4);
|
| 33 |
+
}
|
| 34 |
+
// aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))
|
| 35 |
+
inline ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _thnn_fused_gru_cell_backward_outf(const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) {
|
| 36 |
+
return at::_ops::_thnn_fused_gru_cell_backward_out::call(grad_hy, workspace, has_bias, out0, out1, out2, out3, out4);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/asin_meta_dispatch.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace meta {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor asin(const at::Tensor & self);
|
| 21 |
+
TORCH_API at::Tensor & asin_out(at::Tensor & out, const at::Tensor & self);
|
| 22 |
+
TORCH_API at::Tensor & asin_outf(const at::Tensor & self, at::Tensor & out);
|
| 23 |
+
TORCH_API at::Tensor & asin_(at::Tensor & self);
|
| 24 |
+
|
| 25 |
+
} // namespace meta
|
| 26 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API batch_norm_backward_reduce {
|
| 18 |
+
using schema = ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &, bool, bool, bool);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::batch_norm_backward_reduce")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor)")
|
| 24 |
+
static ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> call(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional<at::Tensor> & weight, bool input_g, bool weight_g, bool bias_g);
|
| 25 |
+
static ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional<at::Tensor> & weight, bool input_g, bool weight_g, bool bias_g);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API batch_norm_backward_reduce_out {
|
| 29 |
+
using schema = ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &, bool, bool, bool, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::batch_norm_backward_reduce")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))")
|
| 35 |
+
static ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> call(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional<at::Tensor> & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3);
|
| 36 |
+
static ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional<at::Tensor> & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/bitwise_or.h
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/bitwise_or_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
| 26 |
+
inline at::Tensor & bitwise_or_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
|
| 27 |
+
return at::_ops::bitwise_or_Tensor_out::call(self, other, out);
|
| 28 |
+
}
|
| 29 |
+
// aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
| 30 |
+
inline at::Tensor & bitwise_or_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
|
| 31 |
+
return at::_ops::bitwise_or_Tensor_out::call(self, other, out);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
| 35 |
+
inline at::Tensor & bitwise_or_out(at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
|
| 36 |
+
return at::_ops::bitwise_or_Scalar_out::call(self, other, out);
|
| 37 |
+
}
|
| 38 |
+
// aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
| 39 |
+
inline at::Tensor & bitwise_or_outf(const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
|
| 40 |
+
return at::_ops::bitwise_or_Scalar_out::call(self, other, out);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
// aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor
|
| 44 |
+
inline at::Tensor bitwise_or(const at::Tensor & self, const at::Scalar & other) {
|
| 45 |
+
return at::_ops::bitwise_or_Scalar::call(self, other);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// aten::bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
|
| 49 |
+
inline at::Tensor bitwise_or(const at::Scalar & self, const at::Tensor & other) {
|
| 50 |
+
return at::_ops::bitwise_or_Scalar_Tensor::call(self, other);
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
// aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor
|
| 54 |
+
inline at::Tensor bitwise_or(const at::Tensor & self, const at::Tensor & other) {
|
| 55 |
+
return at::_ops::bitwise_or_Tensor::call(self, other);
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
| 59 |
+
inline at::Tensor & bitwise_or_out(at::Tensor & out, const at::Scalar & self, const at::Tensor & other) {
|
| 60 |
+
return at::_ops::bitwise_or_Scalar_Tensor_out::call(self, other, out);
|
| 61 |
+
}
|
| 62 |
+
// aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
| 63 |
+
inline at::Tensor & bitwise_or_outf(const at::Scalar & self, const at::Tensor & other, at::Tensor & out) {
|
| 64 |
+
return at::_ops::bitwise_or_Scalar_Tensor_out::call(self, other, out);
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautogradnonfunctional {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor cat(const at::ITensorListRef & tensors, int64_t dim=0);
|
| 21 |
+
|
| 22 |
+
} // namespace compositeexplicitautogradnonfunctional
|
| 23 |
+
} // namespace at
|