diff --git a/.gitattributes b/.gitattributes index eb80d51c0f40b475ff09def757c067bda3dc0fcc..4a87fd35bd9136500fb82e0ee0c662e6c4136d91 100644 --- a/.gitattributes +++ b/.gitattributes @@ -74,3 +74,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/V tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 filter=lfs diff=lfs merge=lfs -text tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 new file mode 100644 index 0000000000000000000000000000000000000000..600038b3afdecfa54266947d525483495cbe93a4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:647373d0020a53c70bd44d2950f81f6c5edec206899855800a76aabe1ae27e02 +size 745240 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ce836707536649108fbd019f58c714198aa0cb7 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/bounds.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/bounds.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7f419e7dcef74aeec1ed42c7e10b6b1b85f8da0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/bounds.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/decomposition.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/decomposition.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93e5c50b28f948ab0ff8a2c0f05a047d43a6d723 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/decomposition.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa8029c48cf19ca616891fe7c6af5ab8a890503f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14f6a4caac3196547e9f7784916435f22897a4f9 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de7a018ce934df91ae0c37e6abaaea9755f96a5b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_helpers.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_helpers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e887265be9312d995629aedb8f822c12b668e75 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_helpers.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_heuristics.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_heuristics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aefeb5f171eceb510ff80f8161b372952509dc05 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_heuristics.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed1b4dc515a3716d9c9c1782d8dc441309bea9e7 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/bounds.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/bounds.py new file mode 100644 index 0000000000000000000000000000000000000000..38545d5663bd6138fbf770c377f7598f1a38b74f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/bounds.py @@ -0,0 +1,124 @@ +import operator +from functools import partial +from typing import Any, Callable, Dict + +from sympy import Expr + +import torch +from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges +from .ir import InterpreterShim, LoopBody, LoopBodyBlock +from .utils import cache_on_self, dominated_nodes +from .virtualized import V + + +class BoundVars: + """ + Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() + It exposes the ranges of the nodes in the `bounds` variable + + Note. A current limitation of this analysis is that it just works on a per-loop basis. + We should be able to propagate the bounds between across the whole graph. This may benefit + the case a bounded variable is returned by a kernel and fed into another. + """ + + def __init__(self, loop_body: LoopBody) -> None: + self.loop_body = loop_body + self.replacement_vals = { + k: ValueRanges[Expr](0, v - 1) + if (isinstance(v, int) or v.is_number) + else bound_sympy(v) + for k, v in loop_body.var_ranges.items() + } + # avoid computing these values, pessimistically assume that they are unbounded + self.unbounded_vars = dominated_nodes( + node + for node in self.loop_body.get_nodes() + if node.target in ["load", "reduction", operator.getitem] + or "masked_subblock" in node.target + ) + # To access this variable call `get_bounds()` + self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {} + + @cache_on_self + def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: + submodules = self.swap_submodules(self.loop_body.submodules) + + # Initialize the environment with the unbounded variables + for node in self.unbounded_vars: + # we need to evaluate masked_subblock to recurse, and we need to set indirect values + if not isinstance(node.target, str) or ( + "masked_subblock" not in node.target + and "set_indirect" not in node.target + ): + self._bounds[node] = ValueRanges[Expr].unknown() + + with V.set_ops_handler(ValueRangeAnalysis()): + interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) + interpreter.run(V.get_ops_handler(), initial_env=self._bounds) + return self._bounds + + def swap_submodules( + self, submodules: Dict[str, Callable[..., Any]] + ) -> Dict[str, Callable[..., ValueRanges[Expr]]]: + result: Dict[str, Callable[..., ValueRanges[Expr]]] = {} + for key in submodules.keys(): + if key == "get_index": + result[key] = self.get_index + elif "masked_subblock" in key: + subblock = self.loop_body.subblocks[key] + # The result within the lambda will reference to the final + # set of modules at the end of the for-loop as it stores a reference to it + + # bind subblock in a function because python lambdas close over by reference + # moving the lambda out of make_fn would close over the reference to subblock, + # so all lambdas would have the same subblock reference that is the final + # subblock in the loop + def make_fn(subblock): + return lambda mask, value: self.masked_subblock( + subblock, self._bounds, mask, value, result + ) + + result[key] = make_fn(subblock) + + elif "set_indirect" in key: + idx = int(key[len("set_indirect") :]) + var = self.loop_body.indirect_vars[idx] + indirect = partial(self.set_indirect, var) + result[key] = indirect + else: + assert "scan" in key + result[key] = submodules[key] + + return result + + def masked_subblock( + self, + subblock: LoopBodyBlock, + env: Dict[torch.fx.Node, ValueRanges[Expr]], + mask: Any, + value: Any, + submodules: Dict[str, Callable[..., Any]], + ) -> ValueRanges[Expr]: + interp = InterpreterShim(subblock.graph, submodules) + interp.run(V.get_ops_handler(), initial_env=env) + output = [node for node in subblock.graph.nodes if node.target == "output"] + assert len(output) == 1 + # dont bother unioning with value since the load from buffer will be + # pessimistically assumed to be inf anyway + return interp.env[output[0]] + + def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]: + assert isinstance(new, ValueRanges) + self.replacement_vals[old] = new + return new + + def get_index(self, name: Expr) -> ValueRanges[Expr]: + expr = self.loop_body.indexing_exprs[name] + bound = self.replacement_vals.get(expr) + if bound is None: + bound = bound_sympy(expr, self.replacement_vals) + # The following assertion is true at the time of this writing + # We don't assert is as to not execute bound_sympy when bound is not None + # assert bound is None or bound == bound_sympy(expr, self.replacement_vals) + self.replacement_vals[name] = bound + return bound diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a22915764d0170bb1648f947b552849ac670a5e4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54d80103ce6f7a03c21db01e0bb604484d7bc557 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c84c3ff945103ab28f9253851c4b622bb243faa Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dbd3a248b5f0da87d82545f62bbf9c1125277e4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9619eb741ae22c30917951701c138c358534356b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2dfca704b65bf85626ad54e0f922cf594fc1785f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py @@ -0,0 +1,1755 @@ +import contextlib +import dataclasses +import functools +import itertools +import logging +import operator +import re +from itertools import chain +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) + +import sympy +from sympy.printing.printer import Printer + +import torch +import torch.fx +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch.utils import _pytree as pytree +from torch.utils._sympy.value_ranges import ValueRanges + +from .. import config, metrics +from ..utils import ( + DeferredLineBase, + do_bench, + free_symbol_startswith, + IndentedBuffer, + sympy_dot, + sympy_index_symbol, + sympy_subs, + unique, +) +from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V + +if TYPE_CHECKING: + from ..ir import TensorBox + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") + + +def data_type_logger(msg): + if schedule_log.isEnabledFor(logging.DEBUG): + schedule_log.debug("Data type propagation: %s", msg) + + +@dataclasses.dataclass +class WorkspaceArg: + """A temporary buffer used for a single kernel, then discarded. + + Not registered as a traditional buffer since there are no users, + so it would be dead code eliminated. + """ + + nbytes: sympy.Expr + zero_fill: bool + + +@dataclasses.dataclass +class TensorArg: + name: str + buffer: str + dtype: torch.dtype + offset: sympy.Expr = sympy.Integer(0) + + +@dataclasses.dataclass +class SizeArg: + name: str + expr: sympy.Expr + + +@dataclasses.dataclass +class DeviceCodegen: + scheduling: type + wrapper_codegen: type + + +KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg] + +device_codegens: Dict[str, DeviceCodegen] = {} + + +class DeviceOpOverrides: + def import_get_raw_stream_as(self, name): + raise NotImplementedError() + + def set_device(self, device_idx): + raise NotImplementedError() + + def synchronize(self): + raise NotImplementedError() + + def device_guard(self, device_idx): + raise NotImplementedError() + + +device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {} + + +# The code generated by Inductor consists of two main parts: kernel code and wrapper code. +# For any new backend looking to integrate with Inductor, customization of these two main +# parts are necessary to generate its specific code. +# +# Kernel code generation is determined by different Scheduling. Consequently, a new +# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently, +# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively. +# +# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code +# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen, +# and override specific member functions to create backend-specific Python wrapper code. +# +# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part +# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces +# provide flexibility to the backend. A backend can choose to implement these classes from scratch, +# or reuse them by extending and overriding as necessary. And Inductor provides the registration API, +# register_backend_for_device, to equip a new backend at runtime. +# +# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces. +# This backend can be used as a reference: +# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 +def register_backend_for_device( + device: str, device_scheduling: type, device_wrapper_codegen: type +): + device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen) + + +def get_scheduling_for_device(device: str): + return device_codegens[device].scheduling if device in device_codegens else None + + +def get_wrapper_codegen_for_device(device: str): + return ( + device_codegens[device].wrapper_codegen if device in device_codegens else None + ) + + +def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes): + from ..ir import FlexibleLayout + + # added contiguous index prevents reordering + return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] + + +def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides): + device_op_overrides_dict[device] = device_op_overrides + + +def get_device_op_overrides(device: str): + assert isinstance(device, str) + + if not device_op_overrides_dict.keys(): + from .cuda import device_op_overrides # noqa: F401 + + if device in device_op_overrides_dict.keys(): + return device_op_overrides_dict[device] + + return DeviceOpOverrides() + + +@functools.lru_cache(None) +def boolean_ops(): + return ( + "is_inf", + "is_nan", + "bitwise_xor", + "logical_not", + "signbit", + "le", + "lt", + "ge", + "gt", + "eq", + "ne", + ) + + +DTYPE_TO_COMPUTATION_DTYPE = { + torch.bfloat16: torch.float, + torch.float16: torch.float, + **{ + dtype: dtype + for dtype in [ + torch.bool, + torch.float32, + torch.float64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + ] + }, +} + + +class DataTypePropagation: + def __init__(self, body) -> None: + self.body = body + self.graphs: Dict[Union[Callable[..., Any], str], Any] = { + "root": body.root_block.graph + } + for k, v in body.subblocks.items(): + self.graphs[k] = v.graph + + def deduce_node_dtype_by_inputs(self, node: torch.fx.Node): + inputs = node.all_input_nodes + input_nodes = [ + n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" + ] + if len(input_nodes) == 0: + return None + + all_input_nodes_propogated = all( + OptimizationContext.key in n.meta + and n.meta[OptimizationContext.key].dtype is not None + for n in input_nodes + ) + if not all_input_nodes_propogated: + return None + + return functools.reduce( + torch.promote_types, + [n.meta[OptimizationContext.key].dtype for n in input_nodes], + ) + + def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node): + sub_graph = self.graphs[node.target] + dtype = self.propagate_graph(sub_graph) + assert dtype + return dtype + + def deduce_node_dtype(self, node: torch.fx.Node): + if node.target in boolean_ops(): + return torch.bool + + if node.op == "placeholder": + return None + + if node.target == "output": + # we can infer output node if it only have 1 arg + if len(node.args) != 1: + return None + + if node.target in ( + "to_dtype", + "index_expr", + ): + return node.args[-1] + + if node.target in ( + "rand", + "randn", + ): + return torch.float + + if node.target in ( + "get_index", + "index_expr", + ): + return torch.int64 + + if node.target in ( + "load", + "store", + "store_reduction", + ): + buf_name = node.args[1] + return V.graph.get_dtype(buf_name) # type: ignore[arg-type] + + if node.target == operator.getitem: + return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type] + + assert isinstance(node.target, str) + + if node.target == "reduction": + return node.args[1] + + if node.target == "constant": + return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index] + + if node.target.startswith("masked_subblock"): + return self.deduce_node_dtype_by_subgraph(node) + + return self.deduce_node_dtype_by_inputs(node) + + def propagate_graph(self, graph: torch.fx.Graph): + assert graph.nodes + graph_dtype = None + # For masked_subblock, we use output's dtype to represent + # the dtype of this subgraph. For other cases, graph_dtype + # might be None + for node in graph.nodes: + if OptimizationContext.key in node.meta: + opt_ctx = node.meta[OptimizationContext.key] + else: + opt_ctx = OptimizationContext() + + opt_ctx.dtype = self.deduce_node_dtype(node) + node.meta[OptimizationContext.key] = opt_ctx + if node.target == "output": + graph_dtype = opt_ctx.dtype + return graph_dtype + + def propagate(self): + self.propagate_graph(self.graphs["root"]) + + @classmethod + def propagate_loopbody(cls, body): + return cls(body).propagate() + + @classmethod + def propagate_scheduler_node(cls, node): + from ..ir import LoopBody + from ..scheduler import SchedulerNode + + assert isinstance(node, SchedulerNode) + assert isinstance(node._body, LoopBody) + DataTypePropagation.propagate_loopbody(node._body) + + +class ExprPrinter(Printer): + @staticmethod + def paren(string): + def all_in_parens(string): + if string[0] != "(" or len(string) < 2: + return False + count = 1 + for i, char in enumerate(string[1:]): + if char == "(": + count += 1 + elif char == ")": + count -= 1 + if count == 0 and i != len(string) - 2: + return False + assert count == 0 + return True + + if ( + isinstance(string, CSEVariable) + or re.match(r"^[a-z0-9_.]+$", string, re.I) + or re.match(r"^\([^)]*\)$", string, re.I) + or string == "" + ): + return string + # don't put extra parens for strings that are already wrapped in parens + if all_in_parens(string): + return string + return f"({string})" + + def _print_Infinity(self, expr): + return "math.inf" + + def _print_NegativeInfinity(self, expr): + return "-math.inf" + + def _print_Relational(self, expr): + return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) + + def _print_Mul(self, expr): + return "*".join(map(self.paren, map(self._print, expr.args))) + + def _print_Add(self, expr): + return " + ".join(map(self.paren, map(self._print, expr.args))) + + def _print_Mod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_FloorDiv(self, expr): + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + + def _print_CleanDiv(self, expr): + return self._print_FloorDiv(expr) + + def _print_GreaterThan(self, expr): + # GreaterThan: >= + # StrictlyGreaterThan: > + # Go figure... + return " >= ".join(map(self.paren, map(self._print, expr.args))) + + def _print_align(self, expr): + assert len(expr.args) == 1 + return f"align({self._print(expr.args[0])})" + + +class PythonPrinter(ExprPrinter): + def _print_ModularIndexing(self, expr): + x, div, mod = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + mod = self.paren(self.doprint(mod)) + if div != "1": + x = f"({x} // {div})" + return f"{x} % {mod}" + + def _print_FloorDiv(self, expr): + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + return f"({x} // {div})" + + def _helper_sqrt(self, expr): + return f"math.sqrt({self._print(expr)})" + + def _print_Pow(self, expr): + # Pow() confuses triton + base, exp = expr.args + # NB: Remember this is sizevar computation! You don't typically + # expect to have to do floating point computation including exponents + # in sizevar compute. Instead of adding support for floating + # point pow, you should make upstream retranslate the Sympy expression + # into Tensor expressions earlier and do that instead. + if exp == 0.5: + return self._helper_sqrt(base) + elif exp == -0.5: + return "1/" + self._helper_sqrt(base) + base = self._print(base) + assert exp == int(exp), exp + exp = int(exp) + if exp > 0: + return "*".join([self.paren(base)] * exp) + elif exp < 0: + return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) + else: # exp == 0 + return "1" + + def _print_floor(self, expr): + assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return f"abs({self._print(expr.args[0])})" + + def _print_Max(self, expr): + assert len(expr.args) >= 2 + return f"max({', '.join(map(self._print, expr.args))})" + + def _print_Min(self, expr): + assert len(expr.args) >= 2 + return f"min({', '.join(map(self._print, expr.args))})" + + def _print_cos(self, expr): + assert len(expr.args) == 1 + return f"math.cos({self._print(expr.args[0])})" + + def _print_cosh(self, expr): + assert len(expr.args) == 1 + return f"math.cosh({self._print(expr.args[0])})" + + def _print_acos(self, expr): + assert len(expr.args) == 1 + return f"math.acos({self._print(expr.args[0])})" + + def _print_sin(self, expr): + assert len(expr.args) == 1 + return f"math.sin({self._print(expr.args[0])})" + + def _print_sinh(self, expr): + assert len(expr.args) == 1 + return f"math.sinh({self._print(expr.args[0])})" + + def _print_asin(self, expr): + assert len(expr.args) == 1 + return f"math.asin({self._print(expr.args[0])})" + + def _print_tan(self, expr): + assert len(expr.args) == 1 + return f"math.tan({self._print(expr.args[0])})" + + def _print_tanh(self, expr): + assert len(expr.args) == 1 + return f"math.tanh({self._print(expr.args[0])})" + + def _print_atan(self, expr): + assert len(expr.args) == 1 + return f"math.atan({self._print(expr.args[0])})" + + def _print_Round(self, expr): + assert len(expr.args) == 1 + return f"round({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr): + assert len(expr.args) == 2 + number, ndigits = expr.args + assert isinstance(ndigits, sympy.Integer) + return f"round({self._print(number)}, {ndigits})" + + +class OpOverrides: + def __init__(self, parent): + super().__init__() + self._parent = parent + + def __getattr__(self, item): + return getattr(self._parent, item) + + @staticmethod + def identity(value): + # used to trigger cse + return value + + @staticmethod + def constant(value, dtype): + return repr(value) + + @staticmethod + def reciprocal(x): + return ops.truediv("1", x) + + @staticmethod + def square(x): + return ops.mul(x, x) + + @staticmethod + def bitwise_not(x): + return f"~{ExprPrinter.paren(x)}" + + @staticmethod + def logical_not(a): + return f"{ExprPrinter.paren(a)} == 0" + + @staticmethod + def bitwise_and(x, y): + return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_or(x, y): + return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_xor(x, y): + return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_left_shift(x, y): + return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_right_shift(x, y): + return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}" + + @staticmethod + def remainder(a, b): + r = ops.mod(a, b) + return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r) + + @staticmethod + def load_seed(name, offset): + return ops.load(name, sympy.Integer(offset)) + + @classmethod + def _initialize_pointwise_overrides(cls, target): + assert target in {"triton", "cpp", "cppvec"}, target + + def pointwise_factory_1(impl): + def func(x): + return impl.format(x=x) + + return func + + def pointwise_factory_2(impl): + def func(x, y): + return impl.format(x=x, y=y) + + return func + + for funcname, data in pointwise_overrides_data.items(): + impl = getattr(data, target) + if isinstance(impl, str): + nof_args = 2 if "{y}" in impl else 1 + # extend the following dictionary with factory + # functions for a specific number of arguments as + # needed: + factory = {1: pointwise_factory_1, 2: pointwise_factory_2}[nof_args] + setattr(cls, funcname, staticmethod(factory(impl))) + + +@dataclasses.dataclass +class OverridesData: + name: str + cpp: str + triton: Optional[str] = None # None when not impl in libdevice/triton + cppvec: Optional[str] = None # None when not impl in aten/.../vec + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = ( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +pointwise_overrides_data: Dict[str, OverridesData] = dict( + airy_ai=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="airy_ai_forward({x})", + name="special_airy_ai", + ), + bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="bessel_j0_forward({x})", + triton="libdevice.j0({x})", + name="special_bessel_j0", + ), + bessel_j1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="bessel_j1_forward({x})", + triton="libdevice.j1({x})", + name="special_bessel_j1", + ), + bessel_y0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="bessel_y0_forward({x})", + triton="libdevice.y0({x})", + name="special_bessel_y0", + ), + bessel_y1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="bessel_y1_forward({x})", + triton="libdevice.y1({x})", + name="special_bessel_y1", + ), + digamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_digamma({x})", + cppvec="{x}.digamma()", + name="digamma", + ), + # no cpp nor triton implementation for entr, it is defined as decomposition + # erf, erfc + erfcx=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_erfcx({x})", + triton="libdevice.erfcx({x})", + name="special_erfcx", + ), + # erfinv, exp2, expit, gammaln + igamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_igamma({x}, {y})", + name="igamma", + ), + igammac=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_igammac({x}, {y})", + name="igammac", + ), + gammainc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_igamma({x}, {y})", + name="special_gammainc", + ), + gammaincc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_igammac({x}, {y})", + name="special_gammaincc", + ), + i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_i0({x})", + triton="libdevice.cyl_bessel_i0({x})", + cppvec="{x}.i0()", + name="i0", + ), + i0e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_i0e({x})", + cppvec="{x}.i0e()", + name="special_i0e", + ), + i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_i1({x})", + triton="libdevice.cyl_bessel_i1({x})", + name="special_i1", + ), + i1e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_i1e({x})", + name="special_i1e", + ), + log_ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_log_ndtr({x})", + name="special_log_ndtr", + ), + # logit + modified_bessel_i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="modified_bessel_i0_forward({x})", + triton="libdevice.cyl_bessel_i0({x})", + name="special_modified_bessel_i0", + ), + modified_bessel_i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="modified_bessel_i1_forward({x})", + triton="libdevice.cyl_bessel_i1({x})", + name="special_modified_bessel_i1", + ), + modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="modified_bessel_k0_forward({x})", + name="special_modified_bessel_k0", + ), + modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="modified_bessel_k1_forward({x})", + name="special_modified_bessel_k1", + ), + # multigamma + ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_ndtr({x})", + name="special_ndtr", + ), + ndtri=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_ndtri({x})", + name="special_ndtri", + ), + polygamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_polygamma({y}, {x})", + name="polygamma", + ), + # psi - alias to digamma + # round + scaled_modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="scaled_modified_bessel_k0_forward({x})", + name="special_scaled_modified_bessel_k0", + ), + scaled_modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="scaled_modified_bessel_k1_forward({x})", + name="special_scaled_modified_bessel_k1", + ), + # sinc + spherical_bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="spherical_bessel_j0_forward({x})", + name="special_spherical_bessel_j0", + ), + zeta=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="zeta({x}, {y})", + name="special_zeta", + ), + chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="chebyshev_polynomial_t_forward({x}, {y})", + name="special_chebyshev_polynomial_t", + ), + chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="chebyshev_polynomial_u_forward({x}, {y})", + name="special_chebyshev_polynomial_u", + ), + chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="chebyshev_polynomial_v_forward({x}, {y})", + name="special_chebyshev_polynomial_v", + ), + chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="chebyshev_polynomial_w_forward({x}, {y})", + name="special_chebyshev_polynomial_w", + ), + legendre_polynomial_p=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="legendre_polynomial_p_forward({x}, {y})", + name="special_legendre_polynomial_p", + ), + shifted_chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="shifted_chebyshev_polynomial_t_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_t", + ), + shifted_chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="shifted_chebyshev_polynomial_u_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_u", + ), + shifted_chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="shifted_chebyshev_polynomial_v_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_v", + ), + shifted_chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="shifted_chebyshev_polynomial_w_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_w", + ), + hermite_polynomial_h=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="hermite_polynomial_h_forward({x}, {y})", + name="special_hermite_polynomial_h", + ), + hermite_polynomial_he=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="hermite_polynomial_he_forward({x}, {y})", + name="special_hermite_polynomial_he", + ), + laguerre_polynomial_l=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="laguerre_polynomial_l_forward({x}, {y})", + name="special_laguerre_polynomial_l", + ), +) + + +# Use mypy to check protocol implemented correctly +def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]: + return h + + +class DeferredLine(DeferredLineBase): + """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" + + def __init__(self, name, line): + super().__init__(line) + self.name = name + assert not isinstance(line, DeferredLineBase) + + def __call__(self): + if all( + self.name not in x + for x in ( + V.graph.removed_buffers, + V.kernel.removed_buffers, + V.graph.inplaced_to_remove, + V.kernel.inplaced_to_remove, + ) + ): + return self.line + return None + + def _new_line(self, line): + return DeferredLine(self.name, line) + + +class BracesBuffer(IndentedBuffer): + def indent(self, offset=1): + @contextlib.contextmanager + def ctx(): + for _ in range(offset): + self.writeline("{") + self._indent += 1 + for _ in range(-offset): + self._indent -= 1 + self.writeline("}") + yield + for _ in range(-offset): + self.writeline("{") + self._indent += 1 + for _ in range(offset): + self._indent -= 1 + self.writeline("}") + + return ctx() + + +class InplacedBuffer(NamedTuple): + inner_name: str + other_names: List[str] + + +class KernelArgs: + @staticmethod + def _lookup(prefix, odict, name): + assert isinstance(name, (str, sympy.Symbol)) + if name not in odict: + odict[name] = f"{prefix}{len(odict)}" + return odict[name] + + def __init__(self, sizevars=None): + self.input_buffers = dict() + self.output_buffers = dict() + self.inplace_buffers = dict() + self.sizevars = sizevars or dict() + self.workspace_arg = None + + def __repr__(self): + return "KernelArgs({})".format( + ", ".join( + map( + repr, + [ + self.input_buffers, + self.output_buffers, + self.inplace_buffers, + self.sizevars, + ], + ) + ) + ) + + def _buffer_is_marked_removed(self, name): + return isinstance(name, str) and name.startswith("REMOVED") + + def input(self, name): + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.output_buffers: + return self.output_buffers[name] + if name in self.inplace_buffers: + return self.inplace_buffers[name].inner_name + if name.startswith("seed"): + return self._lookup("seed", self.input_buffers, name) + return self._lookup("in_ptr", self.input_buffers, name) + + def output(self, name): + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.inplace_buffers: + return self.inplace_buffers[name].inner_name + return self._lookup("out_ptr", self.output_buffers, name) + + def make_inplace(self, input_name, output_name): + assert output_name not in self.inplace_buffers + if input_name in self.inplace_buffers: + buf = self.inplace_buffers[input_name] + buf.other_names.append(output_name) + self.inplace_buffers[output_name] = buf + else: + buf = InplacedBuffer( + f"in_out_ptr{len(unique(self.inplace_buffers.values()))}", + [input_name, output_name], + ) + self.inplace_buffers[input_name] = buf + self.inplace_buffers[output_name] = buf + + def workspace(self, nbytes: sympy.Expr, zero_fill: bool): + if self.workspace_arg is None: + self.workspace_arg = WorkspaceArg(nbytes, zero_fill) + return "ws_ptr", 0 + + offset = self.workspace_arg.nbytes + zero_fill = zero_fill or self.workspace_arg.zero_fill + self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill) + return "ws_ptr", offset + + def seed_offset(self, name, value): + if value in self.sizevars: + return self.sizevars[value] + if name in self.sizevars.values(): + name = ( + f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}" + ) + self.sizevars[value] = name + return name + + def size(self, name): + if str(name) == "seed": + self.sizevars["seed"] = "seed" + return "seed" + return self._lookup("ks", self.sizevars, name) + + def call_names(self): + return chain( + self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() + ) + + def wrap_ptr_arg(self, buf, dtype): + return buf + + def wrap_size_arg(self, size): + return str(size) + + def cpp_argdefs(self): + from .cpp import DTYPE_TO_CPP, INDEX_TYPE + + call_args = [] + arg_defs = [] + arg_types = [] + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + outer = inplaced.other_names[-1] + inner = inplaced.inner_name + dtype = V.graph.get_dtype(outer) + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"{cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.input_buffers.items(): + if outer in self.inplace_buffers: + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"const {cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"const {cpp_dtype}*") + for outer, inner in self.output_buffers.items(): + if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"{cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.sizevars.items(): + arg_defs.append(f"const {INDEX_TYPE} {inner}") + call_args.append(self.wrap_size_arg(outer)) + arg_types.append(f"const {INDEX_TYPE}") + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + assert self.workspace_arg is None, "Workspace not supported on CPU " + return arg_defs, call_args, arg_types + + def python_argdefs(self): + arg_defs = [] + call_args = [] + precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = [] + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + arg_defs.append(inplaced.inner_name) + call_args.append(inplaced.other_names[-1]) + precompile_args.append( + TensorArg( + name=inplaced.inner_name, + buffer=inplaced.other_names[-1], + dtype=V.graph.get_dtype(inplaced.other_names[-1]), + ) + ) + for outer, inner in chain( + self.input_buffers.items(), self.output_buffers.items() + ): + if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): + continue + arg_defs.append(inner) + call_args.append(outer) + precompile_args.append( + TensorArg( + name=inner, + buffer=outer, + dtype=V.graph.get_dtype(outer), + ) + ) + for outer, inner in self.sizevars.items(): + arg_defs.append(inner) + call_args.append(outer) + precompile_args.append(SizeArg(inner, outer)) + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + if self.workspace_arg is not None: + arg_defs.append("ws_ptr") + call_args.append("workspace") + precompile_args.append(self.workspace_arg) + + return arg_defs, call_args, precompile_args + + def aliases(self): + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + for other in inplaced.other_names: + if ( + other in V.graph.inplaced_to_remove + or other in V.kernel.inplaced_to_remove + ): + continue + if other in self.input_buffers: + yield self.input_buffers[other], inplaced.inner_name + if other in self.output_buffers: + yield self.output_buffers[other], inplaced.inner_name + + def is_removed(self, name): + def _is_removed(name, buffers): + return name not in buffers or self._buffer_is_marked_removed(buffers[name]) + + return _is_removed(name, self.output_buffers) and _is_removed( + name, self.inplace_buffers + ) + + # Includes inplace buffers, excludes removed buffers. Essentially, + # after you do a call into this kernel, which buffers actually contain + # updated data? Modeled off of python_argdefs. + def live_output_buffers(self): + live_outs = set() + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + live_outs.add(inplaced.other_names[-1]) + for outer, inner in self.output_buffers.items(): + if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): + continue + live_outs.add(outer) + return live_outs + + +class CSEVariable: + """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. + To do so, the backends can simply overload `Kernel.create_cse_var` + The "CSEVariable.update_on_args" method gives you a hook for annotations + See example of TritonCSEVariable in triton.py + """ + + def __init__(self, name, bounds: ValueRanges[Any]): + assert isinstance(bounds, ValueRanges) + self.name = name + self.bounds = bounds + + def __str__(self): + return self.name + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other) -> bool: + return type(other) == type(self) and other.name == self.name + + def update_on_args(self, name, args, kwargs): + pass + + +class CppWrapperKernelArgs(KernelArgs): + def wrap_ptr_arg(self, buf, dtype): + from .cpp import DTYPE_TO_CPP + + if config.abi_compatible: + # In the abi_compatible model, we just return the buf here. + # We will form correct call args later in wrapper.generate_kernel_all. + return buf + else: + return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())" + + def wrap_size_arg(self, size): + return f"{size}" + + +class CSE: + """Common subexpression elimination""" + + def __init__( + self, + prefix="", + suffix="", + name_prefix="tmp", + iter_buffers=None, + store_cache=None, + reduction_cache=None, + varname_map=None, + ): + self.prefix = prefix + self.suffix = suffix + self.cache = {} + self.name_prefix = name_prefix + self.store_cache = store_cache or {} + self.reduction_cache = reduction_cache or {} + self.iter_buffer_ids = iter_buffers or itertools.count() + self.invalidated_stores = set() + self.varname_map = varname_map or {} + + def invalidate(self, keep_vars: Set[str]): + for name, tmp in list(self.store_cache.items()): + if tmp not in keep_vars: + del self.store_cache[name] + self.invalidated_stores.add(name) + self.cache = {k: v for k, v in self.cache.items() if v in keep_vars} + + def clone(self): + # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional + return CSE( + prefix=self.prefix, + suffix=self.suffix, + name_prefix=self.name_prefix, + iter_buffers=self.iter_buffer_ids, + store_cache=self.store_cache, + varname_map=self.varname_map, + ) + + def generate( + self, + buffer: IndentedBuffer, + expr: Union[str, CSEVariable, OpsValue, IndentedBuffer], + *, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + write=True, + assignment=True, + ) -> CSEVariable: + if isinstance(expr, OpsValue): + expr = expr.value + + assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr) + assert write or assignment + if isinstance(expr, CSEVariable): + # If the expressions were always created with all the information, we could + # assert expr.bounds == bounds, but sometimes the expression is created + # with the loose ValueRanges.unknown(), so we need to tighten the bounds + expr.bounds = expr.bounds.tighten(bounds) + return expr + cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr + var = self.cache.get(cache_key, None) + if not var: + var = self.newvar(bounds) if assignment else None + self.cache[cache_key] = var + if write: + if V.kernel.current_node: + V.kernel.current_node.codegen_originating_info( + buffer, only_once=True + ) + if isinstance(expr, IndentedBuffer): + if assignment: + buffer.writeline(f"{self.prefix}{var} =") + buffer.splice(expr) + buffer.writeline(self.suffix) + else: + if assignment: + line = f"{self.prefix}{var} = {expr}{self.suffix}" + else: + line = f"{expr}{self.suffix}" + buffer.writeline(line) + else: + var.bounds = var.bounds.tighten(bounds) + + return var + + def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable: + var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" + var = V.kernel.create_cse_var(var_name, bounds) + self.varname_map[var_name] = var + return var + + +class IndirectAssertLine(DeferredLineBase): + def __init__(self, line, assert_fn, var, mask, size_map): + self.var = var + self.mask = mask + self.line = line + self.assert_fn = assert_fn + self.size_map = size_map + + def __call__(self): + size, size_str = self.size_map[(self.var, self.mask)] + + # We assert if we've not been able to prove the bound + assert_min = (self.var.bounds.lower >= 0) != sympy.true + assert_max = (self.var.bounds.upper < size) != sympy.true + + # FooBar interview question + if not (assert_min or assert_max): + return None + elif assert_min and assert_max: + # The conditions need to be in parens because of Python's operator precedence. + # It'd be less error-prone to use and/or/not, which is suported by triton + cond = f"(0 <= {self.var}) & ({self.var} < {size_str})" + cond_print = f"0 <= {self.var} < {size_str}" + elif assert_min: + cond = f"0 <= {self.var}" + cond_print = cond + else: + assert assert_max + cond = f"{self.var} < {size_str}" + cond_print = cond + + if self.mask: + cond = f"({cond}) | ~{self.mask}" + return self.line.format( + assert_fn=self.assert_fn, cond=cond, cond_print=cond_print + ) + + def _new_line(self, line): + return IndirectAssertLine( + line, self.assert_fn, self.var, self.mask, self.size_map + ) + + +class CodeGen: + def __init__(self): + super().__init__() + self.exit_stack = contextlib.ExitStack() + + def __enter__(self): + self.exit_stack.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + +class Kernel(CodeGen): + newvar_prefix = "" + suffix = "" + overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None + # TODO: these look dead, but with all the getattr it's hard to tell... + load_format: None = None + store_format: None = None + + def __init__(self, args=None, increase_kernel_count=True): + super().__init__() + if increase_kernel_count: + metrics.generated_kernel_count += 1 + self.args = args or KernelArgs() + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + self.cse: CSE = CSE(self.newvar_prefix, self.suffix) + self.must_keep_buffers = set() + self.store_buffer_names = set() + self._load_mask = None + # set in set_current_node + self.current_node = None + self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None + # Upper bounds for indirect_indexing and their str representation + # NB: None, None is never stored in map, but it is the assumed + # "not set" value for the dict + self.indirect_max_sizes: Dict[ + Tuple[CSEVariable, str], Union[Tuple[sympy.Expr, str], Tuple[None, None]] + ] = {} + + self.removed_buffers = set() + self.inplaced_to_remove = set() + + # key: the buffer to write + # value: the buffer to read and whose memory can be reused for + # the buffer specified by key + self.inplace_update_buffers = dict() + # Set minimum number of elements processed per thread. + self.min_elem_per_thread = 1 + self.kernel_name = None + + @contextlib.contextmanager + def set_current_node(self, node): + prior = self.current_node + self.current_node = node + self.node_to_bounds = node._body.bounds().get_bounds() + try: + yield + finally: + self.current_node = prior + + @contextlib.contextmanager + def swap_buffers(self, lb, cb=None, sb=None): + if cb is None: + cb = lb + loads = self.loads + compute = self.compute + stores = self.stores + cse = self.cse + self.loads = lb + self.compute = cb + self.stores = sb + self.cse = cse.clone() + try: + yield + finally: + self.loads = loads + self.compute = compute + self.stores = stores + self.cse = cse + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: + raise NotImplementedError() + + def indirect_load(self, name: str, index: sympy.Expr): + """A load the depends on an index we have read""" + prior = self.loads + try: + # put the load in the compute section as it might have deps + self.loads = self.compute + return self.load(name, index) + finally: + self.loads = prior + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + raise NotImplementedError() + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + raise NotImplementedError() + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + raise NotImplementedError() + + def scan( + self, + dtype: torch.dtype, + combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable], + value: CSEVariable, + init: int, + ) -> CSEVariable: + raise NotImplementedError() + + def bucketize( + self, + values: CSEVariable, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> CSEVariable: + """ + See [Note: Inductor bucketize op] + """ + raise NotImplementedError() + + @property + def assert_function(self) -> str: + raise NotImplementedError() + + def index_to_str(self, index: sympy.Expr) -> str: + raise NotImplementedError() + + def __enter__(self): + # TODO: hoist this to top level + class CSEProxy: + self.name = "CSEProxy" + + @staticmethod + def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] + def inner(*args, **kwargs): + # TritonTemplateKernel has no current_node + buf_bounds = ValueRanges.unknown() + if hasattr(V.interpreter, "current_node"): + fx_node = V.interpreter.current_node + assert isinstance(self.node_to_bounds, dict) + buf_bounds = self.node_to_bounds.get( + fx_node, ValueRanges.unknown() + ) + + value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] + + def do_cse(v): + csevar = self.cse.generate(self.compute, v, bounds=buf_bounds) + csevar.update_on_args(name, args, kwargs) + return csevar + + return pytree.tree_map(do_cse, value) + + return inner + + @staticmethod + def indirect_indexing( + var: CSEVariable, size: sympy.Expr, check: bool = True + ): + # Skip CSE since this doesn't return an expression + + if var.bounds.lower < 0: # type: ignore[operator] + new_bounds = ValueRanges.unknown() + if var.bounds != ValueRanges.unknown() and isinstance( + size, sympy.Number + ): + # Take the negative part of the bound and add size to it + # Then take union of that and the positive part + # This is a tighter bound than that of a generic ops.where, as we have info on the cond + neg = var.bounds & ValueRanges(-sympy.oo, -1) + new_bounds = ValueRanges(neg.lower + size, neg.upper + size) + # We don't have a good way of representing the empty range + if var.bounds.upper >= 0: # type: ignore[operator] + pos = var.bounds & ValueRanges(0, sympy.oo) + new_bounds = new_bounds | pos + + stm = ops.add(var, self.rename_indexing(size)) + # Mixed negative and non-negative + if var.bounds.upper >= 0: # type: ignore[operator] + lt = ops.lt(var, "0") + stm = ops.where(lt, stm, var) + new_var = self.cse.generate(self.compute, stm, bounds=new_bounds) + + new_var.update_on_args("index_wrap", (var,), {}) + var = new_var + + if self.generate_assert(check): + mask = self.load_mask(var) + + # An assertion line may have been written already, if so just + # update the max size. + map_key = (var, mask) + existing_size, _ = self.indirect_max_sizes.get( + map_key, (None, None) + ) + if existing_size is not None: + size = sympy.Min(size, existing_size) + else: + line = ( + '{assert_fn}({cond}, "index out of bounds: {cond_print}")' + ) + self.compute.writeline( + IndirectAssertLine( + line, + self.assert_function, + var, + mask, + self.indirect_max_sizes, + ) + ) + + self.indirect_max_sizes[map_key] = (size, self.index_to_str(size)) + return sympy_index_symbol(str(var)) + + @staticmethod + def load(name: str, index: sympy.Expr) -> CSEVariable: + if name in self.cse.invalidated_stores: + # A load from an invalidated store requires us to + # keep the actual buffer around + V.kernel.must_keep_buffers.add(name) + if free_symbol_startswith(index, "tmp"): + return self.indirect_load(name, index) + store_cache = self.cse.store_cache + if name in store_cache: + return store_cache[name] + return self.load(name, index) + + @staticmethod + def store( + name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + self.store_buffer_names.add(name) + if mode is None: + self.cse.store_cache[name] = value + if self.current_node: + for other_name in self.current_node.get_mutations(): + self.cse.store_cache[other_name] = value + if name not in V.graph.removed_buffers: + return self.store(name, index, value, mode=mode) + else: + return None # type: ignore[return-value] + + @staticmethod + def store_reduction(name: str, index: sympy.Expr, value: CSEVariable): + self.store_buffer_names.add(name) + self.cse.store_cache[name] = value + if self.current_node: + for other_name in self.current_node.get_mutations(): + self.cse.store_cache[other_name] = value + + if name not in V.graph.removed_buffers: + return self.store_reduction(name, index, value) + + @staticmethod + def reduction( + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + return self.reduction(dtype, src_dtype, reduction_type, value) + + @staticmethod + def scan( + dtype: torch.dtype, + combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable], + value: CSEVariable, + init: int, + ) -> CSEVariable: + return self.scan(dtype, combine_fn, value, init) + + @staticmethod + def bucketize( + values: CSEVariable, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> CSEVariable: + """ + [Note: Inductor bucketize op] + + Given values (tensor) and offsets_name (reference to the name of a 1D + tensor), calculate the bucket that each value belongs to. + + e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True + return = [ 0, 1, 1, 1, 1, 3, 3, 4]. + + When right == False, bucket i refers to range (offsets[i], offsets[i+1]]. + When right == True, bucket i refers to range [offsets[i], offsets[i+1]). + + Offsets must be non-decreasing or the result is undefined. + """ + return self.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + + # Use mypy to check protocol implemented correctly + def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: + return h + + super().__enter__() + assert self.overrides + parent_handler = self.overrides(V.get_ops_handler()) + self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Note that V.graph.scheduler can be None when codegening triton template + kernels. + """ + if V.graph.scheduler: + V.graph.scheduler.remove_kernel_local_buffers() + super().__exit__(exc_type, exc_val, exc_tb) + + def generate_assert(self, check): + return (check or config.debug_index_asserts) and config.assert_indirect_indexing + + def load_mask(self, var) -> str: + # only the triton kernel requires mask + return "" + + def rename_indexing(self, index) -> sympy.Expr: + # adds the necessary kernel args for index expressions + # and renames variables in index expressions to kernel arg names + if isinstance(index, (list, tuple)): + return [self.rename_indexing(x) for x in index] # type: ignore[return-value] + index = V.graph.sizevars.simplify(index) + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) + replacements = { + x: self.args.size(x) + for x in sorted_symbols + if x.name.startswith(("s", "u", "ps")) + or (x.name.startswith("i") and not x.name.startswith("idx")) + } + return sympy_subs(index, replacements) + + def create_cse_var(self, *args, **kwargs): + return CSEVariable(*args, **kwargs) + + +@dataclasses.dataclass +class OptimizationContext: + key: ClassVar[str] = "opt_ctx" + + # Load value as mask + is_load_as_mask: bool = False + + dtype: Optional[torch.dtype] = None + ops_name: str = "" + + # Load uint8/int8 value as float32 + is_load_int8_as_float: bool = False + + +@functools.lru_cache(None) +def jinja2_env(): + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]] + + +class ChoiceCaller: + """ + Represents a possible choice used in autotune_process.py. + During autotuning, self.benchmark() is first called to get benchmark result, + and if this choice is selected, self.output_node() is called to get the output_node. + + Children classes: TritonTemplateCaller, CUDATemplateCaller. + """ + + def __init__(self, name, input_nodes, layout): + super().__init__() + self.name = name + self.layout = layout + self.input_nodes = input_nodes + + def benchmark(self, *args, out) -> float: + algo = self.to_callable() + return do_bench(lambda: algo(*args, out=out)) + + def call_name(self) -> str: + raise NotImplementedError() + + def to_callable(self): + raise NotImplementedError() + + def hash_key(self) -> str: + raise NotImplementedError() + + def output_node(self) -> "TensorBox": + raise NotImplementedError() + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return {} + + +class KernelTemplate: + """ + Base class for defining kernel templates. + + Children classes: TritonTemplate, CUDATemplate + """ + + @staticmethod + def _template_from_string(source): + env = jinja2_env() + if env is not None: + return env.from_string(source) + return None + + @staticmethod + def _fake_get_dtype(fake_out): + _get_dtype_real = V.graph.get_dtype + + def get_dtype(name): + if name == fake_out.get_name(): + return fake_out.get_dtype() + return _get_dtype_real(name) + + return get_dtype + + def __init__(self, name: str): + self.name = name + + def maybe_append_choice(self, choices, **kwargs): + """ + Maybe generates a new ChoiceCaller and appends it into existing choices. + + choices: A list of ChoiceCallers. + kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. + """ + + try: + choices.append(self.generate(**kwargs)) + except NotImplementedError: + pass + + def generate(self, **kwargs) -> ChoiceCaller: + """ + Generates a ChoiceCaller instance from the given arguments. + """ + + raise NotImplementedError() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..c30a3f6434a3d9f0bb7f2e7383e715207a060442 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py @@ -0,0 +1,4038 @@ +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import re +import sys +from copy import copy, deepcopy +from enum import Enum +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import sympy + +import torch +import torch.fx +from torch._inductor import dependencies +from torch._inductor.ir import StorageBox, TensorBox +from torch._prims_common import is_float_dtype +from torch.utils import _pytree as pytree +from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + +from .. import codecache, config, ir, metrics +from ..codegen.wrapper import WrapperCodeGen +from ..optimize_indexing import range_expressable_in_32_bits +from ..scheduler import ( + BaseScheduling, + ForeachKernelSchedulerNode, + FusedSchedulerNode, + SchedulerNode, +) +from ..utils import ( + cache_on_self, + get_fused_kernel_name, + is_welford_reduction, + parallel_num_threads, + sympy_index_symbol, + sympy_product, + sympy_subs, +) + +from ..virtualized import ops, OpsValue, V +from .common import ( + BracesBuffer, + CppWrapperKernelArgs, + CSE, + CSEVariable, + DataTypePropagation, + DeferredLine, + DTYPE_TO_COMPUTATION_DTYPE, + ExprPrinter, + IndentedBuffer, + Kernel, + KernelArgs, + OpOverrides, + OptimizationContext, +) + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") + +DTYPE_TO_CPP = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "half", + torch.int64: "long", + torch.int32: "int", + torch.int16: "short", + torch.int8: "signed char", + torch.uint64: "unsigned long", + torch.uint32: "unsigned int", + torch.uint16: "unsigned short", + torch.uint8: "unsigned char", + torch.uint32: "unsigned int", + torch.uint64: "unsigned long", + torch.bool: "bool", + torch.bfloat16: "bfloat16", + torch.complex64: "complex64", + torch.float8_e4m3fn: "float8_e4m3fn", + torch.float8_e5m2: "float8_e5m2", +} + +DTYPE_TO_ATEN = { + torch.float32: "at::kFloat", + torch.float64: "at::kDouble", + torch.float16: "at::kHalf", + torch.int64: "at::kLong", + torch.int32: "at::kInt", + torch.int16: "at::kShort", + torch.int8: "at::kChar", + torch.uint64: "at::kUInt64", + torch.uint32: "at::kUInt32", + torch.uint16: "at::kUInt16", + torch.uint8: "at::kByte", + torch.uint32: "at::kUInt32", + torch.uint64: "at::kUInt64", + torch.bool: "at::kBool", + torch.bfloat16: "at::kBFloat16", + torch.complex32: "at::kComplexHalf", + torch.complex64: "at::kComplexFloat", + torch.complex128: "at::kComplexDouble", + torch.float8_e4m3fn: "at::kFloat8_e4m3fn", + torch.float8_e5m2: "at::kFloat8_e5m2", + torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz", +} + +DEVICE_TO_ATEN = { + "cpu": "at::kCPU", + "cuda": "at::kCUDA", +} + +INDEX_TYPE = "long" + +NATIVE_OMP_RTYPES = {"+", "*", "^", "||", "min", "max"} +RTYPE_TO_CPP = { + "sum": "+", + "prod": "*", + "xor_sum": "^", + "min": "min", + "max": "max", + "argmin": "argmin", + "argmax": "argmax", + "any": "||", + "welford_reduce": "welford", + "welford_combine": "welford", +} +VECTORIZABLE_RTYPES = { + "max", + "min", + "sum", + "prod", + "xor_sum", + "welford_reduce", + "welford_combine", +} + +PYTHON_TO_CPP = { + "Tensor": "at::Tensor", + "int": "long", + "float": "double", + "bool": "bool", + "str": "std::string", + "ScalarType": "c10::ScalarType", + "MemoryFormat": "at::MemoryFormat", + "Layout": "at::Layout", + "Device": "at::Device", + "number": "at::Scalar", +} + +CONTAINER_PYTHON_TO_CPP = { + "List": "std::vector", + "Optional": "c10::optional", +} + +DTYPE_LOWP_FP = [ + torch.bfloat16, + torch.float16, +] + + +def value_to_cpp(value, cpp_type): + if value == float("-inf"): + return f"-std::numeric_limits<{cpp_type}>::infinity()" + elif value == float("inf"): + return f"std::numeric_limits<{cpp_type}>::infinity()" + elif isinstance(value, bool): + return f"static_cast<{cpp_type}>({str(value).lower()})" + elif math.isnan(value): + return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" + else: + return f"static_cast<{cpp_type}>({repr(value)})" + + +def reduction_init(reduction_type, dtype): + if dtype in DTYPE_LOWP_FP: + # Since load promotes all half-precision inputs to float, the initial + # constant for reduction must be promoted as well + dtype = torch.float32 + if reduction_type in ("xor_sum", "sum", "any"): + return 0 + if reduction_type == "prod": + return 1 + if reduction_type in {"max", "argmax"}: + return ( + f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::min()" + ) + if reduction_type in {"min", "argmin"}: + return ( + f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::max()" + ) + if is_welford_reduction(reduction_type): + return f"Welford<{DTYPE_TO_CPP[dtype]}>()" + raise AssertionError(reduction_type) + + +def reduction_acc_type(reduction_type, dtype): + assert reduction_type not in {"argmin", "argmax"} + scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]] + if is_welford_reduction(reduction_type): + return f"Welford<{scalar_type}>" + + return scalar_type + + +def reduction_combine(reduction_type, var, next_value): + if reduction_type == "sum": + return f"{var} + {next_value}" + if reduction_type == "prod": + return f"{var} * {next_value}" + if reduction_type == "xor_sum": + return f"{var} ^ {next_value}" + if reduction_type == "any": + return f"{var} || {next_value}" + if reduction_type in ("min", "max"): + return f"{reduction_type}_propagate_nan({var}, {next_value})" + if reduction_type == "welford_reduce": + return f"welford_combine({var}, {next_value})" + if reduction_type == "welford_combine": + if isinstance(next_value, tuple): + mean, m2, weight = next_value + else: + mean, m2, weight = reduction_project(reduction_type, next_value) + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + raise AssertionError(reduction_type) + + +def reduction_project(reduction_type, acc): + if is_welford_reduction(reduction_type): + return f"{acc}.mean", f"{acc}.m2", f"{acc}.weight" + elif reduction_type in {"argmin", "argmax"}: + return f"{acc}.index" + return acc + + +def is_to_lowp_dtype(expr): + to_exprs = ["cvt_fp32_to_lowp_fp", "c10::convert"] + if any(to_expr in expr for to_expr in to_exprs): + if "half" in expr: + return torch.half + if "bfloat16" in expr: + return torch.bfloat16 + return None + + +def get_lowp_to_fp32_expr(lowp_var, src_dtype, kernel): + if isinstance(kernel, CppVecKernel): + return f"cvt_lowp_fp_to_fp32<{DTYPE_TO_CPP[src_dtype]}>({lowp_var})" + else: + assert isinstance(kernel, CppKernel) + return f"c10::convert({lowp_var})" + + +index_value_name_counter = 1 + + +def argmax_argmin_prefix(reduction_type, src_dtype, tmpvar): + global index_value_name_counter + struct_name = f"IndexValue_{index_value_name_counter}" + index_value_name_counter += 1 + + # A small annoyance, due to it being a little cumbersome to just throw {} into strings + prefix = [ + f"struct {struct_name} {{size_t index; {DTYPE_TO_CPP[src_dtype]} value;}};", + f"{struct_name} {tmpvar}{{0, {reduction_init(reduction_type, src_dtype)}}};", + ] + + if reduction_type in ["argmax", "argmin"]: + compare_op = "greater_or_nan" if reduction_type == "argmax" else "less_or_nan" + prefix.extend( + [ + "#if !defined(__clang_major__) || __clang_major__ > 9", + f"#pragma omp declare reduction({reduction_type} : {struct_name} :\\", + f" omp_out = {compare_op}(omp_in.value, omp_out.value, omp_in.index, omp_out.index) ? omp_in : omp_out)\\", + f"\tinitializer(omp_priv = {{0, {reduction_init(reduction_type, src_dtype)}}})", + "#endif", + ] + ) + + return prefix + + +@functools.lru_cache +def stride_at(index: sympy.Expr, var: sympy.Symbol): + replacement = {var: var + 1} + new_index = sympy_subs(index, replacement) # type: ignore[arg-type] + return sympy.simplify(new_index - index) + + +@functools.lru_cache +def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int): + """ + Simplifies the index expression within the range of a vectorized loop. + Given a vectorized loop variable `var` in the range of a loop with `vec_length`, + this function transforms the `index` into an equivalent form. It handles + simplifications for cases where `var` can be expressed as `vec_length * a + b`, + where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences + of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations. + + NOTE: + The simplified index expression is intended for analysis purposes only, not + for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables + which are not dependent on the loop variable `var` in the vectorized range. Check + https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details. + + Examples: + 1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then + `FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable + when `div` is divisible by 16. + 2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free + variable when `mod` is divisible by 16. + """ + + div_freevar_id = 0 + mod_freevar_id = 0 + + def visit_indexing_div(divisor): + nonlocal div_freevar_id + result = FloorDiv(var, divisor) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_div_c{div_freevar_id}") + div_freevar_id += 1 + return result + + def visit_modular_indexing(divisor, modulus): + nonlocal mod_freevar_id + result = ModularIndexing(var, divisor, modulus) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length: + result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + return result + + original_index = index + + div = sympy.Wild("divisor") + if index.has(FloorDiv): + index = index.replace(FloorDiv(var, div), visit_indexing_div) + + mod = sympy.Wild("modulus") + if index.has(ModularIndexing): + index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) + + index = sympy.simplify(index) + if index != original_index: + return simplify_index_in_vec_range(index, var, vec_length) + + return index + + +@functools.lru_cache +def stride_at_vec_range(index: sympy.Expr, var: sympy.Symbol, vec_length: int): + index_vec_simplified = simplify_index_in_vec_range(index, var, vec_length) + return stride_at(index_vec_simplified, var) + + +class CppPrinter(ExprPrinter): + def _print_Integer(self, expr): + return f"{int(expr)}L" + + def _print_Where(self, expr): + c = self.paren(self.doprint(expr.args[0])) + p = self.paren(self.doprint(expr.args[1])) + q = self.paren(self.doprint(expr.args[2])) + return f"{c} ? {p} : {q}" + + def _print_ModularIndexing(self, expr): + x, div, mod = expr.args + x = self.paren(self.doprint(x)) + if div != 1: + div = self.paren(self.doprint(div)) + if expr.is_integer: + x = f"c10::div_floor_integer({x}, {div})" + else: + x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + mod = self.paren(self.doprint(mod)) + return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})" + + def _print_FloorDiv(self, expr): + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + if expr.is_integer: + return f"c10::div_floor_integer({x}, {div})" + return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + + def _print_floor(self, expr): + assert len(expr.args) == 1 + r = f"std::floor({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Pow(self, expr): + # Uses float constants to perform FP div + base, exp = expr.args + base = self._print(base) + + if exp == 0.5 or exp == -0.5: + return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})" + assert exp.is_integer + exp = int(exp) + if exp > 0: + r = "*".join([self.paren(base)] * exp) + elif exp < 0: + r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp))) + else: # exp == 0 + r = "1.0" + + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Rational(self, expr): + # Uses float constants to perform FP div + if expr.q == 1: + r = f"{expr.p}" + else: + r = f"{expr.p}.0/{expr.q}.0" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Min(self, expr): + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::min({args[0]}, {args[1]})" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::min({il})" + + def _print_Max(self, expr): + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::max({args[0]}, {args[1]})" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::max({il})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return f"std::abs({self._print(expr.args[0])})" + + def _print_cos(self, expr): + assert len(expr.args) == 1 + return f"std::cos({self._print(expr.args[0])})" + + def _print_cosh(self, expr): + assert len(expr.args) == 1 + return f"std::cosh({self._print(expr.args[0])})" + + def _print_acos(self, expr): + assert len(expr.args) == 1 + return f"std::acos({self._print(expr.args[0])})" + + def _print_sin(self, expr): + assert len(expr.args) == 1 + return f"std::sin({self._print(expr.args[0])})" + + def _print_sinh(self, expr): + assert len(expr.args) == 1 + return f"std::sinh({self._print(expr.args[0])})" + + def _print_asin(self, expr): + assert len(expr.args) == 1 + return f"std::asin({self._print(expr.args[0])})" + + def _print_tan(self, expr): + assert len(expr.args) == 1 + return f"std::tan({self._print(expr.args[0])})" + + def _print_tanh(self, expr): + assert len(expr.args) == 1 + return f"std::tanh({self._print(expr.args[0])})" + + def _print_atan(self, expr): + assert len(expr.args) == 1 + return f"std::atan({self._print(expr.args[0])})" + + def _print_Round(self, expr): + assert len(expr.args) == 1 + return f"std::lrint({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr): + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + return f"static_cast(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})" + + +# A function to print, useful for printing sympy symbols. +cexpr = CppPrinter().doprint + + +def cexpr_index(index): + return f"static_cast<{INDEX_TYPE}>({cexpr(index)})" + + +class RecordOptimizationContext: + def __init__(self, func_name: str = ""): + self.func_name = func_name + self.current_node: Optional[torch.fx.Node] = None + self.opt_ctx: Optional[OptimizationContext] = None + + def __enter__(self): + assert V.interpreter + assert V.interpreter.current_node + + self.current_node = V.interpreter.current_node + assert self.current_node is not None + if OptimizationContext.key in self.current_node.meta: + self.opt_ctx = self.current_node.meta[OptimizationContext.key] + else: + self.opt_ctx = OptimizationContext() + assert self.opt_ctx is not None + self.opt_ctx.ops_name = self.func_name + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.current_node + assert self.opt_ctx + self.current_node.meta[OptimizationContext.key] = self.opt_ctx + + def get_opt_ctx(self): + return self.opt_ctx + + def get_fx_node(self): + assert self.current_node + return self.current_node + + +def get_opt_ctx(node: torch.fx.Node) -> OptimizationContext: + return node.meta.get(OptimizationContext.key, None) + + +def get_current_node_opt_ctx() -> OptimizationContext: + assert V.interpreter.current_node + return get_opt_ctx(V.interpreter.current_node) + + +class CppVecUnsupportedError(Exception): + pass + + +class CppCSEVariable(CSEVariable): + def __init__(self, name, bounds: ValueRanges[Any]): + super().__init__(name, bounds) + self.is_vec = False + self.dtype: Optional[torch.dtype] = None + self.dependent_itervars: Set[sympy.Symbol] = set() + + def update_on_args(self, name, args, kwargs): + if name == "load": + # args[1] is index + self._set_dependent_itervars(args[1]) + else: + # propagate relevant itervars and is_vec from args + self.dependent_itervars.update( + *[ + arg.dependent_itervars + for arg in args + if isinstance(arg, CppCSEVariable) + ] + ) + if name == "index_expr": + self._set_dependent_itervars(args[0]) + if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)): + self.is_vec = True + # NOTE [dtype of CppCSEVariable] + # Deciding dtype according to the current optimization context is not + # always accurate since the dtypes are initialized during dtype propagation + # at the beginning of the codegen. It is possible that some ops are invoked + # during the codegen of the current op and take different dtypes from the + # current op. + # TODO(jgong5): A more accurate way of deciding the dtype of the variables is to + # propagate the dtypes here inside `update_on_args`. + if ( + hasattr(V.interpreter, "current_node") + and get_current_node_opt_ctx() is not None + ): + self.dtype = get_current_node_opt_ctx().dtype + + def _set_dependent_itervars(self, index: sympy.Expr): + """ + Set the relevant itervars for this variable based on the `index` expression. + This includes the itervars directly used in the `index` as well as relevant itervars + of other cse variables used in the `index`. + """ + for s in index.free_symbols: + if s in V.kernel.itervars: + self.dependent_itervars.add(s) # type: ignore[arg-type] + elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined] + self.dependent_itervars.update( + V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined] + ) + + def depends_on(self, itervar: sympy.Symbol): + return itervar in self.dependent_itervars + + +class CppOverrides(OpOverrides): + """Map element-wise ops to C++""" + + @staticmethod + def add(a, b): + return f"decltype({a})({a} + {b})" + + @staticmethod + def sub(a, b): + return f"decltype({a})({a} - {b})" + + @staticmethod + def mul(a, b): + return f"decltype({a})({a} * {b})" + + @staticmethod + def to_dtype(x, dtype, src_dtype=None): + assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP" + return f"c10::convert<{DTYPE_TO_CPP[dtype]}>({x})" + + @staticmethod + def to_dtype_bitcast(x, dtype, src_dtype): + assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP" + if src_dtype in (torch.float16, torch.bfloat16): + # c10::bit_cast requires the source and target have the bitwidth. + # Because the input tensor's dtype could be promoted, e.g. from float16 to + # float, we have to cast the tensor to its original source dtype before + # invoking bit_cast. We also need to convert the bit-casted tensor + # back to float to make sure we keep using higher precision values + # for the rest of the computation. + cast_x = f"c10::convert<{DTYPE_TO_CPP[src_dtype]}>({x})" + cast_x = f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({cast_x})" + return f"c10::convert<{DTYPE_TO_CPP[torch.float32]}>({cast_x})" + else: + return f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({x})" + + @staticmethod + def abs(x): + return f"std::abs({x})" + + @staticmethod + def sin(x): + return f"std::sin({x})" + + @staticmethod + def cos(x): + return f"std::cos({x})" + + @staticmethod + def neg(x): + return f"decltype({x})(-{x})" + + @staticmethod + def exp(x): + # return f"Sleef_expf_u10({x})" + return f"std::exp({x})" + + @staticmethod + def exp2(x): + return f"std::exp2({x})" + + @staticmethod + def expm1(x): + return f"std::expm1({x})" + + @staticmethod + def erf(x): + return f"std::erf({x})" + + @staticmethod + def erfc(x): + return f"std::erfc({x})" + + @staticmethod + def erfinv(x): + return f"calc_erfinv({x})" + + @staticmethod + def sqrt(x): + return f"std::sqrt({x})" + + @staticmethod + def rsqrt(x): + return f"1 / std::sqrt({x})" + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::log1p({x})" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def tan(x): + return f"std::tan({x})" + + @staticmethod + def tanh(x): + return f"std::tanh({x})" + + @staticmethod + def signbit(x): + return f"std::signbit({x})" + + @staticmethod + def pow(a, b): + return f"std::pow({a}, {b})" + + @staticmethod + def log(x): + return f"std::log({x})" + + @staticmethod + def round(x): + return f"std::nearbyint({x})" + + @staticmethod + def floor(x): + return f"std::floor({x})" + + @staticmethod + def floordiv(a, b): + # a and b are integer type + quot = f"{a} / {b}" + rem = f"{a} % {b}" + return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})" + + @staticmethod + def ceil(x): + return f"std::ceil({x})" + + @staticmethod + def trunc(x): + return f"std::trunc({x})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + return f"{a} / {b}" + + @staticmethod + def fmod(a, b): + return f"std::fmod({a}, {b})" + + @staticmethod + def isinf(x): + return f"std::isinf({x})" + + @staticmethod + def isnan(x): + return f"std::isnan({x})" + + @staticmethod + def lgamma(x): + return f"std::lgamma({x})" + + @staticmethod + def acos(x): + return f"std::acos({x})" + + @staticmethod + def acosh(x): + return f"std::acosh({x})" + + @staticmethod + def cosh(x): + return f"std::cosh({x})" + + @staticmethod + def sinh(x): + return f"std::sinh({x})" + + @staticmethod + def asin(x): + return f"std::asin({x})" + + @staticmethod + def asinh(x): + return f"std::asinh({x})" + + @staticmethod + def atan2(x, y): + return f"std::atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"std::atan({x})" + + @staticmethod + def atanh(x): + return f"std::atanh({x})" + + @staticmethod + def copysign(x, y): + return f"std::copysign({x}, {y})" + + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys): + return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys) + + code = BracesBuffer() + exponent = V.kernel.cse.newvar() + mantissa = V.kernel.cse.newvar() + code.writeline(f"int32_t {exponent};") + code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.cache[cache_key] = cse_var + return mantissa, exponent + + @staticmethod + def hypot(x, y): + return f"std::hypot({x}, {y})" + + @staticmethod + def log10(x): + return f"std::log10({x})" + + @staticmethod + def nextafter(x, y): + return f"std::nextafter({x}, {y})" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::max({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def minimum(a, b): + return f"min_propagate_nan({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"max_propagate_nan({a}, {b})" + + @staticmethod + def where(a, b, c): + return f"{a} ? {b} : {c}" + + @staticmethod + def mod(a, b): + return f"mod({a}, {b})" + + @staticmethod + def constant(val, dtype): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx and opt_ctx.dtype is not None + dtype = opt_ctx.dtype + if dtype in DTYPE_LOWP_FP: + # Since load promotes all half-precision inputs to float, constants + # must be promoted as well + dtype = torch.float32 + return value_to_cpp(val, DTYPE_TO_CPP[dtype]) + + @staticmethod + def index_expr(expr, dtype): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx and opt_ctx.dtype is not None + dtype = opt_ctx.dtype + return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype) + + @staticmethod + def masked(mask, body, other): + code = BracesBuffer() + + # Write masked operation into a lambda + body_var = V.kernel.cse.newvar() + code.writeline(f"auto {body_var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + # Use the lambda's return type as the type of other + other_code = value_to_cpp(other, f"decltype({body_var}())") + return f"{mask} ? {body_var}() : {other_code}" + + @staticmethod + def logical_and(a, b): + return f"{a} && {b}" + + @staticmethod + def logical_not(a): + return f"!{a}" + + @staticmethod + def logical_or(a, b): + return f"{a} || {b}" + + @staticmethod + def logical_xor(a, b): + return f"{a} != {b}" + + @staticmethod + def bitwise_and(a, b): + return f"decltype({a})({a} & {b})" + + @staticmethod + def bitwise_not(a): + return f"decltype({a})(~{a})" + + @staticmethod + def bitwise_or(a, b): + return f"decltype({a})({a} | {b})" + + @staticmethod + def bitwise_xor(a, b): + return f"decltype({a})({a} ^ {b})" + + @staticmethod + def bitwise_left_shift(a, b): + return f"decltype({a})({a} << {b})" + + @staticmethod + def bitwise_right_shift(a, b): + return f"decltype({a})({a} >> {b})" + + @staticmethod + def rand(seed: sympy.Expr, offset: sympy.Expr): + return f"normalized_rand_cpu({seed}, {offset})" + + @staticmethod + def randn(seed: sympy.Expr, offset: sympy.Expr): + return f"randn_cpu({seed}, {offset})" + + @staticmethod + def randint64(seed: sympy.Expr, offset: sympy.Expr, low, high): + return f"randint64_cpu({seed}, {offset}, {low}, {high})" + + @staticmethod + def sigmoid(x): + return f"decltype({x})(1) / (decltype({x})(1) + std::exp(-{x}))" + + @staticmethod + def sign(x): + code = BracesBuffer() + scalar_zero = f"decltype({x})(0)" + scalar_one = f"decltype({x})(1)" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {x} > 0 ? {scalar_one} : {scalar_zero};") + code.writeline(f"auto right = {x} < 0 ? {scalar_one} : {scalar_zero};") + code.writeline("return left - right;") + code.writeline("()") + return code + + +CppOverrides._initialize_pointwise_overrides("cpp") + + +class CppVecOverrides(CppOverrides): + """Map element-wise ops to aten vectorization C++""" + + def __new__(cls, *args, **kargs): + self = super().__new__(cls) + + def wrap(func): + # `CppVecKernel` generates both scalar ops and vector ops according to + # whether the inputs are scalars or vectors while all ops in `CppVecOverrides` + # (except for some ops explained below) assume the inputs are vectors. We wrap the ops in + # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to + # `CppOverrides` when all inputs are scalars. + # + # Notes on ops handled separately in their own functions: + # `ops.masked`: + # needs recursive handling of masked body. + # `ops.index_expr`: + # needs to further analyze the dependency of the index expression on + # the tiling itervar. + def wrapper(*args, **kwargs): + scalars = [ + arg + for arg in args + if isinstance(arg, CppCSEVariable) and not arg.is_vec + ] + vectors = [ + arg + for arg in args + if isinstance(arg, CppCSEVariable) and arg.is_vec + ] + new_args = list(args) + if scalars and vectors: + # broadcast scalar args to vector if needed + new_args = [] + vec_dtype = vectors[0].dtype + for arg in args: + if isinstance(arg, CppCSEVariable) and not arg.is_vec: + assert isinstance(V.kernel, CppVecKernel) + # align scalar data type to the vector for binary ops + if len(args) == 2 and arg.dtype != vec_dtype: + arg = ops.to_dtype(arg, vec_dtype) + arg = arg.value if isinstance(arg, OpsValue) else arg + # See NOTE [dtype of CppCSEVariable]: we have to fix arg.dtype since + # the dtype from optimization context could be wrong. + assert isinstance(arg, CppCSEVariable) + arg.dtype = vec_dtype + new_arg = V.kernel.broadcast(arg) + new_args.append(new_arg) + else: + new_args.append(arg) + if vectors: + return func(*new_args, **kwargs) + else: + # fallback to scalar ops + scalar_ops = super(CppVecOverrides, self) + scalar_func = getattr( + scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__) # type: ignore[attr-defined] + ) + assert scalar_func is not None + return scalar_func(*args, **kwargs) + + return wrapper + + for name, method in vars(CppVecOverrides).items(): + if getattr(method, "__class__", None) == staticmethod and name not in [ + "masked", + "index_expr", + ]: + setattr(self, name, wrap(method.__func__)) + return self + + @staticmethod + def add(a, b): + return f"{a} + {b}" + + @staticmethod + def sub(a, b): + return f"{a} - {b}" + + @staticmethod + def mul(a, b): + return f"{a} * {b}" + + @staticmethod + def truediv(a, b): + return f"{a} / {b}" + + @staticmethod + def abs(x): + return f"{x}.abs()" + + @staticmethod + def sin(x): + return f"{x}.sin()" + + @staticmethod + def cos(x): + return f"{x}.cos()" + + @staticmethod + def exp(x): + return f"{x}.exp()" + + @staticmethod + def exp2(x): + return f"{x}.exp2()" + + @staticmethod + def expm1(x): + # decompose for a better performance + vec_one = f"decltype({x})(1)" + return f"{x}.exp() - {vec_one}" + + @staticmethod + def erf(x): + return f"{x}.erf()" + + @staticmethod + def erfc(x): + return f"{x}.erfc()" + + @staticmethod + def erfinv(x): + return f"{x}.erfinv()" + + @staticmethod + def sqrt(x): + return f"{x}.sqrt()" + + @staticmethod + def eq(x, y): + return f"to_float_mask({x} == {y})" + + @staticmethod + def ne(x, y): + return f"to_float_mask({x} != {y})" + + @staticmethod + def lt(x, y): + return f"to_float_mask({x} < {y})" + + @staticmethod + def gt(x, y): + return f"to_float_mask({x} > {y})" + + @staticmethod + def le(x, y): + return f"to_float_mask({x} <= {y})" + + @staticmethod + def ge(x, y): + return f"to_float_mask({x} >= {y})" + + @staticmethod + def and_(x, y): + return f"{x} & {y}" + + @staticmethod + def rsqrt(x): + return f"{x}.rsqrt()" + + @staticmethod + def pow(a, b): + return f"{a}.pow({b})" + + @staticmethod + def log(x): + return f"{x}.log()" + + @staticmethod + def round(x): + return f"{x}.round()" + + @staticmethod + def floor(x): + return f"{x}.floor()" + + @staticmethod + def ceil(x): + return f"{x}.ceil()" + + @staticmethod + def trunc(x): + return f"{x}.trunc()" + + @staticmethod + def fmod(a, b): + return f"{a}.fmod({b})" + + @staticmethod + def lgamma(x): + return f"{x}.lgamma()" + + @staticmethod + def logical_and(a, b): + return f"({a} != 0) & ({b} != 0)" + + @staticmethod + def logical_not(a): + return f"{a} == 0" + + @staticmethod + def logical_or(a, b): + return f"({a} != 0) | ({b} != 0)" + + @staticmethod + def logical_xor(a, b): + return f"({a} != 0) ^ ({b} != 0)" + + @staticmethod + def tan(a): + return f"{a}.tan()" + + @staticmethod + def tanh(a): + vec_one = f"decltype({a})(1)" + vec_two = f"decltype({a})(2)" + vec_minus_two = f"decltype({a})(-2)" + return f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}" + + @staticmethod + def reciprocal(a): + return f"{a}.reciprocal()" + + @staticmethod + def atan(x): + return f"{x}.atan()" + + @staticmethod + def acos(x): + return f"{x}.acos()" + + @staticmethod + def asin(x): + return f"{x}.asin()" + + @staticmethod + def cosh(x): + return f"{x}.cosh()" + + @staticmethod + def sinh(x): + return f"{x}.sinh()" + + @staticmethod + def log10(x): + return f"{x}.log10()" + + @staticmethod + def nextafter(x): + return f"{x}.nextafter()" + + @staticmethod + def copysign(a, b): + return f"{a}.copysign({b})" + + @staticmethod + def atan2(a, b): + return f"{a}.atan2({b})" + + @staticmethod + def hypot(a, b): + return f"{a}.hypot({b})" + + @staticmethod + def atanh(x): + # For real x, atanh(x) = 1/2 * log((1+x)/(1-x)) + vec_one = f"decltype({x})(1)" + vec_one_half = f"decltype({x})(0.5)" + return f"{vec_one_half} * (({vec_one} + {x})/({vec_one} - {x})).log()" + + @staticmethod + def asinh(x): + # For real x, asinh(x) = log(x + sqrt(1 + x**2)) + vec_one = f"decltype({x})(1)" + return f"({x} + ({vec_one} + {x}*{x}).sqrt()).log()" + + @staticmethod + def acosh(x): + return f"{x}.acosh()" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"at::vec::clamp_min({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + # TODO: this seems to be dead + @staticmethod + def sigmoid(x): + return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())" + + @staticmethod + def neg(x): + return f"{x}.neg()" + + @staticmethod + def floordiv(a, b): + # a and b are integer type + _t = f"decltype({a})" + quot = f"{a} / {b}" + has_rem = f"({a} % {b} != {_t}(0))" + is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))" + return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + return f"{a} / {b}" + + @staticmethod + def minimum(a, b): + return f"at::vec::minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"at::vec::maximum({a}, {b})" + + @staticmethod + def square(a): + return f"{a} * {a}" + + @staticmethod + def where(a, b, c): + assert isinstance(b, CppCSEVariable) + if b.dtype != torch.float: + raise CppVecUnsupportedError( + "where with non-float tensor is not supported in vectorized codegen" + ) + return f"decltype({b})::blendv({c}, {b}, {a})" + + @staticmethod + def sign(x): + code = BracesBuffer() + vec_zero = f"decltype({x})(0)" + vec_one = f"decltype({x})(1)" + blendv_l = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})" + blendv_r = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {blendv_l};") + code.writeline(f"auto right = {blendv_r};") + code.writeline("return left - right;") + code.writeline("()") + return code + + @staticmethod + def to_dtype(x, dtype, src_dtype=None): + assert dtype in [ + torch.bool, + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + ], f"{__name__} does not support {dtype}" + node: torch.fx.Node = V.interpreter.current_node + assert node and isinstance(node, torch.fx.Node) + opt_ctx_x = get_opt_ctx(node.args[1]) + assert opt_ctx_x + if opt_ctx_x.dtype in (torch.float, torch.float32) and dtype == torch.bool: + return f"vec_convert_to_mask({x})" + if opt_ctx_x.dtype == torch.bool and dtype in (torch.float, torch.float32): + return f"mask_convert_to_float({x})" + if opt_ctx_x.dtype == torch.bool and dtype in DTYPE_LOWP_FP: + return f"mask_convert_to_lowp<{DTYPE_TO_CPP[dtype]}>({x})" + if opt_ctx_x.dtype == torch.bool and dtype == torch.int64: + return f"mask_convert_to_int64({x})" + if opt_ctx_x.dtype in (torch.float, torch.float32) and dtype in DTYPE_LOWP_FP: + return f"cvt_fp32_to_lowp_fp<{DTYPE_TO_CPP[dtype]}>({x})" + if opt_ctx_x.dtype in DTYPE_LOWP_FP and dtype in (torch.float, torch.float32): + return f"cvt_lowp_fp_to_fp32<{DTYPE_TO_CPP[opt_ctx_x.dtype]}>({x})" + if opt_ctx_x.dtype in (torch.uint8, torch.int8) and dtype in ( + torch.float, + torch.float32, + ): + # Note: this function only convert inputs number of elements equal to at::vec::Vectorized.size() + return f"at::vec::convert_int8_to_float({x})" + if opt_ctx_x.dtype in (torch.float, torch.float32) and dtype in ( + torch.uint8, + torch.int8, + ): + # if we already handle the saturation previously. + # * Pattern match of quantization op in the loop body. + # * Skip the explicit saturation and clamp inside at::vec::convert_float_to_int8. + return f"at::vec::convert_float_to_int8<{DTYPE_TO_CPP[dtype]}>({x})" + if opt_ctx_x.dtype == torch.int32 and dtype == torch.float: + return f"at::vec::convert_to_fp_of_same_size({x})" + if opt_ctx_x.dtype == torch.float and dtype == torch.int32: + return f"at::vec::convert_to_int_of_same_size({x})" + if opt_ctx_x.dtype == torch.int64 and dtype == torch.float: + return f"cvt_int64_to_fp32({x})" + if opt_ctx_x.dtype == torch.float and dtype == torch.int64: + return f"cvt_fp32_to_int64({x})" + if opt_ctx_x.dtype == torch.int32 and dtype == torch.int64: + return f"cvt_int32_to_int64({x})" + if opt_ctx_x.dtype == torch.int64 and dtype == torch.int32: + return f"cvt_int64_to_int32({x})" + # TODO(jgong5): support conversion for other types + # currently we only allow load/store torch.uint8 and handle conversion there + return f"({x})" + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"{x}.log1p()" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def masked(mask, body, other): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + var = V.kernel.cse.newvar() + with V.kernel.masked(mask) as new_mask: + code.writeline(f"auto {var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + body_code = f"{var}()" + body_code_vec = ( + body_code + if result.is_vec + else f"{V.kernel._get_vec_type(torch.float)}({body_code})" + ) + other_code = value_to_cpp(other, "float") + other_code_vec = f"{V.kernel._get_vec_type(torch.float)}({other_code})" + assert isinstance(new_mask, CppCSEVariable), new_mask + if new_mask.is_vec or result.is_vec: + if result.dtype != torch.float: + raise CppVecUnsupportedError( + "masked with non-float tensor is not supported in vectorized codegen" + ) + type = f"decltype({body_code_vec})" + float_mask = f"to_float_mask({new_mask})" + code = BracesBuffer() + code.writeline("[&]") + with V.kernel.swap_buffers(code), code.indent(): + code.writeline(f"if (all_zero({float_mask}))") + with code.indent(): + code.writeline(f"return {other_code_vec};") + code.writeline("else") + with code.indent(): + code.writeline( + f"return {type}::blendv({other_code_vec}, {body_code_vec}, {float_mask});" + ) + code.writeline("()") + csevar = V.kernel.cse.generate( + V.kernel.compute, + code, + ) + else: + csevar = V.kernel.cse.generate( + V.kernel.compute, f"{mask} ? {body_code} : {other_code}" + ) + # `result` is explicitly added to the args for correct propagation + # of relevant itervars and vectorization status. + csevar.update_on_args("masked", (mask, body, other, result), {}) + return csevar + + @staticmethod + def index_expr(expr, dtype): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx and opt_ctx.dtype is not None + dtype = opt_ctx.dtype + assert dtype == torch.int32 + assert isinstance(V.kernel, CppVecKernel) + index = V.kernel.rename_indexing(expr) + tiling_var = V.kernel.itervars[V.kernel.tiling_idx] + stride = stride_at_vec_range(index, tiling_var, V.kernel.tiling_factor) + if stride.is_number and not V.kernel.index_indirect_depends_on( + index, tiling_var + ): + if stride == 0: + return CppOverrides.index_expr(expr, dtype) + value = ops.to_dtype(cexpr(index), dtype) + if isinstance(value, OpsValue): + value = value.value + csevar = V.kernel.arange(value, stride) + else: + csevar = V.kernel.load_non_contiguous(None, index, dtype, V.kernel.compute) + csevar.update_on_args("index_expr", (expr, dtype), {}) + return csevar + + +CppVecOverrides._initialize_pointwise_overrides("cppvec") + + +class CppTile2DOverrides(CppVecOverrides): + @staticmethod + def index_expr(expr, dtype): + assert isinstance(V.kernel, CppTile2DKernel) + expr = V.kernel.transform_indexing(expr) + return CppVecOverrides.index_expr(expr, dtype) + + +class CppKernel(Kernel): + overrides = CppOverrides # type: ignore[assignment] + sexpr = cexpr + newvar_prefix = "auto " + suffix = ";" + + def __init__(self, args, num_threads): + super().__init__(args) + self.call_ranges: Optional[Tuple[sympy.Expr, ...]] = None + self.ranges: List[sympy.Expr] = [] + self.itervars: List[sympy.Symbol] = [] + self.reduction_depth = None + self.reduction_prefix = IndentedBuffer() + self.reduction_suffix = IndentedBuffer() + self.reduction_var_map = {} + self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") + self.preloads = IndentedBuffer() + self.poststores = IndentedBuffer() + self.num_threads = num_threads # num_threads the kernel specialized for + self.reduction_omp_dec: Dict[Tuple[str, str], str] = {} + + @contextlib.contextmanager + def masked(self, mask): + """Context manager to add an additional mask to loads and stores.""" + prior = self._load_mask + if prior: + mask = ops.and_(mask, prior) + if isinstance(mask, OpsValue): + mask = mask.value + assert isinstance(mask, CppCSEVariable) + # see NOTE [dtype of CppCSEVariable] + # mask's dtype should be bool + mask.dtype = torch.bool + + self._load_mask = mask + try: + yield mask + finally: + self._load_mask = prior + + def cache_fp32_cse_var_before_lowp_store(self, var_to_store): + """ + https://github.com/pytorch/pytorch/issues/115260 + For FusedSchedulerNode[node1, node2], the node2 loads what node1 stores and the buffer is + in low-precision floating point data type. When the output of node1 also serves as the output of the + kernel, the result of nodes would be different from the case when output of node1 is not the output + of the kernel (where we don't need to insert `to_dtype` for legalization). To address the problem, on + storing the lowp node1 output, we also add the inverse dtype conversion to high precision data type + to the cse cache. + + Example (pseudo code): + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = load(buf) + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + Without cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + With cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + # also add `to_dtype(node1_input_lowp, dtype=torch.float)` -> `node1_output` to cse cache + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = node1_output # hit cse cache + """ + + if var_to_store.dtype not in DTYPE_LOWP_FP: + # only need to cache fp32 cse var while var_to_store is lowp data + return + + def find_fp32_var(var, cache): + fp32_cse_var = None + fp32_cse_var_name = None + lowp_dtype = None + for expr, cse_var in cache.items(): + if cse_var == var: + lowp_dtype = is_to_lowp_dtype(expr) + if lowp_dtype: + m = re.search(r"tmp\d+", expr) + assert m + fp32_cse_var_name = m.group() + if fp32_cse_var_name: + for cse_var in cache.values(): + if cse_var.name == fp32_cse_var_name: + fp32_cse_var = cse_var + break + assert fp32_cse_var is not None + return fp32_cse_var, lowp_dtype + + fp32_var, lowp_dtype = find_fp32_var(var_to_store, self.cse.cache) + if fp32_var: + self.cse.cache[ + get_lowp_to_fp32_expr(var_to_store, lowp_dtype, self) + ] = fp32_var + + def scale_index_with_offset( + self, index: sympy.Expr, scale=1, itervar_idx=-1, offset=0 + ): + var = self.itervars[itervar_idx] + replacement = {var: var * scale + offset} + new_index = sympy_subs(index, replacement) + return new_index + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in cpp code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the cpp kernel. + """ + return cexpr(self.rename_indexing(index)) + + def index_indirect_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + """ + Check if an index has free symbol CppCSEVariable that depends on `itervar`. + """ + return any( + self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined] + for s in index.free_symbols + if s.name in self.cse.varname_map # type: ignore[attr-defined] + and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined] + ) + + def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + return itervar in index.free_symbols or self.index_indirect_depends_on( + index, itervar + ) + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + line = f"{var}[{cexpr_index(index)}]" + if V.graph.get_dtype(name) in [torch.float16]: + line = f"static_cast({line})" + csevar = self.cse.generate(self.loads, line) + csevar.update_on_args("load", (name, index), {}) + return csevar + + def store(self, name, index, value, mode=None): + assert "buf" in name + var = self.args.output(name) + self.cache_fp32_cse_var_before_lowp_store(value) + index = self.rename_indexing(index) + if mode is None: + line = f"{var}[{cexpr_index(index)}] = {value};" + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + line = f"{var}[{cexpr_index(index)}] += {value};" + else: + dtype = V.graph.get_dtype(name) + # mirroring static_cast(...) in load: + value = f"static_cast<{DTYPE_TO_CPP[dtype]}>({value})" + line = f"atomic_add(&{var}[{cexpr_index(index)}], {value});" + else: + raise NotImplementedError(f"store mode={mode}") + self.stores.writeline(DeferredLine(name, line)) + + def reduction(self, dtype, src_dtype, reduction_type, value): + argmax_or_argmin = reduction_type in {"argmax", "argmin"} + + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + self.reduction_var_map[acc] = reduction_type + if argmax_or_argmin: + self.reduction_prefix.writelines( + argmax_argmin_prefix(reduction_type, src_dtype, acc) + ) + compare_op = ( + "greater_or_nan" if reduction_type == "argmax" else "less_or_nan" + ) + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + self.stores.writelines( + [ + f"if(!({compare_op}({acc}.value, {value}, {acc}.index, {cexpr_index(index)}))) {{", + f" {acc}.index = {cexpr_index(index)}; {acc}.value = {value};", + "}", + ], + ) + else: + acc_type = reduction_acc_type(reduction_type, dtype) + + if (reduction_type, acc_type) not in self.reduction_omp_dec: + if RTYPE_TO_CPP[reduction_type] not in NATIVE_OMP_RTYPES: + # Scalar reduction for other reductions are declared by default + self.reduction_prefix.splice( + f"""\ + #pragma omp declare reduction(\ + {RTYPE_TO_CPP[reduction_type]}:{acc_type}:\ + omp_out = {reduction_combine(reduction_type, "omp_out", "omp_in")}) \ + initializer(omp_priv={{{reduction_init(reduction_type, dtype)}}}) + """ + ) + self.reduction_omp_dec[reduction_type, acc_type] = RTYPE_TO_CPP[ + reduction_type + ] + + self.reduction_prefix.writeline( + f"{acc_type} {acc} = {reduction_init(reduction_type, dtype)};" + ) + self.stores.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, value)};" + ) + + result = reduction_project(reduction_type, acc) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + self.reduction_suffix.writeline( + DeferredLine(name, f"{var}[{cexpr_index(index)}] = {value};") + ) + + def set_ranges(self, lengths, reduction_lengths): + if self.call_ranges: + assert self.call_ranges == tuple(lengths) + tuple( + reduction_lengths + ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" + assert self.reduction_depth == len(lengths) + else: + self.call_ranges = tuple(lengths) + tuple(reduction_lengths) + self.ranges = [self.rename_indexing(x) for x in self.call_ranges] + self.itervars = [ + sympy_index_symbol(f"x{n}") for n in range(len(self.ranges)) + ] + self.reduction_depth = len(lengths) + return ( + self.itervars[: self.reduction_depth], + self.itervars[self.reduction_depth :], + ) + + def size_hint(self): + return V.graph.sizevars.size_hint( + sympy_product(self.call_ranges), fallback=8192 + ) + + def codegen_loops_impl(self, loop_nest, code, worksharing): + threads = parallel_num_threads() + assert self.call_ranges is not None + par_depth = self.decide_parallel_depth( + self.call_ranges[: loop_nest.max_parallel_depth()], threads + ) + with contextlib.ExitStack() as stack: + if par_depth: + if loop_nest.is_reduction_only(): + # need to close the worksharing scope to define reduction vars outside it + worksharing.close() + else: + worksharing.parallel(threads) + loop_nest.mark_parallel(par_depth) + elif threads > 1: + if worksharing.single(): + stack.enter_context(code.indent()) + + def gen_kernel(kernel): + with contextlib.ExitStack() as stack: + assert kernel + if hasattr(kernel, "codegen_inner_loops"): + code.splice(kernel.preloads) + kernel.codegen_inner_loops(code) + stack.enter_context(code.indent()) + code.splice(kernel.loads) + code.splice(kernel.compute) + code.splice(kernel.stores) + if hasattr(kernel, "codegen_inner_loops"): + code.splice(kernel.poststores) + + def get_reduction_code_buffer(loops, is_suffix=True): + for loop in loops: + for kernel in loop.get_kernels(): + if is_suffix: + return kernel.reduction_suffix + else: + return kernel.reduction_prefix + return None + + def gen_loops(loops: List[LoopLevel], in_reduction=False): + with contextlib.ExitStack() as stack_outer: + if loops: + loop = loops[0] + if loop.is_reduction() and not in_reduction: + reduction_prefix = get_reduction_code_buffer( + loops, is_suffix=False + ) + if reduction_prefix: + stack_outer.enter_context(code.indent()) + code.splice(reduction_prefix) + if loop_nest.is_reduction_only() and loop.parallel: + worksharing.parallel(threads) + + for loop in loops: + gen_loop(loop, in_reduction) + + if loops: + loop = loops[0] + if loop_nest.is_reduction_only() and loop.parallel: + worksharing.close() + if loop.is_reduction() and not in_reduction: + code.splice( + get_reduction_code_buffer(loops, is_suffix=True) + ) + + def gen_loop(loop: LoopLevel, in_reduction=False): + with contextlib.ExitStack() as stack: + loop_lines = loop.lines() + if loop_lines is None: + return + code.writelines(loop_lines) + stack.enter_context(code.indent()) + # generate inner loops or loop body + if loop.inner: + gen_loops(loop.inner, loop.is_reduction()) + else: + kernels = loop.get_kernels() + assert len(kernels) == 1 + gen_kernel(kernels[0]) + + stack.enter_context(code.indent()) + if loop_nest.root: + gen_loops(loop_nest.root) + else: + gen_kernel(loop_nest.kernel) + + def codegen_loops(self, code, worksharing): + loop_nest = LoopNestWithSplit.build(self) + self.codegen_loops_impl(loop_nest, code, worksharing) + + @property + def assert_function(self) -> str: + if V.graph.aot_mode: + return "AOTI_TORCH_CHECK" + else: + return "TORCH_CHECK" + + def decide_parallel_depth(self, ranges, threads): + seq = self.size_hint() + par = 1 + depth = 0 + for expr in ranges: + hint = V.graph.sizevars.size_hint(expr, fallback=8192) + if par >= 2 * threads or par == threads: + break + if seq // threads < config.cpp.min_chunk_size: + # not enough work + break + depth += 1 + par *= hint + seq /= hint + # if we assume thread number is dynamic, make sure we + # have at least one parallel scope and let OMP runtime + # to manage the serial vs. parallel. + if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0: + depth = 1 + return depth + + @contextlib.contextmanager + def write_to_suffix(self): + prior = (self.loads, self.compute, self.stores, self.cse) + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + self.cse = self.cse.clone() + yield + self.reduction_suffix.splice(self.loads) + self.reduction_suffix.splice(self.compute) + self.reduction_suffix.splice(self.stores) + (self.loads, self.compute, self.stores, self.cse) = prior + + def create_cse_var(self, *args, **kwargs): + return CppCSEVariable(*args, **kwargs) + + +class CppVecKernel(CppKernel): + overrides = CppVecOverrides # type: ignore[assignment] + + def __init__( + self, + args, + num_threads, + tiling_factor=0, + tiling_idx=-1, + tiling_dtype=torch.float, + ): + super().__init__(args, num_threads) + self.vec_isa = codecache.pick_vec_isa() + assert self.vec_isa + if tiling_factor == 0: + tiling_factor = self.vec_isa.nelements(dtype=tiling_dtype) + self.tiling_factor = tiling_factor + self.tiling_idx = tiling_idx + + def _get_num_vectors(self, dtype: torch.dtype) -> int: + num_vectors = math.ceil( + self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() + ) + assert num_vectors >= 1 + return num_vectors + + def _get_vec_type(self, dtype: torch.dtype) -> str: + num_vectors = self._get_num_vectors(dtype) + if num_vectors == 1: + return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>" + else: + return f"at::vec::VectorizedN<{DTYPE_TO_CPP[dtype]},{num_vectors}>" + + def _get_vec_load_line( + self, + var: str, + index: sympy.Expr, + dtype: torch.dtype, + load_mask: Optional[CppCSEVariable] = None, + ): + """ + Get a load line str that loads a vector from `var` at `index` of type `dtype`. + If `load_mask` is not None, we do a masked load accordingly. + Notes on the `dtype`: + 1. We always load `self.tiling_factor` number of elements regardless of the `dtype`. + It means we load half of the vector lanes for 16-bit data types and quarter of the + vector lanes for 8-bit data types. + 2. `torch.bool` and `torch.uint8` could mean masks and we load them as float mask vectors. + """ + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx is not None + load_mask_str = f"to_float_mask({load_mask})" if load_mask else None + loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var + if dtype in (torch.uint8, torch.int8) and opt_ctx.is_load_int8_as_float: + assert self._get_num_vectors(torch.uint8) == 1 + line = ( + f"masked_load({loadbuf}, {load_mask_str})" + if load_mask_str + else f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>::loadu_one_fourth({loadbuf})" + ) + elif opt_ctx.is_load_as_mask: + line = f"flag_to_float_vec({loadbuf})" + elif dtype in DTYPE_LOWP_FP: + line = ( + f"masked_load({loadbuf}, {load_mask_str})" + if load_mask_str + else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {self.tiling_factor})" + ) + else: + line = ( + f"masked_load({loadbuf}, {load_mask_str})" + if load_mask_str + else f"{self._get_vec_type(dtype)}::loadu({loadbuf})" + ) + return line + + def load_non_contiguous( + self, + var: Optional[str], + index: sympy.Expr, + dtype: torch.dtype, + buffer: Optional[IndentedBuffer] = None, + ) -> CppCSEVariable: + """ + Load a vector in a non-contiguous way. The vector is initialized from an array that is + filled in an inner loop over the tiling factor. + :param var: buffer to load from, i.e. `var[transformed(index)]`. If None, we load the index + as index expression, i.e. `transformed(index)`. + :param index: index into the `var` or the index expression by its own if `var` is None. + The `index` could contain indirect indexing or the tiling itervar. When used in + the inner loop, the index is transformed as follows: + 1. the index is linearized along the tiling dim. + 2. the indirect indexing vector variables are transformed into arrays over the tiling dim. + :param dtype: data type of `var` or `index` if `var` is None. + :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`. + :return: a CppCSEVariable that represents the loaded vector. + """ + if buffer is None: + buffer = self.loads + + def get_result_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.tiling_factor * (4 // dtype.itemsize) + else: + return self.tiling_factor + + def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: + assert vec_var.is_vec + code = BracesBuffer() + code.writeline("[&]") + with self.swap_buffers(code), code.indent(): + vec_dtype = vec_var.dtype + assert vec_dtype is not None + if vec_dtype == torch.bool: + vec_dtype = torch.float + result_size = get_result_size(vec_dtype) + code.writeline( + f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {result_size}> tmpbuf;" + ) + line = f"{vec_var}.store(tmpbuf.data());" + code.writeline(line) + code.writeline("return tmpbuf;") + code.writeline("()") + csevar = self.cse.generate(buffer, code) + assert isinstance(csevar, CppCSEVariable) + return csevar + + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx is not None + is_mask = opt_ctx.is_load_as_mask + code = BracesBuffer() + code.writeline("[&]") + with self.swap_buffers(code), code.indent(): + result_type = "float" if is_mask else f"{DTYPE_TO_CPP[dtype]}" + result_size = get_result_size(dtype) + result_declare = ( + f"__at_align__ std::array<{result_type}, {result_size}> tmpbuf;" + ) + code.writeline(result_declare) + itervar_inner = sympy_index_symbol( + f"{self.itervars[self.tiling_idx]}_inner" + ) + replacements = {} + for indirect_var in ( + self.cse.varname_map[s.name] # type: ignore[attr-defined] + for s in index.free_symbols + if s.name.startswith("tmp") # type: ignore[attr-defined] + ): + assert isinstance(indirect_var, CppCSEVariable) + if indirect_var.is_vec: + array_var = vec_to_array(indirect_var) + replacements[indirect_var] = f"{array_var}[{itervar_inner}]" + load_mask = None + if self._load_mask is not None: + assert isinstance(self._load_mask, CppCSEVariable), self._load_mask + if self._load_mask.is_vec: + load_mask = ( + f"vector_lane_mask_check({self._load_mask}, {itervar_inner})" + ) + else: + load_mask = f"{self._load_mask} != 0" + index = sympy_subs(index, replacements) # type: ignore[arg-type] + index = self.scale_index_with_offset( + index, itervar_idx=self.tiling_idx, offset=itervar_inner + ) + if codecache.is_gcc(): + code.writeline(f"#pragma GCC unroll {self.tiling_factor}") + else: + code.writeline(f"#pragma unroll {self.tiling_factor}") + code.writeline( + f"for (long {itervar_inner} = 0; {itervar_inner} < {self.tiling_factor}; {itervar_inner}++)" + ) + with code.indent(), contextlib.ExitStack() as stack: + rhs = ( + f"{var}[{cexpr_index(index)}]" + if var is not None + else f"{cexpr_index(index)}" + ) + if is_mask: + rhs = f"flag_to_float_scalar({rhs})" + if load_mask: + code.writeline(f"if ({load_mask})") + stack.enter_context(code.indent()) + code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};") + load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type] + code.writeline(f"return {load_line};") + code.writeline("()") + csevar = self.cse.generate(buffer, code) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + + def load(self, name: str, index: sympy.Expr): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + var = self.args.input(name) + index = self.rename_indexing(index) + dtype = V.graph.get_dtype(name) + tiling_var = self.itervars[self.tiling_idx] + stride = stride_at_vec_range(index, tiling_var, self.tiling_factor) + if stride == 0: + # load scalar and lazily broadcast it on demand + return super().load(name, index) + non_contiguous = stride != 1 or self.index_indirect_depends_on( + index, tiling_var + ) + if non_contiguous: + csevar = self.load_non_contiguous(var, index, dtype) + else: + line = self._get_vec_load_line(var, index, dtype, self._load_mask) + csevar = self.cse.generate(self.loads, line) # type: ignore[assignment] + assert isinstance(csevar, CppCSEVariable) + csevar.update_on_args("load", (name, index), {}) + csevar.is_vec = True + return csevar + + def _get_vec_store_line( + self, + value: Union[str, CppCSEVariable], + var: str, + index: sympy.Expr, + dtype: torch.dtype, + ): + """ + Get a store line str that stores `value` into `var` at `index` of `dtype`. + :param value: Vectorized type templaterized on `dtype`. + :param var: buffer to store into. + :index: index into the `var`. + """ + # when value's type is str (e.g., welford reduction), caller should make sure + # it is a vector + assert isinstance(value, str) or ( + isinstance(value, CppCSEVariable) and value.is_vec + ), value + tiling_var = self.itervars[self.tiling_idx] + assert index.has(tiling_var), f"index: {index}, tiling_var: {tiling_var}" + var_expr = f"{var} + {cexpr_index(index)}" + stride = stride_at_vec_range(index, tiling_var, self.tiling_factor) + non_contiguous = stride != 1 or self.index_indirect_depends_on( + index, tiling_var + ) + if non_contiguous: + var_expr = "tmpbuf" + if dtype == torch.float: + line = f"{value}.store({var_expr});" + else: + line = f"{value}.store({var_expr}, {self.tiling_factor});" + if non_contiguous: + inner = sympy_index_symbol(f"{tiling_var}_inner") + new_index = self.scale_index_with_offset( + index, itervar_idx=self.tiling_idx, offset=inner + ) + tmp_bufsize = ( + f"{self.tiling_factor}*sizeof(float)/sizeof({DTYPE_TO_CPP[dtype]})" + ) + line = ( + f"{{ __at_align__ {DTYPE_TO_CPP[dtype]} tmpbuf[{tmp_bufsize}]; {line} " + f"for (long {inner} = 0; {inner} < {self.tiling_factor}; {inner}++) " + f"{var}[{cexpr_index(new_index)}] = tmpbuf[{inner}]; }}" + ) + return line + + def store(self, name, index, value, mode=None): + assert "buf" in name + assert mode is None + assert isinstance(value, CppCSEVariable), value + if not value.is_vec: + # this happens when we store a scalar into a vectorized buffer like "fill" + value = self.broadcast(value) + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + var = self.args.output(name) + self.cache_fp32_cse_var_before_lowp_store(value) + index = self.rename_indexing(index) + self.stores.writeline( + DeferredLine( + name, + self._get_vec_store_line(value, var, index, V.graph.get_dtype(name)), + ) + ) + + def reduction(self, dtype, src_dtype, reduction_type, value): + assert reduction_type in { + "max", + "min", + "sum", + "prod", + "xor_sum", + "welford_reduce", + "welford_combine", + } + assert dtype == src_dtype + assert dtype in [torch.float, torch.int64] + assert isinstance(value, CppCSEVariable), value + + if not value.is_vec: + value = self.broadcast(value) + + acc_type = reduction_acc_type(reduction_type, dtype) + acc_type_vec = self.reduction_acc_type_vec(reduction_type, dtype) + + if (reduction_type, acc_type) not in self.reduction_omp_dec: + if RTYPE_TO_CPP[reduction_type] not in NATIVE_OMP_RTYPES: + # Scalar reduction for other reductions are declared by default + self.reduction_prefix.splice( + f"""\ +#pragma omp declare reduction(\ +{RTYPE_TO_CPP[reduction_type]}:{acc_type}:\ +omp_out = {reduction_combine(reduction_type, "omp_out", "omp_in")}) \ +initializer(omp_priv={{{reduction_init(reduction_type, dtype)}}}) + """ + ) + self.reduction_omp_dec[reduction_type, acc_type] = RTYPE_TO_CPP[ + reduction_type + ] + + if (reduction_type, acc_type_vec) not in self.reduction_omp_dec: + self.reduction_prefix.splice( + f"""\ +#pragma omp declare reduction(\ +{RTYPE_TO_CPP[reduction_type]}:{acc_type_vec}:\ +omp_out = {self.reduction_combine_vec(reduction_type, "omp_out", "omp_in")}) \ +initializer(omp_priv={{{self.reduction_init_vec(reduction_type, dtype)}}}) + """ + ) + self.reduction_omp_dec[reduction_type, acc_type_vec] = RTYPE_TO_CPP[ + reduction_type + ] + + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + acc_vec = f"{acc}_vec" + + self.reduction_var_map[acc_vec] = reduction_type + self.reduction_prefix.writeline( + f"{acc_type} {acc} = {reduction_init(reduction_type, dtype)};" + ) + self.reduction_prefix.writeline( + f"{acc_type_vec} {acc_vec} = {self.reduction_init_vec(reduction_type, dtype)};" + ) + self.stores.writeline( + f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, value)};" + ) + + tmpvar: Union[str, CSEVariable] + if self.tiling_idx >= self.reduction_depth: + # Horizontal reduction + if is_welford_reduction(reduction_type): + assert ( + self._get_num_vectors(dtype) == 1 + ), "Welford reduction does not support VectorizedN (N>1)" + next_value = f"welford_vec_reduce_all({acc_vec})" + else: + reduce_all_body = ( + "{ return " + + self.reduction_combine_vec(reduction_type, "x", "y") + + "; }" + ) + vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>" + vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[dtype]}>" + next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})" + + self.reduction_suffix.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, next_value)};" + ) + tmpvar = acc + else: + tmpvar = acc_vec + + result = reduction_project(reduction_type, tmpvar) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + out_dtype = V.graph.get_dtype(name) + # Only float reductions are vectorized currently + dtype = torch.float + if self.tiling_idx >= self.reduction_depth: + # Horizontal reduction + self.reduction_suffix.writeline( + DeferredLine( + name, + f"{var}[{cexpr_index(index)}] = static_cast<{DTYPE_TO_CPP[out_dtype]}>({value});", + ) + ) + else: + # Vertical reduction + store_lines = [] + if out_dtype != dtype: + if out_dtype in DTYPE_LOWP_FP and dtype == torch.float: + _lowp_fp_tmpvar_vec = f"{DTYPE_TO_CPP[out_dtype]}_{value}" + store_lines = [ + DeferredLine( + name, + f"auto {_lowp_fp_tmpvar_vec} = cvt_fp32_to_lowp_fp<{DTYPE_TO_CPP[out_dtype]}>({value});", + ) + ] + value = _lowp_fp_tmpvar_vec + else: + raise AssertionError( + f"Unsupported reduction type from {dtype} to {out_dtype}" + ) + store_lines += [ + DeferredLine( + name, + self._get_vec_store_line(value, var, index, out_dtype), + ) + ] + self.reduction_suffix.writelines(store_lines) + + def broadcast(self, scalar_var: CppCSEVariable) -> CppCSEVariable: + assert not scalar_var.is_vec + if scalar_var.dtype == torch.bool: + vec_var = self.cse.generate( + self.compute, f"to_float_mask({scalar_var.name})" + ) + else: + assert scalar_var.dtype is not None + vec_var = self.cse.generate( + self.compute, + f"{self._get_vec_type(scalar_var.dtype)}({scalar_var.name})", + ) + assert isinstance(vec_var, CppCSEVariable) + vec_var.dtype = scalar_var.dtype + vec_var.dependent_itervars = scalar_var.dependent_itervars + vec_var.is_vec = True + return vec_var + + def arange( + self, index: Union[sympy.Expr, CppCSEVariable], stride: sympy.Symbol + ) -> CppCSEVariable: + if isinstance(index, sympy.Expr): + index = cexpr(index) + else: + assert isinstance(index, CppCSEVariable) + assert not index.is_vec + csevar = self.cse.generate( + self.compute, + f"{self._get_vec_type(torch.int32)}::arange({index}, {stride})", + ) + assert isinstance(csevar, CppCSEVariable) + csevar.dtype = torch.int32 + csevar.is_vec = True + return csevar + + def reduction_init_vec(self, reduction_type, dtype): + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>()" + + scalar_init = reduction_init(reduction_type, dtype) + return f"{vec_type}({scalar_init})" + + def reduction_acc_type_vec(self, reduction_type, dtype): + assert reduction_type not in {"argmin", "argmax"} + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>" + + return vec_type + + def reduction_combine_vec(self, reduction_type, var, next_value): + if reduction_type == "max": + return f"at::vec::maximum({var}, {next_value})" + elif reduction_type == "min": + return f"at::vec::minimum({var}, {next_value})" + elif reduction_type == "sum": + return f"{var} + {next_value}" + elif reduction_type == "prod": + return f"{var} * {next_value}" + elif reduction_type == "xor_sum": + return f"{var} ^ {next_value}" + elif reduction_type == "welford_reduce": + return f"welford_combine({var}, {next_value})" + elif reduction_type == "welford_combine": + if isinstance(next_value, tuple): + # When reading a value from Inductor IR we have a tuple of variable names + mean, m2, weight = next_value + else: + # When combining intermediate accumulators we have a Welford struct + mean, m2, weight = reduction_project(reduction_type, next_value) + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + else: + raise NotImplementedError() + + +class CppTile2DKernel(CppVecKernel): + """ + A vector kernel that handles the 2d tiles with the tile size defined in `tiling_factor` on + the inner-most loop level and one of the outer loop level (`outer_tiling_idx`). When the data + tile is accessed in a contiguous way from the outer loop axis, a transposition is applied on the + tile to make the access contiguous from the inner-most loop axis. Then, the same vectorization + logic from its parent `CppVecKernel` is leveraged for load/store/compute. The transposed tile load + and store are generated into kernel.preloads and kernel.poststores buffers. + + The loop structure looks like below: + for ... + for i_outer ... + for ... + for inner_most ... + // generated by CppTile2DKernel + float tmp0[16*16]; at::vec::transpose_mxn<...>(tmp0, in_ptr0 + ..., ...); // into kernel.preloads + float tmp1[16*16]; // into kernel.preloads + for i_inner ... { // the kernel inner loop + vectorized loads/compute/stores (e.g., load tmp0, store tmp1) // into kernel.loads/compute/stores + } + at::vec::transpose_mxn(out_ptr0 + ..., tmp1, ...) // into kernel.poststores + for inner_most ... (tail) + // generated by CppVecKernel + ... + for i_outer ... (tail) + for ... + for ... + // generated by CppKernel + ... + """ + + overrides = CppTile2DOverrides # type: ignore[assignment] + + def __init__(self, args, num_threads, tiling_factor, tiling_indices, tiling_dtype): + super().__init__( + args, num_threads, tiling_factor, tiling_indices[1], tiling_dtype + ) + self.tiling_indices = tiling_indices + + def inner_itervar(self): + return sympy_index_symbol(f"{self.itervars[self.outer_idx]}_inner") + + def need_vec_transpose(self, index): + outer_var = self.itervars[self.outer_idx] + inner_var = self.itervars[self.tiling_idx] + outer_stride = stride_at_vec_range(index, outer_var, self.tiling_factor) + inner_stride = stride_at_vec_range(index, inner_var, self.tiling_factor) + return ( + self._load_mask is None # TODO: support transposition with mask + and outer_stride == 1 + and index.has(inner_var) + and not inner_stride.has(inner_var) + and not inner_stride.has(outer_var) + ) + + def gen_transposed_tile_load_store(self, name, var, index, is_store): + # transposed tile load/store outside the kernel inner loop + dtype = V.graph.get_dtype(name) + factor = self.tiling_factor + src = f"{var} + {cexpr_index(index)}" + dst = "__place_holder__" + ld_src = f"{cexpr_index(stride_at_vec_range(index, self.itervars[self.tiling_idx], self.tiling_factor))}" + ld_dst = f"{factor}" + if is_store: + src, dst = dst, src + ld_src, ld_dst = ld_dst, ld_src + + need_define = True + load_or_store = f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{factor},{factor}>({src}, {ld_src}, {dst}, {ld_dst});" + if is_store: + tile_var = self.cse.newvar() + elif load_or_store not in self.cse.cache: + tile_var = self.cse.generate(self.preloads, load_or_store, write=False) + else: + need_define = False + tile_var = self.cse.cache[load_or_store] + + if need_define: + define_line = f"{DTYPE_TO_CPP[dtype]} {tile_var}[{factor}*{factor}] __attribute__ ((aligned ({factor})));" + self.preloads.writeline(define_line) + + load_or_store = load_or_store.replace("__place_holder__", str(tile_var)) + if is_store: + self.poststores.writeline(DeferredLine(name, load_or_store)) + else: + self.preloads.writeline(load_or_store) + + return tile_var + + def load(self, name: str, index: sympy.Expr): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + var = self.args.input(name) + index = self.rename_indexing(index) + + inner = self.inner_itervar() + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=False + ) + # vector load inside the kernel inner loop + loadbuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}" + dtype = V.graph.get_dtype(name) + line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type] + csevar = self.cse.generate(self.loads, line) + csevar.update_on_args("load", (name, index), {}) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + else: + new_index = self.transform_indexing(index) + return super().load(name, new_index) + + def store(self, name, index, value, mode=None): + assert "buf" in name + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + var = self.args.output(name) + + inner = self.inner_itervar() + index = self.rename_indexing(index) + assert mode is None + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=True + ) + # vector store inside the kernel inner loop + storebuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}" + if V.graph.get_dtype(name) in DTYPE_LOWP_FP: + line = f"{value}.store({storebuf}, {self.tiling_factor});" + elif V.graph.get_dtype(name) in (torch.uint8, torch.int8): + line = f"{value}.store({storebuf}, {self.tiling_factor});" + else: + line = f"{value}.store({storebuf});" + self.stores.writeline(DeferredLine(name, line)) + else: + new_index = self.transform_indexing(index) + super().store(name, new_index, value, mode) + + def codegen_inner_loops(self, code): + inner = self.inner_itervar() + code.writeline( + f"for (long {inner} = 0; {inner} < {self.tiling_factor}; {inner}++)" + ) + + def set_ranges(self, group, reduction_group): + vars = super().set_ranges(group, reduction_group) + # do vertical reduction as the tail loop + self.outer_idx, self.tiling_idx = ( + self.tiling_indices + if self.tiling_indices[1] < self.reduction_depth + else reversed(self.tiling_indices) + ) + return vars + + def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: + return self.scale_index_with_offset( + index, + itervar_idx=self.outer_idx, + offset=self.inner_itervar(), + ) + + +class CppVecKernelChecker(CppVecKernel): + def __init__(self, args, num_threads, tiling_factor, tiling_idx=-1): + super().__init__(args, num_threads, tiling_factor, tiling_idx) + + # Since this kernel is only for checker but does not generate any + # code, so we need to decrease the kernel count. + metrics.generated_kernel_count -= 1 + + # Used to record the graph wrapper code as the wrapper_code status could be + # changed during graph run. + self._orig_wrapper_code = None + + self.simd_vec = True + + self.fast_vec_list = [] + for k, v in CppVecOverrides.__dict__.items(): + if isinstance(v, staticmethod): + self.fast_vec_list.append(k) + self.exit_stack = contextlib.ExitStack() + + # Cache all the load result + self.load_supported_dtypes: List[torch.dtype] = [ + torch.float, + torch.bfloat16, + torch.float16, + torch.bool, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + ] + self.store_supported_dtypes: List[torch.dtype] = [ + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + ] + # Cache the dtypes of the store operation. If the store is mixing dtypes, the + # vectorization would not support it as it is hard to determine the vec dtype + self.store_dtypes: List[torch.dtype] = [] + # The dtype is used for vectorization + self.vec_dtype: torch.dtype = torch.float32 + + def disable_vec(self, msg=None): + if schedule_log.isEnabledFor(logging.DEBUG): + schedule_log.debug("Disabled vectorization: %s", msg) + self.simd_vec = False + + def is_mask(self, name: str, users: Dict[torch.fx.Node, None]): + load_type = V.graph.get_dtype(name) + if load_type == torch.bool: + return all(user.target in ("where", "masked") for user in users.keys()) + elif load_type in (torch.uint8, torch.int8): + """ + If the load value is torch.uint8/int8, then we only support the loaded + value is as the mask. + """ + if not all( + user.target == "to_dtype" and user.args[-1] == torch.bool + for user in users.keys() + ): + return False + + for to_dtype_node in users.keys(): + assert to_dtype_node.target == "to_dtype" + if not all( + user.target in ("where", "masked") + for user in to_dtype_node.users.keys() + ): + return False + return True + else: + return False + + def is_load_int8_as_float(self, name: str, users: Dict[torch.fx.Node, None]): + """ + Check: + 1. load_type is torch.uint8 or torch.int8 + 2. has 1 user node of target to_dtype + 3. dtype of to_dtype is torch.float + """ + load_type = V.graph.get_dtype(name) + if load_type not in (torch.uint8, torch.int8): + return False + if len(users) == 1: + user = next(iter(users)) + if (user.target == "to_dtype") and (user.args[-1] == torch.float): + return True + return False + return False + + def can_store_fp32_as_int8(self, store_var: str, value_node: torch.fx.Node): + """ + Check: + 1. store_type is torch.uint8/torch.int8 + 2. value_node is of target to_dtype + 3. dtype of to_dtype node is torch.uint8/torch.int8 + """ + store_type = V.graph.get_dtype(store_var) + if store_type not in (torch.uint8, torch.int8): + return False + if value_node.target == "to_dtype" and value_node.args[-1] in ( + torch.uint8, + torch.int8, + ): + return True + + return False + + def is_load_integer_scalar_tensor(self, name: str, index: sympy.Expr): + load_dtype = V.graph.get_dtype(name) + buffer = V.graph.get_buffer(name) + return ( + load_dtype in [torch.int32, torch.int64] + and isinstance(buffer, TensorBox) + and isinstance(buffer.data, StorageBox) + and (len(buffer.data.layout.size) == 0) + and (index == 0) + ) + + def load(self, name: str, index: sympy.Expr): + with RecordOptimizationContext(__name__) as node_ctx: + load_dtype = V.graph.get_dtype(name) + opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() + assert opt_ctx + opt_ctx.dtype = load_dtype + opt_ctx.is_load_as_mask = self.is_mask(name, node_ctx.get_fx_node().users) + opt_ctx.is_load_int8_as_float = self.is_load_int8_as_float( + name, node_ctx.get_fx_node().users + ) + + var = self.cse.newvar() + + if len(self.itervars) == 0: + self.disable_vec("not a loop") + return var + + if load_dtype in (torch.bool, torch.uint8, torch.int8) and not ( + opt_ctx.is_load_as_mask or opt_ctx.is_load_int8_as_float + ): + if not opt_ctx.is_load_as_mask: + self.disable_vec(f"{load_dtype} not loaded as mask") + elif not opt_ctx.is_load_int8_as_float: + self.disable_vec(f"{load_dtype} not loaded as float") + return var + + if ( + (load_dtype not in self.load_supported_dtypes) + and not self.is_load_integer_scalar_tensor(name, index) + and index.has(self.itervars[self.tiling_idx]) + ): + self.disable_vec(f"{load_dtype} not supported by load") + return var + + return var + + def store(self, name, index, value, mode=None): + with RecordOptimizationContext(__name__) as node_ctx: + if len(self.itervars) == 0: + self.disable_vec("not a loop") + return self.simd_vec + + store_dtype = V.graph.get_dtype(name) + + opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() + assert opt_ctx + opt_ctx.dtype = store_dtype + + store_dtype = torch.float if store_dtype == torch.float32 else store_dtype + self.store_dtypes.append(store_dtype) + if store_dtype not in self.store_supported_dtypes: + self.disable_vec(f"{store_dtype} not supported by store") + return self.simd_vec + + if store_dtype in (torch.uint8, torch.int8): + value_node = node_ctx.get_fx_node().all_input_nodes[-1] + if not self.can_store_fp32_as_int8(name, value_node): + self.disable_vec("not support store float32 as uint8/int8") + return self.simd_vec + + assert "buf" in name + index = self.rename_indexing(index) + + if mode: + self.disable_vec(f"store mode: {mode}") + return self.simd_vec + + if index.is_number: + self.disable_vec(f"constant store index: {index}") + return self.simd_vec + + def reduction(self, dtype, src_dtype, reduction_type, value): + if ( + (dtype == torch.float and src_dtype == torch.float) + or (dtype == torch.int64 and src_dtype == torch.int64) + and reduction_type in VECTORIZABLE_RTYPES + ): + pass + else: + self.disable_vec( + f"reduction: dtype {dtype}, src_dtype {src_dtype}, reduction_type {reduction_type}" + ) + if is_welford_reduction(reduction_type): + return tuple([self.simd_vec] * 3) + return self.simd_vec + + def store_reduction(self, name, index, value): + return self.simd_vec + + def is_supported_cmp(self, node: torch.fx.Node): + def get_node_dtype(node): + if type(node) == torch.fx.Node: + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + return opt_ctx.dtype if opt_ctx else None + else: + return None + + def get_cmp_dtypes(node: torch.fx.Node): + return get_node_dtype(node.args[-2]), get_node_dtype(node.args[-1]) + + assert len(node.args) >= 2 + # cmp(x, y): y is a magic value like x >= 1 + if type(node.args[-1]) in [int, float]: + return True + # cmp(x, y): x is a magic value like 1 >= y + if type(node.args[-2]) in [int, float]: + return False + + left_dtype, right_dtype = get_cmp_dtypes(node) + if left_dtype is None or right_dtype is None: + # TODO(Eikan): To record, deduce and propagate the data type of every expression. + return True + else: + return left_dtype == right_dtype + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self._orig_wrapper_code is not None + # Restore the wrapper_code + V.graph.wrapper_code = self._orig_wrapper_code + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + def __enter__(self): + # Record the graph wrapper code. The wrapper_code status could be + # changed during graph run. Regarding this checker, we also need to + # run the graph but we don't expect to change any status that would + # impact the code generation. Hence, we record the graph wrapper code + # and replace it with a dummy wrapper_code and then restore to the + # original one as long as the checker is finished. + self._orig_wrapper_code = V.graph.wrapper_code + V.graph.wrapper_code = WrapperCodeGen() + + parent_handler = V.MockHandler() + + class VecCheckerProxy: + bin_cmp_ops = ["eq", "ne", "le", "ge", "lt", "gt"] + + @staticmethod + def _bin_cmp_op(x, y): + current_node: torch.fx.Node = V.interpreter.current_node + if not self.is_supported_cmp(current_node): + self.disable_vec(f"binary comparison op: {current_node}") + return self.simd_vec + + @staticmethod + def __getattr__(name): # type: ignore[misc] + def inner(*args, **kwargs): + if name in VecCheckerProxy.bin_cmp_ops: + return VecCheckerProxy._bin_cmp_op(args, kwargs) + + if name not in self.fast_vec_list: + self.disable_vec(f"op: {name}") + + parent_val = getattr(parent_handler, name)(*args, **kwargs) + return pytree.tree_map(lambda _: self.simd_vec, parent_val) + + return inner + + @staticmethod + def load(name: str, index: sympy.Expr): + return self.load(name, index) + + @staticmethod + def store(name, index, value, mode=None): + return self.store(name, index, value, mode=mode) + + @staticmethod + def reduction(dtype, src_dtype, reduction_type, value): + return self.reduction(dtype, src_dtype, reduction_type, value) + + @staticmethod + def store_reduction(name, index, value): + return self.store_reduction(name, index, value) + + @staticmethod + def constant(val, dtype): + with RecordOptimizationContext(__name__) as node_ctx: + opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() + assert opt_ctx + # VecKernel override dtype for constant + # Vectorization only support int32/fp32 now + # So if dtype = int64/fp64, we will cast it to int32/fp32 if possible + i32_iinfo = torch.iinfo(torch.int32) + if ( + dtype == torch.int64 + and val <= i32_iinfo.max + and val >= i32_iinfo.min + ): + opt_ctx.dtype = torch.int32 + + f32_iinfo = torch.finfo(torch.float32) + if dtype == torch.double: + if ( + (val <= f32_iinfo.max and val >= f32_iinfo.min) + or (val == torch.inf) + or (val == -torch.inf) + ): + opt_ctx.dtype = torch.float32 + + supported_dtypes = [ + torch.float32, + torch.int32, + torch.int64, + torch.bfloat16, + torch.float16, + torch.bool, + ] + + if opt_ctx.dtype not in supported_dtypes or ( + opt_ctx.dtype == torch.int32 + and not all( + user.target in VecCheckerProxy.bin_cmp_ops + for user in node_ctx.current_node.users + ) + ): + self.disable_vec(f"constant dtype: {opt_ctx.dtype}") + return val + + @staticmethod + def index_expr(expr, dtype): + assert len(self.ranges) == len(self.itervars) + if not len(self.ranges) or not all( + not isinstance(range, sympy.Expr) or sympy.simplify(range).is_number + for range in self.ranges + ): + # if the range value is sympy.Expr, we might could not deduce the accurate loop interval. + self.disable_vec(f"index_expr: {expr}, dtype {dtype}") + return self.cse.newvar() + + def can_use_int32(): + free_symbols = list(expr.free_symbols) + sizes = { + k: v + for k, v in zip(self.itervars, self.ranges) + if k in free_symbols + } + # Trivial case: Range empty + if any(v == 0 for v in sizes.values()): + return True + + vars_ranges = {k: ValueRanges(0, v - 1) for k, v in sizes.items()} + if not vars_ranges or len(vars_ranges) != len(free_symbols): + i32_iinfo = torch.iinfo(torch.int32) + return ( + expr.is_number + and expr <= i32_iinfo.max + and expr >= i32_iinfo.min + ) + expr_ranges = bound_sympy(expr, vars_ranges) + if math.isinf(expr_ranges.lower) or math.isinf(expr_ranges.upper): # type: ignore[arg-type] + return False + # If something takes the values 0..7, we will compare in the loop + # x < 8. As such, for the loop not to overflow in the last iteration, we want + # to check that expr_ranges.upper + 1 is representable as well + return range_expressable_in_32_bits( + ValueRanges( + int(expr_ranges.lower), int(expr_ranges.upper) + 1 # type: ignore[arg-type] + ) + ) + + with RecordOptimizationContext(__name__) as node_ctx: + assert len(self.ranges) == len(self.itervars) + opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() + assert opt_ctx + if ( + dtype == torch.int64 + and can_use_int32() + and all( + user.target in VecCheckerProxy.bin_cmp_ops + for user in node_ctx.current_node.users + ) + ): + opt_ctx.dtype = torch.int32 + else: + opt_ctx.dtype = dtype + self.disable_vec(f"index_expr: {expr}, dtype {dtype}") + + tmp_var = self.cse.newvar() + return tmp_var + + @staticmethod + def indirect_indexing(index_var, size, check=True): + return sympy_index_symbol(str(index_var)) + + @staticmethod + def masked(mask, body, other): + body() + return self.cse.newvar() + + @staticmethod + def to_dtype(x, dtype, src_dtype=None): + with RecordOptimizationContext(__name__) as node_ctx: + opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() + assert opt_ctx + opt_ctx.dtype = dtype + + cur_node = node_ctx.get_fx_node() + input_value: torch.fx.Node = cur_node.all_input_nodes[1] + if dtype == torch.float: + if input_value.target in [ + "load", + ]: + # Support masked_load for BF16/FP16. Because the legalization will + # insert to_dtype to convert the BF16/FP16 input to FP32. + dtype = ( + V.graph.get_dtype(input_value.args[1]) # type: ignore[arg-type] + if input_value.target == "load" + else input_value.args[-1] + ) + if dtype in [ + torch.float16, + torch.bfloat16, + torch.float, + torch.float64, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + ]: + # Convert from dtype to torch.float + pass + else: + self.disable_vec(f"to_dtype: dtype {dtype}") + elif dtype in DTYPE_LOWP_FP: + if not all(usr.target == "store" for usr in cur_node.users): + self.disable_vec( + "to_dtype: bfloat16/float16 expecting users are all stores" + ) + return x + + store_names = [usr.args[1] for usr in cur_node.users] + if not all( + V.graph.get_dtype(name) in [dtype] for name in store_names + ): + self.disable_vec( + "to_dtype: expecting all stores into bfloat16 or float16" + ) + return x + elif dtype == torch.bool: + pass + elif dtype in (torch.uint8, torch.int8): + # Only allow below 2 cases: + # Case 1: to_int8 and store which corresponding to the single quant node + # at last of fusion pattern. + is_to_int8_and_store = all( + usr.target in ["store"] for usr in cur_node.users + ) + # Case 2: to_int8 and to_float which corresponding to pair of quant/dequant node + # at middle of fusion pattern. + is_to_int8_and_to_float = all( + ( + usr.target in ["to_dtype"] + and usr.args[2] == torch.float32 + ) + for usr in cur_node.users + ) + if not (is_to_int8_and_store or is_to_int8_and_to_float): + self.disable_vec(f"to_dtype: dtype {dtype}") + elif dtype in [torch.int64, torch.int32]: + pass + else: + self.disable_vec(f"to_dtype: dtype {dtype}") + return x + + self.exit_stack.enter_context(V.set_ops_handler(VecCheckerProxy())) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self + + +class CppKernelProxy(CppKernel): + def __init__(self, kernel_group): + super().__init__(kernel_group.args, kernel_group.ws.num_threads) + self.kernel_group = kernel_group + self.loop_nest = None + self.call_ranges = None + self.picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() + + def data_type_propagation(self, nodes): + for _node in nodes: + assert isinstance(_node, SchedulerNode) + DataTypePropagation.propagate_scheduler_node(_node) + + # Check if all the nodes of a given fx graph can support BF16/FP16 + def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode): + if not isinstance(scheduler_node._body, ir.LoopBody): + return True + + _lowp_fp_type: Optional[torch.dtype] = None + + # Propagate the dtype to check if all the fx node is bf16/fp16 + DataTypePropagation.propagate_scheduler_node(scheduler_node) + + sub_blocks = [scheduler_node._body.root_block] + list( + scheduler_node._body.subblocks.values() + ) + for sub_block in sub_blocks: + for _node in sub_block.graph.nodes: + # TODO(Eikan): Regarding get_index and index_expr, we should conclude the + # the data type as well. + if _node.op == "placeholder" or _node.target in ( + "get_index", + "index_expr", + ): + continue + + # Fast path if all operations can support bf16/fp16 without converting to fp32 + if _node.target not in [ + "load", + "store", + "abs", + "neg", + "output", + ]: + return False + + if hasattr(_node, "meta") and _node.meta: + assert OptimizationContext.key in _node.meta + opt_ctx: OptimizationContext = _node.meta[OptimizationContext.key] + if not opt_ctx.dtype or opt_ctx.dtype not in DTYPE_LOWP_FP: + return False + if _lowp_fp_type: + assert ( + _lowp_fp_type == opt_ctx.dtype + ), "scheduler node do not support bf16/fp16 mix" + else: + _lowp_fp_type = opt_ctx.dtype + else: + return False + + scheduler_node._lowp_fp_type = _lowp_fp_type # type: ignore[attr-defined] + return True + + def legalize_lowp_fp_dtype(self, nodes): + def add_to_dtype(sub_graph: torch.fx.Graph): + def is_lowp_fp_load(node: torch.fx.Node): + if node.target not in ["load"]: + return False + assert len(node.args) == 3 + load_dtype = V.graph.get_dtype(node.args[1]) # type: ignore[arg-type] + return load_dtype in DTYPE_LOWP_FP + + def is_lowp_fp_store(node: torch.fx.Node): + if node.target != "store": + return False + _, store_var, _, _, _ = node.args + store_dtype = V.graph.get_dtype(store_var) # type: ignore[arg-type] + return store_dtype in DTYPE_LOWP_FP + + sub_graph_nodes = list(sub_graph.nodes) + to_lowp_fp_legalized_nodes = [] + for _node in sub_graph_nodes: + if is_lowp_fp_load(_node): + # No need to promote to float if all users are direct stores + if all(user.target == "store" for user in _node.users): + continue + ops = _node.args[0] + with sub_graph.inserting_after(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, _node, torch.float) + ) + to_type_node_args = to_type_node.args + _node.replace_all_uses_with(to_type_node) + to_type_node.args = to_type_node_args + metrics.cpp_to_dtype_count += 1 + elif is_lowp_fp_store(_node): + ops, name, _, value_var, _ = _node.args + # No need to promote to float if it is a user of a load which are all directly stored + if value_var.target == "load" and all( + user.target == "store" for user in value_var.users + ): + continue + dtype = V.graph.get_dtype(name) + with sub_graph.inserting_before(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, value_var, dtype) + ) + _node.replace_input_with(value_var, to_type_node) + metrics.cpp_to_dtype_count += 1 + elif _node.target == "reduction": + ( + ops, + dtype, + src_dtype, + reduction_type, + value, + ) = _node.args + if src_dtype in DTYPE_LOWP_FP: + # Since we always convert the load/store value to float if the tensor is bfloat16/float16. + # Therefore, the reduction should never work with bfloat16/float16 value. Hence, we update + # the bfloat16/float16 reduction by + # 1) updating the src_dtype to float + # and 2) updating the dtype to float if it is bfloat16/float16. + assert dtype in [ + torch.float, + torch.bfloat16, + torch.float16, + torch.int64, + ] + _node.args = ( + ops, + torch.float if dtype in DTYPE_LOWP_FP else dtype, + torch.float, + reduction_type, + value, + ) + elif _node.target == "to_dtype" and _node.args[-1] in DTYPE_LOWP_FP: + (ops, x, _) = _node.args + # The legalization always loads the BF16/FP16 tensor as FP32 for computation + # and converts back to BF16/FP16 after the computation. + # Hence, there should be no computation w/ BF16/FP16. + # Therefore, we update the to_dtype by replacing the bf16/fp16 dtype with fp32. + # Save the legalized to_dtype node for the elimination(eliminate_to_dtype step): + # 1) Eliminate the redundant to_dtype node if we have a pattern as follows: + # graph(): + # %lowp_fp_legalized = call_method[target=to_dtype](args = (%ops, %input, torch.float)) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %lowp_fp_legalized, torch.bfloat16/float16)) + # Regarding the first to_dtype, it is redundant because + # the second to_type also converts to the torch.bfloat16/torch.float16. + # Hence, we remove the first to_type. + to_lowp_fp_legalized_nodes.append(_node) + _node.args = (ops, x, torch.float) + else: + pass + + def eliminate_to_dtype(sub_graph: torch.fx.Graph): + def _eliminate_duplicate_to_node(sub_graph: torch.fx.Graph): + # Eliminate the redundant to_dtype node. Let's consider a pattern as follows: + # graph(): + # %to_dtype1 = call_method[target=to_dtype](args = (%ops, %input, torch.float), kwargs = {}) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %to_dtype1, torch.float), kwargs = {}) + # Regarding the first to_dtype, it is redundant because the second to_type also converts to the + # torch.float. Hence, we remove the first to_type + def _used_by_to(to_node: torch.fx.Node): + return all(usr.target == "to_dtype" for usr in to_node.users) + + all_to_nodes = [ + node for node in sub_graph.nodes if node.target == "to_dtype" + ] + all_to_nodes_and_users = [ + {node: node.users} for node in all_to_nodes if _used_by_to(node) + ] + for node_users in all_to_nodes_and_users: + for node, users in node_users.items(): + if node in sub_graph.nodes and ( + all(usr.args[-1] == node.args[-1] for usr in users) + or ( + node in to_lowp_fp_legalized_nodes + and all( + usr.args[-1] in DTYPE_LOWP_FP for usr in users + ) + ) + ): + val_node = node.all_input_nodes[-1] + node.replace_all_uses_with(val_node) + sub_graph.erase_node(node) + + # For debug mode, the graph of LoopBody will attach a new GraphModule as + # owning_module for debugging while the release mode will not. The lint will + # check whether the graph has owning_module to decide if it needs to check + # call_module. LoopBody might contain get_index as a module call. But it + # is just a function. Hence, it cannot pass the lint check for debug mode. + # We bypass the check if the owning_module is None. Eventually, we should call + # get_index via call_function but not call_module. + if sub_graph.owning_module is None: + sub_graph.lint() + + _eliminate_duplicate_to_node(sub_graph) + + eliminate_to_dtype(sub_graph) + + def _legalize_lowp_fp(loop_body: ir.LoopBody): + sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values()) + for sub_block in sub_blocks: + add_to_dtype(sub_block.graph) + + if all( + isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node) + for _node in nodes + ): + # Mark the load node to load bf16/fp16 + for _node in nodes: + sub_blocks = [_node._body.root_block] + list( + _node._body.subblocks.values() + ) + for sub_block in sub_blocks: + for fx_node in sub_block.graph.nodes: + if fx_node.target in ["load", "store"]: + assert fx_node.meta + assert OptimizationContext.key in fx_node.meta + opt_ctx: OptimizationContext = fx_node.meta[ + OptimizationContext.key + ] + assert opt_ctx.dtype in DTYPE_LOWP_FP + + # Bypass the legalization as the kernel can run with bf16/fp16 directly + return + + for _node in nodes: + assert isinstance(_node, SchedulerNode) + assert isinstance(_node._body, ir.LoopBody) + node: SchedulerNode = _node + + def is_memory_copy_scheduler_node(node: SchedulerNode): + op_counts = node.read_writes.op_counts + return ( + len(op_counts) == 2 and "load" in op_counts and "store" in op_counts + ) + + should_legalize = not is_memory_copy_scheduler_node(node) + if should_legalize: + body: ir.LoopBody = node._body + _legalize_lowp_fp(body) + + def codegen_nodes(self, nodes: List[SchedulerNode]): + # Legalize BF16 node by adding to_dtype explicitly + self.legalize_lowp_fp_dtype(nodes) + self.data_type_propagation(nodes) + + assert len(nodes) >= 1 + first_node = nodes[0] + vec_dtype = ( + first_node._lowp_fp_type # type: ignore[attr-defined] + if all( + hasattr(_node, "_lowp_fp_type") + and _node._lowp_fp_type == first_node._lowp_fp_type # type: ignore[attr-defined] + for _node in nodes + ) + else torch.float + ) + + kernel_group = self.kernel_group + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + + self.set_ranges(group, reduction_group) + + def codegen_kernel(cls, *args): + with kernel_group.new_kernel(cls, *args) as kernel: + # Ugly hack to maintain the metrics kernel count since + # we only count in CppKernelProxy, not those contained in it + metrics.generated_kernel_count -= 1 + + run(kernel) + return kernel + + def run(kernel): + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + in_suffix = False + for node in nodes: + if node.group[1] in [ + (group, reduction_group), + (group + reduction_group, ()), + ]: + assert not in_suffix + node.run(vars, reduction_vars) + else: + in_suffix = True + assert node.group[1] == ( + group, + (), + ), f"unexpected group: {node.group[1]} != {group}, {reduction_group}" + # we can fuse in some extra pointwise into the suffix + with kernel.write_to_suffix(): + node.run(vars, ()) + + scalar_kernel = codegen_kernel(CppKernel) + V.graph.removed_buffers |= scalar_kernel.removed_buffers + V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove + self.loop_nest = LoopNestWithSplit.build(scalar_kernel) + + if not self.picked_vec_isa: + return + + def select_tiling_indices(tiling_factor): + all_index = [] + for node in nodes: + rw = dependencies.extract_read_writes(node._body, *node._sizes) + all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)] + contig_vars = set() + contig_vars_list = [] + non_contig_stride_const = set() + non_contig_stride_other = set() + for index in all_index: + for var in index.free_symbols: + if not re.search(r"^d\d+$", var.name): + continue + stride = stride_at_vec_range(index, var, tiling_factor) + if stride == 0: + continue + elif stride == 1: + contig_vars.add(int(var.name[1:])) + contig_vars_list.append(int(var.name[1:])) + elif all(s.name.startswith("s") for s in stride.free_symbols): + non_contig_stride_const.add(int(var.name[1:])) + else: + non_contig_stride_other.add(int(var.name[1:])) + contig_only = ( + contig_vars - non_contig_stride_const - non_contig_stride_other + ) + if len(contig_vars) == 0: + # no contiguous vars + return [len(self.itervars) - 1] + if contig_only: + return sorted(contig_only)[-1:] + contig_and_const_stride = ( + contig_vars & non_contig_stride_const + ) - non_contig_stride_other + contig_vars_sorted = sorted(contig_vars) + if ( + len(contig_vars_sorted) == 2 + and contig_vars_sorted[-1] in contig_and_const_stride + and contig_vars_sorted[-1] == len(self.itervars) - 1 + ): + return contig_vars_sorted + return sorted(contig_vars_sorted, key=contig_vars_list.count)[-1:] + + def select_tiling(dtype: torch.dtype = torch.float): + # TODO(jgong5): support alternative tiling factors and data types + tiling_factor = self.picked_vec_isa.nelements(dtype=dtype) + tiling_indices = select_tiling_indices(tiling_factor) + if tiling_indices: + could_vec = True + for tiling_indice in tiling_indices: + with CppVecKernelChecker( + deepcopy(self.kernel_group.args), + parallel_num_threads(), + tiling_factor, + tiling_indice, + ) as vec_checker: + run(vec_checker) + could_vec = could_vec and vec_checker.simd_vec + if not could_vec: + break + if could_vec: + if len(tiling_indices) == 1: + return [tiling_factor], tiling_indices + if len(tiling_indices) == 2: + return [tiling_factor, tiling_factor], tiling_indices + return [], [] + + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. + # But the generated scalar kernel has updated these global contexts. Hence, the other kernels + # should not do this again to avoid context conflict. By now, we only control the + # config.inplace_buffers. In the future, we could maintain more contexts. + with torch._inductor.config.patch(inplace_buffers=False): + tiling_factors, tiling_indices = select_tiling(vec_dtype) + assert len(tiling_factors) == len(tiling_indices) + try: + if len(tiling_indices) == 1: + vec_kernel = codegen_kernel( + CppVecKernel, tiling_factors[0], tiling_indices[0], vec_dtype + ) + metrics.generated_cpp_vec_kernel_count += 1 + main_loop, tail_loop = self.loop_nest.split_with_tiling( + tiling_indices[0], factor=tiling_factors[0] + ) + main_loop.set_kernel(vec_kernel) + tail_loop.set_kernel(scalar_kernel) + main_loop.simd_vec = True + tail_loop.simd_omp = True + # We chop the loop into two cubes by the nelements - main loop and tail loop. + # Regarding the main loop, it is straightforward that it could be vectorized with + # nelements. But for the tail loop, it still could be vectorized. For example, + # if the nelements is 8(256bits), then the tail loop still could be vectorized + # as 4(128bits). + tail_loop.simd_nelements = tiling_factors[0] // 2 + elif len(tiling_indices) == 2: + assert ( + tiling_indices[1] == len(self.itervars) - 1 + and tiling_factors[0] == tiling_factors[1] + ) + tile2d_kernel = codegen_kernel( + CppTile2DKernel, tiling_factors[0], tiling_indices, vec_dtype + ) + vec_kernel = codegen_kernel( + CppVecKernel, tiling_factors[0], tiling_indices[0], vec_dtype + ) + metrics.generated_cpp_vec_kernel_count += 2 + outer_main_loop, outer_tail_loop = self.loop_nest.split_with_tiling( + tiling_indices[0], factor=tiling_factors[0] + ) + outer_tail_loop.set_kernel(scalar_kernel) + ( + inner_main_loop, + inner_tail_loop, + ) = outer_main_loop.split_with_tiling( + tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0] + ) + inner_main_loop.set_kernel(tile2d_kernel) + inner_tail_loop.set_kernel(vec_kernel) + except CppVecUnsupportedError as e: + if schedule_log.isEnabledFor(logging.DEBUG): + schedule_log.debug("Disabled vectorization: %s", e) + + def codegen_loops(self, code, worksharing): + self.codegen_loops_impl(self.loop_nest, code, worksharing) + + +class ReasonFusedNodes(Enum): + SAME_VARS_REDUCE = "same_vars_reduce" + COMPATIBLE_REDUCTION = "compatible_reduction" + COMPATIBLE_RANGES_NO_REDUCTION = "compatible_ranges_no_reduction" + + +class CppScheduling(BaseScheduling): + # ctypes limits the number of args to 1024, refer to: + # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237 + # We set a conservative threshold here. + MAX_FUSED_KERNEL_ARGS_NUM = 500 + + def __init__(self, scheduler): + self.scheduler = scheduler + self.get_kernel_group() + self._ready_to_flush = False + + def _set_flush_status(self, status: bool): + self._ready_to_flush = status + + def group_fn(self, sizes): + return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) + + def get_kernel_group(self): + from .cpp_wrapper_cpu import CppWrapperCpu + + self.kernel_group: Union[CppWrapperKernelGroup, KernelGroup] + if isinstance(V.graph.wrapper_code, CppWrapperCpu): + self.kernel_group = CppWrapperKernelGroup() + else: + self.kernel_group = KernelGroup() + + def fuse(self, node1, node2): + if node1.is_foreach() or node2.is_foreach(): + return ForeachKernelSchedulerNode.fuse(node1, node2) + else: + if ( + self._why_fuse_nodes(node1, node2) + == ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + ): + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) + assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) + + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + assert reduce1 == () and reduce2 == (), (reduce1, reduce2) + + def get_indexing_ranges_exprs(node): + if isinstance(node, FusedSchedulerNode): + assert len(node.snodes) > 0 + return get_indexing_ranges_exprs(node.snodes[0]) + else: + assert isinstance(node, SchedulerNode) + comp_buffer = node.node + assert isinstance(comp_buffer, ir.ComputedBuffer) + _, body, _ = comp_buffer.get_default_sizes_body() + return body.var_ranges, list(body.indexing_exprs.values()) + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + assert isinstance(node_to_recomp, SchedulerNode) + + ref_node = node2 if len(vars1) < len(vars2) else node1 + + extra_indexing_constraints = get_indexing_ranges_exprs(ref_node) + + node_to_recomp.recompute_size_and_body( + extra_indexing_constraints=extra_indexing_constraints + ) + + _, (vars1, _) = node1.group + _, (vars2, _) = node2.group + assert vars1 == vars2, (vars1, vars2) + + return FusedSchedulerNode.fuse(node1, node2) + + def _why_fuse_nodes(self, node1, node2) -> Optional[ReasonFusedNodes]: + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + if vars1 == vars2 and reduce1 == reduce2: + return ReasonFusedNodes.SAME_VARS_REDUCE + if reduce1 == () and vars1 == vars2 + reduce2: + return ReasonFusedNodes.COMPATIBLE_REDUCTION + if self._can_fuse_nodes_with_compatible_ranges(node1, node2): + return ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + # TODO(jansel): allow fusion pointwise (vars1, ()) suffix? + return None + + def _can_fuse_nodes_with_compatible_ranges(self, node1, node2): + # Here we try to fuse SchedulerNode/FusedSchedulerNode with compatible ranges + # e.g. (s0, s1, s2) and (s0 * s1 * s2) + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + c1 = reduce1 == () and reduce2 == () + c2 = math.prod(vars1) == math.prod(vars2) + c3 = len(vars1) == 1 or len(vars2) == 1 + if not (c1 and c2 and c3): + return False + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + ref_node = node2 if len(vars1) < len(vars2) else node1 + + # We can not recompute sizes and body for nodes other than SchedulerNode + # TODO: we can extend fusion support with compatible ranges for FusedSchedulerNode + if isinstance(node_to_recomp, FusedSchedulerNode): + return False + + def get_buffer(node): + if isinstance(node, FusedSchedulerNode): + assert len(node.snodes) > 0 + # use the last scheduler node from the list as it has the most + # relevant indexing expressions + return get_buffer(node.snodes[-1]) + else: + assert isinstance(node, SchedulerNode) + return node.node + + ref_node_buffer = get_buffer(ref_node) + if isinstance(ref_node_buffer, ir.TemplateBuffer): + return False + + assert isinstance(ref_node_buffer, ir.ComputedBuffer) + + # It may happen that node1 and node2 compatible number of elements + # but different original ranges, for example: + # {d0: s0, d1: s1, d2: s2} vs {d0: s0*s1*s2} + # See https://github.com/pytorch/pytorch/pull/120077/files#r1500427848 for more details + # TODO: we can fix if it allows us to CSE at least one of the variables + var_ranges1 = ref_node_buffer.get_read_writes().var_ranges + var_ranges2 = node_to_recomp.node.get_read_writes().var_ranges + if var_ranges1 != var_ranges2: + return False + + return True + + def _can_fuse_horizontal_impl(self, node1, node2): + assert isinstance(node1, (FusedSchedulerNode, SchedulerNode)) + assert isinstance(node2, (FusedSchedulerNode, SchedulerNode)) + return self._why_fuse_nodes(node1, node2) is not None + + def can_fuse_horizontal(self, node1, node2): + if ( + len(node1.get_nodes()) + len(node2.get_nodes()) + > config.cpp.max_horizontal_fusion_size + ): + return False + + return self._can_fuse_horizontal_impl(node1, node2) + + def can_fuse_vertical(self, node1, node2): + return self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() + + def codegen_nodes(self, nodes: List[SchedulerNode]): + """ + Turn an set of pre-fused nodes into a C++ kernel. + """ + kernel_group = self.kernel_group + + cpp_kernel_proxy = CppKernelProxy(kernel_group) + cpp_kernel_proxy.codegen_nodes(nodes) + + kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) + + args_num = self._get_scheduled_num_args() + if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: + self._set_flush_status(True) + + def _get_scheduled_num_args(self): + return self.kernel_group.get_num_args() + + def ready_to_flush(self): + return self._ready_to_flush + + def codegen_sync(self): + pass + + def flush(self): + self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) + self.get_kernel_group() + self._set_flush_status(False) + + +class KernelGroup: + def __init__(self): + super().__init__() + self.args = KernelArgs() + self.loops_code = BracesBuffer() + self.ws = WorkSharing(self.loops_code) + self.stack = contextlib.ExitStack() + self.stack.enter_context(self.ws) + self.scheduled_nodes = [] + + def new_kernel(self, cls, *args): + return cls(self.args, parallel_num_threads(), *args) + + def finalize_kernel(self, new_kernel, nodes): + self.scheduled_nodes += nodes + code = self.loops_code + ws = self.ws + new_kernel.codegen_loops(code, ws) + + def get_num_args(self): + arg_defs, call_args, arg_types = self.args.cpp_argdefs() + args_num = len(arg_defs) + return args_num + + def codegen_define_and_call(self, wrapper): + self.stack.close() + if not self.scheduled_nodes: + return + + fused_name = ( + get_fused_kernel_name(self.scheduled_nodes, config.cpp.descriptive_names) + if config.cpp.descriptive_names + else "" + ) + kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) + arg_defs, call_args, arg_types = self.args.cpp_argdefs() + arg_defs = ",\n".ljust(25).join(arg_defs) + code = BracesBuffer() + # TODO: support kernel profile on other platforms + enable_kernel_profile = ( + config.cpp.enable_kernel_profile and sys.platform == "linux" + ) + if enable_kernel_profile: + code.writelines(["#include "]) + kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" + code.writeline(codecache.cpp_prefix()) + + code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})') + with code.indent(): + if enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + code.writelines( + [ + f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef({{}}));' + ] + ) + for old, new in self.args.aliases(): + code.writeline(f"auto {old} = {new};") + code.splice(self.loops_code) + + codecache_def = IndentedBuffer() + if not V.graph.cpp_wrapper: + codecache_def.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''") + codecache_def.splice(code) + if not V.graph.cpp_wrapper: + codecache_def.writeline("''')") + + codecache_str = codecache_def.getvalue() + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + codecache_str = codecache_str.replace("#pragma CMT", "//") + wrapper.define_kernel(kernel_name, codecache_str, cuda=False) + # generate the code to call this + wrapper.generate_kernel_call( + kernel_name, call_args, cuda=False, arg_types=arg_types + ) + + +class CppWrapperKernelGroup(KernelGroup): + def __init__(self): + super().__init__() + self.args = CppWrapperKernelArgs() + + +class WorkSharing: + def __init__(self, code): + self.code = code + self.in_parallel = False + self.num_threads = None + self.stack = contextlib.ExitStack() + + def parallel(self, threads): + if self.in_parallel and threads != self.num_threads: + # wrong number of threads + self.close() + if not self.in_parallel: + self.num_threads = threads + self.in_parallel = True + if config.cpp.dynamic_threads: + self.code.writeline("#pragma omp parallel") + else: + self.code.writeline(f"#pragma omp parallel num_threads({threads})") + self.stack.enter_context(self.code.indent()) + + def single(self): + if self.in_parallel: + self.code.writeline("#pragma omp single") + return self.in_parallel + + def close(self): + self.stack.close() + self.in_parallel = False + + def __enter__(self): + self.stack.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stack.__exit__(exc_type, exc_val, exc_tb) + + +@dataclasses.dataclass +class LoopLevel: + var: Optional[sympy.Expr] = None + size: Optional[sympy.Expr] = None + offset: sympy.Expr = sympy.Integer(0) + steps: sympy.Expr = sympy.Integer(1) + parallel: int = 0 + simd_omp: bool = False + simd_vec: bool = False + collapsed: bool = False + reduction_var_map: Optional[Dict[str, str]] = None + parent: Optional["LoopLevel"] = None + # the next inner level of the loop, empty if it is inner-most + # contains >1 LoopLevel if the inner level of loop is split + inner: List["LoopLevel"] = dataclasses.field(default_factory=list) + # kernel assigned to this loop level, only valid when it is a leaf + kernel: Optional[CppKernel] = None + + def __post_init__(self): + # Regarding the C++/OpenMP backend, `codecache.pick_vec_isa()` to check + # vectorization ISA is a time-consuming and one-shot operation. It leads + # to taking a longer time to import `codegen.cpp` package because the + # `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while + # the decorator will invoke `codecache.pick_vec_isa()` to initialize the + # `simd_nelements` of the `LoopLevel`. It might introduce additional compilation + # overhead to the Triton backend. Therefore, we moved the `simd_nelements` to + # `__post_init__` + picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() + self.simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0 + + def get_kernels(self) -> List[CppKernel]: + """Get all kernel objects under this loop level""" + if self.kernel: + return [self.kernel] + kernels = [] + for loop in self.inner: + kernels += loop.get_kernels() + return kernels + + def set_kernel(self, kernel: CppKernel): + """ + Set the kernel under this loop level. No split is allowed under + this loop level. + """ + if not self.inner: + self.kernel = kernel + loop: Optional[LoopLevel] = self + assert loop is not None + if loop.is_reduction(): + loop.reduction_var_map = kernel.reduction_var_map.copy() + loop = loop.parent + while loop is not None and loop.is_reduction(): + assert loop.reduction_var_map is not None + loop.reduction_var_map.update(kernel.reduction_var_map) + loop = loop.parent + return + assert len(self.inner) == 1 + self.inner[0].set_kernel(kernel) + + def get_loops_at(self, depth) -> List["LoopLevel"]: + if depth == 0: + return [self] + else: + loops = [] + for loop in self.inner: + loops += loop.get_loops_at(depth - 1) + return loops + + def is_reduction(self): + return bool(self.reduction_var_map) + + def split_with_tiling(self, depth, factor): + def clone_inner(): + inner = [] + if self.inner: + for loop in self.inner: + inner.append(loop.clone()) + return inner + + def do_split_with_tiling(): + sympy_factor = sympy.Integer(factor) + + offset = FloorDiv(self.size, sympy_factor) * sympy_factor + main_loop = LoopLevel(self.var, offset) + main_loop.steps = sympy_factor + main_loop.parallel = self.parallel + main_loop.collapsed = False + main_loop.reduction_var_map = self.reduction_var_map + main_loop.inner = clone_inner() + if main_loop.inner: + for loop in main_loop.inner: + loop.parent = main_loop + + tail_loop = LoopLevel(self.var, self.size) + tail_loop.offset = offset + tail_loop.parallel = self.parallel + tail_loop.collapsed = False + tail_loop.reduction_var_map = self.reduction_var_map + tail_loop.inner = clone_inner() + if tail_loop.inner: + for loop in tail_loop.inner: + loop.parent = tail_loop + + return main_loop, tail_loop + + if depth == 0: + main_loop, tail_loop = do_split_with_tiling() + parent = self.parent + if parent: + parent.inner = [main_loop, tail_loop] + main_loop.parent = parent + tail_loop.parent = parent + return main_loop, tail_loop + else: + assert len(self.inner) == 1 + return self.inner[0].split_with_tiling(depth - 1, factor) + + def clone(self): + loop = copy(self) + loop.inner = [] + if self.inner: + for inner_loop in self.inner: + inner_loop_clone = inner_loop.clone() + inner_loop_clone.parent = loop + loop.inner.append(inner_loop_clone) + loop.kernel = deepcopy(self.kernel) + return loop + + def lines(self): + offset_expr = cexpr_index(self.offset) + size_expr = cexpr_index(self.size) + if config.cpp.no_redundant_loops and offset_expr == size_expr: + return None + if self.reduction_var_map: + reduction = " " + " ".join( + f"reduction({RTYPE_TO_CPP[rtype]}:{var})" + for var, rtype in self.reduction_var_map.items() + ) + else: + reduction = "" + simd = ( + f"simd simdlen({self.simd_nelements}) " + if self.simd_omp and self.simd_nelements > 1 + else "" + ) + if self.parallel: + # TODO(jansel): look into chunk size and other schedules + line1 = f"#pragma omp for{reduction} " + if self.parallel > 1: + line1 += f" collapse({self.parallel})" + if self.simd_omp: + line1 = line1.replace(" for ", f" for {simd}") + elif self.simd_vec: + line1 = "" + elif self.simd_omp: + line1 = f"#pragma omp {simd}{reduction}" + elif not self.reduction_var_map and codecache.is_gcc(): + line1 = "#pragma GCC ivdep" + else: + line1 = "" + offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}" + size_str = f"{self.var}<{size_expr}" + steps_str = f"{self.var}+={cexpr_index(self.steps)}" + line2 = f"for({offset_str}; {size_str}; {steps_str})" + if self.collapsed or not line1: + return [line2] + return [line1, line2] + + +@dataclasses.dataclass +class LoopNestWithSplit: + """ + A loop-nest like structure but with some loop level split along + the loop range into the main tiling loop and the tail. It is built + with the `build` method as a loop nest and then split with + `split_with_tiling` at some depth. + + A typical case is for vectorization where we typically split at the inner-most + loop level. A more complicated case is 2D tiling where we split at + both inner-most and outer levels. + """ + + root: Optional[List[LoopLevel]] = None + kernel: Optional[CppKernel] = None + + @staticmethod + def build(kernel: CppKernel): + """Build a LoopNest with the given `kernel` as the leaf""" + itervars = kernel.itervars + ranges = kernel.ranges + reduction_depth = kernel.reduction_depth + assert reduction_depth is not None + + root: List[LoopLevel] = [] + levels: List[LoopLevel] = root + loop: Optional[LoopLevel] = None + for loop_idx, (var, size) in enumerate(zip(itervars, ranges)): + loop = LoopLevel(var, size, parent=loop) + if loop_idx >= reduction_depth: + loop.reduction_var_map = kernel.reduction_var_map.copy() + levels.append(loop) + levels = loop.inner + loop_nest = LoopNestWithSplit(root) + if loop: + loop.kernel = kernel + else: + loop_nest.kernel = kernel + return loop_nest + + def __bool__(self): + return bool(self.root) + + def get_loops_at(self, depth) -> List[LoopLevel]: + """Get all the loop levels at the given `depth` (most outer loop has depth 0)""" + loops: List[LoopLevel] = [] + assert self.root is not None + for loop in self.root: + loops += loop.get_loops_at(depth) + return loops + + @cache_on_self + def max_parallel_depth(self): + """ + Maximal allowed depth for parallelism: + 1) Levels without splitting and + 2) All reduction or non-reduction levels + When the loop is split at the top level, the max depth is 1. + """ + max_depth = 0 + assert self.root is not None + loops = self.root + if len(loops) > 1: + return 1 + is_reduction = loops[0].is_reduction() if loops else False + while len(loops) == 1 and loops[0].is_reduction() == is_reduction: + max_depth += 1 + loops = loops[0].inner + return max_depth + + def is_reduction_only(self): + """ + Whether all the loops are for reduction. Reduction loops + are always the inner most ones. + """ + return ( + self.root is not None and len(self.root) > 0 and self.root[0].is_reduction() + ) + + def mark_parallel(self, par_depth): + assert ( + par_depth <= self.max_parallel_depth() + ), "Parallel depth cannot exceed the maximal allowed parallel depth" + assert self.root is not None + loops = self.root + for loop in loops: + loop.parallel = par_depth + for i in range(1, par_depth): + loops = loops[0].inner + loops[0].collapsed = True + + def split_with_tiling(self, depth, factor): + """ + Split the loop into main and tail loops at given `depth` so that the range + of the main loop has range `floor_div(range, factor) * factor` and + the tail loop handles the remainder. The main loop is tiled + according to the `factor`. + """ + loops = self.get_loops_at(depth) + assert len(loops) == 1 + split_loops = loops[0].split_with_tiling(0, factor) + if depth == 0: + self.root = split_loops + return split_loops diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac8aec556a5624cbc1b89e055d53bf1524df0b6d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbacf689942f47369c1c904f4da0c00cfeaa0897 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3533402f0c99a510994be363fc35250091f6dcd Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21b6dffff59c94df12dbc7894ffaf1bb570f0367 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a3b6c0c095a9b12719a210aa4ed13ba54147a13 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..2e9dc2d8a7e149e1d9023a4aa45f0b6ba166e2d1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -0,0 +1,212 @@ +import logging +from typing import cast, List + +from ...._dynamo.utils import counters + +from ... import config, ir +from ...codecache import code_hash, get_path +from ...ir import ComputedBuffer, CUDATemplateBuffer, Pointwise +from ...scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import IndentedBuffer + +from .cutlass_epilogue_gen import CUTLASSEVTOpNotImplementedError + +log = logging.getLogger(__name__) + + +class CUDACPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for CUDA C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by CUDACombinedScheduling. + + It handles fusion decisions and CUDA C++ specific template code generation. + """ + + def __init__(self, scheduler: Scheduler): + super().__init__() + self.scheduler = scheduler + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + def is_cuda_cpp_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, CUDATemplateBuffer + ) + + def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template( + node.get_template_node() + ) + + def _can_fuse_epilogue_impl( + self, + cuda_template_buffer: CUDATemplateBuffer, + epilogue_nodes: List[ir.IRNode], + additional_node: ir.IRNode, + ) -> bool: + """ + Check if the given node can be fused with the epilogue. At the moment, Kernels + support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes. + + Args: + cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer + epilogue_nodes : List[ir.Buffer]: The list of already fused epilogue nodes. + additional_node: The ir.Buffer node to be checked if it can be fused with the epilogue. + Returns: + - bool: True if the given node can be fused with the epilogue, False otherwise. + + """ + if not isinstance(cuda_template_buffer, CUDATemplateBuffer): + return False + if not cuda_template_buffer.template.can_fuse_epilogue: + # The used GEMM op does not support fusing epilogues + return False + if not isinstance(additional_node, ComputedBuffer): + return False + if not isinstance(additional_node.data, Pointwise): + return False + # We can fuse a Pointwise op that depends on the last fused epilogue node + # if any. If there is no epilogue node yet, it needs to depend on the template + # node + node_name = additional_node.get_computed_buffer_name() + if node_name is None: + return False + + if len(epilogue_nodes) == 0: + if cuda_template_buffer.name not in additional_node.get_read_names(): + return False + else: + last_epilogue_node = epilogue_nodes[-1] + assert isinstance(last_epilogue_node, ir.ComputedBuffer) # for mypy + last_epilogue_name = ( + last_epilogue_node.name + if last_epilogue_node.name is not None + else last_epilogue_node.data.name # type: ignore[attr-defined] + ) + if last_epilogue_name not in additional_node.get_read_names(): + return False + if additional_node.layout != cuda_template_buffer.layout: + return False + try: + from torch._inductor.codegen.cuda.cutlass_epilogue_gen import ( + CutlassEVTEpilogueArgumentFormatter, + CutlassEVTEpilogueTypeFormatter, + ) + + CutlassEVTEpilogueTypeFormatter.ir_to_evt_string( + cast(str, cuda_template_buffer.name), "anything", [additional_node] + ) + CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string( + cast(str, cuda_template_buffer.name), [additional_node] + ) + except CUTLASSEVTOpNotImplementedError as e: + not_implemented_op = str(e) + if not_implemented_op.startswith("_op_"): + not_implemented_op = not_implemented_op[4:] + log.warning( + f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}, likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950 + ) + return False + else: + # Likely due to unsupported dtype. + log.warning( + f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}. Reason: {not_implemented_op}" # noqa: G004, B950 + ) + return False + return True + + @staticmethod + def _unwrap_epilogue_nodes(fused_node: FusedSchedulerNode) -> List[ir.IRNode]: + nodes = fused_node.get_nodes() + template_node = fused_node.get_template_node() + nodes.remove(template_node) + return [n.node for n in nodes] + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if self.is_cuda_cpp_template(node1) and isinstance(node2, SchedulerNode): + return self._can_fuse_epilogue_impl( + cast(CUDATemplateBuffer, node1.node), [], node2.node + ) + elif self.is_cuda_cpp_fused_template(node1) and isinstance( + node2, SchedulerNode + ): + fnode1 = cast(FusedSchedulerNode, node1) + return self._can_fuse_epilogue_impl( + fnode1.get_template_node().node, + self._unwrap_epilogue_nodes(fnode1), + node2.node, + ) + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_name = "_".join(["cuda", fused_name, wrapper.next_kernel_suffix()]) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace("KERNEL_NAME", kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.cuda(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''', 'so')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode] + ): + """ + Codegen a CUDA template, possibly with fused epilogues + """ + counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cuda_cpp_template( + template_node + ), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer" + template_node = cast(SchedulerNode, template_node) + _, (numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node) + epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes] + assert all( + isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes + ), "Epilogue nodes must all be instances of ir.ComputedBuffer" + kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes) + with kernel: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule) + kernel.call_kernel(kernel_name, ctb, epilogue_ir_nodes) + V.graph.removed_buffers |= kernel.removed_buffers + self.scheduler.free_buffers() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py new file mode 100644 index 0000000000000000000000000000000000000000..6171921173e9717853eafa9497292a7dbdaaf903 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py @@ -0,0 +1,45 @@ +import functools +import logging +from typing import Optional + +import torch + +from ... import config + +log = logging.getLogger(__name__) + + +def get_cuda_arch() -> Optional[str]: + try: + cuda_arch = config.cuda.arch + if cuda_arch is None: + # Get Compute Capability of the first Visible device + major, minor = torch.cuda.get_device_capability(0) + return str(major * 10 + minor) + return str(cuda_arch) + except Exception as e: + log.error("Error getting cuda arch: %s", e) + return None + + +def get_cuda_version() -> Optional[str]: + try: + cuda_version = config.cuda.version + if cuda_version is None: + cuda_version = torch.version.cuda + return cuda_version + except Exception as e: + log.error("Error getting cuda version: %s", e) + return None + + +@functools.lru_cache(None) +def nvcc_exist(nvcc_path: str = "nvcc") -> bool: + if nvcc_path is None: + return False + import subprocess + + res = subprocess.call( + ["which", nvcc_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + return res == 0 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py new file mode 100644 index 0000000000000000000000000000000000000000..258e54266477c6b3cab695cb4214f7c71f5315d5 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py @@ -0,0 +1,242 @@ +import functools +import itertools +import logging +from typing import List, Optional +from unittest.mock import patch + +import sympy + +import torch +from ...autotune_process import CUDABenchmarkRequest, TensorMeta +from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout + +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import KernelTemplate +from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel + +log = logging.getLogger(__name__) + + +class CUDATemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes: List[Buffer], + layout: Layout, + input_reorder: Optional[List[int]] = None, + ): + """ + + Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + + Args: + name (str): The name of the CUDATemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. + + """ + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer("buf_out", layout) + self.input_reorder = input_reorder + self.layout = layout + + def generate( # type: ignore[override] + self, + **kwargs, + ) -> CUDATemplateCaller: + """ + Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller + may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A CUDATemplateCaller object representing the generated CUDA template caller. + """ + kernel_name = f"cuda_{self.name}" + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(self.output_node) + ), CUDATemplateKernel( + kernel_name=kernel_name, + ) as kernel: + code = self.render(kernel=kernel, **kwargs) + _, call_args, _ = kernel.args.python_argdefs() + log.debug("Generated Code:\n%s", code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, call_args[len(expected_args) :]) + ) + + kernel_hash_name = f"cuda_{self.name}_{next(self.index_counter)}" + + # create the BenchmarkRequest + bmreq = CUDABenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: CUDATemplateBuffer, + epilogue_nodes: Optional[List[IRNode]] = None, + ): + kernel = CUDATemplateKernel( + kernel_name="KERNEL_NAME", + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate + ) + return kernel, render + + return CUDATemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + kwargs, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + using bfloat16 = nv_bfloat16; + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + +class CUTLASSTemplate(CUDATemplate): + """ + CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the + CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels. + """ + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + #include "cute/tensor.hpp" + #include "cutlass/cutlass.h" + #include "cutlass/numeric_types.h" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/host_tensor.h" + #include "cutlass/util/reference/host/tensor_fill.h" + #include "cutlass/util/reference/device/tensor_fill.h" + #include "cutlass/util/device_memory.h" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + using namespace cute; + #define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + throw std::runtime_error(msg); \\ + } \\ + } + + // Used as pass-through functor in EVT just for type casting / rounding + template + struct identity_op { + CUTLASS_HOST_DEVICE + T operator()(T val) const { return val; } + }; + + """ + ) + return res + + def cute_int(self, int_str: str, var_name: str) -> str: + res = "" + if int_str in {"1", "1L"}: + res = "cute::Int<1>{}" + else: + res = int_str + + return f"{res} /* {var_name} */" + + _DTYPE_TO_CUTLASS = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "cutlass::half_t", + torch.int32: "int", + torch.int8: "int8_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "cutlass::bfloat16_t", + } + + def cutlass_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})" diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4828aab466ab519dbfcd018932f102da6e3208 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py @@ -0,0 +1,360 @@ +from typing import Dict, List +from unittest.mock import patch + +import sympy + +import torch._inductor.virtualized as virtualized +from torch._inductor.ir import ComputedBuffer, FlexibleLayout, IRNode, Pointwise +from torch._inductor.utils import IndentedBuffer, sympy_str + + +# Used as a magic string to indicate an unsupported sympy expression +# became part of generated C++ code. +_MAGIC_SYMPY_ERROR_STRING = "[!sympy: unsupported expr!]" + + +def _arg_str(a): + if isinstance(a, sympy.Expr): + # If this return value containting the _MAGIC_SYMPY_ERROR_STRING + # is used as part of the final generated C++ code, + # a CUTLASSEVTOpNotImplementedError is raised to indicate that + # the op could not be converted to a valid EVT expression. + return f"{_MAGIC_SYMPY_ERROR_STRING}('{sympy_str(a)}')" + return str(a) + + +class CUTLASSEVTOpNotImplementedError(NotImplementedError): + pass + + +class CutlassEVTEpilogueTypeFormatter: + """ + Codegen class, which provides an entry point to generate + Cutlass "Epilogue Visitor Tree" (EVT) functor declarations. + + See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder + for more about EVTs and how they are declared and used to generate. + + Notes: + * Used by CUTLASSGemmTemplate. + * This class should not be instantiated by users, it is intended to be used + by calling CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(...) + which instantiates this class as an ops handler for virtualized.V.ops.[op-name] + * Extend this with more _op_ nodes to add support for new pointwise operations. + + + """ + + def __init__(self, accumulator_node_name, evt_type_name): + """ + + Initialize an instance of CutlassEVTEpilogueTypeFormatter. + + Parameters: + - accumulator_node_name (str): The name of the output Buffer for the GEMM operation in the original (unfused) + IR graph. + - evt_type_name (str): The output name of the EVT type we are generating. + + """ + self.accumulator_node_name = accumulator_node_name + self.output = IndentedBuffer(0) + self.var_counter = 0 + self.evt_type_name = evt_type_name + self.aliases = dict() + + @staticmethod + def ir_to_evt_string( + template_output_node_name: str, + evt_type_name: str, + epilogue_nodes: List[IRNode], + ): + """ + Formats IR nodes into a string representation compatible with Cutlass EVT format. + + Args: + template_output_node_name (str): The name of the template output node. + evt_type_name (str): The name of the EVT type. + epilogue_nodes (List[IRNode]): A list of IR nodes representing the epilogue nodes. As of now, these must be + ComputedBuffer nodes wrapping Pointwise nodes. + + Returns: + A string representation of the IR nodes formatted according to the Cutlass EVT format. + """ + formatter = CutlassEVTEpilogueTypeFormatter( + template_output_node_name, evt_type_name + ) + + with virtualized.V.set_ops_handler(formatter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + for node in epilogue_nodes: + if isinstance(node, ComputedBuffer): + pnode = node.data + else: + raise RuntimeError( + "Epilogue nodes must be Pointwise nodes, wrapped in a named ComputedBuffer" + ) + assert isinstance(pnode, Pointwise) + index = pnode._index(pnode.ranges) + result = pnode.inner_fn(index) + # each epilogue node results in a single "using" statement and may refer to the previous steps by name + formatter.aliases[node.name] = result + res = formatter.getvalue(result) # type: ignore[possibly-undefined] + if _MAGIC_SYMPY_ERROR_STRING in res: + raise CUTLASSEVTOpNotImplementedError( + "sympy / indexing expressions not yet supported in EVT fusion" + ) + else: + return res + + def __getattr__(self, name): + """ + Resolve V.ops. calls, after this instance has been installed as V.ops handler. + """ + + def inner(*args, **kwargs): + fargs = [_arg_str(a) for a in args] + fkwargs = {key: _arg_str(a) for key, a in kwargs.items()} + fn = getattr(self, f"_op_{name}") + line = fn(*fargs, **fkwargs) + self.var_counter += 1 + varname = f"EVT_expr_{self.var_counter}" + # replace line with a new variable name + self.output.writeline(f"using {varname} = {line};") + return varname + + if name.startswith("_"): + raise CUTLASSEVTOpNotImplementedError(name) + if hasattr(self, f"_op_{name}"): + return inner + else: + raise CUTLASSEVTOpNotImplementedError(name) + + def _op_load(self, name, index_expr): + # Load an input to an operation. Might be the output of the matmul, the result + # of a previous epilogue node, a constant or (TODO) an auxiliary input. + if name == self.accumulator_node_name: + return f"cutlass::epilogue::fusion::Sm90AccFetch /* :={name} (matmul output in accumulator) */" + elif name in self.aliases: + return self.aliases[name] + else: + # return f"cutlass::epilogue::fusion::Sm90SrcFetch /* :={name} */" + raise CUTLASSEVTOpNotImplementedError( + f"Operand {name} not found. Auxiliary inputs not supported yet." + ) + + def _op_constant(self, value, dtype): + # Load a constant + if str(dtype) in ("torch.float16", "torch.float32"): + return f"cutlass::epilogue::fusion::Sm90ScalarBroadcast /* value={value}, dtype={dtype} */" + else: + raise CUTLASSEVTOpNotImplementedError( + f"Unsupported dtype for constant: {dtype}" + ) + + def _cutlass_binary_functional_op(self, op, a, b): + # Perform a named operation on two inputs + # see https://github.com/NVIDIA/cutlass/blob/6407bcdf0a24097b7b016ee105937693c62f9923/include/cutlass/functional.h for ops + return f"cutlass::epilogue::fusion::Sm90EVT,{a},{b}>" # noqa: B950 + + def _convert_to_output_dtype(self, a): + # Convert the final output to the dtype of the output buffer + return f"cutlass::epilogue::fusion::Sm90EVT,{a}>" # noqa: B950 + + def _op_to_dtype(self, a, *args, **kwargs): + # no-op in our case, since we convert to the output dtype at the end and convert everything to the accumulator + # dtype. + # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible + # throughout the fusion chain. + return a # noqa: B950 + + def _op_mul(self, a, b): + return self._cutlass_binary_functional_op("multiplies", a, b) + + def _op_div(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_truediv(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_ge(self, a, b): + return self._cutlass_binary_functional_op("greater_equal", a, b) + + def _op_add(self, a, b): + return self._cutlass_binary_functional_op("plus", a, b) + + def _op_sub(self, a, b): + return self._cutlass_binary_functional_op("minus", a, b) + + def _op_minimum(self, a, b): + return self._cutlass_binary_functional_op("minimum", a, b) + + def _op_maximum(self, a, b): + return self._cutlass_binary_functional_op("maximum", a, b) + + def _op_relu(self, a): + const_zero = self._op_constant(0.0, "torch.float32") + return f"cutlass::epilogue::fusion::Sm90EVT,{a}, {const_zero}>" # noqa: B950 + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise CUTLASSEVTOpNotImplementedError() + + # Add more ops here... + def getvalue(self, result) -> str: + # Return final result + dtype_converted_expr = self._convert_to_output_dtype( + f"EVT_expr_{self.var_counter}" + ) + self.output.writeline(f"using {self.evt_type_name} = {dtype_converted_expr};") + return self.output.getvalue() + + +class CutlassEVTEpilogueArgumentFormatter: + """ + Codegen class, which provides an entry point to generate + Cutlass "Epilogue Visitor Tree" (EVT) Argument initializers + + See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder + for more about EVTs and how they are declared and used to generate. + + Notes: + * Used by CUTLASSGemmTemplate. + * This class should not be instantiated by users, it is intended to be used + by calling CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(...) + which instantiates this class as an ops handler for virtualized.V.ops.[op-name] + * Extend this with more _op_ nodes to add support for new pointwise operations. + + + """ + + def __init__(self, accumulator_node_name: str): + """ + + Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly. + Use the CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string static method. + + Args: + accumulator_node_name (str): The name of the accumulator node which should contain + the Matmul result before fusion according to the IR graph. + """ + self.accumulator_node_name: str = accumulator_node_name # + self.output: IndentedBuffer = IndentedBuffer(0) # The output buffer for codegen + self.var_counter: int = ( + 0 # used to generate variable names, incremented for each new variable + ) + self.aliases: Dict[str, str] = dict() # Aliases for subexpression functors + + @staticmethod + def ir_to_evt_argument_string( + template_output_node_name: str, + epilogue_nodes: List[IRNode], + ) -> str: + formatter = CutlassEVTEpilogueArgumentFormatter( + template_output_node_name, + ) + + with virtualized.V.set_ops_handler(formatter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + for node in epilogue_nodes: + assert isinstance(node, ComputedBuffer) + pnode = node.data + assert isinstance(pnode, Pointwise) + index = pnode._index(pnode.ranges) + result = pnode.inner_fn(index) + # each epilogue node results in a single "using" statement and may refer to the previous steps by name + if node.name is not None: + formatter.aliases[node.name] = result + + res: str = formatter.getvalue(result) # type: ignore[possibly-undefined] + if _MAGIC_SYMPY_ERROR_STRING in res: + raise CUTLASSEVTOpNotImplementedError( + "sympy / indexing expressions not yet supported in EVT fusion" + ) + else: + return res + + def __getattr__(self, name): + def inner(*args, **kwargs): + fargs = [_arg_str(a) for a in args] + fkwargs = {key: _arg_str(a) for key, a in kwargs.items()} + fn = getattr(self, f"_op_{name}") + line = fn(*fargs, **fkwargs) + return line + + if name.startswith("_"): + raise CUTLASSEVTOpNotImplementedError(name) + + if hasattr(self, f"_op_{name}"): + return inner + else: + raise CUTLASSEVTOpNotImplementedError(name) + + def _op_load(self, name, index_expr): + if name == self.accumulator_node_name: + return "{}" + elif name in self.aliases: + return self.aliases[name] + else: + raise CUTLASSEVTOpNotImplementedError( + f"Operand {name} not found. Auxiliary inputs not supported yet." + ) + + def _op_constant(self, value, dtype): + if str(dtype) in ("torch.float16", "torch.float32"): + return "{ static_cast(" + str(value) + ") }" + else: + raise CUTLASSEVTOpNotImplementedError( + f"Unsupported dtype for constant: {dtype}" + ) + + def _cutlass_binary_functional_op(self, op, a, b): + return f"{{ /*{op}: */ {a}, {b} }}" + + def _op_mul(self, a, b): + return self._cutlass_binary_functional_op("multiplies", a, b) + + def _op_div(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_truediv(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_ge(self, a, b): + return self._cutlass_binary_functional_op("greater_equal", a, b) + + def _op_add(self, a, b): + return self._cutlass_binary_functional_op("plus", a, b) + + def _op_sub(self, a, b): + return self._cutlass_binary_functional_op("minus", a, b) + + def _op_minimum(self, a, b): + return self._cutlass_binary_functional_op("minimum", a, b) + + def _op_maximum(self, a, b): + return self._cutlass_binary_functional_op("maximum", a, b) + + def _op_relu(self, a): + const_zero = self._op_constant(0.0, "torch.float32") + return "{" + str(a) + ", " + const_zero + "}" + + def _op_to_dtype(self, a, dtype, src_dtype=None): + # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible + # throughout the fusion chain. + assert dtype in ( + "torch.float32", + "torch.float16", + ), f"Unsupported dtype: {dtype}" + assert src_dtype in ( + None, + "torch.float32", + "torch.float16", + ), f"Unsupported source dtype: {src_dtype}" + return a + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise CUTLASSEVTOpNotImplementedError() + + def getvalue(self, result) -> str: + return "{" + str(result) + "}" diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..2a386a114e86f9556f0302151b42ae8c37b7e806 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -0,0 +1,186 @@ +from ..cutlass_utils import try_import_cutlass + +if try_import_cutlass(): + import enum + + from cutlass_library.library import * # noqa: F401, F403 + from cutlass_library.gemm_operation import * # noqa: F401, F403 + + # copied / modified from original at + # https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658 + # to support EVT similar to + # https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L315C69-L315C69 # noqa: B950 + class EmitGemmUniversal3xInstanceWithEVT: + """Responsible for emitting a CUTLASS 3.x template definition""" + + def __init__(self, operation_suffix=""): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/gemm/gemm.h", + "cutlass/numeric_types.h", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/epilogue/collective/collective_builder.hpp", + ] + self.builtin_epilogue_functor_template = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + > + """ + self.gemm_template = """ + using EpilogueScheduleType = ${epilogue_schedule}; + static_assert(cute::is_same_v || + cute::is_same_v, + "Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue"); + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementAcc = ${element_accumulator}; + using ElementD = ${element_d}; + ${epilogue_functor}; + using ${operation_name}_epilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + EpilogueScheduleType, + ${operation_name}_epilogue_functor + >::CollectiveOp; + + using ${operation_name}_mainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape, + ${stages}, + ${kernel_schedule} + >::CollectiveOp; + + // Gemm operator ${operation_name} + using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + ${operation_name}_mainloop, + ${operation_name}_epilogue, + ${tile_scheduler}>; + + // Define named type + struct ${operation_name} : + public ${operation_name}_base { }; + + """ + + # + def instance_template(self): + return """ + ${compile_guard_start} + using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; + manifest.append( + new ${gemm_kind}("${operation_name}")); + ${compile_guard_end} + """ + + # + def emit(self, operation): + tile_shape = operation.tile_description.tile_shape + warp_count = operation.tile_description.warp_count + # stage count set to zero indicates builder automatic stage selection + if operation.tile_description.stages > 0: + stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>" + else: + stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout" # noqa: B950 + warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)] + + ( + instance_layout_A, + instance_layout_B, + instance_layout_C, + instance_layout_D, + ) = ( + operation.A.layout, + operation.B.layout, + operation.C.layout, + operation.D.layout, + ) + + # 3.0 profiler integration only supports trivial epilogues for now + epilogue_vector_length = 1 + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + values = { + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), # type: ignore[name-defined] + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], # type: ignore[name-defined] + } + epilogue_functor = SubstituteTemplate( # type: ignore[name-defined] + self.builtin_epilogue_functor_template, values + ) + + elif callable(operation.epilogue_functor): + epilogue_functor = operation.epilogue_functor( + operation.procedural_name() + "_epilogue_functor" + ) + else: + epilogue_functor = str(operation.epilogue_functor) + # + + values = { + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "element_a": DataTypeTag[operation.A.element], # type: ignore[name-defined] + "layout_a": LayoutTag[instance_layout_A], # type: ignore[name-defined] + "element_b": DataTypeTag[operation.B.element], # type: ignore[name-defined] + "layout_b": LayoutTag[instance_layout_B], # type: ignore[name-defined] + "element_c": DataTypeTag[operation.C.element], # type: ignore[name-defined] + "layout_c": LayoutTag[instance_layout_C], # type: ignore[name-defined] + "element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined] + "layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined] + "element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined] + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], # type: ignore[name-defined] # noqa: B950 + "arch": "cutlass::arch::Sm%d" % operation.arch, + "tile_shape_m": str(operation.tile_description.tile_shape[0]), + "tile_shape_n": str(operation.tile_description.tile_shape[1]), + "tile_shape_k": str(operation.tile_description.tile_shape[2]), + "cluster_m": str(operation.tile_description.cluster_shape[0]), + "cluster_n": str(operation.tile_description.cluster_shape[1]), + "cluster_k": str(operation.tile_description.cluster_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined] + "epilogue_schedule": str(EpilogueScheduleTag[operation.epilogue_schedule]), # type: ignore[name-defined] + "epilogue_functor": epilogue_functor, + "stages": stage_count_string, + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "align_c": str(operation.C.alignment), + "align_d": str(operation.C.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], # type: ignore[name-defined] + "transform_b": ComplexTransformTag[operation.B.complex_transform], # type: ignore[name-defined] + "math_operation": MathOperationTag[ # type: ignore[name-defined] + operation.tile_description.math_instruction.math_operation + ], + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), # type: ignore[name-defined] + "tile_scheduler": str(TileSchedulerTag[operation.tile_scheduler]), # type: ignore[name-defined] + } + + return SubstituteTemplate(self.gemm_template, values) # type: ignore[name-defined] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..93a8c08b6a0f25a05f4c194abcef91205f4fc748 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -0,0 +1,18 @@ +from ..common import DeviceOpOverrides, register_device_op_overrides + + +class CUDADeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name): + return f"from torch._C import _cuda_getCurrentRawStream as {name}" + + def set_device(self, device_idx): + return f"torch.cuda.set_device({device_idx})" + + def synchronize(self): + return "torch.cuda.synchronize()" + + def device_guard(self, device_idx): + return f"torch.cuda._DeviceGuard({device_idx})" + + +register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..f82e01daad63c818b10f30cf0876c669359989c0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -0,0 +1,75 @@ +from typing import List + +from ..scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode +from .cuda.cuda_cpp_scheduling import CUDACPPScheduling + +from .triton import TritonScheduling + + +class CUDACombinedScheduling(BaseScheduling): + """ + Scheduler for CUDA Kernels, which delegates calls as appropriate + to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices + and use a unified-wrapper for codegen. + + If Scheduling code needs to be specialized for the case of mixed Triton / CUDA C++ code, + this would also be the place to do it. + """ + + def __init__(self, scheduler: Scheduler): + super().__init__() + self._scheduler = scheduler + self._triton_scheduling = TritonScheduling(scheduler) + self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) + + def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: + if self._cuda_cpp_scheduling.is_cuda_cpp_template( + node + ) or self._cuda_cpp_scheduling.is_cuda_cpp_fused_template(node): + return self._cuda_cpp_scheduling + return self._triton_scheduling + + def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2): + return True + return self._triton_scheduling.can_fuse_vertical(node1, node2) + + def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + for node in (node1, node2): + if self._cuda_cpp_scheduling.is_cuda_cpp_template( + node + ) or self._cuda_cpp_scheduling.is_cuda_cpp_fused_template(node): + return self._cuda_cpp_scheduling.can_fuse_horizontal( + node1, node2 + ) # always False at the moment + return self._triton_scheduling.can_fuse_horizontal(node1, node2) + + def group_fn(self, sizes): + return self._triton_scheduling.group_fn(sizes) + + def codegen_template( + self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode] + ): + if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node): + return self._cuda_cpp_scheduling.codegen_template( + template_node, epilogue_nodes + ) + else: + return self._triton_scheduling.codegen_template( + template_node, epilogue_nodes + ) + + def codegen_nodes(self, nodes: List[SchedulerNode]): + return self._triton_scheduling.codegen_nodes(nodes) + + def codegen_sync(self): + return self._triton_scheduling.codegen_sync() + + def flush(self): + return self._triton_scheduling.flush() + + def codegen_foreach(self, *args, **kwargs): + return self._triton_scheduling.codegen_foreach(*args, **kwargs) + + def benchmark_fused_nodes(self, nodes): + return self._triton_scheduling.benchmark_fused_nodes(nodes) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..689b5fd9dbc79f9cb5c0892f12b8a3e05196717f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py @@ -0,0 +1,130 @@ +import functools + +from typing import Dict, Set, Tuple + +import torch +from torch._dynamo.utils import counters + +from torch._ops import OpOverload, OpOverloadPacket +from ..pattern_matcher import fwd_only, register_replacement + +aten = torch.ops.aten + + +@functools.lru_cache(None) +def _misc_patterns_init(): + from .joint_graph import patterns as joint_graph_patterns + from .post_grad import pass_patterns as post_grad_patterns_all + + post_grad_patterns = post_grad_patterns_all[1] # medium priority + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # These patterns do 2 things + # 1. Since we know that index is completely unique, we can codegen it using + # stores instead of atomic adds, which is quite a bit faster. + # 2. Also, since we are guaranteed that they are completely within bounds, + # we can use unsafe indexing and skip debug asserts + def randperm_index_add_pattern(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return torch.index_add(x, dim=0, source=y, index=index), index + + def randperm_index_add_replacement(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return ( + torch.ops.aten._unsafe_index_put( + x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False + ), + index, + ) + + register_replacement( + randperm_index_add_pattern, + randperm_index_add_replacement, + [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)], + fwd_only, + [post_grad_patterns, joint_graph_patterns], + ) + + def randperm_index_pattern(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten.index(x, (index,)), index + + def randperm_index_replacement(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten._unsafe_index(x, (index,)), index + + pattern = register_replacement( + randperm_index_pattern, + randperm_index_replacement, + [torch.empty(4, 8, device=device)], + fwd_only, + [post_grad_patterns, joint_graph_patterns], + scalar_workaround={"slice_shape": 42}, + ) + + +class NumpyCompatNormalization: + numpy_compat: Dict[str, Tuple[str, ...]] = { + "dim": ("axis",), + "keepdim": ("keepdims",), + "input": ("x", "a", "x1"), + "other": ("x2",), + } + inverse_mapping: Dict[str, str] + cache: Dict["torch.fx.graph.Target", Set[str]] + + def __init__(self): + self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"] + self.inverse_mapping = {} + for actual_kwarg, numpy_kwargs in self.numpy_compat.items(): + for numpy_kwarg in numpy_kwargs: + assert numpy_kwarg not in self.inverse_mapping + self.inverse_mapping[numpy_kwarg] = actual_kwarg + + def __call__(self, graph: torch.fx.Graph): + for node in graph.nodes: + if node.op != "call_function": + continue + if isinstance(node.target, (OpOverload, OpOverloadPacket)): + # only applies to torch ops; e.g. torch.stack(axis=1) works, torch.ops.aten.stack(axis=1) doesn't. + continue + kwargs = node.kwargs + + if node.target in self.cache: + replaceable_kwargs = self.cache[node.target] + else: + signatures = torch.fx.operator_schemas.get_signature_for_torch_op( + node.target + ) + signatures = () if signatures is None else signatures + replaceable_kwargs = set() + for sig in signatures: + for param_name in sig.parameters.keys(): + if param_name in self.numpy_compat: + replaceable_kwargs.update(self.numpy_compat[param_name]) + + self.cache[node.target] = replaceable_kwargs + + if not replaceable_kwargs: + continue + + new_kwargs = {} + kwargs_changed = False + for k, v in kwargs.items(): + if k in replaceable_kwargs: + kwargs_changed = True + new_kwargs[self.inverse_mapping[k]] = v + else: + new_kwargs[k] = v + + if kwargs_changed: + node.kwargs = torch.fx.immutable_collections.immutable_dict(new_kwargs) + counters["inductor"]["numpy_compat_normalization"] += 1 + + +numpy_compat_normalization = NumpyCompatNormalization() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..aa32e1a592a22b3b724a0a11cf3922495b811e37 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -0,0 +1,1204 @@ +import functools +import operator +from functools import reduce +from typing import Any, Tuple + +import torch + +from torch.fx.experimental.symbolic_shapes import has_free_symbols + +from .. import ir + +from ..lowering import lowerings as L +from ..pattern_matcher import ( + Arg, + CallFunction, + filter_nodes, + get_arg_value, + KeywordArg, + MULTIPLE, +) +from ..virtualized import ops +from .freezing_patterns import register_freezing_graph_pattern +from .post_grad import register_lowering_pattern +from .quantization import ( + _register_quantization_lowerings, + _register_quantization_weight_pack_pass, +) + +if torch._C._has_mkldnn: + aten = torch.ops.aten + mkldnn = torch.ops.mkldnn + prims = torch.ops.prims + + _conv_args = [Arg() for _ in range(10)] + _linear_args = [Arg() for _ in range(6)] + _conv_transpose_args = [Arg() for _ in range(11)] + + def _conv_call(users=1): + return CallFunction( + mkldnn._convolution_pointwise.default, *_conv_args, _users=users + ) + + def _linear_call(users=1): + return CallFunction( + mkldnn._linear_pointwise.default, *_linear_args, _users=users + ) + + def _conv_transpose_call(users=1): + return CallFunction( + mkldnn._convolution_transpose_pointwise.default, + *_conv_transpose_args, + _users=users, + ) + + def _to_float(input_call, users=1): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_float"), + _users=users, + ) + + def _to_bf16(input_call): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_bf16"), + _users=1, + ) + + def _to_fp16(input_call): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_fp16"), + _users=1, + ) + + def _unary_fusion_pattern(unary_fusion, call_fn, users, lowp_dtype): + # only insert to_dtype if lowp_dtype is True + computation_call = ( + _to_float(call_fn(), users=users) if lowp_dtype else call_fn(users=users) + ) + out = unary_fusion(computation_call) + if lowp_dtype == torch.bfloat16: + return _to_bf16(out) + elif lowp_dtype == torch.float16: + return _to_fp16(out) + else: + return out + + def _gelu_fusion_1(computation_call): + return CallFunction( + aten.mul, + CallFunction(aten.mul, computation_call, 0.5), + CallFunction( + aten.add, + CallFunction( + aten.erf, + CallFunction(aten.mul, computation_call, 0.7071067811865476), + ), + 1, + ), + ) + + def _gelu_fusion_2(computation_call): + return CallFunction( + aten.mul, + CallFunction(aten.mul, computation_call, 0.5), + CallFunction( + aten.add, + CallFunction( + aten.tanh, + CallFunction( + aten.mul, + CallFunction( + aten.add, + computation_call, + CallFunction( + aten.mul, + CallFunction( + aten.mul, + CallFunction( + aten.mul, computation_call, computation_call + ), + computation_call, + ), + 0.044715, + ), + ), + 0.7978845608028654, + ), + ), + 1, + ), + ) + + def _hardswish_fusion(computation_call): + return CallFunction( + aten.div, + CallFunction( + aten.mul, + computation_call, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 + ), + 6, + ), + ), + 6, + ) + + def _silu_fusion(computation_call): + return CallFunction( + aten.mul, computation_call, CallFunction(aten.sigmoid, computation_call) + ) + + def _hardsigmoid_fusion(computation_call): + return CallFunction( + aten.div, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 + ), + 6, + ), + 6, + ) + + def _leaky_relu_fusion(computation_call): + return CallFunction( + aten.where, + CallFunction(aten.gt, computation_call, 0), + computation_call, + CallFunction(aten.mul, computation_call, KeywordArg("negative_slope")), + ) + + def _hardtanh_fusion(computation_call): + return CallFunction( + aten.clamp_max, + CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")), + KeywordArg("max_value"), + ) + + def _combined_fusion(computation_call, elementwise_op): + return CallFunction(elementwise_op, computation_call) + + # binary_op(other, computation_op) + def _binary_fusion_v1(computation_call, binary_fn): + return CallFunction(binary_fn, KeywordArg("other"), computation_call) + + # binary_op(computation_op, other) + def _binary_fusion_v2(computation_call, binary_fn): + return CallFunction(binary_fn, computation_call, KeywordArg("other")) + + def _is_single_computation_op(computation_op): + def fn(match): + computation_nodes = filter_nodes(match.nodes, computation_op) + if len(computation_nodes) < 1: + return False + if any(n.args[-3] != "none" for n in computation_nodes): + return False + return True + + return fn + + def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): + def fn(match): + matched = _is_single_computation_op(computation_op)(match) + computation_node = filter_nodes(match.nodes, computation_op)[0] + if lowp_dtype: + conversion_dtype_nodes = filter_nodes( + match.nodes, prims.convert_element_type.default + ) + if len(conversion_dtype_nodes) != 2: + return False + # fusion pattern is always in the form of computation_op + to_float32 + unary_op + to_bfloat16 + if computation_node == conversion_dtype_nodes[0].args[0]: + to_float = conversion_dtype_nodes[0].args[1] + to_lp = conversion_dtype_nodes[1].args[1] + else: + to_float = conversion_dtype_nodes[1].args[1] + to_lp = conversion_dtype_nodes[0].args[1] + matched = matched and to_float == torch.float and to_lp == lowp_dtype + return matched + + return fn + + def _register_unary_fusion_lowering( + pattern, unary_attr, computation_op, lowp_dtype=None + ): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_computation_unary_fusion(computation_op, lowp_dtype), + ) + def fn(match, *args, **kwargs): + computation_args = list(args)[:-3] + [ + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + return L[computation_op](*computation_args) + + return fn + + def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None): + @register_lowering_pattern( + pattern, extra_check=_is_single_computation_op(computation_op) + ) + def fn(match, *args, **kwargs): + negative_slope = kwargs.get("negative_slope") + if isinstance(negative_slope, ir.TensorBox): + matched = False + else: # inp is a Number + matched = True + if lowp_dtype: + dtype1 = kwargs.get("to_float") + dtype2 = ( + kwargs.get("to_bf16") + if lowp_dtype == torch.bfloat16 + else kwargs.get("to_fp16") + ) + matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype + computation_args = list(args) + if matched: + computation_args = computation_args[:-3] + [ + "leaky_relu", + [negative_slope], + "", + ] + return L[computation_op](*computation_args) + else: + # computation_args += ["none", [], ""] + out = L[computation_op](*computation_args) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=torch.float) + out = L[aten.where]( + L[aten.gt](out, 0), + out, + L[aten.mul](out, negative_slope), + ) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] + return out + + return fn + + def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None): + @register_lowering_pattern( + pattern, extra_check=_is_single_computation_op(computation_op) + ) + def fn(match, *args, **kwargs): + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + if isinstance(min_value, ir.TensorBox) or isinstance( + max_value, ir.TensorBox + ): + matched = False + else: # inp is a Number + assert max_value is not None + matched = min_value <= max_value + if lowp_dtype: + dtype1 = kwargs.get("to_float") + dtype2 = ( + kwargs.get("to_bf16") + if lowp_dtype == torch.bfloat16 + else kwargs.get("to_fp16") + ) + matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype + computation_args = list(args) + if matched: + computation_args = computation_args[:-3] + [ + "hardtanh", + [min_value, max_value], + "", + ] + return L[computation_op](*computation_args) + else: + out = L[computation_op](*computation_args) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=torch.float) + out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] + return out + + return fn + + _binary_attr = { + aten.add: "add", + ops.add: "add", + aten.sub: "sub", + ops.sub: "sub", + } + + def _is_valid_binary(match, fn): + binary_nodes = filter_nodes(match.nodes, fn) + if len(binary_nodes) < 1: + return False + + def get_meta_value(argument: torch.fx.node.Argument): + # Only torch.fx.Node is expected to have meta. + if isinstance(argument, torch.fx.Node): + return argument.meta.get("val", None) + return None + + if any( + not isinstance(get_meta_value(n.args[0]), torch.Tensor) + or not isinstance(get_meta_value(n.args[1]), torch.Tensor) + for n in binary_nodes + ): + return False + # check alpha is one. + if any( + get_arg_value(n, 2, kwarg_name="alpha") != 1.0 + and get_arg_value(n, 2, kwarg_name="alpha") is not None + for n in binary_nodes + ): + return False + if any( + get_meta_value(n.args[0]).size() != get_meta_value(n.args[1]).size() + or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device + or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype + for n in binary_nodes + ): + return False + # check args[0] and args[1] is not same + if any(n.args[0] == n.args[1] for n in binary_nodes): + return False + return True + + def _is_valid_computation_binary(computation_op, binary_op, other_index=None): + def fn(match): + if not _is_single_computation_op(computation_op)(match): + return False + if not _is_valid_binary(match, binary_op): + return False + return True + + return fn + + def _get_remaining_users(extra_input_node, compute_node): + # Think about this pattern: + # ReLU + # / \ + # Conv1 + # / \ + # Conv2 + # \ / + # Add + # Although, the extra input node (ReLU) has more than 1 users: Conv1 and Add. + # The Conv1 is the ancestor node of the current compute node (Conv2). + # This indicates that the buffer of ReLU has completed all its usage, + # So we can safely make changes to it now by doing Conv2->Add inplace fusion. + # Take above case as example: + # * extra_input_node: ReLU + # * compute_node: Conv2 + # _get_remaining_users will return the users of extra_input_node which are not + # ancestor node of compute_node. + def _is_ancestor_node(_current_node, _ancestor_node): + # Check whether _ancestor_node is the ancestor node of _current_node + _node_list = [_current_node] + _visited_nodes = set() + while len(_node_list) != 0: + _current_node = _node_list.pop(0) + if _current_node not in _visited_nodes: + _visited_nodes.add(_current_node) + if _current_node == _ancestor_node: + return True + elif isinstance( + _current_node, torch.fx.Node + ) and _current_node.op not in ["placeholder", "output", "get_attr"]: + for input in _current_node.all_input_nodes: + _node_list.append(input) # noqa: PERF402 + return False + + return [ + user + for user in list(extra_input_node.users) + if not _is_ancestor_node(compute_node, user) + ] + + def _is_valid_computation_binary_inplace(computation_op, binary_op, other_index): + def fn(match): + if not _is_valid_computation_binary(computation_op, binary_op)(match): + return False + binary_nodes = filter_nodes(match.nodes, binary_op) + + def _get_compute_node(_binary_node, _other_index): + assert ( + len(_binary_node.all_input_nodes) == 2 + ), "Binary node should have 2 input nodes." + _compute_index = 1 if (_other_index == 0) else 0 + return _binary_node.args[_compute_index] + + def _other_input_not_inplaceable(_binary_node, _other_index): + _compute_node = _get_compute_node(_binary_node, _other_index) + return ( + len( + _get_remaining_users( + _binary_node.args[_other_index], _compute_node + ) + ) + > 1 + or _binary_node.args[_other_index] == _compute_node.args[0] + ) + + if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes): + return False + if any( + n.args[other_index].op in ["placeholder", "output"] + for n in binary_nodes + ): + return False + return True + + return fn + + def _register_binary_unary_fusion_lowering( + pattern, + computation_op, + binary_op, + fusion_op, + unary_attr=None, + ): + @register_lowering_pattern( + pattern, extra_check=_is_valid_computation_binary(computation_op, binary_op) + ) + def fn(match, *args, **kwargs): + other = kwargs.get("other") + assert isinstance(other, ir.TensorBox) + binary_attr = _binary_attr[binary_op] + args_list = list(args) + computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] + if len(args_list) > 6: + if unary_attr is not None: + computation_args += [ + 1.0, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + else: + computation_args += [1.0, None, [], None] + return L[fusion_op](*computation_args) + + return fn + + def _can_be_inplace(_other): + if isinstance(_other.data, ir.View): + return _can_be_inplace(_other.data) + else: + return not ( + isinstance(_other.data, ir.ReinterpretView) + or isinstance( + _other.get_layout(), (ir.MutationLayout, ir.AliasedLayout) + ) + ) + + def _register_binary_unary_maybe_inplace_fusion_lowering( + pattern, + computation_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + unary_attr=None, + other_index=None, + ): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_computation_binary_inplace( + computation_op, binary_op, other_index + ), + ) + def fn(match, *args, **kwargs): + other = kwargs.get("other") + assert isinstance(other, ir.TensorBox) + binary_attr = _binary_attr[binary_op] + args_list = list(args) + computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] + if len(args_list) > 6: + if unary_attr is not None: + computation_args += [ + 1.0, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + else: + computation_args += [1.0, None, [], None] + # Make sure the other is not an alias or mutation(fx side doesn't has such info). + other.realize() + if not _can_be_inplace(other): + return L[outplace_fusion_op](*computation_args) + return L[inplace_fusion_op](*computation_args) + + return fn + + computation_ops = [ + mkldnn._convolution_pointwise.default, + mkldnn._linear_pointwise.default, + mkldnn._convolution_transpose_pointwise.default, + ] + + class UnaryAttr: + def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): + self.op_name = op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + def _register_unary_fusion(): + computation_call_fns = [_conv_call, _linear_call, _conv_transpose_call] + + def _unary_fusion_patterns(lowp_dtype): + replacement_unary_fusion_patterns = { + UnaryAttr("gelu", algorithm_attr="tanh"): [ + _unary_fusion_pattern(_gelu_fusion_2, call_fn, 4, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("gelu", algorithm_attr="none"): [ + _unary_fusion_pattern(_gelu_fusion_1, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("hardswish"): [ + _unary_fusion_pattern(_hardswish_fusion, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("hardsigmoid"): [ + _unary_fusion_pattern(_hardsigmoid_fusion, call_fn, 1, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("swish"): [ + _unary_fusion_pattern(_silu_fusion, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + } + if not lowp_dtype: + call_user1 = [call_fn(users=1) for call_fn in computation_call_fns] + replacement_unary_fusion_patterns.update( + { + UnaryAttr("relu"): [ + _combined_fusion(u, aten.relu) for u in call_user1 + ], + UnaryAttr("sigmoid"): [ + _combined_fusion(u, aten.sigmoid) for u in call_user1 + ], + UnaryAttr("tanh"): [ + _combined_fusion(u, aten.tanh) for u in call_user1 + ], + } + ) + + return replacement_unary_fusion_patterns + + for lowp_dtype in [torch.bfloat16, torch.float16, None]: + replace_patterns = _unary_fusion_patterns(lowp_dtype) + for unary_attr, patterns in replace_patterns.items(): + _register_unary_fusion_lowering( + patterns[0], unary_attr, computation_ops[0], lowp_dtype + ) + _register_unary_fusion_lowering( + patterns[1], unary_attr, computation_ops[1], lowp_dtype + ) + _register_unary_fusion_lowering( + patterns[2], unary_attr, computation_ops[2], lowp_dtype + ) + _leaky_relu_patterns = [ + _unary_fusion_pattern(_leaky_relu_fusion, call_fn, 3, lowp_dtype) + for call_fn in computation_call_fns + ] + for pattern, computation_op in zip(_leaky_relu_patterns, computation_ops): + _register_leaky_relu_fusion_lowering( + pattern, computation_op, lowp_dtype + ) + hardtanh_patterns = [ + _unary_fusion_pattern(_hardtanh_fusion, call_fn, 1, lowp_dtype) + for call_fn in computation_call_fns + ] + for pattern, computation_op in zip(hardtanh_patterns, computation_ops): + _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype) + + def _register_inplace_fusion(): + binary_ops = [aten.add, ops.add] + inplace_fusion_op = mkldnn._convolution_pointwise_.binary + outplace_fusion_op = mkldnn._convolution_pointwise.binary + conv_call = _conv_call(users=1) + conv_op = computation_ops[0] + for binary_op in binary_ops: + binary_v1 = _binary_fusion_v1(conv_call, binary_op) + binary_unary_v1 = _combined_fusion(binary_v1, aten.relu) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_unary_v1, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=0, + unary_attr=UnaryAttr("relu"), + ) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_v1, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=0, + ) + binary_v2 = _binary_fusion_v2(conv_call, binary_op) + binary_unary_v2 = _combined_fusion(binary_v2, aten.relu) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_unary_v2, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=1, + unary_attr=UnaryAttr("relu"), + ) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_v2, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=1, + ) + + def _register_binary_fusion(): + binary_ops = [aten.add, ops.add, aten.sub, ops.sub] + fusion_ops = [ + mkldnn._convolution_pointwise.binary, + mkldnn._linear_pointwise.binary, + ] + _computation_user_1 = [_conv_call(users=1), _linear_call(users=1)] + for computation_call, computation_op, fusion_op in zip( + _computation_user_1, computation_ops[:-1], fusion_ops + ): + for binary_op in binary_ops: + pattern = _binary_fusion_v2(computation_call, binary_op) + _register_binary_unary_fusion_lowering( + pattern, computation_op, binary_op, fusion_op + ) + + for binary_op in [aten.add, ops.add]: + pattern = _binary_fusion_v1(computation_call, binary_op) + _register_binary_unary_fusion_lowering( + pattern, computation_op, binary_op, fusion_op + ) + + def _register_binary_unary_fusion(): + binary_ops = [aten.add, ops.add, aten.sub, ops.sub] + fusion_ops = [mkldnn._convolution_pointwise.binary] + _computation_user_1 = [_conv_call(users=1)] + for computation_call, computation_op, fusion_op in zip( + _computation_user_1, computation_ops[:-1], fusion_ops + ): + for binary_op in binary_ops: + pattern_v1 = _combined_fusion( + _binary_fusion_v2(computation_call, binary_op), aten.relu + ) + _register_binary_unary_fusion_lowering( + pattern_v1, + computation_op, + binary_op, + fusion_op, + unary_attr=UnaryAttr("relu"), + ) + for binary_op in [aten.add, ops.add]: + pattern_v2 = _combined_fusion( + _binary_fusion_v1(computation_call, binary_op), aten.relu + ) + _register_binary_unary_fusion_lowering( + pattern_v2, + computation_op, + binary_op, + fusion_op, + unary_attr=UnaryAttr("relu"), + ) + + def _recover_linear(): + # convert reshape+linear+reshape to a single linear for applying fusion path. + @register_freezing_graph_pattern( + CallFunction( + aten.reshape.default, + CallFunction( + mkldnn._linear_pointwise.default, + CallFunction( + aten.reshape.default, + Arg(), + KeywordArg("reshape_1"), + _users=MULTIPLE, + ), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + ), + KeywordArg("reshape_2"), + ), + pass_number=1, + ) + def reshape_linear_reshape_pattern(match, *args, **kwargs): + reshape_1 = kwargs.get("reshape_1") + reshape_2 = kwargs.get("reshape_2") + assert isinstance(reshape_1, list) + assert isinstance(reshape_2, list) + assert len(reshape_1) == 2 + dynamic_shapes = not all( + isinstance(x, int) for x in ([reshape_1[0]] + reshape_2[:-1]) + ) + + graph = match.graph + reshape_2_node = match.output_node() + linear_input_node = reshape_2_node.args[0].args[0].args[0] + # check linear's input's shape[:-1] == reshape_2[:-1] + # and check product(reshape_2[:-1]) == reshape_1[0] + if dynamic_shapes: + # TODO: Haozhe investigate how add guard here + return + else: + can_remove_reshape = linear_input_node.meta.get("val").shape[ + :-1 + ] == torch.Size(reshape_2[:-1]) + can_remove_reshape = can_remove_reshape and ( + reduce(operator.mul, reshape_2[:-1]) == reshape_1[0] + ) + + if can_remove_reshape: + repl = graph.call_function(mkldnn._linear_pointwise.default, args) + repl.meta.update(reshape_2_node.meta) + reshape_2_node.replace_all_uses_with(repl) + old_linear_node = reshape_2_node.args[0] + reshape_1_node = old_linear_node.args[0] + graph.erase_node(reshape_2_node) + graph.erase_node(old_linear_node) + if len(reshape_1_node.users) == 0: + graph.erase_node(reshape_1_node) + + def is_linear_add_bias(match): + add_node = match.output_node() + linear_node = add_node.args[0] + weight_meta = linear_node.args[1].meta.get("val") + bias_meta = add_node.args[1].meta.get("val") + if weight_meta is None or bias_meta is None: + return False + return ( + linear_node.args[2] is None + and bias_meta.dim() == 1 + and bias_meta.size(0) == weight_meta.size(0) + ) + + # convert linear+bias to a single linear for applying fusion path. + @register_freezing_graph_pattern( + CallFunction( + aten.add.Tensor, + CallFunction(mkldnn._linear_pointwise.default, *_linear_args), + Arg(), + ), + pass_number=1, + extra_check=is_linear_add_bias, + ) + def linear_bias_pattern(match, *args): + graph = match.graph + add_node = match.output_node() + linear_node = add_node.args[0] + new_args = list(linear_node.args) + new_args[2] = add_node.args[1] + repl = graph.call_function( + mkldnn._linear_pointwise.default, tuple(new_args) + ) + repl.meta.update(add_node.meta) + add_node.replace_all_uses_with(repl) + match.erase_nodes(graph) + + def _is_packable_mkldnn_rnn_layer(match): + lstm_node = match.output_node() + POS_WEIGHTS = [1, 2] + POS_INPUTS = [0, 5, 6] + POS_ARGS = POS_WEIGHTS + POS_INPUTS + # Weights should be Constant + if any( + lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS + ): + return False + + # Meta info for weights and inputs should be available + if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS): + return False + + # Check device + if any( + lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu" + for POS_ARG in POS_ARGS + ): + return False + + # Check dtype + if any( + lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16 + and not mkldnn._is_mkldnn_bf16_supported() + for POS_ARG in POS_ARGS + ): + return False + if any( + lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16 + and not mkldnn._is_mkldnn_fp16_supported() + for POS_ARG in POS_ARGS + ): + return False + + return True + + def _is_packable_convolution(match): + """ + Check if the node is supported for MKLDNN convolution. + """ + conv_node = match.output_node() + input_meta_value = conv_node.args[0].meta.get("val") + weight_meta_value = conv_node.args[1].meta.get("val") + if input_meta_value is None or weight_meta_value is None: + return False + input_size = input_meta_value.shape + if conv_node.args[1].op != "get_attr": + return False + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type != "cpu" + or meta_value.dim() != 4 + ): + return False + if ( + input_meta_value.dtype == torch.bfloat16 + or weight_meta_value.dtype == torch.bfloat16 + ): + if not mkldnn._is_mkldnn_bf16_supported(): + return False + if ( + input_meta_value.dtype == torch.float16 + or weight_meta_value.dtype == torch.float16 + ): + if not mkldnn._is_mkldnn_fp16_supported(): + return False + is_transposed = conv_node.args[-3] + if is_transposed: + # TODO: Support dynamic shape case for MKLDNN conv transpose. + if has_free_symbols(input_size): + return False + groups = conv_node.args[-1] + in_channels = weight_meta_value.size(0) + # doesn't support group_depthwise_conv_transpose. + if groups > 1 and groups == in_channels: + return False + # Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big + output_paddings = conv_node.args[-2] + strides = conv_node.args[3] + if any( + output_padding >= stride + for output_padding, stride in zip(output_paddings, strides) + ): + return False + return True + + def _is_packable_linear(match): + """ + Check if the node is supported for MKLDNN linear. + """ + linear_node = match.output_node() + # weight_idx is 1 for aten.mm and is 2 for aten.addmm + weight_idx = 2 if linear_node.target == aten.addmm.default else 1 + if linear_node.args[weight_idx].op != "get_attr": + return False + input_meta_value = linear_node.args[weight_idx - 1].meta.get("val") + weight_meta_value = linear_node.args[weight_idx].meta.get("val") + if input_meta_value is None or weight_meta_value is None: + return False + batch_size = input_meta_value.shape[0] + is_lp_weight = weight_meta_value.dtype in ( + torch.bfloat16, + torch.float16, + ) + # on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol. + # on aarch64, use mkldnn op for fp32 as well if acl is enabled + if ( + not is_lp_weight + and not mkldnn._is_mkldnn_acl_supported() + and ((not torch._C.has_mkl) or has_free_symbols(batch_size)) + ): + return False + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type != "cpu" + or meta_value.dim() != 2 + ): + return False + if weight_idx == 2: + bias_meta_value = linear_node.args[0].meta.get("val") + if ( + bias_meta_value is None + or meta_value.device.type != "cpu" + or bias_meta_value.dim() != 1 + or bias_meta_value.size(0) != weight_meta_value.size(1) + ): + return False + + if ( + input_meta_value.dtype == torch.bfloat16 + or weight_meta_value.dtype == torch.bfloat16 + ): + if not mkldnn._is_mkldnn_bf16_supported(): + return False + if ( + input_meta_value.dtype == torch.float16 + or weight_meta_value.dtype == torch.float16 + ): + if not mkldnn._is_mkldnn_fp16_supported(): + return False + return True + + _aten_conv_args = ( + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + KeywordArg("is_transposed"), + Arg(), + Arg(), + ) + + _aten_mkldnn_rnn_layer_args = ( + Arg(), # input + Arg(), # weight0 + Arg(), # weight1 + Arg(), # weight2 + Arg(), # weight3 + Arg(), # hx_ + Arg(), # cx_ + KeywordArg("reverse"), # reverse + Arg(), # batch_sizes + Arg(), # mode + Arg(), # hidden_size + Arg(), # num_layers + Arg(), # has_biases + Arg(), # bidirectional + Arg(), # batch_first + Arg(), # train + ) + + def _register_weight_pack_pass(): + @register_freezing_graph_pattern( + CallFunction(aten.convolution.default, *_aten_conv_args), + extra_check=_is_packable_convolution, + ) + def convolution(match, *args, **kwargs): + is_transposed = kwargs.get("is_transposed") + assert isinstance(is_transposed, bool) + graph = match.graph + conv_node = match.output_node() + input_size = conv_node.args[0].meta.get("val").shape + with graph.inserting_before(conv_node): + constant_args = [args[4], args[3], args[5], args[-1]] + packed_weight_op = mkldnn._reorder_convolution_weight + packed_conv_op = mkldnn._convolution_pointwise.default + if is_transposed: + constant_args.insert(1, args[-2]) # output_padding + packed_weight_op = mkldnn._reorder_convolution_transpose_weight + packed_conv_op = mkldnn._convolution_transpose_pointwise.default + if not has_free_symbols(input_size): + packed_weight_inputs = ( + (args[1],) + tuple(constant_args) + (input_size,) + ) + packed_weight_node = graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + else: + assert not is_transposed + # For dynamic shape case, we need to pack weight in runtime. + packed_weight_node = args[1] + packed_conv_inputs = ( + (args[0], packed_weight_node, args[2]) + + tuple(constant_args) + + ("none", [], "") + ) + packed_conv_node = graph.create_node( + "call_function", packed_conv_op, tuple(packed_conv_inputs) + ) + conv_node.replace_all_uses_with(packed_conv_node) + packed_conv_node.meta.update(conv_node.meta) + graph.erase_node(conv_node) + + @register_freezing_graph_pattern( + CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args), + extra_check=_is_packable_mkldnn_rnn_layer, + ) + def mkldnn_rnn_layer(match, *args, **kwargs): + def get_item(graph, node, index): + return graph.call_function(operator.getitem, (node, index)) + + graph = match.graph + lstm_node = match.output_node() + input = args[0] + weight0, weight1 = args[1:3] + reverse = kwargs.get("reverse") + packed_lstm_op = aten.mkldnn_rnn_layer.default + hidden_size = args[9] + has_biases = args[11] + batch_first = args[13] + with graph.inserting_before(lstm_node): + packed_weight_op = mkldnn._reorder_mkldnn_rnn_layer_weight.default + packed_weight_inputs = ( + weight0, + weight1, + hidden_size, + reverse, + has_biases, + batch_first, + ) + packed_weight_node = graph.create_node( + "call_function", packed_weight_op, packed_weight_inputs, {}, "name" + ) + packed_weight_items = [ + get_item(graph, packed_weight_node, i) for i in range(2) + ] + pack_lstm_inputs = ( + args[0], + *packed_weight_items, + args[3], + args[4], + args[5], + args[6], + reverse, + *args[7:], + ) + + packed_lstm_node = graph.create_node( + "call_function", packed_lstm_op, args=pack_lstm_inputs + ) + lstm_node.replace_all_uses_with(packed_lstm_node) + packed_lstm_node.meta.update(lstm_node.meta) + graph.erase_node(lstm_node) + + @register_freezing_graph_pattern( + CallFunction(aten.addmm.default, Arg(), Arg(), Arg()), + extra_check=_is_packable_linear, + ) + @register_freezing_graph_pattern( + CallFunction(aten.mm.default, Arg(), Arg()), + extra_check=_is_packable_linear, + ) + def linear(match, *args, **kwargs): + graph = match.graph + linear_node = match.output_node() + input = args[0] if linear_node.target == aten.mm.default else args[1] + bias = None if linear_node.target == aten.mm.default else args[0] + weight = args[1] if linear_node.target == aten.mm.default else args[2] + with graph.inserting_before(linear_node): + transpose_weight_node = graph.create_node( + "call_function", aten.permute.default, (weight, (1, 0)) + ) + weight_dtype = weight.meta.get("val").dtype + is_lp_weight = weight_dtype in ( + torch.bfloat16, + torch.float16, + ) + batch_size = input.meta.get("val").shape[0] + if has_free_symbols(batch_size): + assert ( + is_lp_weight or mkldnn._is_mkldnn_acl_supported() + ), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}" + # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance. + packed_weight_inputs = ( + transpose_weight_node, + batch_size.node.shape_env.size_hint(batch_size.node.expr) + if has_free_symbols(batch_size) + else batch_size, + ) + packed_weight_op = ( + mkldnn._reorder_linear_weight + if (is_lp_weight or mkldnn._is_mkldnn_acl_supported()) + else torch.ops.mkl._mkl_reorder_linear_weight + ) + packed_weight_node = graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + + packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node) + if is_lp_weight or mkldnn._is_mkldnn_acl_supported(): + packed_linear_inputs += (bias, "none", [], "") + packed_linear_op = mkldnn._linear_pointwise.default + else: + packed_linear_inputs += (transpose_weight_node, bias, batch_size) + packed_linear_op = torch.ops.mkl._mkl_linear + packed_linear_node = graph.create_node( + "call_function", packed_linear_op, packed_linear_inputs + ) + linear_node.replace_all_uses_with(packed_linear_node) + packed_linear_node.meta.update(linear_node.meta) + graph.erase_node(linear_node) + + def _eliminate_duplicate_packed_nodes(gm): + """ + Combine packed weight nodes with the same inputs to reduce memory usage. + for example: + class Model(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(32, 32, bias=True) + + def forward(self, x): + return self.linear(self.linear(x)) + + the above's packed weight nodes are duplicate if two linear calls have same input size. + """ + if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()): + return gm + + packed_weight_ops = [ + torch._C._nn.mkldnn_reorder_conv2d_weight, + mkldnn._reorder_convolution_transpose_weight, + mkldnn._reorder_linear_weight, + mkldnn._reorder_mkldnn_rnn_layer_weight, + ] + if torch._C.has_mkl: + packed_weight_ops.append(torch.ops.mkl._mkl_reorder_linear_weight) + + for node in gm.graph.nodes: + if node.target in packed_weight_ops and len(node.args[0].users) > 1: + for user_node in list(node.args[0].users.keys()): + if ( + user_node.target == node.target + and user_node != node + and user_node.args == node.args + ): + user_node.replace_all_uses_with(node) + gm.graph.erase_node(user_node) + + @functools.lru_cache(None) + def _mkldnn_fusion_init(): + # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now. + # Otherwise even the matmul or innerproduct can not be accelerated with acl + if ( + torch.backends.mkldnn.enabled + and torch.backends.mkldnn.is_available() + and not torch.ops.mkldnn._is_mkldnn_acl_supported() + ): + _register_unary_fusion() + _register_inplace_fusion() + _register_binary_unary_fusion() + _register_binary_fusion() + _register_quantization_lowerings() + + @functools.lru_cache(None) + def _mkldnn_weight_pack_init(): + if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available(): + _register_weight_pack_pass() + _recover_linear() + _register_quantization_weight_pack_pass() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..0564c6fa846933be3c9ec15b8e30e99e19e41fdc --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py @@ -0,0 +1,1100 @@ +import copy +import functools +import itertools +import logging +import operator +from collections import Counter, defaultdict +from typing import Any, Dict, List, Optional, Set, Union + +from sympy import Expr + +import torch +import torch._inductor as inductor +import torch.utils._pytree as pytree +from torch import fx +from torch._decomp import register_decomposition +from torch._dynamo.utils import counters, optimus_scuba_log + +from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype + +from torch._utils_internal import upload_graph +from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq + +from .. import config, ir, pattern_matcher +from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage + +from ..lowering import lowerings as L +from ..pattern_matcher import ( + _return_true, + Arg, + CallFunction, + CallFunctionVarArgs, + filter_nodes, + get_arg_value, + get_mutation_region_id, + Ignored, + init_once_fakemode, + KeywordArg, + ListOf, + Match, + MULTIPLE, + PatternMatcherPass, + register_graph_pattern, + stable_topological_sort, +) +from ..utils import decode_device, is_pointwise_use +from ..virtualized import V +from .group_batch_fusion import group_batch_fusion_passes +from .reinplace import reinplace_inplaceable_ops + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +# First pass_patterns[0] are applied, then [1], then [2] +pass_patterns = [ + PatternMatcherPass(), + PatternMatcherPass(), + PatternMatcherPass(), +] +# patterns applied only in inference +inference_patterns = PatternMatcherPass() +decompose_mm_pass = PatternMatcherPass() + + +def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): + """ + Passes that run on after grad. This is called once on the forwards + graph and once on the backwards graph. + + The IR here has been normalized and functionalized. + """ + if config.dce: + # has some issues with mutation in inference mode + gm.graph.eliminate_dead_code() + + if is_inference and config.reorder_for_locality: + reorder_for_locality(gm.graph) + + fake_tensor_updater = FakeTensorUpdater(gm.graph) + + if config.post_grad_custom_pre_pass is not None: + config.post_grad_custom_pre_pass(gm.graph) + + if config.pattern_matcher: + lazy_init() + inductor_before_change = copy.deepcopy(counters["inductor"]) + group_batch_fusion_passes(gm.graph, pre_grad=False) + if counters["inductor"] != inductor_before_change: + optimus_scuba_log["group_batch_fusion_post_grad"] = upload_graph(gm.graph) + remove_noop_ops(gm.graph) + for patterns in pass_patterns: + patterns.apply(gm.graph) # type: ignore[arg-type] + if is_inference: + inference_patterns.apply(gm.graph) # type: ignore[arg-type] + decompose_mm_pass.apply(gm.graph) # type: ignore[arg-type] + + if config.post_grad_custom_post_pass is not None: + config.post_grad_custom_post_pass(gm.graph) + + stable_topological_sort(gm.graph) + + move_constructors_to_cuda(gm.graph) + + fake_tensor_updater.incremental_update() + + # Keep these last, since they introduces mutation. Look at + # ./fx_passes/README.md for a discussion of mutation invariants. + reinplace_inplaceable_ops(gm.graph) + decompose_auto_functionalized(gm.graph) + + gm.recompile() + gm.graph.lint() + + +@init_once_fakemode +def lazy_init(): + if torch._C._has_mkldnn: + from . import decompose_mem_bound_mm # noqa: F401 + from .mkldnn_fusion import _mkldnn_fusion_init + + _mkldnn_fusion_init() + + +def reorder_for_locality(graph: torch.fx.Graph): + def visit(other_node): + if ( + other_node.op == "call_function" + and other_node.target != operator.getitem + and all((n in seen_nodes) for n in other_node.users) + and get_mutation_region_id(graph, node) + == get_mutation_region_id(graph, other_node) + ): + # move node's producers right before it + node.prepend(other_node) + + seen_nodes = set() + + # only reorder nodes before the first copy_ in the graph. + # copy_ will appear at the end of functionalized graphs when there is mutation on inputs, + # and this reordering doesnt work well with mutation + first_copy = next( + ( + node + for node in graph.nodes + if node.op == "call_function" + and node.target == torch.ops.aten.copy_.default + ), + None, + ) + past_mutating_epilogue = True if first_copy is None else False + + for node in reversed(graph.nodes): + seen_nodes.add(node) + if not past_mutating_epilogue: + past_mutating_epilogue = node is first_copy + continue + + torch.fx.map_arg((node.args, node.kwargs), visit) + + +def register_lowering_pattern(pattern, extra_check=_return_true, pass_number=1): + """ + Register an aten to inductor IR replacement pattern + """ + return pattern_matcher.register_lowering_pattern( + pattern, extra_check, pass_dict=pass_patterns[pass_number] + ) + + +################################################################################ +# Actual patterns below this point. +# Priority of patterns is: +# - later output nodes first +# - order patterns are defined in +################################################################################ + + +def is_valid_mm_plus_mm(match: Match): + *b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape + *b2, k2, n1 = match.kwargs["mat2"].meta.get("tensor_meta").shape + if k1 != k2: + return False + + *b1, m2, k3 = match.kwargs["mat3"].meta.get("tensor_meta").shape + *b2, k4, n2 = match.kwargs["mat4"].meta.get("tensor_meta").shape + if k3 != k4: + return False + + if m1 != m2 or n1 != n2: + return False + + return True + + +@register_lowering_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")), + CallFunction(aten.mm, KeywordArg("mat3"), KeywordArg("mat4")), + ), + extra_check=is_valid_mm_plus_mm, +) +def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4): + return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4) + + +def cuda_and_enabled_mixed_mm(match): + return (config.use_mixed_mm or config.force_mixed_mm) and getattr( + match.kwargs["mat1"].meta.get("val"), "is_cuda", False + ) + + +def cuda_and_enabled_mixed_mm_and_not_int8(match): + return ( + cuda_and_enabled_mixed_mm(match) + and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False) + and getattr(match.kwargs["mat2"].meta.get("val"), "dtype", torch.int8) + != torch.int8 + ) # bitshift numerics in triton and pytorch don't match for torch.int8 + + +""" + this is intended to be used to unpack a [K,N] int4 tensor from a [K/2, N] uint4x2 tensor + (where the int4 and uint4x2 are represented with int8 and uint8 respectively) + where every other row of the int4 is packed with the row above it as: + uint4x2[k,n] = (8+int4[2*k,n])+(8+int4[2*k+1,n])<<4 + + unpack formulas: + int4[2*k,n]=(uint4x2[k,n] & 0xF) - 8 + int4[2*k+1,n]=(uint4x2[k,n] >> 4) - 8 + + thus matching on unpack formula: + torch.mm(mat1, torch.cat((mat2 & 0xF, mat2>>4),1).reshape(mat2_mm_shape).to(mat2_dtype).sub(8)) + + note: although the unpack formula in pytorch and the triton kernel is designed for a uint8 mat2, the behavior + of the kernel matches the pytorch formula for all dtypes except torch.int8 + where the bitwise numerics in triton do not match those in pytorch. +""" + + +@register_lowering_pattern( + CallFunction( + aten.mm.default, + KeywordArg("mat1"), + CallFunction( + aten.sub.Tensor, + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.reshape.default, + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + aten.bitwise_and.Scalar, + KeywordArg("mat2"), + 0xF, + ), + CallFunction( + aten.__rshift__.Scalar, + KeywordArg("mat2"), + 4, + ), + ), + 1, + ), + KeywordArg("mat2_mm_shape"), + ), + KeywordArg("mat2_dtype"), + ), + 8, + ), + ), + extra_check=cuda_and_enabled_mixed_mm_and_not_int8, +) +def uint4x2_mixed_mm(match: Match, mat1, mat2, mat2_mm_shape, mat2_dtype): + return inductor.kernel.unpack_mixed_mm.tuned_uint4x2_mixed_mm( + mat1, mat2, mat2_mm_shape, mat2_dtype + ) + + +""" + torch.mm(mat1, mat2.to(mat2_dtype)) +""" + + +@register_lowering_pattern( + CallFunction( + aten.mm, + KeywordArg("mat1"), + CallFunction( + prims.convert_element_type.default, + KeywordArg("mat2"), + KeywordArg("mat2_dtype"), + ), + ), + extra_check=cuda_and_enabled_mixed_mm, +) +def mixed_mm(match: Match, mat1, mat2, mat2_dtype): + return inductor.kernel.mm.tuned_mixed_mm(mat1, mat2, mat2_dtype) + + +@register_graph_pattern( + CallFunction( + aten.cumsum.default, + CallFunction( + torch.ops.aten.full.default, + KeywordArg("shape"), + KeywordArg("fill_value"), + dtype=KeywordArg("dtype"), + layout=Ignored(), + device=KeywordArg("device"), + pin_memory=False, + _users=MULTIPLE, + ), + KeywordArg("dim"), + _users=MULTIPLE, + ), + pass_dict=pass_patterns[1], +) +def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim): + """Based on a pattern in OPTForCausalLM""" + + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + # cumsum promotes all integral types to int64 + dtype = torch.int64 + + def repl(*shape): + dim_size = shape[dim] + idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype) + + inter_shape = [1] * len(shape) + inter_shape[dim] = dim_size + return (idx * fill_value).view(inter_shape).expand(shape) + + # only replace the output node, not all nodes + match.nodes = [match.output_node()] + with V.fake_mode: + match.replace_by_example(repl, list(shape)) + + +def shape_of_mm(a, b): + m, _ = a.get_size() + _, n = b.get_size() + return [m, n] + + +@register_lowering_pattern( + CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()), +) +def cat_mm(match, inputs, dim): + return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of_mm) + + +@register_lowering_pattern( + CallFunction( + aten.cat, ListOf(CallFunction(aten.addmm, Arg(), Arg(), Arg())), Arg() + ), +) +def cat_addmm(match, inputs, dim): + def shape_of(bias, a, b): + m, _ = a.get_size() + _, n = b.get_size() + return [m, n] + + return cat_tuned_op(match, inputs, dim, op=L[aten.addmm], shape_of=shape_of) + + +def cat_tuned_op(match, inputs, dim, *, op, shape_of): + """ + Memory planning to remove cat. We can't use the stock memory + planner since autotuning matmuls needs to know the output layout. + """ + if len(inputs) == 1: + return op(*inputs[0]) + + # TODO(jansel): rewrite this as a bmm? + if dim < 0: + dim += len(shape_of(*inputs[0])) + assert dim in (0, 1) + notdim = 1 - dim + + new_size: Optional[Union[List[Expr], List[int]]] = None + offsets_start = [] + offsets_end = [] + + # compute output sizes + for i in range(len(inputs)): + shape = shape_of(*inputs[i]) + if new_size is None: + new_size = shape + else: + new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload] + shape[notdim], new_size[notdim] + ) + new_size[dim] += shape[dim] + offsets_start.append(new_size[dim] - shape[dim]) + offsets_end.append(new_size[dim]) + + assert new_size is not None + dtype = functools.reduce( + torch.promote_types, + [x.get_dtype() for x in itertools.chain.from_iterable(inputs)], + ) + device = inputs[0][0].get_device() + kernel = ir.ConcatKernel( + name=None, + layout=ir.FixedLayout(device, dtype, new_size), + inputs=[], + ) + kernel_tensor = ir.TensorBox.create(kernel) + + for i in range(len(inputs)): + dst = ir.SliceView.create(kernel_tensor, dim, offsets_start[i], offsets_end[i]) + src = op(*inputs[i], layout=dst.get_layout()).data.data + assert isinstance(src, (ir.ExternKernelOut, ir.TemplateBuffer)) + src.layout = ir.AliasedLayout(dst) + kernel.inputs.append(src) + + kernel.name = V.graph.register_buffer(kernel) + kernel.inputs = ir.ConcatKernel.unwrap_storage(kernel.inputs) + return kernel_tensor + + +_cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2) + + +@register_lowering_pattern( + CallFunction( + aten.cat, + [ + _cat_1, + CallFunction( + aten.slice, + _cat_1, + 1, + 0, + KeywordArg("size"), + ), + ], + 1, + ) +) +def cat_slice_cat(match, cat_input, size, dim=1): + """ + This is an example of a more complex pattern where cat_1 is used + multiple times inside the pattern. We fold 2 calls to cat into one. + + Matches: + cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1) + slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807) + slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19) + cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1) + + + Rewrite to: + slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19) + cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1) + """ + first, *rest = cat_input + # Optimization is optional, because we can just not fold the cat + # size should be within first.get_size()[dim] such that the optimization is valid. + # For negative `end`, we currently fallback to not optimizing. + if size >= 0 and V.graph.sizevars.statically_known_leq(size, first.get_size()[dim]): + # fold 2 cats into 1 cat + return L[aten.cat]( + [ + first, + *rest, + L[aten.slice](first, dim, 0, size), + ], + dim, + ) + else: + # don't expect to hit this case, just fall back + tmp = L[aten.cat](cat_input, dim) + return L[aten.cat]( + [ + tmp, + L[aten.slice](tmp, dim, 0, size), + ], + dim, + ) + + +def is_valid_splitwithsizes_cat(match): + split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) + cat_nodes = filter_nodes(match.nodes, aten.cat) + get_item_nodes = filter_nodes(match.nodes, operator.getitem) + if len(split_nodes) != 1 or len(cat_nodes) != 1: + return False + split_node, cat_node = split_nodes[0], cat_nodes[0] + # The dim of split and cat should match for passthrough + if get_arg_value(split_node, 2, "dim") != get_arg_value(cat_node, 1, "dim"): + return False + get_item_args = { + get_arg_value(get_item_node, 1) for get_item_node in get_item_nodes + } + assert None not in get_item_args + split_sizes = get_arg_value(split_node, 1, "split_sizes") + # All parts of split should be included in the cat + if get_item_args != set(range(len(split_sizes))): + return False + # The order of get_item_args should same with cat_node used. + # For example, if the split_node like split_with_sizes(input, [2, 2, 3], 1), + # the cat node should be like cat([get_item(0), get_item(1), get_item(2)], 1). + cat_items_args_order = [ + get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0) + ] + if cat_items_args_order != list(range(len(split_sizes))): + return False + + return True + + +def same_meta(node1: torch.fx.Node, node2: torch.fx.Node): + """True if two nodes have the same metadata""" + val1 = node1.meta.get("val") + val2 = node2.meta.get("val") + return ( + val1 is not None + and val2 is not None + and statically_known_true(sym_eq(val1.size(), val2.size())) + and val1.layout == val2.layout + and val1.dtype == val2.dtype + and val1.device == val2.device + and ( + val1.layout != torch.strided + or statically_known_true(sym_eq(val1.stride(), val2.stride())) + ) + ) + + +noop_registry: Dict[Any, Any] = {} + + +def register_noop_decomp(targets, nop_arg=0): + def register_fun(cond): + register_decomposition(targets, registry=noop_registry, unsafe=True)( + (cond, nop_arg) + ) + return cond + + return register_fun + + +@register_noop_decomp(aten.slice) +def slice_noop(self, dim=0, start=None, end=None, step=1): + if start is None or end is None: + return False + if start == 0 and end >= 2**63 - 1 and step == 1: + return True + return False + + +@register_noop_decomp(aten.slice_scatter, 1) +def slice_scatter_noop(self, src, dim=0, start=None, end=None, step=1): + if start is None: + start = 0 + if end is None: + end = 2**63 - 1 + if start == 0 and end >= 2**63 - 1 and step == 1: + return True + return False + + +@register_noop_decomp(aten.repeat) +def repeat_noop(self, repeats): + return all(r == 1 for r in repeats) + + +@register_noop_decomp(aten.constant_pad_nd) +def constant_pad_nd(x, padding, fill_value=0): + return all(p == 0 for p in padding) + + +@register_noop_decomp(torch.ops.prims.convert_element_type) +def convert_element_type_noop(x, dtype: torch.dtype): + return x.dtype == dtype + + +@register_noop_decomp(torch.ops.prims.device_put) +def device_put_noop(x, device): + return x.device == decode_device(device) + + +@register_noop_decomp([aten.ceil, aten.floor, aten.round, aten.trunc]) +def int_noop(x): + return is_integer_dtype(x.dtype) + + +@register_noop_decomp([aten.pow]) +def pow_noop(a, b): + return isinstance(b, int) and b == 1 + + +@register_noop_decomp([aten.cat], lambda args: args[0][0]) +def cat_noop(inputs, dim=0): + return len(inputs) == 1 + + +@register_noop_decomp(aten.view) +def view_noop(arg, size): + return arg.shape == size + + +# Note, we also always have a check for identical metadata, which is why these +# are safe +@register_noop_decomp([aten.copy], nop_arg=1) +@register_noop_decomp([aten.alias, aten.clone]) +def true_noop(*args, **kwargs): + return True + + +def remove_noop_ops(graph: torch.fx.Graph): + """ + Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph. + """ + inputs = set() + input_storages = set() + output_storages = set() + + for node in graph.nodes: + if node.op == "placeholder": + inputs.add(node) + input_storages.add(get_node_storage(node)) + else: + break + + output_node = next(iter(reversed(graph.nodes))) + assert output_node.op == "output" + for out in output_node.args[0]: + if isinstance(out, torch.fx.Node): + output_storages.add(get_node_storage(out)) + + for node in graph.nodes: + if node.target in noop_registry: + cond, src_index = noop_registry[node.target] + if isinstance(src_index, int): + src = node.args[src_index] + else: + src = src_index(node.args) + if not isinstance(src, torch.fx.Node): + continue + # Don't introduce new aliasing between inputs and outputs. + # See fx_passes/README.md for a discussion of why this is + # necessary. + node_storage = get_node_storage(node) + src_storage = get_node_storage(src) + node_is_view = node_storage == src_storage + if ( + not node_is_view + and node_storage in output_storages + and (src_storage in input_storages or src_storage in output_storages) + ): + continue + + # Even if input and outputs are expected to alias, + # don't make "node is src" True + if ( + node_is_view + and node in output_node.args + and (src in inputs or src in output_node.args) + ): + continue + + is_valid, args, kwargs = get_fake_args_kwargs(node) + if not is_valid: + continue + if same_meta(node, src) and cond(*args, **kwargs): + node.replace_all_uses_with(src) + graph.erase_node(node) + + +def decompose_auto_functionalized(graph): + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized), + pass_dict=graph_pass, + ) + def replacement(match: Match, *args, **kwargs): + from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense + + only_clone_these_tensors = tuple( + match.nodes[0].meta.get("only_clone_these_tensors", []) + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs) + + with V.fake_mode: + match.replace_by_example(decomp, flat_args, run_dce=False) + + graph_pass.apply(graph) + for node in graph.nodes: + if node.target is torch.ops.higher_order.auto_functionalized: + raise AssertionError("auto_functionalized was not removed") + + +@register_lowering_pattern( + CallFunction( + aten.cat, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split_with_sizes, + KeywordArg("input_"), + Ignored(), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + ), + ), + Ignored(), + ), + pass_number=2, + extra_check=is_valid_splitwithsizes_cat, +) +def splitwithsizes_cat_replace(match, input_): + return input_ + + +def is_valid_cat_splitwithsizes(match): + cat_nodes = filter_nodes(match.nodes, aten.cat) + split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) + if len(split_nodes) != 1 or len(cat_nodes) != 1: + return False + split_node, cat_node = split_nodes[0], cat_nodes[0] + + # the cat node has other users: can't eliminate + if len(cat_node.users) > 1: + return False + + # the dim of the cat and split should match + dim = get_arg_value(split_node, 2, "dim") + if dim != get_arg_value(cat_node, 1, "dim"): + return False + + cat_inputs = list(get_arg_value(cat_node, 0)) + split_sizes = get_arg_value(split_node, 1, "split_sizes") + # the number of input tensors in cat and the + # length of the split sizes should match + if len(cat_inputs) != len(split_sizes): + return False + + for cat_input, split_size in zip(cat_inputs, split_sizes): + # each cat input tensor's size along dim + # should match the corresponding split size + if "val" not in cat_input.meta: + return False + cat_input_size = cat_input.meta["val"].size(dim) + if cat_input_size != split_size: + return False + + return True + + +@register_lowering_pattern( + CallFunction( + aten.split_with_sizes, + CallFunction( + aten.cat, + KeywordArg("input_"), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + Ignored(), + ), + pass_number=2, + extra_check=is_valid_cat_splitwithsizes, +) +def cat_splitwithsizes_replace(match, input_): + return input_ + + +def view_to_reshape(gm): + """ + Replace view ops in the GraphModule to reshape ops. + """ + for nd in gm.graph.nodes: + if nd.target == torch.ops.aten.view.default: + nd.target = torch.ops.aten.reshape.default + + +def should_prefer_unfused_addmm(match): + inp = match.kwargs["inp"] + if not inp.meta["val"].is_cuda: + return False + + output = match.output_node() + return all(is_pointwise_use(use) for use in output.users) + + +@register_graph_pattern( + CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()), + pass_dict=pass_patterns[2], + extra_check=should_prefer_unfused_addmm, +) +def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp): + def repl(inp, x1, x2): + return x1 @ x2 + inp + + with V.fake_mode: + match.replace_by_example(repl, [inp, mat1, mat2]) + + +def is_valid_addmm_fusion(match): + mat1, mat2 = match.args + inp = match.kwargs["inp"] + + if not ( + isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor) + ): + return False # Input is a number + + in_shape = inp.meta["val"].shape + mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1] + matched = is_expandable_to(in_shape, mm_shape) + if not matched: + return False # Shape mismatch + + return not should_prefer_unfused_addmm(match) + + +@register_graph_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, Arg(), Arg()), + KeywordArg("inp"), + ), + pass_dict=pass_patterns[2], + extra_check=is_valid_addmm_fusion, +) +@register_graph_pattern( + CallFunction( + aten.add, + KeywordArg("inp"), + CallFunction(aten.mm, Arg(), Arg()), + ), + pass_dict=pass_patterns[2], + extra_check=is_valid_addmm_fusion, +) +def addmm(match, mat1, mat2, *, inp): + def repl(inp, mat1, mat2): + return aten.addmm(inp, mat1, mat2) + + with V.fake_mode: + match.replace_by_example(repl, [inp, mat1, mat2]) + + +def check_shape_cuda_and_fused_int_mm_mul_enabled(match): + return ( + config.force_fuse_int_mm_with_mul + and len(getattr(match.args[2].meta.get("val"), "shape", [])) == 2 + and getattr(match.args[2].meta.get("val"), "is_cuda", False) + ) + + +@register_lowering_pattern( + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.mul, + CallFunction( + aten._int_mm, + Arg(), + Arg(), + ), + Arg(), + ), + Arg(), + ), + check_shape_cuda_and_fused_int_mm_mul_enabled, +) +@register_lowering_pattern( + CallFunction( + aten.mul, + CallFunction( + aten._int_mm, + Arg(), + Arg(), + ), + Arg(), + ), + check_shape_cuda_and_fused_int_mm_mul_enabled, +) +def fused_int_mm_mul(match: Match, mat1, mat2, mat3, out_dtype=None): + return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype) + + +class ConstructorMoverPass: + def __init__(self, target: str, allow_outputs: bool = False) -> None: + """ + Move constructors from cpu to the target_device. + + Sweeps through the module, looking for constructor nodes that can be moved + to the target_device. + + A constructor node can be moved to the target_device iff all of its users + can also be moved (tested by cannot_be_moved). Otherwise, all dependent + constructor nodes won't be moved. + + - target: target device type + - allow_outputs: allow outputs to be moved + """ + + self.target = target + self.allow_outputs = allow_outputs + + assert isinstance(target, str), ( + "target should be a string representing the device type. " + f"Got: {type(target).__name__}" + ) + + def allow_cpu_device(self, node: fx.Node) -> bool: + """ + Returns whether a node that returns a tensor on the target device may have + cpu tensors as input. + """ + return node.target in ( + torch.ops.aten.index.Tensor, + torch.ops.aten.index_put.default, + torch.ops.aten.index_put_.default, + torch.ops.aten.copy.default, + torch.ops.aten.copy_.default, + torch.ops.aten.slice_scatter.default, + ) + + def cannot_be_moved(self, node: fx.Node) -> bool: + """ + Returns whether a node can be moved to the target device. + + If this function returns False, it means that this node and all of its users + won't be moved into the target device. + """ + if node.target == "output": + return not self.allow_outputs + + if not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace in ("prims", "aten") + ): + return True + + return False + + def get_node_device(self, node: fx.Node) -> Optional[torch.device]: + """ + Get the device of a node. + """ + ten = node.meta.get("val") + return None if not isinstance(ten, torch.Tensor) else ten.device + + def get_cpu_indeg_count(self, graph: fx.Graph) -> Dict[fx.Node, int]: + """ + Get the number of cpu inputs to a node + """ + cpu_indeg: Dict[fx.Node, int] = Counter() + + for node in graph.nodes: + cpu_count = 0 + + def add_cpu_inp(node): + nonlocal cpu_count + device = self.get_node_device(node) + cpu_count += device is not None and device.type == "cpu" + + pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs)) + + if cpu_count: + cpu_indeg[node] = cpu_count + + return cpu_indeg + + def __call__(self, graph: fx.Graph) -> None: + target_devices = set() + constructors = [] + + for node in graph.nodes: + device = self.get_node_device(node) + if device and device.type == self.target: + target_devices.add(device) + + if not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace in ("prims", "aten") + ): + continue + + if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target): + continue + + if not node.kwargs.get("device") == torch.device("cpu"): + continue + + constructors.append(node) + + # not handling multiple target devices initially + if not constructors or len(target_devices) != 1: + return + + movable_constructors = self.find_movable_constructors(graph, constructors) + + for node in movable_constructors: + kwargs = node.kwargs.copy() + kwargs["device"] = next(iter(target_devices)) + node.kwargs = kwargs + + def find_movable_constructors( + self, graph: fx.Graph, constructors: List[fx.Node] + ) -> Set[fx.Node]: + """ + Starting from the cpu constructors, iterate through the graph and test that all of their + downstream uses can safely be moved to cpu. + """ + cpu_indeg: Dict[fx.Node, int] = self.get_cpu_indeg_count(graph) + + # which constructors cannot be moved to cuda + cannot_move_to_cuda: Set[fx.Node] = set() + + # For any node in the graph, which constructors does it have a dependency on + constructor_dependencies: Dict[fx.Node, Set[fx.Node]] = defaultdict(set) + + # if a cpu node has a dependency on two different cpu constructors, + # then if either constructor cannot be moved to cuda, the other cannot as well. + # In this case any node with a dependency on one will have a dependency on the other + equal_constructor_sets: Dict[fx.Node, Set[fx.Node]] = { + c: {c} for c in constructors + } + + def make_dependencies_equivalent( + set1: Set[fx.Node], set2: Set[fx.Node] + ) -> Set[fx.Node]: + # could use union find but not worth complexity here + set1.update(set2) + for obj in set1: + equal_constructor_sets[obj] = set1 + return set1 + + queue: List[fx.Node] = list(constructors) + + for c in queue: + constructor_dependencies[c].add(c) + + while queue: + node = queue.pop() + dependencies = constructor_dependencies[node] + + for user in node.users: + if self.cannot_be_moved(user): + cannot_move_to_cuda.update(dependencies) + break + + # this node was used on a op which takes in multiple devices and output a cuda + # tensor. we can convert its cpu input to cuda without making further changes + node_device = self.get_node_device(user) + if ( + self.allow_cpu_device(user) + and node_device + and node_device.type == self.target + ): + del cpu_indeg[user] + else: + # otherwise, we should continue look at its downstream uses + cpu_indeg[user] -= 1 + if cpu_indeg[user] == 0: + del cpu_indeg[user] + queue.append(user) + + unioned_set = make_dependencies_equivalent( + dependencies, constructor_dependencies[user] + ) + constructor_dependencies[user] = unioned_set + + for node in cpu_indeg: + if constructor_dependencies[node]: + cannot_move_to_cuda.update(constructor_dependencies[node]) + + all_cannot_move_to_cuda = cannot_move_to_cuda.copy() + for constructor in cannot_move_to_cuda: + all_cannot_move_to_cuda.update(equal_constructor_sets[constructor]) + + return set(constructors) - all_cannot_move_to_cuda + + +def move_constructors_to_cuda(graph: fx.Graph) -> None: + """ + Moves intermediary tensors which are constructed on the cpu to cuda when safe + """ + ConstructorMoverPass("cuda")(graph) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py new file mode 100644 index 0000000000000000000000000000000000000000..b5fbf8b234b55fb0d09e91e900fe8294d54bdcef --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py @@ -0,0 +1,182 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py new file mode 100644 index 0000000000000000000000000000000000000000..8d53c01ec23521021837afc6606c3a5388817150 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py @@ -0,0 +1,202 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_5, mul_Tensor_6) +mul_Tensor_7 = CallFunction(aten.mul.Tensor, sub_Tensor_1, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_7, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default = CallFunction(aten.clone.default, div_Tensor) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_5, mul_Tensor_6) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +mul_Tensor_7 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_7, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py new file mode 100644 index 0000000000000000000000000000000000000000..7a2fa0c888cbf0c08476f879d653625278a0b264 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py @@ -0,0 +1,186 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/central_index.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/central_index.py new file mode 100644 index 0000000000000000000000000000000000000000..be277a9cf6142743fb5c6b7bb4584081fc52f356 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/central_index.py @@ -0,0 +1,114 @@ +# mypy: ignore-errors + +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py +from ._sfdp_pattern_1 import (_sfdp_pattern_1_training, _sfdp_pattern_1_inference, _sfdp_pattern_1_half_training, _sfdp_pattern_1_half_inference) +from ._sfdp_pattern_2 import (_sfdp_pattern_2_training, _sfdp_pattern_2_inference, _sfdp_pattern_2_half_training, _sfdp_pattern_2_half_inference) +from ._sfdp_pattern_3 import (_sfdp_pattern_3_training, _sfdp_pattern_3_inference, _sfdp_pattern_3_half_training, _sfdp_pattern_3_half_inference) +from ._sfdp_pattern_4 import (_sfdp_pattern_4_training, _sfdp_pattern_4_inference, _sfdp_pattern_4_half_training, _sfdp_pattern_4_half_inference) +from ._sfdp_pattern_5 import (_sfdp_pattern_5_training, _sfdp_pattern_5_inference, _sfdp_pattern_5_half_training, _sfdp_pattern_5_half_inference) +from ._sfdp_pattern_6 import (_sfdp_pattern_6_training, _sfdp_pattern_6_inference, _sfdp_pattern_6_half_training, _sfdp_pattern_6_half_inference) +from ._sfdp_pattern_7 import (_sfdp_pattern_7_training, _sfdp_pattern_7_inference, _sfdp_pattern_7_half_training, _sfdp_pattern_7_half_inference) +from ._sfdp_pattern_8 import (_sfdp_pattern_8_training, _sfdp_pattern_8_inference, _sfdp_pattern_8_half_training, _sfdp_pattern_8_half_inference) +from ._sfdp_pattern_9 import (_sfdp_pattern_9_training, _sfdp_pattern_9_inference, _sfdp_pattern_9_half_training, _sfdp_pattern_9_half_inference) +from ._sfdp_pattern_10 import (_sfdp_pattern_10_training, _sfdp_pattern_10_inference, _sfdp_pattern_10_half_training, _sfdp_pattern_10_half_inference) +from ._sfdp_pattern_11 import (_sfdp_pattern_11_training, _sfdp_pattern_11_inference, _sfdp_pattern_11_half_training, _sfdp_pattern_11_half_inference) +from ._sfdp_pattern_12 import (_sfdp_pattern_12_training, _sfdp_pattern_12_inference, _sfdp_pattern_12_half_training, _sfdp_pattern_12_half_inference) +from ._sfdp_pattern_13 import (_sfdp_pattern_13_training, _sfdp_pattern_13_inference, _sfdp_pattern_13_half_training, _sfdp_pattern_13_half_inference) +from ._sfdp_pattern_14 import (_sfdp_pattern_14_training, _sfdp_pattern_14_inference, _sfdp_pattern_14_half_training, _sfdp_pattern_14_half_inference) +from ._sfdp_pattern_15 import (_sfdp_pattern_15_training, _sfdp_pattern_15_inference, _sfdp_pattern_15_half_training, _sfdp_pattern_15_half_inference) +from ._sfdp_pattern_16 import (_sfdp_pattern_16_training, _sfdp_pattern_16_inference, _sfdp_pattern_16_bs1_training, _sfdp_pattern_16_bs1_inference, _sfdp_pattern_16_half_training, _sfdp_pattern_16_half_inference, _sfdp_pattern_16_half_bs1_training, _sfdp_pattern_16_half_bs1_inference, _sfdp_pattern_16_half_mask_fp32_training, _sfdp_pattern_16_half_mask_fp32_inference, _sfdp_pattern_16_half_mask_fp32_bs1_training, _sfdp_pattern_16_half_mask_fp32_bs1_inference) +from ._sfdp_pattern_17 import (_sfdp_pattern_17_training, _sfdp_pattern_17_inference, _sfdp_pattern_17_half_training, _sfdp_pattern_17_half_inference) + +central_index = { + '_sfdp_pattern_1_training': _sfdp_pattern_1_training, + '_sfdp_pattern_1_inference': _sfdp_pattern_1_inference, + '_sfdp_pattern_2_training': _sfdp_pattern_2_training, + '_sfdp_pattern_2_inference': _sfdp_pattern_2_inference, + '_sfdp_pattern_3_training': _sfdp_pattern_3_training, + '_sfdp_pattern_3_inference': _sfdp_pattern_3_inference, + '_sfdp_pattern_4_training': _sfdp_pattern_4_training, + '_sfdp_pattern_4_inference': _sfdp_pattern_4_inference, + '_sfdp_pattern_5_training': _sfdp_pattern_5_training, + '_sfdp_pattern_5_inference': _sfdp_pattern_5_inference, + '_sfdp_pattern_6_training': _sfdp_pattern_6_training, + '_sfdp_pattern_6_inference': _sfdp_pattern_6_inference, + '_sfdp_pattern_7_training': _sfdp_pattern_7_training, + '_sfdp_pattern_7_inference': _sfdp_pattern_7_inference, + '_sfdp_pattern_8_training': _sfdp_pattern_8_training, + '_sfdp_pattern_8_inference': _sfdp_pattern_8_inference, + '_sfdp_pattern_9_training': _sfdp_pattern_9_training, + '_sfdp_pattern_9_inference': _sfdp_pattern_9_inference, + '_sfdp_pattern_10_training': _sfdp_pattern_10_training, + '_sfdp_pattern_10_inference': _sfdp_pattern_10_inference, + '_sfdp_pattern_11_training': _sfdp_pattern_11_training, + '_sfdp_pattern_11_inference': _sfdp_pattern_11_inference, + '_sfdp_pattern_12_training': _sfdp_pattern_12_training, + '_sfdp_pattern_12_inference': _sfdp_pattern_12_inference, + '_sfdp_pattern_13_training': _sfdp_pattern_13_training, + '_sfdp_pattern_13_inference': _sfdp_pattern_13_inference, + '_sfdp_pattern_14_training': _sfdp_pattern_14_training, + '_sfdp_pattern_14_inference': _sfdp_pattern_14_inference, + '_sfdp_pattern_15_training': _sfdp_pattern_15_training, + '_sfdp_pattern_15_inference': _sfdp_pattern_15_inference, + '_sfdp_pattern_16_training': _sfdp_pattern_16_training, + '_sfdp_pattern_16_inference': _sfdp_pattern_16_inference, + '_sfdp_pattern_16_bs1_training': _sfdp_pattern_16_bs1_training, + '_sfdp_pattern_16_bs1_inference': _sfdp_pattern_16_bs1_inference, + '_sfdp_pattern_17_training': _sfdp_pattern_17_training, + '_sfdp_pattern_17_inference': _sfdp_pattern_17_inference, + '_sfdp_pattern_1_half_training': _sfdp_pattern_1_half_training, + '_sfdp_pattern_1_half_inference': _sfdp_pattern_1_half_inference, + '_sfdp_pattern_2_half_training': _sfdp_pattern_2_half_training, + '_sfdp_pattern_2_half_inference': _sfdp_pattern_2_half_inference, + '_sfdp_pattern_3_half_training': _sfdp_pattern_3_half_training, + '_sfdp_pattern_3_half_inference': _sfdp_pattern_3_half_inference, + '_sfdp_pattern_4_half_training': _sfdp_pattern_4_half_training, + '_sfdp_pattern_4_half_inference': _sfdp_pattern_4_half_inference, + '_sfdp_pattern_5_half_training': _sfdp_pattern_5_half_training, + '_sfdp_pattern_5_half_inference': _sfdp_pattern_5_half_inference, + '_sfdp_pattern_6_half_training': _sfdp_pattern_6_half_training, + '_sfdp_pattern_6_half_inference': _sfdp_pattern_6_half_inference, + '_sfdp_pattern_7_half_training': _sfdp_pattern_7_half_training, + '_sfdp_pattern_7_half_inference': _sfdp_pattern_7_half_inference, + '_sfdp_pattern_8_half_training': _sfdp_pattern_8_half_training, + '_sfdp_pattern_8_half_inference': _sfdp_pattern_8_half_inference, + '_sfdp_pattern_9_half_training': _sfdp_pattern_9_half_training, + '_sfdp_pattern_9_half_inference': _sfdp_pattern_9_half_inference, + '_sfdp_pattern_10_half_training': _sfdp_pattern_10_half_training, + '_sfdp_pattern_10_half_inference': _sfdp_pattern_10_half_inference, + '_sfdp_pattern_11_half_training': _sfdp_pattern_11_half_training, + '_sfdp_pattern_11_half_inference': _sfdp_pattern_11_half_inference, + '_sfdp_pattern_12_half_training': _sfdp_pattern_12_half_training, + '_sfdp_pattern_12_half_inference': _sfdp_pattern_12_half_inference, + '_sfdp_pattern_13_half_training': _sfdp_pattern_13_half_training, + '_sfdp_pattern_13_half_inference': _sfdp_pattern_13_half_inference, + '_sfdp_pattern_14_half_training': _sfdp_pattern_14_half_training, + '_sfdp_pattern_14_half_inference': _sfdp_pattern_14_half_inference, + '_sfdp_pattern_15_half_training': _sfdp_pattern_15_half_training, + '_sfdp_pattern_15_half_inference': _sfdp_pattern_15_half_inference, + '_sfdp_pattern_16_half_training': _sfdp_pattern_16_half_training, + '_sfdp_pattern_16_half_inference': _sfdp_pattern_16_half_inference, + '_sfdp_pattern_16_half_bs1_training': _sfdp_pattern_16_half_bs1_training, + '_sfdp_pattern_16_half_bs1_inference': _sfdp_pattern_16_half_bs1_inference, + '_sfdp_pattern_17_half_training': _sfdp_pattern_17_half_training, + '_sfdp_pattern_17_half_inference': _sfdp_pattern_17_half_inference, + '_sfdp_pattern_16_half_mask_fp32_training': _sfdp_pattern_16_half_mask_fp32_training, + '_sfdp_pattern_16_half_mask_fp32_inference': _sfdp_pattern_16_half_mask_fp32_inference, + '_sfdp_pattern_16_half_mask_fp32_bs1_training': _sfdp_pattern_16_half_mask_fp32_bs1_training, + '_sfdp_pattern_16_half_mask_fp32_bs1_inference': _sfdp_pattern_16_half_mask_fp32_bs1_inference, +} + + +def get_serialized_pattern(key): + import torch._inductor # noqa: F401 + from torch._inductor import config + if config.fallback_random: + return None + + # TODO - could add more validation that the same set of decomps used when + # tracing SDPA are also used in current context. softmax, dropout, etc + # decomp use is stable so not an issue in practice. + return central_index.get(key) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/split_cat.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/split_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..59a967e93463c998c617e9ade0035f694d8e7b48 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/split_cat.py @@ -0,0 +1,1537 @@ +import itertools +import logging +import operator +from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, Union + +from typing_extensions import TypeAlias + +import torch +from torch._dynamo.utils import counters + +from ..pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethodVarArgs, + config_flag, + FailedMatch, + get_arg_value, + Ignored, + KeywordArg, + ListOf, + Match, + MatchContext, + MULTIPLE, + PatternExpr, + register_graph_pattern, + RepeatedExpr, +) +from .group_batch_fusion import is_node_meta_valid +from .pre_grad import ( + merge_getitem_cat_pass, + merge_splits_pass, + normalization_pass, + split_cat_pass, + unbind_stack_pass, +) + +log = logging.getLogger(__name__) + +_Arguments: TypeAlias = Tuple[torch.fx.node.Argument, ...] +_TransformParam: TypeAlias = Tuple[ + Optional[_Arguments], + Optional[_Arguments], + Optional[_Arguments], + Optional[_Arguments], +] +_Range: TypeAlias = Tuple[int, int] + + +def _get_split_args_default(split_node): + input_kwarg = "tensor" + split_size_kwarg = "split_size_or_sections" + dim_kwarg = "dim" + default_dim_value = 0 + if split_node.op == "call_method": + split_size_kwarg = "split_size" + return ( + get_arg_value(split_node, 0, input_kwarg), + get_arg_value(split_node, 1, split_size_kwarg), + get_arg_value(split_node, 2, dim_kwarg) or default_dim_value, + ) + + +# noqa: W605 +# ############The pattern to be optimized is######### +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# | | +# split split -> dim=1, user=1, split_section_size=1 +# | | +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + +# ################After transformation############# +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + + +def remove_split_with_size_one( + graph: torch.fx.Graph, + node: torch.fx.Node, + input: torch.fx.Node, +): + # find the grand children of the split_node + next_users = find_next_users(node) + user = next(iter(node.users.keys())) + # replace the users of grand child node with the input node + for next_user in next_users: + next_user.replace_input_with(user, input) + # erase the split node and its child + graph.erase_node(user) + graph.erase_node(node) + + counters["inductor"]["remove_split_with_size_one"] += 1 + + +def normalize_split_base( + match: Match, + _get_split_args: Callable[ + [torch.fx.Node], Tuple[Optional[torch.fx.Node], Optional[Any], Optional[int]] + ], +): + """ + Normalize split with split_size into split_with_sizes, so that we only deal with one type of split in + subsequent optimizations + """ + split_node = match.nodes[0] + graph = match.graph + split_input, split_size, split_dim = _get_split_args(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if "example_value" not in split_node.meta: + log.debug("example value absent for node: %s", split_node) + return + assert isinstance(split_node.meta["example_value"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] + + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + # remove the dummy split whose split sections size is one + if len(split_sections) == 1: + remove_split_with_size_one(graph, split_node, split_input) + return + if split_dim < 0: # Normalize split dim + split_dim += split_input.meta["example_value"].dim() + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=(split_input, split_sections), + kwargs={"dim": split_dim}, + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + graph.erase_node(split_node) + counters["inductor"]["split_cat_norm"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.split, users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallMethodVarArgs("split", users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def normalize_split_default(match: Match, *args, **kwargs): + return normalize_split_base(match, _get_split_args_default) + + +@register_graph_pattern( + CallFunctionVarArgs(torch.unbind, users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallMethodVarArgs("unbind", users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def normalize_unbind_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + input = get_arg_value(node, 0, "input") + dim = get_arg_value(node, 1, "dim") + if dim is None: + axis = node.kwargs.get("axis") + if axis is not None: + dim = axis + else: + dim = 0 + if input is None: + log.debug("couldn't find unbind args") + return + if "example_value" not in input.meta: + log.debug("example value absent for node: %s", input) + return + ndim = input.meta["example_value"].ndim + if dim < 0: # Normalize unbind dim + dim += ndim + with graph.inserting_after(node): + new_node = graph.call_function( + torch.unbind, + args=(input,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["split_cat_norm"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.cat, users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def normalize_cat_default(match: Match, *args, **kwargs): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.debug("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if "example_value" not in tensor.meta: + log.debug("example value absent for node: %s", tensor) + return + + ndim = cat_node.meta["example_value"].dim() + + def is_empty_tensor(x): + # special case where torch.cat supports cat'ing with an empty tensor + x_shape = x.meta["example_value"].shape + return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0) + + assert all( + ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors + ) + + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=(tensors,), + kwargs={"dim": cat_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["split_cat_norm"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.stack, users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def normalize_stack_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(node, 0, "tensors") + dim = get_arg_value(node, 1, "dim") or 0 + if tensors is None or dim is None: + log.debug("couldn't find stack args") + return + assert isinstance(tensors, (list, tuple)) + + # A bug in pytorch, some nodes miss the example_value metadata + for tensor in itertools.chain([node], tensors): + if "example_value" not in tensor.meta: + log.debug("example value absent for node: %s", tensor) + return + + ndim = node.meta["example_value"].dim() + if dim < 0: # Normalize dim + dim += ndim + + with graph.inserting_after(node): + new_node = graph.call_function( + node.target, + args=(tensors,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["split_cat_norm"] += 1 + + +def find_next_users(split_node: torch.fx.Node) -> List[torch.fx.Node]: + next_users = [] + for getitem_node in split_node.users.keys(): + for getitem_user in getitem_node.users.keys(): + if getitem_user not in next_users: + next_users.append(getitem_user) + return next_users + + +@register_graph_pattern( + CallMethodVarArgs("squeeze", users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def normalize_squeeze_default(match: Match, *args, **kwargs): + squeeze_node = match.nodes[0] + squeeze_input = get_arg_value(squeeze_node, 0) + + if "dim" in squeeze_node.kwargs: + assert len(squeeze_node.args) == 1 + dim = squeeze_node.kwargs["dim"] + elif len(squeeze_node.args) == 1: + # squeeze(Tensor) + dim = None + elif len(squeeze_node.args) == 2: + # squeeze(Tensor self, int dim) + # squeeze(Tensor self, int[] dim) + dim = squeeze_node.args[1] + else: + # squeeze(Tensor self, int[] dim) (called with varargs) + dim = squeeze_node.args[1:] + + if isinstance(dim, Sequence) and len(dim) == 1: + dim = dim[0] + + with match.graph.inserting_after(squeeze_node): + if dim is None: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,) + ) + else: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,), kwargs={"dim": dim} + ) + squeeze_node.replace_all_uses_with(new_squeeze_node) + match.graph.erase_node(squeeze_node) + + +class TorchSplit(CallFunction): + """ + Matches a call to torch.split if it is in a normalized form. Ensures that all users of + splits are unique getitems. + """ + + def __init__(self, arg, sizes, func=torch.split): + # using KeywordArg("dim") for `dim` checks they all match + super().__init__(func, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim")) + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + m = super()._match(node, ctx) + if not m: + return m + split_sections = node.args[1] + if not isinstance(split_sections, (list, tuple)): + return FailedMatch("split not normalized") + # check users are all unique getitems + seen_idxs = set() + for user in node.users: + if not CallFunction(operator.getitem, Arg(), Arg()).match(user): + # This should ideally never happen. Split user should always be a getitem + return FailedMatch(f"user of split not a getitem: {user}") + if not isinstance(user.args[1], int): + return FailedMatch("only integer getitems are handled") + if user.args[1] in seen_idxs: + return FailedMatch(f"duplicate getitem {user.args[1]}") + if user.args[-1] < 0: # type: ignore[operator] + # This shouldn't ideally happen as dynamo normalizes indexes to positive + return FailedMatch("negative index") + seen_idxs.add(user.args[1]) + return m + + +@register_graph_pattern( + TorchSplit( + CallFunction( + operator.getitem, + TorchSplit( + KeywordArg("first_split_input"), + KeywordArg("first_split_sections"), + ), + Ignored(), + ), + KeywordArg("next_split_sections"), + ), + pass_dict=merge_splits_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def merge_splits( + match: Match, + first_split_input: torch.fx.Node, + first_split_sections: List[int], + next_split_sections: List[int], + # Note: dim is implicitly passed by TorchSplit, as it internally uses a pattern with dim + dim: int, +): + node = match.output_node() + # it is possible that the split has no users, + # we check the corner case and skip the pattern + if len(node.users.keys()) == 0: + return + graph = match.graph + first_split = node.args[0].args[0] # type: ignore[union-attr] + next_split_index = node.args[0].args[1] # type: ignore[union-attr] + + new_split_sections = list(first_split_sections) + new_split_sections[next_split_index : next_split_index + 1] = next_split_sections # type: ignore[operator, misc] + + first_split_dim = first_split.kwargs["dim"] # type: ignore[union-attr] + + to_remove = [] + + with graph.inserting_before(first_split): + # Add the new split node + new_split = graph.call_function( + torch.split, + args=(first_split_input, new_split_sections), + kwargs={"dim": first_split_dim}, + ) + first_split_num_to_user = { + user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr] + } + + new_split_num = 0 + for split_num in range(len(first_split_sections)): + if split_num not in first_split_num_to_user: + new_split_num += 1 + continue + old_getitem = first_split_num_to_user[split_num] + if split_num != next_split_index: + old_getitem.update_arg(0, new_split) + old_getitem.update_arg(1, new_split_num) + new_split_num += 1 + else: + next_split_num_to_user = { + user.args[1]: user for user in node.users.keys() + } + # It is not necessary all getitems from the split node are used. + # We use the num of users to check the getitems to be merged. + for next_split_num in range(len(node.users.keys())): + with graph.inserting_after(new_split): + new_getitem = graph.call_function( + operator.getitem, args=(new_split, new_split_num) + ) + new_split_num += 1 + next_getitem = next_split_num_to_user[next_split_num] + new_getitem.meta.update(next_getitem.meta) + next_getitem.replace_all_uses_with(new_getitem) + to_remove.append(next_getitem) + to_remove.append(node) + to_remove.append(old_getitem) + + to_remove.append(first_split) # type: ignore[arg-type] + for node in to_remove: + graph.erase_node(node) + + counters["inductor"]["consecutive_split_merged"] += 1 + + +class SplitCatSimplifier: + """ + Helper class to simplify split-cat pattern. In simple cases, both split and cat node can be removed in a "split->cat" + pattern. However, there are various cases where they can't and we need to simplify split/ add transforms before cat. + Some such cases are: + 1. Final node has additional args (not coming from the initial split) + 2. Shuffling of args between split/cat + 3. Some final nodes are non-(cat/stack) + 4. Split-dim != cat-dim (but equal split) + + Note that any combination of the above cases can happen. + + To deal with 1, 2, & 3 - we iterate over all users of split. And figure out common "ranges" that can be merged. + Then, we simplify the split accordingly. In the best case, split can be entirely removed. + + To deal with 4, we add some transformations (unflatten + movedim) (See `get_transform_params`). + + Finally, depending on final node being cat or stack, unsqueeze/flatten needs to be added. + + """ + + def simplify( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: List[int], + ): + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # Gather inputs of the next users. When inputs come from `split_node`, they are instead represented by + # a tuple indicating the split ranges. See `get_user_input_list` for more details + user_inputs_list = self.get_user_input_list(split_node, next_users) + # Simplify the split_sections based on user_inputs_list. In simpler cases, len(simplified_split_ranges) == 1 and + # we can simply replace the split node. Otherwise, we simplify it. + simplified_split_ranges = self.get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges: # Simplification not possible + return + transform_params_list = self.get_transform_params( + split_node, next_users, user_inputs_list + ) + if not transform_params_list: + return + + # Start actual replacement + user_inputs_list_new = self.replace_split( + graph, split_node, split_sections, user_inputs_list, simplified_split_ranges + ) + self.replace_cat( + graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type] + ) + self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type] + + def get_user_input_list( + self, split_node: torch.fx.Node, next_users: List[torch.fx.Node] + ) -> List[List[Union[torch.fx.Node, _Range]]]: + """ + Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner + list represents the inputs to that particular node. This list can either contain + - a tuple representing the ranges of get_items that should go into the cat (closed interval) + - torch.fx.Node representing "other" inputs (which are not coming from our split) + """ + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]] = [] + for user in next_users: + if user.target in {torch.cat, torch.stack}: + user_inputs_list.append(self.get_merged_user_inputs(split_node, user)) + else: + user_inputs_list.append(self.get_non_cat_node_input(split_node, user)) # type: ignore[arg-type] + return user_inputs_list + + def get_merged_user_inputs( + self, split_node: torch.fx.Node, cat_node: torch.fx.Node + ) -> List[Union[torch.fx.Node, _Range]]: + user_inputs = get_arg_value(cat_node, 0, "tensors") + simplified_user_inputs = [] + split_users = set(split_node.users.keys()) + for user_input in user_inputs: + if user_input not in split_users: + simplified_user_inputs.append(user_input) + else: + # Add which "getitem" cat depends on + simplified_user_inputs.append(user_input.args[1]) + return self.merge_consecutive_inputs(simplified_user_inputs) + + def get_non_cat_node_input( + self, split_node: torch.fx.Node, node: torch.fx.Node + ) -> List[_Range]: + """ + Get input for a non cat node in the same format as `get_merged_user_inputs` + """ + node_input = [] + split_users = set(split_node.users.keys()) + for node_arg in node.all_input_nodes: + if node_arg in split_users: + getitem_num = get_arg_value(node_arg, 1) + node_input.append((getitem_num, getitem_num)) + return node_input + + def merge_consecutive_inputs( + self, inputs: List[Union[torch.fx.Node, int]] + ) -> List[Union[torch.fx.Node, _Range]]: + """ + Merge consecutive inputs going into a user node. + + For e.g. + [arg0, 0, 1, 2, arg1] -> [arg0, (0, 2), arg1] + """ + merged_ranges = [] + cur_range = None + for input_ in inputs: + if isinstance(input_, int): + if not cur_range: + cur_range = [input_, input_] + elif input_ == cur_range[1] + 1: + cur_range[1] += 1 + else: + merged_ranges.append(tuple(cur_range)) + cur_range = [input_, input_] + else: + if cur_range: + merged_ranges.append(tuple(cur_range)) + cur_range = None + merged_ranges.append(input_) # type: ignore[arg-type] + if cur_range: + merged_ranges.append(tuple(cur_range)) + return merged_ranges # type: ignore[return-value] + + def get_simplified_split_ranges( + self, + split_sections, + next_users, + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[_Range]]: + ranges = set() + for user_node, user_inputs in zip(next_users, user_inputs_list): + ranges |= { + user_input + for user_input in user_inputs + if isinstance(user_input, tuple) + } + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + split_ranges = sorted( + [(cumulative_sizes[r[0]], cumulative_sizes[r[1] + 1]) for r in ranges] + ) + + if not self.has_non_overlapping_ranges( + split_ranges, + ): # This need not be a strict condition + # However, we keep it now for simplicity. + return None + split_ranges = self.fill_gaps(split_ranges, 0, cumulative_sizes[-1]) + if len(split_sections) == len(split_ranges): # Simplification not possible + return None + counters["inductor"]["scmerge_split_sections_removed"] = len( + split_sections + ) - len(split_ranges) + return split_ranges + + def has_non_overlapping_ranges(self, ranges: List[_Range]) -> bool: + for range_, next_range in zip(ranges, ranges[1:]): + if range_[1] > next_range[0]: + return False + return True + + def fill_gaps(self, ranges: List[_Range], min_: int, max_: int) -> List[_Range]: + cur = min_ + filled_ranges = [] + for a, b in ranges: + if cur < a: + filled_ranges.append((cur, a)) + filled_ranges.append((a, b)) + cur = b + if filled_ranges[-1][1] < max_: + filled_ranges.append((filled_ranges[-1][1], max_)) + return filled_ranges + + def get_transform_params( + self, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[List[_TransformParam]]]: + """ + Figure out what transforms are needed for each input to each cat node. + + We replace a split node with an unflatten followed by a movedim + """ + split_dim = split_node.kwargs["dim"] + split_sections = split_node.args[1] + transform_params_list: List[List[_TransformParam]] = [] + + for user_node, user_inputs in zip(next_users, user_inputs_list): + if user_node.target not in {torch.cat, torch.stack}: + transform_params_list.append([]) + continue + + cat_dim = get_arg_value(user_node, 1, "dim") + transform_params: List[_TransformParam] = [] + for user_input in user_inputs: + if split_dim == cat_dim and user_node.target == torch.cat: + # No transform needed + transform_params.append((None, None, None, None)) + elif isinstance(user_input, tuple): # Split being simplified + # Verify equal split + subset_split_sections = split_sections[ # type: ignore[index] + user_input[0] : user_input[1] + 1 + ] + # All sections should be equal + if len(set(subset_split_sections)) != 1: + return None + + num_splits = len(subset_split_sections) + unflatten_params = (split_dim, (num_splits, -1)) + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + transform_params.append( + (unflatten_params, movedim_params, None, None) + ) + elif ( + user_node.target == torch.stack or split_dim != cat_dim + ): # We need to unsqueeze inputs not coming through split + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-split inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + def replace_split( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: List[int], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + split_ranges: List[_Range], + ) -> List[List[torch.fx.Node]]: + """ + Replace the split node. It can either remove the split node if len(split_ranges) == 1, or simplify it + into a split with lesser sections if len(split_ranges) > 1. + + Returns the new `user_inputs_list`, with tuples replaced with new getitems from the newer split node. + """ + split_input = split_node.args[0] + split_dim = split_node.kwargs["dim"] + if len(split_ranges) == 1: # We can completely eliminate the split node + split_items = [split_input] + else: + with graph.inserting_after(split_node): + new_split = graph.call_function( + torch.split, + args=( + split_input, + [r[1] - r[0] for r in split_ranges], + ), + kwargs={"dim": split_dim}, + ) + new_split.meta.update(split_node.meta) + counters["inductor"]["scmerge_split_added"] += 1 + with graph.inserting_after(new_split): + split_items = [ + graph.call_function(operator.getitem, args=(new_split, i)) + for i in range(len(split_ranges)) + ] + # Now assign the right getitem to the right input + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + new_user_inputs_list = [] + for user_inputs in user_inputs_list: + new_user_inputs = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # Find the correct new getitem (present in split_items) + new_user_inputs.append( + split_items[ + split_ranges.index( + ( + cumulative_sizes[user_input[0]], + cumulative_sizes[user_input[1] + 1], + ) + ) + ] + ) + else: + new_user_inputs.append(user_input) + new_user_inputs_list.append(new_user_inputs) + return new_user_inputs_list # type: ignore[return-value] + + def replace_cat( + self, + graph: torch.fx.GraphModule, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + user_inputs_list_new, + transform_params_list: List[List[_TransformParam]], + ): + split_dim = split_node.kwargs["dim"] + + split_users = split_node.users.keys() + new_cats = [] + for user_node, user_inputs_new, transform_params in zip( + next_users, user_inputs_list_new, transform_params_list + ): + if user_node.target not in {torch.cat, torch.stack}: + # Change the args and kwargs of non-cat/stack nodes. Replace old getitems (belonging to + # the original split node) with the newer getitems + next_cat_input = 0 + for input_node in user_node.all_input_nodes: + if input_node in split_users: + user_node.replace_input_with( + input_node, user_inputs_new[next_cat_input] + ) + next_cat_input += 1 + continue + + # Handle cat/stack user nodes + cat_dim = get_arg_value(user_node, 1, "dim") + user_inputs_new_transformed = [] + # For `unsqueeze` transform, we will combine consecutive inputs with the same unsqueeze params, and stack them + to_stack = [] + stack_dim = None + with graph.inserting_before(user_node): + for user_input_new, transform_param in zip( + user_inputs_new, transform_params + ): + # Apply transforms + ( + unflatten_params, + movedim_params, + unsqueeze_params, + flatten_params, + ) = transform_param + if unsqueeze_params and ( + stack_dim is None or stack_dim == unsqueeze_params[0] + ): + to_stack.append(user_input_new) + stack_dim = unsqueeze_params[0] + continue + elif to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + to_stack = [] + stack_dim = None + user_inputs_new_transformed.append(stacked_input) + if unsqueeze_params: + to_stack.append(user_input_new) + stack_dim = unsqueeze_params[0] + continue + + if unflatten_params: + user_input_new = graph.call_function( + torch.unflatten, args=(user_input_new, *unflatten_params) + ) + if movedim_params: + user_input_new = graph.call_function( + torch.movedim, args=(user_input_new, *movedim_params) + ) + if flatten_params: + user_input_new = graph.call_function( + torch.flatten, args=(user_input_new, *flatten_params) + ) + user_inputs_new_transformed.append(user_input_new) + if to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + user_inputs_new_transformed.append(stacked_input) + + with graph.inserting_after(user_node): + if len(user_inputs_new_transformed) > 1: + new_cat_node = graph.call_function( + torch.cat, + args=(user_inputs_new_transformed,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta.update(user_node.meta) + counters["inductor"]["scmerge_cat_added"] += 1 + else: + new_cat_node = user_inputs_new_transformed[-1] + + if ( + user_node.target == torch.cat + and split_dim != cat_dim + and split_node.target == torch.split + ): + with graph.inserting_after(new_cat_node): + new_cat_node = graph.call_function( + torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1) + ) + user_node.replace_all_uses_with(new_cat_node) + new_cats.append(new_cat_node) + + def erase_old_nodes( + self, + graph: torch.fx.GraphModule, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + ): + to_remove = [split_node] + counters["inductor"]["scmerge_split_removed"] += 1 + to_remove.extend(split_node.users.keys()) + for next_user in next_users: + if next_user.target not in {torch.cat, torch.stack}: + continue + counters["inductor"]["scmerge_cat_removed"] += 1 + to_remove.append(next_user) + for node in reversed(to_remove): + graph.erase_node(node) + + +class UnbindCatRemover(SplitCatSimplifier): + """ + Helper class to merge Unbind->Cat/Stack. Many of the cases are similar to SplitCatSimplifier. + + Unbind can't be simplified like splits. So, we can only remove the unbind node. Other than this, + other cases like multiple users, additional args, dim mismatch are similar to `SplitCatSimplifier`, + hence we extend that class. + """ + + def remove_unbind( + self, + graph: torch.fx.Graph, + unbind_node: torch.fx.Node, + ): + num_unbind = ( # type: ignore[operator] + max(getitem_node.args[1] for getitem_node in unbind_node.users.keys()) + 1 # type: ignore[operator, union-attr, type-var] + ) + split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type] + + super().simplify(graph, unbind_node, split_sections) + + def get_simplified_split_ranges( + self, + split_sections: List[int], + next_users: List[torch.fx.Node], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[_Range]]: + simplified_split_ranges = super().get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges or len(simplified_split_ranges) != 1: + return None + return simplified_split_ranges + + def get_transform_params( + self, + unbind_node: torch.fx.Node, + next_users: List[torch.fx.Node], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[List[_TransformParam]]]: + """ + Figure out what transforms are needed for each input to each cat node. + + Here is the rough transforms we apply: + + x -> unbind -> stack => x -> movedim + + x -> unbind -> cat => x -> movedim -> flatten + + When cat/stack nodes have additional args: + + addn ---| addn -> unsqueeze ---| + x -> unbind -> stack => x -> movedim -> cat + + addn ---| addn ---| + x -> unbind -> cat => x -> movedim -> flatten -> cat + + (Note application of these depends on the dims as well) + + + """ + split_dim = unbind_node.kwargs["dim"] + transform_params_list: List[List[_TransformParam]] = [] + for user_node, user_inputs in zip(next_users, user_inputs_list): + cat_dim = get_arg_value(user_node, 1, "dim") or 0 + transform_params: List[_TransformParam] = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # User input is coming from unbind + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + flatten_params = None + if user_node.target == torch.cat: + flatten_params = (cat_dim, cat_dim + 1) + transform_params.append( + (None, movedim_params, None, flatten_params) + ) + elif ( + user_node.target == torch.stack + ): # We need to unsqueeze inputs not coming through unbind into cat + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-unbind inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + +class GetItem(CallFunction): + def __init__(self, arg, index, _users=1): + super().__init__(operator.getitem, arg, index, _users=_users) + + def find_anchor_nodes(self, ctx: MatchContext, searched: Set[torch.fx.Node]): + # We generally match GetItem with arg being an Arg(). So, we never return the anchor + # nodes as the stored node in ctx.pattern_to_node is returned. Here we override find_anchor_nodes + # to not use ctx.pattern_to_node + for pattern in self.flat_args_kwargs[0]: + if isinstance(pattern, PatternExpr): + for other_node in pattern.find_anchor_nodes(ctx, searched): + if not isinstance(other_node, torch.fx.Node): + continue + for node in other_node.users: + if node not in searched: + if self._match_fns(node): + yield node + searched.add(node) + + +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + KeywordArg("dim"), + _users=MULTIPLE, + ), + ), + pass_dict=split_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ) + ), + pass_dict=split_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def merge_split_squeeze( + match: Match, split_input: torch.fx.Node, split_sizes: List[int], dim: int +): + graph = match.graph + split = next(node for node in match.nodes if node.target == torch.split) + if not all(s == 1 for s in split_sizes): + return + if isinstance(dim, Sequence): + return + next_users = find_next_users(split) + if not all(node.target == torch.squeeze for node in next_users): + return + with graph.inserting_before(match.output_node()): + unbind = graph.call_function( + torch.unbind, args=(split_input,), kwargs={"dim": dim} + ) + for item_index, getitem_node in sorted( + [ + (getitem_node.args[1], getitem_node) + for getitem_node in split.users.keys() + ] + ): + squeeze = next(iter(getitem_node.users.keys())) + new_get_item = graph.call_function( + operator.getitem, args=(unbind, item_index) + ) + squeeze.replace_all_uses_with(new_get_item) + new_get_item.meta.update(squeeze.meta) + graph.erase_node(squeeze) + graph.erase_node(getitem_node) + graph.erase_node(split) + counters["inductor"]["split_squeeze_replaced"] += 1 + + +getitem_unbind = ListOf( + GetItem( + CallFunction( + torch.unbind, + KeywordArg("unbind_input"), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE), + pass_dict=unbind_stack_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=unbind_stack_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=unbind_stack_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + UnbindCatRemover().remove_unbind(match.graph, unbind_node) + + +getitem_split = ListOf( + CallFunction( + operator.getitem, + TorchSplit( + Ignored(), + KeywordArg("split_sections"), + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + tensors=getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=split_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=split_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + Ignored(), + _users=MULTIPLE, + ), + pass_dict=split_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def simplify_split_cat(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_node = next(node for node in match.nodes if node.target == torch.split) + SplitCatSimplifier().simplify(match.graph, split_node, split_sections) + + +# noqa: W605 +# ############pattern to be optimized is######### + +# split_node(dim=1) +# / \ ... / \ +# getitem getitem getitem getitem -> user=1 +# \ / \ / +# cat (user=mul, dim=1) cat(user=mul, dim=1) +# | \ | \ + +# ################after transformation############# + +# split_node(dim=1) +# / ... \ +# getitem getitem +# | \ | \ + + +def has_same_parent_node(node: torch.fx.Node): + # the input nodes of the node should come from the same parent + prev_node = None + for getitem in node.args[0]: # type: ignore[union-attr] + if getitem.target != operator.getitem: # type: ignore[union-attr] + return False + if prev_node is None: + prev_node = getitem.args[0] # type: ignore[union-attr] + else: + if getitem.args[0] != prev_node: + return False + return True + + +def remove_zeros(split_sections: List[int]): + """ + Remove zeros from the list and get the index mapping dict from getitem + in split node to getitem in new split node + """ + new_split_sections, index_mapping = [], {} + idx = 0 + for i in range(len(split_sections)): + if split_sections[i] > 0: + new_split_sections.append(split_sections[i]) + index_mapping[i] = idx + idx += 1 + + return new_split_sections, index_mapping + + +def is_sorted_and_consecutive(arr: List[int]) -> bool: + # check if the array is sorted + if arr == sorted(arr): + # check if the differences between adjacent elements are all 1 + return all(x[1] - x[0] == 1 for x in zip(arr, arr[1:])) + else: + return False + + +def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: List[int]) -> int: + """ + Calculate the fused tensor size in the indices + """ + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + return fused_tensor_size + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=merge_getitem_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def merge_getitem_cat(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + split_input, split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # 'immutable_list' object does not support mutation. Create a new copy of it + split_sections = list(split_sections) + for cat_user in next_users: + if cat_user.target == torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") + # check the all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + # check all getitem only has one single user + if ( + split_dim != cat_dim + or not has_same_parent_node(cat_user) + or not all(len(arg.users) == 1 for arg in cat_user.args[0]) # type: ignore[union-attr] + ): + continue + # find the index of getitems to be cated/stacked + indices = [] + for arg in cat_user.args[0]: # type: ignore[union-attr] + indices.append(arg.args[1]) # type: ignore[union-attr] + # the gettitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): + continue + # update the arg of cat user, only keep the first getitem + cat_user.update_arg(0, cat_user.args[0][0]) # type: ignore[index] + # calculate the fused tensor sizes in the indices + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + # update the split sections + split_sections[indices[0]] = calculate_fused_tensor_size( + split_node, indices + ) + # padding others with zeros to keep the same dict size + for i in indices[1:]: + split_sections[i] = 0 + # remove all unused indexes in the split_node + new_split_sections, index_mapping = remove_zeros(split_sections) + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=(split_input, split_sections), + kwargs={"dim": split_dim}, + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + # remove all unused getitem nodes + to_remove = [cat_user] + # dictionary keys changed during iteration + new_split_getitem_nodes = list(new_split_node.users.keys()) + for getitem_node in new_split_getitem_nodes: + if getitem_node.args[1] in indices[1:]: + to_remove.append(getitem_node) + # update meta data of getitem + elif getitem_node.args[1] == indices[0]: + cat_user.replace_all_uses_with(getitem_node) + getitem_node.meta.update(cat_user.meta) + else: + # update getitem index for new split node + getitem_node.update_arg(1, index_mapping[getitem_node.args[1]]) + graph.erase_node(split_node) + for getitem_node in to_remove: + graph.erase_node(getitem_node) + # update the split sections of new split node + new_split_node.update_arg(1, new_split_sections) + split_node = new_split_node + split_sections = new_split_sections + + counters["inductor"]["getitem_cat_merged"] += 1 + + +# ############pattern to be optimized is######### + +# split_node(dim=1) -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op /cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# split_node(dim=1) -> -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=split_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def mutate_cat_node(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + split_input, split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + for cat_user in next_users: + if cat_user.target == torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") or 0 + # check that all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + if split_dim != cat_dim or not has_same_parent_node(cat_user): + continue + # find the index of getitems to be cat + indices, idx_to_getitem = [], {} + for getitem in cat_user.args[0]: # type: ignore[union-attr] + indices.append(getitem.args[1]) # type: ignore[union-attr] + idx_to_getitem[getitem.args[1]] = getitem # type: ignore[union-attr] + # the gettitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): + continue + # case 1: the cat uses all getitems from the split + if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type] + # replace the users of the cat node to be the input of the split node + cat_user.replace_all_uses_with(split_node.args[0]) + # remove the cat node + graph.erase_node(cat_user) + counters["inductor"]["cat_mutated"] += 1 + # case 2: the cat uses some getitems from the split + elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type] + # check the split dim, and construct the slice tuple + start_fused_size = calculate_fused_tensor_size( + split_node, list(range(indices[0])) + ) + end_fused_size = start_fused_size + calculate_fused_tensor_size( + split_node, indices + ) + slice_list = [] + for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr] + if i != split_dim: + slice_list.append(slice(None, None, None)) + else: + slice_list.append(slice(start_fused_size, end_fused_size, None)) + with graph.inserting_after(split_node): + slice_node = graph.call_function( + operator.getitem, + args=(split_node.args[0], tuple(slice_list)), + ) + cat_user.replace_all_uses_with(slice_node) + slice_node.meta.update(cat_user.meta) + + # remove the cat node + graph.erase_node(cat_user) + counters["inductor"]["cat_mutated"] += 1 + + +# noqa: W605 +# ############The pattern to be optimized is######### +# split_node (dim=1) +# / ... \ ... / \ +# getitem getitem getitem getitem -> user=1 +# \ / +# stack (dim=0) -> user=1, getitems to be consecutive +# | +# tahn -> user=1 +# | +# unbind (dim=0) +# | + +# ################After transformation############# +# split_node (dim=1) +# / ... / \ +# getitem getitem getitem -> user=1 +# | +# tahn +# | +# split +# | + + +@register_graph_pattern( + CallFunction( + torch.tanh, + CallFunction( + torch.stack, + getitem_split, + dim=Ignored(), + ), + ), + pass_dict=merge_getitem_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallFunction( + torch.tanh, + CallFunction( + torch.stack, + tensors=getitem_split, + dim=Ignored(), + ), + ), + pass_dict=merge_getitem_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallFunction( + torch.tanh, + CallFunction( + torch.stack, + getitem_split, + Ignored(), + ), + ), + pass_dict=merge_getitem_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def merge_stack_tahn_unbind(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + split_input, split_size, split_dim = _get_split_args_default(split_node) + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # 'immutable_list' object does not support mutation. Create a new copy of it + split_sections = list(split_sections) + for user in next_users: + # stack user only has one user + if user.target == torch.stack: + stack_dim = get_arg_value(user, 1, "dim") or 0 + unbind_user = find_next_users(user)[0] + if unbind_user.target != torch.unbind: + continue + unbind_dim = get_arg_value(unbind_user, 1, "dim") or 0 + # stack and unbind should have the same dim + # check the all getitems in the user from the same node + # check all the getitems only has single user + if ( + stack_dim != unbind_dim + or not has_same_parent_node(user) + or not all(len(arg.users) == 1 for arg in user.args[0]) # type: ignore[union-attr] + ): + continue + # find the index of getitems to be stacked + indices = [] + split_sections_for_unbind = [] + for arg in user.args[0]: # type: ignore[union-attr] + indices.append(arg.args[1]) # type: ignore[union-attr] + split_sections_for_unbind.append(split_sections[arg.args[1]]) # type: ignore[union-attr] + # the gettitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): + continue + # update the arg of stack user, only keep the first getitem + user.update_arg(0, user.args[0][0]) # type: ignore[index] + # calculate the fused tensor sizes in the indices + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, index, assignment] + # update the split sections + split_sections[indices[0]] = calculate_fused_tensor_size( + split_node, indices + ) + # padding others with zeros to keep the same dict size + for i in indices[1:]: + split_sections[i] = 0 + # remove all unused indexes in the split_node + new_split_sections, index_mapping = remove_zeros(split_sections) + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=(split_input, split_sections), + kwargs={"dim": split_dim}, + ) + replace_unbind_with_split = graph.call_function( + torch.split, + args=(unbind_user.args[0], split_sections_for_unbind), + kwargs={"dim": split_dim}, + ) + unbind_user.replace_all_uses_with(replace_unbind_with_split) + replace_unbind_with_split.meta.update(unbind_user.meta) + # remove getitem and split, stack + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + # remove all unused getitem nodes + to_remove = [unbind_user] + # dictionary keys changed during iteration + new_split_getitem_nodes = list(new_split_node.users.keys()) + for getitem_node in new_split_getitem_nodes: + if getitem_node.args[1] in indices[1:]: + to_remove.append(getitem_node) + # update meta data of getitem + elif getitem_node.args[1] == indices[0]: + user.replace_all_uses_with(getitem_node) + getitem_node.meta.update(user.meta) + else: + # update getitem index for new split node + getitem_node.update_arg(1, index_mapping[getitem_node.args[1]]) + graph.erase_node(split_node) + graph.erase_node(user) + for getitem_node in to_remove: + graph.erase_node(getitem_node) + # update the split sections of new split node + new_split_node.update_arg(1, new_split_sections) + split_node = new_split_node + split_sections = new_split_sections + + counters["inductor"]["stack_tahn_unbind_merged"] += 1 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py new file mode 100644 index 0000000000000000000000000000000000000000..69f9807120ac7e3acdd4123d7917d038db79a526 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import logging +from typing import Optional, Sequence + +import torch +from torch import _prims, Tensor + +log = logging.getLogger(__name__) + + +def make_prim( + schema: str, + impl_aten, + return_type=_prims.RETURN_TYPE.NEW, + doc: str = "", + tags: Optional[Sequence[torch.Tag]] = None, +): + def meta(*args, **kwargs): + return _prims.TensorMeta(impl_aten(*args, **kwargs)) + + return _prims._make_prim( + schema=schema, + return_type=return_type, + meta=meta, + impl_aten=impl_aten, + doc=doc, + tags=tags, + ) + + +def eager_force_stride(input_tensor: Tensor, stride) -> Tensor: + if input_tensor.stride() == stride: + return input_tensor + new_tensor = input_tensor.clone().as_strided( + input_tensor.shape, + stride, + ) + new_tensor.copy_(input_tensor) + return new_tensor + + +# Custom prims used for handling randomness +seed = make_prim( + "inductor_seed(Device device) -> Tensor", + lambda device: torch.randint(2**63 - 1, [], device=device), + doc="create a fresh seed (one per call) for use with inductor_rand", + tags=(torch.Tag.nondeterministic_seeded,), +) +seeds = make_prim( + "inductor_seeds(int count, Device device) -> Tensor", + lambda count, device: torch.randint(2**63 - 1, [count], device=device), + doc="Horizontal fusion of many inductor_seed() calls", + tags=(torch.Tag.nondeterministic_seeded,), +) +lookup_seed = make_prim( + # if inductor_lookup_seed changes, update partitioners.py + "inductor_lookup_seed(Tensor seeds, int index) -> Tensor", + lambda seeds, index: seeds[index], + doc="Extract a single seed from the result of inductor_seeds()", +) +random = make_prim( + "inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor", + lambda size, seed, mode: getattr(torch, mode)(size, device=seed.device), + doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused", +) +randint = make_prim( + "inductor_randint(SymInt low, SymInt high, SymInt[] size, Tensor seed) -> Tensor", + lambda low, high, size, seed: torch.randint(low, high, size, device=seed.device), + doc="torch.randint() using backend-specific RNG that can be fused", +) +force_stride_order = make_prim( + "inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor", + eager_force_stride, + doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise", +) +masked_scatter_with_index = make_prim( + "inductor_masked_scatter_with_index(Tensor input, Tensor mask, Tensor source_idx, Tensor source) -> Tensor", + lambda input_tensor, mask, index, source: torch.masked_scatter( + input_tensor, mask, source + ), + doc="masked_scatter with precomputed indices", +) +_unsafe_index_put_ = make_prim( + "_unsafe_index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)", + lambda self, indices, values, accumulate=False: torch.ops.aten.index_put_( + self, indices, values, accumulate + ), + doc="Unsafe index_put_ (doesn't issue device asserts)", +) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..9f575a9cfd2ea1a1801f20ae47d6859c18be0512 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py @@ -0,0 +1,6006 @@ +import functools +import itertools +import logging +import os +import warnings +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import sympy + +import torch +import torch.ao.quantization.fx._decomposed +import torch.fx +import torch.utils._pytree as pytree +from torch._higher_order_ops.triton_kernel_wrap import ( + triton_kernel_wrapper_functional, + triton_kernel_wrapper_mutation, +) +from torch._prims_common import ( + canonicalize_dim, + canonicalize_dims, + check, + dtype_to_type, + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + get_computation_dtype, + is_boolean_dtype, + is_float_dtype, + is_integer_dtype, + Number, +) +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from .._dynamo.utils import import_submodule + +from . import config, inductor_prims, ir, test_operators # NOQA: F401 +from .decomposition import decompositions, get_decompositions +from .ir import ( + ExpandView, + IndexingConstant, + is_triton, + ops_wrapper, + PermuteView, + Pointwise, + Reduction, + SqueezeView, + TensorBox, + validate_ir, + View, +) +from .utils import ( + ceildiv, + decode_device, + is_dynamic, + is_pointwise_use, + pad_listlike, + parallel_num_threads, + sympy_product, +) +from .virtualized import ops, V + +log = logging.getLogger(__name__) +lowerings: Dict[torch._ops.OpOverload, Callable[..., Any]] = {} +layout_constraints: Dict[torch._ops.OpOverload, Callable[..., Any]] = {} +fallbacks: Set[torch._ops.OpOverload] = set() +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims +needs_realized_inputs: Set[torch._ops.OpOverload] = set() +foreach_ops: Set[torch._ops.OpOverload] = set() +inplace_foreach_ops: Set[torch._ops.OpOverload] = set() +inplaceable_foreach_ops: Dict[torch._ops.OpOverload, torch._ops.OpOverload] = dict() +quantized_decomposed = torch.ops.quantized_decomposed + + +def assert_nyi(cond, msg): + if not cond: + raise NotImplementedError(f"inductor does not support {msg}") + + +def add_needs_realized_inputs(fn): + if isinstance(fn, (list, tuple, set)): + return [add_needs_realized_inputs(x) for x in fn] + needs_realized_inputs.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + needs_realized_inputs.add(getattr(fn, overload)) + + +def add_layout_constraint(fn, constraint): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + layout_constraints[getattr(fn, overload)] = constraint + else: + layout_constraints[fn] = constraint + + +add_needs_realized_inputs( + [ + aten.as_strided, + aten.avg_pool2d, + aten.avg_pool2d_backward, + aten.bmm, + aten.convolution, + aten.convolution_backward, + aten.max_pool2d_with_indices, + aten.max_pool2d_with_indices_backward, + aten.mm, + aten.upsample_nearest2d, + aten._upsample_nearest_exact2d, + aten.upsample_bicubic2d, + aten._int_mm, + ] +) + +# TODO(jansel): ezyang says we won't need this in the future, try removing it +# based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28 +DTYPE_ID_LOOKUP = { + 0: torch.uint8, + 1: torch.int8, + 2: torch.int16, + 3: torch.int32, + 4: torch.int64, + 5: torch.float16, + 6: torch.float32, + 7: torch.float64, + 8: torch.complex32, + 9: torch.complex64, + 10: torch.complex32, + 11: torch.bool, + 15: torch.bfloat16, + # TODO(jansel): add quantized types? + # _(c10::qint8, QInt8) /* 12 */ + # _(c10::quint8, QUInt8) /* 13 */ + # _(c10::qint32, QInt32) /* 14 */ + # _(c10::quint4x2, QUInt4x2) /* 16 */ + # _(c10::quint2x4, QUInt2x4) /* 17 */ +} + + +def decode_dtype(dtype: int): + if not isinstance(dtype, int): + return dtype + assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP" + dtype = DTYPE_ID_LOOKUP[dtype] + return dtype + + +def is_integer_type(x): + if isinstance(x, TensorBox): + return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + elif isinstance(x, sympy.Expr): + return x.is_integer is True # type: ignore[attr-defined] + else: + return isinstance(x, int) + + +def is_boolean_type(x): + if isinstance(x, TensorBox): + return is_boolean_dtype(x.get_dtype()) + else: + return isinstance(x, bool) + + +def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND): + def construct_input(inp): + if isinstance(inp, (Number, sympy.Expr)): + return inp + else: + assert hasattr(inp, "get_dtype") + dim = len(inp.get_size()) + # construct a tmp tensor to feed into torch.result_type + return torch.zeros([1] * dim, dtype=inp.get_dtype()) + + inps = [construct_input(arg) for arg in args] + _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind) + return dtype + + +def get_overloads(aten_fn): + if not isinstance(aten_fn, (list, tuple)): + aten_fn = [aten_fn] + else: + aten_fn = list(aten_fn) + + for fn in list(aten_fn): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + if other_fn not in lowerings: + aten_fn.append(other_fn) + + return aten_fn + + +def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool): + indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + if (type_promotion_kind or convert_input_to_bool) and indices: + if convert_input_to_bool: + dtype = torch.bool + else: + # FIXME that's a crude approximation for promoting args + promoting_args = [ + a + for a in args + if isinstance(a, (Number, sympy.Expr)) or hasattr(a, "dtype") + ] + dtype = get_promoted_dtype( + *promoting_args, type_promotion_kind=type_promotion_kind + ) + + # sometimes args are an immutable list so we can't mutate them + def promote(arg): + if isinstance(arg, TensorBox): + return to_dtype(arg, dtype) + elif isinstance(arg, ir.Constant): + return ir.Constant(arg.value, dtype, args[indices[0]].get_device()) + else: + return arg + + args = [promote(a) for a in args] + if broadcast and indices: + for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + args[i] = x + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + + return args + + +def _register_foreach_lowering(aten_fn, decomp_fn): + """ + Add a foreach lowering to lowerings dict. + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + assert len(args) <= 2 + out = decomp_fn(*args, **kwargs) + validate_ir(out) + return out + + aten_fns = get_overloads(aten_fn) + foreach_ops.update(aten_fns) + lowerings.update(dict.fromkeys(aten_fns, wrapped)) + return wrapped + + +def _register_lowering( + aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool +): + """ + Add a lowering to lowerings dict + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + args: Union[List[Any], Tuple[Any, ...], Dict[Any, Any]] = list(args) + unpacked = False + # TODO maybe we need to use pytrees here + if len(args) == 1 and isinstance(args[0], (list, tuple)): + unpacked = True + args = args[0] + + # explicitly assert for "out=" ops for better error messages + assert not any( + x == "out" for x in kwargs.keys() + ), "out= ops aren't yet supported" + # kwargs tensors not supported yet unless it's a fallback op + assert not any(isinstance(x, TensorBox) for x in kwargs.values()) or all( + fn in fallbacks for fn in aten_fn + ) + + args = transform_args( + args, broadcast, type_promotion_kind, convert_input_to_bool + ) + + if unpacked: + args = [args] + + out = decomp_fn(*args, **kwargs) + validate_ir(out) + + return out + + aten_fn = get_overloads(aten_fn) + + lowerings.update(dict.fromkeys(aten_fn, wrapped)) + return wrapped + + +def register_lowering( + aten_fn, + broadcast=False, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, +): + """ + Shim to support decorator syntax. + """ + return functools.partial( + _register_lowering, + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + ) + + +def broadcast_symbolic_shapes(a, b): + """ + Broadcasting logic based on symbolic shapes. + + We give the shapes 0 and 1 concrete values, while all other shapes + are symbolic sympy formulas. + """ + output = [] + for x, y in itertools.zip_longest( + reversed(a), reversed(b), fillvalue=sympy.Integer(1) + ): + if y == 1: + output.append(x) + elif x == 1: + output.append(y) + else: + V.graph.sizevars.guard_equals(x, y) + if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols): + output.append(y) # prefer shorter formula + else: + output.append(x) + return tuple(reversed(output)) + + +def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None): + assert ( + override_return_dtype is None or type_promotion_kind is None + ), "only one of override_return_dtype or type_promotion_kind may be given" + + if override_return_dtype is None and type_promotion_kind is None: + type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + + if not any(isinstance(x, (sympy.Expr, int, float)) for x in inputs): + return inputs + if all(isinstance(x, (int, float, sympy.Expr)) for x in inputs): + dtype = override_return_dtype or get_promoted_dtype( + *inputs, type_promotion_kind=type_promotion_kind + ) + + def const_func(x): + if isinstance(x, sympy.Expr): + return ir.IndexingConstant(x, dtype, decode_device(None)) + else: + return ir.Constant(x, dtype, decode_device(None)) + + return [const_func(x) for x in inputs] + ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView))) + out = [] + for x in inputs: + if isinstance(x, (int, float)): + out.append( + ExpandView.create( + ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size()) + ) + ) + elif isinstance(x, sympy.Expr): + out.append( + ExpandView.create( + IndexingConstant(x, ex.get_dtype(), ex.get_device()), + list(ex.get_size()), + ) + ) + else: + out.append(x) + + return out + + +def make_pointwise( + fn, + override_return_dtype=None, + override_device=None, + override_fn_when_input_bool=None, + override_fn_when_cuda_float64=None, + allow_alpha=False, + triton_fallback=None, +): + def inner(*inputs: List[TensorBox], alpha=None): + if triton_fallback is not None and any(map(is_triton, inputs)): + assert not allow_alpha # not implemented + return triton_fallback(*inputs) + + inputs = promote_constants(inputs, override_return_dtype) + if allow_alpha: + if alpha is not None and alpha != 1: + inputs = list(inputs) + inputs[-1] = mul(inputs[-1], alpha) + else: + assert alpha is None + loaders = [x.make_loader() for x in inputs] + ranges = inputs[0].get_size() + dtype = override_return_dtype or inputs[0].get_dtype() + is_cuda = decode_device(inputs[0].get_device()).type == "cuda" + + for other in inputs[1:]: + assert isinstance(other, ir.BaseConstant) or len(ranges) == len( + other.get_size() + ), f"ndim mismatch {fn} {ranges} {other.get_size()}" + + def inner_fn(index): + assert len(index) == len(ranges), f"wrong ndim {index} {ranges}" + if dtype == torch.bool and override_fn_when_input_bool is not None: + return override_fn_when_input_bool(*[load(index) for load in loaders]) + elif override_fn_when_cuda_float64 and is_cuda and dtype == torch.float64: + return override_fn_when_cuda_float64(*[load(index) for load in loaders]) + else: + return fn(*[load(index) for load in loaders]) + + if not override_device: + device = None + for i in inputs: + if i.get_device().type == "cuda": + device = i.get_device() + break + if not device: + device = inputs[0].get_device() + + device = override_device or device + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + ) + + return inner + + +def make_foreach_pointwise(pw_fn, allow_alpha=False): + def inner(*inputs: List[List[TensorBox]], alpha=1): + # group by device, whether any of the inputs are dynamic, and whether their types match + # (proxy for type promotion) + def group_args(arg_pairs): + out = defaultdict(list) + for i, args in enumerate(arg_pairs): + use_foreach = not is_dynamic(*args) + device = None + for t in args: + if isinstance(t, TensorBox): + device = t.data.get_device() + break + assert ( + device is not None + ), "foreach op should have at least one tensor arg" + out[(device, use_foreach)].append((i, args)) + return out + + realize_outputs = ( + len(V.graph.current_node.users) == 0 + or V.graph.current_node.target in inplace_foreach_ops + ) + for node in V.graph.current_node.users: + for user in node.users: + if not (user.op == "call_function" and (user.target in foreach_ops)): + realize_outputs = True + + a_list_input = None + for input in inputs: + if isinstance(input, (list, tuple)): + a_list_input = input + break + assert ( + a_list_input is not None + ), "at least one input must be a list to a foreach op" + + # broadcast scalar inputs to match length of list inputs + broadcast_inputs = [] + for input in inputs: + if not isinstance(input, (list, tuple)): + broadcast_inputs.append([input] * len(a_list_input)) + else: + broadcast_inputs.append(input) + + groups = group_args(zip(*broadcast_inputs)) + + outputs = [None] * len(a_list_input) + for (device, use_foreach), group in groups.items(): + buffer_list = [] + for ( + output_ind, + args, + ) in group: + if allow_alpha: + output = pw_fn(*args, alpha=alpha) + else: + output = pw_fn(*args) + + outputs[output_ind] = output + + if device.type == "cuda" and use_foreach and realize_outputs: + buffer_list.append(output.realize()) + + if buffer_list: + V.graph.register_list(buffer_list) + + assert all(x is not None for x in outputs) + return outputs + + return inner + + +def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False): + src_dtype = x.get_dtype() + if src_dtype == dtype: + return clone(x) if copy else x + + def _to_dtype(x): + return ops.to_dtype(x, dtype, src_dtype=src_dtype) + + return make_pointwise(_to_dtype, override_return_dtype=dtype)(x) + + +@register_lowering(prims.convert_element_type, type_promotion_kind=None) +def _convert_element_type(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + if x.get_size(): + # Decompose since aa aten fallback is more friendly for c++ codegen. + # This decompostion doesn't work for empty tensor, which needs more investigation. + dst = empty_like(x, dtype=dtype) + ir.InplaceCopyFallback.create(dst, x) + return dst + else: + return fallback_handler( + prims.convert_element_type.default, add_to_fallback_set=False + )(x, dtype) + return to_dtype(x, dtype, copy=True) + + +def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False): + x_dtype = x.get_dtype() + if x_dtype == dtype: + return clone(x) if copy else x + + def _get_primitive_bitwidth(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits + else: + return torch.iinfo(dtype).bits + + src_bits = _get_primitive_bitwidth(x_dtype) + dst_bits = _get_primitive_bitwidth(dtype) + if src_bits != dst_bits: + raise NotImplementedError( + f"bitcast {x_dtype} to different bitwidth type {dtype} is not supported yet." + ) + + def _to_dtype_bitcast(x): + # Because we may promote tensor type from float16 or bfloat16 + # to float, we will need to pass the original src dtype (i.e. x_dtype), + # which is used for correctly constructing type conversion before bitcast, + # which requires the bitwidth of the input tensor type is the same as the + # target type. + return ops.to_dtype_bitcast(x, dtype, x_dtype) + + return make_pointwise(_to_dtype_bitcast, override_return_dtype=dtype)(x) + + +@register_lowering(aten.view.dtype, type_promotion_kind=None) +def _view_dtype(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + return TensorBox.create( + ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype) + ) + return to_dtype_bitcast(x, dtype, copy=True) + + +def to_device(x: TensorBox, device: torch.device, *, copy=False): + device = decode_device(device) + if x.get_device() == device: + return clone(x) if copy else x + return TensorBox.create(ir.DeviceCopy.create(x, device)) + + +@register_lowering(prims.device_put, type_promotion_kind=None) +def _device_put(x: TensorBox, device: torch.device): + return to_device(x, device, copy=True) + + +def register_pointwise( + aten_fn, + name=None, + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, + override_return_dtype=None, + override_fn_when_input_bool=None, + allow_alpha=False, + use_libdevice_for_f64=False, + triton_fallback=None, +): + """A pointwise function that maps ops.{name} to inputs""" + name = name or aten_fn.__name__ + fn = ops_wrapper(name) + if use_libdevice_for_f64: + fn_libdevice = ops_wrapper("libdevice_" + name) + if override_fn_when_input_bool is not None: + override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool) + + fn = make_pointwise( + fn, + override_return_dtype=override_return_dtype, + override_fn_when_input_bool=override_fn_when_input_bool, + override_fn_when_cuda_float64=fn_libdevice if use_libdevice_for_f64 else None, # type: ignore[possibly-undefined] + allow_alpha=allow_alpha, + triton_fallback=triton_fallback, + ) + fn = register_lowering( + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + convert_input_to_bool=convert_input_to_bool, + )(fn) + return fn + + +def register_frexp(): + """A pointwise function that maps ops.frexp to inputs""" + name = "frexp" + frexp = ops_wrapper("frexp") + + def frexp0(*args, **kwargs): + return frexp(*args, **kwargs)[0] + + def frexp1(*args, **kwargs): + return frexp(*args, **kwargs)[1] + + pw_fns = [ + make_pointwise(frexp0), + make_pointwise(frexp1, override_return_dtype=torch.int32), + ] + + def fn(*args, **kwargs): + return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs) + + fn = register_lowering( + aten.frexp, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + )(fn) + return fn + + +register_frexp() + + +def register_foreach_pointwise( + aten_fn, + pointwise_lowering_fn, + allow_alpha=False, +): + fn = make_foreach_pointwise(pointwise_lowering_fn, allow_alpha=allow_alpha) + fn = _register_foreach_lowering(aten_fn, fn) + return fn + + +@register_lowering(aten.where, broadcast=False, type_promotion_kind=None) +def where(cond, a, b): + def fn(*args): + return ops.where(*args) + + if isinstance(a, (float, int)): + a = constant_like(a)(b) + if isinstance(b, (float, int)): + b = constant_like(b)(a) + + args = [cond, a, b] + dtype = get_promoted_dtype( + args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + args[i] = x + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + return make_pointwise(fn, override_return_dtype=dtype)( + args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype) + ) + + +@register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None) +def broadcast_tensors(*inputs): + if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): + return broadcast_tensors(*inputs[0]) + target: List[sympy.Expr] = functools.reduce( + broadcast_symbolic_shapes, [x.get_size() for x in inputs], [] + ) + outputs = [] + for x in inputs: + sizes = x.get_size() + if len(sizes) != len(target) or any( + ((a == 1 and b != 1) or (a != 1 and b == 1)) for a, b in zip(sizes, target) + ): + x = expand(x, target) + outputs.append(x) + return outputs + + +@register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of]) +def nop(x): + return x # AOT autograd handles this for us + + +if hasattr(aten, "lift_fresh"): + register_lowering(aten.lift_fresh)(nop) + + +@register_lowering(aten.squeeze, type_promotion_kind=None) +def squeeze(x, dim=None): + assert isinstance(x, TensorBox) + if dim is None: + return TensorBox(SqueezeView.create(x.data)) + + dim = canonicalize_dims(len(x.get_size()), dim) + dims = set((dim,) if not isinstance(dim, tuple) else dim) + + new_shape = [] + for d, s in enumerate(x.get_size()): + if not (d in dims and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1))): + new_shape.append(s) + + # squeeze does nothing if the size isn't 1 + return view(x, new_shape) if new_shape != x.get_size() else x + + +@register_lowering(aten.squeeze_copy, type_promotion_kind=None) +def squeeze_copy(x, dim=None): + return clone(squeeze(x, dim)) + + +@register_lowering([aten.squeeze_]) +def squeeze_(x, dim=None): + val = squeeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +@register_lowering(aten.isinf) +def isinf(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isinf") + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.isnan) +def isnan(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isnan") + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.ceil) +def ceil(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("ceil") + return make_pointwise(fn)(x) + + +@register_lowering(aten.floor) +def floor(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("floor") + return make_pointwise(fn)(x) + + +@register_lowering(aten.round.default) +def round(x): + if is_integer_type(x): + return clone(x) + else: + fn = ops_wrapper("round") + return make_pointwise(fn)(x) + + +@register_lowering(aten.trunc) +def trunc(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("trunc") + return make_pointwise(fn)(x) + + +@register_lowering(aten.expand, type_promotion_kind=None) +def expand(x, sizes): + (x,) = promote_constants([x]) + if isinstance(x, ir.BaseConstant): + return ExpandView.create(x, tuple(sizes)) + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + if tuple(x.get_size()) == tuple(sizes): + return x + + if not any(V.graph.sizevars.shape_env.is_unbacked_symint(s) for s in x.get_size()): + x_size_product = V.graph.sizevars.size_hint(sympy_product(x.get_size())) + # TODO: It would be better to realize the input if any of its sizes + # are unbacked, because typically the size will be non-zero. However, + # this cannot be done directly as below as we'll choke on the size_hint + # here + if x_size_product > 0 and not any( + V.graph.sizevars.shape_env.is_unbacked_symint(s) for s in sizes + ): + # maybe realize input before broadcasting it + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(sizes)) // x_size_product + ) + return TensorBox(ExpandView.create(x.data, tuple(sizes))) + + +@register_lowering(prims.broadcast_in_dim, type_promotion_kind=None) +def broadcast_in_dim(a, shape, broadcast_dimensions): + s = list(shape) + for broadcast_dimension in broadcast_dimensions: + s[broadcast_dimension] = -1 + + v = a + for idx, x in enumerate(s): + if x != -1: + v = unsqueeze(v, idx) + + return expand(v, shape) + + +@register_lowering(aten.expand_as, type_promotion_kind=None) +def expand_as(x, y): + return expand(x, y.get_size()) + + +@register_lowering(aten.repeat) +def repeat(x, repeats): + old_size = list(x.get_size()) + if len(repeats) > len(old_size): + old_size = [sympy.Integer(1)] * (len(repeats) - len(old_size)) + old_size + x = view(x, list(old_size)) + assert len(repeats) == len(x.get_size()) + + new_size = list(x.get_size()) + + zero_tensor = False + for i in range(len(repeats)): + if repeats[i] == 0: + zero_tensor = True + new_size[i] = new_size[i] * repeats[i] + + if zero_tensor: + return empty(new_size, dtype=x.get_dtype(), device=x.get_device()) + if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)): + return expand(x, new_size) + + x_loader: Callable[[Any], Any] + + def inner_fn(index): + assert len(index) == len(repeats) + index = list(index) + for i in range(len(repeats)): + if repeats[i] != 1: + if old_size[i] == 1: + index[i] = sympy.Integer(0) + else: + index[i] = ModularIndexing(index[i], 1, old_size[i]) + return x_loader(index) + + old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size)) + if old_size_product > 0: + # maybe realize the input + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product + ) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(new_size), + ) + + +@register_lowering(aten._unsafe_view, type_promotion_kind=None) +@register_lowering(aten.view, type_promotion_kind=None) +@register_lowering(aten.reshape, type_promotion_kind=None) +def view(x, sizes): + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + return TensorBox(View.create(x.data, sizes)) + + +@register_lowering(aten.permute, type_promotion_kind=None) +def permute(x, dims): + assert isinstance(x, TensorBox) + assert isinstance(dims, (list, tuple)) + return TensorBox(PermuteView.create(x.data, tuple(dims))) + + +@register_lowering(aten.slice, type_promotion_kind=None) +def slice_(x, dim=0, start=0, end=2**63, step=1): + assert isinstance(x, TensorBox) + dim = _validate_dim(x, dim, 0) + dim_size = x.get_size()[dim] + return TensorBox(ir.SliceView.create(x.data, dim, start, end, step)) + + +@register_lowering(aten.as_strided, type_promotion_kind=None) +def as_strided(x, size, stride, storage_offset=None): + if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView): + # as_strided ignores views + x = x.data.unwrap_view() + x.realize() + if not ir.is_storage_and_layout(x): + raise NotImplementedError(f"unrealized as_strided({x}, ...)") + storage, old_layout = ir.as_storage_and_layout(x) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + [sympy.expand(s) for s in size], + [sympy.expand(s) for s in stride], + sympy.expand(storage_offset or 0), + ) + return TensorBox(ir.ReinterpretView(storage, new_layout)) + + +@register_lowering(aten.as_strided_, type_promotion_kind=None) +def as_strided_(x, size, stride, storage_offset=None): + assert isinstance(x, TensorBox) + x.data = as_strided(x, size, stride, storage_offset).data + return x + + +@register_lowering(aten.as_strided_copy, type_promotion_kind=None) +def as_strided_copy(x, size, stride, storage_offset=None): + result = as_strided(x, size, stride, storage_offset) + return clone(result) + + +def pointwise_cat(inputs, dim=0): + # (inclusive, exclusive) + inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = [] + prev_end = 0 + for inp in inputs: + inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type] + prev_end = inputs_ranges[-1][-1] # type: ignore[assignment] + + inputs_loaders = [inp.make_loader() for inp in inputs] + + def inner_fn(idx): + idx_dim = ops.index_expr(idx[dim], torch.int64) + + masks = [] + masked_loads = [] + for i in range(len(inputs)): + start = ( + ops.constant(0, torch.int64) + if i == 0 + else ops.index_expr(inputs_ranges[i][0], torch.int64) + ) + end = ops.index_expr(inputs_ranges[i][1], torch.int64) + + start_cond = ops.ge(idx_dim, start) + end_cond = ops.lt(idx_dim, end) + if i == 0: + mask = end_cond + elif i == len(inputs) - 1: + mask = start_cond + else: + mask = ops.and_(start_cond, end_cond) + + masks.append(mask) + idx_load = list(idx) + + # if we're concatting [4], [2] + # when we index the second tensor for 5 we want to index 5 - 4 + idx_load[dim] -= inputs_ranges[i][0] + + masked_loads.append( + ops.masked( + mask, + lambda: inputs_loaders[i](idx_load), + 0.0, # this value should be unused + ), + ) + + next_val = masked_loads[-1] + for i in range((len(inputs)) - 2, -1, -1): + next_val = ops.where( + masks[i], + masked_loads[i], + next_val, + ) + return next_val + + new_size = list(inputs[0].get_size()) + new_size[dim] = inputs_ranges[-1][-1] + + return Pointwise.create( + device=inputs[0].get_device(), + dtype=inputs[0].get_dtype(), + inner_fn=inner_fn, + ranges=new_size, + ) + + +@register_lowering(quantized_decomposed.quantize_per_channel, type_promotion_kind=None) +def quantized_decomposed_quantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert ( + input.get_dtype() == torch.float32 + ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + assert axis < len( + input.get_size() + ), f"Expecting axis to be < {len(input.get_size())}" + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.int32: + zero_point = ops.to_dtype(zero_point, torch.int32) + inv_scale = ops.reciprocal(scale) + val = ops.round(input * inv_scale) + zero_point + clamped = ops.maximum(qmin, ops.minimum(qmax, val)) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_channel, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + assert ( + input.get_dtype() == dtype + ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + assert axis < len( + input.get_size() + ), f"Expecting axis to be < {len(input.get_size())}" + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.float32: + zero_point = ops.to_dtype(zero_point, torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale + return val + + return Pointwise.create( + device=input.get_device(), + dtype=torch.float32, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering(aten.cat) +def cat(inputs, dim=0): + if all(input.get_dtype() in [torch.int8, torch.uint8] for input in inputs): + # TODO Remove this fallback when we support vectorization + # code gen with uint8 data type directly. + for input in inputs: + input.realize() + if all(len(input.get_size()) == 4 for input in inputs): + inputs, _ = require_channels_last(aten.cat, *inputs) + return fallback_handler(aten.cat.default)(inputs, dim) + + if len(inputs) == 1: + return clone(inputs[0]) + + dim = _validate_dim(inputs[0], dim, 0) + dtype = get_promoted_dtype( + *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + inputs = [to_dtype(inp, dtype) for inp in inputs] + + def unwrap_tensor(x: Union[TensorBox, ir.StorageBox]) -> ir.IRNode: + if isinstance(x, TensorBox): + if isinstance(x.data, ir.BaseView): + return x.data.unwrap_view() + else: + return x.data + + if isinstance(x, ir.StorageBox): + return x.data + + return x + + def should_lower_cat_input(x) -> bool: + # Unrealized inputs will not be storage and layouts, and we dont want to realize + # them in case we want to fuse + if ir.is_storage_and_layout(x): + storage, _ = ir.as_storage_and_layout(x, freeze=False) + return not ir.ConcatKernel.can_realize_into_without_copy(storage) + + if isinstance(x, (TensorBox, ir.StorageBox)): + return should_lower_cat_input(unwrap_tensor(x)) + + if isinstance(x, ir.Pointwise): + return True + + return False + + def is_reduction(t): + return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction) + + def can_fuse_reduction(t): + if isinstance(t, (TensorBox, ir.StorageBox)): + return can_fuse_reduction(unwrap_tensor(t)) + return ( + is_reduction(t) + or isinstance(t, ir.Pointwise) + and any( + can_fuse_reduction(V.graph.get_buffer(read)) + for read in t.get_read_names() + ) + ) + + # fusing reducutions into computed concat buffer can cause regressions. + fusable_reduction = any(can_fuse_reduction(t) for t in inputs) + + # TODO: We observed negative performance impact of pointwise_cat optimization on CPU so disabled it. + # We will revisit this later after enabling vectorization on index_expr. + if inputs[0].get_device().type == "cpu" or fusable_reduction: + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + def op_count(x): + if isinstance(x, (TensorBox, ir.StorageBox)): + return op_count(unwrap_tensor(x)) + + # this will correspond to a direct memory read + if not isinstance(x, ir.Pointwise): + return 0 + + count = x.inner_fn_opcount() + for read in x.get_read_names(): + count += op_count(V.graph.get_buffer(read)) + + return count + + # as of inputs increase, possibility for register spilling also increases + # past a certain threshold of inputs we only fuse if the if the input kernels + # are simple + # not sure if we want to expose to users via config since logic may change in future + MAX_COMPLEX_POINTWISE_CAT = 8 + MAX_SIMPLE_OP_COUNT = 2 + + if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or ( + (len(inputs) <= config.max_pointwise_cat_inputs) + and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs) + ): + pointwise_uses = all(is_pointwise_use(use) for use in V.current_node.users) + all_pointwise_inputs = all(should_lower_cat_input(inp) for inp in inputs) + any_pointwise_inputs = any(should_lower_cat_input(inp) for inp in inputs) + + if all_pointwise_inputs or (any_pointwise_inputs and pointwise_uses): + return pointwise_cat(inputs, dim) + + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + +@register_lowering(aten.diagonal, type_promotion_kind=None) +def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + original_shape = input.get_size() + num_dims = len(original_shape) + dim1 = canonicalize_dim(idx=dim1, rank=num_dims) + dim2 = canonicalize_dim(idx=dim2, rank=num_dims) + + check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0)) + if offset_negative: + diag_size = max(min(original_shape[dim1] + offset, original_shape[dim2]), 0) + else: + diag_size = max(min(original_shape[dim1], original_shape[dim2] - offset), 0) + + base_idx = (0, 0) + if offset_negative: + base_idx = (-offset, 0) + else: + base_idx = (0, offset) + + sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)] + sizes.append(diag_size) + + def reindexer(idx): + diag_idx = idx[-1] + original_idx = [0] * len(original_shape) + cur_dim = 0 + for d in range(num_dims): + if d == dim1: + original_idx[d] = diag_idx + base_idx[0] + elif d == dim2: + original_idx[d] = diag_idx + base_idx[1] + else: + original_idx[d] = idx[cur_dim] + cur_dim += 1 + + assert cur_dim == len(original_shape) - 2 + return original_idx + + return TensorBox(ir.GenericView.create(input, sizes, reindexer)) + + +@register_lowering(aten.diagonal_copy, type_promotion_kind=None) +def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + return clone(diagonal(input, offset, dim1, dim2)) + + +@register_lowering(aten.diagonal_scatter, type_promotion_kind=None) +def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1): + output = clone(input) + target = diagonal(output, offset, dim1, dim2) + mutate_to(target, src) + return output + + +@register_lowering(aten.select, type_promotion_kind=None) +def select(x, dim, idx): + idx = View.handle_negative_index(idx, x.get_size()[dim]) + return squeeze(slice_(x, dim, idx, idx + 1), dim) + + +@register_lowering(aten.split, type_promotion_kind=None) +def split(x, sizes, dim=0): + dim = _validate_dim(x, dim, 0) + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + if isinstance(sizes, sympy.Expr): + # TODO: We don't have to guard on sizes per se, but the number + # of splits must stay constant + sizes = V.graph.sizevars.evaluate_static_shape(sizes) + if isinstance(sizes, (int, sympy.Integer)): + sizes = [sizes] * ((x_size + sizes - 1) // sizes) + result = [] + start = 0 + for size in sizes: + end = start + size + result.append(slice_(x, dim, start, end)) + start = end + return result + + +@register_lowering(aten.split_with_sizes, type_promotion_kind=None) +def split_with_sizes(x, sizes, dim=0): + return split(x, sizes, dim) + + +@register_lowering(aten.unbind, type_promotion_kind=None) +def unbind(x, dim=0): + dim = _validate_dim(x, dim, 0) + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + result = [] + for i in range(x_size): + result.append(select(x, dim, i)) + return result + + +@register_lowering(aten.unfold, type_promotion_kind=None) +def unfold(x, dimension, size, step): + sizes = x.get_size() + ndim = len(sizes) + dim = canonicalize_dim(ndim, dimension) + + if ndim == 0: + return slice_(unsqueeze(x, 0), end=size) + + dim_size = sizes[dim] + sizevars = V.graph.sizevars + sizevars.guard_leq(size, dim_size) + sizevars.guard_lt(0, step) # type: ignore[arg-type] + + new_dim_size = FloorDiv(dim_size - size, step) + 1 + if sizevars.size_hint(dim_size) > 0: + x.mark_reuse(sizevars.size_hint(CeilDiv(new_dim_size * size, dim_size))) + + out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size] + + def reindexer(idx): + dim_idx = idx[-1] + idx[dim] * step + return (*idx[:dim], dim_idx, *idx[dim + 1 : -1]) + + return TensorBox(ir.GenericView.create(x, out_size, reindexer)) + + +@register_lowering(aten.unsqueeze, type_promotion_kind=None) +def unsqueeze(x, dim): + dim = _validate_dim(x, dim, 1) + new_shape = list(x.get_size()) + new_shape.insert(dim, sympy.Integer(1)) + return view(x, new_shape) + + +@register_lowering(aten.unsqueeze_, type_promotion_kind=None) +def unsqueeze_(x, dim): + val = unsqueeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +def _validate_dim(x, dim, offset=0): + assert isinstance(dim, int) + ndim = len(x.get_size()) + if dim < 0: + dim += ndim + offset + assert 0 <= dim < ndim + offset + return dim + + +@register_lowering(aten.glu) +def glu(x, dim=-1): + dim = _validate_dim(x, dim, 0) + # TODO: don't guard on static shape here + new_len = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) // 2 + a = slice_(x, dim, 0, new_len) + b = slice_(x, dim, new_len, new_len * 2) + return mul(a, sigmoid(b)) + + +def register_onednn_fusion_ops(): + if torch._C._has_mkldnn: + cpu_needs_realized_inputs = [ + torch.ops.mkldnn._convolution_pointwise, + torch.ops.mkldnn._convolution_pointwise_, + torch.ops.mkldnn._convolution_transpose_pointwise, + torch.ops.mkldnn._linear_pointwise, + aten.mkldnn_rnn_layer.default, + torch.ops.onednn.qconv2d_pointwise, + ] + + @register_lowering(torch.ops.mkldnn._convolution_pointwise) + def convolution_unary( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + ir.ConvolutionUnary.create( + x, + weight, + bias, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise.binary) + def convolution_binary( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + ir.ConvolutionBinary.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary) + def convolution_binary_inplace( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + ir.ConvolutionBinaryInplace.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._linear_pointwise) + def linear_unary( + x: TensorBox, w: TensorBox, b: TensorBox, attr, scalars, algorithm + ): + return TensorBox.create( + ir.LinearUnary.create(x, w, b, attr, scalars, algorithm) + ) + + @register_lowering(torch.ops.mkldnn._linear_pointwise.binary) + def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr): + return TensorBox.create(ir.LinearBinary.create(x, y, w, b, attr)) + + @register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise) + def convolution_transpose_unary( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + output_padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + ir.ConvolutionTransposeUnary.create( + x, + weight, + bias, + padding, + output_padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering(aten.mkldnn_rnn_layer.default) + def mkldnn_rnn_layer( + x: TensorBox, + w0: TensorBox, + w1: TensorBox, + w2: TensorBox, + w3: TensorBox, + hx: TensorBox, + cx: TensorBox, + reverse: bool, + batch_sizes: List[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + return pytree.tree_map( + TensorBox.create, + ir.MkldnnRnnLayer.create( + x, + w0, + w1, + w2, + w3, + hx, + cx, + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ), + ) + + @register_lowering(torch.ops.onednn.qconv2d_pointwise, type_promotion_kind=None) + def qconvolution_unary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + ir.QConvPointWisePT2E.create( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering( + torch.ops.onednn.qconv2d_pointwise.binary, type_promotion_kind=None + ) + def qconvolution_binary( + x: TensorBox, + x_scale, + x_zp, + accum: TensorBox, + accum_scale, + accum_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ): + if ( + binary_attr == "sum" + and output_dtype in [torch.float32, torch.bfloat16] + and accum.get_dtype() in [torch.float32, torch.bfloat16] + and accum.get_dtype() != output_dtype + ): + # For int8-mixed-bf16 quantization and inplace add, + # there is case when accum dtype is float32 but output dtype is bfloat16. + # Since the accum will be inplaced changed with post op sum, + # we will do accum dtype convertion here. + accum = to_dtype(accum, output_dtype) + return TensorBox.create( + ir.QConvPointWiseBinaryPT2E.create( + x, + x_scale, + x_zp, + accum, + accum_scale, + accum_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ) + ) + + @register_lowering(torch.ops.onednn.qlinear_pointwise, type_promotion_kind=None) + def qlinear_unary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + ir.QLinearPointwisePT2E.create( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ) + ) + + if torch._C.has_mkl: + cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) + + @register_lowering(torch.ops.mkl._mkl_linear) + def mkl_packed_linear( + x: TensorBox, + packed_w: TensorBox, + orig_w: TensorBox, + b: TensorBox, + batch_size, + ): + result = TensorBox.create( + ir.MKLPackedLinear.create(x, packed_w, orig_w, batch_size) + ) + if b is not None: + result = add(result, b) + return result + + add_needs_realized_inputs(cpu_needs_realized_inputs) + else: + pass + + +register_onednn_fusion_ops() + + +def fallback_handler(kernel, add_to_fallback_set=True): + if add_to_fallback_set: + fallbacks.add(kernel) + + def handler(*args, **kwargs): + return pytree.tree_map( + TensorBox.create, ir.FallbackKernel.create(kernel, *args, **kwargs) + ) + + return handler + + +@functools.lru_cache(None) +def _warn_complex_not_supported(): + warnings.warn( + "Torchinductor does not support code generation for complex operators. Performance may be worse than eager." + ) + + +# There are some types (CPU) which we accept as input but not as +# output. +def unsupported_input_tensor(t: torch._subclasses.FakeTensor, parent=None): + "Do not support reading or writing to this tensor" + if t.is_complex(): + # Complex views are supported with IR ComplexView + if parent and parent.target in ( + torch.ops.aten.view.dtype, + torch.ops.prims.convert_element_type.default, + ): + return False + _warn_complex_not_supported() + return True + return False + + +def unsupported_output_tensor(t: torch._subclasses.FakeTensor, parent=None): + "Do not support writing tensor but can read from it" + if unsupported_input_tensor(t, parent): + return True + return t.is_cpu and config.disable_cpp_codegen + + +def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True): + # Custom fallback lowering + if node.target is aten.view_as_complex.default: + return False + + # We should be able to remove this special case once `disable_cpp_codegen` is killed. + if node.target is aten.lift_fresh_copy.default: + return False + + def check_skip_condition(node, parent, is_output): + if not isinstance(node, torch.fx.Node): + return False + + if "val" not in node.meta: + return False + + for meta in pytree.tree_leaves(node.meta["val"]): + if not isinstance(meta, torch._subclasses.FakeTensor): + continue + + if is_output: + if unsupported_output_tensor(meta, parent): + return True + else: + if unsupported_input_tensor(meta, parent): + return True + + return False + + # only skip codegen if there is a cpu output, not input + for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs): + if check_skip_condition(arg, node, is_output=False): + return True + + return check_skip_condition(node, node, is_output=True) + + +def make_fallback(op, layout_constraint=None, warn=True): + assert op not in decompositions, f"both a fallback and a decomp for same op: {op}" + if ( + warn + and bool(os.getenv("CI")) + and get_decompositions([op]) + # if fallback_random, we allow not decomposing random + and not ( + config.fallback_random + and op in torch._decomp.decompositions_for_rng.extra_random_decomps + ) + ): + # Note: 'warn' is holdover from when this was a warning, but for ops that previously + # set warn=False we do not want a CI error. + # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not + # likely to be triggered preferentially on one CI config over another. + if torch._dynamo.config.suppress_errors: + torch._dynamo.config.suppress_errors = False + log.warning( + "A make_fallback error occurred in suppress_errors config," + " and suppress_errors is being disabled to surface it." + ) + raise AssertionError( + f"make_fallback({op}): a decomposition exists, we should switch to it." + " To fix this error, either add a decomposition to core_aten_decompositions (preferred)" + " or inductor_decompositions, and delete the corresponding `make_fallback` line." + " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.", + ) + + def register_fallback(op_overload): + add_needs_realized_inputs(op_overload) + if layout_constraint is not None: + add_layout_constraint(op_overload, layout_constraint) + return register_lowering(op_overload, type_promotion_kind=None)( + fallback_handler(op_overload) + ) + + if isinstance(op, torch._ops.OpOverloadPacket): + for ol in op.overloads(): + op_overload = getattr(op, ol) + register_fallback(op_overload) + elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + register_fallback(op) + else: + raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}") + + +def philox_rand_offset(shape): + """ + TorchInductor offset calculation differs from PyTorch eager offset + calculation for random ops (tl.rand vs torch.rand). In future, we should + strive for same impl for tl.rand and torch.rand. + """ + numel = 1 + for s in shape: + numel = numel * s + return tensor(numel, dtype=torch.int64) + + +@register_lowering(torch.ops.rngprims.philox_rand, type_promotion_kind=None) +def philox_rand(size, seed, offset, stride, device, dtype): + # stride arg is optional and will be used in future for distributed random + # ops. Currently, its unused. + random_pos = ir.FixedLayout( + device, + dtype, + size, + ir.FlexibleLayout.contiguous_strides(size), + ).make_indexer() + seed_loader = seed.make_loader() + offset_loader = offset.make_loader() + + def inner_fn(index): + # Both seed and offset in the philox_rand op are tensors. + # torch seed and offsets are of type int64, but tl.rand accepts int32 + seed_index_expr = ops.to_dtype(seed_loader([]), torch.int32) + offset_index_expr = ops.to_dtype(offset_loader([]), torch.int32) + # Get the offset'd position + rand_index_expr = ops.add( + ops.index_expr(random_pos(index), torch.int32), offset_index_expr + ) + result = ops.rand( + seed_index_expr, + rand_index_expr, + ) + return ops.to_dtype(result, dtype) + + random_values_node = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + ) + + offset_node = philox_rand_offset(size) + return random_values_node, offset_node + + +@register_lowering(aten.native_dropout, type_promotion_kind=None) +def native_dropout(x, p, train): + if config.fallback_random: + return pytree.tree_map( + TensorBox.create, + ir.FallbackKernel.create(aten.native_dropout.default, x, p, train), + ) + else: + raise AssertionError("should be handled in replace_random.py") + + +@register_lowering(aten.bernoulli_, type_promotion_kind=None) +def bernoulli_(x, *args): + assert config.fallback_random or x.get_device() == torch.device( + "cpu" + ), "this should be handled in decomps unless config.fallback_random or the device is CPU" + x.realize() + ir.InplaceBernoulliFallback(x, *args) + return x + + +@register_lowering(aten.bernoulli.p, type_promotion_kind=None) +def bernoulli_p(x, *args): + assert config.fallback_random or x.get_device() == torch.device( + "cpu" + ), "this should be handled in decomps unless config.fallback_random or the device is CPU" + return bernoulli_(clone(x), *args) + + +# This shouldn't be called in general +@register_lowering(aten._foobar) +def _foobar(_): + raise AssertionError() + + +@functools.lru_cache(1) +def _warn_triton_random(salt): + log.info("using triton random, expect difference from eager") + + +def warn_triton_random(): + # only warn once per graph + _warn_triton_random(V.graph.creation_time) + + +fallback_rand_default = fallback_handler(aten.rand.default) +fallback_rand_generator = fallback_handler(aten.rand.generator) +fallback_randn_default = fallback_handler(aten.randn.default) +fallback_randn_generator = fallback_handler(aten.randn.generator) +make_fallback(aten.randint) + + +@register_lowering(aten.rand) +def rand(*args, **kwargs): + if kwargs.get("generator", None) is not None: + return fallback_rand_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_rand_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(aten.randn) +def randn(*args, **kwargs): + if kwargs.get("generator", None) is not None: + return fallback_randn_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_randn_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(inductor_prims.force_stride_order, type_promotion_kind=None) +def inductor_force_stride_order(input_tensor, stride): + stride_order = ir.get_stride_order(stride) + return ir.ExternKernel.require_stride_order(input_tensor, stride_order) + + +@register_lowering(inductor_prims.seed, type_promotion_kind=None) +def inductor_seed(device: torch.device): + raise AssertionError("should be handled in fuse_seed_creation_pass()") + + +@register_lowering(inductor_prims.seeds, type_promotion_kind=None) +def inductor_seeds(count, device): + warn_triton_random() + return TensorBox.create(ir.RandomSeeds(count, decode_device(device))) + + +@register_lowering(inductor_prims.lookup_seed, type_promotion_kind=None) +def inductor_lookup_seed(seeds, index): + def inner_fn(_): + return ops.load_seed(seeds.get_name(), index) + + return Pointwise.create( + device=seeds.get_device(), + dtype=seeds.get_dtype(), + inner_fn=inner_fn, + ranges=[], + ) + + +@register_lowering(inductor_prims.random, type_promotion_kind=None) +def inductor_random(size: List[int], seed: TensorBox, mode: str, *, offset: int = 0): + assert not config.fallback_random + assert mode in ("rand", "randn") + size = [*size] + dtype = torch.float32 + device = seed.get_device() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return getattr(ops, mode)( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + ) + + result = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + result.realize() + return result + + +@register_lowering(inductor_prims.randint, type_promotion_kind=None) +def inductor_randint( + low: int, high: int, size: List[int], seed: TensorBox, *, offset: int = 0 +): + assert not config.fallback_random + size = [*size] + dtype = torch.int64 + device = seed.get_device() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return ops.randint64( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + low, + high, + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + + +@register_lowering(aten.bucketize, type_promotion_kind=None) +def bucketize( + input: TensorBox, + boundaries: TensorBox, + *, + out_int32: bool = False, + right: bool = False, +): + assert len(boundaries.get_size()) == 1 + + if not (is_triton(input) and is_triton(boundaries)): + return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)( + input, boundaries, out_int32=out_int32, right=right + ) + + # The entire boundaries tensor needs to be used by ops.bucketize, so we + # need to realize it into global memory; or in other words, we can't + # guarantee that boundaries.get_name() (used below) will exist unless + # we call boundaries.realize(). + boundaries.realize() + boundaries_size = boundaries.get_size()[0] + boundaries_loader = boundaries.make_loader() + device = input.get_device() + input_loader = input.make_loader() + + index_dtype = torch.int32 if out_int32 else torch.int64 + + def inner_fn(index): + val = input_loader(index) + indices = ops.bucketize( + val, + boundaries.get_name(), + boundaries_size, + index_dtype, + right, + ) + + return indices + + return Pointwise.create( + device=device, + dtype=index_dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +def require_dense(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs) + ) + return args, kwargs + + +def require_contiguous(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs) + ) + return args, kwargs + + +def require_channels_last(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs) + ) + return args, kwargs + + +def constrain_to_fx_strides(fx_node, *args, **kwargs): + def apply_constraint(arg, fx_arg): + if isinstance(arg, ir.IRNode): + stride_order = ir.get_stride_order(fx_arg.meta["val"].stride()) + return ir.ExternKernel.require_stride_order(arg, stride_order) + return arg + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +# TODO(jansel): we should implement decomps or lowerings for these +# https://github.com/pytorch/torchdynamo/issues/327 +FALLBACK_ALLOW_LIST = { + "torchvision::roi_align", +} + + +def sdpa_constraint(fx_node, *args, **kwargs): + # sdpa requires dense last dimension] + + def apply_constraint(arg, fx_arg): + if not isinstance(arg, ir.IRNode): + return arg + + meta_val = fx_arg.meta["val"] + if not meta_val.is_cuda: + return arg + + stride_order = ir.get_stride_order(meta_val.stride()) + if stride_order and stride_order[-1] != 0: + # contiguous stride order + stride_order = list(reversed(range(len(arg.get_size())))) + + # This is the minimum alignment required by SDPA kernels for attention_bias. + # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask + ALIGNMENT = 8 + + assert isinstance(arg, TensorBox) + if len(arg.get_size()) not in (3, 4): + return arg + + def is_aligned_realized_tensor(x): + aligned_strides = all( + (V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0 + for i in range(len(x.get_stride()) - 1) + ) + return ( + V.graph.sizevars.size_hint(x.get_stride()[-1]) + ) == 1 and aligned_strides + + try: + arg.get_stride() + if is_aligned_realized_tensor(arg): + return arg + except AttributeError: + pass + + def is_aligned(x): + return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 + + if isinstance(arg.data, ir.BaseView): + if not is_aligned(arg): + if is_aligned(arg.unwrap_view()): + return arg + + return ir.ExternKernel.require_stride_order(arg, stride_order) + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +# WIP +make_fallback(aten.index_reduce) # @pearu +make_fallback(aten._adaptive_avg_pool3d) # @isuruf +make_fallback(aten.adaptive_max_pool3d) # @isuruf +make_fallback(aten.avg_pool3d) # @isuruf +make_fallback(aten.fractional_max_pool3d) # @isuruf +make_fallback(aten.max_pool3d_with_indices) # @isuruf (can this one be implemented?) +make_fallback(aten.cummax) # @isuruf +make_fallback(aten.cummin) # @isuruf + + +# 1) Easy +make_fallback(aten.uniform, warn=False) +make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py) +make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks +make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl? +make_fallback(aten.searchsorted) # bucketized is implemented (see eager impl) + + +# 1.5) Easy or Impossible +make_fallback(aten._cdist_forward) # p=2 should be feasible +make_fallback(aten._cdist_backward) +# See resize_storage_bytes +make_fallback(aten.resize) +make_fallback(aten.resize_) +make_fallback(aten.resize_as) +make_fallback(aten.resize_as_) + + +# 2) Medium +make_fallback(aten.max_unpool2d) +make_fallback(aten.max_unpool3d) +make_fallback(aten._trilinear) + + +# 3) Difficult +# Scans +# See the discussion at +# https://dev-discuss.pytorch.org/t/pytorch-sparse-gnn-compiler-rfc/1644/19 +make_fallback(aten.segment_reduce.default) +make_fallback(aten._segment_reduce_backward.default) + +# Histogram (need to implement Histogram IR) +make_fallback(aten.histc) +make_fallback(aten.histogram.bin_ct) +make_fallback(aten._histogramdd_bin_edges.default) +make_fallback(aten._histogramdd_from_bin_cts.default) + +# Need templated kernel +make_fallback(aten.addbmm) +make_fallback(aten.addmv, warn=False) +make_fallback(aten._addmm_activation, warn=False) + +# Need templated kernel. Probably impossible to write efficiently +make_fallback(aten.convolution_backward, constrain_to_fx_strides) +make_fallback(aten._cudnn_rnn, require_dense) +make_fallback(aten._cudnn_rnn_backward, require_contiguous) + +# Haven't checked but sound difficult / impossible +make_fallback(aten._embedding_bag, require_contiguous) +make_fallback(aten._embedding_bag_forward_only, require_contiguous) +make_fallback(aten._embedding_bag_dense_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._fused_moving_avg_obs_fq_helper) +make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) + + +# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp +make_fallback(aten.avg_pool3d_backward) +make_fallback(aten.max_pool3d_with_indices_backward) +make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) +make_fallback(aten._adaptive_avg_pool3d_backward) +make_fallback(aten.adaptive_max_pool2d_backward) +make_fallback(aten.adaptive_max_pool3d_backward) +make_fallback(aten.fractional_max_pool2d_backward) +make_fallback(aten.fractional_max_pool3d_backward) +make_fallback(aten.replication_pad1d_backward) +make_fallback(aten.replication_pad2d_backward) +make_fallback(aten.upsample_linear1d_backward) +make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) +make_fallback(aten.upsample_trilinear3d_backward) +make_fallback(aten.grid_sampler_2d_backward, require_dense) +make_fallback(aten._pdist_backward) + + +# 5) Impossible (missing triton/CPU features) + +# Sorting / Sorting-like +make_fallback(aten.sort) +make_fallback(aten.sort.stable) +make_fallback(aten.kthvalue) +make_fallback(aten.topk) +make_fallback(aten.mode) +make_fallback(aten.median) +make_fallback(aten.nanmedian) +make_fallback(aten.randperm) + +# Linalg +make_fallback(aten._linalg_det) +make_fallback(aten.linalg_householder_product) +make_fallback(aten.linalg_inv_ex) +make_fallback(aten.linalg_ldl_factor_ex) +make_fallback(aten.linalg_ldl_solve) +make_fallback(aten.linalg_lu) +make_fallback(aten.linalg_lu_factor_ex) +make_fallback(aten.linalg_lu_solve) +make_fallback(aten.linalg_matrix_exp) +make_fallback(aten.linalg_qr) +make_fallback(aten._linalg_slogdet) +make_fallback(aten._linalg_solve_ex) +make_fallback(aten.linalg_solve_triangular) +make_fallback(aten._linalg_svd) +make_fallback(aten.lu_unpack) +make_fallback(aten.ormqr) +make_fallback(aten._linalg_check_errors) +make_fallback(aten.linalg_pinv.atol_rtol_tensor) +make_fallback(aten._linalg_eigh) +make_fallback(aten.triangular_solve) +make_fallback(aten.linalg_cholesky_ex) +make_fallback(aten.cholesky_inverse) +make_fallback(aten.cholesky_solve) +make_fallback(aten.geqrf) +make_fallback(aten._fft_r2c) # needs complex as well + +# Data dependent (are these necessary?) +make_fallback(aten.nonzero.default) + +# Misc +make_fallback(aten.gcd.default, warn=False) +make_fallback(aten._thnn_fused_lstm_cell, require_dense) +make_fallback(torch._prims.rng_prims.run_and_save_rng_state) +make_fallback(torch._prims.rng_prims.run_with_rng_state) + +# Implmented / Half implemented +# Scans. Implemented for CUDA, missing CPU +make_fallback(aten.masked_scatter) +make_fallback(aten.masked_scatter_backward) + +# Complex number support +make_fallback(aten.view_as_complex, require_contiguous) +make_fallback(aten.angle) # needs complex + +# Needs efficentzerotensor +make_fallback(aten._efficientzerotensor) + +# Needs Sparse +make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) +make_fallback(aten.to_sparse) +make_fallback(aten._to_sparse) + +# Needs dimname support +make_fallback(aten.zeros.names) + + +# 6) Pattern-matched +make_fallback( + aten._scaled_dot_product_efficient_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_efficient_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback(aten._flash_attention_forward.default, sdpa_constraint) +make_fallback(aten._flash_attention_backward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_forward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_backward.default, sdpa_constraint) +make_fallback(aten._scaled_mm.default, constrain_to_fx_strides) + + +# Register with type_promotion_kind None. +# For example, fp16.copy_(fp32) should **not** promote the first input's dtype. +@register_lowering(aten.copy, type_promotion_kind=None) +def copy(self, src, non_blocking=False): + x = src + if self.get_device() != src.get_device(): + x = to_device(x, self.get_device()) + if self.get_dtype() != src.get_dtype(): + x = to_dtype(x, self.get_dtype()) + + if self.get_size() != src.get_size(): + out = expand(x, self.get_size()) + return clone(out) + return clone(x) + + +@register_lowering(aten.clone) +def clone(x, *, memory_format=None): + # TODO(jansel): memory format + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=list(x.get_size()), + ) + + +def clone_preserve_reinterpret_view(x): + reinterpret_view_layouts = [] + if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView): + x = x.data # unwrap TensorBox + while isinstance(x, ir.ReinterpretView): + reinterpret_view_layouts.append(x.get_layout()) + x = x.data + x = TensorBox(x) + + x = clone(x) + + if reinterpret_view_layouts: + x = x.data # unwrap TensorBox + for layout in reinterpret_view_layouts[::-1]: + x = ir.ReinterpretView(x, layout) + x = TensorBox(x) + + return x + + +if hasattr(aten, "lift_fresh_copy"): + register_lowering(aten.lift_fresh_copy)(clone) + + +@register_lowering(prims.iota) +def iota( + length, + *, + start, + step, + dtype, + device, + requires_grad, +): + def fn(index): + return ops.index_expr(step * index[0] + start, dtype=dtype) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=fn, + ranges=[length], + ) + + +@register_lowering(aten.select_scatter, type_promotion_kind=None) +def select_scatter(x, src, dim: int, index: int): + assert x.get_dtype() == src.get_dtype() + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)): + index = index + x.get_size()[dim] + V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type] + V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type] + src = expand(unsqueeze(src, dim), x.get_size()) + src_loader = src.make_loader() + + def inner_fn(idx): + return ops.where( + ops.eq( + ops.index_expr(idx[dim], torch.int32), + ops.index_expr(index, torch.int32), + ), + src_loader(idx), + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + ) + + +@register_lowering(aten.slice_scatter, type_promotion_kind=None) +def slice_scatter(x, src, dim=0, start=None, end=None, step=1): + assert x.get_dtype() == src.get_dtype() + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + dim_size = x.get_size()[dim] + + start, end = ir.SliceView.normalize_start_end(x, dim, start, end) + + src_size = list(x.get_size()) + src_size[dim] = FloorDiv(end - start + (step - 1), step) + src = expand(src, src_size) + src_loader = src.make_loader() + + def inner_fn(idx): + if start == 0 and end == dim_size and step == 1: + # selecting every element is the same as just src.clone() + return src_loader(idx) + + idx_dim = ops.index_expr(idx[dim], torch.int64) + src_idx = list(idx) + src_idx[dim] = FloorDiv(idx[dim] - start, step) + + mask = [] + if start != 0: + mask.append( + ops.ge( + idx_dim, + ops.index_expr(sympy.expand(start), torch.int64), + ) + ) + if end != dim_size: + mask.append( + ops.lt( + idx_dim, + ops.index_expr(sympy.expand(end), torch.int64), + ) + ) + if step != 1: + mask.append( + ops.eq( + ops.index_expr( + ModularIndexing(idx[dim] - start, 1, step), torch.int64 + ), + ops.constant(0, torch.torch.int64), + ) + ) + assert mask + mask = functools.reduce(ops.and_, mask) + src_val = ops.masked( + mask, + lambda: src_loader(src_idx), + 0 if is_integer_type(x) else 0.0, + ) + return ops.where( + mask, + src_val, + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + ) + + +def _unwrap(x): + if isinstance(x, (list, tuple)) and len(x) > 0: + return _unwrap(x[0]) + return x + + +@register_lowering([torch.tensor, aten.scalar_tensor]) +def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + if isinstance(_unwrap(data), int): + dtype = dtype or torch.int64 + else: + dtype = dtype or torch.get_default_dtype() + + ranges: List[sympy.Expr] = [] + + if isinstance(data, sympy.Expr): + + def inner_fn(index): + return ops.index_expr(data, dtype) + + elif isinstance(data, (float, int)): + + def inner_fn(index): + return ops.constant(data, dtype) + + elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8: + # inline small tensors + ranges.append(sympy.Integer(len(data))) + + def inner_fn(index): + def binary_search(start, end): + assert start < end + if end - start == 1: + return ops.constant(data[start], dtype) + mid = (end - start) // 2 + start + return ops.where( + ops.lt( + ops.index_expr(index[0], torch.int64), + ops.constant(mid, torch.int64), + ), + binary_search(start, mid), + binary_search(mid, end), + ) + + if len(data) == 0: + return ops.constant(0, dtype) + return binary_search(0, len(data)) + + else: + return V.graph.add_tensor_constant( + torch.tensor(data, dtype=dtype, device=device) + ) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + ) + + +@register_lowering(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if isinstance(data, TensorBox): + if dtype is not None: + data = to_dtype(data, dtype) + if device is not None: + data = to_device(data, device) + return data + return tensor(data, dtype=dtype, device=device) + + +@register_lowering(torch.LongTensor) +def long_tensor(data): + return tensor(data, dtype=torch.int64) + + +@register_lowering(aten._local_scalar_dense) +def _local_scalar_dense(data): + # This is interesting! Most lowerings return tensors, so you can just + # return the buffer you allocated and it will get used (or not used, if + # it's dead.) But _local_scalar_dense (aka item) returns an int, + # not a Tensor, so you would have a type mismatch if you return a buffer; + # we are obligated to return a sympy expression instead. However, + # we need to actually codegen the .item() call somehow. We do this + # by registering a faux buffer for the DynamicScalar IR node, which is + # solely responsible for generating this .item(). The buffer is + # not used for anything (notice we discard it); at codegen time, + # the "buffer" just gets assigned None. + sym = V.graph.current_node.meta["val"].node.expr + buffer = ir.DynamicScalar(sym, data) + buffer.name = V.graph.register_buffer(buffer) + return sym + + +@register_lowering(aten._assert_scalar) +def _assert_scalar(data, msg): + buffer = ir.AssertScalar(data, msg) + # This buffer isn't used by anyone (it returns None), so we must explicitly register it + buffer.name = V.graph.register_buffer(buffer) + return buffer + + +def _full(fill_value, device, dtype, size): + value = fill_value + if not isinstance(fill_value, (int, float)) and hasattr(value, "value"): + value = value.value + + if isinstance(value, (int, float)): + + def inner_fn(index): + return ops.constant(value, dtype) + + elif isinstance(value, sympy.Expr): + + def inner_fn(index): + return ops.index_expr(value, dtype) + + else: + assert len(value.get_size()) == 0 + value_loader = value.make_loader() + + def inner_fn(index): + return value_loader([]) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + ) + + +@register_lowering(aten.full_like, type_promotion_kind=None) +def full_like(x, fill_value, **kwargs): + return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) + + +def tensor_constructor(fill_value): + # torch.zeros, torch.ones, etc + def inner( + *size, + names=None, + dtype=None, + device=None, + layout=None, + pin_memory=False, + memory_format=None, + ): + assert_nyi(names is None, "named tensors") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + device = decode_device(device) + dtype = dtype or torch.get_default_dtype() + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + # See https://github.com/pytorch/pytorch/issues/118102 + # All sizes at lowering time should be sympy.Symbol, not SymInt! + for s in size: + assert not isinstance(s, torch.SymInt) + size = [sympy.expand(s) for s in size] + return _full(fill_value, device, dtype, size) + + return inner + + +@register_lowering([torch.empty, aten.empty]) +def empty( + *size, + names=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, +): + assert_nyi(names is None, "named tensors") + device = decode_device(device) + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +def create_tensor_like(creation_fn): + """ + Shim to convert X_like(...) into X(...). For example zeros_like() into zeros(). + """ + + def _constant_like( + x, *, dtype=None, device=None, layout=None, pin_memory=False, memory_format=None + ): + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + if dtype is None: + dtype = x.get_dtype() + else: + dtype = decode_dtype(dtype) + device = device or x.get_device() + size = list(x.get_size()) + return creation_fn( + size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory + ) + + return _constant_like + + +def constant_like(fill_value): + return create_tensor_like(tensor_constructor(fill_value)) + + +empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty)) +ones_like = create_tensor_like(tensor_constructor(1)) +zeros_like = create_tensor_like(tensor_constructor(0)) + + +def new_constant(fill_value): + def _new_constant( + x, size, *, dtype=None, layout=None, device=None, pin_memory=None + ): + assert isinstance(size, (list, tuple)) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = decode_dtype(dtype) or x.get_dtype() + device = device or x.get_device() + size = [sympy.Integer(s) for s in size] + return _full(fill_value, device, dtype, size) + + return _new_constant + + +@register_lowering(aten.new_empty) +def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_lowering(aten.empty_strided) +def empty_strided( + size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + assert isinstance(size, (list, tuple)) + assert isinstance(stride, (list, tuple, type(None))) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = decode_dtype(dtype) or torch.get_default_dtype() + device = device or torch.tensor(0.0).device + pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size) + pointwise.realize() + buffer = pointwise.data.data + # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode + buffer.data.ranges = [0] * len(size) + assert isinstance(buffer, ir.ComputedBuffer) + size = [sympy.expand(s) for s in size] + stride = ( + [sympy.expand(s) for s in stride] + if stride + else ir.FlexibleLayout.contiguous_strides(size) + ) + buffer.layout = ir.FixedLayout( + device=device, + dtype=dtype, + size=size, + stride=stride, + ) + return pointwise + + +@register_lowering(aten.new_empty_strided) +def new_empty_strided( + x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, stride, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_lowering(prims.copy_strided.default) +def copy_strided(x, stride): + stride = [V.graph.sizevars.size_hint(s) for s in stride] + stride_order = sorted(range(len(stride)), key=stride.__getitem__) + return ir.ExternKernel.require_stride_order(x, stride_order) + + +@register_lowering([torch.full, aten.full]) +def full(size, fill_value, **kwargs): + assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition" + return tensor_constructor(fill_value)(size, **kwargs) + + +@register_lowering(aten.gather, type_promotion_kind=None) +def gather(x, dim, index, sparse_grad=False): + # sparse_grad doesn't affect forward computation, + # and backward tracing is taken care of by AOT Autograd + assert isinstance(x, TensorBox) + assert index.get_dtype() == torch.int64 + size = x.get_size() + offset = len(size) == 0 + dim = _validate_dim(x, dim, offset) + + x_loader = x.make_loader() + index_loader = index.make_loader() + + def fn(idx): + idx = list(idx) + if len(idx) != 0: + idx[dim] = ops.indirect_indexing(index_loader(idx), size[dim]) + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + ) + + +@register_lowering(aten.embedding, type_promotion_kind=None) +def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + assert not sparse + assert isinstance(weight, TensorBox) + assert isinstance(indices, TensorBox) + assert "int" in str(indices.get_dtype()) + + weight_loader = weight.make_loader() + indices_loader = indices.make_loader() + indices_ndim = len(indices.get_size()) + weight_size = weight.get_size() + new_size = [*indices.get_size(), *weight_size[1:]] + + def fn(idx): + assert len(idx) == len(new_size), f"{idx} != {new_size}" + var_index = indices_loader(idx[:indices_ndim]) + weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [ + *idx[indices_ndim:] + ] + return weight_loader(weight_idx) + + return Pointwise.create( + device=weight.get_device(), + dtype=weight.get_dtype(), + inner_fn=fn, + ranges=new_size, + ) + + +def check_and_broadcast_indices(indices, device): + assert all( + i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8) + for i in indices + if i is not None + ), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}" + if any( + i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None + ): + raise NotImplementedError("Fallback for bool indices") + + valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)] + assert len(valid_idxs) > 0, "requires at least 1 non-None index" + new_indices = [None] * len(indices) + for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])): + # Eager allows indices to be CPU tensor when running on CUDA + # FIXME: Calling to_device(x, device) should work but + # test_advancedindex_mixed_cpu_devices still fails + if x.get_device() != device: + raise NotImplementedError("Fallback when indices is on a different device") + new_indices[i] = x + return new_indices, valid_idxs + + +def index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + x_loader, + check, +): + # Note that behavior of indexing differs when there are non consecutive + # tensors. In this case, the tensor index is pulled to the beginning. + # + # Suppose a = torch.arange(3 * 4 * 5 * 6 * 7).view(3, 4, 5, 6, 7) + # x = torch.tensor[1,2] + # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will + # be pulled to the front. + non_consecutive_tensors = False + for previous, current in zip(tensor_indices, tensor_indices[1:]): + if current - previous != 1: + non_consecutive_tensors = True + + output_size = [x_size[i] for i, val in enumerate(indices) if val is None] + output_size = [*output_size, *x_size[len(output_size) + len(tensor_indices) :]] + + first_tensor_index = tensor_indices[0] + if non_consecutive_tensors: + output_size = tensor_size + output_size + else: + output_size = ( + output_size[:first_tensor_index] + + tensor_size + + output_size[first_tensor_index:] + ) + + def fn(idx): + assert len(idx) == len(output_size) + assert len(indices_loaders) == len(indexed_size) + + rank = len(tensor_size) + new_index = [] + first_tensor_index = tensor_indices[0] + start_offset = 0 if non_consecutive_tensors else first_tensor_index + next_idx = 0 + for i in range(tensor_indices[-1] + 1): + if i == start_offset: + next_idx += rank + if indices[i] is None: + assert next_idx < len(idx) + new_index.append(idx[next_idx]) + next_idx += 1 + else: + loader = indices_loaders[i] + assert loader is not None + size = indexed_size[i] + new_index.append( + ops.indirect_indexing( + loader(idx[start_offset : start_offset + rank]), + size, + check=check, + ) + ) + new_index = [ + *new_index, + *idx[next_idx:], + ] + return new_index if x_loader is None else x_loader(new_index) + + return output_size, fn + + +def index_impl(x, indices, check): + assert isinstance(indices, (list, tuple)) + x_loader = x.make_loader() + indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device()) + assert len(tensor_indices) > 0, "Must have at least one valid idx" + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + # no guards on output size, all the guards are set in broadcast_tensors + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + + x_size = x.get_size() + + indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None] + if 0 in indexed_size and 0 not in tensor_size: + raise IndexError("index is out of bounds for dimension with size 0") + + indexed_size = [x_size[i] for i in range(len(indices))] + output_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + x_loader, + check=check, + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=output_size, + ) + + +@register_lowering(aten.index, type_promotion_kind=None) +def index(x, indices): + try: + return index_impl(x, indices, check=True) + except NotImplementedError: + # Fallback to ATen for boolean indexing + x.realize() + return fallback_handler(aten.index.Tensor, add_to_fallback_set=False)( + x, indices + ) + + +@register_lowering(aten._unsafe_index, type_promotion_kind=None) +def _unsafe_index(x, indices): + return index_impl(x, indices, check=False) + + +# All the indexing decompositions are written in terms of index, index_put, and index_put_ +# We cannot have this lowering as a decomposition as it introduces +# mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead +# code elimination and common subexpression elimination optimizations, which +# assume graphs to be side-effect free. More details at +# https://github.com/pytorch/torchdynamo/issues/1235 +# and +# https://github.com/pytorch/torchdynamo/issues/1863 +@register_lowering(aten.index_put) +def index_put(x, indices, values, accumulate=False): + return index_put_(clone(x), indices, values, accumulate) + + +@register_lowering(aten._unsafe_index_put) +def _unsafe_index_put(x, indices, values, accumulate=False): + return index_put_impl_(clone(x), indices, values, accumulate, check=False) + + +def index_put_as_masked_fill(self, indices, value, accumulate): + if value.get_device() != self.get_device(): + value = to_device(value, self.get_device()) + if accumulate: + value = add(self, value) + return mutate_to(self, where(indices[0], value, self)) + + +def index_put_fallback(self, indices, values, accumulate): + deterministic = torch.are_deterministic_algorithms_enabled() + if is_triton(values) and (accumulate or deterministic): + msg = ( + "index put with accumulate." + if not deterministic + else "deterministic index put." + ) + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate) + return self + + +@register_lowering(aten.index_put_, type_promotion_kind=None) +def index_put_(self, indices, values, accumulate=False): + return index_put_impl_(self, indices, values, accumulate, check=True) + + +@register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None) +def _unsafe_index_put_(self, indices, values, accumulate=False): + return index_put_impl_(self, indices, values, accumulate, check=False) + + +def needs_fallback_due_to_atomic_add_limitations(dtype): + # tl.atomic_add does NOT support the following types + return dtype in {torch.int64, torch.bool, torch.bfloat16} + + +def index_put_impl_(self, indices, values, accumulate, check): + # Dispatch to masked fill for single boolean index with single value + if ( + values.get_numel() == 1 + and len(indices) == 1 + and indices[0].get_dtype() in {torch.bool, torch.uint8} + ): + mask = indices[0] + for _ in range(len(mask.get_size()), len(self.get_size())): + mask = unsqueeze(mask, -1) + return index_put_as_masked_fill(self, [mask], values, accumulate) + + # Fallback in torch deterministic mode + if torch.are_deterministic_algorithms_enabled(): + return index_put_fallback(self, indices, values, accumulate) + + # Fallback if there is a boolean index + for index in indices: + if index is not None and index.get_dtype() in {torch.bool, torch.uint8}: + return index_put_fallback(self, indices, values, accumulate) + + x_size = self.get_size() + x_ndim = len(x_size) + + if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()): + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + self = index_put_fallback(self, indices, values, accumulate) + if x_ndim == 0: + self = view(self, []) + return self + + values = to_dtype(values, self.get_dtype()) + + try: + # Note that code will only get here when dtype is uint32 + indices, tensor_indices = check_and_broadcast_indices( + indices, self.get_device() + ) + except NotImplementedError: + return index_put_fallback(self, indices, values, accumulate) + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + + assert isinstance(self, TensorBox) + self.realize() + + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=check, + ) + + values = expand(values, expected_vals_size) + # all guards are set above during broadcast_tensors and expand + + scatter = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add" if accumulate else None, + ) + buffer = ir.ComputedBuffer( + None, + ir.MutationLayout(self), + scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + + if x_ndim == 0: + self = view(self, []) + return self + + +@register_lowering( + inductor_prims.masked_scatter_with_index, type_promotion_kind=None, broadcast=False +) +def masked_scatter_with_index(self, mask, source_idx, source): + self_flat, mask_flat, source_flat = (view(x, (-1,)) for x in (self, mask, source)) + + assert self.get_size() == mask.get_size() + assert mask.get_dtype() in {torch.bool, torch.uint8} + + self_loader = self_flat.make_loader() + mask_loader = mask_flat.make_loader() + source_idx_loader = source_idx.make_loader() + source_loader = source_flat.make_loader() + source_numel = source.get_numel() + + def inner_fn(idx): + self_val = self_loader(idx) + mask_val = ops.to_dtype(mask_loader(idx), torch.bool) + + def load_source_val(): + source_idx_val = source_idx_loader(idx) + i = ops.indirect_indexing(source_idx_val, source_numel) + return source_loader([i]) + + source_val = ops.masked(mask_val, load_source_val, 0) + return ops.where(mask_val, source_val, self_val) + + result_flat = Pointwise.create( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=inner_fn, + ranges=self_flat.get_size(), + ) + return view(result_flat, self.get_size()) + + +@register_lowering(aten.as_strided_scatter, type_promotion_kind=None) +def as_strided_scatter(self, src, size, stride, storage_offset=None): + output = clone(self) + output_view = as_strided(output, size, stride, storage_offset) + copy_(output_view, src) + return output + + +@register_lowering(aten.scatter, type_promotion_kind=None) +def scatter(x, dim: int, index, src, **kwargs): + return scatter_(clone(x), dim, index, src, **kwargs) + + +def scatter_fallback( + fn, + self, + dim: int, + index, + src, + *, + reduce: Optional[str] = None, + include_self: bool = True, +): + reduce_ty = "add" if fn == "aten.scatter_" else "sum" + if ( + reduce not in {None, reduce_ty} + or ( + isinstance(src, TensorBox) + and src.get_device().type == torch.device("cuda").type + and needs_fallback_due_to_atomic_add_limitations(src.get_dtype()) + ) + or ( + fn == "aten.scatter_reduce_" + and reduce == "sum" + and isinstance(src, TensorBox) + and src.get_device() == torch.device("cpu") + and config.cpp.fallback_scatter_reduce_sum + and (config.cpp.dynamic_threads or parallel_num_threads() != 1) + ) + or (reduce == reduce_ty and self.get_dtype() in {torch.bool, torch.int64}) + or torch.are_deterministic_algorithms_enabled() + ): + ir.ScatterFallback( + V.graph.current_node.target, + fn, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + return self + + return None + + +@register_lowering(aten.scatter_, type_promotion_kind=None) +def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None): + assert reduce in {None, "add", "multiply"} + + fallback_result = scatter_fallback( + "aten.scatter_", self, dim, index, src, reduce=reduce + ) + + if fallback_result: + return fallback_result + + if reduce == "add": + reduce = "sum" + elif reduce == "multiply": + reduce = "prod" + + return scatter_reduce_(self, dim, index, src, reduce) + + +@register_lowering(aten.scatter_add, type_promotion_kind=None) +def scatter_add(x, dim: int, index, src): + return scatter_add_(clone(x), dim, index, src) + + +@register_lowering(aten.scatter_add_, type_promotion_kind=None) +def scatter_add_(x, dim: int, index, src): + return scatter_reduce_(x, dim, index, src, "sum") + + +@register_lowering(aten.scatter_reduce, type_promotion_kind=None) +def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs): + return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs) + + +@register_lowering(aten.scatter_reduce_, type_promotion_kind=None) +def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True): + assert reduce in {None, "sum", "prod", "mean", "amax", "amin"} + + fallback_result = scatter_fallback( + "aten.scatter_reduce_", + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + + if fallback_result: + return fallback_result + + assert isinstance(self, TensorBox) + assert "int" in str(index.get_dtype()) + + ndim = len(self.get_size()) + if ndim == 0: + self = view(self, [1]) + + if isinstance(src, TensorBox) and len(src.get_size()) == 0: + src = view(src, [1]) + + if isinstance(index, TensorBox) and len(index.get_size()) == 0: + index = view(index, [1]) + + dim = _validate_dim(self, dim) + + self.realize() + index_loader = index.make_loader() + src_loader = src.make_loader() if isinstance(src, TensorBox) else None + + def output_indexer(idx): + # self is captured from the end of the function, so it may have 0 dim + shape = self.get_size() + ndim = len(shape) + indirect_idx = list(idx) + indirect_idx[dim] = ops.indirect_indexing( + index_loader(idx), 1 if ndim == 0 else shape[dim] + ) + return indirect_idx + + def fn(idx): + if src_loader: + return src_loader(idx) + else: + # src is a scalar + return ops.constant(src, self.get_dtype()) + + def backend_reduce_str(reduce): + if reduce == "sum": + return "atomic_add" + else: + # TODO: Need to support more reduction type + assert reduce is None + return None + + if not include_self: + # zero out the corresponding elements first + zero_out = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=lambda index: ops.constant(0, self.get_dtype()), + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=None, + ) + buffer = ir.ComputedBuffer( + None, + ir.MutationLayout(self), + zero_out, + ) + buffer.name = V.graph.register_buffer(buffer) + + # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + scatter = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=backend_reduce_str(reduce), + ) + buffer = ir.ComputedBuffer( + None, + ir.MutationLayout(self), + scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + + if ndim == 0: + self = view(self, []) + return self + + +def upsample_nearestnd( + x, + output_size, + scales_x: Tuple[Optional[float], ...], + n: int = 2, + exact: bool = False, +): + x.realize_hint() # elements are reused + x_loader = x.make_loader() + i_sizes = x.get_size()[-n:] + batch = x.get_size()[:-n] + i_sizes = [V.graph.sizevars.evaluate_static_shape(i) for i in i_sizes] + + assert len(scales_x) == n + o_sizes = output_size + + inv_scales = [i / o for i, o in zip(i_sizes, o_sizes)] + for i, scale in enumerate(scales_x): + if scale is not None: + inv_scales[i] = 1.0 / scale + + def scale_fn(x, scale, size): + # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5) + # = floor(scale * (output_index + 0.5)) + # Nearest: input_index = floor(scale * output_index) + x = ops.index_expr(x, torch.float32) + if exact: + x = ops.add(x, ops.constant(0.5, torch.float32)) + x = ops.mul(x, ops.constant(scale, torch.float32)) + x = ops.to_dtype(x, torch.int32) + return ops.indirect_indexing(x, size, check=False) + + def fn(idx): + x = idx[-n:] + b = idx[:-n] + return x_loader( + [*b, *[scale_fn(i, s, size) for i, s, size in zip(x, inv_scales, i_sizes)]] + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=[*batch, *o_sizes], + ) + + +@register_lowering(aten.upsample_nearest1d.default) +def upsample_nearest1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1) + + +@register_lowering(aten._upsample_nearest_exact1d.default) +def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True) + + +@register_lowering(aten.upsample_nearest2d.default) +def upsample_nearest2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2) + + +@register_lowering(aten._upsample_nearest_exact2d.default) +def _upsample_nearest_exact2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True) + + +@register_lowering(aten.upsample_nearest3d.default) +def upsample_nearest3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3) + + +@register_lowering(aten._upsample_nearest_exact3d.default) +def _upsample_nearest_exact3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd( + x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True + ) + + +def _create_constants(*args, dtype): + return tuple(ops.constant(a, dtype) for a in args) + + +@register_lowering(aten.upsample_bicubic2d.default) +def upsample_bicubic2d_default( + x, + output_size, + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + x.realize_hint() + x_loader = x.make_loader() + + N, C, iH, iW = x.get_size() + oH, oW = output_size + + iH = V.graph.sizevars.evaluate_static_shape(iH) + iW = V.graph.sizevars.evaluate_static_shape(iW) + + def get_int_dtype(maxval): + if maxval > torch.iinfo(torch.int32).max: + return torch.int64 + return torch.int32 + + def compute_scale(in_size, out_size, align_corners, scale=None): + if align_corners: + return (in_size - 1) / (out_size - 1) if out_size > 1 else 0 + else: + return 1 / scale if scale is not None and scale > 0 else in_size / out_size + + def compute_source_index(scale, dst_index, align_corners): + dst_index_ie = ops.index_expr(dst_index, torch.float32) + scale = ops.constant(scale, torch.float32) + if align_corners: + return ops.mul(scale, dst_index_ie) + else: + half = ops.constant(0.5, torch.float32) + return scale * (dst_index_ie + half) - half + + def cubic_convolution1(x, A): + _Ap2, _Ap3, _1 = _create_constants(A + 2, A + 3, 1, dtype=torch.float32) + return (_Ap2 * x - _Ap3) * x * x + _1 + + def cubic_convolution2(x, A): + _A, _4A, _5A, _8A = _create_constants( + A, 4 * A, 5 * A, 8 * A, dtype=torch.float32 + ) + return ((_A * x - _5A) * x + _8A) * x - _4A + + def get_cubic_upsample_coefficients(t): + A = -0.75 + _1 = ops.constant(1.0, torch.float32) + c0 = cubic_convolution2(ops.add(t, _1), A) + c1 = cubic_convolution1(t, A) + + x2 = ops.sub(_1, t) + c2 = cubic_convolution1(x2, A) + c3 = cubic_convolution2(ops.add(x2, _1), A) + return (c0, c1, c2, c3) + + def cubic_interp1d(xs, t): + cs = get_cubic_upsample_coefficients(t) + # dot product between xs and cs + return xs[0] * cs[0] + xs[1] * cs[1] + xs[2] * cs[2] + xs[3] * cs[3] + + height_scale = compute_scale(iH, oH, align_corners, scales_h) + width_scale = compute_scale(iW, oW, align_corners, scales_h) + + def clamp(v, min, max): + return ops.maximum(min, ops.minimum(max, v)) + + def fn(idx): + n, c, oy, ox = idx + + real_x = compute_source_index(width_scale, ox, align_corners) + in_x = ops.floor(real_x) + t_x = ops.sub(real_x, in_x) + + real_y = compute_source_index(height_scale, oy, align_corners) + in_y = ops.floor(real_y) + t_y = ops.sub(real_y, in_y) + + def load_bounded(fy, fx): + # TODO(Lezcano) Here we may not need to set-up a device_size + _0 = ops.constant(0, torch.int32) + iHm1 = ops.constant(iH - 1, torch.int32) + iWm1 = ops.constant(iW - 1, torch.int32) + iy = ops.indirect_indexing(clamp(fy, _0, iHm1), iH, check=False) + ix = ops.indirect_indexing(clamp(fx, _0, iWm1), iW, check=False) + return x_loader([n, c, iy, ix]) + + iy = ops.to_dtype(in_y, get_int_dtype(iH + 1)) + ix = ops.to_dtype(in_x, get_int_dtype(iW + 1)) + iys_ofs = tuple(ops.add(iy, ofs) for ofs in (-1, 0, 1, 2)) + ixs_ofs = tuple(ops.add(ix, ofs) for ofs in (-1, 0, 1, 2)) + + def get_x_interp(y): + coeffs_x = tuple(load_bounded(y, x) for x in ixs_ofs) + return cubic_interp1d(coeffs_x, t_x) + + coeffs_y = tuple(get_x_interp(y) for y in iys_ofs) + return cubic_interp1d(coeffs_y, t_y) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=[N, C, sympy.Integer(oH), sympy.Integer(oW)], + ) + + +@register_lowering(aten.reflection_pad1d_backward) +@register_lowering(aten.reflection_pad2d_backward) +@register_lowering(aten.reflection_pad3d_backward) +def _reflection_padnd_backward(grad_output, x, padding): + dim = len(padding) // 2 + + dhw = [h - 1 for h in x.get_size()[-dim:]] + grad_loader = grad_output.make_loader() + + padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] + padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] + + def fn(idx): + b = idx[:-dim] + xyz = idx[-dim:] + + def load_from_output(x): + return grad_loader([*b, *x]) + + def index_range_condition(index_range): + i, lb, ub = index_range + i = ops.index_expr(i, torch.int32) + lb = ops.index_expr(lb, torch.int64) + ub = ops.index_expr(ub, torch.int64) + return ops.and_(ops.ge(i, lb), ops.le(i, ub)) + + # Areas after reflection: + # + # top-left | top | top-right + # ----------------------------------------- + # left | center | right + # ----------------------------------------- + # bottom-left | bottom | bottom-right + # + # The center area is the original matrix. Other areas are reflections. + + center = [xyz[i] + padding_left[i] for i in range(dim)] + left_reflect = [padding_left[i] - xyz[i] for i in range(dim)] + right_reflect = [2 * dhw[i] + padding_left[i] - xyz[i] for i in range(dim)] + + # Accumulate gradients from different areas + # If some of the padding is negative, center load is not always valid + range_c = [ + (center[i], 0, dhw[i] + padding_left[i] + padding_right[i]) + for i in range(dim) + ] + cond = functools.reduce( + ops.and_, [index_range_condition(range_c[i]) for i in range(dim)] + ) + grad = ops.masked(cond, lambda: load_from_output(center), 0.0) + + def accumulate(grad, out, index_ranges): + # If the upper bound is less than the lower bound, we can get rid of one accumulation. + # This happens when the padding size is zero. + for i in range(dim): + upper_less_than_lower = index_ranges[i][2] < index_ranges[i][1] + if isinstance(upper_less_than_lower, bool) and upper_less_than_lower: + return grad + cond = functools.reduce( + ops.and_, + [index_range_condition(index_range) for index_range in index_ranges], + ) + g = ops.masked(cond, lambda: load_from_output(out), 0.0) + return ops.add(grad, g) + + for area in itertools.product(*[[-1, 0, 1] for _ in range(dim)]): + if area == tuple([0] * dim): + # center, this is already done. + continue + + outs = [] + index_ranges = [] + + for i in range(dim): + if area[i] == 0: + out = center[i] + index_range = range_c[i] + elif area[i] == -1: + out = left_reflect[i] + index_range = (xyz[i], 1, padding_left[i]) + elif area[i] == 1: + out = right_reflect[i] + index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1) + + outs.append(out) # type: ignore[possibly-undefined] + index_ranges.append(index_range) # type: ignore[possibly-undefined] + + grad = accumulate(grad, outs, index_ranges) + + return grad + + return Pointwise.create( + device=grad_output.get_device(), + dtype=grad_output.get_dtype(), + inner_fn=fn, + ranges=list(x.get_size()), + ) + + +@register_lowering(prims.rev.default) +def rev(x, dims): + # note - dims pre-canonicalized + x_loader = x.make_loader() + sizes = x.get_size() + + def loader(idx): + idx = list(idx) + assert len(idx) == len(sizes) + for dim in dims: + idx[dim] = (sizes[dim] - 1) - idx[dim] + + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=loader, + ranges=sizes, + ) + + +@register_lowering(aten.constant_pad_nd, type_promotion_kind=None) +def constant_pad_nd(x, padding, fill_value=0): + assert (len(padding) % 2) == 0 + if all(p == 0 for p in padding): + return clone(x) + + sizes = x.get_size() + + bounds = list(reversed(list(zip(padding[::2], padding[1::2])))) + n = len(sizes) - len(bounds) + + # if padding is a complicated expression, hoist it + bounds_precomp: List[Tuple[sympy.Symbol, Any]] = [] + for l, h in bounds: + bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type] + + output_size = list(sizes[:n]) + mask_sizes = [] + for (low, high), size in zip(bounds, sizes[n:]): + mask_sizes.append(size) + output_size.append(sympy.expand(size + low + high)) + assert len(output_size) == len(sizes) + fill_value = dtype_to_type(x.get_dtype())(fill_value) + + def mask(index): + mask = [] + for idx, (low, high), length in zip(index[n:], bounds, mask_sizes): + if low != 0: + mask.append(range_mask_low(idx, 0)) + if high != 0: + mask.append(range_mask_high(idx, length)) + mask = functools.reduce(ops.and_, mask) + return ops.masked(mask, lambda: x_loader(index), fill_value) + + def offset_fn(index): + new_index = list(index[:n]) + for idx, (low, high) in zip(index[n:], bounds_precomp): + new_index.append(idx - low) + assert len(new_index) == len(index) + return mask(new_index) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=offset_fn, + ranges=output_size, + ) + + +def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]): + return ops.ge( + ops.index_expr(i, torch.int64), + ops.index_expr(sympy.Integer(low), torch.int64), + ) + + +def range_mask_high(i: sympy.Expr, high: sympy.Expr): + return ops.lt( + ops.index_expr(i, torch.int64), + ops.index_expr(high, torch.int64), + ) + + +def range_mask(i: sympy.Expr, high: sympy.Expr, low: sympy.Expr): + return ops.and_( + range_mask_low(i, low), + range_mask_high(i, high), + ) + + +def constant_boundary_condition_2d(x, fill_value, padding=None, pad_fill_value=1.0): + *_, h, w = x.get_size() + x_loader = x.make_loader() + padding_h = padding[0] if padding else 0 + padding_w = padding[1] if padding else 0 + + def load(index): + *prefix, ih, iw = index + + mask = ops.and_( + range_mask(ih, h + padding_h, -padding_h), + range_mask(iw, w + padding_w, -padding_w), + ) + return ( + ops.masked( + mask, + lambda: constant_boundary_condition_2d(x, pad_fill_value)( + [*prefix, ih, iw] + ), + fill_value, + ) + if padding + else ops.masked(mask, lambda: x_loader([*prefix, ih, iw]), fill_value) + ) + + return load + + +def pooling_size(x, i, kernel_size, stride, padding, ceil_mode): + x_out = FloorDiv( + x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i] + ) + + if ceil_mode: + x_alt = FloorDiv( + x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i] + ) + if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0: + # Sliding windows must start within the input or left padding + x_alt -= 1 # type: ignore[assignment] + V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type] + if V.graph.sizevars.size_hint(x_out - x_alt) == 0: + # ceil mode is actually a no-op, lets guard on that + V.graph.sizevars.guard_equals(x_out, x_alt) + ceil_mode = False + else: + x_out = x_alt + return x_out, ceil_mode + + +fallback_max_pool2d_with_indices = fallback_handler( + aten.max_pool2d_with_indices.default, + add_to_fallback_set=False, +) + + +@register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None) +def max_pool2d_with_indices( + x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False +): + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + kernel_size = pad_listlike(kernel_size, 2) + stride = pad_listlike(stride, 2) + padding = pad_listlike(padding, 2) + dilation = pad_listlike(dilation, 2) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + assert len(x.get_size()) in (3, 4) + + x.realize_hint() + *batch, h, w = x.get_size() + + h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode) + w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode) + + if padding[0] or padding[1] or ceil_mode1 or ceil_mode2: + x_loader = constant_boundary_condition_2d(x, float("-inf")) + else: + x_loader = x.make_loader() + + new_size = list(batch) + [h_out, w_out] + window_size = kernel_size[0] * kernel_size[1] + + if window_size > 25 or any(d != 1 for d in dilation): + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_max_pool2d_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode + ) + + def fn(idx, return_index): + *prefix, bh, bw = idx + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])): + ih = bh * stride[0] + ih - padding[0] + iw = bw * stride[1] + iw - padding[1] + val = x_loader([*prefix, ih, iw]) + if return_index: + index = ops.index_expr(ih * w + iw, torch.int64) + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) + if maxval is None: + maxval = val + else: + maxval = ops.maximum(val, maxval) + if return_index: + return maxindex + else: + return maxval + + r1 = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=functools.partial(fn, return_index=False), + ranges=new_size, + ) + r2 = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=functools.partial(fn, return_index=True), + ranges=new_size, + ) + # TODO(jansel): should we force these to be realized? + return r1, r2 + + +fallback_max_pool2d_with_indices_backward = fallback_handler( + aten.max_pool2d_with_indices_backward.default, + add_to_fallback_set=False, +) + + +@register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None) +def max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices +): + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + assert len(x.get_size()) in (3, 4) + + # we will read this many times, so make sure it is computed + grad_output.realize_hint() + try: + gO_stride = grad_output.get_stride() + except AttributeError: + # some classes don't have `get_stride` + # TODO will need a better way of determining if inputs are channels-last + gO_stride = None + if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined] + data = x.data.data # type: ignore[attr-defined] + x_buffer = ir.ComputedBuffer( + name=None, + layout=ir.FlexibleLayout( + device=data.get_device(), + dtype=data.get_dtype(), + size=data.get_size(), + ), + data=data, + ) + x_buffer.decide_layout() + x_stride = x_buffer.get_stride() + else: + try: + x_stride = x.get_stride() + except AttributeError: + x_stride = None + + is_channels_last = (x_stride is not None and x_stride[1] == 1) or ( + gO_stride is not None and gO_stride[1] == 1 + ) + autotune = ( + config.coordinate_descent_tuning + or config.max_autotune + or config.max_autotune_pointwise + ) + if any(d != 1 for d in dilation) or (is_channels_last and not autotune): + # don't codegen channels-last when autotune is not enabled, it's very slow + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + indices.realize_hint() + + *batch, height, width = x.get_size() + *_, pooled_height, pooled_width = grad_output.get_size() + + indices_loader = indices.make_loader() + grad_loader = grad_output.make_loader() + new_size = list(x.get_size()) + + h_window_size = max( + [ + max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1) + for h in range(kernel_size[0] * 2) + ] + ) + w_window_size = max( + [ + max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1) + for w in range(kernel_size[1] * 2) + ] + ) + + window_size = h_window_size * w_window_size + + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + indices_size = indices.get_size() + + def fn(idx): + *prefix, h, w = idx + index_test = ops.index_expr(h * width + w, torch.int32) + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + grad_index = [ + *prefix, + ops.indirect_indexing( + ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))), + indices_size[-2], + check=False, + ), + ops.indirect_indexing( + ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))), + indices_size[-1], + check=False, + ), + ] + + index_actual = indices_loader(grad_index) + grad_part = grad_loader(grad_index) + check = ops.eq(index_actual, index_test) + + if gradient is None: + # don't need mask for 0, 0 + gradient = ops.where( + check, grad_part, ops.constant(0.0, torch.float32) + ) + else: + mask = ops.and_( + ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ), + check, + ) + gradient = ops.where(mask, ops.add(gradient, grad_part), gradient) + assert gradient is not None + return gradient + + return Pointwise.create( + device=grad_output.get_device(), + dtype=grad_output.get_dtype(), + inner_fn=fn, + ranges=new_size, + ) + + +def pad_adaptive_loader(x, pad_val=0.0): + *_, h, w = x.get_size() + x_loader = x.make_loader() + + def load(prefix, increments, start_indices, end_indices): + ih, iw = increments + h_start_index, w_start_index = start_indices + h_end_index, w_end_index = end_indices + + mask = ops.and_( + ops.lt( + ops.index_expr(h_start_index + ih, torch.int64), + ops.index_expr(h_end_index, torch.int64), + ), + ops.lt( + ops.index_expr(w_start_index + iw, torch.int64), + ops.index_expr(w_end_index, torch.int64), + ), + ) + + return ops.masked( + mask, + lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]), + pad_val, + ) + + return load + + +def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns): + h_start_index_fn, w_start_index_fn = start_index_fns + h_end_index_fn, w_end_index_fn = end_index_fns + + def fn_sum(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + total = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + if total is None: + total = val + else: + total = ops.add(val, total) + return total + + return fn_sum + + +fallback_adaptive_avg_pool2d = fallback_handler( + aten._adaptive_avg_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten._adaptive_avg_pool2d) +def _adaptive_avg_pool2d(x, output_size): + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) + + h_out, w_out = output_size + + # no-op if the same input and output + if h_in == h_out and w_in == w_out: + return clone(x) + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()) + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [h_in // h_out, w_in // w_out] + return avg_pool2d(x, kernel_size) + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) + h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) + + w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) + w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_avg_pool2d(x, output_size) + + fn_sum = _adaptive_pooling_idx_sum( + [h_kernel_max, w_kernel_max], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + + ones_loader = pad_adaptive_loader(ones_like(x)) + + def fn(idx): + return ops.truediv( + fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader) + ) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + # TODO: should we force these to be realized? + return rv + + +def _adaptive_pooling_idx_max(kernel_maxes, in_sizes, out_sizes, return_index, loader): + # NOTE: There is some duplication between this and addaptive_avg_pool2d and max_pool2d + # Look into refactoring/deduplication after #116418 is merged. + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + h_start_index_fn = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) + h_end_index_fn = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) + w_start_index_fn = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) + w_end_index_fn = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) + + def fn_max(idx): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + index = ops.index_expr( + (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 + ) + if return_index: + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) + if maxval is None: + maxval = val + else: + maxval = ops.maximum(val, maxval) + if return_index: + return maxindex + else: + return maxval + + return fn_max + + +fallback_adaptive_max_pool2d = fallback_handler( + aten.adaptive_max_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.adaptive_max_pool2d) +def adaptive_max_pool2d(x, output_size): + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) + + h_out, w_out = output_size + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()), empty( + o_size, dtype=torch.int64, device=x.get_device() + ) + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [h_in // h_out, w_in // w_out] + return max_pool2d_with_indices(x, kernel_size) + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_max_pool2d(x, output_size) + + inner_func_max_val = _adaptive_pooling_idx_max( + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + return_index=False, + loader=pad_adaptive_loader(x, float("-inf")), + ) + + inner_func_max_idx = _adaptive_pooling_idx_max( + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + return_index=True, + loader=pad_adaptive_loader(x, float("-inf")), + ) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=inner_func_max_val, + ranges=new_size, + ) + ri = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=inner_func_max_idx, + ranges=new_size, + ) + return rv, ri + + +fallback_fractional_max_pool2d = fallback_handler( + aten.fractional_max_pool2d.default, add_to_fallback_set=False +) + + +def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim): + out_sz = out_sz[dim] + in_sz = in_sz[dim] + kernel_sz = kernel_sz[dim] + alpha = (in_sz - kernel_sz) / (out_sz - 1) + samples_loader = samples.make_loader() + + def load(prefix, i): + sample = samples_loader([*prefix, dim]) + i_expr = ops.index_expr(i, samples.get_dtype()) + alpha_expr = ops.index_expr(alpha, samples.get_dtype()) + seq_i = ops.floor((i_expr + sample) * alpha_expr) - ops.floor( + sample * alpha_expr + ) + seq_i = ops.to_dtype(seq_i, torch.int64) + + mask = ops.lt( + i_expr, + ops.index_expr(out_sz - 1, torch.int64), + ) + return ops.where(mask, seq_i, ops.index_expr(in_sz - kernel_sz, torch.int64)) + + return load + + +@register_lowering(aten.fractional_max_pool2d) +def fractional_max_pool2d(x, kernel_size, output_size, random_samples): + x.realize_hint() + *batch, inp_h, inp_w = x.get_size() + kernel_h, kernel_w = kernel_size + h_out, w_out = output_size + + if kernel_h * kernel_w >= 25: + return fallback_fractional_max_pool2d( + x, kernel_size, output_size, random_samples + ) + + gen_offsets_for_dim = functools.partial( + _fractional_pooling_offsets, + samples=random_samples, + in_sz=[inp_h, inp_w], + out_sz=output_size, + kernel_sz=kernel_size, + ) + + h_index_fn = gen_offsets_for_dim(dim=0) + w_index_fn = gen_offsets_for_dim(dim=1) + x_loader = x.make_loader() + + def fn(idx, return_index): + *prefix, bh, bw = idx + + h_start_index = ops.indirect_indexing(h_index_fn(prefix, bh), inp_h) + w_start_index = ops.indirect_indexing(w_index_fn(prefix, bw), inp_w) + + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])): + val = x_loader([*prefix, h_start_index + ih, w_start_index + iw]) + if return_index: + index = ops.index_expr( + (h_start_index + ih) * inp_w + w_start_index + iw, torch.int64 + ) + if maxindex is None: + maxindex = index + else: + maxindex = ops.where( + ops.or_(ops.gt(val, maxval), ops.isnan(val)), index, maxindex + ) + if maxval is None: + maxval = val + else: + maxval = ops.maximum(val, maxval) + if return_index: + return maxindex + else: + return maxval + + new_size = list(batch) + [h_out, w_out] + rv = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=functools.partial(fn, return_index=False), + ranges=new_size, + ) + + ri = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=functools.partial(fn, return_index=True), + ranges=new_size, + ) + return rv, ri + + +@register_lowering(aten.upsample_nearest2d_backward.default) +def upsample_nearest2d_backward( + x, output_size=None, input_size=None, scales_h=None, scales_w=None +): + x.realize_hint() + + *batch, inp_h, inp_w = x.get_size() + inp_h = V.graph.sizevars.evaluate_static_shape(inp_h) + inp_w = V.graph.sizevars.evaluate_static_shape(inp_w) + + *batch, out_h, out_w = input_size + + if inp_h % out_h == 0 and inp_w % out_w == 0: + return avg_pool2d(x, [inp_h // out_h, inp_w // out_w], divisor_override=1) + + h_kernel_max = ceildiv(inp_h, out_h) + w_kernel_max = ceildiv(inp_w, out_w) + + def start_index(index, out_dim, inp_dim): + return CeilDiv(index * inp_dim, out_dim) + + def end_index(index, out_dim, inp_dim): + return start_index((index + 1), out_dim, inp_dim) + + h_start_index = functools.partial(start_index, out_dim=out_h, inp_dim=inp_h) + h_end_index = functools.partial(end_index, out_dim=out_h, inp_dim=inp_h) + + w_start_index = functools.partial(start_index, out_dim=out_w, inp_dim=inp_w) + w_end_index = functools.partial(end_index, out_dim=out_w, inp_dim=inp_w) + + fn_sum = _adaptive_pooling_idx_sum( + [h_kernel_max, w_kernel_max], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + + def fn(idx): + return fn_sum(idx, pad_adaptive_loader(x)) + + rv = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=list(input_size), + ) + + return rv + + +fallback_avg_pool2d = fallback_handler( + aten.avg_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool2d, type_promotion_kind=None) +def avg_pool2d( + x, + kernel_size, + stride=(), + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0] + kernel_size = pad_listlike(kernel_size, 2) + stride = pad_listlike(stride, 2) + padding = pad_listlike(padding, 2) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(x.get_size()) in (3, 4) + + x.realize_hint() + *batch, h, w = x.get_size() + + h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode) + w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode) + + if padding[0] or padding[1] or ceil_mode1 or ceil_mode2: + x_loader = constant_boundary_condition_2d(x, 0.0) + had_padding = True + else: + x_loader = x.make_loader() + had_padding = False + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = kernel_size[0] * kernel_size[1] + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_avg_pool2d( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def fn_sum(idx, loader): + *prefix, bh, bw = idx + total = None + for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])): + ih = bh * stride[0] + ih - padding[0] + iw = bw * stride[1] + iw - padding[1] + val = loader([*prefix, ih, iw]) + if total is None: + total = val + else: + total = ops.add(val, total) + return total + + if not had_padding or divisor_override: + if divisor_override: + scale = 1 / divisor_override + else: + scale = 1.0 / (kernel_size[0] * kernel_size[1]) + + def fn(idx): + return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype)) + + else: + ones_loader = constant_boundary_condition_2d( + ones_like(x), 0.0, padding if count_include_pad else None + ) + + def fn(idx): + # TODO(jansel): optimize to do `int(x 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(ph, pw): + """ + This computes the scaling factor that we will divide an element + by when `count_include_pad=False` + """ + stride_h = ops.constant(stride[0], torch.int32) + stride_w = ops.constant(stride[1], torch.int32) + pad_h = ops.constant(padding[0], torch.int32) + pad_w = ops.constant(padding[1], torch.int32) + kernel_h = ops.constant(kernel_size[0], torch.int32) + kernel_w = ops.constant(kernel_size[1], torch.int32) + hstart = ops.sub(ops.mul(ph, stride_h), pad_h) + wstart = ops.sub(ops.mul(pw, stride_w), pad_w) + hend = ops.minimum( + ops.add(hstart, kernel_h), + ops.add(ops.index_expr(height, torch.int32), pad_h), + ) + wend = ops.minimum( + ops.add(wstart, kernel_w), + ops.add(ops.index_expr(width, torch.int32), pad_w), + ) + hstart = ops.maximum(hstart, ops.constant(0, torch.int32)) + wstart = ops.maximum(wstart, ops.constant(0, torch.int32)) + hend = ops.minimum(hend, ops.index_expr(height, torch.int32)) + wend = ops.minimum(wend, ops.index_expr(width, torch.int32)) + divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart)) + return divide_factor + + def fn(idx): + *prefix, h, w = idx + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] + else: + scale = compute_pool_size_without_padding(ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where(mask, part, ops.constant(0.0, torch.float32)) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + +def _validate_reduction_axis(x, axis): + size = x.get_size() + if isinstance(axis, int): + axis = [axis] + elif not axis: + axis = range(len(size)) + if len(size) == 0: + assert tuple(axis) in [(), (0,), (-1,)], f"invalid axis: {axis}" + return [] + axis = list(axis) + for i in range(len(axis)): + if axis[i] < 0: + axis[i] += len(size) if len(size) else 1 + assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0) + assert len(set(axis)) == len(axis), "reduction axis not unique" + return axis + + +def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = set(_validate_reduction_axis(x, axis)) + + kept_sizes = [] + kept_idx = [] + reduced_sizes = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + reduced_sizes.append(size[i]) + else: + kept_idx.append(i) + kept_sizes.append(size[i]) + + def loader(index, reduction_index): + assert len(reduction_index) == len(reduced_idx) + if keepdims: + assert len(index) == len(size) + index = [index[i] for i in kept_idx] + assert len(index) == len(kept_idx) + new_index = [None] * (len(index) + len(reduction_index)) + for idx, var in itertools.chain( + zip(kept_idx, index), zip(reduced_idx, reduction_index) + ): + new_index[idx] = var + return inner_loader(new_index) + + if keepdims: + new_size = list(size) + for i in reduced_idx: + new_size[i] = sympy.Integer(1) + else: + new_size = kept_sizes + + inner_loader = x.make_loader() + return dict( + device=x.get_device(), + dst_dtype=override_return_dtype or x.get_dtype(), + src_dtype=x.get_dtype(), + inner_fn=loader, + ranges=new_size, + reduction_ranges=reduced_sizes, + ) + + +def make_reduction(reduction_type: str, override_return_dtype=None): + def inner(x, axis=None, keepdims=False, *, dtype=None): + kwargs = _make_reduction_inner( + x, + axis=axis, + keepdims=keepdims, + dtype=dtype, + override_return_dtype=override_return_dtype, + ) + result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) + if isinstance( + result.data.data, Reduction + ): # Only realize if reduction isn't unrolled + result.realize() + return result + + return inner + + +def _make_scan_inner(x, *, axis, dtype): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_dim(x, axis) + + return dict( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + size=x.get_size(), + axis=axis, + ) + + +@register_lowering(aten.mean) +def mean(x, axis=None, keepdim=False, *, dtype=None): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + # compute in higher-precision until end of mean lowering + output_dtype = x.get_dtype() + if output_dtype in (torch.float16, torch.bfloat16): + x = to_dtype(x, torch.float) + sum_result = sum_(x, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + return to_dtype(div(sum_result, denom), output_dtype) + + +def var_mean_sum_(x, axis, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + x_mean = mean(x, axis, keepdim=True) + if return_mean: + x_mean.realize() + + diffs = square(sub(x, x_mean)) + sum_result = sum_(diffs, axis, keepdim) + + denom = sympy_product(size[i] for i in axis) + if correction: + denom = sympy.Max(denom - correction, 0) + denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + x_var = div(sum_result, denom) + if not return_mean: + return (x_var,) + + x_mean = x_mean if keepdim else squeeze(x_mean, axis) + return x_var, x_mean + + +def use_two_step_variance(x, axis, keepdim): + # Instead of unrolling welford, just unroll the simpler two-step var + axis = _validate_reduction_axis(x, axis) + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + + ranges = kwargs["ranges"] + reduction_numel = sympy_product(kwargs["reduction_ranges"]) + return ( + isinstance(reduction_numel, sympy.Integer) + and int(reduction_numel) < config.unroll_reductions_threshold + and sympy_product(ranges) != 1 + ) + + +def var_mean_welford_(x, axis, *, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + loader = kwargs.pop("inner_fn") + kwargs.pop("dst_dtype") + kwargs.pop("src_dtype") + + mean, m2, _ = ir.WelfordReduction.create( + inner_fns=(loader,), + reduction_type="welford_reduce", + dtype=x.get_dtype(), + **kwargs, + ) + m2.realize() + + dtype = x.get_dtype() + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + rnumel = sympy_product(size[i] for i in axis) + + def get_constant_or_index_expr(x, dtype): + if isinstance(x, sympy.Expr) and not x.is_number: + return ops.to_dtype(ops.index_expr(x, torch.int64), dtype) + return ops.constant(x, dtype) + + def scale_fn(data): + c = get_constant_or_index_expr(correction, dtype) + N = get_constant_or_index_expr(rnumel, dtype) + zero = ops.constant(0, dtype) + return data / ops.maximum(zero, N - c) + + var = make_pointwise(scale_fn)(m2) + + if return_mean: + mean.realize() + return var, mean + return (var,) + + +def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): + out_dtype = x.get_dtype() + compute_dtype = get_computation_dtype(out_dtype) + x = to_dtype(x, compute_dtype, copy=False) + kwargs = dict( + x=x, + axis=axis, + correction=correction, + keepdim=keepdim, + return_mean=return_mean, + ) + output = ( + var_mean_sum_(**kwargs) + if use_two_step_variance(x, axis=axis, keepdim=keepdim) + else var_mean_welford_(**kwargs) + ) + output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) + return output[0] if not return_mean else output + + +@register_lowering([aten.var, prims.var]) +def var_(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False + ) + + +@register_lowering(aten.var_mean) +def var_mean(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True + ) + + +def pow_recursive(x, y, dtype): + if y < 0: + return pow_recursive(ops.reciprocal(x), -y, dtype) + if y == 0: + return ops.constant(1, dtype) + if y == 1: + return x + + result = pow_recursive(x, y // 2, dtype) + result = ops.mul(result, result) + if (y % 2) == 1: + result = ops.mul(result, x) + return result + + +@make_pointwise +def pow_native(a, b): + return ops.pow(a, b) + + +fallback_pow_tensor_tensor = fallback_handler( + aten.pow.Tensor_Tensor, add_to_fallback_set=False +) +fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False) +fallback_pow_tensor_scalar = fallback_handler( + aten.pow.Tensor_Scalar, add_to_fallback_set=False +) + + +@register_lowering(aten.pow, broadcast=True) +def pow(a, b): + if isinstance(b, float) and b == int(b): + return pow(a, int(b)) + elif isinstance(b, float) and b == 0.5: + return sqrt(a) + elif isinstance(b, int) and b == 1: + return clone(a) + + # Type promotion ensures all tensor arguments have the same type + dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox)) + is_integer_pow = is_integer_dtype(dtype) + + # Optimize away small fixed powers, or for integers avoid falling back to ATen + embed_exponent = isinstance(b, int) and ( + -32 < b < 32 or (is_integer_pow and b >= 0) + ) + if embed_exponent: + loader = a.make_loader() + + def fn(idx): + return pow_recursive(loader(idx), b, a.get_dtype()) + + return Pointwise.create( + device=a.get_device(), + dtype=a.get_dtype(), + inner_fn=fn, + ranges=a.get_size(), + ) + + if isinstance(a, Number): + if a == 1: + return full_like(b, 1) + if a == 2 and is_float_dtype(b.get_dtype()): + return exp2(b) + + if is_integer_pow: + # ops.pow doesn't work for integers + if isinstance(a, Number): + return fallback_pow_scalar(a, b) + elif isinstance(b, Number): + return fallback_pow_tensor_scalar(a, b) + else: + return fallback_pow_tensor_tensor(a, b) + + return pow_native(a, b) + + +def mutate_to(changed, val, unsafe_alias=False): + if isinstance(changed, TensorBox): + changed_data = changed.data + else: + changed_data = changed + if isinstance(val, TensorBox): + val = val.data + + if not isinstance(val, ir.StorageBox): + # introduce a copy to handle views + val = Pointwise.create( + device=changed.get_device(), + dtype=changed.get_dtype(), + inner_fn=val.make_loader(), + ranges=changed.get_size(), + ).data + assert isinstance(val, ir.StorageBox) + + if isinstance(changed_data, ir.StorageBox) and not ( + changed_data.is_input_buffer() or isinstance(changed_data.data, ir.NopKernel) + ): + # Fast path, just swing the data pointer + val.realize() + changed_data.data = val.data + return changed + + ir.MutationLayout.realize_into(val, changed_data, unsafe_alias=unsafe_alias) + return changed + + +@register_lowering(aten.fill_) +def fill_(x, fill_value): + return mutate_to(x, full_like(x, fill_value)) + + +@register_lowering(aten.copy_, type_promotion_kind=None) +def copy_(dst, src, non_blocking=False): + src = to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) + + +@make_pointwise +def floordiv(a, b): + return ops.floordiv(a, b) + + +@make_pointwise +def truncdiv(a, b): + return ops.truncdiv(a, b) + + +@register_lowering(aten.div, broadcast=True) +def div_mode(a, b, rounding_mode=None): + both_integer = is_integer_type(a) and is_integer_type(b) + both_boolean = is_boolean_type(a) and is_boolean_type(b) + + # floordiv and truncdiv need special handling for integer tensors on Triton, + # see the discussion at https://github.com/openai/triton/issues/605 + if rounding_mode == "floor": + assert not both_boolean, "floordiv operands can not be boolean at the same time" + return floordiv(a, b) if both_integer else floor(div(a, b)) + if rounding_mode == "trunc": + assert not both_boolean, "truncdiv operands can not be boolean at the same time" + return truncdiv(a, b) if both_integer else trunc(div(a, b)) + return div(a, b) + + +@register_lowering([aten.mul], broadcast=True) +def mul(a, b): + both_bool = is_boolean_type(a) and is_boolean_type(b) + if both_bool: + return logical_and(a, b) + else: + fn = ops_wrapper(aten.mul.__name__) + return make_pointwise(fn)(a, b) + + +# NOTE: prims.div maps to a / b in C, so performs truncation division on +# integer inputs and true division for floating and complex inputs. +@register_lowering([prims.div], broadcast=True) +def div_prim(a, b): + is_integral = all(is_boolean_type(x) or is_integer_type(x) for x in [a, b]) + + if is_integral: + return truncdiv(a, b) + + def fn(*args): + return ops.truediv(*args) + + return make_pointwise(fn)(a, b) + + +@register_lowering( + [aten.true_divide, aten.div.Tensor], + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def div(a, b): + a, b = promote_constants( + (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + return div_prim(a, b) + + +@register_lowering([aten.fmod, prims.fmod], broadcast=True) +def fmod(a, b): + is_integral = is_boolean_type(a) or is_integer_type(a) + + if is_integral: + + def fn(a, b): + return ops.mod(a, b) + + else: + + def fn(a, b): + return ops.fmod(a, b) + + return make_pointwise(fn)(a, b) + + +@register_lowering(aten.rsqrt) +def rsqrt(x): + dtype = x.get_dtype() + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + x = to_dtype(x, torch.get_default_dtype()) + + def _rsqrt(x): + return ops.rsqrt(x) + + return make_pointwise(_rsqrt)(x) + + +@register_lowering([aten.sum, prims.sum]) +def sum_(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("sum", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + +fallback_cumsum = fallback_handler(aten.cumsum.default) +fallback_cumprod = fallback_handler(aten.cumprod.default) +fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default) + + +@register_lowering(aten.cumsum) +def cumsum(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + result = ir.Scan.create(**kwargs, combine_fn=ops.add, init=0) + if result is None: + return fallback_cumsum(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.cumprod) +def cumprod(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + result = ir.Scan.create(**kwargs, combine_fn=ops.mul, init=1) + if result is None: + return fallback_cumprod(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.logcumsumexp) +def logcumsumexp(x, dim): + def log_add_exp_helper(a, b): + min_v = ops.minimum(a, b) + max_v = ops.maximum(a, b) + mask = (min_v != max_v) | (~ops.isinf(min_v)) + return ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a) + + dtype = x.get_dtype() + if len(x.get_size()) == 0: + assert dim in [0, -1] + return clone(x) + + kwargs = _make_scan_inner(x, axis=dim, dtype=dtype) + result = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper, init=float("-inf")) + if result is None: + return fallback_logcumsumexp(x, dim=dim) + return result + + +@register_lowering(aten.prod) +def prod(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("prod", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + +@register_lowering(aten.any) +def reduce_any(x, dim=None, keepdim=False): + x = to_dtype(x, torch.bool) + return make_reduction("any")(x, axis=dim, keepdims=keepdim) + + +@register_lowering(aten.max, type_promotion_kind=None) +def reduce_max(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amax(x, axis=dim, keepdims=keepdim), + reduce_argmax(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amax(x, axis=None, keepdims=keepdim) + + +@register_lowering(aten.min, type_promotion_kind=None) +def reduce_min(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amin(x, axis=dim, keepdims=keepdim), + reduce_argmin(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amin(x, axis=None, keepdims=keepdim) + + +register_lowering(prims.xor_sum)(make_reduction("xor_sum")) +reduce_amax = register_lowering(aten.amax)(make_reduction("max")) +reduce_amin = register_lowering(aten.amin)(make_reduction("min")) +reduce_argmax = register_lowering(aten.argmax)( + make_reduction("argmax", override_return_dtype=torch.int64) +) +reduce_argmin = register_lowering(aten.argmin)( + make_reduction("argmin", override_return_dtype=torch.int64) +) + +add = register_pointwise( + aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or" +) + + +def register_pointwise_numeric(op, name=None, triton_fallback=None): + return register_pointwise( + op, + name=name, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + triton_fallback=triton_fallback, + ) + + +def register_pointwise_numeric_ldf64(op): + return register_pointwise( + op, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + use_libdevice_for_f64=True, + ) + + +exp = register_pointwise_numeric_ldf64(aten.exp) +exp2 = register_pointwise_numeric(aten.exp2) +expm1 = register_pointwise_numeric(aten.expm1) +relu = register_pointwise(aten.relu) +sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid) +sqrt = register_pointwise_numeric_ldf64(aten.sqrt) +square = register_pointwise(aten.square) +sub = register_pointwise(aten.sub, allow_alpha=True) +register_pointwise_numeric_ldf64(aten.cos) +register_pointwise_numeric_ldf64(aten.sin) +abs = register_pointwise(aten.abs) +bitwise_and = register_pointwise(aten.bitwise_and) +bitwise_left_shift = register_pointwise(aten.bitwise_left_shift) +bitwise_not = register_pointwise( + aten.bitwise_not, override_fn_when_input_bool="logical_not" +) +bitwise_or = register_pointwise(aten.bitwise_or) +bitwise_right_shift = register_pointwise(aten.bitwise_right_shift) +bitwise_xor = register_pointwise(aten.bitwise_xor) +register_pointwise_numeric(aten.lgamma) +erf = register_pointwise_numeric(aten.erf) +register_lowering( + aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +)(erf) + +register_pointwise_numeric(aten.log1p) +register_pointwise_numeric(aten.tan) +register_pointwise_numeric(aten.tanh) +register_pointwise_numeric_ldf64(aten.log) +logical_and = register_pointwise( + aten.logical_and, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_not = register_pointwise( + aten.logical_not, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_or = register_pointwise( + aten.logical_or, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_xor = register_pointwise( + aten.logical_xor, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +maximum = register_pointwise(aten.maximum) +minimum = register_pointwise(aten.minimum) +register_lowering(aten.clamp_min)(maximum) +register_lowering(aten.clamp_max)(minimum) +neg = register_pointwise(aten.neg) +abs = register_pointwise(aten.abs) +reciprocal = register_pointwise_numeric(aten.reciprocal) +register_pointwise(aten.remainder) +sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity") +register_pointwise(aten.ceil) +register_pointwise(aten.signbit, override_return_dtype=torch.bool) + +register_lowering(aten._neg_view)(neg) + +register_pointwise(aten.le, override_return_dtype=torch.bool) +register_pointwise(aten.lt, override_return_dtype=torch.bool) +register_pointwise(aten.ge, override_return_dtype=torch.bool) +gt = register_pointwise(aten.gt, override_return_dtype=torch.bool) +register_pointwise(aten.eq, override_return_dtype=torch.bool) +register_pointwise(aten.ne, override_return_dtype=torch.bool) + +register_pointwise_numeric(aten.cosh) +register_pointwise_numeric(aten.sinh) +register_pointwise_numeric(aten.acos) +register_pointwise_numeric(aten.acosh) +register_pointwise_numeric(aten.asin) +register_pointwise_numeric(aten.asinh) +register_pointwise_numeric(aten.atan2) +register_pointwise_numeric(aten.atan) +register_pointwise_numeric(aten.atanh) +register_pointwise_numeric(aten.copysign) +register_pointwise_numeric(aten.erfc) +register_pointwise_numeric(aten.erfinv) +register_pointwise_numeric(aten.hypot) +register_pointwise_numeric(aten.log10) +register_pointwise_numeric(aten.nextafter) + +from .codegen.common import pointwise_overrides_data + + +def _get_pointwise_overrides(ns, name): + data = pointwise_overrides_data[name] + op = getattr(ns, data.name, None) + if op is None: + return + + def make_triton_fallback(op): + if data.triton is None: + return fallback_handler(op) + + if isinstance(op, torch._ops.OpOverloadPacket): + for olname in op.overloads(): + ol = getattr(op, olname) + yield ol, data.type_promotion_kind, make_triton_fallback(ol) + else: + yield op, data.type_promotion_kind, make_triton_fallback(op) + + +for name in pointwise_overrides_data: + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + aten, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + prims, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + +foreach_add_list = register_foreach_pointwise( + aten._foreach_add.List, add, allow_alpha=True +) +foreach_add_scalar = register_foreach_pointwise( + aten._foreach_add.Scalar, add, allow_alpha=True +) +register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True) +foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul) +foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul) +register_foreach_pointwise(aten._foreach_sub.List, sub) +register_foreach_pointwise(aten._foreach_sub.Scalar, sub) +register_foreach_pointwise(aten._foreach_neg.default, neg) +register_foreach_pointwise(aten._foreach_abs.default, abs) +register_foreach_pointwise(aten._foreach_pow.Scalar, pow) +register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow) +foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div) +foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div) +register_foreach_pointwise(aten._foreach_sqrt, sqrt) +register_foreach_pointwise(aten._foreach_maximum.List, maximum) +register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum) +register_foreach_pointwise(aten._foreach_minimum.List, minimum) +register_foreach_pointwise(aten._foreach_minimum.Scalar, minimum) +register_foreach_pointwise(aten._foreach_clamp_min.List, maximum) +register_foreach_pointwise(aten._foreach_clamp_min.Scalar, maximum) +register_foreach_pointwise(aten._foreach_clamp_max.List, minimum) +register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum) +register_foreach_pointwise(aten._foreach_reciprocal, reciprocal) +register_foreach_pointwise(aten._foreach_sign, sign) +register_foreach_pointwise(aten._foreach_copy, copy) + + +# these are only encountered as outputs of the graph +# reinplacing epilogue copies improves compile time +# by removing extra buffers sent to the scheduler. +def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op): + inplaceable_foreach_ops[outplace_aten_op] = aten_op + inplace_foreach_ops.add(aten_op) + + def fn(*args, **kwargs): + results = outplace_op(*args, **kwargs) + mut_results = [] + for arg, result in zip(args[0], results): + mut_results.append(mutate_to(arg, result, unsafe_alias=True)) + + return mut_results + + _register_foreach_lowering(aten_op, fn) + + +register_foreach_inplace( + aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list +) +register_foreach_inplace( + aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar +) +register_foreach_inplace( + aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list +) +register_foreach_inplace( + aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar +) +register_foreach_inplace( + aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list +) +register_foreach_inplace( + aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar +) + + +def register_inplace(aten_op, outplace_op): + @register_lowering(aten_op, type_promotion_kind=None) + def fn(*args, **kwargs): + result = outplace_op(*args, **kwargs) + result = to_dtype(result, args[0].get_dtype()) + return mutate_to(args[0], result) + + return fn + + +register_inplace(aten.add_, add) +register_inplace(aten.bitwise_and_, bitwise_and) +register_inplace(aten.bitwise_left_shift_, bitwise_left_shift) +register_inplace(aten.bitwise_not_, bitwise_not) +register_inplace(aten.bitwise_or_, bitwise_or) +register_inplace(aten.bitwise_right_shift_, bitwise_right_shift) +register_inplace(aten.bitwise_xor_, bitwise_xor) +register_inplace(aten.mul_, mul) +register_inplace(aten.div_.Tensor, div) +register_inplace(aten.div_.Tensor_mode, div_mode) +register_inplace(aten.logical_and_, logical_and) +register_inplace(aten.logical_not_, logical_not) +register_inplace(aten.logical_or_, logical_or) +register_inplace(aten.logical_xor_, logical_xor) +register_inplace(aten.sub_, sub) +register_inplace(aten.relu_, relu) +register_inplace(aten.sigmoid_, sigmoid) + + +register_lowering(aten.__and__)(bitwise_and) +register_lowering(aten.__lshift__)(bitwise_left_shift) +register_lowering(aten.__or__)(bitwise_or) +register_lowering(aten.__rshift__)(bitwise_right_shift) +register_lowering(aten.__xor__)(bitwise_xor) + +register_inplace(aten.__iand__, aten.__and__) +register_inplace(aten.__ilshift__, aten.__lshift__) +register_inplace(aten.__ior__, aten.__or__) +register_inplace(aten.__irshift__, aten.__rshift__) +register_inplace(aten.__ixor__, aten.__xor__) + + +@register_lowering(aten.sym_constrain_range) +def sym_constrain_range(a, min=None, max=None): + tracing_context = torch._guards.TracingContext.try_get() + assert ( + tracing_context is None or a in tracing_context.fake_mode.shape_env.var_to_range + ) + return a + + +@register_lowering(aten.sym_size.int) +def sym_size(a, dim): + val = V.graph.current_node.meta["val"] + # Note [Can val be an int?] + # ~~~~~~~~~~~~~~~~~~~~~~~~~ + # In principle, someone could construct an FX graph where + # a call to size/stride has a val that is a plain int (not + # SymInt). However, we will maintain the invariant that + # this is not possible: if you are constructing an FX graph + # where there is a call to size/stride that returns an + # int, but you KNOW that int must always be a constant, + # then you do not need trace that call at all (and just + # constant propagate the integer as is.) + assert isinstance(val, torch.SymInt) + return val.node.expr + + +@register_lowering(aten.sym_stride.int) +def sym_stride(a, dim): + val = V.graph.current_node.meta["val"] + # See Note [Can val be an int?] + assert isinstance(val, torch.SymInt) + return val.node.expr + + +@register_lowering(aten.sym_numel) +def sym_numel(a): + return a.get_numel() + + +for method, func in magic_methods.items(): + register_lowering(method_to_operator(method))(func) + + +@register_lowering(aten._foobar) +def foobar(self, *args, **kwargs): + raise NotImplementedError("Helpful for debugging") + + +@register_lowering(torch.ops._inductor_test.realize) +def _realize(x): + x.realize() + return clone(x) + + +@register_lowering(torch.ops.inductor.resize_storage_bytes_) +def resize_storage_bytes_(variable, new_size): + variable.realize() + ir.ResizeStorageBytes(variable, new_size) + return variable + + +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +make_fallback(auto_functionalized) + + +@register_lowering(triton_kernel_wrapper_mutation) +def triton_kernel_wrap_(*, kernel_idx, grid, kwargs): + ir.UserDefinedTritonKernel(kernel_idx=kernel_idx, grid=grid, kernel_args=kwargs) + return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)} + + +@register_lowering(triton_kernel_wrapper_functional) +def triton_kernel_wrap(*, kernel_idx, grid, kwargs, tensors_to_clone): + new_kwargs = {} + for name, value in kwargs.items(): + if isinstance(value, ir.TensorBox): + x = value.data + has_non_rv_views = False + while isinstance(x, ir.BaseView): + if not isinstance(x, ir.ReinterpretView): + has_non_rv_views = True + break + x = x.data + if has_non_rv_views: + # we realize the inputs wrapped into any view which is not + # ReinterpretView to convert them into ReinterpretView during + # realization; all views being ReinterpretView is assumed by + # the downstream code (e.g., preserving ReinterpretView in + # cloning; layout should be available in mutation marking) + value = ir.TensorBox(ir.ExternKernel.realize_input(value)) + if name in tensors_to_clone: + value = clone_preserve_reinterpret_view(value) + new_kwargs[name] = value + + return triton_kernel_wrap_(kernel_idx=kernel_idx, grid=grid, kwargs=new_kwargs) + + +@register_lowering(torch.ops.higher_order.cond) +def cond(pred, true_fn, false_fn, operands): + if is_triton(pred) or any(map(is_triton, operands)): + msg = "control flow operator: torch.cond." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + result = ir.Conditional.create(pred, true_fn, false_fn, operands) + return list(map(TensorBox.create, result)) + + +try: + import torch.distributed._functional_collectives + + c10d_functional = torch.ops.c10d_functional + + @register_lowering(c10d_functional.wait_tensor) + def wait(input): + return TensorBox.create(ir.Wait.create(input)) + + @register_lowering(c10d_functional.broadcast) + def broadcast(input, src, tag, ranks, group_size): + return ir.Broadcast.create(input, src, tag, ranks, group_size) + + @register_lowering(c10d_functional.all_reduce) + def allreduce(input, reduce_op, tag, ranks, group_size): + return ir.AllReduce.create(input, reduce_op, tag, ranks, group_size) + + @register_lowering(c10d_functional.all_gather_into_tensor) + def all_gather_into_tensor(shard, tag, ranks, group_size): + return TensorBox.create( + ir.AllGatherIntoTensor.create( + ir.ExternKernel.require_contiguous(shard), tag, ranks, group_size + ) + ) + + @register_lowering(c10d_functional.reduce_scatter_tensor) + def reduce_scatter_tensor(input, reduce_op, tag, ranks, group_size): + return TensorBox.create( + ir.ReduceScatterTensor.create(input, reduce_op, tag, ranks, group_size) + ) + + @register_lowering(c10d_functional.all_reduce_coalesced) + def all_reduce_coalesced(input, reduce_op, tag, ranks, group_size): + return ir.AllReduceCoalesced.create(input, reduce_op, tag, ranks, group_size) + + @register_lowering(c10d_functional.all_gather_into_tensor_coalesced) + def all_gather_into_tensor_coalesced(self, tag, ranks, group_size): + result = ir.AllGatherIntoTensorCoalesced.create(self, tag, ranks, group_size) + return list(map(TensorBox.create, result)) + + @register_lowering(c10d_functional.reduce_scatter_tensor_coalesced) + def reduce_scatter_tensor_coalesced(self, reduceOp, tag, ranks, group_size): + result = ir.ReduceScatterTensorCoalesced.create( + self, reduceOp, tag, ranks, group_size + ) + return list(map(TensorBox.create, result)) + + @register_lowering(c10d_functional.all_to_all_single) + def all_to_all_single( + self, output_split_sizes, input_split_sizes, tag, ranks, group_size + ): + return TensorBox.create( + ir.AllToAllSingle.create( + self, output_split_sizes, input_split_sizes, tag, ranks, group_size + ) + ) + + _c10d_functional = torch.ops._c10d_functional + + @register_lowering(_c10d_functional.all_reduce) + def _all_reduce(inp, reduce_op, group_name): + inp = clone(inp) + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_.default, inp, reduce_op, group_name + ) + return inp + + @register_lowering(_c10d_functional.all_reduce_) + def _all_reduce_(inp, reduce_op, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_.default, inp, reduce_op, group_name + ) + return inp + + @register_lowering(_c10d_functional.all_reduce_coalesced) + def _all_reduce_coalesced(inputs, reduce_op, group_name): + inputs = [clone(inp) for inp in inputs] + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + @register_lowering(_c10d_functional.all_reduce_coalesced_) + def _all_reduce_coalesced_(inputs, reduce_op, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + @register_lowering(_c10d_functional.all_gather_into_tensor) + def _all_gather_into_tensor(inp, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_gather_into_tensor.default, + inp, + group_size, + group_name, + ) + ) + + @register_lowering(_c10d_functional.all_gather_into_tensor_coalesced) + def _all_gather_into_tensor_coalesced(inputs, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_gather_into_tensor_coalesced.default, + inputs, + group_size, + group_name, + ), + ) + + @register_lowering(_c10d_functional.reduce_scatter_tensor) + def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.reduce_scatter_tensor.default, + inp, + reduce_op, + group_size, + group_name, + ) + ) + + @register_lowering(_c10d_functional.reduce_scatter_tensor_coalesced) + def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.reduce_scatter_tensor_coalesced.default, + inputs, + reduce_op, + group_size, + group_name, + ), + ) + + @register_lowering(_c10d_functional.all_to_all_single) + def _all_to_all_single(inp, output_split_sizes, input_split_sizes, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_to_all_single.default, + inp, + output_split_sizes, + input_split_sizes, + group_name, + ) + ) + + @register_lowering(_c10d_functional.broadcast) + def _broadcast(inp, src, group_name): + inp = clone(inp) + ir._CollectiveKernel.create_inplace( + _c10d_functional.broadcast_.default, inp, src, group_name + ) + return inp + + @register_lowering(_c10d_functional.broadcast_) + def _broadcast_(inp, src, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.broadcast_.default, inp, src, group_name + ) + return inp + + @register_lowering(_c10d_functional.wait_tensor) + def _wait_tensor(inp): + ir._WaitKernel.create_wait(_c10d_functional.wait_tensor.default, inp) + return inp + +except ImportError: + log.info( + "Inductor support for distributed collectives depends on building torch.distributed" + ) + +# populate lowerings defined in kernel/* +from . import kernel + +import_submodule(kernel) + +from . import quantized_lowerings + +quantized_lowerings.register_quantized_ops() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/test_case.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/test_case.py new file mode 100644 index 0000000000000000000000000000000000000000..546524d900e81aa930c8e847c36c0013e0d39f26 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/test_case.py @@ -0,0 +1,53 @@ +import contextlib +import tempfile +import unittest + +from torch._dynamo.test_case import ( + run_tests as dynamo_run_tests, + TestCase as DynamoTestCase, +) + +from torch._inductor import config + + +def run_tests(needs=()): + dynamo_run_tests(needs) + + +class TestCase(DynamoTestCase): + """ + A base TestCase for inductor tests. Enables FX graph caching and isolates + the cache directory for each test. + """ + + _stack: contextlib.ExitStack + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._stack = contextlib.ExitStack() + cls._stack.enter_context(config.patch({"fx_graph_cache": True})) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._stack.close() + + def setUp(self): + super().setUp() + + # For all tests, mock the tmp directory populated by the inductor + # FxGraphCache, both for test isolation and to avoid filling disk. + self._inductor_cache_tmp_dir = tempfile.TemporaryDirectory() + self._inductor_cache_get_tmp_dir_patch = unittest.mock.patch( + "torch._inductor.codecache.FxGraphCache._get_tmp_dir" + ) + mock_get_dir = self._inductor_cache_get_tmp_dir_patch.start() + mock_get_dir.return_value = self._inductor_cache_tmp_dir.name + + def tearDown(self): + super().tearDown() + + # Clean up the FxGraphCache tmp dir. + self._inductor_cache_get_tmp_dir_patch.stop() + self._inductor_cache_tmp_dir.cleanup() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40581c892fc0c7197e4409916d741d7be636d7e6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/jiterator.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/jiterator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16f8b47527e3da2ee84a140d9afd739e5717e051 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/jiterator.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/nccl.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/nccl.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1228e7424df5bff7cf3baee9b320abca772743c5 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/nccl.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/random.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/random.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0053fa92b771417172753718308454f06b1095e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/random.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/streams.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/streams.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d94d4d52966f9a02bd43a5bf032a540e9df8e98 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/streams.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py new file mode 100644 index 0000000000000000000000000000000000000000..a862acd73184733dca0c811204456adc21394200 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py @@ -0,0 +1,626 @@ +import pickle +import sys +import os +import io +import subprocess +import json +from functools import lru_cache +from typing import Any +from itertools import groupby +import base64 +import warnings + +cache = lru_cache(None) + +__all__ = ["format_flamegraph", "segments", "memory", "compare"] + +def _frame_fmt(f, full_filename=False): + i = f['line'] + fname = f['filename'] + if not full_filename: + fname = fname.split('/')[-1] + func = f['name'] + return f'{fname}:{i}:{func}' + +@cache +def _frame_filter(name, filename): + omit_functions = [ + "unwind::unwind", + "CapturedTraceback::gather", + "gather_with_cpp", + "_start", + "__libc_start_main", + "PyEval_", + "PyObject_", + "PyFunction_", + ] + omit_filenames = [ + "core/boxing", + "/Register", + "/Redispatch", + "pythonrun.c", + "Modules/main.c", + "Objects/call.c", + "Objects/methodobject.c", + "pycore_ceval.h", + "ceval.c", + "cpython/abstract.h", + ] + for of in omit_functions: + if of in name: + return False + for of in omit_filenames: + if of in filename: + return False + return True + +def _frames_fmt(frames, full_filename=False, reverse=False): + if reverse: + frames = reversed(frames) + return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f['name'], f['filename'])] + +def _block_extra_legacy(b): + if 'history' in b: + frames = b['history'][0].get('frames', []) + real_size = b['history'][0]['real_size'] + else: + real_size = b.get('requested_size', b['size']) + frames = [] + return frames, real_size + +def _block_extra(b): + if 'frames' not in b: + # old snapshot format made it more complicated to get frames/allocated size + return _block_extra_legacy(b) + return b['frames'], b['requested_size'] + +def format_flamegraph(flamegraph_lines, flamegraph_script=None): + if flamegraph_script is None: + flamegraph_script = f'/tmp/{os.getuid()}_flamegraph.pl' + if not os.path.exists(flamegraph_script): + import urllib.request + print(f"Downloading flamegraph.pl to: {flamegraph_script}") + urllib.request.urlretrieve( + 'https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl', flamegraph_script) + subprocess.check_call(['chmod', '+x', flamegraph_script]) + args = [flamegraph_script, '--countname', 'bytes'] + p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8') + assert p.stdin is not None + assert p.stdout is not None + p.stdin.write(flamegraph_lines) + p.stdin.close() + result = p.stdout.read() + p.stdout.close() + p.wait() + assert p.wait() == 0 + return result + +def _write_blocks(f, prefix, blocks): + def frames_fragment(frames): + if not frames: + return "" + return ';'.join(_frames_fmt(frames, reverse=True)) + for b in blocks: + if 'history' not in b: + frames, accounted_for_size = _block_extra(b) + f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n') + else: + accounted_for_size = 0 + for h in b['history']: + sz = h['real_size'] + accounted_for_size += sz + if 'frames' in h: + frames = h['frames'] + f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n') + else: + f.write(f'{prefix};{b["state"]}; {sz}\n') + gaps = b['size'] - accounted_for_size + if gaps: + f.write(f'{prefix};{b["state"]}; {gaps}\n') + +def segments(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot['segments']: + prefix = f'stream_{seg["stream"]};seg_{seg["address"]}' + _write_blocks(f, prefix, seg['blocks']) + return format_flamegraph(f.getvalue()) + +def memory(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot['segments']: + prefix = f'stream_{seg["stream"]}' + _write_blocks(f, prefix, seg['blocks']) + return format_flamegraph(f.getvalue()) + +def compare(before, after, format_flamegraph=format_flamegraph): + def _seg_key(seg): + return (seg['address'], seg['total_size']) + + def _seg_info(seg): + return f'stream_{seg["stream"]};seg_{seg["address"]}' + + f = io.StringIO() + + before_segs = {_seg_key(seg) for seg in before} + after_segs = {_seg_key(seg) for seg in after} + + print(f'only_before = {[a for a,_ in (before_segs - after_segs)]}') + print(f'only_after = {[a for a,_ in (after_segs - before_segs)]}') + + for seg in before: + if _seg_key(seg) not in after_segs: + _write_blocks(f, f'only_before;{_seg_info(seg)}', seg['blocks']) + + for seg in after: + if _seg_key(seg) not in before_segs: + _write_blocks(f, f'only_after;{_seg_info(seg)}', seg['blocks']) + + return format_flamegraph(f.getvalue()) + +def _format_size(num): + # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: + if abs(num) < 1024.0: + return f"{num:3.1f}{unit}B" + num /= 1024.0 + return f"{num:.1f}YiB" + +class Bytes: + def __init__(self, value): + self.value = value + + def __add__(self, rhs): + return Bytes(self.value + rhs) + + def __repr__(self): + return _format_size(self.value) + +def calc_active(seg): + return sum(b['size'] for b in seg['blocks'] if b['state'] == 'active_allocated') + +def _report_free(free_external, free_internal): + total = free_external + free_internal + suffix = '' + if total != 0: + pct = (free_internal / total) * 100 + suffix = f' ({pct:.1f}% internal)' + return f'{Bytes(total)}{suffix}' + +PAGE_SIZE = 1024 * 1024 * 20 +legend = f"""\ + +Legend: + [a ] - a segment in the allocator + ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment + a-z: pages filled with a single block's content + ' ': page is completely free + *: page if completely full with multiple blocks + 0-9: page is partially full with tensors of multiple blocks (9 == 90% full) + (X% internal) - of the free memory, X% is free because we rounded the size of the allocation. +""" + +def segsum(data): + r"""Visually reports how the allocator has filled its segments. + + This printout can help debug fragmentation issues since free fragments + will appear as gaps in this printout. The amount of free space is reported + for each segment. + We distinguish between internal free memory which occurs because the + allocator rounds the allocation size, and external free memory, which are + the gaps between allocations in a segment. + Args: + data: snapshot dictionary created from _snapshot() + """ + segments = [] + out = io.StringIO() + out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n") + total_reserved = 0 + total_allocated = 0 + free_external = 0 + free_internal = 0 + for seg in sorted(data['segments'], key=lambda x: (x['total_size'], calc_active(x))): + total_reserved += seg['total_size'] + + seg_free_external = 0 + seg_free_internal = 0 + seg_allocated = 0 + all_ranges = [] + boffset = 0 + for b in seg['blocks']: + active = b['state'] == 'active_allocated' + if active: + _, allocated_size = _block_extra(b) + all_ranges.append((boffset, allocated_size, True)) + seg_allocated += allocated_size + seg_free_internal += b['size'] - allocated_size + else: + seg_free_external += b['size'] + + boffset += b['size'] + + total_allocated += seg_allocated + free_external += seg_free_external + free_internal += seg_free_internal + + nseg = (seg['total_size'] - 1) // PAGE_SIZE + 1 + occupied = [' ' for _ in range(nseg)] + frac = [0.0 for _ in range(nseg)] + active_size = 0 + for i, (start_, size, active) in enumerate(all_ranges): + active_size += size + finish_ = (start_ + size) + start = start_ // PAGE_SIZE + finish = (finish_ - 1) // PAGE_SIZE + 1 + m = chr(ord('a' if active else 'A') + (i % 26)) + for j in range(start, finish): + s = max(start_, j * PAGE_SIZE) + e = min(finish_, (j + 1) * PAGE_SIZE) + frac[j] += (e - s) / PAGE_SIZE + if occupied[j] != ' ': + occupied[j] = '0123456789*'[int(frac[j] * 10)] + else: + occupied[j] = m + stream = '' if seg['stream'] == 0 else f', stream_{seg["stream"]}' + body = ''.join(occupied) + assert seg_free_external + seg_free_internal + seg_allocated == seg['total_size'] + stream = f' stream_{seg["stream"]}' if seg['stream'] != 0 else '' + if seg['total_size'] >= PAGE_SIZE: + out.write(f'[{body}] {Bytes(seg["total_size"])} allocated, ' + f'{_report_free(seg_free_external, seg_free_internal)} free{stream}\n') + out.write(f'segments: {len(data["segments"])}\n') + out.write(f'total_reserved: {Bytes(total_reserved)}\n') + out.write(f'total_allocated: {Bytes(total_allocated)}\n') + internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else '' + out.write(f'total_free: {_report_free(free_external, free_internal)}\n') + out.write(legend) + assert free_internal + free_external + total_allocated == total_reserved + return out.getvalue() + +def trace(data): + out = io.StringIO() + + def format(entries): + segment_intervals : list = [] + segment_addr_to_name = {} + allocation_addr_to_name = {} + + free_names : list = [] + next_name = 0 + + def _name(): + nonlocal next_name + if free_names: + return free_names.pop() + r, m = next_name // 26, next_name % 26 + next_name += 1 + return f'{chr(ord("a") + m)}{"" if r == 0 else r}' + + def find_segment(addr): + for name, saddr, size in segment_intervals: + if addr >= saddr and addr < saddr + size: + return name, saddr + for i, seg in enumerate(data['segments']): + saddr = seg['address'] + size = seg['allocated_size'] + if addr >= saddr and addr < saddr + size: + return f'seg_{i}', saddr + return None, None + count = 0 + out.write(f'{len(entries)} entries\n') + + + total_reserved = 0 + for seg in data['segments']: + total_reserved += seg['total_size'] + + for count, e in enumerate(entries): + if e['action'] == 'alloc': + addr, size = e['addr'], e['size'] + n = _name() + seg_name, seg_addr = find_segment(addr) + if seg_name is None: + seg_name = "MEM" + offset = addr + else: + offset = addr - seg_addr + out.write(f'{n} = {seg_name}[{offset}:{Bytes(size)}]\n') + allocation_addr_to_name[addr] = (n, size, count) + count += size + elif e['action'] == 'free_requested': + addr, size = e['addr'], e['size'] + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f'del {name} # {Bytes(size)}\n') + elif e['action'] == 'free_completed': + addr, size = e['addr'], e['size'] + count -= size + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f'# free completed for {name} {Bytes(size)}\n') + if name in allocation_addr_to_name: + free_names.append(name) + del allocation_addr_to_name[name] + elif e['action'] == 'segment_alloc': + addr, size = e['addr'], e['size'] + name = _name() + out.write(f'{name} = cudaMalloc({addr}, {Bytes(size)})\n') + segment_intervals.append((name, addr, size)) + segment_addr_to_name[addr] = name + elif e['action'] == 'segment_free': + addr, size = e['addr'], e['size'] + name = segment_addr_to_name.get(addr, addr) + out.write(f'cudaFree({name}) # {Bytes(size)}\n') + if name in segment_addr_to_name: + free_names.append(name) + del segment_addr_to_name[name] + elif e['action'] == 'oom': + size = e['size'] + free = e['device_free'] + out.write(f'raise OutOfMemoryError() # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n') + else: + out.write(f'{e}\n') + out.write(f"TOTAL MEM: {Bytes(count)}") + for i, d in enumerate(data['device_traces']): + if d: + out.write(f'Device {i} ----------------\n') + format(d) + return out.getvalue() + + +_memory_viz_template = r""" + + + + + + + +""" + +def _format_viz(data, viz_kind, device): + if device is not None: + warnings.warn('device argument is deprecated, plots now contain all device') + buffer = pickle.dumps(data) + buffer += b'\x00' * (3 - len(buffer) % 3) + # Encode the buffer with base64 + encoded_buffer = base64.b64encode(buffer).decode('utf-8') + + json_format = json.dumps([{"name": 'snapshot.pickle', "base64": encoded_buffer}]) + return _memory_viz_template.replace('$VIZ_KIND', repr(viz_kind)) \ + .replace('$SNAPSHOT', json_format) + +def trace_plot(data, device=None, plot_segments=False): + """Generate a visualization over time of the memory usage recorded by the trace as an html file. + + Args: + data: Memory snapshot as generated from torch.cuda.memory._snapshot() + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations. + Defaults to False. + + Returns: + str: HTML of visualization + """ + return _format_viz(data, 'Active Memory Timeline' if not plot_segments else 'Active Cached Memory Timeline', device) + + +def _profile_to_snapshot(profile): + import torch + from torch.profiler._memory_profiler import Action, TensorKey + from torch._C._profiler import _EventType + memory_profile = profile._memory_profile() + + allocation_stacks = {} + for event in memory_profile._op_tree.sorted_nodes: + if event.tag == _EventType.Allocation: + parent = event.parent + python_parents = [] + while parent: + if parent.tag in (_EventType.PyCall, _EventType.PyCCall): + python_parents.append(parent) + parent = parent.parent + key = TensorKey.from_allocation(event.extra_fields) + + # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor) + # key will be None. I should add some way to identify these, I just haven't yet. + if key and event.extra_fields.alloc_size > 0: + allocation_stacks[key] = python_parents + + + device_count = torch.cuda.device_count() + snapshot = { + 'device_traces': [[] for _ in range(device_count + 1)], + 'segments': [{'device': device, + 'address': None, + 'total_size': 0, + 'stream': 0, + 'blocks': []} for device in range(device_count + 1)] + } + + def to_device(device): + if device.type == 'cuda': + return device.index + else: + return device_count + + def allocate(size, tensor_key, version, during_trace=True): + device = to_device(tensor_key.device) + addr = tensor_key.storage.ptr + + seg = snapshot['segments'][device] # type: ignore[index] + if seg['address'] is None or seg['address'] > addr: + seg['address'] = addr + seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later + category = memory_profile._categories.get(tensor_key, version) + category = category.name.lower() if category is not None else "unknown" + stack = allocation_stacks.get(tensor_key, ()) + stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack] + r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category} + if during_trace: + snapshot['device_traces'][device].append(r) # type: ignore[index] + return r + + def free(alloc, device): + for e in ('free_requested', 'free_completed'): + snapshot['device_traces'][device].append({'action': e, # type: ignore[index] + 'addr': alloc['addr'], + 'size': alloc['size'], + 'stream': 0, + 'frames': alloc['frames']}) + + kv_to_elem = {} + + + + # create the device trace + for time, action, (tensor_key, version), size in memory_profile.timeline: + if not isinstance(tensor_key, TensorKey): + continue + if action == Action.CREATE: + kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version) + elif action == Action.DESTROY: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + elif action == Action.INCREMENT_VERSION: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + kv_to_elem[(tensor_key, version + 1)] = allocate(size, tensor_key, version + 1) + elif action == Action.PREEXISTING: + kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version, during_trace=False) + + + # create the final snapshot state + blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames']) + for (tensor_key, version), event in kv_to_elem.items()] + for device, blocks in groupby(sorted(blocks_at_end), key=lambda x: x[0]): + seg = snapshot['segments'][device] # type: ignore[index] + last_addr = seg['address'] + for _, addr, size, frames in blocks: + if last_addr < addr: + seg['blocks'].append({'size': addr - last_addr, 'state': 'inactive'}) + seg['blocks'].append({'size': size, 'state': 'active_allocated', 'requested_size': size, 'frames': frames}) + last_addr = addr + size + if last_addr < seg['total_size']: + seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'}) + + snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined] + for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef] + seg['total_size'] -= seg['address'] + if not seg['blocks']: + seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'}) + + return snapshot + +def profile_plot(profile, device=None): + """Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file. + + Args: + profile: profile as generated by `torch.profiler.profile(profile_memory=True)` + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + + Returns: + str: HTML of visualization + """ + snapshot = _profile_to_snapshot(profile) + return _format_viz(snapshot, 'Active Memory Timeline', device) + + +def segment_plot(data: Any, device=None): + return _format_viz(data, 'Allocator State History', device) + +if __name__ == "__main__": + import os.path + thedir = os.path.realpath(os.path.dirname(__file__)) + if thedir in sys.path: + # otherwise we find cuda/random.py as random... + sys.path.remove(thedir) + import argparse + + fn_name = 'torch.cuda.memory._snapshot()' + pickled = f'pickled memory statistics from {fn_name}' + parser = argparse.ArgumentParser(description=f'Visualize memory dumps produced by {fn_name}') + + subparsers = parser.add_subparsers(dest='action') + + def _output(p): + p.add_argument('-o', '--output', default='output.svg', help='flamegraph svg (default: output.svg)') + + description = 'Prints overall allocation statistics and a visualization of how the allocators segments are currently filled.' + stats_a = subparsers.add_parser('stats', description=description) + stats_a.add_argument('input', help=pickled) + + description = 'Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style.' + trace_a = subparsers.add_parser('trace', description=description) + trace_a.add_argument('input', help=pickled) + + description = 'Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)' + segments_a = subparsers.add_parser('segments', description=description) + segments_a.add_argument('input', help=pickled) + _output(segments_a) + + description = "Generate a flamegraph the program locations contributing to CUDA memory usage." + memory_a = subparsers.add_parser('memory', description=description) + memory_a.add_argument('input', help=pickled) + _output(memory_a) + + description = 'Generate a flamegraph that shows segments (aka blocks) that have been added ' \ + 'or removed between two different memorys snapshots.' + compare_a = subparsers.add_parser('compare', description=description) + compare_a.add_argument('before', help=pickled) + compare_a.add_argument('after', help=pickled) + _output(compare_a) + + plots = ( + ("trace_plot", "Generate a visualization over time of the memory usage recorded by the trace as an html file."), + ("segment_plot", "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.") + ) + for cmd, description in plots: + trace_plot_a = subparsers.add_parser(cmd, description=description) + trace_plot_a.add_argument('input', help=pickled) + help = 'visualize trace from this device (default: chooses the only device with trace info or errors)' + trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help) + help = 'path to save the visualization(default: output.html)' + trace_plot_a.add_argument('-o', '--output', default='output.html', help=help) + if cmd == "trace_plot": + help = 'visualize change to segments rather than individual allocations' + trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help) + + + args = parser.parse_args() + + def _read(name): + if name == '-': + f = sys.stdin.buffer + else: + f = open(name, 'rb') + data = pickle.load(f) + if isinstance(data, list): # segments only... + data = {'segments': data, 'traces': []} + return data + + def _write(name, data): + with open(name, 'w') as f: + f.write(data) + + if args.action == 'segments': + data = _read(args.input) + _write(args.output, segments(data)) + elif args.action == 'memory': + data = _read(args.input) + _write(args.output, memory(data)) + elif args.action == 'stats': + data = _read(args.input) + print(segsum(data)) + elif args.action == 'trace': + data = _read(args.input) + print(trace(data)) + elif args.action == 'compare': + before = _read(args.before) + after = _read(args.after) + _write(args.output, compare(before, after)) + elif args.action == 'trace_plot': + data = _read(args.input) + _write(args.output, trace_plot(data, device=args.device, plot_segments=args.segments)) + elif args.action == 'segment_plot': + data = _read(args.input) + _write(args.output, segment_plot(data, device=args.device)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1d0ee8830bd68c23115e7788bb7e1a0c220b1882 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_utils.py @@ -0,0 +1,38 @@ +from typing import Any + +import torch + +# The _get_device_index has been moved to torch.utils._get_device_index +from torch._utils import _get_device_index as _torch_get_device_index + + +def _get_device_index( + device: Any, optional: bool = False, allow_cpu: bool = False +) -> int: + r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. + + If :attr:`device` is a torch.device object, returns the device index if it + is a CUDA device. Note that for a CUDA device without a specified index, + i.e., ``torch.device('cuda')``, this will return the current default CUDA + device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, + CPU devices will be accepted and ``-1`` will be returned in this case. + + If :attr:`device` is a Python integer, it is returned as is. + + If :attr:`device` is ``None``, this will return the current default CUDA + device if :attr:`optional` is ``True``. + """ + if isinstance(device, int): + return device + if isinstance(device, str): + device = torch.device(device) + if isinstance(device, torch.device): + if allow_cpu: + if device.type not in ["cuda", "cpu"]: + raise ValueError(f"Expected a cuda or cpu device, but got: {device}") + elif device.type != "cuda": + raise ValueError(f"Expected a cuda device, but got: {device}") + if not torch.jit.is_scripting(): + if isinstance(device, torch.cuda.device): + return device.idx + return _torch_get_device_index(device, optional, allow_cpu) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19f0e09d4eddf7b1078bf07a05db7694cb1a59d6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/common.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dd399f8096609f37ee7545e7046500ab0bc99f6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/common.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/common.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/common.py new file mode 100644 index 0000000000000000000000000000000000000000..c4e8c1cc99b00d63672e12f2908a82c899076306 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/common.py @@ -0,0 +1,9 @@ +from importlib.util import find_spec + +import torch + +__all__ = ["amp_definitely_not_available"] + + +def amp_definitely_not_available(): + return not (torch.cuda.is_available() or find_spec("torch_xla")) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/grad_scaler.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebaa9bced2ca0a2758cc4211308b2ad51437833 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/grad_scaler.py @@ -0,0 +1,28 @@ +import torch +from torch.amp.grad_scaler import OptState + +__all__ = ["GradScaler", "OptState"] + + +class GradScaler(torch.amp.GradScaler): + r""" + See :class:`torch.amp.GradScaler`. + ``torch.cuda.amp.GradScaler(args...)`` is equivalent to ``torch.amp.GradScaler("cuda", args...)`` + """ + + def __init__( + self, + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + enabled: bool = True, + ) -> None: + super().__init__( + "cuda", + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + enabled=enabled, + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/_compatibility.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/_compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..14588fad9a09e1c307c475bda7c551d801dbd731 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/_compatibility.py @@ -0,0 +1,34 @@ +from typing import Any, Dict +import textwrap + +_BACK_COMPAT_OBJECTS : Dict[Any, None] = {} +_MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {} + +def compatibility(is_backward_compatible : bool): + if is_backward_compatible: + + def mark_back_compat(fn): + docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') + docstring += """ +.. note:: + Backwards-compatibility for this API is guaranteed. +""" + fn.__doc__ = docstring + _BACK_COMPAT_OBJECTS.setdefault(fn) + _MARKED_WITH_COMPATIBILITY.setdefault(fn) + return fn + + return mark_back_compat + else: + + def mark_not_back_compat(fn): + docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') + docstring += """ +.. warning:: + This API is experimental and is *NOT* backward-compatible. +""" + fn.__doc__ = docstring + _MARKED_WITH_COMPATIBILITY.setdefault(fn) + return fn + + return mark_not_back_compat diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/_lazy_graph_module.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/_lazy_graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b4bc0d69d7c7e94c3119ed05eab220cfc7aaca --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/_lazy_graph_module.py @@ -0,0 +1,182 @@ +from contextlib import contextmanager + +from torch.fx import GraphModule +from torch.fx.graph_module import ( + _format_import_block, + reduce_graph_module, + reduce_package_graph_module, +) +from torch.package import PackageExporter, sys_importer +from ._compatibility import compatibility + +_use_lazy_graph_module_flag = False +_force_skip_lazy_graph_module_flag = False + + +@compatibility(is_backward_compatible=False) +@contextmanager +def _force_skip_lazy_graph_module(): + """ + Skip using lazy graph module disregarding the setting of _use_lazy_graph_module. + Use to skip _LazyGraphModule when testing inductor torchscript related backend. + + torch.jit.script a _LazyGraphModule results in following error: + https://gist.github.com/shunting314/5143654c8084aed84ecd19b818258a69 + """ + try: + global _force_skip_lazy_graph_module_flag + prior = _force_skip_lazy_graph_module_flag + _force_skip_lazy_graph_module_flag = True + yield + finally: + _force_skip_lazy_graph_module_flag = prior + + +@compatibility(is_backward_compatible=False) +@contextmanager +def _use_lazy_graph_module(should_use: bool): + try: + global _use_lazy_graph_module_flag + prior = _use_lazy_graph_module_flag + _use_lazy_graph_module_flag = ( + should_use and not _force_skip_lazy_graph_module_flag + ) + yield + finally: + _use_lazy_graph_module_flag = prior + + +@compatibility(is_backward_compatible=False) +def _get_graph_module_cls(): + return _LazyGraphModule if _use_lazy_graph_module_flag else GraphModule + + +def _make_graph_module(*args, graph_module_cls=None, **kwargs): + if graph_module_cls is None: + graph_module_cls = _get_graph_module_cls() + + return graph_module_cls(*args, **kwargs) + + +@compatibility(is_backward_compatible=False) +class _LazyGraphModule(GraphModule): + """ + The main difference between _LazyGraphModule and GraphModule is how recompile happens. + GraphModule will do a 'recompile' call to generate python code and the forward method when it's + constructed. Later on if the graph get updated, recompile method can be called again to refresh + the saved python code and forward method. + + However in some cases especially in inductor, the recompilation can be a waste since we never + check the python code for the graph module or call its forward method. A few more concreate + examples regarding pattern matching fx passes in inductor: + 1. some passes will update the graph to be compiled and then call recompile on the GraphModule. + 2. some passes will trace small pattern function to search it in the graph being compiled and + replace the match with the traced graph of a replacement function. The pattern graph and + replacement graph are quite small but there are large amount of them. Doing GraphModule.recompile + for them in GraphModule.__init__ is also a waste of time. + + However simply skip calling GraphModule.recompile in these scenarios is also dangeruous. + People may want to check the python code or call the GraphModule's forward method for debugging purposes. + + The way _LazyGraphModule solves it is, we override the recompile method to just mark the + need for recompilation but does not do the actual recompilation. Later on if people really + access the compiled python code or call the GraphModule's forward method, we do the real + recompilation. + """ + + @classmethod + def from_graphmodule(cls, gm: GraphModule): + if isinstance(gm, _LazyGraphModule): + return gm + else: + return _LazyGraphModule(gm, gm.graph) + + @staticmethod + def force_recompile(gm): + """ + Sometimes we need force a recompile as a workaround + - we want to do the real recompilation before symbolic_trace to avoid error: + https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 + """ + if isinstance(gm, _LazyGraphModule): + gm.real_recompile() + + def real_recompile(self): + if self._needs_recompile(): + self._real_recompile() + + @classmethod + def _needs_recompile(cls): + return cls.forward is cls._lazy_forward + + def _lazy_forward(self, *args, **kwargs): + # Call self.real_recompile() rather than self._real_recompile() here. + # The _lazy_forward method may be saved and call repeatedly. + # Calling self.real_recompile can make sure we skip recompilation if + # we have already done so. + self.real_recompile() + assert not self._needs_recompile() + + # call `__call__` rather than 'forward' since recompilation may + # install a wrapper for `__call__` to provide a customized error + # message. + return self(*args, **kwargs) + + forward = _lazy_forward + + # TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__, + # or __reduce__ by calling _real_recompile. But I don't find a good way + # to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule + # will be used in torch::deploy. So it's skipped for now. + + def __reduce_package__(self, exporter: PackageExporter): + """ + Follow GraphModule.__reduce__ but call 'self._real_recompile' rather + than 'self.recompile' since for a _LazyGraphModule, self.recompile just + mark the need of recompilation and does not return the PythonCode object. + """ + python_code = self._real_recompile() + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + generated_module_name = f"fx-generated._{exporter.get_unique_id()}" + import_block = _format_import_block(python_code.globals, exporter.importer) + module_code = import_block + self.code + exporter.save_source_string(generated_module_name, module_code) + return ( + reduce_package_graph_module, + (dict_without_graph, generated_module_name), + ) + + def __reduce__(self): + """ + Follow GraphModule.__reduce__ but call 'self._real_recompile' rather + than 'self.recompile' since for a _LazyGraphModule, self.recompile just + mark the need of recompilation and does not return the PythonCode object. + """ + python_code = self._real_recompile() + dict_without_graph = self.__dict__.copy() + import_block = _format_import_block(python_code.globals, sys_importer) + del dict_without_graph["_graph"] + return (reduce_graph_module, (dict_without_graph, import_block)) + + def _real_recompile(self): + return super().recompile() + + @classmethod + def recompile(cls): + cls.forward = cls._lazy_forward + + @property + def code(self) -> str: + self.real_recompile() + return super().code + + def __str__(self) -> str: + """ + str(GraphModule) will access the _code attribute. Make sure recompile + happens so _code attribute is available. + """ + self.real_recompile() + return super().__str__() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/_pytree.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/_pytree.py new file mode 100644 index 0000000000000000000000000000000000000000..29ab0c8679113b803ed63f9d520a41e0f2fd3327 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/_pytree.py @@ -0,0 +1,102 @@ +from collections import namedtuple +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type + +import torch.return_types + +from torch.utils._pytree import PyTree, TreeSpec + +FlattenFuncSpec = Callable[[PyTree, TreeSpec], List] +FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool] + +SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {} +SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {} + + +def register_pytree_flatten_spec( + cls: Type[Any], + flatten_fn_spec: FlattenFuncSpec, + flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None, +) -> None: + SUPPORTED_NODES[cls] = flatten_fn_spec + SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec + + +def tree_flatten_spec( + pytree: PyTree, + spec: TreeSpec, + exact_structural_match=False, +) -> List[Any]: + if spec.is_leaf(): + return [pytree] + if spec.type not in SUPPORTED_NODES: + raise RuntimeError( + f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with " + "torch.fx._pytree.register_pytree_flatten_spec. If you have serialized your model, make " + "sure that any custom pytrees have been registered before loading it.", + ) + flatten_fn_spec = SUPPORTED_NODES[spec.type] + child_pytrees = flatten_fn_spec(pytree, spec) + if exact_structural_match: + flatten_fn_exact_match_spec = SUPPORTED_NODES_EXACT_MATCH[spec.type] + if flatten_fn_exact_match_spec and not flatten_fn_exact_match_spec( + pytree, + spec, + ): + raise RuntimeError(f"Cannot flatten pytree {pytree}, given spec: {spec}") + result = [] + for child, child_spec in zip(child_pytrees, spec.children_specs): + flat = tree_flatten_spec(child, child_spec, exact_structural_match) + result += flat + return result + + +def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]: + return [d[k] for k in spec.context] + + +def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]: + return [d[i] for i in range(spec.num_children)] + + +def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]: + return [d[i] for i in range(spec.num_children)] + + +def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]: + return [d[i] for i in range(spec.num_children)] + + +def _dict_flatten_spec_exact_match(d: Dict[Any, Any], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _list_flatten_spec_exact_match(d: List[Any], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _tuple_flatten_spec_exact_match(d: Tuple[Any], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match) +register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match) +register_pytree_flatten_spec( + tuple, + _tuple_flatten_spec, + _tuple_flatten_spec_exact_match, +) +for return_type in torch.return_types.all_return_types: + register_pytree_flatten_spec( + return_type, + _tuple_flatten_spec, + _tuple_flatten_spec_exact_match, + ) +register_pytree_flatten_spec( + namedtuple, # type: ignore[arg-type] + _namedtuple_flatten_spec, + _namedtuple_flatten_spec_exact_match, +) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/annotate.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/annotate.py new file mode 100644 index 0000000000000000000000000000000000000000..032ce14b6ec701dabc2459c501dfb957be5a1487 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/annotate.py @@ -0,0 +1,21 @@ +from torch.fx.proxy import Proxy +from ._compatibility import compatibility + +@compatibility(is_backward_compatible=False) +def annotate(val, type): + # val could be either a regular value (not tracing) + # or fx.Proxy (tracing) + if isinstance(val, Proxy): + if val.node.type: + raise RuntimeError(f"Tried to annotate a value that already had a type on it!" + f" Existing type is {val.node.type} " + f"and new type is {type}. " + f"This could happen if you tried to annotate a function parameter " + f"value (in which case you should use the type slot " + f"on the function signature) or you called " + f"annotate on the same value twice") + else: + val.node.type = type + return val + else: + return val diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..627b497a9cbce95efce88e09597fed08dd92463a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_config.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c72da32e06c88ea389da11ab4b50062c6933cc2c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_config.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adeadf137a075b5e5ee04129e431fa1ee10c1277 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35cc0de90b9df3ee97299e7d511ba81b471472c2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ec3653d61ebd6db451b0c5a6425618fbdb39c4f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e62f253cd386ee0d4af162df5e1d239a95ecbde8 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c6f1328425b157c7fc52828e7efdefb944db554 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/accelerator_partitioner.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/accelerator_partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..c2caf933fd565c33decc55cef954f6b3f923dba6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/accelerator_partitioner.py @@ -0,0 +1,1078 @@ +import operator +from collections import deque +from typing import Dict, List, Set, NamedTuple, Tuple, Deque + +import torch +from torch.fx.passes.graph_manipulation import get_size_of_all_nodes +from torch.fx.experimental.partitioner_utils import ( + Partition, + Device, + PartitionerConfig, + get_partition_to_latency_mapping, + get_latency_of_partitioned_graph, + NodeLatency, + get_extra_size_of, + PartitionMode, +) +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node, map_arg +from torch.fx.passes.split_module import split_module + + +class DAGNode: + """DAGNode class maintains useful information for a partition (submodule), + and its input submodules and output submodules. + """ + + def __init__( + self, + submodule_node: Node, + input_nodes: List[Node], + output_nodes: List[Node], + logical_device_ids: List[int], + size_bytes: int, + ) -> None: + self.submodule_node: Node = submodule_node + self.input_nodes: List[Node] = input_nodes + self.output_nodes: List[Node] = output_nodes + self.logical_device_ids: List[int] = logical_device_ids + self.size_bytes = size_bytes + + def __str__(self) -> str: + return str(self.submodule_node) + + +class DAG: + """DAG class contains all the DAG nodes""" + + def __init__(self) -> None: + self.nodes: List[DAGNode] = [] + + def create_node( + self, + submodule_node: Node, + input_nodes: List[Node], + output_nodes: List[Node], + logical_devices: List[int], + size_bytes: int, + ) -> None: + node = DAGNode( + submodule_node, input_nodes, output_nodes, logical_devices, size_bytes + ) + self.nodes.append(node) + + +class PartitionResult(NamedTuple): + """NameTuple used for returning DAG and a new fx module""" + + dag: DAG + module_with_submodules: GraphModule + + +"""Followings are some helper functions for partition manipulation""" + + +def reset_partition_device(partitions): + for partition in partitions: + partition.logical_device_ids = [] + + +def combine_two_partitions( + partition_0: Partition, partition_1: Partition, partitions: List[Partition] +) -> None: + """Given a list of partitions and its two partitions, + combine these two partitions into a new one appending to the partitions + and remove the previous two partitions from the list of partitions + """ + partition = Partition(len(partitions)) + partition.nodes = partition_0.nodes.union(partition_1.nodes) + partition.recalculate_mem_size() + partitions.append(partition) + partitions.remove(partition_0) + partitions.remove(partition_1) + reorganize_partitions(partitions) + return + + +def set_parents_and_children(partitions: List[Partition]) -> None: + """Given a list of partitions, mark parents and children for each partition""" + # Go through all nodes in a partition. + # If a node's user is in other partition, + # then the other partition is this partition's children. + # This partition is the other partition's parent + for partition in partitions: + partition.children = set() + partition.parents = set() + for partition in partitions: + for node in partition.nodes: + # For each node in the current partition, find its users + users = node.users + for n in users: + # Find which the partition the user node belongs to. + # Note that if the node itself is also belongs to that partition, + # that partition is not the child of the current partition + for p in partitions: + if p != partition and n in p.nodes and node not in p.nodes: + partition.children.add(p) + p.parents.add(partition) + return + + +def reorganize_partitions(partitions: List[Partition]) -> None: + """Given a list of partitions, reorganize partition id, + its parents and its children for each partition + """ + # Rearrange partition ids + for i, partition in enumerate(partitions): + partition.partition_id = i + set_parents_and_children(partitions) + return + + +def get_bfs_level_partition(partitions: List[Partition]) -> None: + """Given a list of partitions, + mark the bfs level for each partition + """ + current_level: Set[Partition] = set() + visited: Set[Partition] = set() + for partition in partitions: + # If a partition has no parent, it should be in root level + if len(partition.parents) == 0: + current_level.add(partition) + next_level: Set[Partition] = set() + level = 0 + # bfs + while current_level: + partition = current_level.pop() + partition.bfs_level = level + visited.add(partition) + children = partition.children + for child in children: + if child not in next_level: + next_level.add(child) + if not current_level: + current_level = next_level.copy() + next_level = set() + level += 1 + return + + +def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int]: + """Given a list of partitions,return node to partition mapping""" + node_to_partition: Dict[Node, int] = {} + for partition in partitions: + for node in partition.nodes: + node_to_partition[node] = partition.partition_id + return node_to_partition + + +def get_logical_id_to_device(devices: List[Device]) -> Dict[int, Device]: + """Get a mapping from device logical ID to Device object.""" + logical_id_to_device: Dict[int, Device] = {} + for d in devices: + logical_id_to_device[d.logical_id] = d + return logical_id_to_device + + +def get_device_partition_stats( + partitions: List[Partition], devices: List[Device] +) -> Tuple[Dict[Device, List[Partition]], Dict[Device, int], List[Partition]]: + """Given a list of partitions and a list of devices, returns: + 1. A mapping from device to partitions on it; + 2. A mapping from device to its remaining memory size; + 3. A list of partitions that do not have a device. + """ + # logical id to device + logical_id_to_device = get_logical_id_to_device(devices) + # Track partitions on device + device_to_partitions: Dict[Device, List[Partition]] = {} + # Track device's left mem size + device_to_left_mem_bytes: Dict[Device, int] = {} + for d in devices: + device_to_partitions[d] = [] + device_to_left_mem_bytes[d] = d.available_mem_bytes + + # Deal with the partitions that already have a device + # and also collect all partitions without a device (no_device_partitions) + no_device_partitions = [] + for partition in partitions: + if partition.logical_device_ids != []: + for logical_id in partition.logical_device_ids: + device = logical_id_to_device[logical_id] + device_to_partitions[device].append(partition) + device_to_left_mem_bytes[device] -= partition.used_mem_bytes + else: + no_device_partitions.append(partition) + + return ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) + + +def get_device_to_partitions_mapping( + partitions: List[Partition], devices: List[Device] +): + """Given a list of partitions and a list of devices, + map each partition into a device. + """ + + def calculate_extra_mem_bytes_needed_for( + partition: Partition, partitions: List[Partition] + ): + all_nodes: Set[Node] = set() + for p in partitions: + all_nodes = all_nodes.union(p.nodes) + if len(all_nodes) == 0: + return partition.used_mem_bytes + all_nodes = all_nodes.union(partition.nodes) + extra_size_needed = 0 + for node in partition.nodes: + extra_size_needed += get_extra_size_of(node, all_nodes) + return extra_size_needed + + def find_device_for(partition: Partition): + """Given a partition, find a logical device for the partition + The algorithm is to put the partition on the device + that has just enough mem left for that partition. + device_to_left_mem_bytes is a dictionary between device and its left mem size + sorted by its left mem size + """ + for d in device_to_left_mem_bytes: + extra_size_needed = calculate_extra_mem_bytes_needed_for( + partition, device_to_partitions[d] + ) + if extra_size_needed < device_to_left_mem_bytes[d]: + device_to_partitions[d].append(partition) + partition.logical_device_ids.append(d.logical_id) + device_to_left_mem_bytes[d] -= extra_size_needed + return True + return False + + ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) = get_device_partition_stats(partitions, devices) + + # Find devices for all the partitions without a device + found_device = True + for partition in no_device_partitions: + device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=lambda item: item[1])) + found_device = find_device_for(partition) + if not found_device: + break + return found_device + + +def check_dependency(partition): + """Given a partition,check if there is a circular dependency on + this partition using bfs + """ + visited: Set[Partition] = {partition} + queue: Deque[Partition] = deque([partition]) + while queue: + p = queue.popleft() + for child in p.children: + if child == partition: + return True + else: + if child not in visited: + visited.add(child) + queue.append(child) + return False + + +class Partitioner: + """A fx module may not fit into one device. + Partitioner class helps partition one fx module into submodules (partitions), + so that the submodules can be executed crossing different accelerators. + The main function of this class is self.partition_graph. + It partitions the fx module based on the scheme specified in partition_config + A DAG structure is returned + along with a new fx module with submodule nodes. + """ + + def __init__(self) -> None: + self.partitions: List[Partition] = [] + self.node_to_partition: Dict[Node, int] = {} + self.devices: List[Device] = [] + + def partition_graph( + self, + fx_module: GraphModule, + torch_module: torch.nn.Module, + partitioner_config: PartitionerConfig, + ) -> PartitionResult: + """Given the fx module, torch module and partitioner_config, + find the partitions, do the partitions, + and then return a DAG and a new fx module with submodule nodes (partitions) + """ + self.graph_module = fx_module + self.torch_module = torch_module + self.devices = partitioner_config.devices + if len(self.devices) == 0: + raise RuntimeError("No devices") + # Tag the size in bytes to all nodes in the graph_module. + get_size_of_all_nodes(self.graph_module) + # Check if there are op nodes in the fx module + nodes = self.graph_module.graph.nodes + if all(node.op in {"placeholder", "get_attr", "output"} for node in nodes): + raise RuntimeError("No Partition since no operations in the module") + # Calculate total size of the fx module + total_size_of_graph = 0 + for node in nodes: + if node.op == "output": + break + total_size_of_graph += node.size_bytes.total_size + # Find the device with the max mem size + device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes) + # AOT based partition + if partitioner_config.mode == PartitionMode.aot_based: + self.aot_based_partition( + partitioner_config.node_to_partition_mapping, + partitioner_config.partition_to_logical_device_mapping, + ) + # Single partition if the whole module can be fit into one device + elif total_size_of_graph <= device_with_max_mem.available_mem_bytes: + self.find_single_partition( + total_size_of_graph, logical_device_id=device_with_max_mem.logical_id + ) + elif total_size_of_graph > sum([d.available_mem_bytes for d in self.devices]): + raise RuntimeError("Devices have no enough memory for the module") + else: + # Sparse nn based partition + if partitioner_config.mode == PartitionMode.sparse_nn: + available_mem_bytes = self.devices[0].available_mem_bytes + if not all( + device.available_mem_bytes == available_mem_bytes + for device in self.devices + ): + raise RuntimeError("All devices must have same memory size!") + # sparse_nn_partition only support same memory size + # TODO: add different size support for sparse_nn_partition + self.sparse_nn_partition(available_mem_bytes) + # Cost aware partition + elif partitioner_config.mode == PartitionMode.cost_aware: + self.cost_aware_partition( + partitioner_config.transfer_rate_bytes_per_sec, + partitioner_config.node_to_latency_mapping, + ) + # KL based partition + elif partitioner_config.mode == PartitionMode.kl_based: + self.kl_based_partition( + partitioner_config.transfer_rate_bytes_per_sec, + partitioner_config.node_to_latency_mapping, + ) + else: + self.size_based_partition() + + # Saturate host if possible. + if partitioner_config.saturate_host: + self.saturate_host() + + # Partition the graph module based on the partition assignment. + module_with_submodules = self.do_partition() + + # The DAG contains DAGNodes with info of each partition's input nodes, output nodes + # and how partitions are connected. + dag = self.dump_dag(module_with_submodules) + ret = PartitionResult(dag, module_with_submodules) + return ret + + def find_single_partition( + self, total_size_of_graph, logical_device_id: int = 0 + ) -> None: + """Fit the whole fx module into one device""" + partition_0 = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op == "output": + # Skip the output node, but there can + # be nodes after the output in certain cases. + continue + partition_0.nodes.add(node) + partition_0.used_mem_bytes = total_size_of_graph + partition_0.logical_device_ids = [logical_device_id] + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def size_based_partition(self) -> None: + """This method is to partition the fx module based on memory size. + It uses greedy approach. The result may not be the best. + The basic idea is: + Step 1: + Find a device which has enough memory to fit the current node, create a empty partition + with the size of that device. + Then keep adding the following nodes into the partition until the partition is full. + Step 2: + Repeat Step 1 until no device left + Step 3: + If some nodes are left, create a partition for each left node (single node partition). + and then try to map those partitions into logical devices with enough mem left. + """ + + def find_device_based_on_size(node) -> Device: + """Given a node, this function is to find a logical device + that could fit the node. + """ + mem_size_needed = get_extra_size_of(node, set()) + device = Device("", -1, -1) + for d in self.devices: + if ( + d not in occupied_devices + and d.available_mem_bytes >= mem_size_needed + ): + device = d + break + if device.available_mem_bytes < 0: + raise RuntimeError(str(node) + "is too large to fit any device") + occupied_devices.append(device) + return device + + # Track partition and its left mem size + partition_to_left_mem_bytes: Dict[Partition, int] = {} + # Track all the devices that have been used + occupied_devices: List[Device] = [] + partition = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op in {"call_module", "call_method", "call_function"}: + # Check if there are devices left + if len(self.partitions) <= len(self.devices): + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + # Check if the current partition is the very first partition + if partition.used_mem_bytes == 0: + # Find a device to fit the first node, return available mem size + device = find_device_based_on_size(node) + occupied_devices.append(device) + # Update partition and its left mem size + partition_to_left_mem_bytes[ + partition + ] = device.available_mem_bytes + # Update available mem for the current partition + partition.logical_device_ids.append(device.logical_id) + else: + # The current partition is not the first partition + # Check if the current node can fit into current partition + if ( + partition_to_left_mem_bytes[partition] + < total_size_of_input_nodes + ): + # Check if no device is left + if len(self.partitions) == len(self.devices): + # No device is left + # Put the previous partitions into a list (non_single_node_partitions) + non_single_node_partitions = self.partitions[:] + # Create the first single node partition for the current node + self.create_single_node_partition(node) + continue + # Some devices are still left + # Create a new partition with a mem size that is enough for the current node + device = find_device_based_on_size(node) + partition = self.create_partition() + total_size_of_input_nodes = get_extra_size_of( + node, partition.nodes + ) + partition_to_left_mem_bytes[ + partition + ] = device.available_mem_bytes + partition.logical_device_ids.append(device.logical_id) + partition.add_node(node) + partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes + # Create single node partitions if no device is left + else: + self.create_single_node_partition(node) + reorganize_partitions(self.partitions) + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + # Mapping all partitions into device + found_partition_to_device_mapping = get_device_to_partitions_mapping( + self.partitions, self.devices + ) + if not found_partition_to_device_mapping: + raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping") + return + + def saturate_host(self) -> None: + """Saturate host by assigning replicates to unused devices with enough memory. + It uses a greedy approach to find a next available set of devices to place all split + partitions: For each used device, it searches for an idle device with minimal memory + size that can hold all the partition located on that device; If the search is successful + for all used devices, it then assigns the new devices' logical ID to the corresponding + partition. + """ + ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) = get_device_partition_stats(self.partitions, self.devices) + + assert ( + len(no_device_partitions) == 0 + ), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}" + + # Devices that hold partitions + used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0] + # Track replicates of the assigned devices + replicated_device_to_used_device: Dict[Device, Device] = {} + + while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len( + self.devices + ): + # Success flag for this round + success = True + # Devices that have not been assigned + idle_devices = [ + d + for d in self.devices + if d not in used_devices and d not in replicated_device_to_used_device + ] + # Temporary mapping from replicated device to original device + temp_replicate_mapping = {} + + # Find a new device to replicate all partitions on an used device + for used_device in used_devices: + # Idle devices that have enough memory + available_devices = [ + d + for d in idle_devices + if d.available_mem_bytes + >= used_device.available_mem_bytes + - device_to_left_mem_bytes[used_device] + ] + if len(available_devices) == 0: + success = False + break + new_device = min(available_devices, key=lambda d: d.available_mem_bytes) + idle_devices.remove(new_device) + temp_replicate_mapping[new_device] = used_device + + if not success: + break + replicated_device_to_used_device.update(temp_replicate_mapping) + + # Update logical device IDs assigned to the partitions + for ( + replicate_device, + original_device, + ) in replicated_device_to_used_device.items(): + logical_id = replicate_device.logical_id + for partition in device_to_partitions[original_device]: + partition.logical_device_ids.append(logical_id) + for p in self.partitions: + print(p.logical_device_ids) + + def do_partition(self) -> GraphModule: + """Return a new fx module with submodule nodes (partitions).""" + module_with_submodules = split_module( + self.graph_module, + self.torch_module, + lambda node: self.node_to_partition[node], + ) + return module_with_submodules + + def dump_dag(self, module_with_submodules: GraphModule) -> DAG: + """Return the dag structure and the new fx module with submodules.""" + dag = DAG() + for node in module_with_submodules.graph.nodes: + if node.op == "output": + break + if node.op in {"placeholder", "get_attr"}: + continue + if node.target == operator.__getitem__: + continue + input_nodes: Dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # When a node has two or more output nodes, + # it outputs its result to 'getitem' nodes. + # Those 'getitem' nodes are the output node for this node. + # Otherwise, the output node is this node itself. + if len(node.users) > 1: + output_nodes = list(node.users) + else: + output_nodes = [node] + partition_id = int(node.name.rsplit("_", 1)[-1]) + device_ids = self.partitions[partition_id].logical_device_ids + size_bytes = self.partitions[partition_id].used_mem_bytes + dag.create_node( + node, list(input_nodes), output_nodes, device_ids, size_bytes + ) + return dag + + def create_partition(self) -> Partition: + """Create a partition and append it to self.partitions.""" + partition_id = len(self.partitions) + partition = Partition(partition_id) + self.partitions.append(partition) + return partition + + def create_single_node_partition(self, node): + """Create a partition for a single node""" + partition = self.create_partition() + partition.add_node(node) + return + + def sparse_nn_partition(self, available_mem_bytes: int) -> None: + """This method partition a sparse nn module. + It is size based partition but different from size_based_partition, + it only works when all the devices have same memory size (available_mem_bytes). + In the future, devices with different mem sizes will be supported like size_based_partition. + It first traverse all the nodes and do the partitions based on the same memory size. + If the current partition has no enough memory left for a new op node + (call_module, call_method, call_function), a new partition is created. + When crossing the boundary between non-embedding nodes and embedding nodes, + a new partition is created regardlessly. + For example, if the current node is a non-embedding node but the next node is an + embedding node, a new partition is created for the next node. + After the partition, the partitions are combined as much as possible. + The rule is that a non-embedding partition only + combines with another non-embedding one. + So as the embedding partitions. + """ + + def combine_partitions_based_on_size( + partitions: List[Partition], available_mem_bytes: int + ) -> None: + """Combining small partitions together to keep as less partitions as possible. + Here is an example of the algorithm to do this: + Assume some partitions, we first sort them based on partition used memory size. + [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)] + The available memory is 10. + step 1: self.find_partition_to_combine_based_on_size() + First, mark bfs level for each partition + Second, look the smallest partition, partition_4: 10 - 1 = 9 + It means any partition has a used memory equal or less than 9 could combine this partition + We go from the largest and selection partition_0. + Check the bfs level for two partitions, if the level difference is less than 2, + it can be combined. + step 2: repeat step 1 until no partitions can be combined + """ + find_combination = True + while find_combination: + # Sort partitions based on memory size + sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes) + # Mark bfs level + get_bfs_level_partition(self.partitions) + find_combination, partitions = find_partition_to_combine_based_on_size( + sorted_partitions, available_mem_bytes, partitions + ) + return + + def calculate_mem_bytes_needed(p1, p2): + """Given two partitions, calculate how many mem bytes + are needed if two partitions are combined + """ + nodes = p1.nodes.union(p2.nodes) + mem_bytes_needed = 0 + for node in nodes: + mem_bytes_needed += get_extra_size_of(node, nodes) + return mem_bytes_needed + + def find_partition_to_combine_based_on_size( + sorted_partitions: List[Partition], + available_mem_bytes: int, + partitions: List[Partition], + ) -> Tuple[bool, List[Partition]]: + """step 1 in combine_partition_based_on_size()""" + find_combination = False + smallest_partition = sorted_partitions.pop(0) + for p in sorted_partitions[::-1]: + if abs(smallest_partition.bfs_level - p.bfs_level) <= 1: + # Calculate how many bytes needed if combined + mem_bytes_needed = calculate_mem_bytes_needed(p, smallest_partition) + if mem_bytes_needed <= available_mem_bytes: + combine_two_partitions(p, smallest_partition, self.partitions) + partitions.remove(smallest_partition) + partitions.remove(p) + partitions.append(self.partitions[-1]) + find_combination = True + break + return find_combination, partitions + + def reset_partition_in_sparse_nn(partition, new_partition=True): + """If crossing the boundary between non-embedding nodes and + embedding nodes, create a new partition + """ + if in_embedding_region: + embedding_partitions.append(partition) + else: + non_embedding_partitions.append(partition) + if new_partition: + partition = self.create_partition() + partition.left_mem_bytes = available_mem_bytes + return partition + return None + + def is_embedding_node(node: Node) -> bool: + """Check if a node is an embedding node""" + if node.op == "call_module": + submodule = self.graph_module + for atom in str(node.target).split("."): + if not hasattr(submodule, atom): + raise RuntimeError( + f"Module {submodule} has no attribute {atom}" + ) + submodule = getattr(submodule, atom) + if "Embedding" in str(submodule): + return True + return False + + # Track embedding partitions and non-embedding partitions separately + embedding_partitions: List[Partition] = [] + non_embedding_partitions: List[Partition] = [] + # A Flag to check the boundary + in_embedding_region: bool = False + partition = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op in {"call_module", "call_method", "call_function"}: + # Check if crossing the boundary between embedding nodes and non embedding nodes + if is_embedding_node(node) != in_embedding_region: + # Crossing the boundary + # Check if the current partition is an empty partition + if partition.used_mem_bytes != 0: + # The current partition isn't an empty partition. Create a new one. + partition = reset_partition_in_sparse_nn(partition) + in_embedding_region = not in_embedding_region + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if ( + total_size_of_input_nodes + partition.used_mem_bytes + > available_mem_bytes + ): + partition = reset_partition_in_sparse_nn(partition) + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if total_size_of_input_nodes > available_mem_bytes: + raise RuntimeError( + node.target + "is too large to fit into a device" + ) + partition.add_node(node) + reset_partition_in_sparse_nn(partition, new_partition=False) + # Set parents and children for partitions + set_parents_and_children(self.partitions) + # Combining non-embedding partitions + combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes) + # Combining embedding partitions + combine_partitions_based_on_size(embedding_partitions, available_mem_bytes) + total_size_of_non_embedding_partitions = 0 + for partition in non_embedding_partitions: + total_size_of_non_embedding_partitions += partition.used_mem_bytes + # Check if devices are enough for all partitions + if len(embedding_partitions) > len(self.devices): + msg = ( + "Need " + + str(len(embedding_partitions)) + + " devices, but only " + + str(len(self.devices)) + + " provided" + ) + raise RuntimeError(msg) + occupied_devices = [] + for i, partition in enumerate(embedding_partitions): + # Check if all non-embedding partitions can fit into embedding partition devices + if ( + total_size_of_non_embedding_partitions + partition.used_mem_bytes + > available_mem_bytes + ): + raise RuntimeError( + "partition_" + + str(partition.partition_id) + + "(embedding partition) and non embedding partitions can not fit into one device" + ) + else: + # Add logical device to the partition + partition.logical_device_ids = [self.devices[i].logical_id] + occupied_devices.append(self.devices[i].logical_id) + # Add logical devices to the non_embedding_partitions + for partition in non_embedding_partitions: + partition.logical_device_ids = occupied_devices + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def cost_aware_partition( + self, + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: Dict[Node, NodeLatency], + ) -> None: + """This method is to partition the fx module based on the cost. + The cost is the total latency of running the whole fx module. + In partitioner_utils.py, the cost model is built. + The cost aware partition algorithm is: + #1. At every beginning, each node is a partition. + Then we map all the partitions to the devices + and calculate the cost + #2. Then try to pre-combine any two of the partitions if the two + partitions can be combined. + (the bfs level is less than 2 or two partitions are connected and + can find partition to device mapping) + See if any partition pair could reduce the current cost. + Choose the pair that shows the minimum cost and then combine them + #3. Repeat #2 until the cost cannot be reduced. + """ + + def try_combining_partitions(p0_index, p1_index, partitions) -> float: + """Given two partitions and a list of partitions, combine these two partitions + and see what is the cost of the modified partition list + """ + p0 = partitions[p0_index] + p1 = partitions[p1_index] + """If two partitions' bfs level are less than 2 or two partitions are connected to each other, + then they can be combined + """ + if ( + (abs(p0.bfs_level - p1.bfs_level) <= 1) + or (p0 in p1.parents) + or p0 in (p1.children) + ): + combine_two_partitions(p0, p1, partitions) + # Check if a circular dependency exists after combining + if check_dependency(partitions[-1]): + return float("inf") + # Check if the modified partition list can be mapped to devices after combination + reset_partition_device(partitions) + found_deivce = get_device_to_partitions_mapping( + partitions, self.devices + ) + if not found_deivce: + return float("inf") + # Calculate the new cost + partition_to_latency_mapping = get_partition_to_latency_mapping( + partitions, node_to_latency_mapping + ) + cost = get_latency_of_partitioned_graph( + partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + return cost + # If two partition can not be combined, the cost is inf + return float("inf") + + def search_combination( + transfer_rate_bytes_per_sec, node_to_latency_mapping + ) -> bool: + """Given transfer rate between partitions and each node's latency, + find two partitions to combine so the cost of the partitions can + be reduced. + The algorithm is : + 1. Go through all the partition pairs and see + if any pair of partitions can be combined. + 2. Calculate the cost after the combination. + 3. Select the minimum cost and combine its corresponding partition pair. + """ + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + cost = get_latency_of_partitioned_graph( + self.partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + if len(self.partitions) == 1: + return False + partition_pair: List[int] = [] + for i in range(len(self.partitions) - 1): + for j in range(i + 1, len(self.partitions)): + # Try to combine the partition pair + # and see the new cost after combination + new_cost = try_combining_partitions(i, j, self.partitions[:]) + if new_cost <= cost: + partition_pair = [i, j] + cost = new_cost + reorganize_partitions(self.partitions) + # If a partition pair is found, combine them + if len(partition_pair) != 0: + p0 = self.partitions[partition_pair[0]] + p1 = self.partitions[partition_pair[1]] + combine_two_partitions(p0, p1, self.partitions) + get_bfs_level_partition(self.partitions) + reset_partition_device(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + return len(partition_pair) != 0 + + for node in self.graph_module.graph.nodes: + if node.op not in {"placeholder", "get_attr", "output"}: + self.create_single_node_partition(node) + # Set up parent partitions and children partitions for each partition + set_parents_and_children(self.partitions) + # Get bfs level for each partition + get_bfs_level_partition(self.partitions) + find_combination = True + while find_combination: + # Search for a pair partition to generate the minimum new cost, + # then combine them + find_combination = search_combination( + transfer_rate_bytes_per_sec, node_to_latency_mapping + ) + # Make sure all partitions are set up correctly + reorganize_partitions(self.partitions) + # Set up node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def kl_based_partition( + self, + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: Dict[Node, NodeLatency], + ) -> None: + """This function is a cost aware partition based + on Kernighan-Lin algorithm. + First, the graph is partitioned using size_based_partition. + Then, each node is swapped with any other node in a different + partition, and at the same time, the cost is estimated after + the swapping. + For example, we have nodes n0, n1, n2, n3 and n4. + Using size_based_partition, n0 and n1 are in Partition p0. + n2, n3 and n4 in Partition p1. The current cost is estimated. + We first tried using n0 to swap with n2 from the other partition. + Then we see that swapping n0 and n2 shows a lower cost + than the current cost and it is the minimum among other pairs like + (n0, None)(This means moving n0 to Partition without swapping other nodes), + (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost + as the current cost. + Then We repeat this process for all the other nodes until all swapping pairs + are tried. + """ + + def swap_nodes(n0, n1, p0, p1): + # Either n0 or n1 could be None + # That means we simply move the node + # to another partition + if n0 is not None: + p0.remove_node(n0) + p1.add_node(n0) + if n1 is not None: + p0.add_node(n1) + p1.remove_node(n1) + + def try_swap_nodes( + n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ): + cost = float("inf") + swap_nodes(n0, n1, p0, p1) + # Reorganize partitions after swapping + reorganize_partitions(self.partitions) + # Check if there is a circular dependency after swapping + if (not check_dependency(p0)) and (not check_dependency(p1)): + reset_partition_device(self.partitions) + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + # Check if all partitions can be mapped to logical devices after swapping + found_device = get_device_to_partitions_mapping( + self.partitions, self.devices + ) + if not found_device: + cost = float("inf") + else: + cost = get_latency_of_partitioned_graph( + self.partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + # Swap back and reset all partitions back to original + swap_nodes(n1, n0, p0, p1) + reorganize_partitions(self.partitions) + reset_partition_device(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + return cost + + def swap_node_to_partition( + node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ): + """This function helps to swap one node from partition p0 + with all the nodes in another partition p1 + """ + p1_nodes = list(p1.nodes) + [None] + min_cost = float("inf") + node_pair: List[Node] = [] + for n1 in p1_nodes: + # Ignore the node if it is not a op node + if n1 is not None and n1.op in {"placeholder", "get_attr"}: + continue + # Try swapping node in p0 with n1 in p1 + cost = try_swap_nodes( + node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ) + if cost < min_cost: + node_pair = [node, n1] + min_cost = cost + return cost, node_pair # type: ignore[possibly-undefined] + + # First use size_base_partition + self.size_based_partition() + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + # Calculate the cost of the partitions + cost = get_latency_of_partitioned_graph( + self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec + ) + # Keep tracking the node pair that shows the better cost + node_pair: List[Node] = [] + # Keep tracking the partition pair of node pair + partition_pair: List[Partition] = [] + # Collect all the op nodes from the graph + op_nodes = [] + for n in self.graph_module.graph.nodes: + if n.op not in {"placeholder", "get_attr", "output"}: + op_nodes.append(n) + for node in op_nodes: + # Find which partition the current node belongs + p0_index = self.node_to_partition[node] + p0 = self.partitions[p0_index] + # Go through all the other partitions to swap + # with other nodes from those partitions + for p1_index, _ in enumerate(self.partitions): + if p0_index != p1_index: + p1 = self.partitions[p1_index] + new_cost, new_node_pair = swap_node_to_partition( + node, + p0, + p1, + node_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + # Update the cost + # Track the swapped node pair and their partitions + if new_cost < cost: + cost = new_cost + node_pair = new_node_pair + partition_pair = [p0, p1] + # Do the swapping after trying all the nodes from a partition + if len(node_pair) != 0: + swap_nodes( + node_pair[0], node_pair[1], partition_pair[0], partition_pair[1] + ) + reorganize_partitions(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + reorganize_partitions(self.partitions) + # Mapping the device to the partition + get_device_to_partitions_mapping(self.partitions, self.devices) + return + + def aot_based_partition( + self, node_to_partition_mapping, partition_to_logical_device_mapping + ): + """This function helps to rebuild the partitions given the nodes and its + corresponding partition id + """ + partition_id_to_partition_mapping: Dict[int, Partition] = {} + self.node_to_partition = node_to_partition_mapping + for node in self.node_to_partition: + partition_id = self.node_to_partition[node] + # If the requested partition has not been created, create the partition + if partition_id not in partition_id_to_partition_mapping: + partition = Partition(partition_id) + self.partitions.append(partition) + partition_id_to_partition_mapping[partition_id] = partition + partition.logical_device_ids = partition_to_logical_device_mapping[ + partition_id + ] + else: + partition = partition_id_to_partition_mapping[ + self.node_to_partition[node] + ] + # Add the current node into the partition + partition.add_node(node) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/debug.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..bd6fed690914e0f3696fb6c37bb63371bd801f93 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/debug.py @@ -0,0 +1,31 @@ +import torch.fx as fx + +def set_trace(gm: fx.GraphModule) -> fx.GraphModule: + """ + Sets a breakpoint in `gm`'s generated python code. It drops into pdb when + `gm` gets run. + + Args: + gm: graph module to insert breakpoint. It is then recompiled for it to + take effect. + + Returns: + the `gm` with breakpoint inserted. + """ + def insert_pdb(body): + return ["import pdb; pdb.set_trace()\n", *body] + + with gm.graph.on_generate_code( + make_transformer=lambda cur_transform: ( + # new code transformer to register + lambda body: ( + insert_pdb( + cur_transform(body) if cur_transform + else body + ) + ) + ) + ): + gm.recompile() + + return gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/optimization.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..4690ba81b360a8354d837d9e8db29bff231e06f9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/optimization.py @@ -0,0 +1,408 @@ +import torch.fx as fx +from torch.fx.node import Argument, Target +from torch.nn.utils.fusion import fuse_conv_bn_eval +from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.fx.passes.shape_prop import ShapeProp +import copy +from collections import defaultdict +import torch.utils.mkldnn as th_mkldnn +import operator +import time +import logging +from enum import Enum + +def _parent_name(target : str) -> Tuple[str, str]: + """ + Splits a qualname into parent path and last atom. + For example, `foo.bar.baz` -> (`foo.bar`, `baz`) + """ + *parent, name = target.rsplit('.', 1) + return parent[0] if parent else '', name + +# Works for length 2 patterns with 2 modules +def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]): + if len(node.args) == 0: + return False + nodes: Tuple[Any, fx.Node] = (node.args[0], node) + for expected_type, current_node in zip(pattern, nodes): + if not isinstance(current_node, fx.Node): + return False + if current_node.op != 'call_module': + return False + if not isinstance(current_node.target, str): + return False + if current_node.target not in modules: + return False + if type(modules[current_node.target]) is not expected_type: + return False + return True + + +def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): + assert isinstance(node.target, str) + parent_name, name = _parent_name(node.target) + modules[node.target] = new_module + setattr(modules[parent_name], name, new_module) + +def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module: + """ + Fuses convolution/BN layers for inference purposes. Will deepcopy your + model by default, but can modify the model inplace as well. + """ + patterns = [(nn.Conv1d, nn.BatchNorm1d), + (nn.Conv2d, nn.BatchNorm2d), + (nn.Conv3d, nn.BatchNorm3d)] + if not inplace: + model = copy.deepcopy(model) + if not no_trace or not isinstance(model, torch.fx.GraphModule): + fx_model = fx.symbolic_trace(model) + else: + fx_model = model + modules = dict(fx_model.named_modules()) + new_graph = copy.deepcopy(fx_model.graph) + + for pattern in patterns: + for node in new_graph.nodes: + if matches_module_pattern(pattern, node, modules): + if len(node.args[0].users) > 1: # Output of conv is used by other nodes + continue + conv = modules[node.args[0].target] + bn = modules[node.target] + if not bn.track_running_stats: + continue + fused_conv = fuse_conv_bn_eval(conv, bn) + replace_node_module(node.args[0], modules, fused_conv) + node.replace_all_uses_with(node.args[0]) + new_graph.erase_node(node) + return fx.GraphModule(fx_model, new_graph) + +def remove_dropout(model: nn.Module) -> nn.Module: + """ + Removes all dropout layers from the module. + """ + fx_model = fx.symbolic_trace(model) + + class DropoutRemover(torch.fx.Transformer): + def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + if isinstance(self.submodules[target], nn.Dropout): + assert len(args) == 1 + return args[0] + else: + return super().call_module(target, args, kwargs) + return DropoutRemover(fx_model).transform() + +def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]): + """ + Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. + """ + new_graph = fx.Graph() + env: Dict[fx.Node, fx.Node] = {} + for input in inputs: + new_node = new_graph.placeholder(input.name) + env[input] = new_node + for node in nodes: + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + new_graph.output([env[output] for output in outputs]) + new_graph.lint() + return fx.GraphModule(orig_module, new_graph) + +mkldnn_supported = [ + nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d, + torch.relu, torch.transpose, torch.sigmoid, + F.relu, F.avg_pool2d, F.adaptive_avg_pool2d +] +# These are operators that may not be convertible into MKLDNN ops (e.g. the +# args are scalar values). Thus, we only include them in the subgraph if their +# arguments are already in MKLDNN. +# TODO: Determine whether this can be removed after type inference. +mkldnn_supported_unknown = [operator.add, operator.mul] +mkldnn_map = { + nn.Conv2d: th_mkldnn.MkldnnConv2d, + nn.Linear: th_mkldnn.MkldnnLinear, + nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a) +} + + +def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): + """ + For each node, if it's a module that can be preconverted into MKLDNN, + then we do so and create a mapping to allow us to convert from the MKLDNN + version of the module to the original. + """ + old_modules: Dict[nn.Module, nn.Module] = {} + for node in nodes: + if node.op == 'call_module': + assert isinstance(node.target, str) + cur_module = modules[node.target] + if type(cur_module) in mkldnn_map: + new_module = mkldnn_map[type(cur_module)](cur_module, torch.float) + assert isinstance(new_module, nn.Module) + old_modules[new_module] = copy.deepcopy(cur_module) + replace_node_module(node, modules, new_module) + return old_modules + +def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]): + """ + Maps each module that's been changed with `modules_to_mkldnn` back to its + original. + """ + for node in nodes: + if node.op == 'call_module': + assert (isinstance(node.target, str)) + cur_module = modules[node.target] + if cur_module in old_modules: + replace_node_module(node, modules, old_modules[cur_module]) + +class MklSubgraph: + def __init__(self, fx_graph: fx.Graph): + self.fx_graph = fx_graph + self.nodes: List[fx.Node] = [] + self.start_nodes: List[fx.Node] = [] + self.end_nodes: List[fx.Node] = [] + +def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): + """ + This generates a heuristic that can be passed into `optimize_for_inference` that + determines whether a subgraph should be run in MKL by running it with the example_inputs. + + Example usage: + heuristic = gen_mkl_autotuner(example_inputs, iters=10) + fast_model = optimization.optimize_for_inference(model, heuristic) + """ + fx_model = None + old_modules = None + + def use_mkl_heuristic(graph: MklSubgraph) -> bool: + nonlocal fx_model, old_modules + input_nodes = graph.start_nodes + if fx_model is None: + fx_model = graph.fx_graph.owning_module + old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined] + ShapeProp(fx_model).propagate(example_inputs) + sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined] + output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes]) + submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args) + + def benchmark(f): + for _ in range(warmup): + f() + begin = time.time() + for _ in range(iters): + out = f() + return time.time() - begin + + mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])]) + + reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules) + no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) + return mkl_time < no_mkl_time + return use_mkl_heuristic + +def use_mkl_length(graph: MklSubgraph) -> bool: + """ + This is a heuristic that can be passed into `optimize_for_inference` that + determines whether a subgraph should be run in MKL by checking if there + are more than 2 nodes in it + """ + return len(graph.nodes) > 2 + +class UnionFind: + def __init__(self, n): + self.parent: List[Optional[int]] = [None] * n + self.size: List[int] = [0] * n + + def make_set(self, v: int): + self.parent[v] = v + self.size[v] = 1 + + def find(self, v: int) -> int: + par = self.parent[v] + if v == par: + return v + assert par is not None + self.parent[v] = self.find(par) + return cast(int, self.parent[v]) + + def join(self, a: int, b: int): + a, b = self.find(a), self.find(b) + if a == b: + return a + if self.size[a] < self.size[b]: + a, b = b, a + self.parent[b] = a + self.size[a] += self.size[b] + +def optimize_for_inference( + model: torch.nn.Module, + pass_config: Optional[Dict[str, Any]] = None, + tracer: Type[fx.Tracer] = fx.Tracer +) -> torch.nn.Module: + """ + Performs a set of optimization passes to optimize a model for the + purposes of inference. Specifically, the passes that are run are: + 1. Conv/BN fusion + 2. Dropout removal + 3. MKL layout optimizations + + The third optimization takes a function `use_mkl_heuristic` that's used + to determine whether a subgraph should be explicitly run in MKL layout. + + Note: As FX does not currently handle aliasing, this pass currently + assumes nothing aliases. If that isn't true, use at your own risk. + """ + default_pass_config = { + "conv_bn_fuse": True, + "remove_dropout": True, + "mkldnn_layout_optimize": {'heuristic': use_mkl_length}, + } + if pass_config is None: + pass_config = {} + default_pass_config.update(pass_config) + + if default_pass_config["conv_bn_fuse"]: + model = fuse(model) + if default_pass_config["remove_dropout"]: + model = remove_dropout(model) + if default_pass_config["mkldnn_layout_optimize"] is False: + return model + if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict): + raise RuntimeError("mkldnn_layout_optimize config is not a dict") + if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]: + raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config") + use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"] + + cur_tracer = tracer() + fx_graph = cur_tracer.trace(copy.deepcopy(model)) + fx_model = fx.GraphModule(cur_tracer.root, fx_graph) + modules: Dict[str, nn.Module] = dict(model.named_modules()) + + class MklSupport(Enum): + NO = 1 + YES = 2 + UNKNOWN = 3 + + # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node. + # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node. + # However, if it's in `mkldnn_supported_unknown`, then we only treat it as + # a MKLDNN node if its inputs are MKLDNN nodes. + for node in list(fx_graph.nodes): + supports_mkldnn = MklSupport.NO + if node.op == 'call_module': + cur_module = modules[node.target] + if type(cur_module) in mkldnn_supported: + supports_mkldnn = MklSupport.YES + sample_parameter = next(cur_module.parameters(), None) + if sample_parameter is not None: + assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules" + assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules" + elif node.op == 'call_function': + if node.target in mkldnn_supported: + supports_mkldnn = MklSupport.YES + elif node.target in mkldnn_supported_unknown: + supports_mkldnn = MklSupport.UNKNOWN + + if supports_mkldnn != MklSupport.NO: + if supports_mkldnn == MklSupport.UNKNOWN: + if not any(arg.target == 'to_dense' for arg in node.args): + continue + with fx_graph.inserting_before(node): + mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, ))) + + node.args = cast(Tuple[fx.node.Argument], mkldnn_args) + + with fx_graph.inserting_after(node): + dense_x = fx_graph.create_node('call_method', 'to_dense', (node,)) + node.replace_all_uses_with(dense_x) + dense_x.args = (node,) + + # Does pre-conversion of all modules into MKLDNN (when possible) + old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules) + fx_graph.old_modules = old_modules # type: ignore[attr-defined] + + # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b + for node in fx_graph.nodes: + if node.op == 'call_method' and node.target == 'to_dense': + prv_node = node.args[0] + users = list(node.users) + for user in users: + if user.op == 'call_method' and user.target == 'to_mkldnn': + user.replace_all_uses_with(prv_node) + fx_graph.erase_node(user) + if len(node.users) == 0: + fx_graph.erase_node(node) + + + num_nodes = len(fx_graph.nodes) + uf = UnionFind(num_nodes) + + def get_color(n): + if hasattr(n, 'color'): # Current node is part of a MKL subgraph + return uf.find(n.color) + if hasattr(n, 'start_color'): # Current node is input to MKL subgraph + return uf.find(n.start_color) + return None + + + # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists + # of input nodes (which are only `to_mkldnn` calls), output nodes + # (`to_dense` calls), and intermediate nodes, which are run entirely on + # MKLDNN layout tensors. + # + # Specifically, this code does a flood fill on a directed acyclic graph + # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes). + # If every node only had one input, this would be sufficient. However, in + # the case that a node has multiple inputs coming from different start + # nodes (i.e. colors), we need to join these 2 colors into 1. That's done + # using a Disjoint Set Union. + for cur_idx, node in enumerate(fx_graph.nodes): + if node.op == 'call_method' and node.target == 'to_mkldnn': + node.start_color = cur_idx + uf.make_set(cur_idx) + elif node.op == 'call_method' and node.target == 'to_dense': + assert get_color(node.args[0]) is not None + node.end_color = get_color(node.args[0]) + else: + cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None] + + if len(cur_colors) == 0: + continue + assert not any(i is None for i in cur_colors) + cur_colors = sorted(cur_colors) + node.color = cur_colors[0] + for other_color in cur_colors[1:]: + uf.join(cur_colors[0], other_color) + + + mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) + for node in fx_graph.nodes: + if hasattr(node, 'color'): + mkldnn_graphs[uf.find(node.color)].nodes.append(node) + if hasattr(node, 'start_color'): + mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node) + if hasattr(node, 'end_color'): + mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node) + + + # Now that we have all the subgraphs, we need to decide which MKLDNN + # subgraphs we actually want to keep in MKLDNN. + for graph in mkldnn_graphs.values(): + if not use_mkl_heuristic(graph): + for node in graph.start_nodes + graph.end_nodes: + prv = node.args[0] + node.replace_all_uses_with(prv) + fx_graph.erase_node(node) + reset_modules(graph.nodes, modules, old_modules) + + mkldnn_conversions = 0 + for node in fx_graph.nodes: + if node.target == 'to_mkldnn' or node.target == 'to_dense': + mkldnn_conversions += 1 + + logging.getLogger(__name__).info(f"mkldnn conversions: {mkldnn_conversions}") + fx_graph.lint() + result = fx.GraphModule(model, fx_graph) + return result diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..7e4eceeb4f28c214b51a57643f60c50a145a9dac --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py @@ -0,0 +1,1122 @@ +# mypy: ignore-errors + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import contextlib +import functools +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import torch +import torch.utils._pytree as pytree +from torch.fx import Tracer, GraphModule +from torch.fx.graph_module import _assign_attr +from weakref import WeakKeyDictionary +from collections import defaultdict +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake +from torch._dispatch.python import enable_python_dispatcher, enable_pre_dispatch +import torch.fx as fx +from torch.fx.node import _side_effectful_need_to_be_preserved_pre_dispatch +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from contextlib import contextmanager, nullcontext +import inspect +from dataclasses import dataclass +import weakref +import operator +from torch.utils._stats import count +import logging + +from torch.overrides import TorchFunctionMode + +from torch.utils._python_dispatch import ( + TorchDispatchMode, + _disable_infra_mode, + _push_mode, + _unset_infra_mode, +) + +from ._backward_state import BackwardState +from .sym_node import SymNode +from ._sym_dispatch_mode import SymDispatchMode +from torch.fx import Proxy +import torch.fx.traceback as fx_traceback +from torch import SymInt, SymFloat, SymBool +from torch.utils.weak import WeakTensorKeyDictionary, WeakIdKeyDictionary, _WeakHashRef + +__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"] + +aten = torch.ops.aten +prim = torch.ops.prim + +log = logging.getLogger(__name__) +not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") + +CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OperatorBase, Callable] = {} + +CONSTANT_NUMEL_LIMIT = 1 + +# We currently convert all SymInt to proxies before we use them. +# This could plausibly be handled at the Dynamo level. +pytree.register_pytree_node( + torch.Size, + lambda xs: (list(xs), None), + lambda xs, _: tuple(xs), + flatten_with_keys_fn=lambda xs: ( + [(pytree.SequenceKey(i), x) for i, x in enumerate(xs)], + None, + ), +) +def fake_signature(fn, nargs): + """FX gets confused by varargs, de-confuse it""" + argnames = ",".join(f"arg{i}" for i in range(nargs)) + return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn}) + +@contextmanager +def decompose(decomposition_table): + global CURRENT_DECOMPOSITION_TABLE + old_decomposition_table = CURRENT_DECOMPOSITION_TABLE + CURRENT_DECOMPOSITION_TABLE = decomposition_table + try: + yield CURRENT_DECOMPOSITION_TABLE + finally: + CURRENT_DECOMPOSITION_TABLE = old_decomposition_table + +# ensure we cannot collide with other properties +proxy_slot = object() +no_default = object() + +py_sym_types = (SymInt, SymFloat, SymBool) + +def is_sym_node(node): + assert hasattr(node, 'meta'), "All nodes traced with proxy_tensor should have meta" + return "val" in node.meta and isinstance(node.meta['val'], py_sym_types) + +def set_proxy_slot(obj, tracer, proxy): + if isinstance(obj, torch.Tensor): + # We DO want to clobber proxies whenever we run an inplace operation + # on a tensor, and it affects the metadata on the proxy. + tracer.tensor_tracker[obj] = proxy + elif isinstance(obj, torch.ScriptObject): + # We DO want to clobber proxies, with a similar rationale as for tensors. + tracer.script_object_tracker[obj] = proxy + else: + # NB: Never clobber pre-existing proxy. Although the proxies + # are in principle equivalent, when we do graph partitioning + # we need there not to be spurious dependencies on tangent inputs. + # This works because primals get their SymInts set first, and + # THEN later we allocate tangent inputs. Make sure if a SymInt + # is derivable from a primal that we use that. + assert isinstance(obj, py_sym_types), type(obj) + if obj not in tracer.symnode_tracker: + tracer.symnode_tracker[obj] = proxy + +def has_proxy_slot(obj, tracer): + assert isinstance(obj, (torch.Tensor, SymNode)), type(obj) + return get_proxy_slot(obj, tracer, False, lambda _: True) + +# the default argument is what to return if the slot is not set. +# the transform argument is handy if you need to extract a subfield from +# the successfully looked up result (but NOT the default.) +def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x): + if isinstance(obj, torch.Tensor): + tracker = tracer.tensor_tracker + elif isinstance(obj, torch.ScriptObject): + tracker = tracer.script_object_tracker + else: + assert isinstance(obj, py_sym_types), type(obj) + tracker = tracer.symnode_tracker + + if obj not in tracker: + if default is no_default: + raise RuntimeError(f"{obj} is not tracked with proxy for {tracer}") + return default + return transform(tracker[obj]) + +def snapshot_fake(val): + return val.detach() + +def extract_val(val): + if is_fake(val): + return snapshot_fake(val) + elif isinstance(val, py_sym_types): + return val + elif isinstance(val, torch.ScriptObject): + return val + elif isinstance(val, BackwardState): + return val + elif isinstance(val, (list, tuple)): + return val.__class__([extract_val(x) for x in val]) + elif isinstance(val, torch.Tensor): + if not val.is_sparse: + # NB: Kinda hacky, but we should try to get val as the metadata + # everywhere + # TODO: This doesn't properly track storages. A more robust + # approach would be to maintain a per-trace FakeTensorMode and + # from_real_tensor to create fake values (don't forget to + # snapshot_fake) + fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) + with fake_tensor_mode: + return torch.empty_strided(val.shape, val.stride(), device=val.device, dtype=val.dtype) + else: + return None + elif isinstance(val, (int, float, bool)): + return val + +# What invariants do we have for the 'val' set on the FX node? It has accurate +# metadata... but only for metadata that exists "below" all other subsystems +# (most notably autograd, but also vmap, functorch transforms, etc). This means +# you can get the dtype, shape, stride, storage, but you CANNOT get requires_grad, +# grad_fn, _base (_base actually may be set due to recursive call to +# ADInplaceOrView, but you shouldn't rely on it.) +def set_meta(proxy, val): + proxy.node.meta['val'] = extract_val(val) + # Best effort tensor_meta setting; prefer using val! + if is_fake(val): + proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) + elif isinstance(val, torch.Tensor) and not val.is_sparse: + proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) + return proxy + +def thunkify(f, *args, **kwargs): + """ + Delays computation of f until it's called again + Also caches the result + """ + return functools.lru_cache(1)(functools.partial(f, *args, **kwargs)) + +def track_tensor(tensor, proxy, *, constant, tracer): + def try_set_proxy_slot(outer_s, proxy_callable, *args): + assert callable(proxy_callable) + if isinstance(outer_s, SymInt): + set_proxy_slot(outer_s, tracer, thunkify(proxy_callable, outer_s, *args)) + # The basic idea is that we need to associate each tensor/SymInt + # with a Proxy. How do we setup this association? We just store + # the proxy on the proxy slot of the object, keyed on the tracer + # (so that if we have multiple tracers at the same time, they + # don't clobber each other.) + for i, s in enumerate(tensor.shape): + try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_size.int(proxy, i), x), i) + + for i, s in enumerate(tensor.stride()): + try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_stride.int(proxy, i), x), i) + + try_set_proxy_slot(tensor.numel(), lambda x: set_meta(torch.ops.aten.sym_numel.default(proxy), x)) + try_set_proxy_slot(tensor.storage_offset(), lambda x: set_meta(torch.ops.aten.sym_storage_offset.default(proxy), x)) + set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant)) + +def track_tensor_tree(inner_res, proxy_res, *, constant, tracer): + def wrap_with_proxy(e, proxy, constant): + if isinstance(e, torch.Tensor): + track_tensor(e, proxy, tracer=tracer, constant=constant) + set_meta(proxy, e) + elif isinstance(e, py_sym_types): + # NB: eagerly set meta here, so that the numbering is in order + set_meta(proxy, e) + set_proxy_slot(e, tracer, lambda: proxy) + elif isinstance(e, torch.ScriptObject): + set_proxy_slot(e, tracer, proxy) + set_meta(proxy, e) + elif isinstance(e, (tuple, list)): + if isinstance(proxy, fx.Proxy): + set_meta(proxy, e) + + # example use case: allreduce_ returns ([tensor], work) + for idx, ee in enumerate(e): + wrap_with_proxy(ee, proxy[idx], get_constant(idx)) + elif isinstance(e, dict): + # In theory we could support const-prop when proxy-tensor-tracing + # operators that returns dicts of tensors, but we have no use case + # for it today (since the only op we currently trace that can + # return a dict is triton_kernel_wrapper_functional/mutation, + # which does not participate in const-prop) + assert constant is None + + if isinstance(proxy, fx.Proxy): + set_meta(proxy, e) + + # example use case: triton_kernel_wrapper takes arguments as kwargs + for key, val in e.items(): + wrap_with_proxy(val, proxy[key], None) + elif isinstance(e, BackwardState): + set_meta(proxy, e) + e.proxy = proxy + else: + # intentionally pass on primitives + pass + + + def get_constant(idx): + if constant is None: + return None + else: + return constant[idx] + + wrap_with_proxy(inner_res, proxy_res, constant) + + return inner_res + + +def maybe_disable_fake_tensor_mode(): + # TODO: figure out if this API generally makes sense and bake it into the + # library + return unset_fake_temporarily() + + +@dataclass +class _ProxyTensor: + proxy: Proxy + constant: Optional[torch.Tensor] + + +def fetch_sym_proxy(tracer): + def inner(e): + n = e.node + if n.constant is not None: + return n.constant + if e.node.expr.is_number: + if isinstance(e, SymBool): + return bool(e.node.expr) + elif isinstance(e, SymInt): + return int(e.node.expr) + return float(e.node.expr) + else: + # NB: we REQUIRE all symints to be tracked + return get_proxy_slot(e, tracer)() + return inner + + +def fetch_object_proxy(tracer): + return lambda t: get_proxy_slot(t, tracer, t) + +HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter, FakeTensor) + +def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs): + unrecognized_types = [] + + def can_handle_tensor(x): + r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer) + if proxy_mode._allow_fake_constant: + r = r or type(x) in (torch._subclasses.FakeTensor,) + if not r: + unrecognized_types.append(type(x)) + return r + + # If there are any tensor subclasses, we need to handle those tensor subclasses first + # TODO: we could use types to test this + if not pytree.tree_all_only(torch.Tensor, can_handle_tensor, (args, kwargs)): + not_implemented_log.debug("ProxyTensorMode tensors without proxy had unrecognized subclasses: %s", unrecognized_types) + return NotImplemented + + r = maybe_handle_decomp(proxy_mode, func, args, kwargs) + if r is not NotImplemented: + return r + + # For pre-autograd tracing, we do not want to run CompositeImplicit decomps. + if not pre_dispatch and func not in [ + torch.ops.aten.size.default, torch.ops.aten.stride.default, torch.ops.aten.storage_offset.default + ]: + with proxy_mode: + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r + + tracer = proxy_mode.tracer + f_args, f_kwargs = pytree.tree_map_only((torch.Tensor, torch.ScriptObject), fetch_object_proxy(tracer), (args, kwargs)) + + # If there are SymInts, we also should not consider this constant. + # However, fake tensor handling of SymInts is sufficiently broken that + # I couldn't write a test for this case + all_constant = ( + pytree.tree_all_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs)) + # TODO: maybe constant SymInts should also be allowed? Not sure if + # this can happen + and pytree.tree_all_only((SymInt, SymFloat, SymBool), lambda _: False, (args, kwargs)) + ) + + if torch.Tag.data_dependent_output in func.tags: + # Check if all of the Tensor inputs are constants + if all_constant: + const_args, const_kwargs = pytree.tree_map_only( + _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs) + ) + with maybe_disable_fake_tensor_mode(): + return func(*const_args, **const_kwargs) + # If any of the Tensor inputs are "real" (not FakeTensor), we may + # incorrectly burn in constants by allowing this access. Raise + # an error in this case + if proxy_mode._error_on_data_dependent_ops and pytree.tree_all_only(torch.Tensor, lambda t: not is_fake(t), (args, kwargs)): + raise RuntimeError( + f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! " + "It's likely that this is caused by data-dependent control flow or similar. " + "It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' " + "in your make_fx call." + ) + proxy_args, proxy_kwargs = pytree.tree_map_only( + (SymInt, SymFloat, SymBool), + fetch_sym_proxy(proxy_mode.tracer), + pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (f_args, f_kwargs)) + ) + + # When we trace through a torch.tensor invocation, you never actually + # see a torch.ops.aten.tensor call. Instead, the way this function is + # implemented internally is that we allocate a plain tensor (this is + # *guaranteed* to be a plain tensor, we disable all modes when doing + # so), and then call at::lift_fresh on it (to give modes a chance to do + # their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed + # to be freshly allocated, so we want lift_fresh to be a no-op (directly + # returning the input argument). + # + # Here is the basic problem: when we trace this sequence of executions + # into an FX graph, what happens to this call sequence? Traditionally, + # tensor constants get interned as buffers on the FX GraphModule. But + # this is dangerous. Consider: + # + # x = torch.tensor(1) + # x.add_(2) + # + # Naively, this traces into: + # + # t = self._tensor_constant0 # initialized to torch.tensor(1) + # x = torch.ops.aten.lift_fresh(t) + # x.add_(2) + # + # If lift_fresh returns t directly, the subsequent add_ call will + # modify the tensor constant. Really, the problem is we've violated + # the invariant the argument to lift is fresh. So what we should + # preserve the invariant by replacing lift_fresh with lift_fresh_copy: + # + # t = self._tensor_constant0 # initialized to torch.tensor(1) + # x = torch.ops.aten.lift_fresh_copy(t) + # x.add_(2) + # + # This is what the overload modification does. + if func is torch.ops.aten.lift_fresh.default: + func = torch.ops.aten.lift_fresh_copy.default + + + proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs, + name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__)) + + # This makes DCE marginally less likely to DCE inplace operations. + # It is not strictly necessary + # Kind of a hacky way to test if an op is in-place or not + if func.overloadpacket.__name__[-1] == "_" and func.overloadpacket.__name__[0] != "_": + if isinstance(args[0], List): + # e.g., c10d::allreduce_ returns a list of tensors as the first element + # in the output. + for i, a in enumerate(args[0]): + a.proxy = proxy_out[0][i] + else: + args[0].proxy = proxy_out + + out = func(*args, **kwargs) + + # In some circumstances, we will be tracing in a situation where a tensor + # is *statically* known to be a constant (currently, this only happens if + # you run torch.tensor; deterministic factory functions like torch.arange + # don't get this treatment). When the tensor in question is small, it's + # helpful to due constant propagation in case we call item() (in which + # case we can return the constant value that is known, rather than give + # an error.) The logic here tests if constant propagation is possible + # (because all of the inputs are constant). If so, we disable fake tensor + # mode (if it is on) and do true compute on the constant. + # + # It's worth highlighting that we're making a policy decision here. + # There is a potential that the tensor is actually quite large, and we + # don't actually want to run the compute. The tensor being quite large + # is one of the reasons why factory functions don't get this treatment + # (since they can be quite large; if a parameter is initialized to a + # constant value it will be!) Similarly, there is also a potential + # to run an operator that blows up the size of a small tensor; we don't + # protect against this case, but we could force, e.g., only single + # element constant computation by testing the numel of the result before + # propagating const-ness. Similarly, we don't require the constant to + # live on CPU, but we could. + any_constant = pytree.tree_any_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs)) + + constant = None + + # If this is a lift, the input tensor is guaranteed to be a + # constant, so we keep a copy of the original argument along so + # we can query it if we're asked to item() it at some later point + if func is torch.ops.aten.lift_fresh_copy.default and out.numel() <= CONSTANT_NUMEL_LIMIT: + with maybe_disable_fake_tensor_mode(): + constant = args[0].clone() + elif ( + torch.Tag.nondeterministic_seeded not in func.tags + and all_constant + and any_constant + and pytree.tree_all_only(torch.Tensor, lambda t: t.numel() <= CONSTANT_NUMEL_LIMIT, out) + ): + # NB: do NOT include factories as constants + with maybe_disable_fake_tensor_mode(): + const_args, const_kwargs = pytree.tree_map_only( + _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs) + ) + constant = func(*const_args, **const_kwargs) + else: + constant = None + + track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) + return out + +class _SymNodeDict: + """ + Wrapper around a dictionary that will hash SymInts with their nodes + """ + def __init__(self): + self.sym_node_dict = {} + + def __setitem__(self, key: py_sym_types, value: Any): + self.sym_node_dict[key.node] = value + + def __getitem__(self, key: py_sym_types): + return self.sym_node_dict[key.node] + + def __contains__(self, key: py_sym_types): + return key.node in self.sym_node_dict + + def get(self, key: py_sym_types, default: Any = None): + return self.sym_node_dict.get(key.node, default) + +class PythonKeyTracer(Tracer): + def __init__(self): + super().__init__(autowrap_modules=()) + self.tensor_tracker = WeakTensorKeyDictionary() + self.symnode_tracker = _SymNodeDict() # type: ignore[var-annotated] + self.script_object_tracker = WeakIdKeyDictionary(dict=None, ref_type=_WeakHashRef) + + # In general, we don't want to make modules leaves. In principle, users of + # this tracer might want to override this in order to turn a couple specific + # modules into leaves in the traced graph. + def call_module( + self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Any: + return forward(*args, **kwargs) + + # We don't want to turn getattr calls into proxies. So we just return the actual value. + def getattr(self, attr, attr_val, parameter_proxy_cache): + return attr_val + + def create_arg(self, a: Any): + if isinstance(a, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node('get_attr', n, (), {}) + qualname: Optional[str] = None + + if not qualname: + i = 0 + while True: + qualname = f'_param_constant{i}' + if not hasattr(self.root, qualname): + break + i += 1 + setattr(self.root, qualname, a) + + return self.create_node('get_attr', qualname, (), {}) + elif isinstance(a, (SymInt, SymFloat, SymBool)): + assert a.node.constant is not None + return a.node.constant + return super().create_arg(a) + + def unwrap_proxy(self, e): + if isinstance(e, torch.Tensor): + return get_proxy_slot(e, self, e, lambda e: e.proxy) + elif isinstance(e, (torch.SymInt, torch.SymFloat, torch.SymBool)): + return get_proxy_slot(e, self, e, lambda e: e()) + elif isinstance(e, torch.ScriptObject): + return get_proxy_slot(e, self, e) + else: + return e + + +@torch._disable_dynamo +def dispatch_trace( + root: Union[torch.nn.Module, Callable], + tracer: Tracer, + concrete_args: Optional[Tuple[Any, ...]] = None, +) -> GraphModule: + graph = tracer.trace(root, concrete_args) + from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints + dedupe_symints(graph) + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + return fx._lazy_graph_module._make_graph_module(tracer.root, graph, name) + + +def wrap_key(f, tensors, tracer, pre_dispatch: bool): + flat_tensors, tensors_spec = pytree.tree_flatten(tensors) + + @functools.wraps(f) + def wrapped(*proxies): + flat_proxies, proxies_spec = pytree.tree_flatten(proxies) + assert len(flat_proxies) == len(flat_tensors) + with disable_proxy_modes_tracing() as m: + assert isinstance(m, ProxyTorchDispatchMode) + track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) + + out = f(*tensors) + out = pytree.tree_map_only( + torch.Tensor, + lambda t: get_proxy_slot(t, tracer, t, lambda x: x.proxy), + out + ) + out = pytree.tree_map_only( + (SymInt, SymFloat, SymBool), + lambda t: get_proxy_slot(t, tracer)(), + out + ) + return out + + return wrapped + +ORIGINAL_ATEN = None +@contextmanager +def set_original_aten_op(func): + global ORIGINAL_ATEN + if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta(): + ORIGINAL_ATEN = func + fx_traceback.current_meta['original_aten'] = func + try: + yield + finally: + ORIGINAL_ATEN = None + fx_traceback.current_meta['original_aten'] = None + else: + yield + + + +# This mode is **only** used for pre_dispatch tracing. +# In particular, we need to make sure that autograd/autocast API's +# that do not desugar into dispatcher operators stay in the graph. +class PreDispatchTorchFunctionMode(TorchFunctionMode): + + def __init__(self, tracer): + self.tracer = tracer + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if func in _side_effectful_need_to_be_preserved_pre_dispatch: + # It's for passing the export verifier which needs to verify the meta['val'] + # TODO(tmanlaibaatar): we should systematically couple it with expoert verifier, + # instead of hardcoding it here. + node = self.tracer.create_node("call_function", func, args, {}) + if func is torch._C._set_grad_enabled: + node.meta['val'] = None + return node + # Don't actually run the function! We just want to trace the calls + # into a graph. We don't actualy want to change global autograd state. + return func(*args, **kwargs) + + +class ProxyTorchDispatchMode(TorchDispatchMode): + def __init__(self, tracer, tracing_mode, pre_dispatch=False, _allow_fake_constant=False, _error_on_data_dependent_ops=True): + dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None + super().__init__(dk) + self.tracer = tracer + self.tracing_mode = tracing_mode + self.enable_tracing = True + self.pre_dispatch = pre_dispatch + self._allow_fake_constant = _allow_fake_constant + self._error_on_data_dependent_ops = _error_on_data_dependent_ops + self.sym_mode = ProxySymDispatchMode(tracer) + self.trace_state = {} + self._managers = [] + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.PROXY + # Every time we enter a mode, we maintain a stack telling us what the previous + # ProxyTorchDispatchMode state was (if there was any). + # This lets us properly reset the state on exit. + self.enter_stack: List[Optional[ProxyTorchDispatchMode]] = [] + + @count + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + with self.sym_mode.enable(False), set_original_aten_op(func): + return self.inner_torch_dispatch(func, types, args, kwargs) + + def __enter__(self): + # sym mode first, then us... + m = self.sym_mode.enable(True) + self._managers.append(m) + m.__enter__() + # Stash and store the previous proxy mode (there may or may not be one) + maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY) + self.enter_stack.append(maybe_prev_proxy_mode) + return super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + m = self._managers.pop() + # ...exit us first, then sym mode + b = super().__exit__(exc_type, exc_value, traceback) + + # Re-enable the previous proxy mode, if there was one. + mb_previous_proxy_mode = self.enter_stack.pop() + if mb_previous_proxy_mode is not None: + _push_mode(mb_previous_proxy_mode) + + if not b: + return m.__exit__(exc_type, exc_value, traceback) + else: + return m.__exit__(None, None, None) + + + def inner_torch_dispatch(self, func, types, args=(), kwargs=None): + if not self.enable_tracing: + return func(*args, **kwargs) + + if func in [prim.device.default]: + return func(*args, **kwargs) + + return proxy_call(self, func, self.pre_dispatch, args, kwargs) + + +class ProxySymDispatchMode(SymDispatchMode): + def __init__(self, tracer): + super().__init__() + self.tracer = tracer + # When false, we don't trace operations. If you do this, you MUST + # call track_tensor/track_tensor_tree on all results of the operation + # to ensure we can adequately track the results + self.enable_tracing = True + + @contextmanager + def enable(self, b): + old = self.enable_tracing + self.enable_tracing = b + try: + yield + finally: + self.enable_tracing = old + + def _compute_proxy(self, func, args, out: Union[SymInt, SymFloat, SymBool]): + n_args = tuple( + get_proxy_slot(a, self.tracer)().node if isinstance(a, py_sym_types) else a + for a in args + ) + + # func doesn't have a __torch_function__ that Proxy can interpose, so + # we gotta do it manually + n_out = self.tracer.create_node("call_function", func, n_args, {}) + p_out = fx.Proxy(n_out, self.tracer) + set_meta(p_out, out) + return p_out + + def __sym_dispatch__(self, func, types, args, kwargs): + if not self.enable_tracing: + return func(*args, **kwargs) + + # Peephole optimize multiply by one + # NB: be careful not to trigger guards here! + if func == operator.mul: + if isinstance(args[1], int) and args[1] == 1: + return args[0] + elif isinstance(args[0], int) and args[0] == 1: + return args[1] + + # For speed, we assume there are no nested data structures + # (otherwise we could use tree_map) + # We also assume there are no keyword arguments. + assert not kwargs + out = func(*args, **kwargs) + + # If func returned a constant, we don't need to trace; we have + # determined that the result is constant (no matter if the inputs + # were symbolic) and it is no longer necessary to trace the + # computation. This could occur if func triggered some guards. + if isinstance(out, py_sym_types): + # Delays tracing out the proxies on this op until we actually need it + p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out) + set_proxy_slot(out, self.tracer, p_out_thunk) + + return out + + +# TODO: I'm not sure what the point of this class is; you can just +# make_fx through a regular Interpreter +class DecompositionInterpreter(torch.fx.Interpreter): + def __init__(self, module: torch.fx.GraphModule, new_graph: torch.fx.Graph, decomposition_table=None, **kwargs): + super().__init__(module, **kwargs) + self.new_graph = new_graph + self.tracer = torch.fx.proxy.GraphAppendingTracer(self.new_graph) + # Blegh + self.tracer.tensor_tracker = WeakTensorKeyDictionary() # type: ignore[attr-defined] + self.tracer.symnode_tracker = weakref.WeakKeyDictionary() # type: ignore[attr-defined] + self.decomposition_table = decomposition_table + if self.decomposition_table is None: + self.decomposition_table = {} + self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") + + def placeholder(self, target, args, kwargs): + out = super().placeholder(target, args, kwargs) + proxy = torch.fx.Proxy(self.new_graph.placeholder(target), self.tracer) + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + # TODO handle case where the first character of target is '*' + return out + + def get_attr(self, target, args, kwargs): + out = super().get_attr(target, args, kwargs) + proxy = torch.fx.Proxy(self.new_graph.get_attr(target), self.tracer) + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + return out + + # call_function, call_method, call_module get traced automatically by the outer mode. + + def output(self, target, args, kwargs): + out = super().output(target, args, kwargs) + + def unwrap(e): + return get_proxy_slot(e, self.tracer, e, lambda x: x.proxy.node) + self.new_graph.output(pytree.tree_map(unwrap, out)) + return out + + def run(self, *args, **kwargs): + # Should enter the mode at least once for being able to restore it later + # See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025 + with decompose(self.decomposition_table), self.mode: + return super().run(*args, **kwargs) + + +def wrapper_and_args_for_make_fx(func, args, kwargs): + # make_fx doesn't support kwargs, so we need to do this flattening + # and then unflatten the args before calling func + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + def wrapped(flat_args): + fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec) + return func(*fn_args, **fn_kwargs) + return wrapped, flat_args + +@contextmanager +def disable_autocast_cache(): + old_value = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(False) + try: + yield + finally: + torch.set_autocast_cache_enabled(old_value) + + +class _ModuleStackTracer(PythonKeyTracer): + r"""Customized version of PythonKeyTracer that retains module stack + information in node.meta["nn_module_stack"]. + + FX symbolic trace actually does this already, but it relies on `self.root` + being the actual module being traced. Since make_fx traces a lambda of our + creation, things don't work properly. + + So for this version we hold onto a reference to the original module + (scope_root) and use that to match the path. Also when we see, + A + / \ + B C + \ / + D + we want to record the path as A.B.D by recording only one path. + See Note [Preserving the nn module stack metadata during export non-strict mode] # noqa: W605 + """ + + def __init__(self, scope_root): + super().__init__() + self.scope_root = scope_root + self.proxy_paths = WeakKeyDictionary() + self.proxy_modules = WeakKeyDictionary() + self.counter = 0 + + self.module_id_cache = defaultdict(list) + for name, mod in self.scope_root.named_modules(remove_duplicate=False): + self.module_id_cache[id(mod)].append(name) + + self_ = self + + class AttrProxy: + def __init__(self, base, path): + self.__class__ = type( + base.__class__.__name__, + (self.__class__, base.__class__), + {}, + ) + self.__dict__ = base.__dict__ + self.__class__.__module__ = base.__class__.__module__ + self.__class__.__qualname__ = base.__class__.__qualname__ + self_.proxy_paths[self] = path + self_.proxy_modules[self] = base + + def __getattr__(self, name): + assert isinstance(self, torch.nn.Module) + attr_val = super().__getattr__(name) + if isinstance(attr_val, AttrProxy): + attr_val = self_.proxy_modules[attr_val] + elif not isinstance(attr_val, torch.nn.Module): + return attr_val + return AttrProxy(attr_val, self_.proxy_paths[self] + "." + name) + + @property + def _modules(self): + assert "_modules" in self.__dict__ + submodules = self.__dict__["_modules"] + assert isinstance(submodules, dict) + return { + key: AttrProxy(value, self_.proxy_paths[self] + "." + str(key)) + for key, value in submodules.items() + } + + self.proxy_type = AttrProxy + + def path_of_module(self, mod: torch.nn.Module) -> str: + """ + Use tracked access path during tracing instead of the default BFS behavior. + Still use all the possible module paths to verify the result. + """ + if mod is self.scope_root: + return "" + + if isinstance(mod, self.proxy_type): + return self.proxy_paths[mod] + + return Tracer.path_of_module(self, mod) + + def getattr(self, attr, attr_val, parameter_proxy_cache): + if not isinstance(attr_val, torch.nn.Module) or isinstance(attr_val, torch.fx.GraphModule): + return super().getattr(attr, attr_val, parameter_proxy_cache) + if isinstance(attr_val, self.proxy_type): + return attr_val + return self.proxy_type(attr_val, attr) + + def trace(self, root, concrete_args): + res = super().trace(root, concrete_args) + # Since we are making AttrProxy mimic the original + # submodule, when someone registers a module directly + # to the tracer while tracing, the proxy object gets registered + # first. So we need to replace the proxy modules with the real ones + # This can happen during HOO tracing + proxy_module_names_to_be_replaced = [] + for name, module in self.root.named_modules(): + if module in self.proxy_modules: + proxy_module_names_to_be_replaced.append((name, module)) + + def _delete_proxy_attr(obj, target): + # Copied from fx/graph_module.py + # Customized it for proxy type + atoms = target.split(".") + path, target_submod = atoms[:-1], atoms[-1] + assert isinstance(obj, torch.nn.Module) + mod = obj + + # Get the parent module + for item in path: + + if not hasattr(mod, item): + return False + + mod = getattr(mod, item) + + if not isinstance(mod, (self.proxy_type, torch.nn.Module)): + return False + + if not hasattr(mod, target_submod): + return False + + # At least the leaf module should be proxy type. + if not isinstance(getattr(mod, target_submod), self.proxy_type): + return False + + delattr(mod, target_submod) + return True + + for (proxy_module_name, proxy_module) in proxy_module_names_to_be_replaced: + _delete_proxy_attr(self.root, proxy_module_name) + actual_module = self.proxy_modules[proxy_module] + _assign_attr(actual_module, self.root, proxy_module_name) + + return res + + + def call_module(self, m, forward, args, kwargs): + """PythonKeyTracer overrides call_module to avoid the scope handling, + but we actually want it. + """ + from torch._dynamo import OptimizedModule + # FIXME (tmanlaibaatar) + # When we call torch.compile inside HOO, we will end up + # invoking a module that is not registered on the root. For + # now, we just inline them. But once we start supporting + # mark_strict in export, we do need to properly handle this. + # Right now, it doesn't matter because current non-strict + # use cases don't need to work with HOO. + if isinstance(m, (OptimizedModule, GraphModule)): + return forward(*args, **kwargs) + return Tracer.call_module(self, m, forward, args, kwargs) + + + def is_leaf_module(self, m, module_qualified_name): + return False + + +def make_fx(f, + decomposition_table=None, + tracing_mode="real", + _allow_non_fake_inputs=False, + *, + pre_dispatch=False, + record_module_stack=False, + _allow_fake_constant=False, + _error_on_data_dependent_ops=True): + assert tracing_mode in ["real", "fake", "symbolic"] + + if decomposition_table is None: + decomposition_table = {} + + if torch.ops.aten.sym_numel.default not in decomposition_table: + decomposition_table = { + **decomposition_table, + torch.ops.aten.sym_numel.default: torch._decomp.decompositions.sym_numel + } + + @functools.wraps(f) + def wrapped(*args): + # Avoid importing sympy at a module level + from .symbolic_shapes import ShapeEnv + + phs = pytree.tree_map(lambda _: fx.PH, args) # type: ignore[attr-defined] + + if hasattr(f, "_orig_mod") and record_module_stack: + scope_root = f._orig_mod + fx_tracer = _ModuleStackTracer(scope_root) + else: + fx_tracer = PythonKeyTracer() + fake_tensor_mode: Any = nullcontext() + if tracing_mode == "real": + fake_tensor_mode = nullcontext() + elif tracing_mode == "fake": + import torch._dynamo + fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args) + if fake_tensor_mode is None: + fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=True, + allow_non_fake_inputs=_allow_non_fake_inputs, + shape_env=ShapeEnv(), + static_shapes=True, + ) + elif tracing_mode == "symbolic": + import torch._dynamo + fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args) + if fake_tensor_mode is None: + shape_env = ShapeEnv() + fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=False, + allow_non_fake_inputs=_allow_non_fake_inputs, + shape_env=shape_env) + else: + shape_env = fake_tensor_mode.shape_env + assert shape_env is not None, "shape_env should be set if tracing with 'symbolic'" + + else: + raise AssertionError(f"Unexpected tracing type: {tracing_mode}") + + python_dispatcher_mode: Any = nullcontext() + pre_dispatch_mode: Any = nullcontext() + # pre-autograd tracing uses per-dispatch-key modes, + # which requires the python dispatcher + if tracing_mode == "symbolic" or pre_dispatch: + python_dispatcher_mode = enable_python_dispatcher() + if pre_dispatch: + pre_dispatch_mode = enable_pre_dispatch() + + proxy_function_mode: Any = nullcontext() + if pre_dispatch: + proxy_function_mode = PreDispatchTorchFunctionMode(fx_tracer) + + proxy_mode = ProxyTorchDispatchMode(fx_tracer, + tracing_mode, + pre_dispatch=pre_dispatch, + _allow_fake_constant=_allow_fake_constant, + _error_on_data_dependent_ops=_error_on_data_dependent_ops) + + arg_count = 0 + + def wrap_fake(x): + nonlocal arg_count + # TODO: it would be nice to line these up with the names + # FX will choose for the placeholders, but we don't + # actually know what the names will be at this point yet + # NB: the Source here is actually meaningless + from torch._dynamo.source import ConstantSource + source = ConstantSource(f"input{arg_count}") + if isinstance(x, torch.Tensor): + arg_count += 1 + return fake_tensor_mode.from_tensor(x, source=source) # type: ignore[attr-defined] + # NB: don't match on bools + elif type(x) is int and tracing_mode == "symbolic": + return shape_env.create_symintnode(shape_env.create_symbol(x, source, positive=None), hint=x, source=source) + + return x + + sym_mode = proxy_mode.sym_mode + + wrap_fn_map = { + "real": lambda x: x, + "fake": wrap_fake, + "symbolic": wrap_fake, + } + args = pytree.tree_map(wrap_fn_map[tracing_mode], args) + + if not hasattr(inspect.unwrap(f), '__code__') or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS: + # FX doesn't support varargs, so we gotta fake up a wrapper + # TODO: Would be nice to fix this at the source... + func = fake_signature(f, len(phs)) + else: + func = f + + # We disable the autocast cache as the autocast cache causes type conversions on parameters to + # check a cache, which introduces untracked tensors into the graph + # + # We also disable tracing by any other tensor proxy-based tracers except the current. The + # purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is + # thus irrelevant to any external functional trace. + with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, proxy_function_mode, \ + sym_mode, proxy_mode, disable_autocast_cache(): + t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs)) + + # TODO: kind of a bad way to do it, should maybe figure out a better way + if tracing_mode == "symbolic": + t.shape_env = shape_env # type: ignore[assignment] + return t + + return wrapped + + +def get_torch_dispatch_modes(): + return torch.utils._python_dispatch._get_current_dispatch_mode_stack() + + +def get_innermost_proxy_mode(): + return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) + + +@contextlib.contextmanager +def disable_proxy_modes_tracing(): + return _disable_infra_mode(torch._C._TorchDispatchModeKey.PROXY) + + +def maybe_handle_decomp(proxy_mode, op, args, kwargs): + if op in CURRENT_DECOMPOSITION_TABLE: + with proxy_mode: + return CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs) + return NotImplemented + + +def get_isolated_graphmodule(func, args, kwargs, tracing_mode="real"): + """A helper function used to get the GraphModule for the given func. + + It's expected to be used in the ProxyTensor tracing context. + It detaches the args and kwargs from the current tracer so that the trace of + the current graph module can be created without any side-effects. + """ + wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs) + + with disable_proxy_modes_tracing(): + gm = make_fx(wrapped, tracing_mode=tracing_mode)(all_args) + return gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py new file mode 100644 index 0000000000000000000000000000000000000000..0e27dc869e1d565a8de0db7277d9ccc6fc259bb4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py @@ -0,0 +1,1330 @@ +""" +This file does three things: +- Contains the definition of SymNode +- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time +- Does not depend on sympy at import time + +As this file is imported from within torch/__init__.py we do not want it to depend on SymPy +to avoid having to load SymPy at import time, as doing so is *very* slow. +""" + +import builtins +import itertools +import logging +import math +import operator +import sys +from functools import lru_cache, update_wrapper +from typing import Optional, Type, TYPE_CHECKING, Union + +import torch + +# NB: The sym_* functions are used via getattr() and must be imported here. +from torch import ( # noqa: F401 + sym_float, + sym_ite, + sym_max, + sym_min, + sym_not, + SymBool, + SymFloat, + SymInt, +) + +from torch.fx.experimental._sym_dispatch_mode import ( + handle_sym_dispatch, + sym_function_mode, +) + +if TYPE_CHECKING: + from torch.fx.experimental.symbolic_shapes import ShapeEnv + +log = logging.getLogger(__name__) +sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node") + + +__all__ = ["SymNode", "method_to_operator", "magic_methods"] + + +SymTypes = (SymInt, SymFloat, SymBool) + + +def _to_symtype(t): + if t is bool: + return SymBool + if t is int: + return SymInt + if t is float: + return SymFloat + return t + + +# TODO: An incomplete list +# 1. Set variables to be equal when we do equality +# 2. Specialize on 0/1 when we do subtraction +class SymNode: + """ + This is a type erased SymInt/SymFloat which we use to do actual operations. + End users don't touch this. Magic methods are NOT defined on this object. + """ + + def __init__( + self, + expr, + shape_env, + pytype, + hint: Optional[Union[int, float, bool]], + constant=None, + fx_node=None, + ): + self._expr = expr + self.shape_env = shape_env + self.pytype = pytype + # What's the difference between hint and constant? + # + # - A constant is known to be invariant across invocations of the model; + # it will always be this value. We only really know this when we + # encounter an honest-to-goodness literal (when wrapping it into + # a SymNode, we set constant.) Most of the time, constant is None + # + # - A hint is a *particular* value from the particular run we are + # tracing, but it may vary the next time around. It's useful to + # keep this around, as if we need a concrete value from a SymNode, + # we will return the hint and guard on the expression that produced + # it giving the same hint next time around. The hint is not + # guaranteed to be set either: if you have an unbacked SymNode, + # there won't be any hint; it was the result of some tensor-dependent + # computation, but we don't know what it actually is because we + # haven't actually run the tensor computation. + # + # If _hint is None, we will query maybe_evaluate_static(compute_hint=True) + # in hopes that we've learned enough about the unbacked symints to + # discharge the hint; otherwise, you're likely to just error out. + # + # (A previous version of this system had some optimizations to only + # recompute when it was possible we had learned enough about the + # unbacked symint that a hint was now possible, but as we added more + # potential refinements to unbacked symints this got harder to keep + # in sync, so we've deleted it for now.) + if hint is not None: + assert type(hint) is pytype or type(hint) is _to_symtype(pytype), ( + "Cannot create SymNode of type " + f"{pytype} with incompatible hint of type {type(hint)}" + ) + self._hint = hint + self.constant: Optional[Union[int, float, bool]] = constant + + # Record the FX node of the current node if we are doing translation + # validation. They will be used for building the input assertions for + # the translation validation problem. + self.fx_node = ( + fx_node if self.shape_env._translation_validation_enabled else None + ) + + def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode": + return SymNode( + self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node + ) + + @property + def expr(self): + return self.shape_env.replace(self._expr) + + # Recompute the hint and see if we've got it now + # Precondition: self._hint is None + def _update_hint(self): + r = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True) + if r is not None: + self._hint = self.pytype(r) if not isinstance(r, SymTypes) else r + + @property + def hint(self): + if self._hint is None: + self._update_hint() + return self._hint + + def has_hint(self): + if self._hint is None: + self._update_hint() + return self._hint is not None + + def require_hint(self, fallback=None): + if self._hint is None: + self._update_hint() + if self._hint is None: + if fallback is not None: + return fallback + # NB: we expect this to raise + return self.shape_env.size_hint(self.expr) + return self._hint + + def maybe_as_int(self): + if self.expr.is_number: + return int(self.expr) + else: + return None + + def is_int(self): + return self.pytype is int + + def is_float(self): + return self.pytype is float + + def is_bool(self): + return self.pytype is bool + + def is_nested_int(self): + # Unbacked SymInts cannot be nested int today + return ( + self._hint is not None + and isinstance(self._hint, SymInt) + and self._hint.node.is_nested_int() + ) + + def wrap_int(self, num): + assert type(num) is int + import sympy + + return SymNode( + sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num + ) + + def wrap_float(self, num): + assert type(num) is float + import sympy + + return SymNode( + sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num + ) + + def wrap_bool(self, num): + assert type(num) is bool + import sympy + + return SymNode( + sympy.true if num else sympy.false, + self.shape_env, + bool, + num, + constant=num, + fx_node=num, + ) + + def clone(self): + return self + + def str(self): + return f"{self.expr}" + + def __str__(self): + return self.str() + + def __repr__(self): + return self.str() + + # These methods call the metaprogrammed methods, they're hand written + # here so we get good stack traces + def abs(self) -> "SymNode": + return self._abs() # type: ignore[attr-defined] + + def pos(self) -> "SymNode": + return self._pos() # type: ignore[attr-defined] + + def round(self, ndigits=None) -> "SymNode": + return self._round(ndigits) # type: ignore[attr-defined] + + def add(self, other) -> "SymNode": + return self._add(other) # type: ignore[attr-defined] + + def sub(self, other) -> "SymNode": + return self._sub(other) # type: ignore[attr-defined] + + def mul(self, other) -> "SymNode": + return self._mul(other) # type: ignore[attr-defined] + + def mod(self, other) -> "SymNode": + return self._mod(other) # type: ignore[attr-defined] + + def pow(self, other) -> "SymNode": + return self._pow(other) # type: ignore[attr-defined] + + def and_(self, other) -> "SymNode": + return self._and_(other) # type: ignore[attr-defined] + + def or_(self, other) -> "SymNode": + return self._or_(other) # type: ignore[attr-defined] + + def truediv(self, other) -> "SymNode": + return self._truediv(other) # type: ignore[attr-defined] + + def floordiv(self, other) -> "SymNode": + return self._floordiv(other) # type: ignore[attr-defined] + + def lshift(self, other) -> "SymNode": + return self._lshift(other) # type: ignore[attr-defined] + + def rshift(self, other) -> "SymNode": + return self._rshift(other) # type: ignore[attr-defined] + + def sym_not(self) -> "SymNode": # noqa: F811 + return self._sym_not() # type: ignore[attr-defined] + + def eq(self, other) -> "SymNode": + return self._eq(other) # type: ignore[attr-defined] + + def ne(self, other) -> "SymNode": + return self._ne(other) # type: ignore[attr-defined] + + def gt(self, other) -> "SymNode": + return self._gt(other) # type: ignore[attr-defined] + + def lt(self, other) -> "SymNode": + return self._lt(other) # type: ignore[attr-defined] + + def le(self, other) -> "SymNode": + return self._le(other) # type: ignore[attr-defined] + + def ge(self, other) -> "SymNode": + return self._ge(other) # type: ignore[attr-defined] + + def floor(self) -> "SymNode": + return self._floor() # type: ignore[attr-defined] + + def is_integer(self) -> "SymNode": + return self._is_integer() # type: ignore[attr-defined] + + def sym_float(self) -> "SymNode": # noqa: F811 + return self._sym_float() # type: ignore[attr-defined] + + def sym_int(self) -> "SymNode": + return self._sym_int() # type: ignore[attr-defined] + + def ceil(self) -> "SymNode": + return self._ceil() # type: ignore[attr-defined] + + def neg(self) -> "SymNode": + return self._neg() # type: ignore[attr-defined] + + def sym_min(self, other) -> "SymNode": # noqa: F811 + return self._sym_min(other) # type: ignore[attr-defined] + + def sym_max(self, other) -> "SymNode": # noqa: F811 + return self._sym_max(other) # type: ignore[attr-defined] + + def sym_ite(self, then_val, else_val) -> "SymNode": + return self._sym_ite(then_val, else_val) # type: ignore[attr-defined] + + def is_contiguous(self, sizes, strides) -> "SymNode": + return self._is_contiguous(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode": + return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode": + return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode": + return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode": + return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined] + + def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode": + return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined] + + # Make C++ happy + def sym_or(self, other): + return self.or_(other) + + def sym_and(self, other): + return self.and_(other) + + def is_non_overlapping_and_dense(self, sizes, strides): + return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] + + def int_(self): + return self.guard_int("", 0) # NB: uses Python backtrace + + # You can manually trigger a guard with this function + def guard_int(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) + try: + return int(r) + except Exception: + log.warning("Failed to convert to int: %s", r) + raise + + def guard_float(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.shape_env.evaluate_expr( + self.expr, self.hint, fx_node=self.fx_node, expect_rational=False + ) + try: + return float(r) + except Exception: + log.warning("Failed to convert to float: %s", r) + raise + + def guard_bool(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) + try: + return bool(r) + except Exception: + log.warning("Failed to convert to bool: %s", r) + raise + + def expect_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + if self.has_hint() and not free_unbacked_symbols(self.expr): + # OK to generate guards + return self.guard_bool(file, line) + # Generate a deferred runtime assert (this might actually end up doing + # a regular guard if we can!) + # TODO: file/line here is very important, because the assert has been + # deferred so you can't backtrace easily + return self.shape_env.defer_runtime_assert( + self.expr, f"{file}:{line}", fx_node=self.fx_node + ) + + def expect_size(self, file, line): + from torch.fx.experimental.symbolic_shapes import _advise_is_size + + b = self.ge(self.wrap_int(0)) + # Generate a deferred runtime assert + r = b.expect_true(file, line) + # Refine compile time range, but only if it's unbacked. + # If you refine range for hinted variables, you can end up making + # improper deductions since compile time reasoning may be + # incompatible with runtime reasoning. + if r and not self.has_hint(): + _advise_is_size(SymInt(self)) + return r + + def guard_size_oblivious(self, file, line): + """ + Like guard_bool, but if we encounter unbacked symbols, if those symbols + are size-like, we will treat them as >= 2 for the purposes of the analysis. + + This CHANGES the runtime semantics, but all size-oblivious sites have been + audited to ensure that the runtime semantics don't change in a material way. + Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping + an unbacked one size, or a tensor reporting as non-contiguous even if it's + contiguous if it would have been reported contiguous due to being empty. + """ + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.shape_env.evaluate_expr( + self.expr, self.hint, fx_node=self.fx_node, size_oblivious=True + ) + try: + return bool(r) + except Exception: + log.warning("Failed to convert to bool: %s", r) + raise + + def bool_(self): + return self.guard_bool("", 0) + + def is_symbolic(self): + return True + + def nested_int(self): + return None + + def is_constant(self): + return False + + +# TODO: this probably needs the sizes-strides eval functions +METHOD_TO_OPERATOR = { + "pos": operator.pos, + "abs": operator.abs, + "add": operator.add, + "and": operator.and_, + "ceil": math.ceil, + "eq": operator.eq, + "floor": math.floor, + "floordiv": operator.floordiv, + "ge": operator.ge, + "gt": operator.gt, + "is_integer": lambda x: x.is_integer(), + "le": operator.le, + "lshift": operator.lshift, + "lt": operator.lt, + "mod": operator.mod, + "mul": operator.mul, + "ne": operator.ne, + "neg": operator.neg, + "or": operator.or_, + "pow": operator.pow, + "round": builtins.round, + "rshift": operator.rshift, + "sub": operator.sub, + "sym_float": sym_float, + "sym_ite": sym_ite, + "sym_max": sym_max, + "sym_min": sym_min, + "sym_not": sym_not, + "truediv": operator.truediv, +} + +unary_magic_methods = { + "abs", + "sym_float", + "ceil", + "floor", + "neg", + "sym_not", + "pos", +} + + +# Adding math ops: sqrt, cos, sin, ... +def _get_sym_node_fn(name): + def fn(self): + return getattr(self, f"_sym_{name}")() + + return fn + + +math_op_names = ( + "sqrt", + "cos", + "cosh", + "sin", + "sinh", + "tan", + "tanh", + "asin", + "acos", + "atan", +) +for name in math_op_names: + sym_name = f"sym_{name}" + priv_sym_name = f"_{sym_name}" + setattr(SymNode, sym_name, _get_sym_node_fn(name)) + METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name) + unary_magic_methods.add(sym_name) + __all__.append(sym_name) + + +# Unary methods that are not magic methods +unary_nonmagic_methods = { + "is_integer", +} + +unary_methods = unary_magic_methods | unary_nonmagic_methods + +# Most methods are only registered on SymInt and SymFloat +# Some methods are only be registered on SymBool +only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"} +# Methods that implicitly convert SymBool into SymInt +bool_becomes_int_magic_methods = {"add", "sub", "mul"} +# Methods that are also on SymBool, in addition to on SymInt and SymFloat +also_bool_magic_methods = {"eq"} +bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods + +# Methods that are only for float +only_float_magic_methods = {"is_integer"} + + +magic_methods_on_operator_with_trailing_underscore = {"and", "or"} + + +always_float_magic_methods = {"truediv", "sym_float", "pow"} + +for name in math_op_names: + sym_name = f"sym_{name}" + always_float_magic_methods.add(sym_name) + + +always_int_magic_methods = {"ceil", "floor"} +always_bool_magic_methods = { + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + "and", + "or", + "sym_not", + "is_non_overlapping_and_dense", + "is_integer", +} + +# Methods that have a `__foo__` as well as `__rfoo__` + + +def _sympy_truediv(a, b): + from torch.utils._sympy.functions import TrueDiv + + return TrueDiv(a, b) + + +def _sympy_floordiv(a, b): + from torch.utils._sympy.functions import FloorDiv + + return FloorDiv(a, b) + + +def _sympy_mod(a, b): + from torch.utils._sympy.functions import Mod + + return Mod(a, b) + + +def _sympy_pow(a, b): + from torch.utils._sympy.functions import Pow + + return Pow(a, b) + + +def _sympy_and(a, b): + import sympy + + return sympy.And(a, b) + + +def _sympy_or(a, b): + import sympy + + return sympy.Or(a, b) + + +def _sympy_lshift(a, b): + from torch.utils._sympy.functions import LShift + + return LShift(a, b) + + +def _sympy_rshift(a, b): + from torch.utils._sympy.functions import RShift + + return RShift(a, b) + + +reflectable_magic_methods = { + "add": operator.add, + "sub": operator.sub, + "mul": operator.mul, + "mod": _sympy_mod, + "pow": _sympy_pow, + "and": _sympy_and, + "or": _sympy_or, + "truediv": _sympy_truediv, + "floordiv": _sympy_floordiv, + "lshift": _sympy_lshift, + "rshift": _sympy_rshift, +} + + +def _floor_ceil_helper(a, fn): + import sympy + + if isinstance(a, sympy.Mul): + aa = a.args + if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer: + coef = sympy.Integer(aa[0]) + if aa[0] == coef: # structural equality test + return coef * aa[1] + if ( + isinstance(a, sympy.Float) + and a == sympy.Integer(a) + or isinstance(a, sympy.Integer) + ): + return sympy.Integer(a) + return fn(a) + + +def _sympy_floor(a): + import sympy + + return _floor_ceil_helper(a, sympy.floor) + + +def _sympy_ceil(a): + import sympy + + return _floor_ceil_helper(a, sympy.ceiling) + + +def _sympy_eq(a, b): + import sympy + + return sympy.Eq(a, b) + + +def _sympy_ne(a, b): + import sympy + + return sympy.Ne(a, b) + + +def _sympy_gt(a, b): + import sympy + + return sympy.Gt(a, b) + + +def _sympy_lt(a, b): + import sympy + + return sympy.Lt(a, b) + + +def _sympy_le(a, b): + import sympy + + return sympy.Le(a, b) + + +def _sympy_ge(a, b): + import sympy + + return sympy.Ge(a, b) + + +def _sympy_min(a, b): + import sympy + + return sympy.Min(a, b) + + +def _sympy_max(a, b): + import sympy + + return sympy.Max(a, b) + + +def _sympy_ite(a, t, f): + import sympy + + return sympy.Piecewise((t, a), (f, True)) + + +current_module = sys.modules[__name__] + + +def _get_sym_math_fn(name): + def fn(a): + import sympy + + return getattr(sympy, name)(a) + + return fn + + +for name in math_op_names: + priv_sympy_name = f"_sympy_{name}" + fn = _get_sym_math_fn(name) + fn.__qualname__ = fn.__name__ = priv_sympy_name + setattr(current_module, priv_sympy_name, fn) + +del fn, name, priv_sympy_name # type: ignore[possibly-undefined] + + +def _sympy_abs(a): + import sympy + + return sympy.Abs(a) + + +def _sympy_round(number, ndigits=None): + from torch.utils._sympy.functions import Round, RoundDecimal + + if ndigits is None: + return Round(number) + else: + return RoundDecimal(number, ndigits) + + +def _sympy_sym_float(a): + # Cannot use sympy.Float(a) here, coz it expects python literals + # Multiply by 1.0 to cast to float. This is needed when the input + # is a SymInt which has the assumption that it is integer and + # SymPy will otherwise assume that return value cannot be a float. + return a * 1.0 + + +def _sympy_is_integer(a): + import sympy + + return sympy.Eq(sympy.floor(a), a) + + +magic_methods = { + **reflectable_magic_methods, + "sym_not": operator.invert, + "pos": operator.pos, + "eq": _sympy_eq, + "ne": _sympy_ne, + "gt": _sympy_gt, + "lt": _sympy_lt, + "le": _sympy_le, + "ge": _sympy_ge, + "floor": _sympy_floor, + "sym_float": _sympy_sym_float, + "ceil": _sympy_ceil, + "neg": operator.neg, + "sym_min": _sympy_min, + "sym_max": _sympy_max, + "sym_ite": _sympy_ite, + "abs": _sympy_abs, + "round": _sympy_round, + "is_integer": _sympy_is_integer, +} + + +for name in math_op_names: + sym_name = f"sym_{name}" + magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}") + +del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined] + + +def sympy_is_contiguous(sizes, strides): + dim = len(sizes) + return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1))) + + +def sympy_is_contiguous_generic(sizes, strides, dim_order): + import sympy + + dim = len(sizes) + + if len(dim_order) != dim: + return sympy.false + + is_contiguous = sympy.true + z = sympy.Integer(1) + # Contiguous if the strides make sense (or the dim is size 1) + for d in dim_order: + is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z) + z *= sizes[d] + # OR if any size is zero + for d in range(dim): + is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0)) + return is_contiguous + + +# NB: There is a TODO in C++ to allow omitting the batch dim. If that +# happens you will need to refactor this + + +def sympy_is_channels_last_contiguous_2d(sizes, strides): + return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0]) + + +def sympy_is_channels_last_contiguous_3d(sizes, strides): + return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0]) + + +def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): + import sympy + + dim = len(sizes) + + if dim != len(dim_order): + return sympy.false + + m = sympy.Integer(0) + r = sympy.true + + # special case for trivial C dimension. default to NCHW + r &= sympy.Ne(strides[1], 0) + + for d in dim_order: + r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m) + # Fallback to NCHW as default layout for ambiguous cases + # This is the flaw of implicit memory_format from strides. + # N111 tensor with identical strides for size 1 dimension; + # Two cases could lead us here: + # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) + # b. N11W contiguous Tensor sliced on the W-dimension. + # ([N,1,1,1]@[W,W,W,W]) + if d == 0: + r &= sympy.Ne(m, strides[1]) + # This is necessary to: + # 1. distinguish the memory_format of N1H1; + # [H, 1, 1, 1] channels_last stride + # [H, H, 1, 1] contiguous stride + # 2. permutation of 1C1W: + # [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) + # [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as + # channels_last + m = strides[d] * sympy.Max(sizes[d], 1) + + return r + + +def sympy_is_channels_last_strides_2d(sizes, strides): + return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0]) + + +def sympy_is_channels_last_strides_3d(sizes, strides): + return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0]) + + +def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides): + from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator + + return IsNonOverlappingAndDenseIndicator(*sizes, *strides) + + +sizes_strides_methods = { + # TODO: These could also be done with indicators, maybe it is better + # for reasoning to do it that way + "is_contiguous": sympy_is_contiguous, + "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d, + "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d, + "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d, + "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d, + "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator, +} + +alternate_impl_if_hinted_methods = { + "sym_min": builtins.min, + "sym_max": builtins.max, +} + + +def to_node(self, num): + if isinstance(num, SymTypes): + return num.node + elif type(num) is bool: + return self.wrap_bool(num) + elif type(num) is int: + return self.wrap_int(num) + elif type(num) is float: + return self.wrap_float(num) + else: + # NotImplemented is important so that Python tries the + # other magic method + return NotImplemented + + +def wrap_node(x): + # TODO: let C++ also take advantage of this + if isinstance(x, SymNode) and x.constant is not None: + return x.constant + if x.is_int(): + return SymInt(x) + elif x.is_float(): + return SymFloat(x) + elif x.is_bool(): + return SymBool(x) + else: + raise AssertionError(f"unrecognized return type {x}") + + +def method_to_operator(method): + return METHOD_TO_OPERATOR[method] + + +def _make_node_magic(method, func): + func = lru_cache(256)(func) + + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"{method}_" + else: + method_attr = method + + def binary_magic_impl(self, other): + from torch.fx.experimental.symbolic_shapes import safe_expand + + op = method_to_operator(method) + + out_hint = None + if self.hint is not None and other.hint is not None: + out_hint = op(self.hint, other.hint) + + alternate_impl = alternate_impl_if_hinted_methods.get(method) + if alternate_impl and out_hint is not None: + return to_node(self, alternate_impl(wrap_node(self), wrap_node(other))) + + if sym_function_mode(): + return to_node( + self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) + ) + assert isinstance(other, SymNode) + # TODO: consider constant prop here + try: + out = func(self.expr, other.expr) + except Exception: + log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) + raise + out = safe_expand(out) + sym_node_log.debug("%s %s %s -> %s", func, self.expr, other.expr, out) + pytype: Type + # This is not strictly correct. In Python, a**b may return complex when + # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This + # returns a float while both arguments are ints: 2**(-1). Also, max and + # min do not type promote. To avoid having data-dependent control flow + # here, we just set the type to float if one of the args is a float. In + # case of a type mismatch, we assume that it will be detected during + # evaluation. + if method in always_float_magic_methods: + pytype = float + elif method in always_bool_magic_methods: + pytype = bool + elif self.pytype is float or other.pytype is float: + pytype = float + else: + pytype = self.pytype + + if ( + pytype is not None + and out_hint is not None + and not isinstance(out_hint, SymTypes) + ): + out_hint = pytype(out_hint) + + # Create a FX node that corresponds to the operation being applied to + # this node. + fx_node, _ = self.shape_env._create_fx_call_function( + op, (self.fx_node, other.fx_node) + ) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + def unary_magic_impl(self): + from torch.fx.experimental.symbolic_shapes import safe_expand + + op = method_to_operator(method) + if sym_function_mode(): + return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) + # TODO: consider constant prop here + expr = self.expr + if method == "floor" or method == "ceiling": + expr = self.shape_env._simplify_floor_div(expr) + + try: + out = func(expr) + except Exception: + log.warning("failed to eval %s(%s)", method, expr) + raise + sym_node_log.debug("%s %s -> %s", func, expr, out) + out_hint = None + if self.hint is not None: + out_hint = op(self.hint) + out = safe_expand(out) + pytype: Type + if method in always_int_magic_methods: + pytype = int + elif method in always_bool_magic_methods: + pytype = bool + elif method in always_float_magic_methods: + pytype = float + else: + pytype = self.pytype + + fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,)) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + if method in unary_methods: + setattr(SymNode, f"_{method_attr}", unary_magic_impl) + elif method == "sym_ite": + + def sym_ite_impl(pred_node, then_node, else_node): + from torch.fx.experimental.symbolic_shapes import safe_expand + + out_hint = then_node.hint if pred_node.hint else else_node.hint + if sym_function_mode(): + return to_node( + pred_node, + handle_sym_dispatch( + sym_ite, + ( + wrap_node(pred_node), + wrap_node(then_node), + wrap_node(else_node), + ), + {}, + ), + ) + + try: + out = func(pred_node.expr, then_node.expr, else_node.expr) + except Exception: + log.warning( + "failed to eval %s(%s, %s, %s)", + method, + pred_node.expr, + then_node.expr, + else_node.expr, + ) + raise + + out = safe_expand(out) + fx_node, _ = pred_node.shape_env._create_fx_call_function( + sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node) + ) + return SymNode( + out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node + ) + + setattr(SymNode, f"_{method_attr}", sym_ite_impl) + elif method == "round": + + def round_impl(self, ndigits=None): + from torch.fx.experimental.symbolic_shapes import safe_expand + + op = builtins.round + if sym_function_mode(): + return to_node( + self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {}) + ) + + expr = self.expr + try: + out = func(expr, ndigits) + except Exception: + log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) + raise + out = safe_expand(out) + + pytype = int if ndigits is None else self.pytype + + out_hint = None + if self.hint is not None: + out_hint = op(self.hint, ndigits) + + # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the + # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here + # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The + # hack down below works, because all round function down the line all take ndigits=None as default in their + # signature. + # TODO: Remove the args construction below if a different sentinel is used by FX. + args = [self.fx_node] + if ndigits is not None: + args.append(ndigits) + fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args)) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + setattr(SymNode, f"_{method_attr}", round_impl) + else: + setattr(SymNode, f"_{method_attr}", binary_magic_impl) + + +def _make_node_sizes_strides(method, func): + # NB: don't LRU cache, lots of arguments + + def sizes_strides_impl(self, sizes, strides): + op = getattr(sys.modules[__name__], method) + if sym_function_mode(): + return to_node( + self, + handle_sym_dispatch( + op, + ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), + {}, + ), + ) + size_exprs = [s.expr for s in sizes] + stride_exprs = [s.expr for s in strides] + try: + out = func(size_exprs, stride_exprs) + except Exception: + log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs) + raise + # bool is never expandable + + size_hints = [] + out_hint = None + for s in sizes: + if s.hint is None: + break + size_hints.append(s.hint) + else: + stride_hints = [] + for s in strides: + if s.hint is None: + break + stride_hints.append(s.hint) + else: + out_hint = op(size_hints, stride_hints) + + # NB: This is the indicator function, not the actual bool! + pytype: Type + if method.endswith("_indicator"): + pytype = int + else: + pytype = bool + return SymNode(out, self.shape_env, pytype, out_hint) + + setattr(SymNode, f"_{method}", sizes_strides_impl) + + # TODO: This is technically hotpath, but in the ideal end state + # guards on this will resolve at a higher level so you never + # spend time in this code + def sizes_strides_user(sizes, strides): + import sympy + + from torch.fx.experimental.symbolic_shapes import ( + eval_is_non_overlapping_and_dense, + ) + + for a in itertools.chain(sizes, strides): + if isinstance(a, SymInt): + return wrap_node( + getattr(a.node, method)( + [to_node(a.node, b) for b in sizes], + [to_node(a.node, b) for b in strides], + ) + ) + if method == "is_non_overlapping_and_dense_indicator": + return eval_is_non_overlapping_and_dense(sizes, strides) + else: + # TODO: this is an awful implementation + return bool( + func( + [sympy.sympify(a) for a in sizes], + [sympy.sympify(a) for a in strides], + ) + ) + + # Skip for is_non_overlapping_and_dense_indicator + if not hasattr(sys.modules[__name__], method): + setattr(sys.modules[__name__], method, sizes_strides_user) + + +for method, func in magic_methods.items(): + _make_node_magic(method, func) + +for method, func in sizes_strides_methods.items(): + _make_node_sizes_strides(method, func) + + +def _make_user_magic(method, user_type): + # User magic takes care of wrapping the other operand into a node, + # so that our internal logic can assume everything is nodes + + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"sym_{method}" + else: + method_attr = method + + def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]): + if isinstance(x, (int, float, bool)): + return x + if isinstance(x, SymBool): + return x.node.guard_bool("", 0) + raise AssertionError("expect to be called with constant SymBools") + + def is_constant(x): + if isinstance(x, (int, float, bool)): + return True + if isinstance(x, (SymInt, SymFloat, SymBool)): + return x.node.is_constant() + return False + + if method in bool_becomes_int_magic_methods: + + def promote(x): + """Implements True+True=2, which works in python but not sympy""" + if isinstance(x, SymBool): + return SymInt(x.node.wrap_int(int(x))) + return x + + else: + + def promote(x): + return x + + # Before and after performing the operation, check if any operands are constant. + # If so, extract out the constant values first. If `self` itself is a + # constant, then "redispatch" by calling back into the operator. Sometimes + # this means that operations involving SymBool return plain bools. + # Alternatively, we could also rewrap into constant Symbool (i.e. by + # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that + # today for no particular reason. + def unary_magic_impl(self): + self = promote(self) + if is_constant(self): + return (method_to_operator(method))(get_constant(self)) + return wrap_node(getattr(self.node, method_attr)()) + + def binary_magic_impl(self, other): + sym_node_log.debug("MAGIC %s %s %s", method, self, other) + self = promote(self) + other = promote(other) + if is_constant(self): + return (method_to_operator(method))(get_constant(self), other) + if is_constant(other): + other = get_constant(other) + other_node = to_node(self.node, other) + if other_node is NotImplemented: + return NotImplemented + ret = wrap_node(getattr(self.node, method_attr)(other_node)) + return get_constant(ret) if is_constant(ret) else ret + + def rbinary_magic_impl(self, other): + self = promote(self) + other = promote(other) + if is_constant(self): + return (method_to_operator(method))(get_constant(self), other) + if is_constant(other): + other = get_constant(other) + other_node = to_node(self.node, other) + if other_node is NotImplemented: + return NotImplemented + ret = wrap_node(getattr(other_node, method_attr)(self.node)) + return get_constant(ret) if is_constant(ret) else ret + + if method in unary_magic_methods: + setattr(user_type, f"__{method}__", unary_magic_impl) + elif method in unary_nonmagic_methods: + orig = getattr(user_type, method) + setattr(user_type, method, update_wrapper(unary_magic_impl, orig)) + elif method == "sym_ite": + + def sym_ite_magic_impl(pred, then_val, else_val): + pred_node = pred.node + then_node = to_node(pred_node, then_val) + else_node = to_node(pred_node, else_val) + if then_node is NotImplemented or else_node is NotImplemented: + return NotImplemented + assert ( + isinstance(then_node, SymNode) + and isinstance(else_node, SymNode) + and then_node.pytype == else_node.pytype + ) + ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) + return get_constant(ret) if ret.node.is_constant() else ret + + setattr(user_type, f"__{method}__", sym_ite_magic_impl) + elif method == "round": + + def round_magic_impl(self, ndigits=None): + if is_constant(self): + return builtins.round(get_constant(self), ndigits) + + return wrap_node(getattr(self.node, method)(ndigits)) + + setattr(user_type, f"__{method}__", round_magic_impl) + else: + setattr(user_type, f"__{method}__", binary_magic_impl) + if method in reflectable_magic_methods: + setattr(user_type, f"__r{method}__", rbinary_magic_impl) + + +for method, func in magic_methods.items(): # type: ignore[assignment] + if method in only_bool_magic_methods: + _make_user_magic(method, SymBool) + continue + if method in only_float_magic_methods: + _make_user_magic(method, SymFloat) + continue + if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods: + _make_user_magic(method, SymBool) + _make_user_magic(method, SymInt) + _make_user_magic(method, SymFloat) + +del method +del func diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unify_refinements.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unify_refinements.py new file mode 100644 index 0000000000000000000000000000000000000000..532d2784fb49ae4cd798b2a0706d82b8151a08de --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unify_refinements.py @@ -0,0 +1,120 @@ +from torch.fx.experimental.graph_gradual_typechecker import Refine +from torch.fx.tensor_type import TensorType +from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined] + + +def infer_symbolic_types_single_pass(traced): + """ + Calls our symbolic inferencer once. + """ + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + +def infer_symbolic_types(traced): + """ + Calls our symbolic inferencer twice. + This is useful when one pass is not enough + to infer all the information such as the case + for braodcasting. + """ + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + r.symbolic_relations() + +def convert_eq(list_of_eq): + """ + Convert equality constraints in the right format + to be used by unification library. + """ + lhs = [] + rhs = [] + for eq in list_of_eq: + lhs.append(eq.lhs) + rhs.append(eq.rhs) + return tuple(lhs), tuple(rhs) + + +def unify_eq(list_of_eq): + """ + Apply unification to a set of + equality constraints + """ + lhs, rhs = convert_eq(list_of_eq) + return unify(lhs, rhs) + + +def substitute_solution_one_type(mapping, t): + """ + Apply the most general unifier to a type + """ + if isinstance(t, Var): + if t in mapping.keys(): + return mapping[t] + else: + return t + + elif isinstance(t, TensorType): + new_type = [] + for typ in t.__args__: + if typ in mapping.keys(): + new_type.append(mapping[typ]) + else: + new_type.append(typ) + return TensorType(tuple(new_type)) + + elif isinstance(t, list): + new_type = [] + for typ in t: + new_type.append(substitute_solution_one_type(mapping, typ)) + return new_type + + elif isinstance(t, tuple): + new_type = [] + for typ in t: + new_type.append(substitute_solution_one_type(mapping, typ)) + return tuple(new_type) + + else: + return t + + +def substitute_all_types(graph, mapping): + """ + Apply the most general unifier to all types in a graph + till reaching a fixed point. If the input and output graph + are the same, we converge. + """ + flag = True + while flag: + flag = False + for k in mapping: + old_mapping_val = mapping[k] + if mapping[k] in mapping.keys(): + new_key = mapping[k] + mapping[k] = mapping[new_key] + if old_mapping_val != mapping[k]: + flag = True + + for n in graph.nodes: + n.type = substitute_solution_one_type(mapping, n.type) + +def check_for_type_equality(g1, g2): + """ + A check equality to be used in fixed points. + We do not use graph equality but instead type + equality. + """ + for n, m in zip(g1.nodes, g2.nodes): + if n.type != m.type: + return False + return True diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/graph.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..590a1497d0d66db2196bf95d80412532ccf16da4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/graph.py @@ -0,0 +1,1653 @@ +from collections import defaultdict +from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name +import torch.utils._pytree as pytree +from . import _pytree as fx_pytree +from ._compatibility import compatibility + +import contextlib +from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type +from dataclasses import dataclass +from contextlib import contextmanager +import copy +import enum +import torch +import keyword +import re +import builtins +import math +import warnings +import inspect + +__all__ = ["PythonCode", "CodeGen", "Graph"] + +if TYPE_CHECKING: + from .graph_module import GraphModule # noqa: F401 + from ._symbolic_trace import Tracer # noqa: F401 + + +# Mapping of builtins to their `typing` equivalent. +_origin_type_map = { + list: List, + dict: Dict, + set: Set, + frozenset: FrozenSet, + tuple: Tuple, +} + + +# Signature for functions thattransforms the body (`list[str]`) of the +# generated code +TransformCodeFunc = Callable[[List[str]], List[str]] + + +class _CustomBuiltin(NamedTuple): + """Additional objs that we add to every graph's globals. + + The repr() for some standard library objects is not valid Python code without + an import. For common objects of this sort, we bundle them in the globals of + every FX graph. + """ + # How to import this object from the standard library. + import_str: str + # The actual object, produced from that import string. + obj: Any + +_custom_builtins: Dict[str, _CustomBuiltin] = {} + + +def _register_custom_builtin(name: str, import_str: str, obj: Any): + _custom_builtins[name] = _CustomBuiltin(import_str, obj) + + +_register_custom_builtin('inf', 'from math import inf', math.inf) +_register_custom_builtin('nan', 'from math import nan', math.nan) +_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None)) +_register_custom_builtin('torch', 'import torch', torch) +_register_custom_builtin('device', 'from torch import device', torch.device) +_register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree) +_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree) + + +def _is_magic(x: str) -> bool: + return x.startswith('__') and x.endswith('__') + + +def _snake_case(s: str) -> str: + """ + Transforms the given string ``s`` to a Python-style variable name + + Examples: + ``mod.snake_case`` -> ``mod.snake_case`` + ``mod.pascalCase``-> ``mod.pascal_case`` + ``mod.ALL_CAPS`` -> ``mod.all_caps`` + """ + chars = [] + prev_lower = False + for c in s: + if prev_lower and c.isupper(): + chars.append('_') + chars.append(c.lower()) + prev_lower = c.islower() + return ''.join(chars) + + +def _is_from_torch(obj: Any) -> bool: + module_name = getattr(obj, '__module__', None) + if module_name is not None: + base_module = module_name.partition('.')[0] + return ( + base_module == 'torch' and + not module_name.startswith("torch._dynamo.") and + not module_name.startswith("torch._inductor.") + ) + + name = getattr(obj, '__name__', None) + # exclude torch because torch.torch.torch.torch works. idk mang + if name is not None and name != 'torch': + for guess in [torch, torch.nn.functional]: + if getattr(guess, name, None) is obj: + return True + + return False + + +class _Namespace: + """A context for associating names uniquely with objects. + + The following invariants are enforced: + - Each object gets a single name. + - Each name is unique within a given namespace. + - Names generated do not shadow builtins, unless the object is indeed that builtin. + """ + def __init__(self): + self._obj_to_name: Dict[Any, str] = {} + self._unassociated_names = set() + self._used_names: Set[str] = set() + self._base_count: Dict[str, int] = defaultdict(int) + + self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+') + self._name_suffix_regex = re.compile(r"(.*)_(\d+)$") + + def create_name(self, candidate: str, obj: Optional[Any]) -> str: + """Create a unique name. + + Arguments: + candidate: used as the basis for the unique name, relevant to the user. + obj: If not None, an object that will be associated with the unique name. + """ + if obj is not None and obj in self._obj_to_name: + return self._obj_to_name[obj] + + # delete all characters that are illegal in a Python identifier + candidate = self._illegal_char_regex.sub('_', candidate) + + if not candidate: + candidate = '_unnamed' + + if candidate[0].isdigit(): + candidate = f'_{candidate}' + + match = self._name_suffix_regex.match(candidate) + if match is None: + base = candidate + num = None + else: + base, num_str = match.group(1, 2) + num = int(num_str) + + candidate = base if num is None else f'{base}_{num}' + if not num: + num = self._base_count[base] + + while candidate in self._used_names or self._is_illegal_name(candidate, obj): + num += 1 + candidate = f'{base}_{num}' + + self._used_names.add(candidate) + self._base_count[base] = num + if obj is None: + self._unassociated_names.add(candidate) + else: + self._obj_to_name[obj] = candidate + return candidate + + def associate_name_with_obj(self, name: str, obj: Any): + """Associate a unique name with an object. + + Neither `name` nor `obj` should be associated already. + """ + assert obj not in self._obj_to_name + assert name in self._unassociated_names + self._obj_to_name[obj] = name + self._unassociated_names.remove(name) + + def _is_illegal_name(self, name: str, obj: Any) -> bool: + # 1. keywords are never allowed as names. + if name in keyword.kwlist: + return True + + # 2. Can't shadow a builtin name, unless you *are* that builtin. + if name in builtins.__dict__: + return obj is not builtins.__dict__[name] + + # 3. Can't shadow our custom builtins either + if name in _custom_builtins: + return obj is not _custom_builtins[name].obj + + return False + + def _rename_object(self, obj: Any, name: str): + assert obj in self._obj_to_name + self._obj_to_name[obj] = name + self._used_names.add(name) + +dtype_abbrs = { + torch.bfloat16: 'bf16', + torch.float64: 'f64', + torch.float32: 'f32', + torch.float16: 'f16', + torch.float8_e4m3fn: 'f8e4m3fn', + torch.float8_e5m2: 'f8e5m2', + torch.float8_e4m3fnuz: 'f8e4m3fnuz', + torch.float8_e5m2fnuz: 'f8e5m2fnuz', + torch.complex32: 'c32', + torch.complex64: 'c64', + torch.complex128: 'c128', + torch.int8: 'i8', + torch.int16: 'i16', + torch.int32: 'i32', + torch.int64: 'i64', + torch.bool: 'b8', + torch.uint8: 'u8', + torch.uint32: 'u32', + torch.uint64: 'u64', +} + +@compatibility(is_backward_compatible=True) +@dataclass +class PythonCode: + """ + Represents all the information necessary to exec or save a graph as Python code. + """ + # Python source code for the forward function definition. + src: str + # Values in global scope during execution of `src_def`. + globals: Dict[str, Any] + # Optional mapping from the forward function's line number to + # node index. + _lineno_map: Optional[Dict[int, Optional[int]]] + + +def _format_target(base: str, target: str) -> str: + elems = target.split('.') + r = base + for e in elems: + if not e.isidentifier(): + r = f'getattr({r}, "{e}")' + else: + r = f'{r}.{e}' + return r + +class _InsertPoint: + def __init__(self, graph, new_insert): + self.graph = graph + self.orig_insert, graph._insert = graph._insert, new_insert + + def __enter__(self): + pass + + def __exit__(self, type, value, tb): + self.graph._insert = self.orig_insert + +class _node_list: + def __init__(self, graph: 'Graph', direction: str = '_next'): + assert direction in ['_next', '_prev'] + self.graph = graph + self.direction = direction + + def __len__(self): + return self.graph._len + + def __iter__(self): + root = self.graph._root + if self.direction == "_next": + cur = root._next + while cur is not root: + if not cur._erased: + yield cur + cur = cur._next + else: + assert self.direction == "_prev" + cur = root._prev + while cur is not root: + if not cur._erased: + yield cur + cur = cur._prev + + def __reversed__(self): + return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') + +class _PyTreeInfo(NamedTuple): + """ + Contains extra info stored when we're using Pytrees + """ + orig_args: List[str] + in_spec: pytree.TreeSpec + out_spec: Optional[pytree.TreeSpec] + +@dataclass(frozen=True) +class _ParsedStackTrace: + """ + Represents the top-most frame of a parsed stack trace + """ + file: str + lineno: str + name: str + code: str + +# get File:lineno code from stack_trace +def _parse_stack_trace(stack_trace: str): + if stack_trace is None: + return None + pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") + lines = stack_trace.strip().split('\n') + # stacktrace should have innermost frame last, so we + # iterate backwards to find the first line that starts + # with 'File ' + summary_str = "" + for idx in range(len(lines) - 2, -1, -1): + line = lines[idx].strip() + matches = pattern.match(line) + if matches: + file = matches.group(1) + lineno = matches.group(2) + name = matches.group(3) + # next line should be the code + code = lines[idx + 1].strip() + return _ParsedStackTrace(file, lineno, name, code) + return None + +@compatibility(is_backward_compatible=False) +class CodeGen: + def __init__(self): + self._body_transformer: Optional[TransformCodeFunc] = None + self._func_name: str = "forward" + + def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str: + """ + Given the free variables and a return annotation, generates the beginning of the FX function. + By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'` + """ + # If the original function didn't have self as its first argument, we + # would have added it. + if len(free_vars) == 0 or free_vars[0] != 'self': + free_vars.insert(0, 'self') + return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" + + def generate_output(self, output_args: Argument) -> str: + """ + Given the output arguments, generates the return statement of the FX function. + Note: The returned statement should not be indented. + """ + return f'return {repr(output_args)}' + + def process_inputs(self, *args: Any) -> Any: + """ + Transforms the inputs so that the graph can take them as arguments, as + non-default codegen may result in the inputs to the function being + different from the inputs to the graph. + + If the graph was directly runnable, this invariant should hold true + `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)` + """ + return args + + def process_outputs(self, outputs: Any) -> Any: + """ + Transforms the outputs of the graph to be identical to the codegen. + + See ``process_inputs`` for more details. + """ + return outputs + + def additional_globals(self) -> List[Tuple[str, Any]]: + """ + If your codegen uses extra global values, add tuples of (identifier,reference to the value) here. + For example, return ['List', typing.List] if you need ``List`` in the global context. + """ + return [] + + def _gen_python_code( + self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False, + ) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation : List[str] = [''] + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + + We call this for names that reference objects external to the + Graph, like functions or types. + + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o : Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return '()' + + typename = _type_repr(o) + + if hasattr(o, '__origin__'): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + if hasattr(o, '__args__'): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + if len(args) == 0: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python < 3.9 + return origin_typename + + return f'{origin_typename}[{",".join(args)}]' + else: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + def _get_repr(arg: Any) -> str: + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, '_fields'): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + elif isinstance(arg, torch._ops.OpOverload): + qualified_name = _get_qualified_name(arg) + global_name = add_global(qualified_name, arg) + return f"{global_name}" + elif isinstance(arg, enum.Enum): + cls = arg.__class__ + clsname = add_global(cls.__name__, cls) + return f"{clsname}.{arg.name}" + return repr(arg) + + def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: + args_s = ', '.join(_get_repr(a) for a in args) + kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + if args_s and kwargs_s: + return f'{args_s}, {kwargs_s}' + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use : Dict[Node, Node] = {} + user_to_last_uses : Dict[Node, List[Node]] = {} + + def register_last_uses(n : Node, user : Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + def delete_unused_values(user : Node): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == 'placeholder': + return + if user.op == 'output': + body.append('\n') + return + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) + body.append(f'; {to_delete_str}\n') + else: + body.append('\n') + + prev_stacktrace = None + + def append_stacktrace_summary(node : Node): + """ + Append a summary of the stacktrace to the generated code. This is + useful for debugging. + """ + nonlocal prev_stacktrace + + if node.op not in {'placeholder', 'output'}: + if node.stack_trace: + if node.stack_trace != prev_stacktrace: + prev_stacktrace = node.stack_trace + summary_str = "" + + parsed_stack_trace = _parse_stack_trace(node.stack_trace) + + if parsed_stack_trace is not None: + lineno = parsed_stack_trace.lineno + code = parsed_stack_trace.code + name = parsed_stack_trace.name + summary_str = f'File: {parsed_stack_trace.file}:{lineno} in {name}, code: {code}' + + body.append(f'\n# {summary_str}\n') + elif prev_stacktrace != "": + prev_stacktrace = "" + body.append('\n# No stacktrace found for following nodes\n') + + def stringify_shape(shape : torch.Size) -> str: + return f"[{', '.join(str(x) for x in shape)}]" + + def emit_node(node : Node): + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + + if verbose: + # override annotation with more detailed information + from torch._subclasses.fake_tensor import FakeTensor + from torch.fx.experimental.proxy_tensor import py_sym_types + from torch.fx.passes.shape_prop import TensorMetadata + + meta_val = node.meta.get('val', node.meta.get('tensor_meta', None)) + + # use string as annotation, to make it valid python code + if isinstance(meta_val, FakeTensor): + maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"' + elif isinstance(meta_val, py_sym_types): + maybe_type_annotation = f': "Sym({meta_val})"' + elif isinstance(meta_val, TensorMetadata): + maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"' + + if node.op == 'placeholder': + assert isinstance(node.target, str) + maybe_default_arg = '' if not node.args else f' = {_get_repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') + raw_name = node.target.replace('*', '') + if raw_name != repr(node): + body.append(f'{repr(node)} = {raw_name}\n') + return + elif node.op == 'call_method': + assert isinstance(node.target, str) + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}' + f'({_format_args(node.args[1:], node.kwargs)})') + return + elif node.op == 'call_function': + assert callable(node.target) + # pretty print operators + if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods: + assert isinstance(node.args, tuple) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}') + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods: + body.append(f'{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; ' + f'{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}') + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if global_name == 'getattr' and \ + isinstance(node.args, tuple) and \ + isinstance(node.args[1], str) and \ + node.args[1].isidentifier() and \ + len(node.args) == 2: + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}') + return + body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') + if node.meta.get('is_wrapped', False): + wrapped_fns.setdefault(global_name) + return + elif node.op == 'call_module': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + return + elif node.op == 'get_attr': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + return + elif node.op == 'output': + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f'node: {node.op} {node.target}') + + for i, node in enumerate(nodes): + # NOTE: emit_node does not emit a string with newline. It depends + # on delete_unused_values to append one + if verbose: + append_stacktrace_summary(node) + # emit a counter comment to keep track of + # node index, which will be deleted later + # after going through _body_transformer + body.append(f"# COUNTER: {i}\n") + emit_node(node) + delete_unused_values(node) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append('pass\n') + + + + if len(wrapped_fns) > 0: + wrap_name = add_global('wrap', torch.fx.wrap) + wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = '' + + if self._body_transformer: + body = self._body_transformer(body) + + for name, value in self.additional_globals(): + add_global(name, value) + + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + + # remove counter and generate lineno to node index mapping + lineno_map: Dict[int, Optional[int]] = {} + prologue_len = prologue.count('\n') + 1 + new_lines: List[str] = [] + cur_idx = None + for line in ''.join(body).split('\n'): + counter = re.search(r"# COUNTER: (\d+)", line) + if counter and counter.group(1) is not None: + cur_idx = int(counter.group(1)) + else: + lineno_map[len(new_lines) + prologue_len] = cur_idx + new_lines.append(line) + + code = "\n".join(new_lines).lstrip('\n') + code = '\n'.join(' ' + line for line in code.split('\n')) + + fn_code = f""" +{wrap_stmts} + +{prologue} +{code}""" + return PythonCode(fn_code, globals_, _lineno_map=lineno_map) + + +# Ideally, we'd like to refactor all of the pytree logic into this codegen +# class. Unfortunately, there are 3 areas we currently need extra logic in FX. +# 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`. +# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec. +# Since we can't access .graph within the FX forward, we need to copy the attribute to the module. +# 3. We currently can't register the pytree imports with `add_global` - not sure why. +class _PyTreeCodeGen(CodeGen): + def __init__(self, pytree_info: _PyTreeInfo): + super().__init__() + self.pytree_info: _PyTreeInfo = pytree_info + + def process_inputs(self, *inputs: Any) -> Any: + flat_args = pytree.arg_tree_leaves(*inputs) + return flat_args + + def process_outputs(self, out: Any) -> Any: + if self.pytree_info is None or self.pytree_info.out_spec is None: + return out + if not isinstance(out, (list, tuple)): + out = [out] + assert self.pytree_info.out_spec is not None + return pytree.tree_unflatten(out, self.pytree_info.out_spec) + + def gen_fn_def(self, free_vars, maybe_return_annotation): + # Given a user function/model: + # myargs = [myargs0, myargs1] + # mykwargs = {'mykwargs0': ..., 'mykwargs1': ...} + # def forward(self, mypos, *myargs, mykey=None, **mykwargs): + # + # The generated code flattens all keywords into positional arguments for `forward()` + # e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1): + # + # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately + # e.g. tree_flatten_spec(([mypos, myargs0, myargs1], + # {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}), + # self._in_spec) + # + # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec + # e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec) + if self.pytree_info is None: + return super().gen_fn_def(free_vars, maybe_return_annotation) + + fn_args = self.pytree_info.orig_args + has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False + if has_orig_self: + free_vars.insert(0, 'self') + fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation) + + if len(free_vars) > 0: # pytree has placeholders in it + # when kwargs is present, in_spec is tuple(args, kwargs) + has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \ + self.pytree_info.in_spec.num_children == 2 and \ + self.pytree_info.in_spec.children_specs[0].type == tuple and \ + self.pytree_info.in_spec.children_specs[1].type == dict + fn_kwargs = '{}' + fn_signature = f"[{', '.join(fn_args)}], self._in_spec" + if has_args_kwargs_tuple: + count_args = self.pytree_info.in_spec.children_specs[0].num_children + fn_args = self.pytree_info.orig_args[:count_args] + fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip( + self.pytree_info.in_spec.children_specs[1].context, + self.pytree_info.orig_args[count_args:])) + '}' + fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec" + + # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid. + # we need to split it to two lines: + # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon) + # one for code: `var1, var2, = function_call()` + without_annotation = [x.split(":")[0] for x in free_vars] + has_annotation = [x + "; " for x in free_vars if ":" in x] + if len(has_annotation) > 0: + fn_definition += "\n " + "".join(has_annotation) + "\n" + fn_definition += f""" + {', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" + return fn_definition + + def generate_output(self, output_args): + if self.pytree_info and self.pytree_info.out_spec: + return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)' + else: + return super().generate_output(output_args) + +@compatibility(is_backward_compatible=True) +class Graph: + """ + ``Graph`` is the main data structure used in the FX Intermediate Representation. + It consists of a series of ``Node`` s, each representing callsites (or other + syntactic constructs). The list of ``Node`` s, taken together, constitute a + valid Python function. + + For example, the following code + + .. code-block:: python + + import torch + import torch.fx + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + + Will produce the following Graph:: + + print(gm.graph) + + .. code-block:: text + + graph(x): + %linear_weight : [num_users=1] = self.linear.weight + %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) + %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) + %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) + %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) + %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) + return topk_1 + + For the semantics of operations represented in the ``Graph``, please see :class:`Node`. + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None, + tracer_extras: Optional[Dict[str, Any]] = None): + """ + Construct an empty Graph. + """ + self._root : Node = Node(self, '', 'root', '', (), {}) + self._used_names : Dict[str, int] = {} # base name -> number + self._insert = self._root.prepend + self._len = 0 + self._graph_namespace = _Namespace() + self._owning_module = owning_module + self._tracer_cls = tracer_cls + self._tracer_extras = tracer_extras + self._codegen = CodeGen() + self._co_fields : Dict[str, Any] = {} + + @property + def owning_module(self): + return self._owning_module + + @owning_module.setter + def owning_module(self, mod: Optional["GraphModule"]): + self._owning_module = mod + + @property + def nodes(self) -> _node_list: + """ + Get the list of Nodes that constitute this Graph. + + Note that this ``Node`` list representation is a doubly-linked list. Mutations + during iteration (e.g. delete a Node, add a Node) are safe. + + Returns: + + A doubly-linked list of Nodes. Note that ``reversed`` can be called on + this list to switch iteration order. + """ + return _node_list(self) + + @compatibility(is_backward_compatible=True) + def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]': + """ + Copy all nodes from a given graph into ``self``. + + Args: + + g (Graph): The source graph from which to copy Nodes. + + val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping + from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed + in with values in it already to override copying of certain values. + + Returns: + + The value in ``self`` that is now equivalent to the output value in ``g``, + if ``g`` had an ``output`` node. ``None`` otherwise. + """ + for node in g.nodes: + if node in val_map: + continue + if node.op == 'output': + rv = map_arg(node.args[0], lambda n: val_map[n]) + return rv if not return_output_node else (rv, node) + val_map[node] = self.node_copy(node, lambda n : val_map[n]) + return None + + def __deepcopy__(self, memo=None) -> 'Graph': + """ + Explicitly implement __deepcopy__ to prevent excessive recursion depth + from the default implementation. This uses graph_copy to copy the nodes + in an iterative way, rather than recursive. It also populates the + memoization table to prevent unnecessary copies (e.g. references to + nodes or other parts of the Graph from a custom GraphModule implementation. + """ + memo = memo if memo else {} + g = Graph(tracer_cls=self._tracer_cls) + output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) + g._codegen = copy.deepcopy(self._codegen) + assert isinstance(output_vals, tuple) + output_val, old_output_node = output_vals + new_output_node = g.output(output_val, type_expr=getattr(old_output_node, 'type', None)) + new_output_node.meta = copy.copy(old_output_node.meta) + return g + + @compatibility(is_backward_compatible=True) + def create_node(self, op: str, target: 'Target', + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + name: Optional[str] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Create a ``Node`` and add it to the ``Graph`` at the current insert-point. + Note that the current insert-point can be set via :meth:`Graph.inserting_before` + and :meth:`Graph.inserting_after`. + + Args: + op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', + 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are + described in the ``Graph`` docstring. + + args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node. + + kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node + + name (Optional[str]): an optional string name for the ``Node``. + This will influence the name of the value assigned to in the + Python generated code. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly-created and inserted node. + """ + assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') + args = () if args is None else args + kwargs = {} if kwargs is None else kwargs + assert isinstance(args, tuple), "args must be a tuple" + assert isinstance(kwargs, dict), "kwargs must be a dict" + + candidate = name if name is not None else self._target_to_str(target) + name = self._graph_namespace.create_name(candidate, None) + n = Node(self, name, op, target, args, kwargs, type_expr) + + self._graph_namespace.associate_name_with_obj(name, n) + + self._insert(n) + self._len += 1 + return n + + @compatibility(is_backward_compatible=False) + def process_inputs(self, *args): + """ + Processes args so that they can be passed to the FX graph. + """ + return self._codegen.process_inputs(*args) + + @compatibility(is_backward_compatible=False) + def process_outputs(self, out): + return self._codegen.process_outputs(out) + + + @compatibility(is_backward_compatible=True) + def erase_node(self, to_erase : Node) -> None: + """ + Erases a ``Node`` from the ``Graph``. Throws an exception if + there are still users of that node in the ``Graph``. + + Args: + + to_erase (Node): The ``Node`` to erase from the ``Graph``. + """ + if len(to_erase.users) > 0: + raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} ' + f'users in the graph: {to_erase.users}!') + if to_erase.graph != self: + raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!") + if to_erase._erased: + warnings.warn(f"erase_node({to_erase}) on an already erased node") + return + + to_erase._remove_from_list() + to_erase._erased = True # iterators may retain handles to erased nodes + self._len -= 1 + + # Null out this Node's argument nodes so that the Nodes referred to + # can update their ``users`` accordingly + new_args = map_arg(to_erase.args, lambda n: None) + assert isinstance(new_args, tuple) + to_erase.args = new_args + new_kwargs = map_arg(to_erase.kwargs, lambda n: None) + assert isinstance(new_kwargs, dict) + to_erase.kwargs = new_kwargs + + @compatibility(is_backward_compatible=True) + def inserting_before(self, n: Optional[Node] = None): + """Set the point at which create_node and companion methods will insert into the graph. + When used within a 'with' statement, this will temporary set the insert point and + then restore it when the with statement exits:: + + with g.inserting_before(n): + ... # inserting before node n + ... # insert point restored to what it was previously + g.inserting_before(n) # set the insert point permanently + + Args: + + n (Optional[Node]): The node before which to insert. If None this will insert before + the beginning of the entire graph. + + Returns: + A resource manager that will restore the insert point on ``__exit__``. + """ + if n is None: + return self.inserting_after(self._root) + assert n.graph == self, "Node to insert before is not in graph." + return _InsertPoint(self, n.prepend) + + @compatibility(is_backward_compatible=True) + def inserting_after(self, n: Optional[Node] = None): + """Set the point at which create_node and companion methods will insert into the graph. + When used within a 'with' statement, this will temporary set the insert point and + then restore it when the with statement exits:: + + with g.inserting_after(n): + ... # inserting after node n + ... # insert point restored to what it was previously + g.inserting_after(n) # set the insert point permanently + + Args: + + n (Optional[Node]): The node before which to insert. If None this will insert after + the beginning of the entire graph. + + Returns: + A resource manager that will restore the insert point on ``__exit__``. + """ + if n is None: + return self.inserting_before(self._root) + assert n.graph == self, "Node to insert after is not in graph." + return _InsertPoint(self, n.append) + + @compatibility(is_backward_compatible=True) + def placeholder(self, name: str, type_expr: Optional[Any] = None, + default_value : Any = inspect.Signature.empty) -> Node: + """ + Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents + a function input. + + Args: + + name (str): A name for the input value. This corresponds to the name + of the positional argument to the function this ``Graph`` represents. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. This is needed in some + cases for proper code generation (e.g. when the function is used + subsequently in TorchScript compilation). + + default_value (Any): The default value this function argument should take + on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty` + should be passed as this argument to specify that the parameter does _not_ + have a default value. + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + args = () if default_value is inspect.Signature.empty else (default_value,) + return self.create_node('placeholder', name, args=args, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the + fetch of an attribute from the ``Module`` hierarchy. + + Args: + + qualified_name (str): the fully-qualified name of the attribute to be retrieved. + For example, if the traced Module has a submodule named ``foo``, which has a + submodule named ``bar``, which has an attribute named ``baz``, the qualified + name ``foo.bar.baz`` should be passed as ``qualified_name``. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + + Returns: + + The newly-created and inserted ``get_attr`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool: + module_path, _, name = qualified_name.rpartition(".") + + try: + submod: torch.nn.Module = mod.get_submodule(module_path) + except AttributeError: + warnings.warn(f"Failed to fetch module {module_path}!") + return False + + if not hasattr(submod, name): + return False + + res = getattr(submod, name) + + if (not isinstance(res, torch.nn.Module) + and not isinstance(res, torch.nn.Parameter) + and name not in submod._buffers): + return False + + return True + + if (self.owning_module and + not _get_attr_reference_exists(self.owning_module, qualified_name)): + warnings.warn("Attempted to insert a get_attr Node with no " + "underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule, " + "GraphModule.add_parameter to add the " + "necessary Parameter, or " + "nn.Module.register_buffer to add the " + "necessary buffer", stacklevel=2) + return self.create_node('get_attr', qualified_name, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def call_module(self, + module_name: str, + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node + represents a call to the forward() function of a ``Module`` in the ``Module`` + hierarchy. + + Args: + + module_name (str): The qualified name of the ``Module`` in the ``Module`` + hierarchy to be called. For example, if the traced ``Module`` has a + submodule named ``foo``, which has a submodule named ``bar``, the + qualified name ``foo.bar`` should be passed as ``module_name`` to + call that module. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this should *not* include a ``self`` argument. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly-created and inserted ``call_module`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + if (self.owning_module and + self.owning_module.get_submodule(module_name) is None): + warnings.warn("Attempted to insert a call_module Node with " + "no underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule") + return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def call_method(self, + method_name: str, + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node + represents a call to a given method on the 0th element of ``args``. + + Args: + + method_name (str): The name of the method to apply to the self argument. + For example, if args[0] is a ``Node`` representing a ``Tensor``, + then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this *should* include a ``self`` argument. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly created and inserted ``call_method`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def call_function(self, + the_function: Callable[..., Any], + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node + represents a call to a Python callable, specified by ``the_function``. + + Args: + + the_function (Callable[..., Any]): The function to be called. Can be any PyTorch + operator, Python function, or member of the ``builtins`` or ``operator`` + namespaces. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called function. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called function + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly created and inserted ``call_function`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node: + """ + Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from + the graph of node to the graph of self. Example:: + + # Copying all the nodes in `g` into `new_graph` + g : torch.fx.Graph = ... + new_graph = torch.fx.graph() + value_remap = {} + for node in g.nodes: + value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n]) + + Args: + + node (Node): The node to copy into ``self``. + + arg_transform (Callable[[Node], Argument]): A function that transforms + ``Node`` arguments in node's ``args`` and ``kwargs`` into the + equivalent argument in ``self``. In the simplest case, this should + retrieve a value out of a table mapping Nodes in the original + graph to ``self``. + """ + args = map_arg(node.args, arg_transform) + kwargs = map_arg(node.kwargs, arg_transform) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type) + result_node.meta = copy.copy(node.meta) + return result_node + + @compatibility(is_backward_compatible=True) + def output(self, result: 'Argument', type_expr: Optional[Any] = None): + """ + Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents + a ``return`` statement in Python code. ``result`` is the value that should + be returned. + + Args: + + result (Argument): The value to be returned. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + .. note:: + + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr) + + def _target_to_str(self, target : Target) -> str: + if callable(target): + op = target.__name__ + else: + assert isinstance(target, str) + op = target + if _is_magic(op): + op = op[2:-2] + op = _snake_case(op) + return op + + @compatibility(is_backward_compatible=True) + def python_code(self, root_module: str, *, verbose: bool = False) -> PythonCode: + """ + Turn this ``Graph`` into valid Python code. + + Args: + + root_module (str): The name of the root module on which to look-up + qualified name targets. This is usually 'self'. + + Returns: + + A PythonCode object, consisting of two fields: + src: the Python source code representing the object + globals: a dictionary of global names in `src` -> the objects that they reference. + """ + # NOTE: [Graph Namespaces] + # + # There are two types of symbols in generated Python source code: + # locals and globals. + # Locals are locally defined by the output of a node in the Graph. + # Globals are references to external objects, like functions or types. + # + # When generating Python code, we need to make sure to name things + # appropriately. In particular: + # - All names should be unique, to avoid weird shadowing bugs. + # - These names need to be consistent, e.g. a object should always be + # referenced by the same name. + # + # To do this, we create a new namespace just for this source. All names + # that get printed must come from this namespace. + # + # Why can't we re-use node.name? Because it was generated within the + # namespace `self._graph_namespace`. In order to provide uniqueness + # over both locals (node.name) *and* globals, we create a completely + # new namespace to put all identifiers in. + namespace = _Namespace() + + # Override Node's repr to generate a valid name within our namespace. + # Since repr() is designed to produce a valid Python expression, it + # makes sense to re-use it. This way, it's easy to print something like + # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is + # implemented cooperatively to allow this. + def node_repr(n: Node): + return namespace.create_name(n.name, n) + + @contextmanager + def override_node_repr(graph: Graph): + orig_repr_fns = {} + for node in graph.nodes: + orig_repr_fns[node] = node._repr_fn + node._repr_fn = node_repr + try: + yield None + finally: + # restore the original repr functions + for node in graph.nodes: + node._repr_fn = orig_repr_fns[node] + + with override_node_repr(self): + return self._python_code(root_module, namespace, verbose=verbose) + + def _python_code(self, root_module: str, namespace: _Namespace, *, verbose: bool = False) -> PythonCode: + return self._codegen._gen_python_code(self.nodes, root_module, namespace, verbose=verbose) + + + def __str__(self) -> str: + """ + Return a human-readable (not machine-readable) string representation + of this Graph + """ + placeholder_names : List[str] = [] + # This is a one-element array just so ``format_node`` can modify the closed + # over value + maybe_return_typename : List[str] = [''] + + node_strs = [node.format_node(placeholder_names) for node in self.nodes] + param_str = ', '.join(placeholder_names) + s = f'graph({param_str}){maybe_return_typename[0]}:' + for node_str in node_strs: + if node_str: + s += '\n ' + node_str + return s + + @compatibility(is_backward_compatible=True) + def print_tabular(self): + """ + Prints the intermediate representation of the graph in tabular + format. Note that this API requires the ``tabulate`` module to be + installed. + """ + try: + from tabulate import tabulate + except ImportError: + print("`print_tabular` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library.") + raise + + node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] + for n in self.nodes] + print(tabulate(node_specs, + headers=['opcode', 'name', 'target', 'args', 'kwargs'])) + + @compatibility(is_backward_compatible=True) + def lint(self): + """ + Runs various checks on this Graph to make sure it is well-formed. In + particular: + - Checks Nodes have correct ownership (owned by this graph) + - Checks Nodes appear in topological order + - If this Graph has an owning GraphModule, checks that targets + exist in that GraphModule + """ + + # Check topo order + def check_arg(arg : Node, n : Optional[Node] = None) -> None: + context_str = f' of Node \'{n}\' ' if n else ' ' + if arg.graph is not self: + raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, ' + f'but was used as an argument! If you are copying nodes from another graph, make ' + f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}') + if arg not in seen_values: + raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been ' + f'defined! Please check that Nodes in the graph are topologically ordered\n{self}') + + seen_names : Set[str] = set() + seen_values : Set[Node] = set() + for node in self.nodes: + if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']: + raise RuntimeError(f'Node {node} had unknown opcode {node.op}!') + if node.graph is not self: + raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!') + map_arg(node.args, lambda arg: check_arg(arg, node)) + map_arg(node.kwargs, lambda arg: check_arg(arg, node)) + seen_values.add(node) + + if node.name in seen_names: + raise RuntimeError(f'Node redefined name {node.name}!') + seen_names.add(node.name) + + # Check targets are legit + if self.owning_module: + for node in self.nodes: + if node.op == 'call_function': + if not callable(node.target): + raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' + 'a Callable is expected') + else: + if not isinstance(node.target, str): + raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' + 'a str is expected') + if node.op in ['get_attr', 'call_module']: + target_atoms = node.target.split('.') + m_itr = self.owning_module + for i, atom in enumerate(target_atoms): + new_m_itr = getattr(m_itr, atom, None) + seen_qualname = '.'.join(target_atoms[:i]) + if new_m_itr is None: + raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute ' + f'{atom} of {seen_qualname}') + if (node.op == "call_module" + and not isinstance(new_m_itr, torch.nn.Module)): + raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' + 'not reference an nn.Module') + elif (node.op == "get_attr" + and not isinstance(new_m_itr, torch.nn.Module) + and not isinstance(new_m_itr, torch.nn.Parameter) + and atom not in m_itr._buffers): + warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' + 'not reference an nn.Module, nn.Parameter, or buffer, which is ' + 'what \'get_attr\' Nodes typically target') + else: + m_itr = new_m_itr + + @compatibility(is_backward_compatible=True) + def eliminate_dead_code(self): + """ + Remove all dead code from the graph, based on each node's number of + users, and whether the nodes have any side effects. The graph must be + topologically sorted before calling. + + Returns: + bool: Whether the graph was changed as a result of the pass. + + Example: + + Before dead code is eliminated, `a` from `a = x + 1` below has no users + and thus can be eliminated from the graph without having an effect. + + .. code-block:: python + + def forward(self, x): + a = x + 1 + return x + self.attr_1 + + After dead code is eliminated, `a = x + 1` has been removed, and the rest + of `forward` remains. + + .. code-block:: python + + def forward(self, x): + return x + self.attr_1 + + .. warning:: + + Dead code elimination has some heuristics to avoid removing + side-effectful nodes (see Node.is_impure) but in general coverage + is very bad, so you should assume that this method is not sound + to call unless you know that your FX graph consists entirely + of functional operations. + """ + # Lint the graph first to make sure its topologically sorted, otherwise + # DCE below will not behave as expected. + self.lint() + + # Reverse iterate so that when we remove a node, any nodes used as an + # input to that node have an updated user count that no longer reflects + # the removed node. + changed = False + for node in reversed(self.nodes): + if not node.is_impure() and len(node.users) == 0: + self.erase_node(node) + changed = True + + return changed + + @compatibility(is_backward_compatible=False) + def set_codegen(self, codegen: CodeGen): + self._codegen = codegen + + @compatibility(is_backward_compatible=False) + def on_generate_code( + self, + make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc] + ): + """Register a transformer function when python code is generated + + Args: + make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]): + a function that returns a code transformer to be registered. + This function is called by `on_generate_code` to obtain the + code transformer. + + This function is also given as its input the currently + registered code transformer (or None if nothing is registered), + in case it is not desirable to overwrite it. This is useful to + chain code transformers together. + + Returns: + a context manager that when used in a `with` statement, to automatically + restore the previously registered code transformer. + + Example: + + .. code-block:: python + + + gm: fx.GraphModule = ... + + # This is a code transformer we want to register. This code + # transformer prepends a pdb import and trace statement at the very + # beginning of the generated torch.fx code to allow for manual + # debugging with the PDB library. + def insert_pdb(body): + return ["import pdb; pdb.set_trace()\\n", *body] + + # Registers `insert_pdb`, and overwrites the current registered + # code transformer (given by `_` to the lambda): + gm.graph.on_generate_code( + lambda _: insert_pdb + ) + + # Or alternatively, registers a code transformer which first + # runs `body` through existing registered transformer, then + # through `insert_pdb`: + gm.graph.on_generate_code( + lambda current_trans: ( + lambda body: insert_pdb( + current_trans(body) if current_trans + else body + ) + ) + ) + + gm.recompile() + gm(*inputs) # drops into pdb + + + This function can also be used as a context manager, with the benefit to + automatically restores the previously registered code transformer: + + .. code-block:: python + + # ... continue from previous example + + with gm.graph.on_generate_code(lambda _: insert_pdb): + # do more stuff with `gm`... + gm.recompile() + gm(*inputs) # drops into pdb + + # now previous code transformer is restored (but `gm`'s code with pdb + # remains - that means you can run `gm` with pdb here too, until you + # run next `recompile()`). + """ + on_gen_code_old = self._codegen._body_transformer + self._codegen._body_transformer = make_transformer(on_gen_code_old) + + @contextlib.contextmanager + def on_generate_code_context_manager(): + try: + yield + finally: + self._codegen._body_transformer = on_gen_code_old + + return on_generate_code_context_manager() + + +reflectable_magic_methods = { + 'add': '{} + {}', + 'sub': '{} - {}', + 'mul': '{} * {}', + 'floordiv': '{} // {}', + 'truediv': '{} / {}', + 'div': '{} / {}', + 'mod': '{} % {}', + 'pow': '{} ** {}', + 'lshift': '{} << {}', + 'rshift': '{} >> {}', + 'and_': '{} & {}', + 'or_': '{} | {}', + 'xor': '{} ^ {}', + 'getitem': '{}[{}]', + 'matmul': '{} @ {}', +} + +magic_methods = dict({ + 'eq': '{} == {}', + 'ne': '{} != {}', + 'lt': '{} < {}', + 'gt': '{} > {}', + 'le': '{} <= {}', + 'ge': '{} >= {}', + 'pos': '+{}', + 'neg': '-{}', + 'invert': '~{}'}, **reflectable_magic_methods) + +inplace_methods = { + 'iadd': '{} += {}', + 'iand': '{} &= {}', + 'ifloordiv': '{} //= {}', + 'ilshift': '{} <<= {}', + 'imod': '{} %= {}', + 'imul': '{} *= {}', + 'imatmul': '{} @= {}', + 'ior': '{} |= {}', + 'ipow': '{} **= {}', + 'irshift': '{} >>= {}', + 'isub': '{} -= {}', + 'itruediv': '{} /= {}', + 'ixor': '{} ^= {}', + 'setitem': '{}[{}] = {}', +} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/graph_module.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..8e806c61c5e471b82ad73b63657aa4a4a0cf9dd5 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/graph_module.py @@ -0,0 +1,884 @@ +import contextlib +import copy +import itertools +import linecache +import os +import sys +import traceback +import warnings +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set, Type, Union + +import torch +import torch.nn as nn +import torch.overrides +from torch.nn.modules.module import _addindent +from torch.package import Importer, PackageExporter, PackageImporter, sys_importer + +from ._compatibility import compatibility +from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode + +__all__ = [ + "reduce_graph_module", + "reduce_package_graph_module", + "reduce_deploy_graph_module", + "GraphModule", +] + +_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes" + +# Normal exec loses the source code, however we can work with +# the linecache module to recover it. +# Using _exec_with_source will add it to our local cache +# and then tools like TorchScript will be able to get source info. +class _EvalCacheLoader: + def __init__(self): + self.eval_cache = {} + self.next_id = 0 + + def cache(self, src: str, globals: Dict[str, Any], co_fields=None): + """Store the source in a private cache, and add a lazy entry in linecache + that allows the source to be retrieved by 'filename'. + + Args: + src (str): The module source to cache + globals (dict): The module globals + + Returns: + str: The cache key (and dummy filename) generated for src. + """ + + key = self._get_key() + if co_fields: + key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}" + self.eval_cache[key] = src + + # Don't mutate globals so that this loader is only used + # to populate linecache, and doesn't interact with other modules + # that might check `__loader__` + globals_copy = globals.copy() + globals_copy["__file__"] = key + globals_copy["__name__"] = key + globals_copy["__loader__"] = self + linecache.lazycache(key, globals_copy) + + return key + + # Part of the loader protocol (PEP 302) + # linecache will use this method when trying to find source code + def get_source(self, module_name) -> Optional[str]: + if module_name in self.eval_cache: + return self.eval_cache[module_name] + return None + + def _get_key(self): + key = f".{self.next_id}" + self.next_id += 1 + return key + + +_loader = _EvalCacheLoader() + + +def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None): + key = _loader.cache(src, globals, co_fields) + exec(compile(src, key, "exec"), globals) + + +def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None): + return _method_from_src( + method_name="forward", src=src, globals=globals, co_fields=co_fields + ) + + +def _method_from_src( + method_name: str, src: str, globals: Dict[str, Any], co_fields=None +) -> Callable: + # avoid mutating the passed in dict + globals_copy = globals.copy() + _exec_with_source(src, globals_copy, co_fields) + fn = globals_copy[method_name] + del globals_copy[method_name] + return fn + + +def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: + if name in _custom_builtins: + return _custom_builtins[name].import_str + if _is_from_torch(name): + return "import torch" + module_name, attr_name = importer.get_name(obj) + return f"from {module_name} import {attr_name} as {name}" + + +def _format_import_block(globals: Dict[str, Any], importer: Importer): + import_strs: Set[str] = set() + for name, obj in globals.items(): + import_strs.add(_format_import_statement(name, obj, importer)) + # Sort the imports so we have a stable import block that allows us to + # hash the graph module and get a consistent key for use in a cache. + return "\n".join(sorted(import_strs)) + + +@compatibility(is_backward_compatible=True) +def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module: + # BC: attribute name was changed from `code` to `_code` to facilitate + # making `code` into a property and adding a docstring to it + fn_src = body.get("_code") or body["code"] + forward = _forward_from_src(import_block + fn_src, {}) + return _deserialize_graph_module(forward, body) + + +@compatibility(is_backward_compatible=True) +def reduce_package_graph_module( + importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str +) -> torch.nn.Module: + forward = importer.import_module(generated_module_name).forward + return _deserialize_graph_module(forward, body) + + +@compatibility(is_backward_compatible=True) +def reduce_deploy_graph_module( + importer: PackageImporter, body: Dict[Any, Any], import_block: str +) -> torch.nn.Module: + ns = {} + ns["__builtins__"] = importer.patched_builtins + fn_src = body.get("_code") + assert fn_src is not None + forward = _forward_from_src(import_block + fn_src, ns) + return _deserialize_graph_module(forward, body) + + +# We create a dummy class here because symbolic_trace pulls the forward() +# function off of the class, rather than the instance. This class is used +# in _deserialize_graph_module() below. +class _CodeOnlyModule(torch.nn.Module): + def __init__(self, body): + super().__init__() + self.__dict__ = body + + +def _deserialize_graph_module(forward, body: Dict[Any, Any], graph_module_cls=None) -> torch.nn.Module: + """ + Deserialize a GraphModule given the dictionary of the original module, + using the code to reconstruct the graph. We delete the actual graph before + saving the dictionary so that changes to the in-memory graph format do not + get serialized. + """ + + # Try to retrieve the forward source in a backward-compatible way + _CodeOnlyModule.forward = forward + + tracer_cls = body.get("_tracer_cls") + if tracer_cls is None: + from ._symbolic_trace import Tracer + + tracer_cls = Tracer + + graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule") + + # This is a workaround for a mypy linter issue related to + # passing base class as an argument - https://github.com/python/mypy/issues/5865. + cls_tracer: Any = tracer_cls + + class KeepModules(cls_tracer): + # we shouldn't trace into any of the submodules, + # because they were not traced in the original GraphModule + def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: + return True + + com = _CodeOnlyModule(body) + + tracer_extras = body.get("_tracer_extras", {}) + graph = KeepModules().trace(com, **tracer_extras) + + # Manually set Tracer class on the reconstructed Graph, to avoid + # referencing the private local subclass KeepModules. + graph._tracer_cls = tracer_cls + from ._lazy_graph_module import _make_graph_module + gm = _make_graph_module(com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls) + + # The GraphModule constructor only retains attributes referenced by the graph. + # In this case, our goal is return a GraphModule as close to identical as the one + # put into the package. If any additional attributes were present in body, + # we should keep them. + for k, v in body.items(): + if not hasattr(gm, k): + setattr(gm, k, v) + return gm + + +# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' +# This installs empty Modules where none exist yet if they are subpaths of target +def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): + *prefix, field = target.split(".") + for item in prefix: + f = getattr(from_module, item) + t = getattr(to_module, item, None) + if f is t: + # we have already installed one of its parents + # (e.g. target = root.linear.weight, but we have already installed root.linear) + # once we install a parent, we no longer need to copy the children + # since all the needed properties will already be present + return + + if t is None: + t = torch.nn.Module() + setattr(to_module, item, t) + from_module, to_module = f, t + + orig = getattr(from_module, field) + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): + to_module.register_buffer(field, orig) + else: + setattr(to_module, field, orig) + + +# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module +# This installs empty Modules where none exist yet if they are subpaths of target +def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): + *prefix, field = target.split(".") + for item in prefix: + t = getattr(to_module, item, None) + + if t is None: + t = torch.nn.Module() + setattr(to_module, item, t) + to_module = t + + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(from_obj, torch.Tensor) and not isinstance( + from_obj, torch.nn.Parameter + ): + to_module.register_buffer(field, from_obj) + else: + setattr(to_module, field, from_obj) + + +class _WrappedCall: + def __init__(self, cls, cls_call): + self.cls = cls + self.cls_call = cls_call + + # Previously, if an error occurred when valid + # symbolically-traced code was run with an invalid input, the + # user would see the source of the error as coming from + # `File "`, where N is some number. We use + # this function to generate a more informative error message. We + # return the traceback itself, a message explaining that the + # error occurred in a traced Module's generated forward + # function, and five lines of context surrounding the faulty + # line + @staticmethod + def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: + # auxiliary variables (for readability) + err_lineno = frame_summary.lineno + assert err_lineno is not None + line = frame_summary.line + assert line is not None + err_line_len = len(line) + all_src_lines = linecache.getlines(frame_summary.filename) + + # constituent substrings of the error message + tb_repr = traceback.format_exc() + custom_msg = ( + "Call using an FX-traced Module, " + f"line {err_lineno} of the traced Module's " + "generated forward function:" + ) + before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) + marker = "~" * err_line_len + "~~~ <--- HERE" + err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) + + # joined message + return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) + + def __call__(self, obj, *args, **kwargs): + try: + if self.cls_call is not None: + return self.cls_call(obj, *args, **kwargs) + else: + return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] + except Exception as e: + assert e.__traceback__ + topmost_framesummary: traceback.FrameSummary = ( + traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] + ) # type: ignore[arg-type] + if "eval_with_key" in topmost_framesummary.filename: + print( + _WrappedCall._generate_error_message(topmost_framesummary), + file=sys.stderr, + ) + raise e.with_traceback(None) # noqa: TRY200 + else: + raise e + +@compatibility(is_backward_compatible=True) +class GraphModule(torch.nn.Module): + """ + GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a + ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated + from that ``graph``. + + .. warning:: + + When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically + regenerated. However, if you edit the contents of the ``graph`` without reassigning + the ``graph`` attribute itself, you must call ``recompile()`` to update the generated + code. + """ + + def __new__(cls: "Type[GraphModule]", *args, **kwargs): + # each instance of a graph module needs its own forward method + # so create a new singleton class for each instance. + # it is a subclass of the user-defined class, the only difference + # is an extra layer to install the forward method + + # address issue described at https://github.com/pytorch/pytorch/issues/63883 + # in other words, traverse class hierarchy to fix the redundant class definition problem + for t in cls.__mro__: + c = t.__qualname__.split(".")[-1] + if c != "GraphModuleImpl": + cls = t + break + + class GraphModuleImpl(cls): # type: ignore[misc, valid-type] + pass + + return super().__new__(GraphModuleImpl) + + @compatibility(is_backward_compatible=True) + def __init__( + self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: Graph, + class_name: str = "GraphModule", + ): + """ + Construct a GraphModule. + + Args: + + root (Union[torch.nn.Module, Dict[str, Any]): + ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type. + In the case that ``root`` is a Module, any references to Module-based objects (via qualified + name) in the Graph's Nodes' ``target`` field will be copied over from the respective place + within ``root``'s Module hierarchy into the GraphModule's module hierarchy. + In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be + looked up directly in the dict's keys. The object mapped to by the Dict will be copied + over into the appropriate place within the GraphModule's module hierarchy. + + graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation + + class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all + error messages will report as originating from ``GraphModule``. It may be helpful to set this + to ``root``'s original name or a name that makes sense within the context of your transform. + """ + super().__init__() + self.__class__.__name__ = class_name + if isinstance(root, torch.nn.Module): + if hasattr(root, "training"): + self.training = root.training + + # When we pickle/unpickle graph module, we don't want to drop any module or attributes. + if isinstance(root, _CodeOnlyModule): + for k, _ in root.named_children(): + _copy_attr(root, self, k) + + for k, _ in root.named_buffers(): + _copy_attr(root, self, k) + + for k, _ in root.named_parameters(): + _copy_attr(root, self, k) + + for node in graph.nodes: + if node.op in ["get_attr", "call_module"]: + assert isinstance(node.target, str) + _copy_attr(root, self, node.target) + elif isinstance(root, dict): + targets_to_copy = [] + for node in graph.nodes: + if node.op in ["get_attr", "call_module"]: + assert isinstance(node.target, str) + if node.target not in root: + raise RuntimeError( + "Node " + + str(node) + + " referenced target " + + node.target + + " but that target was not provided in ``root``!" + ) + targets_to_copy.append(node.target) + # Sort targets in ascending order of the # of atoms. + # This will ensure that less deeply nested attributes are assigned + # before more deeply nested attributes. For example, foo.bar + # will be assigned before foo.bar.baz. Otherwise, we might assign + # the user-provided ``foo.bar`` and wipe out the previously-assigned + # ``foo.bar.baz`` + targets_to_copy.sort(key=lambda t: t.count(".")) + for target_to_copy in targets_to_copy: + _assign_attr(root[target_to_copy], self, target_to_copy) + else: + raise RuntimeError("Unsupported type " + str(root) + " passed for root!") + + self.graph = graph + + # Store the Tracer class responsible for creating a Graph separately as part of the + # GraphModule state, except when the Tracer is defined in a local namespace. + # Locally defined Tracers are not pickleable. This is needed because torch.package will + # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer + # to re-create the Graph during deserialization. + self._tracer_cls = None + if ( + self.graph._tracer_cls + and "" not in self.graph._tracer_cls.__qualname__ + ): + self._tracer_cls = self.graph._tracer_cls + + self._tracer_extras = {} + if self.graph._tracer_extras: + self._tracer_extras = self.graph._tracer_extras + + # Dictionary to store metadata + self.meta: Dict[str, Any] = {} + self._replace_hook = None + + # TorchScript breaks trying to compile the graph setter because of the + # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 + # + # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway + __jit_unused_properties__ = ["graph"] + + @property + def graph(self) -> Graph: + """ + Return the ``Graph`` underlying this ``GraphModule`` + """ + return self._graph + + @graph.setter + def graph(self, g: Graph) -> None: + """ + Set the underlying ``Graph`` for this ``GraphModule``. This will internally + recompile the ``GraphModule`` so that the generated ``forward()`` function + corresponds to ``g`` + """ + assert isinstance(g, Graph), f"Expected a Graph instance, but got {type(g)}" + self._graph = g + g.owning_module = self + self.recompile() + + @compatibility(is_backward_compatible=False) + def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): + """Dumps out module to ``folder`` with ``module_name`` so that it can be + imported with ``from import `` + + Args: + + folder (Union[str, os.PathLike]): The folder to write the code out to + + module_name (str): Top-level name to use for the ``Module`` while + writing out the code + """ + folder = Path(folder) + Path(folder).mkdir(exist_ok=True) + torch.save(self.state_dict(), folder / "state_dict.pt") + tab = " " * 4 + custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()]) + model_str = f""" +import torch +{custom_builtins} + +from torch.nn import * +class {module_name}(torch.nn.Module): + def __init__(self): + super().__init__() +""" + + def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: + safe_reprs = [ + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + ] + if type(module) in safe_reprs: + return f"{module.__repr__()}" + else: + return None + + blobified_modules = [] + for module_name, module in self.named_children(): + module_str = _gen_model_repr(module_name, module) + if module_str is None: + module_file = folder / f"{module_name}.pt" + torch.save(module, module_file) + blobified_modules.append(module_name) + module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") + module_str = f"torch.load(r'{module_file}') # {module_repr}" + model_str += f"{tab*2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in self._buffers.items(): + if buffer is None: + continue + model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" + + for param_name, param in self._parameters.items(): + if param is None: + continue + model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" + + model_str += ( + f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" + ) + model_str += f"{_addindent(self.code, 4)}\n" + + module_file = folder / "module.py" + module_file.write_text(model_str) + + init_file = folder / "__init__.py" + init_file.write_text("from .module import *") + + if len(blobified_modules) > 0: + warnings.warn( + "Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}" + ) + + @compatibility(is_backward_compatible=True) + def add_submodule(self, target: str, m: torch.nn.Module) -> bool: + """ + Adds the given submodule to ``self``. + + This installs empty Modules where none exist yet if they are + subpaths of ``target``. + + Args: + target: The fully-qualified string name of the new submodule + (See example in ``nn.Module.get_submodule`` for how to + specify a fully-qualified string.) + m: The submodule itself; the actual object we want to + install in the current Module + + Return: + bool: Whether or not the submodule could be inserted. For + this method to return True, each object in the chain + denoted by ``target`` must either a) not exist yet, + or b) reference an ``nn.Module`` (not a parameter or + other attribute) + """ + *prefix, field = target.split(".") + mod: torch.nn.Module = self + + for item in prefix: + + submod = getattr(mod, item, None) + + if submod is None: + submod = torch.nn.Module() + setattr(mod, item, submod) + + if not isinstance(submod, torch.nn.Module): + return False + + mod = submod + + mod.add_module(field, m) + return True + + @compatibility(is_backward_compatible=True) + def delete_submodule(self, target: str) -> bool: + """ + Deletes the given submodule from ``self``. + + The module will not be deleted if ``target`` is not a valid + target. + + Args: + target: The fully-qualified string name of the new submodule + (See example in ``nn.Module.get_submodule`` for how to + specify a fully-qualified string.) + + Returns: + bool: Whether or not the target string referenced a + submodule we want to delete. A return value of ``False`` + means that the ``target`` was not a valid reference to + a submodule. + """ + atoms = target.split(".") + path, target_submod = atoms[:-1], atoms[-1] + mod: torch.nn.Module = self + + # Get the parent module + for item in path: + + if not hasattr(mod, item): + return False + + mod = getattr(mod, item) + + if not isinstance(mod, torch.nn.Module): + return False + + if not hasattr(mod, target_submod): + return False + + if not isinstance(getattr(mod, target_submod), torch.nn.Module): + return False + + delattr(mod, target_submod) + return True + + @compatibility(is_backward_compatible=True) + def delete_all_unused_submodules(self) -> None: + """ + Deletes all unused submodules from ``self``. + + A Module is considered "used" if any one of the following is + true: + 1. It has children that are used + 2. Its forward is called directly via a ``call_module`` node + 3. It has a non-Module attribute that is used from a + ``get_attr`` node + + This method can be called to clean up an ``nn.Module`` without + manually calling ``delete_submodule`` on each unused submodule. + """ + used: List[str] = [] + + for node in self.graph.nodes: + + if node.op == "call_module" or node.op == "get_attr": + + # A list of strings representing the different parts + # of the path. For example, `foo.bar.baz` gives us + # ["foo", "bar", "baz"] + fullpath = node.target.split(".") + + # If we're looking at multiple parts of a path, join + # join them with a dot. Otherwise, return that single + # element without doing anything to it. + def join_fn(x: str, y: str) -> str: + return ".".join([x, y] if y else [x]) + + # Progressively collect all the names of intermediate + # modules. For example, if we have the target + # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and + # `foo.bar.baz` to the list. + used.extend(itertools.accumulate(fullpath, join_fn)) + + # For a `call_module` node, also register all recursive submodules + # as used + if node.op == "call_module": + try: + submod = self.get_submodule(node.target) + + for submod_name, _ in submod.named_modules(): + if submod_name != "": + used.append(".".join([node.target, submod_name])) + except AttributeError: + # Node referenced nonexistent submodule, don't need to + # worry about GCing anything + pass + + to_delete = [name for name, _ in self.named_modules() if name not in used] + + for name in to_delete: + self.delete_submodule(name) + + @property + def code(self) -> str: + """ + Return the Python code generated from the ``Graph`` underlying this + ``GraphModule``. + """ + if not hasattr(self, "_code"): + raise RuntimeError( + "Code has not been generated! Please report a bug to PyTorch" + ) + return self._code + + @compatibility(is_backward_compatible=True) + def recompile(self) -> PythonCode: + """ + Recompile this GraphModule from its ``graph`` attribute. This should be + called after editing the contained ``graph``, otherwise the generated + code of this ``GraphModule`` will be out of date. + """ + if isinstance(self._graph._codegen, _PyTreeCodeGen): + self._in_spec = self._graph._codegen.pytree_info.in_spec + self._out_spec = self._graph._codegen.pytree_info.out_spec + python_code = self._graph.python_code(root_module="self") + self._code = python_code.src + self._lineno_map = python_code._lineno_map + + cls = type(self) + co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) + + # Determine whether this class explicitly defines a __call__ implementation + # to wrap. If it does, save it in order to have wrapped_call invoke it. + # If it does not, wrapped_call can use a dynamic call to super() instead. + # In most cases, super().__call__ should be torch.nn.Module.__call__. + # We do not want to hold a reference to Module.__call__ here; doing so will + # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. + cls_call = cls.__call__ if "__call__" in vars(cls) else None + + if "_wrapped_call" not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + + def call_wrapped(self, *args, **kwargs): + return self._wrapped_call(self, *args, **kwargs) + + cls.__call__ = call_wrapped # type: ignore[method-assign] + + return python_code + + # Passing Tracer as argument allows subclasses extending fx.GraphModule + # define their own Tracer (extending fx.Tracer). + def __reduce_deploy__(self, importer: Importer): + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, importer) + return (reduce_deploy_graph_module, (dict_without_graph, import_block)) + + def __reduce_package__(self, exporter: PackageExporter): + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + generated_module_name = f"fx-generated._{exporter.get_unique_id()}" + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, exporter.importer) + module_code = import_block + self.code + exporter.save_source_string(generated_module_name, module_code) + return ( + reduce_package_graph_module, + (dict_without_graph, generated_module_name), + ) + + def __reduce__(self): + """ + Serialization of GraphModule. We serialize only the generated code, not + the underlying ``Graph``. This is because ``Graph`` does not have on-disk + backward-compatibility guarantees, whereas Python source code does. + On the deserialization side, we symbolically trace through the generated + code to regenerate the underlying ``Graph`` + """ + dict_without_graph = self.__dict__.copy() + + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, sys_importer) + del dict_without_graph["_graph"] + return (reduce_graph_module, (dict_without_graph, import_block)) + + def _deepcopy_init(self): + return GraphModule.__init__ + + # because __reduce__ is defined for serialization, + # we need to define deepcopy otherwise it will call __reduce__ + # and cause symbolic tracing to occur every time we try to copy the object + def __deepcopy__(self, memo): + res = type(self).__new__(type(self)) + memo[id(self)] = res + fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo)) + self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["_graph"]) + # hooks are lost during `GraphModule.__init__`, so we need to copy over + # them explicitly, note right now we are only copying state_dict related + # hooks, to reduce bc-related issues, we can copy forward/backward related + # hooks in the future as well if needed + extra_preserved_attrs = [ + "_state_dict_hooks", + "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", + "_replace_hook", + ] + for attr in extra_preserved_attrs: + if attr in self.__dict__: + setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo)) + res.meta = copy.deepcopy(getattr(self, "meta", {}), memo) + if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta: + for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): + setattr(res, attr_name, attr) + return res + + def __copy__(self): + from ._lazy_graph_module import _make_graph_module + res = _make_graph_module(self, self.graph) + res.meta = getattr(self, "meta", {}) + return res + + @compatibility(is_backward_compatible=False) + def print_readable(self, print_output=True): + """ + Return the Python code generated for current GraphModule and its children GraphModules + """ + verbose_python_code = self._graph.python_code(root_module="self", verbose=True) + module_code = verbose_python_code.src + module_code = module_code.lstrip("\n") + module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code + module_code = _addindent(module_code, 4) + + submodule_code_list = [""] + for submodule in self.children(): + if isinstance(submodule, GraphModule): + submodule_code_list.append(submodule.print_readable(print_output=False)) + submodule_code = "\n".join(submodule_code_list) + submodule_code = _addindent(submodule_code, 4) + + output = module_code + submodule_code + if print_output: + print(module_code + submodule_code) + return output + + def __str__(self) -> str: + orig_str = super().__str__() + print_readable_reminder = ( + "# To see more debug info, please use `graph_module.print_readable()`" + ) + return "\n".join([orig_str, self._code, print_readable_reminder]) + + def _replicate_for_data_parallel(self): + new_gm = self.__copy__() + new_gm._is_replica = True + return new_gm + + @contextlib.contextmanager + def _set_replace_hook(self, f): + """ + Takes a callable which will be called everytime when we replace a node + to a new node, or change the node's name. Callable takes three arguments: + the old node we're changing, and NAME of the new node, followed by the + user node which consumes the old node to be replaced. + """ + assert callable(f), "Replace hook must be a callable." + prev, self._replace_hook = self._replace_hook, f + try: + yield + finally: + self._replace_hook = prev + + +# workarounds for issues in __torch_function__ + +# WAR for __torch_function__ not handling tensor lists, +# fix is in https://github.com/pytorch/pytorch/pull/34725 +# orig_cat = torch.cat +# def patched_cat(*args, **kwargs): +# tensors = args[0] +# for t in tensors: +# if isinstance(t, Proxy): +# return t.__torch_function__(patched_cat, (), args, kwargs) +# return orig_cat(*args, **kwargs) +# patched_cat.__module__ = 'torch' +# patched_cat.__name__ = 'cat' +# torch.cat = patched_cat diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/immutable_collections.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/immutable_collections.py new file mode 100644 index 0000000000000000000000000000000000000000..1e65e286c3db14515ecda4e869c4cbb654cd271a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/immutable_collections.py @@ -0,0 +1,112 @@ +from typing import Any, Dict, Iterable, List, Tuple + +from torch.utils._pytree import ( + _dict_flatten, + _dict_flatten_with_keys, + _dict_unflatten, + _list_flatten, + _list_flatten_with_keys, + _list_unflatten, + Context, + register_pytree_node, +) + +from ._compatibility import compatibility + + +__all__ = ["immutable_list", "immutable_dict"] + +_help_mutation = """\ +If you are attempting to modify the kwargs or args of a torch.fx.Node object, +instead create a new copy of it and assign the copy to the node: + new_args = ... # copy and mutate args + node.args = new_args +""" + + +def _no_mutation(self, *args, **kwargs): + raise NotImplementedError( + f"'{type(self).__name__}' object does not support mutation. {_help_mutation}", + ) + + +def _create_immutable_container(base, mutable_functions): + container = type("immutable_" + base.__name__, (base,), {}) + for attr in mutable_functions: + setattr(container, attr, _no_mutation) + return container + + +immutable_list = _create_immutable_container( + list, + [ + "__delitem__", + "__iadd__", + "__imul__", + "__setitem__", + "append", + "clear", + "extend", + "insert", + "pop", + "remove", + ], +) +immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),)) +immutable_list.__hash__ = lambda self: hash(tuple(self)) + +compatibility(is_backward_compatible=True)(immutable_list) + +immutable_dict = _create_immutable_container( + dict, + [ + "__delitem__", + "__setitem__", + "clear", + "pop", + "popitem", + "update", + ], +) +immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),)) +immutable_dict.__hash__ = lambda self: hash(tuple(self.items())) +compatibility(is_backward_compatible=True)(immutable_dict) + + +# Register immutable collections for PyTree operations +def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: + return _dict_flatten(d) + + +def _immutable_dict_unflatten( + values: Iterable[Any], + context: Context, +) -> Dict[Any, Any]: + return immutable_dict(_dict_unflatten(values, context)) + + +def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: + return _list_flatten(d) + + +def _immutable_list_unflatten( + values: Iterable[Any], + context: Context, +) -> List[Any]: + return immutable_list(_list_unflatten(values, context)) + + +register_pytree_node( + immutable_dict, + _immutable_dict_flatten, + _immutable_dict_unflatten, + serialized_type_name="torch.fx.immutable_collections.immutable_dict", + flatten_with_keys_fn=_dict_flatten_with_keys, +) +register_pytree_node( + immutable_list, + _immutable_list_flatten, + _immutable_list_unflatten, + serialized_type_name="torch.fx.immutable_collections.immutable_list", + flatten_with_keys_fn=_list_flatten_with_keys, +) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/interpreter.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..47a6f5a5bfc9135cb4adbc468ebf60ac5f655925 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/interpreter.py @@ -0,0 +1,512 @@ +from .graph_module import GraphModule +from ._lazy_graph_module import _make_graph_module +from .graph import Graph +from .node import Argument, Node, Target, map_arg, map_aggregate +from .proxy import Proxy +from ._symbolic_trace import Tracer +from ._compatibility import compatibility +from . import config +import torch.fx.traceback as fx_traceback +import torch +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +import inspect +from contextlib import contextmanager +from torch.hub import tqdm + +__all__ = ['Interpreter', 'Transformer'] + +@compatibility(is_backward_compatible=True) +class Interpreter: + """ + An Interpreter executes an FX graph Node-by-Node. This pattern + can be useful for many things, including writing code + transformations as well as analysis passes. + + Methods in the Interpreter class can be overridden to customize + the behavior of execution. The map of overrideable methods + in terms of call hierarchy:: + + run() + +-- run_node + +-- placeholder() + +-- get_attr() + +-- call_function() + +-- call_method() + +-- call_module() + +-- output() + + Example: + + Suppose we want to swap all instances of ``torch.neg`` with + ``torch.sigmoid`` and vice versa (including their ``Tensor`` + method equivalents). We could subclass Interpreter like so:: + + class NegSigmSwapInterpreter(Interpreter): + def call_function(self, target : Target, + args : Tuple, kwargs : Dict) -> Any: + if target == torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(n) + + def call_method(self, target : Target, + args : Tuple, kwargs : Dict) -> Any: + if target == 'neg': + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(n) + + def fn(x): + return torch.sigmoid(x).neg() + + gm = torch.fx.symbolic_trace(fn) + input = torch.randn(3, 4) + result = NegSigmSwapInterpreter(gm).run(input) + torch.testing.assert_close(result, torch.neg(input).sigmoid()) + + Args: + module (torch.nn.Module): The module to be executed + garbage_collect_values (bool): Whether to delete values after their last + use within the Module's execution. This ensures optimal memory usage during + execution. This can be disabled to, for example, examine all of the intermediate + values in the execution by looking at the ``Interpreter.env`` attribute. + graph (Optional[Graph]): If passed, the interpreter will execute this + graph instead of `module.graph`, using the provided `module` + argument to satisfy any requests for state. + """ + @compatibility(is_backward_compatible=True) + def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None): + self.module = module + self.submodules = dict(self.module.named_modules()) + if graph is not None: + self.graph = graph + else: + self.graph = self.module.graph + self.env : Dict[Node, Any] = {} + self.name = "Interpreter" + self.garbage_collect_values = garbage_collect_values + self.extra_traceback = True + + if self.garbage_collect_values: + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use : Dict[Node, Node] = {} + self.user_to_last_uses : Dict[Node, List[Node]] = {} + + def register_last_uses(n : Node, user : Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + self.user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(self.graph.nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + @compatibility(is_backward_compatible=True) + def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any: + """ + Run `module` via interpretation and return the result. + + Args: + *args: The arguments to the Module to run, in positional order + initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. + This is a dict mapping `Node` to any value. This can be used, for example, to + pre-populate results for certain `Nodes` so as to do only partial evaluation within + the interpreter. + enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and + process_outputs function first before using them. + + Returns: + Any: The value returned from executing the Module + """ + self.env = initial_env if initial_env is not None else {} + + # Positional function args are consumed left-to-right by + # `placeholder` nodes. Use an iterator to keep track of + # position and extract those values. + if enable_io_processing: + args = self.graph.process_inputs(*args) + self.args_iter : Iterator[Any] = iter(args) + pbar = tqdm(total=len(self.graph.nodes), + desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}", + initial=0, position=0, leave=True, disable=config.disable_progress, delay=0) + + for node in self.graph.nodes: + pbar.update(1) + if node in self.env: + # Short circuit if we have this value. This could + # be used, for example, for partial evaluation + # where the caller has pre-populated `env` with + # values for a subset of the program. + continue + + try: + self.env[node] = self.run_node(node) + except Exception as e: + if self.extra_traceback: + msg = f"While executing {node.format_node()}" + msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg) + msg += f"\nOriginal traceback:\n{node.stack_trace}" + e.args = (msg,) + e.args[1:] + if isinstance(e, KeyError): + raise RuntimeError(*e.args) from e + raise + + if self.garbage_collect_values: + for to_delete in self.user_to_last_uses.get(node, []): + del self.env[to_delete] + + if node.op == 'output': + output_val = self.env[node] + return self.graph.process_outputs(output_val) if enable_io_processing else output_val + + @compatibility(is_backward_compatible=True) + def boxed_run(self, args_list): + """ + Run `module` via interpretation and return the result. This uses the "boxed" + calling convention, where you pass a list of arguments, which will be cleared + by the interpreter. This ensures that input tensors are promptly deallocated. + """ + args_iter = iter(args_list) + env = {} + for n in self.graph.nodes: + if n.op == "placeholder": + env[n] = next(args_iter) + args_list.clear() + return self.run(initial_env=env) + + @contextmanager + def _set_current_node(self, node): + with fx_traceback.set_current_meta(node): + yield + + @compatibility(is_backward_compatible=True) + def run_node(self, n : Node) -> Any: + """ + Run a specific node ``n`` and return the result. + Calls into placeholder, get_attr, call_function, + call_method, call_module, or output depending + on ``node.op`` + + Args: + n (Node): The Node to execute + + Returns: + Any: The result of executing ``n`` + """ + with self._set_current_node(n): + args, kwargs = self.fetch_args_kwargs_from_env(n) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + return getattr(self, n.op)(n.target, args, kwargs) + + # Main Node running APIs + @compatibility(is_backward_compatible=True) + def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``placeholder`` node. Note that this is stateful: + ``Interpreter`` maintains an internal iterator over + arguments passed to ``run`` and this method returns + next() on that iterator. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Returns: + Any: The argument value that was retrieved. + """ + assert isinstance(target, str) + if target.startswith('*'): + # For a starred parameter e.g. `*args`, retrieve all + # remaining values from the args list. + return list(self.args_iter) + else: + try: + return next(self.args_iter) + except StopIteration as si: + if len(args) > 0: + return args[0] + else: + raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si + + @compatibility(is_backward_compatible=True) + def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``get_attr`` node. Will retrieve an attribute + value from the ``Module`` hierarchy of ``self.module``. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + Any: The value of the attribute that was retrieved + """ + assert isinstance(target, str) + return self.fetch_attr(target) + + @compatibility(is_backward_compatible=True) + def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``call_function`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the function invocation + """ + assert not isinstance(target, str) + + # Execute the function and return the result + return target(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``call_method`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the method invocation + """ + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + + # Execute the method and return the result + assert isinstance(target, str) + return getattr(self_obj, target)(*args_tail, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``call_module`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the module invocation + """ + # Retrieve executed args and kwargs values from the environment + + # Execute the method and return the result + assert isinstance(target, str) + submod = self.fetch_attr(target) + + return submod(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute an ``output`` node. This really just retrieves + the value referenced by the ``output`` node and returns it. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + Any: The return value referenced by the output node + """ + return args[0] + + # Helper methods + @compatibility(is_backward_compatible=True) + def fetch_attr(self, target : str): + """ + Fetch an attribute from the ``Module`` hierarchy of ``self.module``. + + Args: + target (str): The fully-qualified name of the attribute to fetch + + Return: + Any: The value of the attribute. + """ + target_atoms = target.split('.') + attr_itr = self.module + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr + + @compatibility(is_backward_compatible=True) + def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: + """ + Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` + from the current execution environment. + + Args: + n (Node): The node for which ``args`` and ``kwargs`` should be fetched. + + Return: + Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``. + """ + args = self.map_nodes_to_values(n.args, n) + assert isinstance(args, tuple) + kwargs = self.map_nodes_to_values(n.kwargs, n) + assert isinstance(kwargs, dict) + return args, kwargs + + @compatibility(is_backward_compatible=True) + def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: + """ + Recursively descend through ``args`` and look up the concrete value + for each ``Node`` in the current execution environment. + + Args: + args (Argument): Data structure within which to look up concrete values + + n (Node): Node to which ``args`` belongs. This is only used for error reporting. + """ + def load_arg(n_arg : Node) -> Any: + if n_arg not in self.env: + raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' + f'to diagnose such issues') + return self.env[n_arg] + return map_arg(args, load_arg) + +@compatibility(is_backward_compatible=True) +class Transformer(Interpreter): + """ + ``Transformer`` is a special type of interpreter that produces a + new ``Module``. It exposes a ``transform()`` method that returns + the transformed ``Module``. ``Transformer`` does not require + arguments to run, as ``Interpreter`` does. ``Transformer`` works + entirely symbolically. + + Example: + + Suppose we want to swap all instances of ``torch.neg`` with + ``torch.sigmoid`` and vice versa (including their ``Tensor`` + method equivalents). We could subclass ``Transformer`` like so:: + + class NegSigmSwapXformer(Transformer): + def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + if target == torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(n) + + def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + if target == 'neg': + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(n) + + def fn(x): + return torch.sigmoid(x).neg() + + gm = torch.fx.symbolic_trace(fn) + + transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() + input = torch.randn(3, 4) + torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid()) + + Args: + module (GraphModule): The ``Module`` to be transformed. + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, module): + super().__init__(module) + self.new_graph = Graph() + self.new_graph.set_codegen(module.graph._codegen) + + class TransformerTracer(Tracer): + def __init__(self, graph: Graph): + super().__init__() + self.graph = graph + self.tensor_attrs: Dict[torch.Tensor, str] = {} # type: ignore[assignment] + + def is_leaf_module(self, _, __) -> bool: + return True + + self.tracer = TransformerTracer(self.new_graph) + self.tracer.root = module + + @compatibility(is_backward_compatible=True) + def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: + """ + Execute a ``placeholder`` node. In ``Transformer``, this is + overridden to insert a new ``placeholder`` into the output + graph. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + """ + assert isinstance(target, str) + default_value = next(iter(args)) if args else inspect.Signature.empty + return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer) + + @compatibility(is_backward_compatible=True) + def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: + """ + Execute a ``get_attr`` node. In ``Transformer``, this is + overridden to insert a new ``get_attr`` node into the output + graph. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + """ + assert isinstance(target, str) + return self.tracer.create_proxy("get_attr", target, args, kwargs) + + @compatibility(is_backward_compatible=True) + def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + # Override so that the leaf module policy from `self.tracer` is respected. + assert isinstance(target, str) + submod = self.fetch_attr(target) + return self.tracer.call_module(submod, submod.forward, args, kwargs) + + @compatibility(is_backward_compatible=True) + def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + # Override so that functions that were wrapped are still wrapped. + return self.tracer.create_proxy('call_function', target, args, kwargs) + + @compatibility(is_backward_compatible=True) + def transform(self) -> GraphModule: + """ + Transform ``self.module`` and return the transformed + ``GraphModule``. + """ + with fx_traceback.preserve_node_meta(): + result = super().run(enable_io_processing=False) + if result is not None: + def strip_proxy(a : Union[Argument, Proxy]) -> Any: + return a.node if isinstance(a, Proxy) else a + self.new_graph.output(map_aggregate(result, strip_proxy)) + return _make_graph_module(self.module, self.new_graph) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/traceback.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/traceback.py new file mode 100644 index 0000000000000000000000000000000000000000..438babe20910316fab5e9b56385fa4dd9b3af5cc --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/traceback.py @@ -0,0 +1,99 @@ +import traceback +from contextlib import contextmanager +from typing import List, Any, Dict +from ._compatibility import compatibility + +__all__ = ['preserve_node_meta', 'has_preserved_node_meta', + 'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr', + 'format_stack', 'set_current_meta', 'get_current_meta'] + +current_meta: Dict[str, Any] = {} +should_preserve_node_meta = False + + +@compatibility(is_backward_compatible=False) +@contextmanager +def preserve_node_meta(): + global should_preserve_node_meta + + saved_should_preserve_node_meta = should_preserve_node_meta + try: + should_preserve_node_meta = True + yield + finally: + should_preserve_node_meta = saved_should_preserve_node_meta + + +@compatibility(is_backward_compatible=False) +def set_stack_trace(stack : List[str]): + global current_meta + + if should_preserve_node_meta and stack: + current_meta["stack_trace"] = "".join(stack) + + +@compatibility(is_backward_compatible=False) +def set_grad_fn_seq_nr(seq_nr): + global current_meta + + if should_preserve_node_meta: + # The seq_nr is captured by eager mode in the grad_fn during forward + current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr] + current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1 + + +@compatibility(is_backward_compatible=False) +def reset_grad_fn_seq_nr(): + # NB: reset state properly, this would be helpful towards supporting + # reentrant autograd if we actually wanted to do that. + global current_meta + if should_preserve_node_meta: + current_level = current_meta.get("in_grad_fn", 0) + assert current_level > 0 + if current_level == 1: + del current_meta["in_grad_fn"] + del current_meta["grad_fn_seq_nr"] + else: + current_meta["in_grad_fn"] = current_level - 1 + current_meta["grad_fn_seq_nr"].pop() + + +@compatibility(is_backward_compatible=False) +def format_stack() -> List[str]: + if should_preserve_node_meta: + return [current_meta.get("stack_trace", "")] + else: + # fallback to traceback.format_stack() + return traceback.format_list(traceback.extract_stack()[:-1]) + + +@compatibility(is_backward_compatible=False) +def has_preserved_node_meta() -> bool: + return should_preserve_node_meta + + +@compatibility(is_backward_compatible=False) +@contextmanager +def set_current_meta(node): + global current_meta + if should_preserve_node_meta and node.meta: + saved_meta = current_meta + try: + current_meta = node.meta.copy() + + # Append (node.name, node.target) onto "from_node" for provenance tracking + if "from_node" not in current_meta: + current_meta["from_node"] = [(node.name, node.target)] + elif current_meta["from_node"][-1][0] != node.name: + current_meta["from_node"].append((node.name, node.target)) + + yield + finally: + current_meta = saved_meta + else: + yield + + +@compatibility(is_backward_compatible=False) +def get_current_meta() -> Dict[str, Any]: + return current_meta