diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2e27775f099134c0fe0d6cba94ee6c4b4c86052 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/error.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/error.py new file mode 100644 index 0000000000000000000000000000000000000000..9fc55092fda9515c953407b44317b678fe85f74f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/error.py @@ -0,0 +1,56 @@ +from enum import Enum + + +class ExportErrorType(Enum): + # User providing invalid inputs to either tracer, or other public facing APIs + INVALID_INPUT_TYPE = 1 + + # User returning values from their models that we don’t support. + INVALID_OUTPUT_TYPE = 2 + + # Generated IR does not conform to Export IR Specification. + VIOLATION_OF_SPEC = 3 + + # User’s code contains types and functionalities we don’t support. + NOT_SUPPORTED = 4 + + # User's code didn't provide necessary details for us to successfully trace and export. + # For example, we use a lot of decorators and ask users to annotate their model. + MISSING_PROPERTY = 5 + + # User is using an API without proper initialization step. + UNINITIALIZED = 6 + + +def internal_assert(pred: bool, assert_msg: str) -> None: + """ + This is exir's custom assert method. It internally just throws InternalError. + Note that the sole purpose is to throw our own error while maintaining similar syntax + as python assert. + """ + + if not pred: + raise InternalError(assert_msg) + + +class InternalError(Exception): + """ + Raised when an internal invariance is violated in EXIR stack. + Should hint users to report a bug to dev and expose the original + error message. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class ExportError(Exception): + """ + This type of exception is raised for errors that are directly caused by the user + code. In general, user errors happen during model authoring, tracing, using our public + facing APIs, and writing graph passes. + """ + + def __init__(self, error_code: ExportErrorType, message: str) -> None: + prefix = f"[{error_code}]: " + super().__init__(prefix + message) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_base.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_base.py new file mode 100644 index 0000000000000000000000000000000000000000..4f31e71dc1d5f5facf7b85eb78779ffd715fd2f8 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_base.py @@ -0,0 +1,435 @@ +import operator +import traceback +import typing +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +from functorch.experimental.control_flow import _unstack_pytree +from torch import fx +from torch._dispatch.python import enable_python_dispatcher +from torch._export.pass_infra.node_metadata import NodeMetadata +from torch._export.pass_infra.proxy_value import ProxyValue +from torch._subclasses import FakeTensor, UnsupportedFakeTensorException +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx import traceback as fx_traceback +from torch.fx.experimental.proxy_tensor import PythonKeyTracer +from torch.fx.graph import CodeGen +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.utils import _pytree as pytree + + +__all__ = ["_ExportPassBaseDeprecatedDoNotUse"] + + +Argument = Any +Value = Any +Fn = Callable[..., Any] +PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] + + +_TORCH_SYM_OPS: Set[Callable] = { + torch.sym_int, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_not, + torch.sym_sqrt, +} + + +class ExportPassBaseError(RuntimeError): + pass + + +class _ExportPassBaseDeprecatedDoNotUse(PassBase): + """ + Interpreter-based pass class to help users maintain the IR spec while writing + transformations. + """ + + @staticmethod + def _create_dummy_node_metadata(): + return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))}) + + + class ExportTracer(PythonKeyTracer): + def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None: + super().__init__() + self.callback = callback + self.root = torch.nn.Module() + self.graph = torch.fx.Graph() + self.graph.set_codegen(codegen) + self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment] + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self.submodules: Dict[torch.nn.Module, str] = {} + + def trace(self) -> None: + raise ExportPassBaseError("ExportTracer doesn't support trace().") + + def create_arg(self, a: Argument) -> torch.fx.Node: + if isinstance(a, torch.nn.Module): + if a not in self.submodules: + name_submodule = f"submodule_{len(self.submodules)}" + self.root.add_module(name_submodule, a) + self.submodules[a] = name_submodule + elif isinstance(a, FakeTensor): + if not hasattr(a, "constant") or a.constant is None: + raise ExportPassBaseError(f"Cannot add {a} to graph.") + a = a.constant + node = super().create_arg(a) + if ( + isinstance(a, torch.Tensor) + and isinstance(node, torch.fx.Node) + and node.op == "get_attr" + ): + self.set_metadata(node, a) + self.callback.on_attr(ProxyValue(a, node)) + return node + + def set_metadata( + self, node: torch.fx.Node, value: Argument, + ) -> None: + # propagate the fake tensor or sym nodes + def make_val( + x: Argument, + ) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]: + if isinstance(x, FakeTensor): + return x + elif isinstance(x, torch.Tensor): + if x.is_quantized: + # TODO (tmanlaibaatar) properly support Quantized FakeTensor + x = torch.dequantize(x) + + try: + assert self.fake_tensor_mode is not None + # TODO we should allocate static shapes + # for param/buffer values + if isinstance(x, torch.nn.Parameter): + fake_tensor = self.fake_tensor_mode.from_tensor( + x, static_shapes=True + ) + else: + fake_tensor = self.fake_tensor_mode.from_tensor(x) + except UnsupportedFakeTensorException: + # TODO: This is just a workaround to get over the + # x.as_subclass error + print( + "Fakeifying a Tensor subclass is not supported \ + right now. Instead a TensorMetadata is used." + ) + fake_tensor = None + return fake_tensor + elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)): + return x + else: + return None + + node.meta["val"] = pytree.tree_map(make_val, value) + + # Set the tensor_metadata for values that do not have a corresponding FakeTensor + def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]: + if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor): + if x.is_quantized: + # TODO (tmanlaibaatar) properly support Quantized FakeTensor + x = torch.dequantize(x) + + try: + assert self.fake_tensor_mode is not None + _ = self.fake_tensor_mode.from_tensor(x) + tensor_meta = None + except UnsupportedFakeTensorException: + # TODO: This is just a workaround to get over the + # x.as_subclass error + tensor_meta = _extract_tensor_metadata(x) + return tensor_meta + else: + return None + + node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value) + + class ExportInterpreter(fx.Interpreter): + def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None: + super().__init__(gm) + self.callback = callback + self.node: torch.fx.Node = next(iter(gm.graph.nodes)) + + def placeholder( + self, + target: str, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> ProxyValue: + arg = super().placeholder(target, args, kwargs) + return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta)) + + def output( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> ProxyValue: + return self.callback.output(args[0], NodeMetadata(self.node.meta)).data + + def call_function( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> ProxyValue: + meta = NodeMetadata(self.node.meta) + + if target == operator.getitem: + value, key = args + return self.callback.call_getitem(value, key, meta) + elif getattr(target, "__module__", None) in {"_operator", "math"}: + assert callable(target) + return self.callback.call_sym(target, args, meta) + elif target in _TORCH_SYM_OPS: + assert callable(target) + return self.callback.call_sym(target, args, meta) + elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)): + return self.callback.call_operator( + target, + args, + kwargs, + meta, + ) + elif target == torch.ops.higher_order.cond: + pred, true_fn, false_fn, inputs = args + return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta) + elif target == torch.ops.higher_order.map_impl: + f, mapped_args, operands = args # type: ignore[assignment] + return self.callback.call_map(f, mapped_args, operands, meta) + # For other unregistered HigherOrderOps, just interpret them blindly + elif isinstance(target, torch._ops.HigherOrderOperator): + return self.callback._fx( + "call_function", + target, + args, + kwargs, + meta, + ) + else: + raise ExportPassBaseError(f"Unsupported target type: {target}") + + def get_attr( + self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] + ) -> Argument: + return super().get_attr(target, args, kwargs) + + def call_module( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> None: + raise ExportPassBaseError("call_module is not supported.") + + def call_method( + self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] + ) -> None: + raise ExportPassBaseError("call_method is not supported.") + + def run_node(self, n: torch.fx.Node) -> Argument: + self.node = n + self.callback.node_debug_str = n.format_node() + return super().run_node(n) + + def __init__(self) -> None: + self.interpreter = torch.fx.Interpreter( + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ) + self.tracer = self.ExportTracer(self, CodeGen()) + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self._initialized = True + self.node_debug_str: typing.Optional[str] = None + + def _fx( + self, + kind: str, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + args_data, kwargs_data = pytree.tree_map_only( + ProxyValue, lambda x: x.data, (args, kwargs) + ) + res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data) + args_proxy, kwargs_proxy = pytree.tree_map_only( + ProxyValue, lambda x: x.proxy, (args, kwargs) + ) + + name = None + if isinstance(target, torch._ops.OpOverload): + name = self.tracer.graph._target_to_str(target.overloadpacket.__name__) + + res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name) + res_proxy.node.meta.update(meta.data) + self.tracer.set_metadata(res_proxy.node, res_data) + return ProxyValue(res_data, res_proxy) + + def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]: + # TODO(angelayi): Update this with what we decide to do for metadata in + # the exported graph module + if (args := graph_module.meta.get("args", None)) is not None: + return list(args) + + def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]: + if "val" in node.meta: + fake = node.meta["val"] + if hasattr(fake, "constant") and fake.constant is not None: + return fake.constant + return fake + elif tensor_meta := node.meta.get("tensor_meta"): + assert self.fake_tensor_mode is not None + return FakeTensor( + self.fake_tensor_mode, + torch.empty( + tensor_meta.shape, + dtype=tensor_meta.dtype, + device="meta", + requires_grad=tensor_meta.requires_grad, + memory_format=tensor_meta.memory_format, + ), + torch.device("cpu"), + ) + elif len(node.users) == 0: + return None + raise ExportPassBaseError( + f"Cannot construct an input for graph module: {graph_module}.", + ) + + return [ + extract_input(node) + for node in graph_module.graph.nodes + if node.op == "placeholder" + ] + + def on_attr(self, attr: ProxyValue) -> None: + pass + + def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue: + arg_proxy = self.tracer.create_proxy("placeholder", name, (), {}) + arg_proxy.node.meta = meta.data + self.tracer.set_metadata(arg_proxy.node, arg) + return ProxyValue(arg, arg_proxy) + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + return self._fx("call_function", op, args, kwargs, meta) + + def call_sym( + self, + target: Fn, + args: Tuple[Argument, ...], + meta: NodeMetadata, + ) -> ProxyValue: + return self._fx("call_function", target, args, {}, meta) + + def call_cond( + self, + pred: ProxyValue, + true_fn: torch.fx.GraphModule, + false_fn: torch.fx.GraphModule, + inputs: List[Argument], + meta: NodeMetadata, + ) -> ProxyValue: + true_branch = self.call_submodule(true_fn, tuple(inputs)) + false_branch = self.call_submodule(false_fn, tuple(inputs)) + assert true_branch is not None + assert false_branch is not None + return self._fx( + "call_function", + torch.ops.higher_order.cond, + (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)), + {}, + meta, + ) + + def call_map( + self, + f: torch.fx.GraphModule, + mapped_args: List[ProxyValue], + operands: List[ProxyValue], + meta: NodeMetadata, + ) -> ProxyValue: + xs = _unstack_pytree([arg.data for arg in mapped_args])[0] + f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands])) + assert f_branch is not None + return self._fx( + "call_function", + torch.ops.higher_order.map_impl, + (f_branch.graph_module, mapped_args, operands), + {}, + meta, + ) + + def call_getitem( + self, value: ProxyValue, key: int, meta: NodeMetadata + ) -> ProxyValue: + return self._fx("call_function", operator.getitem, (value, key), {}, meta) + + def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: + return self._fx("output", "output", (results,), {}, meta) + + def call_submodule( + self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...] + ) -> PassResult: + prev_tracer, self.tracer = self.tracer, self.ExportTracer( + self, graph_module.graph._codegen + ) + self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode + interpreter = self.ExportInterpreter(self, graph_module) + prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ) + inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs) + with fx_traceback.preserve_node_meta(): + interpreter.run(*inputs_data) + + new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph) + + self.tracer = prev_tracer + self.interpreter = prev_interpreter + return PassResult( + new_graph_module, + True, + ) + + def call(self, graph_module: fx.GraphModule) -> PassResult: + if not getattr(self, "_initialized", False): + raise ExportPassBaseError( + "ExportPass is not initialized with __init__().", + ) + + inputs = self.inputs(graph_module) + + fake_tensor_mode = None + for i in inputs: + if isinstance(i, FakeTensor): + assert ( + fake_tensor_mode is None or fake_tensor_mode is i.fake_mode + ), "Multiple fake tensor mode detected." + fake_tensor_mode = i.fake_mode + if fake_tensor_mode is None: + self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True) + fake_tensor_mode = nullcontext() # type: ignore[assignment] + dispatcher_mode = nullcontext() # type: ignore[assignment] + else: + fake_tensor_mode.allow_non_fake_inputs = True + self.tracer.fake_tensor_mode = fake_tensor_mode + dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment] + self.fake_tensor_mode = self.tracer.fake_tensor_mode + + with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr] + result = self.call_submodule(graph_module, tuple(inputs)) + + return result diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bca7faec519cd30a304d179f386f48694828c0ee Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/proxy_value.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/proxy_value.py new file mode 100644 index 0000000000000000000000000000000000000000..66592d48a45efca0851e51df19d07f6346d8a335 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/proxy_value.py @@ -0,0 +1,41 @@ +# pyre-strict +from typing import Union + +import torch + + +class ProxyValue: + # pyre-ignore + def __init__(self, data, proxy: Union[torch.fx.Proxy, torch.fx.Node]): + # pyre-ignore + self.data = data + self.proxy_or_node = proxy + + @property + def node(self) -> torch.fx.Node: + if isinstance(self.proxy_or_node, torch.fx.Node): + return self.proxy_or_node + assert isinstance(self.proxy_or_node, torch.fx.Proxy) + return self.proxy_or_node.node + + @property + def proxy(self) -> torch.fx.Proxy: + if not isinstance(self.proxy_or_node, torch.fx.Proxy): + raise RuntimeError( + f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}" + ) + return self.proxy_or_node + + def to_tensor(self) -> torch.Tensor: + assert isinstance(self.data, torch.Tensor) + return self.data + + def is_tensor(self) -> bool: + return isinstance(self.data, torch.Tensor) + + # pyre-ignore + def __iter__(self): + yield from self.data + + def __bool__(self) -> bool: + return bool(self.data) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67e7cc3d0658727123f8b9afd19a09618a3f32c8 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/union.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/union.py new file mode 100644 index 0000000000000000000000000000000000000000..8dfce61f0ab215932e08f4dbc180d36fa08c7a9b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/union.py @@ -0,0 +1,69 @@ +import functools +from dataclasses import fields +from typing import Hashable, Set + + +class _UnionTag(str): + _cls: Hashable + + @staticmethod + def create(t, cls): + tag = _UnionTag(t) + assert not hasattr(tag, "_cls") + tag._cls = cls + return tag + + def __eq__(self, cmp) -> bool: + assert isinstance(cmp, str) + other = str(cmp) + assert other in _get_field_names( + self._cls + ), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}" + return str(self) == other + + def __hash__(self): + return hash(str(self)) + + +@functools.lru_cache(maxsize=None) +def _get_field_names(cls) -> Set[str]: + return {f.name for f in fields(cls)} + + +class _Union: + _type: _UnionTag + + @classmethod + def create(cls, **kwargs): + assert len(kwargs) == 1 + obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type] + obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls) + return obj + + def __post_init__(self): + assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc] + + @property + def type(self) -> str: + try: + return self._type + except AttributeError as e: + raise RuntimeError( + f"Please use {type(self).__name__}.create to instantiate the union type." + ) from e + + @property + def value(self): + return getattr(self, self.type) + + def __getattribute__(self, name): + attr = super().__getattribute__(name) + if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type] + raise AttributeError(f"Field {name} is not set.") + return attr + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return f"{type(self).__name__}({self.type}={getattr(self, self.type)})" diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb911c2340b2414fa10820dfd38a6cf37af9164f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__init__.py @@ -0,0 +1,150 @@ +from typing import Any, Dict, List, Optional + +import torch.fx +import torch.utils._pytree as pytree + +__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"] + + +def compile( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + options: Optional[Dict[str, Any]] = None, +): + """ + Compile a given FX graph with TorchInductor. This allows compiling + FX graphs captured without using TorchDynamo. + + Args: + gm: The FX graph to compile. + example_inputs: List of tensor inputs. + options: Optional dict of config options. See `torch._inductor.config`. + + Returns: + Callable with same behavior as gm but faster. + """ + from .compile_fx import compile_fx + + return compile_fx(gm, example_inputs, config_patches=options) + + +def aot_compile( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + options: Optional[Dict[str, Any]] = None, +) -> str: + """ + Ahead-of-time compile a given FX graph with TorchInductor into a shared library. + + Args: + gm: The FX graph to compile. + example_inputs: List of tensor inputs. + options: Optional dict of config options. See `torch._inductor.config`. + + Returns: + Path to the generated shared library + """ + from .compile_fx import compile_fx_aot + + # We will serialize the pytree info into the .so as constant strings + in_spec = None + out_spec = None + if isinstance(gm.graph._codegen, torch.fx.graph._PyTreeCodeGen): + codegen = gm.graph._codegen + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.recompile() + + if codegen.pytree_info.in_spec is not None: + in_spec = codegen.pytree_info.in_spec + if codegen.pytree_info.out_spec is not None: + out_spec = codegen.pytree_info.out_spec + + else: + if hasattr(gm, "_in_spec"): + in_spec = gm._in_spec + if hasattr(gm, "_out_spec"): + out_spec = gm._out_spec + + serialized_in_spec = pytree.treespec_dumps(in_spec) if in_spec is not None else "" + serialized_out_spec = ( + pytree.treespec_dumps(out_spec) if out_spec is not None else "" + ) + + options = ( + { + "aot_inductor.serialized_in_spec": serialized_in_spec, + "aot_inductor.serialized_out_spec": serialized_out_spec, + } + if options is None + else { + **options, + "aot_inductor.serialized_in_spec": serialized_in_spec, + "aot_inductor.serialized_out_spec": serialized_out_spec, + } + ) + + return compile_fx_aot( + gm, + example_inputs, + config_patches=options, + ) + + +def list_mode_options( + mode: Optional[str] = None, dynamic: Optional[bool] = None +) -> Dict[str, Any]: + r"""Returns a dictionary describing the optimizations that each of the available + modes passed to `torch.compile()` performs. + + Args: + mode (str, optional): The mode to return the optimizations for. + If None, returns optimizations for all modes + dynamic (bool, optional): Whether dynamic shape is enabled. + + Example:: + >>> torch._inductor.list_mode_options() + """ + + mode_options: Dict[str, Dict[str, bool]] = { + "default": {}, + # enable cudagraphs + "reduce-overhead": { + "triton.cudagraphs": True, + }, + # enable max-autotune + "max-autotune-no-cudagraphs": { + "max_autotune": True, + }, + # enable max-autotune + # enable cudagraphs + "max-autotune": { + "max_autotune": True, + "triton.cudagraphs": True, + }, + } + return mode_options[mode] if mode else mode_options # type: ignore[return-value] + + +def list_options() -> List[str]: + r"""Returns a dictionary describing the optimizations and debug configurations + that are available to `torch.compile()`. + + The options are documented in `torch._inductor.config`. + + Example:: + + >>> torch._inductor.list_options() + """ + + from torch._inductor import config + + current_config: Dict[str, Any] = config.shallow_copy_dict() + + return list(current_config.keys()) + + +def cudagraph_mark_step_begin(): + "Indicates that a new iteration of inference or training is about to begin." + from .cudagraph_trees import mark_step_begin + + mark_step_begin() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09f63935450fd3abb816a5e08e277c4095bd49fe Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a13846e822778ef661cbedd4af561d3f305328f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/config.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb2fad195416e4d62ab9e618c45691bbf8f4335f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/config.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/coordinate_descent_tuner.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/coordinate_descent_tuner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f84631c6ac472509791633862962b29c9ee7192 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/coordinate_descent_tuner.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/debug.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/debug.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17135a27aad9e1b37031b5db171076287059f157 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/debug.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/dependencies.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/dependencies.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04f3ce6254d9868cf19bbb6c9cd2cdbecfff9ad9 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/dependencies.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/freezing.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/freezing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44429e5112bb071971d9f0a7a0ca104946b388ea Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/freezing.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/hooks.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/hooks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e03755d326e46159e3e55c06066f8e165812831 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/hooks.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0f13e023a06df0a9539dd818212b16cbf33c342 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/metrics.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc51f8296f02019044e500bc4da6687d0f346de2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/metrics.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..847ab24145623948b68ab7986b9eba4de1806c78 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b398374009f09d23e4074f254d0291d5c3237d0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/sizevars.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/sizevars.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d50c9e08653cc87b92549bb4c3691e0123ea6e3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/sizevars.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_case.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_case.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40b5fc602a571307389803468db81cca3a772eb4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_case.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_operators.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_operators.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f530f372ebd1cbc7baa2d08b432f541453fb781 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_operators.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/virtualized.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/virtualized.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1de49ccf3ef34bb6bcdece7e238b47b105761633 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/virtualized.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..669bc0bb3ccb41ba8c3353c49146570bf9ac501d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_prefix.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_prefix.h new file mode 100644 index 0000000000000000000000000000000000000000..d7773672c06c23bcdfbbf669e3fa6d6d976059b3 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_prefix.h @@ -0,0 +1,595 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) +#define INDUCTOR_USE_VECTOR_TYPES() 1 +#else +#define INDUCTOR_USE_VECTOR_TYPES() 0 +#endif + +#if INDUCTOR_USE_VECTOR_TYPES() +#include +#include +#include +#endif + +typedef at::Half half; +typedef at::BFloat16 bfloat16; + +typedef at::Float8_e4m3fn float8_e4m3fn; +typedef at::Float8_e5m2 float8_e5m2; + +template +struct Welford { + T mean = T(0); + T m2 = T(0); + T weight = T(0); +}; + + +template +struct IsVecType: std::false_type {}; + +#if INDUCTOR_USE_VECTOR_TYPES() +template +struct IsVecType>: std::true_type {}; +#endif + +template +Welford welford_combine(const Welford &a, const Welford &b) { + if constexpr (!IsVecType::value) { + if (a.weight == 0) { + return b; + } + if (b.weight == 0) { + return a; + } + } + auto delta = b.mean - a.mean; + auto new_weight = a.weight + b.weight; + auto wb_over_w = b.weight / new_weight; + if constexpr (IsVecType::value) { + // Guard against division by zero + wb_over_w = T::blendv(wb_over_w, T(0), new_weight == T(0)); + } + auto result = Welford{ + a.mean + delta * wb_over_w, + a.m2 + b.m2 + delta * delta * a.weight * wb_over_w, + new_weight + }; + return result; +} + +template +Welford welford_combine(const Welford &acc, T data) { + // Add a single data point + auto delta = data - acc.mean; + auto new_weight = acc.weight + T(1); + auto new_mean = acc.mean + delta / new_weight; + auto new_delta = data - new_mean; + auto result = Welford{ + new_mean, + acc.m2 + delta * new_delta, + new_weight + }; + return result; +} + +// Refer to https://github.com/pytorch/pytorch/blob/b5b36cf0c4e1958f1ff25120f5d4beeef3288187/ +// aten/src/ATen/native/SharedReduceOps.h#L419-L445 +template +inline bool greater_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) { + // If (a == b), then choose the one with lower idx, else max(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a > b); +} + +template +inline bool less_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) { + // If (a == b), then choose the one with lower idx, else min(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a < b); +} + +#if INDUCTOR_USE_VECTOR_TYPES() +template +inline at::vec::Vectorized vec_shuffle_down(at::vec::Vectorized x, size_t n) { + using Vec = at::vec::Vectorized; + alignas(alignof(Vec)) scalar_t array[Vec::size()]; + x.store(array); + for (size_t i = 0; i + n < Vec::size(); i += 2 * n) { + array[i] = array[i + n]; + } + return Vec::loadu(array); +} + +#ifdef CPU_CAPABILITY_AVX2 +inline at::vec::Vectorized vec_shuffle_down(at::vec::Vectorized x, size_t n) { + using vec_t = at::vec::Vectorized; +#define SHUFFLE_MASK(z, y, x, w) ((z << 6) | (y << 4) | (x << 2) | w) + switch (n) { + case 1: + return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(1, 1, 3, 3))); + case 2: + return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(2, 2, 2, 2))); + case 4: + return vec_t(_mm256_permute2f128_ps(x, x, SHUFFLE_MASK(1, 1, 1, 1))); + } + TORCH_CHECK(false, "Unhandled vec_shuffle_down value ", n); +} +#endif + +template +Welford welford_vec_reduce_all(Welford> acc) { + using Vec = at::vec::Vectorized; + for (size_t n = 1; n < Vec::size(); n *= 2) { + auto shuffled = Welford{ + vec_shuffle_down(acc.mean, n), + vec_shuffle_down(acc.m2, n), + vec_shuffle_down(acc.weight, n) + }; + acc = welford_combine(acc, shuffled); + } + + Welford result; + alignas(alignof(Vec)) scalar_t array[Vec::size()]; + acc.mean.store(array); + result.mean = array[0]; + + acc.m2.store(array); + result.m2 = array[0]; + + acc.weight.store(array); + result.weight = array[0]; + + return result; +} +#endif + + +template inline typename std::common_type::type mod(T a, U b) { return a % b; } +template <> inline float mod(float a, float b) { return std::fmod(a, b); } +template <> inline double mod(double a, double b) { return std::fmod(a, b); } + +template +inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) { + if (at::_isnan(a)) { + return a; + } + return a > b ? a : b; +} + +template +inline scalar_t min_propagate_nan(scalar_t a, scalar_t b) { + if (at::_isnan(a)) { + return a; + } + return a < b ? a : b; +} + +constexpr float uint32_to_uniform_float(uint32_t value) { + // maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + constexpr float scale = 4.6566127342e-10; + return static_cast(value & 0x7FFFFFFF) * scale; +} + +float normalized_rand_cpu(uint32_t seed, uint32_t offset) { + return uint32_to_uniform_float(at::Philox4_32(seed, 0, offset)()); +} + +float randn_cpu(uint32_t seed, uint32_t offset) { + at::Philox4_32 engine(seed, 0, offset); + return engine.randn(10); +} + +int64_t randint64_cpu(uint32_t seed, uint32_t offset, int64_t low, int64_t high) { + auto gen = at::Philox4_32(seed, 0, offset); + uint64_t r0 = gen(); + uint64_t r1 = gen(); + uint64_t result = r0 | (r1 << 32); + return static_cast(result % (high - low)) + low; +} + +template struct AsIntegerType { typedef T type; }; +template <> struct AsIntegerType { typedef uint32_t type; }; +template <> struct AsIntegerType { typedef uint64_t type; }; +template <> struct AsIntegerType { typedef uint16_t type; }; + +template +typename std::enable_if::value, T>::type +inline fetch_value(volatile T *addr) { + return *addr; +} + +template +typename std::enable_if::value, T>::type +inline fetch_value(volatile T *addr) { + return T(addr->x, T::from_bits()); +} + +template +typename std::enable_if::value>::type +atomic_add(volatile T *addr, T offset) { + typedef typename AsIntegerType::type alt_type; + + static_assert(sizeof(std::atomic) == sizeof(T), + "std::atomic issue"); + + alt_type expected; + + alt_type desired; + + std::atomic *atomic_addr = (std::atomic *)addr; + do { + T val = fetch_value(addr); + reinterpret_cast(&expected)[0] = val; + reinterpret_cast(&desired)[0] = val + offset; + } while (!atomic_addr->compare_exchange_weak(expected, desired, + std::memory_order_relaxed)); +} + +// Since C++20 float is supported by fetch_add, but the performance may not +// better than compare_exchange_weak, which can be checked by microbenchmark +// inductor_cpu_atomic.py +template +typename std::enable_if::value>::type +atomic_add(volatile T *addr, T offset) { + static_assert(sizeof(std::atomic) == sizeof(T), + "std::atomic issue"); + std::atomic *atomic_addr = (std::atomic *)addr; + atomic_addr->fetch_add(offset, std::memory_order_relaxed); +} + +// This function is used to convert bool or uint8 to float mask for +// vectorization. The caller needs to make sure the src represents TRUE/FALSE +// correctly. +template +inline float flag_to_float_scalar(T src) { + float ret; + *(uint32_t*)(&ret) = src ? 0xFFFFFFFF : 0; + return ret; +} + +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) + +inline at::vec::Vectorized masked_load(const float* src, at::vec::Vectorized mask) { +# if defined(CPU_CAPABILITY_AVX512) + at::vec::Vectorized zero_vec(0); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_loadu_ps(zero_vec, mmask, src); +# elif defined(CPU_CAPABILITY_AVX2) + auto all_ones = _mm256_set1_epi32(0xFFFFFFFF); + auto mmask = _mm256_cmpeq_epi32(_mm256_castps_si256(mask), all_ones); + return _mm256_maskload_ps(src, mmask); +# elif defined(CPU_CAPABILITY_ZVECTOR) + auto result = at::vec::Vectorized::loadu(src); + return (result & mask); +# else +# error Unsupported vectorization CPU capability +# endif +} + +template +typename std::enable_if::value || std::is_same::value, at::vec::Vectorized>::type +inline masked_load(const T* src, at::vec::Vectorized mask) { +# if defined(CPU_CAPABILITY_AVX512) + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), all_ones, _MM_CMPINT_EQ); + auto zero = _mm256_set1_epi16(0); + auto temp = _mm256_mask_loadu_epi16(zero, mmask, src); + return _mm512_inserti32x8(_mm512_castsi256_si512(temp), zero, 1); +# elif defined(CPU_CAPABILITY_AVX2) + auto all_ones = _mm256_set1_epi32(0xFFFFFFFF); + auto mmask_vec = _mm256_cmpeq_epi32(_mm256_castps_si256(mask), all_ones); + __at_align__ uint32_t mmask[8]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(mmask), mmask_vec); + __at_align__ uint16_t result[16]; + for (auto i = 0; i < 8; i++) { + result[i] = mmask[i] == 0xFFFFFFFF ? src[i].x: uint16_t(0); + } + return at::vec::Vectorized::loadu(result); +# elif defined(CPU_CAPABILITY_ZVECTOR) + auto result = at::vec::Vectorized::loadu(src, 8); + uint32_t maskdata[8] = { 0 }; + uint16_t maskdata_dest[16] = { 0 }; + mask.store(maskdata); + for (auto i = 0; i < 8; i++) { + maskdata_dest[i] = (maskdata[i] == 0xFFFFFFFF) ? 0xFFFF: 0; + } + auto maskvector = at::vec::Vectorized::loadu(maskdata_dest); + return (result & maskvector); +# else +# error Unsupported vectorization CPU capability +# endif +} + +template +typename std::enable_if::value || std::is_same::value, at::vec::Vectorized>::type +inline masked_load(const T* src, at::vec::Vectorized mask) { +# if defined(CPU_CAPABILITY_AVX512) + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), all_ones, _MM_CMPINT_EQ); + auto zero = _mm_set1_epi8(0); + auto temp = _mm_mask_loadu_epi8(zero, mmask, src); + return _mm512_inserti64x2(_mm512_set1_epi32(0), temp, 0); +# elif defined(CPU_CAPABILITY_AVX2) + auto all_ones = _mm256_set1_epi32(0xFFFFFFFF); + auto mmask_vec = _mm256_cmpeq_epi32(_mm256_castps_si256(mask), all_ones); + __at_align__ uint32_t mmask[8]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(mmask), mmask_vec); + __at_align__ T result[32]; + for (auto i = 0; i < 8; i++) { + result[i] = mmask[i] == 0xFFFFFFFF ? src[i]: T(0); + } + return at::vec::Vectorized::loadu(result); +# elif defined(CPU_CAPABILITY_ZVECTOR) + auto result = at::vec::Vectorized::loadu(src, 8); + uint32_t maskdata[8]; + T maskdata_dest[32] = { 0 }; + mask.store(maskdata); + for (auto i = 0; i < 8; i++) { + maskdata_dest[i] = (maskdata[i] == 0xFFFFFFFF) ? 0xFF: 0; + } + auto maskvector = at::vec::Vectorized::loadu(maskdata_dest); + return (result & maskvector); +# else +# error Unsupported vectorization CPU capability +# endif +} + +template +inline at::vec::Vectorized flag_to_float_vec(const T* src) { + __at_align__ float dst_tmp[at::vec::Vectorized::size()]; + #pragma unroll + for (int64_t i = 0; i < at::vec::Vectorized::size(); i++) { + dst_tmp[i] = flag_to_float_scalar(src[i]); + } + return at::vec::Vectorized::loadu(dst_tmp); +} + +template +inline at::vec::Vectorized cvt_lowp_fp_to_fp32( + at::vec::Vectorized src) { + at::vec::Vectorized res_vec1(0); + at::vec::Vectorized res_vec2(0); + std::tie(res_vec1, res_vec2) = at::vec::convert_to_float(src); + return res_vec1; +} + +template +inline at::vec::Vectorized cvt_fp32_to_lowp_fp( + at::vec::Vectorized src) { + return at::vec::convert_from_float(src, src); +} + +inline at::vec::Vectorized mask_convert_to_float(at::vec::Vectorized src) { + auto zeros = at::vec::Vectorized(0); + auto ones = at::vec::Vectorized(1); + return at::vec::Vectorized::blendv(zeros, ones, src); +} + +template +inline +typename std::enable_if::value || std::is_same::value, at::vec::Vectorized>::type +mask_convert_to_lowp(at::vec::Vectorized src) { + auto fp_vec = mask_convert_to_float(src); + return cvt_fp32_to_lowp_fp(fp_vec); +} + +template +inline at::vec::Vectorized vec_convert_to_mask(at::vec::Vectorized src) { + assert( + at::vec::Vectorized::size() == at::vec::Vectorized::size()); + at::vec::Vectorized res_vec(0); + __at_align__ float dst_tmp[at::vec::Vectorized::size()]; + __at_align__ SRC src_tmp[at::vec::Vectorized::size()]; + src.store(src_tmp); + +#pragma unroll + for (int i = 0; i < at::vec::Vectorized::size(); i++) { + *(uint32_t*)(dst_tmp + i) = src_tmp[i] ? 0xFFFFFFFF : 0; + } + + return res_vec.loadu(dst_tmp); +} + +template +inline at::vec::Vectorized to_float_mask(at::vec::Vectorized src) { + return vec_convert_to_mask(src); +} + +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) +template <> +inline at::vec::Vectorized to_float_mask(at::vec::Vectorized src) { +#if defined(CPU_CAPABILITY_AVX2) + return at::vec::Vectorized(_mm256_castsi256_ps(src)); +#else + return at::vec::Vectorized(_mm512_castsi512_ps(src)); +#endif +} +#endif + +template <> +inline at::vec::Vectorized to_float_mask(at::vec::Vectorized src) { + return src; +} + +inline at::vec::Vectorized to_float_mask(int src) { + union { + float fmask; + uint32_t imask; + } mask; + mask.imask = src ? 0xFFFFFFFF : 0; + return at::vec::Vectorized(mask.fmask); +} + +inline bool all_zero(at::vec::Vectorized src) { +# if defined(CPU_CAPABILITY_AVX512) + auto src_int = _mm512_castps_si512(src); + __mmask16 mask = _mm512_test_epi32_mask(src_int, src_int); + return mask == 0; +# elif defined(CPU_CAPABILITY_AVX2) + return _mm256_testz_ps(src, src); +# else + __at_align__ int mask[at::vec::Vectorized::size()]; + src.store(mask); + for (int i = 0; i < at::vec::Vectorized::size(); i++) { + if (mask[i] != 0) { + return false; + } + } + return true; +# endif +} + +inline bool vector_lane_mask_check(at::vec::Vectorized src, int lane) { +# if defined(CPU_CAPABILITY_AVX512) + return _mm512_movepi32_mask(_mm512_castps_si512(src)) & (1 << lane); +# elif defined(CPU_CAPABILITY_AVX2) + return _mm256_movemask_ps(src) & (1 << lane); +# else + __at_align__ int mask[at::vec::Vectorized::size()]; + src.store(mask); + return mask[lane] != 0; +# endif +} + +inline at::vec::Vectorized cvt_int64_to_fp32(at::vec::VectorizedN src) { +# if defined(CPU_CAPABILITY_AVX512) + auto low = _mm512_cvtepi64_ps(src[0]); + auto high = _mm512_cvtepi64_ps(src[1]); + return _mm512_insertf32x8(_mm512_castps256_ps512(low), high, 1); +# elif defined(CPU_CAPABILITY_AVX2) + auto low_double = at::vec::convert_to_fp_of_same_size(src[0]); + auto low = _mm256_cvtpd_ps(low_double); + auto high_double = at::vec::convert_to_fp_of_same_size(src[1]); + auto high = _mm256_cvtpd_ps(high_double); + return _mm256_insertf128_ps(_mm256_castps128_ps256(low), high, 1); +# else + constexpr int float_vec_size = at::vec::Vectorized::size(); + constexpr int int64_vec_size = at::vec::Vectorized::size(); + __at_align__ float result[float_vec_size]; + __at_align__ int64_t src_buf[int64_vec_size]; + for (int i = 0; i < 2; i++) { + src[i].store(src_buf + i * int64_vec_size); + for (int j = 0; j < int64_vec_size; j++) { + result[i * int64_vec_size + j] = static_cast(src_buf[i * int64_vec_size + j]); + } + } + return at::vec::Vectorized::loadu(result); +# endif +} + +inline at::vec::VectorizedN cvt_fp32_to_int64(at::vec::Vectorized src) { + at::vec::VectorizedN result; +# if defined(CPU_CAPABILITY_AVX512) + result[0] = _mm512_cvt_roundps_epi64(_mm512_castps512_ps256(src), _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC); + result[1] = _mm512_cvt_roundps_epi64(_mm512_extractf32x8_ps(src, 1), _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC); +# elif defined(CPU_CAPABILITY_AVX2) + auto int32_vec = at::vec::convert_to_int_of_same_size(src); + result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(int32_vec)); + result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(int32_vec, 1)); +# else + constexpr int float_vec_size = at::vec::Vectorized::size(); + constexpr int int64_vec_size = at::vec::Vectorized::size(); + __at_align__ float src_buf[float_vec_size]; + __at_align__ int64_t result_buf[int64_vec_size]; + src.store(src_buf); + for (int i = 0; i < 2; i++) { + for (int j = 0; j < int64_vec_size; j++) { + result_buf[j] = static_cast(src_buf[i * int64_vec_size + j]); + } + result[i] = at::vec::Vectorized::loadu(result_buf); + } +# endif + return result; +} + +inline at::vec::Vectorized cvt_int64_to_int32(at::vec::VectorizedN src) { +# if defined(CPU_CAPABILITY_AVX512) + auto low = _mm512_cvtepi64_epi32(src[0]); + auto high = _mm512_cvtepi64_epi32(src[1]); + return _mm512_inserti32x8(_mm512_castsi256_si512(low), high, 1); +# elif defined(CPU_CAPABILITY_AVX2) + auto low = _mm256_shuffle_epi32(src[0], _MM_SHUFFLE(2, 0, 2, 0)); + auto high = _mm256_shuffle_epi32(src[1], _MM_SHUFFLE(2, 0, 2, 0)); + auto low_perm = _mm256_permute4x64_epi64(low, _MM_SHUFFLE(3, 1, 2, 0)); + auto high_perm = _mm256_permute4x64_epi64(high, _MM_SHUFFLE(3, 1, 2, 0)); + return _mm256_blend_epi32(low_perm, high_perm, 0xF0); +# else + constexpr int int32_vec_size = at::vec::Vectorized::size(); + constexpr int int64_vec_size = at::vec::Vectorized::size(); + __at_align__ int32_t result[int32_vec_size]; + __at_align__ int64_t src_buf[int64_vec_size]; + for (int i = 0; i < 2; i++) { + src[i].store(src_buf + i * int64_vec_size); + for (int j = 0; j < int64_vec_size; j++) { + result[i * int64_vec_size + j] = static_cast(src_buf[i * int64_vec_size + j]); + } + } + return at::vec::Vectorized::loadu(result); +# endif +} + +inline at::vec::VectorizedN cvt_int32_to_int64(at::vec::Vectorized src) { + at::vec::VectorizedN result; +# if defined(CPU_CAPABILITY_AVX512) + result[0] = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(src)); + result[1] = _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(src, 1)); +# elif defined(CPU_CAPABILITY_AVX2) + result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(src)); + result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(src, 1)); +#else + constexpr int int32_vec_size = at::vec::Vectorized::size(); + constexpr int int64_vec_size = at::vec::Vectorized::size(); + __at_align__ int32_t src_buf[int32_vec_size]; + __at_align__ int64_t result_buf[int64_vec_size]; + src.store(src_buf); + for (int i = 0; i < 2; i++) { + for (int j = 0; j < int64_vec_size; j++) { + result_buf[j] = static_cast(src_buf[i * int64_vec_size + j]); + } + result[i] = at::vec::Vectorized::loadu(result_buf); + } +# endif + return result; +} + +inline at::vec::VectorizedN mask_convert_to_int64(at::vec::Vectorized src) { + return cvt_fp32_to_int64(mask_convert_to_float(src)); +} + +inline at::vec::Vectorized to_float_mask(at::vec::VectorizedN src) { + return to_float_mask(cvt_int64_to_int32(src)); +} + +#endif diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..7990af0c8668b83afe4c9b1762b15617f100d723 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -0,0 +1,1851 @@ +import functools +import os +import sys +from itertools import count +from typing import List, Optional, Tuple + +import sympy +from sympy import Expr + +import torch +import torch._ops +from .. import config, ir + +from ..codecache import CudaKernelParamCache +from ..utils import cache_on_self, sympy_product +from ..virtualized import V +from .common import IndentedBuffer +from .wrapper import EnterSubgraphLine, ExitSubgraphLine, pexpr, WrapperCodeGen + + +class CppWrapperCpu(WrapperCodeGen): + """ + Generates cpp wrapper for running on CPU and calls cpp kernels + """ + + def __init__(self): + if not hasattr(self, "device"): + self.device = "cpu" + super().__init__() + self.declare = "auto " + self.declare_maybe_reference = "decltype(auto) " + self.ending = ";" + self.open_bracket = "{" + self.closed_bracket = "}" + self.comment = "//" + self.namespace = "at::" + self.none_str = "nullptr" if config.abi_compatible else "at::Tensor()" + self.extern_call_ops = set() + self.size = "sizes()" + self.stride = "strides()" + self.cuda = False + self.supports_intermediate_hooks = False + self.outputs_need_copy = set() + self.kernel_callsite_id = count() + self.int_array_id = count() # for int array local variable declarations + self.declared_int_array_vars = set() + self.tmp_tensor_id = count() # for tmp tensor local variable declarations + self.arg_var_id = count() + self.used_cached_devices = set() + self.used_cached_dtypes = set() + self.cached_output_id = count() + self.scalar_to_tensor_id = count() + + from .cpp import cexpr, CppPrinter + + self.expr_printer = cexpr + + # CppPrinter sometimes calls at::native functions which causes problems in + # the ABI-compatible mode. Currently we are hitting this problem when codegen + # Grid computation expressions, but we my need to fix other size computation + # as well. + class GridExprCppPrinter(CppPrinter): + def _print_FloorDiv(self, expr): + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + assert expr.is_integer, "Expect integers in GridExprPrinter" + return f"({x}/{div})" + + self.grid_expr_printer = GridExprCppPrinter().doprint + + def generate_kernel_call( + self, + name, + call_args, + grid=None, + device_index=None, + cuda=True, + triton=True, + arg_types=None, + grid_fn: str = "grid", + triton_meta=None, + ): + """ + Generates kernel call code. + + cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ + if cuda: + return super().generate_kernel_call( + name, + call_args, + grid, + device_index, + cuda, + triton, + arg_types, + grid_fn, + ) + else: + if config.abi_compatible: + assert arg_types is not None and len(call_args) == len( + arg_types + ), "Mismatch call_args and arg_types in generate_kernel_call" + new_args = [] + for idx, arg in enumerate(call_args): + if "*" in arg_types[idx]: + var_name = f"var_{next(self.arg_var_id)}" + self.writeline( + f"auto* {var_name} = get_data_ptr_wrapper({arg});" + ) + new_args.append(f"({arg_types[idx]})({var_name})") + else: + # arg is a scalar + new_args.append(arg) + self.writeline(self.wrap_kernel_call(name, new_args)) + else: + self.writeline(self.wrap_kernel_call(name, call_args)) + + def write_constant(self, name, hashed): + # include a hash so our code cache gives different constants different files + self.header.writeline(f"// {name} {hashed}") + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + if V.graph.aot_mode: + for header_cpp_file in ("interface.cpp", "implementation.cpp"): + with open( + os.path.join( + os.path.dirname(__file__), "aoti_runtime", header_cpp_file + ) + ) as f: + self.header.splice(f.read()) + else: + self.header.splice( + """ + import torch + from torch._inductor.codecache import CppWrapperCodeCache + + cpp_wrapper_src = ( + ''' + """ + ) + + if config.abi_compatible: + if config.c_shim_version == "1": + self.header.splice("#include ") + else: + self.header.splice( + f"#include " + ) + self.header.splice( + """ + #include + #include + #include + """ + ) + if V.graph.aot_mode: + self.header.splice( + """ + #include + """ + ) + else: + self.header.splice( + """ + #include + #include + #include + #include + #include + #include + #include + #include + + #define reinterpret_tensor torch::inductor::_reinterpret_tensor + #define alloc_from_pool torch::inductor::_alloc_from_pool + """ + ) + + self.header.splice("#include ") + + if not V.graph.aot_mode: + self.header.splice( + """ + #include + + using namespace torch::aot_inductor; + """ + ) + + from .memory_planning import ALIGN_BYTES + + # Round up to the nearest multiple of ALIGN_BYTES + # ALIGN_BYTES must be a power of 2 + self.header.splice( + f""" + [[maybe_unused]] static int64_t align(int64_t nbytes) {{ + return (nbytes + {ALIGN_BYTES} - 1) & -{ALIGN_BYTES}; + }} + """ + ) + + def mark_output_type(self): + # mark output type to unwrap tensor back to python scalar + from ..ir import ShapeAsConstantBuffer + + output_is_tensor = dict() + for idx, x in enumerate(V.graph.graph_outputs): + if isinstance(x, ShapeAsConstantBuffer): + output_is_tensor[idx] = False + else: + output_is_tensor[idx] = True + + self.output_is_tensor = output_is_tensor + + def write_prefix(self): + if V.graph.is_const_graph: + # We do not write prefix for constant graph, it will be written by main module. + return + + if V.graph.aot_mode: + self.prefix.writeline("namespace torch {") + self.prefix.writeline("namespace aot_inductor {") + + def write_input_output_info( + self, + info_kind: str, + idx: int, + name: str, + ): + self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""") + + @staticmethod + def get_input_cpp_type(input): + assert config.use_minimal_arrayref_interface + from .cpp import DTYPE_TO_CPP + + if isinstance(input, sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype(input) + assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}" + return DTYPE_TO_CPP[dtype] + return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>" + + def write_wrapper_decl(self): + inputs_len = len(V.graph.graph_inputs.keys()) + if V.graph.aot_mode: + if config.use_minimal_arrayref_interface and not V.graph.is_const_graph: + from .cpp import DTYPE_TO_CPP + + input_cpp_types = ", ".join( + f"{CppWrapperCpu.get_input_cpp_type(x)}" + for x in V.graph.graph_inputs.values() + ) + + output_arrayref_types = ", ".join( + f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>" + for x in V.graph.graph_outputs + ) + + self.prefix.splice( + f""" + using AOTInductorModelInputs = std::tuple<{input_cpp_types}>; + using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>; + """ + ) + + if V.graph.const_module: + self.header.splice(V.graph.const_module.wrapper_code.header) + self.prefix.splice(V.graph.const_code) + + if V.graph.is_const_graph: + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + else: + if not config.aot_inductor.use_runtime_constant_folding: + # If we do not split the constant graph, we'll just create + # an empty implementation when wrapping the main module. + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) {} + + """ + ) + + run_impl_proto = """ + void AOTInductorModel::run_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + if config.use_minimal_arrayref_interface: + self.prefix.splice( + """ + template <> + AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface< + AOTInductorModelInputs, AOTInductorModelOutputs>( + const AOTInductorModelInputs& inputs, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + self.suffix.splice(run_impl_proto) + self.suffix.splice( + """ + AOTInductorModelInputs inputs; + convert_handles_to_inputs(input_handles, inputs); + auto outputs = run_impl_minimal_arrayref_interface( + inputs, stream, proxy_executor); + // NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this + // interface to perform well for a DSO using the minimal arrayref interface, all we need + // to do is provide ThreadLocalCachedTensor for each one! + convert_outputs_to_handles(outputs, output_handles); + } + """ + ) + + self.suffix.splice( + """ + extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface( + AOTInductorModelHandle model_handle, + const AOTInductorModelInputs& inputs, + AOTInductorModelOutputs& outputs) { + auto model = reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + outputs = model->run_impl_minimal_arrayref_interface( + inputs, + (torch::aot_inductor::DeviceStreamType)nullptr, + nullptr); + }) + } + """ + ) + else: + self.prefix.splice(run_impl_proto) + else: + self.prefix.splice( + """ + void inductor_entry_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed) + ) { + """ + ) + with self.prefix.indent(): + # assign inputs and outputs in both cases so the later codegen can be simplified + if not config.use_minimal_arrayref_interface: + if not V.graph.is_const_graph: + if V.graph.aot_mode: + num_args = len(V.graph.graph_inputs) + else: + # Weights are promoted in the JIT mode + num_args = len(V.graph.graph_inputs) + len(V.graph.constants) + self.prefix.splice( + """ + pybind11::gil_scoped_release release; + """ + ) + + if config.abi_compatible: + self.prefix.splice( + f""" + auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args}); + """ + ) + else: + # This looks dumb, but can avoid creating two versions of code in the AOTInductor runtime. + self.prefix.splice( + f""" + auto inputs = alloc_tensors_by_stealing_from_handles(input_handles, {num_args}); + """ + ) + + if inputs_len != 0: + for idx, input_key in enumerate(V.graph.graph_inputs.keys()): + if config.use_minimal_arrayref_interface: + self.prefix.writeline( + f"auto {input_key} = std::get<{idx}>(inputs);" + ) + continue + # unwrap input tensor back to scalar + if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + from .cpp import DTYPE_TO_CPP + + dtype = may_get_constant_buffer_dtype( + V.graph.graph_inputs[input_key] + ) + assert ( + dtype is not None + ), "Fails to get the dtype of the sympy.Expr" + cpp_dtype = DTYPE_TO_CPP[dtype] + if config.abi_compatible: + self.prefix.writeline(f"{cpp_dtype} {input_key};") + dtype_str = str(dtype).split(".")[-1] + self.prefix.writeline( + f"aoti_torch_item_{dtype_str}(inputs[{idx}], &{input_key});" + ) + else: + self.prefix.writeline( + f"{cpp_dtype} {input_key} = inputs[{idx}].item<{cpp_dtype}>();" + ) + else: + self.prefix.writeline( + f"auto {input_key} = std::move(inputs[{idx}]);" + ) + + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + for idx, constants_key in enumerate(V.graph.constants.keys()): + if V.graph.aot_mode: + # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. + # Don't call std::move here because it will cause constants_ to lose the ownership. + if config.abi_compatible: + self.prefix.writeline( + f"""auto {constants_key} = constants_->at({idx});""" + ) + else: + self.prefix.writeline( + f"auto {constants_key} = *tensor_handle_to_tensor_pointer(" + + f"""constants_->at({idx}));""" + ) + else: + # Append constants as inputs to the graph + constants_idx = inputs_len + idx + self.prefix.writeline( + f"auto {constants_key} = inputs[{constants_idx}];" + ) + + self.codegen_inputs(self.prefix, V.graph.graph_inputs) + + if V.graph.aot_mode: + if not V.graph.is_const_graph: + if config.use_minimal_arrayref_interface: + # TODO: input shape checking for regular tensor interface as well? + self.codegen_input_numel_asserts() + else: + self.prefix.writeline("inputs.clear();") + self.prefix.writeline( + "auto& kernels = static_cast(*this->kernels_.get());" + ) + + def codegen_input_numel_asserts(self): + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(buf.get_size()) == 0: + continue + numel = buf.get_numel() + self.prefix.writeline(f"assert_numel({name}, {numel});") + + def codegen_input_size_var_decl(self, code: IndentedBuffer, name): + if config.abi_compatible: + code.writeline(f"int64_t* {name}_size;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes({name}, &{name}_size));" + ) + else: + super().codegen_input_size_var_decl(code, name) + + def codegen_input_stride_var_decl(self, code: IndentedBuffer, name): + if config.abi_compatible: + code.writeline(f"int64_t* {name}_stride;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides({name}, &{name}_stride));" + ) + else: + super().codegen_input_stride_var_decl(code, name) + + def codegen_model_kernels(self): + self.prefix.writeline("namespace {") + self.prefix.writeline( + "class AOTInductorModelKernels : public AOTInductorModelKernelsBase {" + ) + self.prefix.writeline(" public:") + declare_kernel = set(self.src_to_kernel.values()) + declare_kernel.update( + entry[0] for entry in self.user_defined_kernel_cache.values() + ) + if V.graph.const_module: + declare_kernel.update( + V.graph.const_module.wrapper_code.src_to_kernel.values() + ) + for kernel in declare_kernel: + self.prefix.writeline(f" CUfunction {kernel}{{nullptr}};") + self.prefix.writeline("};") + self.prefix.writeline("} // namespace") + + def codegen_model_constructor(self): + """ + // Generated code example + AOTInductorModel::AOTInductorModel() + : AOTInductorModelBase(4, 1) { + inputs_info_[0].name = "input0"; + inputs_info_[0].dtype = "torch.float16"; + ... + constants_info_[0].name = "L__self___weight"; + constants_info_[0].dtype = at::kFloat; + constants_info_[0].offset = 0; + constants_info_[0].data_size = 8192; + constants_info_[0].shape = {64, 32}; + constants_info_[0].stride = {32, 1}; + ... + outputs_info_[0].name = "output0"; + outputs_info_[0].dtype = "torch.float16"; + } + """ + + num_inputs = len(V.graph.graph_inputs) + num_outputs = len(V.graph.graph_outputs) + num_constants = len(V.graph.constants) + self.prefix.splice( + f""" + AOTInductorModel::AOTInductorModel(std::shared_ptr constants_map, + std::shared_ptr> constants_array, + const std::string& device_str, + std::optional cubin_dir) + : AOTInductorModelBase({num_inputs}, {num_outputs}, {num_constants}, device_str, cubin_dir) {{ + """ + ) + + with self.prefix.indent(): + for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()): + assert not isinstance( + inp, sympy.Expr + ), f"input {name=} cannot be symbolic" + self.write_input_output_info("inputs_info_", idx, name) + + for idx, (name, tensor) in enumerate(V.graph.constants.items()): + assert isinstance(tensor, torch.Tensor) + self.prefix.writeline(f"""constants_info_[{idx}].name = "{name}";""") + self.prefix.writeline( + f"constants_info_[{idx}].dtype = static_cast({self.codegen_dtype(tensor.dtype)});" + ) + self.prefix.writeline( + f"constants_info_[{idx}].offset = {tensor.storage_offset()};" + ) + self.prefix.writeline( + f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};" + ) + from_folded = "true" if name in V.graph.folded_constants else "false" + self.prefix.writeline( + f"constants_info_[{idx}].from_folded = {from_folded};" + ) + + size_str = ", ".join([str(s) for s in tensor.size()]) + self.prefix.writeline(f"constants_info_[{idx}].shape = {{{size_str}}};") + + stride_str = ", ".join([str(s) for s in tensor.stride()]) + self.prefix.writeline( + f"constants_info_[{idx}].stride = {{{stride_str}}};" + ) + if name in V.graph.dynamo_flat_name_to_original_fqn: + original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get( + name, name + ) + elif name in V.graph.allocated_constant_name: + original_fqn = V.graph.allocated_constant_name[name] + else: + raise AssertionError("original_fqn must be set for constant") + self.prefix.writeline( + f"""constants_info_[{idx}].original_fqn = "{original_fqn}";""" + ) + self.prefix.writeline("update_constants_map(std::move(constants_map));") + self.prefix.writeline("update_constants_array(std::move(constants_array));") + + def escape_string(x): + return ( + x.replace("\\", "\\\\") + .replace('"', '\\"') + .replace("\n", "\\n") + .replace("\t", "\\t") + ) + + self.prefix.writeline( + f'in_spec_ = "{escape_string(config.aot_inductor.serialized_in_spec)}";' + ) + self.prefix.writeline( + f'out_spec_ = "{escape_string(config.aot_inductor.serialized_out_spec)}";' + ) + + for idx, output in enumerate(V.graph.graph_outputs): + assert not isinstance( + output, sympy.Expr + ), f"output {name=} cannot be symbolic" + name = f"output{idx}" + self.write_input_output_info("outputs_info_", idx, name) + + self.prefix.writeline( + "this->kernels_ = std::make_unique();" + ) + + self.prefix.writeline("}") + + def codegen_const_run_driver(self): + """ + // Generated code example + std::unordered_map AOTInductorModel::const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization + ) { + std::unordered_map folded_constants_map; + std::vector output_handles; + // build up output_handles over here. + _const_run_impl(output_handles, stream, proxy_executor); + // build up folded_constants_map + return folded_constants_map; + } + """ + + self.prefix.splice( + """ + std::unordered_map AOTInductorModel::const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization + ) { + """ + ) + if not config.aot_inductor.use_runtime_constant_folding: + self.prefix.splice( + """ + if (!initialization) { + std::cerr << "[WARNING] Calling constant_folding in model, but compiled with config: " + << "aot_inductor.use_runtime_constant_folding=False\\n"; + } + return {}; + } + """ + ) + return + + with self.prefix.indent(): + # This is a mapping to the index of constant folding graph's output + const_index_mapping: List[Optional[Tuple[int, str]]] = [None] * len( + V.graph.const_output_index + ) + for idx, (name, _) in enumerate(V.graph.constants.items()): + if name in V.graph.const_output_index: + const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload] + assert ( + None not in const_index_mapping + ), "Not all constant gets mapped for constant folding graph." + + self.prefix.writeline( + f""" + std::unordered_map folded_constants_map; + folded_constants_map.reserve({len(const_index_mapping)}); + std::vector output_handles({len(const_index_mapping)}); + """ + ) + + self.prefix.splice( + """ + // The below assignment of output_handles to constants is not used directly. + // It's only used to memo the correspondence of handle and constants. + """ + ) + + for output_idx, (const_idx, _) in enumerate(const_index_mapping): # type: ignore[misc] + self.prefix.writeline( + f"output_handles[{output_idx}] = constants_->at({const_idx});" + ) + + self.prefix.writeline( + "_const_run_impl(output_handles, stream, proxy_executor);" + ) + + for output_idx, (_, const_name) in enumerate(const_index_mapping): # type: ignore[misc] + self.prefix.writeline( + f'folded_constants_map["{const_name}"] = output_handles[{output_idx}];' + ) + self.prefix.writeline("return folded_constants_map;") + + self.prefix.writeline("}") + + def generate(self, is_inference): + if V.graph.aot_mode and not V.graph.is_const_graph: + self.codegen_model_kernels() + self.codegen_model_constructor() + self.codegen_const_run_driver() + self.write_wrapper_decl() + return super().generate(is_inference) + + def finalize_prefix(self): + cached_dtypes_buffer = IndentedBuffer() + if config.abi_compatible: + for dtype in self.used_cached_dtypes: + cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});") + for device in self.used_cached_devices: + cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});") + cached_dtypes_buffer.splice(self.prefix) + self.prefix = cached_dtypes_buffer + + def define_kernel( + self, name: str, kernel: str, metadata: Optional[str] = None, cuda=False + ): + self.header.splice(f"\n{kernel}\n") + + def codegen_scalar_to_tensor(self, output: str): + name = f"scalar_to_tensor_{next(self.scalar_to_tensor_id)}" + self.wrapper_call.writeline( + f"RAIIAtenTensorHandle {name} = scalar_to_tensor_handle({output});" + ) + return name + + @cache_on_self + def get_output_refs(self): + return [ + f"torch::tensor({x.codegen_reference(self.wrapper_call)})" + if isinstance(x, ir.ShapeAsConstantBuffer) and not config.abi_compatible + else x.codegen_reference(self.wrapper_call) + for x in V.graph.graph_outputs + ] + + def generate_return(self, output_refs): + cst_names = V.graph.constants.keys() + arr_iface = ( + not V.graph.is_const_graph and config.use_minimal_arrayref_interface + ) # For brevity. + + def use_thread_local_cached_output_tensor(idx, output): + cached_output_name = f"cached_output_{next(self.cached_output_id)}" + cache_type = "Array" if arr_iface else "Tensor" + self.wrapper_call.writeline( + f"thread_local ThreadLocalCachedOutput{cache_type}> " + f"{cached_output_name}({output});" + ) + if arr_iface: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + output_entry = f"std::get<{idx}>(output_arrayref_tensors)" + element_type = f"std::decay_t" + self.wrapper_call.writeline( + f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), " + f"output_handles[{idx}]));" + ) + + if arr_iface: + self.wrapper_call.writeline( + "AOTInductorModelOutputs output_arrayref_tensors;" + ) + for idx, output in enumerate(output_refs): + if config.abi_compatible: + output_buffer = V.graph.graph_outputs[idx] + if isinstance(output_buffer, ir.ShapeAsConstantBuffer): + # Need to wrap scalar into tensor as the main function returns a vector of tensors + output_tensor = self.codegen_scalar_to_tensor(output) + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output_tensor}.release();" + ) + continue + + output_is_tensor_handle_expr = ( + f"std::is_same_v," + "RAIIAtenTensorHandle> || " + f"std::is_same_v," + "AtenTensorHandle> || " + f"std::is_same_v," + "ConstantHandle>" + ) + self.wrapper_call.writeline( + f"if constexpr ({output_is_tensor_handle_expr}) {{" + ) + with self.wrapper_call.indent(): + if arr_iface: + cached_output_name = ( + f"cached_output_{next(self.cached_output_id)}" + ) + output_value_type = f"std::decay_t(output_arrayref_tensors).data()[0])>" + self.wrapper_call.writeline( + f"thread_local RAIIAtenTensorHandle {cached_output_name};" + ) + if output in cst_names: + # NOTE(return_constant): In some rare cases where we return + # a constant, we have to return a copy of this constant, + # because (1) constants are not owned by the Model instance + # (2) constants remain the same cross inference runs, + # assuming they are not updated at runtime Basically, we + # cannot release or transfer the ownership of any original + # constant to the user. + self.wrapper_call.writeline( + f"AtenTensorHandle {cached_output_name}_tmp;" + ) + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &{cached_output_name}_tmp);" + ) + self.wrapper_call.writeline( + f"{cached_output_name} = {cached_output_name}_tmp;" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name} = {output}.release();" + ) + self.wrapper_call.writeline( + f"convert_handle_to_arrayref_tensor({cached_output_name}, " + f"std::get<{idx}>(output_arrayref_tensors));" + ) + else: + if output in cst_names: + # See NOTE(return_constant) above. + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &output_handles[{idx}]);" + ) + else: + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output}.release();" + ) + self.wrapper_call.writeline("} else {") + with self.wrapper_call.indent(): + use_thread_local_cached_output_tensor(idx, output) + self.wrapper_call.writeline("}") + + else: + assert ( + not arr_iface + ), "minimal ArrayRef interface is only supported in ABI-compatible mode" + if output in cst_names: + output_expr = f"{output}.clone()" + # See NOTE(return_constant) above. + else: + output_expr = output + self.wrapper_call.writeline( + f"output_handles[{idx}] = reinterpret_cast(" + + f"new at::Tensor({output_expr}));" + ) + if arr_iface: + self.wrapper_call.writeline("return output_arrayref_tensors;") + + def generate_before_suffix(self, result): + if not V.graph.is_const_graph: + if V.graph.aot_mode: + result.writeline("} // AOTInductorModel::run_impl") + else: + result.writeline("} // inductor_entry_impl") + + def generate_end(self, result): + if V.graph.aot_mode: + if V.graph.is_const_graph: + result.writeline("} // AOTInductorModel::_const_run_impl") + else: + result.writeline("} // namespace aot_inductor") + result.writeline("} // namespace torch") + return + + result.writeline("'''\n)") + result.splice( + f""" + inductor_entry = CppWrapperCodeCache.load_pybinding( + ["std::vector"], cpp_wrapper_src, {self.cuda}, {len(V.graph.graph_outputs)}) + """ + ) + + # unwrap output tensor back to python scalar + if all(x for x in self.output_is_tensor.values()): + # If no ShapeAsConstantBuffer in the output, directly return the output as tensors + return_str = "return f(args_tensor)" + else: + outputs = [ + f"outputs[{i}]" if self.output_is_tensor[i] else f"outputs[{i}].item()" + for i in range(len(V.graph.graph_outputs)) + ] + outputs_str = f"[{', '.join(outputs)}]" + return_str = f""" + outputs = f(args_tensor) + return {outputs_str} + """ + + args_str = "args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]" + if V.graph.constants: + # Append constants to the input args for cpp wrapper. + # Python wrapper directly gets the value inside the wrapper call + # as a global variable passed when calling exec(code, mod.__dict__, mod.__dict__). + # For cpp wrapper, we need to pass this python value to the inductor_entry_impl function explicitly. + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + constants_str = f"[{', '.join(V.graph.constants.keys())}]" + args_str += f""" + constants_tensor = {constants_str} + args_tensor.extend(constants_tensor) + """ + + # Wrap the func to support setting result._boxed_call = True + result.splice( + f""" + def _wrap_func(f): + def g(args): + {args_str} + {return_str} + return g + call = _wrap_func(inductor_entry) + """ + ) + + def generate_c_shim_extern_kernel_call(self, kernel, args): + # In the abi_compatible mode, we call fallback aten ops through a C shim layer + self.allow_stack_allocation = False + kernel_tokens = kernel.split("::") + kernel_suffix = kernel_tokens[-1] + if kernel_suffix == "call": + kernel_suffix = kernel_tokens[-2] + if config.c_shim_version == "1": + shim_fn = f"aoti_torch_{kernel_suffix}" + else: + shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}" + + # HACK: val_to_arg_str jams multiple arguments together using a comma. If that + # ever breaks, it needs to be reworked to be able to return multiple arguments, + # and the split-on-comma code here needs to be removed. + wrapped_args = [] + for x in args: + pieces = x.split(", ") + for piece in pieces: + # We only really *need* convert_arrayref_tensor_to_tensor for + # ArrayRefTensors. The code flowing into here uses `0` for nullptr, + # which convert_arrayref_tensor_to_tensor would blindly coerce to int, + # so just avoid wrapping integers. + if not piece.isdigit(): + piece = f"convert_arrayref_tensor_to_tensor({piece})" + wrapped_args.append(piece) + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(wrapped_args)}));" + ) + + def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args): + # registered output buffer name + name = extern_kernel.name + output_handle_name = f"{name}_handle" + self.writeline(f"AtenTensorHandle {output_handle_name};") + output_arg = f"&{output_handle_name}" + self.generate_c_shim_extern_kernel_call( + extern_kernel.get_kernel_name(), args + [output_arg] + ) + self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});") + + def generate_extern_kernel_alloc(self, extern_kernel, args): + if config.abi_compatible: + self.generate_c_shim_extern_kernel_alloc(extern_kernel, args) + else: + super().generate_extern_kernel_alloc(extern_kernel, args) + + def generate_c_shim_fallback_kernel(self, fallback_kernel, args): + output_args = [] + output_raii_handles = [] + output_name_base = fallback_kernel.get_name() + for idx, output in enumerate(fallback_kernel.outputs): + if isinstance(output, ir.MultiOutput): + name = f"{output.get_name()}" + output_handle_name = f"{name}_handle" + if output.indices: + assert ( + output.indices[0][1] == idx + ), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}" + self.writeline(f"AtenTensorHandle {output_handle_name};") + output_args.append(f"&{output_handle_name}") + output_raii_handles.append( + f"RAIIAtenTensorHandle {name}({output_handle_name});" + ) + elif isinstance(output, int): + output_name = f"{output_name_base}_{idx}" + self.writeline(f"int64_t {output_name} = {output};") + output_args.append(f"&{output_name}") + elif output is None: + output_args.append("nullptr") + else: + raise NotImplementedError("unsupported type of {output=}") + args = args + output_args + assert ( + fallback_kernel.abi_compatible_kernel is not None + ), f"abi_compatible_kernel is None for {fallback_kernel.python_kernel_name=}" + self.generate_c_shim_extern_kernel_call( + fallback_kernel.abi_compatible_kernel, args + ) + for raii_handle in output_raii_handles: + self.writeline(raii_handle) + + def generate_fallback_kernel(self, fallback_kernel, args): + if config.abi_compatible: + self.generate_c_shim_fallback_kernel(fallback_kernel, args) + else: + super().generate_fallback_kernel(fallback_kernel, args) + + def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel): + if output_view: + output_as_strided = f"{output_view.codegen_reference()}" + output_name = f"{output_view.get_name()}_as_strided" + self.writeline(f"auto {output_name} = {output_as_strided};") + + args.insert(0, output_name) + else: + args.insert(0, f"{codegen_reference}") + + if config.abi_compatible: + self.generate_c_shim_extern_kernel_call(kernel, args) + else: + self.writeline(self.wrap_kernel_call(kernel, args)) + + def generate_user_defined_triton_kernel( + self, kernel_name, grid, configs, args, triton_meta + ): + assert len(grid) != 0 + if len(grid) == 1: + grid_decision = grid[0] + else: + meta = CudaKernelParamCache.get(kernel_name) + assert meta is not None + grid_decision = None + for i, c in enumerate(configs): + if all(arg == meta["meta"][key] for key, arg in c.kwargs.items()): + grid_decision = grid[i] + break + assert grid_decision is not None + + self.generate_kernel_call( + kernel_name, + args, + grid=grid_decision, + device_index=V.graph.scheduler.current_device.index, + cuda=True, + triton=True, + triton_meta=triton_meta, + ) + + def generate_scatter_fallback( + self, output, inputs, kernel, python_kernel_name, src_is_tensor, reduce, kwargs + ): + # TODO: support other overload for cpp wrapper and remove the below assertions + if config.abi_compatible: + # call the ABI shim function instead of the ATen one + kernel = kernel.replace("at::", "aoti_torch_") + line = f"{kernel}({output}, {','.join(map(str, inputs))}" + if python_kernel_name == "aten.scatter_": + if src_is_tensor: + if reduce: + line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" + else: + assert ( + reduce is None + ), "Expect reduce to be None for aten.scatter_ with scalar src" + else: + line += f", {','.join(kwargs)}" + line += f"){self.ending}" + self.writeline(line) + + def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + if V.graph.aot_mode and V.graph.cpp_wrapper and config.abi_compatible: + # See the comment in codegen_reinterpret_view about why having something like + # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding + # tensor prematurely deallocated, thus this std::vector().data() trick here. + indices_str = ( + f"std::vector{{{', '.join(indices)}}}.data()" + ) + args = [x, indices_str, str(len(indices)), values, accumulate] + else: + indices_str = ( + f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}" + ) + args = [x, indices_str, values, accumulate] + + args.insert(0, x) # set x as the output tensor, this fallback mutates x. + self.writeline(self.wrap_kernel_call(kernel, args)) + + def add_benchmark_harness(self, output): + if V.graph.aot_mode: + return + super().add_benchmark_harness(output) + + def codegen_sizevar(self, x: Expr) -> str: + return self.expr_printer(V.graph.sizevars.simplify(x)) + + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: + if config.abi_compatible: + # in the abi_compatible mode, outputs are returned via arguments + return name + else: + return f"std::get<{index}>({basename})" + + def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: + parts = list(map(self.codegen_sizevar, shape)) + if len(parts) == 0: + return "{}" + if len(parts) == 1: + return f"{{{parts[0]}, }}" + return f"{{{', '.join(parts)}}}" + + def codegen_dynamic_scalar(self, node): + from .cpp import DTYPE_TO_ATEN, DTYPE_TO_CPP + + (data,) = (t.codegen_reference() for t in node.inputs) + if config.abi_compatible: + dtype = node.inputs[0].get_dtype() + dtype_str = str(dtype).split(".")[-1] + self.writeline(f"{DTYPE_TO_CPP[dtype]} {node.sym};") + self.writeline(f"aoti_torch_item_{dtype_str}({data}, &{node.sym});") + # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again + self.unbacked_symbol_decls.add(str(node.sym)) + else: + if node.is_bool: + self.writeline(f"bool {node.sym} = {data}.item() ? 1 : 0;") + else: + convert_type = DTYPE_TO_ATEN[node.inputs[0].get_dtype()].replace( + "at::k", "to" + ) + self.writeline(f"auto {node.sym} = {data}.item().{convert_type}();") + + def can_stack_allocate_buffer(self, buffer): + return ( + self.allow_stack_allocation + and buffer.get_device().type == "cpu" + and self.can_prove_buffer_has_static_shape(buffer) + and ir.is_contiguous_strides_for_shape( + buffer.get_stride(), buffer.get_size() + ) + ) + + def make_buffer_free(self, buffer): + return ( + "" + if isinstance(buffer.get_layout(), ir.MultiOutputLayout) + or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers) + or ( + config.use_minimal_arrayref_interface + and V.graph.aot_mode + and buffer.get_name() in V.graph.graph_inputs + ) + else f"{buffer.get_name()}.reset();" + ) + + def make_free_by_names(self, names_to_del: List[str]): + return " ".join(f"{name}.reset();" for name in names_to_del) + + def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): + if config.abi_compatible: + return f"auto {new_name} = std::move({old_name}); // reuse" + else: + return super().codegen_exact_buffer_reuse(old_name, new_name, del_line) + + def generate_profiler_mark_wrapper_call(self, stack): + self.wrapper_call.writeline( + 'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef());' + ) + + def write_triton_header_once(self): + pass + + def generate_start_graph(self): + pass + + def generate_end_graph(self): + pass + + def generate_inf_and_nan_checker(self, nodes): + for buf in nodes.get_names(): + # TODO: Add buf name directly into check_inf_and_nan. + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_check_inf_and_nan({buf}));" + ) + + def codegen_device(self, device): + if config.abi_compatible: + self.used_cached_devices.add(device.type) + return f"cached_torch_device_type_{device.type},{device.index if device.index else 0}" + else: + from .cpp import DEVICE_TO_ATEN + + return ( + f"c10::Device({DEVICE_TO_ATEN[device.type]}, {device.index})" + if device.index is not None + else f"{DEVICE_TO_ATEN[device.type]}" + ) + + def codegen_dtype(self, dtype): + if config.abi_compatible: + dtype_str = str(dtype).split(".")[-1] + self.used_cached_dtypes.add(dtype_str) + return f"cached_torch_dtype_{dtype_str}" + else: + from .cpp import DTYPE_TO_ATEN + + return DTYPE_TO_ATEN[dtype] + + @functools.lru_cache(None) + def codegen_int_array_var( + self, + int_array: str, + writer=None, + known_statically=False, + graph=None, # for per-graph caching + ): + # Because the memory planning is done in two passes (see the implementation + # of self.generate), the writeline behavior is different in the two passes. + # As a result, the emitted int array declarations may appear in a later + # position of the generated code, so the second pass codegen should not + # reuse int array declarations generated in the first pass + if writer is None: + # The first pass codegen uses `self` as the writer + writer = self + + var = f"int_array_{next(self.int_array_id)}" + if var not in self.declared_int_array_vars: + self.declared_int_array_vars.add(var) + if known_statically: + writer.writeline(f"static constexpr int64_t {var}[] = {int_array};") + else: + writer.writeline(f"int64_t {var}[] = {int_array};") + return var + + def make_buffer_allocation(self, buffer): + return self.make_allocation( + buffer.get_name(), + buffer.get_device(), + buffer.get_dtype(), + buffer.get_size(), + buffer.get_stride(), + buffer if self.can_stack_allocate_buffer(buffer) else None, + ) + + def make_allocation( + self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None + ): + orig_stride = stride + device_str = self.codegen_device(device) + dtype_code = self.codegen_dtype(dtype) + size = self.codegen_shape_tuple(shape) + stride = self.codegen_shape_tuple(orig_stride) + if config.abi_compatible: + size_array_var = self.codegen_int_array_var( + size, + self.wrapper_call, + known_statically=self.is_statically_known_list_of_ints(shape), + graph=self.get_codegened_graph(), + ) + stride_array_var = self.codegen_int_array_var( + stride, + self.wrapper_call, + known_statically=self.is_statically_known_list_of_ints(orig_stride), + graph=self.get_codegened_graph(), + ) + device_type, device_id = device_str.split(",") + device_idx = "this->device_idx_" if V.graph.aot_mode else device_id + if buffer_if_can_stack_allocate is not None: + from .cpp import DTYPE_TO_CPP + + self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate + cpp_type = DTYPE_TO_CPP[dtype] + numel = buffer_if_can_stack_allocate.get_numel() + # Note: we don't zero storage because empty_strided doesn't zero either. + self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];") + args = [ + f"{name}_storage", + size_array_var, + stride_array_var, + device_type, + device_idx, + ] + return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});" + + args = [ + str(len(shape)), + size_array_var, + stride_array_var, + dtype_code, + device_type, + device_idx, + f"&{name}_handle", + ] + + self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" + ) + + return f"RAIIAtenTensorHandle {name}({name}_handle);" + + if V.graph.aot_mode and device_str.startswith("c10::Device("): + tensor_device = f"{device_str.split(',')[0]}, this->device_idx_)" + else: + tensor_device = device_str + + if device.type == "cpu": + return f"at::Tensor {name} = at::detail::empty_strided_cpu({size}, {stride}, {dtype_code});" + if device.type == "cuda": + return ( + f"at::Tensor {name} = at::detail::empty_strided_cuda(" + f"{size}, {stride}, {dtype_code}, c10::DeviceType::CUDA);" + ) + return ( + f"{self.declare}{name} = {self.namespace}empty_strided(" + f"{size}, {stride}, at::TensorOptions({tensor_device}).dtype({dtype_code})){self.ending}" + ) + + def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + if config.abi_compatible: + size = self.codegen_shape_tuple(shape) + stride = self.codegen_shape_tuple(stride) + tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" + args = [ + name, + pexpr(offset), # bytes not numel + self.codegen_dtype(dtype), + str(len(shape)), + self.codegen_int_array_var( + size, self.wrapper_call, graph=self.get_codegened_graph() + ), + self.codegen_int_array_var( + stride, self.wrapper_call, graph=self.get_codegened_graph() + ), + f"&{tmp_name}", + ] + self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));" + ) + return f"RAIIAtenTensorHandle({tmp_name})" + + return "alloc_from_pool({})".format( + ", ".join( + [ + name, + pexpr(offset), # bytes not numel + self.codegen_dtype(dtype), + self.codegen_shape_tuple(shape), + self.codegen_shape_tuple(stride), + ] + ) + ) + + def codegen_reinterpret_view( + self, data, size_list, stride_list, offset, writer + ) -> str: + dim = str(len(size_list)) + size = self.codegen_shape_tuple(size_list) + stride = self.codegen_shape_tuple(stride_list) + offset = self.codegen_sizevar(offset) + + if config.abi_compatible: + tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" + # Because the memory planning is done in two passes (see the implementation + # of self.generate), the writeline behavior is different in the two passes. + if writer is None: + writer = self + + args = [ + f"{data.get_name()}", + dim, + self.codegen_int_array_var( + size, + writer, + known_statically=self.is_statically_known_list_of_ints(size_list), + graph=self.get_codegened_graph(), + ), + self.codegen_int_array_var( + stride, + writer, + known_statically=self.is_statically_known_list_of_ints(stride_list), + graph=self.get_codegened_graph(), + ), + offset, + ] + + def gen_reinterpret_call(writer, args): + writer.writeline( + f"auto {tmp_name} = reinterpret_tensor_wrapper({', '.join(args)});" + ) + + if ( + self.can_stack_allocate_buffer(data) + and self.is_statically_known_list_of_ints(size_list) + and self.is_statically_known_list_of_ints(stride_list) + and ir.is_contiguous_strides_for_shape(stride_list, size_list) + ): + gen_reinterpret_call(writer, args) + return tmp_name + + gen_reinterpret_call(writer, args) + + # NB, the return handle here represents a temporary tensor, which will be automatically + # released. + # Here's a sample usage in the cpp wrapper code: + # ``` + # aoti_torch_addmm_out( + # buf1, + # arg1_1, + # RAIIAtenTensorHandle(tmp_tensor_handle_0), + # buf0, + # 1L, + # 1L)); + # ``` + # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out. + # This could be problematic when it's used in a different pattern, for example: + # ```` + # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}; + # aoti_torch_proxy_executor_call_function(..., tensor_args); + # ```` + # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter + # kernel call. + # + # This is solved by updating the proxy_executor invocation to + # ``` + # aoti_torch_proxy_executor_call_function(..., + # std::vector{ + # RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6 + # }.data() + # ); + # ``` + return f"wrap_with_raii_handle_if_needed({tmp_name})" + else: + args = [data.get_name(), size, stride, offset] + return f"reinterpret_tensor({', '.join(args)})" + + def codegen_device_copy(self, src, dst): + if config.abi_compatible: + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_tensor_copy_(expensive_copy_to_tensor_if_needed({src}), {dst}));" + ) + else: + self.writeline(f"{dst}.copy_({src});") + + def codegen_multi_output(self, name, value): + # in the abi_compatible mode, outputs are retrieved by passing + # output pointers, so we skip its codegen here. + if not config.abi_compatible: + super().codegen_multi_output(name, value) + + def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): + for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs): + if config.abi_compatible: + # in ABI-compatible mode, we copy the underlying at::Tensor of the conditional + # input (outer_input) into another at::Tensor to be used as a subgraph input + # (inner_input) in the nested scope. we can't std::move here, as the codegened + # outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we + # can't necessarily std::move it back to the origin (x). + self.writeline(f"AtenTensorHandle {inner_input}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));" + ) + self.writeline( + f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);" + ) + else: + self.writeline( + f"{self.declare}{inner_input} = {outer_input}{self.ending}" + ) + + def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): + for inner_output, outer_output in zip( + subgraph.graph.graph_outputs, outer_outputs + ): + src = inner_output.codegen_reference() + if config.abi_compatible: + # in ABI-compatible mode, we need to std::move subgraph output (inner_output) + # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy + # constructor is deleted. + src = f"std::move({src})" + self.writeline(f"{outer_output} = {src}{self.ending}") + + def codegen_conditional(self, conditional): + name = conditional.get_name() + outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands] + if config.abi_compatible: + outer_outputs = [] + for out in conditional.outputs: + # in ABI-compatible mode, ir.MultiOutput is not codegened, + # hence pre-declare output variables directly and separately + self.writeline(f"RAIIAtenTensorHandle {out.get_name()};") + outer_outputs.append(out.get_name()) + predicate = f"{conditional.predicate.get_name()}_scalar" + self.writeline(f"bool {predicate};") + # in ABI-compatible mode, we need to use the ABI shim function + # to extract a C++ bool from the unrelying scalar bool Tensor + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_bool({conditional.predicate.codegen_reference()}, &{predicate}));" + ) + else: + # in non-ABI-compatible mode, we can codegen the conditional outputs + # as array of at::Tensor instances, as the ir.MultiOutput is codegened + outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] + self.writeline(f"at::Tensor {name}[{len(conditional.outputs)}];") + predicate = f"{conditional.predicate.codegen_reference()}.item()" + + self.writeline(f"if ({predicate}) {{") + self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) + self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("} else {") + self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) + self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("}") + + def generate_extern_kernel_args_decl_if_needed( + self, op_overload, raw_args, output_args + ): + arg_types = [x.real_type for x in op_overload._schema.arguments] + return_types = [x.type for x in op_overload._schema.returns] + + new_tensor_args = [] + new_int_args = [] + + def fill_args(arg, arg_type): + static_arg_types = ( + torch.FloatType, + torch.BoolType, + torch.StringType, + torch.Type, + torch.DeviceObjType, + ) + inductor_tensor_buffers = ( + ir.Buffer, + ir.ReinterpretView, + ) + + if isinstance(arg_type, torch.TensorType): + assert isinstance(arg, inductor_tensor_buffers), f"got {type(arg)}" + new_tensor_args.append(f"{arg.codegen_reference()}") + elif isinstance(arg_type, torch.IntType): + # int + new_int_args.append(str(arg)) + elif isinstance(arg_type, torch.SymIntType): + # SymInt + expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg + new_int_args.append(self.expr_printer(expr)) + elif isinstance(arg_type, torch.NumberType): + # Scalar of type int + assert isinstance(arg, (int, float, bool)) + # Only treat int Scalar as dynamic + if isinstance(arg, int): + new_int_args.append(str(arg)) + elif isinstance(arg_type, torch.ListType): + assert isinstance(arg, (list, tuple)) + + # List[Tensor] + if isinstance(arg_type.getElementType(), torch.TensorType): + new_tensor_args.extend([f"{a.codegen_reference()}" for a in arg]) + # List[Optional[Tensor]] + elif isinstance( + arg_type.getElementType(), torch.OptionalType + ) and isinstance( + arg_type.getElementType().getElementType(), torch.TensorType + ): + new_tensor_args.extend( + [f"{a.codegen_reference()}" for a in arg if a is not None] + ) + # List[int] + elif isinstance(arg_type.getElementType(), torch.IntType): + new_int_args.extend([str(a) for a in arg]) + # List[SymInt] + elif isinstance(arg_type.getElementType(), torch.SymIntType): + expressions = [ + a.node.expr if isinstance(a, torch.SymInt) else a for a in arg + ] + new_int_args.extend( + [self.expr_printer(expr) for expr in expressions] + ) + # List[Scalar] + elif isinstance(arg_type.getElementType(), torch.NumberType): + # Only treat int Scalar as dynamic + is_int_type = [isinstance(a, int) for a in arg] + if any(is_int_type): + assert all( + is_int_type + ), "AOTInductor only supports int scalars of the same type" + new_int_args.extend([str(a) for a in arg]) + else: + assert isinstance( + arg_type.getElementType(), static_arg_types # type: ignore[arg-type] + ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + else: + assert isinstance( + arg_type, static_arg_types # type: ignore[arg-type] + ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + + for arg, arg_type in zip(raw_args, arg_types): + if arg is not None: + if isinstance(arg_type, torch.OptionalType): + fill_args(arg, arg_type.getElementType()) + else: + fill_args(arg, arg_type) + + def fill_output_arg(arg, return_type): + if isinstance(return_type, torch.TensorType): + self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);") + new_tensor_args.append(f"{arg}") + elif isinstance(return_type, torch.SymIntType): + raise NotImplementedError("NYI support for return type: SymInt") + elif isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.SymIntType + ): + raise NotImplementedError("NYI support for return type: List[SymInt]") + else: + raise AssertionError(f"Unsupported return type found: {return_type}") + + # TODO: Only support tensor(s) returns for now, SymInt is not implemented yet + for return_type in return_types: + if isinstance(return_type, (torch.TensorType)): + pass + elif isinstance(return_type, torch.OptionalType): + assert isinstance(return_type.getElementType(), torch.TensorType) + elif isinstance(return_type, torch.ListType): + assert isinstance(return_type.getElementType(), torch.TensorType) + else: + raise NotImplementedError( + f"return type {return_type} is not yet supported." + ) + + for output_arg in output_args: + assert output_arg is not None, "Optional return types are not yet supported" + if isinstance(output_arg, (list, tuple)): + for out in output_arg: + fill_output_arg(out, torch.TensorType.get()) + else: + fill_output_arg(output_arg, torch.TensorType.get()) + + return new_tensor_args, new_int_args + + def generate_extern_kernel_alloc_and_find_schema_if_needed( + self, + name, + kernel, + codegen_args, + cpp_op_schema, + cpp_kernel_key, + cpp_kernel_overload_name="", + op_overload=None, + raw_args=None, + outputs=None, + ): + if config.is_fbcode(): + assert op_overload is not None + assert raw_args is not None + assert outputs is not None + + return self.generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode( + name, + cpp_kernel_key, + op_overload, + raw_args, + outputs, + ) + else: + return self.generate_extern_kernel_alloc_and_find_schema_if_needed_oss( + name, + kernel, + codegen_args, + cpp_op_schema, + cpp_kernel_key, + cpp_kernel_overload_name, + ) + + def generate_extern_kernel_alloc_and_find_schema_if_needed_oss( + self, + name, + kernel, + codegen_args, + cpp_op_schema, + cpp_kernel_key, + cpp_kernel_overload_name="", + ): + if cpp_kernel_key not in self.extern_call_ops: + self.writeline( + f"static auto op_{cpp_kernel_key} = c10::Dispatcher::singleton()" + ) + self.writeline( + f'\t.findSchemaOrThrow("{kernel}", "{cpp_kernel_overload_name}")' + ) + self.writeline(f"\t.typed<{cpp_op_schema}>();") + self.extern_call_ops.add(cpp_kernel_key) + + self.writeline( + f"auto {name} = op_{cpp_kernel_key}.call({', '.join(codegen_args)});" + ) + + def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode( + self, + name, + cpp_kernel_key, + op_overload, + raw_args, # contains both args and flatten kwargs + outputs, + ): + def extract_output_name(out): + assert out is not None, "None, i.e. optional output is not supported" + if isinstance(out, ir.MultiOutput): + return out.get_name() + elif isinstance(out, (list, tuple)): + return type(out)(extract_output_name(o) for o in out) + else: + raise AssertionError(f"Unexpected output: {type(out)}") + + # output_args has the same pytree structure as outputs + output_args = extract_output_name(outputs) + if isinstance(output_args, str): + output_args = [output_args] + + ( + tensor_call_args, + int_call_args, + ) = self.generate_extern_kernel_args_decl_if_needed( + op_overload, raw_args, output_args + ) + + tensor_call_args_str = ", ".join(tensor_call_args) + int_call_args_str = ", ".join(int_call_args) + + extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1 + + self.writeline( + f"aoti_torch_proxy_executor_call_function(proxy_executor, " + f"{extern_kernel_node_index}, " + f"{len(int_call_args)}, " + f"std::vector{{{int_call_args_str}}}.data(), " + f"{len(tensor_call_args)}, " + f"std::vector{{{tensor_call_args_str}}}.data());" + ) + + self.extern_call_ops.add(cpp_kernel_key) + + def generate_reset_kernel_saved_flags(self): + pass + + def generate_save_uncompiled_kernels(self): + pass + + def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str: + if ( + config.abi_compatible + and not is_legacy_abi + and isinstance(type_, torch.OptionalType) + ): + if val is None: + return "0" # nullptr is not available in C + if not isinstance(type_.getElementType(), torch.TensorType): + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};") + return f"&{var_name}" + elif config.c_shim_version == "2": + # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim + base_handle = self.val_to_arg_str(val) + if "wrap_with_raii_handle_if_needed" in base_handle: + # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to + # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call. + tmp_var_name = f"var_{next(self.arg_var_id)}" + self.writeline( + f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};" + ) + base_handle = tmp_var_name + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"AtenTensorHandle {var_name} = {base_handle}.get();") + return f"&{var_name}" + + return self.val_to_arg_str(val) + + def val_to_arg_str(self, val) -> str: + if val is None: + # When None is passed as an argument, it represents an optional that does not contain a value. + if config.abi_compatible: + return "0" # nullptr is not available in C + return "c10::nullopt" + elif isinstance(val, bool): + if config.abi_compatible: + return "1" if val else "0" + else: + return "true" if val else "false" + elif isinstance(val, int): + # uint64_t is long on Linux, but long long on MacOS + return f"{val}LL" if sys.platform == "darwin" else f"{val}L" + elif isinstance(val, str): + return f'"{val}"' + elif isinstance( + val, (ir.Buffer, ir.ReinterpretView, ir.StorageBox, ir.TensorBox) + ): + return val.codegen_reference() + elif isinstance(val, torch.device): + return self.codegen_device(val) + elif isinstance(val, torch.dtype): + return self.codegen_dtype(val) + elif isinstance(val, float) and val in [float("inf"), float("-inf")]: + if val == float("inf"): + return "std::numeric_limits::infinity()" + else: + return "-std::numeric_limits::infinity()" + elif isinstance(val, (list, tuple)): + # FIXME handle embedded optional types? + result = f"{{{', '.join(self.val_to_arg_str(x) for x in val)}}}" + if config.abi_compatible: + static = self.is_statically_known_list_of_ints(val) + # Need to pass the array length because we can't use std::vector + int_var_array = self.codegen_int_array_var( + result, + known_statically=static, + graph=self.get_codegened_graph(), + ) + return f"{int_var_array}, {len(val)}" + else: + return result + else: + return repr(val) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..caa3549b1f9d2062ff0e97e7a4ee9b3857451110 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -0,0 +1,328 @@ +import functools +import os +from itertools import chain, count +from typing import Any, List, Optional, TYPE_CHECKING + +import sympy + +from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name + +from .. import config +from ..codecache import CudaKernelParamCache +from ..triton_heuristics import grid as default_grid +from ..virtualized import V +from .cpp_wrapper_cpu import CppWrapperCpu +from .wrapper import SymbolicCallArg + +if TYPE_CHECKING: + from ..graph import GraphLowering + + +def is_int(s: str) -> bool: + # Cpp code gen adds L at the end of ints + # Lets remove it for checking whether we have an int or not + if s and s[-1] == "L": + s = s[:-1] + try: + int(s) + except ValueError: + return False + except TypeError: + return False + return True + + +def is_float(s: str) -> bool: + try: + float(s) + except ValueError: + return False + return True + + +class CppWrapperCuda(CppWrapperCpu): + """ + Generates cpp wrapper for running on GPU and calls CUDA kernels + """ + + def __init__(self): + self.device = "cuda" + super().__init__() + self.grid_id = count() + self.cuda = True + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + super().write_header() + + self.header.splice("#include ") + if config.abi_compatible: + self.header.splice( + "#include " + ) + else: + self.header.splice( + """ + #include + #include + #include + """ + ) + + self.header.splice( + """ + #define CUDA_DRIVER_CHECK(EXPR) \\ + do { \\ + CUresult code = EXPR; \\ + const char *msg; \\ + cuGetErrorString(code, &msg); \\ + if (code != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string(msg)); \\ + } \\ + } while (0); + + namespace { + + struct Grid { + Grid(uint32_t x, uint32_t y, uint32_t z) + : grid_x(x), grid_y(y), grid_z(z) {} + uint32_t grid_x; + uint32_t grid_y; + uint32_t grid_z; + + bool is_non_zero() { + return grid_x > 0 && grid_y > 0 && grid_z > 0; + } + }; + + } // anonymous namespace + + static inline CUfunction loadKernel( + std::string filePath, + const std::string &funcName, + uint32_t sharedMemBytes, + const std::optional &cubinDir = std::nullopt) { + if (cubinDir) { + std::filesystem::path p1{*cubinDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline void launchKernel( + CUfunction func, + uint32_t gridX, + uint32_t gridY, + uint32_t gridZ, + uint32_t numWarps, + uint32_t sharedMemBytes, + void* args[], + cudaStream_t stream) { + CUDA_DRIVER_CHECK(cuLaunchKernel( + func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr + )); + } + """ + ) + + def write_get_raw_stream(self, index, graph=None): + name = f"stream{index}" + self.writeline(f"cudaStream_t {name};") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream({index}, (void**)&{name}));" + ) + return name + + def define_kernel( + self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True + ): + if not cuda: + return super().define_kernel(name, kernel, metadata, cuda) + + def generate(self, is_inference): + self.prefix.writeline("\n") + if not V.graph.aot_mode: + for kernel in chain( + self.src_to_kernel.values(), + [entry[0] for entry in self.user_defined_kernel_cache.values()], + ): + self.prefix.writeline(f"static CUfunction {kernel} = nullptr;") + self.prefix.writeline("\n") + return super().generate(is_inference) + + @functools.lru_cache(None) + def generate_load_kernel_once( + self, + name: str, + mangled_name: str, + cubin_path: str, + shared_mem: int, + graph: "GraphLowering", # for per-graph caching + ): + if V.graph.aot_mode: + self.writeline(f"if (kernels.{name} == nullptr) {{") + self.writeline( + f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);""" + ) + self.writeline("}") + else: + self.writeline(f"if ({name} == nullptr) {{") + self.writeline( + f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});""" + ) + self.writeline("}") + + def generate_args_decl(self, call_args): + dynamic_symbols = V.graph.sizevars.free_symbols() + # TODO: only works for constant now, need type info + new_args = [] + for arg in call_args: + var_name = f"var_{next(self.arg_var_id)}" + if isinstance(arg, (sympy.Integer, sympy.Symbol, SymbolicCallArg)): + self.writeline(f"auto {var_name} = {arg};") + elif isinstance(arg, sympy.Expr): + self.writeline(f"auto {var_name} = {self.expr_printer(arg)};") + elif is_int(arg): + self.writeline(f"int {var_name} = {arg};") + elif is_float(arg): + self.writeline(f"float {var_name} = {arg};") + elif any(str(arg) == s.name for s in dynamic_symbols): + self.writeline(f"auto {var_name} = {arg};") + elif arg == "nullptr": + self.writeline(f"auto {var_name} = nullptr;") + elif arg == "c10::nullopt": + self.writeline(f"auto {var_name} = c10::nullopt;") + else: + if config.abi_compatible: + self.writeline(f"CUdeviceptr {var_name};") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast(&{var_name})));" + ) + else: + self.writeline( + f"CUdeviceptr {var_name} = reinterpret_cast({arg}.data_ptr());" + ) + new_args.append(f"&{var_name}") + + return ", ".join(new_args) + + def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True): + """ + Generate grid configs for launching a CUDA kernel using the grid + function from triton_heuristics. + """ + if not cuda: + return grid + assert isinstance(grid, list), f"expected {grid=} to be a list" + grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid] + grid_fn = default_grid(*grid) + params = CudaKernelParamCache.get(name) + assert ( + params is not None + ), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}" + block_cfg = { + "XBLOCK": params["x_block"], + "YBLOCK": params["y_block"], + "ZBLOCK": params["z_block"], + } + return grid_fn(block_cfg) + + def generate_kernel_call( + self, + name, + call_args, + grid=None, + device_index=None, + cuda=True, + triton=True, + arg_types=None, + grid_fn: str = "grid", + triton_meta=None, + ): + if not cuda: + # Even in CppWrapperCuda, we may see cpp kernels + return super().generate_kernel_call( + name, call_args, grid, device_index, cuda, triton, arg_types + ) + + params = CudaKernelParamCache.get(name) + assert ( + params is not None + ), f"cuda kernel parameters for {name} should already exist at this moment" + mangled_name = params.get("mangled_name", None) + assert mangled_name is not None, "missing mangled_name" + cubin_path = params.get(get_cpp_wrapper_cubin_path_name(), None) + assert cubin_path is not None and os.path.exists( + cubin_path + ), f"cubin file should already exist at this moment: {cubin_path}" + shared_mem = params.get("shared_mem", 0) + + self.generate_load_kernel_once( + name, mangled_name, cubin_path, shared_mem, V.graph + ) + + # args with value 1 are added into equal_to_1 and constants + # in triton_meta (in the Python codegen) which makes them + # inlined in the PTX and compiled CUBIN + if ( + triton_meta is not None + and "configs" in triton_meta + and triton_meta["configs"] + ): + equal_to_1 = triton_meta["configs"][0].equal_to_1 + call_args = [arg for i, arg in enumerate(call_args) if i not in equal_to_1] + + call_args = self.generate_args_decl(call_args) + kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}" + self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};") + stream = ( + "stream" + if V.graph.aot_mode + else self.write_get_raw_stream(device_index, V.graph) + ) + grid_name = f"{name}_grid_{next(self.grid_id)}" + assert isinstance( + grid, (list, tuple) + ), f"expected grid to be a list or tuple but got: {grid=}" + + grid = [V.graph.sizevars.simplify(item) for item in grid] + grid_uses_symbolic_shapes = any(item.free_symbols for item in grid) + grid_args = [self.grid_expr_printer(item) for item in grid] + grid_args_str = ", ".join(grid_args) + self.writeline(f"Grid {grid_name} = Grid({grid_args_str});") + + if grid_uses_symbolic_shapes: + self.writeline(f"if ({grid_name}.is_non_zero()) {{") + kernel_var_name = f"kernels.{name}" if V.graph.aot_mode else name + self.writeline( + "launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format( + kernel_var_name, + f"{grid_name}.grid_x", + f"{grid_name}.grid_y", + f"{grid_name}.grid_z", + params["num_warps"], + params["shared_mem"], + kernel_args_var, + stream, + ) + ) + if grid_uses_symbolic_shapes: + self.writeline("}") diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ca8ef0e593fc1a3aedbef867c3a0093b7439501 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..b9980ef3d1f25e0dd760a0c811416c9b0c7d64a4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -0,0 +1,374 @@ +import logging +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union + +from ... import ir +from ...autotune_process import CUDABenchmarkRequest +from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout, TensorBox +from ...select_algorithm import ChoiceCaller +from ...utils import sympy_product +from ...virtualized import V + +from ..common import IndentedBuffer, Kernel, OpOverrides, PrimitiveInfoType +from ..cpp import CppPrinter, DTYPE_TO_CPP + +if TYPE_CHECKING: + from torch._inductor.codegen.cuda.cuda_template import CUDATemplate + +log = logging.getLogger(__name__) + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +class CUDAKernel(Kernel): + """ + Baseclass for CUDA / Cutlass based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + +class CUDATemplateKernel(CUDAKernel): + """ + Template kernels defined by CUDA / Cutlass in C++. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream" + + def __init__(self, kernel_name): + """ + Initializes a new instance of the CUDATemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + # Mapping from arg name to IRNode. + self.named_nodes: Dict[str, IRNode] = {} + + def arg_name(self, node: IRNode) -> Optional[str]: + """ + Returns arg name of a given input or output node. + """ + if node is None: + return None + return {**self.args.input_buffers, **self.args.output_buffers}.get( + node.get_name(), None + ) + + def check_not_null(self, node: IRNode) -> str: + """ + Generates code to check that a node is not null. + """ + + if node is None: + return "" + + size_str = self.size(node, 0, -1) + name_str = self.arg_name(node) + if name_str is None: + return "" + + res = IndentedBuffer(initial_indent=2) + res.tabwidth = 1 + res.splice( + f""" + {{ + if (!{name_str}) {{ + int64_t {name_str}_size = {size_str}; + if ({name_str}_size > 0) {{ + throw std::runtime_error("input {name_str} is null but size is not 0!"); + }} + }} + }} + """ + ) + return res.getvalue() + + def def_kernel( + self, + inputs: List[IRNode], + outputs: List[IRNode], + names_str: str = "", + input_reorder: Optional[List[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + """ + + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + arg_defs, *_ = self.args.cpp_argdefs() + return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})" + + def call_kernel( + self, name: str, node: "CUDATemplateBuffer", epilogue_nodes: List[ir.Buffer] # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.WrapperCodeGen + + name: Name of kernel function. + node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + _, call_args, _ = self.args.python_argdefs() + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + else: + call_args[i] = f"c_void_p({call_args[i]}.data_ptr())" + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + call_args.append("None") + + if node.get_workspace_size() > 0: + call_args.append(f"c_void_p({node.get_name()}_workspace.data_ptr())") + else: + call_args.append("None") + + wrapper.generate_kernel_call( + name, + call_args, + device_index=V.graph.scheduler.current_device.index, + cuda=True, + triton=False, + ) + + def dtype(self, node: IRNode) -> Optional[str]: + """ + Generates code which represents dtype of a given node. + """ + + if node is None: + return "void" + return DTYPE_TO_CPP.get(node.get_layout().dtype) + + def offset(self, node: IRNode) -> str: + """ + Generates code which represents offset of a given node. + """ + + if node is None: + return "0" + return str(node.get_layout().offset) + + def ptr(self, node: IRNode) -> str: + """ + Generates code which represents pointer of a given node. + """ + + if node is None: + return "nullptr" + arg_name = self.arg_name(node) + if arg_name is None: + return "nullptr" + offset = self.offset(node) + return arg_name if offset == "0" else f"{arg_name} + {offset}" + + def size( + self, + node: IRNode, + start_index: int, + end_index: Optional[int] = None, + default_value: int = 0, + ) -> str: + """ + Hook called from template code to get the size of an arg. + Generates code which represents size of a given node in [start_index, end_index). + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + start_index = _normalize_idx(start_index, len(node.get_size())) + if end_index is None: + end_index = start_index + end_index = _normalize_idx(end_index, len(node.get_size())) + + sizes = node.get_size()[start_index : end_index + 1] + if len(sizes) == 0: + return str(default_value) + + val = sympy_product(sizes) + return cexpr(self.rename_indexing(val)) + + def stride(self, node: IRNode, index: int, default_value: int = 0) -> str: + """ + Hook called from template code to get the stride of an arg. + Generates code which represents stride of a given node at index. + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + index = _normalize_idx(index, len(node.get_size())) + if index < 0: + return str(default_value) + + stride = node.get_stride()[index] + return cexpr(self.rename_indexing(stride)) + + def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str: + """ + Hook called from template code to get the row or column stride of an arg. + This is required by some CUTLASS 2.X APIs. + If the node is in row_major, it returns stride[-2]. + If the node is in column_major, it returns stride[-1]. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None or len(node.get_stride()) < 2: + return str(default_value) + + stride0 = node.get_stride()[-1] + stride1 = node.get_stride()[-2] + if stride0 == 1: + return cexpr(self.rename_indexing(stride1)) + elif stride1 == 1: + return cexpr(self.rename_indexing(stride0)) + else: + raise RuntimeError( + f"At least 1 stride should be 1. Strides: {node.get_stride()=}" + ) + + +class CUDATemplateCaller(ChoiceCaller): + """ + CUDATemplateCaller + + This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CUDABenchmarkRequest): The benchmark request for the caller. + template_buffer (CUDATemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: List[Buffer], + layout: Layout, + make_kernel_render: Callable[[CUDATemplateBuffer, Optional[List[IRNode]]], str], + bmreq: CUDABenchmarkRequest, + template: "CUDATemplate", # type: ignore[name-defined] + info_kwargs: Optional[Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]], # type: ignore[type-arg] + ): + super().__init__(name, input_nodes, layout) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + return self.bmreq.benchmark( + *args, output_tensor=out + ) # @TODO: Hack for ensuring that Cutlass Kernel is preferred + + def __str__(self): + return f"CUDATemplateCaller(source_file={self.bmreq.source_file})" + + def call_name(self) -> str: + return f"cuda_template_kernels.{self.name}" + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + if self.info_kwargs is not None and "op" in self.info_kwargs: + op: Any = self.info_kwargs["op"] + epilogue_node_names: List[str] = [ + getattr(en, "name", "no_name") + for en in self.info_kwargs.get("epilogue_nodes", []) # type: ignore[union-attr] + ] + epilogue_node_strs: List[str] = [ + str(en) for en in self.info_kwargs.get("epilogue_nodes", []) # type: ignore[union-attr] + ] + return { + "backend": "CUDA", + "op_type": type(op).__name__, + "op_conf_name": str(op.configuration_name()), + "op_arch": str(op.arch), + "tile_shape": str(op.tile_description.tile_shape), + "epilogue_schedule": str(op.epilogue_schedule), + "kernel_schedule": str(op.kernel_schedule), + "element_accumulator": str(op.accumulator_type()), + "op_name": str(op.procedural_name()), + "epilogue_node_names": epilogue_node_names, # type: ignore[dict-item] + "epilogue_node_strs": epilogue_node_strs, # type: ignore[dict-item] + "instruction_shape": str( + op.tile_description.math_instruction.instruction_shape + ), + } + else: + return {"backend": "CUDA", "op_type": "unknown"} + + def output_node(self) -> TensorBox: + return TensorBox.create( + CUDATemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + template=self.template, + ) + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bd904b333c38367b885d280d6c55dbbee2b578a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e41ea62bbe5d7bcb5f1e99fc2b16d976a42658eb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/gemm_template.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..9e76b5291c59e9698ad3c9fdaa2480abd0332f21 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/gemm_template.py @@ -0,0 +1,706 @@ +import copy +import logging +import re +from typing import cast, Dict, List, Optional, Tuple + +from ...config import cuda as inductor_cuda_config +from ...ir import Buffer, CUDATemplateBuffer, FixedLayout, IRNode, Layout +from ..common import IndentedBuffer + +from . import cutlass_utils +from .cuda_kernel import CUDATemplateKernel +from .cuda_template import CUTLASSTemplate +from .cutlass_epilogue_gen import ( + CutlassEVTEpilogueArgumentFormatter, + CutlassEVTEpilogueTypeFormatter, +) + +log = logging.getLogger(__name__) + +GEMM_TEMPLATE = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +{{instance_definition}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { + try { + {{kernel.check_not_null(X)}} + {{kernel.check_not_null(W)}} + {{kernel.check_not_null(Bias)}} + {{kernel.check_not_null(Y)}} + int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}}; + int64_t M = {{kernel.size(X, -2)}}; + int64_t K = {{kernel.size(X, -1)}}; + int64_t N = {{kernel.size(W, -1)}}; + using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + {{instance_type}}::Arguments arguments; + {{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw, + X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}} + {{instance_type}} gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op.initialize(arguments, workspace, stream); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op(stream); + CUTLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} +""" + + +GEMM_ARGS_CUTLASS_2X = r""" + int64_t batch_stride_x = {{kernel.stride(X, -3)}}; + int64_t row_stride_x = {{kernel.row_or_column_stride(X)}}; + int64_t batch_stride_w = {{kernel.stride(W, -3)}}; + int64_t row_stride_w = {{kernel.row_or_column_stride(W)}}; + int64_t batch_stride_bias = {{kernel.stride(Bias, -3)}}; + int64_t row_stride_bias = {{kernel.row_or_column_stride(Bias)}}; + int64_t batch_stride_y = {{kernel.stride(Y, -3)}}; + int64_t row_stride_y = {{kernel.row_or_column_stride(Y)}}; + // Initialize GemmUniversalInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast(M), + static_cast(N), + static_cast(K) + }, // GemmCoord problem_size + {{split_k if split_k > 1 else 'B'}}, // int batch_count + {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // void const * ptr_A + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // void const * ptr_B + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // void const * ptr_C + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // void * ptr_D + batch_stride_x, // int64_t batch_stride_A + batch_stride_w, // int64_t batch_stride_B + batch_stride_bias, // int64_t batch_stride_C + batch_stride_y, // int64_t batch_stride_D + row_stride_x, // typename LayoutA::Stride::LongIndex lda + row_stride_w, // typename LayoutB::Stride::LongIndex ldb + row_stride_bias, // typename LayoutC::Stride::LongIndex ldc + row_stride_y, // typename LayoutC::Stride::LongIndex ldd + }; +""" + + +GEMM_ARGS_CUTLASS_3X = r""" + // Initialize GemmUniversal3xInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast({{M}}), + static_cast({{N}}), + static_cast(K), + static_cast(B) + }, // ProblemShape problem_shape + { + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // ElementA const* ptr_A + { + {{template.cute_int(kernel.stride(X, -2), "stride_x0")}}, + {{template.cute_int(kernel.stride(X, -1), "stride_x1")}}, + {{template.cute_int(kernel.stride(X, -3), "batch_stride_x")}} + }, // StrideA dA + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B + { + {{template.cute_int(kernel.stride(W, -1), "stride_w1")}}, + {{template.cute_int(kernel.stride(W, -2), "stride_w0")}}, + {{template.cute_int(kernel.stride(W, -3), "batch_stride_w")}} + }, // StrideB dB + }, // MainloopArguments mainloop + {{epilogue_arguments}} + }; +""" + +GEMM_ARGS_CUTLASS_3X_EPILOGUE = r""" + // see https://tinyurl.com/4rk89z48 + { + {{epilogue_args}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // ElementC const* ptr_C + { + {{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}}, + {{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}}, + {{template.cute_int(kernel.stride(Bias, -3), "batch_stride_bias")}} + }, // StrideC dC + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D + { + {{template.cute_int(kernel.stride(Y, -2), "stride_y0")}}, + {{template.cute_int(kernel.stride(Y, -1), "stride_y1")}}, + {{template.cute_int(kernel.stride(Y, -3), "batch_stride_y")}} + }, // StrideD dD + }, // EpilogueArguments epilogue +""" + + +class CUTLASSGemmTemplate(CUTLASSTemplate): + """ + CUTLASS GEMM template, which is used to generate CUTLASS GEMM kernels + including those which allow flexible fusions with epilogues. + """ + + def __init__( + self, + input_nodes: List[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[List[int]] = None, + can_fuse_epilogue: Optional[bool] = None, + ): + """ + Args: + input_nodes: input nodes of the kernel + layout: layout of the output node + alpha: alpha value of the GEMM operation + beta: beta value of the GEMM operation + input_reorder: reorder of the input nodes + can_fuse_epilogue: If set to True, will only list and use operators capable of flexible epilogue fusions. + If False, it will not use those. If None, both may be listed, but it will not allow fusions. + Defaults to None + """ + super().__init__("cutlass_gemm", input_nodes, layout, input_reorder) + self.alpha = alpha + self.beta = beta + self.can_fuse_epilogue = can_fuse_epilogue + + @staticmethod + def add_cutlass_gemm_choices( + choices, + layout, + input_nodes, + alpha=1, + beta=0, + input_reorder=None, + fuseable=True, + non_fuseable=True, + ): + if non_fuseable: + if fuseable: + # list both fuseable and non-fuseable ops, and treat them all as non-fuseable + can_fuse_epilogue = False + else: + can_fuse_epilogue = None + + cutlass_template = CUTLASSGemmTemplate( + input_nodes, + layout, + alpha=alpha, + beta=beta, + input_reorder=input_reorder, + can_fuse_epilogue=can_fuse_epilogue, + ) + ops = cutlass_template.gen_ops() + for op in ops: + cutlass_template.maybe_append_choice( + choices, + op=op, + ) + else: + ops = [] + if fuseable: + cutlass_template_evt = CUTLASSGemmTemplate( + input_nodes, + layout, + alpha=alpha, + beta=beta, + input_reorder=input_reorder, + can_fuse_epilogue=True, + ) + # This will list only ops capable of EVT fusion + ops_evt = cutlass_template_evt.gen_ops() + for op in ops_evt: + cutlass_template_evt.maybe_append_choice( + choices, + op=op, + ) + else: + ops_evt = [] + log.debug( + "Added %d cutlass gemm configs and %d fuseable gemm configs.", + len(ops), + len(ops_evt), + ) + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + #include "cutlass/gemm/gemm.h" + #include "cutlass/gemm/device/gemm_universal.h" + #include "cutlass/gemm/device/gemm_universal_adapter.h" + #include "cutlass/gemm/kernel/gemm_universal.hpp" + #include "cutlass/gemm/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/default_epilogue.hpp" + #include "cutlass/epilogue/thread/linear_combination.h" + #include "cutlass/gemm/dispatch_policy.hpp" + #include "cutlass/gemm/kernel/tile_scheduler.hpp" + #include "cutlass/util/distribution.h" + #include "cutlass/util/packed_stride.hpp" + #include "cutlass/util/tensor_view_io.h" + """ + ) + return res + + @staticmethod + def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821 + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if torch_layout.stride[-1] == 1: + return cutlass_lib.LayoutType.RowMajor + elif torch_layout.stride[-2] == 1: + return cutlass_lib.LayoutType.ColumnMajor + else: + return None + + @staticmethod + def flip_cutlass_layout( + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_lib.LayoutType": # type: ignore[name-defined] # noqa: F821 + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if cutlass_layout == cutlass_lib.LayoutType.RowMajor: + return cutlass_lib.LayoutType.ColumnMajor + else: + return cutlass_lib.LayoutType.RowMajor + + @staticmethod + def layout_match(torch_layout, cutlass_layout) -> bool: + return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout + + @staticmethod + def set_alignment(torch_layout, op_element) -> bool: + alignment = cutlass_utils.get_max_alignment(torch_layout) + if alignment < op_element.alignment: + return False + else: + op_element.alignment = alignment + return True + + @staticmethod + def has_tma_epilogue(op) -> bool: + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + result = False + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + epilogue_schedule_str = str(op.epilogue_schedule).split(".")[-1] + result = epilogue_schedule_str.lower().startswith("tma") + return result + + @staticmethod + def supports_evt(op: "cutlass_library.gemm_op.GemmOperation") -> bool: # type: ignore[name-defined] # noqa: F821 + """ + returns True if the op is capable of flexible epilogue fusions + using epilogue visitor trees. + + See https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L283-L285 # noqa: B950 + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if op.gemm_kind != cutlass_lib.GemmKind.Universal3x: + return False + if op.epilogue_schedule not in ( + cutlass_lib.EpilogueScheduleType.TmaWarpSpecialized, + cutlass_lib.EpilogueScheduleType.TmaWarpSpecializedCooperative, + ): + return False + + return True + + def render_evt_epilogue_declaration( + self, + template_output_node_name: str, + evt_type_name: str, + epilogue_nodes: List[IRNode], + ) -> str: + """Generates the epilogue for the EVT epilogue fusion""" + return CutlassEVTEpilogueTypeFormatter.ir_to_evt_string( + template_output_node_name, evt_type_name, epilogue_nodes + ) + + def define_gemm_instance( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + output_buffer_name: str, + epilogue_nodes: Optional[List[IRNode]] = None, + ) -> Tuple[str, str]: + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + from torch._inductor.codegen.cuda.cutlass_lib_extensions.gemm_operation_extensions import ( + EmitGemmUniversal3xInstanceWithEVT, + ) + + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + emitter = EmitGemmUniversal3xInstanceWithEVT() + op.epilogue_functor = lambda epilogue_functor_type_name: self.render_evt_epilogue_declaration( + output_buffer_name, epilogue_functor_type_name, epilogue_nodes + ) + else: + emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance() + op_def = emitter.emit(op) + pattern = re.compile(r"\s*struct\s(.*?)\s:") + decl = [line for line in op_def.split("\n") if "struct " in line][-1] + else: + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + raise RuntimeError( + "EVT epilogue fusion is not supported for Cutlass 2.x ops." + ) + emitter = cutlass_gemm_op.EmitGemmInstance() + op_def = emitter.emit(op) + op_def = op_def.replace( + "cutlass::gemm::device::Gemm", "cutlass::gemm::device::GemmUniversal" + ) + op_def = op_def.replace("false,", "") + pattern = re.compile(r"\s*using\s(.*?)\s=") + decl = op_def.split("\n")[2] + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid Gemm config: \n" + op_def) + op_type = match.groups()[0] + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + op_def += f"\n using {op_type}_device_type = cutlass::gemm::device::GemmUniversalAdapter<{op_type}>;\n" + op_type = f"{op_type}_device_type" + return op_def, op_type + + @staticmethod + def should_swap_XW( + bias: IRNode, + beta: float, + ) -> bool: + return True + + # TODO(ipiszy): Check whether it's necessary to swap X/W. + # strides = bias.get_stride() + # if strides[-1] != 1: + # return True + # for stride in strides[:-1]: + # if stride != 0: + # return True + # return False + + @staticmethod + def swap_XW( + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + # Swap X and W in GemmOperation. + new_op = copy.deepcopy(op) + new_op.A.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.A.layout) + new_op.B.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.B.layout) + new_op.A, new_op.B = new_op.B, new_op.A + new_op.C.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.C.layout) + new_op.D.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.D.layout) + return new_op + + def filter_op( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + # Skip simt kernels + if ( + op.tile_description.math_instruction.opcode_class + == cutlass_lib.OpcodeClass.Simt + ): + return None + + # Only keep GemmUniversal kernels + if op.gemm_kind not in { + cutlass_lib.GemmKind.Universal, + cutlass_lib.GemmKind.Universal3x, + }: + return None + # Filter ops by dtypes. + X = self.input_nodes[0] + W = self.input_nodes[1] + accumulator_torch_dtype = cutlass_utils.get_accumulator_dtype( + [X.get_dtype(), W.get_dtype()], + ) + if not ( + cutlass_utils.dtype_match(X.get_dtype(), op.A.element) + and cutlass_utils.dtype_match(W.get_dtype(), op.B.element) + and cutlass_utils.dtype_match( + self.output_node.get_layout().dtype, op.C.element + ) + and cutlass_utils.dtype_match( + accumulator_torch_dtype, op.accumulator_type() + ) + ): + return None + + # Filter ops by input layouts. + if not ( + self.layout_match(X.get_layout(), op.A.layout) + and self.layout_match(W.get_layout(), op.B.layout) + ): + return None + + # Update op. + op = copy.deepcopy(op) + + # Set output layout. + op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout()) + + # Filter ops by alignments and set alignments. + if not ( + self.set_alignment(X.get_layout(), op.A) + and self.set_alignment(W.get_layout(), op.B) + and self.set_alignment(self.output_node.get_layout(), op.D) + ): + return None + + # Set epilogue. + # TODO: update epilogue functor according to epilogues. + op.element_epilogue = op.accumulator_type() + + # Set bias layout and alignment. + if len(self.input_nodes) >= 3 and self.input_nodes[2] is not None: + Bias = self.input_nodes[2] + bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout()) + if op.gemm_kind != cutlass_lib.GemmKind.Universal3x: + if bias_layout != op.D.layout: + # For cutlass2, bias and output layout must match + return None + else: + op.C.layout = bias_layout + if not self.set_alignment(Bias.get_layout(), op.C): + return None + else: + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + op.C.element = cutlass_lib.DataType.void + else: + op.C.layout = op.D.layout + supports_evt: bool = self.supports_evt(op) + if (self.can_fuse_epilogue is not None) and ( + self.can_fuse_epilogue != supports_evt + ): + return None + if inductor_cuda_config.cutlass_only_evt_capable_ops and not supports_evt: + return None + return op + + def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm] + res: Dict[str, cutlass_gemm_op.GemmOperation] = dict() + num_3x_ops = 0 + num_2x_ops = 0 + for op_dict in ops.values(): + for op_list in op_dict.values(): + for op in op_list: + assert isinstance(op, cutlass_gemm_op.GemmOperation) + filter_res = self.filter_op(op) + if ( + filter_res is not None + and res.get(filter_res.configuration_name(), None) is None + ): + res[filter_res.configuration_name()] = filter_res + for op in res.values(): + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + num_3x_ops += 1 + else: + num_2x_ops += 1 + log.debug( + "Got cutlass configs: total number of ops: %d, " + "total number of 3x ops: %d, total number of 2x ops: %d", + len(res), + num_3x_ops, + num_2x_ops, + ) + return list(res.values())[: inductor_cuda_config.cutlass_max_profiling_configs] + + def gemm_mode(self) -> str: + sizes = self.output_node.get_size() + if len(sizes) > 2: + return "cutlass::gemm::GemmUniversalMode::kBatched" + else: + return "cutlass::gemm::GemmUniversalMode::kGemm" + + def render_gemm_arguments( + self, + argument_template: str, + epilogue_template: str, + should_swap_xw: bool, + X: IRNode, + W: IRNode, + Bias: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: CUDATemplateKernel, + epilogue_args, + ) -> str: + options = dict( + alpha=self.alpha, + beta=self.beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + template=self, + kernel=kernel, + M="M", + N="N", + epilogue_args=epilogue_args, + ) + + if epilogue_template is not None: + if should_swap_xw: + # Swap + def clone_with_transposed_stride(node: IRNode) -> IRNode: + old_layout = node.get_layout() + new_stride = list(old_layout.stride) + new_stride[-2], new_stride[-1] = new_stride[-1], new_stride[-2] + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + list(old_layout.size), + new_stride, + old_layout.offset, + ) + return Buffer(node.get_name(), new_layout) + + new_X = clone_with_transposed_stride(X) + new_W = clone_with_transposed_stride(W) + new_Bias = clone_with_transposed_stride(Bias) + new_Y = clone_with_transposed_stride(Y) + options["X"], options["W"], options["Bias"], options["Y"] = ( + new_W, + new_X, + new_Bias, + new_Y, + ) + options["M"], options["N"] = "N", "M" + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, **options + ) + else: + arguments = self._template_from_string(GEMM_ARGS_CUTLASS_2X).render( + split_k=1, **options + ) + return arguments + + def render( # type: ignore[override] + self, + kernel: CUDATemplateKernel, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + template_buffer_node: Optional[CUDATemplateBuffer] = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs, + ) -> str: + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + assert self.can_fuse_epilogue and CUTLASSGemmTemplate.supports_evt( + op + ), "op does not support EVT epilogue fusion" + assert ( + template_buffer_node is not None + ), "Template node is required for epilogue fusion" + assert isinstance( + template_buffer_node, CUDATemplateBuffer + ), f"Template node has to be a CUDATemplateBuffer, is type {type(template_buffer_node)}" + assert ( + template_buffer_node.name is not None + ), "Output node has to be a Buffer with a name" + # This is the name of the output of the Matmul, before epilogues are applied. + # it is not necessarily materialized in global memory if we have an epilogue + + template_output_node_name = ( + template_buffer_node.name if template_buffer_node is not None else None + ) + + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + assert isinstance( + op, cutlass_gemm_op.GemmOperation + ), "op argument is required and has to be an instance of GemmOperation" + if template_buffer_node is not None: + self.output_node = template_buffer_node + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + self.output_node = cast(Buffer, epilogue_nodes[-1]) + + assert len(self.input_nodes) >= 2 and self.output_node is not None + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + epilogue_template: Optional[str] = None + should_swap_xw: bool = False + epilogue_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}" + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + if Bias is not None and self.has_tma_epilogue(op): + if self.should_swap_XW(Bias, self.beta): + # TMA epilogue requires bias vector in column major to get best perf. + op = self.swap_XW(op) + should_swap_xw = True + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + epilogue_args = ( + CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string( + cast(str, template_output_node_name), epilogue_nodes + ) + ) + epilogue_template = GEMM_ARGS_CUTLASS_3X_EPILOGUE + argument_template = GEMM_ARGS_CUTLASS_3X + else: + # TODO: Support split_k. + argument_template = GEMM_ARGS_CUTLASS_2X + + instance_definition, instance_type = self.define_gemm_instance( + op, cast(str, template_output_node_name), epilogue_nodes + ) + options = dict( + alpha=self.alpha, + beta=self.beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + epilogue_template=epilogue_template, + argument_template=argument_template, + should_swap_xw=should_swap_xw, + template=self, + kernel=kernel, + instance_definition=instance_definition, + instance_type=instance_type, + input_reorder=self.input_reorder, + epilogue_args=epilogue_args, + ) + res = self._template_from_string(GEMM_TEMPLATE).render(**options) + return res diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/memory_planning.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/memory_planning.py new file mode 100644 index 0000000000000000000000000000000000000000..d94e4723dba3828cff886148eacc5707def8ab5c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/memory_planning.py @@ -0,0 +1,799 @@ +from __future__ import annotations + +import collections +import dataclasses +import itertools +import pprint +from typing import Any, Dict, Iterable, List, Optional, Protocol + +import sympy + +import torch +from .. import config, ir +from ..utils import cache_on_self, CachedMethod, IndentedBuffer +from ..virtualized import V + +from .wrapper import ( + AllocateLine, + FreeIfNotReusedLine, + MemoryPlanningLine, + NullLine, + ReuseLine, +) + + +ALIGN_BYTES = 64 +assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" + + +def _align(nbytes): + """Round up to the nearest multiple of ALIGN_BYTES""" + return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES + + +def _is_aligned(v: sympy.Expr): + """v can be statically proven to be a multiple of ALIGN_BYTES""" + if isinstance(v, (sympy.Add, sympy.Max)): + return all(map(_is_aligned, v.args)) + return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES + + +class align(sympy.Function): + """Symbolically round up to the nearest multiple of ALIGN_BYTES""" + + nargs = (1,) + is_integer = True + + @classmethod + def eval(cls, value): + if isinstance(value, (int, sympy.Integer)): + return _align(int(value)) + if _is_aligned(value): + return value + + +@dataclasses.dataclass +class LiveRange: + """ + A range where a given tensor is live. Begin and end are both counters + representing points in the program of grouped memory operations. + Begin is inclusive, end is exclusive. + + Invariant: begin <= end + """ + + begin: float # int | ±inf + end: float # int | ±inf + + def contains(self, other: LiveRange): + """Is other entirely within self""" + return self.begin <= other.begin and other.end <= self.end + + def join(self, other: LiveRange): + """Combine two ranges using a union operation""" + return LiveRange(min(self.begin, other.begin), max(self.end, other.end)) + + def __len__(self): + return self.end - self.begin + + +class LiveRanges: + """ + A collection of LiveRange regions, allowing for non-contiguous + live regions. + + Invariant: LiveRanges.ranges is in sorted order and non-overlapping + """ + + def __init__(self, ranges: Iterable[LiveRange]): + ranges = [*sorted(ranges, key=lambda x: x.begin)] + self.ranges = ranges[:1] + for r in ranges[1:]: + assert self.ranges[-1].begin <= r.begin + if self.ranges[-1].end >= r.begin: + self.ranges[-1] = LiveRange.join(self.ranges[-1], r) + else: + self.ranges.append(r) + + def overlaps(self, other: LiveRanges): + """Check if any pair of ranges in self and other overlap""" + left = collections.deque(self.ranges) + right = collections.deque(other.ranges) + while left and right: + if left[0].begin > right[0].begin: + left, right = right, left + assert left[0].begin <= right[0].begin + if left[0].end > right[0].begin: + return True + left.popleft() + return False + + @property + def begin(self): + return self.ranges[0].begin + + @property + def end(self): + return self.ranges[-1].end + + def __repr__(self): + return f"{self.__class__.__name__}([{', '.join(map(repr, self.ranges))}])" + + +class AllocationTreeNode: + """ + Abstract base class for nodes in allocation pool. + """ + + def allocate(self, block: Allocation, is_last: bool) -> bool: + """ + Try to assign block to a memory location in this bool. Return True if + an assignment was made. + """ + return False + + def get_live_ranges(self) -> LiveRanges: + """Aggregate LiveRanges for all objects below this in tree""" + raise NotImplementedError() + + def get_size_hint(self) -> int: + """Number of bytes used for example inputs""" + raise NotImplementedError() + + def get_symbolic_size(self) -> sympy.Expr: + """Number of bytes needed at runtime""" + raise NotImplementedError() + + def finalize(self, pool, offset) -> AllocationTreeNode: + """Called after all allocations have been made""" + return self + + def is_empty(self): + return False + + +@dataclasses.dataclass +class Allocation(AllocationTreeNode): + """ + Represents memory allocated to a given node in the allocation pool. + """ + + node: ir.Buffer + live_range: LiveRange + size_hint: int + symbolic_size: sympy.Expr + allocated: bool = False + pool: Optional[AllocationPool] = None + offset: Optional[sympy.Expr] = None + + @property + def device(self): + return self.node.get_device() + + def get_live_ranges(self): + return LiveRanges([self.live_range]) + + def get_size_hint(self): + return self.size_hint + + def get_symbolic_size(self): + return self.symbolic_size + + def mark_allocated(self): + assert not self.allocated + self.allocated = True + + def finalize(self, pool, offset): + assert self.pool is None and self.offset is None + self.pool = pool + self.offset = offset + return self + + def codegen_alloc_from_pool(self, wrapper): + assert self.pool + node = self.node + shape = tuple(node.get_size()) + stride = tuple(node.get_stride()) + return wrapper.codegen_alloc_from_pool( + self.pool.name, self.offset, node.get_dtype(), shape, stride + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"node={self.node.get_name()}, " + f"live_range={self.live_range}, " + f"size_hint={self.size_hint}, " + f"symbolic_size={self.symbolic_size}, " + f"pool={self.pool.name if self.pool else None}, " + f"offset={self.offset})" + ) + + +@dataclasses.dataclass +class Empty(AllocationTreeNode): + """ + Placeholder to represent empty space in the allocation pool. + Only exists to get the size_hint correct in parent nodes. + """ + + size_hint: int + + def get_live_ranges(self): + return LiveRanges([]) + + def get_size_hint(self): + return self.size_hint + + def get_symbolic_size(self): + return 0 + + def is_empty(self): + return True + + +class MemorySplitProtocol(Protocol): + get_live_ranges: CachedMethod[[], LiveRanges] + get_size_hint: CachedMethod[[], int] + get_symbolic_size: CachedMethod[[], sympy.Expr] + + def _allocate(self, block: Allocation, is_last: bool) -> bool: + ... + + +class ClearCacheOnAllocateMixin(MemorySplitProtocol): + """ + Helper to assist in caching get_live_ranges, get_size_hint, and + get_symbolic_size. + """ + + def allocate(self, block: Allocation, is_last: bool): + is_allocated = self._allocate(block, is_last) + if is_allocated: + self.clear_cache() + return is_allocated + + def clear_cache(self): + self.get_live_ranges.clear_cache(self) + self.get_size_hint.clear_cache(self) + self.get_symbolic_size.clear_cache(self) + + +@dataclasses.dataclass +class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): + """ + Contains a list of allocations not overlapping in LiveRanges. + + Invariant: no pair (a,b) in self.allocations will have: + a.get_live_ranges().overlaps(b.get_live_ranges()) + """ + + allocations: List[AllocationTreeNode] + + def _allocate(self, block: Allocation, is_last: bool): + slot_size = self.get_size_hint() + block_size = block.get_size_hint() + if not is_last and block_size > slot_size: + return False # doesn't fit + + block_live = block.get_live_ranges() + overlapping = [ + s for s in self.allocations if s.get_live_ranges().overlaps(block_live) + ] + if len(overlapping) > 1: + # TODO(jansel): we could try harder here by merging overlapping in space + return False + elif len(overlapping) == 1: + return overlapping[0].allocate(block, is_last) + else: + block.mark_allocated() + + if len(self.allocations) == 1 and isinstance(self.allocations[-1], Empty): + self.allocations.pop() + + if slot_size == block_size: + # perfect fit + self.allocations.append(block) + elif slot_size > block_size: + self.allocations.append( + SpatialSplit.create(block, slot_size - block_size) + ) + else: # grow this allocation + assert is_last + self.allocations = [ + *( + SpatialSplit.create(a, block_size - slot_size) + for a in self.allocations + ), + block, + ] + return True + + @cache_on_self + def get_live_ranges(self) -> LiveRanges: + return LiveRanges( + itertools.chain.from_iterable( + x.get_live_ranges().ranges for x in self.allocations + ) + ) + + @cache_on_self + def get_size_hint(self) -> int: + if not self.allocations: + return 0 + return max(x.get_size_hint() for x in self.allocations) + + @cache_on_self + def get_symbolic_size(self) -> sympy.Expr: + if not self.allocations: + return 0 # type: ignore[return-value] + return sympy.Max(*[x.get_symbolic_size() for x in self.allocations]) + + def is_empty(self): + return len(self.allocations) == 1 and self.allocations[0].is_empty() + + def finalize(self, pool, offset): + self.allocations = [block.finalize(pool, offset) for block in self.allocations] + self.clear_cache() + if len(self.allocations) == 1: + return self.allocations[0] + return self + + +@dataclasses.dataclass +class SpatialSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): + """ + Contains two allocations, left and right, that do not overlap in space. + Right will be allocated immediately after left in memory. + """ + + left: TemporalSplit + right: TemporalSplit + + @staticmethod + def create(left, extra_space): + assert isinstance(left, AllocationTreeNode) + assert isinstance(extra_space, int) and extra_space >= 1 + return SpatialSplit(TemporalSplit([left]), TemporalSplit([Empty(extra_space)])) + + def _allocate(self, block: Allocation, is_last: bool): + return self.left.allocate(block, False) or self.right.allocate(block, is_last) + + @cache_on_self + def get_live_ranges(self): + return LiveRanges( + itertools.chain( + self.left.get_live_ranges().ranges, self.right.get_live_ranges().ranges + ) + ) + + @cache_on_self + def get_size_hint(self) -> int: + return _align(self.left.get_size_hint()) + self.right.get_size_hint() + + @cache_on_self + def get_symbolic_size(self) -> sympy.Expr: + return align(self.left.get_symbolic_size()) + self.right.get_symbolic_size() + + def finalize(self, pool, offset): + self.left = self.left.finalize(pool, offset) + self.right = self.right.finalize( + pool, offset + align(self.left.get_symbolic_size()) + ) + self.clear_cache() + if self.right.is_empty(): + return self.left + return self + + +@dataclasses.dataclass +class AllocationPool: + """ + Represents a pool of allocations that will be generated by a single + call to torch.empty. + """ + + device: torch.device + root: TemporalSplit + can_expand: bool = True + restrict_live_range: Optional[LiveRange] = None + name: Optional[str] = None + names_to_del: List[str] = dataclasses.field(default_factory=list) + creation_cache: Dict[str, str] = dataclasses.field(default_factory=dict) + + def allocate(self, block: Allocation, is_last: bool): + if self.restrict_live_range and not self.restrict_live_range.contains( + block.live_range + ): + return False + + is_last = self.can_expand and is_last + if self.root.allocate(block, is_last): + return True + + if is_last: + return self.allocate_at_end(block) + + return False + + def allocate_at_end(self, block): + block.mark_allocated() + self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))]) + return True + + def finalize(self, name): + assert not self.name + self.name = name + self.names_to_del.append(name) + self.root.finalize(self, 0) + + def codegen_create(self, wrapper, code: IndentedBuffer): + assert self.name + nbytes = self.root.get_symbolic_size() + for block in self.root.allocations: + if isinstance(block, Allocation) and nbytes == block.get_symbolic_size(): + # optimization: fuse first allocation and pool creation + node = block.node + code.writeline( + wrapper.make_allocation( + self.name, + device=self.device, + dtype=node.get_dtype(), + shape=tuple(node.get_size()), + stride=tuple(node.get_stride()), + ) + ) + self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name + return + else: + code.writeline( + wrapper.make_allocation( + self.name, + device=self.device, + dtype=torch.uint8, + shape=(nbytes,), + stride=(1,), + ) + ) + + def codegen_destroy(self, wrapper, code: IndentedBuffer): + code.writeline(wrapper.make_free_by_names(self.names_to_del)) + + def __eq__(self, other): + return self is other + + def __hash__(self): + return id(self) + + +@dataclasses.dataclass +class AllocationPools: + """ + Collection of many AllocationPool objects grouped by device. + """ + + device_to_pools: Dict[torch.device, List[AllocationPool]] = dataclasses.field( + default_factory=dict + ) + + def get_pools(self, block): + if block.device not in self.device_to_pools: + self.device_to_pools[block.device] = [] + return self.device_to_pools[block.device] + + def allocate(self, block: Allocation): + pools = self.get_pools(block) + + for pool in pools: + if pool.allocate(block, is_last=pool is pools[-1]): + return + + # everything is full, make a new pool + pools.append( + AllocationPool( + block.device, + TemporalSplit([block]), + can_expand=config.memory_pool != "none", + ) + ) + block.mark_allocated() + + def allocate_output(self, block: Allocation): + """Outputs get different pools so memory gets freed properly""" + pools = self.get_pools(block) + if pools and config.memory_pool in ("outputs", "combined"): + pools[-1].allocate_at_end(block) + else: + # create a new pool + block.mark_allocated() + pools.append( + AllocationPool( + block.device, + TemporalSplit([block]), + can_expand=config.memory_pool == "combined", + ) + ) + + def finalize(self): + """Called at the end of allocation process""" + for i, pool in enumerate( + itertools.chain.from_iterable(self.device_to_pools.values()) + ): + pool.finalize(f"pool{i}") + + def pprint(self): + for pool in itertools.chain.from_iterable(self.device_to_pools.values()): + print() + print(pool.name) + print(pool.root.get_live_ranges()) + pprint.pprint(pool.root) + + +class BufferGroup: + """ + Due to inplace reuse an allocated buffer can have many names. + This tracks these collections of buffers sharing underlying memory. + """ + + def __init__(self, node: ir.Buffer): + self.node = node + self.names = [node.get_name()] + self.is_output = False + self.allocation: Optional[Allocation] = None + self.live_range = LiveRange(float("inf"), -float("inf")) + + def update_usage(self, timestep: int): + """Expand self.live_range to include timestep""" + self.live_range = LiveRange( + min(timestep, self.live_range.begin), + max(timestep, self.live_range.end), + ) + + def sym_nbytes(self): + return self.node.get_layout().storage_size() * self.node.get_dtype().itemsize + + def make_allocation(self): + assert not self.allocation, "multiple allocations" + assert isinstance(self.live_range.begin, int), "live ranges not computed" + nbytes = self.sym_nbytes() + # For now, fallback value will be used if we encounter an unbacked SymInt. The longer-term plan is to have + # size_hint() use better heuristics for unbackeds, at which point the fallback value will be ignored. + size_hint = V.graph.sizevars.size_hint(nbytes, fallback=64) + self.allocation = Allocation( + self.node, + self.live_range, + size_hint=size_hint, + symbolic_size=nbytes, + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.names!r}, is_output={self.is_output}, " + f"live_range={self.live_range}" + ) + + +@dataclasses.dataclass +class PoolMemoryPlanningLine(MemoryPlanningLine): + """Abstract base class for {Alloc,Dealloc}FromPoolLine""" + + group: BufferGroup + timestep: Optional[int] = None + + @property + def node(self): + return self.group.node + + +@dataclasses.dataclass +class AllocFromPoolLine(PoolMemoryPlanningLine): + """Similar to AllocationLine, but takes memory from a pool""" + + is_first_pool_usage: bool = False + + def codegen(self, code: IndentedBuffer): + allocation = self.group.allocation + assert allocation and allocation.pool + pool = allocation.pool + name = self.node.get_name() + + if self.is_first_pool_usage: + pool.codegen_create(self.wrapper, code) + + pool.names_to_del.extend(self.group.names) + alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper) + if alloc_from_pool in pool.creation_cache: + code.writeline( + self.wrapper.make_tensor_alias( + name, pool.creation_cache[alloc_from_pool], "alloc" + ) + ) + else: + pool.creation_cache[alloc_from_pool] = name + code.writeline( + f"{self.wrapper.declare}{name} = {alloc_from_pool}{self.wrapper.ending}" + ) + + +@dataclasses.dataclass +class DeallocFromPoolLine(PoolMemoryPlanningLine): + """Similar to FreeIfNotReusedLine, but takes memory from a pool""" + + is_last_pool_usage: bool = False + + def codegen(self, code: IndentedBuffer): + if self.is_last_pool_usage: + assert self.group.allocation and self.group.allocation.pool + self.group.allocation.pool.codegen_destroy(self.wrapper, code) + + +@dataclasses.dataclass +class MemoryPlanner: + """ + Coordination object to run memory planning passes during wrapper + codegen. + """ + + wrapper: Any + pools: AllocationPools = dataclasses.field(default_factory=AllocationPools) + buffer_groups: Optional[List[BufferGroup]] = None + + def plan(self, lines: List[Any]) -> List[Any]: + """Call all the memory planning passes in sequence""" + lines = [*lines] + self.drop_removed_buffers(lines) + self.convert_to_pool_lines(lines) + self.compute_live_ranges(lines) + self.allocate_groups() + self.mark_first_last_usage(lines) + return lines + + def drop_removed_buffers(self, lines): + """ + Replace any memory planning lines in V.graph.removed_buffers with NullLine + """ + # drop any removed buffers + for i, line in enumerate(lines): + if isinstance(line, (AllocateLine, FreeIfNotReusedLine, ReuseLine)): + if line.node.get_name() in V.graph.removed_buffers: + lines[i] = NullLine(self.wrapper) + + def compute_buffer_groups(self, lines): + """ + Populates self.buffer_groups with BufferGroup objects that join + allocations with common storage (due to inplace reuse) into a + single object. + """ + name_to_group = {} + for line in lines: + if isinstance(line, AllocateLine): + name = line.node.get_name() + assert name not in name_to_group + name_to_group[name] = BufferGroup(line.node) + elif isinstance(line, ReuseLine): + old_name = line.node.get_name() + new_name = line.reused_as.get_name() + assert new_name not in name_to_group + # TODO(jansel): we should support reusing buffers created via ExternKernelAlloc + if old_name in name_to_group: + name_to_group[old_name].names.append(new_name) + name_to_group[new_name] = name_to_group[old_name] + + outputs = set(V.graph.get_output_names()) + unique_groups = [*{id(g): g for g in name_to_group.values()}.values()] + for group in unique_groups: + group.is_output = any(x in outputs for x in group.names) + + assert self.buffer_groups is None + self.buffer_groups = unique_groups + return name_to_group + + def convert_to_pool_lines(self, lines): + """ + Convert AllocateLine/FreeIfNotReusedLine/ReuseLine into their + pool-based counterparts. + """ + name_to_group = self.compute_buffer_groups(lines) + for i, line in enumerate(lines): + if isinstance(line, AllocateLine): + if line.node.get_name() in name_to_group: + lines[i] = AllocFromPoolLine( + self.wrapper, name_to_group[line.node.get_name()] + ) + elif isinstance(line, FreeIfNotReusedLine): + assert not line.is_reused + if line.node.get_name() in name_to_group: + lines[i] = DeallocFromPoolLine( + self.wrapper, name_to_group[line.node.get_name()] + ) + elif isinstance(line, ReuseLine): + if line.node.get_name() in name_to_group: + line.delete_old = False + + def compute_live_ranges(self, lines): + """Populate every BufferGroup.live_ranges field based on first/last usage""" + timestep = 0 + worklist = collections.deque(lines) + while worklist: + if isinstance(worklist[0], MemoryPlanningLine): + timestep += 1 + while worklist and isinstance(worklist[0], MemoryPlanningLine): + line = worklist.popleft() + if isinstance(line, PoolMemoryPlanningLine): + line.group.update_usage(timestep) + line.timestep = timestep + else: + worklist.popleft() + + timestep += 1 + assert self.buffer_groups is not None + for group in self.buffer_groups: + if group.is_output: + group.update_usage(timestep) + + def allocate_groups(self): + """ + Assign every allocation to a specific location in a specific AllocationPool. + """ + assert config.memory_pool in ("none", "intermediates", "outputs", "combined") + assert self.buffer_groups is not None + + for group in self.buffer_groups: + group.make_allocation() + + outputs: List[Allocation] = [] + intermediates: List[Allocation] = [] + for group in self.buffer_groups: + assert group.allocation + if group.is_output and config.memory_pool != "combined": + outputs.append(group.allocation) + else: + intermediates.append(group.allocation) + + for block in sorted( + outputs, + key=lambda x: ( + x.size_hint, + -len(x.live_range), + ), + ): + self.pools.allocate_output(block) + + for block in sorted( + intermediates, + key=lambda x: ( + -x.size_hint, + -len(x.live_range), + ), + ): + self.pools.allocate(block) + + self.pools.finalize() + + def mark_first_last_usage(self, lines): + """ + Populate the AllocFromPoolLine.is_first_pool_usage and + DeallocFromPoolLine.is_last_pool_usage fields so that pools + are created/destroyed. + """ + seen = set() + for line in lines: + if isinstance(line, AllocFromPoolLine): + assert line.group.allocation + pool = line.group.allocation.pool + assert pool is not None + if pool not in seen: + line.is_first_pool_usage = True + seen.add(pool) + + seen = set() + for line in reversed(lines): + if isinstance(line, DeallocFromPoolLine): + assert line.group.allocation + pool = line.group.allocation.pool + assert pool is not None + if pool not in seen: + line.is_last_pool_usage = ( + pool.root.get_live_ranges().end <= line.timestep + ) + seen.add(pool) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/multi_kernel.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/multi_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..e03ca8eca95b50e02eaed7c5db2ebd77b5edb440 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/multi_kernel.py @@ -0,0 +1,413 @@ +import logging +import os +from typing import Any, List + +from torch._inductor.metrics import get_metric_table, is_metric_table_enabled + +from .. import config +from ..codecache import PyCodeCache, TritonFuture +from ..utils import cache_on_self, do_bench +from ..virtualized import V +from .common import TensorArg + +log = logging.getLogger(__name__) + + +def get_kernel_argdefs(kernel): + arg_defs, _, _ = kernel.args.python_argdefs() + return arg_defs + + +def _get_all_args(args_list): + all_args = max(args_list, key=len)[:] + for args in args_list: + assert set(args).issubset(set(all_args)), f"{args} v.s. {all_args}" + + return all_args + + +def get_all_kernel_argdefs(kernels): + """ + The logic here must match with `get_all_call_args`. + """ + argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels] + + return _get_all_args(argdefs_list) + + +def get_all_call_args(call_args_list): + """ + Passed in the call_args for each subkernel and return the call_args for the + combined multi-kernel. + + Note an algorithm as follows does not always work: + ``` + all_call_args: Dict[ + Any, None + ] = {} # use a dict rather than set to maintain insertion order + for call_args in call_args_list: + all_call_args.update({arg: None for arg in call_args}) + + all_call_args = list(all_call_args.keys()) + ``` + It will fail if any kernel has the same argument passed in multiple times. + Check test_pass_same_arg_multi_times in test_multi_kernel.py + + Instead, we pick the longest call args and assert that otehr call args are + a subset of it. + """ + return _get_all_args(call_args_list) + + +def get_numel_argdefs(kernel): + numel_argdefs = [] + for tree in kernel.range_trees: + if tree.prefix != "r" or kernel.inside_reduction: + numel_argdefs.append(f"{tree.prefix}numel") + + return numel_argdefs + + +class MultiKernelState: + """ + Maintain state of multi-kernel compilation so we don't define duplicated + multi-kernel for the same set of sub-kernels. + + V.graph.wrapper_code has a reference to MultiKernelState instance. + """ + + def __init__(self): + self.subkernel_to_kernel_name = {} + + def define_kernel(self, kernels): + """ + Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}". + This has some minor issue. + + E.g. for persistent reduction https://gist.github.com/shunting314/39e7c00ff8bb2055942ed5a3255d61ca , + there are 2 flavors of non-persistent reduction: + https://gist.github.com/shunting314/056d43d35907e87efb883970b35c17d4 + and + https://gist.github.com/shunting314/02ee753b65c513c54e695626afe682bd + + The only different is cache eviction policy. + + We should name the multi-kernel differently in these 2 cases. + """ + kernel_names = tuple(k.kernel_name for k in kernels) + if kernel_names in self.subkernel_to_kernel_name: + return self.subkernel_to_kernel_name[kernel_names] + + # name the multi kernel based on the first kernel + multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}" + self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name + + if V.graph.cpp_wrapper: + # we should not generate any python code for multi-kernel during + # the second pass of cpp-wrapper. + return multi_kernel_name + + wrapper = V.graph.wrapper_code + + kernel_call_def_code = "\n".join( + [ + f""" + def call{idx}(need_clone_args=False): + args = [{', '.join(get_kernel_argdefs(kernels[idx]))}] + if need_clone_args: + args, _ = multi_kernel_call.kernels[{idx}].clone_args(*args) + multi_kernel_call.kernels[{idx}].run(*args, {', '.join(get_numel_argdefs(kernels[idx]))}, grid=grid, stream=stream) + """.format( + idx + ).strip( + "\n" + ) + for idx in range(len(kernels)) + ] + ) + + # add subkernel src code hashes to the multi-kernel source code so changing a + # subkernel implementation will result in a differnt py file for + # multi-kernel. This makes cache implementation straightforward since + # we can decide cache file name based on multi-kernel py file name + # directly. + # + # Without the hash added for subkernels, the cache file may be shared by + # different subkernels which is incorrect. + subkernel_hashes = "\n".join( + f"# subkernel{i} code hash: {kernel.code_hash}" + for i, kernel in enumerate(kernels) + ) + + src_code = f""" +{subkernel_hashes} +def run(multi_kernel_call, {', '.join(get_all_kernel_argdefs(kernels))}, {', '.join(get_numel_argdefs(kernels[0]))}, grid, stream): +{kernel_call_def_code} + multi_kernel_call.run_with_argless_kernels([call0, call1]) + """ # noqa: B950 line too long + wrapper.header.splice( + f""" + {multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, [ + {", ".join(kernel_names)}, + ], + ''' + """ + ) + wrapper.header.splice(src_code) + wrapper.header.splice( + """ + ''' + ) + """ + ) + + return multi_kernel_name + + +class MultiKernel: + """ + This class maintains the compile time state for multi kernels. + + Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2. + The generated definition for the multi-kernel will looks like: + ``` + multi_kernel_kernel1 = MultiKernelCall([kernel1, kernel2], multi_kernel_definition_code) + ``` + + Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39 + """ + + def __init__(self, kernels): + assert len(kernels) >= 2 + + self.kernels = kernels + self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel( + kernels + ) + + # need this since some code in inductor check if the kernel object has an args + # attribute to decide if it's a non-null kernel. + self.args = object() + + def call_kernel(self, kernel_name): + """ + Collect the union of arguments from all subkernels as the arguments + for the multi-kernel. + """ + assert kernel_name == self.kernel_name + call_args_list = [kernel.get_call_args() for kernel in self.kernels] + + all_call_args = get_all_call_args(call_args_list) + grid: List[Any] = [] + + if V.graph.cpp_wrapper: + # for the second pass of cpp-wrapper codegen, we should call + # the fast kernel directly + picked_kernel = MultiKernelCall.lookup_choice(kernel_name) + kernel_name = self.kernels[picked_kernel].kernel_name + final_call_args = call_args_list[picked_kernel] + else: + final_call_args = all_call_args + + # numels for all subkernels should be the same. Use kernels[0] here + self.kernels[0].add_numel_to_call_args_and_grid( + kernel_name, final_call_args, grid + ) + + grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid) + + V.graph.wrapper_code.generate_kernel_call( + kernel_name, + final_call_args, + grid, + V.graph.scheduler.current_device.index, + ) + + def codegen_nan_check(self): + wrapper = V.graph.wrapper_code + seen = set() + for k in self.kernels: + _, call_args, arg_types = k.args.python_argdefs() + for arg, arg_type in zip(call_args, arg_types): + if arg in seen: + continue + seen.add(arg) + if isinstance(arg_type, TensorArg): + line = f"assert not {arg}.isnan().any().item()" + wrapper.writeline(line) + line = f"assert not {arg}.isinf().any().item()" + wrapper.writeline(line) + + @property + def removed_buffers(self): + return set.intersection(*[k.removed_buffers for k in self.kernels]) + + @property + def inplaced_to_remove(self): + return set.intersection(*[k.inplaced_to_remove for k in self.kernels]) + + @property + @cache_on_self + def inplace_update_buffers(self): + """ + Make sure all kernels have the same inplace update mappings. + """ + for k in self.kernels[1:]: + assert k.inplace_update_buffers == self.kernels[0].inplace_update_buffers + return self.kernels[0].inplace_update_buffers + + def warn_mix_layout(self, kernel_name: str): + pass + + +class MultiKernelCall: + """ + This class is called at run time to actually run the kernel + """ + + def __init__(self, multi_kernel_name, kernels, src_code): + assert len(kernels) >= 2 + self._kernels = kernels + self.multi_kernel_name = multi_kernel_name + + self._run = PyCodeCache.load(src_code).run + self.disable_cache = os.environ.get( + "TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE" + ) == "1" or is_metric_table_enabled("persistent_red_perf") + + self.picked_kernel = None + if config.triton.multi_kernel > 1: + # manually force a subkernel to ease perf testing + picked_by_config = config.triton.multi_kernel - 2 + assert picked_by_config < len(self._kernels) + self.picked_kernel = picked_by_config + elif not self.disable_cache: + self.load_cache() + + self._recorded = False + + def cache_file_path(self): + py_file_path = self._run.__globals__["__file__"] + return os.path.splitext(py_file_path)[0] + ".picked_kernel" + + def load_cache(self): + assert self.picked_kernel is None + path = self.cache_file_path() + if os.path.exists(path): + with open(path) as fd: + self.picked_kernel = int(fd.read()) + assert self.picked_kernel >= 0 and self.picked_kernel < len( + self._kernels + ) + log.debug( + "Load picked kernel %d from cache file %s", self.picked_kernel, path + ) + + def store_cache(self): + assert self.picked_kernel is not None + path = self.cache_file_path() + with open(path, "w") as fd: + fd.write(str(self.picked_kernel)) + log.debug("Store picked kernel %d to cache file %s", self.picked_kernel, path) + + @property + def kernels(self): + """ + Read results from future. + + This should be called after parallel compilation is done. + In case you call this before compilation is done, + it may slow down the parallel compilation. + """ + for i, kernel in enumerate(self._kernels): + if isinstance(kernel, TritonFuture): + self._kernels[i] = kernel.result() + + return self._kernels + + def run(self, *args, **kwargs): + self._run(self, *args, **kwargs) + + @staticmethod + def benchmark_sub_kernels(kernel_calls): + """ + Benchmark all the sub kernels and return the execution time + (in milliseconds) for each of time. + + Unit test may mock this method to force a specific kernel to + be picked. + """ + return [ + do_bench(lambda: kernel_call(True), rep=40, fast_flush=True) + for kernel_call in kernel_calls + ] + + # record_choice and lookup_choice are helper functions for cpp-wrapper + # codegen. The first pass use record_choice to keep the choice and + # the second pass do lookup by calling lookup_choice. + # + # An alternative that reused the multi-kernel cache does not work well + # since during codegen of the second pass, it's very hard to know the + # path for the cache file. Also reading the cache file need do some IO + # which can be slower. + @staticmethod + def record_choice(multi_kernel_name, choice): + """ + Record the multi-kernel choice for cpp-wrapper first pass codegen + for the second pass. + + We should do nothing if this function is not called during codegen. + """ + from torch._inductor.graph import GraphLowering + + if not isinstance(V.graph, GraphLowering): + return + + if not V.graph.record_multi_kernel_choice: + return + + V.graph.multi_kernel_to_choice[multi_kernel_name] = choice + + @staticmethod + def lookup_choice(multi_kernel_name): + # this should always been done during cpp-wrapper codegen + assert V.graph.record_multi_kernel_choice + # there should be no miss + return V.graph.multi_kernel_to_choice[multi_kernel_name] + + def run_with_argless_kernels(self, kernel_calls): + if self.picked_kernel is None: + timings = self.benchmark_sub_kernels(kernel_calls) + self.picked_kernel = timings.index(min(timings)) + k0 = self.kernels[0] + log.debug( + "pick %dth sub-kernel in %s. Size hints %s. Reduction hint %s. Timings %s", + self.picked_kernel, + [k.inductor_meta.get("kernel_name") for k in self.kernels], + k0.size_hints, + k0.inductor_meta.get("reduction_hint"), + timings, + ) + + def get_kernel_path(k): + return k.fn.fn.__code__.co_filename + + get_metric_table("persistent_red_perf").add_row( + lambda: { + "kernel1_name": get_kernel_path(self.kernels[0]), + "kernel2_name": get_kernel_path(self.kernels[1]), + "kernel1_latency": timings[0], + "kernel2_latency": timings[1], + "size_hints": k0.size_hints, + "reduction_hint": k0.inductor_meta.get("reduction_hint"), + "speedup": timings[1] / timings[0], + } + ) + + if not self.disable_cache: + self.store_cache() + + if not self._recorded: + self._recorded = True + self.record_choice(self.multi_kernel_name, self.picked_kernel) + kernel_calls[self.picked_kernel]() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_foreach.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_foreach.py new file mode 100644 index 0000000000000000000000000000000000000000..449af125d89e2165c6b9850fa37e3e5c663398ff --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_foreach.py @@ -0,0 +1,250 @@ +import itertools +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List, Tuple + +from sympy import Integer + +import torch + +from .. import metrics +from ..scheduler import SchedulerNode +from ..utils import ceildiv, Placeholder +from ..virtualized import V +from .common import IndentedBuffer, Kernel +from .triton import gen_common_triton_imports, TritonKernel +from .triton_utils import config_of, signature_to_meta + + +@dataclass +class PartitionState: + partitions: List[ + List[Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]] + ] + cur_partition: List[ + Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer] + ] + cur_count: int + + def finalize(self): + if self.cur_partition: + self.partitions.append(self.cur_partition) + + +class ForeachKernel(Kernel): + MAX_NUM_ARGS = 250 # number where I would no longer get triton errors + + @staticmethod + def _update_partition(partition_state, node_rw_count, node_info): + if partition_state.cur_count + node_rw_count > ForeachKernel.MAX_NUM_ARGS: + partition_state.partitions.append(partition_state.cur_partition) + partition_state.cur_partition = [node_info] + partition_state.cur_count = node_rw_count + else: + partition_state.cur_count += node_rw_count + partition_state.cur_partition.append(node_info) + + @staticmethod + def horizontal_partition(subkernel_nodes, triton_scheduling): + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel) + for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args + (read/writes) and to have the same 2D or 1D blocking strategy.""" + assert len(subkernel_nodes) >= 1 + + partition_state_1d = PartitionState([], [], 0) + yelem_to_partition_state_2d: Dict[Integer, PartitionState] = defaultdict( + lambda: PartitionState([], [], 0) + ) + + for node in subkernel_nodes: + fused_nodes = node.get_nodes() + _, (numel, rnumel) = max( + fused_nodes, key=lambda x: int(x.is_reduction()) + ).group + tiled_groups = triton_scheduling.select_tiling(fused_nodes, numel, rnumel) + node_info = fused_nodes, tiled_groups, numel, rnumel + + read_writes = node.read_writes + read_write_count = len(read_writes.reads) + len(read_writes.writes) + + if tiled_groups[1] == 1: + ForeachKernel._update_partition( + partition_state_1d, read_write_count, node_info + ) + else: + y_elem = tiled_groups[0] + partition_state_2d = yelem_to_partition_state_2d[y_elem] + ForeachKernel._update_partition( + partition_state_2d, read_write_count, node_info + ) + + partition_state_1d.finalize() + all_partitions = partition_state_1d.partitions + for partition_state_2d in yelem_to_partition_state_2d.values(): + partition_state_2d.finalize() + all_partitions.extend(partition_state_2d.partitions) + + return all_partitions + + def __init__(self): + super().__init__() + self.blocking_2d = False + self.block_size_1d = 1024 # Try tuning this value + self.block_size_2d = 32 + self.num_warps = 8 + self.sub_kernels = [] + self.iter_vars_count = itertools.count() + self.x_block_count = 0 + self.y_block_count = 0 + + def get_block_size(self): + if self.blocking_2d: + return self.block_size_2d + else: + return self.block_size_1d + + @staticmethod + def codegen_pid_offsets(code, block_count, lower_bound, prefix): + if block_count == 0: + code.splice(f"{prefix}pid_offset = {prefix}pid") + else: + code.splice(f"{prefix}pid_offset = {prefix}pid - {lower_bound}") + + def codegen_pid_range(self, code, x_elems): + num_x_blocks = ceildiv(x_elems, self.get_block_size()) + upper_bound_x_pid = self.x_block_count + num_x_blocks + lower_bound_x_pid = self.x_block_count + + if self.x_block_count == 0: + cond = "if" + else: + cond = "elif" + + x_pid_bounds_check = ( + f"xpid >= {lower_bound_x_pid} and xpid < {upper_bound_x_pid}" + ) + code.splice(f"{cond} {x_pid_bounds_check}:") + + with code.indent(): + ForeachKernel.codegen_pid_offsets( + code, num_x_blocks, lower_bound_x_pid, "x" + ) + self.x_block_count += num_x_blocks + + def create_sub_kernel(self, *groups, index_dtype, mutations, reduction_hint): + sub_kernel = TritonKernel( + *groups, + index_dtype=index_dtype, + mutations=mutations, + pid_cache={ + "tl.program_id(0)": "xpid_offset", + "tl.program_id(1)": "ypid", + }, + reduction_hint=reduction_hint, + ) + if self.blocking_2d: + assert len(groups) == 3 + + self.blocking_2d |= groups[1] != 1 and len(groups) == 3 + metrics.generated_kernel_count -= 1 + sub_kernel.args = self.args + sub_kernel.iter_vars_count = self.iter_vars_count + sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids + self.sub_kernels.append(sub_kernel) + return sub_kernel + + def jit_lines(self): + can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels) + size_dtype = "tl.int32" if can_use_32bit else "tl.int64" + _, _, signature = self.args.python_argdefs() + triton_meta = { + "signature": signature_to_meta(signature, size_dtype=size_dtype), + "device": V.graph.scheduler.current_device.index, + "device_type": V.graph.scheduler.current_device.type, + "constants": {}, + } + triton_meta["configs"] = [config_of(signature)] + inductor_meta = { + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "backend_hash": torch.utils._triton.triton_hash_with_backend(), + } + return f""" + @triton_heuristics.foreach( + num_warps={self.num_warps}, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + ) + @triton.jit + """ + + def grid(self): + return ( + self.x_block_count, + ceildiv(int(self.sub_kernels[0].numels[0]), self.block_size_2d) + if self.blocking_2d + else 1, + 1, + ) + + def codegen_kernel(self, name=None): + code = IndentedBuffer() + + code.splice(gen_common_triton_imports()) + argdefs, _, _ = self.args.python_argdefs() + code.splice(self.jit_lines()) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):" + ) + + with code.indent(): + code.splice("xpid = tl.program_id(0)") + if self.blocking_2d: + code.splice("ypid = tl.program_id(1)") + code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}") + code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}") + else: + code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}") + + for sub_kernel in self.sub_kernels: + assert len(sub_kernel.numels) <= 3 + # TODO mlazos: support dynamic shapes + numel_ind = 0 if not self.blocking_2d else 1 + self.codegen_pid_range(code, int(sub_kernel.numels[numel_ind])) + with code.indent(): + if self.blocking_2d: + code.splice(f"ynumel = {sub_kernel.numels[0]}") + code.splice(f"xnumel = {sub_kernel.numels[1]}") + else: + code.splice(f"xnumel = {sub_kernel.numels[0]}") + + sub_kernel.codegen_body() + code.splice(sub_kernel.body) + + code.splice("else:") + with code.indent(): + code.splice("pass") + + return code.getvalue() + + def call_kernel(self, code, name: str): + _, call_args, _ = self.args.python_argdefs() + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + if V.graph.cpp_wrapper: + V.graph.wrapper_code.generate_kernel_call( + name, + call_args, + device_index=V.graph.scheduler.current_device.index, + grid=self.grid(), + ) + else: + # TODO: refactor generate_kernel_call + call_args_str = ", ".join(call_args) + stream_name = code.write_get_raw_stream( + V.graph.scheduler.current_device.index + ) + code.writeline( + f"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})" + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_split_scan.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_split_scan.py new file mode 100644 index 0000000000000000000000000000000000000000..7c52b1c0ad329ca0a633c7d66ea5e21ea80c360a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_split_scan.py @@ -0,0 +1,180 @@ +import functools + +from typing import Optional, Set + +from torch._inductor import config, ir + +from torch._inductor.codegen.triton import ( + IterationRangesRoot, + triton_compute_type, + TritonKernel, + TritonKernelOverrides, +) + +from torch._prims_common import prod + +from torch.utils._sympy.functions import CeilDiv + + +class TritonSplitScanKernel(TritonKernel): + """Generates a triton kernel that supports ops.scan calls while also splitting + the reduction dimension over multiple triton programs. + + For this kernel, loop numels will always take the form ``(xdim, rdim)`` + and the grid has the shape ``(CeilDiv(rdim, RBLOCK), xdim)``. Communication + between blocks occurs within a global memory workspace buffer, which + must be zero-filled before launching the kernel. + + Note that generation for ``ops.reduction`` is not supported. + + For details of the communication strategy, see + https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + """ + + def __init__( + self, + *groups, + index_dtype: str, + mutations: Optional[Set[str]] = None, + reduction_hint=ir.ReductionHint.DEFAULT, + min_elem_per_thread=0, + ): + super().__init__( + *groups, + index_dtype=index_dtype, + mutations=mutations, + pid_cache=None, + reduction_hint=reduction_hint, + min_elem_per_thread=min_elem_per_thread, + ) + self.no_x_dim = True + + def initialize_range_tree(self, pid_cache): + prefixes = "yxr" + assert len(self.numels) <= len( + prefixes + ), "z dimension not supported for split scan" + active_prefixes = prefixes[len(prefixes) - len(self.numels) :] + + grid_dims = "rxy" + for numel, prefix in zip(self.numels, active_prefixes): + is_reduction = prefix == "r" + tensor_dim = 0 if is_reduction else None + grid_dim = grid_dims.find(prefix) + self.range_trees.append( + IterationRangesRoot( + f"{prefix}index", + numel, + prefix, + grid_dim, + self, + pid_cache=pid_cache, + is_loop=False, + tensor_dim=tensor_dim, + grid_dim=grid_dim, + ) + ) + for tree in self.range_trees: + tree.codegen_header(self.body) + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise NotImplementedError("NYI TritonSplitDimKernel reductions") + + def scan(self, dtype, combine_fn, value, init): + import triton.language as tl + + compute_type = triton_compute_type(dtype) + compute_type_triton = getattr(tl, compute_type[3:]) + + element_nbits = compute_type_triton.primitive_bitwidth + + scratch_type = "tl.uint32" if element_nbits <= 16 else "tl.uint64" + scratch_type_triton = getattr(tl, scratch_type[3:]) + scratch_elems_per_block = 3 if element_nbits == 64 else 1 + scratch_nbytes_per_block = scratch_elems_per_block * ( + scratch_type_triton.primitive_bitwidth // 8 + ) + + cse_load = functools.partial(self.cse.generate, self.loads) + cse_compute = functools.partial(self.cse.generate, self.compute) + + assert len(self.numels) == 2, "Unexpected tiling" + min_rblock = config.triton.min_split_scan_rblock + max_blocks = prod(self.numels[:-1]) * CeilDiv(self.numels[-1], min_rblock) + nbytes = scratch_nbytes_per_block * max_blocks + scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True) + if offset != 0: + scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}") + runtime_rblocks = cse_load(f"tl.num_programs({self.range_trees[-1].index})") + scratch_base = cse_load( + f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * " + f"{scratch_elems_per_block} * {runtime_rblocks}" + ) + + masks = {f"{tree.prefix}mask" for tree in self.range_trees} + self.filter_masks(masks) + masks = sorted(masks) + if self._load_mask: + masks.append(self._load_mask) + + value = cse_compute(f"{value}.to({compute_type})") + value = cse_compute(f"tl.broadcast_to({value}, {self.dense_size_str()})") + init = cse_compute(f"tl.full([], {init}, {compute_type})") + if masks: + cond = " & ".join(masks) + masked_value = cse_compute(TritonKernelOverrides.where(cond, value, init)) + else: + masked_value = value + + combine_helper_fn = self._lift_helper(combine_fn, 2) + dim = self.triton_tensor_ndim() - 1 + assert dim == 0, "" + + block_sum = cse_compute( + f"tl.reduce({masked_value}, {dim}, {combine_helper_fn})" + ) + exclusive_prefix = self.cse.newvar() + if element_nbits == 64: + self.compute.splice( + f""" + {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback_64( + {scratch_base}, + {block_sum}, + {self.range_trees[-1].get_pid()}, + {combine_helper_fn}, + {init}, + ) + """, + strip=True, + ) + + else: + assert element_nbits <= 32 + value_as_uint_dtype = f"tl.uint{element_nbits}" + + self.compute.splice( + f""" + {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback( + {scratch_base}, + {block_sum}, + {self.range_trees[-1].get_pid()}, + {combine_helper_fn}, + {init}, + DTYPE_VALUE_AS_UINT={value_as_uint_dtype}, + DTYPE_PACK={scratch_type}, + ) + """, + strip=True, + ) + # Compute final cumsum + block_scan = cse_compute( + f"tl.associative_scan({masked_value}, {dim}, {combine_helper_fn})" + ) + return cse_compute(f"{combine_helper_fn}({exclusive_prefix}, {block_scan})") + + def _get_heuristic(self): + return "split_scan" + + def _get_grid_fn(self): + return "split_scan_grid" diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1f58187ca4a4ee3c78b926a1262aee9c079c01b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_utils.py @@ -0,0 +1,130 @@ +from typing import Any, Dict, List, Optional + +import torch + +from .. import config +from ..utils import _type_of, instance_descriptor +from ..virtualized import V +from .common import KernelArgType, SizeArg, TensorArg, WorkspaceArg + + +def signature_of(arg: KernelArgType, *, size_dtype: str) -> str: + if isinstance(arg, TensorArg): + # TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes. + # Related PR: https://github.com/openai/triton/pull/2279/ + if arg.dtype == torch.float8_e4m3fn: + tye = "*fp8e4nv" + elif arg.dtype == torch.float8_e5m2: + tye = "*fp8e5" + elif arg.dtype == torch.float8_e4m3fnuz: + tye = "*fp8e4b8" + elif arg.dtype == torch.float8_e5m2fnuz: + tye = "*fp8e5b16" + else: + tye = _type_of(arg.dtype) + if V.graph.is_unspec_arg(arg.buffer): + # had unwrapped 0d tensor as scalar + new_tye = tye.lstrip("*") + if new_tye in ["fp16", "bf16"]: + return "fp32" + else: + return new_tye + else: + return tye + if isinstance(arg, SizeArg): + if arg.expr is None: + # From triton/runtime/jit.py + # `None` is nullptr. Implicitly convert to *i8. + return "*i8" + elif isinstance(arg.expr, float): + return "fp32" + if size_dtype == "tl.int32": + return "i32" + elif size_dtype == "tl.int64": + return "i64" + else: + raise NotImplementedError(f"unhandled size_dtype {size_dtype}") + if isinstance(arg, WorkspaceArg): + return "*i8" + raise NotImplementedError(f"unhandled {type(arg)}: {arg}") + + +def signature_to_meta( + signature: List[KernelArgType], + *, + size_dtype: str, + indices: Optional[List[int]] = None, +) -> Dict[int, str]: + if indices is None: + indices = list(range(len(signature))) + return { + i: signature_of(arg, size_dtype=size_dtype) + for i, arg in zip(indices, signature) + } + + +def config_of( + args: List[KernelArgType], + *, + indices: Optional[List[int]] = None, +) -> Any: + if indices is None: + indices = list(range(len(args))) + + def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool: + """ + Roughly follow triton code here: + https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222 + """ + if isinstance(x, TensorArg): + if include_tensor: + offset_aligned = V.graph.sizevars.statically_known_multiple_of( + x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type] + ) + return offset_aligned and not V.graph.scheduler.is_unaligned_buffer( + x.buffer + ) + else: + return False + if isinstance(x, SizeArg): + # TODO(voz): These are kinda redundant, if we can solve out statically_known_multiple_of with + # _maybe_evaluate_static... + if x.name.startswith("load_seed_offset"): + return False + if x.expr is None: + return False + if isinstance(x.expr, float): + return False + return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type] + if isinstance(x, WorkspaceArg): + return V.graph.sizevars.statically_known_multiple_of(x.nbytes, alignment) # type: ignore[arg-type] + raise NotImplementedError(f"unhandled {type(x)}: {x}") + + if config.triton.divisible_by_16: + divisible_by_16 = tuple( + i + for i, arg in zip(indices, args) + if is_aligned(arg, alignment=16, include_tensor=True) + ) + else: + divisible_by_16 = () + divisible_by_8 = tuple( + i + for i, arg in zip(indices, args) + if is_aligned(arg, alignment=8, include_tensor=False) + ) + + equal_to_1 = tuple( + i + for i, arg in zip(indices, args) + if isinstance(arg, SizeArg) + and arg.expr is not None + and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type] + ) + # ids_of_folded_args is set from equal_to_1 + # and None args by the Triton compiler + ids_of_folded_args = tuple(equal_to_1) + + return instance_descriptor( + divisible_by_16, equal_to_1, ids_of_folded_args, divisible_by_8 + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/wrapper.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..fa5274f28bbe0034747112fba7e732570249eb8c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/wrapper.py @@ -0,0 +1,1543 @@ +import collections +import contextlib +import dataclasses +import functools +import inspect +import operator +import re +from itertools import count +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) + +import sympy +from sympy import Expr + +import torch +import torch._ops +from torch._dynamo.utils import counters, dynamo_timed + +from torch._inductor.codegen.multi_kernel import MultiKernelState +from torch.fx.experimental.symbolic_shapes import SymTypes +from torch.fx.node import _get_qualified_name +from torch.utils._sympy.singleton_int import SingletonInt + +from .. import codecache, config, ir +from ..ir import ReinterpretView +from ..utils import ( + cache_on_self, + get_benchmark_name, + LineContext, + sympy_product, + sympy_str, +) +from ..virtualized import V +from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter +from .triton_utils import config_of, signature_to_meta + +if TYPE_CHECKING: + import triton + + from ..graph import GraphLowering + + +pexpr = PythonPrinter().doprint + + +ReuseKey = Tuple[torch.device, torch.dtype, str] + + +def buffer_reuse_key(node: ir.Buffer) -> ReuseKey: + return ( + node.get_device(), + node.get_dtype(), + # NB: this is symbolic so that we don't try to reuse a buffer + # for s0 for s1, just because they happen to share the same + # size hint + sympy_str(V.graph.sizevars.simplify(node.layout.storage_size())), + ) + + +def convert_arg_type(arg: torch.Argument) -> str: + from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP + + # use x.real_type instead of x.type so that we get ScalarType instead of int + python_type = repr(arg.real_type) # type: ignore[attr-defined] + + if python_type == "Tensor": + # Conversions rules follow https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#func + if arg.alias_info is not None and arg.alias_info.is_write: + return f"at::{python_type}&" + else: + return f"at::{python_type} const&" + + if python_type in PYTHON_TO_CPP: + cpp_type = PYTHON_TO_CPP[python_type] + return cpp_type + + # Convert args of container types e.g. Optional[*] + for py_container, cpp_container in CONTAINER_PYTHON_TO_CPP.items(): + container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type) + if len(container_match) == 1: + contained_type = container_match[0] + assert ( + contained_type in PYTHON_TO_CPP + ), f"unsupported {py_container} type in convert_arg_type: {contained_type}" + cpp_contained_type = PYTHON_TO_CPP[contained_type] + return f"{cpp_container}<{cpp_contained_type}>" + + raise AssertionError(f"unsupport python_type: {python_type}") + + +def convert_return_type(ret: torch.Argument) -> str: + # use x.real_type instead of x.type so that we get ScalarType instead of int + python_type = repr(ret.real_type) # type: ignore[attr-defined] + python_to_cpp = { + "Tensor": "at::Tensor", + "List[Tensor]": "std::vector", + } + + cpp_type = python_to_cpp.get(python_type, None) + assert cpp_type is not None, f"NYI return type: {python_type}" + # An output aliasing an input is returned by reference only when it's a + # Tensor, not when it's a Tensor[]. For example, aten.split.Tensor's output + # aliases the input tensor, but the op returns a vector by value. + if python_type == "Tensor" and ret.alias_info is not None: + cpp_type += "&" + return cpp_type + + +def get_cpp_op_schema(kernel: torch._ops.OpOverload) -> str: + args = kernel._schema.arguments + returns = kernel._schema.returns + + num_returns = len(returns) + assert num_returns > 0, "must have at least one return value" + + if num_returns == 1: + cpp_return_value = convert_return_type(returns[0]) + elif num_returns > 1: + tuple_returns = ", ".join([convert_return_type(r) for r in returns]) + cpp_return_value = f"std::tuple<{tuple_returns}>" + + cpp_arg_type = [f"{convert_arg_type(arg)} {arg.name}" for arg in args] + return f"{cpp_return_value}({', '.join(cpp_arg_type)})" # type: ignore[possibly-undefined] + + +# TODO: Move to a well known place +TritonMetaParams = Dict[str, int] +TritonGrid = Union[ + Tuple[Union[int, sympy.Expr], ...], Callable[[TritonMetaParams], Tuple[int, ...]] +] + + +def user_defined_kernel_grid_fn_code( + name: str, + configs: List["triton.Config"], + grids: List[TritonGrid], + wrapper: Optional["WrapperCodeGen"] = None, +) -> Tuple[str, str]: + output = IndentedBuffer() + + def _convert_to_sympy_expr(item: Union[int, sympy.Expr]) -> sympy.Expr: + return item if isinstance(item, sympy.Expr) else sympy.Integer(item) + + def determine_grid(grid: TritonGrid): + if wrapper is None or callable(grid): + # return as-is when used in eager mode or when grid is callable + return grid + # Grid contains ints/Expr, so utilize wrapper's expr printer for codegen + sympy_grid = tuple(_convert_to_sympy_expr(g) for g in grid) + return wrapper.codegen_shape_tuple(sympy_grid) + + fn_name = f"grid_wrapper_for_{name}" + output.writeline(f"def {fn_name}(meta):") + with output.indent(): + if len(grids) == 1: + grid = determine_grid(grids[0]) + output.writeline(f"return {grid}") + else: + assert len(grids) > 1 + assert len(grids) == len(configs) + seen = set() + for grid, c in zip(grids, configs): + guards = [f"meta['{name}'] == {val}" for name, val in c.kwargs.items()] + guards = " and ".join(guards) + grid = determine_grid(grid) + statement = f"if {guards}: return {grid}" + if statement in seen: + continue + seen.add(statement) + output.writeline(statement) + + return fn_name, output.getvalue() + + +@dataclasses.dataclass +class SymbolicCallArg: + inner: str + # the original symbolic expression represented by inner + inner_expr: sympy.Expr + + def __str__(self): + return str(self.inner) + + +# Default thread stack sizes vary by platform: +# - Linux: 8 MB +# - macOS: 512 KB +# - Windows: 1 MB +# Just pick something comfortably smaller than the smallest for now. +MAX_STACK_ALLOCATION_SIZE = 1024 * 100 + + +class MemoryPlanningState: + def __init__(self): + super().__init__() + self.reuse_pool: Dict[ + ReuseKey, List[FreeIfNotReusedLine] + ] = collections.defaultdict(list) + self.total_allocated_buffer_size: int = 0 + + def __contains__(self, key: ReuseKey) -> bool: + return bool(self.reuse_pool.get(key, None)) + + def pop(self, key: ReuseKey) -> "FreeIfNotReusedLine": + item = self.reuse_pool[key].pop() + assert not item.is_reused + return item + + def push(self, key: ReuseKey, item: "FreeIfNotReusedLine") -> None: + assert not item.is_reused + self.reuse_pool[key].append(item) + + +class WrapperLine: + pass + + +@dataclasses.dataclass +class EnterSubgraphLine(WrapperLine): + wrapper: "WrapperCodeGen" + graph: "GraphLowering" + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper.push_codegened_graph(self.graph) + code.do_indent() + + +@dataclasses.dataclass +class ExitSubgraphLine(WrapperLine): + wrapper: "WrapperCodeGen" + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper.pop_codegened_graph() + code.do_unindent() + + +@dataclasses.dataclass +class EnterDeviceContextManagerLine(WrapperLine): + device_idx: int + last_seen_device_guard_index: Optional[int] + + def codegen(self, code: IndentedBuffer) -> None: + if V.graph.cpp_wrapper: + code.writeline("\n") + if V.graph.aot_mode: + # In AOT mode, we have a stream provided as a param. A stream is + # associated with a device, so we never expect the device to change. + # CUDAStreamGuard sets the stream and the device. + if self.last_seen_device_guard_index is None: + if config.abi_compatible: + code.writeline( + "AOTICudaStreamGuard stream_guard(stream, this->device_idx_);" + ) + else: + code.writeline( + "at::cuda::CUDAStreamGuard stream_guard(" + + "at::cuda::getStreamFromExternal(stream, this->device_idx_));" + ) + else: + assert ( + self.last_seen_device_guard_index == self.device_idx + ), "AOTInductor only supports running on one CUDA device" + else: + if self.last_seen_device_guard_index is None: + code.writeline( + f"AOTICudaGuard device_guard({self.device_idx});" + if config.abi_compatible + else f"at::cuda::CUDAGuard device_guard({self.device_idx});" + ) + else: + code.writeline(f"device_guard.set_index({self.device_idx});") + else: + # Note _DeviceGuard has less overhead than device, but only accepts + # integers + code.writeline(f"with {V.graph.device_ops.device_guard(self.device_idx)}:") + code.do_indent() + code.writeline(V.graph.device_ops.set_device(self.device_idx)) + + +class ExitDeviceContextManagerLine(WrapperLine): + def codegen(self, code: IndentedBuffer) -> None: + if not V.graph.cpp_wrapper: + code.do_unindent() + + +@dataclasses.dataclass +class MemoryPlanningLine(WrapperLine): + wrapper: "WrapperCodeGen" + + def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine": + """First pass to find reuse""" + return self + + def codegen(self, code: IndentedBuffer) -> None: + """Second pass to output code""" + pass + + def __str__(self) -> str: + """ + Emits a string representation that fits on one line. + """ + args: List[str] = [] + for field in dataclasses.fields(self): + if field.name == "wrapper": + continue + val = getattr(self, field.name) + args.append( + f"{field.name}={val.get_name() if field.type is ir.Buffer else val}" + ) + return f"{type(self).__name__}({', '.join(args)})" + + +@dataclasses.dataclass +class AllocateLine(MemoryPlanningLine): + node: ir.Buffer + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if self.node.get_name() in V.graph.removed_buffers: + return NullLine(self.wrapper) + + # try to reuse a recently freed buffer + key = buffer_reuse_key(self.node) + if config.allow_buffer_reuse and key in state: + free_line = state.pop(key) + free_line.is_reused = True + return ReuseLine(self.wrapper, free_line.node, self.node) + + if self.node.get_device().type == "cpu": + static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node) + if static_shape is not None: + state.total_allocated_buffer_size += int( + functools.reduce(operator.mul, static_shape, 1) + ) + + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + line = self.wrapper.make_buffer_allocation(self.node) + code.writeline(line) + + +@dataclasses.dataclass +class FreeIfNotReusedLine(MemoryPlanningLine): + node: ir.Buffer + is_reused: bool = False + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if isinstance(self.node.layout, (ir.AliasedLayout, ir.MultiOutputLayout)): + return self + assert not self.is_reused + if self.node.get_name() in V.graph.removed_buffers: + return NullLine(self.wrapper) + if config.allow_buffer_reuse: + state.push(buffer_reuse_key(self.node), self) + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + if not self.is_reused: + code.writeline(self.wrapper.make_buffer_free(self.node)) + + +@dataclasses.dataclass +class ReuseLine(MemoryPlanningLine): + node: ir.Buffer + reused_as: ir.Buffer + delete_old: bool = True + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if self.node.get_name() in V.graph.removed_buffers: + assert self.reused_as.get_name() in V.graph.removed_buffers + return NullLine(self.wrapper) + assert self.reused_as.get_name() not in V.graph.removed_buffers + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + assert self.reused_as.get_name() not in V.graph.removed_buffers + code.writeline( + self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old) + ) + + +class NullLine(MemoryPlanningLine): + pass + + +BufferName = str + + +class WrapperCodeGen(CodeGen): + """ + Generate outer wrapper in Python that calls the kernels. + """ + + def __init__(self): + super().__init__() + self._names_iter: Iterator[int] = count() + self.header = IndentedBuffer() + self.prefix = IndentedBuffer() + self.suffix = IndentedBuffer() + self.wrapper_call = IndentedBuffer() + # If the generated source code is exactly the same, reuse the + # pre-existing kernel for it + self.src_to_kernel: Dict[str, str] = {} + self.kernel_numel_expr: Set[Tuple[str, "GraphLowering"]] = set() + self.lines: List[Union[MemoryPlanningLine, LineContext]] = [] + self.declare = "" + self.declare_maybe_reference = "" + self.ending = "" + self.open_bracket = "[" + self.closed_bracket = "]" + self.comment = "#" + self.namespace = "" + self.none_str = "None" + self.size = "size()" + self.stride = "stride()" + self.last_seen_device_guard_index: Optional[int] = None + self.supports_intermediate_hooks = True + self.expr_printer = pexpr + self.user_defined_kernel_cache: Dict[Tuple[Any, ...], Tuple[str, Any]] = {} + self.unbacked_symbol_decls: Set[str] = set() # str of sympy.Symbol + self.allow_stack_allocation: Optional[bool] = None + self.stack_allocated_buffers: Dict[BufferName, ir.Buffer] = {} + self.computed_sizes: Set[sympy.Symbol] = set() + + # this is used for tracking which GraphLowering instance---parent graph + # or (nested) subgraph---is currently codegened; the primary use case is + # including the graph instance into a cache key to avoid cross-graph + # caching during lowering of nested subgraphs + self.codegened_graph_stack = [V.graph] + + self.write_header() + self.write_prefix() + + if not V.graph.aot_mode: + for name, hashed in V.graph.constant_reprs.items(): + # include a hash so our code cache puts different constants into different files + self.write_constant(name, hashed) + + self.allocated: Set[BufferName] = set() + self.freed: Set[BufferName] = set() + + # maps from reusing buffer to reused buffer + self.reuses: Dict[BufferName, BufferName] = dict() + + self.write_get_raw_stream = functools.lru_cache(None)( # type: ignore[assignment] + self.write_get_raw_stream + ) + + @functools.lru_cache(None) + def add_import_once(line: str) -> None: + self.header.writeline(line) + + self.add_import_once = add_import_once + self._metas: Dict[str, str] = {} + self.multi_kernel_state = MultiKernelState() + + def write_constant(self, name: str, hashed: str) -> None: + self.header.writeline(f"{name} = None # {hashed}") + + def write_header(self) -> None: + self.header.splice( + f""" + from ctypes import c_void_p, c_long + import torch + import math + import random + import os + import tempfile + from math import inf, nan + from torch._inductor.hooks import run_intermediate_hooks + from torch._inductor.utils import maybe_profile + from torch._inductor.codegen.memory_planning import _align as align + + from torch import device, empty_strided + from {codecache.__name__} import AsyncCompile + from torch._inductor.select_algorithm import extern_kernels + from torch._inductor.codegen.multi_kernel import MultiKernelCall + + aten = torch.ops.aten + inductor_ops = torch.ops.inductor + assert_size_stride = torch._C._dynamo.guards.assert_size_stride + empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu + empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda + alloc_from_pool = torch.ops.inductor._alloc_from_pool + reinterpret_tensor = torch.ops.inductor._reinterpret_tensor + async_compile = AsyncCompile() + + """ + ) + + @cache_on_self + def write_triton_header_once(self) -> None: + self.header.splice( + """ + import triton + import triton.language as tl + from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph + {} + """.format( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + ) + + def add_meta_once(self, meta: TritonMetaParams) -> str: + meta = repr(meta) + if meta not in self._metas: + var = f"meta{len(self._metas)}" + self._metas[meta] = var + self.header.writeline(f"{var} = {meta}") + return self._metas[meta] + + @cache_on_self + def get_output_refs(self) -> List[str]: + return [x.codegen_reference(self.wrapper_call) for x in V.graph.graph_outputs] + + def mark_output_type(self) -> None: + return + + def codegen_input_size_asserts(self) -> None: + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(buf.get_size()) == 0: + continue + size = self.codegen_shape_tuple(buf.get_size()) + stride = self.codegen_shape_tuple(buf.get_stride()) + self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})") + + def codegen_input_nan_asserts(self) -> None: + self.prefix.writeline("# make sure graph inputs are not nan/inf") + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + + line = f"assert not {name}.isnan().any().item()" + self.prefix.writeline(line) + line = f"assert not {name}.isinf().any().item()" + self.prefix.writeline(line) + + def write_prefix(self) -> None: + self.prefix.splice( + """ + + async_compile.wait(globals()) + del async_compile + + def call(args): + """ + ) + with self.prefix.indent(): + if config.triton.debug_sync_graph: + self.prefix.writeline(V.graph.device_ops.synchronize()) + if V.graph.graph_inputs: + lhs = ", ".join(V.graph.graph_input_names) + if len(V.graph.graph_input_names) == 1: + lhs += "," + self.prefix.writeline(f"{lhs} = args") + self.prefix.writeline("args.clear()") + + self.codegen_inputs(self.prefix, V.graph.graph_inputs) + if config.size_asserts: + self.codegen_input_size_asserts() + if config.nan_asserts: + self.codegen_input_nan_asserts() + + # this function (and below) takes a graph as input so + # that stream caching happens per graph instance. this + # is important for nested subgraph codegening. + def write_get_raw_stream(self, device_idx: int, graph=None) -> str: + self.write_triton_header_once() + name = f"stream{device_idx}" + self.writeline(f"{name} = get_raw_stream({device_idx})") + return name + + def get_codegened_graph(self): + return self.codegened_graph_stack[-1] + + def push_codegened_graph(self, graph): + self.codegened_graph_stack.append(graph) + + def pop_codegened_graph(self): + return self.codegened_graph_stack.pop() + + def next_kernel_suffix(self) -> str: + return f"{next(self._names_iter)}" + + def codegen_device_guard_enter(self, device_idx: int) -> None: + self.writeline( + EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index) + ) + self.last_seen_device_guard_index = device_idx + + def codegen_device_guard_exit(self) -> None: + self.writeline(ExitDeviceContextManagerLine()) + + def generate_return(self, output_refs: List[str]) -> None: + if output_refs: + self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )") + else: + self.wrapper_call.writeline("return ()") + + def generate_before_suffix(self, result: IndentedBuffer) -> None: + return + + def generate_end(self, result: IndentedBuffer) -> None: + return + + def generate_fallback_kernel(self, fallback_kernel, args): + self.generate_extern_kernel_alloc(fallback_kernel, args) + + def generate_extern_kernel_alloc(self, extern_kernel, args): + output_name = extern_kernel.get_name() + origin_node = extern_kernel.get_origin_node() + kernel_name = extern_kernel.get_kernel_name() + ending = self.ending + if config.memory_planning and "view_as_complex" in kernel_name: + # view operation fallbacks cause issues since inductor + # doesn't know the memory is still needed and might reuse it. + ending = f".clone(){ending}" + self.writeline( + f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}" + ) + if ( + self.supports_intermediate_hooks + and config.generate_intermediate_hooks + and origin_node is not None + ): + counters["inductor"]["intermediate_hooks"] += 1 + self.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {output_name})" + ) + + def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel): + if output_view: + args.append(f"out={output_view.codegen_reference()}") + else: + args.append(f"out={codegen_reference}") + self.writeline(f"{kernel}({', '.join(args)})") + + def generate_user_defined_triton_kernel( + self, kernel_name, grid, configs, args, triton_meta + ): + grid, code = user_defined_kernel_grid_fn_code( + kernel_name, configs, grid, wrapper=self + ) + # Must happen after free symbols are already codegened + # Emit the grid wrapper function right before the call + for line in code.split("\n"): + self.writeline(line) + + stream_name = self.write_get_raw_stream( + V.graph.scheduler.current_device.index, V.graph + ) + self.writeline( + f"{kernel_name}.run({', '.join(args)}, grid={grid}, stream={stream_name})" + ) + + def generate_scatter_fallback( + self, output, inputs, kernel, python_kernel_name, src_is_tensor, reduce, kwargs + ): + line = f"{kernel}({','.join(map(str, inputs))}" + if kernel == "aten.scatter_": + if reduce: + line += f", reduce={repr(reduce)}" + else: + line += ", ".join([""] + kwargs) + line += f"){self.ending}" + self.writeline(line) + + def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + indices_str = f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}" + args = [x, indices_str, values, accumulate] + self.writeline(self.wrap_kernel_call(kernel, args)) + + def generate_extern_kernel_alloc_and_find_schema_if_needed( + self, + name, + kernel, + codegen_args, + cpp_op_schema, + cpp_kernel_key, + cpp_kernel_overload_name="", + op_overload=None, + raw_args=None, + outputs=None, + ): + self.writeline(f"{name} = {kernel}({', '.join(codegen_args)})") + + def generate_inf_and_nan_checker(self, node): + # TODO: Add check for python too. + pass + + @dynamo_timed + def generate(self, is_inference): + if config.profile_bandwidth: + self.write_triton_header_once() + result = IndentedBuffer() + result.splice(self.header) + + with contextlib.ExitStack() as stack: + stack.enter_context(self.wrapper_call.indent()) + if config.profiler_mark_wrapper_call: + self.generate_profiler_mark_wrapper_call(stack) + if config.profile_bandwidth: + self.generate_start_graph() + + # We disable planning during training because it presently increases peak memory consumption. + if is_inference and config.memory_planning: + self.memory_plan() + # TODO: integrate memory planning & stack allocation? + self.allow_stack_allocation = False + else: + self.memory_plan_reuse() + + if config.triton.store_cubin: + self.generate_reset_kernel_saved_flags() + + for line in self.lines: + if isinstance(line, WrapperLine): + line.codegen(self.wrapper_call) + else: + self.wrapper_call.writeline(line) + + output_refs = self.get_output_refs() + self.mark_output_type() + if config.triton.debug_sync_graph: + self.wrapper_call.writeline(V.graph.device_ops.synchronize()) + + if config.profile_bandwidth: + self.generate_end_graph() + + if config.triton.store_cubin: + self.generate_save_uncompiled_kernels() + + self.generate_return(output_refs) + + self.finalize_prefix() + result.splice(self.prefix) + + with result.indent(): + result.splice(self.wrapper_call) + + self.generate_before_suffix(result) + result.splice(self.suffix) + + self.generate_end(result) + + self.add_benchmark_harness(result) + + return result.getvaluewithlinemap() + + def memory_plan(self): + from .memory_planning import MemoryPlanner + + self.lines = MemoryPlanner(self).plan(self.lines) + + def memory_plan_reuse(self): + out_names = V.graph.get_output_names() + + while ( + self.lines + and isinstance(self.lines[-1], MemoryPlanningLine) + # TODO: this seems legit, NullLine has no node + and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] + ): + # these lines will be pointless + self.lines.pop() + + # codegen allocations in two passes + planning_states = [MemoryPlanningState()] + past_planning_states = [] + for i in range(len(self.lines)): + line = self.lines[i] + if isinstance(line, MemoryPlanningLine): + self.lines[i] = line.plan(planning_states[-1]) + elif isinstance(line, EnterSubgraphLine): + planning_states.append(MemoryPlanningState()) + elif isinstance(line, ExitSubgraphLine): + past_planning_states.append(planning_states.pop()) + past_planning_states.append(planning_states.pop()) + assert len(planning_states) == 0 + + # conservatively use the sum of all allocated buffer sizes + # in potentially nested scopes as the total allocated size + total_allocated_buffer_size = sum( + s.total_allocated_buffer_size for s in past_planning_states + ) + + self.allow_stack_allocation = ( + self.allow_stack_allocation is not False + and config.allow_stack_allocation + and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE + ) + + def codegen_input_size_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"{self.declare}{name}_size = {name}.{self.size}{self.ending}") + + def codegen_input_stride_var_decl(self, code: IndentedBuffer, name): + code.writeline( + f"{self.declare}{name}_stride = {name}.{self.stride}{self.ending}" + ) + + def codegen_inputs( + self, code: IndentedBuffer, graph_inputs: Dict[str, ir.TensorBox] + ): + """Assign all symbolic shapes to locals""" + + @functools.lru_cache(None) + def sizeof(name): + self.codegen_input_size_var_decl(code, name) + return f"{name}_size" + + @functools.lru_cache(None) + def strideof(name): + self.codegen_input_stride_var_decl(code, name) + return f"{name}_stride" + + # Assign all symbolic shapes needed to local variables + needed = V.graph.sizevars.free_symbols() + + def is_expr(x): + return isinstance(x[1], sympy.Expr) + + graph_inputs_expr = list(filter(is_expr, graph_inputs.items())) + graph_inputs_tensors = list( + filter(lambda x: not is_expr(x), graph_inputs.items()) + ) + + for name, shape in graph_inputs_expr: + shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] + if shape in needed: + needed.remove(shape) # type: ignore[arg-type] + code.writeline(f"{self.declare}{shape} = {name}{self.ending}") + + for name, value in graph_inputs_tensors: + shapes = value.get_size() + for dim, shape in enumerate(shapes): + shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] + if shape in needed: + needed.remove(shape) # type: ignore[arg-type] + code.writeline( + f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}" + ) + + for name, value in graph_inputs_tensors: + shapes = value.get_stride() + for dim, shape in enumerate(shapes): + shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] + if shape in needed: + needed.remove(shape) # type: ignore[arg-type] + code.writeline( + f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}" + ) + + def ensure_size_computed(self, sym: sympy.Symbol): + if isinstance(sym, sympy.Symbol) and sym.name.startswith("ps"): + if sym in self.computed_sizes: + return + self.computed_sizes.add(sym) + expr = V.graph.sizevars.inv_precomputed_replacements[sym] + self.writeline( + f"{self.declare}{sym} = {self.expr_printer(expr)}{self.ending}" + ) + + def finalize_prefix(self): + pass + + def codegen_python_sizevar(self, x: Expr) -> str: + return pexpr(V.graph.sizevars.simplify(x)) + + def codegen_sizevar(self, x: Expr) -> str: + return self.codegen_python_sizevar(x) + + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: + return f"{basename}[{index}]" + + def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: + parts = list(map(self.codegen_python_sizevar, shape)) + if len(parts) == 0: + return "()" + if len(parts) == 1: + return f"({parts[0]}, )" + return f"({', '.join(parts)})" + + def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: + return self.codegen_python_shape_tuple(shape) + + def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + return "alloc_from_pool({})".format( + ", ".join( + [ + name, + pexpr(offset), # bytes not numel + str(dtype), + self.codegen_shape_tuple(shape), + self.codegen_shape_tuple(stride), + ] + ) + ) + + def codegen_reinterpret_view(self, data, size, stride, offset, writer) -> str: + size = self.codegen_shape_tuple(size) + stride = self.codegen_shape_tuple(stride) + offset = self.codegen_sizevar(offset) + return f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})" + + def codegen_device_copy(self, src, dst): + self.writeline(f"{dst}.copy_({src})") + + def codegen_multi_output(self, name, value): + self.writeline(f"{self.declare}{name} = {value}{self.ending}") + + def codegen_dynamic_scalar(self, node): + (data,) = (t.codegen_reference() for t in node.inputs) + if node.is_bool: + self.writeline(f"{node.sym} = 1 if {data}.item() else 0") + else: + self.writeline(f"{node.sym} = {data}.item()") + # No one should ever use this buffer, but for uniformity + # define the variable and assign it None + self.writeline(f"{node.get_name()} = None") + + def benchmark_compiled_module(self, output): + def add_fake_input(name, shape, stride, device, dtype): + output.writeline( + f"{name} = rand_strided(" + f"{self.codegen_python_shape_tuple(shape)}, " + f"{self.codegen_python_shape_tuple(stride)}, " + f"device='{device}', dtype={dtype})" + ) + + def add_expr_input(name, val): + output.writeline(f"{name} = {val}") + + output.writelines( + ["", "", "def benchmark_compiled_module(times=10, repeat=10):"] + ) + with output.indent(): + output.splice( + """ + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + """, + strip=True, + ) + + for name, value in V.graph.constants.items(): + # all the constants are global variables, that's why we need + # these 'global var_name' lines + output.writeline(f"global {name}") + add_fake_input( + name, value.size(), value.stride(), value.device, value.dtype + ) + + for name, value in V.graph.graph_inputs.items(): + if isinstance(value, sympy.Symbol) and isinstance( + V.graph.sizevars.var_to_val.get(value, None), SingletonInt + ): + # Inductor should only work with dense -> dense graph, and + # SingletonInts belong to metadata that should only live on + # the subclass. + continue + if isinstance(value, sympy.Expr): # Don't need to add symbolic + add_expr_input(name, V.graph.sizevars.size_hint(value)) + else: + shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()] + stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()] + add_fake_input( + name, shape, stride, value.get_device(), value.get_dtype() + ) + + call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])" + output.writeline(f"fn = lambda: {call_str}") + output.writeline("return print_performance(fn, times=times, repeat=repeat)") + + def add_benchmark_harness(self, output): + """ + Append a benchmark harness to generated code for debugging + """ + if not config.benchmark_harness: + return + + self.benchmark_compiled_module(output) + + output.writelines(["", "", 'if __name__ == "__main__":']) + with output.indent(): + output.writelines( + [ + "from torch._inductor.wrapper_benchmark import compiled_module_main", + f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)", + ] + ) + + def define_kernel( + self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True + ): + metadata_comment = f"{metadata}\n" if metadata else "" + self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}") + + def define_user_defined_triton_kernel(self, kernel, configs, kwargs): + original_name = kernel.__name__ + + from .common import KernelArgType, SizeArg, TensorArg + + signature: List[KernelArgType] = [] + constants: Dict[int, Any] = {} + non_constant_indices = [] + equal_to_1_arg_idx: List[int] = [] + for idx, key in enumerate(kernel.arg_names): + if key not in kwargs: + continue + arg = kwargs[key] + if idx in kernel.constexprs: + constants[idx] = arg + else: + non_constant_indices.append(idx) + if isinstance(arg, ir.Buffer): + signature.append( + TensorArg( + name=key, + buffer=arg.get_name(), + dtype=arg.get_dtype(), + ) + ) + elif isinstance(arg, ir.ReinterpretView): + # for ReinterpretView we use the underlying + # buffer name and note the (possibly non-zero) + # offset relative to the underlying buffer + signature.append( + TensorArg( + name=key, + buffer=arg.data.get_name(), + dtype=arg.get_dtype(), + offset=arg.layout.offset, + ) + ) + else: + signature.append(SizeArg(key, arg)) + if arg is not None and V.graph.sizevars.statically_known_equals(arg, 1): # type: ignore[arg-type] + equal_to_1_arg_idx.append(idx) + index_dtype = "tl.int32" + triton_meta = { + "signature": signature_to_meta( + signature, + size_dtype=index_dtype, + indices=non_constant_indices, + ), + "device": V.graph.scheduler.current_device.index, + "device_type": V.graph.scheduler.current_device.type, + # Triton compiler includes equal_to_1 args into constants even + # when they are not constexpr. otherwise there may be a segfault + # during launching the Inductor-compiled Triton kernel. + # TODO(aakhundov): add None args to constants, too. currently, this + # causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input. + # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 + # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 + "constants": { + **constants, + **{idx: 1 for idx in equal_to_1_arg_idx}, + }, + "configs": [ + config_of( + signature, + indices=non_constant_indices, + ) + ], + } + + # Distinguish between different functions using function id + cache_key: List[Any] = [id(kernel.fn)] + if len(configs) > 0: + for arg in kwargs.values(): + # We need to key on non tensor arg only in autotune mode + if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)): + cache_key.append(arg) + cache_key.append(str(triton_meta)) + cache_key = tuple(cache_key) + + if cache_key in self.user_defined_kernel_cache: + return self.user_defined_kernel_cache[cache_key] + + name = f"{original_name}_{len(self.user_defined_kernel_cache)}" + # Add to the cache for the next use + self.user_defined_kernel_cache[cache_key] = (name, triton_meta) + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''") + + from .triton import gen_common_triton_imports + + compile_wrapper.splice(gen_common_triton_imports()) + + inductor_meta = { + "kernel_name": name, + "backend_hash": torch.utils._triton.triton_hash_with_backend(), + } + + configs = [ + { + "kwargs": config.kwargs, + "num_warps": config.num_warps, + "num_stages": config.num_stages, + } + for config in configs + ] + + compile_wrapper.splice( + f""" + @triton_heuristics.user_autotune( + configs={configs!r}, + inductor_meta={inductor_meta!r}, + triton_meta={triton_meta!r}, + filename=__file__, + custom_kernel=True, + ) + @triton.jit + """ + ) + compile_wrapper.splice(kernel.src, strip=True) + + # Also include any possible kernel being called indirectly + from triton import JITFunction + + symbols_included = {original_name} + + def traverse(cur_kernel): + for symbol_name in cur_kernel.fn.__code__.co_names: + if symbol_name in symbols_included: + continue + if symbol_name in cur_kernel.fn.__globals__: + symbol = cur_kernel.fn.__globals__[symbol_name] + if isinstance(symbol, JITFunction): + compile_wrapper.newline() + compile_wrapper.writeline("@triton.jit") + compile_wrapper.splice(symbol.src, strip=True) + symbols_included.add(symbol_name) + traverse(symbol) + elif isinstance(symbol, (int, str, bool)): + compile_wrapper.newline() + compile_wrapper.writeline(f"{symbol_name} = {symbol!r}") + symbols_included.add(symbol_name) + + traverse(kernel) + + compile_wrapper.writeline( + f"''', device_str='{V.graph.scheduler.current_device.type}')" + ) + _, lineno = inspect.getsourcelines(kernel.fn) + srcfile = inspect.getsourcefile(kernel.fn) + metadata = f"# Original path: {srcfile}:{lineno}" + self.define_kernel( + name, + compile_wrapper.getvalue(), + metadata, + ) + return name, triton_meta + + def generate_numel_expr(self, kernel_name: str, tree): + expr = f"{kernel_name}_{tree.prefix}numel" + if (expr, V.graph) not in self.kernel_numel_expr: + # declare expr once in each graph (scope) + self.kernel_numel_expr.add((expr, V.graph)) + self.writeline( + f"{self.declare}{expr} = {self.expr_printer(tree.numel)}{self.ending}" + ) + else: + self.writeline(f"{expr} = {self.expr_printer(tree.numel)}{self.ending}") + # We can get symbolic expressions here, like s0*64 + # It is fine to have them here, but we need to handle them correctly as their own type + # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* + # scalars as well. + # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for + # constant now, need type info. I agree, this needs type info, and while this is not true type info + # it suffices as a type hint for the purposes of producing the correct code for this type. + return SymbolicCallArg(expr, tree.numel) + + def generate_workspace_allocation(self, nbytes, device, zero_fill): + line = self.make_allocation( + "workspace", device, torch.uint8, shape=(nbytes,), stride=(1,) + ) + self.writeline(line) + if zero_fill: + self.writeline(f"workspace.zero_(){self.ending}") + + def wrap_kernel_call(self, name, call_args): + return f"{name}({', '.join(call_args)}){self.ending}" + + def generate_profiler_mark_wrapper_call(self, stack): + self.wrapper_call.writeline("from torch.profiler import record_function") + self.wrapper_call.writeline( + f"with record_function('graph_{V.graph.graph_id}_inductor_wrapper_call'):" + ) + stack.enter_context(self.wrapper_call.indent()) + + def generate_start_graph(self): + self.wrapper_call.writeline("start_graph()") + + def generate_end_graph(self): + self.wrapper_call.writeline("end_graph()") + + def generate_reset_kernel_saved_flags(self): + self.wrapper_call.splice( + """ + for kernel in globals().values(): + if isinstance(kernel, torch._inductor.triton_heuristics.CachingAutotuner): + kernel.cuda_kernel_saved = False + """ + ) + + def generate_save_uncompiled_kernels(self): + """ + Precompile and save the CUBINs of the Triton kernels that haven't + been precompiled and saved as a side effect of running the generated + JIT model (Python wrapper). This can happen when the model contains + control flow: only one pass through the control flow operators covers + the kernels that are saved, the remaining kernels are not launched, + hence not saved. The main purpose of this codegen is to compile and + save the Triton kernels outside the active control flow path for + subsequent AOTInductor code generation and compilation. + """ + self.wrapper_call.splice( + """ + for kernel in globals().values(): + if isinstance(kernel, torch._inductor.triton_heuristics.CachingAutotuner): + if not kernel.cuda_kernel_saved: + if len(kernel.launchers) == 0: + kernel.precompile() + kernel.save_cuda_kernel( + grid=(0, 0, 0), # use dummy grid + stream="stream", # use dummy stream + launcher=kernel.launchers[0], + ) + """ + ) + + def generate_default_grid(self, name: str, grid_args: List[Any]): + return grid_args + + def generate_kernel_call( + self, + name, + call_args, + grid=None, + device_index=None, + cuda=True, + triton=True, + arg_types=None, + grid_fn: str = "grid", + triton_meta=None, + ): + """ + Generates kernel call code. + + cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ + if cuda: + call_args_str = ", ".join(pexpr(item) for item in call_args) + stream_name = self.write_get_raw_stream( + V.graph.scheduler.current_device.index, V.graph + ) + if triton: + grid_str = ", ".join(pexpr(item) for item in grid) + grid_str = f"{grid_fn}({grid_str})" + self.writeline( + f"{name}.run({call_args_str}, grid={grid_str}, stream={stream_name})" + ) + else: + stream_ptr = f"c_void_p({stream_name})" + self.writeline(f"{name}.{name}({call_args_str}, {stream_ptr})") + else: + self.writeline(self.wrap_kernel_call(name, call_args)) + + def writeline(self, line): + self.lines.append(line) + + def enter_context(self, ctx): + self.lines.append(LineContext(ctx)) + + def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str: + raise NotImplementedError() + + def val_to_arg_str(self, s): + if isinstance(s, SymTypes): + return pexpr(sympy.expand(repr(s))) + elif isinstance(s, sympy.Expr): + return pexpr(s) + elif isinstance(s, (tuple, list)): + + @dataclasses.dataclass + class Shim: + ref: Any + + def __repr__(self): + return self.ref + + return repr(type(s)(Shim(self.val_to_arg_str(a)) for a in s)) + elif isinstance(s, torch._ops.OpOverload): + return _get_qualified_name(s) + elif isinstance(s, (ir.Buffer, ReinterpretView)): + return s.codegen_reference() + else: + return repr(s) + + # The following methods are for memory management + def make_buffer_allocation(self, buffer): + device = buffer.get_device() + dtype = buffer.get_dtype() + shape = tuple(buffer.get_size()) + stride = tuple(buffer.get_stride()) + return self.make_allocation(buffer.get_name(), device, dtype, shape, stride) + + def make_allocation(self, name, device, dtype, shape, stride): + if device.type in ("cpu", "cuda"): + # optimized path for faster allocations, saving ~2us versus the stuff below + return ( + f"{name} = empty_strided_{device.type}(" + f"{self.codegen_shape_tuple(shape)}, " + f"{self.codegen_shape_tuple(stride)}, " + f"{dtype})" + ) + # all other devices: + return ( + f"{name} = empty_strided(" + f"{self.codegen_shape_tuple(shape)}, " + f"{self.codegen_shape_tuple(stride)}, " + f"device='{device.type}', dtype={dtype})" + ) + + def make_tensor_alias(self, new_name, old_name, comment=""): + return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}" + + def make_buffer_free(self, buffer): + return f"del {buffer.get_name()}" + + def make_free_by_names(self, names_to_del: List[str]): + return f"del {', '.join(name for name in names_to_del)}" + + def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): + return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse" + + def make_buffer_reuse(self, old, new, delete_old: bool): + assert old.get_dtype() == new.get_dtype() + old_name = old.get_name() + new_name = new.get_name() + del_line = ";" + if old_name not in V.graph.get_output_names() and delete_old: + del_line = f"; {self.make_buffer_free(old)}" + + if old.get_size() == new.get_size() and old.get_stride() == new.get_stride(): + if old_name in self.stack_allocated_buffers: + self.stack_allocated_buffers[new_name] = new + return self.codegen_exact_buffer_reuse(old_name, new_name, del_line) + + reinterpret_view = self.codegen_reinterpret_view( + old, new.get_size(), new.get_stride(), 0, self.wrapper_call + ) + if reinterpret_view in self.stack_allocated_buffers: + self.stack_allocated_buffers[new_name] = new + return f"{self.declare_maybe_reference}{new_name} = {reinterpret_view}{del_line} {self.comment} reuse" + + def codegen_deferred_allocation(self, name, layout): + self.writeline( + DeferredLine( + name, + f"{self.declare_maybe_reference}{name} = {layout.view.codegen_reference()}{self.ending} " + f"{self.comment} alias", + ) + ) + + def codegen_allocation(self, buffer): + assert ( + buffer.get_workspace_size() == 0 + ), "Only support zero workspace size for now!" + + name = buffer.get_name() + + if name in V.graph.removed_buffers or name in self.allocated: + return + self.allocated.add(name) + if isinstance( + buffer, + (ir.ExternKernelAlloc, ir.MultiOutput), + ): + return + + layout = buffer.get_layout() + if isinstance(layout, ir.MutationLayout): + return + if isinstance(layout, ir.AliasedLayout): + assert isinstance( + layout.view, ir.ReinterpretView + ), f"unexpected {type(layout.view)}: {layout.view}" + self.codegen_allocation(layout.view.data) + self.codegen_deferred_allocation(name, layout) + return + + self.writeline(AllocateLine(self, buffer)) + + def codegen_free(self, buffer): + assert ( + buffer.get_workspace_size() == 0 + ), "Only support zero workspace size for now!" + + name = buffer.get_name() + + # can be freed but not reused + if isinstance(buffer, ir.InputBuffer): + self.writeline(self.make_buffer_free(buffer)) + return + + if not self.can_reuse(buffer): + return + self.freed.add(name) + + self.writeline(FreeIfNotReusedLine(self, buffer)) + + def can_reuse(self, input_buffer, output_buffer=None): + name = input_buffer.get_name() + if ( + name in V.graph.removed_buffers + or name in V.graph.graph_inputs + or name in V.graph.constants + or name in V.graph.never_reuse_buffers + or name in self.freed + ): + return False + + return True + + def did_reuse(self, buffer, reused_buffer): + # Check whether a given buffer was reused by a possible reuser in the wrapper codegen + # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed + return ( + buffer.get_name() in self.reuses + and self.reuses[buffer.get_name()] == reused_buffer.get_name() + ) + + def codegen_inplace_reuse(self, input_buffer, output_buffer): + assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer) + self.codegen_allocation(input_buffer) + self.freed.add(input_buffer.get_name()) + self.allocated.add(output_buffer.get_name()) + self.reuses[output_buffer.get_name()] = input_buffer.get_name() + self.writeline(ReuseLine(self, input_buffer, output_buffer)) + + def codegen_unbacked_symbol_decl(self, symbol): + name = str(symbol) + if name in self.unbacked_symbol_decls: + return name + else: + # When in CppWrapperCpu, we should only generate the declaration once + self.unbacked_symbol_decls.add(name) + return self.declare + name + + def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): + for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs): + self.writeline(f"{self.declare}{inner_input} = {outer_input}{self.ending}") + + def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): + for inner_output, outer_output in zip( + subgraph.graph.graph_outputs, outer_outputs + ): + self.writeline( + f"{outer_output} = {inner_output.codegen_reference()}{self.ending}" + ) + + def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs): + try: + self.push_codegened_graph(subgraph.graph) + self.writeline(f"{self.comment} subgraph: {subgraph.name}") + self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs) + parent_graph = V.graph + with V.set_graph_handler(subgraph.graph): + subgraph.graph.codegen_subgraph( + parent_graph=parent_graph, + ) + self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs) + finally: + self.pop_codegened_graph() + + def codegen_conditional(self, conditional): + name = conditional.get_name() + outer_inputs = [buf.codegen_reference() for buf in conditional.operands] + outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] + + self.writeline(f"{name} = [None] * {len(conditional.outputs)}") + self.writeline(f"if {conditional.predicate.codegen_reference()}.item():") + self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) + self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("else:") + self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) + self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + + @staticmethod + def statically_known_int_or_none(x): + try: + val = V.graph._shape_env._maybe_evaluate_static(x) + return int(x) + except Exception: + return None + + @staticmethod + def statically_known_list_of_ints_or_none(lst): + result = [] + for x in lst: + num = WrapperCodeGen.statically_known_int_or_none(x) + if num is None: + return None + result.append(num) + return result + + @staticmethod + def is_statically_known_list_of_ints(lst): + return WrapperCodeGen.statically_known_list_of_ints_or_none(lst) is not None + + @staticmethod + def static_shape_for_buffer_or_none(buffer): + return WrapperCodeGen.statically_known_list_of_ints_or_none(buffer.get_size()) + + @staticmethod + def can_prove_buffer_has_static_shape(buffer): + return WrapperCodeGen.static_shape_for_buffer_or_none(buffer) is not None diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..6ff48e5dc6d91be8e9ceef5916be67526630888e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py @@ -0,0 +1,273 @@ +import math +from enum import IntEnum + +import sympy + +import torch +from . import ir + +from .utils import get_dtype_size, sympy_product +from .virtualized import V + + +class NCCL_COLL(IntEnum): + ALL_REDUCE = 0 + ALL_GATHER = 1 + REDUCE_SCATTER = 2 + + +class NVIDIA_GPU_TYPE(IntEnum): + VOLTA = 0 + AMPERE = 1 + HOPPER = 2 + + +def get_gpu_type() -> NVIDIA_GPU_TYPE: + gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or "" + if "V100" in gpu_info: + return NVIDIA_GPU_TYPE.VOLTA + elif "A100" in gpu_info: + return NVIDIA_GPU_TYPE.AMPERE + elif "H100" in gpu_info: + return NVIDIA_GPU_TYPE.HOPPER + else: + # for other gpu types, assume Ampere + return NVIDIA_GPU_TYPE.AMPERE + + +def get_collective_type(node: ir.IRNode) -> NCCL_COLL: + if isinstance(node, ir._CollectiveKernel): + kernel_name = node.python_kernel_name + assert kernel_name is not None + if "all_reduce" in kernel_name: + return NCCL_COLL.ALL_REDUCE + elif "all_gather" in kernel_name: + return NCCL_COLL.ALL_GATHER + elif "reduce_scatter" in kernel_name: + return NCCL_COLL.REDUCE_SCATTER + else: + raise Exception(f"Unsupported collective kernel: {kernel_name}") + + if isinstance(node, (ir.AllReduce, ir.AllReduceCoalesced)): + return NCCL_COLL.ALL_REDUCE + elif isinstance(node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)): + return NCCL_COLL.ALL_GATHER + elif isinstance(node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)): + return NCCL_COLL.REDUCE_SCATTER + else: + raise Exception(f"Unsupported collective type: {node}") + + +def get_collective_input_size_bytes(node: ir.IRNode) -> int: + sz_bytes = 0 + for inp in node.inputs: # type: ignore[attr-defined] + shape = inp.layout.size + numel = sympy_product(inp.layout.size) + if isinstance(numel, sympy.Integer): + # For ease of testing + numel = int(numel) + else: + numel = V.graph.sizevars.size_hint(numel) + sz_bytes += numel * get_dtype_size(inp.layout.dtype) + return sz_bytes + + +def get_collective_group_size(node: ir.IRNode) -> int: + if type(node) == ir._CollectiveKernel: + from torch.distributed.distributed_c10d import _get_group_size_by_name + + return _get_group_size_by_name(node.constant_args[-1]) + elif isinstance(node, ir.CollectiveKernel): + return node.constant_args[2] # type: ignore[attr-defined] + else: + raise TypeError(f"Unsupported collective type: {node}") + + +#################################################################################################################### +# The following code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # +#################################################################################################################### + + +class NCCL_HW(IntEnum): + NVLINK = 0 + PCI = 1 + NET = 2 + + +class NCCL_ALGO(IntEnum): + TREE = 0 + RING = 1 + + +class NCCL_PROTO(IntEnum): + # The ordering and enum values here matches original in + # https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28 + # For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990 + LL = 0 # Low-latency + # LL128 = 1 # Low-latency 128-byte + # SIMPLE = 2 + + +# Latencies in us +# len(NCCL_ALGO) x len(NCCL_PROTO) +# NOTE: use array instead of tensor to prevent incompatibility with fake mode +baseLat = [ + # Tree + [ + 6.8, # LL + ], + # Ring + [ + 6.6, # LL + ], +] + +# Latencies in us +# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO) +hwLat = [ + # NVLINK + [ + [0.6], # Tree (LL) + [0.6], # Ring (LL) + ], + # PCI + [ + [1.0], # Tree (LL) + [1.0], # Ring (LL) + ], + # NET + [ + [5.0], # Tree (LL) + [2.7], # Ring (LL) + ], +] + + +# LL128 max BW per channel +llMaxBws = [ + # Volta-N1/Intel-N2/Intel-N4 + [ + 39.0, + 39.0, + 20.4, + ], + # Ampere-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], + # Hopper-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], +] + + +def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: + """ + Returns estimated NCCL collective runtime in nanoseconds (ns). + + The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc. + We aim to estimate the runtime as accurately as possible. + + Assumptions: + - only ring algorithm (NCCL_ALGO_RING) is used + - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used + - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + - collective is one of: allreduce, reducescatter, allgather + """ + tensor_storage_size_bytes = get_collective_input_size_bytes(node) + # Convert bytes to GB + tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024 + + # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus. + # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + num_gpus_per_node = 8 + group_size = get_collective_group_size(node) + nNodes = math.ceil(group_size / num_gpus_per_node) + nRanks = group_size # this is total # of gpus globally that participate in this collective op + + if nRanks <= 1: + return 0 + + # Assumes ring algorithm + nccl_algo = NCCL_ALGO.RING + nccl_proto = NCCL_PROTO.LL + coll = get_collective_type(node) + + # =============== bandwidth computation =============== + # First compute bandwidth in GB/s; then at the end, convert it to GB/ns + + bwIntra = torch._inductor.config.intra_node_bw + bwInter = torch._inductor.config.inter_node_bw + + compCapIndex = get_gpu_type() + index2 = nNodes - 1 if nNodes <= 2 else 2 + # LL: for single node, we look at GPU type; for multi-node, we look at CPU type + index1 = compCapIndex if nNodes == 1 else 0 + llMaxBw = llMaxBws[index1][index2] + + # NOTE: each step of ring algorithm is synchronized, + # and is bottlenecked by the slowest link which is the inter-node interconnect. + # hence when nNodes >= 2, bw is inter-node bandwidth. + # NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc + # have this as `if nNodes <= 2` which seems wrong. Corrected it here. + bw = bwIntra if nNodes == 1 else bwInter + nChannels = 2 # Assume # channels is 2 + busBw = nChannels * bw + + # Various model refinements + busBw = min( + llMaxBw, + busBw + * (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0), + ) + + if coll == NCCL_COLL.ALL_REDUCE: + nsteps = 2 * (nRanks - 1) + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): + nsteps = nRanks - 1 + + # Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time) + ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined] + bandwidth = busBw * ratio + # Convert GB/s to GB/ns + bandwidth_GB_per_ns = bandwidth / 1e9 + + # =============== latency computation =============== + intraHw = NCCL_HW.NVLINK + hw = intraHw if nNodes == 1 else NCCL_HW.NET + + if coll == NCCL_COLL.ALL_REDUCE: + if nNodes > 1: + nInterSteps = 2 * nNodes + else: + nInterSteps = 0 + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): + nInterSteps = nNodes - 1 + + # First compute latency in us; then at the end, convert it to ns + latency = baseLat[nccl_algo][nccl_proto] + intraLat = hwLat[intraHw][nccl_algo][nccl_proto] + interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto] + + # Inter-node rings still have to launch nsteps * net overhead. + netOverhead = 0.0 + if nNodes > 1: + netOverhead = 1.0 # getNetOverhead(comm); + intraLat = max(intraLat, netOverhead) + latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined] + # Convert us to ns + latency_ns = latency * 1e3 + + # =============== final result =============== + transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns + return transport_ns + latency_ns + + +################################################################################################################ +# The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # +################################################################################################################ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb19bb72ad7143c82e1dadb67e26ec72f393fcb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py @@ -0,0 +1,2159 @@ +""" +CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables, +which share the same memory pool. Sharing a memory pool is an extremely +important optimization when chaining multiple CUDA graphs together, as it +prevents you from needing to copy intermediate tensors from one graph to the +next, and reduces overall memory usage by allowing dead memory from the first +pool to be reused in the second. + +The standard graph/make_graph_callables support sharing memory pool, but +with a lot of caveats. CUDA graph trees remove these restrictions: + +* Previously, if you recorded graphs A, B, you had to replay A, B in that + order. With CUDA graph trees, after replaying A, you can change your + mind and record/replay a different graph B'; we will support efficient + execution of both A, B and A, B', using only max(mem(A, B), mem(A, B')). In + other words: we support arbitrary trees of CUDA graph operations, not just + sequences (this is why this feature is called CUDA graph trees.) + +* Previously, if you executed graph A, some non-CUDA graph code, and then + graph B, after executing graph B, it was not safe to retain any references + to intermediates produced by A. With CUDA graph trees, we track if any +outputs of graph A are still live by the time graph B is run, and make + sure graph B doesn't clobber there memory when reusing the CUDA graphs + pool. You'll get a separate recording of B depending on what tensors + stay live or dead. + +CUDA graph trees are flexible enough to be used in Dynamo across graph breaks, +which is their primary use case. + +The ability to switch from replay to record is fairly nontrivial: remember that +when you replay a CUDA graph, you only replay CUDA operations; no CPU side state +is updated. In particular, the CPU-side book-keeping for the allocator is not +reconstructed. However, to record a new child CUDA graph, we must restore this +book-keeping. This is what checkpoint pool state is used for. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import gc +import itertools +import operator +import sys +import threading +import traceback +import warnings +import weakref +from collections import defaultdict + +from enum import auto, Enum +from typing import ( + Any, + Callable, + cast, + Dict, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch.fx +from torch import Tensor +from torch._dynamo.mutation_guard import GenerationTracker +from torch._dynamo.utils import preserve_rng_state +from torch._inductor.compile_fx import ( + align_inputs_from_check_idxs, + copy_misaligned_inputs, + get_expanded_dims, + get_input_idxs_to_check, + index_expanded_dims, + remove_unaligned_input_idxs, + static_input, +) +from torch.multiprocessing.reductions import StorageWeakRef +from torch.storage import UntypedStorage +from torch.types import _bool +from torch.utils import _pytree as pytree +from torch.utils.weak import TensorWeakRef + +StorageWeakRefPointer = int +StorageDataPtr = int +NBytes = int + +if torch.backends.cuda.is_built(): + from torch._C import ( + _cuda_CUDAAllocator_AllocatorState as AllocatorState, + _set_cached_tensors_enabled as _set_cached_tensors_enabled, + ) +else: + + class AllocatorState: # type: ignore[no-redef] + pass + + def _set_cached_tensors_enabled(enabled: _bool) -> None: + pass + + +log = torch._logging.getArtifactLogger(__name__, "cudagraphs") + + +from . import config + + +@dataclasses.dataclass(frozen=True) +class GraphID: + "Unique counter of a cuda graph recording" + id: int + + +@dataclasses.dataclass(frozen=True) +class FunctionID: + "Unique counter of a function wrapped in cudagraphify_impl" + id: int + + +@dataclasses.dataclass(frozen=True) +class WrappedFunction: + """ + Represents a function that you want to record for CUDA graph replay, + with a little more metadata so we can identify if we have an applicable + CUDA graph in our CUDA graph tree for it. + """ + + model: Callable[..., Any] + static_input_idxs: Sequence[int] + id: FunctionID + constants: Tuple[torch.Tensor, ...] + + +def clear_cublass_cache(): + """ + Cublas keeps a persistent workspace allocation for running matmuls. This poses a problem for + doing warmup within a CUDAGraph private pool because we do not want persistent allocations from + one one run to the next. When we begin a new run of a cudagraphs path (generation), all tensors + from the previous generation are freed. This frees them the memory pool, but not elsewhere. + A tensor in the cublas workspace would continue to be in use the workspace but would also get allocated + in the next run. The memory would be in use in two places. + + To solve this, we clear cublas caches before and after warming up or recording. If a workspace is required + it will be allocated to the cudagraph private pool and accounted for in the allocator for the duration of the + program. There is no overhead to this on replay since cudagraphs removes allocation overhead. + """ + torch._C._cuda_clearCublasWorkspaces() + + +@contextlib.contextmanager +def clear_cublas_manager(): + "Context manager around clearing cublas caches that will clear on enter and exit" + clear_cublass_cache() + try: + yield + finally: + clear_cublass_cache() + + +@contextlib.contextmanager +def disable_conv_cache_emptying(): + prev = torch._C._cuda_get_conv_benchmark_empty_cache() + torch._C._cudnn_set_conv_benchmark_empty_cache(False) + try: + yield + finally: + torch._C._cudnn_set_conv_benchmark_empty_cache(prev) + + +@contextlib.contextmanager +def enable_history_recording(): + "Turns on history recording in the CUDA Caching Allocator" + enabled = torch._C._cuda_isHistoryEnabled() + try: + if not enabled: + torch.cuda.memory._record_memory_history() + yield + finally: + if not enabled: + torch.cuda.memory._record_memory_history(None) + + +def get_history_recording(): + # TODO - remove, prevents cleanup + if not config.triton.cudagraph_trees_history_recording: + return contextlib.nullcontext() + return enable_history_recording() + + +class TreeManagerContainer: + """ + Manages the lifetime of the tree manager. Like `PrivatePool` in cuda caching allocator, + the tree and its corresponding memory pool should be kept alive as long as any outstanding + graph or tensor which is an output of a graph remains alive. + + There is a single tree manager container per device. + + The lifecycle of a tree_manager is: + - Is constructed, no graph, no fns, no tensors + - Tree manager is fetched, resulting in tree manager being allocated + - We generate a bunch of functions, calling add_strong_reference + - These functions die, calling finalize_reference + - When all the functions die, we finalize_tree_manager. + + TODO: in the future, we would like to do the following once storage weak refs land + - We look for all the live storages and add references to THOSE + - We count as storages die + - All the storages are dead, we deallocate the tree manager + """ + + def __init__(self, device_index): + # This class keeps a strong reference to tree_manager, + # but upon all other strong references to the tree_manager will reset it to None. + # We need a strong reference so that we can still access its attributes upon cleanup. + self.tree_manager: Optional[CUDAGraphTreeManager] = None + + # Number of outstanding references to the current tree manager + self.live_cudagraphify_fns = 0 + + self.device_index = device_index + + # Following two objects are only set in the case that Tensor outputs outlive + # the cudagraphify_fns. Reference to the Graph is needed to keep the private pool from + # deallocation. + self.live_storages_count = 0 + self.graph: Optional[torch.cuda.CUDAGraph] = None + + self.lock = threading.Lock() + + def _finalize_tensor(self): + with self.lock: + self.live_storages_count -= 1 + if self.live_storages_count == 0: + self.graph = None + + # manager was used again after existing cleanup, + # we shouldnt set it to None + if self.live_cudagraphify_fns == 0: + self.tree_manager = None + + def finalize_cudagraphify_fn(self): + with self.lock: + self.live_cudagraphify_fns -= 1 + if self.live_cudagraphify_fns == 0: + self._finalize_tree_manager() + + def _finalize_tree_manager(self): + assert self.lock.locked() + self.tree_manager = None + + # TODO - when issue #91395 is landed, we can set a weakref on + # storages and trigger a deallocation when all outputs of the + # cudagraph are dead. + + # live_storages = list( + # tree_manager.live_cudagraph_pool_storages_in_curr_execution() + # ) + + # # Maintain reference to graph to keep tensors alive + # assert len(tree_manager.roots) > 0, "expected at least one use" + # root = next(tree_manager.get_roots()) + # self.graph = root.graph + # seen_storages = set() + # for stor in live_storages: + # if stor in seen_storages: + # continue + # seen_storages.add(stor) + # self.live_storages_count += 1 + # . weakref.finalize(stor, self._finalize_tensor) + + def add_strong_reference(self, fn: Callable[..., Any]): + with self.lock: + self.live_cudagraphify_fns += 1 + + weakref.finalize(fn, self.finalize_cudagraphify_fn) + + def get_tree_manager(self) -> CUDAGraphTreeManager: + with self.lock: + if self.tree_manager is None: + self.tree_manager = CUDAGraphTreeManager(self.device_index) + return self.tree_manager + + +local = threading.local() + +# one tree manager per device +local.tree_manager_containers = {} +local.tree_manager_locks = defaultdict(threading.Lock) + + +# only incremented by user call of mark_step_begin +class MarkStepBox: + mark_step_counter = 0 + + +# We need to register this as an object that will be copied over as TLS when new +# threads are created in autograd +torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_containers) +torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks) + + +def mark_step_begin(): + "Indicates that a new iteration of inference or training is about to begin." + + # iterate down to distinguish from GenerationTracking counter + MarkStepBox.mark_step_counter -= 1 + + +def reset_cudagraph_trees(): + "Clear all cudagraph trees" + # see shutdown below for why this is necessary + container_dict = get_obj(local, "tree_manager_containers") + locks_dict = get_obj(local, "tree_manager_locks") + for device, lock in locks_dict.items(): + with lock: + container = container_dict.get(device) + if not container or not container.tree_manager: + continue + + container.tree_manager.shutdown() + + _set_cached_tensors_enabled(False) + container_dict.clear() + + MarkStepBox.mark_step_counter = 0 + + +def get_obj(local, attr_name): + if hasattr(local, attr_name): + return getattr(local, attr_name) + else: + assert torch._C._is_key_in_tls(attr_name) + return torch._C._get_obj_in_tls(attr_name) + + +def get_container(device_index: int): + container_dict = get_obj(local, "tree_manager_containers") + lock = get_obj(local, "tree_manager_locks")[device_index] + + with lock: + if device_index not in container_dict: + container_dict[device_index] = TreeManagerContainer(device_index) + + return container_dict[device_index] + + +def get_manager( + device_index: int, create_if_none_exists=True +) -> Optional[CUDAGraphTreeManager]: + if create_if_none_exists: + return get_container(device_index).get_tree_manager() + return get_container(device_index).tree_manager + + +def cudagraphify_impl(model, inputs, static_input_idxs, *args, **kwargs): + fn_cache: Dict[Tuple[int, ...], Callable[..., Any]] = {} + + # Detect int inputs: we need to index on these + int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)] + get_ints: Any = operator.itemgetter(*int_key) if int_key else lambda _: None + + del inputs + + def deferred_cudagraphify(inputs): + int_key = get_ints(inputs) + fn = fn_cache.get(int_key) + if fn is not None: + return fn(inputs) + + if int_key is None: + log.info("recording cudagraph tree for graph without symints") + else: + log.info("recording cudagraph tree for symint key %s", int_key) + + # first get indices we need to check to align, then update our static inputs, + # and finally copy + check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) + new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) + copy_misaligned_inputs(inputs, check_input_idxs) + + fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs) + fn = align_inputs_from_check_idxs(fn, inputs_to_check=check_input_idxs) + fn_cache[int_key] = fn + + return out + + return deferred_cudagraphify + + +def cudagraphify( + model, + inputs, + static_input_idxs=(), + *, + device_index: int, + is_backward: bool, + is_inference: bool, + stack_traces: Optional[StackTraces] = None, + constants: Tuple[torch.Tensor, ...] = (), +): + manager = get_container(device_index).get_tree_manager() + assert not (is_backward and is_inference) + mode = ( + CompilationMode.BACKWARD + if is_backward + else (CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD) + ) + + return manager.add_function( + model, + inputs, + static_input_idxs, + stack_traces, + mode, + constants, + ) + + +class StorageWeakRefWrapper: + """ + Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked. + """ + + __slots__ = ["ref", "_data_ptr", "extra_ref_check"] + + storage_ref: Optional[StorageWeakRef] + + def __init__( + self, + inp: Union[Tensor, UntypedStorage], + extra_ref_check: Optional[Callable[[], None]] = None, + ): + """ + extra_ref_check is an additional check we need to run to check if the + weak ref has expired. in checking storage use count we assume extra_ref_check + will hold an additional reference to the storage. + """ + if isinstance(inp, Tensor): + stor = inp.untyped_storage() + else: + assert isinstance(inp, UntypedStorage) + stor = inp + self.ref = StorageWeakRef(stor) + self._data_ptr = stor.data_ptr() + self.extra_ref_check = extra_ref_check + + @classmethod + def from_weakref_and_data_ptr(cls, cdata, data_ptr, extra_ref_check=None): + instance = cls.__new__(cls) + instance._data_ptr = data_ptr + instance.ref = StorageWeakRef.from_weakref(cdata) + instance.extra_ref_check = extra_ref_check + return instance + + def __call__(self) -> Optional[StorageWeakRefPointer]: + if self.expired(): + return None + + return self.ref.cdata + + def swap_weakref(self, cdata): + self.ref.__del__() + self.ref.cdata = cdata + + def data_ptr(self) -> int: + "NB: returns the data ptr even if the storage has expired" + return self._data_ptr + + def remove_extra_reference(self): + self.extra_ref_check = None + + def expired(self): + if self.extra_ref_check is not None and not self.extra_ref_check(): + return False + + # if extra_ref_check is not None we expect an additional reference + stor_count = torch._C._storage_Use_Count(self.ref.cdata) + return (stor_count - (self.extra_ref_check is not None)) == 0 + + def __repr__(self): + if self.ref is None or self.ref.expired(): + return f"StorageWeakRefWrapper to {self.data_ptr()}; dead" + else: + return f"StorageWeakRefWrapper to {self.data_ptr()}; alive" + + +def is_live(weak_ref: Optional[StorageWeakRefWrapper]) -> bool: + return maybe_deref(weak_ref) is not None + + +def maybe_deref( + weak_ref: Optional[StorageWeakRefWrapper], +) -> Optional[Tuple[StorageWeakRefPointer, int]]: + if weak_ref is None: + return None + r = weak_ref() + if r is None: + return None + # NB: r.data_ptr() does not necessarily equal weak_ref.data_ptr() + return r, weak_ref.data_ptr() + + +@contextlib.contextmanager +def _use_cuda_memory_pool_manager(device, mem_pool, stream): + """ + Context manager to use cuda graph pool for new allocations. If you use this manager + all cudagraph tensors in use should be reflected in the allocator or they will be overwritten. + existing_graph should already have been used in a capture, and the mem_pool must already exist, + because this manager will not preserve a reference to the pool which keeps it alive. + """ + torch.cuda.synchronize() + stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(stream), torch.device(device): + torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool) + try: + yield + finally: + torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool) + torch._C._cuda_releasePool(device, mem_pool) + + torch.cuda.current_stream().wait_stream(stream) + + +def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]: + if not isinstance(t, torch.Tensor): + assert t is None + return None + return StorageWeakRefWrapper(t) + + +# A path index of (depth, offset) indices into a graph that is `depth`` number of nodes from the root +# at graph output offset +PathOutputIndex = Tuple[int, int] + +# For each node in the path, for each output, is the output alive +PathLiveness = List[List[bool]] + +StackTraces = List[Optional[str]] + + +class CUDAWarmupNode: + """ + Simplified Wrapper around A CUDA Model that wraps outputs in storage refs and exposes + apis to get the live storages in the current chain of warmup. + + A CUDAWarmupNode may have either CUDAGraphNode or CUDAWarmupNode as a parent, but may only have + CUDAWarmupNode as children, because we cannot record or execute with tensors which do not have stable + memory addresses. + + CUDAWarmupNode and CUDAGraphNode have a number of differences that make it easier to use separate classes. + - Much of the CUDAGraphNode logic & initialization is based on the tensor properties of first recording. In the + first instance of warmup, these are not finalized yet. + - All Inputs to the RecordedFunction must be copied over to the cuda graph memory pool, this is unnecessary in warmup. + - CUDAWarmup is only used once and so does not need to optimize as much bookkeeping. It is much simpler. + + NB: this class and CUDAGraphNode need to expose `path_live_weakrefs`, `all_outputs_are_dead`, and + `self.outputs_weakrefs`, `stack_traces`, and `tensor_weakrefs` for compatibility. + """ + + def __init__( + self, + wrapped_function: WrappedFunction, + parent, + cuda_graphs_pool: Tuple[int, int], + existing_cuda_graph: Optional[torch.cuda.CUDAGraph], + device_index: int, + stack_traces: Optional[StackTraces], + stream: torch.cuda.Stream, + already_warm: bool, + ): + self.wrapped_function = wrapped_function + self.parent = parent + self.cuda_graphs_pool = cuda_graphs_pool + self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = [] + self.tensor_weakrefs: List[Optional[TensorWeakRef]] = [] + self.existing_cuda_graph = existing_cuda_graph + self.has_run = False + self.device_index = device_index + self.stack_traces = stack_traces + self.stream = stream + self.already_warm = already_warm + + def run(self, new_inputs): + assert not self.has_run, "Wrapped function should never be run twice" + + # See: output_is_alias_of_persistent_static_inputs below. We should only be returning freshly created + # storages in path_live_weakrefs. + existing_path_data_ptrs = { + t.data_ptr() for t in self.path_live_weakrefs() if t() + } + + def get_non_cudagraph_inps(): + non_cudagraph_inps = set() + for t in itertools.chain(new_inputs, self.wrapped_function.constants): + if ( + isinstance(t, torch.Tensor) + and t.untyped_storage().data_ptr() not in existing_path_data_ptrs + ): + non_cudagraph_inps.add(t.untyped_storage().data_ptr()) + return non_cudagraph_inps + + non_cudagraph_inps = get_non_cudagraph_inps() + + if config.triton.slow_path_cudagraph_asserts and not self.already_warm: + refs = list(self.path_live_weakrefs()) + check_memory_pool(self.device_index, self.cuda_graphs_pool, refs) + + with torch.cuda.device( + self.device_index + ), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager( + self.device_index, self.cuda_graphs_pool, self.stream + ), get_history_recording(): + out = self.wrapped_function.model(new_inputs) + + assert len(new_inputs) == 0 + + # sdpa returns cpu tensors when not recording cuda graph + def add_ref(o): + return ( + o is not None + and isinstance(o, torch.Tensor) + and o.is_cuda + and o.untyped_storage().data_ptr() not in non_cudagraph_inps + and o.untyped_storage().data_ptr() != 0 + ) + + self.outputs_weakrefs.extend( + [map_to_ref(o) if add_ref(o) else None for o in out] + ) + self.tensor_weakrefs.extend( + [TensorWeakRef(o) if add_ref(o) else None for o in out] + ) + + if config.triton.slow_path_cudagraph_asserts and not self.already_warm: + out_refs = self.path_live_weakrefs() + new_storages = [ + t for t in out_refs if t.data_ptr() not in non_cudagraph_inps + ] + check_memory_pool(self.device_index, self.cuda_graphs_pool, new_storages) + + return out + + @property + def _path_from_root(self): + nodes = [] + node = self + while node: + nodes.append(node) + node = node.parent + + yield from reversed(nodes) + + def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: + "Returns all live storages weakrefs that created by nodes in this path" + for node in self._path_from_root: + for output in node.outputs_weakrefs: + if is_live(output): + yield output + + def all_outputs_are_dead(self): + return not list(self.path_live_weakrefs()) + + +# Aliases for List that say what the indices denote +InputList = List # input indexes +OutputList = List # output indexes +LevelList = List # levels (distance from root of tree) + + +class OutputAliasInfo: + pass + + +class _UnaliasedStorage(OutputAliasInfo): + "Singleton to mark that the graph output constructs a new alias or is None" + pass + + +UnaliasedStorage = _UnaliasedStorage() + + +class AliasesPriorGraphOutput(OutputAliasInfo): + "Marks that the graph output aliases an output of a prior graph" + __slots__ = ["index"] + + index: PathOutputIndex + + def __init__(self, index: PathOutputIndex): + assert isinstance(index, tuple) + self.index = index + + +class AliasesNewOutput(OutputAliasInfo): + "Marks that the graph output aliases an index in the new, returned outputs" + + __slots__ = ["index"] + + index: int + + def __init__(self, index): + assert isinstance(index, int) + self.index = index + + +class CUDAGraphNode: + """ + A single recording of a function into a CUDA Graph. Recordings of CUDA Graphs share a single memory pool + and are structured into a tree, where there is a single recording that can precede it (parent) and multiple + subsequent recordings that may follow (children). A node will have no parent if it is the first recording + in a tree; i.e., when it is first recorded, there are no live tensors from a previous recording which + would force a dependency. + + On first recording, all of the live tensors in the current CUDA Graph Node path will be + reflected in the corresponding private pool. On subsequent executions, the caching allocator + is unaffected when the graph is replayed. + + In order to support recording a subsequent cuda graph recording after execution of this graph, + we checkpoint the state of the memory pool so that it may later be resumed. + + WrappedFunction should have already been warmed up prior to invocation. + + See [setCheckpointPoolState] for further explanation, as well as + https://user-images.githubusercontent.com/13564/222815509-374f3400-f83d-4f7d-8fa6-4a092b3250bb.png + """ + + def __init__( + self, + wrapped_function: WrappedFunction, + id: GraphID, + parent: Optional[CUDAGraphNode], + inputs: List[Tensor], + cuda_graphs_pool: Tuple[int, int], + device_index: int, + stack_traces: Optional[StackTraces], + stream: torch.cuda.Stream, + ): + assert isinstance(inputs, (list, tuple)) + + self.wrapped_function = wrapped_function + self.id = id + self.device = device_index + self.stack_traces = stack_traces + self.stream = stream + + # if this is a root parent will be None. use weakref to prevent reference cycle + self._parent = weakref.ref(parent) if parent is not None else None + # reference to the shared memory pool for the entire cuda graphs tree + self.cuda_graphs_pool = cuda_graphs_pool + + # A single wrapped function may be recorded multiple times if memory patterns or + # invariants change from one execution to the next + self.children: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list) + + # StorageWeakRef maintains whether the Storage C++ object remains allocated, + # not whether the corresponding memory has been deallocated. In order + # to use them to track memory deallocations we must maintain a single StorageWeakRef + # for all Storages that reference that memory (even if we are constructing Storages + # that do not have a deallocator function). We maintain one single storage_cache + # as we execute any tree path. When we retrieve a storage from the cache we + # check that it is still alive, and we hash based on observed recording data ptr + # and storage cdata. + + # we preserve a single reference to executed outputs that is then referenced + # in children to avoid children having to chase parent pointers in the hot path + # DO NOT reassign output_weakrefs, only call `clear()` + # Path is a series of nodes from root to the current node + self.outputs_weakrefs: OutputList[Optional[StorageWeakRefWrapper]] = [] + self.path_weakrefs: LevelList[OutputList[Optional[StorageWeakRefWrapper]]] = [ + node.outputs_weakrefs for node in self._path_from_root + ] + self.path_stacktraces: LevelList[StackTraces] = [ + node.stack_traces for node in self._path_from_root + ] + self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = [] + + # tensors which are outputs of previous graphs in the tree + self.cudagraph_managed_idxs: List[int] = [ + idx + for idx, t in enumerate(inputs) + if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t) + ] + + self.static_input_idxs: List[int] = list( + set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs) + ) + + self.static_input_data_ptrs: InputList[Optional[int]] = [ + ( + inputs[i].data_ptr() + if isinstance(inputs[i], torch.Tensor) and i in self.static_input_idxs + else None + ) + for i in range(len(inputs)) + ] + + # When we checkpoint, and free generations, we will be manually freeing the outputs + # of CUDAGraphNodes. We should not be freeing parameters, not do we need to account for + # their liveness (they are static), so we need to compute which outputs are aliases of + # parameters. Some static inputs are saved tensors from the forward that die in the backward. + # Their locations are static but lifetimes are not. We only include the persistent static + # data ptrs below because the non persistent data ptrs may be outputs of this record and + # fresh allocations. + + # precompute expanded dims to avoid computing in the hot path + self.expanded_dims: List[List[int]] = [ + get_expanded_dims(x) + if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs + else [] + for idx, x in enumerate(inputs) + ] + + # For each node in path, which outputs were observed to be live + # before invoking graph recording, and after graph recording + self.recorded_liveness_before_graph: LevelList[OutputList[bool]] = [] + self.recorded_liveness_after_graph: LevelList[OutputList[bool]] = [] + + # List of Tuples of (depth, output_index) that index into node at depth + # number of nodes from root and output_index of outputs. Will index into + # path_weakrefs. + self.expected_dead_indices_before_graph: List[PathOutputIndex] = [] + self.expected_dead_indices_after_graph: List[PathOutputIndex] = [] + + # all live indices after graph recording + self.live_indices_after_graph: List[PathOutputIndex] = [] + + if self.parent is not None: + previous_liveness = self.parent.recorded_liveness_after_graph + curr_liveness = self._get_liveness(self.path_weakrefs) + + different_indices = self._get_different_indices( + previous_liveness, curr_liveness + ) + + self.recorded_liveness_before_graph = curr_liveness + self.expected_dead_indices_before_graph = different_indices + + recording_inputs = self._allocate_and_copy_recording_inputs(inputs) + # recording inputs will copy over memory, so we can free non recording inputs + inputs.clear() + del inputs + + # graph used for recording model invocation + self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + + # we allocate non-static inputs within the same memory pool as the CUDAGraph + # which we will record the model with. For memory efficiency, it is important + # to reclaim the input memory when the inputs are no longer live. To accomplish this, + # we reconstruct tensors at the correct data pointers of our inputs which are + # non owning and do not prevent deallocation. On subsequent executions, input values + # will be copied over to these tensors. + self.reconstructed_inputs: InputList[Union[Tensor, int]] = [ + self._reconstruct_from_tensor_metadata(self._tensor_metadata(x)) + if isinstance(x, torch.Tensor) + else x + for x in recording_inputs + ] + + # DO THE RECORDING!!! + # We record the CUDA graph in the constructor of CUDAGraphNode, which + # gives you what the CPU side compute of the function would do. We + # don't throw the recording outputs away: their memory is + # correctly accounted for in the CUDAGraphs caching allocator. This + # means on the very FIRST run of the CUDA graph node, we can directly + # do more recording, because we have a valid caching allocator state. + # NB: This relies on run() being called immediately after the + # constructor, otherwise this optimization would not be valid. + + # initialized below in _record + + self.checkpointed_caching_state: Optional[AllocatorState] = None + + # Output Storage Alias information, can be: + # - A new, unaliased storage, or the output is None + # - An alias of an output of a prior graph + # - An alias of an output already created in the reconstructed outputs + # This is None if the output in question is an int + self.output_storage_alias: OutputList[Optional[OutputAliasInfo]] = [] + + # is the output Storage unaliased in subsequent outputs, of all subsequent paths + # if it is, we cached the output tensor and adjust storage liveness tracking to also + # check if the output tensor does not have an additional python reference. + # If a descendent node discovers it has an alias of a prior output, then the output + # will no longer be cached in the ancestor. + # The large majority of tensors are unaliased, and preserving aliased output tensors would add + # significant additional complexity with marginal gains + # The cached tensor outputs are added on the first execution, and cleared whenever we need + # to do subsequent recording + self.unaliased_in_all_paths: OutputList[bool] = [] + self.cached_tensor_outputs: OutputList[Optional[Tensor]] = [] + + # if an output aliases a static, persistent input then the corresponding Tensor will + # be set here. These are different than cached tensors, because they are tensors that + # are aliases of parameters that are always live. + self.static_output_tensors: OutputList[Optional[Tensor]] = [] + + # Cleared after recording + self.recording_outputs: Optional[ + OutputList[Union[torch.Tensor, int]] + ] = self._record(wrapped_function.model, recording_inputs) + self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = [] + + # As with inputs, we do not want to keep the outputs permanently alive because that would prevent + # their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata + # needed to reconstruct instead. + assert self.recording_outputs is not None + for out in self.recording_outputs: + if isinstance(out, torch.Tensor): + self.outputs_metadata.append( + self._tensor_metadata(out, ignore_storage_offset=False) + ) + else: + assert isinstance(out, (int, type(None))), type(out) + self.outputs_metadata.append(out) + + self.graph.replay() + + def _copy_input(self, idx, dst, src): + expanded_dims = self.expanded_dims[idx] + dst = index_expanded_dims(dst, expanded_dims) + src = index_expanded_dims(src, expanded_dims) + # TODO - one jit kernel across multiple inputs + dst.copy_(src) + + def run_first_inputs(self, new_inputs): + if config.triton.fast_path_cudagraph_asserts: + self.debug_check_invariants_before_invocation() + + # graph is already invoked in the __init__ + # inputs are copied over in _allocate_recording_inputs and subsequently cleared + assert len(new_inputs) == 0 + outputs = self.recording_outputs + self.recording_outputs = None + return outputs + + def run(self, new_inputs): + if config.triton.fast_path_cudagraph_asserts: + self.debug_check_invariants_before_invocation() + + assert len(self.static_input_data_ptrs) == len(new_inputs) + # NB: this ranges over non-static inputs too + for idx, data_ptr in enumerate(self.static_input_data_ptrs): + if idx in self.cudagraph_managed_idxs: + continue + if not isinstance(new_inputs[idx], torch.Tensor): + pass + elif data_ptr is not None: + # static input, e.g., parameter + assert data_ptr == new_inputs[idx].data_ptr() + else: + # non-static input, need to copy it into CUDA graph + dst = self.reconstructed_inputs[idx] + src = new_inputs[idx] + self._copy_input(idx, dst, src) + + new_inputs.clear() + self.run_graph() + + outputs = self.reconstruct_outputs() + self.debug_check_invariants_after_invocation() + + return outputs + + def reconstruct_outputs(self): + "Reconstruct output tensors according to their saved metadata and alias information" + + # Cached tensors will not yet be set on the first execution + # They are also cleared in checkpointing, so if we checkpoint this node + # and then execute it again we will need to repopulate cached tensors + if not self.cached_tensor_outputs: + self._initialize_cached_tensors() + + outputs: List[Optional[Union[int, torch.Tensor]]] = [] + + for i, (storage_info, metadata) in enumerate( + zip(self.output_storage_alias, self.outputs_metadata) + ): + if not isinstance(metadata, dict): # tensor metadata + assert isinstance(metadata, (int, type(None))) + outputs.append(metadata) + continue + + cached_t = self.cached_tensor_outputs[i] + if cached_t is not None: + # No need to update weakrefs, already correctly initialized + outputs.append(cached_t) + continue + + static_t = self.static_output_tensors[i] + if static_t is not None: + assert self.outputs_weakrefs[i] is None + outputs.append(static_t) + continue + + storage = self.prepare_alias_info_for_tensor_construction( + storage_info, metadata + ) + + if isinstance(storage, UntypedStorage) or storage is None: + out = self._reconstruct_from_tensor_metadata(metadata, storage) + else: + assert isinstance(storage, int) + out = self._reconstruct_from_tensor_metadata( + metadata, cast(torch.Tensor, outputs[storage]).untyped_storage() + ) + + outputs.append(out) + w = self.outputs_weakrefs[i] + assert w is not None + w.swap_weakref(out.untyped_storage()._weak_ref()) + + return outputs + + def prepare_alias_info_for_tensor_construction( + self, + out_alias_info: Optional[OutputAliasInfo], + metadata: Union[Dict[str, Any], int, None], + ) -> Union[UntypedStorage, None, int]: + if ( + isinstance(metadata, (int, type(None))) + or out_alias_info is UnaliasedStorage + ): + return None + + if isinstance(out_alias_info, AliasesPriorGraphOutput): + depth, existing_output_index = out_alias_info.index + ref = self.path_weakrefs[depth][existing_output_index] + assert ref is not None + return torch.UntypedStorage._new_with_weak_ptr(ref()) + + assert isinstance(out_alias_info, AliasesNewOutput) + return out_alias_info.index + + def prepare_storages_for_construction( + self, + ) -> List[Union[UntypedStorage, None, int]]: + output_storages = [] + for output_storage_alias, metadata in zip( + self.output_storage_alias, self.outputs_metadata + ): + output_storages.append( + self.prepare_alias_info_for_tensor_construction( + output_storage_alias, metadata + ) + ) + + return output_storages + + def run_graph(self): + assert self.graph is not None + self.graph.replay() + + def all_outputs_are_dead(self): + "All outputs of the path from this node to its root are dead" + for depth, output_index in self.live_indices_after_graph: + if is_live(self.path_weakrefs[depth][output_index]): + return False + return True + + def _record(self, model, inputs): + "Record the model" + + def static_input_iter(): + for i in self.wrapped_function.static_input_idxs: + if isinstance( + inputs[i], torch.Tensor + ) and not self._is_cuda_graph_recorded_tensor(inputs[i]): + yield inputs[i] + + # see: output_is_alias_of_persistent_static_inputs above + static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = { + inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp) + for inp in itertools.chain( + static_input_iter(), self.wrapped_function.constants + ) + } + + if config.triton.slow_path_cudagraph_asserts: + # need to use parent live weakrefs because live_indices isnt set yet + memory = ( + [] if self.parent is None else list(self.parent.path_live_weakrefs()) + ) + memory += [ + StorageWeakRefWrapper(elem) + for i, elem in enumerate(inputs) + if isinstance(elem, torch.Tensor) + and i not in self.wrapped_function.static_input_idxs + and elem.untyped_storage().data_ptr() != 0 + ] + check_memory_pool(self.device, self.cuda_graphs_pool, memory) + + with preserve_rng_state(), torch.cuda.device( + self.device + ), clear_cublas_manager(), torch.cuda.graph( + self.graph, + stream=self.stream, + pool=self.cuda_graphs_pool, + capture_error_mode="thread_local", + ), get_history_recording(): + static_outputs = model(inputs) + + # running model should reclaim memory + assert len(inputs) == 0 + + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + self._add_first_outputs(static_outputs, static_input_persistent_storage_ptrs) + + return static_outputs + + def _add_first_outputs( + self, + outputs, + static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper], + ): + "Add the outputs from the first invocation of the node and set up metadata" + + # getting liveness before we have added the outputs to path, so the length + # of the two lists is equal + prev_liveness = self.recorded_liveness_before_graph + curr_liveness = self._get_liveness(self.path_weakrefs) + + delta = self._get_different_indices(prev_liveness, curr_liveness) + self.expected_dead_indices_after_graph = delta + + assert len(self.outputs_weakrefs) == 0 + # index from data pointer to index in outputs + output_new_storages_index: Dict[StorageDataPtr, int] = {} + + self.unaliased_in_all_paths = [False for _ in range(len(outputs))] + self.static_output_tensors = [None for _ in range(len(outputs))] + + for i, o in enumerate(outputs): + if o is None or not isinstance(o, torch.Tensor): + self.output_storage_alias.append(UnaliasedStorage) + continue + + torch._check( + o.is_cuda or o.untyped_storage().data_ptr() == 0, + lambda: ( + "Expected all cuda outputs in cuda graph recording. Non cuda output " + f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" + ), + ), + + ref = static_input_persistent_storage_ptrs.get( + o.untyped_storage().data_ptr(), None + ) + # also treat empty storages as static outputs because we do not need to manage their lifetime + # and they should not participate in checkpointing + is_empty_storage = o.untyped_storage().data_ptr() == 0 + if (ref and ref() is not None) or is_empty_storage: + self.output_storage_alias.append(None) + self.static_output_tensors[i] = o + continue + + path_ref = self._is_alias_of_live_recorded_tensor(o) + if path_ref is not None: + self._mark_prior_graph_output_as_aliased(path_ref) + self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref)) + continue + + if o.untyped_storage().data_ptr() in output_new_storages_index: + index = output_new_storages_index[o.untyped_storage().data_ptr()] + self.unaliased_in_all_paths[index] = False + self.output_storage_alias.append(AliasesNewOutput(index)) + continue + + output_new_storages_index[o.untyped_storage().data_ptr()] = i + self.output_storage_alias.append(UnaliasedStorage) + self.unaliased_in_all_paths[i] = True + + if self.stack_traces is None: + self.stack_traces = [None for _ in range(len(outputs))] + else: + assert len(self.stack_traces) == len( + outputs + ), "Wrong number of stack traces passed in" + + assert not self.outputs_weakrefs + for out, static_output_tensor in zip(outputs, self.static_output_tensors): + if not isinstance(out, torch.Tensor) or static_output_tensor is not None: + self.outputs_weakrefs.append(None) + self.tensor_weakrefs.append(None) + else: + self.outputs_weakrefs.append(StorageWeakRefWrapper(out)) + self.tensor_weakrefs.append(TensorWeakRef(out)) + + self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs) + self.checkpointed_caching_state = torch._C._cuda_getCheckpointState( + self.device, self.cuda_graphs_pool + ) + + # now, get liveness with outputs added + for depth in range(len(self.path_weakrefs)): + for output_index in range(len(self.path_weakrefs[depth])): + if is_live(self.path_weakrefs[depth][output_index]): + self.live_indices_after_graph.append((depth, output_index)) + + self.debug_check_invariants_after_invocation() + if config.triton.slow_path_cudagraph_asserts: + check_memory_pool( + self.device, self.cuda_graphs_pool, list(self.path_live_weakrefs()) + ) + + def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex): + "Remove a graph output from the unaliased, cached tensors in an ancestor node" + depth, output_index = index + node = list(self._path_from_root)[depth] + node.unaliased_in_all_paths[output_index] = False + x = self.path_weakrefs[depth][output_index] + assert x is not None + x.remove_extra_reference() + + def _initialize_cached_tensors(self): + # we should not be clearing output_weakrefs, and they should be set in the first + # record run + assert len(self.outputs_weakrefs) == len(self.outputs_metadata) + + for i, (storage_info, metadata, make_cached) in enumerate( + zip( + self.output_storage_alias, + self.outputs_metadata, + self.unaliased_in_all_paths, + ) + ): + if not make_cached: + self.cached_tensor_outputs.append(None) + continue + + assert storage_info is UnaliasedStorage + assert isinstance(metadata, dict) + s = self.create_storage(metadata) + out = self._reconstruct_from_tensor_metadata(metadata, storage=s) + + # XXX: let autograd know that there will be an additional reference to the tensor + # that can be ignored when deciding whether to do gradient buffer inplacing. + # Otherwise, inplacing could differ between tracing and subsequent execution. + # For some models we tested this led to inputs no longer being in cudagraph pools, + # leading to spurious re-recordings. + # It also tells AMP cache that even though the tensor impls cannot be cached + # in dtype conversions. + + torch._C._add_cached_tensor(out) + + self_ref = weakref.ref(self) + + # one reference in our array, and calling sys.getrefcount bumps the refcount by one + def check_refcount(i): + self_loc = self_ref() + if self_loc is None: + return False + return self_loc.get_output_refcount(i) == 2 + + check = functools.partial(check_refcount, i=i) + + self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check) + self.cached_tensor_outputs.append(out) + + def get_output_refcount(self, index): + return sys.getrefcount(self.cached_tensor_outputs[index]) + + @property + def parent(self): + "unwraps the weakref to _parent" + return self._parent() if self._parent is not None else None + + @property + def _path_to_root(self): + "Returns all nodes in the path starting at self and ending at root" + node = self + while node: + yield node + node = node.parent + + @property + def _path_from_root(self): + "Returns all nodes in the path starting at the root and ending at self" + nodes = reversed(list(self._path_to_root)) + yield from nodes + + def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor): + "Is this tensor an output of a node in this path" + for output_refs in self.path_weakrefs: + for storage_weak_ref in output_refs: + if storage_weak_ref is None: + continue + # don't need to check liveness of storage since the cuda graph managed + # memory is never released. + data_ptr = storage_weak_ref.data_ptr() + if t.untyped_storage().data_ptr() == data_ptr: + return True + + return False + + def _is_alias_of_live_recorded_tensor( + self, t: torch.Tensor + ) -> Optional[PathOutputIndex]: + for depth, output_refs in enumerate(self.path_weakrefs): + for output_index, storage_ref in enumerate(output_refs): + if (storage_and_ptr := maybe_deref(storage_ref)) is not None: + storage, ptr = storage_and_ptr + if ptr == t.untyped_storage().data_ptr(): + return (depth, output_index) + + return None + + @staticmethod + def _check_liveness( + indices: List[PathOutputIndex], + output_refs: List[List[Optional[StorageWeakRefWrapper]]], + ): + "Check that all of the indices specified are dead references" + for depth, output_index in indices: + w = output_refs[depth][output_index] + assert w is not None + if w() is not None: + return False + return True + + def add_child(self, function_id: FunctionID, node: CUDAGraphNode): + "Adds node as a a child of self" + self.children[function_id].append(node) + + @staticmethod + def _get_different_indices( + prev: List[List[bool]], curr: List[List[bool]] + ) -> List[PathOutputIndex]: + "Find indices where the two lists differ." + dead_indices = [] + assert len(prev) <= len(curr) + for i, (outputs1, outputs2) in enumerate(zip(prev, curr)): + assert len(outputs1) == len(outputs2) + for j, (output1, output2) in enumerate(zip(outputs1, outputs2)): + if output1 != output2: + dead_indices.append((i, j)) + + return dead_indices + + @staticmethod + def _get_liveness( + weakrefs: List[List[Optional[StorageWeakRefWrapper]]], + ) -> List[List[bool]]: + "Maps weakrefs to true if the reference is alive and false otherwise" + if len(weakrefs) == 0: + return [] + + return [pytree.tree_map(is_live, outputs) for outputs in weakrefs] + + def debug_assert_invariants( + self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex] + ): + if not config.triton.fast_path_cudagraph_asserts: + return + + for i, node in enumerate(self._path_from_root): + assert self.path_weakrefs[i] is node.outputs_weakrefs + + nodes = list(self._path_from_root) + + live_blocks = get_block_addrs(self.cuda_graphs_pool) + + live_storage_data_ptrs = set() + live_storage_weak_ptrs = set() + + for depth, outputs_liveness in enumerate(expected_liveness): + for output_idx, output_liveness in enumerate(outputs_liveness): + # tensor can die early, but it can't be alive when it should be dead + w = self.path_weakrefs[depth][output_idx] + if (stor_weak_ptr_and_data_ptr := maybe_deref(w)) is not None: + assert output_liveness + stor_weak_ptr, stor_data_ptr = stor_weak_ptr_and_data_ptr + assert (stor_data_ptr in live_storage_data_ptrs) == ( + stor_weak_ptr in live_storage_weak_ptrs + ) + live_storage_data_ptrs.add(stor_data_ptr) + live_storage_weak_ptrs.add(stor_weak_ptr) + + is_persistent_alias = ( + nodes[depth].static_output_tensors[output_idx] is not None + ) + + if is_persistent_alias: + assert stor_data_ptr not in live_blocks + + for depth, output_index in newly_dead: + assert not is_live(self.path_weakrefs[depth][output_index]) + + def debug_check_invariants_before_invocation(self): + self.debug_assert_invariants( + self.recorded_liveness_before_graph, self.expected_dead_indices_before_graph + ) + + def debug_check_invariants_after_invocation(self): + self.debug_assert_invariants( + self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph + ) + + def data_ptrs_dead_since_invocation(self) -> List[int]: + """ + Since this node was invoked, return data ptrs of all tensor outputs that have died + in the current executing tree path. + """ + curr_liveness = self._get_liveness(self.path_weakrefs) + _get_different_indices = self._get_different_indices( + self.recorded_liveness_after_graph, curr_liveness + ) + + path = list(self._path_from_root) + ptrs_to_deallocate = [] + for depth, output_index in _get_different_indices: + ptrs_to_deallocate.append( + path[depth].outputs_metadata[output_index]["data_ptr"] + ) + + return ptrs_to_deallocate + + def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: + for i, j in self.live_indices_after_graph: + out = self.path_weakrefs[i][j] + if out is not None and is_live(out): + yield out + + def remove_node_cached_tensors(self): + for t in self.cached_tensor_outputs: + if t is not None: + torch._C._remove_cached_tensor(t) + self.cached_tensor_outputs.clear() + + for i, unaliased in enumerate(self.unaliased_in_all_paths): + if unaliased: + n = self.outputs_weakrefs[i] + assert n is not None + n.remove_extra_reference() + + def remove_path_cached_tensors(self): + for node in self._path_from_root: + node.remove_node_cached_tensors() + + def clear_path_state(self): + "Clear the path state in this current executing node" + # this doesnt actually do anything right now, leaving it as placeholder + pass + + @staticmethod + def _tensor_metadata(x, ignore_storage_offset=True): + assert isinstance(x, torch.Tensor) + # We ignore the storage offset for inputs, but not for outputs + # TODO: - should we make the storage resizable ? + return { + "nbytes": x.untyped_storage().nbytes(), + "data_ptr": x.untyped_storage().data_ptr(), + "size": x.shape, + "stride": x.stride(), + "dtype": x.dtype, + "device": x.device, + "storage_offset": x.storage_offset() if not ignore_storage_offset else 0, + } + + def _reconstruct_from_tensor_metadata( + self, metadata: Dict[str, Any], storage=None + ) -> Tensor: + s = self.create_storage(metadata) if storage is None else storage + return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s) + + def create_storage(self, metadata): + return torch._C._construct_storage_from_data_pointer( + metadata["data_ptr"], metadata["device"], metadata["nbytes"] + ) + + def _allocate_and_copy_recording_inputs( + self, inputs + ) -> List[Union[torch.Tensor, int]]: + """ + Allocate inputs for non static, non cudagraph managraphed managed tensors in the memory pool + and copy over the tensor values. + """ + + torch.cuda.synchronize() + self.stream.wait_stream(torch.cuda.current_stream()) + recording_inputs: List[Union[Tensor, int]] = [] + + with warnings.catch_warnings(record=True), torch.cuda.device( + self.device + ), _use_cuda_memory_pool_manager( + self.device, + mem_pool=self.cuda_graphs_pool, + stream=self.stream, + ): + for i, inp in enumerate(inputs): + if not isinstance(inp, torch.Tensor): + assert isinstance(inp, int) + recording_inputs.append(inp) + elif i not in self.static_input_idxs: + # static_input does an allocation! + recording_inputs.append(static_input(inp)) + # copy over and clear non recording input + self._copy_input(i, recording_inputs[-1], inp) + inputs[i] = None + del inp + else: + recording_inputs.append(inp) + + return recording_inputs + + def check_invariants(self, inputs: List[Tensor]) -> bool: + """ + Checks if this node can be run. The same pattern of tensor liveness and tensors + managed in the cudagraph private pool must remain stable. + """ + + # previously managed data pointers remain stable + for idx in self.cudagraph_managed_idxs: + if inputs[idx].data_ptr() != self.static_input_data_ptrs[idx]: + return False + + if not self._check_liveness( + self.expected_dead_indices_before_graph, self.path_weakrefs + ): + return False + + # the cudagraph managed tensors which died upon recording must also die upon + # this invocation. it is too late to check after we've replayed the graph, + # because we would have already written over their memory. + for idx in self.cudagraph_managed_idxs: + inputs[idx] = None # type: ignore[call-overload] + + torch._check( + self._check_liveness( + self.expected_dead_indices_after_graph, self.path_weakrefs + ), + lambda: "TODO: graph recording observed an input tensor deallocate during graph " + " recording that did not occur during replay. Please file an issue.", + ) + return True + + def num_descendants(self) -> int: + "Total number of descendents of this node" + num_desc = 0 + for children in self.children.values(): + for child in children: + num_desc += 1 + num_desc += child.num_descendants() + return num_desc + + +def get_cudagraph_segments(pool_id): + segments = torch.cuda.memory_snapshot() + return [segment for segment in segments if segment["segment_pool_id"] == pool_id] + + +def get_block_addrs(pool_id, live_only=True): + blocks = [] + + for segment in get_cudagraph_segments(pool_id): + addr = segment["address"] + for block in segment["blocks"]: + if block["state"] == "active_allocated" or not live_only: + blocks.append(addr) + + addr += block["size"] + + return blocks + + +def format_tb(frames): + formatted_traceback = [] + + for entry in frames: + formatted_traceback.append( + traceback.FrameSummary(entry["filename"], entry["line"], entry["name"]) + ) + + return "".join(traceback.format_list(formatted_traceback)) + + +def check_memory_pool(device, pool_id, live_storages_ptrs: List[StorageWeakRefWrapper]): + assert all( + isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs + ) # noqa: C419 + unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} + + # check if there is a divergence first, then do the expensive snapshot call after + # we know it will error + if torch._C._cuda_checkPoolLiveAllocations(device, pool_id, unique_storages): + return + + # at this point we are past the fast-path. we have seen rare cases where a dead tensor is dead, + # but hasn't been gc'd yet, and gives false positive for allocated_not_in_live_storages + gc.collect() + + segments = get_cudagraph_segments(pool_id) + + allocated_not_in_live_storages = {} + + for segment in segments: + addr = segment["address"] + for block in segment["blocks"]: + if block["state"] == "active_allocated": + if addr not in unique_storages: + allocated_not_in_live_storages[addr] = block + else: + unique_storages.remove(addr) + + addr += block["size"] + + torch._check( + len(unique_storages) == 0, + lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}", + ) + + if allocated_not_in_live_storages != 0: + formatted = [] + for dp, block in allocated_not_in_live_storages.items(): + trace = format_tb(block.get("frames", [])) + formatted.append(f"Data Pointer: {dp}, history: \n{trace}") + formatted_s = "\n".join(formatted) + msg = ( + f"These live storage data ptrs are in the cudagraph pool but not " + f"accounted for as an output of cudagraph trees: \n\n{formatted_s}" + ) + raise RuntimeError(msg) + + +class ExecutionState(Enum): + """ + Represents the state of the CUDAGraph Tree. Will be None if there is no live current memory allocated + in the cuda graph pool. Otherwise will reflect the state of the most recently executed node. + """ + + NONE = auto() + WARMUP = auto() + RECORDING = auto() + EXECUTION = auto() + + +class CompilationMode(Enum): + FORWARD = auto() + BACKWARD = auto() + INFERENCE = auto() + + +class CUDAGraphTreeManager: + """ + Groups individual recordings or executions of cuda graphs into a tree of recordings, + and checks required invariants, and manages warmups of graphs. + + When graphs are recorded in the same tree, it enforces subsequent execution + to follow the same order and have the same output tensor livespans. To remove + unnecessary coupling of cuda graphs (and additional imposed invariants), + the tree manager will end a currently recording tree whenever it is valid - when + the memory pool no longer has any live allocations. + + We ignore outputs from a previous generation that correspond to prior model outputs. + Currently this is hardcoded `GenerationTracker.generation` tracked in torch dynamo. + # TODO: make generation increment configurable, warn on overwrite. + + We run graph warmups in the cudagraph memory pool and return the result on the first invocation + of a function. For many models it is important to reclaim activations as you run the backward. + If we were to warm up the model and keep an extra copy of the inputs around to subsequently + use for recording, we would incur a memory penalty. Additionally, if we are part way through training + your model and need to recompile, memory will be allocated to the cuda graph pool, so we run this + warmup run in the cuda graph memory pool. As for recording, warm up needs the state of live tensors + to be accurately reflected so we checkpoint the allocator state if we need to warm up following graph + replay. + """ + + def __init__(self, device_index: int): + # roots are functions which have no dependencies on an other node. I.e., + # when they are first invoked, none of their inputs are outputs are outputs + # of another node, nor are there any live outputs of another node whose + # liveness would create a dependency. + self.roots: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list) + + # mapping from function id to wrapped function + self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {} + + self.ids_to_stack_traces: Dict[FunctionID, StackTraces] = {} + + self.warmed_up_functions: Set[FunctionID] = set() + # if we fail to increment generation, and are stuck warming up, + # only warn on each function once + self.warned_functions: Set[FunctionID] = set() + torch._C._set_cached_tensors_enabled(True) + + # NB: cuda caching allocator will remember the stream a segment is allocated to + # and only allocate that segment to the same stream. we need to use a single stream + # for all allocations to the memory pool, otherwise the allocations to separate streams + # will not be reused; separate recordings would have use the same memory pool, but not + # the same memory. + + with torch.cuda.device(device_index): + torch.cuda.synchronize() + self.stream = torch.cuda.Stream() + self.stream.wait_stream(torch.cuda.current_stream()) + + # Keeps Memory Pool Alive + self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle() + + with warnings.catch_warnings(record=True), torch.cuda.graph( + self.graph, + pool=self.cuda_graphs_thread_pool, + stream=self.stream, + capture_error_mode="thread_local", + ): + pass + + self.graph_counter = itertools.count(0) + self.func_counter = itertools.count(0) + + # whether we the current node is in a state of warmup, recording, execution. If + # there is no current node the state will be ExecutionState.None. + self.path_state = ExecutionState.NONE + self.device_index = device_index + + # the most recently invoked cudagraph wrapping of a function. Will be None + # when there is no output from a previous recording or execution whose memory + # we need to respect in the cuda caching allocation. If you incremented generation, + # this will also be none, as ignore those allocations. + self.current_node: Optional[CUDAGraphNode] = None + + # current generation of cudagraph invocations. when torch.compile is run + # we increment the current generation. are willing to ignore live outputs + # of a previous generation in checking liveness. + self.current_gen: int = -1 + + # number of instances we are in execution and failed to match to an + # existing child + self.debug_fail_counter = 0 + # number of instances we had to checkpoint the function + self.debug_checkpointing_counter = 0 + + self.id_to_mode: Dict[FunctionID, CompilationMode] = {} + + # Note: [Backward Generation Handling] + # We generally perform a sequence of forward executions followed by backward executions. + # If multiple torch.compile wrapped forwards are executed with their backwards pending, + # we should not disregard the outputs from a prior torch.compile since the entire training + # loop hasn't completed. Occasionally, a backward pass corresponding to a forward pass may + # not be executed, so we cannot wait for all pending forward pass backward completions, so + # we cannot wait for all backwards to have been invoked. Instead we wait for a single backward + # invocation. Triggering a backward pass typically doesn't lead to another torch.compile + # invocation, making it less likely for the generation to increase between multiple + # backward calls. The following use case is covered by this approach: + # mod1 = torch.compile(...) + # mod2 = torch.compile(...) + # mod2(mod1(x)).sum().backward() + + self.running_forwards_with_pending_backwards = False + + def run(self, new_inputs: List[Tensor], function_id: FunctionID): + assert self.graph is not None, "Running CUDAGraph after shutdown" + out = self._run(new_inputs, function_id) + + # The forwards are only pending following invocation, not before + mode = self.id_to_mode[function_id] + if mode == CompilationMode.FORWARD: + self.running_forwards_with_pending_backwards = True + elif mode == CompilationMode.BACKWARD: + self.running_forwards_with_pending_backwards = False + + return out + + def set_to_running_backward(self): + self.running_forwards_with_pending_backwards = False + + def _run(self, new_inputs: List[Tensor], function_id: FunctionID): + # we will try to end the current execution lazily, since + # we dont want to do unnecessary checking of the existing outputs + # on the hot path, but both recording and warmup only happen once + # so we check up front + if self.in_recording: + self.try_end_curr_recording(function_id) + + if self.in_warmup: + self.try_end_curr_warmup(function_id) + + # warming up a function and subsequentally recording may use different memory addresses + # because both depend on the state of the caching allocator. if we warm up graph A, + # then warm up graph B and make more allocations, the subsequent recording of A will not + # necessarily use the same addresses as in the warm up. Thus any warm up of a node can only + # be followed by warm up runs. + if ( + not ( + function_id in self.warmed_up_functions + or config.triton.skip_cudagraph_warmup + ) + ) or self.in_warmup: + # If we are in the middle of executing cuda graphs, then we need to checkpoint memory state. + # Both Recording and Warmup will be reflected in the allocator and dont need changes + if self.path_state == ExecutionState.EXECUTION: + self.apply_checkpoint_execution_state_in_allocator() + + return self.run_eager(new_inputs, function_id) + + child_nodes = ( + self.roots if self.current_node is None else self.current_node.children + ) + + if not self.in_recording: + for child in child_nodes[function_id]: + # here we are checking memory consistency between recording and execution, + # as well as things like stability of tensor locations, etc + # and other + if child.check_invariants(new_inputs): + return self.execute_node(child, new_inputs) + + # now that we know the new function can't be run as a child of the + # current node, if it is a root, try to end the current execution. + # as noted above, we want to do this lazily to avoid having to + # check all existing outputs + if self.current_node is not None and function_id in self.roots: + self.try_end_curr_execution() + + # run again to hit the root matching case which must succeed + if self.current_node is None: + return self.run(new_inputs, function_id) + + # at this point, we necessarily will do a new recording + self.debug_fail_counter += 1 + + self.try_end_curr_execution() + if self.current_node is not None: + self.apply_checkpoint_execution_state_in_allocator() + + # now, we are in a recording state ! + return self.record_function(new_inputs, function_id) + + def shutdown(self): + """ + Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn + might reference a backward which invokes a CUDA Graph Node, we have to manually clear them on shutdown + to avoid a reference cycle. + """ + nodes = [] + for roots in self.roots.values(): + nodes.extend(roots) + + while nodes: + node = nodes.pop() + for children in node.children.values(): + nodes.extend(children) + node.remove_node_cached_tensors() + node.graph = None + + self.graph = None + self.roots = None # type: ignore[assignment] + self.current_node = None + + def record_function(self, new_inputs, function_id) -> List[Optional[Tensor]]: + graph_id = self.new_graph_id() + log.debug( + "Recording function %d of graph recording id %d", + function_id.id, + graph_id.id, + ) + torch.cuda.synchronize() + node = CUDAGraphNode( + self.ids_to_funcs[function_id], + graph_id, + self.current_node, + new_inputs, + self.cuda_graphs_thread_pool, + self.device_index, + self.ids_to_stack_traces[function_id], + self.stream, + ) + if self.current_node is None: + self.roots[function_id].append(node) + else: + self.current_node.add_child(function_id, node) + self.current_node = node + self.path_state = ExecutionState.RECORDING + self.update_generation() + torch.cuda.synchronize() + return node.run_first_inputs(new_inputs) + + def execute_node(self, node: CUDAGraphNode, new_inputs) -> List[Optional[Tensor]]: + self.current_node = node + self.path_state = ExecutionState.EXECUTION + self.update_generation() + return node.run(new_inputs) + + def run_eager(self, new_inputs, function_id: FunctionID): + # this is only stored on current node, because when we start a new path, + # we will deallocate it + already_warm = function_id in self.warmed_up_functions + if not already_warm: + log.debug("Running warmup of function %d", function_id.id) + else: + log.debug( + "Running eager of function %d because ancestor needed to warm up", + function_id.id, + ) + self.warmed_up_functions.add(function_id) + node = CUDAWarmupNode( + self.ids_to_funcs[function_id], + self.current_node, + self.cuda_graphs_thread_pool, + self.graph, + self.device_index, + self.ids_to_stack_traces[function_id], + self.stream, + already_warm, + ) + self.current_node = node + self.path_state = ExecutionState.WARMUP + self.update_generation() + return node.run(new_inputs) + + def new_graph_id(self) -> GraphID: + return GraphID(next(self.graph_counter)) + + def new_func_id(self) -> FunctionID: + return FunctionID(next(self.func_counter)) + + def add_function( + self, + model, + inputs, + static_input_idxs, + stack_traces, + mode, + constants, + ) -> Tuple[Callable[..., Any], List[Optional[Tensor]]]: + id = self.new_func_id() + self.ids_to_stack_traces[id] = stack_traces + self.ids_to_funcs[id] = WrappedFunction( + model, + static_input_idxs, + id, + tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_cuda), + ) + self.id_to_mode[id] = mode + fn = functools.partial(self.run, function_id=id) + + # container needs to set clean up when fn dies + get_container(self.device_index).add_strong_reference(fn) + return fn, fn(inputs) + + @property + def in_recording(self): + return self.path_state == ExecutionState.RECORDING + + @property + def in_warmup(self): + return self.path_state == ExecutionState.WARMUP + + def get_roots(self) -> Iterator[CUDAGraphNode]: + for nodes in self.roots.values(): + yield from nodes + + @property + def current_node(self): + return self._current_node + + @current_node.setter + def current_node(self, value): + self._current_node = value + if value is None: + self.path_state = ExecutionState.NONE + + def update_generation(self): + self.current_gen = self.get_curr_generation() + + @staticmethod + def get_curr_generation() -> int: + if MarkStepBox.mark_step_counter != 0: + return MarkStepBox.mark_step_counter + + return GenerationTracker.generation + + @staticmethod + def user_invoked_mark_step(): + return MarkStepBox.mark_step_counter != 0 + + def can_start_new_generation(self) -> bool: + if not self.in_new_torch_compile_invocation(): + return False + + if self.user_invoked_mark_step(): + return True + + return not self.running_forwards_with_pending_backwards + + def in_new_torch_compile_invocation(self): + return self.current_gen != self.get_curr_generation() + + def try_end_curr_recording(self, function_id: FunctionID) -> None: + """ + Check if the current recording can be terminated, either because all outputs of the + previously recorded node are dead or because it was executed in a different + generation. Will set current_node to None and in_recording to False if successful. + """ + assert self.in_recording + assert self.current_node is not None + + # multiple invocations, allow overwriting the previous generation + if self.can_start_new_generation(): + self.dealloc_current_path_weakrefs() + self.clear_current_path_state_and_set_to_none() + return + + if self.current_node.all_outputs_are_dead(): + self.clear_current_path_state_and_set_to_none() + return + + self.check_warn_on_unable_to_start_executing(function_id) + + def try_end_curr_execution(self) -> None: + """ + Check if the current executing node can be terminated, either because all outputs of the + previously executed node are dead or because it was executed in a different generation. + Will set current_node to None if successful. + """ + + assert not self.in_recording + if self.current_node is None: + return + + if self.can_start_new_generation(): + self.clear_current_path_state_and_set_to_none() + return + + if self.current_node.all_outputs_are_dead(): + self.clear_current_path_state_and_set_to_none() + + def try_end_curr_warmup(self, function_id: FunctionID): + if self.can_start_new_generation(): + self.dealloc_current_path_weakrefs() + self.current_node = None + return + + if self.current_node.all_outputs_are_dead(): + self.current_node = None + return + + self.check_warn_on_unable_to_start_executing(function_id) + + def check_warn_on_unable_to_start_executing(self, function_id: FunctionID): + "Warn if we in a potential loop where we are unable to hit fast path" + if ( + function_id in self.warned_functions + or not self.in_new_torch_compile_invocation() + ): + return + + existing_nodes = [ + node + for node in self.current_node._path_from_root + if node.wrapped_function.id == function_id + ] + + if len(existing_nodes) <= 1: + return + + # repeated same pattern + parents = { + n.parent.wrapped_function.id + for n in itertools.chain(existing_nodes, (self.current_node,)) + if n.parent is not None + } + if len(parents) == len(existing_nodes): + return + + self.warned_functions.add(function_id) + warnings.warn( + "Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. " + "Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() " + "before each model invocation" + ) + + def dealloc_current_path_weakrefs(self): + # TODO: we could also allow the these weak refs to continue to be allocated, + # but that adds some complications. + for node in self.current_node._path_from_root: + assert len(node.tensor_weakrefs) == len(node.stack_traces) + for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces): + ten = None if t is None else t() + if ten is None: + continue + + stack_trace = ( + stack_trace.strip() + if stack_trace + else "[Could not find stack trace]" + ) + msg = ( + "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. " + f"Stack trace: {stack_trace}. " + "To prevent overwriting, clone the tensor outside of torch.compile() " + "or call torch.compiler.cudagraph_mark_step_begin() before each model invocation." + ) + torch._C._set_storage_access_error_msg(ten, msg) + + deleted = set() + for storage_ref in self.current_node.path_live_weakrefs(): + if storage_ref() and storage_ref.data_ptr() not in deleted: + deleted.add(storage_ref.data_ptr()) + torch._C._free_And_Remove_DeleterFn(storage_ref()) + + def clear_current_path_state_and_set_to_none(self): + self.current_node.clear_path_state() + self.current_node = None + + def apply_checkpoint_execution_state_in_allocator(self): + """ + Checkpoint the current execution state in the caching allocator so that + additional cudagraph recordings can be made respecting existent live storages. + """ + self.debug_checkpointing_counter += 1 + log.debug( + "Checkpointing cuda caching allocator state. Number of checkpoints %d", + self.debug_checkpointing_counter, + ) + + state = self.current_node.checkpointed_caching_state + device = self.current_node.device + assert state is not None and device is not None + + # currently we deallocate on instead of allowing stale recordings + stale_storages: List[int] = [] + + # remove cached tensors, otherwise they would prevent memory from being + # reclaimed in subsequent recordings + self.current_node.remove_path_cached_tensors() + live_storages_wrappers = list(self.current_node.path_live_weakrefs()) + + live_storages_weak_refs = [t() for t in live_storages_wrappers] + ptrs_to_deallocate = self.current_node.data_ptrs_dead_since_invocation() + torch._C._cuda_setCheckpointPoolState( + device, state, stale_storages, live_storages_weak_refs + ) + + # NB: deduplicate aliased outputs + for ptr in set(ptrs_to_deallocate): + torch._C._cuda_cudaCachingAllocator_raw_delete(ptr) + + # Now the live blocks should be exactly equal to the live storages in private pool + if config.triton.slow_path_cudagraph_asserts: + check_memory_pool( + self.device_index, self.cuda_graphs_thread_pool, live_storages_wrappers + ) + for wrapper in live_storages_wrappers: + assert wrapper() + assert torch._C._has_Standard_Deleter(wrapper()) + assert wrapper.data_ptr() not in ptrs_to_deallocate + + def live_cudagraph_pool_storages_in_curr_execution( + self, + ) -> List[StorageWeakRefPointer]: + if self.current_node is None: + return [] + # explicitly ignoring previous recorded outputs from past path + return [t() for t in self.current_node.path_live_weakrefs()] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/hooks.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..2b558f4350a79235b5e28f91bee24655822a7933 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/hooks.py @@ -0,0 +1,28 @@ +import contextlib +from typing import Callable, List, TYPE_CHECKING + +if TYPE_CHECKING: + import torch + +# Executed in the order they're registered +INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = [] + + +@contextlib.contextmanager +def intermediate_hook(fn): + INTERMEDIATE_HOOKS.append(fn) + try: + yield + finally: + INTERMEDIATE_HOOKS.pop() + + +def run_intermediate_hooks(name, val): + global INTERMEDIATE_HOOKS + hooks = INTERMEDIATE_HOOKS + INTERMEDIATE_HOOKS = [] + try: + for hook in hooks: + hook(name, val) + finally: + INTERMEDIATE_HOOKS = hooks diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ops_handler.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ops_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..9a25edfa7d960ecd2df3c54cac468ac9e1ed3a7f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ops_handler.py @@ -0,0 +1,655 @@ +import itertools +from typing import Any, Callable, Generic, Literal, Optional, Tuple, TypeVar, Union +from unittest.mock import patch + +import sympy +from typing_extensions import Protocol + +import torch +import torch.utils._pytree as pytree +from torch.fx.graph import inplace_methods, magic_methods +from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str + +T = TypeVar("T") +StoreMode = Optional[Literal["atomic_add"]] +ReductionType = Literal[ + "argmax", + "argmin", + "welford_reduce", + "welford_combine", + "any", + "max", + "min", + "prod", + "sum", + "xor_sum", +] + + +def _arg_str(a) -> str: + if isinstance(a, sympy.Expr): + return sympy_str(a) + return str(a) + + +# NB: This is not done as a parent class, because our ops handlers +# implementations make heavy use of __getattr__ magic, and pre-existing +# stubs for methods would interfere with this mechanism. +# +# TODO: A superclass that does desugaring for operations like +# reciprocal/square might be useful. +class OpsHandler(Protocol[T]): + """ + Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``, + as well as the contract for op handlers. The type T signifies the domain + of the abstract analysis AKA what all of the functions return / take as arguments + anywhere compute occurs. + + While these operators are typically dtype polymorphic (e.g., you can use mul + on both integers and floats), they do NOT do promotion and usually return the + same dtype as the input. You are expected to have handled type promotion + during ATen decompositions. Most operators correspond exactly to pointwise + operations as defined by torch, so when in doubt about semantics, check the + corresponding torch documentation. These are all scalar operations (so they + are defined to operate on a single element at a time.) + + For convenience, many operators take a src_dtype which indicates what the dtype + of the input argument is. Although in principle this can be derived by an + analysis, providing this for ops where it is useful helps avoid having to repeatedly + recompute dtype in code generation. + + Note that this often describes a class of static methods, for stateless + ops handlers. + + Handlers are often defined using ``__getattr__`` metaprogramming, which means + that you cannot declare that a type implements a protocol by inheriting from + it (as the type stubs count as attribute declarations and impede the getattr + magic method from being called). Instead, define a function that casts an + argument of your type to the protocol, which is sufficient to induce mypy to + test that the protocol is implemented correctly. Search for ``_typecheck_`` + in this file to see some examples. If you see an obscure error where a + class doesn't implement a Protocol, but mypy doesn't say why, check to see + that ``__getattr__`` is typed correctly (typically, it is not possible to + type ``__getattr__`` without typing it as ``Callable[..., Any]``) + """ + + def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T: + """Produces a scalar constant of type dtype.""" + ... + + def load_seed(self, name: str, offset: T): + """Computes inductor_prims.lookup_seed.""" + ... + + def rand(self, seed: T, offset: T) -> T: + """Computes inductor_prims.random with mode="rand". offset has dtype int32.""" + ... + + def randn(self, seed: T, offset: T) -> T: + """Computes inductor_prims.random with mode="randn". offset has dtype int32.""" + ... + + def randint64(self, seed: T, offset: T, low: T, high: T) -> T: + """Computes inductor_prims.randint. offset has dtype int32.""" + ... + + def masked(self, mask: T, body: Callable[[], T], other: T) -> T: + """ + Computes body, but only perform loads/stores if the boolean mask + evaluates to true. For example, you would use this if you needed to + perform an indirect load that may not be valid on some elements; + without masking, invalid accesses can cause IMAs. When mask is true, + the result is the result of body; otherwise it is other. + + Contrast this with ops.where, which can multiplex between two values + that have been unconditionally computed. + """ + ... + + def where(self, condition: T, input: T, other: T) -> T: + """ + Computes torch.where: when condition is true, return input; otherwise return other. + """ + ... + + def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T: + """ + Converts a sympy expression into a scalar of type dtype. expr is typically + an indexing expression, thus the name; however, it can also be used in + non-indexing situations. + """ + ... + + def to_dtype( + self, x: T, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None + ) -> T: + """ + Convert x to dtype. src_dtype can be optionally set to specify what the original + dtype of x was, which can improve code generation (used by torch to(dtype=dtype)). + """ + ... + + def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: + """ + Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) + src_dtype must be the original type of x. + """ + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These operations are only available in a "kernel" context. Check + # torch._inductor.codegen.common.CSEProxy for their typical implementation + # in op handler (routing to their respective implementations in the kernel + # handler) + # + # Importantly, inside a kernel, indexing and mask variables are available + # in scope, which are typically used by sympy.Expr indexing. + + def indirect_indexing( + self, x: T, size: sympy.Expr, check: bool = True + ) -> sympy.Expr: + """ + Convert an integral x into a sympy.Expr that can be subsequently used in + indexing computation. 'size' represents an upper bound on the what valid + indexes can be; when 'check' is True, we check that the x is in bounds. + + NB: This is typically mandatory to implement for any analysis, because you + MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol). + """ + ... + + def load(self, name: str, index: sympy.Expr) -> T: + """ + Load from the memory location 'name', offset by some indexing expression 'index'. + """ + ... + + def store( + self, + name: str, + index: sympy.Expr, + value: T, + mode: StoreMode = None, + ) -> None: + """ + Store 'value' to the memory location 'name' offset by 'expr'. If + specified, 'mode' can require the store to be an atomic addition. + """ + ... + + # TODO: Better explain how the "collective" semantics of these ops; + # remember that the input value is a scalar, you can't reduce on it in the + # traditional sense! + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: T, + ) -> Union[T, Tuple[T, ...]]: + """ + Perform a 'reduction_type' reduction on 'value' of dtype 'src_dtype', + using 'dtype' as the accumulation dtype for the reduction. The result + is an intermediate computation which should be stored to the final + location using 'ops.store_reduction'. + + Valid reduction types are . For Welford reduction types, this + function returns multiple outputs; consult reduction_num_outputs to + determine the amount in metaprogramming applications. + """ + ... + + # TODO: in practice, this seems to actually return None, but not returning + # a T makes common __getattr__ idioms not type correctly. Figure out if + # this should be returning something. + def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T: + """ + Store the fully accumulated result of 'reduction' to the memory + location 'name' offset by 'expr'. + """ + ... + + def scan( + self, dtype: torch.dtype, combine_fn: Callable[[T, T], T], value: T, init: int + ) -> T: + """ + Perform an associative scan on 'value'. + """ + # TODO: Improve the description with some pseudocode + ... + + def bucketize( + self, + values: T, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> T: + # See [Note: Inductor bucketize op] + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # The following ops have semantics that correspond exactly to the torch + # operation with the same corresponding name. + + def abs(self, x0: T) -> T: + ... + + def exp(self, x0: T) -> T: + ... + + def exp2(self, x0: T) -> T: + ... + + def expm1(self, x0: T) -> T: + ... + + def sqrt(self, x0: T) -> T: + ... + + def relu(self, x0: T) -> T: + ... + + def minimum(self, x0: T, x1: T) -> T: + ... + + def maximum(self, x0: T, x1: T) -> T: + ... + + def cos(self, x0: T) -> T: + ... + + def sin(self, x0: T) -> T: + ... + + def lgamma(self, x0: T) -> T: + ... + + def erf(self, x0: T) -> T: + ... + + def cosh(self, x0: T) -> T: + ... + + def sinh(self, x0: T) -> T: + ... + + def acos(self, x0: T) -> T: + ... + + def acosh(self, x0: T) -> T: + ... + + def asin(self, x0: T) -> T: + ... + + def asinh(self, x0: T) -> T: + ... + + def atan2(self, x0: T, x1: T) -> T: + ... + + def atan(self, x0: T) -> T: + ... + + def atanh(self, x0: T) -> T: + ... + + def copysign(self, x0: T, x1: T) -> T: + ... + + def erfc(self, x0: T) -> T: + ... + + def erfinv(self, x0: T) -> T: + ... + + def frexp(self, x0: T): + ... + + def hypot(self, x0: T, x1: T) -> T: + ... + + def log10(self, x0: T) -> T: + ... + + def nextafter(self, x0: T, x1: T) -> T: + ... + + def logical_and(self, x0: T, x1: T) -> T: + ... + + def logical_not(self, x0: T) -> T: + ... + + def logical_or(self, x0: T, x1: T) -> T: + ... + + def logical_xor(self, x0: T, x1: T) -> T: + ... + + def bitwise_and(self, x0: T, x1: T) -> T: + ... + + def bitwise_not(self, x0: T) -> T: + ... + + def bitwise_or(self, x0: T, x1: T) -> T: + ... + + def bitwise_xor(self, x0: T, x1: T) -> T: + ... + + def bitwise_left_shift(self, x0: T, x1: T) -> T: + ... + + def bitwise_right_shift(self, x0: T, x1: T) -> T: + ... + + def rsqrt(self, x0: T) -> T: + ... + + def log1p(self, x0: T) -> T: + ... + + def tan(self, x0: T) -> T: + ... + + def tanh(self, x0: T) -> T: + ... + + def sigmoid(self, x0: T) -> T: + ... + + def signbit(self, x0: T) -> T: + ... + + def fmod(self, x0: T, x1: T) -> T: + ... + + def log(self, x0: T) -> T: + ... + + def isinf(self, x0: T) -> T: + ... + + def isnan(self, x0: T) -> T: + ... + + def round(self, x0: T) -> T: + ... + + def floor(self, x0: T) -> T: + ... + + def sign(self, x0: T) -> T: + ... + + def to_int(self, x0: T) -> T: + ... + + def trunc(self, x0: T) -> T: + ... + + def truncdiv(self, x0: T, x1: T) -> T: + ... + + def ceil(self, x0: T) -> T: + ... + + def neg(self, x0: T) -> T: + ... + + def reciprocal(self, x0: T) -> T: + ... + + def eq(self, x0: T, x1: T) -> T: + ... + + def ne(self, x0: T, x1: T) -> T: + ... + + def lt(self, x0: T, x1: T) -> T: + ... + + def gt(self, x0: T, x1: T) -> T: + ... + + def le(self, x0: T, x1: T) -> T: + ... + + def ge(self, x0: T, x1: T) -> T: + ... + + def add(self, x0: T, x1: T) -> T: + ... + + def sub(self, x0: T, x1: T) -> T: + ... + + def mul(self, x0: T, x1: T) -> T: + ... + + def floordiv(self, x0: T, x1: T) -> T: + ... + + def truediv(self, x0: T, x1: T) -> T: + ... + + def div(self, x0: T, x1: T) -> T: + ... + + def mod(self, x0: T, x1: T) -> T: + ... + + def pow(self, x0: T, x1: T) -> T: + ... + + def and_(self, x0: T, x1: T) -> T: + ... + + def or_(self, x0: T, x1: T) -> T: + ... + + def xor(self, x0: T, x1: T) -> T: + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # In CUDA, optimized implementations of other mathematical operations are + # offered separately via libdevice for double precision computation (in + # Triton, these go to tl.math rather than tl). We lower to these + # operators when doing FP64 on CUDA. Note that some operators + # unconditional go to tl.math. + # + # TODO(ezyang): Is this really the best way to do this? What if we have + # abs internally route to tl.math automatically when given a double + # precision input? One reason is that when doing codegen, we often don't + # know what the dtype of the inputs are! (In principle we do know, but + # for many analyses it's not conveniently available.) + + def libdevice_abs(self, x0: T) -> T: + ... + + def libdevice_exp(self, x0: T) -> T: + ... + + def libdevice_sqrt(self, x0: T) -> T: + ... + + def libdevice_cos(self, x0: T) -> T: + ... + + def libdevice_sin(self, x0: T) -> T: + ... + + def libdevice_sigmoid(self, x0: T) -> T: + ... + + def libdevice_log(self, x0: T) -> T: + ... + + +class MockHandler: + def __getattr__(self, name): + if name == "name": + return "MockHandler" + + def inner(*args, **kwargs): + fargs = [_arg_str(a) for a in args] + fargs.extend(f"{k}={v}" for k, v in kwargs.items()) + return f"ops.{name}({', '.join(fargs)})" + + return inner + + @staticmethod + def masked(mask, body, other) -> str: + return f"ops.masked({mask}, {body()}, {other})" + + @staticmethod + def frexp(x): + return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]") + + @staticmethod + def indirect_indexing(index_var, size, check=True) -> sympy.Symbol: + return sympy_index_symbol(f"({str(index_var)})") + + @classmethod + def _init_cls(cls): + def make_handler(format_string): + @staticmethod # type: ignore[misc] + def inner(*args): + return format_string.format(*args) + + return inner + + for name, format_string in itertools.chain( + magic_methods.items(), inplace_methods.items() + ): + setattr(cls, name, make_handler(format_string)) + + +MockHandler._init_cls() + + +# Use mypy to check protocol implemented correctly +def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]: + return h + + +class KernelFormatterHandler: + def __init__(self, parent_handler): + self.parent_handler = parent_handler + self.output = IndentedBuffer(1) + self.var_counter = itertools.count() + + @staticmethod + def ir_to_string(ir_fn, index, rindex=None) -> str: + from .ir import FlexibleLayout + from .virtualized import V + + args = [index, rindex] if rindex is not None else [index] + names = ["index", "rindex"] if rindex is not None else ["index"] + formatter = KernelFormatterHandler(MockHandler()) + + with formatter.output.indent(-1): + formatter.output.writeline(f"def inner_fn({', '.join(names)}):") + for name, arg in zip(names, args): + if arg: + lhs = ", ".join( + [ + str("_" if isinstance(v, (int, sympy.Integer)) else v) + for v in arg + ] + ) + formatter.output.writeline(f"{lhs} = {name}") + + with V.set_ops_handler(formatter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + result = ir_fn(*args) + return formatter.getvalue(result) + + def __getattr__(self, name) -> Callable[..., Any]: + def inner(*args, **kwargs): + line = getattr(self.parent_handler, name)(*args, **kwargs) + if name == "indirect_indexing": + return line + + def write(line): + # replace line with a new variable name + varname = f"tmp{next(self.var_counter)}" + self.output.writeline(f"{varname} = {line}") + return varname + + return pytree.tree_map(write, line) + + return inner + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[str, Tuple[str, ...]], + ) -> Union[str, Tuple[str, ...]]: + line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value) + num_values = reduction_num_outputs(reduction_type) + varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)] + self.output.writeline(f"{','.join(varnames)} = {line}") + return tuple(varnames) if num_values > 1 else varnames[0] + + def getvalue(self, result): + self.output.writeline(f"return {result}") + return self.output.getvalue() + + +# Use mypy to check protocol implemented correctly +def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]: + return h + + +class WrapperHandler(Generic[T]): + def __init__(self, inner: OpsHandler[T]): + self._inner = inner + + def __getattr__(self, item): + return getattr(self._inner, item) + + +# Use mypy to check protocol implemented correctly +def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]: + return h + + +class OpCounterCSE: + """Shim to count how many ops are used""" + + def __init__(self, inner): + super().__init__() + self.parent_handler = inner + self.op_count = 0 + self.var_names = {} + + def __getattr__(self, name): + def inner(*args, **kwargs): + val = getattr(self.parent_handler, name)(*args, **kwargs) + if name == "indirect_indexing": + return val + + def count(val): + if val not in self.var_names: + varname = f"tmp{self.op_count}" + self.op_count += 1 + self.var_names[val] = varname + return varname + else: + return self.var_names[val] + + return pytree.tree_map(count, val) + + return inner + + +def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]: + return h diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/optimize_indexing.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/optimize_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..680659dc4f1d9b232125fa15b291a6bed82d821b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/optimize_indexing.py @@ -0,0 +1,118 @@ +import math + +import sympy + +import torch +from torch.utils._sympy.value_ranges import ValueRanges +from .ir import LoopBody +from .utils import dominated_nodes + + +def val_expressable_in_32_bits(val): + if getattr(val, "is_Boolean", False): + return True + + if isinstance(val, sympy.Expr): + assert val.is_number + if val.is_Integer or val.is_Boolean: + val = int(val) + else: + val = float(val) + + # bound within mantissa + if isinstance(val, float): + return val <= (2**24) and val >= -(2**24) + + if isinstance(val, int): + iinfo = torch.iinfo(torch.int32) + return val <= iinfo.max and val >= iinfo.min + + raise Exception(f"Unexpected value {val}") + + +def range_expressable_in_32_bits(range): + return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits( + range.upper + ) + + +def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_vals): + # if a downstream use of a node explicitly converts to int32, or float16/float32/float64, + # then it's precision is set for that chain of uses, and we don't need to consider those + # dominated values + def skip_filter(node): + return node.target == "to_dtype" and node.args[2] in ( + torch.int32, + torch.float32, + torch.float64, + ) + + # TODO - there are dominated uses whose dtype does not depend on whether + # we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to + # int32 without changing the output precision of the node. this case hasn't shown up + for dominated in dominated_nodes([node], skip_filter): + if dominated.target in ["store", "output"]: + continue + + if isinstance(dominated.target, str) and "set_indirect" in dominated.target: + idx = int(dominated.target[len("set_indirect") :]) + indirect_var = indirect_vars[idx] + + # We check that we can compute all the indices it's involved in with int32 + for index, expr in indices.items(): + if indirect_var in expr.free_symbols: + index_val = replacement_vals[index] + + if math.isinf(index_val.lower) or math.isinf(index_val.upper): + return + + # all indices are integers, so make sure that we + # use the bounds of integers instead of floats. + # TODO - not sure if we should be doing int/float casts while tracing, + # might interfere with sympy. + + index_val_int = ValueRanges[sympy.Expr]( + int(index_val.lower), int(index_val.upper) + ) + if not range_expressable_in_32_bits(index_val_int): + return + + if not range_expressable_in_32_bits(bounds[dominated]): + return + + args = list(node.args) + args[2] = torch.int32 + node.args = tuple(args) + + +def indexing_dtype_strength_reduction(loop_body: LoopBody): + """ + Performs Value Range Analysis on LoopBody's fx graph to reduce precision of + intermediaries from int64 to int32 + """ + bv = loop_body.bounds() + + int64_dtype_nodes = [ + node + for node in loop_body.get_nodes() + if ( + node.target == "to_dtype" + and node.args[2] == torch.int64 + and node not in bv.unbounded_vars + ) + ] + if not int64_dtype_nodes: + return + + bounds = bv.get_bounds() + + # TODO - if dominated node of one to_dtype is not expressible in int32, + # we should short circuit another to_dtype node if that node also dominates + for node in int64_dtype_nodes: + try_to_reduce_precision( + node, + bounds, + loop_body.indirect_vars, + loop_body.indexing_exprs, + bv.replacement_vals, + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/triton_heuristics.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/triton_heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..a7b9cb5ec72a75210c45227e6cc7613c3e535933 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/triton_heuristics.py @@ -0,0 +1,1527 @@ +import builtins +import copy +import functools +import hashlib +import inspect +import json +import logging +import math +import operator +import os +import os.path +import re +import threading +from enum import auto, Enum +from typing import Any, Callable, Dict, List, Optional, Set, Tuple + +import torch + +import torch.autograd.profiler as autograd_profiler +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.utils import dynamo_timed, get_first_attr +from torch.utils._triton import has_triton_package + +from . import config +from .codecache import cache_dir, CudaKernelParamCache +from .coordinate_descent_tuner import CoordescTuner + +from .ir import ReductionHint, TileHint +from .utils import ( + ceildiv, + conditional_product, + create_bandwidth_info_str, + do_bench, + get_max_y_grid, + get_num_bytes, + next_power_of_2, + triton_config_to_hashable, +) + + +log = logging.getLogger(__name__) + +if has_triton_package(): + import triton + from triton import Config + from triton.runtime.autotuner import OutOfResources + from triton.runtime.jit import KernelInterface + + try: + from triton.compiler.compiler import ASTSource + except ImportError: + ASTSource = None +else: + Config = object + triton = None + KernelInterface = object + OutOfResources = object + ASTSource = None + + +_NUM_THREADS_PER_WARP = 32 + + +class HeuristicType(Enum): + PERSISTENT_REDUCTION = auto() + POINTWISE = auto() + REDUCTION = auto() + SPLIT_SCAN = auto() + TEMPLATE = auto() + USER_AUTOTUNE = auto() + + +class AutotuneHint(Enum): + ELEMENTS_PER_WARP_32 = 0 + + # Triton codegen tries to codegen set of AutotuneHints. + # Enum.__repr__ looks like """ + # which isn't valid python. + # Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32". + __repr__ = Enum.__str__ + + +def autotune_hints_to_configs( + hints: Set[AutotuneHint], size_hints, block_size: int +) -> List[Config]: + """ + AutotuneHints can be attached to the metadata of triton kernels for providing + suggestions about what to try for autotuning. One reason to do this is if there are + some configs that are only useful in specific scenarios, in which case we can avoid + wasting compile time on autotuning unless we know we are in one of those scenarios. + + Based on those hints, this function will generate a list of additional autotuning + configs to try. + """ + xyz_options: Tuple[Tuple[int, Optional[int], Optional[int]], ...] + configs = [] + + for hint in hints: + if hint == AutotuneHint.ELEMENTS_PER_WARP_32: + if len(size_hints) == 1: + xyz_options = ((block_size // 4, None, None),) + elif len(size_hints) == 2: + xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None)) + elif len(size_hints) == 3: + xyz_options = ( + (block_size // 4, 1, 1), + (1, block_size // 4, 1), + (1, 1, block_size // 4), + ) + for xyz in xyz_options: + configs.append( + triton_config( + size_hints, + *xyz, + num_elements_per_warp=32, + ) + ) + + return configs + + +def disable_pointwise_autotuning(): + # Autotuning can give different benchmarking results from run to run, and + # therefore we disable autotuning when use_deterministic flag is on. + if torch.are_deterministic_algorithms_enabled(): + return True + return not config.triton.autotune_pointwise + + +class CachingAutotuner(KernelInterface): + """ + Simplified version of Triton autotuner that has no invalidation + key and caches the best config to disk to improve cold start times. + Unlike the main triton Autotuner, this version can precompile all + configs, and does not rely on the Triton JIT. + """ + + def __init__( + self, + fn, + triton_meta, # passed directly to triton + configs, + save_cache_hook, + mutated_arg_names, + heuristic_type, + size_hints=None, + inductor_meta=None, # metadata not relevant to triton + custom_kernel=False, # whether the kernel is inductor-generated or custom + ): + super().__init__() + + assert len(configs) > 0, "Non-empty TritonConfig list required for compiling" + self.fn = fn + self.triton_meta = triton_meta + self.inductor_meta = {} if inductor_meta is None else inductor_meta + self.save_cache_hook = save_cache_hook + self.mutated_arg_names = mutated_arg_names + self.configs = configs + self.heuristic_type = heuristic_type + self.custom_kernel = custom_kernel + self.cuda_kernel_saved = False + + # Align the default design that default as cuda + self.device_type = ( + triton_meta["device_type"] if "device_type" in triton_meta else "cuda" + ) + self.gpu_device = get_interface_for_device(self.device_type) + + if log.isEnabledFor(logging.DEBUG): + log.debug( + "CachingAutotuner gets %d configs for %s", + len(self.configs), + self.fn.__name__, + ) + for c in self.configs: + log.debug(c) + + self.launchers = [] + self.lock = threading.Lock() + if os.getenv("TRITON_CACHE_DIR") is None: + os.environ["TRITON_CACHE_DIR"] = os.path.join( + cache_dir(), + "triton", + str(self.triton_meta.get("device", 0)), + ) + + self.size_hints = size_hints + self.coordesc_tuner = CoordescTuner( + is_mm=False, name=self.fn.__name__, size_hints=size_hints + ) + + # pre-create the profiler context manager to reduce latency + self.record_function_ctx = torch._C._profiler._RecordFunctionFast( + self.inductor_meta.get("kernel_name", "triton kernel") + ) + + def precompile(self, warm_cache_only_with_cc=None): + with self.lock: + if self.launchers: + return + self.launchers = [] + compiled_binaries = [] + if not self.configs: + raise RuntimeError("No triton configs are available") + + for c in self.configs: + try: + compiled_binary, launcher = self._precompile_config( + c, warm_cache_only_with_cc + ) + except OutOfResources: + # Skip the config if we run out of resource + continue + self.launchers.append(launcher) + compiled_binaries.append(compiled_binary) + + if len(self.launchers) == 0: + raise RuntimeError( + "No valid triton configs. Report a fatal compilation error" + ) + + seen_configs = set(self.configs) + + device_prop = self.gpu_device.Worker.get_device_properties( + self.triton_meta["device"] + ) + if ( + config.dynamic_scale_rblock + and self.heuristic_type == HeuristicType.REDUCTION + and self.size_hints is not None + # Disable for AMDGPU as Triton is not ready to return n_regs for a compiled_binary. + and torch.version.hip is None + and device_prop.major >= 8 + ): + for triton_config, compiled_binary in zip( + self.configs, compiled_binaries + ): + assert len(self.size_hints) == 2 + xblock = triton_config.kwargs.get("XBLOCK", 1) + rblock = triton_config.kwargs["RBLOCK"] + total_block = (self.size_hints[0] + xblock - 1) // xblock + nreg = getattr(compiled_binary, "n_regs", None) + if nreg is None: + continue + + # make sure rblock is not too small + if rblock <= 64: + continue + + # each SM of A100 has 65536 32-bit registers. To maximize + # the theoretical occupancy, we need run 2048 threads on each + # SM. So each thread should use no more than 65536 / 2048 + # = 32 registers. In cases where occupancy matters, and each + # thread uses too many registers, reduce RBLOCK to reduce + # the register usage. + # For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd + # from PLBartForCausalLM, latency improve from + # 7.795ms to 4.883ms. + # + if ( + nreg + <= device_prop.regs_per_multiprocessor + // device_prop.max_threads_per_multi_processor + ): + continue + + nreg_per_warp = nreg * 32 + nreg_per_block = nreg_per_warp * triton_config.num_warps + + # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)' + # The formula below is a tighter upper bound since we have the assumption that + # nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor + # due to the if condition above and: + # regs_per_multiprocessor / nreg_per_block + # = regs_per_multiprocessor / (nreg * 32 * num_warps) + # < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps) + # = max_threads_per_multi_processor / (32 * num_warps) + # Using a tigher upper bound can reveal more optimization opportunities. + max_blocks_per_sm = max( + device_prop.regs_per_multiprocessor // nreg_per_block, 1 + ) + + if ( + total_block + <= max_blocks_per_sm * device_prop.multi_processor_count + ): + # no need to improve occupancy + continue + new_config = copy.deepcopy(triton_config) + new_config.kwargs["RBLOCK"] = rblock // 2 + if new_config in seen_configs: + continue + seen_configs.add(new_config) + self.launchers.append( + self._precompile_config(new_config, warm_cache_only_with_cc)[1] + ) + self.configs = None + + def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]): + """Ahead of time compile a given autotuner config.""" + compile_meta = copy.deepcopy(self.triton_meta) + for k, v in cfg.kwargs.items(): + compile_meta["constants"][self.fn.arg_names.index(k)] = v + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + compile_meta["debug"] = ( + config.assert_indirect_indexing and torch.version.hip is None + ) + + # Setting device_type="hip" required on ROCm to pass down to triton + compile_meta["device_type"] = ( + self.device_type if torch.version.hip is None else "hip" + ) + + if warm_cache_only_with_cc: + cc = warm_cache_only_with_cc + else: + # Use device_type 'cuda' for both cuda and hip devices to retrieve + # the compute capability. + device_type = self.device_type if torch.version.hip is None else "cuda" + device_id = compile_meta["device"] + device = torch.device(device_type, device_id) + cc = self.gpu_device.get_compute_capability(device) + + compile_meta["cc"] = cc + + if ASTSource: + compile_args = ( + ASTSource( + self.fn, + compile_meta["signature"], + compile_meta["constants"], + compile_meta["configs"][0], + ), + ) + + target = (compile_meta["device_type"], cc) + options = { + "num_warps": compile_meta["num_warps"], + "num_stages": compile_meta["num_stages"], + "debug": compile_meta["debug"], + } + compile_kwargs = { + "target": target, + "options": options, + } + else: + compile_args = (self.fn,) + compile_kwargs = compile_meta + + if warm_cache_only_with_cc: + return ( + triton.compile(*compile_args, **compile_kwargs), + None, + ) + + # load binary to the correct device + with self.gpu_device.device(compile_meta["device"]): # type: ignore[attr-defined] + # need to initialize context + self.gpu_device.synchronize(self.gpu_device.current_device()) + + try: + binary = triton.compile(*compile_args, **compile_kwargs) + except Exception: + log.exception( + "Triton compilation failed: %s\n%s\nmetadata: %s", + self.inductor_meta.get("kernel_name", "triton_"), + self.fn.src, + compile_meta, + ) + raise + binary._init_handles() + + call_args = [ + arg + for i, arg in enumerate(self.fn.arg_names) + if i not in self.fn.constexprs + ] + def_args = [name for name in self.fn.arg_names if name not in cfg.kwargs] + + scope = { + "grid_meta": cfg.kwargs, + "bin": binary, + "launch_enter_hook": binary.launch_enter_hook, + "launch_exit_hook": binary.launch_exit_hook, + "metadata": binary.metadata, + "torch": torch, + "set_device": self.gpu_device.set_device, + "current_device": self.gpu_device.current_device, + } + + scope["runner"] = get_first_attr(binary, "run", "c_wrapper") + scope["function"] = get_first_attr(binary, "function", "cu_function") + scope["cta_args"] = ( + (binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims")) + if hasattr(binary, "num_ctas") + else ( + (binary.metadata.num_ctas, *binary.metadata.cluster_dims) + if hasattr(binary, "metadata") + else () + ) + ) + scope["num_warps"] = ( + binary.num_warps + if hasattr(binary, "num_warps") + else binary.metadata.num_warps + ) + binary_shared = ( + binary.shared if hasattr(binary, "shared") else binary.metadata.shared + ) + scope["shared"] = binary_shared + + exec( + f""" + def launcher({', '.join(def_args)}, grid, stream): + if callable(grid): + grid_0, grid_1, grid_2 = grid(grid_meta) + else: + grid_0, grid_1, grid_2 = grid + + runner(grid_0, grid_1, grid_2, num_warps, + *cta_args, shared, + stream, function, + launch_enter_hook, + launch_exit_hook, + metadata, + {', '.join(call_args)}) + return bin + """.lstrip(), + scope, + ) + + launcher = scope["launcher"] + launcher.config = cfg + launcher.n_regs = getattr(binary, "n_regs", None) + launcher.n_spills = getattr(binary, "n_spills", None) + launcher.shared = binary_shared + launcher.store_cubin = config.triton.store_cubin + # store this global variable to avoid the high overhead of reading it when calling run + if launcher.store_cubin: + launcher.fn = self.fn + launcher.bin = binary + + return binary, launcher + + def bench(self, launcher, *args, grid, **kwargs): + """Measure the performance of a given launcher""" + # we don't skip configs wiht spilled registers when auto-tuning custom + # (user-written) Triton kernels, as (i) we don't have any knowledge or + # control over the kernel code; (ii) there is empirical evidence that + # for some (complicated) custom Triton kernels, a register-spilling + # config may yield the best latency. + if not self.custom_kernel and launcher.n_spills > config.triton.spill_threshold: + log.debug( + "Skip config %s because of register spilling: %d", + launcher.config, + launcher.n_spills, + ) + return float("inf") + + stream = self.gpu_device.get_raw_stream( # type: ignore[call-arg] + self.gpu_device.current_device() + ) + + def kernel_call(): + if launcher.config.pre_hook is not None: + launcher.config.pre_hook( + {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} + ) + + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) + launcher( + *cloned_args, + **cloned_kwargs, + grid=grid, + stream=stream, + ) + + return do_bench(kernel_call, rep=40, fast_flush=True) + + def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: + from .compile_fx import clone_preserve_strides + + # clone inplace buffers to avoid autotune contaminating them if + # the kernel does in-place stores. avoid cloning other buffers because + # it leads to increase memory use + cloned_args = [] + for i, arg in enumerate(args): + if self.fn.arg_names[i] in self.mutated_arg_names: + assert isinstance(arg, torch.Tensor) + cloned_args.append(clone_preserve_strides(arg)) + else: + cloned_args.append(arg) + + cloned_kwargs: Dict[str, Any] = {} + for name, arg in kwargs.items(): + if name in self.mutated_arg_names: + assert isinstance(arg, torch.Tensor) + cloned_kwargs[name] = clone_preserve_strides(arg) + else: + cloned_kwargs[name] = arg + + return cloned_args, cloned_kwargs + + @dynamo_timed + def benchmark_all_configs(self, *args, **kwargs): + timings = { + launcher: self.bench(launcher, *args, **kwargs) + for launcher in self.launchers + } + + for k, v in timings.items(): + self.coordesc_tuner.cache_benchmark_result(k.config, v) + + if log.isEnabledFor(logging.DEBUG): + log.debug("Benchmark all input configs for %s, get:", self.fn.__name__) + for k, v in timings.items(): + log.debug( + "%s: %f, nreg %d, nspill %d, #shared-mem %s", + k.config, + v, + k.n_regs, + k.n_spills, + k.shared, + ) + + return timings + + def autotune_to_one_config(self, *args, **kwargs): + """Do the actual autotuning""" + timings = self.benchmark_all_configs(*args, **kwargs) + self.launchers = [builtins.min(timings, key=timings.get)] + if self.save_cache_hook: + self.save_cache_hook(self.launchers[0].config) + + def save_cuda_kernel(self, grid, stream, launcher): + if callable(grid): + grid_x, grid_y, grid_z = grid(launcher.config.kwargs) + else: + grid_x, grid_y, grid_z = grid + + key = self.inductor_meta.get("kernel_name", None) # unique kernel name + assert key is not None, "kernel_name can not be None" + params = { + "mangled_name": launcher.bin.metadata.name + if hasattr(launcher.bin.metadata, "name") + else launcher.bin.metadata["name"], + "grid_x": grid_x, + "grid_y": grid_y, + "grid_z": grid_z, + "x_block": launcher.config.kwargs.get("XBLOCK", 1), + "y_block": launcher.config.kwargs.get("YBLOCK", None), + "z_block": launcher.config.kwargs.get("ZBLOCK", None), + "num_warps": launcher.bin.num_warps + if hasattr(launcher.bin, "num_warps") + else launcher.bin.metadata.num_warps, + "shared_mem": launcher.bin.shared + if hasattr(launcher.bin, "shared") + else launcher.bin.metadata.shared, + "stream": stream, + # User defined triton kernels will have arbitrary kwarg names + "meta": launcher.config.kwargs, + } + + if torch.version.hip is None: + CudaKernelParamCache.set(key, params, launcher.bin.asm["cubin"]) + else: + # There is some divergence between CUDA and ROCm here. + # On ROCm's triton we only have the the path to the binary, not the binary itself. + # For ROCm we will copy the binary to the new location instead of writing to file + import pathlib + + launcher.bin.asm["hsaco"] = pathlib.Path( + launcher.bin.asm["hsaco_path"] + ).read_bytes() + CudaKernelParamCache.set(key, params, launcher.bin.asm["hsaco"]) + + self.cuda_kernel_saved = True + + def coordinate_descent_tuning(self, launcher, *args, **kwargs): + """ + Coordinate descent tuning can be run with or without max-autotune. + + The only difference between these two is the starting config for coordinate_descent tuning. + E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4 + and max-autotune figure out C3 is the best. + + Then if coordinate descnt tuning is run with max-autotune disabled, it will start from C1; + while if coordinate descent tuning is run with max-autotune enabled, it will start from C3. + """ + if ( + self.heuristic_type == HeuristicType.TEMPLATE + or self.heuristic_type == HeuristicType.USER_AUTOTUNE + ): + # skip triton template + return launcher + + cloned_args, _ = self.clone_args(*args) + config2launcher = {launcher.config: launcher} + + def benchmark_one_config(config): + with self.lock: + _, launcher = self._precompile_config(config, None) + config2launcher[config] = launcher + + out = self.bench(launcher, *cloned_args, **kwargs) + log.debug( + "COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d", + launcher.config, + out, + launcher.n_regs, + launcher.n_spills, + launcher.shared, + ) + return out + + assert not ( + self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION + and "RBLOCK" in launcher.config.kwargs + ), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK" + best_config = self.coordesc_tuner.autotune( + benchmark_one_config, launcher.config, None + ) + best_config.found_by_coordesc = True + + if self.save_cache_hook: + self.save_cache_hook(best_config, found_by_coordesc=True) + return config2launcher.get(best_config) + + def run(self, *args, grid, stream, **kwargs): + if len(self.launchers) != 1: + if len(self.launchers) == 0: + self.precompile() + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, grid=grid, **kwargs) + + if ( + not getattr(self.launchers[0].config, "found_by_coordesc", False) + and config.coordinate_descent_tuning + ): + self.launchers = [ + self.coordinate_descent_tuning( + self.launchers[0], *args, grid=grid, **kwargs + ) + ] + + (launcher,) = self.launchers + if launcher.store_cubin: + self.save_cuda_kernel(grid, stream, launcher) + + if launcher.config.pre_hook is not None: + launcher.config.pre_hook( + {**dict(zip(self.arg_names, args)), **launcher.config.kwargs, **kwargs} + ) + + # guard the record_function_ctx and only call it if profiling is currently + # in progress, to reduce latency when profiler is not turned on. Note that + # the "if" statement (instead of, say, a contextlib.nullcontext) is intentional; + # it is faster than entering and exiting a context manager, even if the context + # manager is a nullcontext. + if autograd_profiler._is_profiler_enabled: + with self.record_function_ctx: + return launcher( + *args, + **kwargs, + grid=grid, + stream=stream, + ) + else: + return launcher( + *args, + **kwargs, + grid=grid, + stream=stream, + ) + + +def _find_names(obj): + import gc + import inspect + + frame = inspect.currentframe() + while frame is not None: + frame.f_locals + frame = frame.f_back + obj_names = [] + for referrer in gc.get_referrers(obj): + if isinstance(referrer, dict): + for k, v in referrer.items(): + if v is obj: + obj_names.append(k) + return obj_names + + +collected_calls: List[Any] = [] + + +def start_graph(): + collected_calls.clear() + + +def end_graph(): + if len(collected_calls) == 0: + return + overall_time = sum(call[0] for call in collected_calls) + overall_gb = sum(call[1] for call in collected_calls) + cur_file = inspect.stack()[1].filename + summary_str = ( + f"SUMMARY ({cur_file})\n" + f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s" + ) + print(summary_str) + print() + output_file = config.profile_bandwidth_output + if output_file is not None: + # sort perf numbers in descending order, i.e. placing the + # most runtime-heavy kernels at the top of the list + sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True) + try: + with open(output_file, "a") as file: + log.debug("Save profile bandwidth results to %s", output_file) + file.write("====================\n") + file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n") + for ms, num_gb, gb_per_s, kernel_name in sorted_calls: + # also display the runtime percentage for each kernel + percentage = f"{ms/overall_time*100:.2f}%" + suffix = f" \t {percentage} \t {kernel_name}" + bw_info_str = create_bandwidth_info_str( + ms, + num_gb, + gb_per_s, + suffix=suffix, + color=False, + ) + file.write(bw_info_str + "\n") + file.write(f"{summary_str}\n\n") + except Exception as e: + log.warning( + "failed to write profile bandwidth result into %s: %s", + output_file, + e, + ) + + +class DebugAutotuner(CachingAutotuner): + def __init__(self, *args, regex_filter="", **kwargs): + self.regex_filter = regex_filter + super().__init__(*args, **kwargs) + self.cached = None + + def run(self, *args, grid, stream): + possible_names = _find_names(self) + kernel_name = f"{max(possible_names, key=len)}" + if not re.match(self.regex_filter, kernel_name): + return + super().run(*args, grid=grid, stream=stream) + (launcher,) = self.launchers + + if self.cached is None: + ms = self.bench(launcher, *args, grid=grid) + num_in_out_ptrs = len( + [ + arg_name + for arg_name in self.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = self.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + gb_per_s = num_gb / (ms / 1e3) + self.cached = (ms, num_gb, gb_per_s, kernel_name) + else: + ms, num_gb, gb_per_s, kernel_name = self.cached + collected_calls.append((ms, num_gb, gb_per_s, kernel_name)) + print( + create_bandwidth_info_str(ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}") + ) + + +def hash_configs(configs: List[Config]): + """ + Hash used to check for changes in configurations + """ + hasher = hashlib.sha256() + for cfg in configs: + hasher.update( + f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode() + ) + return hasher.hexdigest() + + +def load_cached_autotuning( + best_config, + configs_hash: str, + configs: List[Config], +): + if best_config is None: + return None + if best_config.pop("configs_hash", None) != configs_hash: + return None + + if config.coordinate_descent_tuning and best_config.pop("found_by_coordesc", False): + num_warps = best_config.pop("num_warps") + num_stages = best_config.pop("num_stages") + triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages) + triton_config.found_by_coordesc = True + return triton_config + + matching_configs = [ + cfg + for cfg in configs + if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) + and cfg.num_warps == best_config.get("num_warps") + and cfg.num_stages == best_config.get("num_stages") + ] + if len(matching_configs) != 1: + return None + + return matching_configs[0] + + +def cached_autotune( + size_hints: Optional[List[int]], + configs: List[Config], + triton_meta, + heuristic_type, + filename=None, + inductor_meta=None, + custom_kernel=False, +): + """ + A copy of triton.autotune that calls our subclass. Our subclass + has additional debugging, error handling, and on-disk caching. + """ + configs = unique_configs(configs) + assert len(configs) == 1 or filename + save_cache_hook: Optional[Callable[[Any, Any], Any]] + inductor_meta = {} if inductor_meta is None else inductor_meta + + # on disk caching logic and/or remote caching + if filename is not None and (len(configs) > 1 or config.coordinate_descent_tuning): + configs_hash = hash_configs(configs) + + cache_filename = None + remote_cache = None + remote_cache_key = None + if config.use_autotune_local_cache: + cache_filename = os.path.splitext(filename)[0] + ".best_config" + if config.use_autotune_remote_cache or ( + config.is_fbcode() + and torch._utils_internal.justknobs_check( + "pytorch/autotune_remote_cache:enable" + ) + ): + backend_hash = inductor_meta.get("backend_hash", None) + if backend_hash is not None: + key = backend_hash + configs_hash + "autotune-best-config" + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + + try: + if config.is_fbcode(): + remote_cache = ( + triton.runtime.fb_memcache.FbMemcacheRemoteCacheBackend( + key, is_autotune=True + ) + ) + else: + remote_cache = triton.runtime.cache.RedisRemoteCacheBackend(key) + except Exception: + remote_cache = None + log.warning("Unable to create a remote cache", exc_info=True) + # we already sha256 hash the source contents + remote_cache_key = os.path.basename(filename) + else: + log.debug( + "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache" + ) + + best_config = None + if cache_filename is not None and os.path.exists(cache_filename): + with open(cache_filename) as fd: + best_config = json.loads(fd.read()) + elif remote_cache is not None and remote_cache_key is not None: + cache_outs = remote_cache.get([remote_cache_key]) + cache_out = cache_outs.get(remote_cache_key, None) + best_config = json.loads(cache_out) if cache_out else None + + best_config = load_cached_autotuning(best_config, configs_hash, configs) + if best_config: + configs = [best_config] + + def save_cache_hook(cfg, found_by_coordesc=False): + data = json.dumps( + { + **cfg.kwargs, + "num_warps": cfg.num_warps, + "num_stages": cfg.num_stages, + "configs_hash": configs_hash, + "found_by_coordesc": found_by_coordesc, + } + ) + if cache_filename is not None: + with open(cache_filename, "w") as fd: + fd.write(data) + if remote_cache is not None and remote_cache_key is not None: + remote_cache.put(remote_cache_key, data) + + if log.isEnabledFor(logging.DEBUG): + type_str = "coordesc" if found_by_coordesc else "heuristic" + log.debug("Save %s tuning result to %s", type_str, cache_filename) + + else: + save_cache_hook = None + + mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) + + def decorator(fn): + # Remove XBLOCK from config if it's not a function argument. + # This way, coordinate descent tuning will not try to tune it. + # + # Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1. + import inspect + + if "XBLOCK" not in inspect.signature(fn.fn).parameters: + for tconfig in configs: + if "XBLOCK" in tconfig.kwargs: + assert tconfig.kwargs["XBLOCK"] == 1 + tconfig.kwargs.pop("XBLOCK") + + if config.profile_bandwidth: + return DebugAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + regex_filter=config.profile_bandwidth_regex, + configs=configs, + save_cache_hook=save_cache_hook, + mutated_arg_names=mutated_arg_names, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + ) + return CachingAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + configs=configs, + save_cache_hook=save_cache_hook, + mutated_arg_names=mutated_arg_names, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + ) + + return decorator + + +def unique_configs(configs: List[Config]): + """Remove duplicate configurations""" + seen = set() + pruned_configs = [] + + for cfg in configs: + key = triton_config_to_hashable(cfg) + if key not in seen: + seen.add(key) + pruned_configs.append(cfg) + return pruned_configs + + +def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None): + for numel, label in zip((xnumel, ynumel, znumel), "XYZ"): + if numel is None: + continue + block = cfg[f"{label}BLOCK"] + if numel == 1: + assert block == 1, ( + f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1" + f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})." + ) + max_block = config.triton.max_block[label] + max_block_str = f'config.triton.max_block["{label}"]' + assert max_block % block == 0, ( + f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}" + f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})." + ) + + +def triton_config( + size_hints, + x, + y=None, + z=None, + num_stages=1, + num_elements_per_warp=256, + min_elem_per_thread=0, +) -> Config: + """ + Construct a pointwise triton config with some adjustment heuristics + based on size_hints. Size_hints is a tuple of numels in each tile + dimension and will be rounded up to the nearest power of 2. + + num_elements_per_warp is a suggestion for controlling how many warps + the triton config should contain. e.g.: if x=16, y=8, z=4 then + num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128, + we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's + just a suggestion, and sometimes other adjustment heuristics will + override the num_elements_per_warp. + + min_elem_per_thread controls the minimum number of elements + processed by each thread. It's always enforced. + """ + # Ideally we want to read this from some device config + + # for a 2d size_hints [a, b], a should be mapped to YBLOCK rather than XBLOCK + size_hints = list(reversed(size_hints)) + + maxGridSize = [2147483647, 65535, 65535] + + target = conditional_product(x, y, z) + if conditional_product(*size_hints) < target: + target //= 8 + + # shrink sizes to size hints + x = min(x, size_hints[0]) + if y: + y = min(y, size_hints[1]) + if z: + z = min(z, size_hints[2]) + + # if we are below original block size, scale up where we can; + # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension + while x < min(size_hints[0], config.triton.max_block["X"]) and ( + x * maxGridSize[0] < size_hints[0] or conditional_product(x, y, z) < target + ): + x *= 2 + while ( + y + and y < min(size_hints[1], config.triton.max_block["Y"]) + and ( + y * maxGridSize[1] < size_hints[1] or conditional_product(x, y, z) < target + ) + ): + y *= 2 + while ( + z + and z < min(size_hints[2], config.triton.max_block["Z"]) + and ( + z * maxGridSize[2] < size_hints[2] or conditional_product(x, y, z) < target + ) + ): + z *= 2 + + num_warps = next_power_of_2( + min(max(conditional_product(x, y, z) // num_elements_per_warp, 1), 8) + ) + # we are going to arrive at 2 warps only if bs was too small due to + # numel being too small. However to workaround some ptx bugs we still + # want at least 4 warps if there's enough elements per thread + # given that this is a rare situation, don't expect this to affect perf + # in general + # see https://github.com/pytorch/pytorch/pull/97950 + num_warps = max(num_warps, 4) if conditional_product(x, y, z) >= 128 else num_warps + xnumel = size_hints[0] + ynumel = size_hints[1] if y else None + znumel = size_hints[2] if z else None + + # Increase x to satisfy min_elem_per_thread requirements. + block_size = max( + conditional_product(x, y, z), + min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps, + ) + x *= math.ceil(block_size / conditional_product(x, y, z)) + + cfg = {"XBLOCK": x} + if y: + cfg["YBLOCK"] = y + if z: + cfg["ZBLOCK"] = z + check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def triton_config_reduction(size_hints, x, r, num_stages=1, num_warps=None) -> Config: + """ + Construct a reduction triton config with some adjustment heuristics + based on size_hints. Size_hints is a tuple of numels in each tile + dimension and will be rounded up to the nearest power of 2. + """ + + target = conditional_product(x, r) + if conditional_product(*size_hints) < target: + target //= 8 + + # shrink sizes to size hints + x = min(x, size_hints[0]) + r = min(r, size_hints[1]) + + # if we are below original block size, scale up where we can + while x < size_hints[0] and conditional_product(x, r) < target: + x *= 2 + while r < size_hints[1] and conditional_product(x, r) < target: + r *= 2 + + cfg = {"XBLOCK": x, "RBLOCK": r} + if num_warps is None: + num_warps = conditional_product(x, r) // 128 + num_warps = next_power_of_2(min(max(num_warps, 2), 8)) + check_config(cfg, xnumel=size_hints[0]) + assert ( + r <= config.triton.max_block["R"] + ), f"increase config.triton.MAX_BLOCK['r'] to {r}" + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1): + """ + Construct a tile reduction triton config with some adjustment + heuristics based on size_hints. Size_hints is a tuple of numels in + each tile dimension and will be rounded up to the nearest power of 2. + """ + + target = conditional_product(x, y, r) + if conditional_product(*size_hints) < target: + target //= 8 + + # shrink sizes to size hints + x = min(x, size_hints[0]) + y = min(y, size_hints[1]) + r = min(r, size_hints[2]) + + # if we are below original block size, scale up where we can + while x < size_hints[0] and conditional_product(x, y, r) < target: + x *= 2 + while r < size_hints[2] and conditional_product(x, y, r) < target: + r *= 2 + while y < size_hints[1] and conditional_product(x, y, r) < target: + y *= 2 + + cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r} + num_warps = next_power_of_2(min(max(conditional_product(x, y, r) // 256, 1), 8)) + check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1]) + assert ( + r <= config.triton.max_block["R"] + ), f"increase config.triton.MAX_BLOCK['r'] to {r}" + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def pointwise( + size_hints, + triton_meta, + tile_hint=None, + filename=None, + min_elem_per_thread=0, + inductor_meta=None, +): + """ + Construct @triton.heuristics() based on size_hints. + """ + inductor_meta = {} if inductor_meta is None else inductor_meta + assert not inductor_meta.get("no_x_dim") + + numel = functools.reduce(operator.mul, size_hints) + bs = max(256, min(numel // 128, 1024)) + + hinted_configs = autotune_hints_to_configs( + inductor_meta.get("autotune_hints", set()), size_hints, bs + ) + + triton_config_with_settings = functools.partial( + triton_config, min_elem_per_thread=min_elem_per_thread + ) + + if len(size_hints) == 1: + if disable_pointwise_autotuning() and not ( + config.max_autotune or config.max_autotune_pointwise + ): + return cached_autotune( + size_hints, + [triton_config_with_settings(size_hints, bs)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + else: + return cached_autotune( + size_hints, + [ + triton_config_with_settings( + size_hints, bs, num_elements_per_warp=256 + ), + triton_config_with_settings( + size_hints, bs // 2, num_elements_per_warp=64 + ), + *hinted_configs, + ], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + if len(size_hints) == 2: + if (disable_pointwise_autotuning() or tile_hint == TileHint.SQUARE) and not ( + config.max_autotune or config.max_autotune_pointwise + ): + return cached_autotune( + size_hints, + [triton_config_with_settings(size_hints, 32, 32)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + return cached_autotune( + size_hints, + [ + triton_config_with_settings(size_hints, 32, 32), + triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 + triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 16, 256), + triton_config_with_settings(size_hints, bs, 1), + triton_config_with_settings(size_hints, 1, bs), + *hinted_configs, + ], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.POINTWISE, + ) + if len(size_hints) == 3: + if disable_pointwise_autotuning(): + return cached_autotune( + size_hints, + [triton_config_with_settings(size_hints, 16, 16, 16)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + return cached_autotune( + size_hints, + [ + triton_config_with_settings(size_hints, 16, 16, 16), + triton_config_with_settings(size_hints, 64, 8, 8), + triton_config_with_settings(size_hints, 8, 64, 8), + triton_config_with_settings(size_hints, 8, 8, 64), + triton_config_with_settings(size_hints, bs, 1, 1), + triton_config_with_settings(size_hints, 1, bs, 1), + triton_config_with_settings(size_hints, 1, 1, bs), + *hinted_configs, + ], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.POINTWISE, + ) + raise NotImplementedError(f"size_hints: {size_hints}") + + +def _reduction_configs( + *, size_hints: List[int], inductor_meta: Dict[str, Any] +) -> List[Config]: + reduction_hint = inductor_meta.get("reduction_hint", None) + assert len(size_hints) == 2 + rnumel = size_hints[-1] + + contiguous_config = triton_config_reduction( + size_hints, 1, (rnumel if 256 <= rnumel < 2048 else 2048) + ) + outer_config = triton_config_reduction(size_hints, 64, 8) + tiny_config = triton_config_reduction( + size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, min(rnumel, 2048) + ) + if config.max_autotune or config.max_autotune_pointwise: + pass # skip all these cases + elif reduction_hint == ReductionHint.INNER: + return [contiguous_config] + elif reduction_hint == ReductionHint.OUTER: + return [outer_config] + elif reduction_hint == ReductionHint.OUTER_TINY: + return [tiny_config] + if disable_pointwise_autotuning(): + return [triton_config_reduction(size_hints, 32, 128)] + return [ + contiguous_config, + outer_config, + tiny_config, + triton_config_reduction(size_hints, 64, 64), + triton_config_reduction(size_hints, 8, 512), + # halve the XBLOCK/RBLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + triton_config_reduction(size_hints, 64, 4, num_warps=8), + ] + + +def reduction( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + """args to @triton.heuristics()""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints = [1, *size_hints[1:]] + + assert triton_meta is not None + rnumel = size_hints[-1] + if len(size_hints) != 2: + raise NotImplementedError(f"size_hints: {size_hints}") + + configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.REDUCTION, + filename=filename, + ) + + +def persistent_reduction( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints = [1, *size_hints[1:]] + + xnumel, rnumel = size_hints + + configs = [ + triton_config_reduction(size_hints, xblock, rnumel) + for xblock in (1, 8, 32, 128) + if xblock == 1 or (rnumel * xblock <= 4096 and xblock <= xnumel) + ] + + # TODO(jansel): we should be able to improve these heuristics + if reduction_hint == ReductionHint.INNER and rnumel >= 256: + configs = configs[:1] + elif reduction_hint == ReductionHint.OUTER: + configs = configs[-1:] + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = [ + triton_config_reduction( + size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel + ) + ] + for c in configs: + # we don't need RBLOCK for persistent reduction + c.kwargs.pop("RBLOCK") + + if disable_pointwise_autotuning(): + configs = configs[:1] + + return cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.PERSISTENT_REDUCTION, + ) + + +def split_scan( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + """Heuristic for TritonSplitScanKernel""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints = [1, *size_hints[1:]] + + assert triton_meta is not None + rnumel = size_hints[-1] + if len(size_hints) != 2: + raise NotImplementedError(f"size_hints: {size_hints}") + + configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) + + # Fixup configs to enforce the minimum RBLOCK size + min_rblock = config.triton.min_split_scan_rblock + for cfg in configs: + if cfg.kwargs["RBLOCK"] < min_rblock: + cfg.kwargs["RBLOCK"] = min_rblock + + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.SPLIT_SCAN, + filename=filename, + ) + + +def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=None): + """ + Compile a triton template + """ + return cached_autotune( + None, + [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +def user_autotune( + configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False +): + """ + Compile a user defined triton kernel + """ + defaults = inspect.signature(triton.Config).parameters + default_num_stages = defaults["num_stages"].default + default_num_warps = defaults["num_warps"].default + + if len(configs) == 0: + configs = [ + triton.Config( + {}, num_stages=default_num_stages, num_warps=default_num_warps + ) + ] + else: + configs = [ + triton.Config( + c.get("kwargs", {}), + num_stages=c.get("num_stages", default_num_stages), + num_warps=c.get("num_warps", default_num_warps), + ) + for c in configs + ] + + return cached_autotune( + None, + configs, + triton_meta=triton_meta, + heuristic_type=HeuristicType.USER_AUTOTUNE, + filename=filename, + inductor_meta=inductor_meta, + custom_kernel=custom_kernel, + ) + + +def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): + """ + Compile a triton foreach kernel + """ + return cached_autotune( + None, + [triton.Config({}, num_stages=1, num_warps=num_warps)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +def grid(*numels): + """Helper function to compute triton grids""" + if len(numels) == 1: + xnumel, ynumel, znumel = numels[0], None, None + elif len(numels) == 2: + xnumel, ynumel, znumel = numels[1], numels[0], None + elif len(numels) == 3: + xnumel, ynumel, znumel = numels[2], numels[1], numels[0] + else: + raise AssertionError(f"invalid size for numels {len(numels)}") + + def get_grid_dim(numel, block): + if numel is None: + return 1 + if block is None: + return numel + return ceildiv(numel, block) + + max_grid_dims = config.triton.max_tiles + + def grid_fn(meta): + x_grid = get_grid_dim(xnumel, meta.get("XBLOCK", 1)) + y_grid = get_grid_dim(ynumel, meta.get("YBLOCK", None)) + + MAX_Y_GRID = get_max_y_grid() + if znumel is None and max_grid_dims <= 2: + div = ceildiv(y_grid, MAX_Y_GRID) + y_grid = y_grid // div + z_grid = div + else: + z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None)) + torch._check( + y_grid <= MAX_Y_GRID, + lambda: f"Generated y grid beyond 2^16 ({y_grid}) not supported with z dimension present. File issue", + ) + + return ( + x_grid, + y_grid, + z_grid, + ) + + return grid_fn + + +def split_scan_grid(xnumel, rnumel): + def grid_fn(meta): + assert meta.get("XBLOCK", 1) == 1 + return (ceildiv(rnumel, meta.get("RBLOCK", 1)), xnumel, 1) + + return grid_fn diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27c308fe760aef488a65e3350ec4ec9cb58bcb14 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54fb1441aa2605275704b1c78a048daeee113612 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f03b52303107bd2a2a164853be60988fc098897 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3aa6d0f2a920cedf2bdd468f9fe668067ddfd383 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..08e066a9a72f8633beb5982bffa2a02d472315e3 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/__init__.py @@ -0,0 +1,1985 @@ +from __future__ import annotations + +import operator +import warnings +import weakref + +from contextlib import nullcontext +from enum import Enum +from functools import cmp_to_key, reduce +from typing import ( + Any, + Callable, + cast, + List, + NamedTuple, + Optional, + overload, + Sequence, + Tuple, + Type, + TYPE_CHECKING, + Union, +) + +from typing_extensions import TypeAlias + + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # such as auto-completion in tools like pylance, even when these modules are not explicitly + # imported in user code. + + import sympy + +import torch +from torch import sym_float, sym_int, sym_max + + +ShapeType: TypeAlias = Union[torch.Size, List[int], Tuple[int, ...]] +StrideType: TypeAlias = Union[List[int], Tuple[int, ...]] +DimsType: TypeAlias = Union[int, List[int], Tuple[int, ...]] +DimsSequenceType: TypeAlias = Union[List[int], Tuple[int, ...]] +# TODO: Type[torch.SymInt], Type[torch.SymFloat] +NumberTypeType: TypeAlias = Union[Type[bool], Type[int], Type[float], Type[complex]] +# TODO: This needs a lot more type annotations +# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat] +NumberType: TypeAlias = Union[bool, int, float, complex] +RealNumberType: TypeAlias = Union[bool, int, float] + +Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat) +# I don't call it Integral because numbers.Integral includes bool, but IntLike +# does not +Dim = int +IntLike = (int, torch.SymInt) +FloatLike = (float, torch.SymFloat) +IntWithoutSymInt = int +FloatWithoutSymFloat = float +DeviceLikeType: TypeAlias = Union[str, torch.device, int] +Tensor = torch.Tensor + + +torch_function_passthrough = { + torch.device, + torch.sym_not, + torch.sym_float, + torch.sym_int, + torch.sym_max, + torch.sym_min, + torch._sym_sqrt, # type: ignore[attr-defined] + torch.sym_ite, + torch.Tensor.dim, + torch.Tensor.ndim.__get__, # type: ignore[attr-defined] + torch.Tensor.numel, + torch.Tensor.size, + torch.Tensor.storage_offset, + torch.Tensor.stride, + torch.Tensor.dtype.__get__, # type: ignore[attr-defined] + torch.Tensor.is_sparse.__get__, # type: ignore[attr-defined] + torch.Tensor.shape.__get__, # type: ignore[attr-defined] + torch.Tensor.device.__get__, # type: ignore[attr-defined] + torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined] + torch.Tensor.layout.__get__, # type: ignore[attr-defined] + torch.Tensor.is_contiguous, + # For TorchRefsMode only + torch.Tensor.__format__, + torch.Tensor.__repr__, + torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined] +} + + +TensorLikeType = torch.Tensor +TensorLike = torch.Tensor +TensorSequenceType: TypeAlias = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]] +TensorOrNumberLikeType: TypeAlias = Union[TensorLikeType, NumberType] + +CustomOutParamAnnotation = "__custom_out_param__" + + +def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if len(a) != len(b): + return False + + for x, y in zip(a, b): + if allow_rhs_unbacked: + # TODO: We should check that the symbols are consistent + # with each other + if isinstance(y, torch.SymInt): + continue + # NB: Naively, you would not expect to have to do an oblivious guard + # here because there is seemingly no broadcasting here, but in fact we + # use this in some situations to determine if we need to do an expand + # on the tensor because they don't line up, so you can definitely end + # up trying to prove u0 != 1 in this situation. See + # python test/test_proxy_tensor.py -k test_cumsum_unbacked + if guard_size_oblivious(x != y): + return False + + return True + + +def _maybe_get_pytype(t): + if t is torch.SymFloat: + return float + elif t is torch.SymInt: + return int + elif t is torch.SymBool: + return bool + else: + return t + + +# TODO: look at using torch.testing.assert_close instead with an option +# to just compare metadata +def compare_tensor_meta( + a: TensorLikeType, + b: TensorLikeType, + check_strides=False, + *, + allow_rhs_unbacked=False, + check_conj=True, +): + """ + Checks that two tensor likes have the same shape, + dtype and device. + + In the future this will validate additional metadata, like + strides. + """ + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + + if not same_shape(a.shape, b.shape, allow_rhs_unbacked=allow_rhs_unbacked): + msg = f"Shapes {a.shape} and {b.shape} are not equal!" + raise AssertionError(msg) + + if a.dtype != b.dtype: + msg = f"Dtypes {a.dtype} and {b.dtype} are not equal!" + raise AssertionError(msg) + + if a.device != b.device: + # Handles special cuda:0 vs cuda case + # TODO: we should review why this happens and see about fixing it + if (str(a.device) == "cuda:0" or str(a.device) == "cuda") and ( + str(b.device) == "cuda:0" or str(b.device) == "cuda" + ): + pass + else: + msg = f"Devices {a.device} and {b.device} are not equal!" + raise AssertionError(msg) + + # Stride checking is currently disabled, see https://github.com/pytorch/pytorch/issues/78050 + if check_strides: + same_strides, idx = check_significant_strides(a, b) + if not same_strides: + msg = f"Stride mismatch! Strides are {a.stride()} and {b.stride()} (mismatched at {idx})!" + raise RuntimeError(msg) + + if a.storage_offset() != b.storage_offset(): + msg = f"Storage offset mismatch! Storage offsets are {a.storage_offset()} and {b.storage_offset()}!" + raise RuntimeError(msg) + + if check_conj: + if a.is_conj() != b.is_conj(): + raise RuntimeError( + f"Conj mismatch! is_conj is set to {a.is_conj()} and {b.is_conj()}" + ) + + if a.is_neg() != b.is_neg(): + raise RuntimeError( + f"Neg mismatch! is_neg is set to {a.is_neg()} and {b.is_neg()}" + ) + + +def _check_strides_helper( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, significant_only=True +) -> Tuple[bool, Optional[int]]: + # NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch + # See https://github.com/pytorch/pytorch/issues/77553 + # Only compares strides that are "meaningful" -- strides for dimensions with length > 1 + # and for tensors with more than one element + if ( + not only_cuda or a.device.type == "cuda" or b.device.type == "cuda" + ) and a.numel() > 0: + for idx in range(a.ndim): + check = not significant_only or a.shape[idx] > 1 + if a.stride()[idx] != b.stride()[idx] and check: + return False, idx + + return True, None + + +def check_significant_strides( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True +) -> Tuple[bool, Optional[int]]: + return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=True) + + +def check_all_strides( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True +) -> Tuple[bool, Optional[int]]: + return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False) + + +# This function is equivalent to compute_contiguous() from TensorImpl.cpp +def is_contiguous(a: TensorLikeType) -> bool: + """ + Tests whether a tensor is contiguous or not. + + Tensors are contiguous when they have no elements, + one element, or when they have "nested" strides. + """ + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(a.numel() < 2): + return True + + expected_stride = 1 + for x, y in reversed(tuple(zip(a.shape, a.stride()))): + # Skips checking strides when a dimension has length 1 + if guard_size_oblivious(x == 1): + continue + + if y != expected_stride: + return False + expected_stride = expected_stride * x + + return True + + +# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp +def is_channels_last_contiguous_2d(a: Tensor) -> bool: + # NHWC or not channels last 2D contiguous + if a.ndim != 4: + return False + + expected_stride = 1 + for idx in (1, 3, 2, 0): + length = a.shape[idx] + if length == 1: + continue + + stride = a.stride()[idx] + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +def is_channels_last_contiguous_3d(a: Tensor) -> bool: + # NDHWC or not channels last 3D contiguous + if a.ndim != 5: + return False + + expected_stride = 1 + for idx in (1, 4, 3, 2, 0): + length = a.shape[idx] + if length == 1: + continue + + stride = a.stride()[idx] + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +_memory_formats = { + torch.contiguous_format, + torch.preserve_format, + torch.channels_last, + torch.channels_last_3d, +} + + +def validate_memory_format(memory_format: torch.memory_format): + torch._check( + memory_format in _memory_formats, + lambda: f"Received unknown memory format {memory_format}!", + ) + + +def is_contiguous_for_memory_format( # type: ignore[return] + a: Tensor, *, memory_format: torch.memory_format +) -> bool: + validate_memory_format(memory_format) + + if memory_format == torch.contiguous_format: + return is_contiguous(a) + if memory_format == torch.channels_last: + return is_channels_last_contiguous_2d(a) + if memory_format == torch.channels_last_3d: + return is_channels_last_contiguous_3d(a) + + torch._check( + False, + lambda: f"is_contiguous received unsupported memory format {memory_format}", + ) + + +# NOTE: that tensors with no elements and channels last is ??? +def is_channels_last_contiguous(a: Tensor) -> bool: + """ + True when a tensor is channels-last contiguous. + + This requires that: + + - the tensor is conceptually either 4 (NHWC) or 5 (NDHWC) dimensions + - if we name the tensor's dimensions NCHW or NCDHW, then the strides are such that the + stride of the 'C' dimension (Cs) is 1 and the strides corresponding to + each dimension (Xs) can be ordered Cs <= Ws <= Hs <= (Ds) <= Ns and are + "nested" -- so Ws = Cs * Cl, where Cl is the length of the 'C' dimension, + for example. + """ + return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a) + + +def is_non_overlapping_and_dense(a: Tensor) -> bool: + """ + True when a tensor is non-overlapping and dense. + + A tensor is non-overlapping and dense when there exists a permutation of + its dimensions that is contiguous. + """ + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if a.is_sparse: + return False + + # Short-circuits if the tensor is already contiguous or channels-last contiguous + if is_contiguous(a) or is_channels_last_contiguous(a): + return True + + # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp + + # Short-circuits for tensors of rank one, which are + # non-overlapping and "dense" if their stride is one + if a.ndim == 1: + return a.stride()[0] == 1 + + # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous + # Sorts (length, stride) pairs by stride + # + # This sort is done in a size-oblivious way, which helps if we do a + # comparison like 2048*u0 > u0; we just want this to return True + # (and not worry about what if u0 is zero). + class K(NamedTuple): + size: int + stride: int + + def __lt__(self, other): + return guard_size_oblivious(self.stride < other.stride) + + def __gt__(self, other): + return guard_size_oblivious(self.stride > other.stride) + + def __le__(self, other): + return guard_size_oblivious(self.stride <= other.stride) + + def __ge__(self, other): + return guard_size_oblivious(self.stride >= other.stride) + + def __eq__(self, other): + return guard_size_oblivious(self.stride == other.stride) + + lengths_and_strides = sorted(map(K, a.shape, a.stride())) + + expected_stride = 1 + for length, stride in lengths_and_strides: + if guard_size_oblivious(length == 1): + continue + + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +# NOTE: Based on the implementation in TensorIterator.cpp, but note that +# the note [Computing output strides] is incorrect, because it +# says that strides will be preserved even if they are not +# "non overlapping and dense", but this is incorrect. The +# output of elementwise operations are always given +# non overlapping and dense strides. +# This is also INCORRECT because it does not model TensorIterator's +# short-circuit, which can cause different strides. +def compute_elementwise_output_logical_to_physical_perm( + *tensors, _skip_checks=False +) -> List[int]: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if not _skip_checks and len(tensors) == 0: + msg = "Can't compute elementwise output strides for zero tensors!" + raise ValueError(msg) + + if not _skip_checks: + check_same_shape(*tensors, allow_cpu_scalar_tensors=True) + + # Filters the tensors to actual tensors + if not _skip_checks: + tensors = tuple( + a + for a in tensors + if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) + ) + + # Short-circuits for CPU scalar case + if len(tensors) == 0: + return [] + + # Short-circuits for shapes with zero or one dimensions + # TODO: are these necessary? + ndim = tensors[0].ndim + if ndim == 0: + return [] + if ndim == 1: + return [0] + + # Short-circuits if contiguous, following the fake fast path. + # This reduces the number of guards we end up making + # TODO: do channels last too + is_contiguous = True + for t in tensors: + is_contiguous = is_contiguous and t.is_contiguous( + memory_format=torch.contiguous_format + ) + + if is_contiguous: + return list(range(ndim)) + + shape = tensors[0].shape + + def should_swap(idx_a, idx_b): + for tensor in tensors: + stride_a = tensor.stride()[idx_a] + stride_b = tensor.stride()[idx_b] + + if guard_size_oblivious(stride_a == 0) or guard_size_oblivious( + stride_b == 0 + ): + continue + + if guard_size_oblivious(stride_a < stride_b): + return -1 + + if guard_size_oblivious(stride_a > stride_b): + return 1 + + # stride_a == stride_b + if guard_size_oblivious(shape[idx_a] > shape[idx_b]): + return 1 + + # Note: this case is hit if all strides are zero, + # or all strides are equal and all dimensions have the same length + return 0 + + # The "sort" order for the permutation is back-to-front, but + # the natural order for permutations is front-to-back. Do the + # sorting back-to-front and then reverse it on output. + # + # also, note this returns the logical to physical shape permutation + perm = list(reversed(range(ndim))) + + # insertion sort with support for ambiguous comparisons + for i in range(1, ndim): + dim1 = i + for dim0 in reversed(range(i)): + comparison = should_swap(perm[dim0], perm[dim1]) + if comparison > 0: + perm[dim0], perm[dim1] = perm[dim1], perm[dim0] + dim1 = dim0 + elif comparison < 0: + break + + return list(reversed(perm)) + + +def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]: + """ + Computes the output strides for elementwise operations. + """ + if len(tensors) == 0: + msg = "Can't compute elementwise output strides for zero tensors!" + raise ValueError(msg) + + check_same_shape(*tensors, allow_cpu_scalar_tensors=True) + + # Filters the tensors to actual tensors + tensors = tuple( + a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) + ) + + # Short-circuits for CPU scalar case + if len(tensors) == 0: + return () + + ndim = tensors[0].ndim + shape = tensors[0].shape + + if ndim == 0: + return () + if ndim == 1: + return (1,) + + logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm( + *tensors, _skip_checks=True + ) + permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical + + new_strides = make_contiguous_strides_for(permuted_shape) + permuted_strides = apply_perm( + new_strides, invert_perm(logical_to_physical_perm) + ) # to logical + + return tuple(permuted_strides) + + +# Identity permutation is [0, 1, 2] +def apply_perm(inp, perm): + ndim = len(inp) + permuted_inp = [-1] * ndim + for idx, x in enumerate(perm): + permuted_inp[idx] = inp[x] + return permuted_inp + + +def invert_perm(perm): + ndim = len(perm) + new_perm = [-1] * ndim + for idx, x in enumerate(perm): + new_perm[x] = idx + return new_perm + + +# +# Common helper functions +# + + +def validate_dim_length(length: int): + """ + Validates that an object represents a valid + dimension length. + """ + + if isinstance(length, (int, torch.SymInt)): + torch._check_is_size(length) + else: + # sometimes called with sympy expression by inductor + assert length >= 0 + + +def validate_shape(shape: ShapeType): + """ + Validates that a sequence represents a valid shape. + """ + + assert isinstance(shape, Sequence), type(shape) + for l in shape: + validate_dim_length(l) + + +def validate_strides(strides: StrideType): + """ + Verifies the object specifies valid strides. + """ + + assert isinstance(strides, Sequence) + for stride in strides: + assert stride >= 0 + + +def validate_idx(rank: int, idx: int): + """ + Validates that idx is a valid index for the given shape. + Assumes the index is already canonicalized. + """ + + assert isinstance(idx, Dim) + assert isinstance(rank, Dim) + + assert idx >= 0 and idx < rank or idx == 0 + + +def validate_dimension_indices(rank: int, indices: DimsSequenceType): + for idx in indices: + validate_idx(rank, idx) + + +def validate_exclusive_idx(rank: int, ex_idx: int): + """ + Validates that ex_idx is a valid exclusive index + for the given shape. + """ + + assert isinstance(ex_idx, Dim) + assert isinstance(rank, Dim) + assert ex_idx > 0 and ex_idx <= rank + + +# "Wraps" a dim (up to one time) for the given rank, allowing dims to be +# specified using negative indices. If `wrap_scalar` is true then scalar +# tensors of rank 0 will allow dimensions in the range [-1, 0]. Otherwise, +# idx should be in the range [-rank, rank-1]. +def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int: + if rank < 0: + msg = f"Rank cannot be negative but got {rank}" + raise IndexError(msg) + + if rank == 0: + if not wrap_scalar: + msg = f"Dimension specified as {idx} but tensor has no dimensions" + raise IndexError(msg) + rank = 1 + + if idx >= 0 and idx < rank: + return idx + + if idx < 0: + _idx = idx + rank + else: + _idx = idx + + if _idx < 0 or _idx >= rank: + # Same error message as in aten/src/ATen/WrapDimUtils.h:49 + msg = f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {idx})" + raise IndexError(msg) + + return _idx + + +# Takes a dimension or sequence of dimensions and "wraps" them, +# mapping negative offsets to positive ones +@overload +def canonicalize_dims( + rank: int, indices: Sequence[int], wrap_scalar: bool = True +) -> Tuple[int, ...]: + pass + + +@overload +def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int: + pass + + +def canonicalize_dims(rank, indices, wrap_scalar=True): + if isinstance(indices, Dim): + return canonicalize_dim(rank, indices, wrap_scalar) + + return tuple(canonicalize_dim(rank, x, wrap_scalar) for x in indices) + + +def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool: + """ + Validates that perm is a permutation of length rank. + """ + + if not isinstance(perm, Sequence): + return False + + if not (tuple(sorted(perm)) == tuple(range(0, rank))): + return False + + return True + + +def is_same_shape(a: Sequence, b: Sequence) -> bool: + """ + Compares two shapes a and b, returning True if they are the same + (their ranks and corresponding lengths match) and False otherwise. + """ + + return tuple(a) == tuple(b) + + +def is_cpu_scalar_tensor(a: Any) -> bool: + return isinstance(a, TensorLike) and a.ndim == 0 and a.device.type == "cpu" + + +def check_same_device(*args, allow_cpu_scalar_tensors): + """ + Checks that all Tensors in args have the same device. + + Raises a RuntimeError when: + - args contains an object whose type is not Tensor or Number + - two Tensor objects in args have different devices, unless one is a CPU scalar tensor and allow_cpu_scalar_tensors is True + """ + # Short-circuits if all (one or fewer) arguments are trivially on the same device + if len(args) <= 1: + return + + # Note: cannot initialize device to the first arg's device (it may not have one) + device = None + for arg in args: + if isinstance(arg, Number): + continue + elif isinstance(arg, TensorLike): + if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): + continue + + if device is None: + device = arg.device + + if device != arg.device: + msg = ( + "Tensor on device " + + str(arg.device) + + " is not on the expected device " + + str(device) + + "!" + ) + raise RuntimeError(msg) + else: + msg = ( + "Unexpected type when checking for same device, " + str(type(arg)) + "!" + ) + raise RuntimeError(msg) + + +def canonicalize_device(device: DeviceLikeType) -> torch.device: + if isinstance(device, torch.device): + return device + + assert isinstance(device, str) + return torch.device(device) + + +# Asserts if any of the following are true: +# - a non-scalar or non-Tensor is given +# - the shape of any tensors is distinct +def check_same_shape(*args, allow_cpu_scalar_tensors: bool): + """ + Checks that all Tensors in args have the same shape. + + Raises a RuntimeError when: + - args contains an object whose type is not Tensor or Number + - two Tensor objects in args have different devices + """ + shape = None + + for arg in args: + if isinstance(arg, Number): + continue + elif isinstance(arg, TensorLike): + if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): + continue + + if shape is None: + shape = arg.shape + + if not is_same_shape(shape, arg.shape): + msg = f"Shape {arg.shape} is not the expected shape {shape}!" + raise RuntimeError(msg) + else: + msg = ( + "Unexpected type when checking for same shape, " + str(type(arg)) + "!" + ) + raise RuntimeError(msg) + + +# Acquires a common shape, if it exists, from one or more tensor arguments, +# filtering number arguments +def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]: + shape = None + scalar_shape = None + + for arg in args: + if isinstance(arg, Number): + continue + elif isinstance(arg, TensorLike): + if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): + scalar_shape = arg.shape + continue + + if shape is None: + shape = arg.shape + + if not is_same_shape(shape, arg.shape): + return None + else: + return None + + return shape if shape is not None else scalar_shape + + +# Extracts dimensions that might be passed either as a list/tuple or as varargs. +# A typical case is Tensor.permute . +def extract_dims_from_varargs( + dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]] +) -> DimsSequenceType: + if dims and isinstance(dims[0], Sequence): + assert len(dims) == 1 + dims = cast(Tuple[DimsSequenceType], dims) + return dims[0] + else: + return cast(DimsSequenceType, dims) + + +def extract_shape_from_varargs( + shape: Union[ShapeType, Tuple[ShapeType]], + validate=True, +) -> Tuple[int, ...]: + """ + Returns a shape from varargs. + + In PyTorch, operations that accept shapes often accept them as varargs, like + foo(*shape). However a user can pass the shape as a sequence of integers, + like this: + + foo(1, 2, 3) + + or as a sequence of integers + + foo((1, 2, 3)) + + In the first case shape will be a tuple of integers, and in the second case it's a tuple + containing a tuple of integers. This validates those inputs and canonicalizes them + to a tuple of integers. + """ + + # Handles tuple unwrapping + if len(shape) == 1 and isinstance(shape[0], Sequence): + shape = shape[0] + + if validate: + validate_shape(shape) # type: ignore[arg-type] + return shape # type: ignore[return-value] + + +def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]: + ndim = max(len(a), len(b)) + expandedSizes = [0] * ndim + + for i in range(ndim - 1, -1, -1): + offset = ndim - 1 - i + dimA = len(a) - 1 - offset + dimB = len(b) - 1 - offset + sizeA = a[dimA] if dimA >= 0 else 1 + sizeB = b[dimB] if dimB >= 0 else 1 + + torch._check( + (sizeA == sizeB) or (sizeA == 1) or (sizeB == 1), + lambda: ( + f"The size of tensor a ({sizeA}) must match the size of " + f"tensor b ({sizeB}) at non-jagged dimension {i}" + ), + ) + + # 1s map to the other size (even 0) + expandedSizes[i] = sizeB if sizeA == 1 else sizeA + + return tuple(expandedSizes) + + +def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]: + """ + Infers the size of a dim with size -1, if it exists. + Also checks that new shape is compatible with the number of elements. + """ + dim = None + newsize = 1 + for i, d in enumerate(shape): + if d == -1: + torch._check(dim is None, lambda: "only one dimension can be inferred") + dim = i + elif d >= 0: + newsize *= d + else: + torch._check(False, lambda: f"invalid shape dimension {d}") + if dim is None: + torch._check( + numel == newsize, + lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", + ) + else: + from torch.fx.experimental.symbolic_shapes import definitely_true + + torch._check( + newsize != 0, + lambda: ( + f"cannot reshape tensor of 0 elements into shape {list(shape)} because the " + f"unspecified dimension size -1 can be any value and is ambiguous" + if definitely_true(numel == 0) + else f"shape '{list(shape)}' is invalid for input of size {numel}" + ), + ) + torch._check( + numel % newsize == 0, + lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", + ) + # Convert to list to produce a compatible error message with core + # PyTorch, which prints sequences in square brackets. + shape = list(shape) + shape[dim] = numel // newsize + # NB: This is pretty important when you have unbacked SymInts. + # Suppose you have (i0, 12) resizing into (2, -1, 12). The old + # range for i0 is typically [2, inf], which means if you divide + # by two the new range should be [1, inf]. But this is bad news + # if you have an unbacked SymInt: we need to reapply the unsound + # assumption that the size is >= 2. + torch._check_is_size(shape[dim]) + return tuple(shape) + + +_integer_dtypes = ( + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, +) +_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32) +_complex_dtypes = (torch.complex32, torch.complex64, torch.complex128) + + +def is_boolean_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype is torch.bool + + +def is_integer_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype in _integer_dtypes + + +def is_low_precision_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype in _low_precision_dtypes + + +def is_float_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype.is_floating_point + + +def is_complex_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype in _complex_dtypes + + +def is_grad_dtype(dtype: torch.dtype) -> bool: + """ + Checks if the dtype can require a gradient. + """ + return dtype.is_floating_point or is_complex_dtype(dtype) + + +_complex_to_real_dtype_map = { + torch.complex128: torch.float64, + torch.complex64: torch.float32, + torch.complex32: torch.float16, +} + +_real_to_complex_dtype_map = { + torch.float16: torch.complex32, + torch.bfloat16: torch.complex64, + torch.float32: torch.complex64, + torch.float64: torch.complex128, +} + + +def corresponding_real_dtype(dtype: torch.dtype) -> torch.dtype: + return _complex_to_real_dtype_map[dtype] + + +def corresponding_complex_dtype(dtype: torch.dtype) -> torch.dtype: + return _real_to_complex_dtype_map[dtype] + + +def dtype_to_type(dtype: torch.dtype) -> type: + """ + Computes the corresponding Python type (AKA "type kind") for the + given dtype. + """ + assert isinstance(dtype, torch.dtype) + + if dtype is torch.bool: + return bool + if dtype in _integer_dtypes: + return int + if dtype.is_floating_point: + return float + if dtype in _complex_dtypes: + return complex + + raise ValueError("Invalid dtype!") + + +def dtype_to_type_ctor(dtype: torch.dtype) -> Callable[[NumberType], NumberType]: + """ + Computes the corresponding Python type constructor for the + given dtype. + """ + assert isinstance(dtype, torch.dtype) + + if dtype is torch.bool: + return lambda x: bool(x) + if dtype in _integer_dtypes: + return sym_int + if dtype.is_floating_point: + return sym_float + if dtype in _complex_dtypes: + # TODO: type error here is real, replace with sym_complex + return lambda x: complex(x) # type: ignore[arg-type] + + raise ValueError("Invalid dtype!") + + +def type_to_dtype(typ: type) -> torch.dtype: + """ + Computes the corresponding dtype for a Number type. + """ + + assert isinstance(typ, type) + + if typ is bool: + return torch.bool + if typ in [int, torch.SymInt]: + return torch.long + if typ in [float, torch.SymFloat]: + return torch.get_default_dtype() + # TODO: sym_complex_float? + if typ is complex: + return corresponding_complex_dtype(torch.get_default_dtype()) + + raise ValueError("Invalid type!") + + +def get_dtype(x: Union[torch.Tensor, NumberType]): + if isinstance(x, torch.Tensor): + return x.dtype + else: + return type_to_dtype(type(x)) + + +_ordered_types = (bool, int, float, complex) + + +def check_fp_or_complex( + dtype: torch.dtype, fn_name: str, allow_low_precision_dtypes: bool = True +): + """ + Checks whether the input is floating point or complex. + If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32 + """ + torch._check( + is_float_dtype(dtype) or is_complex_dtype(dtype), + lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}", + ) + torch._check( + allow_low_precision_dtypes or not is_low_precision_dtype(dtype), + lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}", + ) + + +def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"): + torch._check( + len(A.shape) >= 2, + lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", + ) + + +def get_higher_type(a: type, b: type) -> type: + """ + Returns the higher of the two given Number types. + + The types are ordered bool -> int -> float -> complex. + """ + a, b = _maybe_get_pytype(a), _maybe_get_pytype(b) + # Type checking + if a not in _ordered_types or b not in _ordered_types: + raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}") + + if a is b: + return a + + for typ in _ordered_types: + if a is typ: + return b + if b is typ: + return a + + raise ValueError("Unknown Python scalar type!") + + +# Returns the higher of two torch datatypes a and b or, if the two +# are not ordered relative to each other, the next +# higher datatype +def get_higher_dtype( + a: Optional[Union[torch.dtype, TensorLikeType, NumberType]], + b: Optional[Union[torch.dtype, TensorLikeType, NumberType]], +) -> Optional[torch.dtype]: + """ + Computes the "lowest" datatype that is weakly + "higher" than both a and b. + """ + + # Type checking + assert a is None or isinstance(a, (torch.dtype, TensorLike, Number)) + assert b is None or isinstance(b, (torch.dtype, TensorLike, Number)) + + def _extract_dtype( + x: Optional[Union[torch.dtype, TensorLikeType, NumberType]] + ) -> Optional[torch.dtype]: + if x is None: + return None + if isinstance(x, torch.dtype): + return x + if isinstance(x, TensorLike): + return x.dtype + if isinstance(x, Number): + return type_to_dtype(type(x)) + + raise RuntimeError("Unexpected type given to _extract_dtype!") + + a, b = _extract_dtype(a), _extract_dtype(b) + + if a is b: + return a + + if a is None: + return b + + if b is None: + return a + + ordered_datatypes = ( + (torch.bool,), + (torch.uint8, torch.int8), + (torch.int16,), + (torch.int32,), + (torch.int64,), + (torch.float16, torch.bfloat16), + (torch.float32,), + (torch.float64,), + (torch.complex32,), + (torch.complex64,), + (torch.complex128,), + ) + + for idx, dtypes in enumerate(ordered_datatypes): + if a in dtypes and b in dtypes: + return ordered_datatypes[idx + 1][0] + if a in dtypes: + return b + if b in dtypes: + return a + + raise RuntimeError("Unexpected termination!") + + +def check_pin_memory(pin_memory: bool): + torch._check_not_implemented( + not pin_memory, lambda: "PrimTorch does not support pinned memory" + ) + + +def check_layout(layout: torch.layout): + torch._check_not_implemented( + layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}" + ) + + +# TODO: maybe unify with can_cast_to? +def is_weakly_lesser_type(a: type, b: type) -> bool: + """ + Compares two types, a and b, returning True if a is weakly "less" than b. + + The comparison is determined by the following type ordering: bool, int, float, complex. + """ + + a, b = _maybe_get_pytype(a), _maybe_get_pytype(b) + + if a not in _ordered_types or b not in _ordered_types: + raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}") + + for typ in _ordered_types: + if a == typ: + return True + if b == typ: + return False + + raise RuntimeError("Unexpected termination!") + + +def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool: + for fn in (is_complex_dtype, is_float_dtype, is_integer_dtype, is_boolean_dtype): + if fn(cast_to): + return True + if fn(cast_from): + return False + + raise ValueError(f"Received unknown dtypes {cast_to}, {cast_from}!") + + +def check_same_dtype(*args): + """ + Checks that all Tensors in args have the same device and that all Numbers have the + same corresponding Python type. + + Raises a RuntimeError when: + - args contains an object whose type is not Tensor or Number + - two Tensors objects in args have different dtypes + - two Number objects in args have different types + - there are Tensors and Numbers in args, and one of those Tensors corresponding + Python types is different from the type of one of those Numbers + """ + full_dtype = None + scalar_type = None + + for arg in args: + if isinstance(arg, Number): + # Scalar type checking is disabled (and may be removed in the future) + continue + # if scalar_type is None: + # scalar_type = type(arg) + + # if scalar_type is not type(arg): + # msg = ( + # "Scalar of type " + # + str(type(arg)) + # + " is not the expected type of " + # + str(scalar_type) + # + "!" + # ) + # raise RuntimeError(msg) + elif isinstance(arg, TensorLike): + if full_dtype is None: + full_dtype = arg.dtype + if scalar_type is None: + scalar_type = dtype_to_type(arg.dtype) + + if full_dtype is not arg.dtype: + msg = ( + "Tensor with dtype " + + str(arg.dtype) + + " is not the expected dtype of " + + str(full_dtype) + + "!" + ) + raise RuntimeError(msg) + + arg_type = dtype_to_type(arg.dtype) + if arg_type is not scalar_type: + msg = ( + "Tensor with corresponding Python type " + + str(arg_type) + + " is not the expected type of " + + str(scalar_type) + + "!" + ) + raise RuntimeError(msg) + else: + msg = ( + "Unexpected type when checking for same dtype, " + str(type(arg)) + "!" + ) + raise RuntimeError(msg) + + +# Maps datatypes to their computation types for elementwise operations +_computation_dtype_map = { + torch.bfloat16: torch.float32, + torch.float16: torch.float32, + torch.complex32: torch.complex64, +} + + +def get_computation_dtype(dtype: torch.dtype) -> torch.dtype: + return _computation_dtype_map.get(dtype, dtype) + + +_cpu_acc_type_map = { + torch.bfloat16: torch.float64, + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.complex32: torch.complex128, + torch.complex64: torch.complex128, +} + + +def get_acc_type(dtype: torch.dtype, device: torch.device) -> torch.dtype: + # Equivalent to at::toAccumulateType, prefer computation_dtype where possible + if device.type == "cpu": + return _cpu_acc_type_map.get(dtype, dtype) + else: + return get_computation_dtype(dtype) + + +class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum): + DEFAULT = (0,) + NO_OPMATH = (1,) + INT_TO_FLOAT = (2,) + ALWAYS_BOOL = (3,) + COMPLEX_TO_FLOAT = (4,) + BOOL_TO_LONG = (5,) + + +class REDUCTION_OUTPUT_TYPE_KIND(Enum): + SAME = (0,) + COMPLEX_TO_FLOAT = (1,) # for complex types outputs corresponding real type + KEEP_PROMOTED_TYPE = (2,) # keep output in opmath type, needed for mean + ALWAYS_BOOL = (3,) + + +# Describes the return type of the primitive: +# +# - NEW, a new tensor is created +# - VIEW, a view of an input tensor is returned +# - INPLACE, one or more input tensors is modified +# +# these descriptors are mututally exclusive and exhaustive. +class RETURN_TYPE(Enum): + NEW = (0,) + VIEW = (1,) + INPLACE = (2,) + + +# TODO: when NumberType contains the sym types, can simplify this +def number_type(x: Union[NumberType, torch.SymInt, torch.SymFloat]) -> Type: + if isinstance(x, torch.SymInt): + return int + elif isinstance(x, torch.SymFloat): + return float + else: + return type(x) + + +def expr_type(x: sympy.Expr) -> Type: + if x.is_integer: # type: ignore[attr-defined] + return int + else: + # NB: Not strictly correct, but we don't support SymPy complex or bool. + return float + + +# TODO: document type promotion kinds +def elementwise_dtypes( + *_args, + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, +) -> Tuple[torch.dtype, torch.dtype]: + """ + Computes the computation and result dtypes for elementwise type promotion + on the given arguments and with the given elementwise type promotion kind. + + Note that not all inputs to an elementwise operation necessarily participate in type promotion. + For example, the "alpha" parameter of torch.add does not participate in type promotion, + although it may be cast to the Python type corresponding to the computation dtype that + the type promotion algorithm determines. + + Default elementwise type promotion, which all other type promotion kinds tweak (see below), + first decides which of four ordered types to use: + + bool -> integer -> floating point -> complex + + The selected type is the "lowest" type in the above list such that all number arguments + have a weakly "lower" type and all tensor arguments have a weakly lower corresponding + type for their dtype. + + Once the type is determined, the particular result dtype is found. The dtypes are + partially ordered as follows: + + bool -> uint8, int8 -> int16 -> int32 -> int64 -> + float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128 + + The result dtype is selected by: + - if no tensor's dtype has the same corresponding type as the one selected, + then the result dtype is the (default) dtype corresponding to the selected type + (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype) + - if the result type is complex then the dtype is: + - the default complex dtype if there are no floating point or complex tensors + - if there are floating point or complex tensors with one or more dimensions, then + the complex dtype corresponding to the highest corresponding complex dtype among those tensors + (for example, double + cfloat -> cdouble) + - if there are only floating point or complex tensors with zero dimensions, then + the complex dtype corresponding to the highest corresponding complex dtype among those tensors + - if the first two cases do not apply, the result dtype is the highest dtype among + all tensors with one or more dimensions of the output type, and if there are no such + tensors then it's the highest dtype among all tensors with zero dimensions of the output type + (for example, long + half -> half, even if the half tensor has zero dimensions) + + The "corresponding complex dtypes" are: + float16 -> complex32 + bfloat16 -> complex64 + float32 -> complex64 + float64 -> complex128 + complex32 -> complex32 + complex64 -> complex64 + complex128 -> complex128 + + The DEFAULT type promotion kind computes per above, and then uses the result dtype to pick a computation + dtype by mapping low precision floating point and complex dtypes as follows: + + float16 -> float32 + bfloat16 -> float32 + complex32 -> complex64 + + This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the + computation dtype the same as the result dtype when it's selected. NO_OPMATH is appropriate for kernels + which perform no mathematical operations on their tensors (see below for examples). + + The INT_TO_FLOAT type promotion kind maps boolean and integer result dtypes to the default floating point dtype, + and computation dtypes to the appropriate op math dtype. + + The COMPLEX_TO_FLOAT type promotion kind maps complex result dtypes to the corresponding float dtype, following this + mapping: + + complex32 -> float16 + complex64 -> float32 + complex128 -> float64 + + Note that COMPLEX_TO_FLOAT derives the computation dtype as the DEFAULT setting does. + + The BOOL_TO_LONG type promotion kind maps boolean computation and result dtypes to long. + + The ALWAYS_BOOL type promotion kind always sets the result dtype to bool. + + Example operators for each type promotion option: + DEFAULT : add + NO_OPMATH : where, nextafter, cat + INT_TO_FLOAT : sin + COMPLEX_TO_FLOAT : abs + BOOL_TO_LONG : pow + ALWAYS_BOOL : eq + + """ + + args = tuple(x for x in _args if x is not None) + + highest_type: type = bool + + # Import sympy locally, as importing it eagerly at a module level is too slow + # See https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589 + import sympy + + for x in args: + if not isinstance(x, (Number, TensorLike, sympy.Expr)): + msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!" + raise ValueError(msg) + + if isinstance(x, Number): + highest_type = get_higher_type(highest_type, number_type(x)) + elif isinstance(x, sympy.Expr): + highest_type = get_higher_type(highest_type, expr_type(x)) + else: + # x is a TensorLike + highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype)) + + result_dtype = None + + def _find_highest_dtype_filtered( + args, filter, *, float_as_complex=False + ) -> Optional[torch.dtype]: + zero_dim_tensor_dtype = None + one_plus_dim_tensor_dtype = None + for x in args: + if isinstance(x, TensorLike) and filter(x.dtype): + _dtype = x.dtype + if float_as_complex and is_float_dtype(_dtype): + _dtype = corresponding_complex_dtype(_dtype) + if x.ndim == 0: + zero_dim_tensor_dtype = get_higher_dtype( + zero_dim_tensor_dtype, _dtype + ) + else: + # x.ndim > 0 + one_plus_dim_tensor_dtype = get_higher_dtype( + one_plus_dim_tensor_dtype, _dtype + ) + + # Prefers dtype of tensors with one or more dimensions + if one_plus_dim_tensor_dtype is not None: + return one_plus_dim_tensor_dtype + + return zero_dim_tensor_dtype + + if highest_type is float: + result_dtype = _find_highest_dtype_filtered(args, is_float_dtype) + result_dtype = ( + torch.get_default_dtype() if result_dtype is None else result_dtype + ) + elif highest_type is complex: + result_dtype = _find_highest_dtype_filtered( + args, + lambda x: is_float_dtype(x) or is_complex_dtype(x), + float_as_complex=True, + ) + if result_dtype is None: + result_dtype = corresponding_complex_dtype(torch.get_default_dtype()) + elif highest_type is int: + result_dtype = _find_highest_dtype_filtered(args, is_integer_dtype) + result_dtype = torch.long if result_dtype is None else result_dtype + else: + # highest_type is bool + result_dtype = torch.bool + + if type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT: + return get_computation_dtype(result_dtype), result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH: + return result_dtype, result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT: + if is_integer_dtype(result_dtype) or is_boolean_dtype(result_dtype): + result_dtype = torch.get_default_dtype() + return get_computation_dtype(result_dtype), result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT: + # NOTE: computation can still occur in a complex dtype + computation_dtype = get_computation_dtype(result_dtype) + if is_complex_dtype(result_dtype): + result_dtype = corresponding_real_dtype(result_dtype) + return computation_dtype, result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG: + if is_boolean_dtype(result_dtype): + return torch.long, torch.long + return get_computation_dtype(result_dtype), result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL: + return get_computation_dtype(result_dtype), torch.bool + else: + raise ValueError(f"Unknown type promotion kind {str(type_promotion_kind)}") + + +def reduction_dtypes( + arg, + output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, + dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.dtype, Optional[torch.dtype]]: + # even though some reductions, like amin or amax, don't strictly require type promotion, + # all the math ops (including comparisons) are still defined only for a computation type, + # so promotion will still happen. We are doing it explicitly here + inp_dtype = dtype if dtype is not None else arg.dtype + computation_dtype = get_computation_dtype(inp_dtype) + if ( + output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME + or output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + ): + result_dtype = dtype if dtype else arg.dtype + if ( + output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + and is_complex_dtype(result_dtype) + ): + result_dtype = corresponding_real_dtype(result_dtype) + elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE: + result_dtype = None + else: # ALWAYS_BOOL + result_dtype = torch.bool + return computation_dtype, result_dtype + + +# This function's logic is borrowed from the following functions defined in C++: +# batched_matrix_contiguous_strides and contiguous_strides +def make_contiguous_strides_for( + shape: ShapeType, row_major: bool = True +) -> Tuple[int, ...]: + """ + Returns the strides of a contiguous tensor if row_major + If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices + This is often used when calling external libraries like BLAS/LAPACK/cuSolver... + """ + # contiguous_strides from c10/util/strides.h + validate_shape(shape) + if not shape: + return () + + from torch.fx.experimental.symbolic_shapes import is_nested_int + + multiplier = 1 + strides = [] + for l in reversed(shape): + strides.append(multiplier) + multiplier *= l if is_nested_int(l) else sym_max(l, 1) + + result = tuple(reversed(strides)) + + # batched_matrix_contiguous_strides from aten/src/ATen/native/LinearAlgebraUtils.h + if row_major: + return result + else: + if len(shape) < 2: + return result + return result[:-2] + (1, max(shape[-2], 1)) + + +def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]: + torch._check( + len(shape) == 3, + lambda: "Only tensors of rank 3 can use the channels_last_1d memory format", + ) + + multiplier = 1 + strides = [0] * 3 + for idx in (1, -1, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]: + # TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5? + torch._check( + len(shape) == 4, + lambda: "Only tensors of rank 4 can use the channels_last memory format", + ) + + multiplier = 1 + strides = [0] * 4 + for idx in (1, -1, -2, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]: + torch._check( + len(shape) == 5, + lambda: "Only tensors of rank 5 can use the channels_last_3d memory format", + ) + + multiplier = 1 + strides = [0] * 5 + for idx in (1, -1, -2, -3, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_strides_for(shape: ShapeType) -> Tuple[int, ...]: + ndim = len(shape) if isinstance(shape, Sequence) else 1 + if ndim == 3: + return make_channels_last_1d_strides_for(shape) + elif ndim == 4: + return make_channels_last_2d_strides_for(shape) + elif ndim == 5: + return make_channels_last_3d_strides_for(shape) + else: + raise RuntimeError( + f"no channels last format strides exist in {ndim} dimensions" + ) + + +def compute_reduction_output_shape( + shape: ShapeType, dimensions: Sequence +) -> Tuple[int, ...]: + for idx in dimensions: + validate_idx(len(shape), idx) + + new_shape = [] + for idx in range(len(shape)): + if idx in dimensions: + continue + + new_shape.append(shape[idx]) + + return tuple(new_shape) + + +def validate_no_repeating_dims(dims: Sequence): + if len(dims) != len(set(dims)): + raise RuntimeError("duplicate value in the list of dims") + + +def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]: + if dims is None: + return tuple(range(len(shape))) + dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims) + validate_no_repeating_dims(dims) + return dims + + +def set_correction( + unbiased: Optional[bool] = None, + correction: Optional[NumberType] = None, +) -> float: + if correction is not None and unbiased is not None: + raise RuntimeError("cannot specify both correction and unbiased arguments") + elif correction is None and unbiased is None: + correction = 1.0 + elif correction is None and unbiased is not None: + correction = 0.0 if unbiased is False else 1.0 + # NB: we don't actually support symint here, but it's harmless to accept + if not isinstance(correction, (IntLike, FloatLike)): + raise ValueError("correction argument should be integer or float") + if correction < 0: + raise ValueError("correction argument should be non-negative") + return sym_float(correction) + + +def compute_required_storage_length( + shape: ShapeType, strides: StrideType, storage_offset: int +) -> int: + """Computes the minimum storage size to hold the given tensor geometry. + + Example + ======= + + This is the size of a newly allocated tensor's storage, in units of elements + + >>> t = torch.empty((10, 20)) + >>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset()) + 200 + + >>> # xdoctest: +SKIP(failing) + >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11)) + >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset()) + >>> size == t.storage().size() + True + + A valid tensor may have a larger storage size, but never smaller + + >>> slice = torch.empty(100)[20:40] + >>> slice.storage().size() + 100 + + >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset()) + 40 + + """ + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + # Short-circuits if the shape has no elements + if guard_size_oblivious(reduce(operator.mul, shape, 1) == 0): + return 0 + + max_offset = sum((x - 1) * y for x, y in zip(shape, strides)) + # +1 to account for the first element which offsets are taken from + return 1 + storage_offset + max_offset + + +def check_in_bounds_for_storage( + a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int +): + """ + Determines if the given shape, strides, and offset are valid for the given storage. + """ + + required_length = compute_required_storage_length(shape, strides, storage_offset) + if a.size() < required_length: + msg = ( + "Can't view a storage of size {} with an offset of {}, shape of {}, and strides of {}, " + "which requires a storage of size {}".format( + a.size(), storage_offset, str(shape), str(strides), required_length + ) + ) + raise ValueError(msg) + + +# NOTE: This function should ideally be removed, but some Meta internal models +# packaged with `torch.package` are using it, so it will have to be removed +# at some point in the future when those models no longer use this function. +def check( + b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError +) -> None: + """ + Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails. + Error message is a callable producing a string (to avoid wasting time + string formatting in non-error case, and also to make it easier for torchdynamo + to trace.) + + .. note:: This function is planned for removal in the future. Please use + `torch._check*` functions instead. + """ + warnings.warn( + DeprecationWarning( + "'torch._prims_common.check' will be removed in the future. Please use " + "'torch._check*' functions instead" + ) + ) + torch._check_with(exc_type, b, s) + + +# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in +# c10/core/MemoryFormat.h into one function +def are_strides_like_channels_last( + shape: Sequence[int], strides: Sequence[int] +) -> bool: + ndim = len(shape) + + if ndim == 4: + # Check for channels_last_2d + dim_order = [1, 3, 2, 0] + elif ndim == 5: + # Check for channels_last_3d + dim_order = [1, 4, 3, 2, 0] + else: + return False + + if strides[1] == 0: + return False + + min = 0 + for d in dim_order: + if shape[d] == 0: + return False + if strides[d] < min: + return False + if d == 0 and min == strides[1]: + return False + min = strides[d] + if strides[d] > 1: + min *= shape[d] + return True + + +def suggest_memory_format(x: TensorLikeType) -> torch.memory_format: + if x.layout != torch.strided: + return torch.contiguous_format + + if are_strides_like_channels_last(x.shape, x.stride()): + return torch.channels_last if x.ndim == 4 else torch.channels_last_3d + + return torch.contiguous_format + + +def prod(xs: Sequence[NumberType]) -> NumberType: + """Product of elements in input sequence. Returns 1 for empty sequence""" + return reduce(operator.mul, xs, 1) + + +def is_expandable_to(shape: ShapeType, desired: ShapeType) -> bool: + """Checks if a shape can be expanded to another shape. + This is equivalent to checking if the two shapes are broadcastable. + """ + # This is a Python implementation of + # aten/src/ATen/ExpandUtils.h:is_expandable_to + if len(shape) > len(desired): + return False + for i in range(len(shape)): + if shape[-i - 1] != desired[-i - 1] and shape[-i - 1] != 1: + return False + return True + + +def mask_tensor(mask: TensorLikeType, t: TensorLikeType): + """ + Similar to torch.where(mask, t, 0) but if t is boolean, + result is also boolean and not promoted to int. + """ + # torch.where(mask, t, False) is equivalent + # but feels hacky and might break in the future + if t.dtype is torch.bool: + return mask.logical_and(t) + else: + return torch.where(mask, t, 0) + + +def get_aten_op(fn: Callable, name: str): + """ + Given the __module__ of reference and its name, it returns + (our best guess of) the ATen name of the associated operation + + Note: In ATen, the __name__ of a function within a module often + starts by the module name. E.g. linalg_eigh, or special_zeta + """ + module = fn.__module__ + prefix = "torch._refs" + assert module.startswith(prefix) + module = module[len(prefix) :] + # We want to go from .special / .nn.functional + # to special and special_ / nn_functional_ + if module: + module = module[1:] + module = module.replace(".", "_") + module = module + "_" + return getattr(torch._ops.ops.aten, f"{module}{name}") + + +def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype: + return dtype if dtype is not None else torch.get_default_dtype() + + +def device_or_default(device: Optional[DeviceLikeType]) -> DeviceLikeType: + return device if device is not None else torch.device("cpu") + + +def layout_or_default(layout: Optional[torch.layout]) -> torch.layout: + return layout if layout is not None else torch.strided + + +def clone_preserve_strides(x): + needed_size = compute_required_storage_length( + x.size(), x.stride(), x.storage_offset() + ) + # Our eager implementations for *_scatter ops are all primitives w.r.t autograd, + # so these as_strided() calls are not seen by autograd. + # We need to mimic this behavior in our ref/prim implementations. + # TODO: a better way to handle this would be with a new op, "_unsafe_as_strided" + # We should revisit this when we add a compositional as_strided op, + # and also as part of https://github.com/pytorch/pytorch/issues/90507 + try: + old = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView + ) + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, True + ) + buffer = torch.as_strided(x, (needed_size,), (1,), 0).clone() + return torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset()) + finally: + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, old + ) + + +def alert_not_deterministic(caller: str): + if torch.are_deterministic_algorithms_enabled(): + if torch.is_deterministic_algorithms_warn_only_enabled(): + warnings.warn( + f"{caller} does not have a deterministic implementation, but you set " + f"'torch.use_deterministic_algorithms(True, warn_only=True)'. " + f"You can file an issue at https://github.com/pytorch/pytorch/issues " + f"to help us prioritize adding deterministic support for this operation." + ) + else: + torch._check( + False, + lambda: ( + f"{caller} does not have a deterministic implementation, but you set " + f"'torch.use_deterministic_algorithms(True)'. You can turn off " + f"determinism just for this operation, or you can use the " + f"'warn_only=True' option, if that's acceptable for your application. " + f"You can also file an issue at https://github.com/pytorch/pytorch/issues " + f"to help us prioritize adding deterministic support for this operation." + ), + ) + + +class CUDARngStateHelper: + @staticmethod + def get_torch_state_as_tuple(fake_mode=nullcontext()): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA not available") + + with fake_mode: + seed = torch.tensor(torch.cuda.initial_seed()) + offset = torch.tensor(torch.cuda._get_rng_state_offset()) + return seed, offset + + @staticmethod + def set_torch_state_tensor(seed, offset): + # Rng state is [64-bit seed, 64-bit offset] + seed_portion = seed.reshape([1]).view(torch.uint8) + offset_portion = offset.reshape([1]).view(torch.uint8) + new_state = torch.cat([seed_portion, offset_portion]) + torch.cuda.set_rng_state(new_state) + + @staticmethod + def set_new_offset(relative_offset): + torch.cuda._set_rng_state_offset(relative_offset.item()) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9896390f12434108cd43bd2e897b9aab7cb2832 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__init__.py @@ -0,0 +1,89 @@ +r''' +FX is a toolkit for developers to use to transform ``nn.Module`` +instances. FX consists of three main components: a **symbolic tracer,** +an **intermediate representation**, and **Python code generation**. A +demonstration of these components in action: + +:: + + import torch + # Simple module for demonstration + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return self.linear(x + self.param).clamp(min=0.0, max=1.0) + + module = MyModule() + + from torch.fx import symbolic_trace + # Symbolic tracing frontend - captures the semantics of the module + symbolic_traced : torch.fx.GraphModule = symbolic_trace(module) + + # High-level intermediate representation (IR) - Graph representation + print(symbolic_traced.graph) + """ + graph(): + %x : [num_users=1] = placeholder[target=x] + %param : [num_users=1] = get_attr[target=param] + %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) + %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {}) + %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) + return clamp + """ + + # Code generation - valid Python code + print(symbolic_traced.code) + """ + def forward(self, x): + param = self.param + add = x + param; x = param = None + linear = self.linear(add); add = None + clamp = linear.clamp(min = 0.0, max = 1.0); linear = None + return clamp + """ + +The **symbolic tracer** performs "symbolic execution" of the Python +code. It feeds fake values, called Proxies, through the code. Operations +on theses Proxies are recorded. More information about symbolic tracing +can be found in the :func:`symbolic_trace` and :class:`Tracer` +documentation. + +The **intermediate representation** is the container for the operations +that were recorded during symbolic tracing. It consists of a list of +Nodes that represent function inputs, callsites (to functions, methods, +or :class:`torch.nn.Module` instances), and return values. More information +about the IR can be found in the documentation for :class:`Graph`. The +IR is the format on which transformations are applied. + +**Python code generation** is what makes FX a Python-to-Python (or +Module-to-Module) transformation toolkit. For each Graph IR, we can +create valid Python code matching the Graph's semantics. This +functionality is wrapped up in :class:`GraphModule`, which is a +:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a +``forward`` method generated from the Graph. + +Taken together, this pipeline of components (symbolic tracing -> +intermediate representation -> transforms -> Python code generation) +constitutes the Python-to-Python transformation pipeline of FX. In +addition, these components can be used separately. For example, +symbolic tracing can be used in isolation to capture a form of +the code for analysis (and not transformation) purposes. Code +generation can be used for programmatically generating models, for +example from a config file. There are many uses for FX! + +Several example transformations can be found at the +`examples `__ +repository. +''' + +from .graph_module import GraphModule +from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta +from .graph import Graph, CodeGen +from .node import Node, map_arg, has_side_effect +from .proxy import Proxy +from .interpreter import Interpreter as Interpreter, Transformer as Transformer +from .subgraph_rewriter import replace_pattern diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__init__.pyi b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..750cda338856eb808e136a09f339f224c9627d45 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__init__.pyi @@ -0,0 +1,11 @@ +from ._symbolic_trace import ( + symbolic_trace as symbolic_trace, + Tracer as Tracer, + wrap as wrap, +) +from .graph import Graph as Graph +from .graph_module import GraphModule as GraphModule +from .interpreter import Interpreter as Interpreter, Transformer as Transformer +from .node import has_side_effect as has_side_effect, map_arg as map_arg, Node as Node +from .proxy import Proxy as Proxy +from .subgraph_rewriter import replace_pattern as replace_pattern diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d6454d3299c37f4c33acefed78043b9b03c6247 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20218d1f12484988a5a00b5e329285be8837c25b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0918c2fd5ca7c1f4a40135f7f8674284a7673f76 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1941100c11b7340570cd557f846128df16698d3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/_backward_state.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/_backward_state.py new file mode 100644 index 0000000000000000000000000000000000000000..9c742431857c33af22dbc1ad73b5bdfcf6124b9c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/_backward_state.py @@ -0,0 +1,27 @@ +import torch.fx + + +class BackwardState: + """ + BackwardState is used to pass Python hooks from the forwards pass + into the backwards pass in Dynamo+Compiled Autograd. + + It is created by TorchDynamo and has special handling there. + Dynamo will pass an empty BackwardState to the forwards, then populate + members on it (via setattr) only after the forwards graph is finished. + Later on, in CompileAutograd we will inline and add the needed guards + on the BackwardState. + + BackwardState is identified and has special handling in AOTAutograd. + During AOTAutograd: + 1) BackwardState is an input to the forwards graph + 2) It must only be used in the backwards + 3) It will be empty in the forwards + 4) In the forwards we add a wrapper to save it + 5) In the backwards it becomes an input + 6) There can only be one per graph + + BackwardState requires CompiledAutograd. + """ + + proxy: torch.fx.Proxy diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/merge_matmul.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/merge_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..bd56694773e9b97087b9a2f83b175fa7ec990b04 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/merge_matmul.py @@ -0,0 +1,171 @@ +import torch + +from torch.fx.node import Node +from torch.fx._symbolic_trace import symbolic_trace +from torch.fx.passes.tools_common import legalize_graph +import itertools +import operator + +from typing import Dict, List, Tuple + + +def split_result_tensors( + result: torch.Tensor, inputs: List[torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + """ + A free function for use in the merge_matmul graph transformation below that + splits the output from a merged matmul into the individual results for each + input tensor. + + Arguments: + result: The merged matmul result tensor. + inputs: The list of inputs that were merged into one for the matmul. + + Returns: + List of matmul results for each input tensor. + """ + # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we + # need an int even when tracing + if isinstance(result, torch.fx.Proxy): + splits = [0] * len(inputs) + else: + splits = [x.shape[0] for x in inputs] + + return torch.split(result, splits) + + +def may_depend_on(a: Node, b: Node, search_depth: int = 6): + """ + Determine if one node depends on another in a torch.fx.Graph. + + Arguments: + a: The node that may have a dependency on b. + b: The node that a may have a dependency on. + search_depth: In the case of an indirect dependency, this function + searches upto this many nodes away in search of a + data dependency. If none is found, the function + makes the conservative assumption that there is a + dependency. + + Returns: + True if a may depend on b, False if it definitely does not. + """ + # Equivalence is defined as dependence. + if a == b: + return True + + # If a has no inputs, it cannot depend on b. + if len(a.all_input_nodes) == 0: + return False + + # If the search depth has been exhausted and no conclusion has been + # reached, assume that there is a data dependency. + if search_depth == 0: + return True + + # Recursively check all inputs of a. + for inp in a.all_input_nodes: + if may_depend_on(inp, b, search_depth - 1): + return True + + return False + + +def are_nodes_independent(nodes: List[Node]): + """ + Check if all of the given nodes are pairwise-data independent. + + Arguments: + nodes: The nodes to check for data dependencies. + + Returns: + True if any pair in nodes has a data dependency. + """ + # For each pair in nodes: + for i, j in itertools.combinations(nodes, 2): + if may_depend_on(i, j) or may_depend_on(j, i): + return False + + return True + + +def merge_matmul(in_mod: torch.nn.Module): + """ + A graph transformation that merges matrix multiplication operations that share the same right-hand + side operand into one large matrix multiplication. + ____ _________ _________ + ---- | | | | M| A * C | + M| A | T| B | * K| C | = |---------| + ---- , | | | | T| B * C | + K ---- --------- --------- + K R R + """ + gm = symbolic_trace(in_mod) + + rhs_users: Dict[Node, List[Node]] = {} + lhs_users: Dict[Node, List[Node]] = {} + + # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to + # the matmul of which they are the LHS/RHS. + for node in gm.graph.nodes: + if node.op != "call_function" or node.target is not torch.matmul: + continue + + lhs, rhs = node.args + + # TODO: Properly handle aliasing caused by get_attr. For now, + # use the attribute name as the operand if the node is a + # get_attr. + lhs = lhs.target if lhs.op == "get_attr" else lhs + rhs = rhs.target if rhs.op == "get_attr" else rhs + + lhs_users.setdefault(lhs, []).append(node) + rhs_users.setdefault(rhs, []).append(node) + + for rhs, mms in rhs_users.items(): + # There must be at least matmuls for a merge to make sense. + if len(mms) < 2: + continue + + # All matmuls must not depend on each other directly or indirectly + # in order for the merge to be possible. + if not are_nodes_independent(mms): + continue + + lhs_vals = [mm.args[0] for mm in mms] + + # Merge the matmul. + # Collect a list of LHS operands and the single RHS operand. + lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] + rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs + + # Concatenate all the LHS operands. + merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) + + # Multiply the concatenated LHS operands with the one RHS. This will produce + # the same results as all the individual matmuls involving rhs in the original graph, + # but they will all be concatenated together. + merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) + + # Split the result of the merged matmul using the shapes of the LHS operands + # to ascertain how large each chunk should be. + merge_mm_split = gm.graph.call_function( + split_result_tensors, (merge_mm, lhs), {} + ) + merge_mm_res = [ + gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) + for out in range(len(lhs)) + ] + + # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. + for old, new in zip(mms, merge_mm_res): + old.replace_all_uses_with(new) + gm.graph.erase_node(old) + + # All of the new nodes created above were inserted at the end, so we need to sort + # the nodes topologically to make sure all definitions precede uses. + legalize_graph(gm) + + gm.recompile() + gm.graph.lint() + return gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/meta_tracer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/meta_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..be19e7b93ac8b850cc3619d983ef748b66cfa0fa --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/meta_tracer.py @@ -0,0 +1,268 @@ +import torch +import torch.fx +import warnings +import functools +import builtins + +from typing import Any, Callable, Dict, Optional, Union + +def embedding_override(self, input): + return torch.empty(*input.shape, self.weight.shape[-1], device='meta') + + +def nn_layernorm_override(self, input): + return input + + +def torch_relu_override(x): + return x + + +def torch_nn_relu_override(self, x): + return x + + +def functional_relu_override(x, inplace=False): + assert not inplace, 'dont support inplace functional.relu for metatensor analysis' + return x + + +def torch_where_override(condition, x, y): + # torch.where returns the broadcasted tensor of condition, x, and y, + # so hack it by using addition + return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta') + + +def torch_abs_override(input, *, out=None): + assert out is None, 'Dont support in-place abs for MetaTensor analysis' + return input + +manual_meta_overrides : Dict[Callable, Callable] = { + torch.nn.Embedding: embedding_override, + torch.nn.LayerNorm: nn_layernorm_override, + torch.relu: torch_relu_override, + torch.nn.functional.relu: functional_relu_override, + torch.nn.ReLU: torch_nn_relu_override, + torch.where: torch_where_override, + torch.abs: torch_abs_override, +} + +def gen_constructor_wrapper(target): + @functools.wraps(target) + def wrapper(*args, **kwargs): + proxy = None + + def check_has_proxy(v): + if isinstance(v, torch.fx.Proxy): + nonlocal proxy + proxy = v + torch.fx.node.map_aggregate(args, check_has_proxy) + torch.fx.node.map_aggregate(kwargs, check_has_proxy) + + if proxy is not None: + return proxy.tracer.create_proxy('call_function', target, args, kwargs) + else: + return target(*args, **kwargs) + return wrapper, target + +class MetaProxy(torch.fx.Proxy): + def install_tensor_meta(self, tensor_meta): + self._tensor_meta = tensor_meta + + def size(self, dim=None): + if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + return self._tensor_meta.size(*[dim] if dim else []) + return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {}) + + def dim(self): + if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + return self._tensor_meta.dim() + return self.tracer.create_proxy('call_method', 'dim', (self,), {}) + + @property + def shape(self): + if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + return self._tensor_meta.shape + return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {}) + + @property + def dtype(self): + if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + return self._tensor_meta.dtype + return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {}) + + @property + def device(self): + # Hack so we can track when devices are used. During meta-tensor propagation, + # replace these values with a constant 'meta' + return MetaDeviceAttribute(self, 'device') + + def __getattr__(self, k): + if k == '_tensor_meta': + return self.__getattribute__(k) + # note: not added to the graph yet, if this is a method call + # we peephole optimize to the method invocation + return MetaAttribute(self, k) + +class MetaAttribute(MetaProxy): + def __init__(self, root, attr: str): + + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node = None + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + +class MetaDeviceAttribute(MetaAttribute): + pass + +def proxys_to_metas(v): + if isinstance(v, MetaDeviceAttribute): + return 'meta' + if isinstance(v, torch.fx.Proxy): + assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}' + assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta' + return v._tensor_meta + return v + +class MetaTracer(torch.fx.Tracer): + allow_insert_stateless_mods : bool = True + + _TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye'] + + def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None): + rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + + if kind == 'placeholder' and target in self.meta_args: + rv.install_tensor_meta(self.meta_args[target]) + return rv + + if target in self.orig_fns: + # NOTE: tensor constructors in PyTorch define the `device` argument as + # *kwargs-only*. That is why this works. If you add methods to + # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, + # this will break and you will likely see issues where we cannot infer + # the size of the output. + if 'device' in kwargs: + kwargs['device'] = 'meta' + + try: + args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas) + kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas) + + if kind == 'call_function': + meta_target = manual_meta_overrides.get(target, target) + meta_out = meta_target(*args_metas, **kwargs_metas) + elif kind == 'call_method': + meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) + elif kind == 'call_module': + assert hasattr(self, 'orig_forward') + self._disable_module_getattr = True + try: + mod = self.root.get_submodule(target) + mod_type = type(mod) + if mod_type in manual_meta_overrides: + meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) + else: + meta_out = self.orig_forward(*args_metas, **kwargs_metas) + finally: + self._disable_module_getattr = False + elif kind == 'get_attr': + self._disable_module_getattr = True + try: + attr_itr = self.root + atoms = target.split('.') + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + assert isinstance(attr_itr, torch.Tensor) + meta_out = attr_itr.to(device='meta') + finally: + self._disable_module_getattr = False + else: + return rv + + # TODO + assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet' + rv.install_tensor_meta(meta_out) + except Exception as e: + warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}') + + return rv + + def getattr(self, attr, attr_val, parameter_proxy_cache): + if getattr(self, '_disable_module_getattr', False): + return attr_val + else: + return super().getattr(attr, attr_val, parameter_proxy_cache) + + def call_module(self, m, forward, args, kwargs): + self.orig_forward = forward + return super().call_module(m, forward, args, kwargs) + + def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str: + """ + Helper method which tries to insert a module that was not declared as submodule. + """ + idx = 0 + mod_name = mod.__class__.__name__.lower() + path = f"{mod_name}_{idx}" + while hasattr(self.root, path): + path = f"{mod_name}_{idx}" + idx += 1 + + self.root.add_module(path, mod) + return path + + def path_of_module(self, mod: torch.nn.Module) -> str: + try: + return super().path_of_module(mod) + except NameError as e: + if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0: + path = self._insert_module_as_submodule(mod) + self.prev_module = path + return path + raise + + def proxy(self, node): + return MetaProxy(node, self) + + def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): + assert isinstance(meta_args, dict) + self.meta_args = meta_args + + self.patched_torch_methods = { + target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH + } + self.orig_fns = set() + + for name, (wrapper, orig) in self.patched_torch_methods.items(): + setattr(torch, name, wrapper) + self.orig_fns.add(orig) + + try: + graph = super().trace(root, concrete_args) + graph._tracer_extras = {'meta_args': meta_args} + return graph + finally: + for name, (_, orig) in self.patched_torch_methods.items(): + setattr(torch, name, orig) + + +def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], + meta_args : Optional[Dict[str, torch.Tensor]] = None, + concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule: + tracer = MetaTracer() + graph = tracer.trace(root, meta_args, concrete_args) + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + gm = torch.fx.GraphModule(tracer.root, graph, name) + return gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69ad2f0cad21c00ea878b796e123b5c465033624 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9398cad46c6637f29c81a2306ca67a65af9f59f3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/core.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/core.py new file mode 100644 index 0000000000000000000000000000000000000000..560ceb588924d69e0721f261c107d17ee494ef95 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/core.py @@ -0,0 +1,118 @@ +from collections.abc import Iterator # type: ignore[import] +from functools import partial + +from .unification_tools import assoc # type: ignore[import] +from .utils import transitive_get as walk +from .variable import isvar +from .dispatch import dispatch + +__all__ = ["reify", "unify"] + +############### +# Reification # +############### + +@dispatch(Iterator, dict) +def _reify(t, s): + return map(partial(reify, s=s), t) + # return (reify(arg, s) for arg in t) +_reify + +@dispatch(tuple, dict) # type: ignore[no-redef] +def _reify(t, s): + return tuple(reify(iter(t), s)) +_reify + +@dispatch(list, dict) # type: ignore[no-redef] +def _reify(t, s): + return list(reify(iter(t), s)) +_reify + +@dispatch(dict, dict) # type: ignore[no-redef] +def _reify(d, s): + return {k: reify(v, s) for k, v in d.items()} +_reify + +@dispatch(object, dict) # type: ignore[no-redef] +def _reify(o, s): + return o # catch all, just return the object + +def reify(e, s): + """ Replace variables of expression with substitution + >>> # xdoctest: +SKIP + >>> x, y = var(), var() + >>> e = (1, x, (3, y)) + >>> s = {x: 2, y: 4} + >>> reify(e, s) + (1, 2, (3, 4)) + >>> e = {1: x, 3: (y, 5)} + >>> reify(e, s) + {1: 2, 3: (4, 5)} + """ + if isvar(e): + return reify(s[e], s) if e in s else e + return _reify(e, s) + +############### +# Unification # +############### + +seq = tuple, list, Iterator + +@dispatch(seq, seq, dict) +def _unify(u, v, s): + if len(u) != len(v): + return False + for uu, vv in zip(u, v): # avoiding recursion + s = unify(uu, vv, s) + if s is False: + return False + return s +# +# @dispatch((set, frozenset), (set, frozenset), dict) +# def _unify(u, v, s): +# i = u & v +# u = u - i +# v = v - i +# return _unify(sorted(u), sorted(v), s) +# +# +# @dispatch(dict, dict, dict) +# def _unify(u, v, s): +# if len(u) != len(v): +# return False +# for key, uval in iteritems(u): +# if key not in v: +# return False +# s = unify(uval, v[key], s) +# if s is False: +# return False +# return s +# +# +# @dispatch(object, object, dict) +# def _unify(u, v, s): +# return False # catch all + + +@dispatch(object, object, dict) +def unify(u, v, s): # no check at the moment + """ Find substitution so that u == v while satisfying s + >>> x = var('x') + >>> unify((1, x), (1, 2), {}) + {~x: 2} + """ + u = walk(u, s) + v = walk(v, s) + if u == v: + return s + if isvar(u): + return assoc(s, u, v) + if isvar(v): + return assoc(s, v, u) + return _unify(u, v, s) +unify + +@dispatch(object, object) # type: ignore[no-redef] +def unify(u, v): + return unify(u, v, {}) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/dispatch.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..93039ce75070fec8da52d03067d5c0b851a79b50 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/dispatch.py @@ -0,0 +1,6 @@ +from functools import partial +from .multipledispatch import dispatch # type: ignore[import] + +namespace = {} # type: ignore[var-annotated] + +dispatch = partial(dispatch, namespace=namespace) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/more.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/more.py new file mode 100644 index 0000000000000000000000000000000000000000..2b074235f14a2adc56a07eac9959a67e49f614e2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/more.py @@ -0,0 +1,117 @@ +from .core import unify, reify # type: ignore[attr-defined] +from .dispatch import dispatch + + +def unifiable(cls): + """ Register standard unify and reify operations on class + This uses the type and __dict__ or __slots__ attributes to define the + nature of the term + See Also: + >>> # xdoctest: +SKIP + >>> class A(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + >>> unifiable(A) + + >>> x = var('x') + >>> a = A(1, 2) + >>> b = A(1, x) + >>> unify(a, b, {}) + {~x: 2} + """ + _unify.add((cls, cls, dict), unify_object) + _reify.add((cls, dict), reify_object) + + return cls + + +######### +# Reify # +######### + + +def reify_object(o, s): + """ Reify a Python object with a substitution + >>> # xdoctest: +SKIP + >>> class Foo(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + ... def __str__(self): + ... return "Foo(%s, %s)"%(str(self.a), str(self.b)) + >>> x = var('x') + >>> f = Foo(1, x) + >>> print(f) + Foo(1, ~x) + >>> print(reify_object(f, {x: 2})) + Foo(1, 2) + """ + if hasattr(o, '__slots__'): + return _reify_object_slots(o, s) + else: + return _reify_object_dict(o, s) + + +def _reify_object_dict(o, s): + obj = object.__new__(type(o)) + d = reify(o.__dict__, s) + if d == o.__dict__: + return o + obj.__dict__.update(d) + return obj + + +def _reify_object_slots(o, s): + attrs = [getattr(o, attr) for attr in o.__slots__] + new_attrs = reify(attrs, s) + if attrs == new_attrs: + return o + else: + newobj = object.__new__(type(o)) + for slot, attr in zip(o.__slots__, new_attrs): + setattr(newobj, slot, attr) + return newobj + + +@dispatch(slice, dict) +def _reify(o, s): + """ Reify a Python ``slice`` object """ + return slice(*reify((o.start, o.stop, o.step), s)) + + +######### +# Unify # +######### + + +def unify_object(u, v, s): + """ Unify two Python objects + Unifies their type and ``__dict__`` attributes + >>> # xdoctest: +SKIP + >>> class Foo(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + ... def __str__(self): + ... return "Foo(%s, %s)"%(str(self.a), str(self.b)) + >>> x = var('x') + >>> f = Foo(1, x) + >>> g = Foo(1, 2) + >>> unify_object(f, g, {}) + {~x: 2} + """ + if type(u) != type(v): + return False + if hasattr(u, '__slots__'): + return unify([getattr(u, slot) for slot in u.__slots__], + [getattr(v, slot) for slot in v.__slots__], + s) + else: + return unify(u.__dict__, v.__dict__, s) + + +@dispatch(slice, slice, dict) +def _unify(u, v, s): + """ Unify a Python ``slice`` object """ + return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31be31f087b7de25d8a0f1678b67113625a6d227 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81f4a22be95b96ebf22f524efeb35a849d35a27c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py new file mode 100644 index 0000000000000000000000000000000000000000..71db96dd476e85e51ac9e0bd70b9901b0796e2af --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py @@ -0,0 +1,119 @@ +from .utils import _toposort, groupby +from .variadic import isvariadic + +__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature", + "edge", "ordering"] + +class AmbiguityWarning(Warning): + pass + + +def supercedes(a, b): + """ A is consistent and strictly more specific than B """ + if len(a) < len(b): + # only case is if a is empty and b is variadic + return not a and len(b) == 1 and isvariadic(b[-1]) + elif len(a) == len(b): + return all(map(issubclass, a, b)) + else: + # len(a) > len(b) + p1 = 0 + p2 = 0 + while p1 < len(a) and p2 < len(b): + cur_a = a[p1] + cur_b = b[p2] + if not (isvariadic(cur_a) or isvariadic(cur_b)): + if not issubclass(cur_a, cur_b): + return False + p1 += 1 + p2 += 1 + elif isvariadic(cur_a): + assert p1 == len(a) - 1 + return p2 == len(b) - 1 and issubclass(cur_a, cur_b) + elif isvariadic(cur_b): + assert p2 == len(b) - 1 + if not issubclass(cur_a, cur_b): + return False + p1 += 1 + return p2 == len(b) - 1 and p1 == len(a) + + +def consistent(a, b): + """ It is possible for an argument list to satisfy both A and B """ + + # Need to check for empty args + if not a: + return not b or isvariadic(b[0]) + if not b: + return not a or isvariadic(a[0]) + + # Non-empty args check for mutual subclasses + if len(a) == len(b): + return all(issubclass(aa, bb) or issubclass(bb, aa) + for aa, bb in zip(a, b)) + else: + p1 = 0 + p2 = 0 + while p1 < len(a) and p2 < len(b): + cur_a = a[p1] + cur_b = b[p2] + if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b): + return False + if not (isvariadic(cur_a) or isvariadic(cur_b)): + p1 += 1 + p2 += 1 + elif isvariadic(cur_a): + p2 += 1 + elif isvariadic(cur_b): + p1 += 1 + # We only need to check for variadic ends + # Variadic types are guaranteed to be the last element + return (isvariadic(cur_a) and p2 == len(b) or # type: ignore[possibly-undefined] + isvariadic(cur_b) and p1 == len(a)) # type: ignore[possibly-undefined] + + +def ambiguous(a, b): + """ A is consistent with B but neither is strictly more specific """ + return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) + + +def ambiguities(signatures): + """ All signature pairs such that A is ambiguous with B """ + signatures = list(map(tuple, signatures)) + return {(a, b) for a in signatures for b in signatures + if hash(a) < hash(b) + and ambiguous(a, b) + and not any(supercedes(c, a) and supercedes(c, b) + for c in signatures)} + + +def super_signature(signatures): + """ A signature that would break ambiguities """ + n = len(signatures[0]) + assert all(len(s) == n for s in signatures) + + return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] + for i in range(n)] + + +def edge(a, b, tie_breaker=hash): + """ A should be checked before B + Tie broken by tie_breaker, defaults to ``hash`` + """ + # A either supercedes B and B does not supercede A or if B does then call + # tie_breaker + return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)) + + +def ordering(signatures): + """ A sane ordering of signatures to check, first to last + Topological sort of edges as given by ``edge`` and ``supercedes`` + """ + signatures = list(map(tuple, signatures)) + edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] + edges = groupby(lambda x: x[0], edges) + for s in signatures: + if s not in edges: + edges[s] = [] + edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[assignment, attr-defined] + return _toposort(edges) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/core.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/core.py new file mode 100644 index 0000000000000000000000000000000000000000..2a8ed78e52e364852ce557f18a633b45e87ee2b0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/core.py @@ -0,0 +1,83 @@ +import inspect +import sys + +from .dispatcher import Dispatcher, MethodDispatcher + +global_namespace = {} # type: ignore[var-annotated] + +__all__ = ["dispatch", "ismethod"] + +def dispatch(*types, **kwargs): + """ Dispatch function on the types of the inputs + Supports dispatch on all non-keyword arguments. + Collects implementations based on the function name. Ignores namespaces. + If ambiguous type signatures occur a warning is raised when the function is + defined suggesting the additional method to break the ambiguity. + + Example: + >>> # xdoctest: +SKIP + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + >>> @dispatch(float) + ... def f(x): + ... return x - 1 + >>> # xdoctest: +SKIP + >>> f(3) + 4 + >>> f(3.0) + 2.0 + >>> # Specify an isolated namespace with the namespace keyword argument + >>> my_namespace = {} + >>> @dispatch(int, namespace=my_namespace) + ... def foo(x): + ... return x + 1 + >>> # Dispatch on instance methods within classes + >>> class MyClass(object): + ... @dispatch(list) + ... def __init__(self, data): + ... self.data = data + ... @dispatch(int) + ... def __init__(self, datum): + ... self.data = [datum] + >>> MyClass([1, 2, 3]).data + [1, 2, 3] + >>> MyClass(3).data + [3] + """ + namespace = kwargs.get('namespace', global_namespace) + + types = tuple(types) + + def _df(func): + name = func.__name__ + + if ismethod(func): + dispatcher = inspect.currentframe().f_back.f_locals.get( # type: ignore[union-attr] + name, # type: ignore[union-attr] + MethodDispatcher(name), + ) + else: + if name not in namespace: + namespace[name] = Dispatcher(name) + dispatcher = namespace[name] + + dispatcher.add(types, func) + return dispatcher + return _df + + +def ismethod(func): + """ Is func a method? + Note that this has to work as the method is defined but before the class is + defined. At this stage methods look like functions. + """ + if hasattr(inspect, "signature"): + signature = inspect.signature(func) + return signature.parameters.get('self', None) is not None + else: + if sys.version_info.major < 3: + spec = inspect.getargspec(func) # type: ignore[attr-defined] + else: + spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment] + return spec and spec.args and spec.args[0] == 'self' diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..d2a8e6bfc7ffffad565ea82747a417bd863608d2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -0,0 +1,430 @@ +from warnings import warn +import inspect +from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning +from .utils import expand_tuples +from .variadic import Variadic, isvariadic +import itertools as itl + +__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter", + "variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"] + +class MDNotImplementedError(NotImplementedError): + """ A NotImplementedError for multiple dispatch """ + + +def ambiguity_warn(dispatcher, ambiguities): + """ Raise warning when ambiguity is detected + Parameters + ---------- + dispatcher : Dispatcher + The dispatcher on which the ambiguity was detected + ambiguities : set + Set of type signature pairs that are ambiguous within this dispatcher + See Also: + Dispatcher.add + warning_text + """ + warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) + + +def halt_ordering(): + """Deprecated interface to temporarily disable ordering. + """ + warn( + 'halt_ordering is deprecated, you can safely remove this call.', + DeprecationWarning, + ) + + +def restart_ordering(on_ambiguity=ambiguity_warn): + """Deprecated interface to temporarily resume ordering. + """ + warn( + 'restart_ordering is deprecated, if you would like to eagerly order' + 'the dispatchers, you should call the ``reorder()`` method on each' + ' dispatcher.', + DeprecationWarning, + ) + + +def variadic_signature_matches_iter(types, full_signature): + """Check if a set of input types matches a variadic signature. + Notes + ----- + The algorithm is as follows: + Initialize the current signature to the first in the sequence + For each type in `types`: + If the current signature is variadic + If the type matches the signature + yield True + Else + Try to get the next signature + If no signatures are left we can't possibly have a match + so yield False + Else + yield True if the type matches the current signature + Get the next signature + """ + sigiter = iter(full_signature) + sig = next(sigiter) + for typ in types: + matches = issubclass(typ, sig) + yield matches + if not isvariadic(sig): + # we're not matching a variadic argument, so move to the next + # element in the signature + sig = next(sigiter) + else: + try: + sig = next(sigiter) + except StopIteration: + assert isvariadic(sig) + yield True + else: + # We have signature items left over, so all of our arguments + # haven't matched + yield False + + +def variadic_signature_matches(types, full_signature): + # No arguments always matches a variadic signature + assert full_signature + return all(variadic_signature_matches_iter(types, full_signature)) + + +class Dispatcher: + """ Dispatch methods based on type signature + Use ``dispatch`` to add implementations + Examples + -------- + >>> # xdoctest: +SKIP("bad import name") + >>> from multipledispatch import dispatch + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + >>> @dispatch(float) + ... def f(x): + ... return x - 1 + >>> f(3) + 4 + >>> f(3.0) + 2.0 + """ + __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc' + + def __init__(self, name, doc=None): + self.name = self.__name__ = name + self.funcs = {} + self.doc = doc + + self._cache = {} + + def register(self, *types, **kwargs): + """ register dispatcher with new implementation + >>> # xdoctest: +SKIP + >>> f = Dispatcher('f') + >>> @f.register(int) + ... def inc(x): + ... return x + 1 + >>> @f.register(float) + ... def dec(x): + ... return x - 1 + >>> @f.register(list) + ... @f.register(tuple) + ... def reverse(x): + ... return x[::-1] + >>> f(1) + 2 + >>> f(1.0) + 0.0 + >>> f([1, 2, 3]) + [3, 2, 1] + """ + def _df(func): + self.add(types, func, **kwargs) # type: ignore[call-arg] + return func + return _df + + @classmethod + def get_func_params(cls, func): + if hasattr(inspect, "signature"): + sig = inspect.signature(func) + return sig.parameters.values() + + @classmethod + def get_func_annotations(cls, func): + """ get annotations of function positional parameters + """ + params = cls.get_func_params(func) + if params: + Parameter = inspect.Parameter + + params = (param for param in params + if param.kind in + (Parameter.POSITIONAL_ONLY, + Parameter.POSITIONAL_OR_KEYWORD)) + + annotations = tuple( + param.annotation + for param in params) + + if all(ann is not Parameter.empty for ann in annotations): + return annotations + + def add(self, signature, func): + """ Add new types/method pair to dispatcher + >>> # xdoctest: +SKIP + >>> D = Dispatcher('add') + >>> D.add((int, int), lambda x, y: x + y) + >>> D.add((float, float), lambda x, y: x + y) + >>> D(1, 2) + 3 + >>> D(1, 2.0) + Traceback (most recent call last): + ... + NotImplementedError: Could not find signature for add: + >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback + >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs + >>> # as inputs. See ``ambiguity_warn`` for an example. + """ + # Handle annotations + if not signature: + annotations = self.get_func_annotations(func) + if annotations: + signature = annotations + + # Handle union types + if any(isinstance(typ, tuple) for typ in signature): + for typs in expand_tuples(signature): + self.add(typs, func) + return + + new_signature = [] + + for index, typ in enumerate(signature, start=1): + if not isinstance(typ, (type, list)): + str_sig = ', '.join(c.__name__ if isinstance(c, type) + else str(c) for c in signature) + raise TypeError(f"Tried to dispatch on non-type: {typ}\n" + f"In signature: <{str_sig}>\n" + f"In function: {self.name}") + + # handle variadic signatures + if isinstance(typ, list): + if index != len(signature): + raise TypeError( + 'Variadic signature must be the last element' + ) + + if len(typ) != 1: + raise TypeError( + 'Variadic signature must contain exactly one element. ' + 'To use a variadic union type place the desired types ' + 'inside of a tuple, e.g., [(int, str)]' + ) + new_signature.append(Variadic[typ[0]]) + else: + new_signature.append(typ) + + self.funcs[tuple(new_signature)] = func + self._cache.clear() + + try: + del self._ordering + except AttributeError: + pass + + @property + def ordering(self): + try: + return self._ordering + except AttributeError: + return self.reorder() + + def reorder(self, on_ambiguity=ambiguity_warn): + self._ordering = od = ordering(self.funcs) + amb = ambiguities(self.funcs) + if amb: + on_ambiguity(self, amb) + return od + + def __call__(self, *args, **kwargs): + types = tuple([type(arg) for arg in args]) + try: + func = self._cache[types] + except KeyError as e: + func = self.dispatch(*types) + if not func: + raise NotImplementedError( + f'Could not find signature for {self.name}: <{str_signature(types)}>') from e + self._cache[types] = func + try: + return func(*args, **kwargs) + + except MDNotImplementedError as e: + funcs = self.dispatch_iter(*types) + next(funcs) # burn first + for func in funcs: + try: + return func(*args, **kwargs) + except MDNotImplementedError: + pass + + raise NotImplementedError( + "Matching functions for " + f"{self.name}: <{str_signature(types)}> found, but none completed successfully",) from e + + def __str__(self): + return f"" + __repr__ = __str__ + + def dispatch(self, *types): + """Determine appropriate implementation for this type signature + This method is internal. Users should call this object as a function. + Implementation resolution occurs within the ``__call__`` method. + >>> # xdoctest: +SKIP + >>> from multipledispatch import dispatch + >>> @dispatch(int) + ... def inc(x): + ... return x + 1 + >>> implementation = inc.dispatch(int) + >>> implementation(3) + 4 + >>> print(inc.dispatch(float)) + None + See Also: + ``multipledispatch.conflict`` - module to determine resolution order + """ + + if types in self.funcs: + return self.funcs[types] + + try: + return next(self.dispatch_iter(*types)) + except StopIteration: + return None + + def dispatch_iter(self, *types): + + n = len(types) + for signature in self.ordering: + if len(signature) == n and all(map(issubclass, types, signature)): + result = self.funcs[signature] + yield result + elif len(signature) and isvariadic(signature[-1]): + if variadic_signature_matches(types, signature): + result = self.funcs[signature] + yield result + + def resolve(self, types): + """ Determine appropriate implementation for this type signature + .. deprecated:: 0.4.4 + Use ``dispatch(*types)`` instead + """ + warn("resolve() is deprecated, use dispatch(*types)", + DeprecationWarning) + + return self.dispatch(*types) + + def __getstate__(self): + return {'name': self.name, + 'funcs': self.funcs} + + def __setstate__(self, d): + self.name = d['name'] + self.funcs = d['funcs'] + self._ordering = ordering(self.funcs) + self._cache = {} + + @property + def __doc__(self): + docs = [f"Multiply dispatched method: {self.name}"] + + if self.doc: + docs.append(self.doc) + + other = [] + for sig in self.ordering[::-1]: + func = self.funcs[sig] + if func.__doc__: + s = f'Inputs: <{str_signature(sig)}>\n' + s += '-' * len(s) + '\n' + s += func.__doc__.strip() + docs.append(s) + else: + other.append(str_signature(sig)) + + if other: + docs.append('Other signatures:\n ' + '\n '.join(other)) + + return '\n\n'.join(docs) + + def _help(self, *args): + return self.dispatch(*map(type, args)).__doc__ + + def help(self, *args, **kwargs): + """ Print docstring for the function corresponding to inputs """ + print(self._help(*args)) + + def _source(self, *args): + func = self.dispatch(*map(type, args)) + if not func: + raise TypeError("No function found") + return source(func) + + def source(self, *args, **kwargs): + """ Print source code for the function corresponding to inputs """ + print(self._source(*args)) + + +def source(func): + s = f'File: {inspect.getsourcefile(func)}\n\n' + s = s + inspect.getsource(func) + return s + + +class MethodDispatcher(Dispatcher): + """ Dispatch methods based on type signature + See Also: + Dispatcher + """ + __slots__ = ('obj', 'cls') + + @classmethod + def get_func_params(cls, func): + if hasattr(inspect, "signature"): + sig = inspect.signature(func) + return itl.islice(sig.parameters.values(), 1, None) + + def __get__(self, instance, owner): + self.obj = instance + self.cls = owner + return self + + def __call__(self, *args, **kwargs): + types = tuple([type(arg) for arg in args]) + func = self.dispatch(*types) + if not func: + raise NotImplementedError(f'Could not find signature for {self.name}: <{str_signature(types)}>') + return func(self.obj, *args, **kwargs) + + +def str_signature(sig): + """ String representation of type signature + >>> str_signature((int, float)) + 'int, float' + """ + return ', '.join(cls.__name__ for cls in sig) + + +def warning_text(name, amb): + """ The text for ambiguity warnings """ + text = f"\nAmbiguities exist in dispatched function {name}\n\n" + text += "The following signatures may result in ambiguous behavior:\n" + for pair in amb: + text += "\t" + \ + ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" + text += "\n\nConsider making the following additions:\n\n" + text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) + + f')\ndef {name}(...)' for s in amb]) + return text diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/node.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/node.py new file mode 100644 index 0000000000000000000000000000000000000000..cc07a5f9dcfc42ecac326b4f5efd9e67b84beebf --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/node.py @@ -0,0 +1,726 @@ +# mypy: ignore-errors + +# Nodes represent a definition of a value in our graph of operators. +from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set +from ._compatibility import compatibility +from .immutable_collections import immutable_dict, immutable_list +import torch +import builtins +import types +import inspect +import warnings +from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair +from .._ops import ops as _ops + +if TYPE_CHECKING: + from .graph import Graph + +__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"] + +BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, + torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload] +base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] + +Target = Union[Callable[..., Any], str] + +Argument = Optional[Union[ + Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types + List[Any], # actually Argument + Dict[str, Any], # actually Argument + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + range, + 'Node', + BaseArgumentTypes +]] + +_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = { + torch._C._set_grad_enabled, + torch.amp._enter_autocast, + torch.amp._exit_autocast, +} + +# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs, +# or add logic to correctly mark all inplace ops as side effectful. +_side_effectful_functions: Set[Callable] = { + torch._assert, + torch._assert_async, + _ops.aten._assert_async.msg, + _ops.aten._assert_scalar.default, + _ops.aten.copy_.default, + _ops.aten.index_put_.default, + _ops.aten.sym_constrain_range.default, + _ops.aten.sym_constrain_range_for_size.default, + _ops.profiler._record_function_enter, + _ops.profiler._record_function_enter_new, + _ops.profiler._record_function_exit, + _ops.inductor.accumulate_grad_.default, + _ops.inductor.resize_storage_bytes_.default, +} | _side_effectful_need_to_be_preserved_pre_dispatch + + +@compatibility(is_backward_compatible=False) +def has_side_effect(fn: Callable) -> None: + _side_effectful_functions.add(fn) + return fn + + +# this is fixed on master, WAR for 1.5 +def _find_module_of_method(orig_method: Callable[..., Any]) -> str: + name = orig_method.__name__ + module = orig_method.__module__ + if module is not None: + return module + for guess in [torch, torch.nn.functional]: + if getattr(guess, name, None) is orig_method: + return guess.__name__ + raise RuntimeError(f'cannot find module for {orig_method}') + +# Borrowed from CPython typing module +# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156 +def _type_repr(obj): + """Return the repr() of an object, special-casing types (internal helper). + If obj is a type, we return a shorter version than the default + type.__repr__, based on the module and qualified name, which is + typically enough to uniquely identify a type. For everything + else, we fall back on repr(obj). + """ + if isinstance(obj, type): + if obj.__module__ == 'builtins': + return obj.__qualname__ + return f'{obj.__module__}.{obj.__qualname__}' + if obj is ...: + return '...' + if isinstance(obj, types.FunctionType): + return obj.__name__ + return repr(obj) + +def _get_qualified_name(func: Callable[..., Any]) -> str: + # things like getattr just appear in builtins + if getattr(builtins, func.__name__, None) is func: + return func.__name__ + # torch.Tensor.{fn} + if (isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType)) + and func is getattr(torch.Tensor, func.__name__, None)): + return f"torch.Tensor.{func.__name__}" + name = func.__name__ + if name == "": + # For lambdas, try to get their defining name in the module + try: + name = inspect.getsource(func).split("=")[0].strip() + except Exception as e: + raise RuntimeError("Unable to represent lambda") from e + module = _find_module_of_method(func) + module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module + # Fixup segment_reduce mismatch + if module == "torch" and name == "segment_reduce": + name = "_" + name + return f'{module}.{name}' + +def _format_arg(arg, max_list_len=float('inf')) -> str: + if hasattr(arg, '_custom_fx_repr_fn'): + return arg._custom_fx_repr_fn() + elif isinstance(arg, list): + items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) + maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' + return f'[{items}{maybe_len}]' + elif isinstance(arg, tuple): + items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) + maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' + maybe_comma = ',' if len(arg) == 1 else '' + return f'({items}{maybe_comma}{maybe_len})' + elif isinstance(arg, dict): + items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items()) + return f'{{{items_str}}}' + + if isinstance(arg, Node): + return '%' + str(arg) + else: + return str(arg) + +@compatibility(is_backward_compatible=True) +class Node: + """ + ``Node`` is the data structure that represents individual operations within + a ``Graph``. For the most part, Nodes represent callsites to various entities, + such as operators, methods, and Modules (some exceptions include nodes that + specify function inputs and outputs). Each ``Node`` has a function specified + by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows: + + - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. + ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument + denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to + the function parameters (e.g. ``x``) in the graph printout. + - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the + fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. + ``args`` and ``kwargs`` are don't-care + - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign + to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, + following the Python calling convention + - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is + as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. + ``args`` and ``kwargs`` represent the arguments to invoke the module on, *excluding the self argument*. + - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method + to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, + *including the self argument* + - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement + in the Graph printout. + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', + args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], + return_type : Optional[Any] = None) -> None: + """ + Instantiate an instance of ``Node``. Note: most often, you want to use the + Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather + than instantiating a ``Node`` directly. + + Args: + graph (Graph): The ``Graph`` to which this ``Node`` should belong. + + name (str): The name to which the output of this ``Node`` should be assigned + + op (str): The opcode for this ``Node``. Can be one of 'placeholder', + 'call_method', 'call_module', 'call_function', 'get_attr', + 'output' + + target ('Target'): The target this op should call. See the broader + ``Node`` docstring for more details. + + args (Tuple['Argument']): The args to be passed to ``target`` + + kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target`` + + return_type (Optional[Any]): The python type expression representing the + type of the output of this node. This field can be used for + annotation of values in the generated code or for other types + of analyses. + """ + self.graph = graph + self.name = name # unique name of value being created + assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'] + self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr + if op == 'call_function': + if not callable(target): + raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' + 'but a Callable is expected') + else: + if not isinstance(target, str): + raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' + 'but a str is expected') + self.target = target # for method/module/function, the name of the method/module/function/attr + # being invoked, e.g add, layer1, or torch.add + + # All `Node`-valued inputs. Key is the Node, value is don't-care. + # The public API for this is `all_input_nodes`, this private attribute + # should not be accessed directly. + self._input_nodes : Dict[Node, None] = {} + self.__update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore[arg-type] + + # All of the nodes that use the value produced by this Node + # Note one user may correspond to several uses, e.g. the node fo ``x + x`` + # would appear once here, but represents two uses. + # + # Is a dict to act as an "ordered set". Keys are significant, value dont-care + self.users : Dict[Node, None] = {} + # Type expression representing the output value of this node. + # This should contain the same class of Type objects that would appear + # as type annotations for function inputs/outputs. + # + # For placeholder nodes, this value will be used to type-annotate the + # generated function parameters. + # For the return node, this value will be used to type-annotate the + # generated function return type. (Note this is a special case. ``return`` + # does not produce a value, it's more of a notation. Thus, this value + # describes the type of args[0] in the ``return`` node. + self.type : Optional[Any] = return_type + self._prev = self + self._next = self + self._erased = False + + # If set, use this fn to print this node + self._repr_fn : Optional[Callable[[Node], str]] = None + + # Dictionary to store metadata passes need to do their + # transformations. This metadata is preserved across node copies + self.meta : Dict[str, Any] = {} + + @property + def next(self) -> 'Node': + """ + Returns the next ``Node`` in the linked list of Nodes. + + Returns: + + The next ``Node`` in the linked list of Nodes. + """ + return self._next + + @property + def prev(self) -> 'Node': + """ + Returns the previous ``Node`` in the linked list of Nodes. + + Returns: + + The previous ``Node`` in the linked list of Nodes. + """ + return self._prev + + @compatibility(is_backward_compatible=True) + def prepend(self, x: 'Node') -> None: + """ + Insert x before this node in the list of nodes in the graph. Example:: + + Before: p -> self + bx -> x -> ax + After: p -> x -> self + bx -> ax + + Args: + x (Node): The node to put before this node. Must be a member of the same graph. + """ + assert self.graph == x.graph, "Attempting to move a Node into a different Graph" + if self == x: + warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.") + return + x._remove_from_list() + p = self._prev + p._next, x._prev = x, p + x._next, self._prev = self, x + + @compatibility(is_backward_compatible=True) + def append(self, x: 'Node') -> None: + """ + Insert ``x`` after this node in the list of nodes in the graph. + Equivalent to ``self.next.prepend(x)`` + + Args: + x (Node): The node to put after this node. Must be a member of the same graph. + """ + self._next.prepend(x) + + def _remove_from_list(self): + p, n = self._prev, self._next + p._next, n._prev = n, p + + @property + def args(self) -> Tuple[Argument, ...]: + """ + The tuple of arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more + information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. + """ + return self._args + + @args.setter + def args(self, a : Tuple[Argument, ...]): + """ + Set the tuple of arguments to this Node. The interpretation of arguments + depends on the node's opcode. See the ``fx.Graph`` docstring for more + information. + """ + # DO NOT CALL `__update_args_kwargs` directly. The correct way to + # set `args` is via direct assignment, i.e. `node.args = new_args` + self.__update_args_kwargs(map_arg(a, lambda x: x), self._kwargs) # type: ignore[arg-type] + + @property + def kwargs(self) -> Dict[str, Argument]: + """ + The dict of keyword arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more + information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. + """ + return self._kwargs + + @kwargs.setter + def kwargs(self, k : Dict[str, Argument]): + """ + Set the dict of kwargs to this Node. The interpretation of arguments + depends on the node's opcode. See the ``fx.Graph`` docstring for more + information. + """ + # DO NOT CALL `__update_args_kwargs` directly. The correct way to + # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` + self.__update_args_kwargs(self._args, map_arg(k, lambda x: x)) # type: ignore[arg-type] + + @property + def all_input_nodes(self) -> List['Node']: + """ + Return all Nodes that are inputs to this Node. This is equivalent to + iterating over ``args`` and ``kwargs`` and only collecting the values that + are Nodes. + + Returns: + + List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this + ``Node``, in that order. + """ + return list(self._input_nodes.keys()) + + @compatibility(is_backward_compatible=True) + def update_arg(self, idx : int, arg : Argument) -> None: + """ + Update an existing positional argument to contain the new value + ``arg``. After calling, ``self.args[idx] == arg``. + + Args: + + idx (int): The index into ``self.args`` of the element to update + arg (Argument): The new argument value to write into ``args`` + """ + args = list(self.args) + args[idx] = arg + self.args = tuple(args) + + @compatibility(is_backward_compatible=True) + def insert_arg(self, idx : int, arg : Argument) -> None: + """ + Insert an positional argument to the argument list with given index. + + Args: + + idx (int): The index of the element in ``self.args`` to be inserted before. + arg (Argument): The new argument value to insert into ``args`` + """ + assert 0 <= idx <= len(self.args), "insert_args index must be between 0 and len(self.args)" + args_left = self.args[:idx] + args_right = self.args[idx:] + + self._args = args_left + (arg,) + args_right + + _new_input_nodes = {} + map_arg(arg, _new_input_nodes.setdefault) + + for new_use in _new_input_nodes.keys(): + if new_use not in self._input_nodes: + self._input_nodes.setdefault(new_use) + new_use.users.setdefault(self) + + @compatibility(is_backward_compatible=True) + def update_kwarg(self, key : str, arg : Argument) -> None: + """ + Update an existing keyword argument to contain the new value + ``arg``. After calling, ``self.kwargs[key] == arg``. + + Args: + + key (str): The key in ``self.kwargs`` of the element to update + arg (Argument): The new argument value to write into ``kwargs`` + """ + kwargs = dict(self.kwargs) + kwargs[key] = arg + self.kwargs = kwargs + + @property + def stack_trace(self) -> Optional[str]: + """ + Return the Python stack trace that was recorded during tracing, if any. + When traced with fx.Tracer, this property is usually populated by + `Tracer.create_proxy`. To record stack traces during tracing for debug purposes, + set `record_stack_traces = True` on the `Tracer` instance. + When traced with dynamo, this property will be populated by default by + `OutputGraph.create_proxy`. + + stack_trace would have the innermost frame at the end of the string. + """ + return self.meta.get("stack_trace", None) + + @stack_trace.setter + def stack_trace(self, trace : Optional[str]): + self.meta["stack_trace"] = trace + + def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']): + """ + This API is internal. Do *not* call it directly. + """ + self._args = new_args + self._kwargs = new_kwargs + + for old_use in self._input_nodes.keys(): + old_use.users.pop(self) + + self._input_nodes = {} + map_arg(self._args, self._input_nodes.setdefault) + map_arg(self._kwargs, self._input_nodes.setdefault) + + for new_use in self._input_nodes.keys(): + new_use.users.setdefault(self) + + def __repr__(self) -> str: + if self._repr_fn: + return self._repr_fn(self) + return self.name + + def _pretty_print_target(self, target): + """ + Make target printouts more user-friendly. + 1) builtins will be printed as `builtins.xyz` + 2) operators will be printed as `operator.xyz` + 3) other callables will be printed with qualified name, e.g. torch.add + """ + if isinstance(target, str): + return target + if hasattr(target, '__module__'): + if not hasattr(target, '__name__'): + # Just to be defensive, if we don't have `__name__`, get the + # qualname. Not sure if this happens for any members of `operator` + # or `builtins`. This fallback path is not as good, since e.g. + # things in `operator` have `_operator` as their __module__. + return _get_qualified_name(target) + if target.__module__ == 'builtins': + return f'builtins.{target.__name__}' + elif target.__module__ == '_operator': + return f'operator.{target.__name__}' + return _get_qualified_name(target) + + @compatibility(is_backward_compatible=True) + def format_node(self, + placeholder_names: Optional[List[str]] = None, + maybe_return_typename: Optional[List[str]] = None) -> Optional[str]: + """ + Return a descriptive string representation of ``self``. + + This method can be used with no arguments as a debugging + utility. + + This function is also used internally in the ``__str__`` method + of ``Graph``. Together, the strings in ``placeholder_names`` + and ``maybe_return_typename`` make up the signature of the + autogenerated ``forward`` function in this Graph's surrounding + GraphModule. ``placeholder_names`` and ``maybe_return_typename`` + should not be used otherwise. + + Args: + placeholder_names: A list that will store formatted strings + representing the placeholders in the generated + ``forward`` function. Internal use only. + maybe_return_typename: A single-element list that will store + a formatted string representing the output of the + generated ``forward`` function. Internal use only. + + Returns: + str: If 1) we're using ``format_node`` as an internal helper + in the ``__str__`` method of ``Graph``, and 2) ``self`` + is a placeholder Node, return ``None``. Otherwise, + return a descriptive string representation of the + current Node. + """ + if self.op == 'placeholder': + assert isinstance(self.target, str) + arg_str = self.target + arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else '' + if placeholder_names: + placeholder_names.append(arg_str) + return None + maybe_typename = f'{_type_repr(self.type)} ' if self.type else '' + default_val = '(default=' + str(self.args[0]) + ')' if self.args else '' + return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}' + elif self.op == 'get_attr': + maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' + return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ + f'{self.op}[target={self._pretty_print_target(self.target)}]' + elif self.op == 'output': + if self.type and maybe_return_typename: + maybe_return_typename[0] = f' -> {_type_repr(self.type)}' + return f'return {self.args[0]}' + else: + maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' + return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ + f'{self.op}[target={self._pretty_print_target(self.target)}](' \ + f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})' + + @compatibility(is_backward_compatible=True) + def replace_all_uses_with(self, + replace_with : 'Node', + delete_user_cb: Callable[['Node'], bool] = lambda user: True, + *, + propagate_meta=False + ) -> List['Node']: + """ + Replace all uses of ``self`` in the Graph with the Node ``replace_with``. + + Args: + + replace_with (Node): The node to replace all uses of ``self`` with. + delete_user_cb (Callable): Callback that is called to determine + whether a given user of the self node should be removed. + propagate_meta (bool): Whether or not to copy all properties + on the .meta field of the original node onto the replacement node. + For safety, this is only valid to do if the replacement node + doesn't already have an existing .meta field. + + Returns: + + The list of Nodes on which this change was made. + """ + if propagate_meta: + assert len(replace_with.meta) == 0, \ + 'Called node.replace_all_uses_with(replace_with, propagate_meta=True), ' \ + 'but replace_with already has .meta keys' + for k, v in self.meta.items(): + replace_with.meta[k] = v + to_process = list(self.users) + skipped = [] + m = self.graph.owning_module + for use_node in to_process: + if not delete_user_cb(use_node): + skipped.append(use_node) + continue + + def maybe_replace_node(n : Node) -> Node: + if n == self: + return replace_with + else: + return n + + if getattr(m, "_replace_hook", None): + m._replace_hook(old=self, new=replace_with.name, user=use_node) + + new_args = map_arg(use_node.args, maybe_replace_node) + new_kwargs = map_arg(use_node.kwargs, maybe_replace_node) + assert isinstance(new_args, tuple) + assert isinstance(new_kwargs, dict) + use_node.__update_args_kwargs(new_args, new_kwargs) + + assert len(self.users) - len(skipped) == 0 + return [n for n in to_process if n not in skipped] + + @compatibility(is_backward_compatible=False) + def is_impure(self): + """ + Returns whether this op is impure, i.e. if its op is a placeholder or + output, or if a call_function or call_module which is impure. + + Returns: + + bool: If the op is impure or not. + """ + if self.op in {"placeholder", "output"}: + return True + + # Check if an impure function. + if self.op == "call_function": + return self.target in _side_effectful_functions + + # Check if an impure module. + if self.op == "call_module": + assert ( + self.graph.owning_module is not None + ), "self.graph.owning_module not set for purity check" + target_mod = self.graph.owning_module.get_submodule(self.target) + assert ( + target_mod is not None + ), f"Did not find expected submodule target {self.target}" + return getattr(target_mod, "_is_impure", False) + + return False + + @compatibility(is_backward_compatible=False) + def normalized_arguments( + self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None, + kwarg_types : Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to Python targets. This means that + `args/kwargs` will be matched up to the module/functional's + signature and return exclusively kwargs in positional order + if `normalize_to_only_use_kwargs` is true. + Also populates default values. Does not support positional-only + parameters or varargs parameters. + + Supports module calls. + + May require `arg_types` and `kwarg_types` in order to disambiguate overloads. + + Args: + root (torch.nn.Module): Module upon which to resolve module targets. + arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args + kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns NamedTuple ArgsKwargsPair, or `None` if not successful. + """ + if self.op == 'call_function': + assert callable(self.target) + return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type] + elif self.op == 'call_module': + assert isinstance(self.target, str) + return normalize_module(root, self.target, self.args, self.kwargs) # type: ignore[arg-type] + + return None + + @compatibility(is_backward_compatible=True) + def replace_input_with(self, old_input: 'Node', new_input: 'Node'): + """ + Loop through input nodes of ``self``, and replace all instances of + ``old_input`` with ``new_input``. + + Args: + + old_input (Node): The old input node to be replaced. + new_input (Node): The new input node to replace ``old_input``. + """ + def maybe_replace_node(n : Node) -> Node: + return new_input if n == old_input else n + + m = self.graph.owning_module + if getattr(m, "_replace_hook", None): + m._replace_hook(old=old_input, new=new_input.name, user=self) + + new_args = map_arg(self.args, maybe_replace_node) + new_kwargs = map_arg(self.kwargs, maybe_replace_node) + assert isinstance(new_args, tuple) + assert isinstance(new_kwargs, dict) + self.__update_args_kwargs(new_args, new_kwargs) + + def _rename(self, candidate: str): + if candidate == self.name: + return + name = self.graph._graph_namespace.create_name(candidate, None) + self.name = name + self.graph._graph_namespace._rename_object(self, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name == 'name' and hasattr(self, "name"): + m = self.graph.owning_module + if getattr(m, "_replace_hook", None): + assert isinstance(value, str) + for user in self.users: + m._replace_hook(old=self, new=value, user=user) + object.__setattr__(self, name, value) + + +@compatibility(is_backward_compatible=True) +def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: + """ + Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. + """ + assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" + return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) + +@compatibility(is_backward_compatible=True) +def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: + """ + Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. + """ + if isinstance(a, tuple): + t = tuple(map_aggregate(elem, fn) for elem in a) + # Support NamedTuple (if it has `_fields`) by repacking into original type. + return t if not hasattr(a, '_fields') else type(a)(*t) + elif isinstance(a, list): + return immutable_list(map_aggregate(elem, fn) for elem in a) + elif isinstance(a, dict): + return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items()) + elif isinstance(a, slice): + return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) + else: + return fn(a) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/subgraph_rewriter.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/subgraph_rewriter.py new file mode 100644 index 0000000000000000000000000000000000000000..b4972720a05dc1d46792968f6ac2d008a1e29357 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/subgraph_rewriter.py @@ -0,0 +1,349 @@ +from .graph_module import GraphModule +from .graph import Graph +from .node import Node +from ._symbolic_trace import symbolic_trace +from ._compatibility import compatibility + +import copy +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING +import torch + +if TYPE_CHECKING: + from .passes.utils.matcher_with_name_node_map_utils import InternalMatch + +__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"] + +@compatibility(is_backward_compatible=True) +class Match(NamedTuple): + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + +@compatibility(is_backward_compatible=False) +@dataclass +class ReplacedPatterns: + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + # List of nodes that were added into the graph + replacements: List[Node] + +def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: + gm.delete_all_unused_submodules() + + if isinstance(replacement, GraphModule): + replacement.graph.lint() + + def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]: + module_path, _, attr_name = target.rpartition(".") + try: + mod: torch.nn.Module = gm.get_submodule(module_path) + except AttributeError: + return None + attr = getattr(mod, attr_name, None) + return attr + + for node in gm.graph.nodes: + if node.op == "call_module" or node.op == "get_attr": + + gm_attr = try_get_attr(gm, node.target) + replacement_attr = try_get_attr(replacement, node.target) + + # CASE 1: This target already exists as an attribute in our + # result GraphModule. Whether or not it exists in + # `replacement`, the existing submodule takes precedence. + if gm_attr is not None: + continue + + # CASE 2: The target exists as an attribute in `replacement` + # only, so we need to copy it over. + elif replacement_attr is not None: + new_attr = copy.deepcopy(replacement_attr) + if isinstance(replacement_attr, torch.nn.Module): + gm.add_submodule(node.target, new_attr) + else: + setattr(gm, node.target, new_attr) + + # CASE 3: The target doesn't exist as an attribute in `gm` + # or `replacement` + else: + raise RuntimeError("Attempted to create a \"", node.op, + "\" node during subgraph rewriting " + f"with target {node.target}, but " + "the referenced attribute does not " + "exist in the replacement GraphModule") + + gm.graph.lint() + + +@compatibility(is_backward_compatible=True) +def replace_pattern( + gm: GraphModule, + pattern: Union[Callable, GraphModule], + replacement: Union[Callable, GraphModule] +) -> List[Match]: + """ + Matches all possible non-overlapping sets of operators and their + data dependencies (``pattern``) in the Graph of a GraphModule + (``gm``), then replaces each of these matched subgraphs with another + subgraph (``replacement``). + + Args: + ``gm``: The GraphModule that wraps the Graph to operate on + ``pattern``: The subgraph to match in ``gm`` for replacement + ``replacement``: The subgraph to replace ``pattern`` with + + Returns: + List[Match]: A list of ``Match`` objects representing the places + in the original graph that ``pattern`` was matched to. The list + is empty if there are no matches. ``Match`` is defined as: + + .. code-block:: python + + class Match(NamedTuple): + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + + Examples: + + .. code-block:: python + + import torch + from torch.fx import symbolic_trace, subgraph_rewriter + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, w1, w2): + m1 = torch.cat([w1, w2]).sum() + m2 = torch.cat([w1, w2]).sum() + return x + torch.max(m1) + torch.max(m2) + + def pattern(w1, w2): + return torch.cat([w1, w2]).sum() + + def replacement(w1, w2): + return torch.stack([w1, w2]) + + traced_module = symbolic_trace(M()) + + subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) + + The above code will first match ``pattern`` in the ``forward`` + method of ``traced_module``. Pattern-matching is done based on + use-def relationships, not node names. For example, if you had + ``p = torch.cat([a, b])`` in ``pattern``, you could match + ``m = torch.cat([a, b])`` in the original ``forward`` function, + despite the variable names being different (``p`` vs ``m``). + + The ``return`` statement in ``pattern`` is matched based on its + value only; it may or may not match to the ``return`` statement in + the larger graph. In other words, the pattern doesn't have to extend + to the end of the larger graph. + + When the pattern is matched, it will be removed from the larger + function and replaced by ``replacement``. If there are multiple + matches for ``pattern`` in the larger function, each non-overlapping + match will be replaced. In the case of a match overlap, the first + found match in the set of overlapping matches will be replaced. + ("First" here being defined as the first in a topological ordering + of the Nodes' use-def relationships. In most cases, the first Node + is the parameter that appears directly after ``self``, while the + last Node is whatever the function returns.) + + One important thing to note is that the parameters of the + ``pattern`` Callable must be used in the Callable itself, + and the parameters of the ``replacement`` Callable must match + the pattern. The first rule is why, in the above code block, the + ``forward`` function has parameters ``x, w1, w2``, but the + ``pattern`` function only has parameters ``w1, w2``. ``pattern`` + doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. + As an example of the second rule, consider replacing + + .. code-block:: python + + def pattern(x, y): + return torch.neg(x) + torch.relu(y) + + with + + .. code-block:: python + + def replacement(x, y): + return torch.relu(x) + + In this case, ``replacement`` needs the same number of parameters + as ``pattern`` (both ``x`` and ``y``), even though the parameter + ``y`` isn't used in ``replacement``. + + After calling ``subgraph_rewriter.replace_pattern``, the generated + Python code looks like this: + + .. code-block:: python + + def forward(self, x, w1, w2): + stack_1 = torch.stack([w1, w2]) + sum_1 = stack_1.sum() + stack_2 = torch.stack([w1, w2]) + sum_2 = stack_2.sum() + max_1 = torch.max(sum_1) + add_1 = x + max_1 + max_2 = torch.max(sum_2) + add_2 = add_1 + max_2 + return add_2 + """ + match_and_replacements = _replace_pattern(gm, pattern, replacement) + return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements] + + +# Experimental API, not backward compatible +@compatibility(is_backward_compatible=False) +def replace_pattern_with_filters( + gm: GraphModule, + pattern: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule], + match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, + ignore_literals: bool = False, +) -> List[ReplacedPatterns]: + """ + See replace_pattern for documentation. This function is an overload with an additional match_filter argument. + + Args: + ``match_filters``: A list of functions that take in + (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating + whether the match satisfies the condition. + See matcher_utils.py for definition of InternalMatch. + """ + + return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals) + + +def _replace_pattern( + gm: GraphModule, + pattern: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule], + match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, + ignore_literals: bool = False, +) -> List[ReplacedPatterns]: + + from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch + + if match_filters is None: + match_filters = [] + + # Get the graphs for `gm`, `pattern`, `replacement` + original_graph: Graph = gm.graph + + if isinstance(pattern, GraphModule): + pattern_graph = pattern.graph + elif isinstance(pattern, Graph): + pattern_graph = pattern + else: + pattern_graph = symbolic_trace(pattern).graph + + if isinstance(replacement, GraphModule): + replacement_graph = replacement.graph + elif isinstance(replacement, Graph): + replacement_graph = replacement + else: + replacement_graph = symbolic_trace(replacement).graph + + matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False, + remove_overlapping_matches=True, ignore_literals=ignore_literals) + _matches: List[InternalMatch] = matcher.match(original_graph) + + # Filter out matches that don't match the filter + _matches = [ + m for m in _matches + if all(match_filter(m, original_graph, pattern_graph) + for match_filter in match_filters) + ] + + replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] + + # As we progressively replace nodes, we'll need to keep track of how the match results should change + match_changed_node: Dict[Node, Node] = {} + + match_and_replacements = [] + for match in _matches: + + # Build connecting between replacement graph's input and original graph input producer node + + # Initialize `val_map` with mappings from placeholder nodes in + # `replacement` to their corresponding node in `original_graph` + assert len(match.placeholder_nodes) == len(replacement_placeholders) + val_map: Dict[Node, Node] = {} + for rn, gn in zip(replacement_placeholders, match.placeholder_nodes): + if isinstance(gn, Node): + val_map[rn] = match_changed_node.get(gn, gn) + if gn != val_map[rn]: + # Update match.placeholder_nodes and match.nodes_map with the node that replaced gn + gn_ind = match.placeholder_nodes.index(gn) + match.placeholder_nodes[gn_ind] = match_changed_node[gn] + map_key = list(match.nodes_map.keys())[list(match.nodes_map.values()).index(gn)] + match.nodes_map[map_key] = match_changed_node[gn] + else: + val_map[rn] = gn + + # Copy the replacement graph over + user_nodes: Set[Node] = set() + for n in match.returning_nodes: + for user in n.users: + user_nodes.add(user) + assert user_nodes, "The returning_nodes should have at least one user node" + + if len(user_nodes) == 1: + first_user_node = next(iter(user_nodes)) + else: + # If there are multiple user nodes, we need to find the first user node + # in the current execution order of the `original_graph` + for n in original_graph.nodes: + if n in user_nodes: + first_user_node = n + break + + with original_graph.inserting_before(first_user_node): # type: ignore[possibly-undefined] + copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map) + + if isinstance(copied_returning_nodes, Node): + copied_returning_nodes = (copied_returning_nodes, ) + + # Get a list of nodes that have been replaced into the graph + replacement_nodes: List[Node] = [v for v in val_map.values() if v not in match.placeholder_nodes] + + # Hook the output Node of the replacement subgraph in to the + # original Graph at the correct location + assert len(match.returning_nodes) == len(copied_returning_nodes) + for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): + gn.replace_all_uses_with(copied_node) + match_changed_node[gn] = copied_node + # Remove the original nodes + for node in reversed(pattern_graph.nodes): + if node.op != "placeholder" and node.op != "output": + gn = match.nodes_map[node] + gm.graph.erase_node(gn) + + match_and_replacements.append( + ReplacedPatterns( + anchor=match.anchors[0], + nodes_map=match.nodes_map, + replacements=replacement_nodes + ) + ) + + # Update the passed-in GraphModule to reflect the new state of + # `original_graph` + gm.recompile() + + # If `replacement` was an nn.Module, we'll need to make sure that + # all the submodules have been copied over correctly + if isinstance(replacement, torch.nn.Module): + _replace_attributes(gm, replacement) + + return match_and_replacements diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dc8e7573e8b9b03f512e2ec6f14ea9c0318d0ed Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/spawn.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/spawn.py new file mode 100644 index 0000000000000000000000000000000000000000..fed869c9ae26469b03b48ca9d9de260312501c1d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/spawn.py @@ -0,0 +1,281 @@ +import logging +import multiprocessing +import multiprocessing.connection +import os +import pickle +import signal +import sys +import tempfile +import time +import warnings +from typing import Optional + +from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] + +log = logging.getLogger(__name__) + + +class ProcessException(Exception): + __slots__ = ["error_index", "error_pid"] + + def __init__(self, msg: str, error_index: int, pid: int): + super().__init__(msg) + self.msg = msg + self.error_index = error_index + self.pid = pid + + def __reduce__(self): + return type(self), (self.msg, self.error_index, self.pid) + + +class ProcessRaisedException(ProcessException): + """Exception raised when a process failed due to an exception raised by the code.""" + + def __init__( + self, + msg: str, + error_index: int, + error_pid: int, + ): + super().__init__(msg, error_index, error_pid) + + +class ProcessExitedException(ProcessException): + """Exception raised when a process failed due to signal or exited with a specific code.""" + + __slots__ = ["exit_code"] + + def __init__( + self, + msg: str, + error_index: int, + error_pid: int, + exit_code: int, + signal_name: Optional[str] = None, + ): + super().__init__(msg, error_index, error_pid) + self.exit_code = exit_code + self.signal_name = signal_name + + def __reduce__(self): + return ( + type(self), + (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name), + ) + + +def _wrap(fn, i, args, error_file): + # prctl(2) is a Linux specific system call. + # On other systems the following function call has no effect. + # This is set to ensure that non-daemonic child processes can + # terminate if their parent terminates before they do. + _prctl_pr_set_pdeathsig(signal.SIGINT) + + try: + fn(i, *args) + except KeyboardInterrupt: + pass # SIGINT; Killed by parent, do nothing + except Exception: + # Propagate exception to parent process, keeping original traceback + import traceback + + with open(error_file, "wb") as fh: + pickle.dump(traceback.format_exc(), fh) + sys.exit(1) + + +class ProcessContext: + def __init__(self, processes, error_files): + self.error_files = error_files + self.processes = processes + self.sentinels = { + process.sentinel: index for index, process in enumerate(processes) + } + + def pids(self): + return [int(process.pid) for process in self.processes] + + def join(self, timeout=None): + r"""Join one or more processes within spawn context. + + Attempt to join one or more processes in this spawn context. + If one of them exited with a non-zero exit status, this function + kills the remaining processes and raises an exception with the cause + of the first process exiting. + + Returns ``True`` if all processes have been joined successfully, + ``False`` if there are more processes that need to be joined. + + Args: + timeout (float): Wait this long before giving up on waiting. + """ + # Ensure this function can be called even when we're done. + if len(self.sentinels) == 0: + return True + + # Wait for any process to fail or all of them to succeed. + ready = multiprocessing.connection.wait( + self.sentinels.keys(), + timeout=timeout, + ) + + error_index = None + for sentinel in ready: + index = self.sentinels.pop(sentinel) + process = self.processes[index] + process.join() + if process.exitcode != 0: + error_index = index + break + + # Return if there was no error. + if error_index is None: + # Return whether or not all processes have been joined. + return len(self.sentinels) == 0 + + # Assume failure. Terminate processes that are still alive. + # Try SIGTERM then SIGKILL if the process isn't going down. + # The reason is related to python signal handling is limited + # to main thread and if that is in c/c++ land and stuck it won't + # to handle it. We have seen processes getting stuck not handling + # SIGTERM for the above reason. + timeout: int = 30 + for process in self.processes: + if process.is_alive(): + log.warning("Terminating process %s via signal SIGTERM", process.pid) + process.terminate() + end = time.monotonic() + timeout + for process in self.processes: + time_to_wait = max(0, end - time.monotonic()) + process.join(time_to_wait) + for process in self.processes: + if process.is_alive(): + log.warning( + "Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL", + process.pid, + ) + process.kill() + process.join() + + # The file will only be created if the process crashed. + failed_process = self.processes[error_index] + if not os.access(self.error_files[error_index], os.R_OK): + exitcode = self.processes[error_index].exitcode + if exitcode < 0: + try: + name = signal.Signals(-exitcode).name + except ValueError: + name = f"" + raise ProcessExitedException( + "process %d terminated with signal %s" % (error_index, name), + error_index=error_index, + error_pid=failed_process.pid, + exit_code=exitcode, + signal_name=name, + ) + else: + raise ProcessExitedException( + "process %d terminated with exit code %d" % (error_index, exitcode), + error_index=error_index, + error_pid=failed_process.pid, + exit_code=exitcode, + ) + + with open(self.error_files[error_index], "rb") as fh: + original_trace = pickle.load(fh) + msg = "\n\n-- Process %d terminated with the following error:\n" % error_index + msg += original_trace + raise ProcessRaisedException(msg, error_index, failed_process.pid) + + +class SpawnContext(ProcessContext): + def __init__(self, processes, error_files): + warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.") + super().__init__(processes, error_files) + + +# Note: [start_processes] +# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a +# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the +# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork' +# works better than 'spawn'. Every helper function we created for mp.spawn is indeed +# general enough, and backends like XLA can reuse them in Colab notebooks as well. +# Currently we only add this API first, we can consider adding it to documentation as +# needed in the future. +def start_processes( + fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn" +): + mp = multiprocessing.get_context(start_method) + error_files = [] + processes = [] + for i in range(nprocs): + # Each process is assigned a file to write tracebacks to. We + # use the file being non-empty to indicate an exception + # occurred (vs an expected shutdown). Note: this previously + # used a multiprocessing.Queue but that can be prone to + # deadlocks, so we went with a simpler solution for a one-shot + # message between processes. + tf = tempfile.NamedTemporaryFile( + prefix="pytorch-errorfile-", suffix=".pickle", delete=False + ) + tf.close() + os.unlink(tf.name) + process = mp.Process( + target=_wrap, + args=(fn, i, args, tf.name), + daemon=daemon, + ) + process.start() + error_files.append(tf.name) + processes.append(process) + + context = ProcessContext(processes, error_files) + if not join: + return context + + # Loop on join until it returns True or raises an exception. + while not context.join(): + pass + + +def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"): + r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``. + + If one of the processes exits with a non-zero exit status, the + remaining processes are killed and an exception is raised with the + cause of termination. In the case an exception was caught in the + child process, it is forwarded and its traceback is included in + the exception raised in the parent process. + + Args: + fn (function): Function is called as the entrypoint of the + spawned process. This function must be defined at the top + level of a module so it can be pickled and spawned. This + is a requirement imposed by multiprocessing. + + The function is called as ``fn(i, *args)``, where ``i`` is + the process index and ``args`` is the passed through tuple + of arguments. + + args (tuple): Arguments passed to ``fn``. + nprocs (int): Number of processes to spawn. + join (bool): Perform a blocking join on all processes. + daemon (bool): The spawned processes' daemon flag. If set to True, + daemonic processes will be created. + start_method (str): (deprecated) this method will always use ``spawn`` + as the start method. To use a different start method + use ``start_processes()``. + + Returns: + None if ``join`` is ``True``, + :class:`~ProcessContext` if ``join`` is ``False`` + + """ + if start_method != "spawn": + msg = ( + "This method only supports start_method=spawn (got: %s).\n" + "To use a different start_method use:\n\t\t" + " torch.multiprocessing.start_processes(...)" % start_method + ) + warnings.warn(msg) + return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")