diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bee89d49b935b7394c6759977480a158c79a5b1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25c53edd304c9729d3d61aeb283cd5acea01033c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21d4a4bb34e6126a9e6728372a07f446a6c1c95c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11e89f136afab0492afd4cf7fbdb6be990e4e710 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d98095909a1410c428f1a5b577f33bf8ba9c5bc9 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1af69a0b5b5373f54b8a7237a1846b805caf0035 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..d86565885370d9f9cc8c8adaefcc1c30e240762d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/binary_folding.py @@ -0,0 +1,277 @@ +import functools +import itertools + +import torch +from ..._dynamo.utils import counters + +from ..pattern_matcher import Arg, CallFunction, KeywordArg +from .freezing_patterns import register_binary_folding_pattern + +aten = torch.ops.aten +prims = torch.ops.prims + + +def mark_mixed_dtype_conv(conv): + conv_dtype = conv.meta["val"].dtype + if conv_dtype not in (torch.float16, torch.bfloat16): + return + + if not len(conv.users) == 1: + return + + conv_user = next(iter(conv.users.keys())) + if not isinstance(conv_user.meta["val"], torch.Tensor): + return + + if not conv_user.meta["val"].dtype == torch.float32: + return + + while conv_user.target in _binary_ops: + if not len(conv_user.users) == 1: + return + + conv_user = next(iter(conv_user.users.keys())) + + if not ( + conv_user.target == prims.convert_element_type.default + and conv_user.args[1] == conv_dtype + ): + return + + conv.meta["_allow_conv_mixed_dtype_folding"] = conv_dtype + + +def mark_mixed_dtype_allowed_convs(gm): + """ + Mark convolutions which we will binary fold even with mixed precision constants. We constant fold in the higher precision + for better accuracy and then recover the original precision after. + """ + for node in gm.graph.nodes: + if node.target is aten.convolution.default: + mark_mixed_dtype_conv(node) + + +def recover_original_precision_folded_convs(gm): + """ + After binary folding conv weights and biases to a higher dtype, recover the original precision they were in. + """ + graph = gm.graph + convs = [node for node in graph.nodes if node.target is aten.convolution.default] + for node in convs: + orig_dtype = node.meta.get("_allow_conv_mixed_dtype_folding", None) + if orig_dtype is None: + continue + + with graph.inserting_before(node): + for idx in [1, 2]: + old_input = node.args[idx] + if old_input is None: + continue + + new_input = graph.create_node( + "call_function", + prims.convert_element_type.default, + (old_input, orig_dtype), + ) + node.replace_input_with(old_input, new_input) + + +_binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor] + + +@functools.lru_cache(None) +def binary_folding_init(): + _conv_args = [Arg() for _ in range(9)] + _computation_ops = [aten.convolution.default] + _computation_calls = [CallFunction(aten.convolution.default, *_conv_args, _users=1)] + + """ + In order to fuse add/sub/mul/div with conv, the dimensions of its + constant tensor must satisfy the following: + - with resizing, broadcast to w/ weight/bias tensor shape + - broadcast to the conv output shape + It needs to have a shape that can resize to weight/bias + tensor shape because we need to run the op with the conv + weights/bias without changing their sizes. + It needs to broadcast to the conv output shape so that we do + accidentally change the shape of op output by pre-fusing it + compared to eager. + The only dimension value shared by weight/bias/conv output + is they all contain a dim with value = channels-out. In the + conv output tensor, this is in the second dimension, + so the pointwise op tensor may have a second dimension of + value == channels-out, but all the other dimensions have to be 1 + """ + + def _op_not_broadcasting_with_conv(weight_tensor, other_tensor): + # According to opDoesNotBroadCastWithConv of frozen_conv_folding.cpp + weight_shape = weight_tensor.shape + other_shape = other_tensor.shape + if len(weight_shape) < len(other_shape): + return False + if len(weight_shape) == len(other_shape) + 1: + # weight shape is [o, i, *], other_shape is [o, 1...]. + for i in reversed(range(len(other_shape))): + if i == 0 and weight_shape[0] == other_shape[i]: + continue + if other_shape[i] != 1: + return False + else: + # weight shape is [o, i, *], other_shape is [1, i, *] + for i in reversed(range(len(other_shape))): + if i == 1 and weight_shape[0] == other_shape[i]: + continue + if other_shape[i] != 1: + return False + return True + + def _check_conv_and_broadcast_op(conv_node, other): + # According to checkConvAndBroadcastingOpPreConditions of frozen_conv_folding.cpp. + # conv.weight + if conv_node.args[1].op != "get_attr": + return False + # conv.bias + if conv_node.args[1] is not None and conv_node.args[1].op != "get_attr": + return False + if ( + not isinstance(other, int) + and not isinstance(other, float) + and other.op != "get_attr" + ): + return False + + if not len(conv_node.args[1].users) == 1: + return False + + weight_meta_value = conv_node.args[1].meta.get("val") + if weight_meta_value is None: + return False + # Avoid fusing op that causes type promotion + # restricting to float avoids int/float difficulties with scalar overload + if not weight_meta_value.is_floating_point(): + return False + if isinstance(other, torch.fx.Node) and other.op == "get_attr": + other_meta_value = other.meta.get("val") + if not other_meta_value.is_floating_point(): + return False + if ( + torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) + != weight_meta_value.dtype + ): + if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False): + return False + + if ( + other_meta_value.dtype != torch.float + and weight_meta_value.dtype not in (torch.float16, torch.bfloat16) + ): + return False + + if not _op_not_broadcasting_with_conv(weight_meta_value, other_meta_value): + return False + else: + # TODO: support scalar case + return False + + return True + + def _is_foldable_pattern(match): + binary_node = match.output_node() + computation_node = binary_node.args[0] + other = binary_node.args[1] + if binary_node.args[0].target not in _computation_ops: + computation_node = binary_node.args[1] + other = binary_node.args[0] + if binary_node.args[0].target == aten.convolution.default: + return _check_conv_and_broadcast_op(computation_node, other) + + return False + + def resize_scalar_or_tensor_to_shape(graph, other, shape): + # TODO: support scalar case + if other.meta.get("val").numel() == 1: + # expand errors if the shape input has less # dims than the tensor input + res = graph.create_node( + "call_function", + aten.reshape.default, + (other, (1,)), + ) + res = graph.create_node( + "call_function", + aten.expand.default, + (res, shape), + ) + else: + res = graph.create_node( + "call_function", + aten.reshape.default, + (other, shape), + ) + return res + + def _create_new_conv_node(graph, conv_node, binary_node, other): + assert conv_node.target == aten.convolution.default + conv_args = list(conv_node.args) + weight_meta_value = conv_node.args[1].meta.get("val") + bias = conv_args[2] + if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, other, (weight_meta_value.size(0),) + ) + new_bias = graph.create_node( + "call_function", + binary_node.target, + (0 if bias is None else bias, other_reshape), + ) + conv_args[2] = new_bias + else: + assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor] + weight_broadcast_shape = [1 for _ in range(len(weight_meta_value.shape))] + weight_broadcast_shape[0] = weight_meta_value.size(0) + other_reshape1 = resize_scalar_or_tensor_to_shape( + graph, other, tuple(weight_broadcast_shape) + ) + new_weight = graph.create_node( + "call_function", binary_node.target, (conv_args[1], other_reshape1) + ) + new_weight.meta.update(conv_args[1].meta) + conv_args[1] = new_weight + if bias is not None: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, other, (weight_meta_value.size(0),) + ) + new_bias = graph.create_node( + "call_function", binary_node.target, (bias, other_reshape) + ) + new_bias.meta.update(bias.meta) + conv_args[2] = new_bias + return graph.create_node("call_function", conv_node.target, tuple(conv_args)) + + for _computation_call, binary_op in itertools.product( + _computation_calls, _binary_ops + ): + + @register_binary_folding_pattern( + CallFunction(binary_op, _computation_call, KeywordArg("other")), + extra_check=_is_foldable_pattern, + ) + def folded_op(match, *args, **kwargs): + counters["inductor"]["binary_folding"] += 1 + other = kwargs.get("other") + binary_node = match.output_node() + computation_node = ( + binary_node.args[0] + if binary_node.args[0].target in _computation_ops + else binary_node.args[1] + ) + graph = match.graph + with graph.inserting_before(binary_node): + # TODO: support linear? + assert computation_node.target == aten.convolution.default + new_computation_node = _create_new_conv_node( + graph, computation_node, binary_node, other + ) + binary_node.replace_all_uses_with(new_computation_node) + new_computation_node.meta.update(computation_node.meta) + graph.erase_node(binary_node) + graph.erase_node(computation_node) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..14f484255bea239498ed11334c93fc9311e4120f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -0,0 +1,221 @@ +import logging +from typing import List, Optional + +import torch +from torch import Tensor +from torch._dynamo.utils import counters +from torch._inductor import utils + +from ..pattern_matcher import ( + Arg, + CallFunction, + config_flag, + Ignored, + Match, + register_graph_pattern, +) +from .post_grad import decompose_mm_pass + +aten = torch.ops.aten +log = logging.getLogger(__name__) + +# TODO: need a better strategy for decomposing mm +MIN_FIRST_DIMENSION_DECOMPOSITION = 10240 +MAX_OTHER_DIMENSION_DECOMPOSITION = 32 + + +def check_device(a: Tensor, b: Tensor) -> bool: + return a.is_cuda and b.is_cuda + + +def should_decompose_common( + mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None +) -> bool: + return ( + torch._inductor.config.decompose_mem_bound_mm + and check_device(mat1, mat2) + and not utils.any_is_symbolic(mat1, mat2, input) + ) + + +def should_decompose_bmm(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + if not should_decompose_common(mat1, mat2): + return False + else: + if len(mat1.shape) != 3 or len(mat2.shape) != 3: + return False + if mat1.shape[0] < MIN_FIRST_DIMENSION_DECOMPOSITION: + return False + # 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION + if (mat1.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION) + ( + mat1.shape[2] < MAX_OTHER_DIMENSION_DECOMPOSITION + ) + (mat2.shape[2] < MAX_OTHER_DIMENSION_DECOMPOSITION) < 2: + return False + return True + + +def should_decompose_mm(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + return ( + should_decompose_common(mat1, mat2) + and len(mat1.shape) == 2 + and len(mat2.shape) == 2 + and mat1.shape[0] >= MIN_FIRST_DIMENSION_DECOMPOSITION + and mat2.shape[0] < MAX_OTHER_DIMENSION_DECOMPOSITION + and mat2.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION + ) + + +def should_decompose_mmt(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + return ( + should_decompose_common(mat1, mat2) + and len(mat1.shape) == 2 + and len(mat2.shape) == 2 + and mat1.shape[0] >= MIN_FIRST_DIMENSION_DECOMPOSITION + and mat1.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION + and mat2.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION + ) + + +def should_decompose_mm_largek(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + return ( + should_decompose_common(mat1, mat2) + and len(mat1.shape) == 2 + and len(mat2.shape) == 2 + and mat1.shape[1] >= MIN_FIRST_DIMENSION_DECOMPOSITION + and mat1.shape[0] < MAX_OTHER_DIMENSION_DECOMPOSITION + and mat2.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION + ) + + +def is_node_meta_valid(node: torch.fx.Node): + return "val" in node.meta + + +def print_decompose_pattern(match: Match, inputs: List[torch.fx.Node]): + node = match.nodes[-1] + log.debug( + "Decompose %s with input shape: %s", + node.target, + ", ".join( + str(input.meta["val"].shape) if "val" in input.meta else "None" + for input in inputs + ), + ) + + +@register_graph_pattern( + CallFunction(aten.bmm, Arg(), Arg()), + pass_dict=decompose_mm_pass, + extra_check=config_flag("decompose_mem_bound_mm"), +) +def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, :, None] * mat2[:, None, :, :], dim=-2) + + if should_decompose_bmm(mat1, mat2): + counters["inductor"]["decompose_bmm"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + return + + +@register_graph_pattern( + CallFunction(aten.addmm, Arg(), Arg(), Arg()), + pass_dict=decompose_mm_pass, + extra_check=config_flag("decompose_mem_bound_mm"), +) +def decompose_addmm( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, + mat3: torch.fx.Node, +): + def repl(mat1, mat2, mat3): + return torch.sum(mat2[:, :, None] * mat3[None, :, :], dim=-2) + mat1 + + if should_decompose_mm(mat2, mat3): + counters["inductor"]["decompose_addmm"] += 1 + match.replace_by_example(repl, [mat1, mat2, mat3]) + print_decompose_pattern(match, [mat1, mat2, mat3]) + return + + +@register_graph_pattern( + CallFunction(aten.mm, CallFunction(aten.permute, Arg(), Ignored()), Arg()), + pass_dict=decompose_mm_pass, + extra_check=config_flag("decompose_mem_bound_mm"), +) +def decompose_mmt( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, +): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, None] * mat2[:, None, :], dim=0) + + if should_decompose_mmt(mat1, mat2): + counters["inductor"]["decompose_mmt"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + return + + +@register_graph_pattern( + CallFunction(aten.mm, Arg(), Arg()), + pass_dict=decompose_mm_pass, + extra_check=config_flag("decompose_mem_bound_mm"), +) +def decompose_mm( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, +): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, None] * mat2[None, :, :], dim=-2) + + if should_decompose_mm(mat1, mat2): + counters["inductor"]["decompose_mm"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + return + + +@register_graph_pattern( + CallFunction(aten.mm, Arg(), Arg()), + pass_dict=decompose_mm_pass, + extra_check=config_flag("decompose_mem_bound_mm"), +) +def decompose_mm_large_k( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, +): + def repl(mat1, mat2): + mat1 = mat1.permute(1, 0) + return torch.sum(mat1[:, :, None] * mat2[:, None, :], dim=0) + + if should_decompose_mm_largek(mat1, mat2): + counters["inductor"]["decompose_mm_large_k"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + return diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a5b7ccd6854c68e1002206fdc35a5969bb20d9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -0,0 +1,157 @@ +import torch +import torch.nn as nn + +from torch._dynamo.utils import counters +from torch._inductor import config as inductor_config +from torch.func import functional_call + +from ..pattern_matcher import CallModuleVarArgs, Match, register_graph_pattern + +from .pre_grad import efficient_conv_bn_eval_pass + + +def efficient_conv_bn_eval( + bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor +): + """ + Implementation based on https://arxiv.org/abs/2305.11624 + "Tune-Mode ConvBN Blocks For Efficient Transfer Learning" + It leverages the associative law between convolution and affine transform, + i.e., normalize (weight conv feature) = (normalize weight) conv feature. + It works for Eval mode of ConvBN blocks during validation, and can be used + for **training** as well, but only if one sets `bn.training=False`. It + reduces memory footprint and computation cost, at the cost of slightly + reduced numerical stability. + Args: + bn (nn.modules.batchnorm._BatchNorm): a BatchNorm module. + conv (nn.modules.conv._ConvNd): a conv module + x (torch.Tensor): Input feature map. + """ + + assert bn.running_var is not None + + # These lines of code are designed to deal with various cases + # like bn without affine transform, and conv without bias + weight_on_the_fly = conv.weight + if conv.bias is not None: + bias_on_the_fly = conv.bias + else: + bias_on_the_fly = torch.zeros_like(bn.running_var) + + if bn.weight is not None: + bn_weight = bn.weight + else: + bn_weight = torch.ones_like(bn.running_var) + + if bn.bias is not None: + bn_bias = bn.bias + else: + bn_bias = torch.zeros_like(bn.running_var) + + # shape of [C_out, 1, 1, 1] in Conv2d + target_shape = [-1] + [1] * (conv.weight.ndim - 1) + if isinstance(conv, nn.modules.conv._ConvTransposeNd): + # for transposed conv, the C_out dimension should at index 1. + target_shape[:2] = [target_shape[1], target_shape[0]] + weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape(target_shape) + # shape of [C_out, 1, 1, 1] in Conv2d + coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff + + # shape of [C_out, C_in, k, k] in Conv2d + weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly + # shape of [C_out] in Conv2d + bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * ( + bias_on_the_fly - bn.running_mean + ) + + input = x + params = {"weight": weight_on_the_fly, "bias": bias_on_the_fly} + output = functional_call(conv, params, input) + return output + + +@register_graph_pattern( + CallModuleVarArgs( + [ + nn.modules.batchnorm._BatchNorm, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.SyncBatchNorm, + ], + ), + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs): + # We matched a BN node + bn_node = match.nodes[0] + graph = match.graph + gm = graph.owning_module + bn_mod = getattr(gm, bn_node.target) # type: ignore[arg-type] + + # We can only use efficient conv-bn for eval mode with track_running_stats + if not bn_mod.track_running_stats or bn_mod.training: + return + + # Check if the input is Conv + if bn_node.args: + input_node = bn_node.args[0] + else: + input_node = bn_node.kwargs["input"] + if input_node.op != "call_module": # type: ignore[union-attr] + return + if not hasattr(gm, input_node.target): # type: ignore[arg-type, union-attr] + return + input_mod = getattr(gm, input_node.target) # type: ignore[arg-type, union-attr] + supported_convs = [ + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + ] + if not any(isinstance(input_mod, cls) for cls in supported_convs): + return + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + # Find a pair of conv and bn computation nodes to optimize. + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(conv_node): + # create `get_attr` node to access modules + # note that we directly call `create_node` to fill the `name` + # argument. `graph.get_attr` and + # `graph.call_function` does not allow the `name` argument. + conv_get_node = graph.create_node( + op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr] + ) + bn_get_node = graph.create_node( + op="get_attr", target=bn_node.target, name="get_bn" + ) + if conv_node.args: # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + else: + conv_input = conv_node.kwargs["input"] # type: ignore[union-attr] + # prepare args for the fused function + args = (bn_get_node, conv_get_node, conv_input) + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval, + args=args, + name="efficient_conv_bn_eval", + ) + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..6c1f52f47022ed4c2e77611e06d42b9b2d84df41 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/freezing_patterns.py @@ -0,0 +1,212 @@ +import functools + +import torch +from torch._inductor.compile_fx import fake_tensor_prop +from ..._dynamo.utils import counters + +from .. import config +from ..pattern_matcher import ( + _return_true, + CallFunction, + fwd_only, + Ignored, + init_once_fakemode, + KeywordArg, + Match, + PatternMatcherPass, + register_graph_pattern, + register_replacement, + stable_topological_sort, +) + +aten = torch.ops.aten + +# First pass_patterns[0] are applied, then [1], then [2] +pass_patterns = [ + PatternMatcherPass(), + PatternMatcherPass(), + PatternMatcherPass(), +] + +binary_folding_pass = PatternMatcherPass() + + +def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs): + """ + Passes that are applied to the graph to freeze pass. + """ + + from ..freezing import constant_fold + + lazy_init() + # We need a few rounds of binary folding to get rid of all the + # unnecessary nodes, but may need a good method to chose the rounds number. + # works like: conv+binary+binary. + binary_folding = counters["inductor"]["binary_folding"] + fake_tensor_prop(gm, aot_example_inputs, True) + + torch._inductor.fx_passes.binary_folding.mark_mixed_dtype_allowed_convs(gm) + for _ in range(4): + constant_fold(gm) + # Make sure meta['val'] is properly set for all nodes + fake_tensor_prop(gm, aot_example_inputs, True) + binary_folding_pass.apply(gm.graph) # type: ignore[arg-type] + # If we don't have binary folding, we don't need to run the pass again. + # TODO: remove the need to run fake_tensor_prop on the whole model. + if counters["inductor"]["binary_folding"] == binary_folding: + break + binary_folding = counters["inductor"]["binary_folding"] + + torch._inductor.fx_passes.binary_folding.recover_original_precision_folded_convs(gm) + + constant_fold(gm) + fake_tensor_prop(gm, aot_example_inputs, True) + + for pattern in pass_patterns: + pattern.apply(gm.graph) # type: ignore[arg-type] + + # The CPU weight packing always assume the conv's weight is channels last, + # So make sure the layout_optimization is on when doing it. + if ( + torch._C._has_mkldnn + and config.cpp.weight_prepack + and config.layout_optimization + ): + from .mkldnn_fusion import _eliminate_duplicate_packed_nodes + + _eliminate_duplicate_packed_nodes(gm) + + stable_topological_sort(gm.graph) + gm.recompile() + gm.graph.lint() + + +@init_once_fakemode +def lazy_init(): + if torch._C._has_mkldnn and config.cpp.weight_prepack: + from .mkldnn_fusion import _mkldnn_weight_pack_init + + _mkldnn_weight_pack_init() + + from .binary_folding import binary_folding_init + + addmm_patterns_init() + binary_folding_init() + + +def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0): + return register_graph_pattern( + pattern, + extra_check=extra_check, + pass_dict=pass_patterns[pass_number], + ) + + +def register_binary_folding_pattern(pattern, extra_check=_return_true): + return register_graph_pattern( + pattern, + extra_check=extra_check, + pass_dict=binary_folding_pass, + ) + + +@functools.lru_cache(None) +def addmm_patterns_init(): + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False) + + def check_concat_weights(match): + weights = [ + match.kwargs["w1"], + match.kwargs["w2"], + ] + if "w3" in match.kwargs: + weights.append(match.kwargs["w3"]) + + return all( + w.op == "get_attr" and w.meta["val"].shape == weights[0].meta["val"].shape + for w in weights + ) + + def matmul_fuse_pattern(inp, w1, w2, w3): + return (inp @ w1, inp @ w2, inp @ w3) + + def matmul_replacement(inp, w1, w2, w3): + cat_t = torch.cat((w1, w2, w3), dim=1) + mm = inp @ cat_t + return mm.chunk(3, dim=1) + + register_replacement( + matmul_fuse_pattern, + matmul_replacement, + [val(), val(), val(), val()], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2", "w3"), + ) + + def matmul_fuse_pattern_two(inp, w1, w2): + return (inp @ w1, inp @ w2) + + def matmul_replacement_two(inp, w1, w2): + cat_t = torch.cat((w1, w2), dim=1) + mm = inp @ cat_t + return mm.chunk(2, dim=1) + + register_replacement( + matmul_fuse_pattern_two, + matmul_replacement_two, + [val(), val(), val()], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2"), + ) + + def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3): + return ( + aten.addmm(b1, inp, w1), + aten.addmm(b2, inp, w2), + aten.addmm(b3, inp, w3), + ) + + def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3): + cat_w = torch.cat((w1, w2, w3), dim=1) + cat_b = torch.cat((b1, b2, b3)) + return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1) + + register_replacement( + addmm_fuse_pattern_second, + addmm_fuse_replacement_second, + [val() for _ in range(7)], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"), + ) + + +def same_dtype(match): + return match.output_node().args[0].meta["val"].dtype == match.kwargs["dtype"] + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.convert_element_type.default, + Ignored(), + KeywordArg("dtype"), + ), + pass_dict=pass_patterns[0], + extra_check=same_dtype, +) +def unnecessary_dtype_convert(match: Match, **kwargs): + """Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding""" + graph = match.graph + node = match.output_node() + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4baf12d4eae35c4dab6eae285cf313e52ae5ec6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/numeric_utils.py @@ -0,0 +1,210 @@ +import gc +import logging +import os +import random +import traceback + +import numpy + +import torch +import torch.optim as optim + +from .. import config + +logger: logging.Logger = logging.getLogger(__name__) + +MAIN_RANDOM_SEED = 1337 + +# Set the CUBLAS_WORKSPACE_CONFIG environment variable +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + +# If the two forward functions involve any non-deterministic operations, +# such as certain types of parallelism or asynchronous execution, +# this can also lead to different outputs. +def set_deterministic() -> None: + """Make torch manual seed deterministic.""" + + torch.manual_seed(MAIN_RANDOM_SEED) + random.seed(MAIN_RANDOM_SEED) + numpy.random.seed(MAIN_RANDOM_SEED) + torch.use_deterministic_algorithms(True) + + +def clean_memory() -> None: + """Clean memory to avoid OOM.""" + gc.collect() + torch.cuda.empty_cache() + + +# We compare the numerical results before and after pre/post grad fx passes +# transformation to make sure the numerical results are the same. +def compare_dict_tensors(dict_base, dict_control, precision): + if len(set(dict_base.keys())) != len(set(dict_control.keys())): + logger.warning("Mismatch keys found before and after pre/post grad fx passes.") + logger.debug("keys before pre/post grad fx passes %s", dict_base.keys()) + logger.debug("keys after pre/post grad fx passes %s", dict_control.keys()) + return False + is_allclose = True + for key in dict_base.keys(): + if key not in dict_control: + logger.warning( + "Mismatch parameter name %s does not exist after pre/post grad fx passes", + key, + ) + # Some parameters have `None`, and not every param has a valid .grad field, we skip them + if dict_base[key] is None or dict_control[key] is None: + continue + if not torch.allclose( + dict_base[key], + dict_control[key], + rtol=precision, + atol=precision, + equal_nan=True, + ): + logger.warning( + "Mismatch parameter values found before and after pre/post grad fx passes." + ) + logger.debug("value before pre/post grad fx passes %s", dict_base[key]) + logger.debug("value after pre/post grad fx passes %s", dict_control[key]) + is_allclose = False + return is_allclose + + +def compare_tuple_tensors(tuple_base, tuple_control, precision): + if len(tuple_base) != len(tuple_control): + logger.warning( + "Mismatch fw output length. before transformation: %s, after transformation: %s", + len(tuple_base), + len(tuple_control), + ) + return False + is_allclose = True + for i in range(len(tuple_base)): + # Some parameters have `None`, we skip them + if tuple_base[i] is None or tuple_control[i] is None: + continue + if not torch.allclose( + tuple_base[i], + tuple_control[i], + rtol=precision, + atol=precision, + equal_nan=True, + ): + logger.debug( + "forward output before pre/post grad fx passes %s", tuple_base[i] + ) + logger.debug( + "forward output after pre/post grad fx passes %s", tuple_control[i] + ) + is_allclose = False + return is_allclose + + +def compare_parameters(model_base, model_control, precision): + return compare_dict_tensors( + dict(model_base.named_parameters()), + dict(model_control.named_parameters()), + precision, + ) + + +def compare_forward_output(pred_base, pred_control, precision): + return compare_tuple_tensors( + pred_base, + pred_control, + precision, + ) + + +def compare_gradients(model_base, model_control, precision): + grad_base = {key: param.grad for key, param in model_base.named_parameters()} + grad_pt2 = {key: param.grad for key, param in model_control.named_parameters()} + return compare_dict_tensors( + grad_base, + grad_pt2, + precision, + ) + + +def run_model( + model_base, model_control, model_input, num_iterations=10, precision=1e-4 +): + clean_memory() + for i in range(num_iterations): + logger.info("start %s iteration", i) + set_deterministic() + pred_base = model_base(*model_input) + set_deterministic() + pred_control = model_control(*model_input) + + res = compare_parameters(model_base, model_control, precision) + logger.info("compare parameters. Numerical result : %s", res) + + res = compare_forward_output(pred_base, pred_control, precision) + logger.info("compare loss/predict. Numerical result : %s", res) + # tensor may not have a grad_fn + try: + _ = pred_base[0].sum().backward(retain_graph=True) + _ = pred_control[0].sum().backward(retain_graph=True) + res = compare_gradients(model_base, model_control, precision) + logger.info("compare param grad. Numerical result : %s", res) + except Exception as e: + logger.exception("Exception %s when compare gradients", e) + traceback.print_exc() + + if config.fx_passes_numeric_check["requires_optimizer"]: + try: + optimizer_base = optim.SGD( + [param for name, param in model_base.named_parameters()], lr=0.01 + ) + optimizer_base.step() + + optimizer_control = optim.SGD( + [param for name, param in model_control.named_parameters()], lr=0.01 + ) + optimizer_control.step() + + res = compare_parameters(model_base, model_control, precision) + logger.info( + "compare parameters with optimizer added. Numerical result : %s", + res, + ) + except Exception as e: + logger.exception( + "Exception %s when optimizer is added to check parameter names", e + ) + traceback.print_exc() + else: + logger.warning( + "no parameter with optimizer to compare with length %s before transformation" + " and the length %s after transformation", + len(dict(model_base.named_parameters())), + len(dict(model_control.named_parameters())), + ) + + +def numeric_check_if_enabled( + gm_before_fx_passes, + gm_after_fx_passes, + example_inputs, + num_iterations, + precision, +): + # need to topo-sort graphmodule before we run the model, + # otherwise it may fail as refer before def + # fail silently in order not to block the model run + try: + with torch.autograd.set_detect_anomaly(True): + run_model( + gm_before_fx_passes, + gm_after_fx_passes, + example_inputs, + num_iterations=num_iterations, + precision=precision, + ) + except Exception as e: + logger.warning( + "Runtime numeric check failed in pre grad fx passes with error: %s", e + ) + traceback.print_exc() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..12fcef85a8a2cdb1d7a7b78eb371562a8849e5f3 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/pad_mm.py @@ -0,0 +1,567 @@ +import functools +from typing import List, Optional, Set, Union + +import torch +from torch import Tensor +from torch._inductor import utils +from torch._subclasses.fake_tensor import FakeTensor +from torch.utils._mode_utils import no_dispatch +from torch.utils._triton import has_triton + +from ..pattern_matcher import ( + fwd_only, + joint_fwd_bwd, + Match, + MatchContext, + register_replacement, +) +from ..utils import is_view + +aten = torch.ops.aten + + +# This flag is only used for testing purpose. +# Changing it to True will ignore comparing do_bench times +# between original pattern and padded one. +_skip_do_bench_times = False + + +def fetch_fake_tensors(match, kwarg_names) -> List[Tensor]: + kwargs = match.kwargs + return [kwargs[name].meta["val"] for name in kwarg_names] + + +def unwrap_fake_args(*arg_names): + def decorator(func): + def wrapper(match): + fake_tensors = fetch_fake_tensors(match, arg_names) + return func(*fake_tensors) + + return wrapper + + return decorator + + +def get_alignment_size(x: Tensor) -> int: + if x.dtype == torch.float16 or x.dtype == torch.half or x.dtype == torch.bfloat16: + return 8 + elif x.dtype == torch.float32 or x.dtype == torch.float: + return 4 + else: + return 0 + + +def check_device(a: Tensor, b: Tensor) -> bool: + return a.is_cuda and b.is_cuda + + +def check_dtype(a: Tensor, b: Tensor) -> bool: + return a.is_floating_point() and b.is_floating_point() + + +def _result_layout_affects_graph_output(match: Match) -> bool: + """ + Check if the matched GEMM operation potentially affects the graph output strides. + returns True if the matched op's output buffer does not pass through functions which certainly + redefine the memory layout before being part of the graph output. + """ + + if match.ctx is not None: + assert isinstance(match.ctx, MatchContext) + search_node: torch.fx.Node = match.output_node() + else: + return True + + assert search_node is not None + seen: Set[torch.fx.Node] = set() + + def find_output(node: torch.fx.Node, is_start_node=False): + if not isinstance(node, torch.fx.Node): + return False + if node in seen: + return False + seen.add(node) + if node.op == "output": + return True + if node.op != "call_function": + return False + if not is_start_node and ( + (not isinstance(node.target, torch._ops.OpOverload)) + or (not is_view(node.target)) + ): + return False + if node.users is not None and len(node.users) > 0: + for n in node.users: + if find_output(n): + return True + return False + + return find_output(search_node, True) + + +def should_pad_common( + mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None +) -> bool: + # It's fine we have symbolic shapes or strides as long as they + # have hints. Later, we will make sure we only pad non-symbolic dimensions. + def valid_shape_and_stride(t: Optional[Tensor]) -> bool: + if t is None: + return True + + symbolic_cnt = 0 + for x in t.size(): + if isinstance(x, int): + continue + elif utils.is_symbolic(x): + if not x.node.has_hint(): + return False + symbolic_cnt += 1 + else: + return False + # filter out cases where all dimentions are symbolic + if symbolic_cnt == len(t.size()): + return False + return all( + isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint()) + for x in t.stride() + ) + + return ( + torch._inductor.config.shape_padding + and check_device(mat1, mat2) + and check_dtype(mat1, mat2) + and all(valid_shape_and_stride(t) for t in (mat1, mat2, input)) + ) + + +def get_padded_length(x: Union[int, torch.SymInt], alignment_size) -> int: + # we don't pad x if it is symbolic + if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0: + return 0 + return int((x // alignment_size + 1) * alignment_size) - x + + +def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor: + if padded_length == 0: + return x + pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :]) + return torch.cat([x, pad], dim=dim) + + +def addmm_pattern( + input: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float +) -> Tensor: + return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + + +def should_pad_addmm(match: Match) -> bool: + if ( + torch._inductor.config.keep_output_stride + and _result_layout_affects_graph_output(match) + ): + return False + mat1, mat2, input = fetch_fake_tensors(match, ("mat1", "mat2", "input")) + return should_pad_common(mat1, mat2, input) and should_pad_bench( + mat1, mat2, torch.ops.aten.addmm, input=input + ) + + +def addmm_replace( + input: Optional[Tensor], mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0 +) -> Tensor: + m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) + k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) + + if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0: + return pad_addmm( + input, + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + beta, + alpha, + ) + + return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + + +def pad_addmm( + input: Optional[Tensor], + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + beta=1.0, + alpha=1.0, +): + # addmm decomp with padding will go through pad_addmm multiple times if multiple dimensions are needed to be padded + if k_padded_length != 0: + mat1 = pad_dim(mat1, k_padded_length, 1) + mat2 = pad_dim(mat2, k_padded_length, 0) + elif n_padded_length != 0: + mat2 = pad_dim(mat2, n_padded_length, 1) + elif m_padded_length != 0: + mat1 = pad_dim(mat1, m_padded_length, 0) + + # the add broadcasts, so we only pad if the dimension != 1 + if input is not None and k_padded_length == 0: + if n_padded_length != 0: + if input.dim() == 2 and input.shape[1] != 1: + input = pad_dim(input, n_padded_length, 1) + elif input.dim() == 1 and input.shape[0] != 1: + input = pad_dim(input, n_padded_length, 0) + elif m_padded_length != 0 and input.dim() == 2 and input.shape[0] != 1: + input = pad_dim(input, m_padded_length, 0) + + if k_padded_length != 0: + return addmm_replace(input, mat1, mat2, beta=beta, alpha=alpha) + elif n_padded_length != 0: + return addmm_replace(input, mat1, mat2, beta=beta, alpha=alpha)[ + :, :-n_padded_length + ] + else: + return addmm_replace(input, mat1, mat2, beta=beta, alpha=alpha)[ + :-m_padded_length, : + ] + + +def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool: + denominator = M * K + N * K + M * N + if denominator == 0: + return False + arithmetic_intensity = (M * N * K) / denominator + + # Fails with AMD + try: + machine_balance = ( + 1000 * utils.get_device_tflops(dtype) + ) / utils.get_gpu_dram_gbps() + except Exception: + return True + + # dram_gbps might be underestimating bandwidth because of cache. + # if we estimate machine balance too low we might miss some speedups, + # if we extimate too high there will be unnecessary compilation time increase. + # TODO - finetune coefficient here. As a reference point, Triton mm model assumes + # 80% of reads are in cache and cache is 4x faster than dram_gbps + machine_balance = machine_balance * 0.5 + + return arithmetic_intensity > machine_balance + + +@functools.lru_cache(None) +def get_pad_cache(): + return torch._inductor.codecache.LocalCache() + + +def get_cached_should_pad(key): + return get_pad_cache().lookup(key) + + +def set_cached_should_pad(key, value): + return get_pad_cache().set_value(key, value=value) + + +def should_pad_bench_key( + mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None +) -> str: + def tensor_key(t): + return (t.shape, t.stride(), t.dtype) + + tf32_key = ( + None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32 + ) + key = ( + tensor_key(mat1), + tensor_key(mat2), + op, + input if input is None else tensor_key(input), + tf32_key, + ) + + return str(key) + + +def should_pad_bench( + mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None +) -> bool: + if not has_triton(): + return False + + do_bench = functools.partial( + utils.do_bench, + warmup=5, + ) + + with no_dispatch(): + if op is torch.ops.aten.mm or op is torch.ops.aten.addmm: + m = mat1.shape[0] + k = mat1.shape[1] + n = mat2.shape[1] + + m_padded_length = get_padded_length(m, get_alignment_size(mat1)) + k_padded_length = get_padded_length(k, get_alignment_size(mat1)) + n_padded_length = get_padded_length(n, get_alignment_size(mat2)) + elif op is torch.ops.aten.bmm: + m = mat1.shape[1] + k = mat1.shape[2] + n = mat2.shape[2] + + m_padded_length = get_padded_length(m, get_alignment_size(mat1)) + k_padded_length = get_padded_length(k, get_alignment_size(mat1)) + n_padded_length = get_padded_length(n, get_alignment_size(mat2)) + else: + return False + + if m_padded_length == k_padded_length == n_padded_length == 0: + return False + + if not is_mm_compute_bound(m, k, n, mat1.dtype): + return False + + # We don't want to look up the cache for cases that are trivially false + # since it does file io + key = should_pad_bench_key(mat1, mat2, op, input) + + cached_pad = get_cached_should_pad(key) + if cached_pad is not None: + return cached_pad + + def realize_symbols(ds): + return [d if isinstance(d, int) else d.node.hint for d in ds] + + def realize_tensor(t): + if isinstance(t, FakeTensor): + size_hints = realize_symbols(t.size()) + stride_hint = realize_symbols(t.stride()) + real_size = ( + sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1 + ) + real_t = torch.randn(real_size, dtype=t.dtype, device=t.device) + return torch.as_strided(real_t, size_hints, stride_hint) + else: + return torch.randn_like(t) + + mat1 = realize_tensor(mat1) + mat2 = realize_tensor(mat2) + if op is torch.ops.aten.bmm or op is torch.ops.aten.mm: + ori_time = do_bench( + lambda: op(mat1, mat2), + ) + else: + if input is not None: + input = realize_tensor(input) + ori_time = do_bench( + lambda: op(input, mat1, mat2), + ) + + mat1_pad = torch.randn_like(mat1) + mat2_pad = torch.randn_like(mat2) + + if op is torch.ops.aten.addmm: + input_pad = None + if input is not None and input.is_cuda: + input_pad = torch.randn_like(input) + pad_time = do_bench( + lambda: pad_addmm( + input_pad, + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + ), + ) + elif op is torch.ops.aten.mm: + pad_time = do_bench( + lambda: pad_mm( + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + ), + ) + else: + pad_time = do_bench( + lambda: pad_bmm( + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + ), + ) + + # Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable + # tradeoff between performance improvement from shape padding and overhead from additional memory ops + # TODO: Build a learned model which would be better than this heuristic + should_pad = _skip_do_bench_times or ori_time > pad_time * 1.1 + set_cached_should_pad(key, should_pad) + + return should_pad + + +def mm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: + return aten.mm(mat1, mat2) + + +def should_pad_mm(match: Match) -> bool: + if ( + torch._inductor.config.keep_output_stride + and _result_layout_affects_graph_output(match) + ): + return False + mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) + return should_pad_common(mat1, mat2) and should_pad_bench( + mat1, mat2, torch.ops.aten.mm + ) + + +def mm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: + m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) + k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) + + return pad_mm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length) + + +def pad_mm( + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, +) -> Tensor: + # mm_replace will go through pad_mm multiple times if multiple dimensions are needed to be padded + if k_padded_length != 0: + mat1 = pad_dim(mat1, k_padded_length, 1) + mat2 = pad_dim(mat2, k_padded_length, 0) + return torch.ops.aten.mm(mat1, mat2) + elif n_padded_length != 0: + mat2 = pad_dim(mat2, n_padded_length, 1) + return torch.ops.aten.mm(mat1, mat2)[:, :-n_padded_length] + else: + mat1 = pad_dim(mat1, m_padded_length, 0) + return torch.ops.aten.mm(mat1, mat2)[:-m_padded_length, :] + + +def bmm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: + return aten.bmm(mat1, mat2) + + +def should_pad_bmm(match: Match) -> bool: + if ( + torch._inductor.config.keep_output_stride + and _result_layout_affects_graph_output(match) + ): + return False + mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) + return should_pad_common(mat1, mat2) and should_pad_bench( + mat1, mat2, torch.ops.aten.bmm + ) + + +def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: + m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2)) + + if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0: + return pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length) + + return aten.bmm(mat1, mat2) + + +def pad_bmm( + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, +) -> Tensor: + # bmm_replace will go through pad_bmm multiple times if multiple dimensions are needed to be padded + if k_padded_length != 0: + mat1 = pad_dim(mat1, k_padded_length, 2) + mat2 = pad_dim(mat2, k_padded_length, 1) + + return aten.bmm(mat1, mat2) + elif n_padded_length != 0: + mat2 = pad_dim(mat2, n_padded_length, 2) + return aten.bmm(mat1, mat2)[:, :, :-n_padded_length].contiguous() + else: + mat1 = pad_dim(mat1, m_padded_length, 1) + return aten.bmm(mat1, mat2)[:, :-m_padded_length, :].contiguous() + + +@functools.lru_cache(None) +def _pad_mm_init(): + from .joint_graph import patterns + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # sizes/values dont actually matter for initial trace + # once we get a possible match we re-trace with the actual values and verify the match still holds + + dim2a = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) + dim2b = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) + + dim3a = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) + dim3b = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) + + dim1a = functools.partial(torch.empty, (4), device=device, requires_grad=True) + + # workaround https://github.com/pytorch/pytorch/issues/97894 + # 0.113377 is a "magic" value that lets us recover the lost input arg relationship + rep = {"beta": 0.213377, "alpha": 0.113377} + + for pattern, replacement, args, workaround, extra_check in [ + ( + mm_pattern, + mm_replace, + [dim2a(), dim2b()], + {}, + should_pad_mm, + ), + ( + bmm_pattern, + bmm_replace, + [dim3a(), dim3b()], + {}, + should_pad_bmm, + ), + ( + addmm_pattern, + addmm_replace, + [dim1a(), dim2a(), dim2b()], + rep, + should_pad_addmm, + ), + ]: + assert isinstance(workaround, dict) # mypy is unable to infer the type properly + register_replacement( + pattern, + replacement, + args, + joint_fwd_bwd, + patterns, + extra_check=extra_check, + scalar_workaround=workaround, + ) + register_replacement( + pattern, + replacement, + args, + fwd_only, + patterns, + extra_check=extra_check, + scalar_workaround=workaround, + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..cdbdf9cd113c213f00a644b33803b4858d6402af --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/quantization.py @@ -0,0 +1,1980 @@ +import copy +import functools +import itertools +import math +import operator +from typing import Any, Tuple + +import torch +from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import has_free_symbols +from ..lowering import lowerings as L, require_channels_last +from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match +from ..utils import pad_listlike +from .freezing_patterns import register_freezing_graph_pattern +from .post_grad import register_lowering_pattern + +aten = torch.ops.aten +prims = torch.ops.prims +quantized_decomposed = torch.ops.quantized_decomposed +quantized = torch.ops.quantized + +""" +The quantization.py file primarily incorporates passes related to quantization fusion +in inductor, includes: +1. Dequant Promotion; +2. Conv/GEMM weight prepack with oneDNN Library; +3. Conv/GEMM quantization fusion with output quant node (if have); +4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more; + +It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference +of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is +1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM. +2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node. +Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16 +quantization. +""" + + +def _may_generate_pattern_with_dtype_convert(pattern, dtype=Arg(), dtype_convert=True): + if dtype_convert: + return CallFunction( + prims.convert_element_type.default, + pattern, + dtype, + ) + else: + return pattern + + +def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True): + if with_reshape: + return CallFunction( + torch.ops.aten.reshape.default, + pattern, + reshape_size, + ) + else: + return pattern + + +def _generate_linear_t_pattern( + _dequant_per_channel_pattern, + dtype, +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = CallFunction( + aten.permute.default, + _may_generate_pattern_with_dtype_convert( + _dequant_per_channel_pattern, + KeywordArg("autocast_wgt_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("permute_axes"), + ) + return t_pattern + + +""" +dequantize activation: + x = x.to(fp32) + x = x - zero_point + x = x * scale +""" +dequantize_per_tensor_activation_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.sub.Tensor, + CallFunction( + prims.convert_element_type.default, + KeywordArg("x"), + KeywordArg("x_dq_dtype"), + ), + KeywordArg("x_zp"), + ), + KeywordArg("x_scale"), +) + +dequantize_per_channel_weight_pattern = CallFunction( + quantized_decomposed.dequantize_per_channel.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("w_axis"), + KeywordArg("w_quant_min"), + KeywordArg("w_quant_max"), + KeywordArg("w_dtype"), +) + +dequantize_per_channel_to_bf16_weight_pattern = ( + _may_generate_pattern_with_dtype_convert( + dequantize_per_channel_weight_pattern, + KeywordArg("autocast_wgt_dtype"), + ) +) + +dequantize_per_channel_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + +dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_to_bf16_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + + +def get_dequantize_qconv_pt2e_pattern(users=1): + return CallFunction( + torch.ops.onednn.qconv2d_pointwise.default, + KeywordArg("x"), + KeywordArg("x_scale"), # x_scale + KeywordArg("x_zp"), # x_zp + KeywordArg("packed_weight"), # packed_weight + KeywordArg("w_scale"), # w_scale + KeywordArg("w_zp"), # w_zp + KeywordArg("b"), # bias + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("groups"), + KeywordArg("inv_output_scale"), # inv_output_scale = 1.0 + KeywordArg("output_zero_point"), # output_zero_point = 0 + KeywordArg("output_dtype"), # output_dtype = None + KeywordArg("attr"), # attr = "none" + Arg(), # scalars + Arg(), # algorithm + _users=users, + ) + + +def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors): + qlinear_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + return CallFunction( + qlinear_op, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("b"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("postop_name"), + KeywordArg("postop_args"), + KeywordArg("postop_algorithm"), + ) + + +dequantize_accum_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.sub.Tensor, + CallFunction( + prims.convert_element_type.default, + KeywordArg("accum"), + KeywordArg("accum_dq_dtype"), + ), + KeywordArg("accum_zp"), + ), + KeywordArg("accum_scale"), +) + + +def generate_pattern_with_binary( + binary_post_op, + computation_call, + extra_input_pattern, + int8_mixed_bf16_with_inplace_add=False, +): + binary_pattern = CallFunction( + binary_post_op, + computation_call, + extra_input_pattern, + ) + return _may_generate_pattern_with_dtype_convert( + binary_pattern, + KeywordArg("convert_dtype_after_inplace_add"), + int8_mixed_bf16_with_inplace_add, + ) + + +def generate_pattern_with_unary(computation_call, unary_post_op): + if unary_post_op is not None: + if unary_post_op == aten.hardtanh.default: + return CallFunction( + aten.clamp_max, + CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")), + KeywordArg("max_value"), + ) + if unary_post_op == aten.hardswish.default: + return CallFunction( + aten.div, + CallFunction( + aten.mul, + computation_call, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, + CallFunction(aten.add, computation_call, 3), + 0, + ), + 6, + ), + ), + 6, + ) + else: + return CallFunction( + unary_post_op, + computation_call, + ) + return computation_call + + +def generate_pattern_with_output_quant(computation_call, dtype=torch.float32): + """ + quantize output: + output = round(output * o_inv_scale) + output = output + zero_point + output = clamp_min(output, 0) + output = clamp_max(output, 127) + output = output.to(uint8) + """ + assert dtype in [torch.float32, torch.bfloat16] + quantized_op_output_pattern_pt2e = CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.clamp_max.default, + CallFunction( + aten.clamp_min.default, + CallFunction( + aten.add.Tensor, + CallFunction( + aten.round.default, + CallFunction( + aten.mul.Tensor, + _may_generate_pattern_with_dtype_convert( + computation_call, + KeywordArg("autocast_output_quant_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("o_inv_scale"), + ), + ), + KeywordArg("o_zp"), + ), + KeywordArg("o_qmin"), + ), + KeywordArg("o_qmax"), + ), + KeywordArg("o_dtype"), + ) + return quantized_op_output_pattern_pt2e + + +def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value): + if kwarg_name in check_node.kwargs: + actual_value = check_node.kwargs[kwarg_name] + return actual_value == expected_value + else: + assert len(check_node.args) >= (args_index + 1) + actual_value = check_node.args[args_index] + return actual_value == expected_value + + +def _is_valid_quantized_conv2d_optimization_pattern(output_dtype): + def fn(match): + if output_dtype is not None: + # Only keep matched pattern with same output_dtype + qconv_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qconv2d_pointwise + )[0] + return _check_node_kwarg_arg_value( + qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype + ) + return True + + return fn + + +def _register_quantized_conv_lowering( + pattern, + pass_number, + computation_op, + output_dtype, + unary_attr, + original_pattern_output_dtype=torch.float32, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_conv2d_optimization_pattern(output_dtype), + pass_number=pass_number, + ) + def qconv(match: Match, *args, **kwargs): + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # Conv Params + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + assert output_dtype in [None, torch.float32, torch.bfloat16] + # Output QParams + o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0 + o_zero_point = kwargs["o_zp"] if output_dtype is None else 0 + assert ( + kwargs["output_dtype"] is original_pattern_output_dtype + ) # Expected int8-in fp32-out qconv in weight prepack phase + assert ( + kwargs["attr"] == "none" + ) # Expected no post op fused in weight prepack phase + if unary_attr.op_name == "hardtanh": + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + unary_attr.scalars_attr = [min_value, max_value] + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ) + counters["inductor"]["qconv2d_unary_matcher_count"] += 1 + counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qconv + + +def _is_valid_quantized_linear_optimization_pattern(output_dtype): + def fn(match): + if output_dtype is not None: + # Only keep matched pattern with same output_dtype + qlinear_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qlinear_pointwise + )[0] + return _check_node_kwarg_arg_value( + qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype + ) + return True + + return fn + + +def _register_quantized_linear_lowering( + pattern, + pass_number, + computation_op, + output_dtype, + unary_attr, + original_pattern_output_dtype=torch.float32, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_linear_optimization_pattern(output_dtype), + pass_number=pass_number, + ) + def qlinear(match: Match, *args, **kwargs): + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # bias + b = kwargs["b"] if "b" in kwargs else None + + # Output QParams + o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0 + o_zero_point = kwargs["o_zp"] if output_dtype is None else 0 + assert ( + kwargs["output_dtype"] is original_pattern_output_dtype + ) # Expected int8-in fp32/bf16-out qlinear in weight prepack phase + assert ( + kwargs["postop_name"] == "none" + ) # Expected no post op fused in weight prepack phase + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + o_inv_scale, + o_zero_point, + output_dtype, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ) + counters["inductor"]["qlinear_unary_matcher_count"] += 1 + counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qlinear + + +def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype): + # Check if it's a valid Conv Binary Pattern: + # * qconv2d_pointwise should only has one users + # * Extra input of binary node comes from dequant pattern + # * the two inputs of binary node should have attribute "meta" and should be tensors + # * the two inputs of binary node should have the same shape + # * All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + def fn(match): + compute_node = filter_nodes(match.nodes, torch.ops.onednn.qconv2d_pointwise)[0] + # qconv2d_pointwise should only have one user + if len(compute_node.users) != 1: + return False + binary_node_inputs = next(iter(compute_node.users)).args + assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs" + if output_dtype is not None: + extra_input_of_binary_node = None + for arg in binary_node_inputs: + if arg != compute_node: + extra_input_of_binary_node = arg + break + assert extra_input_of_binary_node is not None + # Extra input of binary node comes from dequant pattern + if (not isinstance(extra_input_of_binary_node, torch.fx.Node)) or ( + extra_input_of_binary_node.target != aten.mul.Tensor + ): + return False + + # the two inputs of binary node should have attribute "meta" and should be tensors + if not ( + hasattr(binary_node_inputs[0], "meta") + and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ) or not ( + hasattr(binary_node_inputs[1], "meta") + and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ): + return False + # the two inputs of binary node should have the same shape + if ( + binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr] + != binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr] + ): + return False + + # All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + + from .mkldnn_fusion import _get_remaining_users + + extra_input_of_pattern = ( + match.kwargs["accum"] + if output_dtype is None + else match.kwargs["accum_after_dequant"] + ) + if ( + len( + _get_remaining_users( + extra_input_of_pattern, + compute_node, + ) + ) + > 1 + or extra_input_of_pattern == compute_node.args[0] + ): + return False + return True + + return fn + + +def _register_quantized_conv_binary_lowering( + pattern, + pass_number, + computation_op, + output_dtype, + binary_unary_attr, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_conv_binary_optimization_pattern(output_dtype), + pass_number=pass_number, + ) + def qconv_binary(match: Match, *args, **kwargs): + x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"] + accum = ( + kwargs["accum"] if output_dtype is None else kwargs["accum_after_dequant"] + ) + accum_scale = kwargs["accum_scale"] if output_dtype is None else 1.0 + accum_zp = kwargs["accum_zp"] if output_dtype is None else 0 + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + # Output QParams + o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0 + o_zero_point = kwargs["o_zp"] if output_dtype is None else 0 + + accum.realize() + from .mkldnn_fusion import _can_be_inplace + + assert _can_be_inplace( + accum + ), "QConv Binary Inplace Fusion requires accum is not an alias or mutation." + + computation_args = ( + x, + x_scale, + x_zp, + accum, + accum_scale, + accum_zp, + packed_weight, + w_scale, + w_zp, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + binary_unary_attr.binary_op_name, + binary_unary_attr.alpha, + binary_unary_attr.unary_op_name, + binary_unary_attr.scalars_attr, + binary_unary_attr.algorithm_attr, + ) + counters["inductor"]["qconv2d_binary_matcher_count"] += 1 + counters["inductor"]["qconv2d_binary_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qconv_binary + + +def _register_quantization_unary_fusion(): + class UnaryAttr: + def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): + self.op_name = op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + # QConv2d + # Priority 1 to match: QConv2d Unary pattern with int8 output + # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. + # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant + conv_unary_replace_patterns = { + UnaryAttr("none", [], ""): generate_pattern_with_output_quant( + get_dequantize_qconv_pt2e_pattern(1), + dtype=original_pattern_output_dtype, + ), + UnaryAttr("relu", [], ""): generate_pattern_with_output_quant( + generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(1), aten.relu.default + ), + dtype=original_pattern_output_dtype, + ), + UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant( + generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(1), aten.hardtanh.default + ), + dtype=original_pattern_output_dtype, + ), + UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant( + generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(2), aten.hardswish.default + ), + dtype=original_pattern_output_dtype, + ), + } + + for unary_attr, patterns in conv_unary_replace_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_quantized_conv_lowering( + patterns, + 1, # pass_number + torch.ops.onednn.qconv2d_pointwise, # computation_op + None, # output_dtype, None is the default value for int8 output + unary_attr, # unary_attr + original_pattern_output_dtype=original_pattern_output_dtype, + ) + + # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output + conv_unary_replace_float_out_patterns = { + UnaryAttr("relu", [], ""): generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(1), aten.relu.default + ), + UnaryAttr("hardtanh", [], ""): generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(1), aten.hardtanh.default + ), + UnaryAttr("hardswish", [], ""): generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(2), aten.hardswish.default + ), + } + + for unary_attr, patterns in conv_unary_replace_float_out_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_quantized_conv_lowering( + patterns, + 2, # pass_number + torch.ops.onednn.qconv2d_pointwise, # computation_op + original_pattern_output_dtype, # output_dtype + unary_attr, # unary_attr + original_pattern_output_dtype=original_pattern_output_dtype, + ) + + # QLinear + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) + # Priority 1 to match: QLinear Unary pattern with int8 output + linear_unary_replace_patterns = { + UnaryAttr("none", [], ""): generate_pattern_with_output_quant( + qlinear_pattern, + dtype=original_pattern_output_dtype, + ), + UnaryAttr("relu", [], ""): generate_pattern_with_output_quant( + generate_pattern_with_unary(qlinear_pattern, aten.relu.default), + dtype=original_pattern_output_dtype, + ), + } + + for unary_attr, patterns in linear_unary_replace_patterns.items(): + _register_quantized_linear_lowering( + patterns, + 1, # pass_number + torch.ops.onednn.qlinear_pointwise, # computation_op + None, # output_dtype + unary_attr, # unary_attr + original_pattern_output_dtype=original_pattern_output_dtype, + ) + + # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output + linear_unary_replace_float_out_patterns = { + UnaryAttr("relu", [], ""): generate_pattern_with_unary( + qlinear_pattern, aten.relu.default + ), + } + + for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): + _register_quantized_linear_lowering( + patterns, + 2, # pass_number + torch.ops.onednn.qlinear_pointwise, # computation_op + original_pattern_output_dtype, # output_dtype + unary_attr, # unary_attr + original_pattern_output_dtype=original_pattern_output_dtype, + ) + + +def _register_quantization_binary_fusion(): + class BinaryUnaryAttr: + def __init__( + self, + binary_op_name: str, + alpha=None, + unary_op_name: str = "none", + scalars_attr=None, + algorithm_attr=None, + ): + self.binary_op_name = binary_op_name + self.alpha = alpha if alpha else 1.0 + self.unary_op_name = unary_op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + for int8_mixed_bf16_with_inplace_add in [False, True]: + # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output + binary_replace_patterns = { + BinaryUnaryAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + ), + dtype=torch.bfloat16 + if int8_mixed_bf16_with_inplace_add + else torch.float32, + ), + BinaryUnaryAttr( + "sum", 1.0, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + ), + aten.relu.default, + ), + dtype=torch.bfloat16 + if int8_mixed_bf16_with_inplace_add + else torch.float32, + ), + } + + for binary_unary_attr, patterns in binary_replace_patterns.items(): + _register_quantized_conv_binary_lowering( + patterns, + 0, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + None, # output_dtype + binary_unary_attr, # binary_unary_attr + ) + + # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = { + BinaryUnaryAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + ), + aten.relu.default, + ), + } + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + if int8_mixed_bf16_with_inplace_add: + _register_quantized_conv_binary_lowering( + patterns, + 0, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + # Note that for int8-mixed-bf16 and non-inplace add, because we have + # q-dq inserted at extra input of add, so the non-inplace add has bf16 and fp32 inputs, + # the output dtype will be float32. + # For inplace add, there is a extra to_bf16 node at add output, so the fusion pattern has bfloat16 output. + torch.bfloat16, + binary_unary_attr, # binary_unary_attr + ) + else: + _register_quantized_conv_binary_lowering( + patterns, + 1, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + torch.float32, + binary_unary_attr, # binary_unary_attr + ) + + # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = { + BinaryUnaryAttr("sum", 1.0, "none", [], ""): generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + ), + } + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_conv_binary_lowering( + patterns, + 1 if int8_mixed_bf16_with_inplace_add else 2, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + # Same output dtype setting as conv-add-relu pattern + torch.bfloat16 if int8_mixed_bf16_with_inplace_add else torch.float32, + binary_unary_attr, # binary_unary_attr + ) + + +def _is_valid_quantized_maxpool2d_optimization_pattern(): + def fn(match): + # Only match the pattern which max_pool2d_with_indices returns value + # instead of indices. + get_item_node = filter_nodes(match.nodes, operator.getitem)[0] + return get_item_node.args[1] == 0 + + return fn + + +def _register_quantized_maxpool2d_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(), + ) + def qmaxpool2d(match: Match, *args, **kwargs): + x = kwargs["x"] + kernel_size = kwargs["kernel_size"] + stride = kwargs["stride"] if ("stride" in kwargs) else None + padding = kwargs["padding"] if ("padding" in kwargs) else 0 + dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1 + ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False + + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + kernel_size = pad_listlike(kernel_size, 2) + stride = pad_listlike(stride, 2) + padding = pad_listlike(padding, 2) + dilation = pad_listlike(dilation, 2) + + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + + computation_args = ( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ) + computation_args, _ = require_channels_last(computation_op, *computation_args) + return L[computation_op](*computation_args) + + return qmaxpool2d + + +def _register_quantization_maxpool2d(): + # Currently, the default parameters are not in FX Graph generated by Dynamo export. + # So, if user defines nn.MaxPool2d with different assignment of default parameter, + # it will generate graph with different number of input nodes and hence + # different pattern to be matched. + # Refer to the issue: https://github.com/pytorch/pytorch/issues/105901 + max_pool2d_args_list = [ + [ + KeywordArg("stride"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("ceil_mode"), + ], + ] + + for max_pool2d_args in max_pool2d_args_list: + dequantize_maxpool2d_pattern = CallFunction( + aten.max_pool2d_with_indices.default, + dequantize_per_tensor_activation_pattern, + KeywordArg("kernel_size"), + *max_pool2d_args, + ) + dequantize_maxpool2d_get_item_pattern = CallFunction( + operator.getitem, + dequantize_maxpool2d_pattern, + Arg(), + ) + _register_quantized_maxpool2d_lowering( + generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern), + quantized.max_pool2d.default, + ) + + +def _is_input_output_same_scale_zp(check_node): + def fn(match): + # Ensure all the inputs and output has same scale and zero point + # Step 1: Check inputs/output zero point + sub_nodes = filter_nodes(match.nodes, aten.sub.Tensor) + zero_points = [node.args[1] for node in sub_nodes] + add_nodes = filter_nodes(match.nodes, aten.add.Tensor) + assert len(add_nodes) == 1, "expect only 1 add node at output quant pattern" + zero_points.append(add_nodes[0].args[1]) + if not all(zero_point == zero_points[0] for zero_point in zero_points): + return False + + # Step 2: Check inputs/output scale + mul_nodes = filter_nodes(match.nodes, aten.mul.Tensor) + # We need to find mul node at output since the scale value is reciprocal to input scale. + # Mul node at output should connect to cat node directly. + scales = [ + ( + mul_node.args[1] + if mul_node.args[0].target is check_node # type: ignore[union-attr] + else 1.0 / mul_node.args[1] # type: ignore[operator] + ) + for mul_node in mul_nodes + ] + if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type] + return False + + return True + + return fn + + +def _register_quantized_cat_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_input_output_same_scale_zp(aten.cat.default), + ) + def qcat(match: Match, inputs, dim, **kwargs): + # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...] + uint8_inputs = [input[0] for input in inputs] + return L[computation_op](uint8_inputs, dim) + + return qcat + + +_raw_dequantize_per_tensor_activation_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.sub.Tensor, + CallFunction( + prims.convert_element_type.default, + Arg(), + Arg(), + ), + Arg(), + ), + Arg(), +) + + +def _register_quantization_cat(): + dequantize_cat_pattern = CallFunction( + aten.cat.default, + ListOf(_raw_dequantize_per_tensor_activation_pattern), + KeywordArg("dim"), + ) + _register_quantized_cat_lowering( + generate_pattern_with_output_quant(dequantize_cat_pattern), + aten.cat, + ) + + +def _register_quantized_reshape_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_input_output_same_scale_zp(aten.reshape.default), + ) + def qreshape(match: Match, *args, **kwargs): + qx = kwargs["x"] + shape = kwargs["shape"] + counters["inductor"]["qreshape_matcher_count"] += 1 + counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes) + return L[computation_op](qx, shape) + + return qreshape + + +def _register_quantization_reshape(): + dequantize_reshape_pattern = CallFunction( + torch.ops.aten.reshape.default, + dequantize_per_tensor_activation_pattern, + KeywordArg("shape"), + ) + _register_quantized_reshape_lowering( + generate_pattern_with_output_quant(dequantize_reshape_pattern), + aten.reshape, + ) + + +def _register_quantization_lowerings(): + _register_quantization_unary_fusion() + _register_quantization_binary_fusion() + _register_quantization_maxpool2d() + _register_quantization_cat() + _register_quantization_reshape() + + +def _is_valid_dequant_promotion_pattern(dtype=torch.float32): + def _inner(match): + assert dtype in [torch.float32, torch.bfloat16] + dequant_pattern_end_node = match.output_node() + if dequant_pattern_end_node.target not in [ + aten.mul.Tensor, + prims.convert_element_type.default, + aten.reshape.default, + ]: + return False + + if dequant_pattern_end_node.target is aten.reshape.default: + mul_node = ( + dequant_pattern_end_node.args[0] # pattern: linear <- reshape <- mul + if dtype == torch.float32 + else dequant_pattern_end_node.args[0].args[ + 0 + ] # pattern: linear <- reshape <- to_bf16 <- mul + ) + else: + mul_node = ( + dequant_pattern_end_node # pattern: linear <- mul + if dtype == torch.float32 + else dequant_pattern_end_node.args[ + 0 + ] # pattern: linear <- to_bf16 <- mul + ) + + sub_node = mul_node.args[0] + to_fp32_node = sub_node.args[0] + if ( + mul_node.target is aten.mul.Tensor + and sub_node.target is aten.sub.Tensor + and to_fp32_node.target is prims.convert_element_type.default + and len(list(dequant_pattern_end_node.users)) > 1 + ): + # If dequant pattern has more than 1 users, then do dequant promoted + return True + return False + + return _inner + + +def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_promotion_pattern(dtype), + pass_number=pass_number, + ) + def dequant_promotion(match: Match, *args, **kwargs): + # Dequant_promotion will transform + # graph 1: + # quant + # + - - - | - - - + + # | dequant | + # | / \ | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # into: + # graph 2: + # quant + # + - - / - \ - - + + # |dequant dequant| + # | | | | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # In graph 1, the dequant node is shared by node1 and node2, + # as a result, neither node1 nor node2 could form an int8 + # fusion pattern. + # After this transformation, the graph 2 could hit the int8 + # fusion pattern: dequant-node-quant, respectively for + # node1 and node2. + assert dtype in [torch.float32, torch.bfloat16] + + def clone_to_new_node(graph, source_node, user_node): + # Clone the source_node to a new node + # Replace user_node's input from source_node to new_node + assert ( + source_node.op == "call_function" + ), "clone_to_new_node only support node.op call_function" + with graph.inserting_before(user_node): + new_node = graph.call_function( + source_node.target, + args=source_node.args, + kwargs=source_node.kwargs, + ) + new_node.meta = copy.copy(source_node.meta) + user_node.replace_input_with(source_node, new_node) + return new_node + + # Find the start node and end node of a dequant pattern + # * End node should be the match.output_node() + # * Start node should be the node of dtype convert to float32 + dequant_pattern_end_node = match.output_node() + assert dequant_pattern_end_node.target in [ + aten.mul.Tensor, + prims.convert_element_type.default, + aten.reshape.default, + ] + + # For a dequant pattern, we should expect see the node list as: + # * OPT(aten.reshape.default) + # * OPT(prims.convert_element_type.default) (to_bf16) + # * aten.mul + # * aten.sub + # * prims.convert_element_type.default (to_fp32) + def _find_first_node_in_dequant_pattern(_node): + if ( + _node.target is prims.convert_element_type.default + and _node.args[1] == torch.float32 + ): + # For a dequant pattern, we expect the start node is a to_fp32 node + return _node + else: + assert ( + len(_node.args) >= 1 + ), "In in dequant pattern, each node should have more than 1 arg." + return _find_first_node_in_dequant_pattern(_node.args[0]) + + dequant_pattern_start_node = _find_first_node_in_dequant_pattern( + dequant_pattern_end_node + ) + + # Clone the dequant pattern for each user node + graph = match.graph + user_node_list = list(dequant_pattern_end_node.users) + for user_node in user_node_list[1:]: + _source_node = dequant_pattern_end_node + _user_node = user_node + while _source_node != dequant_pattern_start_node.args[0]: + _user_node = clone_to_new_node(graph, _source_node, _user_node) + _source_node = _source_node.args[0] # type: ignore[assignment] + + counters["inductor"]["dequant_promotion_matcher_count"] += 1 + counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes) + + +def _is_valid_dequant_conv2d_pattern(dtype): + def _inner(match): + # Here we do some further check to ensure: + # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now. + # 2. The dequant pattern has only 1 user of conv2d node. + # If these conditions don't meet, we will not + # insert weight prepack node into the matched pattern. + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + input_meta_value = conv_node.args[0].meta.get("val") + weight_meta_value = conv_node.args[1].meta.get("val") + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type != "cpu" + or meta_value.dim() != 4 + ): + # Only support conv2d now + return False + + assert dtype in [torch.float32, torch.bfloat16] + if dtype == torch.float32: + mul_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + mul_node = convert_to_bf16.args[0] + sub_node = mul_node.args[0] + to_fp32_node = sub_node.args[0] + + assert to_fp32_node.target is prims.convert_element_type.default + assert sub_node.target is aten.sub.Tensor + assert mul_node.target is aten.mul.Tensor + if ( + len(list(to_fp32_node.users)) != 1 + or len(list(sub_node.users)) != 1 + or len(list(mul_node.users)) != 1 + ): + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + return True + + return _inner + + +def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_conv2d_pattern(dtype), + pass_number=pass_number, + ) + def qconv_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qconv2d_pointwise <- onednn.qconv_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + if dtype == torch.float32: + mul_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + mul_node = convert_to_bf16.args[0] # type: ignore[union-attr] + sub_node = mul_node.args[0] # type: ignore[union-attr] + to_fp32_node = sub_node.args[0] # type: ignore[union-attr] + has_clone_to_channel_last_node_in_pattern = ( + conv_node.args[1].target is aten.clone.default # type: ignore[union-attr] + ) + clone_node = ( + conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None + ) + + if dtype == torch.float32: + dequant_per_channel = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + else: + weight_to_bf16_node = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] + + assert ( + dequant_per_channel.target # type: ignore[union-attr] + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Conv Params + bias, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(conv_node): + # Insert weight prepack node and the QConv node + packed_weight_inputs = ( + qw, + w_scale, + x_scale, + x_zp, + stride, + padding, + dilation, + groups, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qconv_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: Tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + 1.0, # inv_output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # attr + [], # scalars + "", # algorithm + ) + new_conv_node = graph.call_function( + torch.ops.onednn.qconv2d_pointwise.default, args=new_args + ) + conv_node.replace_all_uses_with(new_conv_node) + new_conv_node.meta.update(conv_node.meta) + + # Erase the original conv node + graph.erase_node(conv_node) + # Erase the dequant pattern + if dtype == torch.bfloat16: + graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined] + # Erase the dequant pattern + graph.erase_node(mul_node) + graph.erase_node(sub_node) + graph.erase_node(to_fp32_node) + # Erase the dequant per channel pattern + if clone_node is not None: + graph.erase_node(clone_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + graph.erase_node(dequant_per_channel) + counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_convolution_node_pattern( + _dequant_per_channel_pattern, dtype=torch.float32 +): + assert dtype in [torch.float32, torch.bfloat16] + dequant_convolution_node_pattern = CallFunction( + aten.convolution.default, + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + _dequant_per_channel_pattern, + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("is_transposed"), + KeywordArg("out_padding"), + KeywordArg("groups"), + ) + return dequant_convolution_node_pattern + + +def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): + assert dtype in [torch.float32, torch.bfloat16] + return ( + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_weight_pattern, + dtype, + ), + # There is another pattern due to the pass of convert_conv_weights_to_channels_last + # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. + # Depend on some heuristics, it may or may not insert to(channel_last) node + # between convolution and dequant_per_channel node + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_clone_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_clone_weight_pattern, + dtype, + ), + ) + + +def _get_linear_node(match, input_dim_exceeds_two, input_contiguous): + output_reshape_node = None + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node = match.output_node() + assert output_reshape_node.target is aten.reshape.default + linear_node = output_reshape_node.args[0] + else: + linear_nodes = filter_nodes(match.nodes, aten.bmm.default) + assert len(linear_nodes) == 1 + linear_node = linear_nodes[0] + else: + linear_node = match.output_node() + + assert linear_node.target in ( + aten.addmm.default, + aten.mm.default, + aten.bmm.default, + ) + return linear_node, output_reshape_node + + +def _get_linear_dq_mul_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous +): + act_reshape_node = None + activation_to_bf16_node = None + act_expand_node = None + if input_dim_exceeds_two: + if input_contiguous: + act_reshape_node = linear_node.args[input_index] + assert act_reshape_node.target is aten.reshape.default + if dtype == torch.float32: + # pattern: linear -> reshape -> mul + mul_node = act_reshape_node.args[0] + else: + # pattern: linear -> reshape -> to_bf16 -> mul + activation_to_bf16_node = act_reshape_node.args[0] + mul_node = activation_to_bf16_node.args[0] + else: + # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous + act_expand_node = linear_node.args[input_index] + assert act_expand_node.target is aten.expand.default + if dtype == torch.float32: + mul_node = act_expand_node.args[0] + else: + activation_to_bf16_node = act_expand_node.args[0] + mul_node = activation_to_bf16_node.args[0] + else: + if dtype == torch.float32: + # pattern: linear -> mul + mul_node = linear_node.args[input_index] + else: + # pattern: linear -> to_bf16 -> mul + activation_to_bf16_node = linear_node.args[input_index] + mul_node = activation_to_bf16_node.args[0] + return mul_node, act_reshape_node, activation_to_bf16_node, act_expand_node + + +def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous): + def _inner(match): + # Check dequant pattern has only 1 user. + ( + linear_node, + _, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + + input_index = 1 if linear_node.target is aten.addmm.default else 0 + assert dtype in [torch.float32, torch.bfloat16] + + ( + mul_node, + _, + _, + _, + ) = _get_linear_dq_mul_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + sub_node = mul_node.args[0] + to_fp32_node = sub_node.args[0] + + assert to_fp32_node.target is prims.convert_element_type.default + assert sub_node.target is aten.sub.Tensor + assert mul_node.target is aten.mul.Tensor + if ( + len(list(to_fp32_node.users)) != 1 + or len(list(sub_node.users)) != 1 + or len(list(mul_node.users)) != 1 + ): + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + + # Extra check for bmm pattern + if input_dim_exceeds_two and not input_contiguous: + # Check for act + # Act expand size should be exactly same as act size + act_expand_size = match.kwargs["act_expand_size"] + act_node = match.kwargs["x"] + if not ( + hasattr(act_node, "meta") + and isinstance(act_node.meta.get("val", None), torch.Tensor) + and (act_node.meta["val"].size() == torch.Size(act_expand_size)) + ): + return False + + # Check for wgt + # wgt permute dims should be [1, 0] + wgt_permute_dims = match.kwargs["permute_axes"] + if wgt_permute_dims != [1, 0]: + return False + + # Check below wgt size items: + # wgt before expand should with dim 2 + # Expand size should with dim 3 + # Expand size[0] should same as act size[0] + # Expand size[1] should same as wgt size[1] + # Expand size[2] should same as wgt size[0] + qweight_node = match.kwargs["q_weight"] + wgt_expand_size = match.kwargs["wgt_expand_size"] + if not ( + hasattr(qweight_node, "meta") + and isinstance(qweight_node.meta.get("val", None), torch.Tensor) + and len(qweight_node.meta["val"].size()) == 2 + and len(wgt_expand_size) == 3 + and wgt_expand_size[0] == act_node.meta["val"].size()[0] + and wgt_expand_size[1] == qweight_node.meta["val"].size()[1] + and wgt_expand_size[2] == qweight_node.meta["val"].size()[0] + ): + return False + + return True + + return _inner + + +def _register_qlinear_weight_prepack_pass( + pattern, + pass_number, + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, +): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_linear_pattern( + dtype, input_dim_exceeds_two, input_contiguous + ), + pass_number=pass_number, + ) + def qlinear_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + mm/addmm <- t <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + ( + linear_node, + output_reshape_node, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + ( + mul_node, + act_reshape_node, + activation_to_bf16_node, + act_expand_node, + ) = _get_linear_dq_mul_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + sub_node = mul_node.args[0] + to_fp32_node = sub_node.args[0] + + if input_dim_exceeds_two and not input_contiguous: + wgt_expand_node = linear_node.args[weight_index] + assert wgt_expand_node.target is aten.expand.default + t_node = wgt_expand_node.args[0] + else: + t_node = linear_node.args[weight_index] + + if dtype == torch.float32: + dequant_per_channel = t_node.args[0] + else: + weight_to_bf16_node = t_node.args[0] + dequant_per_channel = weight_to_bf16_node.args[0] + assert ( + dequant_per_channel.target + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Params + bias = kwargs["b"] if "b" in kwargs else None + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + # Insert weight prepack node and the qlinear node + packed_weight_inputs = ( + qw, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qlinear_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: Tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + Node = torch.fx.node.Node + if isinstance(x_scale, Node) and isinstance(x_zp, Node): + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.tensor, args=new_args + ) + else: + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.default, args=new_args + ) + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_reshape_node.meta) + else: + if bias: + output_add_node_for_bias = match.output_node() + assert output_add_node_for_bias.target is aten.add.Tensor + output_add_node_for_bias.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_add_node_for_bias.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + + # Erase the original linear node + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(output_reshape_node) + elif not input_contiguous and bias: + graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined] + graph.erase_node(linear_node) + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(act_reshape_node) + else: + graph.erase_node(act_expand_node) + graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined] + if dtype == torch.bfloat16: + graph.erase_node(activation_to_bf16_node) + # Erase the dequant pattern + graph.erase_node(mul_node) + graph.erase_node(sub_node) + graph.erase_node(to_fp32_node) + # Erase the dequant per channel pattern + graph.erase_node(t_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + graph.erase_node(dequant_per_channel) + + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_linear_node_pattern( + _dequant_per_channel_pattern, dtype=torch.float32, input_dim_exceeds_two=False +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + dequant_linear_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern + + +def _generate_dequant_bmm_node_pattern( + _dequant_per_channel_pattern, + dtype=torch.float32, + with_bias=False, +): + # When activation of linear dim exceed 2 and not contiguous + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + + assert dtype in [torch.float32, torch.bfloat16] + dequant_bmm_pattern = CallFunction( + aten.bmm.default, + CallFunction( + aten.expand.default, + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_expand_size"), + ), + CallFunction( + aten.expand.default, + t_pattern, + KeywordArg("wgt_expand_size"), + ), + ) + + def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias): + if _with_bias: + return CallFunction( + aten.add.Tensor, + _dequant_bmm_pattern, + KeywordArg("b"), + ) + else: + return _dequant_bmm_pattern + + return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias) + + +def _generate_qlinear_weight_prepack_patterns( + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, + with_bias=False, +): + if input_dim_exceeds_two and not input_contiguous: + return _generate_dequant_bmm_node_pattern( + dequantize_per_channel_weight_pattern, + dtype, + with_bias, + ) + else: + return _generate_dequant_linear_node_pattern( + dequantize_per_channel_weight_pattern, dtype, input_dim_exceeds_two + ) + + +def _register_dequant_promotion(): + dequant_pattern_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False] + ) + for dtype, input_dim_exceeds_two in dequant_pattern_cases: + # 4 dequantization patterns will be matched based on the dtype and input dimension size. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 + # quant + # + - - - - | - - - - + + # | dequant | + # | | | + # | OPT(to_bf16) | + # | | | + # | OPT(reshape) | + # | / \ | + # | node1 node2 | + # + - - | - - - | - - + + # OPT(reshape) OPT(reshape) + # + - - | - - - | - - + + # OPT(to_fp32) OPT(to_fp32) + # + - - | - - - | - - + + # quant quant + _register_dequant_promotion_pass( + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + with_reshape=input_dim_exceeds_two, + ), + pass_number=0, + dtype=dtype, + ) # pass_number=0 to run before weight prepack + + +def _register_qconv_weight_prepack(): + for dtype in [torch.float32, torch.bfloat16]: + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qconv_weight_prepack_pass( + weight_prepack_pattern, pass_number=1, dtype=dtype + ) + + +def _register_qlinear_weight_prepack(): + # 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous. + # Then convert the pattern into a QLinear node with int8_fp32/bf16. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | OPT(reshape) permute | + # | \ / | + # | addmm/mm | + # | | | + # | OPT(reshape) | + + # Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous + # Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | expand permute | + # | \ | | + # | expand | + # | / | + # | bmm | + # | | | + # | OPT(add) | + + linear_weight_prepack_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False] + ) + + # Step 1: register patterns from mm and addmm + for dtype, input_dim_exceeds_two in linear_weight_prepack_cases: + weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( + dtype, input_dim_exceeds_two + ) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qlinear_weight_prepack_pass( + weight_prepack_pattern, + pass_number=1, + dtype=dtype, + input_dim_exceeds_two=input_dim_exceeds_two, + ) + + # Step 2: register patterns from bmm + # Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous + # refer to: + # https://github.com/pytorch/pytorch/blob/ + # 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968 + # in this case, we can convert it back to qlinear + for dtype, with_bias in itertools.product( + [torch.float32, torch.bfloat16], [True, False] + ): + bmm_pattern = _generate_qlinear_weight_prepack_patterns( + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + with_bias=with_bias, + ) + _register_qlinear_weight_prepack_pass( + bmm_pattern, + pass_number=1 + if with_bias + else 2, # if with_bias, there is an output add, so we should try to match it firstly + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + ) + + +@functools.lru_cache(None) +def _register_quantization_weight_pack_pass(): + # Step 1: Dequant promotion for int8-mixed-fp32/bf16 + _register_dequant_promotion() + + # Step 2: QConv weight prepack + _register_qconv_weight_prepack() + + # Step 3: QLinear weight prepack + _register_qlinear_weight_prepack() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py new file mode 100644 index 0000000000000000000000000000000000000000..45863aac16ec190da804b557d31dbf0f9b5702bb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/replace_random.py @@ -0,0 +1,139 @@ +import collections +import logging + +import torch + +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from .. import config, inductor_prims +from ..pattern_matcher import ( + CallFunctionVarArgs, + Match, + PatternMatcherPass, + register_graph_pattern, +) +from ..virtualized import V + +log = logging.getLogger(__name__) +patterns = PatternMatcherPass() +aten = torch.ops.aten + + +def replace_random_passes(gm: torch.fx.GraphModule): + """Modify the given FX graph to use backend-native random ops""" + if config.fallback_random: + return 0 + + count = patterns.apply(gm) + count += fuse_seed_creation_pass(gm.graph) + + return count + + +def fuse_seed_creation_pass(graph: torch.fx.Graph): + """ + Horizontally fuse all the seed generation on each device + + a = inductor_seed(dev) + b = inductor_seed(dev) + + Becomes: + seeds = inductor_seeds(2, dev) + a = inductor_lookup_seed(seeds, 0) + b = inductor_lookup_seed(seeds, 1) + + We do this because seed creation is entirely launch overhead bound. + """ + device_seeds = collections.defaultdict(list) + for node in graph.nodes: + if CallFunctionVarArgs(inductor_prims.seed).match(node): + device_seeds[node.args[0]].append(node) + + if not device_seeds: + return 0 + + for device, seeds in device_seeds.items(): + with graph.inserting_before(seeds[0]): + combined = graph.call_function(inductor_prims.seeds, (len(seeds), device)) + with V.fake_mode: + combined.meta["val"] = torch.empty( + [len(seeds)], device=device, dtype=torch.int64 + ) + combined.meta["tensor_meta"] = _extract_tensor_metadata( + combined.meta["val"] + ) + + for idx, seed in enumerate(seeds): + with graph.inserting_before(seed): + new_seed = graph.call_function( + inductor_prims.lookup_seed, (combined, idx) + ) + seed.replace_all_uses_with(new_seed) + new_seed.meta.update(seed.meta) + graph.erase_node(seed) + + return len(device_seeds) + + +def default_kwargs(device): + return {} + + +def get_device(device): + if device is not None: + return device + return torch.empty([]).device # default device + + +@register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns) +def replace_random( + match: Match, + size, + *, + generator=None, + dtype=None, + device=None, + layout=None, + pin_memory=None, +): + if generator is not None: + return + + def replacement(size): + result = inductor_prims.random( + size, inductor_prims.seed(device), mode, **default_kwargs(device) + ) + if dtype is not None: + result = result.to(dtype) + return result + + mode = { + aten.rand: "rand", + aten.randn: "randn", + }[ + match.output_node().target.overloadpacket # type: ignore[union-attr] + ] # type: ignore[union-attr] + device = get_device(device) + match.replace_by_example(replacement, [size]) + + +@register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns) +def replace_randint( + match: Match, + low, + high, + size, + *, + dtype=torch.int64, + device=None, + layout=None, + pin_memory=None, +): + def replacement(size): + result = inductor_prims.randint(low, high, size, inductor_prims.seed(device)) + return result.to(dtype) + + device = get_device(device) + match.replace_by_example(replacement, [size]) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfae7eb27beb1cc84b34eaedb2d420403835ed15 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bab912deabfc6dac3f213cb008e882cfda54e2b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2d96992bbd3e3b74460e9c192abce2ddb146016 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cbb0cd6380362819180be0c26e79c40f46f2380 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e642e443acf8482f9d1a0b116ddc0f56ddd7e04 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/central_index.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/central_index.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70e7766ba1fc10cef24174df91b8982910212bb2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/central_index.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py new file mode 100644 index 0000000000000000000000000000000000000000..5d1e8fd009256beef3cdfcb93f360ef71a77bbc7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py @@ -0,0 +1,232 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py new file mode 100644 index 0000000000000000000000000000000000000000..c20cf28a7a56599d73f5bb28961a313c6f9f2bb6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py @@ -0,0 +1,142 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5, _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, sub_Tensor_1, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, sub_Tensor_1) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default = CallFunction(aten.clone.default, div_Tensor) +_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value')) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, convert_element_type_default_5, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, convert_element_type_default_5) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_half_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) +_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value')) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py new file mode 100644 index 0000000000000000000000000000000000000000..8a156501274654b9f4286f3d9485609bb3c71d80 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py @@ -0,0 +1,236 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, sub_Tensor_1) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_15_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_15_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py new file mode 100644 index 0000000000000000000000000000000000000000..7b1a4d6ad10a096e1387619f27b2ed6489a7d89a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py @@ -0,0 +1,202 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default = CallFunction(aten.clone.default, div_Tensor_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py new file mode 100644 index 0000000000000000000000000000000000000000..2067fdbdfb028ace5aa01c2323c3ca16f44b02c1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py @@ -0,0 +1,206 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default = CallFunction(aten.clone.default, div_Tensor_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py new file mode 100644 index 0000000000000000000000000000000000000000..ebfdf9b888d7ffa8593d35575b0f3135d7a5a10f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py @@ -0,0 +1,213 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f318a185ce50bbebbe5935d48ebc2d70df6f929 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11ec53de566925a58b81165c51517ec272cb06b2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/bmm.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/bmm.py new file mode 100644 index 0000000000000000000000000000000000000000..1878cef79f0f5c8a7358dc96cdee63b176005109 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/bmm.py @@ -0,0 +1,128 @@ +import torch + +from ..lowering import register_lowering +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import ceildiv as cdiv, use_aten_gemm_kernels, use_triton_template + +from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options + +aten = torch.ops.aten + + +def bmm_grid(b, m, n, meta): + return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1) + + +bmm_template = TritonTemplate( + name="bmm", + grid=bmm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", -2)}} + N = {{size("B", -1)}} + K = {{size("A", -1)}} + + stride_aq = {{stride("A", 0)}} + stride_am = {{stride("A", 1)}} + stride_ak = {{stride("A", 2)}} + + stride_bq = {{stride("B", 0)}} + stride_bk = {{stride("B", 1)}} + stride_bn = {{stride("B", 2)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + idx_q = tl.program_id(1) # batch dimension for BMM + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_q = tl.program_id(1) # batch dimension for BMM + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}} +""", +) + +aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out") +aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out") + + +@register_lowering(aten.bmm) +def tuned_bmm(mat1, mat2, *, layout=None): + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + + # options to tune from + choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] + if use_triton_template(layout): + for config in mm_configs(m, n, k): + bmm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout) + + +# Don't register this since it is slower than decomposing it +# @register_lowering(aten.baddbmm) +def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout) + + # options to tune from + choices = ( + [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)] + if use_aten_gemm_kernels() + else [] + ) + if use_triton_template(layout): + for config in mm_configs(m, n, k): + bmm_template.maybe_append_choice( + choices, + input_nodes=(inp, mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + ) + + return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..8021720b01d2761e0aa758a1feb2ac4d3effa320 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm.py @@ -0,0 +1,312 @@ +import functools +import logging +from typing import Any, Dict, List, Optional + +import torch +from torch._inductor.virtualized import V +from .. import config as inductor_config +from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate +from ..lowering import register_lowering +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import ( + use_aten_gemm_kernels, + use_cutlass_template, + use_max_autotune, + use_triton_template, +) +from .mm_common import ( + addmm_epilogue, + int8_mm_configs, + mm_args, + mm_configs, + mm_grid, + mm_options, +) + +log = logging.getLogger(__name__) +aten = torch.ops.aten + +mm_template = TritonTemplate( + name="mm", + grid=mm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + if B_PROLOGUE_CAST_TYPE is not None: + b = b.to(B_PROLOGUE_CAST_TYPE) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", +) + +aten_mm = ExternKernelChoice(torch.mm, "at::mm_out") + + +aten_addmm = ExternKernelChoice( + torch.addmm, "at::addmm_out", op_overload=aten.addmm.default +) + +aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm") + + +def _is_int8_mat(mat): + return mat.get_dtype() in (torch.int8, torch.uint8) + + +def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): + """ + Giving torch.addmm a 1D tensor calls a different (faster) cublasLt + kernel under the hood. There are a few shapes where this is slower, + but they are rare. + """ + if inp.stride(0) == 0 or inp.size(0) == 1: + return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta) + return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta) + + +aten_bias_addmm = ExternKernelChoice(bias_addmm, None) + + +@register_lowering(aten.mm, type_promotion_kind=None) +def tuned_mm(mat1, mat2, *, layout=None): + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + + # options to tune from + choices = [aten_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] + + if m * n != 0 and use_triton_template(layout): + for config in mm_configs(m, n, k): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + if m * n != 0 and use_cutlass_template(layout): + CUTLASSGemmTemplate.add_cutlass_gemm_choices( + choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True + ) + + from torch._inductor.ir import FixedLayout, FlexibleLayout + + if ( + len(choices) == 1 + and use_aten_gemm_kernels() + and isinstance(layout, FixedLayout) + ): + # If we are not autotuning, we can swap to a FlexibleLayout + # in order to get fusion optimizations to kick in, e.g. ConcatFusion + layout = FlexibleLayout( + device=layout.device, dtype=layout.dtype, size=layout.size + ) + choices = [aten_mm.bind((mat1, mat2), layout)] + + return autotune_select_algorithm("mm", choices, [mat1, mat2], layout) + + +@register_lowering(aten._int_mm, type_promotion_kind=None) +def tuned_int_mm(mat1, mat2, *, layout=None): + m, n, k, layout, mat1, mat2 = mm_args( + mat1, mat2, layout=layout, out_dtype=torch.int32 + ) + choices = ( + [aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] + ) + if m * n != 0 and use_triton_template(layout, enable_int32=True): + # TODO: Re-enable eager mode implementation once cuBLAS is fixed + choices = [] + for config in int8_mm_configs(m, n, k): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout) + + +@register_lowering(aten.addmm, type_promotion_kind=None) +def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout) + if m * n == 0 or not use_max_autotune(): + choices = ( + [ + aten_addmm.bind( + (inp, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout) + + choices = ( + [ + aten_addmm.bind( + (inp_expanded, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + + if ( + use_aten_gemm_kernels() + and inp_expanded.get_stride()[0] == 0 + and inp_expanded.get_device().type == "cuda" + and inductor_config.triton.autotune_cublasLt + ): + # unexpand inp to make sure fused addmm from cublasLt is used + choices.insert( + 0, + aten_bias_addmm.bind( + (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta + ), + ) + + if use_triton_template(layout): + for config in mm_configs(m, n, k): + mm_template.maybe_append_choice( + choices, + input_nodes=(inp_expanded, mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + ) + + if use_cutlass_template(layout): + CUTLASSGemmTemplate.add_cutlass_gemm_choices( + choices, + layout, + [mat1, mat2, inp_expanded], + alpha=alpha, + beta=beta, + input_reorder=[2, 0, 1], + fuseable=False, + ) + + return autotune_select_algorithm( + "addmm", choices, [inp_expanded, mat1, mat2], layout + ) + + +def fallback_mixed_mm(mat1, mat2, *, out): + return torch.mm(mat1, mat2.to(mat1.dtype), out=out) + + +aten_fallback_mixed_mm = ExternKernelChoice(fallback_mixed_mm, None) + + +@functools.lru_cache(None) +def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: + props = torch.cuda.get_device_properties(index or 0) + return props.major <= 7 + + +def tuned_mixed_mm(mat1, mat2, mat2_dtype): + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None) + choices = [aten_fallback_mixed_mm.bind((mat1, mat2), layout)] + if ( + mat1.layout.dtype != torch.float32 and not mat2.layout.is_contiguous() + ) or _is_sm7x_or_older_gpu(layout.device.index): + # can't use triton kernel unless one of these is true or if running on v100 (numerical issues) + return autotune_select_algorithm("mixed_mm", choices, [mat1, mat2], layout) + if inductor_config.force_mixed_mm: + choices = [] + b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "") + has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2) + for config in mm_configs(m, n, k, has_int8_tensor=has_int8_tensor): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout, b_prologue_cast_type), + ) + return autotune_select_algorithm("mixed_mm", choices, [mat1, mat2], layout) + + +# This op is a special case of the int_mm op which we use based on the pattern +# _int_mm -> mul (defined in ../fx_passes/post_grad.py) in order to prevent +# realization of the int32 _int_mm output by forcing fusion with the mul op. +# This is only used when config.force_fuse_int_mm_with_mul = True +def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None): + out_dtype = ( + torch.promote_types(mat3.get_dtype(), torch.int32) + if out_dtype is None + else out_dtype + ) + m, n, k, layout, mat1, mat2, mat3 = mm_args( + mat1, mat2, mat3, layout=layout, out_dtype=out_dtype + ) + choices: List[Dict[Any, Any]] = [] + for config in int8_mm_configs(m, n, k): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2, mat3), + layout=layout, + **dict(mm_options(config, m, n, k, layout), ACC_TYPE="tl.int32"), + suffix_args=1, + epilogue_fn=V.ops.mul, + ) + return autotune_select_algorithm("int_mm", choices, [mat1, mat2, mat3], layout) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_common.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_common.py new file mode 100644 index 0000000000000000000000000000000000000000..12a280cb91bdcd839e159be1c5c1964e585de965 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_common.py @@ -0,0 +1,262 @@ +import functools +import logging +from typing import cast, List, Tuple + +import sympy + +import torch +from torch._inductor.select_algorithm import realize_inputs +from torch._inductor.virtualized import V + +from .. import config as inductor_config +from ..utils import ceildiv as cdiv, next_power_of_2 + +log = logging.getLogger(__name__) + + +def triton_config(num_stages, num_warps, **kwargs): + from triton import Config + + return Config(kwargs, num_stages=num_stages, num_warps=num_warps) + + +def filtered_configs( + m: int, + n: int, + k: int, + configs: List[Tuple[int, int, int, int, int]], + has_int8_tensor=False, +): + """Heuristic to shrink configs when they are bigger than the input size""" + + # According to https://github.com/openai/triton/issues/2156#issuecomment-1695897424 + # it's safer to use at least [32, 32] block size for int8/uint8 + # tensors + min_block_size = 32 if has_int8_tensor else 16 + m = max( + next_power_of_2( + V.graph.sizevars.size_hint( + m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + ) + ), + min_block_size, + ) + n = max( + next_power_of_2( + V.graph.sizevars.size_hint( + n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + ) + ), + min_block_size, + ) + k = max( + next_power_of_2( + V.graph.sizevars.size_hint( + k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + ) + ), + min_block_size, + ) + used = set() + for block_m, block_n, block_k, num_stages, num_warps in configs: + # shrink configs for small sizes + block_m = max(min(block_m, m), min_block_size) + block_n = max(min(block_n, n), min_block_size) + block_k = max(min(block_k, k), min_block_size) + # each warp computes 16x16 tile = 256 + num_warps = min(num_warps, block_m * block_n // 256) + if torch.version.hip: + for matrix_instr_nonkdim in [0, 16]: + if matrix_instr_nonkdim != 0 and ( + block_m % matrix_instr_nonkdim != 0 + or block_n % matrix_instr_nonkdim != 0 + ): + # block_m and block_n must be a multiple of matrix_instr_nonkdim + continue + if ( + block_m, + block_n, + block_k, + num_stages, + num_warps, + matrix_instr_nonkdim, + ) not in used: + used.add( + ( + block_m, + block_n, + block_k, + num_stages, + num_warps, + matrix_instr_nonkdim, + ) + ) + yield triton_config( + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_stages=num_stages, + num_warps=num_warps, + matrix_instr_nonkdim=matrix_instr_nonkdim, + ) + else: + if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used: + used.add((block_m, block_n, block_k, num_stages, num_warps, 0)) + yield triton_config( + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_stages=num_stages, + num_warps=num_warps, + ) + + +# List of dictionaries to store the kernel configs. Configs that evaluate to true +# will be utilised on the target platform +mm_kernel_configs = [ + # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" + {"config": (64, 64, 32, 2, 4), "cond": True}, + {"config": (64, 128, 32, 3, 4), "cond": True}, + {"config": (128, 64, 32, 3, 4), "cond": True}, + {"config": (64, 128, 32, 4, 8), "cond": True}, + {"config": (128, 64, 32, 4, 8), "cond": True}, + {"config": (64, 32, 32, 5, 8), "cond": True}, + {"config": (32, 64, 32, 5, 8), "cond": True}, + {"config": (128, 128, 32, 2, 8), "cond": True}, + {"config": (64, 64, 64, 3, 8), "cond": True}, + {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None}, + {"config": (64, 64, 16, 2, 4), "cond": True}, + {"config": (32, 32, 16, 1, 2), "cond": True}, +] + +int8_mm_kernel_configs = [ + {"config": (64, 64, 32, 2, 4), "cond": True}, + {"config": (64, 128, 32, 3, 4), "cond": True}, + {"config": (128, 64, 32, 3, 4), "cond": True}, + {"config": (64, 128, 32, 4, 8), "cond": True}, + {"config": (128, 64, 32, 4, 8), "cond": True}, + {"config": (64, 32, 32, 5, 8), "cond": True}, + {"config": (32, 64, 32, 5, 8), "cond": True}, + {"config": (128, 128, 32, 2, 8), "cond": True}, + {"config": (64, 64, 64, 3, 8), "cond": True}, + # {"config": (32, 32, 128, 2, 4), "cond": True}, + # {"config": (64, 64, 16, 2, 4), "cond": True}, + # {"config": (32, 32, 16, 1, 2), "cond": True}, + {"config": (128, 256, 128, 3, 8), "cond": torch.version.hip is None}, + {"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None}, +] + +# Create filtered list of configs based on cond evaluation + + +mm_platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in mm_kernel_configs + if config["cond"] +) +int8_platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in int8_mm_kernel_configs + if config["cond"] +) + +# On ROCm convert num_stages to 1 as pipelining provides no benefit +if torch.version.hip: + mm_platform_configs = tuple( + (config[0], config[1], config[2], 1, config[4]) + for config in mm_platform_configs + ) + int8_platform_configs = tuple( + (config[0], config[1], config[2], 1, config[4]) + for config in mm_platform_configs + ) + +mm_configs = functools.partial( + filtered_configs, + configs=mm_platform_configs, +) + +int8_mm_configs = functools.partial( + filtered_configs, + configs=int8_platform_configs, +) + + +def mm_grid(m, n, meta): + """ + The CUDA grid size for matmul triton templates. + """ + return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1) + + +def acc_type(dtype): + if dtype in (torch.float16, torch.bfloat16): + return "tl.float32" + return f"tl.{dtype}".replace("torch.", "") + + +def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None): + """ + Common options to matmul triton templates. + """ + even_k_symbolic = ( + # it isn't worth guarding on this + sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) + == config.kwargs["BLOCK_K"] + ) + allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( + not inductor_config.force_same_precision + or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0) + ) + return dict( + GROUP_M=8, + EVEN_K=even_k_symbolic, + ALLOW_TF32=allow_tf32, + ACC_TYPE=acc_type(layout.dtype), + B_PROLOGUE_CAST_TYPE=b_prologue_cast_type, + num_stages=config.num_stages, + num_warps=config.num_warps, + **config.kwargs, + ) + + +def mm_args(mat1, mat2, *others, layout=None, out_dtype=None, use_4x2_dim=False): + """ + Common arg processing for mm,bmm,addmm,etc + """ + mat1, mat2 = realize_inputs(mat1, mat2) + *b1, m, k1 = mat1.get_size() + *b2, k2, n = mat2.get_size() + b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)] + if use_4x2_dim: + k2 = k2 * 2 + k = V.graph.sizevars.guard_equals(k1, k2) + if layout is None: + from torch._inductor.ir import FixedLayout + + if out_dtype is None: + out_dtype = mat1.get_dtype() + layout = FixedLayout( + mat1.get_device(), + out_dtype, + [*b, m, n], + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + + from ..lowering import expand + + others = [realize_inputs(expand(x, layout.size)) for x in others] + + return [m, n, k, layout, mat1, mat2, *others] + + +def addmm_epilogue(dtype, alpha, beta): + def epilogue(acc, bias): + if alpha != 1: + acc = V.ops.mul(acc, V.ops.constant(alpha, dtype)) + if beta != 1: + bias = V.ops.mul(bias, V.ops.constant(beta, dtype)) + return V.ops.add(acc, bias) + + return epilogue diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_plus_mm.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_plus_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..95ef6f043dfce7a5f8e4e8064e9a25809af29ac6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/mm_plus_mm.py @@ -0,0 +1,235 @@ +import functools + +import torch + +from ..lowering import lowerings +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import use_aten_gemm_kernels, use_triton_template +from ..virtualized import V +from .mm_common import mm_args, mm_grid, mm_options + +aten = torch.ops.aten + +aten_mm_plus_mm = ExternKernelChoice( + torch.ops.inductor._mm_plus_mm, "torch::inductor::_mm_plus_mm" +) + +mm_plus_mm_template = TritonTemplate( + name="mm_plus_mm", + grid=mm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C", "D")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K1 = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + # K2 = {{size("C", 1)}} + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + stride_cm = {{stride("C", 0)}} + stride_ck = {{stride("C", 1)}} + stride_dk = {{stride("D", 0)}} + stride_dn = {{stride("D", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck) + D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k1 in range(K1, 0, -BLOCK_K): + # First matmul with A @ B + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k1, other=0.) + b = tl.load(B, mask=rk[:, None] < k1, other=0.) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + for k2 in range(K1, 0, -BLOCK_K): + + # Second matmul with C @ D + if EVEN_K: + c = tl.load(C) + d = tl.load(D) + else: + c = tl.load(C, mask=rk[None, :] < k2, other=0.) + d = tl.load(D, mask=rk[:, None] < k2, other=0.) + acc += tl.dot(c, d, allow_tf32=ALLOW_TF32) + C += BLOCK_K * stride_ck + D += BLOCK_K * stride_dk + + + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", +) + + +@functools.lru_cache(None) +def mm_configs(): + import triton + + # List of dictionaries to store the kernel configs. Configs that evaluate to true + # will be utilised on the target platform + mm_triton_configs = [ + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, + "num_stages": 2, + "num_warps": 4, + "cond": True, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, + "num_stages": 3, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, + "num_stages": 4, + "num_warps": 16, + "cond": True, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32}, + "num_stages": 4, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, + "num_stages": 4, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, + "num_stages": 1, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, + "num_stages": 1, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, + "num_stages": 1, + "num_warps": 8, + "cond": torch.version.hip is None, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, + "num_stages": 2, + "num_warps": 4, + "cond": True, + }, + { + "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16}, + "num_stages": 1, + "num_warps": 2, + "cond": True, + }, + ] + + # Filter out configs in which cond evaluates to true + # On ROCm convert num_stages to 1 as pipelining provides no benefit + if torch.version.hip: + filtered_configs = [ + triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"]) + for c in mm_triton_configs + if c["cond"] + ] + else: + filtered_configs = [ + triton.Config( + c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"] + ) + for c in mm_triton_configs + if c["cond"] + ] + + return filtered_configs + + +def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): + """ + Computes mm(mat1, mat2) + mm(mat3, mat4) + """ + m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout) + # Optimization is optional, because we can always just not do the fusion + if ( + m1 * n1 == 0 + or m2 * n2 == 0 + or not V.graph.sizevars.statically_known_list_equals( + mat1.get_size(), mat3.get_size() + ) + or not V.graph.sizevars.statically_known_list_equals( + mat2.get_size(), mat4.get_size() + ) + ): + # TODO(jansel): support different K values when this is fixed: + # https://github.com/openai/triton/issues/967 + return lowerings[aten.add]( + lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4) + ) + + assert layout1 == layout2 + # options to tune from + choices = ( + [aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)] + if use_aten_gemm_kernels() + else [] + ) + if use_triton_template(layout1): + for config in mm_configs(): + # see https://github.com/openai/triton/issues/1298 + # BLOCK_K = K causes llvm error + if config.kwargs["BLOCK_K"] < k1: + mm_plus_mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2, mat3, mat4), + layout=layout1, + **mm_options(config, m1, n1, k1, layout1), + ) + + return autotune_select_algorithm( + "mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1 + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/EmptyTensor.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/EmptyTensor.h new file mode 100644 index 0000000000000000000000000000000000000000..88a29547406cd144a2df80e1505e3fd433d0ee59 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/EmptyTensor.h @@ -0,0 +1,29 @@ +// Copyright © 2022 Apple Inc. + +#pragma once +#include + +namespace at::detail { + +C10_EXPORT TensorBase empty_mps( + IntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt, + c10::optional memory_format_opt); +C10_EXPORT TensorBase empty_mps( + IntArrayRef size, const TensorOptions &options); + +C10_EXPORT TensorBase empty_strided_mps( + IntArrayRef size, + IntArrayRef stride, + ScalarType dtype, + c10::optional device_opt); + +C10_EXPORT TensorBase empty_strided_mps( + IntArrayRef size, + IntArrayRef stride, + const TensorOptions &options); + +} // namespace at::detail diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/IndexKernels.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/IndexKernels.h new file mode 100644 index 0000000000000000000000000000000000000000..2b3cfae0c3eeaa4754c1fa97f2ee9c02efe6b2d4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/IndexKernels.h @@ -0,0 +1,630 @@ +#pragma once + +namespace at::mps { + +static const char * indexing_metal_shaders = R"INDEX_METAL( +#include +#include + +using namespace metal; + +#if __METAL_VERSION__ < 300 +struct IndexAB { + // Allow up to 16 indices + metal::array indexArray [[ id(0) ]]; +}; +#else +struct IndexAB { + constant int64_t* indexArray; +}; + +#endif + +template +kernel void index_select( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else + constant IndexAB & indexAB [[buffer(0)]], +#endif + constant void * indexSizes [[buffer(1)]], + constant void * indexStrides [[buffer(2)]], + constant OffsetsT * offsets [[buffer(3)]], + constant void * inputData [[buffer(4)]], + device void * outputData [[buffer(5)]], + constant uint32_t & num_indices [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]) { + constant int64_t * index_sizes = (constant int64_t *)indexSizes; + constant int64_t * index_strides = (constant int64_t *)indexStrides; + int64_t offset = 0; + for (uint32_t i = 0; i < num_indices; i++) { +#if __METAL_VERSION__ >= 300 + constant int64_t* indexArray = indexAB[i].indexArray; +#else + constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i]; +#endif + int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; + if (index < 0) { + index += index_sizes[i]; + } + offset += index * index_strides[i]; + } + device T * out = (device T*)((device char*)outputData + offsets[thread_index].x); + constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y + offset); + *out = *in; +} + +template +void index_put_impl( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB, +#else + constant IndexAB & indexAB, +#endif + constant int64_t * index_sizes, + constant int64_t * index_strides, + constant OffsetsT * offsets, + constant void * inputData, + device void * outputData, + constant uint32_t & num_indices, + uint thread_index) { + int64_t offset = 0; + for (uint32_t i = 0; i < num_indices; i++) { +#if __METAL_VERSION__ >= 300 + constant int64_t* indexArray = indexAB[i].indexArray; +#else + constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i]; +#endif + int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; + + if (index < 0) { + index += index_sizes[i]; + } + offset += index * index_strides[i]; + } + device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset); + constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y); + *out = *in; +} + +template +kernel void index_put_serial( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else + constant IndexAB & indexAB [[buffer(0)]], +#endif + constant void * indexSizes [[buffer(1)]], + constant void * indexStrides [[buffer(2)]], + constant OffsetsT * offsets [[buffer(3)]], + constant void * inputData [[buffer(4)]], + device void * outputData [[buffer(5)]], + constant uint32_t & num_indices [[buffer(6)]], + constant uint * numIters [[buffer(7)]], + uint thread_index [[thread_position_in_grid]]) { + + constant int64_t * index_sizes = (constant int64_t *)indexSizes; + constant int64_t * index_strides = (constant int64_t *)indexStrides; + + for (uint iter_i = 0; iter_i < *numIters; iter_i++) { + index_put_impl(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i); + } +} + +template +kernel void index_put( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else + constant IndexAB & indexAB [[buffer(0)]], +#endif + constant void * indexSizes [[buffer(1)]], + constant void * indexStrides [[buffer(2)]], + constant OffsetsT * offsets [[buffer(3)]], + constant void * inputData [[buffer(4)]], + device void * outputData [[buffer(5)]], + constant uint32_t & num_indices [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]) { + + constant int64_t * index_sizes = (constant int64_t *)indexSizes; + constant int64_t * index_strides = (constant int64_t *)indexStrides; + index_put_impl(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index); +} + +#if __METAL_VERSION__ < 300 +#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \ +template \ +[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \ +kernel void index_ ## INDEX_OP_TYPE( \ + constant IndexAB & indexAB [[buffer(0)]], \ + constant void * indexSizes [[buffer(1)]], \ + constant void * indexStrides [[buffer(2)]], \ + constant IDX_DTYPE * offsets [[buffer(3)]], \ + constant void * inputData [[buffer(4)]], \ + device void * outputData [[buffer(5)]], \ + constant uint32_t & num_indices [[buffer(6)]], \ + uint thread_index [[thread_position_in_grid]]); +#else +#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \ +template \ +[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \ +kernel void index_ ## INDEX_OP_TYPE( \ + constant IndexAB * indexAB [[buffer(0)]], \ + constant void * indexSizes [[buffer(1)]], \ + constant void * indexStrides [[buffer(2)]], \ + constant IDX_DTYPE * offsets [[buffer(3)]], \ + constant void * inputData [[buffer(4)]], \ + device void * outputData [[buffer(5)]], \ + constant uint32_t & num_indices [[buffer(6)]], \ + uint thread_index [[thread_position_in_grid]]); +#endif + +#define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \ + REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \ + REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \ + REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \ + REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \ + REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \ + REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \ + REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \ + REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3); + +REGISTER_INDEX_OP_ALL_DTYPES(select); +REGISTER_INDEX_OP_ALL_DTYPES(put); + +#if __METAL_VERSION__ < 300 +#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \ +template \ +[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \ +kernel void index_ ## INDEX_OP_TYPE( \ + constant IndexAB & indexAB [[buffer(0)]], \ + constant void * indexSizes [[buffer(1)]], \ + constant void * indexStrides [[buffer(2)]], \ + constant IDX_DTYPE * offsets [[buffer(3)]], \ + constant void * inputData [[buffer(4)]], \ + device void * outputData [[buffer(5)]], \ + constant uint32_t & num_indices [[buffer(6)]], \ + constant uint * numIters [[buffer(7)]], \ + uint thread_index [[thread_position_in_grid]]); +#else +#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \ +template \ +[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \ +kernel void index_ ## INDEX_OP_TYPE( \ + constant IndexAB * indexAB [[buffer(0)]], \ + constant void * indexSizes [[buffer(1)]], \ + constant void * indexStrides [[buffer(2)]], \ + constant IDX_DTYPE * offsets [[buffer(3)]], \ + constant void * inputData [[buffer(4)]], \ + device void * outputData [[buffer(5)]], \ + constant uint32_t & num_indices [[buffer(6)]], \ + constant uint * numIters [[buffer(7)]], \ + uint thread_index [[thread_position_in_grid]]); +#endif + +#define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \ + REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \ + REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \ + REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \ + REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \ + REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \ + REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \ + REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \ + REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3); + +REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial); + +template +kernel void kernel_index_offsets(constant StridesT * strides [[buffer(0)]], + device DataT * data_offsets [[buffer(1)]], + constant uint * iter_shape [[buffer(2)]], + constant uint & num_dimensions [[buffer(3)]], + uint thread_index [[thread_position_in_grid]]) { + data_offsets[thread_index] = 0; + uint32_t idx = thread_index; + for (uint32_t dim = 0; dim < num_dimensions; dim++) { + uint32_t remainder = idx % iter_shape[dim]; + idx /= iter_shape[dim]; + + data_offsets[thread_index] += remainder * DataT(strides[dim]); + } +} + +template +[[host_name("kernel_index_offsets_32")]] +kernel void kernel_index_offsets( + constant packed_uint3 * strides [[buffer(0)]], + device uint3 * data_offsets [[buffer(1)]], + constant uint * iter_shape [[buffer(2)]], + constant uint & num_dimensions [[buffer(3)]], + uint thread_index [[thread_position_in_grid]]); + +template +[[host_name("kernel_index_offsets_64")]] +kernel void kernel_index_offsets( + constant packed_uint3 * strides [[buffer(0)]], + device ulong3 * data_offsets [[buffer(1)]], + constant uint * iter_shape [[buffer(2)]], + constant uint & num_dimensions [[buffer(3)]], + uint thread_index [[thread_position_in_grid]]); + +template +kernel void index_put_accumulate_native_dtypes( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else + constant IndexAB & indexAB [[buffer(0)]], +#endif + constant void * indexSizes [[buffer(1)]], + constant void * indexStrides [[buffer(2)]], + constant OffsetsT * offsets [[buffer(3)]], + constant void * inputData [[buffer(4)]], + device void * outputData [[buffer(5)]], + constant uint32_t & num_indices [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]) { + constant int64_t * index_sizes = (constant int64_t *)indexSizes; + constant int64_t * index_strides = (constant int64_t *)indexStrides; + int64_t offset = 0; + for (uint32_t i = 0; i < num_indices; i++) { +#if __METAL_VERSION__ >= 300 + constant int64_t* indexArray = indexAB[i].indexArray; +#else + constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i]; +#endif + int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; + if (index < 0) { + index += index_sizes[i]; + } + offset += index * index_strides[i]; + } + device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset); + constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index].y); + atomic_fetch_add_explicit(out, *in, memory_order_relaxed); +} + +template +__attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) { + device atomic_uint* uintAddr = (device atomic_uint*)addr; + uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed); + T updated = as_type(expected) + value; + while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type(updated), memory_order_relaxed, memory_order_relaxed)) { + updated = as_type(expected) + value; + } +} + +template +kernel void atomic_index_put_accumulate( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else + constant IndexAB & indexAB [[buffer(0)]], +#endif + constant void * indexSizes [[buffer(1)]], + constant void * indexStrides [[buffer(2)]], + constant OffsetsT * offsets [[buffer(3)]], + constant void * inputData [[buffer(4)]], + device void * outputData [[buffer(5)]], + constant uint32_t & num_indices [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]) { + constant int64_t * index_sizes = (constant int64_t *)indexSizes; + constant int64_t * index_strides = (constant int64_t *)indexStrides; + int64_t offset = 0; + for (uint32_t i = 0; i < num_indices; i++) { +#if __METAL_VERSION__ >= 300 + constant int64_t* indexArray = indexAB[i].indexArray; +#else + constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i]; +#endif + int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; + if (index < 0) { + index += index_sizes[i]; + } + offset += index * index_strides[i]; + } + device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset); + constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y); + atomic_fetch_add_relaxed(out, *in); +} + +template +[[host_name("index_put_accumulate_32bit_float_idx32")]] +kernel void atomic_index_put_accumulate( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else + constant IndexAB & indexAB [[buffer(0)]], +#endif + constant void * indexSizes [[buffer(1)]], + constant void * indexStrides [[buffer(2)]], + constant uint3 * offsets [[buffer(3)]], + constant void * inputData [[buffer(4)]], + device void * outputData [[buffer(5)]], + constant uint32_t & num_indices [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]); + +template +[[host_name("index_put_accumulate_32bit_float_idx64")]] +kernel void atomic_index_put_accumulate( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else + constant IndexAB & indexAB [[buffer(0)]], +#endif + constant void * indexSizes [[buffer(1)]], + constant void * indexStrides [[buffer(2)]], + constant ulong3 * offsets [[buffer(3)]], + constant void * inputData [[buffer(4)]], + device void * outputData [[buffer(5)]], + constant uint32_t & num_indices [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]); + +template +[[host_name("index_put_accumulate_32bit_int_idx32")]] +kernel void index_put_accumulate_native_dtypes( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else + constant IndexAB & indexAB [[buffer(0)]], +#endif + constant void * indexSizes [[buffer(1)]], + constant void * indexStrides [[buffer(2)]], + constant uint3 * offsets [[buffer(3)]], + constant void * inputData [[buffer(4)]], + device void * outputData [[buffer(5)]], + constant uint32_t & num_indices [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]); + +template +[[host_name("index_put_accumulate_32bit_int_idx64")]] +kernel void index_put_accumulate_native_dtypes( +#if __METAL_VERSION__ >= 300 + constant IndexAB * indexAB [[buffer(0)]], +#else + constant IndexAB & indexAB [[buffer(0)]], +#endif + constant void * indexSizes [[buffer(1)]], + constant void * indexStrides [[buffer(2)]], + constant ulong3 * offsets [[buffer(3)]], + constant void * inputData [[buffer(4)]], + device void * outputData [[buffer(5)]], + constant uint32_t & num_indices [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]); +)INDEX_METAL"; + +static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER( +struct __attribute__ ((packed)) packed_uint5{{ + uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u; +}}; + +template +Y cast(const X x); + +template<> +{1} cast<{1}, {0}>(const {0} x) {{ + return {2}; +}} + +kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint5 & size [[buffer(2)]], + constant packed_uint5 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + packed_uint5 local_index; + local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x; + local_index.y = linear_index / (size.u * size.w * size.z) % size.y; + local_index.z = linear_index / (size.u * size.w) % size.z; + local_index.w = linear_index / size.u % size.w; + local_index.u = linear_index % size.u; + + packed_uint5 strided_index; + strided_index.x = local_index.x * stride.x; + strided_index.y = local_index.y * stride.y; + strided_index.z = local_index.z * stride.z; + strided_index.w = local_index.w * stride.w; + strided_index.u = local_index.u * stride.u; + + dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]); +}} + +kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint4 & size [[buffer(2)]], + constant packed_uint4 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + packed_uint4 local_index; + local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0]; + local_index.y = linear_index / (size[3] * size[2]) % size[1]; + local_index.z = linear_index / size[3] % size[2]; + local_index.w = linear_index % size[3]; + + const packed_uint4 strided_index = local_index * stride; + dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]); +}} + +kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint3 & size [[buffer(2)]], + constant packed_uint3 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + packed_uint3 local_index; + local_index.x = linear_index / (size[2] * size[1]) % size[0]; + local_index.y = linear_index / size[2] % size[1]; + local_index.z = linear_index % size[2]; + + const packed_uint3 strided_index = local_index * stride; + dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]); +}} + +kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint2 & size [[buffer(2)]], + constant packed_uint2 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + packed_uint2 local_index; + local_index.x = linear_index / size[1] % size[0]; + local_index.y = linear_index % size[1]; + + const packed_uint2 strided_index = local_index * stride; + dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]); +}} + +kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant int & size [[buffer(2)]], + constant int & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + const int local_index = linear_index % size; + const int strided_index = local_index * stride; + dst[strided_index] = cast<{1}>(src[linear_index]); +}} +)METAL_SCATTER"; + +static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER( +struct __attribute__ ((packed)) packed_uint5{{ + uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u; +}}; + +template +Y cast(const X x); + +template<> +{1} cast<{1}, {0}>(const {0} x) {{ + return {2}; +}} + +kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint5 & size [[buffer(2)]], + constant packed_uint5 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + + packed_uint5 local_index; + local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x; + local_index.y = linear_index / (size.u * size.w * size.z) % size.y; + local_index.z = linear_index / (size.u * size.w) % size.z; + local_index.w = linear_index / size.u % size.w; + local_index.u = linear_index % size.u; + + packed_uint5 strided_index; + strided_index.x = local_index.x * stride.x; + strided_index.y = local_index.y * stride.y; + strided_index.z = local_index.z * stride.z; + strided_index.w = local_index.w * stride.w; + strided_index.u = local_index.u * stride.u; + + dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]); +}} + +kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint4 & size [[buffer(2)]], + constant packed_uint4 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + packed_uint4 local_index; + local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0]; + local_index.y = linear_index / (size[3] * size[2]) % size[1]; + local_index.z = linear_index / size[3] % size[2]; + local_index.w = linear_index % size[3]; + + const packed_uint4 strided_index = local_index * stride; + dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]); +}} + +kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint3 & size [[buffer(2)]], + constant packed_uint3 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + packed_uint3 local_index; + local_index.x = linear_index / (size[2] * size[1]) % size[0]; + local_index.y = linear_index / size[2] % size[1]; + local_index.z = linear_index % size[2]; + + const packed_uint3 strided_index = local_index * stride; + dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]); +}} + +kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant packed_uint2 & size [[buffer(2)]], + constant packed_uint2 & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + packed_uint2 local_index; + local_index.x = linear_index / size[1] % size[0]; + local_index.y = linear_index % size[1]; + + const packed_uint2 strided_index = local_index * stride; + dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]); +}} + +kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]], + constant void * src_ [[buffer(0)]], + device void * dst_ [[buffer(1)]], + constant int & size [[buffer(2)]], + constant int & stride [[buffer(3)]], + constant uint32_t & numel [[buffer(4)]]) {{ + if (linear_index >= numel) return; + + constant {0} * src = (constant {0} *)src_; + device {1} * dst = (device {1} *)dst_; + + const int local_index = linear_index % size; + const int strided_index = local_index * stride; + dst[linear_index] = cast<{1}>(src[strided_index]); +}} +)METAL_GATHER"; +} // namespace at::mps diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGuardImpl.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGuardImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..fe43fcf40fd34a38c9633b0c7a01e5668bfc9aba --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGuardImpl.h @@ -0,0 +1,174 @@ +// Copyright © 2022 Apple Inc. + +#pragma once +#include +#include +#include +#include +#include +#include + +#ifdef __OBJC__ +#include +#include +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at::mps { + +typedef MPSEvent* mpsEvent_t; + +// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl +// https://github.com/pytorch/pytorch/issues/77170 +struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr c10::DeviceType static_type = c10::DeviceType::MPS; + + // constructor + MPSGuardImpl() {} + explicit MPSGuardImpl(c10::DeviceType t) { + TORCH_INTERNAL_ASSERT(t == c10::DeviceType::MPS); + } + + // returns the type + c10::DeviceType type() const override { + return c10::DeviceType::MPS; + } + + Device exchangeDevice(Device d) const override { + return Device(c10::DeviceType::MPS, 0); + } + + Device getDevice() const override { + return Device(c10::DeviceType::MPS, 0); + } + + c10::optional uncheckedGetDevice() const noexcept { + return Device(c10::DeviceType::MPS, 0); + } + + void setDevice(Device d) const override { + TORCH_INTERNAL_ASSERT(d.is_mps()); + } + + void uncheckedSetDevice(Device d) const noexcept override { + // TODO: Currently setting only device 0 + } + + Stream getStream(Device d) const noexcept override { + return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); + } + + Stream getDefaultStream(Device d) const override { + return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); + } + + // NB: These do NOT set the current device + Stream exchangeStream(Stream s) const noexcept override { + return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); + } + DeviceIndex deviceCount() const noexcept override { + if (at::hasMPS()) { + //TODO: extend it for multi-device case + return 1; + } else { + return 0; + } + } + + // Event-related functions + void createEvent( + mpsEvent_t* event, + const EventFlag flag) const; + + void destroyEvent( + void* event, + const DeviceIndex device_index) const noexcept override; + + void record( + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override; + + void block( + void* event, + const Stream& stream) const override; + + bool queryEvent(void* event) const override; + +}; + +/// A variant of OptionalDeviceGuard that is specialized for MPS. +struct OptionalMPSGuard { + explicit OptionalMPSGuard() : guard_() {} + + explicit OptionalMPSGuard(c10::optional device_opt) + : guard_(device_opt) {} + + /// Set the current MPS device to the passed device index, if it is not + /// nullopt + explicit OptionalMPSGuard(c10::optional device_index_opt) + : guard_(device_index_opt) {} + + // Copy is not allowed + OptionalMPSGuard(const OptionalMPSGuard&) = delete; + OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete; + OptionalMPSGuard(OptionalMPSGuard&& other) = delete; + OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete; + + /// Sets the MPS device to the given device, initializing the guard if it + /// is not already initialized. Errors if the given device is not a MPS + /// device. + void set_device(Device device) { + guard_.set_device(device); + } + + /// Sets the MPS device to the given device, initializing the guard if it is + /// not already initialized. Errors if the given device is not a MPS device. + void reset_device(Device device) { + guard_.reset_device(device); + } + + /// Sets the MPS device to the given device index, initializing the guard if + /// it is not already initialized. + void set_index(DeviceIndex device_index) { + guard_.set_index(device_index); + } + + /// Returns the device that was set immediately prior to initialization of the + /// guard, or nullopt if the guard is uninitialized. + c10::optional original_device() const { + return guard_.original_device(); + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device, if the guard is initialized, + /// or nullopt if the guard is uninitialized. + c10::optional current_device() const { + return guard_.current_device(); + } + + /// Restore the original MPS device, resetting this guard to uninitialized + /// state. + void reset() { + guard_.reset(); + } + + private: + c10::impl::InlineOptionalDeviceGuard guard_; +}; + + +C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl); + +} // namespace at::mps diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h new file mode 100644 index 0000000000000000000000000000000000000000..1a1039b916f8e47fb7771236611d4616d17bd445 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h @@ -0,0 +1,369 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CPU_CAPABILITY_AVX2 +#include +#include +#endif + + +namespace at { +namespace native { +namespace templates { +namespace cpu { +namespace { + +// ==================================================== Random ======================================================== + +template +void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) { + AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] { + std::lock_guard lock(generator->mutex_); + cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t { + uniform_int_from_to_distribution random(range, base); + return random(generator); + }); + }), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +} + +// This is the special kernel to handle single specific case: +// from(inclusive) = std::numeric_limits::lowest() +// to(exclusive) = None (= std::numeric_limits::max() + 1) +template +void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] { + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) { + std::lock_guard lock(generator->mutex_); + cpu_serial_kernel(iter, [generator]() -> scalar_t { + uniform_int_full_range_distribution random; + return random(generator); + }); + } else { + TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16"); + } + }); +} + +template +struct RandomFromToKernel { + void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, c10::optional gen) { + random_from_to_kernel(iter, range, base, check_generator(gen)); + } + void operator()(TensorIteratorBase& iter, c10::optional gen) { + random_full_64_bits_range_kernel(iter, check_generator(gen)); + } +}; + +template +void random_kernel(TensorIteratorBase& iter, RNG generator) { + std::lock_guard lock(generator->mutex_); + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] { + cpu_serial_kernel(iter, [generator]() -> scalar_t { + uniform_int_distribution random; + return random(generator); + }); + }); +} + +template +struct RandomKernel { + void operator()(TensorIteratorBase& iter, c10::optional gen) { + random_kernel(iter, check_generator(gen)); + } +}; + +// ==================================================== Normal ======================================================== + +#ifdef CPU_CAPABILITY_AVX2 +static void normal_fill_16_AVX2(float *data, + const __m256* two_pi, + const __m256* one, + const __m256* minus_two, + const __m256* mean, + const __m256* std_v) { + const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data)); + const __m256 u2 = _mm256_loadu_ps(data + 8); + // sincos256_ps and log256_ps are from avx_mathfun.h + const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1))); + const __m256 theta = _mm256_mul_ps(*two_pi, u2); + __m256 sintheta, costheta; + sincos256_ps(theta, &sintheta, &costheta); + const __m256 n1 = _mm256_mul_ps(radius, costheta); + const __m256 n2 = _mm256_mul_ps(radius, sintheta); + _mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean)); + _mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean)); +} + +template +void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) { + float *data = self.data_ptr(); + auto size = self.numel(); + std::lock_guard lock(generator->mutex_); + for (const auto i : c10::irange(size)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi); + const __m256 one = _mm256_set1_ps(1.0f); + const __m256 minus_two = _mm256_set1_ps(-2.0f); + const __m256 mean_v = _mm256_set1_ps(mean); + const __m256 std_v = _mm256_set1_ps(std); + + for (int64_t i = 0; i < size - 15; i += 16) { + normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v); + } + + if (size % 16 != 0) { + // Recompute the last 16 values. + data = data + size - 16; + for (const auto i : c10::irange(16)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v); + } +} +#endif + +template +static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) { + for (const auto j : c10::irange(8)) { + const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log. + const scalar_t u2 = data[j + 8]; + const scalar_t radius = std::sqrt(-2 * std::log(u1)); + const scalar_t theta = 2.0f * c10::pi * u2; + data[j] = radius * std::cos(theta) * std + mean; + data[j + 8] = radius * std::sin(theta) * std + mean; + } +} + +template +void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) { + scalar_t *data = self.data_ptr(); + auto size = self.numel(); + std::lock_guard lock(generator->mutex_); + for (const auto i : c10::irange(size)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + + for (int64_t i = 0; i < size - 15; i += 16) { + normal_fill_16(data + i, mean, std); + } + if (size % 16 != 0) { + // Recompute the last 16 values. + data = data + size - 16; + for (const auto i : c10::irange(16)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + normal_fill_16(data, mean, std); + } +} + +template +void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) { + auto size = self.numel(); + if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) { +#ifdef CPU_CAPABILITY_AVX2 + normal_fill_AVX2(self, static_cast(mean), static_cast(std), generator); +#else + normal_fill(self, static_cast(mean), static_cast(std), generator); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] { + if (size >= 16 && self.is_contiguous()) { + normal_fill(self, static_cast(mean), static_cast(std), generator); + } else { + auto iter = TensorIterator::borrowing_nullary_op(self); + std::lock_guard lock(generator->mutex_); + cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t { + at::normal_distribution normal(mean, std); + return static_cast(normal(generator)); + }); + } + }); + } +} + +template +struct NormalKernel { + void operator()(Tensor& self, double mean, double std, c10::optional gen) { + normal_kernel(self, mean, std, check_generator(gen)); + } +}; + +// ==================================================== Uniform ======================================================= + +template +void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + auto from = static_cast(from_); + auto to = static_cast(to_); + at::uniform_real_distribution uniform(from, to); + cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t { + return static_cast(uniform(generator)); + }); + }); +} + +template +struct UniformKernel { + void operator()(TensorIteratorBase& iter, double from, double to, c10::optional gen) { + uniform_kernel(iter, from, to, check_generator(gen)); + } +}; + +// ==================================================== Cauchy ======================================================== + +template +void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::cauchy_distribution cauchy(median, sigma); + cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t { + return static_cast(cauchy(generator)); + }); + }); +} + +template +struct CauchyKernel { + void operator()(TensorIteratorBase& iter, double median, double sigma, c10::optional gen) { + cauchy_kernel(iter, median, sigma, check_generator(gen)); + } +}; + +// ================================================== LogNormal ======================================================= + +template +void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::lognormal_distribution logNormal(mean, std); + cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t { + return static_cast(logNormal(generator)); + }); + }); +} + +template +struct LogNormalKernel { + void operator()(TensorIteratorBase& iter, double mean, double std, c10::optional gen) { + log_normal_kernel(iter, mean, std, check_generator(gen)); + } +}; + +// =================================================== Geometric ====================================================== + +template +void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::geometric_distribution geometric(p); + cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t { + return static_cast(geometric(generator)); + }); + }); +} + +template +struct GeometricKernel { + void operator()(TensorIteratorBase& iter, double p, c10::optional gen) { + geometric_kernel(iter, p, check_generator(gen)); + } +}; + +// ================================================== Exponential ===================================================== + +template +void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) { + TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::exponential_distribution exponential(lambda); + cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t { + return static_cast(exponential(generator)); + }); + }); +} + +template +struct ExponentialKernel { + void operator()(TensorIteratorBase& iter, double lambda, c10::optional gen) { + exponential_kernel(iter, lambda, check_generator(gen)); + } +}; + +// ================================================== Bernoulli ======================================================= + +template +void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, + self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(generator->mutex_); + using self_t = scalar_t; + auto p_cpu = p_.to(kCPU); + auto p = expand_inplace(self, p_cpu); + auto iter = TensorIteratorConfig() + .add_output(self) + .add_input(*p) + .check_all_same_dtype(false) + .build(); + if (p->scalar_type() == kDouble) { + cpu_serial_kernel(iter, [&](const double p_val) -> self_t { + at::bernoulli_distribution bernoulli(p_val); + return static_cast(bernoulli(generator)); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, + p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] { + using p_t = scalar_t; + cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t { + at::bernoulli_distribution bernoulli(p_val); + return static_cast(bernoulli(generator)); + }); + }); + } + }); +} + +template +void bernoulli_kernel(const TensorBase &self, double p, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, + self.scalar_type(), "bernoulli_scalar_cpu_", [&] { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(generator->mutex_); + auto iter = TensorIterator::borrowing_nullary_op(self); + cpu_serial_kernel(iter, [p, generator]() -> scalar_t { + at::bernoulli_distribution bernoulli(p); + return static_cast(bernoulli(generator)); + }); + }); +} + +template +struct BernoulliKernel { + void operator()(const TensorBase &self, double p, c10::optional gen) { + bernoulli_kernel(self, p, check_generator(gen)); + } + void operator()(const TensorBase &self, const TensorBase &p_, c10::optional gen) { + bernoulli_kernel(self, p_, check_generator(gen)); + } +}; + +}}}}} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..4537feddd0c7b2110f5da66ab995d4c9685d62e0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h @@ -0,0 +1,445 @@ +#pragma once + +#include +#if AT_MKLDNN_ENABLED() +#include +#include +#include +#include + +#include + +using PrimitiveCacheKey = std::tuple< + double, // input_scale + int64_t, // input_zero_point + std::vector, // input_shape + double, // output_scale + int64_t, // output_zero_point + int64_t, // OMP_number_of_threads + double, // accum_scale + int64_t>; // accum_zero_point + +enum CacheKeyIndex { + InputScale, + InputZeroPoint, + InputShape, + OutputScale, + OutputZeroPoint, + NumOfThreads, +}; + +// Base class of primitive cache +struct PrimitiveCache { + PrimitiveCacheKey key; + + bool hit(const PrimitiveCacheKey& key) { + return this->key == key; + } +}; + +using LinearParams = ideep::matmul_forward_params; +using Conv = dnnl::convolution_forward; +using ConvDesc = dnnl::convolution_forward::primitive_desc; +using ConvParams = ideep::convolution_forward_params; +using Deconv = dnnl::deconvolution_forward; +using DeconvDesc = dnnl::deconvolution_forward::primitive_desc; +using DeconvParams = ideep::deconv_forward_params; + +struct LinearPrimitiveCache : PrimitiveCache { + LinearPrimitiveCache() {} + + LinearPrimitiveCache( + const PrimitiveCacheKey& key, + const LinearParams& param) { + this->key = key; + this->param = param; + } + + LinearParams param; + + // For dynamic qlinear, scale and zero point + // are set at execution time. So we only need to compare + // the rest part of key. + bool hit_dynamic(const PrimitiveCacheKey& new_key) { + auto cached_input_shape = std::get(this->key); + auto new_input_shape = std::get(new_key); + return ( + cached_input_shape == new_input_shape && + std::get(this->key) == std::get(new_key)); + } + + LinearParams& get_param() { + return param; + } +}; + +struct ConvPrimitiveCache : PrimitiveCache { + ConvPrimitiveCache() {} + + ConvPrimitiveCache( + const PrimitiveCacheKey& key, + const ConvParams& params) { + this->key = key; + this->params = params; + } + + ConvParams params; + + ConvParams& get_params() { + return params; + } +}; + +struct DeconvPrimitiveCache : PrimitiveCache { + DeconvPrimitiveCache() {} + + DeconvPrimitiveCache( + const PrimitiveCacheKey& key, + const DeconvParams& params) { + this->key = key; + this->params = params; + } + + DeconvParams params; + + DeconvParams& get_params() { + return params; + } +}; + +enum PostOps { + NoPostOp, + Relu, + LeakyRelu, + Tanh, + Gelu +}; + +static std::unordered_map POST_OP_TABLE = { + {"none", NoPostOp}, + {"relu", Relu}, + {"leaky_relu", LeakyRelu}, + {"tanh", Tanh}, + {"gelu", Gelu} +}; + +struct PackedLinearWeightsOnednn : public LinearPackedParamsBase { + PackedLinearWeightsOnednn( + std::unique_ptr weight, + c10::optional bias, + at::Tensor orig_weight, + c10::optional orig_bias) + : weight_(std::move(weight)), + bias_(std::move(bias)), + orig_weight_(std::move(orig_weight)), + orig_bias_(std::move(orig_bias)) { + cache_initialized_flag = std::make_unique(); + } + std::unique_ptr weight_; + c10::optional bias_; + at::Tensor orig_weight_; + c10::optional orig_bias_; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override; + + at::Tensor apply_leaky_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point, + double negative_slope); + + at::Tensor apply_tanh( + at::Tensor input, + double output_scale, + int64_t output_zero_point); + + std::tuple> unpack() override; + + c10::optional bias() override { + return orig_bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + c10::optional bias); + + private: + LinearPrimitiveCache prim_cache; + std::unique_ptr cache_initialized_flag; + + template + at::Tensor apply_impl( + at::Tensor input, + double output_scale, + int64_t output_zero_point, + torch::List post_op_args = torch::List()); + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false); + + LinearPrimitiveCache& get_cache() { + return prim_cache; + } +}; + +template +struct PackedConvWeightsOnednn : public ConvPackedParamsBase { + PackedConvWeightsOnednn( + std::unique_ptr weight, + c10::optional bias, + at::Tensor orig_weight, + c10::optional orig_bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + uint8_t transpose) + : weight_(std::move(weight)), + bias_(std::move(bias)), + orig_weight_(std::move(orig_weight)), + orig_bias_(std::move(orig_bias)), + stride_(std::move(stride)), + padding_(std::move(padding)), + output_padding_(std::move(output_padding)), + dilation_(std::move(dilation)), + groups_(groups), + transpose_(transpose) { + cache_initialized_flag = std::make_unique(); + } + + std::unique_ptr weight_; + c10::optional bias_; + at::Tensor orig_weight_; + c10::optional orig_bias_; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + uint8_t transpose_; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) override; + + at::Tensor apply_add( + const at::Tensor& input, + const at::Tensor& accum, + double output_scale, + int64_t output_zero_point); + + at::Tensor apply_add_relu( + const at::Tensor& input, + const at::Tensor& accum, + double output_scale, + int64_t output_zero_point); + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return (bool)transpose_; + } + + private: + ConvPrimitiveCache conv_prim_cache; + DeconvPrimitiveCache deconv_prim_cache; + std::unique_ptr cache_initialized_flag; + + template + at::Tensor apply_impl( + const at::Tensor& input, + const c10::optional& accum, + double output_scale, + int64_t output_zero_point); + + ConvPrimitiveCache& get_conv_cache() { + assert(!transpose()); + return conv_prim_cache; + } + + DeconvPrimitiveCache& get_deconv_cache() { + assert(transpose()); + return deconv_prim_cache; + } +}; + +namespace onednn_utils { + +static ideep::attr_t create_attr_by_post_op( + const std::string& post_op_name, + const torch::List>& post_op_args, + const dnnl::algorithm post_algorithm) { + using ideep::tensor; + PostOps post_op = POST_OP_TABLE[post_op_name]; + if (post_op == Relu) { + return ideep::attr_t::fuse_relu(); + } else if (post_op == LeakyRelu) { + return ideep::attr_t::fuse_relu_v2(/*alpha=*/post_op_args[0].value().to()); + } else if (post_op == Tanh) { + return ideep::attr_t::fuse_tanh(); + } else if (post_op == Gelu) { + return ideep::attr_t::fuse_gelu_v2(0.f, 0.f, post_algorithm); + } + return ideep::attr_t(); +} + +// Try to reorder tensor to expected desc at runtime +// Do it in a `try...catch...` manner to avoid oneDNN's errors +// TODO: Move it to third_party/ideep +static void try_reorder( + ideep::tensor& t, + const ideep::tensor::desc&& desc, + ideep::scale_t scales) { + if (t.get_desc() != desc) { + try { + t = t.reorder_if_differ_in(desc); + } catch (...) { + ideep::tensor&& plain = t.to_public(nullptr, t.get_data_type()); + t = plain.reorder_if_differ_in(desc); + } + t.set_scale(scales); + } +} + +// ONEDNN requires symmetric quantization of weight +// Use this util function to check. +static bool is_weight_symmetric_quant( + const at::Tensor& weight, + bool is_transposed_conv) { + bool is_symmetric = true; + const auto qtype = weight.qscheme(); + if (qtype == c10::kPerTensorAffine) { + is_symmetric &= (weight.q_zero_point() == 0); + } else if (qtype == c10::kPerChannelAffine) { + if (is_transposed_conv) { + // This case is currently not supported in PyTorch + // but we do not want to raise an error in this util function. + is_symmetric = false; + } else { + auto output_channels = weight.size(0); + for (int i = 0; i < output_channels; ++i) { + auto zp = weight.q_per_channel_zero_points()[i].item(); + is_symmetric &= (zp == 0); + } + } + } else { + // This case is currently not supported in PyTorch + // but we do not want to raise an error in this util function. + is_symmetric = false; + } + return is_symmetric; +} + +// When qengine is x86, use this util func to check if onednn kernel +// is preferred than fbgemm's to get better performance. +static bool should_use_onednn_quant( + const at::Tensor& weight, + bool is_transposed_conv, + int groups, + torch::List output_padding) { + // Performance of onednn is only validated on Linux right now. + // Also, the heuristics for dispatching are based on perf data on Linux. + // So, for x86 qengine, we always use fbgemm kernels if OS is not Linux. + // TODO Support more OSs. +#if !defined(__linux__) + return false; +#else + bool vnni_available = cpuinfo_has_x86_avx512vnni(); + bool w_sym_quant = + is_weight_symmetric_quant(weight, is_transposed_conv); + bool opad_all_zero = + std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; }); + return vnni_available && (groups <= 100) && w_sym_quant && opad_all_zero; +#endif +} + +} // onednn_utils + +at::Tensor _qconv_prepack_onednn( + at::Tensor weight, // from CPU backend instead of QuantizedCPU + at::Tensor weight_scales, // Weight zero points must be 0 for onednn + double input_scale, + int64_t input_zero_point, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + c10::optional> input_shape=c10::nullopt); + +static at::Tensor _quantized_convolution_onednn( + at::Tensor act, // contains quantized values but not QTensor + double act_scale, + int64_t act_zero_point, + at::Tensor weight, // MKLDNN tensor with quantized values + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, // Bias is packed if not None + torch::List stride, + torch::List padding, + torch::List dilation, + bool transposed, + int64_t groups, + double inv_output_scale, + int64_t output_zero_point, + c10::optional accum=c10::nullopt, // accum to fused with conv add + double accum_scale=1.0, + int64_t accum_zero_point=0, + bool fp32_output=false, + c10::optional binary_attr=c10::nullopt, + c10::optional binary_alpha=c10::nullopt, + c10::optional unary_attr=c10::nullopt, + torch::List> unary_scalars=torch::List>(), + c10::optional unary_algorithm=c10::nullopt); + +#endif // #if AT_MKLDNN_ENABLED() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..ff334d4c8d48ceeb4fa83fdbcd2e678a3e2d887d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h @@ -0,0 +1,335 @@ +#pragma once + +#ifdef USE_XNNPACK +#include + +#include +#include + +using xnnpack_operator = at::native::xnnpack::Operator; + +namespace at { +namespace native { +namespace xnnp_utils { + +/* + * Return shape in the same order as the memory format + * e.g. channels_last will return NHWC instead of NCHW + */ +std::vector get_mem_format_aware_shape(const at::Tensor& in); + +/* + * Input is always int8_t, output can be [int8_t, uint8_t]. + * input + offset = output + * int8_t + 128 = uint8_t + * int8_t + 0 = int8_t + */ +template +void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out); + +template +Tensor convert_conv_weights_to_channel_last_tensor( + const at::Tensor& src, + int groups, + bool transpose); + +/* + * Series of create wrapper functions to call xnn_create_[de]conv* functions. + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_create_convolution2d_nhwc( + uint32_t pad_top, + uint32_t pad_right, + uint32_t pad_bottom, + uint32_t pad_left, + uint32_t kernel_h, + uint32_t kernel_w, + uint32_t stride_h, + uint32_t stride_w, + uint32_t dilation_h, + uint32_t dilation_w, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t ip_chan_stride, + size_t op_chan_stride, + int8_t izp, + float ip_scale, + int8_t kzp, + const float* k_scales, + const int8_t* kernel, + const int32_t* bias, + int8_t ozp, + float op_scale, + int8_t op_min, + int8_t op_max, + uint32_t flags, + xnn_operator_t* op, + bool per_channel, + bool transpose) { + /* Symmetric quantization forces kzp = 0 */ + TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero." + "But got: ", kzp); + + if (transpose) { + TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); + return xnn_create_deconvolution2d_nhwc_qs8( + pad_top, /* uint32_t output_padding_top */ + pad_right, /* uint32_t output_padding_right */ + pad_bottom, /* uint32_t output_padding_bottom */ + pad_left, /* uint32_t output_padding_left */ + kernel_h, /* uint32_t kernel_height */ + kernel_w, /* uint32_t kernel_width */ + stride_h, /* uint32_t stride_height */ + stride_w, /* uint32_t stride_width */ + dilation_h, /* uint32_t dilation_height */ + dilation_w, /* uint32_t dilation_width */ + groups, /* uint32_t groups */ + group_input_channels, /* size_t group_input_channels */ + group_output_channels, /* size_t group_output_channels */ + ip_chan_stride, /* size_t input_pixel_stride */ + op_chan_stride, /* size_t output_pixel_stride */ + izp, /* int8_t input_zero_point */ + ip_scale, /* float input_scale */ + k_scales[0], /* float kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + ozp, /* int8_t output_zero_point */ + op_scale, /* float output_scale */ + op_min, /* int8_t output_min */ + op_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t weights_cache */ + op); /* xnn_operator_t* deconvolution_op_out */ + + } + + if (!per_channel) { + return xnn_create_convolution2d_nhwc_qs8( + pad_top, /* uint32_t input_padding_top */ + pad_right, /* uint32_t input_padding_right */ + pad_bottom, /* uint32_t input_padding_bottom */ + pad_left, /* uint32_t input_padding_left */ + kernel_h, /* uint32_t kernel_height */ + kernel_w, /* uint32_t kernel_width */ + stride_h, /* uint32_t subsampling_height */ + stride_w, /* uint32_t subsampling_width */ + dilation_h, /* uint32_t dilation_height */ + dilation_w, /* uint32_t dilation_width */ + groups, /* uint32_t groups */ + group_input_channels, /* size_t group_input_channels */ + group_output_channels, /* size_t group_output_channels*/ + ip_chan_stride, /* size_t input_channel_stride */ + op_chan_stride, /* size_t output_channel_stride */ + izp, /* int8_t input_zero_point */ + ip_scale, /* float input_scale */ + k_scales[0], /* float kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + ozp, /* int8_t output_zero_point */ + op_scale, /* float output_scale */ + op_min, /* int8_t output_min */ + op_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t weights_cache */ + op); /* xnn_operator_t* convolution_op_out */ + } else { /* per_channel */ + return xnn_create_convolution2d_nhwc_qs8_qc8w( + pad_top, /* uint32_t input_padding_top */ + pad_right, /* uint32_t input_padding_right */ + pad_bottom, /* uint32_t input_padding_bottom */ + pad_left, /* uint32_t input_padding_left */ + kernel_h, /* uint32_t kernel_height */ + kernel_w, /* uint32_t kernel_width */ + stride_h, /* uint32_t subsampling_height */ + stride_w, /* uint32_t subsampling_width */ + dilation_h, /* uint32_t dilation_height */ + dilation_w, /* uint32_t dilation_width */ + groups, /* uint32_t groups */ + group_input_channels, /* size_t group_input_channels */ + group_output_channels, /* size_t group_output_channels*/ + ip_chan_stride, /* size_t input_channel_stride */ + op_chan_stride, /* size_t output_channel_stride */ + izp, /* int8_t input_zero_point */ + ip_scale, /* float input_scale */ + k_scales, /* const float* kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + ozp, /* int8_t output_zero_point */ + op_scale, /* float output_scale */ + op_min, /* int8_t output_min */ + op_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t weights_cache */ + op); /* xnn_operator_t* convolution_op_out */ + } +} + +/* + * Series of reshape wrapper functions to call xnn_reshape_[de]conv* functions. + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_reshape_convolution2d_nhwc( + xnn_operator_t op, + size_t batch, + size_t in_h, + size_t in_w, + pthreadpool_t pt_pool, + bool per_channel = false, + bool transpose = false, + uint32_t adj_h = 0, + uint32_t adj_w = 0) { + if(transpose) { + TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); + return xnn_reshape_deconvolution2d_nhwc_qs8( + op, /* xnn_operator_t deconvolution_op */ + batch, /* size_t batch_size */ + in_h, /* size_t input_height */ + in_w, /* size_t input_width */ + adj_h, /* uint32_t adjustment_height */ + adj_w, /* uint32_t adjustment_width */ + nullptr, /* size_t* output_height_out */ + nullptr, /* size_t* output_width_out */ + pt_pool); /* pthreadpool_t threadpool */ + } + + size_t workspace_size = SIZE_MAX; + size_t workspace_alignment = SIZE_MAX; + + if (!per_channel) { + return xnn_reshape_convolution2d_nhwc_qs8( + op, /* xnn_operator_t convolution_op */ + batch, /* size_t batch_size */ + in_h, /* size_t input_height */ + in_w, /* size_t input_width */ + &workspace_size, /* size_t* workspace_size */ + &workspace_alignment, /* size_t* workspace_alignment */ + nullptr, /* size_t* output_height_out */ + nullptr, /* size_t* output_width_out */ + pt_pool); /* pthreadpool_t threadpool */ + } else { /* per_channel */ + return xnn_reshape_convolution2d_nhwc_qs8_qc8w( + op, /* xnn_operator_t convolution_op */ + batch, /* size_t batch_size */ + in_h, /* size_t input_height */ + in_w, /* size_t input_width */ + &workspace_size, /* size_t* workspace_size */ + &workspace_alignment, /* size_t* workspace_alignment */ + nullptr, /* size_t* output_height_out */ + nullptr, /* size_t* output_width_out */ + pt_pool); /* pthreadpool_t threadpool */ + } +} + + +/* + * Series of setup wrapper functions to call xnn_setup_[de]conv* functions. + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_setup_convolution2d_nhwc( + xnn_operator_t op, + const int8_t* inp, + int8_t* outp, + bool per_channel = false, + bool transpose = false) { + if(transpose) { + TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); + + return xnn_setup_deconvolution2d_nhwc_qs8( + op, /* xnn_operator_t deconvolution_op */ + inp, /* const int8_t* input */ + outp); /* int8_t* output */ + } + + if (!per_channel) { + return xnn_setup_convolution2d_nhwc_qs8( + op, /* xnn_operator_t deconvolution_op */ + nullptr, /* void workspace */ + inp, /* const int8_t* input */ + outp); /* int8_t* output */ + } else { /* per_channel */ + return xnn_setup_convolution2d_nhwc_qs8_qc8w( + op, /* xnn_operator_t deconvolution_op */ + nullptr, /* void workspace */ + inp, /* const int8_t* input */ + outp); /* int8_t* output */ + } +} + + +/* + * Series of wrapper functions to call xnn_create* and xnn_setup* + * functions for linear + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_create_fully_connected_nc( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + int8_t input_zero_point, + float input_scale, + int8_t kernel_zero_point, + float kernel_scale, + const int8_t* kernel, + const int32_t* bias, + int8_t output_zero_point, + float output_scale, + int8_t output_min, + int8_t output_max, + uint32_t flags, + xnn_operator_t* fully_connected_op_out) { + /* Symmetric quantization forces kzp = 0 */ + TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero." + "But got: ", kernel_zero_point); + return xnn_create_fully_connected_nc_qs8( + input_channels, /* size_t input_channels */ + output_channels, /* size_t output_channels */ + input_stride, /* size_t input_stride */ + output_stride, /* size_t output_stride */ + input_zero_point, /* int8_t input_zero_point */ + input_scale, /* float input_scale */ + kernel_scale, /* float kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + output_zero_point, /* int8_t output_zero_point */ + output_scale, /* float output_scale */ + output_min, /* int8_t output_min */ + output_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t */ + fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */ +} + +C10_ALWAYS_INLINE +enum xnn_status xnnp_reshape_fully_connected_nc( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool) { + return xnn_reshape_fully_connected_nc_qs8( + fully_connected_op, /* xnn_operator_t fully_connected_op */ + batch_size, /* size_t batch_size */ + threadpool); /* pthreadpool_t threadpool */ +} + +C10_ALWAYS_INLINE +enum xnn_status xnnp_setup_fully_connected_nc( + xnn_operator_t fully_connected_op, + const int8_t* input, + int8_t* output) { + return xnn_setup_fully_connected_nc_qs8( + fully_connected_op, /* xnn_operator_t fully_connected_op */ + input, /* const int8_t* input */ + output /* int8_t* output */ + ); +} + +} // namespace xnnp_utils +} // namespace native +} // namespace at + +#endif // USE_XNNPACK diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/Factory.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/Factory.h new file mode 100644 index 0000000000000000000000000000000000000000..bd153aaa67529694c3bfc2494e01ab86a1cbddda --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/Factory.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace at { +namespace native { +namespace mobile { + +Tensor allocate_padded_contiguous_if_needed( + const Tensor& input, + c10::MemoryFormat memory_format); + +// TODO: Remove this function when at::native::empty() is modified to accept a +// custom memory allocator. + +at::Tensor empty_with_tail_padding( + IntArrayRef size, + const caffe2::TypeMeta dtype, + c10::MemoryFormat memory_format, + c10::optional maybe_names); + +} // namespace mobile +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamUtils.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..adb5f1cfa49f9726db5a9304b2546b1ceff52eb3 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamUtils.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +namespace at { +namespace native { + +template +inline std::vector _expand_param_if_needed( + ArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + if (list_param.size() == 1) { + return std::vector(expected_dim, list_param[0]); + } else if ((int64_t)list_param.size() != expected_dim) { + std::ostringstream ss; + ss << "expected " << param_name << " to be a single integer value or a " + << "list of " << expected_dim << " values to match the convolution " + << "dimensions, but got " << param_name << "=" << list_param; + AT_ERROR(ss.str()); + } else { + return list_param.vec(); + } +} + +inline std::vector expand_param_if_needed( + IntArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + return _expand_param_if_needed(list_param, param_name, expected_dim); +} + +inline std::vector expand_param_if_needed( + SymIntArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + return _expand_param_if_needed(list_param, param_name, expected_dim); +} + +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamsHash.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamsHash.h new file mode 100644 index 0000000000000000000000000000000000000000..6b7894cb8549f59d34b9b52d660780b729ada575 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamsHash.h @@ -0,0 +1,104 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +// Hashing machinery for Params +// Fowler–Noll–Vo hash function +// see +// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function +template +struct ParamsHash { + // Params must be a POD because we read out its memory + // contents as char* when hashing + static_assert(std::is_standard_layout_v, "Params is not POD"); + + size_t operator()(const Params& params) const { + auto ptr = reinterpret_cast(¶ms); + uint32_t value = 0x811C9DC5; + for (const auto i : c10::irange(sizeof(Params))) { + value ^= ptr[i]; + value *= 0x01000193; + } + return (size_t)value; + } +}; + +template +struct ParamsEqual { + // Params must be a POD because we read out its memory + // contents as char* when comparing + static_assert(std::is_standard_layout_v, "Params is not POD"); + + bool operator()(const Params& a, const Params& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(Params)) == 0; + } +}; + +// Provide explicit byte-for-byte constructors to avoid uwittingly leaving +// padding bytes unitialized (e.g., when passing Params by value) +template +struct ParamsWrapper { + T pod; + static_assert( + std::is_standard_layout_v, + "ParamsWrapper cannot wrap non-POD data"); + + ParamsWrapper() { + memset(&(this->pod), 0, sizeof(this->pod)); + } + + ParamsWrapper(const ParamsWrapper& other) { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + } + + ParamsWrapper(ParamsWrapper&& other) noexcept { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + } + + ParamsWrapper& operator=(const ParamsWrapper& other) { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + return *this; + } + + ParamsWrapper& operator=(ParamsWrapper&& other) noexcept { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + return *this; + } + + inline friend bool operator==( + const ParamsWrapper& lhs, + const ParamsWrapper& rhs) noexcept { + auto ptr1 = reinterpret_cast(&(lhs.pod)); + auto ptr2 = reinterpret_cast(&(rhs.pod)); + return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0; + } +}; + +// Wrapped version: this allows the outer struct to have custom copy and move +// constructors for additional safety +template +struct ParamsWrapperHash { + // Params must be a POD because we read out its memory + // contents as char* when hashing + static_assert( + std::is_standard_layout_v, + "ParamsWrapper cannot wrap non-POD data"); + + size_t operator()(const ParamsWrapper& params_wrapper) const { + auto ptr = reinterpret_cast(&(params_wrapper.pod)); + uint32_t value = 0x811C9DC5; + for (const auto i : c10::irange(sizeof(params_wrapper.pod))) { + value ^= ptr[i]; + value *= 0x01000193; + } + return (size_t)value; + } +}; + +} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_cpu_dispatch.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..711871d3d5b9aaee43e32ff100d1989c293eb919 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_cpu_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API void _amp_foreach_non_finite_check_and_unscale_(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale); + +} // namespace cpu +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_bin_edges_cpu_dispatch.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_bin_edges_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e9522076acf945233926db9d7562dbf78cdd8db1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_bin_edges_cpu_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API ::std::vector _histogramdd_bin_edges(const at::Tensor & self, at::IntArrayRef bins, c10::optional> range=c10::nullopt, const c10::optional & weight={}, bool density=false); + +} // namespace cpu +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_from_bin_tensors_ops.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_from_bin_tensors_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..b3c813aa7e4c033dafb925515cce8b64f11a8afc --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_histogramdd_from_bin_tensors_ops.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API _histogramdd_from_bin_tensors { + using schema = at::Tensor (const at::Tensor &, at::TensorList, const c10::optional &, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_histogramdd_from_bin_tensors") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor") + static at::Tensor call(const at::Tensor & self, at::TensorList bins, const c10::optional & weight, bool density); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, const c10::optional & weight, bool density); +}; + +struct TORCH_API _histogramdd_from_bin_tensors_out { + using schema = at::Tensor & (const at::Tensor &, at::TensorList, const c10::optional &, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_histogramdd_from_bin_tensors") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_histogramdd_from_bin_tensors.out(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!)") + static at::Tensor & call(const at::Tensor & self, at::TensorList bins, const c10::optional & weight, bool density, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, const c10::optional & weight, bool density, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_differentiable_gru_cell_backward.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_differentiable_gru_cell_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..3e631b902ad1333ca7fc9859ff6e97dcccebdcc1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_differentiable_gru_cell_backward.h @@ -0,0 +1,30 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::_thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) +inline ::std::tuple _thnn_differentiable_gru_cell_backward(const at::Tensor & grad_hy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const c10::optional & input_bias, const c10::optional & hidden_bias) { + return at::_ops::_thnn_differentiable_gru_cell_backward::call(grad_hy, input_gates, hidden_gates, hx, input_bias, hidden_bias); +} + +} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_fused_gru_cell_backward.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_fused_gru_cell_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..b716162d690981abe0c6667c47b95fac92ebd659 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_thnn_fused_gru_cell_backward.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::_thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) +inline ::std::tuple _thnn_fused_gru_cell_backward(const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) { + return at::_ops::_thnn_fused_gru_cell_backward::call(grad_hy, workspace, has_bias); +} + +// aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) +inline ::std::tuple _thnn_fused_gru_cell_backward_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) { + return at::_ops::_thnn_fused_gru_cell_backward_out::call(grad_hy, workspace, has_bias, out0, out1, out2, out3, out4); +} +// aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) +inline ::std::tuple _thnn_fused_gru_cell_backward_outf(const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) { + return at::_ops::_thnn_fused_gru_cell_backward_out::call(grad_hy, workspace, has_bias, out0, out1, out2, out3, out4); +} + +} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/asin_meta_dispatch.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/asin_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..be29d19a4bc4d0a7966a9a160e69121d88ee8c80 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/asin_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor asin(const at::Tensor & self); +TORCH_API at::Tensor & asin_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & asin_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & asin_(at::Tensor & self); + +} // namespace meta +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_ops.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..cd21c5d2a48c90fcd775ec8d7e06f4e7698febd4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_backward_reduce_ops.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API batch_norm_backward_reduce { + using schema = ::std::tuple (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const c10::optional &, bool, bool, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::batch_norm_backward_reduce") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor)") + static ::std::tuple call(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & weight, bool input_g, bool weight_g, bool bias_g); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & weight, bool input_g, bool weight_g, bool bias_g); +}; + +struct TORCH_API batch_norm_backward_reduce_out { + using schema = ::std::tuple (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const c10::optional &, bool, bool, bool, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::batch_norm_backward_reduce") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))") + static ::std::tuple call(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); +}; + +}} // namespace at::_ops diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/bitwise_or.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/bitwise_or.h new file mode 100644 index 0000000000000000000000000000000000000000..3f7078d79979a1e253c826cae6fd917ccd4bd982 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/bitwise_or.h @@ -0,0 +1,67 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & bitwise_or_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_or_Tensor_out::call(self, other, out); +} +// aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & bitwise_or_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_or_Tensor_out::call(self, other, out); +} + +// aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & bitwise_or_out(at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_or_Scalar_out::call(self, other, out); +} +// aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & bitwise_or_outf(const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::bitwise_or_Scalar_out::call(self, other, out); +} + +// aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor bitwise_or(const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_or_Scalar::call(self, other); +} + +// aten::bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor +inline at::Tensor bitwise_or(const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_or_Scalar_Tensor::call(self, other); +} + +// aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor bitwise_or(const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_or_Tensor::call(self, other); +} + +// aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & bitwise_or_out(at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_or_Scalar_Tensor_out::call(self, other, out); +} +// aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & bitwise_or_outf(const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_or_Scalar_Tensor_out::call(self, other, out); +} + +} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e504a5eaf81ddf258dd7682f86be0f21731d5c2e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor cat(const at::ITensorListRef & tensors, int64_t dim=0); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/celu_native.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/celu_native.h new file mode 100644 index 0000000000000000000000000000000000000000..13bd89480c93aebd88c7977a3ad9fd1ac3fc7524 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/celu_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor celu(const at::Tensor & self, const at::Scalar & alpha=1.0); +TORCH_API at::Tensor & celu_out(const at::Tensor & self, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & celu_(at::Tensor & self, const at::Scalar & alpha=1.0); +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/expand.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/expand.h new file mode 100644 index 0000000000000000000000000000000000000000..cc066dac3efe1cba6e32c20a8ec8a28948229574 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/expand.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +namespace symint { + template ::value>> + at::Tensor expand(const at::Tensor & self, at::IntArrayRef size, bool implicit=false) { + return at::_ops::expand::call(self, c10::fromIntArrayRefSlow(size), implicit); + } +} + +namespace symint { + template ::value>> + at::Tensor expand(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit=false) { + return at::_ops::expand::call(self, size, implicit); + } +} + +} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/exponential_compositeexplicitautograd_dispatch.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/exponential_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..8f212862e1cf7f9501fbc502c519f29ed9c28294 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/exponential_compositeexplicitautograd_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor exponential(const at::Tensor & self, double lambd=1, c10::optional generator=c10::nullopt); +TORCH_API at::Tensor & exponential_out(at::Tensor & out, const at::Tensor & self, double lambd=1, c10::optional generator=c10::nullopt); +TORCH_API at::Tensor & exponential_outf(const at::Tensor & self, double lambd, c10::optional generator, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/fill_meta_dispatch.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/fill_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..97fe69bc140d0408365f51e546b2ebbc0350a031 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/fill_meta_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor & fill_(at::Tensor & self, const at::Scalar & value); +TORCH_API at::Tensor & fill_(at::Tensor & self, const at::Tensor & value); + +} // namespace meta +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/hardshrink_backward_meta_dispatch.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/hardshrink_backward_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..819e120a26567a7bc6bb961cadca5b7e1efb3816 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/hardshrink_backward_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor hardshrink_backward(const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd); +TORCH_API at::Tensor & hardshrink_backward_out(at::Tensor & grad_input, const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd); +TORCH_API at::Tensor & hardshrink_backward_outf(const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & grad_input); + +} // namespace meta +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/lift.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/lift.h new file mode 100644 index 0000000000000000000000000000000000000000..c8a3fec981b4cc696c4e34b1e84a8f560dabd626 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/lift.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::lift(Tensor self) -> Tensor +inline at::Tensor lift(const at::Tensor & self) { + return at::_ops::lift::call(self); +} + +// aten::lift.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & lift_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::lift_out::call(self, out); +} +// aten::lift.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & lift_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::lift_out::call(self, out); +} + +} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/lift_fresh_ops.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/lift_fresh_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..2cad94381ed1bad9aab9ec46bcc52151de4d5bd7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/lift_fresh_ops.h @@ -0,0 +1,28 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API lift_fresh { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::lift_fresh") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "lift_fresh(Tensor(a) self) -> Tensor(a)") + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +}} // namespace at::_ops diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/lstm.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/lstm.h new file mode 100644 index 0000000000000000000000000000000000000000..a2bcee47190d39b3bf61666cf4f0b2a9451c6336 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/lstm.h @@ -0,0 +1,35 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor) +inline ::std::tuple lstm(const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::lstm_input::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); +} + +// aten::lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor) +inline ::std::tuple lstm(const at::Tensor & data, const at::Tensor & batch_sizes, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + return at::_ops::lstm_data::call(data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); +} + +} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/max_pool2d_backward_ops.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/max_pool2d_backward_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..c6310664f9e97a948750f802ec6ddd08f3ddb152 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/max_pool2d_backward_ops.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API max_pool2d_backward { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::max_pool2d_backward") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor") + static at::Tensor call(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); +}; + +struct TORCH_API max_pool2d_backward_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::max_pool2d_backward") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "max_pool2d_backward.out(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)") + static at::Tensor & call(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/native_group_norm.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/native_group_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..f7ede3a6c7f902babff2e91c3573737777e9622e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/native_group_norm.h @@ -0,0 +1,91 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) +inline ::std::tuple native_group_norm(const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) { + return at::_ops::native_group_norm::call(input, weight, bias, N, C, HxW, group, eps); +} +namespace symint { + template ::value>> + ::std::tuple native_group_norm(const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) { + return at::_ops::native_group_norm::call(input, weight, bias, N, C, HxW, group, eps); + } +} + +// aten::native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) +inline ::std::tuple native_group_norm_symint(const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps) { + return at::_ops::native_group_norm::call(input, weight, bias, N, C, HxW, group, eps); +} +namespace symint { + template ::value>> + ::std::tuple native_group_norm(const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps) { + return at::_ops::native_group_norm::call(input, weight, bias, N, C, HxW, group, eps); + } +} + +// aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +inline ::std::tuple native_group_norm_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) { + return at::_ops::native_group_norm_out::call(input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); +} +namespace symint { + template ::value>> + ::std::tuple native_group_norm_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) { + return at::_ops::native_group_norm_out::call(input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); + } +} + +// aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +inline ::std::tuple native_group_norm_outf(const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_group_norm_out::call(input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); +} +namespace symint { + template ::value>> + ::std::tuple native_group_norm_outf(const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_group_norm_out::call(input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); + } +} + +// aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +inline ::std::tuple native_group_norm_symint_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps) { + return at::_ops::native_group_norm_out::call(input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); +} +namespace symint { + template ::value>> + ::std::tuple native_group_norm_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps) { + return at::_ops::native_group_norm_out::call(input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); + } +} + +// aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +inline ::std::tuple native_group_norm_symint_outf(const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_group_norm_out::call(input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); +} +namespace symint { + template ::value>> + ::std::tuple native_group_norm_outf(const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_group_norm_out::call(input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); + } +} + +} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/ones_compositeexplicitautograd_dispatch.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/ones_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..835950ad52dde183027152d26ec0444be693f9bf --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/ones_compositeexplicitautograd_dispatch.h @@ -0,0 +1,34 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor ones(at::IntArrayRef size, c10::optional names, at::TensorOptions options={}); +TORCH_API at::Tensor ones(at::IntArrayRef size, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); +TORCH_API at::Tensor & ones_out(at::Tensor & out, at::IntArrayRef size, c10::optional names); +TORCH_API at::Tensor & ones_outf(at::IntArrayRef size, c10::optional names, at::Tensor & out); +TORCH_API at::Tensor ones(at::IntArrayRef size, at::TensorOptions options={}); +TORCH_API at::Tensor ones(at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); +TORCH_API at::Tensor ones_symint(c10::SymIntArrayRef size, at::TensorOptions options={}); +TORCH_API at::Tensor ones_symint(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); +TORCH_API at::Tensor & ones_out(at::Tensor & out, at::IntArrayRef size); +TORCH_API at::Tensor & ones_outf(at::IntArrayRef size, at::Tensor & out); +TORCH_API at::Tensor & ones_symint_out(at::Tensor & out, c10::SymIntArrayRef size); +TORCH_API at::Tensor & ones_symint_outf(c10::SymIntArrayRef size, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/pad_sequence_compositeimplicitautograd_dispatch.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/pad_sequence_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..9906a613c8f40c0b52e882b6967263c39b4415f5 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/pad_sequence_compositeimplicitautograd_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor pad_sequence(at::TensorList sequences, bool batch_first=false, double padding_value=0.0); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/rshift.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/rshift.h new file mode 100644 index 0000000000000000000000000000000000000000..edb93bd66e6c9c107e8c9ab50caa58f5721c29bf --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/rshift.h @@ -0,0 +1,53 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor __rshift__(const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__rshift___Scalar::call(self, other); +} + +// aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor __rshift__(const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__rshift___Tensor::call(self, other); +} + +// aten::__rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & __rshift___out(at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__rshift___Scalar_out::call(self, other, out); +} +// aten::__rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & __rshift___outf(const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::__rshift___Scalar_out::call(self, other, out); +} + +// aten::__rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & __rshift___out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__rshift___Tensor_out::call(self, other, out); +} +// aten::__rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & __rshift___outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::__rshift___Tensor_out::call(self, other, out); +} + +} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/scatter_add_cuda_dispatch.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/scatter_add_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..59328cbc756e7ecc5072c49732afe74680451ed6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/scatter_add_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor scatter_add(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src); +TORCH_API at::Tensor & scatter_add_out(at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src); +TORCH_API at::Tensor & scatter_add_outf(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, at::Tensor & out); +TORCH_API at::Tensor & scatter_add_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src); + +} // namespace cuda +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/size_native.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/size_native.h new file mode 100644 index 0000000000000000000000000000000000000000..094bb83d6e347676db41d567056127a9c9fa1f98 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/size_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API int64_t size(const at::Tensor & self, int64_t dim); +TORCH_API int64_t size(const at::Tensor & self, at::Dimname dim); +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/slogdet_native.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/slogdet_native.h new file mode 100644 index 0000000000000000000000000000000000000000..190cadb20538d27009379dfe8ba5c2532b33b803 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/slogdet_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API ::std::tuple slogdet(const at::Tensor & self); +TORCH_API ::std::tuple slogdet_out(const at::Tensor & self, at::Tensor & sign, at::Tensor & logabsdet); +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/special_erfc.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/special_erfc.h new file mode 100644 index 0000000000000000000000000000000000000000..ee375d259501813756296c9dd05b4fe36026a66a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/special_erfc.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::special_erfc(Tensor self) -> Tensor +inline at::Tensor special_erfc(const at::Tensor & self) { + return at::_ops::special_erfc::call(self); +} + +// aten::special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & special_erfc_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_erfc_out::call(self, out); +} +// aten::special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & special_erfc_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_erfc_out::call(self, out); +} + +} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/special_i1_native.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/special_i1_native.h new file mode 100644 index 0000000000000000000000000000000000000000..b0a6e26aff1cea39c89e19d5036085b9d883281a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/special_i1_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_special_i1_out : public at::meta::structured_special_i1 { +void impl(const at::Tensor & self, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/swapaxes_compositeimplicitautograd_dispatch.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/swapaxes_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..f05972c32a4a7aa87f61f55cf2c2bbbd8d7301a2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/swapaxes_compositeimplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor swapaxes(const at::Tensor & self, int64_t axis0, int64_t axis1); +TORCH_API at::Tensor & swapaxes_(at::Tensor & self, int64_t axis0, int64_t axis1); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/upsample_nearest3d_cpu_dispatch.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/upsample_nearest3d_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..695dfbfbe29a03a579f039a6fe6552ea66a8549a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/upsample_nearest3d_cpu_dispatch.h @@ -0,0 +1,28 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor upsample_nearest3d(const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt); +TORCH_API at::Tensor upsample_nearest3d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt); +TORCH_API at::Tensor & upsample_nearest3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt); +TORCH_API at::Tensor & upsample_nearest3d_outf(const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & out); +TORCH_API at::Tensor & upsample_nearest3d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt); +TORCH_API at::Tensor & upsample_nearest3d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/vdot_compositeexplicitautograd_dispatch.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/vdot_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..1049b7b4874e7f50311bef4b819d0f79a3bf15fe --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/vdot_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor & vdot_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & vdot_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/view_as_ops.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/view_as_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..d7328784f27f5b2f732d0ad0dcfa8cd683e9945a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/view_as_ops.h @@ -0,0 +1,28 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API view_as { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::view_as") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "view_as(Tensor(a) self, Tensor other) -> Tensor(a)") + static at::Tensor call(const at::Tensor & self, const at::Tensor & other); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other); +}; + +}} // namespace at::_ops diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/dnnl_config.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/dnnl_config.h new file mode 100644 index 0000000000000000000000000000000000000000..48925e1e3ab49ae135c6e9c4c501aa2f5e030913 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/dnnl_config.h @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_CONFIG_H +#define DNNL_CONFIG_H + +#include "oneapi/dnnl/dnnl_config.h" + +#endif /* DNNL_CONFIG_H */ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/libshm.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/libshm.h new file mode 100644 index 0000000000000000000000000000000000000000..28024aa2338d1f46ce280abeb92a633f89be1385 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/libshm.h @@ -0,0 +1,46 @@ +#pragma once + +#include + +#ifdef __cplusplus + +void libshm_init(const char* manager_exec_path); + +// Superclass to run a constructor before at::RefcountedMapAllocator +class THManagedMapAllocatorInit { + protected: + THManagedMapAllocatorInit(const char* manager_handle, const char* filename); + std::string manager_handle_; +}; + +// Like a at::RefcountedMapAllocator, but it also makes use of an external +// shared memory manager process to ensure that shared memory regions actually +// get freed in the end (even if processes lose the memory). +class THManagedMapAllocator : private THManagedMapAllocatorInit, + public at::RefcountedMapAllocator { + public: + THManagedMapAllocator( + const char* manager_handle, + const char* filename, + int flags, + size_t size); + + void close() override; + + ~THManagedMapAllocator() override { + close(); + } + + static at::DataPtr makeDataPtr( + const char* manager_handle, + const char* filename, + int flags, + size_t size); + static THManagedMapAllocator* fromDataPtr(const at::DataPtr&); + + const char* manager_handle() const { + return manager_handle_.c_str(); + } +}; + +#endif diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/qnnpack_func.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/qnnpack_func.h new file mode 100644 index 0000000000000000000000000000000000000000..10bbc000192d7e03745e2cf3fb263a9655cde00c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/qnnpack_func.h @@ -0,0 +1,166 @@ +#pragma once + +#include +#include + +namespace qnnpack { +class PrePackConvWeights final { + public: + PrePackConvWeights( + const pytorch_qnnp_operator_t convolution, + const uint8_t* kernel_zero_points, + const uint8_t* kernel, + const int32_t* bias); + + void* getPackedWeights() const + { + return packed_weights_; + } + + int64_t getOutputChannels() const + { + return output_channels_; + } + + ~PrePackConvWeights() + { + if (packed_weights_ != nullptr) { + free(packed_weights_); + } + } + + PrePackConvWeights() = delete; + PrePackConvWeights(const PrePackConvWeights&) = delete; + PrePackConvWeights& operator=(const PrePackConvWeights&) = delete; + + private: + void* packed_weights_ = nullptr; + int64_t output_channels_; +}; + +class PackBMatrix final { + public: + PackBMatrix( + size_t input_channels, + size_t output_channels, + const uint8_t* kernel_zero_points, + const float* requantization_scale, + const uint8_t* kernel, + const int32_t* bias); + + // This constructor is to be used for dynamic mode + // quantization. In dynamic mode, we dont yet support + // per channel quantization, and paying the cost of + // memory allocation for per channel zero point and + // requant scale will hurt performance. + PackBMatrix( + size_t input_channels, + size_t output_channels, + const uint8_t kernel_zero_point, + const float requantization_scale, + const uint8_t* kernel, + const int32_t* bias); + + void* getPackedWeights() const + { + return packed_weights_; + } + + void unpackWeights( + const uint8_t* kernel_zero_points, + int8_t* kernel + ) const; + + size_t getInputChannels() const + { + return input_channels_; + } + + size_t getOutputChannels() const + { + return output_channels_; + } + + ~PackBMatrix() + { + if (packed_weights_ != nullptr) { + free(packed_weights_); + } + } + + PackBMatrix() = delete; + PackBMatrix(const PackBMatrix&) = delete; + PackBMatrix& operator=(const PackBMatrix&) = delete; + + private: + void* packed_weights_ = nullptr; + size_t input_channels_; + size_t output_channels_; +}; + +enum pytorch_qnnp_status qnnpackLinear( + const size_t batch_size, + const size_t input_channels, + const size_t output_channels, + const uint8_t input_zero_point, + const uint8_t* kernel_zero_points, + const float* requantization_scales, + const uint8_t output_zero_point, + const uint8_t output_min, + const uint8_t output_max, + const uint8_t* input, + const size_t input_stride, + void* packed_weights, + uint8_t* output, + const size_t output_stride, + pthreadpool_t threadpool); + +enum pytorch_qnnp_status qnnpackConv( + const pytorch_qnnp_operator_t convolution, + void* packed_weights, + const size_t batch_size, + const size_t input_depth, + const size_t input_height, + const size_t input_width, + const uint8_t input_zero_point, + const uint8_t* input, + const uint8_t* kernel_zero_points, + const float* requantization_scales, + const uint8_t output_zero_point, + const uint8_t output_min, + const uint8_t output_max, + uint8_t* output, + pthreadpool_t threadpool); + +enum pytorch_qnnp_status qnnpackDeConv( + const pytorch_qnnp_operator_t deconvolution, + void* packed_weights, + const size_t batch_size, + const size_t input_height, + const size_t input_width, + const uint8_t input_zero_point, + const uint8_t* input, + const uint8_t* kernel_zero_points, + const float* requantization_scales, + const uint8_t output_zero_point, + const uint8_t output_min, + const uint8_t output_max, + uint8_t* output, + pthreadpool_t threadpool); + +enum pytorch_qnnp_status qnnpackLinearDynamic( + const size_t batch_size, + const size_t input_channels, + const size_t output_channels, + const uint8_t input_zero_point, + const uint8_t* kernel_zero_points, + const float* dequantization_scales, + const uint8_t* input, + const size_t input_stride, + void* packed_weights, + const float* bias, + float* output, + const size_t output_stride, + pthreadpool_t threadpool); + +} // namespace qnnpack