diff --git a/.gitattributes b/.gitattributes index 849cfabdd4167a72817582cce9896dffce13ba44..4d62a9c21308ef7af9dc6dbc7963be05388c9bb8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -128,3 +128,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/torch/nn/__pycache__/functional.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/scheduler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__init__.py b/.venv/lib/python3.11/site-packages/torch/_inductor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f95e7caaf71e95c675d4ea9e467c99d76ebdb842 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/__init__.py @@ -0,0 +1,179 @@ +# mypy: allow-untyped-defs +from typing import Any, Dict, List, Optional, Tuple + +import torch.fx +import torch.utils._pytree as pytree + + +__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"] + + +def compile( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + options: Optional[Dict[str, Any]] = None, +): + """ + Compile a given FX graph with TorchInductor. This allows compiling + FX graphs captured without using TorchDynamo. + + Args: + gm: The FX graph to compile. + example_inputs: List of tensor inputs. + options: Optional dict of config options. See `torch._inductor.config`. + + Returns: + Callable with same behavior as gm but faster. + """ + from .compile_fx import compile_fx + + return compile_fx(gm, example_inputs, config_patches=options) + + +def aot_compile( + gm: torch.fx.GraphModule, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + *, + options: Optional[Dict[str, Any]] = None, +) -> str: + """ + Ahead-of-time compile a given FX graph with TorchInductor into a shared library. + + Args: + gm: The FX graph to compile. + args: Example arguments + kwargs: Example keyword arguments + options: Optional dict of config options. See `torch._inductor.config`. + + Returns: + Path to the generated shared library + """ + from .compile_fx import compile_fx_aot, graph_returns_tuple + + assert graph_returns_tuple(gm), ( + "Graph output must be a tuple(). This is so that we can avoid " + "pytree processing of the outputs. Please change the module to " + "have tuple outputs." + ) + + # We will serialize the pytree info into the .so as constant strings + in_spec = None + out_spec = None + if isinstance(gm.graph._codegen, torch.fx.graph._PyTreeCodeGen): + codegen = gm.graph._codegen + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.recompile() + + if codegen.pytree_info.in_spec is not None: + in_spec = codegen.pytree_info.in_spec + if codegen.pytree_info.out_spec is not None: + out_spec = codegen.pytree_info.out_spec + + else: + if hasattr(gm, "_in_spec"): + in_spec = gm._in_spec + if hasattr(gm, "_out_spec"): + out_spec = gm._out_spec + + serialized_in_spec = pytree.treespec_dumps(in_spec) if in_spec is not None else "" + serialized_out_spec = ( + pytree.treespec_dumps(out_spec) if out_spec is not None else "" + ) + + flat_args_with_path, received_spec = pytree.tree_flatten_with_path( + (args, kwargs or {}) + ) + + # Replace non-tensor (constant) inputs with Nones, since these are not being + # used anyways by the graph + flat_example_inputs = [ + x[1] if isinstance(x[1], torch.Tensor) else None for x in flat_args_with_path + ] + + if in_spec is not None and received_spec != in_spec: + raise ValueError( # noqa: B904 + "Trying to flatten user inputs with exported input tree spec: \n" + f"{in_spec}\n" + "but actually got inputs with tree spec of: \n" + f"{received_spec}" + ) + + options = ( + { + "aot_inductor.serialized_in_spec": serialized_in_spec, + "aot_inductor.serialized_out_spec": serialized_out_spec, + } + if options is None + else { + **options, + "aot_inductor.serialized_in_spec": serialized_in_spec, + "aot_inductor.serialized_out_spec": serialized_out_spec, + } + ) + + return compile_fx_aot( + gm, + flat_example_inputs, # type: ignore[arg-type] + config_patches=options, + ) + + +def list_mode_options( + mode: Optional[str] = None, dynamic: Optional[bool] = None +) -> Dict[str, Any]: + r"""Returns a dictionary describing the optimizations that each of the available + modes passed to `torch.compile()` performs. + + Args: + mode (str, optional): The mode to return the optimizations for. + If None, returns optimizations for all modes + dynamic (bool, optional): Whether dynamic shape is enabled. + + Example:: + >>> torch._inductor.list_mode_options() + """ + + mode_options: Dict[str, Dict[str, bool]] = { + "default": {}, + # enable cudagraphs + "reduce-overhead": { + "triton.cudagraphs": True, + }, + # enable max-autotune + "max-autotune-no-cudagraphs": { + "max_autotune": True, + }, + # enable max-autotune + # enable cudagraphs + "max-autotune": { + "max_autotune": True, + "triton.cudagraphs": True, + }, + } + return mode_options[mode] if mode else mode_options # type: ignore[return-value] + + +def list_options() -> List[str]: + r"""Returns a dictionary describing the optimizations and debug configurations + that are available to `torch.compile()`. + + The options are documented in `torch._inductor.config`. + + Example:: + + >>> torch._inductor.list_options() + """ + + from torch._inductor import config + + current_config: Dict[str, Any] = config.shallow_copy_dict() + + return list(current_config.keys()) + + +def cudagraph_mark_step_begin(): + "Indicates that a new iteration of inference or training is about to begin." + from .cudagraph_trees import mark_step_begin + + mark_step_begin() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4d5d08b26d64a90b28880e220903520a7a217af Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/aoti_eager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/aoti_eager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67b2340c7c0cd725569da67ef62213a0ff9d3989 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/aoti_eager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/async_compile.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/async_compile.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b5f3d56812252b271754ddaf8f053b550e9f553 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/async_compile.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfe44fde67fd15f07d794335a9e1763545b1bde5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/bounds.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/bounds.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e693a28df1896ac4fca5c8278f44a7e69ad655c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/bounds.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fa732b2dadaa3311b114f0de570b8601e1fa7dc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comms.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5c70d5708abb87b18e9d3e6c19e4e8b18141d35 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comms.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50915d7041fc5efb48455be02eef29dd27b9ac9c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81766f4248a11e4dd045b5bfcba36ac7bfa1fcdb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a929abd129c17ebe9ec0ee7ea960000c785c77e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90c189f1faec7cfae6089c30774a2521278995cc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d26d30becabd232c12b640eecd5eca4c96086587 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e8eb94d2a07d3c5e4c156ab01aaa003020a8b21 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/debug.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/debug.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0465e6efa8c9392c287fbee7f62b41d12696d67e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/debug.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/decomposition.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/decomposition.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d528b8dc7f6951b307f8c787715f017362202288 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/decomposition.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/dependencies.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/dependencies.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec91a8b41bb10e7acc95e45dac4eb67aa20a1f06 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/dependencies.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/exc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/exc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1b4f5a4942570f6001587e43ed5f9cc5c3f16b8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/exc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/extern_node_serializer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/extern_node_serializer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da6ccddfa1f9e893455a3c09ebfb52c453d2e5be Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/extern_node_serializer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/freezing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/freezing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b543972bcc709e71648bb20f18c0c55912238eb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/freezing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1480480cb5e417e247370e2006eb1382d7192b93 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/graph.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/graph.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8367311337c47aceeed1a9b1a528007ce21adc32 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/graph.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/hooks.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/hooks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a08e0b550c2e4ab43a9f365e83eaf9dbbeb75898 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/hooks.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad060ca634f268180cada512a162240d09e677c0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8eb877cff2fdd2b760b2b55187e2bbc37f2b8d55 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/jagged_lowerings.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/jagged_lowerings.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21e4c933d164cf9fa66bbdad5f910ce3a46988dc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/jagged_lowerings.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/loop_body.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/loop_body.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd589f340ba833c4f35a7c2538835f0186707507 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/loop_body.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/metrics.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49d8db3706399989cdc29553ab8f25fc5f9e6055 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/metrics.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/mkldnn_ir.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/mkldnn_ir.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8d87f1c4475febd3c5c845f17be382ce816d521 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/mkldnn_ir.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/mkldnn_lowerings.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/mkldnn_lowerings.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8905108acce47c2115ab47e758da53620c6eb286 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/mkldnn_lowerings.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad9f3f0db3671884ee93cc0fb969e4003ca5d743 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad2516ca264c8d4b38bbf7ab68b46a5d9a9a6935 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3400e2efa875fdbe0d5e0ff78216829e92799d2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/remote_cache.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/remote_cache.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d4064fc2252960a98d5ae7f8fc2d4328b7c63df Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/remote_cache.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/scheduler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..561d5fe97d494e9c879fb32738097621a634df3c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/scheduler.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48fc35b3ba35cd6f3ba02d218d951aa8a531c58ad217a2e94bfb14483e5a78af +size 216212 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5daf464548b903f46f5692a72e3356d3ba3a7151 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/sizevars.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/sizevars.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..633d2dd7e0f5dd62b26cf797f9d19930740c78e0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/sizevars.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/subgraph_lowering.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/subgraph_lowering.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4bb79748b8954a7751f1eb899da025e496ab353 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/subgraph_lowering.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_case.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_case.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eef70523e59496ac33db955080270559026f8fbe Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_case.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_operators.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_operators.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f98defd790c9227d35c2e81359b8b5968298a97 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_operators.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/virtualized.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/virtualized.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..362117f0f9dac8683881530125acda6104796cc2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/virtualized.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3d9bc4c8e0af16ba5c1da3148374e45199ab36d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/aoti_eager.py b/.venv/lib/python3.11/site-packages/torch/_inductor/aoti_eager.py new file mode 100644 index 0000000000000000000000000000000000000000..f733ce4fbd5a1748c663892c10cf349166bc9461 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/aoti_eager.py @@ -0,0 +1,298 @@ +import json +import logging +import os +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple +from unittest import mock + +import torch +import torch._export +from torch._inductor.utils import is_cpu_device + +from .runtime.runtime_utils import cache_dir + + +log = logging.getLogger(__name__) + + +def aoti_eager_cache_dir(namespace: str, device: str) -> Path: + return Path(cache_dir()) / "aoti_eager" / namespace / device + + +def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any: + from filelock import FileLock + + # Avoid circular import + from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT + + op_conf_lock_file = f"{op_func_name_with_overload}.lock" + lock_dir = get_lock_dir() + return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT) + + +def load_aoti_eager_cache( + ns: str, op_func_name_with_overload: str, device_type: str +) -> List[Optional[Dict[str, Any]]]: + device_kernel_cache = aoti_eager_cache_dir(ns, device_type) + op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json" + if not op_conf.exists(): + return [] + + try: + with aoti_eager_op_conf_lock(op_func_name_with_overload): + with open(op_conf) as f: + json_data = json.load(f) + for item in json_data: + # Get absolution path for kernel library + kernel_lib_abs_path = device_kernel_cache / item["kernel_path"] + item["kernel_path"] = kernel_lib_abs_path.as_posix() + + # Check if the kernel library exists + if not kernel_lib_abs_path.exists(): + return [] + + for metadata in item["meta_info"]: + if metadata.get("is_dynamic"): + raise NotImplementedError( + "Only support static shape for now" + ) + if ( + "device_type" in metadata + and metadata["device_type"] == "cpu" + ): + metadata["device_index"] = -1 + for dtype_key in ["dtype", "dtype_value"]: + if dtype_key in metadata: + metadata[dtype_key] = getattr( + torch, metadata[dtype_key].split(".")[-1] + ) + if "layout_value" in metadata: + metadata["layout_value"] = getattr( + torch, metadata["layout_value"].split(".")[-1] + ) + if "memory_format_value" in metadata: + metadata["memory_format_value"] = getattr( + torch, metadata["memory_format_value"].split(".")[-1] + ) + + return json_data + except Exception as e: + err_msg = f"Failed to load aoti eager cache: {e}" + log.exception(err_msg) + return [] + + +def supported_builtin_dtype_torch_dtype() -> Dict[type, torch.dtype]: + return {int: torch.int32, float: torch.float, bool: torch.bool} + + +def supported_scalar_types() -> Tuple[type, ...]: + type_to_torch_dtype = supported_builtin_dtype_torch_dtype() + return tuple(type_to_torch_dtype.keys()) + + +def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any]: + metadata: Dict[str, Any] = {} + metadata["is_dynamic"] = dynamic + + assert isinstance(input, torch.Tensor) + metadata["device_type"] = f"{input.device.type}" + if is_cpu_device([input]): + metadata["device_index"] = -1 + else: + metadata["device_index"] = input.device.index + metadata["dtype"] = f"{input.dtype}" + metadata["sizes"] = list(input.size()) + metadata["strides"] = list(input.stride()) + metadata["requires_grad"] = input.requires_grad + metadata["dispatch_key_set"] = torch._C._dispatch_keys(input).raw_repr() + return metadata + + +def extract_tensor_list_metadata( + dynamic: bool, + input: List[torch.Tensor], +) -> Dict[str, Any]: + metadata_list = [] + for item in input: + assert isinstance(item, torch.Tensor) + metadata_list.append(extract_tensor_metadata(dynamic, item)) + + metadata: Dict[str, Any] = {} + metadata["tensor_list"] = metadata_list + return metadata + + +def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]: + assert isinstance(input, supported_scalar_types()) + metadata: Dict[str, Any] = {} + metadata["is_dynamic"] = False + # Scalar tensor + metadata["device_type"] = device_type + metadata["device_index"] = -1 if device_type == "cpu" else 0 + type_to_torch_dtype = supported_builtin_dtype_torch_dtype() + metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}" + metadata["scalar_value"] = input + return metadata + + +def extract_string_metadata(input: str) -> Dict[str, Any]: + assert isinstance(input, str) + metadata: Dict[str, Any] = {} + metadata["string_value"] = input + return metadata + + +def extract_dtype_metadata(input: torch.dtype) -> Dict[str, Any]: + assert isinstance(input, torch.dtype) + metadata: Dict[str, Any] = {} + metadata["dtype_value"] = f"{input}" + return metadata + + +def extract_device_metadata(input: torch.device) -> Dict[str, Any]: + assert isinstance(input, torch.device) + metadata: Dict[str, Any] = {} + metadata["device_type_value"] = f"{input.type}" + metadata["device_index_value"] = input.index + return metadata + + +def extract_layout_metadata(input: torch.layout) -> Dict[str, Any]: + assert isinstance(input, torch.layout) + metadata: Dict[str, Any] = {} + metadata["layout_value"] = f"{input}" + return metadata + + +def aoti_compile_with_persistent_cache( + ns: str, + op_func_name_with_overload: str, + device_type: str, + dynamic: bool, + f: Callable[..., Any], + args: Tuple[Any], + kwargs: Dict[str, Any], + *, + dynamic_shapes: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = None, + remove_runtime_assertions: bool = False, + disable_constraint_solver: bool = False, +) -> str: + """ + Compile the given function with persistent cache for AOTI eager mode. + """ + assert not dynamic, "Only support static shape for now" + flattened_inputs = list(args) + list(kwargs.values()) + if not all( + isinstance( + input, + ( + supported_scalar_types(), + torch.Tensor, + list, + str, + torch.dtype, + torch.device, + torch.layout, + ), + ) + for input in flattened_inputs + ): + err_msg = f"Unsupported input types: {flattened_inputs}" + log.exception(err_msg) + raise NotImplementedError(err_msg) + + for input in flattened_inputs: + if isinstance(input, list) and not all( + isinstance(item, torch.Tensor) for item in input + ): + err_msg = f"_impl_with_aoti_compile encounters unsupported input types: {flattened_inputs}" + log.exception(err_msg) + raise NotImplementedError(err_msg) + + persistent_cache = aoti_eager_cache_dir(ns, device_type) + if not persistent_cache.exists(): + persistent_cache.mkdir(parents=True) + + persistent_cache_lib = persistent_cache / "lib" + if not persistent_cache_lib.exists(): + persistent_cache_lib.mkdir() + + with mock.patch.dict( + os.environ, + {"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()}, + ): + try: + kernel_lib_path = torch._export.aot_compile( + f, + args, + kwargs, + dynamic_shapes=dynamic_shapes, + remove_runtime_assertions=remove_runtime_assertions, + disable_constraint_solver=disable_constraint_solver, + # Some operations may have non-Tensor parameters like int, float, bool. These + # non-Tensor parameters will not be the input of the graph. Therefore, we do + # need to keep the same signature. + same_signature=False, + ) + + kernel_metadata_items = [] + + for idx, input in enumerate(flattened_inputs): + if isinstance(input, torch.Tensor): + metadata = extract_tensor_metadata(dynamic, input) + elif isinstance(input, list): + assert all(isinstance(item, torch.Tensor) for item in input) + metadata = extract_tensor_list_metadata(dynamic, input) + elif isinstance(input, supported_scalar_types()): + metadata = extract_scalar_metadata(device_type, input) + elif isinstance(input, str): + metadata = extract_string_metadata(input) + elif isinstance(input, torch.dtype): + metadata = extract_dtype_metadata(input) + elif isinstance(input, torch.device): + metadata = extract_device_metadata(input) + elif isinstance(input, torch.layout): + metadata = extract_layout_metadata(input) + else: + raise NotImplementedError(f"Unsupported input type: {type(input)}") + + metadata["arg_order"] = idx + kernel_metadata_items.append(metadata) + + kernel_meta_info: Dict[str, Any] = {} + kernel_meta_info["meta_info"] = kernel_metadata_items + kernel_meta_info["kernel_path"] = ( + Path(kernel_lib_path).relative_to(persistent_cache).as_posix() + ) + + json_data = [] + update_json = True + op_conf = persistent_cache / f"{op_func_name_with_overload}.json" + mode = "r" if op_conf.exists() else "w" + with aoti_eager_op_conf_lock(op_func_name_with_overload): + with open(op_conf, mode) as op_conf_file: + try: + json_data = json.load(op_conf_file) + except Exception as e: + json_data = [] + + assert isinstance(json_data, list) + for item in json_data: + assert isinstance(item, dict) + # Same kernel meta info already exists in the json file + if item["meta_info"] == kernel_metadata_items: + update_json = False + break + + if update_json: + json_data.append(kernel_meta_info) + with open(op_conf, "w") as op_conf_file: + json.dump(json_data, op_conf_file, indent=4) + + return kernel_lib_path + except Exception as e: + err_msg = f"Failed to compile {op_func_name_with_overload}: {e}" + log.exception(err_msg) + return "" diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59a33366a873b0ab0686d413220a019dca798478 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/learned_heuristic_controller.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/learned_heuristic_controller.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96cda3ff7b16560133680c814a4f6046908c97ca Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/learned_heuristic_controller.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/learnedheuristic_interface.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/learnedheuristic_interface.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d9563b5f1e869860f29aee8c22deadd8daf98c8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/learnedheuristic_interface.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingA100.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingA100.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9845ca8a9c566127fa4398b0a33764828ca0bd6b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingA100.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingH100.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingH100.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46ebb3b5e5fe3e81325aea4d5ad2b0032ee4a5db Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingH100.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMA100.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMA100.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53a2d31a68cb32fd35262efb1d07056d3ace47e2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMA100.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMH100.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMH100.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0975ea3edb4f50fff1bbf6337ed2a63b8a2fc41e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMH100.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_PadMMA100.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_PadMMA100.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45c99143d103a4accec508d46a5bc7ef1012dd9b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_PadMMA100.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..629c2c0579a77395f36b9c08b15fa7769a7a28bc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/bounds.py b/.venv/lib/python3.11/site-packages/torch/_inductor/bounds.py new file mode 100644 index 0000000000000000000000000000000000000000..7452f2bb1b62b6862a63bcb8d508032d5e33c7e9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/bounds.py @@ -0,0 +1,140 @@ +# mypy: allow-untyped-defs +import logging +import operator +from functools import partial +from typing import Any, Callable, Dict + +from sympy import Expr + +import torch +from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges + +from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock +from .utils import cache_on_self, dominated_nodes +from .virtualized import V + + +log = logging.getLogger(__name__) + + +class BoundVars: + """ + Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() + It exposes the ranges of the nodes in the `bounds` variable + + Note. A current limitation of this analysis is that it just works on a per-loop basis. + We should be able to propagate the bounds between across the whole graph. This may benefit + the case a bounded variable is returned by a kernel and fed into another. + """ + + def __init__(self, loop_body: LoopBody) -> None: + def upper_bound(v): + return bound_sympy(v).upper if isinstance(v, Expr) else v + + self.loop_body = loop_body + self.replacement_vals = { + k: ValueRanges[Expr](0, upper_bound(v) - 1) + for k, v in loop_body.var_ranges.items() + } + # avoid computing these values, pessimistically assume that they are unbounded + self.unbounded_vars = dominated_nodes( + node + for node in self.loop_body.get_nodes() + if node.target in ["load", "reduction", operator.getitem] + or "masked_subblock" in node.target + ) + # To access this variable call `get_bounds()` + self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {} + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"loop_body={self.loop_body},\n " + f"replacement_vals={self.replacement_vals}, \n" + f"unbounded_vars={self.unbounded_vars}, \n" + f"_bounds={self._bounds})" + ) + + @cache_on_self + def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: + submodules = self.swap_submodules(self.loop_body.submodules) + + # Initialize the environment with the unbounded variables + for node in self.unbounded_vars: + # we need to evaluate masked_subblock to recurse, and we need to set indirect values + if not isinstance(node.target, str) or ( + "masked_subblock" not in node.target + and "set_indirect" not in node.target + ): + self._bounds[node] = ValueRanges[Expr].unknown() + + with V.set_ops_handler(ValueRangeAnalysis()): + interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) + log.debug("get_bounds:\n%s", self.loop_body.root_block.graph) + interpreter.run(V.get_ops_handler(), initial_env=self._bounds) + return self._bounds + + def swap_submodules( + self, submodules: Dict[str, Callable[..., Any]] + ) -> Dict[str, Callable[..., ValueRanges[Expr]]]: + result: Dict[str, Callable[..., ValueRanges[Expr]]] = {} + for key in submodules.keys(): + if key == "get_index": + result[key] = self.get_index + elif "masked_subblock" in key: + subblock = self.loop_body.subblocks[key] + # The result within the lambda will reference to the final + # set of modules at the end of the for-loop as it stores a reference to it + + # bind subblock in a function because python lambdas close over by reference + # moving the lambda out of make_fn would close over the reference to subblock, + # so all lambdas would have the same subblock reference that is the final + # subblock in the loop + def make_fn(subblock): + return lambda mask, value: self.masked_subblock( + subblock, self._bounds, mask, value, result + ) + + result[key] = make_fn(subblock) + elif "set_indirect" in key: + idx = int(key[len("set_indirect") :]) + var = self.loop_body.indirect_vars[idx] + indirect = partial(self.set_indirect, var) + result[key] = indirect + else: + assert "scan" in key + result[key] = submodules[key] + + return result + + def masked_subblock( + self, + subblock: LoopBodyBlock, + env: Dict[torch.fx.Node, ValueRanges[Expr]], + mask: Any, + value: Any, + submodules: Dict[str, Callable[..., Any]], + ) -> ValueRanges[Expr]: + interp = InterpreterShim(subblock.graph, submodules) + interp.run(V.get_ops_handler(), initial_env=env) + output = [node for node in subblock.graph.nodes if node.target == "output"] + assert len(output) == 1 + # dont bother unioning with value since the load from buffer will be + # pessimistically assumed to be inf anyway + return interp.env[output[0]] + + def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]: + assert isinstance(new, ValueRanges) + self.replacement_vals[old] = new + return new + + def get_index(self, name: Expr) -> ValueRanges[Expr]: + expr = self.loop_body.indexing_exprs[name] + bound = self.replacement_vals.get(expr) + if bound is None: + bound = bound_sympy(expr, self.replacement_vals) + # The following assertion is true at the time of this writing + # We don't assert is as to not execute bound_sympy when bound is not None + # assert bound is None or bound == bound_sympy(expr, self.replacement_vals) + self.replacement_vals[name] = bound + return bound diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..80085aa6d18f2d59c0ddb00b77b607ce08441d9c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py @@ -0,0 +1,32 @@ +# mypy: allow-untyped-defs +import re + +import torch +from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE + + +# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like: +# "... +# from ..codecache import CudaKernelParamCache +# ..." +# In such cases, we do not need to hipify_torch the orignial class/file name in codegen/codecache + + +def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str: + if torch.version.hip is None and not force_hipify: + return source_codes + + def c2_repl(m): + return PYTORCH_MAP[m.group(0)] + + # We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch, + # it will apply positive lookbehind (?<=\W) to the pattern to avoid matching + # keyword at the beginning of code line. However, this can happen in codegen, + # which will cause the pattern to not match. + + # Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example + # we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA" + RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)") + + source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes) + return source_codes diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/codegen_device_driver.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/codegen_device_driver.py new file mode 100644 index 0000000000000000000000000000000000000000..c31017fe6471cb5dca9a41feac23c15d1a71b8aa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/codegen_device_driver.py @@ -0,0 +1,91 @@ +import torch + + +# Provide aoti module launch hip/cuda drivers. This file is also used for unit testing purpose + + +def cuda_kernel_driver() -> str: + source_codes = """ + #define CUDA_DRIVER_CHECK(EXPR) \\ + do { \\ + CUresult code = EXPR; \\ + const char *msg; \\ + cuGetErrorString(code, &msg); \\ + if (code != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string(msg)); \\ + } \\ + } while (0); + + namespace { + + struct Grid { + Grid(uint32_t x, uint32_t y, uint32_t z) + : grid_x(x), grid_y(y), grid_z(z) {} + uint32_t grid_x; + uint32_t grid_y; + uint32_t grid_z; + + bool is_non_zero() { + return grid_x > 0 && grid_y > 0 && grid_z > 0; + } + }; + + } // anonymous namespace + + static inline CUfunction loadKernel( + std::string filePath, + const std::string &funcName, + uint32_t sharedMemBytes, + const std::optional &cubinDir = std::nullopt) { + if (cubinDir) { + std::filesystem::path p1{*cubinDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline void launchKernel( + CUfunction func, + uint32_t gridX, + uint32_t gridY, + uint32_t gridZ, + uint32_t numWarps, + uint32_t sharedMemBytes, + void* args[], + cudaStream_t stream) { + CUDA_DRIVER_CHECK(cuLaunchKernel( + func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr + )); + } + """ + if torch.version.hip is not None: + # Adjusting the warp size to GPU supported wavefront size on AMD GPU + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + source_codes = source_codes.replace( + "32*numWarps", str(prop.warp_size) + "*numWarps" + ) + return source_codes + + +def cuda_kernel_header() -> str: + source_codes = """ + #include + #include + #include + """ + return source_codes diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_prefix.h b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_prefix.h new file mode 100644 index 0000000000000000000000000000000000000000..b46e772cd6ce11287f5cd04e79c7f06357fd4169 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_prefix.h @@ -0,0 +1,973 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// WARNING: be extra careful when including more ATen/c10 header files here! +// Because AOTInductor generated code will copy-paste this cpp_prefix.h for +// the CPU backend, we have to make sure the used headers are implemented +// in a header-only way, i.e. all the function and class definitions are +// in .h files instead of .cpp files, to avoid ABI backward-compatiblity breakage. + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) +#define INDUCTOR_USE_VECTOR_TYPES() 1 +#else +#define INDUCTOR_USE_VECTOR_TYPES() 0 +#endif + +#if INDUCTOR_USE_VECTOR_TYPES() +#include +#include +#else +// For calc_erfinv +#include +#endif + +typedef at::Half half; +typedef at::BFloat16 bfloat16; + +typedef at::Float8_e4m3fn float8_e4m3fn; +typedef at::Float8_e5m2 float8_e5m2; + +template +struct Welford { + T mean = T(0); + T m2 = T(0); + // Use weight for tail cases since the index of each element in the vec may be + // different. A single index can not express masked welford reduction. + T weight = T(0); + uint64_t index = 0; +}; + + +template +struct IsVecType: std::false_type {}; + +#if INDUCTOR_USE_VECTOR_TYPES() +template +struct IsVecType>: std::true_type {}; +#endif + +template +struct WeightRecp { + using scalar_t = typename T::value_type; + std::vector weight_recps; + WeightRecp(uint64_t N) { + weight_recps.reserve(N); + for (const auto i : c10::irange(N)) { + weight_recps.push_back( + scalar_t(static_cast(1) / static_cast(i + 1))); + } + } +}; + +template +Welford welford_combine(const Welford& a, const Welford& b, bool use_index=false) { + if (a.index == 0) { + return b; + } + if (b.index == 0) { + return a; + } + auto delta = b.mean - a.mean; + auto a_weight = use_index ? T(a.index) : a.weight; + auto b_weight = use_index ? T(b.index) : b.weight; + auto new_weight = a_weight + b_weight; + auto new_index = a.index + b.index; + auto wb_over_w = b_weight / new_weight; + if constexpr (IsVecType::value) { + // Guard against division by zero + wb_over_w = T::blendv(wb_over_w, T(0), new_weight == T(0)); + } + auto result = Welford{ + a.mean + delta * wb_over_w, + a.m2 + b.m2 + delta * delta * a_weight * wb_over_w, + new_weight, + new_index + }; + return result; +} + +template +Welford welford_combine(const Welford& acc, const T& data, const WeightRecp* w=nullptr) { + // Add a single data point + uint64_t new_index = acc.index + 1; + auto new_weight = acc.weight + T(1); + auto delta = data - acc.mean; + T new_mean; + if constexpr (!IsVecType::value) { + new_mean = acc.mean + delta / new_weight; + } else { + // use new_index to fecth 1 / new_weight to avoid divisions + new_mean = acc.mean + + ((w == nullptr || acc.index >= w->weight_recps.size()) + ? delta / new_weight + : delta * T(w->weight_recps[acc.index])); + } + auto new_delta = data - new_mean; + auto result = Welford{ + new_mean, + acc.m2 + delta * new_delta, + new_weight, + new_index + }; + return result; +} + +template +struct IndexValue { + int64_t index; + T value; + IndexValue(int64_t idx, T val) :index(idx), value(val) {}; + IndexValue() {}; +}; + +#if INDUCTOR_USE_VECTOR_TYPES() +template +Welford welford_combine(const Welford& acc, const T& data, const int64_t tail_size, const WeightRecp* w=nullptr) { + auto out = welford_combine(acc, data, w); + return Welford{ + T::set(acc.mean, out.mean, tail_size), + T::set(acc.m2, out.m2, tail_size), + T::set(acc.weight, out.weight, tail_size), + out.index + }; +} + +template +T max_masked_reduce(const T& a, const T& b, const int64_t tail_size) { + auto out = at::vec::maximum(a, b); + return T::set(a, out, tail_size); +} + +template +T min_masked_reduce(const T& a, const T& b, const int64_t tail_size) { + auto out = at::vec::minimum(a, b); + return T::set(a, out, tail_size); +} + +template +T sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) { + auto out = a + b; + return T::set(a, out, tail_size); +} + +template +T prod_masked_reduce(const T& a, const T& b, const int64_t tail_size) { + auto out = a * b; + return T::set(a, out, tail_size); +} + +template +T xor_sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) { + auto out = a ^ b; + return T::set(a, out, tail_size); +} +#endif + +// Refer to https://github.com/pytorch/pytorch/blob/b5b36cf0c4e1958f1ff25120f5d4beeef3288187/ +// aten/src/ATen/native/SharedReduceOps.h#L419-L445 +template +inline bool greater_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) { + // If (a == b), then choose the one with lower idx, else max(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a > b); +} + +template +inline bool less_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) { + // If (a == b), then choose the one with lower idx, else min(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a < b); +} + +template +inline IndexValue& argmin_combine(IndexValue& a, T next_value, int64_t next_index){ + if(!(less_or_nan(a.value, next_value, a.index, next_index))){ + a.value = next_value; + a.index = next_index; + } + return a; +} +template +inline IndexValue& argmax_combine(IndexValue& a, T next_value, int64_t next_index){ + if(!(greater_or_nan(a.value, next_value, a.index, next_index))){ + a.value = next_value; + a.index = next_index; + } + return a; +} +template +inline IndexValue& argmin_combine(IndexValue& a, const IndexValue& next){ + return argmin_combine(a, next.value, next.index); +} +template +inline IndexValue& argmax_combine(IndexValue& a, const IndexValue& next){ + return argmax_combine(a, next.value, next.index); +} + +#if INDUCTOR_USE_VECTOR_TYPES() + +template +inline at::vec::Vectorized div_floor_floating_vec( + const at::vec::Vectorized& a, + const at::vec::Vectorized& b) { + using vec_t = at::vec::Vectorized; + const auto basic_div = a / b; + vec_t inf(std::numeric_limits::infinity()); + auto mod = a.fmod(b); + // Fixup for a case that isn't properly handled by Sleef_fmod + auto floor = vec_t::blendv(a - mod, a, (basic_div.abs() == inf) & (a.abs() != inf)); + auto div = floor / b; + const auto zero = vec_t(0); + auto mask = (mod != zero) & ((b < zero) ^ (mod < zero)); + const auto one = vec_t(1); + div = vec_t::blendv(div, div - one, mask); + auto floordiv = div.floor(); + mask = (div - floordiv) > vec_t(0.5); + floordiv = vec_t::blendv(floordiv, floordiv + one, mask); + floordiv = vec_t::blendv(floordiv, zero.copysign(basic_div), div == zero); + floordiv = vec_t::blendv(floordiv, basic_div, b == zero); + return floordiv; +}; + +template +inline at::vec::VectorizedN div_floor_floating_vec( + const at::vec::VectorizedN& a, + const at::vec::VectorizedN& b) { + at::vec::VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = div_floor_floating_vec(a[i], b[i]); + } + return result; +} + +template +struct IndexValueVec { + at::vec::VectorizedN value; + at::vec::VectorizedN index; + + IndexValueVec(const T _value) { + value = at::vec::VectorizedN(_value); + index = at::vec::VectorizedN(0); + }; + + IndexValueVec() {}; +}; + + +template , int> = 0> +at::vec::VecMask inline get_mask_for_argmin_argmax( + const at::vec::VecMask& vmask, + const IndexValueVec& a, + const at::vec::VectorizedN& value, + const at::vec::VectorizedN& index +){ + /* + vec impl for less_or_nan and greater_or_nan + example for argmin: + a.value = [NaN, NaN, 0, 2, 1, 0] + value = [NaN, 0, 0, 1, 2, NaN] + vmask = [false, false, false, false, true, false] + all_nan_or_equal = [true, false, true, false, false, false] + imask = [a.index[0] < index[0], ..., a.index[-1] < index[-1]] + iv_mask = blendv (vmask, imask, all_nan_or_equal) + [a.index[0] < index[0], false, a.index[2] < index[2], false, true, false] + a_nan_b_not: [false, false, false, false, false, true] + mask = iv_mask | a_nan_b_not + [a.index[0] < index[0], false, a.index[2] < index[2], false, true, true] + */ + using v_t = at::vec::VecMask; + using i_t = at::vec::VecMask; + i_t vmask_itype = vmask.template cast(); + // use itype here since there is vec impl for operator~ for itype + // while there may not vec impl for vtype + v_t isnan_a = a.value.isnan(); + i_t isnan_a_itype = isnan_a.template cast(); + v_t isnan_b = value.isnan(); + i_t isnan_b_type = isnan_b.template cast(); + i_t all_nan_mask = isnan_a_itype & isnan_b_type; + v_t equal_mask = (a.value == value); + i_t equal_mask_itype = equal_mask.template cast(); + i_t all_nan_or_equal = all_nan_mask | equal_mask_itype; + i_t imask(a.index < index); + i_t iv_mask = i_t::blendv(vmask_itype, imask, all_nan_or_equal); + i_t isnan_a_notnan_b = isnan_a_itype & (~isnan_b_type); + return iv_mask | isnan_a_notnan_b; +} + +template , int> = 0> +at::vec::VecMask inline get_mask_for_argmin_argmax( + const at::vec::VecMask& vmask, + const IndexValueVec& a, + const at::vec::VectorizedN& value, + const at::vec::VectorizedN& index +){ + using v_t = at::vec::VecMask; + using i_t = at::vec::VecMask; + i_t vmask_itype = vmask.template cast(); + v_t equal_mask = (a.value == value); + i_t equal_mask_itype = equal_mask.template cast(); + i_t imask(a.index < index); + return i_t::blendv(vmask_itype, imask, equal_mask_itype); +} + + +template +inline IndexValueVec& argmin_vec_impl(IndexValueVec& a, at::vec::VectorizedN value, at::vec::VectorizedN index, std::optional tail_size){ + at::vec::VecMask vmask(a.value < value); + at::vec::VecMask final_mask = get_mask_for_argmin_argmax(vmask, a, value, index); + if (tail_size.has_value()) { + a.value = at::vec::VectorizedN::set(a.value, at::vec::minimum(a.value, value), tail_size.value()); + a.index = at::vec::VectorizedN::set(a.index, at::vec::VecMask::blendv(index, a.index, final_mask), tail_size.value()); + } else { + a.value = at::vec::minimum(a.value, value); + a.index = at::vec::VecMask::blendv(index, a.index, final_mask); + } + return a; +} + +template +inline IndexValueVec& argmax_vec_impl(IndexValueVec& a, at::vec::VectorizedN value, at::vec::VectorizedN index, std::optional tail_size){ + at::vec::VecMask vmask(a.value > value); + at::vec::VecMask final_mask = get_mask_for_argmin_argmax(vmask, a, value, index); + if (tail_size.has_value()) { + a.value = at::vec::VectorizedN::set(a.value, at::vec::maximum(a.value, value), tail_size.value()); + a.index = at::vec::VectorizedN::set(a.index, at::vec::VecMask::blendv(index, a.index, final_mask), tail_size.value()); + } else { + a.value = at::vec::maximum(a.value, value); + a.index = at::vec::VecMask::blendv(index, a.index, final_mask); + } + return a; +} + +template +inline at::vec::VectorizedN create_index(int64_t next_index){ + at::vec::VectorizedN next_idx; + if constexpr (horizontal) { + next_idx = at::vec::VectorizedN::arange(next_index, 1); + } else { + next_idx = at::vec::VectorizedN(next_index); + } + return next_idx; +} + +template +inline IndexValueVec& argmin_combine_vec(IndexValueVec& a, at::vec::VectorizedN next_value, int64_t next_index, std::optional tail_size = std::nullopt){ + auto next_idx = create_index(next_index); + return argmin_vec_impl(a, next_value, next_idx, tail_size); +} + +template +inline IndexValueVec& argmax_combine_vec(IndexValueVec& a, at::vec::VectorizedN next_value, int64_t next_index, std::optional tail_size = std::nullopt){ + auto next_idx = create_index(next_index); + return argmax_vec_impl(a, next_value, next_idx, tail_size); +} + +template +inline IndexValue argmin_vec_reduce_all(const IndexValueVec& vec){ + constexpr int len = at::vec::VectorizedN::size(); + __at_align__ T tmpval[len]; + __at_align__ int64_t tmpidx[len]; + vec.value.store(tmpval); + vec.index.store(tmpidx); + IndexValue res = IndexValue(tmpidx[0], tmpval[0]); + for (int i = 1; i < len; i++){ + res = argmin_combine(res, tmpval[i], tmpidx[i]); + } + return res; +} + +template +inline IndexValue argmax_vec_reduce_all(const IndexValueVec& vec){ + constexpr int len = at::vec::VectorizedN::size(); + __at_align__ T tmpval[len]; + __at_align__ int64_t tmpidx[len]; + vec.value.store(tmpval); + vec.index.store(tmpidx); + IndexValue res = IndexValue(tmpidx[0], tmpval[0]); + for (int i = 1; i < len; i++){ + res = argmax_combine(res, tmpval[i], tmpidx[i]); + } + return res; +} + +template +inline IndexValueVec& argmin_combine_vec(IndexValueVec& vec_a, const IndexValueVec& vec_b, std::optional tail_size = std::nullopt){ + return argmin_vec_impl(vec_a, vec_b.value, vec_b.index, tail_size); +} + +template +inline IndexValueVec& argmax_combine_vec(IndexValueVec& vec_a, const IndexValueVec& vec_b, std::optional tail_size = std::nullopt){ + return argmax_vec_impl(vec_a, vec_b.value, vec_b.index, tail_size); +} + +template +inline at::vec::Vectorized vec_shuffle_down(at::vec::Vectorized x, size_t n) { + using Vec = at::vec::Vectorized; + alignas(alignof(Vec)) scalar_t array[Vec::size()]; + x.store(array); + for (size_t i = 0; i + n < Vec::size(); i += 2 * n) { + array[i] = array[i + n]; + } + return Vec::loadu(array); +} + +#ifdef CPU_CAPABILITY_AVX2 +inline at::vec::Vectorized vec_shuffle_down(at::vec::Vectorized x, size_t n) { + using vec_t = at::vec::Vectorized; +#define SHUFFLE_MASK(z, y, x, w) ((z << 6) | (y << 4) | (x << 2) | w) + switch (n) { + case 1: + return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(1, 1, 3, 3))); + case 2: + return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(2, 2, 2, 2))); + case 4: + return vec_t(_mm256_permute2f128_ps(x, x, SHUFFLE_MASK(1, 1, 1, 1))); + } + TORCH_CHECK(false, "Unhandled vec_shuffle_down value ", n); +} +#endif + +#ifdef CPU_CAPABILITY_AVX512 +inline at::vec::Vectorized vec_shuffle_down(at::vec::Vectorized x, size_t n) { + using vec_t = at::vec::Vectorized; +#define SHUFFLE_MASK(z, y, x, w) ((z << 6) | (y << 4) | (x << 2) | w) + switch (n) { + case 1: + return vec_t(_mm512_permute_ps(x, SHUFFLE_MASK(1, 1, 3, 3))); + case 2: + return vec_t(_mm512_permute_ps(x, SHUFFLE_MASK(2, 2, 2, 2))); + case 4: + return vec_t(_mm512_permutexvar_ps( + _mm512_set_epi32( + 12, 12, 12, 12, 12, 12, 12, 12, 4, 4, 4, 4, 4, 4, 4, 4), + x)); + case 8: + return vec_t(_mm512_permutexvar_ps( + _mm512_set_epi32(8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8), x)); + } + TORCH_CHECK(false, "Unhandled vec_shuffle_down value ", n); +} +#endif + +template +Welford welford_vec_reduce_all(Welford> acc) { + using Vec = at::vec::Vectorized; + Welford result; + if (acc.index == 0) { + return result; + } + // if all values of acc.weight are same as index, + // use index to reduce to save the overhead of vec_shuffle_down for acc.weight + bool use_index = (acc.weight - Vec(acc.index)).zero_mask() == static_cast((1 << Vec::size()) - 1); + for (size_t n = 1; n < Vec::size(); n *= 2) { + auto shuffled = Welford{ + vec_shuffle_down(acc.mean, n), + vec_shuffle_down(acc.m2, n), + use_index ? Vec(0) : vec_shuffle_down(acc.weight, n), + acc.index}; + acc = welford_combine(acc, shuffled, use_index); + } + + alignas(alignof(Vec)) scalar_t array[Vec::size()]; + acc.mean.store(array); + result.mean = array[0]; + + acc.m2.store(array); + result.m2 = array[0]; + + acc.weight.store(array); + result.weight = array[0]; + result.index = result.weight; + + return result; +} + +template +Welford welford_vec_reduce_all(Welford> acc) { + auto Welford0 = Welford>{ + acc.mean[0], + acc.m2[0], + acc.weight[0], + acc.index + }; + auto Welford1 = Welford>{ + acc.mean[1], + acc.m2[1], + acc.weight[1], + acc.index + }; + return welford_vec_reduce_all(welford_combine(Welford0, Welford1)); +} +#endif + + +template inline typename std::common_type::type mod(T a, U b) { return a % b; } +template <> inline float mod(float a, float b) { return std::fmod(a, b); } +template <> inline double mod(double a, double b) { return std::fmod(a, b); } + +template +inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) { + if (at::_isnan(a)) { + return a; + } + return a > b ? a : b; +} + +template +inline scalar_t min_propagate_nan(scalar_t a, scalar_t b) { + if (at::_isnan(a)) { + return a; + } + return a < b ? a : b; +} + +constexpr float uint32_to_uniform_float(uint32_t value) { + // maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + constexpr float scale = 4.6566127342e-10; + return static_cast(value & 0x7FFFFFFF) * scale; +} + +float normalized_rand_cpu(uint32_t seed, uint32_t offset) { + return uint32_to_uniform_float(at::Philox4_32(seed, 0, offset)()); +} + +float randn_cpu(uint32_t seed, uint32_t offset) { + at::Philox4_32 engine(seed, 0, offset); + return engine.randn(10); +} + +int64_t randint64_cpu(uint32_t seed, uint32_t offset, int64_t low, int64_t high) { + auto gen = at::Philox4_32(seed, 0, offset); + uint64_t r0 = gen(); + uint64_t r1 = gen(); + uint64_t result = r0 | (r1 << 32); + return static_cast(result % (high - low)) + low; +} + +template struct AsIntegerType { typedef T type; }; +template <> struct AsIntegerType { typedef uint32_t type; }; +template <> struct AsIntegerType { typedef uint64_t type; }; +template <> struct AsIntegerType { typedef uint16_t type; }; + +template +typename std::enable_if_t, T> +inline fetch_value(volatile T *addr) { + return *addr; +} + +template +typename std::enable_if_t, T> +inline fetch_value(volatile T *addr) { + return T(addr->x, T::from_bits()); +} + +template +typename std::enable_if_t> +atomic_add(volatile T *addr, T offset) { + typedef typename AsIntegerType::type alt_type; + + static_assert(sizeof(std::atomic) == sizeof(T), + "std::atomic issue"); + + alt_type expected; + + alt_type desired; + + std::atomic *atomic_addr = (std::atomic *)addr; + do { + T val = fetch_value(addr); + reinterpret_cast(&expected)[0] = val; + reinterpret_cast(&desired)[0] = val + offset; + } while (!atomic_addr->compare_exchange_weak(expected, desired, + std::memory_order_relaxed)); +} + +// Since C++20 float is supported by fetch_add, but the performance may not +// better than compare_exchange_weak, which can be checked by microbenchmark +// inductor_cpu_atomic.py +template +typename std::enable_if_t> +atomic_add(volatile T *addr, T offset) { + static_assert(sizeof(std::atomic) == sizeof(T), + "std::atomic issue"); + std::atomic *atomic_addr = (std::atomic *)addr; + atomic_addr->fetch_add(offset, std::memory_order_relaxed); +} + +#if INDUCTOR_USE_VECTOR_TYPES() +template +void atomic_add_vec(T *addr, at::vec::VectorizedN index, at::vec::VectorizedN offset) { + constexpr int len = at::vec::VectorizedN::size(); + static_assert(len <= at::vec::VectorizedN::size()); + __at_align__ std::array tmpbuf; + __at_align__ std::array tmpidx; + offset.store(tmpbuf.data()); + index.store(tmpidx.data()); + for (int i = 0; i < len; i++){ + atomic_add(addr + tmpidx[i], tmpbuf[i]); + } +} +#endif + +std::tuple, int> _get_factors(int64_t number) { + int count = 0; + for (int64_t i = std::sqrt(number); i > 0; --i) { + if (number % i == 0) { + count += 2; + } + } + auto factors = std::shared_ptr(new int64_t[count]); + int index = 0; + for (int64_t i = std::sqrt(number); i > 0; --i) { + if (number % i == 0) { + factors[index++] = number / i; + factors[index++] = i; + } + } + return std::make_tuple(factors, count); +} + +std::tuple, int> get_factors(int64_t number) { + thread_local std::map, int>> cache; + auto it = cache.find(number); + if (it != cache.end()) { + return it->second; + } else { + auto factors = _get_factors(number); + cache[number] = factors; + return factors; + } +} + +void _mm_get_thread_blocking( + int num_threads, + int max_k_slices, + int64_t M, + int64_t N, + int64_t K, + int64_t Mr, + int64_t Nr, + int64_t Kr, + int64_t& Mt, + int64_t& Nt, + int64_t& Kt) { + // see NOTE [Thread blocking in Cpp GEMM] for heuristics + Mt = Nt = Kt = 0; + + auto get_blocking = [](int64_t m_factor, + int64_t n_factor, + int64_t k_factor, + int64_t m_blocks, + int64_t n_blocks, + int64_t k_blocks) { + int64_t thread_block_k = (k_blocks + k_factor - 1) / k_factor; + int64_t thread_block_n = (n_blocks + n_factor - 1) / n_factor; + int64_t thread_block_m = (m_blocks + m_factor - 1) / m_factor; + return std::make_tuple(thread_block_m, thread_block_n, thread_block_k); + }; + + auto is_better_blocking = [=](int64_t Mt_, + int64_t Nt_, + int64_t Kt_, + int64_t Mt, + int64_t Nt, + int64_t Kt) { + return Mt == 0 || Kt_ < Kt || Mt_ * Mr + Nt_ * Nr < Mt * Mr + Nt * Nr; + }; + + int64_t m_blocks = (M + Mr - 1) / Mr; + int64_t n_blocks = (N + Nr - 1) / Nr; + int64_t k_blocks = (K + Kr - 1) / Kr; + + auto [factors, count] = get_factors(num_threads); + assert(count > 0); + + for (int i = 0; i < count; ++i) { + int64_t n_factor = factors[i]; + int64_t m_factor = num_threads / n_factor; + if (n_blocks >= n_factor && m_blocks >= m_factor) { + auto [Mt_, Nt_, Kt_] = get_blocking( + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks); + if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) { + std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_); + } + } + } + + if (Mt != 0) { + return; + } + + for (int i = 0; i < count; ++i) { + int64_t k_factor = factors[i]; + if (k_blocks >= k_factor && (max_k_slices == 0 || k_factor <= max_k_slices)) { + auto [mxn_factors, mxn_count] = get_factors(num_threads / k_factor); + for (int j = 0; j < mxn_count; ++j) { + int64_t n_factor = mxn_factors[j]; + int64_t m_factor = num_threads / (k_factor * n_factor); + if (n_blocks >= n_factor && m_blocks >= m_factor) { + auto [Mt_, Nt_, Kt_] = get_blocking( + m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks); + if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) { + std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_); + } + } + } + } + } + + if (Mt != 0) { + return; + } + + for (int i = 0; i < count; ++i) { + int64_t n_factor = factors[i]; + int64_t m_factor = num_threads / n_factor; + if (n_blocks >= n_factor || m_blocks >= m_factor) { + auto [Mt_, Nt_, Kt_] = get_blocking( + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks); + if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) { + std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_); + } + } + } + + assert(Mt != 0); +} + +void mm_get_thread_blocking( + int num_threads, + int max_k_slices, + int64_t M, + int64_t N, + int64_t K, + int64_t Mr, + int64_t Nr, + int64_t Kr, + int64_t& Mt, + int64_t& Nt, + int64_t& Kt) { + thread_local std::map< + std::tuple, + std::tuple> cache; + auto key = std::make_tuple(num_threads, max_k_slices, M, N, K, Mr, Nr, Kr); + auto it = cache.find(key); + if (it != cache.end()) { + std::tie(Mt, Nt, Kt) = it->second; + return; + } else { + _mm_get_thread_blocking(num_threads, max_k_slices, M, N, K, Mr, Nr, Kr, Mt, Nt, Kt); + cache[key] = std::make_tuple(Mt, Nt, Kt); + } +} + +template +void _mm_get_cache_blocking( + int num_threads, + int64_t M, + int64_t N, + int64_t K, + int64_t Mr, + int64_t Nr, + int64_t Kr, + int64_t Mt_blocks, + int64_t Nt_blocks, + int64_t Kt_blocks, + int64_t& Mc_blocks, + int64_t& Nc_blocks, + int64_t& Kc_blocks, + uint32_t L1_cache_size, + uint32_t L2_cache_size) { + // See NOTE [CPP GEMM Cache Blocking Algorithm] for the cache blocking algorithm. + // TODO(jgong5): cache cache blocking results + // TODO: tune the factor here + float L1_limit_factor = 0.8; + float L2_limit_factor = 0.5; + + auto L1 = L1_cache_size * L1_limit_factor; + auto L2 = L2_cache_size * L2_limit_factor; + + constexpr size_t num_byte_A = sizeof(X_t); + constexpr size_t num_byte_B = sizeof(W_t); + + int64_t size_cache_B = Kr * Kt_blocks * Nr * num_byte_B; + Kc_blocks = Kt_blocks; + if (size_cache_B > L1) { + Kc_blocks = (int64_t)std::floor(L1 / (Kr * Nr * num_byte_B)); + } + + float min_Mc_ratio = 2; + int64_t min_Mc_blocks = std::ceil(min_Mc_ratio * Mr / Nr); + auto Kt_bytes = Kt_blocks * Kr * num_byte_A; + if (min_Mc_blocks * Mr * Kt_bytes < L2) { + Mc_blocks = std::min(Mt_blocks, (int64_t)std::floor(L2 / (Mr * Kt_bytes))); + Nc_blocks = 1; + } else { + Mc_blocks = Mt_blocks; + Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks); + auto Nc_bytes = Nc_blocks * Nr * 4; + auto Kc_bytes = Kc_blocks * Kr * num_byte_A; + if (Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2) { + auto M_max = (std::sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8; + if (M_max < Mc_blocks * Mr) { + Mc_blocks = (int64_t)std::floor(M_max / Mr); + Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks); + } + } + } +} + +template +void mm_get_cache_blocking( + int num_threads, + int64_t M, + int64_t N, + int64_t K, + int64_t Mr, + int64_t Nr, + int64_t Kr, + int64_t Mt_blocks, + int64_t Nt_blocks, + int64_t Kt_blocks, + int64_t& Mc_blocks, + int64_t& Nc_blocks, + int64_t& Kc_blocks, + uint32_t L1_cache_size, + uint32_t L2_cache_size) { + thread_local std::map< + std::tuple, + std::tuple> cache; + auto key = std::make_tuple(num_threads, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks, L1_cache_size, L2_cache_size); + auto it = cache.find(key); + if (it != cache.end()) { + std::tie(Mc_blocks, Nc_blocks, Kc_blocks) = it->second; + return; + } else { + _mm_get_cache_blocking( + num_threads, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks, Mc_blocks, Nc_blocks, Kc_blocks, L1_cache_size, L2_cache_size); + cache[key] = std::make_tuple(Mc_blocks, Nc_blocks, Kc_blocks); + } +} + +inline void mm_get_thread_blocks( + int thread_id, + int64_t M_blocks, + int64_t N_blocks, + int64_t K_blocks, + int64_t Mt_blocks, + int64_t Nt_blocks, + int64_t Kt_blocks, + int64_t& m_block_start, + int64_t& m_block_end, + int64_t& n_block_start, + int64_t& n_block_end, + int64_t& k_block_start, + int64_t& k_block_end) { + int64_t num_Kt = (K_blocks + Kt_blocks - 1) / Kt_blocks; + k_block_start = (thread_id % num_Kt) * Kt_blocks; + k_block_end = std::min(k_block_start + Kt_blocks, K_blocks); + thread_id /= num_Kt; + int64_t num_Nt = (N_blocks + Nt_blocks - 1) / Nt_blocks; + n_block_start = (thread_id % num_Nt) * Nt_blocks; + n_block_end = std::min(n_block_start + Nt_blocks, N_blocks); + thread_id /= num_Nt; + m_block_start = std::min(thread_id * Mt_blocks, M_blocks); + m_block_end = std::min(m_block_start + Mt_blocks, M_blocks); +} + +struct amx_tilecfg { + uint8_t palette_id; + uint8_t start_row; + uint8_t reserved_0[14]; + uint16_t colsb[16]; + uint8_t rows[16]; +}; + +class AMXState { + private: + amx_tilecfg tilecfg_; + uint8_t rows_; + uint16_t colsb_; + uint8_t num_tile_rows_; + uint8_t num_tile_columns_; + + public: + AMXState() : rows_(0), colsb_(0), num_tile_rows_(0), num_tile_columns_(0) { + memset(&tilecfg_, 0, sizeof(tilecfg_)); + } + + inline void configure( + uint8_t rows, + uint16_t colsb, + uint8_t num_tile_rows, + uint8_t num_tile_columns, + void (*loadconfig)(const amx_tilecfg&)) { + if (tilecfg_.palette_id == 1 && rows_ == rows && colsb_ == colsb && + num_tile_rows_ == num_tile_rows && + num_tile_columns_ == num_tile_columns) { + return; + } + tilecfg_.palette_id = 1; + rows_ = rows; + colsb_ = colsb; + num_tile_rows_ = num_tile_rows; + num_tile_columns_ = num_tile_columns; + const auto num_c_tiles = num_tile_rows * num_tile_columns; + // For C + for (int i = 0; i < num_c_tiles; i++) { + tilecfg_.rows[i] = rows; + tilecfg_.colsb[i] = 64; + } + // For A + for (int i = 0; i < num_tile_rows; i++) { + tilecfg_.rows[i + num_c_tiles] = rows; + tilecfg_.colsb[i + num_c_tiles] = colsb; + } + // For B + for (int i = 0; i < num_tile_columns; i++) { + tilecfg_.rows[i + num_c_tiles + num_tile_rows] = colsb / 4; + tilecfg_.colsb[i + num_c_tiles + num_tile_rows] = 64; + } + loadconfig(tilecfg_); + } + + inline void release(void (*tile_release)()) { + tilecfg_.palette_id = 0; + tile_release(); + } +}; diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_split_scan.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_split_scan.py new file mode 100644 index 0000000000000000000000000000000000000000..9eea00fbb8d6b6127e2aa117f8da8be36b556e5b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_split_scan.py @@ -0,0 +1,174 @@ +# mypy: allow-untyped-defs +import functools +from typing import Optional + +import torch._inductor.runtime.hints +from torch._inductor import config +from torch._inductor.codegen.simd import IterationRangesRoot +from torch._inductor.codegen.triton import triton_compute_type, TritonKernel +from torch._prims_common import prod +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CeilDiv + + +class TritonSplitScanKernel(TritonKernel): + """Generates a triton kernel that supports ops.scan calls while also splitting + the reduction dimension over multiple triton programs. + + For this kernel, loop numels will always take the form ``(xdim, rdim)`` + and the grid has the shape ``(CeilDiv(rdim, RBLOCK), xdim)``. Communication + between blocks occurs within a global memory workspace buffer, which + must be zero-filled before launching the kernel. + + Note that generation for ``ops.reduction`` is not supported. + + For details of the communication strategy, see + https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + """ + + def __init__( + self, + *groups, + index_dtype: str, + mutations: Optional[OrderedSet[str]] = None, + reduction_hint=torch._inductor.runtime.hints.ReductionHint.DEFAULT, + min_elem_per_thread=0, + ) -> None: + super().__init__( + *groups, + index_dtype=index_dtype, + mutations=mutations, + pid_cache=None, + reduction_hint=reduction_hint, + min_elem_per_thread=min_elem_per_thread, + ) + self.no_x_dim = True + + def should_use_persistent_reduction(self) -> bool: + return False + + def initialize_range_tree(self, pid_cache): + prefixes = "yxr" + assert len(self.numels) <= len( + prefixes + ), "z dimension not supported for split scan" + active_prefixes = prefixes[len(prefixes) - len(self.numels) :] + + grid_dims = "rxy" + for numel, prefix in zip(self.numels, active_prefixes): + is_reduction = prefix == "r" + tensor_dim = 0 if is_reduction else None + grid_dim = grid_dims.find(prefix) + self.range_trees.append( + IterationRangesRoot( + f"{prefix}index", + numel, + prefix, + grid_dim, + self, + pid_cache=pid_cache, + is_loop=False, + tensor_dim=tensor_dim, + grid_dim=grid_dim, + has_zdim=False, + ) + ) + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise NotImplementedError("NYI TritonSplitDimKernel reductions") + + def scan(self, dtypes, combine_fn, values): + import triton.language as tl + + (dtype,) = dtypes + (value,) = values + + compute_type = triton_compute_type(dtype) + compute_type_triton = getattr(tl, compute_type[3:]) + + element_nbits = compute_type_triton.primitive_bitwidth + + scratch_type = "tl.uint32" if element_nbits <= 16 else "tl.uint64" + scratch_type_triton = getattr(tl, scratch_type[3:]) + scratch_elems_per_block = 3 if element_nbits == 64 else 1 + scratch_nbytes_per_block = scratch_elems_per_block * ( + scratch_type_triton.primitive_bitwidth // 8 + ) + + cse_load = functools.partial(self.cse.generate, self.loads) + cse_compute = functools.partial(self.cse.generate, self.compute) + + assert len(self.numels) == 2, "Unexpected tiling" + min_rblock = config.triton.min_split_scan_rblock + max_blocks = prod(self.numels[:-1]) * CeilDiv(self.numels[-1], min_rblock) + nbytes = scratch_nbytes_per_block * max_blocks + scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True) + if offset != 0: + scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}") + runtime_rblocks = cse_load(f"tl.num_programs({self.range_trees[-1].index})") + scratch_base = cse_load( + f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * " + f"{scratch_elems_per_block} * {runtime_rblocks}" + ) + + masks = {f"{tree.prefix}mask" for tree in self.range_trees} + self.filter_masks(masks) + masks = sorted(masks) + assert not self._load_mask, "ops.scan not supported inside ops.masked" + + value = cse_compute(f"{value}.to({compute_type})") + value = cse_compute(f"tl.broadcast_to({value}, {self.dense_size_str()})") + + combine_helper_fn = self._lift_helper(combine_fn, 1) + dim = self.triton_tensor_ndim() - 1 + assert dim == 0, "" + + block_sum = cse_compute(f"tl.reduce({value}, {dim}, {combine_helper_fn})") + exclusive_prefix = self.cse.newvar() + if element_nbits == 64: + self.compute.splice( + f""" + {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback_64( + {scratch_base}, + {block_sum}, + {self.iteration_ranges_get_pid(self.range_trees[-1])}, + {combine_helper_fn}, + ) + """, + strip=True, + ) + + else: + assert element_nbits <= 32 + value_as_uint_dtype = f"tl.uint{element_nbits}" + + self.compute.splice( + f""" + {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback( + {scratch_base}, + {block_sum}, + {self.iteration_ranges_get_pid(self.range_trees[-1])}, + {combine_helper_fn}, + DTYPE_VALUE_AS_UINT={value_as_uint_dtype}, + DTYPE_PACK={scratch_type}, + ) + """, + strip=True, + ) + # Compute final cumsum + block_scan = cse_compute( + f"tl.associative_scan({value}, {dim}, {combine_helper_fn})" + ) + combined_result = cse_compute( + f"{combine_helper_fn}({exclusive_prefix}, {block_scan})" + ) + return ( + cse_compute(f"tl.where(roffset == 0, {block_scan}, {combined_result})"), + ) + + def _get_heuristic(self): + return "split_scan" + + def _get_grid_fn(self): + return "split_scan_grid" diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/wrapper.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc422f4c294cb719e1036e72cfc68d32e58ced8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/wrapper.py @@ -0,0 +1,2057 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import dis +import functools +import inspect +import logging +import operator +import re +import tempfile +from itertools import count +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) + +import sympy +from sympy import Expr + +import torch +import torch._ops +from torch import dtype as torch_dtype +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor.codegen.debug_utils import DebugPrinterManager +from torch._inductor.codegen.multi_kernel import MultiKernelState +from torch._inductor.runtime.runtime_utils import cache_dir +from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes +from torch.fx.node import _get_qualified_name +from torch.utils._sympy.singleton_int import SingletonInt +from torch.utils._sympy.symbol import symbol_is_type, SymT + +from .. import async_compile, config, ir +from ..codecache import output_code_log +from ..ir import ReinterpretView +from ..runtime import triton_heuristics +from ..runtime.hints import DeviceProperties +from ..utils import ( + cache_on_self, + get_benchmark_name, + LineContext, + sympy_product, + sympy_str, +) +from ..virtualized import V +from .aoti_hipify_utils import maybe_hipify_code_wrapper +from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter +from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta + + +if TYPE_CHECKING: + import triton + + from ..graph import GraphLowering + + +pexpr = PythonPrinter().doprint + + +ReuseKey = Tuple[torch.device, torch.dtype, str] + + +def buffer_reuse_key(node: ir.Buffer) -> ReuseKey: + return ( + node.get_device(), + node.get_dtype(), + # NB: this is symbolic so that we don't try to reuse a buffer + # for s0 for s1, just because they happen to share the same + # size hint + sympy_str(V.graph.sizevars.simplify(node.layout.storage_size())), + ) + + +def convert_arg_type(arg: torch.Argument) -> str: + from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP + + # use x.real_type instead of x.type so that we get ScalarType instead of int + python_type = repr(arg.real_type) # type: ignore[attr-defined] + + if python_type == "Tensor": + # Conversions rules follow https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#func + if arg.alias_info is not None and arg.alias_info.is_write: + return f"at::{python_type}&" + else: + return f"at::{python_type} const&" + + if python_type in PYTHON_TO_CPP: + cpp_type = PYTHON_TO_CPP[python_type] + return cpp_type + + # Convert args of container types e.g. Optional[*] + for py_container, cpp_container in CONTAINER_PYTHON_TO_CPP.items(): + container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type) + if len(container_match) == 1: + contained_type = container_match[0] + assert ( + contained_type in PYTHON_TO_CPP + ), f"unsupported {py_container} type in convert_arg_type: {contained_type}" + cpp_contained_type = PYTHON_TO_CPP[contained_type] + return f"{cpp_container}<{cpp_contained_type}>" + + raise AssertionError(f"unsupport python_type: {python_type}") + + +def convert_return_type(ret: torch.Argument) -> str: + # use x.real_type instead of x.type so that we get ScalarType instead of int + python_type = repr(ret.real_type) # type: ignore[attr-defined] + python_to_cpp = { + "Tensor": "at::Tensor", + "List[Tensor]": "std::vector", + } + + cpp_type = python_to_cpp.get(python_type, None) + assert cpp_type is not None, f"NYI return type: {python_type}" + # An output aliasing an input is returned by reference only when it's a + # Tensor, not when it's a Tensor[]. For example, aten.split.Tensor's output + # aliases the input tensor, but the op returns a vector by value. + if python_type == "Tensor" and ret.alias_info is not None: + cpp_type += "&" + return cpp_type + + +def get_cpp_op_schema(kernel: torch._ops.OpOverload) -> str: + args = kernel._schema.arguments + returns = kernel._schema.returns + + num_returns = len(returns) + assert num_returns > 0, "must have at least one return value" + + if num_returns == 1: + cpp_return_value = convert_return_type(returns[0]) + elif num_returns > 1: + tuple_returns = ", ".join([convert_return_type(r) for r in returns]) + cpp_return_value = f"std::tuple<{tuple_returns}>" + + cpp_arg_type = [f"{convert_arg_type(arg)} {arg.name}" for arg in args] + return f"{cpp_return_value}({', '.join(cpp_arg_type)})" # type: ignore[possibly-undefined] + + +# TODO: Move to a well known place +TritonMetaParams = Dict[str, int] +TritonGrid = Union[ + Tuple[Union[int, sympy.Expr], ...], Callable[[TritonMetaParams], Tuple[int, ...]] +] + + +def user_defined_kernel_grid_fn_code( + name: str, + configs: List[triton.Config], # type: ignore[name-defined] + grids: List[TritonGrid], + wrapper: Optional[WrapperCodeGen] = None, +) -> Tuple[str, str]: + output = IndentedBuffer() + + def _convert_to_sympy_expr(item: Union[int, sympy.Expr]) -> sympy.Expr: + return item if isinstance(item, sympy.Expr) else sympy.Integer(item) + + def determine_grid( + grid: TritonGrid, + ): + """ + This function return a tuple of two values: the first one is for the real grid + which is used in the generated code; the second one is an example grid with + concreate values which is used in the autotune block to run the generated + kernels at compile time. + """ + if wrapper is None or callable(grid): + # return as-is when used in eager mode or when grid is callable + return grid, grid + # Grid contains ints/Expr, so utilize wrapper's expr printer for codegen + sympy_grid = tuple(_convert_to_sympy_expr(g) for g in grid) + return ( + wrapper.codegen_shape_tuple(sympy_grid), + wrapper.codegen_shape_tuple( + tuple( + wrapper.generate_example_arg_value(g, type(g)) for g in sympy_grid + ) + ) + if config.triton.autotune_at_compile_time + else None, + ) + + def writeline(line: str, example_grid: Optional[str] = None): + output.writeline(line) + if ( + wrapper + and config.triton.autotune_at_compile_time + and name not in wrapper.kernel_autotune_names + ): + wrapper.kernel_autotune_calls.writeline(example_grid or line) + + fn_name = f"grid_wrapper_for_{name}" + writeline(f"def {fn_name}(meta):") + kernel_autotune_calls_indent = ( + wrapper.kernel_autotune_calls.indent() + if wrapper and config.triton.autotune_at_compile_time + else contextlib.nullcontext() + ) + with output.indent(), kernel_autotune_calls_indent: + if len(grids) == 1: + grid, example_grid = determine_grid(grids[0]) + writeline(f"return {grid}", f"return {example_grid}") + else: + assert len(grids) > 1 + assert len(grids) == len(configs) + seen = set() + for grid, c in zip(grids, configs): + guards = [f"meta['{name}'] == {val}" for name, val in c.kwargs.items()] + guards = " and ".join(guards) + grid, example_grid = determine_grid(grid) + statement = f"if {guards}: return {grid}" + if statement in seen: + continue + seen.add(statement) + writeline(statement, f"if {guards}: return {example_grid}") + + return fn_name, output.getvalue() + + +@dataclasses.dataclass +class SymbolicCallArg: + inner: str + # the original symbolic expression represented by inner + inner_expr: sympy.Expr + + def __str__(self): + return str(self.inner) + + +# Default thread stack sizes vary by platform: +# - Linux: 8 MB +# - macOS: 512 KB +# - Windows: 1 MB +# Just pick something comfortably smaller than the smallest for now. +MAX_STACK_ALLOCATION_SIZE = 1024 * 100 + + +class MemoryPlanningState: + def __init__(self): + super().__init__() + self.reuse_pool: Dict[ + ReuseKey, List[FreeIfNotReusedLine] + ] = collections.defaultdict(list) + self.total_allocated_buffer_size: int = 0 + + def __contains__(self, key: ReuseKey) -> bool: + return bool(self.reuse_pool.get(key, None)) + + def pop(self, key: ReuseKey) -> FreeIfNotReusedLine: + item = self.reuse_pool[key].pop() + assert not item.is_reused + return item + + def push(self, key: ReuseKey, item: FreeIfNotReusedLine) -> None: + assert not item.is_reused + self.reuse_pool[key].append(item) + + +class WrapperLine: + pass + + +@dataclasses.dataclass +class EnterSubgraphLine(WrapperLine): + wrapper: WrapperCodeGen + graph: GraphLowering + + def __post_init__(self) -> None: + self.wrapper.push_computed_sizes(self.wrapper.computed_sizes) + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper.push_codegened_graph(self.graph) + code.do_indent() + + +@dataclasses.dataclass +class ExitSubgraphLine(WrapperLine): + wrapper: WrapperCodeGen + + def __post_init__(self) -> None: + self.wrapper.computed_sizes = self.wrapper.pop_computed_sizes() + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper.pop_codegened_graph() + code.do_unindent() + + +@dataclasses.dataclass +class EnterDeviceContextManagerLine(WrapperLine): + device_idx: int + last_seen_device_guard_index: Optional[int] + + def codegen(self, code: IndentedBuffer) -> None: + if V.graph.cpp_wrapper: + code.writeline("\n") + if V.graph.aot_mode: + # In AOT mode, we have a stream provided as a param. A stream is + # associated with a device, so we never expect the device to change. + # CUDAStreamGuard sets the stream and the device. + if self.last_seen_device_guard_index is None: + if config.abi_compatible: + code.writeline( + "AOTICudaStreamGuard stream_guard(stream, this->device_idx_);" + ) + else: + code.writeline( + maybe_hipify_code_wrapper( + "at::cuda::CUDAStreamGuard stream_guard(" + + "at::cuda::getStreamFromExternal(stream, this->device_idx_));" + ) + ) + else: + assert ( + self.last_seen_device_guard_index == self.device_idx + ), "AOTInductor only supports running on one CUDA device" + else: + if self.last_seen_device_guard_index is None: + code.writeline( + f"AOTICudaGuard device_guard({self.device_idx});" + if config.abi_compatible + else maybe_hipify_code_wrapper( + f"at::cuda::CUDAGuard device_guard({self.device_idx});" + ) + ) + else: + code.writeline(f"device_guard.set_index({self.device_idx});") + else: + # Note _DeviceGuard has less overhead than device, but only accepts + # integers + code.writeline(f"with {V.graph.device_ops.device_guard(self.device_idx)}:") + code.do_indent() + code.writeline(V.graph.device_ops.set_device(self.device_idx)) + + +class ExitDeviceContextManagerLine(WrapperLine): + def codegen(self, code: IndentedBuffer) -> None: + if not V.graph.cpp_wrapper: + code.do_unindent() + + +@dataclasses.dataclass +class MemoryPlanningLine(WrapperLine): + wrapper: WrapperCodeGen + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + """First pass to find reuse""" + return self + + def codegen(self, code: IndentedBuffer) -> None: + """Second pass to output code""" + + def __str__(self) -> str: + """ + Emits a string representation that fits on one line. + """ + args: List[str] = [] + for field in dataclasses.fields(self): + if field.name == "wrapper": + continue + val = getattr(self, field.name) + args.append( + f"{field.name}={val.get_name() if field.type is ir.Buffer else val}" + ) + return f"{type(self).__name__}({', '.join(args)})" + + +@dataclasses.dataclass +class AllocateLine(MemoryPlanningLine): + node: ir.Buffer + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if self.node.get_name() in V.graph.removed_buffers: + return NullLine(self.wrapper) + + # try to reuse a recently freed buffer + key = buffer_reuse_key(self.node) + if config.allow_buffer_reuse and key in state: + free_line = state.pop(key) + free_line.is_reused = True + return ReuseLine(self.wrapper, free_line.node, self.node) + + if self.node.get_device().type == "cpu": + static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node) + if static_shape is not None: + state.total_allocated_buffer_size += int( + functools.reduce(operator.mul, static_shape, 1) + ) + + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + line = self.wrapper.make_buffer_allocation(self.node) + code.writeline(line) + + +@dataclasses.dataclass +class FreeIfNotReusedLine(MemoryPlanningLine): + node: ir.Buffer + is_reused: bool = False + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if len(self.node.get_inputs_that_alias_output()) > 0: + return self + if isinstance(self.node.layout, ir.MultiOutputLayout): + return self + assert not self.is_reused + if self.node.get_name() in V.graph.removed_buffers: + return NullLine(self.wrapper) + if config.allow_buffer_reuse: + state.push(buffer_reuse_key(self.node), self) + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + if not self.is_reused: + code.writeline(self.wrapper.make_buffer_free(self.node)) + + +@dataclasses.dataclass +class ReuseLine(MemoryPlanningLine): + node: ir.Buffer + reused_as: ir.Buffer + delete_old: bool = True + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if self.node.get_name() in V.graph.removed_buffers: + assert self.reused_as.get_name() in V.graph.removed_buffers + return NullLine(self.wrapper) + assert self.reused_as.get_name() not in V.graph.removed_buffers + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + assert self.reused_as.get_name() not in V.graph.removed_buffers + code.writeline( + self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old) + ) + + +class NullLine(MemoryPlanningLine): + pass + + +BufferName = str + + +class WrapperCodeGen(CodeGen): + """ + Generate outer wrapper in Python that calls the kernels. + """ + + def __init__(self): + super().__init__() + self._names_iter: Iterator[int] = count() + self.imports = IndentedBuffer() + self.header = IndentedBuffer() + self.prefix = IndentedBuffer() + self.suffix = IndentedBuffer() + self.wrapper_call = IndentedBuffer() + self.kernel_autotune_defs = IndentedBuffer() + self.kernel_autotune_calls = IndentedBuffer() + self.kernel_autotune_names: Set[str] = set() + # If the generated source code is exactly the same, reuse the + # pre-existing kernel for it + self.src_to_kernel: Dict[str, str] = {} + self.kernel_numel_expr: Set[Tuple[str, GraphLowering]] = set() + self.lines: List[Union[MemoryPlanningLine, LineContext]] = [] + self.declare = "" + self.declare_maybe_reference = "" + self.ending = "" + self.open_bracket = "[" + self.closed_bracket = "]" + self.comment = "#" + self.namespace = "" + self.none_str = "None" + self.size = "size()" + self.stride = "stride()" + self.last_seen_device_guard_index: Optional[int] = None + self.supports_intermediate_hooks = True + self.expr_printer: Callable[[Any], str] = pexpr + self.user_defined_kernel_cache: Dict[Tuple[Any, ...], Tuple[str, Any]] = {} + self.unbacked_symbol_decls: Set[str] = set() # str of sympy.Symbol + self.allow_stack_allocation: Optional[bool] = None + self.stack_allocated_buffers: Dict[BufferName, ir.Buffer] = {} + self.computed_sizes: Set[sympy.Symbol] = set() + + # this is used for tracking which GraphLowering instance---parent graph + # or (nested) subgraph---is currently codegened; the primary use case is + # including the graph instance into a cache key to avoid cross-graph + # caching during lowering of nested subgraphs + self.codegened_graph_stack = [] + self.computed_sizes_stack = [] + + self.write_header() + self.write_prefix() + self.write_kernel_autotune_defs_header() + + if not V.graph.aot_mode: + for name, hashed in V.graph.constant_reprs.items(): + # include a hash so our code cache puts different constants into different files + self.write_constant(name, hashed) + + self.allocated: Set[BufferName] = set() + self.freed: Set[BufferName] = set() + + # maps from reusing buffer to reused buffer + self.reuses: Dict[BufferName, BufferName] = {} + + self.write_get_raw_stream = functools.lru_cache(None)( # type: ignore[assignment] + self.write_get_raw_stream + ) + + @functools.lru_cache(None) + def add_import_once(line: str) -> None: + self.imports.writeline(line) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(line) + + self.add_import_once = add_import_once + self._metas: Dict[str, str] = {} + self._meta_vars: Set[str] = set() + self.multi_kernel_state = MultiKernelState() + + # intermediate tensor value printing utility + self.debug_printer = DebugPrinterManager( + debug_printer_level=config.aot_inductor.debug_intermediate_value_printer + ) + + def write_constant(self, name: str, hashed: str) -> None: + self.header.writeline(f"{name} = None # {hashed}") + + def write_header(self) -> None: + context = torch._guards.TracingContext.try_get() + aot_config_comment = "" + if context is not None and context.aot_graph_name is not None: + aot_config_comment = f"# AOT ID: {context.aot_graph_name}" + self.imports.splice( + f""" + {aot_config_comment} + from ctypes import c_void_p, c_long, c_int + import torch + import math + import random + import os + import tempfile + from math import inf, nan + from torch._inductor.hooks import run_intermediate_hooks + from torch._inductor.utils import maybe_profile + from torch._inductor.codegen.memory_planning import _align as align + from torch import device, empty_strided + from {async_compile.__name__} import AsyncCompile + from torch._inductor.select_algorithm import extern_kernels + from torch._inductor.codegen.multi_kernel import MultiKernelCall + """, + strip=True, + ) + self.header.splice( + """ + aten = torch.ops.aten + inductor_ops = torch.ops.inductor + _quantized = torch.ops._quantized + assert_size_stride = torch._C._dynamo.guards.assert_size_stride + empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu + empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda + empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu + reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor + alloc_from_pool = torch.ops.inductor._alloc_from_pool + async_compile = AsyncCompile() + """, + strip=True, + ) + + def write_kernel_autotune_defs_header(self) -> None: + self.kernel_autotune_defs.splice( + f""" + import torch + from torch._dynamo.testing import rand_strided + from torch._dynamo.utils import preserve_rng_state + from torch._inductor.select_algorithm import AlgorithmSelectorCache + from {async_compile.__name__} import AsyncCompile + + async_compile = AsyncCompile() + generate_example_value = AlgorithmSelectorCache.generate_example_value + """ + ) + + @cache_on_self + def write_triton_header_once(self) -> None: + import_str = f""" + import triton + import triton.language as tl + from {triton_heuristics.__name__} import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph + """ + self.imports.splice(import_str, strip=True) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.splice(import_str) + self.write_get_raw_stream_header_once() + + @cache_on_self + def write_get_raw_stream_header_once(self) -> None: + self.imports.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + + def add_meta_once(self, meta: TritonMetaParams) -> str: + meta = repr(meta) + if meta not in self._metas: + var = f"meta{len(self._metas)}" + self._metas[meta] = var + self.header.writeline(f"{var} = {meta}") + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(f"{var} = {meta}") + self._meta_vars.add(var) + return self._metas[meta] + + @cache_on_self + def get_output_refs(self) -> List[str]: + return [x.codegen_reference(self.wrapper_call) for x in V.graph.graph_outputs] + + def mark_output_type(self) -> None: + return + + def codegen_input_size_asserts(self) -> None: + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(buf.get_size()) == 0: + continue + size = self.codegen_shape_tuple(buf.get_size()) + stride = self.codegen_shape_tuple(buf.get_stride()) + self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})") + + def codegen_input_nan_asserts(self) -> None: + self.prefix.writeline("# make sure graph inputs are not nan/inf") + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + + line = f"assert not {name}.isnan().any().item()" + self.prefix.writeline(line) + line = f"assert not {name}.isinf().any().item()" + self.prefix.writeline(line) + + def write_prefix(self) -> None: + self.prefix.splice( + """ + + async_compile.wait(globals()) + del async_compile + + def call(args): + """ + ) + with self.prefix.indent(): + if config.triton.debug_sync_graph: + self.prefix.writeline(V.graph.device_ops.synchronize()) + if V.graph.graph_inputs: + lhs = ", ".join(V.graph.graph_input_names) + if len(V.graph.graph_input_names) == 1: + lhs += "," + self.prefix.writeline(f"{lhs} = args") + self.prefix.writeline("args.clear()") + + self.codegen_inputs(self.prefix, V.graph.graph_inputs) + if config.size_asserts: + self.codegen_input_size_asserts() + if config.nan_asserts: + self.codegen_input_nan_asserts() + + # this function (and below) takes a graph as input so + # that stream caching happens per graph instance. this + # is important for nested subgraph codegening. + def write_get_raw_stream(self, device_idx: int, graph=None) -> str: + self.write_get_raw_stream_header_once() + name = f"stream{device_idx}" + self.writeline(f"{name} = get_raw_stream({device_idx})") + return name + + def get_codegened_graph(self): + return self.codegened_graph_stack[-1] + + def push_codegened_graph(self, graph): + self.codegened_graph_stack.append(graph) + + def pop_codegened_graph(self): + return self.codegened_graph_stack.pop() + + def push_computed_sizes(self, computed_sizes): + from copy import deepcopy + + return self.computed_sizes_stack.append(deepcopy(computed_sizes)) + + def pop_computed_sizes(self): + return self.computed_sizes_stack.pop() + + def next_kernel_suffix(self) -> str: + return f"{next(self._names_iter)}" + + def codegen_device_guard_enter(self, device_idx: int) -> None: + self.writeline( + EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index) + ) + if config.triton.autotune_at_compile_time: + # mimic logic of EnterDeviceContextManagerLine.codegen for the autotune code block + self.write_triton_header_once() + self.kernel_autotune_calls.writeline( + f"with {V.graph.device_ops.device_guard(device_idx)}:" + ) + self.kernel_autotune_calls.do_indent() + self.kernel_autotune_calls.writeline( + V.graph.device_ops.set_device(device_idx) + ) + self.kernel_autotune_calls.writeline( + f"stream{device_idx} = get_raw_stream({device_idx})" + ) + self.last_seen_device_guard_index = device_idx + + def codegen_device_guard_exit(self) -> None: + self.writeline(ExitDeviceContextManagerLine()) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.do_unindent() + + def generate_return(self, output_refs: List[str]) -> None: + if output_refs: + self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )") + else: + self.wrapper_call.writeline("return ()") + + def generate_before_suffix(self, result: IndentedBuffer) -> None: + return + + def generate_end(self, result: IndentedBuffer) -> None: + return + + def generate_fallback_kernel(self, fallback_kernel, args): + self.generate_extern_kernel_alloc(fallback_kernel, args) + + def generate_extern_kernel_alloc(self, extern_kernel, args): + # If it's a NoneLayout then the extern_kernel should essentially be + # treated as if it doesn't return anything + no_return = isinstance(extern_kernel.layout, ir.NoneLayout) + output_name = extern_kernel.get_name() + origin_node = extern_kernel.get_origin_node() + kernel_name = extern_kernel.get_kernel_name() + ending = self.ending + if config.memory_planning and "view_as_complex" in kernel_name: + # view operation fallbacks cause issues since inductor + # doesn't know the memory is still needed and might reuse it. + ending = f".clone(){ending}" + + if no_return: + self.writeline(f"{self.declare}{kernel_name}({', '.join(args)}){ending}") + else: + self.writeline( + f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}" + ) + if ( + self.supports_intermediate_hooks + and config.generate_intermediate_hooks + and origin_node is not None + ): + counters["inductor"]["intermediate_hooks"] += 1 + self.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {output_name})" + ) + + def generate_extern_kernel_out( + self, kernel: str, out: str, out_view: Optional[str], args: List[str] + ): + args.append(f"out={out_view if out_view else out}") + self.writeline(f"{kernel}({', '.join(args)})") + + def generate_user_defined_triton_kernel( + self, + kernel_name: str, + raw_args: List[Any], + grid: List[Any], + configs, + triton_meta, + constexprs, + ): + grid_fn, code = user_defined_kernel_grid_fn_code( + kernel_name, configs, grid, wrapper=self + ) + # Must happen after free symbols are already codegened + # Emit the grid wrapper function right before the call + for line in code.split("\n"): + self.writeline(line) + + args = [self.val_to_arg_str(v) for v in raw_args] + arg_types = [ + arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg) + for arg in raw_args + ] + self.generate_kernel_call( + kernel_name, args, grid_fn=grid_fn, arg_types=arg_types, raw_args=raw_args + ) + + def generate_scatter_fallback( + self, + output, + inputs, + cpp_kernel_name, + python_kernel_name, + src_is_tensor, + reduce, + kwargs, + ): + line = f"{python_kernel_name}({','.join(map(str, inputs))}" + if python_kernel_name.startswith("aten.scatter_reduce"): + line += ", ".join([""] + kwargs) + else: + if reduce: + line += f", reduce={repr(reduce)}" + line += ")" + self.writeline(line) + + def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + indices_str = f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}" + args = [x, indices_str, values, accumulate] + self.writeline(self.wrap_kernel_call(kernel, args)) + + def generate_extern_kernel_alloc_and_find_schema_if_needed( + self, + buf_name: str, + python_kernel_name: str, + cpp_kernel_name: str, + codegen_args: List[str], + cpp_op_schema: str, + cpp_kernel_key: str, + cpp_kernel_overload_name: str = "", + op_overload: Optional[torch._ops.OpOverload] = None, + raw_args=None, + outputs=None, + ): + self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(codegen_args)})") + + def generate(self, is_inference): + with dynamo_timed("WrapperCodeGen.generate"): + return self._generate(is_inference) + + def _generate(self, is_inference): + if config.profile_bandwidth: + self.write_triton_header_once() + result = IndentedBuffer() + result.splice(self.imports) + result.writeline("") + result.splice(self.header) + # We do not want the cpp header for intermediate const graph. Headers would be + # rendered by the main module instead. + if V.graph.aot_mode and V.graph.cpp_wrapper and V.graph.is_const_graph: + result = IndentedBuffer() + + with contextlib.ExitStack() as stack: + stack.enter_context(self.wrapper_call.indent()) + if config.profiler_mark_wrapper_call: + self.generate_profiler_mark_wrapper_call(stack) + if config.profile_bandwidth: + self.generate_start_graph() + + # We disable planning during training because it presently increases peak memory consumption. + if is_inference and config.memory_planning: + self.memory_plan() + # TODO: integrate memory planning & stack allocation? + self.allow_stack_allocation = False + else: + self.memory_plan_reuse() + + if config.triton.store_cubin: + self.generate_reset_kernel_saved_flags() + + for line in self.lines: + if isinstance(line, WrapperLine): + line.codegen(self.wrapper_call) + else: + self.wrapper_call.writeline(line) + + output_refs = self.get_output_refs() + self.mark_output_type() + if config.triton.debug_sync_graph: + self.wrapper_call.writeline(V.graph.device_ops.synchronize()) + + if config.profile_bandwidth: + self.generate_end_graph() + + if config.triton.store_cubin: + self.generate_save_uncompiled_kernels() + + if config.triton.autotune_at_compile_time: + self.generate_and_run_autotune_block() + + self.generate_return(output_refs) + + self.finalize_prefix() + result.splice(self.prefix) + + with result.indent(): + result.splice(self.wrapper_call) + + self.generate_before_suffix(result) + result.splice(self.suffix) + + self.generate_end(result) + + self.add_benchmark_harness(result) + + return result.getvaluewithlinemap() + + def generate_and_run_autotune_block(self): + """ + Compose self.kernel_autotune_defs and self.kernel_autotune_calls into a single block of + code and execute it to trigger Triton kernel compilation and auto-tuning + """ + self.kernel_autotune_defs.splice( + """ + async_compile.wait(globals()) + del async_compile + """ + ) + scope = {} # type: ignore[var-annotated] + tuning_code = ( + self.kernel_autotune_defs.getvalue() + self.kernel_autotune_calls.getvalue() + ) + if output_code_log.level == logging.DEBUG: + # Save the autotuning code block into a file + # Create a temporary file + with tempfile.NamedTemporaryFile( + dir=cache_dir(), suffix=".py", delete=False + ) as f: + f.write(tuning_code.encode("utf-8")) + file_path = f.name + output_code_log.debug( + "\nCompile-time auto-tuning code: \n%s\nAuto-tuning code written to %s", + tuning_code, + file_path, + ) + # Execute the code to autotune kernels + exec(tuning_code, scope) + + def memory_plan(self): + from .memory_planning import MemoryPlanner + + self.lines = MemoryPlanner(self).plan(self.lines) + + def memory_plan_reuse(self): + out_names = V.graph.get_output_names() + + while ( + self.lines + and isinstance(self.lines[-1], MemoryPlanningLine) + # TODO: this seems legit, NullLine has no node + and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] + ): + # these lines will be pointless + self.lines.pop() + + # codegen allocations in two passes + planning_states = [MemoryPlanningState()] + past_planning_states = [] + for i in range(len(self.lines)): + line = self.lines[i] + if isinstance(line, MemoryPlanningLine): + self.lines[i] = line.plan(planning_states[-1]) + elif isinstance(line, EnterSubgraphLine): + planning_states.append(MemoryPlanningState()) + elif isinstance(line, ExitSubgraphLine): + past_planning_states.append(planning_states.pop()) + past_planning_states.append(planning_states.pop()) + assert len(planning_states) == 0 + + # conservatively use the sum of all allocated buffer sizes + # in potentially nested scopes as the total allocated size + total_allocated_buffer_size = sum( + s.total_allocated_buffer_size for s in past_planning_states + ) + + self.allow_stack_allocation = ( + self.allow_stack_allocation is not False + and config.allow_stack_allocation + and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE + ) + + def codegen_input_size_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"{self.declare}{name}_size = {name}.{self.size}{self.ending}") + + def codegen_input_stride_var_decl(self, code: IndentedBuffer, name): + code.writeline( + f"{self.declare}{name}_stride = {name}.{self.stride}{self.ending}" + ) + + def codegen_inputs( + self, code: IndentedBuffer, graph_inputs: Dict[str, ir.TensorBox] + ): + """Assign all symbolic shapes to locals""" + + @functools.lru_cache(None) + def sizeof(name): + self.codegen_input_size_var_decl(code, name) + return f"{name}_size" + + @functools.lru_cache(None) + def strideof(name): + self.codegen_input_stride_var_decl(code, name) + return f"{name}_stride" + + # Assign all symbolic shapes needed to local variables + bound_vars: Set[sympy.Symbol] = set() + + def is_expr(x): + return isinstance(x[1], sympy.Expr) + + graph_inputs_expr = list(filter(is_expr, graph_inputs.items())) + graph_inputs_tensors = list( + filter(lambda x: not is_expr(x), graph_inputs.items()) + ) + + for name, shape in graph_inputs_expr: + if isinstance(shape, sympy.Symbol) and shape not in bound_vars: + code.writeline(f"{self.declare}{shape} = {name}{self.ending}") + bound_vars.add(shape) + + for name, value in graph_inputs_tensors: + shapes = value.get_size() + for dim, shape in enumerate(shapes): + if isinstance(shape, sympy.Symbol) and shape not in bound_vars: + code.writeline( + f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}" + ) + bound_vars.add(shape) + + for name, value in graph_inputs_tensors: + shapes = value.get_stride() + for dim, shape in enumerate(shapes): + if isinstance(shape, sympy.Symbol) and shape not in bound_vars: + code.writeline( + f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}" + ) + bound_vars.add(shape) + + def ensure_size_computed(self, sym: sympy.Symbol): + if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE): + if sym in self.computed_sizes: + return + self.computed_sizes.add(sym) + expr = V.graph.sizevars.inv_precomputed_replacements[sym] + self.writeline( + f"{self.declare}{sym} = {self.expr_printer(expr)}{self.ending}" + ) + + def finalize_prefix(self): + pass + + def codegen_python_sizevar(self, x: Expr, *, simplify: bool = True) -> str: + return pexpr(x, simplify=simplify) + + def codegen_sizevar(self, x: Expr) -> str: + return self.codegen_python_sizevar(x) + + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: + return f"{basename}[{index}]" + + def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: + parts = list(map(self.codegen_python_sizevar, shape)) + if len(parts) == 0: + return "()" + if len(parts) == 1: + return f"({parts[0]}, )" + return f"({', '.join(parts)})" + + def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: + return self.codegen_python_shape_tuple(shape) + + def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + return "alloc_from_pool({})".format( + ", ".join( + [ + name, + pexpr(offset), # bytes not numel + str(dtype), + self.codegen_shape_tuple(shape), + self.codegen_shape_tuple(stride), + ] + ) + ) + + def codegen_reinterpret_view( + self, data, size, stride, offset, writer, dtype=None + ) -> str: + if ( + size == data.layout.size + and stride == data.layout.stride + and offset == data.layout.offset + ): + if dtype is not None and dtype != data.dtype: + return f"aten.view.dtype({data.get_name()}, {dtype})" + else: + return f"{data.get_name()}" + else: + size = self.codegen_shape_tuple(size) + stride = self.codegen_shape_tuple(stride) + offset = self.codegen_sizevar(offset) + if dtype is not None and dtype != data.dtype: + return f"aten.view.dtype(reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset}), {dtype})" + else: + return ( + f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})" + ) + + def codegen_device_copy(self, src, dst): + self.writeline(f"{dst}.copy_({src})") + + def codegen_multi_output(self, name, value): + self.writeline(f"{self.declare}{name} = {value}{self.ending}") + + def codegen_dynamic_scalar(self, node): + (data,) = (t.codegen_reference() for t in node.inputs) + if len(node.keypath) == 0: + self.writeline(f"{node.sym} = {data}.item()") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], ConvertIntKey): + self.writeline(f"{node.sym} = 1 if {data}.item() else 0") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], DivideByKey): + self.writeline(f"{node.sym}_undivided = {data}.item()") + self.writeline( + f"assert {node.sym}_undivided % {node.keypath[0].divisor} == 0, " + f"f'{{{node.sym}_undivided}} not divisible by {node.keypath[0].divisor}'" + ) + self.writeline( + f"{node.sym} = {node.sym}_undivided // {node.keypath[0].divisor}" + ) + else: + raise AssertionError(f"unrecognized keypath {node.keypath}") + # No one should ever use this buffer, but for uniformity + # define the variable and assign it None + self.writeline(f"{node.get_name()} = None") + + def benchmark_compiled_module(self, output): + def add_fake_input(name, shape, stride, device, dtype): + output.writeline( + f"{name} = rand_strided(" + f"{self.codegen_python_shape_tuple(shape)}, " + f"{self.codegen_python_shape_tuple(stride)}, " + f"device='{device}', dtype={dtype})" + ) + + def add_expr_input(name, val): + output.writeline(f"{name} = {val}") + + def add_torchbind_input(name, value): + import pickle + + output.writeline(f"{name} = pickle.loads({pickle.dumps(value)!r})") + + output.writelines( + ["", "", "def benchmark_compiled_module(times=10, repeat=10):"] + ) + with output.indent(): + output.splice( + """ + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + """, + strip=True, + ) + + for name, value in V.graph.constants.items(): + # all the constants are global variables, that's why we need + # these 'global var_name' lines + output.writeline(f"global {name}") + add_fake_input( + name, value.size(), value.stride(), value.device, value.dtype + ) + + if len(V.graph.torchbind_constants) > 0: + output.writeline("import pickle") + for name, torchbind_obj in V.graph.torchbind_constants.items(): + # all the constants are global variables, that's why we need + # these 'global var_name' lines + output.writeline(f"global {name}") + add_torchbind_input(name, torchbind_obj) + + for name, value in V.graph.graph_inputs.items(): + if isinstance(value, sympy.Symbol) and isinstance( + V.graph.sizevars.var_to_val.get(value, None), SingletonInt + ): + # Inductor should only work with dense -> dense graph, and + # SingletonInts belong to metadata that should only live on + # the subclass. + continue + if isinstance(value, sympy.Expr): # Don't need to add symbolic + # TODO: this fallback and those below actually will generate possibly + # invalid benchmark code, because it's not guaranteed 42 + # is actually a valid value for the kernel in question. + # See https://github.com/pytorch/pytorch/issues/124686 + add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42)) + else: + shape = [ + V.graph.sizevars.size_hint(x, fallback=42) + for x in value.get_size() + ] + stride = [ + V.graph.sizevars.size_hint(x, fallback=42) + for x in value.get_stride() + ] + add_fake_input( + name, + shape, + stride, + value.get_device(), + value.get_dtype(), + ) + + call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])" + output.writeline(f"fn = lambda: {call_str}") + output.writeline("return print_performance(fn, times=times, repeat=repeat)") + + def add_benchmark_harness(self, output): + """ + Append a benchmark harness to generated code for debugging + """ + if not config.benchmark_harness: + return + + self.benchmark_compiled_module(output) + + output.writelines(["", "", 'if __name__ == "__main__":']) + with output.indent(): + output.writelines( + [ + "from torch._inductor.wrapper_benchmark import compiled_module_main", + f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)", + ] + ) + + def define_kernel( + self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True + ): + metadata_comment = f"{metadata}\n" if metadata else "" + body = f"\n\n{metadata_comment}{name} = {kernel}" + self.header.splice(body) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_defs.splice(body) + + def define_user_defined_triton_kernel(self, kernel, configs, kwargs): + from torch.utils._triton import patch_triton_dtype_repr + + patch_triton_dtype_repr() + + original_name = kernel.__name__ + + from .common import KernelArgType, SizeArg, TensorArg + + signature: List[KernelArgType] = [] + constants: Dict[int, Any] = {} + non_constant_indices = [] + equal_to_1_arg_idx: List[int] = [] + for idx, key in enumerate(kernel.arg_names): + if key not in kwargs: + continue + arg = kwargs[key] + if idx in kernel.constexprs: + constants[idx] = arg + else: + non_constant_indices.append(idx) + if isinstance(arg, ir.Buffer): + signature.append( + TensorArg( + name=key, + buffer=arg.get_name(), + dtype=arg.get_dtype(), + ) + ) + elif isinstance(arg, ir.ReinterpretView): + # for ReinterpretView we use the underlying + # buffer name and note the (possibly non-zero) + # offset relative to the underlying buffer + signature.append( + TensorArg( + name=key, + buffer=arg.data.get_name(), + dtype=arg.get_dtype(), + offset=arg.layout.offset, + ) + ) + else: + signature.append(SizeArg(key, arg)) + if isinstance( + arg, (int, sympy.Integer) + ) and V.graph.sizevars.statically_known_equals( + arg, 1 # type: ignore[arg-type] + ): + equal_to_1_arg_idx.append(idx) + index_dtype = "tl.int32" + triton_meta = { + "signature": signature_to_meta( + signature, + size_dtype=index_dtype, + indices=non_constant_indices, + ), + "device": DeviceProperties.create( + V.graph.scheduler.get_current_device_or_throw() + ), + # Triton compiler includes equal_to_1 args into constants even + # when they are not constexpr. otherwise there may be a segfault + # during launching the Inductor-compiled Triton kernel. + # TODO(aakhundov): add None args to constants, too. currently, this + # causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input. + # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 + # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 + "constants": { + **constants, + **dict.fromkeys(equal_to_1_arg_idx, 1), + }, + "configs": [ + config_of( + signature, + indices=non_constant_indices, + ) + ], + } + + # Distinguish between different functions using function id + cache_key: List[Any] = [id(kernel.fn)] + if len(configs) > 0: + for arg in kwargs.values(): + # We need to key on non tensor arg only in autotune mode + if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)): + cache_key.append(arg) + cache_key.append(str(triton_meta)) + cache_key = tuple(cache_key) + + if cache_key in self.user_defined_kernel_cache: + return self.user_defined_kernel_cache[cache_key] + + name = f"{original_name}_{len(self.user_defined_kernel_cache)}" + # Add to the cache for the next use + self.user_defined_kernel_cache[cache_key] = (name, triton_meta) + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''") + + from .triton import gen_common_triton_imports, TritonKernel + + compile_wrapper.splice(gen_common_triton_imports()) + + inductor_meta = { + "kernel_name": name, + **TritonKernel.inductor_meta_common(), + } + + configs = [ + { + "kwargs": config.kwargs, + "num_warps": config.num_warps, + "num_stages": config.num_stages, + } + for config in configs + ] + + compile_wrapper.splice( + f""" + @triton_heuristics.user_autotune( + configs={configs!r}, + inductor_meta={inductor_meta!r}, + triton_meta={triton_meta!r}, + filename=__file__, + custom_kernel=True, + ) + @triton.jit + """ + ) + compile_wrapper.splice(kernel.src, strip=True) + + # Also include any possible kernel being called indirectly + from triton import JITFunction # type: ignore[name-defined, attr-defined] + from triton.language import constexpr # type: ignore[name-defined] + + # global constexpr vars handled above + symbols_included = {original_name} + + def traverse(cur_kernel): + # here we extract the unqualified names (i.e., not attributes and + # without prepended module name) loaded in the kernel code, which + # are matched with the co_names and __globals__ below to codegen + # the respective imports necessary for the kernel compilation + unqualified_loads = { + inst.argval + for inst in dis.Bytecode(cur_kernel.fn) + if inst.opname == "LOAD_GLOBAL" + } + global_annotations = cur_kernel.fn.__globals__.get("__annotations__", {}) + for symbol_name in cur_kernel.fn.__code__.co_names: + if symbol_name in symbols_included: + continue + if symbol_name in cur_kernel.fn.__globals__: + symbol = cur_kernel.fn.__globals__[symbol_name] + if isinstance(symbol, JITFunction): + compile_wrapper.newline() + compile_wrapper.writeline("@triton.jit") + compile_wrapper.splice(symbol.src, strip=True) + symbols_included.add(symbol_name) + traverse(symbol) + elif isinstance(symbol, (int, str, bool, constexpr)): + compile_wrapper.newline() + if isinstance(symbol, constexpr): + symbol_str = f"tl.constexpr({symbol.value!r})" + else: + symbol_str = f"{symbol!r}" + if annotation := global_annotations.get(symbol_name): + annotion_code = "" + if isinstance(annotation, type): + annotation_code = ( + f": {annotation.__module__}.{annotation.__name__}" + ) + else: + annotation_code = f": {annotation!r}" + compile_wrapper.writeline( + f"{symbol_name}{annotation_code} = {symbol_str}" + ) + else: + compile_wrapper.writeline(f"{symbol_name} = {symbol!r}") + symbols_included.add(symbol_name) + elif ( + symbol_name in unqualified_loads + and symbol_name != "tl" # already imported + and hasattr(symbol, "__module__") + # only codegen imports from triton; JITFunctions + # imported from other modules will be codegened + # in the separate branch above + and symbol.__module__.startswith("triton") + ): + # a global symbol imported from triton is referenced + # without module qualification (i.e., `store` instead + # of `tl.store`): need to codegen an import + compile_wrapper.writeline( + f"from {symbol.__module__} import {symbol.__name__} as {symbol_name}" + ) + symbols_included.add(symbol_name) + + traverse(kernel) + + current_device = V.graph.scheduler.get_current_device_or_throw() + compile_wrapper.writeline(f"''', device_str='{current_device.type}')") + _, lineno = inspect.getsourcelines(kernel.fn) + srcfile = inspect.getsourcefile(kernel.fn) + metadata = f"# Original path: {srcfile}:{lineno}" + self.define_kernel( + name, + compile_wrapper.getvalue(), + metadata, + ) + return name, triton_meta + + def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None): + expr = f"{kernel_name}_{tree.prefix}numel" + if suffix is not None: + expr += f"_{suffix}" + if (expr, V.graph) not in self.kernel_numel_expr: + # declare expr once in each graph (scope) + self.kernel_numel_expr.add((expr, V.graph)) + self.writeline( + f"{self.declare}{expr} = {self.expr_printer(tree.numel)}{self.ending}" + ) + else: + self.writeline(f"{expr} = {self.expr_printer(tree.numel)}{self.ending}") + # We can get symbolic expressions here, like s0*64 + # It is fine to have them here, but we need to handle them correctly as their own type + # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* + # scalars as well. + # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for + # constant now, need type info. I agree, this needs type info, and while this is not true type info + # it suffices as a type hint for the purposes of producing the correct code for this type. + return SymbolicCallArg(expr, tree.numel) + + def generate_workspace_allocation(self, nbytes, device, zero_fill): + line = self.make_allocation( + "workspace", device, torch.uint8, shape=(nbytes,), stride=(1,) + ) + self.writeline(line) + if zero_fill: + self.writeline(f"workspace.zero_(){self.ending}") + + def wrap_kernel_call(self, name, call_args): + return f"{name}({', '.join(call_args)}){self.ending}" + + def generate_profiler_mark_wrapper_call(self, stack): + self.wrapper_call.writeline("from torch.profiler import record_function") + self.wrapper_call.writeline( + f"with record_function('graph_{V.graph.graph_id}_inductor_wrapper_call'):" + ) + stack.enter_context(self.wrapper_call.indent()) + + def generate_start_graph(self): + self.wrapper_call.writeline("start_graph()") + + def generate_end_graph(self): + self.wrapper_call.writeline(f"end_graph({config.profile_bandwidth_output!r})") + + def generate_reset_kernel_saved_flags(self): + self.wrapper_call.splice( + f""" + for kernel in globals().values(): + if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner): + kernel.cuda_kernel_saved = False + """ + ) + + def generate_save_uncompiled_kernels(self): + """ + Precompile and save the CUBINs of the Triton kernels that haven't + been precompiled and saved as a side effect of running the generated + JIT model (Python wrapper). This can happen when the model contains + control flow: only one pass through the control flow operators covers + the kernels that are saved, the remaining kernels are not launched, + hence not saved. The main purpose of this codegen is to compile and + save the Triton kernels outside the active control flow path for + subsequent AOTInductor code generation and compilation. + """ + self.wrapper_call.splice( + f""" + for kernel in globals().values(): + if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner): + if not kernel.cuda_kernel_saved: + if len(kernel.launchers) == 0: + kernel.precompile() + kernel.save_gpu_kernel( + grid=(0, 0, 0), # use dummy grid + stream="stream", # use dummy stream + launcher=kernel.launchers[0], + ) + """ + ) + + def generate_default_grid( + self, + kernel_name: str, + grid: List[Any], + cuda: bool = True, + grid_callable: Optional[Callable[..., Any]] = None, + **grid_extra_kwags, + ): + return grid + + def prepare_triton_kernel_call(self, device_index, call_args): + def wrap_arg(arg): + if isinstance(arg, str): + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + return arg + ".item()" if should_unwrap_unspec_arg(arg) else arg + elif isinstance(arg, (int, float, bool, SymbolicCallArg)): + return str(arg) + else: + return self.expr_printer(V.graph.sizevars.simplify(arg)) + + call_args = [wrap_arg(arg) for arg in call_args] + + if device_index is None: + current_device = V.graph.scheduler.get_current_device_or_throw() + device_index = current_device.index + + return device_index, call_args + + def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None): + if isinstance(arg_type, torch_dtype): + if V.graph.try_get_buffer(arg) is not None: + buf_name = arg + buf = V.graph.get_buffer(arg) + else: + assert ( + raw_arg is not None + ), "V.graph.get_buffer(arg) and raw_arg can't be None at the same time" + buf_name = f"tmp_arg_{index}" + buf = raw_arg + + size = V.graph.sizevars.size_hints( + buf.get_size(), + fallback=config.unbacked_symint_fallback, + ) + stride = V.graph.sizevars.size_hints( + buf.get_stride(), + fallback=config.unbacked_symint_fallback, + ) + device = buf.get_device() + dtype = buf.get_dtype() + offset = V.graph.sizevars.size_hint( + buf.layout.offset, + fallback=config.unbacked_symint_fallback, + ) + value = f"generate_example_value({size}, {stride}, '{device}', {dtype}, {offset})" + self.kernel_autotune_calls.writeline(f"{buf_name} = {value}") + return buf_name + elif issubclass(arg_type, sympy.Basic) or isinstance(arg, SymbolicCallArg): + # arg is a symbol or symbolic expression + if isinstance(arg, str): + if arg in self._meta_vars: + return arg + if raw_arg is None: + return "None" + arg = raw_arg + if isinstance(arg, SymbolicCallArg): + arg = arg.inner_expr + if arg in V.graph.sizevars.inv_precomputed_replacements: + arg = V.graph.sizevars.inv_precomputed_replacements[arg] + return str( + V.graph.sizevars.size_hint( + arg, + fallback=config.unbacked_symint_fallback, + ) + ) + elif isinstance(arg, (str, int, float, bool)): + return str(arg) + elif isinstance(arg, list): + return f"[{', '.join(self.generate_example_arg_value(a, type(a)) for a in arg)}]" + else: + raise NotImplementedError(f"Unsupported type {type(arg)}") + + def _grid_dim_str(self, grid_per_dim): + if isinstance(grid_per_dim, list): + return ( + "[" + ", ".join(self._grid_dim_str(item) for item in grid_per_dim) + "]" + ) + else: + return pexpr(grid_per_dim) + + def generate_kernel_call( + self, + kernel_name, + call_args, + grid=None, + device_index=None, + cuda=True, + triton=True, + arg_types=None, + raw_args=None, + grid_fn: str = "grid", + triton_meta=None, + autotune_configs=None, + grid_extra_kwargs="", + ): + """ + Generates kernel call code. + + cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ + if cuda: + device_index, call_args_str = self.prepare_triton_kernel_call( + device_index, call_args + ) + call_args_str = ", ".join(call_args_str) + stream_name = self.write_get_raw_stream(device_index, V.graph) + if triton: + self.write_triton_header_once() + if grid is None: + grid_str = grid_fn + else: + grid_str = ", ".join(self._grid_dim_str(item) for item in grid) + if grid_extra_kwargs: + grid_str = f"{grid_str}, {grid_extra_kwargs}" + grid_str = f"{grid_fn}({grid_str})" + self.writeline( + f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})" + ) + if ( + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names + ): + # Create example args for autotune in a separate epilogue + assert arg_types is not None and len(call_args) == len( + arg_types + ), "call_args and arg_types do not match" + + tensor_args = {} + all_args = [] + if raw_args is None: + # create a dummy raw_args for uniform behavior in the following loop + raw_args = [None] * len(call_args) + else: + assert len(raw_args) == len( + call_args + ), "call_args and raw_args do not match" + + for i, (arg, arg_type, raw_arg) in enumerate( + zip(call_args, arg_types, raw_args) + ): + key = None + if isinstance(arg, str) and "=" in str(arg): + # arg may be passed in a kwarg style, and then we need to extract its value + key, arg = arg.split("=") + + if isinstance(arg_type, torch_dtype): + if arg not in tensor_args: + arg_str = self.generate_example_arg_value( + arg, arg_type, raw_arg, i + ) + tensor_args[arg] = arg_str + else: + arg_str = tensor_args[arg] + else: + arg_str = self.generate_example_arg_value( + arg, arg_type, raw_arg, i + ) + all_args.append(arg_str if key is None else f"{key}={arg_str}") + + if grid is None: + grid_str = grid_fn + else: + grid_str = ", ".join( + self.generate_example_arg_value(g, type(g)) for g in grid + ) + if grid_extra_kwargs: + grid_str = f"{grid_str}, {grid_extra_kwargs}" + grid_str = f"{grid_fn}({grid_str})" + + self.kernel_autotune_calls.writeline( + f"{kernel_name}.run({', '.join(all_args)}, grid={grid_str}, stream={stream_name})" + ) + self.kernel_autotune_calls.writeline( + f"del {', '.join(arg for arg in tensor_args.values())}\n", + ) + self.kernel_autotune_names.add(kernel_name) + else: + stream_ptr = f"c_void_p({stream_name})" + self.writeline( + f"{kernel_name}.{kernel_name}({call_args_str}, {stream_ptr})" + ) + else: + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + + def writeline(self, line): + self.lines.append(line) + + def writelines(self, lines): + for line in lines: + self.writeline(line) + + def enter_context(self, ctx): + self.lines.append(LineContext(ctx)) + + def val_to_arg_str(self, s, type_=None): + from torch.utils._triton import dtype_to_string, has_triton_package + + if has_triton_package(): + import triton + + if isinstance(s, SymTypes): + return pexpr(s.node.expr) + elif isinstance(s, sympy.Expr): + return pexpr(s) + elif isinstance(s, (tuple, list)): + + @dataclasses.dataclass + class Shim: + ref: Any + + def __repr__(self): + return self.ref + + return repr(type(s)(Shim(self.val_to_arg_str(a)) for a in s)) + elif isinstance(s, torch._ops.OpOverload): + return _get_qualified_name(s) + elif isinstance(s, (ir.Buffer, ReinterpretView)): + return s.codegen_reference() + elif has_triton_package() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined] + return dtype_to_string(s) + else: + return repr(s) + + # The following methods are for memory management + def make_buffer_allocation(self, buffer): + device = buffer.get_device() + dtype = buffer.get_dtype() + shape = tuple(buffer.get_size()) + stride = tuple(buffer.get_stride()) + return self.make_allocation(buffer.get_name(), device, dtype, shape, stride) + + def make_allocation(self, name, device, dtype, shape, stride): + if device.type in ("cpu", "cuda", "xpu"): + # optimized path for faster allocations, saving ~2us versus the stuff below + return ( + f"{name} = empty_strided_{device.type}(" + f"{self.codegen_shape_tuple(shape)}, " + f"{self.codegen_shape_tuple(stride)}, " + f"{dtype})" + ) + # all other devices: + return ( + f"{name} = empty_strided(" + f"{self.codegen_shape_tuple(shape)}, " + f"{self.codegen_shape_tuple(stride)}, " + f"device='{device.type}', dtype={dtype})" + ) + + def make_tensor_alias(self, new_name, old_name, comment=""): + return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}" + + def make_buffer_free(self, buffer): + return f"del {buffer.get_name()}" + + def make_free_by_names(self, names_to_del: List[str]): + return f"del {', '.join(name for name in names_to_del)}" + + def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): + return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse" + + def make_buffer_reuse(self, old: ir.Buffer, new: ir.Buffer, delete_old: bool): + assert old.get_dtype() == new.get_dtype() + old_name = old.get_name() + new_name = new.get_name() + del_line = ";" + if old_name not in V.graph.get_output_names() and delete_old: + del_line = f"; {self.make_buffer_free(old)}" + + if old.get_size() == new.get_size() and old.get_stride() == new.get_stride(): + if old_name in self.stack_allocated_buffers: + self.stack_allocated_buffers[new_name] = new + return self.codegen_exact_buffer_reuse(old_name, new_name, del_line) + + reinterpret_view = self.codegen_reinterpret_view( + old, new.get_size(), new.get_stride(), 0, self.wrapper_call + ) + if reinterpret_view in self.stack_allocated_buffers: + self.stack_allocated_buffers[new_name] = new + return f"{self.declare_maybe_reference}{new_name} = {reinterpret_view}{del_line} {self.comment} reuse" + + def codegen_deferred_allocation(self, name, layout): + self.writeline( + DeferredLine( + name, + f"{self.declare_maybe_reference}{name} = {layout.view.codegen_reference()}{self.ending} " + f"{self.comment} alias", + ) + ) + + def codegen_allocation(self, buffer: ir.Buffer): + name = buffer.get_name() + + if name in V.graph.removed_buffers or name in self.allocated: + return + self.allocated.add(name) + if isinstance( + buffer.get_defining_op(), + (ir.ExternKernelAlloc, ir.MultiOutput), + ): + return + + layout = buffer.get_layout() + if isinstance(layout, ir.MutationLayoutSHOULDREMOVE): + return + if isinstance(layout, ir.NoneLayout): + return + if isinstance(layout, ir.NonOwningLayout): + assert isinstance( + layout.view, ir.ReinterpretView + ), f"unexpected {type(layout.view)}: {layout.view}" + assert isinstance(layout.view.data, ir.StorageBox), type(layout.view.data) + assert isinstance(layout.view.data.data, ir.Buffer), type(layout.view.data) + self.codegen_allocation(layout.view.data.data) + self.codegen_deferred_allocation(name, layout) + return + + self.writeline(AllocateLine(self, buffer)) + + def codegen_free(self, buffer): + name = buffer.get_name() + + # can be freed but not reused + if isinstance(buffer, ir.InputBuffer): + self.writeline(self.make_buffer_free(buffer)) + return + + if not self.can_reuse(buffer): + return + self.freed.add(name) + + self.writeline(FreeIfNotReusedLine(self, buffer)) + + def can_reuse(self, input_buffer, output_buffer=None): + name = input_buffer.get_name() + return not ( + name in V.graph.removed_buffers + or name in V.graph.graph_inputs + or name in V.graph.constants + or name in V.graph.torchbind_constants + or name in V.graph.never_reuse_buffers + or name in self.freed + ) + + def did_reuse(self, buffer, reused_buffer): + # Check whether a given buffer was reused by a possible reuser in the wrapper codegen + # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed + return ( + buffer.get_name() in self.reuses + and self.reuses[buffer.get_name()] == reused_buffer.get_name() + ) + + def codegen_inplace_reuse(self, input_buffer: ir.Buffer, output_buffer: ir.Buffer): + assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer) + self.codegen_allocation(input_buffer) + self.freed.add(input_buffer.get_name()) + self.allocated.add(output_buffer.get_name()) + self.reuses[output_buffer.get_name()] = input_buffer.get_name() + self.writeline(ReuseLine(self, input_buffer, output_buffer)) + + def codegen_unbacked_symbol_decl(self, symbol): + name = str(symbol) + if name in self.unbacked_symbol_decls: + return name + else: + # When in CppWrapperCpu, we should only generate the declaration once + self.unbacked_symbol_decls.add(name) + return self.declare + name + + def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): + for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs): + self.writeline(f"{self.declare}{inner_input} = {outer_input}{self.ending}") + + def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): + for inner_output, outer_output in zip( + subgraph.graph.graph_outputs, outer_outputs + ): + self.writeline( + f"{outer_output} = {inner_output.codegen_reference()}{self.ending}" + ) + + def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs): + try: + self.push_codegened_graph(subgraph.graph) + self.writeline(f"{self.comment} subgraph: {subgraph.name}") + self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs) + parent_graph = V.graph + with V.set_graph_handler(subgraph.graph): + subgraph.graph.codegen_subgraph( + parent_graph=parent_graph, + ) + self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs) + finally: + self.pop_codegened_graph() + + def codegen_conditional(self, conditional): + name = conditional.get_name() + + self.writeline(f"{name} = [None] * {len(conditional.outputs)}") + + outer_inputs = [buf.codegen_reference() for buf in conditional.operands] + outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] + + predicate = conditional.predicate.codegen_reference() + if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): + # move the Tensor predicate to host + predicate = f"{predicate}.item()" + + self.writeline(f"{name} = [None] * {len(conditional.outputs)}") + self.writeline(f"if {predicate}:") + self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) + self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("else:") + self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) + self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + + def codegen_while_loop(self, while_loop): + name = while_loop.get_name() + outer_carried_inputs = [ + buf.codegen_reference() for buf in while_loop.carried_inputs + ] + outer_additional_inputs = [ + buf.codegen_reference() for buf in while_loop.additional_inputs + ] + + self.writeline(f"{name} = [None] * {len(outer_carried_inputs)}") + for i, inp in enumerate(outer_carried_inputs): + # set the initial state before the loop + self.writeline(f"{name}[{i}] = {inp}") + + cond_outer_inputs = [ + *[f"{name}[{i}]" for i in range(len(outer_carried_inputs))], + *outer_additional_inputs, + ] + cond_outer_outputs = [f"{name}_cond_result"] + body_outer_inputs = list( + cond_outer_inputs + ) # same inputs for cond_fn and body_fn + # Carry over the state from body_fn. Note: We only carry over + # the carried_inputs part of the inputs, the additional ones + # are passed in as they're before. + body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)] + + self.writeline("while True:") + self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph)) + self.codegen_subgraph( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + self.writeline( + f"if not {cond_outer_outputs[0]}.item(): break" + ) # condition doesn't hold + self.writeline(ExitSubgraphLine(self)) + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + self.codegen_subgraph( + while_loop.body_subgraph, body_outer_inputs, body_outer_outputs + ) + self.writeline(ExitSubgraphLine(self)) + + @staticmethod + def statically_known_int_or_none(x): + try: + if getattr(x, "free_symbols", None): + # _maybe_evaluate_static will return (s0 // (2 // s0)) as 2, but + # the actual codegen will still generate the full expression here. + return None + if isinstance(x, int): + return x + val = V.graph._shape_env._maybe_evaluate_static(x) + return int(val) + except Exception: + return None + + @staticmethod + def statically_known_list_of_ints_or_none(lst): + result = [] + for x in lst: + num = WrapperCodeGen.statically_known_int_or_none(x) + if num is None: + return None + result.append(num) + return result + + @staticmethod + def is_statically_known_list_of_ints(lst): + return WrapperCodeGen.statically_known_list_of_ints_or_none(lst) is not None + + @staticmethod + def static_shape_for_buffer_or_none(buffer): + return WrapperCodeGen.statically_known_list_of_ints_or_none(buffer.get_size()) + + @staticmethod + def can_prove_buffer_has_static_shape(buffer): + return WrapperCodeGen.static_shape_for_buffer_or_none(buffer) is not None diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__init__.py b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__main__.py b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6503b7901a5a1b2fb76f03c9550b562b623a88 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__main__.py @@ -0,0 +1,47 @@ +# mypy: allow-untyped-defs +import argparse +import logging +import os +import sys + +from torch._inductor.async_compile import pre_fork_setup +from torch._inductor.compile_worker.subproc_pool import SubprocMain +from torch._inductor.compile_worker.watchdog import _async_compile_initializer +from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path + + +log = logging.getLogger(__name__) + +_set_triton_ptxas_path() + +try: + import triton + + assert triton is not None # preload in parent +except ImportError: + pass + + +def main(): + try: + parser = argparse.ArgumentParser() + parser.add_argument("--workers", type=int) + parser.add_argument("--parent", type=int) + parser.add_argument("--read-fd", type=int) + parser.add_argument("--write-fd", type=int) + args = parser.parse_args() + if os.getppid() != args.parent: + sys.exit(0) + read_fd = os.fdopen(args.read_fd, "rb") + write_fd = os.fdopen(args.write_fd, "wb") + + pre_fork_setup() + + _async_compile_initializer(args.parent) + SubprocMain(args.workers, read_fd, write_fd).main() + except Exception: + log.exception("Uncaught exception in compile_worker subprocess") + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14119b58e3d7ccf3051a99a96f8207ee376bd801 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__pycache__/__main__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__pycache__/__main__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..432b4e00b3f7e523d77830d6fa31259832d85dca Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__pycache__/__main__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__pycache__/subproc_pool.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__pycache__/subproc_pool.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8ecce4e055a9243c8ef81f589985806cb33483f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__pycache__/subproc_pool.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__pycache__/watchdog.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__pycache__/watchdog.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e0a0367147b29c4acf9c10c7a7e56a4ad411b61 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/__pycache__/watchdog.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/subproc_pool.py b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/subproc_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..77938dc2e44ddd4fa70034daf520e00a659b57b5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/subproc_pool.py @@ -0,0 +1,314 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +import multiprocessing +import os +import pickle +import struct +import subprocess +import sys +import threading +import traceback +import typing +from concurrent.futures import Future, ProcessPoolExecutor +from concurrent.futures.process import BrokenProcessPool +from typing import Any, Callable, Dict + +from torch._inductor import config +from torch._inductor.compile_worker.watchdog import _async_compile_initializer + + +log = logging.getLogger(__name__) + + +def _pack_msg(job_id, length): + return struct.pack("nn", job_id, length) + + +def _unpack_msg(data): + if not data: + return -1, -1 + return struct.unpack("nn", data) + + +msg_bytes = len(_pack_msg(0, 0)) + + +def _send_msg(write_pipe, job_id, job_data=b""): + length = len(job_data) + write_pipe.write(_pack_msg(job_id, length)) + if length > 0: + write_pipe.write(job_data) + write_pipe.flush() + + +def _recv_msg(read_pipe): + job_id, length = _unpack_msg(read_pipe.read(msg_bytes)) + data = read_pipe.read(length) if length > 0 else b"" + return job_id, data + + +def _get_ld_library_path(): + path = os.environ.get("LD_LIBRARY_PATH", "") + if config.is_fbcode(): + from libfb.py.parutil import get_runtime_path + + runtime_path = get_runtime_path() + if runtime_path: + lib_path = os.path.join(runtime_path, "runtime", "lib") + path = os.pathsep.join([lib_path, path]) if path else lib_path + + return path + + +class _SubprocExceptionInfo: + """ + Carries exception info from subprocesses across the wire. traceback + objects are not pickleable, so we store the trace as a string and + use it for the message in the exception thrown in the main process. + """ + + def __init__(self, details) -> None: + self.details = details + + +class SubprocException(Exception): + """ + Thrown when a job in a subprocess raises an Exception. + """ + + def __init__(self, details) -> None: + super().__init__(f"An exception occurred in a subprocess:\n\n{details}") + + +class SubprocPool: + """ + Mimic a concurrent.futures.ProcessPoolExecutor, but wrap it in + a subprocess.Popen() to try to avoid issues with forking/spawning + """ + + def __init__(self, nprocs: int) -> None: + entry = os.path.join(os.path.dirname(__file__), "__main__.py") + + subproc_read_fd, write_fd = os.pipe() + read_fd, subproc_write_fd = os.pipe() + self.write_pipe = os.fdopen(write_fd, "wb") + self.read_pipe = os.fdopen(read_fd, "rb") + + cmd = [ + sys.executable, + entry, + f"--workers={nprocs}", + f"--parent={os.getpid()}", + f"--read-fd={str(subproc_read_fd)}", + f"--write-fd={str(subproc_write_fd)}", + ] + self.process = subprocess.Popen( + cmd, + env={ + **os.environ, + # We need to set the PYTHONPATH so the subprocess can find torch. + "PYTHONPATH": os.pathsep.join(sys.path), + # We don't want to re-warm the pool when the subprocess imports + # torch._inductor.codecache since the warming process is what + # creates the SubprocPool in the first place. + "TORCH_WARM_POOL": "0", + # Some internal usages need a modified LD_LIBRARY_PATH. + "LD_LIBRARY_PATH": _get_ld_library_path(), + }, + pass_fds=(subproc_read_fd, subproc_write_fd), + ) + self.write_lock = threading.Lock() + self.read_thread = threading.Thread(target=self._read_thread, daemon=True) + + self.futures_lock = threading.Lock() + self.pending_futures: Dict[int, Future[Any]] = {} + self.job_id_count = itertools.count() + + self.running = True + + # Start thread last to ensure all member variables are initialized + # before any access. + self.read_thread.start() + + def submit(self, job_fn: Callable[..., Any], *args): + if args: + job_fn = functools.partial(job_fn, *args) + job_data = pickle.dumps(job_fn, pickle.HIGHEST_PROTOCOL) + future: Future[Any] + with self.futures_lock: + job_id = next(self.job_id_count) + self.pending_futures[job_id] = future = Future() + future.set_running_or_notify_cancel() + with self.write_lock: + if not self.running: + raise RuntimeError("submit() on closed pool") + _send_msg(self.write_pipe, job_id, job_data) + return future + + def _read_thread(self): + try: + while True: + job_id, data = _recv_msg(self.read_pipe) + if job_id < 0: + if self.running: + log.warning("SubprocPool unclean exit") + self.read_pipe.close() + return + result = pickle.loads(data) + with self.futures_lock: + if not self.running: + return + if isinstance(result, _SubprocExceptionInfo): + # An exception occurred in the submitted job + self.pending_futures[job_id].set_exception( + SubprocException(result.details) + ) + elif isinstance(result, Exception): + # An exception occurred in some of our subprocess machinery. + self.pending_futures[job_id].set_exception(result) + else: + self.pending_futures[job_id].set_result(result) + del self.pending_futures[job_id] + except Exception: + log.exception("failure in SubprocPool._read_thread") + + def shutdown(self): + try: + with self.write_lock: + if not self.running: + return + self.running = False + _send_msg(self.write_pipe, -1) + self.write_pipe.close() + self.process.wait(10) + except OSError as e: + log.warning("Ignored OSError in pool shutdown: %s", e) + finally: + with self.futures_lock: + for future in self.pending_futures.values(): + if not future.cancel(): + future.set_exception(RuntimeError("SubprocPool closed")) + self.pending_futures.clear() + + +class SubprocMain: + """Communicates with a SubprocPool in the parent process, called by __main__.py""" + + def __init__(self, nprocs, read_pipe, write_pipe) -> None: + self.read_pipe = read_pipe + self.write_pipe = write_pipe + self.write_lock = threading.Lock() + self.nprocs = nprocs + self.pool = self._new_pool(nprocs, True) + self.running = True + + def _new_pool(self, nprocs, warm): + pool = ProcessPoolExecutor( + nprocs, + mp_context=multiprocessing.get_context("fork"), + initializer=functools.partial(_async_compile_initializer, os.getpid()), + ) + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + if warm: + _warm_process_pool(pool, nprocs) + return pool + + def main(self): + while True: + job_id, data = _recv_msg(self.read_pipe) + if job_id < 0: + return self._shutdown() + self.submit(job_id, data) + + def _shutdown(self): + with self.write_lock: + self.running = False + try: + _send_msg(self.write_pipe, -1) + self.write_pipe.close() + except BrokenPipeError: + pass # parent process already shutdown + self.read_pipe.close() + self.pool.shutdown() + + def submit(self, job_id, data): + while self.running: + try: + self._submit_inner(job_id, data) + return + except BrokenProcessPool: + # If any subprocess in the pool crashes, we get a BrokenProcessPool + # exception and the whole pool becomes unusable. Handle crashes by + # recreating the pool and resubmitting. + self.pool = self._new_pool(self.nprocs, False) + + def _submit_inner(self, job_id, data): + future = self.pool.submit(functools.partial(SubprocMain.do_job, data)) + + def callback(_): + if not self.running: + return + try: + result = future.result() + except Exception as e: + log.exception("Error in subprocess") + result = pickle.dumps(e, pickle.HIGHEST_PROTOCOL) + assert isinstance(result, bytes) + with self.write_lock: + if self.running: + _send_msg(self.write_pipe, job_id, result) + + future.add_done_callback(callback) + + @staticmethod + def do_job(data): + # do the pickle/unpickle in the sub-subproc + job = pickle.loads(data) + try: + result = job() + except Exception as e: + result = _SubprocExceptionInfo(traceback.format_exc()) + return pickle.dumps(result, pickle.HIGHEST_PROTOCOL) + + +AnyPool = typing.Union[ProcessPoolExecutor, SubprocPool] + + +def _warm_process_pool(pool: AnyPool, n: int): + if isinstance(pool, SubprocPool): + return # no need + assert isinstance(pool, ProcessPoolExecutor) + + # We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the + # slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread. + + # Examples: + # A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup + # tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup + + # So we want to start the workers early when it is still cheap, and also to allow the workers to get + # ready before we have work for them. + + # ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle. + # But if we waited until then fork time will be long and we will be waiting for the processes to initialize. + + # We force them to start here with some YOLOing of the internal methods. + + # TODO(masnesral): Are these still relevant? + if hasattr(pool, "_start_queue_management_thread"): + pool._start_queue_management_thread() + else: + for _ in range(n): + pool._adjust_process_count() + if hasattr(pool, "_start_executor_manager_thread"): + pool._start_executor_manager_thread() + + +class TestException(RuntimeError): + pass + + +def raise_testexc(): + raise TestException diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/watchdog.py b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/watchdog.py new file mode 100644 index 0000000000000000000000000000000000000000..f3956e1272e9b6997a10d8cd8f354094701df3a0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_worker/watchdog.py @@ -0,0 +1,38 @@ +# mypy: allow-untyped-defs +import os +import signal +from threading import Thread +from time import sleep +from typing import Optional + + +# If this process dies abnormally (e.g. segfault) +# it will not shut down the workers. Instead, +# the workers will have their parent reassigned to the +# init process. This launches a separate thread to +# watch for the worker getting reassigned, +# and cleans it up in this case. +# +# This function cannot be an inner function since otherwise mp_context="spawn" would +# not work for ProcessPoolExecutor since inner functions cannot be pickled. +def _async_compile_initializer(orig_ppid) -> None: + def run() -> None: + while True: + sleep(1) + if orig_ppid != os.getppid(): + os.kill(os.getpid(), signal.SIGKILL) + + global _watchdog_thread, _original_parent + _original_parent = orig_ppid + _watchdog_thread = Thread(target=run, daemon=True) + _watchdog_thread.start() + # Ignore Ctrl-C (i.e. SIGINT) sent to pool workers to avoid meaningless log spam. + signal.signal(signal.SIGINT, signal.SIG_IGN) + + +_watchdog_thread: Optional[Thread] = None +_original_parent: Optional[int] = None + + +def has_parent_changed() -> bool: + return _original_parent != os.getppid() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/cpp_builder.py b/.venv/lib/python3.11/site-packages/torch/_inductor/cpp_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..95a0bff86fd8a4f1963d279d096b82dd33934ab2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/cpp_builder.py @@ -0,0 +1,1511 @@ +# This CPP builder is designed to support both Windows and Linux OS. +# The design document please check this RFC: https://github.com/pytorch/pytorch/issues/124245 + +import copy +import errno +import functools +import json +import logging +import os +import platform +import re +import shlex +import shutil +import subprocess +import sys +import sysconfig +import warnings +from ctypes import cdll +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple, Union + +import torch +from torch._dynamo.utils import dynamo_timed +from torch._inductor import config, exc +from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA +from torch._inductor.runtime.runtime_utils import cache_dir +from torch.torch_version import TorchVersion + + +if config.is_fbcode(): + from triton.fb import build_paths # noqa: F401 + + from torch._inductor.fb.utils import ( + log_global_cache_errors, + log_global_cache_stats, + log_global_cache_vals, + use_global_cache, + ) +else: + + def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: + pass + + def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: + pass + + def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: + pass + + def use_global_cache() -> bool: + return False + + +# Windows need setup a temp dir to store .obj files. +_BUILD_TEMP_DIR = "CxxBuild" + +# initialize variables for compilation +_IS_LINUX = sys.platform.startswith("linux") +_IS_MACOS = sys.platform.startswith("darwin") +_IS_WINDOWS = sys.platform == "win32" + +SUBPROCESS_DECODE_ARGS = ("utf-8",) if _IS_WINDOWS else () + +log = logging.getLogger(__name__) + + +# =============================== toolchain =============================== +@functools.lru_cache(1) +def cpp_compiler_search(search: str) -> str: + from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT + + for cxx in search: + try: + if cxx is None: + # gxx package is only available for Linux + # according to https://anaconda.org/conda-forge/gxx/ + if sys.platform != "linux": + continue + # Do not install GXX by default + if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"): + continue + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock( + os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT + ) + with lock: + cxx = install_gcc_via_conda() + subprocess.check_output([cxx, "--version"]) + return cxx + except (subprocess.SubprocessError, FileNotFoundError, ImportError): + continue + raise exc.InvalidCxxCompiler + + +def install_gcc_via_conda() -> str: + """On older systems, this is a quick way to get a modern compiler""" + prefix = os.path.join(cache_dir(), "gcc") + cxx_path = os.path.join(prefix, "bin", "g++") + if not os.path.exists(cxx_path): + log.info("Downloading GCC via conda") + conda = os.environ.get("CONDA_EXE", "conda") + if conda is None: + conda = shutil.which("conda") + if conda is not None: + subprocess.check_call( + [ + conda, + "create", + f"--prefix={prefix}", + "--channel=conda-forge", + "--quiet", + "-y", + "python=3.8", + "gxx", + ], + stdout=subprocess.PIPE, + ) + return cxx_path + + +@functools.lru_cache(None) +def check_compiler_exist_windows(compiler: str) -> None: + """ + Check if compiler is ready, in case end user not activate MSVC environment. + """ + try: + output_msg = ( + subprocess.check_output([compiler, "/help"], stderr=subprocess.STDOUT) + .strip() + .decode(*SUBPROCESS_DECODE_ARGS) + ) + except FileNotFoundError as exc: + raise RuntimeError(f"Compiler: {compiler} is not found.") from exc + except subprocess.SubprocessError: + # Expected that some compiler(clang, clang++) is exist, but they not support `/help` args. + pass + + +def get_cpp_compiler() -> str: + if _IS_WINDOWS: + compiler = os.environ.get("CXX", "cl") + check_compiler_exist_windows(compiler) + else: + if config.is_fbcode(): + return ( + build_paths.cc() if torch.version.hip is None else build_paths.clang() + ) + if isinstance(config.cpp.cxx, (list, tuple)): + search = tuple(config.cpp.cxx) + else: + search = (config.cpp.cxx,) + compiler = cpp_compiler_search(search) + return compiler + + +@functools.lru_cache(None) +def _is_apple_clang(cpp_compiler: str) -> bool: + version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8") + return "Apple" in version_string.splitlines()[0] + + +def _is_clang(cpp_compiler: str) -> bool: + # Mac OS apple clang maybe named as gcc, need check compiler info. + if sys.platform == "darwin": + return _is_apple_clang(cpp_compiler) + elif _IS_WINDOWS: + # clang suite have many compilers, and only clang-cl is supported. + if re.search(r"((clang$)|(clang\+\+$))", cpp_compiler): + raise RuntimeError( + "Please use clang-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)." + ) + return bool(re.search(r"(clang-cl)", cpp_compiler)) + return bool(re.search(r"(clang|clang\+\+)", cpp_compiler)) + + +def _is_gcc(cpp_compiler: str) -> bool: + if sys.platform == "darwin" and _is_apple_clang(cpp_compiler): + return False + return bool(re.search(r"(gcc|g\+\+)", cpp_compiler)) + + +@functools.lru_cache(None) +def _is_msvc_cl(cpp_compiler: str) -> bool: + if not _IS_WINDOWS: + return False + + try: + output_msg = ( + subprocess.check_output([cpp_compiler, "/help"], stderr=subprocess.STDOUT) + .strip() + .decode(*SUBPROCESS_DECODE_ARGS) + ) + return "Microsoft" in output_msg.splitlines()[0] + except FileNotFoundError as exc: + return False + + return False + + +@functools.lru_cache(None) +def _is_intel_compiler(cpp_compiler: str) -> bool: + def _check_minimal_version(compiler_version: TorchVersion) -> None: + """ + On Windows: early version icx has `-print-file-name` issue, and can't preload correctly for inductor. + """ + min_version = "2024.2.1" if _IS_WINDOWS else "0.0.0" + if compiler_version < TorchVersion(min_version): + raise RuntimeError( + f"Intel Compiler error: less than minimal version {min_version}." + ) + + try: + output_msg = ( + subprocess.check_output( + [cpp_compiler, "--version"], stderr=subprocess.DEVNULL + ) + .strip() + .decode(*SUBPROCESS_DECODE_ARGS) + ) + is_intel_compiler = "Intel" in output_msg.splitlines()[0] + if is_intel_compiler: + if _IS_WINDOWS: + if re.search(r"((icx$)|(icx-cc$))", cpp_compiler): + raise RuntimeError( + "Please use icx-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)." + ) + + # Version check + icx_ver_search = re.search(r"(\d+[.]\d+[.]\d+[.]\d+)", output_msg) + if icx_ver_search is not None: + icx_ver = icx_ver_search.group(1) + _check_minimal_version(TorchVersion(icx_ver)) + + return is_intel_compiler + except FileNotFoundError as exc: + return False + except subprocess.SubprocessError: + # --version args not support. + return False + + return False + + +@functools.lru_cache(None) +def is_gcc() -> bool: + return _is_gcc(get_cpp_compiler()) + + +@functools.lru_cache(None) +def is_clang() -> bool: + return _is_clang(get_cpp_compiler()) + + +@functools.lru_cache(None) +def is_intel_compiler() -> bool: + return _is_intel_compiler(get_cpp_compiler()) + + +@functools.lru_cache(None) +def is_apple_clang() -> bool: + return _is_apple_clang(get_cpp_compiler()) + + +@functools.lru_cache(None) +def is_msvc_cl() -> bool: + return _is_msvc_cl(get_cpp_compiler()) + + +def get_compiler_version_info(compiler: str) -> str: + env = os.environ.copy() + env["LC_ALL"] = "C" # Don't localize output + try: + version_string = subprocess.check_output( + [compiler, "-v"], stderr=subprocess.STDOUT, env=env + ).decode(*SUBPROCESS_DECODE_ARGS) + except Exception as e: + try: + version_string = subprocess.check_output( + [compiler, "--version"], stderr=subprocess.STDOUT, env=env + ).decode(*SUBPROCESS_DECODE_ARGS) + except Exception as e: + return "" + # Mutiple lines to one line string. + version_string = version_string.replace("\r", "_") + version_string = version_string.replace("\n", "_") + return version_string + + +# =============================== cpp builder =============================== +def _append_list(dest_list: List[str], src_list: List[str]) -> None: + for item in src_list: + dest_list.append(copy.deepcopy(item)) + + +def _remove_duplication_in_list(orig_list: List[str]) -> List[str]: + new_list: List[str] = [] + for item in orig_list: + if item not in new_list: + new_list.append(item) + return new_list + + +def _create_if_dir_not_exist(path_dir: str) -> None: + if not os.path.exists(path_dir): + try: + Path(path_dir).mkdir(parents=True, exist_ok=True) + except OSError as exc: # Guard against race condition + if exc.errno != errno.EEXIST: + raise RuntimeError( # noqa: TRY200 (Use `raise from`) + f"Fail to create path {path_dir}" + ) + + +def _remove_dir(path_dir: str) -> None: + if os.path.exists(path_dir): + for root, dirs, files in os.walk(path_dir, topdown=False): + for name in files: + file_path = os.path.join(root, name) + os.remove(file_path) + for name in dirs: + dir_path = os.path.join(root, name) + os.rmdir(dir_path) + os.rmdir(path_dir) + + +def _run_compile_cmd(cmd_line: str, cwd: str) -> bytes: + cmd = shlex.split(cmd_line) + try: + status = subprocess.check_output(args=cmd, cwd=cwd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + output = e.output.decode("utf-8") + openmp_problem = "'omp.h' file not found" in output or "libomp" in output + if openmp_problem and sys.platform == "darwin": + instruction = ( + "\n\nOpenMP support not found. Please try one of the following solutions:\n" + "(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ " + "that has builtin OpenMP support;\n" + "(2) install OpenMP via conda: `conda install llvm-openmp`;\n" + "(3) install libomp via brew: `brew install libomp`;\n" + "(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path" + " with `include/omp.h` under it." + ) + output += instruction + raise exc.CppCompileError(cmd, output) from e + return status + + +def run_compile_cmd(cmd_line: str, cwd: str) -> bytes: + with dynamo_timed("compile_file"): + return _run_compile_cmd(cmd_line, cwd) + + +def normalize_path_separator(orig_path: str) -> str: + if _IS_WINDOWS: + return orig_path.replace(os.sep, "/") + return orig_path + + +class BuildOptionsBase: + """ + This is the Base class for store cxx build options, as a template. + Acturally, to build a cxx shared library. We just need to select a compiler + and maintains the suitable args. + """ + + def __init__( + self, + compiler: str = "", + definitions: Optional[List[str]] = None, + include_dirs: Optional[List[str]] = None, + cflags: Optional[List[str]] = None, + ldflags: Optional[List[str]] = None, + libraries_dirs: Optional[List[str]] = None, + libraries: Optional[List[str]] = None, + passthrough_args: Optional[List[str]] = None, + aot_mode: bool = False, + use_absolute_path: bool = False, + compile_only: bool = False, + ) -> None: + self._compiler = compiler + self._definations: List[str] = definitions or [] + self._include_dirs: List[str] = include_dirs or [] + self._cflags: List[str] = cflags or [] + self._ldflags: List[str] = ldflags or [] + self._libraries_dirs: List[str] = libraries_dirs or [] + self._libraries: List[str] = libraries or [] + # Some args is hard to abstract to OS compatable, passthough it directly. + self._passthough_args: List[str] = passthrough_args or [] + + self._aot_mode: bool = aot_mode + self._use_absolute_path: bool = use_absolute_path + self._compile_only: bool = compile_only + + def _process_compile_only_options(self) -> None: + if self._compile_only: + self._libraries_dirs = [] + self._libraries = [] + + def _remove_duplicate_options(self) -> None: + self._definations = _remove_duplication_in_list(self._definations) + self._include_dirs = _remove_duplication_in_list(self._include_dirs) + self._cflags = _remove_duplication_in_list(self._cflags) + self._ldflags = _remove_duplication_in_list(self._ldflags) + self._libraries_dirs = _remove_duplication_in_list(self._libraries_dirs) + self._libraries = _remove_duplication_in_list(self._libraries) + self._passthough_args = _remove_duplication_in_list(self._passthough_args) + + def _finalize_options(self) -> None: + self._process_compile_only_options + self._remove_duplicate_options + + def get_compiler(self) -> str: + return self._compiler + + def get_definations(self) -> List[str]: + return self._definations + + def get_include_dirs(self) -> List[str]: + return self._include_dirs + + def get_cflags(self) -> List[str]: + return self._cflags + + def get_ldflags(self) -> List[str]: + return self._ldflags + + def get_libraries_dirs(self) -> List[str]: + return self._libraries_dirs + + def get_libraries(self) -> List[str]: + return self._libraries + + def get_passthough_args(self) -> List[str]: + return self._passthough_args + + def get_aot_mode(self) -> bool: + return self._aot_mode + + def get_use_absolute_path(self) -> bool: + return self._use_absolute_path + + def get_compile_only(self) -> bool: + return self._compile_only + + def save_flags_to_file(self, file: str) -> None: + attrs = { + "compiler": self.get_compiler(), + "definitions": self.get_definations(), + "include_dirs": self.get_include_dirs(), + "cflags": self.get_cflags(), + "ldflags": self.get_ldflags(), + "libraries_dirs": self.get_libraries_dirs(), + "libraries": self.get_libraries(), + "passthrough_args": self.get_passthough_args(), + "aot_mode": self.get_aot_mode(), + "use_absolute_path": self.get_use_absolute_path(), + "compile_only": self.get_compile_only(), + } + + with open(file, "w") as f: + json.dump(attrs, f) + + +def _get_warning_all_cflag(warning_all: bool = True) -> List[str]: + if not _IS_WINDOWS: + return ["Wall"] if warning_all else [] + else: + return [] + + +def _get_cpp_std_cflag(std_num: str = "c++17") -> List[str]: + if _IS_WINDOWS: + """ + On Windows, only c++20 can support `std::enable_if_t`. + Ref: https://learn.microsoft.com/en-us/cpp/overview/cpp-conformance-improvements-2019?view=msvc-170#checking-for-abstract-class-types # noqa: B950 + Note: + Only setup c++20 for Windows inductor. I tried to upgrade all project to c++20, but it is failed: + https://github.com/pytorch/pytorch/pull/131504 + """ + std_num = "c++20" + return [f"std:{std_num}"] + else: + return [f"std={std_num}"] + + +def _get_os_related_cpp_cflags(cpp_compiler: str) -> List[str]: + if _IS_WINDOWS: + cflags = [ + "wd4819", + "wd4251", + "wd4244", + "wd4267", + "wd4275", + "wd4018", + "wd4190", + "wd4624", + "wd4067", + "wd4068", + "EHsc", + ] + else: + cflags = ["Wno-unused-variable", "Wno-unknown-pragmas"] + if _is_clang(cpp_compiler): + cflags.append("Werror=ignored-optimization-argument") + return cflags + + +def _get_optimization_cflags() -> List[str]: + if _IS_WINDOWS: + return ["O2"] + else: + cflags = ["O0", "g"] if config.aot_inductor.debug_compile else ["O3", "DNDEBUG"] + cflags.append("ffast-math") + cflags.append("fno-finite-math-only") + + if not config.cpp.enable_unsafe_math_opt_flag: + cflags.append("fno-unsafe-math-optimizations") + if not config.cpp.enable_floating_point_contract_flag: + cflags.append("ffp-contract=off") + + if sys.platform != "darwin": + # https://stackoverflow.com/questions/65966969/why-does-march-native-not-work-on-apple-m1 + # `-march=native` is unrecognized option on M1 + if not config.is_fbcode(): + if platform.machine() == "ppc64le": + cflags.append("mcpu=native") + else: + cflags.append("march=native") + + return cflags + + +def _get_shared_cflag(compile_only: bool) -> List[str]: + if _IS_WINDOWS: + """ + MSVC `/MD` using python `ucrtbase.dll` lib as runtime. + https://learn.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=msvc-170 + """ + SHARED_FLAG = ["DLL", "MD"] + else: + if compile_only: + return ["fPIC"] + if platform.system() == "Darwin" and "clang" in get_cpp_compiler(): + # This causes undefined symbols to behave the same as linux + return ["shared", "fPIC", "undefined dynamic_lookup"] + else: + return ["shared", "fPIC"] + + return SHARED_FLAG + + +def get_cpp_options( + cpp_compiler: str, + compile_only: bool, + warning_all: bool = True, + extra_flags: Sequence[str] = (), +) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]: + definations: List[str] = [] + include_dirs: List[str] = [] + cflags: List[str] = [] + ldflags: List[str] = [] + libraries_dirs: List[str] = [] + libraries: List[str] = [] + passthough_args: List[str] = [] + + cflags = ( + _get_shared_cflag(compile_only) + + _get_optimization_cflags() + + _get_warning_all_cflag(warning_all) + + _get_cpp_std_cflag() + + _get_os_related_cpp_cflags(cpp_compiler) + ) + + passthough_args.append(" ".join(extra_flags)) + + return ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) + + +class CppOptions(BuildOptionsBase): + """ + This class is inherited from BuildOptionsBase, and as cxx build options. + This option need contains basic cxx build option, which contains: + 1. OS related args. + 2. Toolchains related args. + 3. Cxx standard related args. + Note: + 1. This Options is good for assist modules build, such as x86_isa_help. + """ + + def __init__( + self, + compile_only: bool = False, + warning_all: bool = True, + extra_flags: Sequence[str] = (), + use_absolute_path: bool = False, + ) -> None: + super().__init__() + self._compiler = get_cpp_compiler() + self._use_absolute_path = use_absolute_path + self._compile_only = compile_only + + ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) = get_cpp_options( + cpp_compiler=self._compiler, + compile_only=compile_only, + extra_flags=extra_flags, + warning_all=warning_all, + ) + + _append_list(self._definations, definations) + _append_list(self._include_dirs, include_dirs) + _append_list(self._cflags, cflags) + _append_list(self._ldflags, ldflags) + _append_list(self._libraries_dirs, libraries_dirs) + _append_list(self._libraries, libraries) + _append_list(self._passthough_args, passthough_args) + self._finalize_options() + + +def _get_glibcxx_abi_build_flags() -> List[str]: + if not _IS_WINDOWS: + return ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] + else: + return [] + + +def _get_torch_cpp_wrapper_defination() -> List[str]: + return ["TORCH_INDUCTOR_CPP_WRAPPER"] + + +def _use_custom_generated_macros() -> List[str]: + return [" C10_USING_CUSTOM_GENERATED_MACROS"] + + +def _use_fb_internal_macros() -> List[str]: + if not _IS_WINDOWS: + if config.is_fbcode(): + fb_internal_macros = [ + "C10_USE_GLOG", + "C10_USE_MINIMAL_GLOG", + "C10_DISABLE_TENSORIMPL_EXTENSIBILITY", + ] + # TODO: this is to avoid FC breakage for fbcode. When using newly + # generated model.so on an older verion of PyTorch, need to use + # the v1 version for aoti_torch_create_tensor_from_blob + create_tensor_from_blob_v1 = "AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1" + + fb_internal_macros.append(create_tensor_from_blob_v1) + return fb_internal_macros + else: + return [] + else: + return [] + + +def _setup_standard_sys_libs( + cpp_compiler: str, + aot_mode: bool, + use_absolute_path: bool, +) -> Tuple[List[str], List[str], List[str]]: + from torch._inductor.codecache import _LINKER_SCRIPT + + cflags: List[str] = [] + include_dirs: List[str] = [] + passthough_args: List[str] = [] + if _IS_WINDOWS: + return cflags, include_dirs, passthough_args + + if config.is_fbcode(): + cflags.append("nostdinc") + # Note that the order of include paths do matter, as a result + # we need to have several branches interleaved here + if torch.version.hip is None: + include_dirs.append(build_paths.sleef()) + include_dirs.append(build_paths.openmp()) + include_dirs.append(build_paths.python()) + if torch.version.hip is not None: + include_dirs.append(build_paths.clang_include()) + include_dirs.append(build_paths.gcc_include()) + include_dirs.append(build_paths.gcc_install_tools_include()) + else: + include_dirs.append(build_paths.cc_include()) + include_dirs.append(build_paths.libgcc()) + include_dirs.append(build_paths.libgcc_arch()) + include_dirs.append(build_paths.libgcc_backward()) + include_dirs.append(build_paths.glibc()) + include_dirs.append(build_paths.linux_kernel()) + include_dirs.append("include") + + if aot_mode and not use_absolute_path: + linker_script = _LINKER_SCRIPT + else: + linker_script = os.path.basename(_LINKER_SCRIPT) + + if _is_clang(cpp_compiler): + passthough_args.append(" --rtlib=compiler-rt") + passthough_args.append(" -fuse-ld=lld") + passthough_args.append(f" -Wl,--script={linker_script}") + passthough_args.append(" -B" + build_paths.glibc_lib()) + passthough_args.append(" -L" + build_paths.glibc_lib()) + + return cflags, include_dirs, passthough_args + + +def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> Tuple[List[str], List[str]]: + macros = [] + build_flags = [] + if vec_isa != invalid_vec_isa: + # Add Windows support later. + for x in vec_isa.build_macro(): + macros.append(copy.deepcopy(x)) + + build_flags = [vec_isa.build_arch_flags()] + + if config.is_fbcode(): + cap = str(vec_isa).upper() + macros = [ + f"CPU_CAPABILITY={cap}", + f"CPU_CAPABILITY_{cap}", + f"HAVE_{cap}_CPU_DEFINITION", + ] + + return macros, build_flags + + +def _get_torch_related_args( + include_pytorch: bool, aot_mode: bool +) -> Tuple[List[str], List[str], List[str]]: + from torch.utils.cpp_extension import _TORCH_PATH, TORCH_LIB_PATH + + include_dirs = [ + os.path.join(_TORCH_PATH, "include"), + os.path.join(_TORCH_PATH, "include", "torch", "csrc", "api", "include"), + # Some internal (old) Torch headers don't properly prefix their includes, + # so we need to pass -Itorch/lib/include/TH as well. + os.path.join(_TORCH_PATH, "include", "TH"), + os.path.join(_TORCH_PATH, "include", "THC"), + ] + libraries_dirs = [TORCH_LIB_PATH] + libraries = [] + if sys.platform != "darwin" and not config.is_fbcode(): + libraries = ["torch", "torch_cpu"] + if not aot_mode: + libraries.append("torch_python") + + if _IS_WINDOWS: + libraries.append("sleef") + + # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690 + if not config.abi_compatible: + libraries.append("c10") + libraries_dirs.append(TORCH_LIB_PATH) + + return include_dirs, libraries_dirs, libraries + + +def _get_python_include_dirs() -> List[str]: + include_dir = Path(sysconfig.get_path("include")) + # On Darwin Python executable from a framework can return + # non-existing /Library/Python/... include path, in which case + # one should use Headers folder from the framework + if not include_dir.exists() and platform.system() == "Darwin": + std_lib = Path(sysconfig.get_path("stdlib")) + include_dir = (std_lib.parent.parent / "Headers").absolute() + if not (include_dir / "Python.h").exists(): + warnings.warn(f"Can't find Python.h in {str(include_dir)}") + return [str(include_dir)] + + +def _get_python_related_args() -> Tuple[List[str], List[str]]: + python_include_dirs = _get_python_include_dirs() + python_include_path = sysconfig.get_path( + "include", scheme="nt" if _IS_WINDOWS else "posix_prefix" + ) + if python_include_path is not None: + python_include_dirs.append(python_include_path) + + if _IS_WINDOWS: + python_path = os.path.dirname(sys.executable) + python_lib_path = [os.path.join(python_path, "libs")] + else: + python_lib_path = [sysconfig.get_config_var("LIBDIR")] + + if config.is_fbcode(): + python_include_dirs.append(build_paths.python()) + + return python_include_dirs, python_lib_path + + +@functools.lru_cache(None) +def is_conda_llvm_openmp_installed() -> bool: + try: + command = "conda list llvm-openmp --json" + output = subprocess.check_output(command.split()).decode("utf8") + return len(json.loads(output)) > 0 + except subprocess.SubprocessError: + return False + + +@functools.lru_cache(None) +def homebrew_libomp() -> Tuple[bool, str]: + try: + # check if `brew` is installed + subprocess.check_output(["which", "brew"]) + # get the location of `libomp` if it is installed + # this is the location that `libomp` **would** be installed + # see https://github.com/Homebrew/brew/issues/10261#issuecomment-756563567 for details + libomp_path = ( + subprocess.check_output(["brew", "--prefix", "libomp"]) + .decode("utf8") + .strip() + ) + # check if `libomp` is installed + omp_available = os.path.exists(libomp_path) + return omp_available, libomp_path + except subprocess.SubprocessError: + return False, "" + + +@functools.lru_cache(None) +def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None: + try: + output = subprocess.check_output([cpp_compiler, "-print-file-name=bin"]).decode( + "utf8" + ) + omp_path = os.path.join(output.rstrip(), omp_name) + if os.path.isfile(omp_path): + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + omp_module = cdll.LoadLibrary(omp_path) + except subprocess.SubprocessError: + pass + + +@functools.lru_cache(None) +def perload_icx_libomp_win(cpp_compiler: str) -> None: + def _load_icx_built_in_lib_by_name(cpp_compiler: str, lib_name: str) -> bool: + try: + output = subprocess.check_output( + [cpp_compiler, f"-print-file-name={lib_name}"], + stderr=subprocess.DEVNULL, + ).decode(*SUBPROCESS_DECODE_ARGS) + omp_path = output.rstrip() + if os.path.isfile(omp_path): + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + omp_module = cdll.LoadLibrary(omp_path) + return True + except subprocess.SubprocessError: + pass + return False + + """ + Intel Compiler implenmented more math libraries than clang, for performance proposal. + We need preload them like openmp library. + """ + preload_list = [ + "libiomp5md.dll", # openmp + "svml_dispmd.dll", # svml library + "libmmd.dll", # libm + ] + + for lib_name in preload_list: + _load_icx_built_in_lib_by_name(cpp_compiler, lib_name) + + +def _get_openmp_args( + cpp_compiler: str, +) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str]]: + cflags: List[str] = [] + ldflags: List[str] = [] + include_dir_paths: List[str] = [] + lib_dir_paths: List[str] = [] + libs: List[str] = [] + passthough_args: List[str] = [] + if _IS_MACOS: + # Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang` + cflags.append("Xclang") + cflags.append("fopenmp") + + # only Apple builtin compilers (Apple Clang++) require openmp + omp_available = not _is_apple_clang(cpp_compiler) + + # check the `OMP_PREFIX` environment first + omp_prefix = os.getenv("OMP_PREFIX") + if omp_prefix is not None: + header_path = os.path.join(omp_prefix, "include", "omp.h") + valid_env = os.path.exists(header_path) + if valid_env: + include_dir_paths.append(os.path.join(omp_prefix, "include")) + lib_dir_paths.append(os.path.join(omp_prefix, "lib")) + else: + warnings.warn("environment variable `OMP_PREFIX` is invalid.") + omp_available = omp_available or valid_env + + if not omp_available: + libs.append("omp") + + # prefer to use openmp from `conda install llvm-openmp` + conda_prefix = os.getenv("CONDA_PREFIX") + if not omp_available and conda_prefix is not None: + omp_available = is_conda_llvm_openmp_installed() + if omp_available: + conda_lib_path = os.path.join(conda_prefix, "lib") + include_dir_paths.append(os.path.join(conda_prefix, "include")) + lib_dir_paths.append(conda_lib_path) + # Prefer Intel OpenMP on x86 machine + if os.uname().machine == "x86_64" and os.path.exists( + os.path.join(conda_lib_path, "libiomp5.dylib") + ): + libs.append("iomp5") + + # next, try to use openmp from `brew install libomp` + if not omp_available: + omp_available, libomp_path = homebrew_libomp() + if omp_available: + include_dir_paths.append(os.path.join(libomp_path, "include")) + lib_dir_paths.append(os.path.join(libomp_path, "lib")) + + # if openmp is still not available, we let the compiler to have a try, + # and raise error together with instructions at compilation error later + elif _IS_WINDOWS: + """ + On Windows, `clang` and `icx` have their specific openmp implenmention. + And the openmp lib is in compiler's some sub-directory. + For dynamic library(DLL) load, the Windows native APIs are `LoadLibraryA` and `LoadLibraryExA`, and their search + dependencies have some rules: + https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-loadlibraryexa#searching-for-dlls-and-dependencies + In some case, the rules may not include compiler's sub-directories. + So, it can't search and load compiler's openmp library correctly. + And then, the whole application would be broken. + + To avoid the openmp load failed, we can automatic locate the openmp binary and preload it. + 1. For clang, the function is `perload_clang_libomp_win`. + 2. For icx, the function is `perload_icx_libomp_win`. + """ + if _is_clang(cpp_compiler): + cflags.append("openmp") + libs.append("libomp") + perload_clang_libomp_win(cpp_compiler, "libomp.dll") + elif _is_intel_compiler(cpp_compiler): + cflags.append("Qiopenmp") + libs.append("libiomp5md") + perload_icx_libomp_win(cpp_compiler) + else: + # /openmp, /openmp:llvm + # llvm on Windows, new openmp: https://devblogs.microsoft.com/cppblog/msvc-openmp-update/ + # msvc openmp: https://learn.microsoft.com/zh-cn/cpp/build/reference/openmp-enable-openmp-2-0-support?view=msvc-170 + cflags.append("openmp") + cflags.append("openmp:experimental") # MSVC CL + else: + if config.is_fbcode(): + include_dir_paths.append(build_paths.openmp()) + + openmp_lib = build_paths.openmp_lib() + fb_openmp_extra_flags = f"-Wp,-fopenmp {openmp_lib}" + passthough_args.append(fb_openmp_extra_flags) + + libs.append("omp") + else: + if _is_clang(cpp_compiler): + # TODO: fix issue, can't find omp.h + cflags.append("fopenmp") + libs.append("gomp") + elif _is_intel_compiler(cpp_compiler): + cflags.append("fiopenmp") + else: + cflags.append("fopenmp") + libs.append("gomp") + + return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthough_args + + +def get_mmap_self_macro(use_mmap_weights: bool) -> List[str]: + macros = [] + if use_mmap_weights: + macros.append(" USE_MMAP_SELF") + return macros + + +def get_cpp_torch_options( + cpp_compiler: str, + vec_isa: VecISA, + include_pytorch: bool, + aot_mode: bool, + compile_only: bool, + use_absolute_path: bool, + use_mmap_weights: bool, +) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]: + definations: List[str] = [] + include_dirs: List[str] = [] + cflags: List[str] = [] + ldflags: List[str] = [] + libraries_dirs: List[str] = [] + libraries: List[str] = [] + passthough_args: List[str] = [] + + torch_cpp_wrapper_definations = _get_torch_cpp_wrapper_defination() + use_custom_generated_macros_definations = _use_custom_generated_macros() + + ( + sys_libs_cflags, + sys_libs_include_dirs, + sys_libs_passthough_args, + ) = _setup_standard_sys_libs(cpp_compiler, aot_mode, use_absolute_path) + + isa_macros, isa_ps_args_build_flags = _get_build_args_of_chosen_isa(vec_isa) + + ( + torch_include_dirs, + torch_libraries_dirs, + torch_libraries, + ) = _get_torch_related_args(include_pytorch=include_pytorch, aot_mode=aot_mode) + + python_include_dirs, python_libraries_dirs = _get_python_related_args() + + ( + omp_cflags, + omp_ldflags, + omp_include_dir_paths, + omp_lib_dir_paths, + omp_lib, + omp_passthough_args, + ) = _get_openmp_args(cpp_compiler) + + cxx_abi_passthough_args = _get_glibcxx_abi_build_flags() + fb_macro_passthough_args = _use_fb_internal_macros() + + mmap_self_macros = get_mmap_self_macro(use_mmap_weights) + + definations = ( + torch_cpp_wrapper_definations + + use_custom_generated_macros_definations + + isa_macros + + fb_macro_passthough_args + + mmap_self_macros + ) + include_dirs = ( + sys_libs_include_dirs + + python_include_dirs + + torch_include_dirs + + omp_include_dir_paths + ) + cflags = sys_libs_cflags + omp_cflags + ldflags = omp_ldflags + libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths + libraries = torch_libraries + omp_lib + passthough_args = ( + sys_libs_passthough_args + + isa_ps_args_build_flags + + cxx_abi_passthough_args + + omp_passthough_args + ) + + return ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) + + +class CppTorchOptions(CppOptions): + """ + This class is inherited from CppTorchOptions, which automatic contains + base cxx build options. And then it will maintains torch related build + args. + 1. Torch include_directories, libraries, libraries_directories. + 2. Python include_directories, libraries, libraries_directories. + 3. OpenMP related. + 4. Torch MACROs. + 5. MISC + """ + + def __init__( + self, + vec_isa: VecISA = invalid_vec_isa, + include_pytorch: bool = False, + warning_all: bool = True, + aot_mode: bool = False, + compile_only: bool = False, + use_absolute_path: bool = False, + use_mmap_weights: bool = False, + shared: bool = True, + extra_flags: Sequence[str] = (), + ) -> None: + super().__init__( + compile_only=compile_only, + warning_all=warning_all, + extra_flags=extra_flags, + use_absolute_path=use_absolute_path, + ) + + self._aot_mode = aot_mode + + ( + torch_definations, + torch_include_dirs, + torch_cflags, + torch_ldflags, + torch_libraries_dirs, + torch_libraries, + torch_passthough_args, + ) = get_cpp_torch_options( + cpp_compiler=self._compiler, + vec_isa=vec_isa, + include_pytorch=include_pytorch, + aot_mode=aot_mode, + compile_only=compile_only, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + ) + + _append_list(self._definations, torch_definations) + _append_list(self._include_dirs, torch_include_dirs) + _append_list(self._cflags, torch_cflags) + _append_list(self._ldflags, torch_ldflags) + _append_list(self._libraries_dirs, torch_libraries_dirs) + _append_list(self._libraries, torch_libraries) + _append_list(self._passthough_args, torch_passthough_args) + self._finalize_options() + + +def _set_gpu_runtime_env() -> None: + if ( + config.is_fbcode() + and torch.version.hip is None + and "CUDA_HOME" not in os.environ + and "CUDA_PATH" not in os.environ + ): + os.environ["CUDA_HOME"] = build_paths.cuda() + + +def _transform_cuda_paths(lpaths: List[str]) -> None: + # This handles two cases: + # 1. Meta internal cuda-12 where libs are in lib/cuda-12 and lib/cuda-12/stubs + # 2. Linux machines may have CUDA installed under either lib64/ or lib/ + for i, path in enumerate(lpaths): + if ( + "CUDA_HOME" in os.environ + and path.startswith(os.environ["CUDA_HOME"]) + and not os.path.exists(f"{path}/libcudart_static.a") + ): + for root, dirs, files in os.walk(path): + if "libcudart_static.a" in files: + lpaths[i] = os.path.join(path, root) + lpaths.append(os.path.join(lpaths[i], "stubs")) + break + + +def get_cpp_torch_cuda_options( + cuda: bool, + aot_mode: bool = False, + compile_only: bool = False, +) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]: + definations: List[str] = [] + include_dirs: List[str] = [] + cflags: List[str] = [] + ldflags: List[str] = [] + libraries_dirs: List[str] = [] + libraries: List[str] = [] + passthough_args: List[str] = [] + if ( + config.is_fbcode() + and "CUDA_HOME" not in os.environ + and "CUDA_PATH" not in os.environ + ): + os.environ["CUDA_HOME"] = ( + build_paths.rocm() if torch.version.hip else build_paths.cuda() + ) + + _set_gpu_runtime_env() + from torch.utils import cpp_extension + + include_dirs = cpp_extension.include_paths(cuda) + libraries_dirs = cpp_extension.library_paths(cuda) + + if cuda: + definations.append(" USE_ROCM" if torch.version.hip else " USE_CUDA") + + if torch.version.hip is not None: + if config.is_fbcode(): + libraries += ["amdhip64"] + else: + libraries += ["c10_hip", "torch_hip"] + definations.append(" __HIP_PLATFORM_AMD__") + else: + if config.is_fbcode(): + libraries += ["cuda"] + else: + libraries += ["c10_cuda", "cuda", "torch_cuda"] + + if aot_mode: + if config.is_fbcode(): + from torch._inductor.codecache import cpp_prefix_path + + cpp_prefix_include_dir = [f"{os.path.dirname(cpp_prefix_path())}"] + include_dirs += cpp_prefix_include_dir + + if cuda and torch.version.hip is None: + _transform_cuda_paths(libraries_dirs) + + if config.is_fbcode(): + if torch.version.hip is not None: + include_dirs.append(os.path.join(build_paths.rocm(), "include")) + else: + include_dirs.append(os.path.join(build_paths.cuda(), "include")) + + if aot_mode and cuda: + if torch.version.hip is None: + if not compile_only: + # Only add link args, when compile_only is false. + passthough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"] + + return ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) + + +class CppTorchCudaOptions(CppTorchOptions): + """ + This class is inherited from CppTorchOptions, which automatic contains + base cxx build options and torch common build options. And then it will + maintains cuda device related build args. + """ + + def __init__( + self, + vec_isa: VecISA = invalid_vec_isa, + include_pytorch: bool = False, + cuda: bool = True, + aot_mode: bool = False, + compile_only: bool = False, + use_absolute_path: bool = False, + use_mmap_weights: bool = False, + shared: bool = True, + extra_flags: Sequence[str] = (), + ) -> None: + super().__init__( + vec_isa=vec_isa, + include_pytorch=include_pytorch, + aot_mode=aot_mode, + compile_only=compile_only, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + extra_flags=extra_flags, + ) + + cuda_definations: List[str] = [] + cuda_include_dirs: List[str] = [] + cuda_cflags: List[str] = [] + cuda_ldflags: List[str] = [] + cuda_libraries_dirs: List[str] = [] + cuda_libraries: List[str] = [] + cuda_passthough_args: List[str] = [] + + ( + cuda_definations, + cuda_include_dirs, + cuda_cflags, + cuda_ldflags, + cuda_libraries_dirs, + cuda_libraries, + cuda_passthough_args, + ) = get_cpp_torch_cuda_options( + cuda=cuda, aot_mode=aot_mode, compile_only=compile_only + ) + _append_list(self._definations, cuda_definations) + _append_list(self._include_dirs, cuda_include_dirs) + _append_list(self._cflags, cuda_cflags) + _append_list(self._ldflags, cuda_ldflags) + _append_list(self._libraries_dirs, cuda_libraries_dirs) + _append_list(self._libraries, cuda_libraries) + _append_list(self._passthough_args, cuda_passthough_args) + self._finalize_options() + + +def get_name_and_dir_from_output_file_path( + file_path: str, +) -> Tuple[str, str]: + """ + This function help prepare parameters to new cpp_builder. + Example: + input_code: /tmp/tmpof1n5g7t/5c/c5crkkcdvhdxpktrmjxbqkqyq5hmxpqsfza4pxcf3mwk42lphygc.cpp + name, dir = get_name_and_dir_from_output_file_path(input_code) + Run result: + name = c5crkkcdvhdxpktrmjxbqkqyq5hmxpqsfza4pxcf3mwk42lphygc + dir = /tmp/tmpof1n5g7t/5c/ + + put 'name' and 'dir' to CppBuilder's 'name' and 'output_dir'. + CppBuilder --> get_target_file_path will format output path accoding OS: + Linux: /tmp/tmppu87g3mm/zh/czhwiz4z7ca7ep3qkxenxerfjxy42kehw6h5cjk6ven4qu4hql4i.so + Windows: [Windows temp path]/tmppu87g3mm/zh/czhwiz4z7ca7ep3qkxenxerfjxy42kehw6h5cjk6ven4qu4hql4i.dll + """ + name_and_ext = os.path.basename(file_path) + name, ext = os.path.splitext(name_and_ext) + dir = os.path.dirname(file_path) + + return name, dir + + +class CppBuilder: + """ + CppBuilder is a cpp jit builder, and it supports both Windows, Linux and MacOS. + Args: + name: + 1. Build target name, the final target file will append extension type automatically. + 2. Due to the CppBuilder is supports mutliple OS, it will maintains ext for OS difference. + sources: + Source code file list to be built. + BuildOption: + Build options to the builder. + output_dir: + 1. The output_dir the taget file will output to. + 2. The default value is empty string, and then the use current dir as output dir. + 3. Final target file: output_dir/name.ext + """ + + def __get_python_module_ext(self) -> str: + SHARED_LIB_EXT = ".pyd" if _IS_WINDOWS else ".so" + return SHARED_LIB_EXT + + def __get_object_ext(self) -> str: + EXT = ".obj" if _IS_WINDOWS else ".o" + return EXT + + def __init__( + self, + name: str, + sources: Union[str, List[str]], + BuildOption: BuildOptionsBase, + output_dir: str = "", + ) -> None: + self._compiler = "" + self._cflags_args = "" + self._definations_args = "" + self._include_dirs_args = "" + self._ldflags_args = "" + self._libraries_dirs_args = "" + self._libraries_args = "" + self._passthough_parameters_args = "" + + self._output_dir = "" + self._target_file = "" + + self._use_absolute_path: bool = False + self._aot_mode: bool = False + + self._name = name + + # Code start here, initial self internal veriables firstly. + self._compiler = BuildOption.get_compiler() + self._use_absolute_path = BuildOption.get_use_absolute_path() + self._aot_mode = BuildOption.get_aot_mode() + + self._output_dir = output_dir + + self._compile_only = BuildOption.get_compile_only() + file_ext = ( + self.__get_object_ext() + if self._compile_only + else self.__get_python_module_ext() + ) + self._target_file = os.path.join(self._output_dir, f"{self._name}{file_ext}") + + if isinstance(sources, str): + sources = [sources] + + if config.is_fbcode(): + if self._aot_mode and not self._use_absolute_path: + inp_name = sources + # output process @ get_name_and_dir_from_output_file_path + else: + # We need to copy any absolute-path torch includes + inp_name = [os.path.basename(i) for i in sources] + self._target_file = os.path.basename(self._target_file) + + self._sources_args = " ".join(inp_name) + else: + self._sources_args = " ".join(sources) + + for cflag in BuildOption.get_cflags(): + if _IS_WINDOWS: + self._cflags_args += f"/{cflag} " + else: + self._cflags_args += f"-{cflag} " + + for defination in BuildOption.get_definations(): + if _IS_WINDOWS: + self._definations_args += f"/D {defination} " + else: + self._definations_args += f"-D {defination} " + + for inc_dir in BuildOption.get_include_dirs(): + if _IS_WINDOWS: + self._include_dirs_args += f"/I {inc_dir} " + else: + self._include_dirs_args += f"-I{inc_dir} " + + for ldflag in BuildOption.get_ldflags(): + if _IS_WINDOWS: + self._ldflags_args += f"/{ldflag} " + else: + self._ldflags_args += f"-{ldflag} " + + for lib_dir in BuildOption.get_libraries_dirs(): + if _IS_WINDOWS: + self._libraries_dirs_args += f'/LIBPATH:"{lib_dir}" ' + else: + self._libraries_dirs_args += f"-L{lib_dir} " + + for lib in BuildOption.get_libraries(): + if _IS_WINDOWS: + self._libraries_args += f'"{lib}.lib" ' + else: + self._libraries_args += f"-l{lib} " + + for passthough_arg in BuildOption.get_passthough_args(): + self._passthough_parameters_args += f"{passthough_arg} " + + def get_command_line(self) -> str: + def format_build_command( + compiler: str, + sources: str, + include_dirs_args: str, + definations_args: str, + cflags_args: str, + ldflags_args: str, + libraries_args: str, + libraries_dirs_args: str, + passthougn_args: str, + target_file: str, + ) -> str: + if _IS_WINDOWS: + # https://learn.microsoft.com/en-us/cpp/build/walkthrough-compile-a-c-program-on-the-command-line?view=msvc-1704 + # https://stackoverflow.com/a/31566153 + cmd = ( + f"{compiler} {include_dirs_args} {definations_args} {cflags_args} {sources} " + f"{passthougn_args} /LD /Fe{target_file} /link {libraries_dirs_args} {libraries_args} {ldflags_args} " + ) + cmd = normalize_path_separator(cmd) + else: + compile_only_arg = "-c" if self._compile_only else "" + cmd = re.sub( + r"[ \n]+", + " ", + f""" + {compiler} {sources} {definations_args} {cflags_args} {include_dirs_args} + {passthougn_args} {ldflags_args} {libraries_args} {libraries_dirs_args} {compile_only_arg} -o {target_file} + """, + ).strip() + return cmd + + command_line = format_build_command( + compiler=self._compiler, + sources=self._sources_args, + include_dirs_args=self._include_dirs_args, + definations_args=self._definations_args, + cflags_args=self._cflags_args, + ldflags_args=self._ldflags_args, + libraries_args=self._libraries_args, + libraries_dirs_args=self._libraries_dirs_args, + passthougn_args=self._passthough_parameters_args, + target_file=self._target_file, + ) + return command_line + + def get_target_file_path(self) -> str: + return normalize_path_separator(self._target_file) + + def build(self) -> Tuple[bytes, str]: + """ + It is must need a temperary directory to store object files in Windows. + After build completed, delete the temperary directory to save disk space. + """ + _create_if_dir_not_exist(self._output_dir) + _build_tmp_dir = os.path.join( + self._output_dir, f"{self._name}_{_BUILD_TEMP_DIR}" + ) + _create_if_dir_not_exist(_build_tmp_dir) + + build_cmd = self.get_command_line() + + status = run_compile_cmd(build_cmd, cwd=_build_tmp_dir) + + _remove_dir(_build_tmp_dir) + return status, self._target_file diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/loop_body.py b/.venv/lib/python3.11/site-packages/torch/_inductor/loop_body.py new file mode 100644 index 0000000000000000000000000000000000000000..2f0b00b724e71fb8620d45f6601635f37772baf4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/loop_body.py @@ -0,0 +1,594 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import itertools +import re +from enum import auto, Enum +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple + +import sympy + +import torch.fx +from torch._dynamo.utils import identity +from torch.utils._sympy.symbol import SymT + +from . import config, dependencies +from .codegen.common import index_prevent_reordering +from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs +from .virtualized import ops, V + + +class InterpreterShim(torch.fx.Interpreter): + @staticmethod + @functools.lru_cache(None) + def _dummy_gm(): + return torch.fx.symbolic_trace(identity) + + def __init__(self, graph, submodules): + # call super() with a placeholder to avoid constructing a + # GraphModule which is very expensive (it does codegen). + super().__init__(self._dummy_gm(), garbage_collect_values=False) + self.module = self # type: ignore[assignment] + self.graph = graph + self.submodules = submodules + self.extra_traceback = False + self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign] + self.current_node = None + + def run_node(self, n: torch.fx.Node) -> Any: + self.current_node = n + return super().run_node(n) + + def run(self, *args, **kwargs): + with V.set_interpreter_handler(self): + return super().run(*args, **kwargs) + + +class MemoryEntry(NamedTuple): + index_name: str # LoopBody.indexing_exprs[index_name] + buffer_name: Optional[str] + mode: Optional[str] # V.ops.store(..., mode=mode) + + +class MemoryUsageType(Enum): + # These are 1:1 with the opcode generating the usage + LOAD = auto() + LOAD_SEED = auto() + STORE = auto() + STORE_REDUCTION = auto() + INDEX_EXPR = auto() + CHECK_BOUNDS = auto() + BUCKETIZE = auto() + + +class LoopBody: + """ + Captures the body of a Loops subclass into an FX graph. Persists any + indexing simplifications and makes it easier to analyze loop bodies. + """ + + indexing_exprs: Dict[str, sympy.Expr] + indexing_exprs_name: Dict[sympy.Expr, str] + submodules: Dict[str, Any] + subblocks: Dict[str, LoopBodyBlock] + indirect_vars: List[str] + indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] + root_block: LoopBodyBlock + memory_usage: Dict[MemoryUsageType, List[MemoryEntry]] + + def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): + super().__init__() + + _flat_sizes = tuple(var_ranges.values()) + self.sizes = ( + _flat_sizes[: len(iter_vars)], + _flat_sizes[len(iter_vars) :], + ) + + self.iter_vars = iter_vars + self.reduce_vars = reduce_vars + self.var_ranges = var_ranges + + if isinstance(fn, LoopBody): + self._init_with_copy(fn, args) + else: + self._init_with_tracing(fn, args) + + self.indexing = None + + def _init_with_tracing(self, fn, args): + """Do an FX trace of an arbitrary callable to construct self""" + self.indexing_exprs = {} + self.indexing_exprs_name = {} + self.submodules = {"get_index": self.get_index} + self.subblocks = {} + self.indirect_vars = [] + self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {} + self.memory_usage = {t: [] for t in MemoryUsageType} + self.root_block = LoopBodyBlock(self, fn, args) # traces + del self.indexing_exprs_name # not used after _init_with_tracing + + def _init_with_copy(self, other: LoopBody, args): + """ + _init_with_tracing() is slow, so this is a fast path in the case + where we are just reordering/merging/splitting the args of an + existing LoopBody. + """ + indexing_exprs = other.indexing_from_args(args) + self.indexing_exprs = { + name: V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges) + for name, expr in indexing_exprs.items() + } + self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()} + self.indirect_vars = other.indirect_vars + self.indirect_var_ranges = other.indirect_var_ranges + self.memory_usage = other.memory_usage + self.root_block = other.root_block.clone(self) + + submodules = {**other.submodules} + submodules.pop("get_index") + self.submodules = { + "get_index": self.get_index, + **{k: v.clone(self) for k, v in submodules.items()}, # type: ignore[attr-defined] + } + + def merge_loops(self) -> LoopBody: + """ + Merge both iteration and reduction loops and return a new LoopBody. + """ + old_body = self + old_sizes = self.sizes + old_iter_vars, old_reduce_vars = old_body.vars + old_iter_sizes, old_reduce_sizes = old_sizes + + index_exprs = [*old_body.indexing_exprs.values()] + + iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops( + old_iter_vars, + old_iter_sizes, + index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes), + ) + + reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops( + old_reduce_vars, + old_reduce_sizes, + index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes), + ) + + # if iter_sizes == old_iter_sizes: + # # no dimensions get merged. + # return old_sizes, old_body + + # Note: if no dimension get merges, the symbol prefix will + # remain 'y'. But if we merge dimensions, we change prefix to + # 'z'. If this is an issue, we can always retrace the LoopBody + # to change symbol prefix to 'z'. + # + # There is indeed an issue due to symbol name conflicting. + # y0 maybe reused for the y dimension later. + ( + iter_vars, + reduce_vars, + ), var_ranges = dependencies.index_vars_no_squeeze( + iter_sizes, reduce_sizes, prefix="t" + ) + new_body = LoopBody( + old_body, + [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], + var_ranges, + iter_vars, + reduce_vars, + ) + + # use the original symbol prefix + # Can try to optimize if this is a bottleneck for compilation time + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + iter_sizes, reduce_sizes, prefix="z" + ) + new_body2 = LoopBody( + new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 + ) + return new_body2 + + def reorder_iter_loops(self, new_order) -> LoopBody: + """ + Reorder iteration loops and return a new LoopBody. + """ + from .ir import same_reorder + + old_body = self + old_sizes = self.sizes + assert len(old_sizes[0]) == len(new_order) + reorder_fn = same_reorder(new_order) + + iter_size, reduce_size = old_sizes + new_iter_size = reorder_fn(iter_size) + + new_sizes = (new_iter_size, reduce_size) + + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + *new_sizes, prefix="t" # type: ignore[arg-type] + ) + + inverse_order = {b: a for a, b in enumerate(new_order)} + inverse_order = [inverse_order[i] for i in range(len(new_order))] + + def new_body(*indices: Sequence[sympy.Expr]) -> Any: + index = list(itertools.chain(*indices)) + assert len(index) == len(iter_size) + len(reduce_size) + iter_idx = index[: len(iter_size)] + reduce_idx = index[len(iter_size) :] + iter_idx = [iter_idx[i] for i in inverse_order] + return old_body(iter_idx, reduce_idx) + + loop_body = LoopBody( + new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars + ) + + # use the original symbol prefix so we can do multiple round of reordering + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + *new_sizes, prefix="z" # type: ignore[arg-type] + ) + new_body = LoopBody( + loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 + ) + return new_body + + @property + def vars(self): + assert self.iter_vars is not None + assert self.reduce_vars is not None + return self.iter_vars, self.reduce_vars + + @cache_on_self + def get_nodes(self): + all_graphs = itertools.chain( + (self.root_block.graph,), + (block.graph for block in self.subblocks.values()), + ) + return [node for graph in all_graphs for node in graph.nodes] + + @cache_on_self + def bounds(self): + # Doing a local import to avoid dumping all the code here + from .bounds import BoundVars + + return BoundVars(self) + + def get_read_expr(self, buffer_name): + # reversed to match old behavior + for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]): + if entry.buffer_name == buffer_name: + return self.indexing_exprs[entry.index_name] + raise KeyError(buffer_name) + + def get_write_expr(self, buffer_name): + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ): + if entry.buffer_name == buffer_name: + return self.indexing_exprs[entry.index_name] + raise KeyError(buffer_name) + + def get_read_exprs(self): + return [ + self.indexing_exprs[entry.index_name] + for entry in self.memory_usage[MemoryUsageType.LOAD] + ] + + def get_write_exprs(self): + return [ + self.indexing_exprs[entry.index_name] + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ) + ] + + def debug_str(self): + lines = [f"var_ranges = {dict(self.var_ranges)}"] + lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) + lines.extend( + [ + block.debug_str(name) + for name, block in itertools.chain( + [("body", self.root_block)], self.subblocks.items() + ) + ] + ) + return "\n".join(lines) + + def is_memory_copy(self) -> bool: + """ + True of this contains only a single loads and store. + Note, this could involve a layout change. + """ + return ( + len(self.memory_usage[MemoryUsageType.LOAD]) == 1 + and len(self.memory_usage[MemoryUsageType.STORE]) == 1 + and len(self.submodules) == 1 # get_index + and self.root_block.contains_only_ops(("load", "store")) + ) + + __repr__ = debug_str + + def add_index_expr( + self, + expr: sympy.Expr, + mtype: MemoryUsageType, + buffer_name: Optional[str] = None, + mode: Optional[str] = None, + ): + name = self.indexing_exprs_name.get(expr) + if not name: + name = f"index{len(self.indexing_exprs)}" + self.indexing_exprs_name[expr] = name + self.indexing_exprs[name] = expr + self.memory_usage[mtype].append(MemoryEntry(name, buffer_name, mode)) + return name + + def add_submodule(self, block, prefix): + """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" + if prefix[-1].isnumeric() and prefix not in self.submodules: + name = prefix + else: + name = f"{prefix}{len(self.submodules)}" + self.submodules[name] = block + return name + + def add_indirect(self, size): + var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars)) + assert var not in self.indirect_var_ranges + self.indirect_vars.append(var) + self.indirect_var_ranges[var] = size + return var + + def replace_indirect(self, old, new): + """Swap in a variable used in indirect indexing""" + if str(old) == str(new): + return + assert self.indexing is not None + self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} + + def get_index(self, name): + assert self.indexing is not None + return self.indexing[name] + + def indexing_from_args(self, indices): + index = [*itertools.chain.from_iterable(indices)] + assert len(index) == len(self.var_ranges), (index, self.var_ranges) + assert all( + v not in self.var_ranges for v in index + ), f"{self.var_ranges=}, {indices=}" + replacements = dict(zip(self.var_ranges.keys(), index)) + return { + name: sympy_subs(expr, replacements) + for name, expr in self.indexing_exprs.items() + } + + def __call__(self, *indices): + self.indexing = self.indexing_from_args(indices) + result = self.root_block() + self.indexing = None + return result + + def bind_set_indirect_shim(self, var, size, check, wrap_neg): + def set_indirect(new_var): + self.replace_indirect( + var, V.ops.indirect_indexing(new_var, size, check, wrap_neg) + ) + + set_indirect.clone = functools.partial( # type: ignore[attr-defined] + LoopBody.bind_set_indirect_shim, + var=var, + size=size, + check=check, + wrap_neg=wrap_neg, + ) + return set_indirect + + def bind_scan_shim(self, combine_fn): + def shim(dtypes, values): + return V.ops.scan(dtypes, combine_fn, values) + + shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn) # type: ignore[attr-defined] + return shim + + def bind_masked_shim(self, name): + def shim(mask, other): + return V.ops.masked(mask, self.subblocks[name], other) + + shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name) # type: ignore[attr-defined] + return shim + + +class LoopBodyBlock: + """ + Captures the body of a Loops subclass into an FX graph. + In normal cases there will be a 1:1 mapping between LoopBody and + LoopBodyBlock, hower in the case of ops.masked() the masked out + operations will manifest as an extra LoopBodyBlock. + """ + + def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]): + self.body = body + + def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs): + return tracer.create_proxy( + "call_module", + "get_index", + (body.add_index_expr(expr, mtype, **kwargs),), + {}, + ) + + class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] + self.name = "CaptureIndexing" + + def load(self, name: str, index: sympy.Expr): + index = add_index(index, MemoryUsageType.LOAD, buffer_name=name) + return self._inner.load(name, index) + + def load_seed(self, name: str, index: int): + assert isinstance(index, int) + body.add_index_expr( + sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name + ) + return self._inner.load_seed(name, index) + + def store(self, name, index, value, mode=None): + index = add_index( + index, MemoryUsageType.STORE, buffer_name=name, mode=mode + ) + return self._inner.store(name, index, value, mode) + + def store_reduction(self, name, index, value): + index = add_index( + index, MemoryUsageType.STORE_REDUCTION, buffer_name=name + ) + return self._inner.store_reduction(name, index, value) + + def reduction(self, dtype, src_dtype, reduction_type, value): + result = self._inner.reduction(dtype, src_dtype, reduction_type, value) + if "welford" in reduction_type: + return tuple(result[i] for i in range(3)) + return result + + def index_expr(self, index, dtype): + if isinstance(index, (int, sympy.Integer)): + return self._inner.constant(int(index), dtype) + index = add_index(index, MemoryUsageType.INDEX_EXPR) + return self._inner.index_expr(index, dtype) + + def check_bounds(self, index, size, lower, upper): + index = add_index(index, MemoryUsageType.CHECK_BOUNDS) + size = add_index(size, MemoryUsageType.CHECK_BOUNDS) + return self._inner.check_bounds(index, size, lower, upper) + + def bucketize( + self, + values, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ): + offsets_size = add_index( + offsets_size, MemoryUsageType.BUCKETIZE, buffer_name=offsets_name + ) + return self._inner.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + + @staticmethod + def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): + """ + Recursively capture the masked out body in another LoopBodyBlock + """ + name = self.body.add_submodule(None, "masked_subblock") + self.body.submodules[name] = self.body.bind_masked_shim(name) + self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) + return tracer.create_proxy( + "call_module", name, (mask_proxy, other_proxy), {} + ) + + @staticmethod + def scan( + dtype_proxy, + combine_fn: Callable[ + [Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...] + ], + value_proxy, + ): + shim = self.body.bind_scan_shim(combine_fn) + name = self.body.add_submodule(shim, "scan") + result = tracer.create_proxy( + "call_module", + name, + (dtype_proxy, value_proxy), + {}, + ) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(value_proxy))) + + def sort(self, dtypes, values, stable, descending): + result = self._inner.sort(dtypes, values, stable, descending) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(values))) + + def frexp(self, value_proxy): + result = self._inner.frexp(value_proxy) + # Proxies are iterable, but some methods expect tuples/lists + return (result[0], result[1]) + + @staticmethod + def indirect_indexing(index_proxy, size, check=True, wrap_neg=True): + """ + Flow data from tensors into indexing formulas. + Introduce a call_module to update the indexing. + """ + + var = self.body.add_indirect(size) + set_indirect = self.body.bind_set_indirect_shim( + var, size, check, wrap_neg + ) + tracer.create_proxy( + "call_module", + self.body.add_submodule(set_indirect, f"set_{var}"), + (index_proxy,), + {}, + ) + return var + + @staticmethod + def output(result): + tracer.create_proxy("output", "output", (result,), {}) + + tracer = torch.fx.Tracer() + tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) + proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) + + from .index_propagation import IndexPropagation + from .sizevars import SimplifyIndexing + + handler: Any = SimplifyIndexing( + CaptureIndexing(proxy_ops), self.body.var_ranges + ) + if config.constant_and_index_propagation: + handler = IndexPropagation( + handler, self.body.var_ranges, self.body.indirect_var_ranges + ) + + with V.set_ops_handler(handler): + # This indirection is just a cute way to get IndexPropagation to + # unwrap the return value. + ops.output(fn(*args)) + self.graph = tracer.graph + + def __call__(self): + graph = self.graph + submodules = self.body.submodules + + return InterpreterShim(graph, submodules).run(V.get_ops_handler()) + + def debug_str(self, name="block"): + code = torch.fx.GraphModule(self.body.submodules, self.graph).code + return re.sub( + # strip `; del var0` suffixes to make output prettier + r";[^\n]*", + "", + code.strip().replace("def forward(", f"def {name}("), + ) + + def contains_only_ops(self, allowed_ops) -> bool: + return all( + node.target in allowed_ops + for node in self.graph.find_nodes(op="call_method") + ) + + def clone(self, body: LoopBody): + """Shallow copy with a new parent LoopBody""" + copy = LoopBodyBlock.__new__(LoopBodyBlock) + copy.__dict__.update({**self.__dict__, "body": body}) + return copy diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/ops_handler.py b/.venv/lib/python3.11/site-packages/torch/_inductor/ops_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..c47ee1026ab919f0d4c619a8e89278268f503d52 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/ops_handler.py @@ -0,0 +1,1093 @@ +# mypy: allow-untyped-defs +import itertools +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Literal, + NamedTuple, + Optional, + Tuple, + TypeVar, + Union, +) +from typing_extensions import Protocol +from unittest.mock import patch + +import sympy + +import torch +import torch.utils._pytree as pytree + +from ..utils._ordered_set import OrderedSet +from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str + + +T = TypeVar("T") +StoreMode = Optional[Literal["atomic_add"]] +ReductionType = Literal[ + "argmax", + "argmin", + "welford_reduce", + "welford_combine", + "any", + "max", + "min", + "prod", + "sum", + "xor_sum", +] + + +def _arg_str(a) -> str: + if isinstance(a, sympy.Expr): + return sympy_str(a) + return str(a) + + +# NB: This is not done as a parent class, because our ops handlers +# implementations make heavy use of __getattr__ magic, and pre-existing +# stubs for methods would interfere with this mechanism. +# +# TODO: A superclass that does desugaring for operations like +# reciprocal/square might be useful. +class OpsHandler(Protocol[T]): + """ + Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``, + as well as the contract for op handlers. The type T signifies the domain + of the abstract analysis AKA what all of the functions return / take as arguments + anywhere compute occurs. + + While these operators are typically dtype polymorphic (e.g., you can use mul + on both integers and floats), they do NOT do promotion and usually return the + same dtype as the input. You are expected to have handled type promotion + during ATen decompositions. Most operators correspond exactly to pointwise + operations as defined by torch, so when in doubt about semantics, check the + corresponding torch documentation. These are all scalar operations (so they + are defined to operate on a single element at a time.) + + For convenience, many operators take a src_dtype which indicates what the dtype + of the input argument is. Although in principle this can be derived by an + analysis, providing this for ops where it is useful helps avoid having to repeatedly + recompute dtype in code generation. + + Note that this often describes a class of static methods, for stateless + ops handlers. + + Handlers are often defined using ``__getattr__`` metaprogramming, which means + that you cannot declare that a type implements a protocol by inheriting from + it (as the type stubs count as attribute declarations and impede the getattr + magic method from being called). Instead, define a function that casts an + argument of your type to the protocol, which is sufficient to induce mypy to + test that the protocol is implemented correctly. Search for ``_typecheck_`` + in this file to see some examples. If you see an obscure error where a + class doesn't implement a Protocol, but mypy doesn't say why, check to see + that ``__getattr__`` is typed correctly (typically, it is not possible to + type ``__getattr__`` without typing it as ``Callable[..., Any]``) + """ + + def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T: + """Produces a scalar constant of type dtype.""" + ... + + def load_seed(self, name: str, offset: T): + """Computes inductor_prims.lookup_seed.""" + ... + + def rand(self, seed: T, offset: T) -> T: + """Computes inductor_prims.random with mode="rand". offset has dtype int32.""" + ... + + def randn(self, seed: T, offset: T) -> T: + """Computes inductor_prims.random with mode="randn". offset has dtype int32.""" + ... + + def randint64(self, seed: T, offset: T, low: T, high: T) -> T: + """Computes inductor_prims.randint. offset has dtype int32.""" + ... + + def masked(self, mask: T, body: Callable[[], T], other: T) -> T: + """ + Computes body, but only perform loads/stores if the boolean mask + evaluates to true. For example, you would use this if you needed to + perform an indirect load that may not be valid on some elements; + without masking, invalid accesses can cause IMAs. When mask is true, + the result is the result of body; otherwise it is other. Here, `other` + needs to be a constant. + + Contrast this with ops.where, which can multiplex between two values + that have been unconditionally computed. + """ + ... + + def where(self, condition: T, input: T, other: T) -> T: + """ + Computes torch.where: when condition is true, return input; otherwise return other. + """ + ... + + def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T: + """ + Converts a sympy expression into a scalar of type dtype. expr is typically + an indexing expression, thus the name; however, it can also be used in + non-indexing situations. + """ + ... + + def to_dtype( + self, + x: T, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types=True, + ) -> T: + """ + Convert x to dtype. src_dtype can be optionally set to specify what the original + dtype of x was, which can improve code generation (used by torch to(dtype=dtype)). + """ + ... + + def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with truncation semantics (similar to how the int + constructor works in Python). In Inductor codegen, this just decays + to trunc and then to_dtype, but this composite operation helps + roundtrips for Sympy evaluation. + + dtype is taken as an explicit parameter because the desired output + dtype is typically the index dtype, which may vary between int32 and + int64 depending on if we've shown that all the indexing operations can + be done in int32. + """ + ... + + def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def floor_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def round_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with round-to-even semantics. See also trunc_to_int. + """ + ... + + def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: + """ + Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) + src_dtype must be the original type of x. + """ + ... + + def identity(self, x: T) -> T: + """ + Returns x as is. This is used to trigger CSE. + """ + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These operations are only available in a "kernel" context. Check + # torch._inductor.codegen.common.CSEProxy for their typical implementation + # in op handler (routing to their respective implementations in the kernel + # handler) + # + # Importantly, inside a kernel, indexing and mask variables are available + # in scope, which are typically used by sympy.Expr indexing. + + def indirect_indexing( + self, x: T, size: sympy.Expr, check: bool = True, wrap_neg=True + ) -> sympy.Expr: + """ + Convert an integral x into a sympy.Expr that can be subsequently used in + indexing computation. 'size' represents an upper bound on the what valid + indexes can be; when 'check' is True, we check that the x is in bounds. + + NB: This is typically mandatory to implement for any analysis, because you + MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol). + """ + ... + + def load(self, name: str, index: sympy.Expr) -> T: + """ + Load from the memory location 'name', offset by some indexing expression 'index'. + """ + ... + + def store( + self, + name: str, + index: sympy.Expr, + value: T, + mode: StoreMode = None, + ) -> None: + """ + Store 'value' to the memory location 'name' offset by 'expr'. If + specified, 'mode' can require the store to be an atomic addition. + """ + ... + + # TODO: Better explain how the "collective" semantics of these ops; + # remember that the input value is a scalar, you can't reduce on it in the + # traditional sense! + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: T, + ) -> Union[T, Tuple[T, ...]]: + """ + Perform a 'reduction_type' reduction on 'value' of dtype 'src_dtype', + using 'dtype' as the accumulation dtype for the reduction. The result + is an intermediate computation which should be stored to the final + location using 'ops.store_reduction'. + + Valid reduction types are . For Welford reduction types, this + function returns multiple outputs; consult reduction_num_outputs to + determine the amount in metaprogramming applications. + """ + ... + + # TODO: in practice, this seems to actually return None, but not returning + # a T makes common __getattr__ idioms not type correctly. Figure out if + # this should be returning something. + def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T: + """ + Store the fully accumulated result of 'reduction' to the memory + location 'name' offset by 'expr'. + """ + ... + + def scan( + self, + dtypes: Tuple[torch.dtype, ...], + combine_fn: Callable[[Tuple[T, ...], Tuple[T, ...]], Tuple[T, ...]], + values: Tuple[T, ...], + ) -> Tuple[T, ...]: + """ + Perform an associative scan on 'value'. + """ + # TODO: Improve the description with some pseudocode + ... + + def sort( + self, + dtypes: Tuple[torch.dtype, ...], + values: Tuple[T, ...], + stable: bool, + descending: bool, + ) -> Tuple[T, ...]: + """ + Sort values along the reduction dimension. + """ + ... + + def bucketize( + self, + values: T, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> T: + # See [Note: Inductor bucketize op] + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # The following ops have semantics that correspond exactly to the torch + # operation with the same corresponding name. + + def abs(self, x0: T) -> T: + ... + + def exp(self, x0: T) -> T: + ... + + def exp2(self, x0: T) -> T: + ... + + def expm1(self, x0: T) -> T: + ... + + def sqrt(self, x0: T) -> T: + ... + + def relu(self, x0: T) -> T: + ... + + def minimum(self, x0: T, x1: T) -> T: + ... + + def maximum(self, x0: T, x1: T) -> T: + ... + + def cos(self, x0: T) -> T: + ... + + def sin(self, x0: T) -> T: + ... + + def lgamma(self, x0: T) -> T: + ... + + def erf(self, x0: T) -> T: + ... + + def cosh(self, x0: T) -> T: + ... + + def sinh(self, x0: T) -> T: + ... + + def acos(self, x0: T) -> T: + ... + + def acosh(self, x0: T) -> T: + ... + + def asin(self, x0: T) -> T: + ... + + def asinh(self, x0: T) -> T: + ... + + def atan2(self, x0: T, x1: T) -> T: + ... + + def atan(self, x0: T) -> T: + ... + + def atanh(self, x0: T) -> T: + ... + + def copysign(self, x0: T, x1: T) -> T: + ... + + def erfc(self, x0: T) -> T: + ... + + def erfinv(self, x0: T) -> T: + ... + + def frexp(self, x0: T): + ... + + def hypot(self, x0: T, x1: T) -> T: + ... + + def log10(self, x0: T) -> T: + ... + + def log2(self, x0: T) -> T: + ... + + def nextafter(self, x0: T, x1: T) -> T: + ... + + def logical_and(self, x0: T, x1: T) -> T: + ... + + def logical_not(self, x0: T) -> T: + ... + + def logical_or(self, x0: T, x1: T) -> T: + ... + + def logical_xor(self, x0: T, x1: T) -> T: + ... + + def bitwise_and(self, x0: T, x1: T) -> T: + ... + + def bitwise_not(self, x0: T) -> T: + ... + + def bitwise_or(self, x0: T, x1: T) -> T: + ... + + def bitwise_xor(self, x0: T, x1: T) -> T: + ... + + def bitwise_left_shift(self, x0: T, x1: T) -> T: + ... + + def bitwise_right_shift(self, x0: T, x1: T) -> T: + ... + + def rsqrt(self, x0: T) -> T: + ... + + def log1p(self, x0: T) -> T: + ... + + def tan(self, x0: T) -> T: + ... + + def tanh(self, x0: T) -> T: + ... + + def sigmoid(self, x0: T) -> T: + ... + + def signbit(self, x0: T) -> T: + ... + + def fmod(self, x0: T, x1: T) -> T: + ... + + def log(self, x0: T) -> T: + ... + + def isinf(self, x0: T) -> T: + ... + + def isnan(self, x0: T) -> T: + ... + + # NB: this returns a float, like the torch operation + # This rounds half to even to break ties + def round(self, x0: T) -> T: + ... + + # NB: this returns a float, like the torch operation + def floor(self, x0: T) -> T: + ... + + def sign(self, x0: T) -> T: + ... + + # NB: this returns a float, like the torch operation + def trunc(self, x0: T) -> T: + ... + + # NB: this returns a float, like the torch operation + def ceil(self, x0: T) -> T: + ... + + def neg(self, x0: T) -> T: + ... + + def reciprocal(self, x0: T) -> T: + ... + + def eq(self, x0: T, x1: T) -> T: + ... + + def ne(self, x0: T, x1: T) -> T: + ... + + def lt(self, x0: T, x1: T) -> T: + ... + + def gt(self, x0: T, x1: T) -> T: + ... + + def le(self, x0: T, x1: T) -> T: + ... + + def ge(self, x0: T, x1: T) -> T: + ... + + def add(self, x0: T, x1: T) -> T: + ... + + def sub(self, x0: T, x1: T) -> T: + ... + + def mul(self, x0: T, x1: T) -> T: + ... + + # NB: this returns a float, like the torch operation + def pow(self, x0: T, x1: T) -> T: + ... + + def and_(self, x0: T, x1: T) -> T: + ... + + def or_(self, x0: T, x1: T) -> T: + ... + + def xor(self, x0: T, x1: T) -> T: + ... + + # These are metaprogrammed by MockHandler._init_cls + def lshift(self, x0: T, x1: T) -> T: + ... + + def rshift(self, x0: T, x1: T) -> T: + ... + + def getitem(self, x0: T, x1: T) -> T: + # TODO: this is probably just illegal lol + ... + + def matmul(self, x0: T, x1: T) -> T: + # TODO: this is probably just illegal lol + ... + + def invert(self, x0: T) -> T: + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These are "special" operators. These only exist if the target + # language actually supports the operator. Keep this in sync with + # pointwise_overrides_data. + + def airy_ai(self, x: T) -> T: + ... + + def bessel_j0(self, x: T) -> T: + ... + + def bessel_j1(self, x: T) -> T: + ... + + def bessel_y0(self, x: T) -> T: + ... + + def bessel_y1(self, x: T) -> T: + ... + + def digamma(self, x: T) -> T: + ... + + def erfcx(self, x: T) -> T: + ... + + def fma(self, x: T, y: T, z: T) -> T: + ... + + def igamma(self, x: T, y: T) -> T: + ... + + def igammac(self, x: T, y: T) -> T: + ... + + def gammainc(self, x: T, y: T) -> T: + ... + + def gammaincc(self, x: T, y: T) -> T: + ... + + def i0(self, x: T) -> T: + ... + + def i0e(self, x: T) -> T: + ... + + def i1(self, x: T) -> T: + ... + + def i1e(self, x: T) -> T: + ... + + def log_ndtr(self, x: T) -> T: + ... + + def modified_bessel_i0(self, x: T) -> T: + ... + + def modified_bessel_i1(self, x: T) -> T: + ... + + def modified_bessel_k0(self, x: T) -> T: + ... + + def modified_bessel_k1(self, x: T) -> T: + ... + + def ndtr(self, x: T) -> T: + ... + + def ndtri(self, x: T) -> T: + ... + + def polygamma(self, x: T, y: T) -> T: + ... + + def scaled_modified_bessel_k0(self, x: T) -> T: + ... + + def scaled_modified_bessel_k1(self, x: T) -> T: + ... + + def spherical_bessel_j0(self, x: T) -> T: + ... + + def zeta(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_t(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_u(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_v(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_w(self, x: T, y: T) -> T: + ... + + def legendre_polynomial_p(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_t(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_u(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_v(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_w(self, x: T, y: T) -> T: + ... + + def hermite_polynomial_h(self, x: T, y: T) -> T: + ... + + def hermite_polynomial_he(self, x: T, y: T) -> T: + ... + + def laguerre_polynomial_l(self, x: T, y: T) -> T: + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These operators are a bit special, because they are conventionally + # natively supported in both Python and C, but the semantics differ so + # care must be taken + + def truncdiv(self, x0: T, x1: T) -> T: + """C-style trunc division between integers only. Computes the true + division of two numbers and rounds the result to zero. + """ + ... + + def floordiv(self, x0: T, x1: T) -> T: + """Python-style floor division between integers only. Computes the + true division of two numbers and floors the result. If you want + floor division for floats, do regular truediv and floor the result. + """ + ... + + def truediv(self, x0: T, x1: T) -> T: + """True division between floats. Integer inputs are NOT valid. To + do Python-style (int, int) -> float division, use int_truediv""" + ... + + def int_truediv(self, x0: T, x1: T) -> T: + """True division between integers. This is NOT the same as promoting + to float and doing integer division, there is a bespoke algorithm for + doing the division in higher precision than the above. + """ + ... + + def div(self, x0: T, x1: T) -> T: + """TODO: to be removed. This renders as / no matter what the backend is + which is incoherent.""" + ... + + def mod(self, x0: T, x1: T) -> T: + """C-style modulus, take sign from LHS (x0).""" + ... + + def remainder(self, x0: T, x1: T) -> T: + """Python-style modulus, take sign from RHS (x1).""" + ... + + def round_decimal(self, x0: T, x1: T) -> T: + """Python-style round with decimal argument""" + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # In CUDA, optimized implementations of other mathematical operations are + # offered separately via libdevice for double precision computation (in + # Triton, these go to tl.math rather than tl). We lower to these + # operators when doing FP64 on CUDA. Note that some operators + # unconditional go to tl.math. + # + # TODO(ezyang): Is this really the best way to do this? What if we have + # abs internally route to tl.math automatically when given a double + # precision input? One reason is that when doing codegen, we often don't + # know what the dtype of the inputs are! (In principle we do know, but + # for many analyses it's not conveniently available.) + + def libdevice_abs(self, x0: T) -> T: + ... + + def libdevice_exp(self, x0: T) -> T: + ... + + def libdevice_sqrt(self, x0: T) -> T: + ... + + def libdevice_cos(self, x0: T) -> T: + ... + + def libdevice_sin(self, x0: T) -> T: + ... + + def libdevice_sigmoid(self, x0: T) -> T: + ... + + def libdevice_log(self, x0: T) -> T: + ... + + +class NoopHandler: + def __getattr__(self, name): + if name == "name": + return "NoopHandler" + + def inner(*args, **kwargs): + return None + + return inner + + @staticmethod + def masked(mask, body, other) -> None: + return None + + @staticmethod + def frexp(x) -> Tuple[None, None]: + return (None, None) + + @staticmethod + def scan(dtypes, combine_fn, values) -> Tuple[None, ...]: + return (None,) * len(values) + + @staticmethod + def sort(dtypes, values, stable, descending) -> Tuple[None, ...]: + return (None,) * len(values) + + @staticmethod + def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: + return sympy.Integer(0) + + +# Use mypy to check protocol implemented correctly +def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]: + return h + + +class MockHandler: + def __getattr__(self, name): + if name == "name": + return "MockHandler" + + def inner(*args, **kwargs): + fargs = [_arg_str(a) for a in args] + fargs.extend(f"{k}={v}" for k, v in kwargs.items()) + return f"ops.{name}({', '.join(fargs)})" + + return inner + + @staticmethod + def masked(mask, body, other) -> str: + return f"ops.masked({mask}, {body()}, {other})" + + @staticmethod + def frexp(x): + return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]") + + @staticmethod + def scan(dtypes, combine_fn, values): + return tuple( + f"ops.scan({dtypes}, {combine_fn}, {values})[{i}]" + for i in range(len(values)) + ) + + @staticmethod + def sort(dtypes, values, stable, descending): + return tuple( + f"ops.sort({dtypes}, {values}, stable={stable}, descending={descending})[{i}]" + for i in range(len(values)) + ) + + @staticmethod + def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: + return sympy_index_symbol(str(index_var)) + + @classmethod + def _init_cls(cls): + def make_handler(format_string): + @staticmethod # type: ignore[misc] + def inner(*args): + return format_string.format(*args) + + return inner + + for name, format_string in { + "add": "{} + {}", + "sub": "{} - {}", + "mul": "{} * {}", + "floordiv": "{} // {}", + "truediv": "{} / {}", + "mod": "{} % {}", # careful, depending on target semantics varies + "pow": "{} ** {}", + "lshift": "{} << {}", + "rshift": "{} >> {}", + "and_": "{} & {}", + "or_": "{} | {}", + "xor": "{} ^ {}", + "eq": "{} == {}", + "ne": "{} != {}", + "lt": "{} < {}", + "gt": "{} > {}", + "le": "{} <= {}", + "ge": "{} >= {}", + "neg": "-{}", + }.items(): + setattr(cls, name, make_handler(format_string)) + + +MockHandler._init_cls() + + +# Use mypy to check protocol implemented correctly +def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]: + return h + + +class KernelFormatterHandler: + def __init__(self, parent_handler): + self.parent_handler = parent_handler + self.output = IndentedBuffer(1) + self.var_counter = itertools.count() + + @staticmethod + def ir_to_string(ir_fn, index, rindex=None) -> str: + from .ir import FlexibleLayout + from .virtualized import V + + args = [index, rindex] if rindex is not None else [index] + names = ["index", "rindex"] if rindex is not None else ["index"] + formatter = KernelFormatterHandler(MockHandler()) + + with formatter.output.indent(-1): + formatter.output.writeline(f"def inner_fn({', '.join(names)}):") + for name, arg in zip(names, args): + if arg: + lhs = ", ".join( + [ + str("_" if isinstance(v, (int, sympy.Integer)) else v) + for v in arg + ] + ) + formatter.output.writeline(f"{lhs} = {name}") + + with V.set_ops_handler(formatter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + result = ir_fn(*args) + return formatter.getvalue(result) + + def __getattr__(self, name) -> Callable[..., Any]: + def inner(*args, **kwargs): + line = getattr(self.parent_handler, name)(*args, **kwargs) + if name == "indirect_indexing": + return line + + def write(line): + # replace line with a new variable name + varname = f"tmp{next(self.var_counter)}" + self.output.writeline(f"{varname} = {line}") + return varname + + return pytree.tree_map(write, line) + + return inner + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[str, Tuple[str, ...]], + ) -> Union[str, Tuple[str, ...]]: + line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value) + num_values = reduction_num_outputs(reduction_type) + varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)] + self.output.writeline(f"{','.join(varnames)} = {line}") + return tuple(varnames) if num_values > 1 else varnames[0] + + def getvalue(self, result): + self.output.writeline(f"return {result}") + return self.output.getvalue() + + +# Use mypy to check protocol implemented correctly +def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]: + return h + + +class WrapperHandler(Generic[T]): + def __init__(self, inner: OpsHandler[T]): + self._inner = inner + + def __getattr__(self, item): + return getattr(self._inner, item) + + +# Use mypy to check protocol implemented correctly +def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]: + return h + + +class AddParenHandler(WrapperHandler[T]): + def __getattr__(self, name): + def inner(*args, **kwargs): + val = getattr(self._inner, name)(*args, **kwargs) + return f"({val})" + + return inner + + +# Use mypy to check protocol implemented correctly +def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]: + return h + + +class OpCountResult(NamedTuple): + num_ops: int + used_ops: OrderedSet[str] + read_buffers: List[str] + nontrivial_read_count: int + + +class OpCounterCSE: + """Shim to count how many ops are used""" + + def __init__(self, inner): + super().__init__() + self.parent_handler = inner + self.op_count = 0 + self.var_names = {} + self._used_ops: OrderedSet[str] = OrderedSet() + self._read_names: List[str] = [] + self._nontrivial_read_count = 0 + + def __getattr__(self, name): + def inner(*args, **kwargs): + return pytree.tree_map( + self._update_count, getattr(self.parent_handler, name)(*args, **kwargs) + ) + + self._used_ops.add(name) + return inner + + def _update_count(self, val): + varname = self.var_names.get(val) + if not varname: + varname = f"tmp{self.op_count}" + self.op_count += 1 + self.var_names[val] = varname + return varname + + def indirect_indexing(self, *args, **kwargs): + self._used_ops.add("indirect_indexing") + return self.parent_handler.indirect_indexing(*args, **kwargs) + + def load(self, name: str, index: sympy.Expr) -> str: + val = self.parent_handler.load(name, index) + if val not in self.var_names: + self._used_ops.add("load") + self._read_names.append(name) + if not isinstance(index, (sympy.Integer, int)): + self._nontrivial_read_count += 1 + return self._update_count(val) + + def load_seed(self, name: str, offset: T): + val = self.parent_handler.load_seed(name, offset) + if val not in self.var_names: + self._used_ops.add("load_seed") + self._read_names.append(name) + return self._update_count(val) + + def bucketize( + self, + values, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ): + val = self.parent_handler.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + if val not in self.var_names: + self._used_ops.add("bucketize") + self._read_names.append(offsets_name) + return self._update_count(val) + + def getvalue(self): + return OpCountResult( + self.op_count, self._used_ops, self._read_names, self._nontrivial_read_count + ) + + +def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]: + return h + + +class ExtractConstantsHandler(NoopHandler): + def __init__(self, device): + self.device = device + + def constant(self, value: Any, dtype: torch.dtype) -> "torch._inductor.ir.Constant": + from torch._inductor import ir + + return ir.Constant(value=value, dtype=dtype, device=self.device) + + +def _typecheck_ExtractConstantsHandler(h: ExtractConstantsHandler) -> OpsHandler[Any]: + return h + + +class SimpleCSEHandler(WrapperHandler[T]): + """Wraps the underlying handler with a CSE pass + + NOTE: Compared to codegen level CSE this is simplified as it + doesn't support stores which require load cache invalidation. + """ + + def __init__(self, inner: OpsHandler[T]): + super().__init__(inner) + self.cse_cache: Dict[str, Union[T, Tuple[T, ...]]] = {} + self.mock = MockHandler() + + def indirect_indexing(self, *args, **kwargs) -> sympy.Expr: + return super().indirect_indexing(*args, **kwargs) # type: ignore[misc] + + def store(self, *args, **kwargs) -> T: + raise NotImplementedError("store not implemented") + + def store_reduction(self, *args, **kwargs) -> T: + raise NotImplementedError("store not implemented") + + def __getattr__(self, name) -> Callable[..., Any]: + def inner(*args, **kwargs): + key = getattr(self.mock, name)(*args, **kwargs) + val = self.cse_cache.get(key) + if val is not None: + return val + + val = getattr(self._inner, name)(*args, **kwargs) + self.cse_cache[key] = val + return val + + return inner + + +def _typecheck_SimpleCSEHandler(h: SimpleCSEHandler[Any]) -> OpsHandler[Any]: + return h diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/optimize_indexing.py b/.venv/lib/python3.11/site-packages/torch/_inductor/optimize_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..96bf8641f3c9a62b3c61fe769132717ef493cf7f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/optimize_indexing.py @@ -0,0 +1,120 @@ +# mypy: allow-untyped-defs +import math + +import sympy + +import torch +from torch.utils._sympy.value_ranges import ValueRanges + +from .loop_body import LoopBody +from .utils import dominated_nodes + + +def val_expressable_in_32_bits(val): + if getattr(val, "is_Boolean", False): + return True + + if isinstance(val, sympy.Expr): + assert val.is_number + if val.is_Integer or val.is_Boolean: + val = int(val) + else: + val = float(val) + + # bound within mantissa + if isinstance(val, float): + return val <= (2**24) and val >= -(2**24) + + if isinstance(val, int): + iinfo = torch.iinfo(torch.int32) + return val <= iinfo.max and val >= iinfo.min + + raise TypeError(f"Unexpected value {val}") + + +def range_expressable_in_32_bits(range): + return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits( + range.upper + ) + + +def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_vals): + # if a downstream use of a node explicitly converts to int32, or float16/float32/float64, + # then it's precision is set for that chain of uses, and we don't need to consider those + # dominated values + def skip_filter(node): + return node.target == "to_dtype" and node.args[2] in ( + torch.int32, + torch.float32, + torch.float64, + ) + + # TODO - there are dominated uses whose dtype does not depend on whether + # we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to + # int32 without changing the output precision of the node. this case hasn't shown up + for dominated in dominated_nodes([node], skip_filter): + if dominated.target in ["store", "output"]: + continue + + if isinstance(dominated.target, str) and "set_indirect" in dominated.target: + idx = int(dominated.target[len("set_indirect") :]) + indirect_var = indirect_vars[idx] + + # We check that we can compute all the indices it's involved in with int32 + for index, expr in indices.items(): + if indirect_var in expr.free_symbols: + index_val = replacement_vals[index] + + if math.isinf(index_val.lower) or math.isinf(index_val.upper): + return + + # all indices are integers, so make sure that we + # use the bounds of integers instead of floats. + # TODO - not sure if we should be doing int/float casts while tracing, + # might interfere with sympy. + + index_val_int = ValueRanges[sympy.Expr]( + int(index_val.lower), int(index_val.upper) + ) + if not range_expressable_in_32_bits(index_val_int): + return + + if not range_expressable_in_32_bits(bounds[dominated]): + return + + args = list(node.args) + args[2] = torch.int32 + node.args = tuple(args) + + +def indexing_dtype_strength_reduction(loop_body: LoopBody): + """ + Performs Value Range Analysis on LoopBody's fx graph to reduce precision of + intermediaries from int64 to int32 + """ + bv = loop_body.bounds() + + int64_dtype_nodes = [ + node + for node in loop_body.get_nodes() + if ( + node.target == "to_dtype" + and node.args[2] == torch.int64 + and node not in bv.unbounded_vars + ) + ] + if not int64_dtype_nodes: + return + + bounds = bv.get_bounds() + + # TODO - if dominated node of one to_dtype is not expressible in int32, + # we should short circuit another to_dtype node if that node also dominates + for node in int64_dtype_nodes: + try_to_reduce_precision( + node, + bounds, + loop_body.indirect_vars, + loop_body.indexing_exprs, + bv.replacement_vals, + ) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d21fd1911dfb6649e38cb63696edf2a47beb18a8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/pt2_archive_constants.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/pt2_archive_constants.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6f42091884e9062c4e806aff1e5da6bfbc5e4c1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/pt2_archive_constants.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/scheduler.py b/.venv/lib/python3.11/site-packages/torch/_inductor/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..564a9b4ccfd8337e13c52c005c6ba5149fd85825 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/scheduler.py @@ -0,0 +1,3727 @@ +# mypy: disallow-untyped-defs +from __future__ import annotations + +import collections +import dataclasses +import functools +import itertools +import logging +import math +import operator +import os +import pprint +import textwrap +import traceback +import typing +from typing import ( + Any, + Callable, + Counter, + DefaultDict, + Dict, + Generic, + List, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, +) + +import sympy + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor.metrics import get_metric_table, is_metric_table_enabled +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.symbol import free_symbol_is_type, SymT +from torch.utils._triton import has_triton + +from . import comms, config, dependencies, ir, metrics +from .codecache import write_text +from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel +from .comm_analysis import estimate_nccl_collective_runtime +from .dependencies import Dep, MemoryDep, StarDep, WeakDep +from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout +from .loop_body import LoopBody +from .runtime.runtime_utils import green_text, red_text +from .sizevars import SimplifyIndexing +from .utils import ( + cache_on_self, + cmp, + device_need_guard, + get_device_tflops, + get_dtype_size, + get_gpu_dram_gbps, + IndentedBuffer, + is_collective, + is_gpu, + is_wait, + sympy_product, +) +from .virtualized import V + + +log = logging.getLogger(__name__) +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") +loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering") + + +@dataclasses.dataclass +class SchedulerBuffer: + scheduler: Scheduler + node: ir.Buffer + defining_op: BaseSchedulerNode + users: List[NodeUser] = dataclasses.field(default_factory=list) + + def __hash__(self) -> int: + return hash(self.node.name) + + def debug_str(self) -> str: + result = IndentedBuffer() + name = self.get_name() + result.writeline(f"{name}: {type(self.node).__name__}") + result.writeline(f"{name}.layout = {self.node.layout}") + if self.get_aliases(): + result.writeline(f"{name}.aliases = {pformat(self.get_aliases())}") + if self.get_mutations(): + result.writeline(f"{name}.mutations = {pformat(self.get_mutations())}") + + if len(self.users) <= 1: + result.writeline(f"{name}.users = {self.users}") + else: + result.writeline(f"{name}.users = [") + with result.indent(1): + for user in self.users: + result.writeline(f"{user},") + result.writeline("]") + return result.getrawvalue() + + def get_name(self) -> str: + return self.node.get_name() + + def allocate(self) -> None: + assert self.node is not None + if not self.node.should_allocate(): + return + + if self.node.get_inputs_that_alias_output() or self.node.get_mutation_names(): + V.graph.wrapper_code.codegen_allocation(self.node) + return + + # hacky check for if V.kernel is a real kernel or NullHandler + if ( + hasattr(V.kernel, "args") + and self.get_name() in V.kernel.inplace_update_buffers + ): + V.graph.wrapper_code.codegen_inplace_reuse( + self.scheduler.name_to_buf[ + V.kernel.inplace_update_buffers[self.get_name()] + ].node, + self.node, + ) + else: + V.graph.wrapper_code.codegen_allocation(self.node) + + def can_free(self) -> bool: + # There's no real allocated buffer, no need to free it + assert self.node is not None + if isinstance(self.node.layout, ir.NoneLayout): + return False + for use in self.users: + if isinstance(use.node, OutputNode): + return False + return True + + def set_users(self, users: List[NodeUser]) -> None: + # deduplicate + result: Dict[int, NodeUser] = {} + for use in users: + if id(use.node) in result: + result[id(use.node)] = use.merge(result[id(use.node)]) + else: + result[id(use.node)] = use + self.users = list(result.values()) + + def get_aliases(self) -> Sequence[str]: + assert self.node is not None + return self.node.get_inputs_that_alias_output() + + def get_mutations(self) -> List[str]: + assert self.node is not None + return self.node.get_mutation_names() + + +class BaseSchedulerNode: + group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]] + read_writes: dependencies.ReadWrites + unmet_dependencies: OrderedSet[Dep] + # .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode. + # e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node + # in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3. + # For non-"grouped" nodes (i.e. regular SchedulerNode), + # .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`. + min_order: int + max_order: int + + def __init__(self, scheduler: Scheduler) -> None: + self.scheduler: Scheduler = scheduler + + def _init_from_node(self, node: ir.Operation) -> None: + self.node: Optional[ir.Operation] = node + self.ancestors: OrderedSet[str] = OrderedSet() + self.last_usage: OrderedSet[ + str + ] = OrderedSet() # buffers that won't be used after this kernel + self.written = False + self.outputs: List[SchedulerBuffer] = [ + SchedulerBuffer( + scheduler=self.scheduler, + node=output, + defining_op=self, + ) + for output in node.get_outputs() + ] + self.outputs_by_name: Dict[str, SchedulerBuffer] = { + buf.get_name(): buf for buf in self.outputs + } + + def __repr__(self) -> str: + return f"{type(self).__name__}(name={self.get_name()!r})" + + def debug_str(self) -> str: + """Longer form printout for trace logs""" + name = self.get_name() + buf = IndentedBuffer() + buf.splice( + f"""\ +{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__}) +{name}.writes = {pformat(self.read_writes.writes)} +{name}.unmet_dependencies = {pformat(self.unmet_dependencies)} +{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)} +{name}.outputs = [ + """ + ) + with buf.indent(): + for out in self.get_outputs(): + buf.splice(out.debug_str()) + buf.writeline("]") + + try: + buf.splice(self.debug_str_extra()) + except Exception: + log.warning("Ignoring error in debug_str()", exc_info=True) + + return buf.getrawvalue().rstrip() + + def debug_str_extra(self) -> str: + return "" + + def debug_str_short(self) -> str: + maybe_data = getattr(self.node, "data", None) + data_str = "" + if isinstance(maybe_data, torch._inductor.ir.Pointwise): + data_str = ", " + maybe_data.str_helper( + [maybe_data.get_size()], shorten=False, multiline=False + ) + elif isinstance(maybe_data, torch._inductor.ir.Reduction): + data_str = ", " + maybe_data.str_helper( + [maybe_data.get_reduction_size(), maybe_data.get_reduction_type()], + shorten=False, + multiline=False, + ) + return f"{self}{data_str}" + + def log_details(self) -> None: + log.info( + "%s: unmet_dependencies = %s, writes = %s", + self, + self.unmet_dependencies, + self.read_writes.writes, + ) + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + return + + def update_mutated_names(self, renames: Dict[str, str]) -> None: + self.set_read_writes(self.read_writes.rename(renames)) + + def add_fake_dep(self, dep: Dep) -> None: + self.set_read_writes(self.read_writes.with_read(dep)) + + def has_aliasing_or_mutation(self) -> bool: + return any( + buf.get_aliases() or buf.get_mutations() for buf in self.get_outputs() + ) + + def set_read_writes(self, rw: dependencies.ReadWrites) -> None: + self.read_writes = rw + self.unmet_dependencies = self.read_writes.reads + self.prune_deps() + + def set_last_usage( + self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str] + ) -> None: + used_buffers = self.used_or_aliased_buffer_names() + used_buffers = OrderedSet([mutation_real_name.get(k, k) for k in used_buffers]) + self.last_usage = used_buffers - future_used_buffers + + def mark_run(self) -> None: + for buf in self.outputs: + buf.allocate() + + def used_buffer_names(self) -> OrderedSet[str]: + return OrderedSet( + dep.name + for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) + ) + + def used_or_aliased_buffer_names(self) -> OrderedSet[str]: + used_names: OrderedSet[str] = OrderedSet() + + deps = [ + dep.name + for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) + ] + while len(deps) > 0: + dep = deps.pop() + used_names.add(dep) + if V.graph.name_to_buffer.get(dep): + for alias in V.graph.name_to_buffer[dep].get_inputs_that_alias_output(): + if alias not in used_names: + deps.append(alias) + return used_names + + def prune_deps(self) -> None: + self.unmet_dependencies = OrderedSet( + dep + for dep in self.unmet_dependencies + if dep.name not in self.scheduler.available_buffer_names + ) + + def prune_weak_deps(self) -> None: + # Prune weak dependencies on operations that have been removed + def should_prune(dep: Dep) -> bool: + if not isinstance(dep, WeakDep): + return False + op = self.scheduler.name_to_buf[dep.name].defining_op + return op.get_name() in V.graph.removed_operations + + to_remove = OrderedSet( + dep for dep in self.read_writes.reads if should_prune(dep) + ) + self.set_read_writes(self.read_writes.remove_reads(to_remove)) + + def prune_redundant_deps( + self, name_to_fused_node: Dict[str, BaseSchedulerNode] + ) -> None: + _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf) + + def get_name(self) -> str: + assert self.node is not None + return self.node.get_operation_name() + + def get_first_name(self) -> str: + return self.get_name() + + def get_operation_names(self) -> OrderedSet[str]: + return OrderedSet(node.get_name() for node in self.get_nodes()) + + def get_buffer_names(self) -> OrderedSet[str]: + return OrderedSet(out.get_name() for out in self.outputs) + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + return [self] + + def get_outputs(self) -> Sequence[SchedulerBuffer]: + return self.outputs + + def get_output(self, buf_name: str) -> SchedulerBuffer: + return self.outputs_by_name[buf_name] + + def get_device(self) -> torch.device: + assert self.node is not None + return self.node.get_device() + + def is_reduction(self) -> bool: + return False + + def is_split_scan(self) -> bool: + return False + + def is_template(self) -> bool: + return False + + def is_extern(self) -> bool: + return False + + def is_foreach(self) -> bool: + return False + + def can_inplace(self, read_dep: dependencies.Dep) -> bool: + return False + + def has_side_effects(self) -> bool: + return False + + def decide_inplace_update(self) -> None: + """ + Decide if there should be inplace updates for the node + and record the decision in the active kernel. + """ + from .codegen.wrapper import buffer_reuse_key + + if not ( + isinstance(self, (SchedulerNode,)) + and config.inplace_buffers + and V.graph.has_feature(self.get_device(), BackendFeature.INPLACE_BUFFERS) + and ( + not isinstance(V.kernel, torch._inductor.codegen.simd.SIMDKernel) + or getattr(V.kernel, "mutations", None) is not None + ) + # hacky check for if V.kernel is a real kernel or NullHandler + and hasattr(V.kernel, "args") + ): + return + + ordered_reads = sorted(self.read_writes.reads, key=lambda x: x.name) + + for buf in self.get_outputs(): + buf_node = buf.node + assert buf_node is not None + if ( + not buf_node.should_allocate() + or buf_node.get_inputs_that_alias_output() + or buf_node.get_mutation_names() + or buf.get_name() in V.graph.removed_buffers + ): + continue + + for read in ordered_reads: + input_buf: Optional[SchedulerBuffer] = self.scheduler.name_to_buf.get( + read.name + ) + if ( + input_buf + and V.graph.wrapper_code.can_reuse(input_buf, self) + and not isinstance(input_buf.defining_op, NopKernelSchedulerNode) + ): + assert input_buf.users is not None + remaining_uses = [ + x + for x in input_buf.users + if x.node.get_name() not in self.scheduler.completed_operations + ] + if ( + len(remaining_uses) == 1 + and remaining_uses[0].can_inplace + and remaining_uses[0].node is self + and input_buf.node is not None + and not isinstance( + input_buf.node.get_layout(), + ( + ir.MultiOutputLayout, + ir.MutationLayoutSHOULDREMOVE, + ), + ) + and not ( + isinstance( + input_buf.defining_op.node, + (ir.FallbackKernel, ir.MultiOutput), + ) + and len(input_buf.node.get_inputs_that_alias_output()) > 0 + ) + and buffer_reuse_key(input_buf.node) + == buffer_reuse_key(buf.node) + ): + # if there isn't a triton kernel, then we don't need to call triton-specific things. + # but TODO this might be a convenient place to signal to the Collective kernels to inplace + # (and, can we make "kernel" less generic of a name?) + V.kernel.args.make_inplace(input_buf.get_name(), buf.get_name()) + # mutations not tracked in cpp kernels + if isinstance( + V.kernel, torch._inductor.codegen.simd.SIMDKernel + ): + V.kernel.mutations.add(input_buf.get_name()) + V.kernel.mutations.add(buf.get_name()) + + # update last usage of reused node + self.last_usage.discard(input_buf.get_name()) + + V.kernel.inplace_update_buffers[ + buf.get_name() + ] = input_buf.get_name() + break + + def codegen_originating_info( + self, buffer: IndentedBuffer, only_once: bool = True + ) -> None: + if not config.comment_origin: + return + + if only_once and self.written: + return + assert self.node is not None + origins = self.node.get_origins() + out_lines = [] + + for o in origins: + if o.op == "output": + # These are boring and samey + continue + + out_lines.append("") + # TODO(voz): Should the pragma be constant somewhere? + out_lines.append("#pragma CMT ORIGIN:") + op_info_str = f"#pragma CMT {o.op} {o.target}" + if "seq_nr" in o.meta: + op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}" + out_lines.append(op_info_str) + if "stack_trace" in o.meta: + stack_trace = f"{o.meta['stack_trace']}" + stack_trace_last_line = stack_trace.split("|")[-1] + out_lines.append( + "#pragma CMT " + + stack_trace_last_line.replace("{", "{{") + .replace("}", "}}") + .replace("\n", "\\") + ) + out_lines.append("#pragma CMT END ORIGIN") + out_lines.append("") + + if len(out_lines) == 0: + return + + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + buffer.writelines(out_lines) + self.written = True + + def get_read_write_buffers_sizes(self) -> int: + """ + Counting the number of bytes accessed for a kernel is + surprisingly tricky. In particular, there is a differentiation + between 'theoretical' memory accesses and practical memory + accesses. For example, a layernorm kernel may actually access an + input 3 times, but in theory, it only needs to access its input + once (and may be optimized to do so through say, persistent + reductions) + + Another example is that even though a buffer is passed in, we may + not access the entire buffer. This may occur if we are accessing + a slice of the buffer. Another tricky case is for indirect + indexing, where the amount of bytes accessed depends on the + values of the input. + + What this function aims to compute is the memory accesses for + worst-case inputs, best-case optimization. What this means is + that for each buffer we compute the amount of potential accesses in two ways and take the minimum. + + 1. Numel in ranges multiplied by number of deps the buffer has + 2. The buffer size + """ + if isinstance(self, NopKernelSchedulerNode): + return 0 + if isinstance(self, ExternKernelSchedulerNode) and isinstance( + self.node, MultiOutput + ): + # todo: Calculate this - it's kinda annoying. + return 0 + + def try_size_hint(s: sympy.Expr) -> int: + return V.graph.sizevars.size_hint(s, fallback=0) + + if isinstance(self, SchedulerNode): + node_numel = try_size_hint( + sympy_product(self.get_ranges()[0]) + * sympy_product(self.get_ranges()[1]), + ) + else: + node_numel = int(1e9) + buf_accesses = collections.defaultdict(list) + for dep in self.read_writes.reads | self.read_writes.writes: + buf_accesses[dep.name].append(dep) + + reads = OrderedSet(dep.name for dep in self.read_writes.reads) + writes = OrderedSet(dep.name for dep in self.read_writes.writes) + + def is_materialized(buf: str, snodes: Sequence[BaseSchedulerNode]) -> bool: + users = self.scheduler.name_to_buf[buf].users + buf_uses = OrderedSet(user.node for user in users) + return len(buf_uses - OrderedSet(snodes)) > 0 + + if isinstance(self, FusedSchedulerNode): + removed_buffers = OrderedSet( + dep for dep in writes if not is_materialized(dep, self.snodes) + ) + writes = writes - removed_buffers + reads = reads - removed_buffers + node_bytes = 0 + + for buf_name in reads | writes: + buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name]) + buf: Union[ir.Buffer, ir.TensorBox] + if buf_name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[buf_name] + elif buf_name in V.graph.graph_inputs: + buf = V.graph.graph_inputs[buf_name] + else: + continue + + def get_buf_bytes(buf: Optional[Union[ir.Buffer, ir.TensorBox]]) -> int: + if not buf: + return 0 + # Kind of a lazy way to get the MultiOutput nodes corresponding to + # a MultiOutputLayout + if isinstance(buf.layout, MultiOutputLayout): + users = self.scheduler.name_to_buf[buf.get_name()].users + tot = 0 + for user in users: + assert isinstance(user.node, BaseSchedulerNode) + if isinstance(user.node.node, MultiOutput): + for sched_buf in user.node.get_outputs(): + tot += get_buf_bytes(sched_buf.node) + else: + # Buf is a MultiOutputLayout but not all of its + # users are MultiOutputs... + # TODO: Figure out what's going on + return 0 + return tot + elif isinstance(buf.layout, ir.NoneLayout): + return sum( + get_buf_bytes(V.graph.get_buffer(mut_name)) + for mut_name in buf.get_mutation_names() + ) + else: + buf_elems = try_size_hint(sympy_product(buf.get_size())) + return get_dtype_size(buf.get_dtype()) * min( + buf_accessed_elems, buf_elems + ) + + node_bytes += get_buf_bytes(buf) + + return node_bytes + + def get_estimated_runtime(self) -> float: + """ + Returns estimated op runtime in nanoseconds (ns) + """ + buf = self.get_nodes()[0].get_outputs()[0] + layout = buf.node.get_layout() + dtype = buf.node.get_dtype() + + if layout.device is not None and not is_gpu(layout.device.type): + # default to no reordering based on runtime + return 0 + + # Collective kernels + if is_collective(self.node): + assert isinstance(self.node, ir.IRNode) + try: + return estimate_nccl_collective_runtime(self.node) + except ValueError as e: + # We don't know how to estimate runtime for this collective, + # falling back to 0 + log.info(e) + return 0 + + elif is_wait(self.node): + # ir.Wait is only used for collective ops. + # The time needed for the collective op is already estimated and considered + # when we are processing the collective op IR node, so ir.Wait takes 0 time + # since it doesn't take extra time to get the result after the collective is completed. + return 0 + + try: + gpu_memory_bandwidth = get_gpu_dram_gbps() + gpu_flops = get_device_tflops(dtype) * 10**12 + except Exception: + return 0 + + if isinstance(self, ExternKernelSchedulerNode): + assert isinstance(self.node, ir.ExternKernel), f"{type(self.node)=}" + op = kernel_name_to_op.get( + getattr(self.node, "python_kernel_name", ""), None + ) + + # if there is a resolved op, dry-run using fake mode and record flop count + if op is not None: + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.utils.flop_counter import FlopCounterMode + + if any( + len(free_unbacked_symbols(n.get_numel())) > 0 + for n in self.node.inputs + ): + # Tensor has unbacked symints, we don't know how to estimate + # runtime for that today + return 0 + + with FakeTensorMode() as fake_mode, FlopCounterMode( + display=False + ) as flop_counter_mode, V.set_current_node( + self.node.fx_node + ), V.set_fake_mode( + fake_mode + ): + from .ir import ir_node_to_tensor + + fake_inputs = [ + ir_node_to_tensor(input, guard_shape=False) + for input in self.node.inputs + ] + cls = self.node.__class__ + cls.process_kernel(op, *fake_inputs, **self.node.kwargs) + + # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship + factor = 1.0 + counted_flops = flop_counter_mode.get_total_flops() + counted_bytes = self.get_read_write_buffers_sizes() + compute_time = (factor * counted_flops / gpu_flops) * 1e9 + transfer_time = counted_bytes / gpu_memory_bandwidth + + # Return estimated runtime in nanoseconds + return max(compute_time, transfer_time) + + elif isinstance(self, FusedSchedulerNode) or isinstance( + self.node, ComputedBuffer + ): + # Return estimated runtime in nanoseconds (bytes / gbps) + return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth + + return 0 + + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + return None + + +class WhyNoFuse: + # TODO when we drop support for Python < 3.10, we can use + # @dataclass(slots=True) instead of manually specifying __slots__. + __slots__ = ["node1", "node2", "reason", "args"] + reason: str + args: Tuple[Any, ...] + + def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None: + self.node1 = node1 + self.node2 = node2 + + def __call__(self, reason: str, *args: Any) -> None: + self.reason = reason + self.args = args + fusion_log.debug(self) + + def __str__(self) -> str: + return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + ( + self.reason % self.args + ) + + +def pformat(obj: Any) -> str: + if isinstance(obj, OrderedSet): + # pformat has trouble with sets of sympy exprs + obj = sorted(obj, key=str) + result = pprint.pformat(obj, indent=4) + if "\n" in result: + return f"\n{textwrap.indent(result, ' ' * 4)}" + return result + + +class OutputNode: + def __init__(self, dep: StarDep) -> None: + self.unmet_dependencies = OrderedSet([dep]) + + def is_reduction(self) -> bool: + return False + + def get_inputs_that_alias_output(self) -> Sequence[str]: + return () + + def get_name(self) -> str: + return "OUTPUT" + + __repr__ = get_name + + +def _prune_redundant_deps( + node: BaseSchedulerNode, + name_to_fused_node: Dict[str, BaseSchedulerNode], + name_to_buf: Dict[str, SchedulerBuffer], +) -> None: + """ + Prunes weakdeps intended for mutation ordering + on an upstream fused node if after fusion there is another dependency + on the fused upstream node, making the weakdep redundant + + In essence this enforces an ordering on fusions. As fusions occur, weakdeps will + be incrementally removed, enabling other fusions, ensuring they are fused in order. + """ + name_to_dep_count: Counter[str] = collections.Counter() + + for dep in node.unmet_dependencies: + if not isinstance(dep, WeakDep): + op = name_to_buf[dep.name].defining_op + name_to_dep_count[name_to_fused_node[op.get_name()].get_name()] += 1 + + def should_prune(dep: Dep) -> bool: + if isinstance(dep, WeakDep): + op_name = name_to_buf[dep.name].defining_op.get_name() + is_redundant = name_to_dep_count[name_to_fused_node[op_name].get_name()] > 0 + # These can occur because fused nodes always gather deps from their snodes + # If B has a weakdep on A + # B gets fused with C, then any time BC is fused, the weakdep will reappear + is_self_dep = name_to_fused_node[op_name] == node + return is_redundant or is_self_dep + else: + return False + + deps_to_prune = OrderedSet( + dep for dep in node.unmet_dependencies if should_prune(dep) + ) + + if deps_to_prune: + node.unmet_dependencies = node.unmet_dependencies - deps_to_prune + node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) + + +# TODO(xmfan): reuse: an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel +kernel_name_to_op = { + "extern_kernels.convolution": torch.ops.aten.convolution, + "extern_kernels.mm": torch.ops.aten.mm, + "extern_kernels.bmm": torch.ops.aten.bmm, + "extern_kernels.addmm": torch.ops.aten.addmm, +} + + +class ExternKernelSchedulerNode(BaseSchedulerNode): + def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self.set_read_writes(node.get_read_writes()) + + def debug_str_extra(self) -> str: + return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" + + def is_extern(self) -> bool: + return True + + def has_side_effects(self) -> bool: + assert self.node is not None + return hasattr(self.node, "has_side_effects") and self.node.has_side_effects() + + +class NopKernelSchedulerNode(BaseSchedulerNode): + def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self.set_read_writes(node.get_read_writes()) + + +class SchedulerNode(BaseSchedulerNode): + def __init__( + self, + scheduler: Scheduler, + node: Union[ir.ComputedBuffer, ir.TemplateBuffer], + ) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self._compute_attrs() + + def _compute_attrs( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ) -> None: + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) + self._sizes, self._body = self.node.simplify_and_reorder( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, + ) + + group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn + self.group = (self.node.get_device(), group_fn(self._sizes)) + + # Don't normalize since normalization will merge loops which + # makes it hard to decide new loop orders. + should_normalize = ( + not config.loop_ordering_after_fusion + or self.node.get_device().type != "cuda" + ) + + if isinstance(self.node, ir.TemplateBuffer): + self.set_read_writes( + self.node.extract_read_writes(normalize=should_normalize) + ) + else: + self.set_read_writes( + dependencies.extract_read_writes( + self._body, *self._sizes, normalize=should_normalize + ) + ) + + def recompute_size_and_body( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ) -> None: + self._compute_attrs( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, + ) + + def refresh_dependencies(self, normalize: bool) -> None: + # Fake dependencies are added manually. They can not be analyzed from + # extract_read_writes. Find them out and apply manually. + fake_deps = { + dep for dep in self.read_writes.reads if isinstance(dep, (WeakDep, StarDep)) + } + + # don't normalize since the loop order may need to be further changed + # later + self.set_read_writes( + dependencies.extract_read_writes( + self._body, *self._sizes, normalize=normalize + ).with_read(fake_deps) + ) + + def apply_new_loop_order(self, new_order: Sequence[int]) -> None: + self._body = self._body.reorder_iter_loops( + new_order, + ) + self._sizes = self._body.sizes + + self.refresh_dependencies(normalize=False) + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + new_order = None + self_sizes = self._sizes[0] + if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: + new_order = self_dep.decide_loop_order_to_match(other_dep) + + if new_order: + metrics.num_loop_reordering += 1 + loop_ordering_log.debug( + "Reorder loops for %s with order %s", self.get_name(), new_order + ) + self.apply_new_loop_order(new_order) + else: + loop_ordering_log.debug( + "Don't reordering %s because we can not decide the suitable loop order", + self.get_name(), + ) + + def debug_str_extra(self) -> str: + name = self.get_name() + lines = [ + f"{name}.group.device = {self.group[0]}", + f"{name}.group.iteration = {self.group[1]}", + f"{name}.sizes = {self._sizes}", + ] + for dep in self.read_writes.reads_and_writes(): + if not isinstance(dep, WeakDep): + buf_name = dep.name + buf = V.graph.get_buffer(buf_name) + lines.append(f"{buf_name}_layout = {pformat(buf.layout)}") + if isinstance(self._body, LoopBody): + lines.append(f"class {name}_loop_body:") + lines.append(textwrap.indent(self._body.debug_str(), " ")) + + assert self.node is not None + if ir.is_triton(self.node.get_device()): + lines.extend(debug_triton_code(self)) + + return "\n".join(lines) + + def get_ranges(self) -> Sequence[Sequence[sympy.Expr]]: + return self._sizes + + def is_reduction(self) -> bool: + assert isinstance( + self.node, (ir.ComputedBuffer, ir.TemplateBuffer) + ), f"{type(self.node)=}" + return bool(self.node.get_reduction_type()) + + def is_split_scan(self) -> bool: + assert isinstance( + self.node, (ir.ComputedBuffer, ir.TemplateBuffer) + ), f"{type(self.node)=}" + return isinstance(self.node, ir.ComputedBuffer) and isinstance( + self.node.data, ir.SplitScan + ) + + def is_template(self) -> bool: + return isinstance(self.node, ir.TemplateBuffer) + + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + return self.node if isinstance(self.node, ir.TemplateBuffer) else None + + def run(self, *index_vars: Sequence[sympy.Expr]) -> None: + self.decide_inplace_update() + self.mark_run() + self.codegen(index_vars) + + def ranges_from_index_vars( + self, index_vars: Sequence[Sequence[sympy.Expr]] + ) -> Dict[sympy.Expr, sympy.Expr]: + sizes = self._sizes + assert sum(map(len, sizes)) == sum(map(len, index_vars)) + var_ranges = dict( + zip( + itertools.chain.from_iterable(index_vars), + itertools.chain.from_iterable(sizes), + ) + ) + return var_ranges + + def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: + var_ranges = self.ranges_from_index_vars(index_vars) + try: + with V.set_ops_handler( + SimplifyIndexing(V.get_ops_handler(), var_ranges) + ), V.kernel.set_current_node(self): + self._body(*index_vars) + except Exception: + log.fatal("Error in codegen for %s", self.node) + raise + + @cache_on_self + def pointwise_read_writes(self) -> dependencies.ReadWrites: + """ + Get the memory dependencies in the non-reduction axis. + """ + sizes, reduction_sizes = self._sizes + return dependencies.extract_read_writes( + self._body, sizes, hidden_args=[[sympy.Integer(0)] * len(reduction_sizes)] + ) + + def can_inplace(self, read_dep: dependencies.Dep) -> bool: + if self.is_template(): + return False + if any(out.get_aliases() for out in self.get_outputs()): + return False + if len(self.read_writes.writes) == 1 and isinstance( + read_dep, dependencies.MemoryDep + ): + write_dep = next(iter(self.read_writes.writes)) + assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}" + return read_dep.index == write_dep.index and read_dep.size == write_dep.size + return False + + @cache_on_self + def _get_atomic_add_buffers(self) -> OrderedSet[str]: + buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet() + if isinstance(self._body, LoopBody): + for node in self._body.get_nodes(): + if ( + node.op == "call_method" + and node.target == "store" + and ( + ("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add") + or (len(node.args) == 5 and node.args[4] == "atomic_add") + ) + ): + buffers_store_as_atomic_add.add( + node.kwargs["name"] + if "name" in node.kwargs + else (node.args[1] if len(node.args) >= 2 else "") + ) + return buffers_store_as_atomic_add + + +def refresh_group_node_dependencies(group_snode: BaseSchedulerNode) -> None: + snodes = group_snode.snodes # type: ignore[attr-defined] + group_snode.set_read_writes( + dependencies.ReadWrites.merge_list([x.read_writes for x in snodes]) + ) + + group_snode.unmet_dependencies = ( + OrderedSet( + dep + for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes]) + if dep.name not in group_snode.get_buffer_names() + ) + - group_snode.read_writes.writes + ) + + +def init_group_node( + group_snode: BaseSchedulerNode, + scheduler: Scheduler, + snodes: List[BaseSchedulerNode], +) -> None: + assert isinstance(group_snode, (FusedSchedulerNode, GroupedSchedulerNode)) + group_snode.snodes = snodes + group_snode.scheduler = scheduler + group_snode.node = None + group_snode.ancestors = OrderedSet.union( + *[x.ancestors for x in snodes if x.ancestors is not None] + ) + + refresh_group_node_dependencies(group_snode) + + group_snode.min_order = min(x.min_order for x in group_snode.snodes) + group_snode.max_order = max(x.max_order for x in group_snode.snodes) + group_snode.outputs_by_name = { + buf.get_name(): buf for buf in group_snode.get_outputs() + } + + +class FusedSchedulerNode(BaseSchedulerNode): + """ + This is a "fake" scheduler node that represents a group of scheduler nodes + that are meant to be fused together. The way it does this is by maintaining + its unmet dependencies as the union of its constituent nodes. + """ + + snodes: List[BaseSchedulerNode] + + @classmethod + def fuse( + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> FusedSchedulerNode: + assert node1.scheduler is node2.scheduler + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) + assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) + nodes = list(itertools.chain(node1.get_nodes(), node2.get_nodes())) + return cls(node1.scheduler, nodes) + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + if self.is_template(): + # We can not really reorder loops for a triton template + return + self_sizes = None + for snode in self.snodes: + assert isinstance(snode, SchedulerNode) + if self_sizes is not None and self_sizes != snode._sizes[0]: + loop_ordering_log.debug( + "Can not reorder fused node due to different sizes" + ) + return + self_sizes = snode._sizes[0] + new_order = None + + assert self_sizes is not None + if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: + new_order = self_dep.decide_loop_order_to_match(other_dep) + + if not new_order: + loop_ordering_log.debug( + "Dont reordering fused node %s because we can not decide the suitable loop order", + self.get_name(), + ) + return + metrics.num_loop_reordering += 1 + loop_ordering_log.debug( + "Reorder loops for fused node %s with order %s", self.get_name(), new_order + ) + for snode in self.snodes: + assert isinstance(snode, SchedulerNode) + snode.apply_new_loop_order(new_order) # type: ignore[arg-type] + + refresh_group_node_dependencies(self) + + def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None: + super().__init__(scheduler) + init_group_node(self, scheduler, snodes) + self.users: List[NodeUser] = [] + self.group = max(snodes, key=lambda x: int(x.is_reduction())).group + + @cache_on_self + def get_name(self) -> str: + return "_".join([x.get_name() for x in self.snodes]) + + def get_first_name(self) -> str: + return self.snodes[0].get_name() + + @cache_on_self + def get_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes]) + + def get_outputs(self) -> List[SchedulerBuffer]: + result: List[SchedulerBuffer] = [] + for node in self.snodes: + result.extend(node.get_outputs()) + return result + + def debug_str_extra(self) -> str: + lines = [ + f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}" + for i, node in enumerate(self.snodes) + ] + node = self.snodes[0].node + if node is not None: + device = node.get_device() + if ir.is_triton(device): + lines.extend(debug_triton_code(self)) + + return textwrap.indent("\n".join(lines).rstrip(), " ") + + def debug_str_short(self) -> str: + snodes_str = [node.debug_str_short() for node in self.snodes] + return f"{self}, snodes: {snodes_str}" + + def set_last_usage( + self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str] + ) -> None: + # Set self.last_usage using the global information + # This will be used for inter-kernel optimisations + super().set_last_usage(future_used_buffers, mutation_real_name) + # Set self.last_usage on the snodes + # This will be used for optimisations within the kernel + future_used_buffers: OrderedSet[str] = OrderedSet() + for node in reversed(self.snodes): + node.set_last_usage(future_used_buffers, mutation_real_name) + future_used_buffers.update(node.last_usage) + + @cache_on_self + def used_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union(*[x.used_buffer_names() for x in self.snodes]) + + @cache_on_self + def used_or_aliased_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union( + *[x.used_or_aliased_buffer_names() for x in self.snodes] + ) + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + return self.snodes + + def __repr__(self) -> str: + return f"{type(self).__name__}(nodes={self.get_name()})" + + @cache_on_self + def is_reduction(self) -> bool: + return any(x.is_reduction() for x in self.snodes) + + @cache_on_self + def is_split_scan(self) -> bool: + return any(x.is_split_scan() for x in self.snodes) + + @cache_on_self + def is_template(self) -> bool: + return any(x.is_template() for x in self.snodes) + + @cache_on_self + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + for node in self.snodes: + if node.is_template(): + return node.get_template_node() + return None + + def get_device(self) -> torch.device: + return self.group[0] + + @cache_on_self + def has_aliasing_or_mutation(self) -> bool: + return any(x.has_aliasing_or_mutation() for x in self.snodes) + + # None of these need to be implemented, as a FusedSchedulerNode is just an + # abstraction for scheduling purposes + def update_mutated_names(self, renames: Dict[str, str]) -> None: + raise NotImplementedError + + def add_fake_dep(self, name: Dep) -> None: + raise NotImplementedError + + def can_inplace(self, read_dep: dependencies.Dep) -> bool: + raise NotImplementedError + + def debug_str(self) -> str: + """Longer form printout for trace logs""" + name = self.get_name() + node_typestr = ",".join(type(n).__name__ for n in self.snodes) + buf = IndentedBuffer() + buf.splice( + f"""\ +{name}: {type(self).__name__}({node_typestr}) +{name}.writes = {pformat(self.read_writes.writes)} +{name}.unmet_dependencies = {pformat(self.unmet_dependencies)} +{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)} +{name}.outputs = [ + """ + ) + with buf.indent(): + for out in self.get_outputs(): + buf.splice(out.debug_str()) + buf.writeline("]") + + try: + buf.splice(self.debug_str_extra()) + except Exception: + log.warning("Ignoring error in debug_str()", exc_info=True) + + return buf.getrawvalue().rstrip() + + +class ForeachKernelSchedulerNode(FusedSchedulerNode): + """ + This is a schedular node that consists of a set of scheduler nodes that + has no data dependencies among them and can be executed in parallel. + """ + + def get_consumer_subnode_for( + self, producer: BaseSchedulerNode + ) -> Optional[BaseSchedulerNode]: + for buf in producer.get_outputs(): + if buf.get_name() in self.read_to_node: + return self.read_to_node[buf.get_name()] + + return None + + def get_producer_subnode_for( + self, consumer: BaseSchedulerNode + ) -> Optional[BaseSchedulerNode]: + producers = set() + for rd in consumer.read_writes.reads: + if rd.name not in self.scheduler.name_to_buf: + continue + + node_name = self.scheduler.name_to_buf[rd.name].defining_op.get_name() + if node_name in self.name_to_node: + producers.add(self.name_to_node[node_name]) + + # Don't permit fusion if there are multiple subnodes + # that this consumer reads from + if len(producers) == 1: + return next(iter(producers)) + else: + return None + + @classmethod + def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool: + why = WhyNoFuse(producer, consumer) + if producer.is_foreach() and consumer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + foreach_match = len(producer.snodes) == len(consumer.snodes) + if not foreach_match: + why("foreach do not have same length") + return foreach_match and all( + producer.scheduler.can_fuse(l, r) + for l, r in zip(producer.snodes, consumer.snodes) + ) + elif consumer.is_foreach(): + if producer.is_reduction(): + why( + "candidate producer is a reduction, foreach ops cannot be fused with reductions currently" + ) + return False + + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + consumer_subnode = consumer.get_consumer_subnode_for(producer) + if consumer_subnode is not None: + return consumer.scheduler.can_fuse(producer, consumer_subnode) + + why("candidate producer is not dep of any foreach consumer") + return False + + elif producer.is_foreach(): + if consumer.is_reduction(): + why( + "candidate consumer is a reduction, foreach ops cannot be fused with reductions currently" + ) + return False + + producer = typing.cast(ForeachKernelSchedulerNode, producer) + producer_subnode = producer.get_producer_subnode_for(consumer) + if producer_subnode is not None: + return producer.scheduler.can_fuse(producer_subnode, consumer) + + why("candidate consumer has no dep in any foreach producer") + return False + + raise AssertionError( + "At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node" + ) + + @classmethod + def fuse( + cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode + ) -> ForeachKernelSchedulerNode: + assert producer.is_foreach() or consumer.is_foreach() + if producer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + use_custom_partition_algo = producer.use_custom_partition_algo + enable_autotune = producer.enable_autotune + else: + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + use_custom_partition_algo = consumer.use_custom_partition_algo + enable_autotune = consumer.enable_autotune + prev_node_1 = None + prev_node_2 = None + fused_nodes: List[BaseSchedulerNode] + if producer.is_foreach() and consumer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + fused_nodes = [ + FusedSchedulerNode.fuse(l, r) + for l, r in zip(producer.snodes, consumer.snodes) + ] + elif producer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + producer_subnode = producer.get_producer_subnode_for(consumer) + fused_nodes = [] + prev_node_1 = producer + prev_node_2 = None + for node in producer.snodes: + if node is producer_subnode: + new_node = FusedSchedulerNode.fuse(node, consumer) + prev_node_2 = new_node + fused_nodes.append(new_node) + else: + fused_nodes.append(node) + + elif consumer.is_foreach(): + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + consumer_subnode = consumer.get_consumer_subnode_for(producer) + fused_nodes = [] + prev_node_1 = consumer + prev_node_2 = None + + for node in consumer.snodes: + if node is consumer_subnode: + new_node = FusedSchedulerNode.fuse(producer, node) + prev_node_2 = new_node + fused_nodes.append(new_node) + else: + fused_nodes.append(node) + else: + raise AssertionError( + "At least one node passed to ForeachKernelSchedulerNode.fuse should be a foreach node" + ) + + return cls( + producer.scheduler, + fused_nodes, + use_custom_partition_algo=use_custom_partition_algo, + prev_node_1=prev_node_1, + prev_node_2=prev_node_2, + enable_autotune=enable_autotune, + ) + + def __init__( + self, + scheduler: Scheduler, + snodes: List[BaseSchedulerNode], + use_custom_partition_algo: bool, + prev_node_1: Optional[BaseSchedulerNode] = None, + prev_node_2: Optional[BaseSchedulerNode] = None, + enable_autotune: bool = False, + ) -> None: + self.read_to_node = {} + self.name_to_node = {} + + if prev_node_1 is None or prev_node_2 is None: + super().__init__(scheduler, snodes) + + for node in snodes: + for read in node.read_writes.reads: + self.read_to_node[read.name] = node + + for name in node.get_operation_names(): + self.name_to_node[name] = node + else: + self.scheduler = scheduler + self.snodes = snodes + self.node = None + self.users: List[NodeUser] = [] + + self.set_read_writes( + dependencies.ReadWrites.merge_list( + [prev_node_1.read_writes, prev_node_2.read_writes] + ) + ) + + self.unmet_dependencies = ( + OrderedSet( + dep + for dep in OrderedSet.union( + prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies + ) + if dep.name not in self.get_buffer_names() + ) + - self.read_writes.writes + ) + + self.min_order = min([prev_node_1.min_order, prev_node_2.min_order]) + self.max_order = max([prev_node_1.max_order, prev_node_2.max_order]) + + if prev_node_1.is_foreach(): + assert isinstance(prev_node_1, ForeachKernelSchedulerNode) + foreach_node, other_node = prev_node_1, prev_node_2 + else: + assert isinstance(prev_node_2, ForeachKernelSchedulerNode) + foreach_node, other_node = prev_node_2, prev_node_1 + + self.ancestors = foreach_node.ancestors + self.ancestors.update(other_node.ancestors) + + self.name_to_node = foreach_node.name_to_node + for name in other_node.get_operation_names(): + self.name_to_node[name] = other_node + + self.use_custom_partition_algo = use_custom_partition_algo + self.group = (snodes[0].get_device(), ((sympy.Expr("combo_kernel"),),)) + self.origins: OrderedSet[torch.fx.Node] = OrderedSet() + self.enable_autotune = enable_autotune + + @classmethod + def combinable_nodes( + cls, nodes: List[BaseSchedulerNode] + ) -> List[BaseSchedulerNode]: + extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)] + if extern: + log.debug( + "ComboKernels: %d external nodes are filtered %s", + len(extern), + [node.node.get_origins() for node in extern if node.node is not None], + ) + filtered_nodes = [ + x + for x in nodes + if not isinstance(x, (NopKernelSchedulerNode, ExternKernelSchedulerNode)) + ] + foreach_nodes = [ + x for x in filtered_nodes if isinstance(x, ForeachKernelSchedulerNode) + ] + if foreach_nodes: + log.debug("ComboKernels: %d foreach nodes are filtered", len(foreach_nodes)) + filtered_nodes = [ + x for x in filtered_nodes if not isinstance(x, ForeachKernelSchedulerNode) + ] + template_nodes = [x for x in filtered_nodes if x.is_template()] + if template_nodes: + log.debug( + "ComboKernels: %d template nodes are filtered", {len(template_nodes)} + ) + filtered_nodes = [x for x in filtered_nodes if x not in template_nodes] + return filtered_nodes + + @staticmethod + def _default_group_nodes_for_combo_kernels( + scheduler: Scheduler, + ) -> List[List[BaseSchedulerNode]]: + """ + Returns a list of lists of nodes that are to be grouped together. + """ + sorted_nodes = scheduler._topological_sort_nodes() + grouped_nodes = [] + max_num_nodes = 8 + for nodes in sorted_nodes: + grouped_nodes.extend( + [ + nodes[i : i + max_num_nodes] + for i in range(0, len(nodes), max_num_nodes) + ] + ) + + return grouped_nodes + + group_algorithm_for_combo_kernels: Callable[ + [Scheduler], List[List[BaseSchedulerNode]] + ] = _default_group_nodes_for_combo_kernels + + @staticmethod + def set_group_algorithm_for_combo_kernels( + custom_group_algorithm: Callable[[Scheduler], List[List[BaseSchedulerNode]]] + ) -> None: + ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = ( + custom_group_algorithm + ) + + @staticmethod + def group_nodes_for_combo_kernels( + scheduler: Scheduler, + ) -> List[List[BaseSchedulerNode]]: + return ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels(scheduler) + + def mark_run(self) -> None: + raise NotImplementedError + + def codegen(self) -> None: + assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}" + self.node.get_store_function()(self.node.make_loader()()) + + def is_foreach(self) -> bool: + return True + + def get_subkernel_nodes(self) -> List[BaseSchedulerNode]: + """Returns a list of nodes which comprise the combo kernel. + These nodes may be vertically fused.""" + return list(self.snodes) + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + """Returns all nodes contained in this kernel, unpacking fused nodes + into their constituent scheduler nodes.""" + return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes)) + + def get_first_name(self) -> str: + return self.snodes[0].get_first_name() + + def prune_redundant_deps( + self, name_to_fused_node: Dict[str, BaseSchedulerNode] + ) -> None: + _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf) + + for node in self.snodes: + node.prune_redundant_deps(name_to_fused_node) + + +class GroupedSchedulerNode(BaseSchedulerNode): + """ + This is a "fake" scheduler node that represents a group of scheduler nodes + that are meant to be *grouped* together (it does not allow another node to be scheduled + in between its constituent nodes, nor does it allow another node to fuse into any of its constituent nodes). + The way it does this is by maintaining its unmet dependencies as the union of its constituent nodes. + Fusion will still happen among the nodes within each GroupedSchedulerNode. + At codegen time, this scheduler node will be unpacked and codegen is called on each constituent node. + """ + + snodes: List[BaseSchedulerNode] + + @classmethod + def create(cls, snodes: List[BaseSchedulerNode]) -> GroupedSchedulerNode: + scheduler = snodes[0].scheduler + assert all(node.scheduler is scheduler for node in snodes) + grouped_snode = cls(scheduler, snodes) # type: ignore[arg-type] + for snode in snodes: + scheduler.name_to_fused_node[snode.get_name()] = grouped_snode + scheduler.name_to_fused_node[grouped_snode.get_name()] = grouped_snode + return grouped_snode + + def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None: + super().__init__(scheduler) + init_group_node(self, scheduler, snodes) + + def unpack(self) -> List[BaseSchedulerNode]: + """ + Do fusion among nodes within this GroupedSchedulerNode, + and then unpack this GroupedSchedulerNode into regular nodes. + """ + for snode in self.snodes: + self.scheduler.name_to_fused_node[snode.get_name()] = snode + del self.scheduler.name_to_fused_node[self.get_name()] + return self.scheduler.fuse_nodes(self.snodes) + + def add_fake_dep(self, fake_dep: Dep) -> None: + self.set_read_writes(self.read_writes.with_read(fake_dep)) + self.unmet_dependencies.add(fake_dep) + + @cache_on_self + def get_name(self) -> str: + return "_".join([x.get_name() for x in self.snodes]) + + def get_first_name(self) -> str: + return self.snodes[0].get_name() + + @cache_on_self + def get_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes]) + + def get_outputs(self) -> List[SchedulerBuffer]: + result: List[SchedulerBuffer] = [] + for node in self.snodes: + result.extend(node.get_outputs()) + return result + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + return self.snodes + + @classmethod + def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool: + # GroupedSchedulerNode cannot be fused with another node + return False + + +def pick_loop_order( + stride_lengths: List[List[int]], + sizes: List[sympy.Expr], + priority_idx: Tuple[int, ...] = (), +) -> List[int]: + """ + A heuristic to decide loop iteration orders. This has not been well + tuned and may be something we should autotune. + """ + + @functools.cmp_to_key + def index_cmp(a: int, b: int) -> int: + if sizes[a] == 1 or sizes[b] == 1: + # 1-sizes don't matter, just move them to the end + return cmp(sizes[a] == 1, sizes[b] == 1) + + # Take abs, otherwise flipped dimensions are treated as smaller + # strides than contiguous dims + stride_len_a = [abs(sl[a]) for sl in stride_lengths] + stride_len_b = [abs(sl[b]) for sl in stride_lengths] + + # equivalent to + # np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all() + a_first = sum( + sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b) + ) + b_first = sum( + sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b) + ) + if a_first > b_first: + return -1 + if b_first > a_first: + return 1 + + # otherwise contiguous + return cmp(b, a) + + order = list(reversed(range(len(stride_lengths[0])))) + if len(priority_idx) > 0: + # if we have priority node, only use that node's order + stride_lengths = [stride_lengths[pi] for pi in priority_idx] + if config.pick_loop_orders: + order.sort(key=index_cmp) + return order + + +@dataclasses.dataclass +class NodeUser: + node: Union[BaseSchedulerNode, OutputNode] + can_inplace: bool = False + + # A weak user must be scheduled after a given node, but doesn't actually + # use the result + is_weak: bool = False + + def __hash__(self) -> int: + return hash((self.node.get_name(), self.can_inplace, self.is_weak)) + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, NodeUser) + and self.get_name() == other.get_name() + and self.can_inplace == other.can_inplace + and self.is_weak == other.is_weak + ) + + def get_name(self) -> str: + return self.node.get_name() + + def merge(self, other: NodeUser) -> NodeUser: + assert self.node is other.node + return NodeUser( + self.node, + self.can_inplace and other.can_inplace, + self.is_weak and other.is_weak, + ) + + +_post_grad_graph_counter = itertools.count() + + +class Scheduler: + __dep_size_hint_cache: Dict[Dep, int] + + def __init__(self, nodes: List[ir.Operation]) -> None: + with dynamo_timed("Scheduler.__init__"): + self._init(nodes) + + def _init(self, nodes: List[ir.Operation]) -> None: + super().__init__() + self.__dep_size_hint_cache = {} + V.graph.scheduler = self + self.backends: Dict[torch.device, BaseScheduling] = {} + self.post_grad_graph_id = next(_post_grad_graph_counter) + + self.completed_operations: OrderedSet[str] = OrderedSet() + self.available_buffer_names = OrderedSet( + [ + *V.graph.graph_inputs.keys(), + *V.graph.constants.keys(), + *V.graph.torchbind_constants.keys(), + ] + ) + + self.nodes = [self.create_scheduler_node(n) for n in nodes] + self.update_zero_dim_cpu_tensor() + # some new constants could have been created above + self.available_buffer_names.update(V.graph.constants.keys()) + for node in self.nodes: + node.prune_deps() + + self.name_to_node: Dict[str, BaseSchedulerNode] = { + n.get_name(): n for n in self.nodes + } + self.name_to_buf: Dict[str, SchedulerBuffer] = { + buf.get_name(): buf for node in self.nodes for buf in node.get_outputs() + } + self.name_to_fused_node: Dict[str, BaseSchedulerNode] = self.name_to_node.copy() + + # mutation_real_name: Maps back to the original name for codegen + # Example: + # If you mutate buf0 inside of buf1's kernel, then: + # mutation_real_name = {"buf0" : "buf1"} + # all subsequent uses of buf0 become buf1's usage in dependency graph + self.mutation_real_name: Dict[str, str] = {} + + # We handle mutation by renaming modified versions of the same + # buffer in the dependency graph to prevent cycles. + # mutation_renames: tracks the current name for a given buffer + # (changed once per mutation) + # Example: + # If you mutate buf0 inside of buf1's kernel, then: + # mutation_renames = {"buf1" : "buf0"} + # in codegen we only use buf0, never buf1 + self.mutation_renames: Dict[str, str] = {} + + self.compute_dependencies() + self.nodes = self.topological_sort_schedule(self.nodes) + self.dead_node_elimination() + self.name_to_fused_node = {n.get_name(): n for n in self.nodes} + self.compute_ancestors() + if config.reorder_for_compute_comm_overlap: + self.nodes = comms.decide_global_ordering_of_comms( + self.nodes, + self.name_to_buf, + self.name_to_fused_node, + ) + + metrics.ir_nodes_pre_fusion += len(self.nodes) + V.debug.ir_pre_fusion(self.nodes) + self.num_orig_nodes = len(self.nodes) + self.create_foreach_nodes() + self.nodes = self.topological_sort_schedule(self.nodes) + self.logged_slow_fusion: OrderedSet[Tuple[str, str]] = OrderedSet() + if config._pre_fusion_custom_pass is not None: + self.nodes = config._pre_fusion_custom_pass(self.nodes) + self.nodes = self.fuse_nodes(self.nodes) + self.merge_loops() + self.finalize_multi_template_buffers() + if config.reorder_for_compute_comm_overlap: + self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) + if config.combo_kernels: + self.create_combo_kernel_nodes(num_ck_nodes=None) + self.process_grouped_nodes() + self.compute_last_usage() + V.debug.ir_post_fusion(self.nodes) + V.debug.graph_diagram(self.nodes) + self.debug_draw_graph() + + # used during codegen: + self.current_device: Optional[torch.device] = None + self.buffer_names_to_free: OrderedSet[str] = OrderedSet() + + # fx graph node to the position it appears in the graph + # for debug attribution + self.origin_to_index: Dict[torch.fx.Node, int] = {} + + get_metric_table("graph_stats").add_row( + lambda: { + "graph_id": self.post_grad_graph_id, + "num_nodes_before_fusion": self.num_orig_nodes, + "num_nodes_after_fusion": len(self.nodes), + } + ) + + def get_current_device_or_throw(self) -> torch.device: + if device := self.current_device: + return device + else: + raise RuntimeError("No current device") + + def debug_draw_graph(self) -> None: + """Generate an image of the graph for debugging""" + if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1": + from .debug import draw_buffers + + draw_buffers(self.nodes, print_graph=True) + + def debug_print_nodes(self, label: str) -> None: + if log.isEnabledFor(logging.INFO): + log.info("%s:", label) + for node in self.nodes: + node.log_details() + + def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode: + assert ( + node.get_origins() is not None + ), "All nodes passed to scheduling must have an origin" + if node.is_no_op(): + return NopKernelSchedulerNode(self, node) + elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)): + return SchedulerNode(self, node) + elif isinstance(node, ir.ExternKernel): + return ExternKernelSchedulerNode(self, node) + else: + raise NotImplementedError(node) + + def create_foreach_nodes(self) -> None: + removed_node_names: OrderedSet[str] = OrderedSet() + fe_nodes = [] + kept_node_names = self.name_to_fused_node.keys() + + for names in V.graph.lists.values(): + names = [ + name + for name in names + if name in kept_node_names + and not isinstance(self.name_to_node[name], NopKernelSchedulerNode) + ] + if not names: + # All nodes eliminated + continue + + removed_node_names.update(names) + snodes = [self.name_to_node[name] for name in names] + + enable_autotune = config.combo_kernels_autotune > 1 + fe_node = ForeachKernelSchedulerNode( + self, + snodes, + use_custom_partition_algo=False, + enable_autotune=enable_autotune, + ) + + fe_nodes.append(fe_node) + + for name in names: + self.name_to_fused_node[name] = fe_node + + self.nodes = [ + node for node in self.nodes if node.get_name() not in removed_node_names + ] + list(fe_nodes) + + def compute_dependencies(self) -> None: + """ + Create dependency edges between nodes, handling aliasing and + mutation properly. + """ + + T = TypeVar("T") + + class DedupList(Generic[T]): + """ + This data structure behaves like a list except it makes sure the + elements remain unique. + Normally one could use a OrderedSet/dict for this purpose however + the list in question gets elements appended as it is being + iterated over which means that we need to keep the list + semantics. + """ + + def __init__( + self, + items: Optional[List[T]] = None, + membership: Optional[OrderedSet[T]] = None, + ) -> None: + self.items = items or [] + self.membership = membership or OrderedSet() + + def append(self, node_user: T) -> None: + if node_user in self.membership: + return + self.items.append(node_user) + self.membership.add(node_user) + + def __add__(self, other: DedupList[T]) -> DedupList[T]: + new_membership = OrderedSet.union(self.membership, other.membership) + new_items = self.items + [ + x for x in other.items if x not in self.membership + ] + return DedupList(new_items, new_membership) + + name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict( + DedupList + ) + + # handle aliasing by using python aliasing in name_to_users + # if foo aliases bar then we will make name_to_users["foo"] point + # to the same python list as name_to_users["bar"] + for node in self.nodes: + for buf1 in node.get_outputs(): + buf1_name = buf1.get_name() + for buf2_name in buf1.get_aliases(): + if buf1_name in name_to_users and buf2_name in name_to_users: + # merge the two + list1 = name_to_users[buf1_name] + list2 = name_to_users[buf2_name] + combined = list1 + list2 + for key in name_to_users.keys(): + if ( + name_to_users[key] is list1 + or name_to_users[key] is list2 + ): + name_to_users[key] = combined + elif buf1_name in name_to_users: + name_to_users[buf2_name] = name_to_users[buf1_name] + else: + name_to_users[buf1_name] = name_to_users[buf2_name] + + def rename(n: str) -> str: + if n in self.mutation_renames: + return rename(self.mutation_renames[n]) + return n + + def add_user( + used_by_name: str, + user_node: Union[BaseSchedulerNode, OutputNode], + can_inplace: bool = False, + is_weak: bool = False, + ) -> None: + name_to_users[rename(used_by_name)].append( + NodeUser(user_node, can_inplace, is_weak) + ) + + unbacked_symbol_to_origin_node: Dict[sympy.Symbol, Optional[str]] = {} + + # NB: None means that the dependency is on an input. Don't actually + # generate a dependency because if we do, Inductor will start trying + # to free the unbacked int but that's pointless + for name, val in V.graph.graph_inputs.items(): + if isinstance(val, sympy.Expr): + for fs in val.free_symbols: + unbacked_symbol_to_origin_node[fs] = None + + for node in self.nodes: + log.debug("scheduling %s", node.node) + + # unbacked symbols don't follow ordinary buffer dependencies, so + # we track their def/uses separately + assert node.node is not None + unbacked_symbol_defs = sorted( + node.node.get_unbacked_symbol_defs(), key=lambda x: x.name + ) + for s in unbacked_symbol_defs: + assert isinstance(s, sympy.Symbol) + # Pick the first definer as canonical. There may be multiple + # because if a MultiOutputLayout buffer propagates an unbacked + # symint to multiple outputs, they will all claim to def it. + if s not in unbacked_symbol_to_origin_node: + unbacked_symbol_to_origin_node[s] = node.get_name() + + unbacked_symbol_uses = sorted( + node.node.get_unbacked_symbol_uses(), key=lambda x: x.name + ) + # if a kernel takes unbacked symints, register dependencies + for s in unbacked_symbol_uses: + assert ( + s in unbacked_symbol_to_origin_node + ), f"{s} not in {unbacked_symbol_to_origin_node}" + if (r := unbacked_symbol_to_origin_node[s]) is not None: + for buf in self.name_to_node[r].get_outputs(): + node.add_fake_dep(StarDep(buf.get_name())) + + if ( + len(node.read_writes.writes) == 1 + and (dep := next(iter(node.read_writes.writes))) + and isinstance(dep, MemoryDep) + ): + node_mode = dep.mode + else: + node_mode = None + + # Handle output mutations + for buf in node.get_outputs(): + # a node will mutate either 0 or 1 buffers + assert len(buf.get_mutations()) <= 1 + for alt_name in buf.get_mutations(): + alt_name = rename(alt_name) + # this node must run after the prior writer + add_user(alt_name, node) + node.add_fake_dep(StarDep(alt_name, mode=node_mode)) + for user in name_to_users[alt_name].items: + if user.get_name() == node.get_name(): + continue + + assert isinstance(user.node, BaseSchedulerNode) + for other_name in user.node.get_buffer_names(): + # this node must run after all prior readers + other_name = rename(other_name) + node.add_fake_dep( + WeakDep(other_name, mutating_buf=buf.get_name()) + ) + add_user(other_name, node, is_weak=True) + + # add normal non-mutation dependencies + for read in node.read_writes.reads: + if not isinstance(read, WeakDep): + add_user(read.name, node, node.can_inplace(read)) + + node.update_mutated_names(self.mutation_renames) + + # update our renaming scheme for the next iteration + for buf in node.get_outputs(): + for alt_name in buf.get_mutations(): + self.mutation_renames[rename(alt_name)] = buf.get_name() + self.mutation_renames[alt_name] = buf.get_name() + self.mutation_real_name[ + buf.get_name() + ] = self.mutation_real_name.get(alt_name, alt_name) + + # make sure outputs aren't dead-code-eliminated + for buf_name in V.graph.get_output_names(): + log.debug("scheduling output %s", buf_name) + add_user(buf_name, OutputNode(StarDep(buf_name))) + + # make sure unbacked symints aren't dead-code-eliminated + for out in V.graph.graph_outputs: + for s in out.get_unbacked_symbol_uses(): + assert ( + s in unbacked_symbol_to_origin_node + ), f"{s} not in {unbacked_symbol_to_origin_node.keys()}" + if r := unbacked_symbol_to_origin_node[s]: + for buf_name in self.name_to_node[r].get_buffer_names(): + log.debug( + "scheduling output %s for unbacked symint %s", buf_name, s + ) + add_user(buf_name, OutputNode(StarDep(buf_name))) + + # make sure input mutation isn't dead-code-eliminated + for name in self.mutation_renames: + if name in V.graph.graph_inputs: + add_user(name, OutputNode(StarDep(name))) + V.graph.mutated_inputs.add(name) + elif name in V.graph.constants: + # In AOTI, module parameters and buffers are not lifted as graph inputs + add_user(name, OutputNode(StarDep(name))) + + inp_names = { + name: index for index, name in enumerate(V.graph.graph_inputs.keys()) + } + V.graph.mutated_input_idxs = [ + inp_names[name] for name in V.graph.mutated_inputs + ] + + # copy users information onto the nodes + for node in self.nodes: + for buf in node.get_outputs(): + buf.set_users(name_to_users[buf.get_name()].items) + + def dead_node_elimination(self) -> None: + """ + Remove any nodes without users + """ + # self.nodes is in topological order, so by iterating in reverse order + # we have visited (and potentially removed) all users before visiting a + # given node. + updated_nodes = [] + for node in reversed(self.nodes): + + def can_eliminate_user(user: NodeUser) -> bool: + return user.is_weak or user.get_name() in V.graph.removed_operations + + active_buffers = False + for buf in node.get_outputs(): + can_eliminate = all(can_eliminate_user(u) for u in buf.users) + if can_eliminate: + log.debug("removed dead buffer: %s", buf.get_name()) + V.graph.removed_buffers.add(buf.get_name()) + else: + active_buffers = True + + can_eliminate = not node.has_side_effects() and not active_buffers + + if not can_eliminate: + updated_nodes.append(node) + else: + # dead code + log.debug("removed dead operation: %s", node.get_name()) + V.graph.removed_operations.add(node.get_name()) + + self.nodes = list(reversed(updated_nodes)) + + # Prune any WeakDeps no longer needed + for node in self.nodes: + node.prune_weak_deps() + + def topological_sort_schedule( + self, nodes: List[BaseSchedulerNode] + ) -> List[BaseSchedulerNode]: + """ + Ensure nodes is in topologically sorted order + """ + seen: OrderedSet[BaseSchedulerNode] = OrderedSet() + name_to_node: Dict[str, BaseSchedulerNode] = dict() + result: List[BaseSchedulerNode] = [] + + def visit(n: BaseSchedulerNode) -> None: + if n not in seen: + seen.add(n) + for dep in sorted(n.unmet_dependencies, key=lambda d: d.name): + # We only care about doing toposort within `nodes` + if dep.name not in name_to_node: + continue + visit(name_to_node[dep.name]) + result.append(n) + + for node in nodes: + for name in node.get_buffer_names(): + name_to_node[name] = node + for node in nodes: + visit(node) + return result + + def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> List[BaseSchedulerNode]: + unmet_deps = set() + if isinstance( + snode, + ( + SchedulerNode, + ExternKernelSchedulerNode, + NopKernelSchedulerNode, + FusedSchedulerNode, + ), + ): + for dep in snode.unmet_dependencies: + unmet_deps.add(dep.name) + else: + raise RuntimeError( + f"get_unmet_dep_nodes is not implemented for {type(snode)}." + ) + unmet_dep_ops = (self.name_to_buf[dep].defining_op for dep in unmet_deps) + return list({self.name_to_fused_node[n.get_name()] for n in unmet_dep_ops}) + + def _topological_sort_nodes(self) -> List[List[BaseSchedulerNode]]: + """ + Sort nodes by their topological order, return a list of node lists. + """ + order = [] + nodes = dict.fromkeys(self.nodes, 0) + children: Dict[Any, Any] = {} + for node in self.nodes: + deps = self._get_unmet_dep_nodes(node) + nodes[node] = len(deps) + for dep in deps: + c = children.get(dep, []) + c.append(node) + children[dep] = c + + zero_deg_nodes = [n for n, v in nodes.items() if v == 0] + while zero_deg_nodes: + order.append(zero_deg_nodes) + for n in zero_deg_nodes: + for user in children.get(n, []): + nodes[user] -= 1 + nodes.pop(n) + zero_deg_nodes = [n for n, v in nodes.items() if v == 0] + assert not nodes, "Topological sort failed!" + return order + + def compute_ancestors(self) -> None: + """ + Populate each node.ancestors + """ + # note self.nodes is topologically sorted + name_to_ancestors: Dict[str, OrderedSet[str]] = {} + for node in self.nodes: + ancestors: OrderedSet[str] = OrderedSet() + for dep in node.unmet_dependencies: + dep_node_name = self.name_to_buf[dep.name].defining_op.get_name() + ancestors.add(dep_node_name) + ancestors |= name_to_ancestors[dep_node_name] + name_to_ancestors[node.get_name()] = ancestors + node.ancestors = ancestors + + for order, node in enumerate(self.nodes): + node.min_order = order + node.max_order = order + + def merge_loops(self) -> None: + for node in self.nodes: + if not config.loop_ordering_after_fusion: + continue + + # Even for CPU, if we are using the halide backend, we still need + # the merge loops steps below + if not isinstance(node, (SchedulerNode, FusedSchedulerNode)) or ( + node.get_device().type != "cuda" and config.cpu_backend != "halide" + ): + continue + for snode in node.get_nodes(): + # merge loops for the scheduler node + if not isinstance(snode, SchedulerNode) or snode.is_template(): + continue + + snode._body = snode._body.merge_loops() + snode._sizes = snode._body.sizes + + # merge_loops is called after loop reordering. + # We still need retain fake dependencies since codegen the + # estimated amount of memory access rely on them. + snode.refresh_dependencies(normalize=True) + + # Note that for CPU backend, merging loops will change + # snode.group. It's fine for Triton backend. + # But if we simplify update snode.group like this: + # group_fn = self.get_backend(snode.node.get_device()).group_fn + # snode.group = (snode.node.get_device(), group_fn(snode._sizes)) + # There is still an issue due to different snode in a + # FusedSchedulerNode having different merged loops. + # Skip CPU backend for now. + + def fuse_nodes(self, nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: + """ + Combine eligible nodes into FusedSchedulerNodes. + """ + for i in range(10): + old_len = len(nodes) + fusion_log.debug( + "===== attempting fusion (%d/10): %d nodes =====", + i + 1, + old_len, + ) + nodes = self.fuse_nodes_once(nodes) + new_len = len(nodes) + fusion_log.debug( + "completed fusion round (%d/10): fused %d nodes into %d nodes\n", + i + 1, + old_len, + new_len, + ) + if new_len == old_len or new_len == 1: + fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1) + break + return nodes + + def process_grouped_nodes(self) -> None: + """ + Unpack GroupedSchedulerNode into regular nodes. + """ + new_nodes: List[BaseSchedulerNode] = [] + for node in self.nodes: + new_nodes.extend( + node.unpack() if isinstance(node, GroupedSchedulerNode) else [node] + ) + self.nodes = new_nodes + + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> Tuple[float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + assert len(nodes) > 0 + device = nodes[0].get_device() + self.current_device = device + backend = self.get_backend(device) + return backend.benchmark_fused_nodes(nodes) + + def finalize_multi_template_buffers(self) -> None: + def replace_operation_buffer( + orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer + ) -> None: + replaced_buf_name = new_node.get_name() + orig_buf_name = orig_node.get_name() + assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str) + + replaced_op_name = new_node.get_operation_name() + orig_op_name = orig_node.get_operation_name() + assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str) + + del V.graph.name_to_buffer[replaced_buf_name] + new_node.name = orig_buf_name + + del V.graph.name_to_op[replaced_op_name] + new_node.operation_name = orig_op_name + + orig = V.graph.buffers.index(orig_node) + V.graph.buffers.remove(new_node) + V.graph.buffers[orig] = new_node + V.graph.name_to_buffer[orig_buf_name] = new_node + + orig = V.graph.operations.index(orig_node) + V.graph.operations.remove(new_node) + V.graph.operations[orig] = new_node + V.graph.name_to_op[orig_op_name] = new_node + + for i, node in enumerate(self.nodes): + if isinstance(node, SchedulerNode) and isinstance( + node.node, ir.MultiTemplateBuffer + ): + multi_node = node.node + min_node_unfused, _ = multi_node.get_min_choice() + + if isinstance( + min_node_unfused, + torch._inductor.ir.TritonTemplateCallerBase, + ): + node.node.finalize_as_triton_caller(min_node_unfused) + continue + + out_tensorbox = min_node_unfused.output_node() + out_storage = out_tensorbox.data + assert isinstance(out_storage, ir.StorageBox) + out_buffer = out_storage.data + assert isinstance(out_buffer, ir.OperationBuffer) + + out_buffer.layout = multi_node.layout + replace_operation_buffer(multi_node, out_buffer) + new_scheduler_node = self.create_scheduler_node(out_buffer) + + self.nodes[i] = new_scheduler_node + self.name_to_node[node.get_name()] = new_scheduler_node + self.name_to_fused_node[node.get_name()] = new_scheduler_node + + for new_out, old_out in zip( + new_scheduler_node.get_outputs(), node.get_outputs() + ): + self.name_to_buf[old_out.get_name()] = new_out + new_out.users = old_out.users + + new_scheduler_node.min_order = node.min_order + new_scheduler_node.max_order = node.max_order + new_scheduler_node.last_usage = node.last_usage + + def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool: + return any( + hasattr(n.node, "data") + and n.node is not None + and hasattr(n.node.data, "scatter_mode") + and n.node.data.scatter_mode == "atomic_add" + for n in node_list + ) + + def speedup_by_fusion( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + If config.benchmark_fusion is False, always return True. + Otherwise, return True if fusion can brings speedup. + """ + + is_multi_template = node1.is_template() and isinstance( + node1.get_template_node(), ir.MultiTemplateBuffer + ) + if not config.benchmark_fusion and not is_multi_template: + return True + + if ( + node1.is_template() + and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer) + or node1.is_foreach() + or node2.is_foreach() + ): + # TODO support benchmarking epilogue fusion + return True + + node_list_1 = node1.get_nodes() + device = node_list_1[0].get_device() + + # don't support benchmark fusion for CPU right now. + if device.type == "cpu": + return True + + node_list_2 = node2.get_nodes() + node_list_fused = list(itertools.chain(node_list_1, node_list_2)) + + # We can not accurately benchmark kernel using atomic_add + # due to how we generate random integer inputs. + # Skip benchmarking them by allowing fusion. + if self._any_atomic_add(node_list_fused): + return True + + from triton.compiler.errors import CompilationError + + why = WhyNoFuse(node1, node2) + + def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None: + if fusion_log.isEnabledFor(logging.DEBUG): + if ms_fused < ms1 + ms2: + fusion_log.debug( + "can fuse (benchmark): fusing %s with %s cause %sx speedup", + node1.get_buffer_names(), + node2.get_buffer_names(), + green_text(f"{(ms1 + ms2) / ms_fused:.3f}"), + ) + else: + fusion_log.debug( + "cannot fuse (benchmark): fusing %s with %s cause %sx slowdown", + node1.get_buffer_names(), + node2.get_buffer_names(), + red_text(f"{ms_fused / (ms1 + ms2):.3f}"), + ) + + if isinstance(node1, SchedulerNode) and isinstance( + node1.node, ir.MultiTemplateBuffer + ): + multi_node = node1.node + choice_timings = multi_node.choice_timings + + _, ms1 = multi_node.get_min_choice() + ms2, path2 = self.benchmark_fused_nodes(node_list_2) + + min_ms_fused = float("inf") + ms_fused_choice = None + + triton_choices = 0 + + for choice, unfused_time in sorted( + choice_timings.items(), key=lambda x: x[1] + ): + if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase): + continue + + if unfused_time >= ms1 + ms2: + break + + triton_choices += 1 + if triton_choices > config.max_epilogue_benchmarked_choices: + break + + # TODO - parallel compile triton templates + # TODO - should prune/skip choices that are not within certain % of best choice + with node1.node.swap_as_triton_caller(choice): + ms_fused, _ = self.benchmark_fused_nodes(node_list_fused) + + if ms_fused < min_ms_fused: + min_ms_fused = ms_fused + ms_fused_choice = choice + + log_fusion(min_ms_fused, ms1, ms2) + + # after we do a fusion, we finalize a triton template. + # TODO - could preserve multi template and choices for subsequent fusions + if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None: + node1.node.finalize_as_triton_caller(ms_fused_choice) + return True + else: + return False + else: + try: + ms1, path1 = self.benchmark_fused_nodes(node_list_1) + if math.isinf(ms1): + why("register spilling of the first kernel") + return False + ms2, path2 = self.benchmark_fused_nodes(node_list_2) + if math.isinf(ms2): + why("register spilling of the second kernel") + return False + ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused) + if math.isinf(ms_fused): + why("register spilling of the fused kernel") + return False + except CompilationError as e: + # workaround triton issue: https://github.com/openai/triton/issues/2151 + if "Loop-carried variable" in str(e): + return True # allow fusion + else: + raise + + log_fusion(ms_fused, ms1, ms2) + if ( + is_metric_table_enabled("slow_fusion") + and ms_fused >= ms1 + ms2 + and (path1, path2) not in self.logged_slow_fusion + ): + self.logged_slow_fusion.add((path1, path2)) + get_metric_table("slow_fusion").add_row( + lambda: { + "kernel1_path": path1, + "kernel1_latency": ms1, + "kernel2_path": path2, + "kernel2_latency": ms2, + "fused_kernel_path": path_fused, + "fused_kernel_latency": ms_fused, + "slow_down_ratio": ms_fused / (ms1 + ms2), + } + ) + return ms_fused < ms1 + ms2 + + def fuse_nodes_once( + self, nodes: List[BaseSchedulerNode] + ) -> List[BaseSchedulerNode]: + """ + Combine eligible nodes into FusedSchedulerNodes. + + This relies on two key functions to control the logic: + - self.can_fuse(): checks if a fusion is legal + - self.score_fusion(): assigns priority to a given fusion + """ + fused_nodes = OrderedSet(nodes) + if fusion_log.isEnabledFor(logging.DEBUG): + fusion_log.debug("fuse_nodes_once, candidates:") + for node in fused_nodes: + fusion_log.debug(" " + node.debug_str_short()) # noqa: G003 + for node1, node2 in self.get_possible_fusions(nodes): + node1 = self.name_to_fused_node[node1.get_first_name()] + node2 = self.name_to_fused_node[node2.get_first_name()] + if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle( + node1, node2 + ): + if not self.speedup_by_fusion(node1, node2): + continue + fusion_log.debug( + "fusing %s with %s", node1.get_name(), node2.get_name() + ) + + # above can_fuse asserts that node2 has the same device + device = node1.get_device() + node3 = self.get_backend(device).fuse(node1, node2) + fused_nodes.remove(node1) + fused_nodes.remove(node2) + fused_nodes.add(node3) + self.name_to_fused_node.update( + {n.get_name(): node3 for n in node3.get_nodes()} + ) + nodes = sorted(fused_nodes, key=lambda x: x.min_order) + nodes = self.topological_sort_schedule(nodes) + self.prune_redundant_deps(nodes) + return nodes + + def create_combo_kernel_nodes(self, num_ck_nodes: Optional[int] = None) -> None: + """ + Groups parallel nodes + """ + fused_nodes = set(self.nodes) + count = 0 + num_nodes_orig = len(self.nodes) + log.debug("ComboKernels: Generating with num_ck_nodes = %d...", num_ck_nodes) + for num, node_list in enumerate( + ForeachKernelSchedulerNode.group_nodes_for_combo_kernels(self) + ): + node_list = ForeachKernelSchedulerNode.combinable_nodes(node_list) + if len(node_list) < 2: + continue + if num_ck_nodes is not None and count > num_ck_nodes: + break + if not self.speedup_by_combo_kernel(node_list): + log.debug("ComboKernels: Not speeding up %d-th group", num) + continue + count += 1 + enable_autotune = config.combo_kernels_autotune > 0 + group_snode = ForeachKernelSchedulerNode( + node_list[0].scheduler, + node_list, + use_custom_partition_algo=True, + enable_autotune=enable_autotune, + ) + log.info( + "ComboKernels: Combining %d nodes for %d-th group", + len(node_list), + num, + ) + for node in node_list: + fused_nodes.remove(node) + fused_nodes.add(group_snode) + self.name_to_fused_node.update( + {n.get_name(): group_snode for n in group_snode.get_nodes()} + ) + self.nodes = sorted(fused_nodes, key=lambda x: x.min_order) + self.nodes = self.topological_sort_schedule(self.nodes) + log.info( + "Generated ComboKernel nodes: %d ComboKernels, totally %d -> %d nodels", + count, + num_nodes_orig, + len(self.nodes), + ) + self.prune_redundant_deps(self.nodes) + + def prune_redundant_deps(self, nodes: List[BaseSchedulerNode]) -> None: + for node in nodes: + node.prune_redundant_deps(self.name_to_fused_node) + + def get_possible_fusions( + self, nodes: List[BaseSchedulerNode] + ) -> List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]: + """ + Helper to find all legal fusion opportunities, sorted by self.score_fusion() + """ + possible_fusions = [] + seen: OrderedSet[Tuple[BaseSchedulerNode, BaseSchedulerNode]] = OrderedSet() + + def check_all_pairs(nodes: List[BaseSchedulerNode]) -> None: + for node1_index, node1 in enumerate(nodes): + for node2 in nodes[node1_index + 1 :]: + key = (node1, node2) + if key in seen: + continue + seen.add(key) + + if self.can_fuse(node1, node2): + possible_fusions.append(key) + elif (node2.is_template() or node2.is_foreach()) and self.can_fuse( + node2, node1 + ): + # foreach fusions and epilogue fusions are order dependent + possible_fusions.append((node2, node1)) + + buffer_names_grouping = collections.defaultdict(list) + for node in nodes: + for buf in node.used_buffer_names(): + buffer_names_grouping[buf].append(node) + for node_grouping in buffer_names_grouping.values(): + check_all_pairs(node_grouping) + + if config.aggressive_fusion: + group_grouping = collections.defaultdict(list) + for node in nodes: + group = getattr(node, "group", None) + if group: + group_grouping[group].append(node) + for node_grouping in group_grouping.values(): + check_all_pairs(node_grouping) + + possible_fusions = self.get_possible_fusions_with_highest_priority( + possible_fusions + ) + possible_fusions.sort(key=self.score_fusion_key, reverse=True) + fusion_log.debug("found %d possible fusions", len(possible_fusions)) + return possible_fusions + + def will_fusion_create_cycle( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Finds whether there's a path from node1 to node2 (or vice-versa) + caused indirectly by other fusions. + """ + # since we are just returning boolean here, use slightly faster, unordered set + visited: Set[FusedSchedulerNode] = set() + + def found_path(node: BaseSchedulerNode) -> bool: + # only fused nodes can introduce new ancestors. + if isinstance(node, FusedSchedulerNode) and node not in visited: + visited.add(node) + if node.get_operation_names().issubset(combined_ancestors): + # All fusion outputs are in ancestors of node1 and node2, thus + # cannot introduce new path: + # + # 1. if output is neither descendent of node1 or node2, the + # output cannot introduce a path + # 2. due to [can_fuse]: if WLOG output is descendent of node1, it cannot be + # on path(node1->node2), hence it cannot be ancestor of node2 + # 3. due to [acyclic]: if WLOG output is descendent of node1, it cannot be + # ancestor of node1 + return False + else: + # continue DFS of new ancestors introduced by the fusion + return bool(combined_names & node.ancestors) or any( + found_path(self.name_to_fused_node[n]) + for n in node.ancestors - combined_ancestors + ) + return False + + # as above - use slightly faster, unordered set + combined_names = ( + node1.get_operation_names()._dict.keys() + | node2.get_operation_names()._dict.keys() + ) + combined_ancestors = ( + node1.ancestors._dict.keys() | node2.ancestors._dict.keys() + ) - combined_names + cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors) + if cycle: + WhyNoFuse(node1, node2)("will create cycle") + return cycle + + def can_fusion_increase_peak_memory( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + This function prevents fusion for nodes that can increase memory + footprint. This problem is more common in horizontal fusion, where nodes + that are far apart in the original order get fused, lengthening the live + intervals of tensors. This is very evident in models with activation + checkpointing, where the recomputed nodes from different checkpointed + regions get fused and significantly increase the memory footprint. + + The current attempt is a quick, possibly hacky, heuristic to prevent the + fusion of nodes that are far away in the original order. + + A better but difficult to implement heurisitic would be to use live + intervals of the buffers, find region of peak pressure in the original + program and prevent fusion that crosses that peak region. We might need + special care or good approximation in this implementation, as fusion of + node changes live intervals, and re-computing live intervals and peak + memory after each fusion can introduce large compilation overhead. + """ + proximity_score = max( + abs(node1.min_order - node2.max_order), + abs(node2.min_order - node1.max_order), + ) + return proximity_score > 64 + + def decide_fusion_fail_reason( + self, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + common_buf_names: Tuple[str, ...], + ) -> str: + """ + Try to decide reasons why fusion fail due to no shared memory even though + there are common buffers. + """ + reasons = {} + node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} + node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} + + for buf_name in common_buf_names: + buf = V.graph.get_buffer(buf_name) + lhs_dep = node1_name2dep[buf_name] + rhs_dep = node2_name2dep[buf_name] + + if lhs_dep.get_numel() != rhs_dep.get_numel(): + reasons[ + buf_name + ] = f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}" + continue + + # same numel but different MemoryDep.size. Should be broadcasting + if sympy_product(lhs_dep.size) != sympy_product(rhs_dep.size): + reasons[buf_name] = "broadcast" + continue + + if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep): + reasons[ + buf_name + ] = f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}" + continue + + lhs_off = lhs_dep.get_offset() + rhs_off = rhs_dep.get_offset() + if lhs_off != rhs_off: + # One example is in transformer, we use a concatenated linear layer + # to project Q/K/V and then split the result. The 3 splits will + # point to the same buffer with different offsets. + reasons[buf_name] = f"different offset: {lhs_off} v.s. {rhs_off}" + continue + + if ( + lhs_dep.normalize_with_stride_order() + == rhs_dep.normalize_with_stride_order() + ): + reasons[buf_name] = f"Mismatch loop orders: {lhs_dep} v.s. {rhs_dep}" + continue + + # Add more rules here + reasons[ + buf_name + ] = f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. Layout: {buf.layout}" + + return str(reasons) + + def has_shared_data_after_reordering_loop( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Right now just greedily reorder the loop of node1 to be compatible with node2, + but ideally we should have some heuristics to reorder the loop for node2 + to be compatibile with node1 if that's more efficient. + """ + + # TODO Don't do loop reordering for CPU for now. + # Should debug more why it does not work for CPU codegen + if not config.loop_ordering_after_fusion or any( + n.get_device().type == "cpu" for n in [node1, node2] + ): + return False + + node1_buffer_names = node1.read_writes.buffer_names() + node2_buffer_names = node2.read_writes.buffer_names() + # Fast path: no common buffers. + common_buffer_names = node1_buffer_names & node2_buffer_names + if not common_buffer_names: + return False + + node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} + node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} + + # Find the commons buffers that has different loop orders + candidates = [] + for buffer_name in common_buffer_names: + lhs_dep = node1_name2dep[buffer_name] + rhs_dep = node2_name2dep[buffer_name] + if ( + lhs_dep.normalize_with_stride_order() + == rhs_dep.normalize_with_stride_order() + ): + candidates.append( + ( + V.graph.sizevars.size_hint(lhs_dep.get_numel(), fallback=0), + lhs_dep, + rhs_dep, + ) + ) + + if len(candidates) == 0: + return False + + # Pick the largest buffer to guide the loop reordering + numel, lhs_dep, rhs_dep = sorted(candidates, reverse=True, key=lambda x: x[0])[ + 0 + ] + + if lhs_dep.num_vars != rhs_dep.num_vars: + # this can happen due to we don't merge loops. + # We can not do loop reordering in this case right now + # Simply returning true if the two Deps are the same after + # normalization (merging loops) + return lhs_dep.normalize() == rhs_dep.normalize() + + # Only reorder loops for pointwise for now + if not node1.is_reduction(): + node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep) + elif not node2.is_reduction(): + node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep) + else: + loop_ordering_log.debug( + "Don't reorder loops since both nodes are reductions: %s v.s. %s", + node1.get_name(), + node2.get_name(), + ) + + return self.score_fusion_memory(node1, node2) > 0 + + def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: + """ + Determine if it is possible to combine node1 and node2 into a + single fused node. + """ + + if node1 is node2: + return False + + why = WhyNoFuse(node1, node2) + + if isinstance(node1, GroupedSchedulerNode) or isinstance( + node2, GroupedSchedulerNode + ): + why("grouped node must not be fused with other nodes") + return False + if ( + isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node1.is_template() + ): + why("node1 is extern or nop") + return False + if ( + isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node2.is_template() + ): + why("node2 is extern or nop") + return False + + if node2.get_operation_names() & node1.ancestors: + why("node1 must go before node2") + return False + + if node2.is_template(): + why("templates can only fuse epilogues") + return False + if node1.is_template() and ( + node2.has_aliasing_or_mutation() + or node2.is_reduction() + or not config.epilogue_fusion + ): + why("template epilogue not satisfied") + return False + + if ( + node1.get_buffer_names() | node2.get_buffer_names() + ) & V.graph.no_fuse_buffer_names: + why("fusion for buffer explicit disabled") + return False + + device = node1.get_device() + device2 = node2.get_device() + if device != device2: + why("device mismatch (%s vs %s)", device, device2) + return False + del device2 + + no_shared_data = self.score_fusion_memory(node1, node2) == 0 + if no_shared_data: + no_shared_data = not self.has_shared_data_after_reordering_loop( + node1, node2 + ) + + loop_ordering_log.debug( + "%s and %s has%s shared data", + node1.get_name(), + node2.get_name(), + " no" if no_shared_data else "", + ) + if no_shared_data and ( + not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction() + ): + if is_metric_table_enabled("fusion_failure_due_to_indexing_mismatch"): + common_buf_names = ( + node1.read_writes.buffer_names() & node2.read_writes.buffer_names() + ) + if len(common_buf_names) > 0: + get_metric_table("fusion_failure_due_to_indexing_mismatch").add_row( + lambda: { + "pre_grad_graph_id": V.graph.graph_id, + "post_grad_graph_id": V.graph.post_grad_graph_id, + "node1_name": node1.get_name(), + "node2_name": node2.get_name(), + "node1_debug_str": write_text(node1.debug_str()), + "node2_debug_str": write_text(node2.debug_str()), + "common_buffer_names": list(common_buf_names), + "failure_reason": self.decide_fusion_fail_reason( + node1, node2, common_buf_names + ), + } + ) + + why("no shared data due to indexing mismatch") + return False + why("no shared data") + return False # heuristic not needed for correctness + + if ( + not node1.is_foreach() + and not node2.is_foreach() + and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size + ): + why("exceeds max fusion") + return False # heuristic not needed for correctness + + if node1.get_operation_names() & node2.ancestors: + # node2 depends on node1 outputs + if not self.can_fuse_vertical(node1, node2): + return False + return self.get_backend(device).can_fuse_vertical(node1, node2) + else: # nodes don't depend on each other, but may have common reads + if self.can_fusion_increase_peak_memory(node1, node2): + why("will increase peak memory") + return False + return self.get_backend(device).can_fuse_horizontal(node1, node2) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Check if it is legal to fuse a consumer (node2) into a producer (node1). + + We can fuse them if all the reads of node2 either match + corresponding writes in node1, or are written by nodes that can + be scheduled before the fusion of node1 and node2. + """ + node1_buf_names = node1.get_buffer_names() + node1_op_names = node1.get_operation_names() + computed_deps: OrderedSet[Dep] = OrderedSet() + why = WhyNoFuse(node1, node2) + + for cd in node1.read_writes.writes: + if not isinstance(cd, MemoryDep): + continue + for rd in node2.unmet_dependencies: + if self.fusable_read_and_write(rd, cd): + computed_deps.add(rd) + + for dep in node2.unmet_dependencies: + if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2): + computed_deps.add(dep) + + remaining_deps = OrderedSet( + dep.name for dep in node2.unmet_dependencies - computed_deps + ) + if remaining_deps & node1_buf_names: + # MemoryDeps didn't match and read different locations of the same buffer. + # Examples here include: + # - MemoryDep("foo", x) != MemoryDep("foo", x + 1) + # - MemoryDep("foo", x) != StarDep("foo") + why("memory deps did not match") + return False + for name in remaining_deps: + op_name = self.name_to_buf[name].defining_op.get_name() + if node1_op_names & self.name_to_fused_node[op_name].ancestors: + why("intermediate nodes between node1 & node2") + return False + + return True + + def fusable_weak_dep( + self, weak_dep: WeakDep, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if weak_dep.name not in node1.get_buffer_names(): + return False + + # A weak dep can be fused if and only if the fused operation acts inplace + # on the buffer being mutated. i.e. the same index is being read then mutated + mutating_writes = [ + write + for write in node2.read_writes.writes + if write.name == weak_dep.mutating_buf + ] + if len(mutating_writes) != 1: + return False + write = mutating_writes[0] + assert isinstance(write, MemoryDep) + + if free_symbol_is_type(write.index, SymT.TMP): + return False + + real_name = self.mutation_real_name[weak_dep.mutating_buf] + relevant_reads = [ + read for read in node1.read_writes.reads if read.name == real_name + ] + return all( + isinstance(read, MemoryDep) + and not free_symbol_is_type(read.index, SymT.TMP) + and read.index == write.index + and read.size == write.size + for read in relevant_reads + ) + + # StarDep doesn't match MemoryDep, different indices don't match + # However, broadcasting sometimes strips dimensions, and if that's the case + # we still can match unmet dep + # if there's indirect indexing, don't match it + def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: + if isinstance(read, MemoryDep): + if read.mode == write.mode and write.mode is not None: + return True + read_name = self.mutation_renames.get(read.name, read.name) + + if ( + read_name != write.name + or free_symbol_is_type(read.index, SymT.TMP) + or free_symbol_is_type(write.index, SymT.TMP) + ): + return False + + if config.loop_ordering_after_fusion and read.num_vars != write.num_vars: + # Need merge loops if we do loop ordering after fusion since + # we have not merged the loops yet when creating the scheduler + # nodes. + read = read.normalize() + write = write.normalize() + + return ( + read.index == write.index + and len(read.size) >= len(write.size) + and read.size[: len(write.size)] == write.size + ) + elif isinstance(read, StarDep): + read_name = self.mutation_renames.get(read.name, read.name) + write_name = self.mutation_renames.get(write.name, write.name) + if ( + read.mode == write.mode + and write.mode is not None + and read_name == write_name + ): + return True + return False + + def score_fusion( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> Tuple[bool, bool, int, int]: + """ + Assign a score (higher comes first) to the fusion of node1 + and node2. When different fusions conflict with each other, + this is the way we decide what order to run them in. + + Our current score is based on: + - Estimate of the saved memory operations + - Fusions closer together in original order + """ + memory_score = self.score_fusion_memory(node1, node2) + proximity_score = -max( + abs(node1.min_order - node2.max_order), + abs(node2.min_order - node1.max_order), + ) + return ( + node1.is_template() == config.epilogue_fusion_first and memory_score > 0, + node1.is_reduction() == node2.is_reduction() and memory_score > 0, + memory_score, + proximity_score, + ) + + def dep_size_hint(self, dep: Dep) -> int: + res = 0 + if dep not in self.__dep_size_hint_cache: + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + self.__dep_size_hint_cache[dep] = res + else: + res = self.__dep_size_hint_cache[dep] + return res + + def score_fusion_memory( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + """ + The first term in our fusion score that estimates number of saved + memory operations. + """ + node1_dep_len = len(node1.read_writes.reads) + len(node1.read_writes.writes) + node2_dep_len = len(node1.read_writes.reads) + len(node2.read_writes.writes) + + # optimization: iter over smaller set + if max(node1_dep_len, node2_dep_len) * 4 > min(node1_dep_len, node2_dep_len): + if node1_dep_len > node2_dep_len: + tmp = node1 + node1 = node2 + node2 = tmp + + deps = [] + for dep in node1.read_writes.reads | node1.read_writes.writes: + if dep in node2.read_writes.reads or dep in node2.read_writes.writes: + deps.append(dep) + + return sum(self.dep_size_hint(dep) for dep in deps) + + common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( + node2.read_writes.reads | node2.read_writes.writes + ) + return sum(self.dep_size_hint(dep) for dep in common_memory_deps) + + def get_possible_fusions_with_highest_priority( + self, possible_fusions: List[Tuple[BaseSchedulerNode, BaseSchedulerNode]] + ) -> List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]: + # Group the possible fusions based on their priority from the backend. + # Only return the group of possible fusions with highest priority. + if len(possible_fusions) == 0: + return possible_fusions + possible_fusions_group_by_priority: Dict[ + int, List[Tuple[BaseSchedulerNode, BaseSchedulerNode]] + ] = {} + + for node1, node2 in possible_fusions: + assert node1.get_device() == node2.get_device() + device = node1.get_device() + fusion_pair_priority = int( + self.get_backend(device).get_fusion_pair_priority(node1, node2) + ) + if fusion_pair_priority not in possible_fusions_group_by_priority: + possible_fusions_group_by_priority[fusion_pair_priority] = [ + (node1, node2), + ] + else: + possible_fusions_group_by_priority[fusion_pair_priority].append( + (node1, node2) + ) + # return the possible fusions with highest priority + possible_fusions_with_highest_priority = min( + possible_fusions_group_by_priority.items(), key=operator.itemgetter(0) + )[1] + assert len(possible_fusions_with_highest_priority) > 0 + return possible_fusions_with_highest_priority + + def score_fusion_key( + self, nodes: Tuple[BaseSchedulerNode, BaseSchedulerNode] + ) -> Tuple[bool, bool, int, int]: + """ + Shim for list.sort(key=...) + """ + node1, node2 = nodes + return self.score_fusion(node1, node2) + + def compute_last_usage(self) -> None: + """ + Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode) + """ + + future_used_buffers: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + + for node in reversed(self.nodes): + node.set_last_usage(future_used_buffers, self.mutation_real_name) + future_used_buffers.update(node.last_usage) + + def free_buffers(self) -> None: + """Free any buffers that are no longer needed""" + for name in sorted( + self.buffer_names_to_free + - V.graph.removed_buffers + - V.graph.wrapper_code.freed + ): + if name in self.name_to_buf: + buf = self.name_to_buf[name] + if buf.can_free(): + V.graph.wrapper_code.codegen_free(buf.node) + elif name in V.graph.graph_inputs: + storage = V.graph.graph_inputs[name].data + assert isinstance(storage, ir.StorageBox) and storage.is_input_buffer() + V.graph.wrapper_code.codegen_free(storage.data) + + self.buffer_names_to_free.clear() + + def remove_kernel_local_buffers(self) -> None: + """ + Any buffers that are both created and have a last use in the + same kernel can be removed. + """ + + fused_node_names = OrderedSet( + self.name_to_buf[buf].defining_op.get_name() + for buf in V.kernel.store_buffer_names + if buf in self.name_to_buf + ) + names_to_remove = [] + for out_buf in V.kernel.store_buffer_names: + if out_buf not in self.name_to_buf: + # Aux buffers created during kernel codegen + names_to_remove.append(out_buf) + continue + users = self.name_to_buf[out_buf].users + assert users is not None + users = OrderedSet(user.get_name() for user in users if not user.is_weak) + if users.issubset(fused_node_names): + names_to_remove.append(out_buf) + + def remove_filter(n: str) -> bool: + return ( + n not in V.kernel.must_keep_buffers + and n not in V.kernel.args.input_buffers + and n not in self.mutation_renames + and n not in self.mutation_real_name + ) + + names_to_remove = list(filter(remove_filter, names_to_remove)) + + for name in names_to_remove: + if name in V.kernel.args.inplace_buffers: + buf = V.kernel.args.inplace_buffers[name] + if isinstance(buf, str) and buf.startswith("REMOVED"): + continue + remove = all(n in names_to_remove for n in buf.other_names) + if remove: + self.remove_inplace_buffer(name) + V.kernel.inplaced_to_remove.add(name) + else: + self.remove_buffer(name) + + def remove_buffer(self, name: str) -> None: + # Assign a special value instead of deleting the entry + # because we still rely on output_buffers's length to + # generate unique arg name. + log.debug("remove_buffer(%r)", name) + V.kernel.args.output_buffers[name] = "REMOVED" + V.kernel.removed_buffers.add(name) + + def remove_inplace_buffer(self, name: str) -> None: + log.debug("removing_inplace_buffer(%r)", name) + inner_name = V.kernel.args.inplace_buffers[name].inner_name + V.kernel.args.inplace_buffers[name] = inner_name.replace( + "in_out_ptr", "REMOVED" + ) + V.kernel.removed_buffers.add(name) + + def flush(self) -> None: + for backend in self.backends.values(): + backend.flush() + self.free_buffers() + + def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode) -> None: + assert isinstance(scheduler_node, ExternKernelSchedulerNode) + # 'decide_inplace_update' stores the inplace update decisions in + # the current kernel from where 'allocate' retrieve those decisions. + # We have to make sure there is a non-NULL kernel handler to store + # those inplace update decisions. + counters["inductor"]["extern_calls"] += 1 + with V.set_kernel_handler(Kernel(increase_kernel_count=False)): + scheduler_node.decide_inplace_update() + scheduler_node.mark_run() + node = scheduler_node.node + assert isinstance(node, ir.ExternKernel), f"{type(node)=}" + node.codegen(V.graph.wrapper_code) + self.free_buffers() + + def create_backend(self, device: torch.device) -> BaseScheduling: + assert ( + not is_gpu(device.type) or device.index is not None + ), f"{device} should have been normalized in lowering" + V.graph.add_device_info(device) + + device_scheduling = get_scheduling_for_device(device.type) + if device_scheduling is None: + raise RuntimeError(f"Unsupported device type: {device.type}") + + if not has_triton(): + if ( + device.type == "cuda" + and (device_props := torch.cuda.get_device_properties(device)).major < 7 + ): + raise RuntimeError( + f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" # noqa: B950 + ) + elif is_gpu(device.type): + raise RuntimeError( + "Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950 + ) + + return device_scheduling(self) + + def get_backend(self, device: torch.device) -> BaseScheduling: + if device not in self.backends: + self.backends[device] = self.create_backend(device) + return self.backends[device] + + def enter_context(self, node: BaseSchedulerNode) -> None: + def get_order(n: torch.fx.Node) -> int: + if n not in self.origin_to_index: + self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) + return self.origin_to_index[n] + + # Use a dict to have ordering + origins = { + (get_order(e), e): None + for n in node.get_nodes() + if n.node is not None + for e in n.node.get_origins() + } + origins = list(origins.keys()) + if origins: + _, last = max(origins, key=operator.itemgetter(0)) + V.graph.wrapper_code.enter_context(last) + + def codegen(self) -> None: + with dynamo_timed("Scheduler.codegen"): + return self._codegen() + + def _codegen(self) -> None: + if config.check_stack_no_cycles_TESTING_ONLY: + import torch._dynamo.convert_frame + + stack = traceback.extract_stack() + seen = set() + for frame in reversed(stack): + # This is where maybe_cprofile is + if ( + frame.name == "_compile_inner" + and frame.filename == torch._dynamo.convert_frame.__file__ + ): + break + key = (frame.filename, frame.lineno) + assert key not in seen, ( + f"Duplicate stack frame {frame.filename}:{frame.lineno}; " + "did you add a decorator to one of the functions in this stack " + "trace? If so, try using a context manager instead." + ) + seen.add(key) + + for node in self.nodes: + try: + log.debug( + "Generating code for node %s with estimated runtime %f", + node.get_name(), + node.get_estimated_runtime(), + ) + except Exception as e: + log.debug( + "Generating code for node %s with estimated runtime 0.0", + node.get_name(), + ) + + self.enter_context(node) + + if not isinstance(node, NopKernelSchedulerNode) and ( + device := node.get_device() + ): + if ( + device != self.current_device + or node.is_extern() + or node.is_template() + ): + self.flush() + if device != self.current_device: + if self.current_device and device_need_guard( + self.current_device.type + ): + V.graph.wrapper_code.codegen_device_guard_exit() + if device_need_guard(device.type): + assert device.index is not None, "device should have an index" + V.graph.wrapper_code.codegen_device_guard_enter(device.index) + + self.current_device = device + + self.buffer_names_to_free.update(node.last_usage) + + if node.is_template(): + node, *epilogue = node.get_nodes() + self.get_backend(device).codegen_template(node, epilogue) + elif node.is_extern(): + node = typing.cast(ExternKernelSchedulerNode, node) + self.codegen_extern_call(node) + elif node.is_foreach(): + node = typing.cast(ForeachKernelSchedulerNode, node) + backend_ = self.get_backend(device) + from .codegen.cuda_combined_scheduling import CUDACombinedScheduling + from .codegen.simd import SIMDScheduling + + if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)): + backend = backend_ + else: + raise AssertionError(f"{type(self)=}") + backend.codegen_combo_kernel(node) + elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): + self.get_backend(device).codegen_node(node) + else: + assert isinstance(node, NopKernelSchedulerNode) + node.mark_run() + + if config.triton.debug_sync_kernel: + self.get_backend(device).codegen_sync() + + self.available_buffer_names.update(node.get_buffer_names()) + self.completed_operations.update(node.get_operation_names()) + + if not isinstance(node, NopKernelSchedulerNode): + device = node.get_device() + if device is not None and self.get_backend(device).ready_to_flush(): + self.flush() + + if self.current_device and device_need_guard(self.current_device.type): + # exit the outermost CUDA device guard. this is + # important for nested indentation codegen-ing. + V.graph.wrapper_code.codegen_device_guard_exit() + + self.flush() + + def benchmark_combo_kernel( + self, node_list: Sequence[BaseSchedulerNode] + ) -> Tuple[float, float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + device = node_list[0].get_device() + V.graph.scheduler = self + self.current_device = device + backend = self.get_backend(device) + return backend.benchmark_combo_kernel(node_list) + + def speedup_by_combo_kernel(self, nodes: List[BaseSchedulerNode]) -> bool: + """ + If config.benchmark_fusion is False, always return True. + Otherwise, return True if fusion can brings speedup. + """ + if not config.benchmark_combo_kernel: + return True + + subkernel_nodes = nodes + device = subkernel_nodes[0].get_device() + + # don't support benchmark fusion for CPU right now. + if device.type == "cpu": + return True + + from triton.compiler.errors import CompilationError + + ms1, path1_list = 0.0, [] + for i, snode in enumerate(subkernel_nodes): + node_list = snode.get_nodes() + # We can not accurately benchmark kernel using atomic_add + # due to how we generate random integer inputs. + if self._any_atomic_add(node_list): + fusion_log.debug( + "ComboKernel: benchmarking may not accurate due to atomic_add" + ) + + try: + ms, path = self.benchmark_fused_nodes(node_list) + if math.isinf(ms): + fusion_log.debug( + "ComboKernel benchmark: register spilling of %d-th subkernel", + i, + ) + return False + except CompilationError as e: + # workaround triton issue: https://github.com/openai/triton/issues/2151 + if "Loop-carried variable" in str(e): + fusion_log.debug( + "ComboKernel benchmark: return True because of loop-carried variable" + ) + return True # allow fusion + else: + raise + ms1 += ms + path1_list.append(path) + + try: + ms2, ms2_clone, path2_list = self.benchmark_combo_kernel(subkernel_nodes) + except CompilationError as e: + # workaround triton issue: https://github.com/openai/triton/issues/2151 + if "Loop-carried variable" in str(e): + fusion_log.debug( + "ComboKernel benchmark: return True because of loop-carried variable" + ) + return True # allow fusion + else: + raise + + # small kernels are very likely to have speedup but hard to benchmark. So we skip benchmarking. + small_kernel = ms2 - ms2_clone < 0.3 or ms1 < 0.3 + if fusion_log.isEnabledFor(logging.DEBUG): + if ms1 > ms2 or small_kernel: + fusion_log.debug( + "can fuse (benchmark): fusing causes %sx speedup", + green_text(f"{ms1 / ms2:.3f}"), + ) + else: + fusion_log.debug( + "cannot fuse (benchmark): fusing causes %sx slowdown", + red_text(f"{ms1 / ms2:.3f}"), + ) + # ms1 returned by benchmark_fused_nodes discounted clone time + return ms2 - ms2_clone < ms1 or small_kernel + + def get_buffer_layout(self, buf_name: str) -> ir.Layout: + buf = self.name_to_buf[buf_name] + assert buf.node is not None + return buf.node.get_layout() + + def update_zero_dim_cpu_tensor(self) -> None: + for node in self.nodes: + if node.get_device() and is_gpu(node.get_device().type): + for read in node.read_writes.reads: + buffer = V.graph.name_to_buffer.get(read.name) + if ( + buffer + and buffer.get_device() + and buffer.get_device().type == "cpu" + and not isinstance(buffer.layout, MultiOutputLayout) + and buffer.get_size() == [] + ): + V.graph.zero_dim_cpu_tensor_list.add(read.name) + + +class BaseScheduling: + @classmethod + def get_backend_features(cls, device: torch.device) -> Sequence[BackendFeature]: + """Return a set of .codegen.common.BackendFeature()""" + return () + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Check whether node1 and node2 can be vertically fused or not. + """ + raise NotImplementedError + + def can_fuse_horizontal( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Check whether node1 and node2 can be horizontally fused or not. + """ + raise NotImplementedError + + def fuse( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> FusedSchedulerNode: + """ + Fuse two nodes + """ + if node1.is_foreach() or node2.is_foreach(): + return ForeachKernelSchedulerNode.fuse(node1, node2) + else: + return FusedSchedulerNode.fuse(node1, node2) + + def group_fn( + self, sizes: Sequence[Sequence[sympy.Expr]] + ) -> Tuple[Tuple[sympy.Expr, ...], ...]: + """ + Process the iteration sizes in case a transformation needs to be applied. + """ + raise NotImplementedError + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + ) -> Optional[str]: + """ + Given a template node, generate a kernel. + + This function is only available for triton now. If the third-party backend behaves as a sub-class + of TritonScheduling, it can override it or reuse it. + """ + raise NotImplementedError + + def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None: + """ + Generate a kernel given a list of pre-fused nodes. + """ + raise NotImplementedError + + def codegen_sync(self) -> None: + """ + Generate synchronization code for the kernel. This method depends on the hardware characteristics. + """ + raise NotImplementedError + + def ready_to_flush(self) -> bool: + """ + Check whether the backend is requesting the scheduler to flush the generated kernel. + If not supported, please return False. + """ + return False + + def flush(self) -> None: + """ + Flush the generated kernel and python wrapper code to the source code file. + """ + raise NotImplementedError + + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> Tuple[float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + raise NotImplementedError + + def get_fusion_pair_priority( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + """ + Return an unsigned integer which represents the priority of this fusion pair. + The smaller is with higher priority. + """ + return 0 + + def benchmark_combo_kernel( + self, node_list: Sequence[BaseSchedulerNode] + ) -> Tuple[float, float, str]: + """ + Benchmark the list of nodes to combine and return the execution time + and memory copy time in milliseconds on randomly generated inputs. + """ + raise NotImplementedError + + +def debug_triton_code(node: Union[SchedulerNode, FusedSchedulerNode]) -> List[str]: + lines = [] + multi_template = node.get_template_node() + assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) + if multi_template and multi_template.make_kernel_render is None: + lines.append(f"{node.get_name()} Unfinalized multi template buffer") + else: + from torch._inductor.codegen.cuda_combined_scheduling import ( + CUDACombinedScheduling, + ) + + from .codegen.simd import SIMDScheduling + + snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes + device = snodes[0].get_device() + backend = node.scheduler.get_backend(device) + assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)) + V.graph.scheduler.current_device = device + + # Don't increment kernel count when generating debug string. + # This will confuse some unit tests that check the number of + # generated kernels. + old_generated_kernel_count = metrics.generated_kernel_count + triton_code = backend.generate_kernel_code_from_nodes(snodes).strip() + metrics.generated_kernel_count = old_generated_kernel_count + + lines.append(f"{node.get_name()} Triton code:") + lines.append(textwrap.indent(triton_code, " ")) + return lines diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/sizevars.py b/.venv/lib/python3.11/site-packages/torch/_inductor/sizevars.py new file mode 100644 index 0000000000000000000000000000000000000000..8d3c6d411d278c2ef893ad06f8195dbd2096572c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/sizevars.py @@ -0,0 +1,892 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import sympy +from sympy import Expr + +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, ShapeEnv +from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import symbol_is_type, SymT +from torch.utils._sympy.value_ranges import bound_sympy, IntInfinity, ValueRanges + +from .runtime.runtime_utils import is_power_of_2 +from .utils import ( + has_free_symbols, + sympy_index_symbol, + sympy_index_symbol_with_prefix, + sympy_subs, + VarRanges, +) +from .virtualized import V + + +log = logging.getLogger(__name__) + + +def evaluate_expr( + shape_env: ShapeEnv, + expr: Union[sympy.Basic, bool], + axioms: Optional[Tuple[sympy.Expr]] = None, + var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges[Any]]]] = None, +) -> bool: + if expr in (True, False): + return bool(expr) + + try: + simplified = shape_env._maybe_evaluate_static( + expr, + axioms=axioms, + var_to_range=var_to_range, + ) + if simplified is not None: + return bool(simplified) + except Exception: + log.debug("Could not simplify %s", expr, exc_info=True) + + return False + + +# This class is a little awkward, because ShapeEnv is doing most of the heavy +# lifting and in some cases we should be directly passing through to ShapeEnv, +# but there is some extra inductor logic that needs to be handled here +class SizeVarAllocator: + def __init__(self, shape_env=None) -> None: + super().__init__() + if shape_env is None: + shape_env = ShapeEnv() + self.shape_env = shape_env + self.var_to_val = self.shape_env.var_to_val + self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements + # Maps of dynamic sizes that have to be precomputed on the host to the kernel args. + # The basic idea is if we have some complicated sympy expression + # f(s0), we may choose to precompute it on the host and then replace + # all occurrences of that sympy expression with ps0, so that when we + # codegen we simply reference ps0 directly without repeating + # f(s0). Unlike regular size variables, ps variables cannot be + # guarded upon; so if we are asked to guard on a Sympy expression + # which potentially could have already had a precomputed replacement + # on it, we are obligated to invert the precomputed replacements + # (inv_precomputed_replacements). + self.precomputed_replacements: Dict[Expr, sympy.Symbol] = {} + self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = {} + self.stride_vars = self.make_stride_vars_cache() + self.simplify_with_ranges = self.make_simplify_with_ranges_cache() + self._simplify_loops = self.make_simplify_loops_cache() + + def simplify(self, expr: Expr): + return sympy.expand(expr).xreplace(self.replacements) + + def make_simplify_with_ranges_cache(self) -> Callable[[Expr, VarRanges], Expr]: + """ + self._simplify_with_ranges() can be expensive, cache its results + """ + cache: Dict[Tuple[Any, ...], Expr] = {} + replacement_count = len(self.replacements) + + def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr: + nonlocal replacement_count + if replacement_count != len(self.replacements): + # new replacements invalidates cached results + cache.clear() + replacement_count = len(self.replacements) + key = (expr, *var_ranges.items()) + result = cache.get(key, None) + if result is None: + result = self._simplify_with_ranges(expr, var_ranges) + cache[key] = result + return result + + return simplify_with_ranges + + def make_simplify_loops_cache(self): + """ + self._simplify_with_ranges() can be expensive, cache its results + """ + cache: Dict[Tuple[Any, ...], Any] = {} + replacement_count = len(self.replacements) + + def simplify_loops(index_vars, sizes, index_formulas): + nonlocal replacement_count + if replacement_count != len(self.replacements): + # new replacements invalidates cached results + cache.clear() + replacement_count = len(self.replacements) + key = (*index_vars, *sizes, *index_formulas) + result = cache.get(key, None) + if result is None: + result = self._simplify_loops_impl(index_vars, sizes, index_formulas) + cache[key] = result + return result + + return simplify_loops + + def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr: + """ + Simplify indexing expression with knowledge of the ranges of + iteration variables. + """ + + expr = join_dimensions(self.simplify(expr)) + original_expr = expr + + var_to_range = dict(self.shape_env.var_to_range) + var_to_range.update( + { + k: ValueRanges( + 0, max(0, v - 1) if not has_free_symbols([v]) else IntInfinity() + ) + for k, v in var_ranges.items() + } + ) + for var in expr.free_symbols: + if var not in var_to_range: + var_to_range[var] = ValueRanges(0, IntInfinity()) + + var_to_range_tuple = cast( + Tuple[Tuple[sympy.Symbol, ValueRanges[sympy.Expr]]], + tuple(var_to_range.items()), + ) + + axioms = [] + for var, upper_bound in var_ranges.items(): + axioms.append(0 <= var) + axioms.append(var < upper_bound) + axioms = tuple(axioms) + self.shape_env.get_axioms() + + def statically_known(expr): + evaluated = self.shape_env._maybe_evaluate_static( + expr, + axioms=axioms, + var_to_range=var_to_range_tuple, + ) + return bool(evaluated) + + def remove_zero_terms(base, divisor): + """Symbols smaller than the divisor are zero""" + if not statically_known(base >= 0): + return base + + for v in base.free_symbols: + if v in var_ranges: + # var smaller than divisor can be removed + # if the rest is guaranteed to be multiple of divisor + rest = sympy.Wild("_rest", exclude=[v]) + m = base.match(v + rest) + if m and v not in m[rest].free_symbols: + gcd = sympy.gcd(m[rest], divisor) + if gcd == divisor: + if statically_known(v < divisor): + base = m[rest] + return base + + def visit_indexing_div(base, divisor): + return FloorDiv(remove_zero_terms(base, divisor), divisor) + + def visit_modular_indexing(base, divisor, modulus): + base = remove_zero_terms(base, divisor) + + can_remove_mod = statically_known(base >= 0) and statically_known( + base < modulus * divisor + ) + + if can_remove_mod: + return FloorDiv(base, divisor) + return ModularIndexing(base, divisor, modulus) + + if expr.has(ModularIndexing): + expr = expr.replace( + ModularIndexing( + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), + sympy.Wild("modulus", integer=True), + ), + visit_modular_indexing, + ) + + if expr.has(FloorDiv): + expr = expr.replace( + FloorDiv( + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), + ), + visit_indexing_div, + ) + + if expr != original_expr: + return self._simplify_with_ranges(expr, var_ranges) + return expr + + def _simplify_loops_impl( + self, index_vars: List[sympy.Symbol], sizes, index_formulas + ): + """ + Try to remove as many axis from loop iterations as possible, by: + 1) removing size==1 dimensions + 2) fuse contiguous dimensions into a single loop + If channel_last = True, we will prevent the last dim fused with other dims + """ + sizes = list(map(self.simplify, sizes)) + + strides = [ + # index_formulas may contain boolean expressions (e.g. s0 < 10), + # for which "strides" don't make sense so we ignore them here. + # NOTE: These expressions may still block merging dims in the sound + # substitution test performed in can_merge_dims. + self.stride_vars(x, index_vars) + if isinstance(x, sympy.Expr) + else [0] * len(index_vars) + for x in index_formulas + ] + assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0])) + + for i in range(len(sizes)): + if sizes[i] == 1: + # remove dim + sizes[i] = None + + def can_merge_dims(a, b): + for k in range(len(strides)): + if self.simplify(strides[k][a] * sizes[a]) == self.simplify( + strides[k][b] + ): + # approximate test passed, try sound version + va = index_vars[a] + vb = index_vars[b] + m1 = sympy_index_symbol("_merge_tester1") + m2 = sympy_index_symbol("_merge_tester2") + # NOTE: can't sub vb=0 here in case va * vb appears in the expression, + # in which case both expr1 and expr2 would be zero! + expr1 = sympy_subs(index_formulas[k], {va: m1 * sizes[a], vb: m2}) + expr2 = sympy_subs(index_formulas[k], {va: 0, vb: (m1 + m2)}) + if self.simplify(expr1) == self.simplify(expr2): + continue + return False + return True + + changed = True + while changed: + changed = False + for i, j in itertools.product( + reversed(range(len(sizes))), reversed(range(len(sizes))) + ): + if i == j or sizes[i] is None or sizes[j] is None: + continue + if can_merge_dims(i, j): + changed = True + sizes[i] = sizes[i] * sizes[j] + sizes[j] = None + + def reindex(index): + it = list(reversed(index)) + new_index = [] + for size in sizes: + if size is None: + new_index.append(sympy.Integer(0)) + else: + new_index.append(it.pop()) + assert not it + return new_index + + def prune(index): + assert len(index) == len(sizes) + return [i for i, s in zip(index, sizes) if s is not None] + + return [x for x in sizes if x is not None], reindex, prune + + # Note - [On Statically Known] + # + # The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system + # operated by providing essentially a question, where the size hinted values were evaluated. If the condition was + # true, we add a guard and return True, otherwise, False. + # + # def maybe_guard_foo(args): + # if size_hinted_check(args): + # return False # No guard, no optim + # guard(args) # Make a guard + # return True # Safe to apply optimization + # + # The prior system incurred a guard, and green lit an optimization. + # + # The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the + # condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we + # return False. + # + # def maybe_guard_foo(args): + # if all_static(args): + # return True # Safe to apply optimization + # else: + # return False # No guard, no optim + + # See Note - [On Statically Known] + + def is_expr_static_and_true(self, expr: Union[sympy.Basic, bool]) -> bool: + return evaluate_expr(self.shape_env, expr) + + def statically_known_equals( + self, left: Union[Expr, int], right: Union[Expr, int] + ) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left and right are equal. + """ + return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type] + + # See Note - [On Statically Known] + def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left and right lists are equal. + """ + return len(left) == len(right) and all( + self.statically_known_equals(l, r) for l, r in zip(left, right) + ) + + # See Note - [On Statically Known] + def statically_known_leq(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is less than or equal to right. + """ + expr = left <= right + return self.is_expr_static_and_true(expr) + + # See Note - [On Statically Known] + def statically_known_geq(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right. + """ + expr = left >= right + return self.is_expr_static_and_true(expr) + + # See Note - [On Statically Known] + def statically_known_lt(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is less than right. + """ + expr = left < right + return self.is_expr_static_and_true(expr) + + # See Note - [On Statically Known] + def statically_known_gt(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is greater than right. + """ + expr = left > right + return self.is_expr_static_and_true(expr) + + # See Note - [On Statically Known] + def statically_known_multiple_of( + self, numerator: Expr, denominator: Union[Expr, int] + ) -> bool: + """ + Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator. + """ + if free_unbacked_symbols(numerator) or free_unbacked_symbols(denominator): + return False + expr = sympy.Eq(numerator % denominator, 0) + return self.is_expr_static_and_true(expr) # type: ignore[arg-type] + + # See Note - [On Statically Known] + def statically_known_power_of_2(self, expr: Expr) -> bool: + """ + Returns a bool indicating if x is known to be a power of 2. + """ + return isinstance(expr, sympy.Integer) and is_power_of_2(int(expr)) + + # The guard functions require you to ALREADY KNOW that a particular + # condition holds. If you don't know (you want to guard on an expression + # being a particular value, and then get access to that value), use + # the evaluate functions. + + def guard_equals(self, left: Expr, right: Expr) -> Expr: + if isinstance(left, Expr): + left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] + if isinstance(right, Expr): + right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] + assert self.shape_env.evaluate_expr(sympy.Eq(left, right)) + return left + + def guard_leq(self, left: Expr, right: Expr) -> None: + return self.guard_lt(left, right + 1) + + def guard_lt(self, left: Expr, right: Expr) -> None: + assert self.shape_env.evaluate_expr(sympy.Lt(left, right)) + + def guarded_order(self, seq): + """ + Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing. + """ + seq = [*map(self.remove_precomputed_replacements, seq)] + seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)] + seq.sort() + order = [-1] * len(seq) + last_var = None + for new_index, (_, orig_index, var) in enumerate(seq): + order[orig_index] = new_index + if last_var is not None: + self.guard_leq(last_var, var) + last_var = var + return order + + # The evaluate functions evaluate some symbolic sympy expression + # (NB: not necessarily an Expr) and return what the concrete result + # is, guarding on the expression being that result + + # NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b) + # as this will ensure that you actually have a sympy'ified expression, + # and will prevent you from incorrectly writing evaluate_expr(a == b) + # which does the wrong thing if a or b is a sympy expression + def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool: + assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left) + return self.shape_env.evaluate_expr(sympy.sympify(left)) + + def evaluate_min(self, left: Expr, right: Expr) -> Expr: + """return the smaller of left and right, and guard on that choice""" + if isinstance(left, Expr): + left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] + if isinstance(right, Expr): + right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] + try: + lv = self.size_hint(left) + rv = self.size_hint(right) + except TypeError: # unbacked symints + if left == right or self.statically_known_leq(left, right): + return left + if self.statically_known_leq(right, left): + return right + gcd = sympy.gcd(left, right) + if left == gcd: # handle `min(10*u0, u0)` etc + return left + if right == gcd: + return right + raise TypeError( + f"evaluate_min({left}, {right}) with unbacked symints" + ) from None + if lv <= rv: + self.guard_leq(left, right) + return left + else: + self.guard_leq(right, left) + return right + + def evaluate_max(self, left: Expr, right: Expr) -> Expr: + """return the larger of left and right, and guard on that choice""" + # Always choose the opposite of eval min for consistency + # This means min(a, b) and max(a, b) produce the same guards + min_val = self.evaluate_min(left, right) + return right if min_val is left else left + + def evaluate_static_shape(self, left: Union[Expr, int]) -> int: + if isinstance(left, int): + return left + right = self.size_hint(left) + self.guard_equals(left, sympy.Integer(right)) + return int(right) + + def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> List[int]: + return [self.evaluate_static_shape(x) for x in left] + + def remove_precomputed_replacements(self, expr: Expr) -> Expr: + if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols): # type: ignore[attr-defined] + return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type] + return expr + + def symbolic_hint(self, expr: Union[Expr, int]) -> Union[Expr, int]: + if isinstance(expr, int): + return expr + # Substitute all hints into expr, but leave unbacked symints alone + expr = self.simplify(expr) + if not isinstance(expr, Expr): + assert isinstance(expr, int) + return expr + free_symbols = expr.free_symbols + if not free_symbols: + try: + return int(expr) # type: ignore[return-value] + except TypeError: + return expr # inf/nan/I + expr = self.remove_precomputed_replacements(expr) + return sympy_subs(expr, self.var_to_val) + + def size_hint( + self, expr: Union[Expr, int], *, fallback: Optional[int] = None + ) -> int: + out = self.symbolic_hint(expr) + if not isinstance(out, (int, sympy.Integer)) and fallback is not None: + # Use the provided heuristic fallback hint + unbacked_sym_vrs = { + s: self.shape_env.var_to_range.get(s, None) for s in out.free_symbols + } + if all(vr is not None for vr in unbacked_sym_vrs.values()): + hint_vr = bound_sympy(out, unbacked_sym_vrs) # type: ignore[arg-type] + if isinstance(hint_vr.lower, (int, sympy.Integer)): + fallback = max(fallback, int(hint_vr.lower)) + if isinstance(hint_vr.upper, (int, sympy.Integer)): + fallback = min(fallback, int(hint_vr.upper)) + return fallback + + try: + return int(out) + except Exception: + log.debug("failed on: %s", out) + raise + + def size_hints( + self, + exprs: Iterable[Expr], + *, + fallback: Optional[int] = None, + ) -> Tuple[int, ...]: + return tuple(self.size_hint(x, fallback=fallback) for x in exprs) + + def _lru_cache(self, fn, maxsize=None): + """ + Wrapper around functools.lru_cache that clears when replacements + has been invalidated. + """ + fn_cache = functools.lru_cache(maxsize)(fn) + prior_len = len(self.replacements) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + nonlocal prior_len + if prior_len != len(self.replacements): + prior_len = len(self.replacements) + fn_cache.cache_clear() + return fn_cache(*args, **kwargs) + + return wrapper + + def make_stride_vars_cache(self): + cache = self._lru_cache(self._stride_vars) + + def stride_vars( + index: Expr, + vars: Sequence[sympy.Symbol], + support_vars: Optional[Sequence[sympy.Symbol]] = None, + ) -> List[Expr]: + if not support_vars: + support_vars = vars + return cache(index, tuple(vars), tuple(support_vars)) + + return stride_vars + + def _stride_vars( + self, + index: Expr, + vars: Sequence[sympy.Symbol], + support_vars: Sequence[sympy.Symbol], + ) -> List[Expr]: + """Convert an indexing expression back into strides + + NOTE: This is only valid if the index is a standard strided offset + calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a + stride of -10 because the index wraps around after the first element + + """ + strides = [] + index = self.simplify(index) + # remove any offset + index = index - sympy_subs( + index, {v: sympy.Integer(0) for v in support_vars if v != 0} + ) + for i in range(len(vars)): + # drop all the other dims + index_dim = sympy_subs( + index, + { + support_vars[j]: sympy.Integer(0) + for j in range(len(support_vars)) + if vars[i] != support_vars[j] and support_vars[j] != 0 + }, + ) + v = vars[i] + if v == 0: + strides.append(sympy.Integer(0)) + else: + # TODO(jansel): should we use sympy.diff here? + strides.append( + sympy_subs(index_dim, {v: sympy.Integer(1)}) + - sympy_subs(index_dim, {v: sympy.Integer(0)}) + ) + return strides + + def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr: + """Extract offset part of an indexing expression""" + index = self.simplify(index) + return sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0}) + + def stride_hints( + self, + index: Expr, + vars: Sequence[sympy.Symbol], + support_vars: Optional[Sequence[sympy.Symbol]] = None, + ) -> List[int]: + for v in index.free_symbols: + if symbol_is_type(v, SymT.INDIRECT): # type: ignore[attr-defined] + index = sympy_subs(index, {v: 0}) # type: ignore[dict-item] + result = [] + for s in self.stride_vars(index, vars, support_vars): + try: + result.append(self.size_hint(s)) + except TypeError: + result.append(0) + return result + + def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]: + strides = tuple(map(abs, self.stride_hints(index, vars))) + order = list(range(len(strides))) + order.sort(key=lambda x: (strides[x] == 0, strides[x])) + return order + + def lookup_precomputed_size(self, expr: Expr) -> Expr: + if ( + isinstance(expr, (int, sympy.Symbol, sympy.Number)) + or expr.is_number + or expr.is_symbol + ): + return expr + expr = self.remove_precomputed_replacements(expr) + if expr not in self.precomputed_replacements: + sym = sympy_index_symbol_with_prefix( + SymT.PRECOMPUTED_SIZE, len(self.precomputed_replacements) + ) + self.precomputed_replacements[expr] = sym + self.inv_precomputed_replacements[sym] = expr + return self.precomputed_replacements[expr] + + def free_symbols(self) -> Set[sympy.Symbol]: + return set(self.var_to_val.keys()) - set(self.replacements.keys()) + + def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr: + """ + A pair of special ModularIndexing can be combined. + + E.g. ModularIndexing(ModularIndexing(x, 1, a), 1, b) + We can simplify this to ModuleIndexing(x, 1, b), if + 1. x is non negative integer + 2. a and b are positive integers + 3. a is a multiple of b. + """ + + def _check_args(x, div, mod, is_first): + if not isinstance(div, sympy.Integer) or not isinstance(mod, sympy.Integer): + return False + if div != 1: + return False + if mod <= 0: + return False + + if is_first: + # first ModularIndexing should conatins a nested ModularIndex + if not isinstance(x, ModularIndexing): + return False + else: + # second ModularIndexing should constains a non-negative + # symbol + if not isinstance(x, sympy.Symbol) or not self.statically_known_geq( + x, 0 + ): + return False + return True + + if isinstance(index, ModularIndexing): + x, div, mod = index.args + + if not _check_args(x, div, mod, True): + return index + + x2, div2, mod2 = x.args + + if not _check_args(x2, div2, mod2, False): + return index + + if mod2 % mod != 0: + return index + + return ModularIndexing(x2, 1, mod) + + return index + + def expand_floor_div( + self, index: sympy.Expr + ) -> Union[bool, Tuple[sympy.Expr, sympy.Expr]]: + """ + Expand the FloorDiv to the entire expression so that the expression may + be simplfied. + + E.g., for a 2D contiguous tensor with shape [a, 2 * b], and index variables + x1, x2, index expression 'x1 * 2b + x2' can be easily combined. + But index expression 'x1 * b + x2 // 2' can not. + By expanding the FloorDiv to the entire expression, we get + '(x1 * 2b + x2) // 2'. This transformation allows us to merge loops + for the numerator! + + Return false if this optimization can be applied; + Return the new expression and the denominator otherwise. + The original expression will be equivalent to 'new_expression // denominator' + """ + if not isinstance(index, sympy.Add): + return False + terms = index.args + + if len(terms) < 2: + return False + floor_div_index = -1 + varlist = [] + factorlist = [] + for idx, term in enumerate(terms): + if isinstance(term, sympy.Mul): + # For dynamic shape, term like '2*s1*x1' has 3 child nodes. + # - A integer for 2 + # - A symbol for s1 + # - A symbol for x1 + # Skip for now. + if len(term.args) != 2: + return False + factor, var = term.args + varlist.append(var) + factorlist.append(factor) + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + # It's easier to reason about the correceness of the transformation + # for non-negative integers. + if not self.statically_known_geq(var, 0): + return False + elif isinstance(term, FloorDiv): + var, factor = term.args + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + if not self.statically_known_geq(var, 0): + return False + if floor_div_index >= 0: + # can not handle multi FloorDiv yet + return False + + floor_div_index = idx + varlist.append(var) + # this factor is denominator + factorlist.append(factor) + else: + return False + + if floor_div_index < 0: + return False + + # Construct the new expression and remember the denominator + denominator = factorlist[floor_div_index] + new_index = sympy.Integer(0) + + for var, factor, idx in zip(varlist, factorlist, itertools.count()): + if idx == floor_div_index: + new_index += var + else: + new_index += (factor * denominator) * var + + return new_index, denominator + + +def join_dimensions(expr: Expr) -> Expr: + if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing): + return expr # fast exit path + return _join_dimensions_cached(expr) + + +@functools.lru_cache(256) +def _join_dimensions_cached(expr: Expr) -> Expr: + """ + ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4) + becomes + ModularIndexing(i0, 1, 128) + ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32) + becomes i0 + + + This type of pattern can come from view operations + """ + assert isinstance(expr, sympy.Add) + + scale = sympy.Wild("scale", exclude=[0], integer=True) + base = sympy.Wild("base", integer=True) + divisor = sympy.Wild("divisor", integer=True) + mod1 = sympy.Wild("modulus", integer=True) + mod2 = sympy.Wild("modulus2", integer=True) + for term1 in expr.args: + m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) + if m1: + for term2 in expr.args: + m2 = term2.match( + m1[scale] + * m1[mod1] + * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2) + ) + if m2 and term1 != term2: + expr = join_dimensions( + expr + - term1 + - term2 + + m1[scale] + * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2]) + ) + return expr + for term1 in expr.args: + m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) + if m1: + for term2 in expr.args: + m2 = term2.match( + m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1]) + ) + if m2 is not None: # in case of success we get an empty dict here + expr = join_dimensions( + expr + - term1 + - term2 + + m1[scale] * FloorDiv(m1[base], m1[divisor]) + ) + return expr + return expr + + +class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined] + """ + A wrapper around .virtualize.ops that uses var range information to + simplify ModularIndexing/FloorDiv. + """ + + def __init__(self, inner, var_ranges: VarRanges) -> None: + super().__init__(inner) + self.name = "SimplifyIndexing" + self._simplify: Callable[ + [Expr], Expr + ] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges) + + def load(self, name: str, index: sympy.Expr): + return self._inner.load(name, self._simplify(index)) + + def store(self, name, index, value, mode=None): + return self._inner.store(name, self._simplify(index), value, mode=mode) + + def store_reduction(self, name, index, value): + return self._inner.store_reduction(name, self._simplify(index), value) + + def index_expr(self, index, dtype): + return self._inner.index_expr(self._simplify(index), dtype) + + def check_bounds(self, index, size, lower, upper): + return self._inner.check_bounds(self._simplify(index), size, lower, upper) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/utils.py b/.venv/lib/python3.11/site-packages/torch/_inductor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b17f3a68559a60d0d84f814739cc3b325b07594c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/utils.py @@ -0,0 +1,2037 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import enum +import functools +import inspect +import io +import itertools +import logging +import math +import operator +import os +import platform +import shutil +import sys +import tempfile +import textwrap +import time +import unittest +from datetime import datetime +from io import StringIO +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + NamedTuple, + Optional, + Protocol, + Sequence, + Set, + TypeVar, + Union, + ValuesView, +) +from typing_extensions import Concatenate, ParamSpec +from unittest import mock + +import sympy + +import torch + + +GPU_TYPES = ["cuda", "xpu"] + + +# defines here before import torch._dynamo is for avoiding circular import +# when get_gpu_type is imported from dynamo +@functools.lru_cache(None) +def get_gpu_type(): + avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()] + assert len(avail_gpus) <= 1 + gpu_type = "cuda" if len(avail_gpus) == 0 else avail_gpus.pop() + return gpu_type + + +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.utils import detect_fake_mode +from torch.autograd import DeviceType +from torch.autograd.profiler_util import EventList +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import ShapeProp +from torch.utils._sympy.functions import ( + CeilDiv, + CleanDiv, + FloorDiv, + Identity, + ModularIndexing, +) +from torch.utils._sympy.symbol import make_symbol, SymT +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + +from . import config +from .runtime.runtime_utils import ceildiv as runtime_ceildiv + + +_IS_WINDOWS = sys.platform == "win32" + +log = logging.getLogger(__name__) + +_T = TypeVar("_T") +VarRanges = Dict[sympy.Expr, sympy.Expr] +InputType = Union[torch.Tensor, int] + + +GPU_ALIGN_BYTES = 16 +ALIGNMENT = 16 + +ALIGN_BYTES = 64 +assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" + + +def _align(nbytes): + """Round up to the nearest multiple of ALIGN_BYTES""" + return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES + + +def _is_aligned(v: sympy.Expr): + """v can be statically proven to be a multiple of ALIGN_BYTES""" + if isinstance(v, (sympy.Add, sympy.Max)): + return all(map(_is_aligned, v.args)) + return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES + + +class align(sympy.Function): + """Symbolically round up to the nearest multiple of ALIGN_BYTES""" + + nargs = (1,) + is_integer = True + + @classmethod + def eval(cls, value): + if isinstance(value, (int, sympy.Integer)): + return _align(int(value)) + if _is_aligned(value): + return value + + +def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float: + """ + Returns benchmark results by examining torch profiler events. + This could be more accurate as it doesn't count CPU side overhead. + However, this also requires manually excluding irrelevant event, e.g. + vectorized_elementwise_kernel which is used to fill L2 cache, + various CUDA events, etc, so could also be fragile. + """ + + fn() + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + + # Warm-up + for _ in range(n_warmup): + fn() + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + # Benchmark + for i in range(n_repeat): + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + fn() + # Record clocks + torch.cuda.synchronize() + + log.debug("raw events") + log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1)) + + filtered_events = EventList( + [ + event + for event in p.events() + if event.device_type == DeviceType.CUDA and event.name != "Context Sync" + ] + ) + if len(filtered_events) % n_repeat != 0: + raise RuntimeError( + "Failed to divide all profiling events into #repeat groups. " + "#CUDA events: %d, #repeats: %s", + len(filtered_events), + n_repeat, + ) + num_event_per_group = len(filtered_events) / n_repeat + actual_events = EventList( + [ + event + for i, event in enumerate(filtered_events) + if i % num_event_per_group != 0 + ] + ) + actual_events._build_tree() + actual_events = actual_events.key_averages() + + log.debug("profiling time breakdown") + log.debug(actual_events.table(row_limit=-1)) + + res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat + log.debug("profiling results: %s ms", res) + return res + + +@functools.lru_cache(None) +def has_torchvision_roi_align() -> bool: + try: + from torchvision.ops import roi_align # noqa: F401 + + torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta") + return roi_align is not None and hasattr( + getattr(torch.ops, "torchvision", None), "roi_align" + ) + except ImportError: + return False + except RuntimeError as e: + assert "torchvision::nms does not exist" in str(e) + return False + + +def decode_device(device: Union[Optional[torch.device], str]) -> torch.device: + if device is None: + return torch.tensor(0.0).device # default device + if isinstance(device, str): + device = torch.device(device) + if device.type not in ("cpu", "meta") and device.index is None: + device_interface = get_interface_for_device(device.type) + return torch.device(device.type, index=device_interface.Worker.current_device()) + return device + + +def sympy_product(it): + return functools.reduce(operator.mul, it, sympy.Integer(1)) + + +def sympy_dot(seq1, seq2): + assert len(seq1) == len(seq2) + return sympy.expand(sum(a * b for a, b in zip(seq1, seq2))) + + +def unique(it: Iterable[_T]) -> ValuesView[_T]: + return {id(x): x for x in it}.values() + + +def ceildiv( + numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] +) -> Union[int, sympy.Expr]: + if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): + return CeilDiv(sympy.sympify(numer), sympy.sympify(denom)) + # TODO: There is a bug in a call to this function, to repro: + # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy + # --amp --only YituTechConvBert --dynamic-shapes + assert isinstance(numer, int) and isinstance( + denom, int + ), f"{numer}: {type(numer)}, {denom}: {type(denom)}" + return runtime_ceildiv(numer, denom) + + +def _type_of(key): + # Use the function here to get rid of dependencies on the Triton during the codegen. + # Refer to Triton implementation here: + # https://github.com/openai/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238 + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + dtype_str = str(key).split(".")[-1] + tys = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8e4b15x4": "fp8e4b15x4", + "float8_e4m3fn": "fp8e4nv", + "float8_e5m2": "fp8e5", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", + } + # reinterpret can create triton type + for v in list(tys.values()): + tys[v] = v + return key if isinstance(key, str) else f"*{tys[dtype_str]}" + + +def convert_shape_to_inductor( + lst: Iterable[Union[int, torch.SymInt]] +) -> List[sympy.Expr]: + """ + Gets the shape and stride of a tensor. For non-symbolic tensors, this is + trivial. But for symbolic tensors, we need to map from SymIntNode into + sympy.Expr. + """ + return [sympy.sympify(i) for i in lst] + + +def convert_shape_to_symint( + lst: Iterable[Union[int, sympy.Expr]] +) -> List[Union[int, torch.SymInt]]: + """ + Takes a list of shapes from Inductor and converts them into symints (or just + ints if all shapes are static). + """ + from .virtualized import V + + return [ + i + if isinstance(i, int) + else int(i) + if isinstance(i, sympy.Integer) + else V.graph.sizevars.shape_env.create_symintnode(i, hint=None) + for i in lst + ] + + +def is_view(op: torch._ops.OpOverload): + """ + Does this op overload have aliasing + """ + assert isinstance(op, torch._ops.OpOverload) + return any(a.alias_info is not None for a in op._schema.arguments) + + +def is_pointwise_use( + use, is_pointwise_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None +): + """ + Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn` + + Uses in views ops will follow the views uses + """ + + if not use.op == "call_function": + return False + + if not ( + isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem + ): + return False + + if use.target is operator.getitem or is_view(use.target): + return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users) + + return torch.Tag.pointwise in use.target.tags or ( + is_pointwise_fn is not None and is_pointwise_fn(use.target) + ) + + +def gen_gm_and_inputs(target, args, kwargs): + g = torch.fx.Graph() + g_args = [] + a_args = [] + for n, arg in enumerate(args): + if isinstance(arg, torch.Tensor): + g_args.append(g.placeholder(f"arg{n}")) + a_args.append(arg) + else: + g_args.append(arg) + assert all(not isinstance(x, torch.Tensor) for x in kwargs.values()) + node = g.call_function(target, tuple(g_args), kwargs) + if ( + len(target._schema.returns) == 1 + and str(target._schema.returns[0].type) == "Tensor" + ): + node = (node,) # type: ignore[assignment] + g.output(node) + + gm = torch.fx.GraphModule({}, g) + return gm, a_args + + +def synchronize(device: str = "cuda"): + if device == "cpu": + return + device_interface = get_interface_for_device(device) + if device_interface.is_available(): + device_interface.synchronize() + + +def timed( + model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda" +) -> float: + synchronize(device) + torch.manual_seed(1337) + t0 = time.perf_counter() + for _ in range(times): + result = model(*example_inputs) + synchronize(device) + t1 = time.perf_counter() + # GC the result after timing + assert result is not None # type: ignore[possibly-undefined] + return t1 - t0 + + +def print_performance( + fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda" +): + timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)]) + took = torch.median(timings) / times + print(f"{took / baseline:.6f}") + return took + + +def precompute_method(obj: Any, method: str): + """Replace obj.method() with a new method that returns a precomputed constant.""" + result = getattr(obj, method)() + setattr(obj, method, lambda: result) + + +def precompute_methods(obj: Any, methods: List[str]): + """Replace methods with new methods that returns a precomputed constants.""" + for method in methods: + precompute_method(obj, method) + + +def cmp(a, b) -> int: + return int(a > b) - int(a < b) + + +def pad_listlike(x, size): + if len(x) == 1: + return type(x)([x[0]]) * size + else: + return x + + +# Used to ensure that iterating over a set is deterministic +def tuple_sorted(x): + if len(x) == 0: + return [] + + def sort_func(elem): + if isinstance(elem, str): + return elem + else: + # We expect `elem` to be `scheduler.BaseSchedulerNode` type here, + # but we are not able to do isinstance assert because of circular dependency + return elem.get_name() + + return sorted(x, key=sort_func) + + +P = ParamSpec("P") +RV = TypeVar("RV", covariant=True) + + +class CachedMethod(Protocol, Generic[P, RV]): + @staticmethod + def clear_cache(self) -> None: + ... + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: + ... + + +# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature +def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]: + key = f"__{fn.__name__}_cache" + + @functools.wraps(fn) + def wrapper(self): + if not hasattr(self, key): + setattr(self, key, fn(self)) + return getattr(self, key) + + def clear_cache(self): + if hasattr(self, key): + delattr(self, key) + + wrapper.clear_cache = clear_cache # type: ignore[attr-defined] + return wrapper # type: ignore[return-value] + + +def aggregate_origins(node_schedule): + from . import ir + + if isinstance(node_schedule, list): + return functools.reduce( + operator.or_, + [ + node.node.origins + for node in node_schedule + if hasattr(node, "node") and node.node + ], + set(), + ) + elif isinstance(node_schedule, ir.ExternKernel): + return node_schedule.origins + else: + return set() + + +def get_fused_kernel_name(node_schedule, descriptive_names): + all_origins = aggregate_origins(node_schedule) + if descriptive_names == "original_aten": + # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions) + sources = [ + origin.meta["original_aten"]._overloadpacket.__name__ + for origin in all_origins + if origin.op == "call_function" + and "original_aten" in origin.meta + and origin.meta["original_aten"] is not None + ] + sources = sorted(set(sources)) + elif descriptive_names == "torch": + # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph) + sources = [] + for origin in all_origins: + if origin.op == "call_function" and "source_fn_stack" in origin.meta: + source_fn = origin.meta["source_fn_stack"][-1] + if isinstance(source_fn[1], str): + sources.append(source_fn[1]) + else: + sources.append(source_fn[1].__name__) + sources = sorted(set(sources)) + elif descriptive_names == "inductor_node": + sources = [ + origin.name for origin in all_origins if origin.op == "call_function" + ] + else: + raise NotImplementedError + sources = sources + return "_".join(["fused"] + sources) + + +def get_kernel_metadata(node_schedule, wrapper): + all_origins = aggregate_origins(node_schedule) + inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"] + + from_node_dict = collections.defaultdict(list) + original_aten_dict = collections.defaultdict(list) + + # Attempt to sort `inductor_nodes` topologically. Note that the case + # where `inductor_nodes` contains nodes from multiple graph instances + # is not supported. An example of this is conditional statements. + single_graph = None + if len(inductor_nodes): + unique_graphs = {n.graph for n in inductor_nodes} + if len(unique_graphs) == 1: + single_graph = inductor_nodes[0].graph + # create a map of idx -> node and cache it + if not hasattr(single_graph, "_inductor_kernel_metadata_node_to_idx_map"): + node_to_idx_map = {} + for idx, n in enumerate(single_graph.nodes): + node_to_idx_map[n] = idx + single_graph._inductor_kernel_metadata_node_to_idx_map = node_to_idx_map + inductor_nodes.sort( + key=lambda n: single_graph._inductor_kernel_metadata_node_to_idx_map[n] + ) + + for node in inductor_nodes: + if "original_aten" in node.meta and node.meta["original_aten"] is not None: + key = str(node.meta["original_aten"]._overloadpacket) + original_aten_dict[key].append(node.name) + if "from_node" in node.meta: + key = node.meta["from_node"][0][0] + from_node_dict[key].append(node.name) + sort_str = "Topologically Sorted" if single_graph is not None else "Unsorted" + metadata = ( + f"{wrapper.comment} {sort_str} Source Nodes: [{', '.join(from_node_dict.keys())}], " + f"Original ATen: [{', '.join(original_aten_dict.keys())}]" + ) + + # trace back to original node here + detailed_metadata = [f"{wrapper.comment} Source node to ATen node mapping:"] + for original_node, nodes in sorted(from_node_dict.items()): + detailed_metadata.append( + f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}" + ) + + # print the aot_autograd graph fragment + if single_graph is not None: + detailed_metadata.append(f"{wrapper.comment} Graph fragment:") + for n in inductor_nodes: + # TODO(future): maybe refactor torch/fx/graph.py to make it easy to + # generate python code for graph fragments + detailed_metadata.append(f"{wrapper.comment} {n.format_node()}") + + return metadata, "\n".join(detailed_metadata) + + +def dominated_nodes( + initial_queue: Iterable[torch.fx.Node], skip_filter=None +) -> Set[torch.fx.Node]: + """Returns the set of nodes whose values depend on those within initial_queue""" + initial_queue = list(initial_queue) + dominated_set = set(initial_queue) + + while initial_queue: + node = initial_queue.pop() + for user in node.users: + if skip_filter and skip_filter(user): + continue + if user not in dominated_set: + dominated_set.add(user) + initial_queue.append(user) + + return dominated_set + + +def gather_origins(args, kwargs): + import itertools + + from . import ir + + def is_unrealized_node(n): + if isinstance(n, ir.TensorBox): + return is_unrealized_node(n.data) + if isinstance(n, ir.StorageBox): + return is_unrealized_node(n.data) + return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise) + + kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)] + arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)] + return set(itertools.chain(*arg_origins, *kwarg_origins)) + + +def sympy_str(expr: sympy.Expr) -> str: + """ + Normal sympy str is very slow, this is a lot faster. The result are + somewhat worse, as it doesn't do as much simplification. So don't + use this for final codegen. + """ + if isinstance(expr, sympy.Symbol): + return expr.name + if isinstance(expr, sympy.Add): + return " + ".join(map(sympy_str, expr.args)) + if isinstance(expr, sympy.Mul): + return " * ".join(map(sympy_str, expr.args)) + + if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)): + return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})" + return str(expr) + + +def get_bounds_index_expr(index): + from .virtualized import V + + # If this expression does not come from an FX node, we compute its bounds + if ( + config.compute_all_bounds + and (fx_node := getattr(V.interpreter, "current_node", None)) + and fx_node.target != "index_expr" + ): + return bound_sympy(index) + else: + return ValueRanges.unknown() + + +def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol: + """ + Used to generate an integer-nonnegative symbol. + """ + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert prefix != SymT.SIZE + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return make_symbol(prefix, idx, integer=True, nonnegative=True) + + +def generate_assert(check): + return (check or config.debug_index_asserts) and config.assert_indirect_indexing + + +def sympy_index_symbol(name: str) -> sympy.Symbol: + """ + Used to generate an integer-nonnegative symbol. + """ + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert name[0] != "s" + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return sympy.Symbol(name, integer=True, nonnegative=True) + + +def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr: + """ + When the passed replacement symbol v is a string, it is converted to a symbol with name v that + have the same replaced expression integer and nonnegative properties. + """ + + def to_symbol(replaced, replacement): + assert isinstance(replaced, sympy.Expr) + if isinstance(replacement, str): + return sympy.Symbol( + replacement, + integer=replaced.is_integer, # type: ignore[attr-defined] + nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined] + ) + else: + return replacement + + # xreplace is faster than subs, but is way more picky + return sympy.sympify(expr).xreplace( + {k: to_symbol(k, v) for k, v in replacements.items()} + ) + + +def is_symbolic(a: Any) -> bool: + return isinstance(a, torch.SymInt) or ( + isinstance(a, torch.Tensor) + and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride())) + ) + + +def any_is_symbolic(*args: Any) -> bool: + return any(is_symbolic(a) for a in args) + + +def get_first_incompatible_cudagraph_node(gm): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + forbidden_set = { + "aten._fused_moving_avg_obs_fq_helper.default", + "aten._fused_moving_avg_obs_fq_helper_functional.default", + "aten.multinomial.default", + "fbgemm.dense_to_jagged.default", + "fbgemm.jagged_to_padded_dense.default", + "run_and_save_rng_state", + "run_with_rng_state", + "aten._local_scalar_dense", + # Technically, it's not necessary to ban this, because an + # assert_scalar with constant arguments can be validly run + # with CUDA graphs, but the operator is also pointless with + # constant arguments, so might as well ban + "aten._assert_scalar", + } + if torch.are_deterministic_algorithms_enabled(): + forbidden_set.update( + { + "aten._unsafe_index_put.default", + "aten._unsafe_masked_index_put_accumulate.default", + "aten.index_put.default", + "aten.index_put_.default", + "aten.scatter.src", + "aten.scatter.reduce", + "aten.scatter.value_reduce", + "aten.scatter_add_", + "aten.scatter_add.default", + "aten.scatter_reduce.two", + "aten.scatter_reduce_.two", + "aten.scatter_reduce.two_out", + } + ) + for node in gm.graph.nodes: + if str(node.target) in forbidden_set: + return node + if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val): + return node + return None + + +def has_incompatible_cudagraph_ops(gm): + return get_first_incompatible_cudagraph_node(gm) is not None + + +def output_node(gm: torch.fx.GraphModule): + """Get the output node from an FX graph""" + last_node = next(iter(reversed(gm.graph.nodes))) + assert last_node.op == "output" + return last_node + + +_registered_caches: List[Any] = [] + + +def clear_on_fresh_inductor_cache(obj: Any): + """ + Use this decorator to register any caches that should be cache_clear'd + with fresh_inductor_cache(). + """ + if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear): + raise AttributeError(f"{obj} does not have a cache_clear method") + + _registered_caches.append(obj) + return obj + + +def clear_inductor_caches(): + """ + Clear all registered caches. + """ + for obj in _registered_caches: + obj.cache_clear() + + +@contextlib.contextmanager +def fresh_inductor_cache(cache_entries=None, dir=None, delete=True): + """ + Contextmanager that provides a clean tmp cachedir for inductor. + + Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes + generated with this cache instance. + """ + clear_inductor_caches() + + inductor_cache_dir = tempfile.mkdtemp(dir=dir) + try: + with mock.patch.dict( + os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} + ): + log.debug("Using inductor cache dir %s", inductor_cache_dir) + triton_cache_dir = os.path.join(inductor_cache_dir, "triton") + with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}): + yield + if isinstance(cache_entries, dict): + assert len(cache_entries) == 0, "expected empty cache_entries dict" + if os.path.exists(triton_cache_dir): + files = os.listdir(triton_cache_dir) + cache_entries.update( + { + f: os.path.getsize(os.path.join(triton_cache_dir, f)) + for f in files + if ".lock" not in f + } + ) + if delete: + shutil.rmtree(inductor_cache_dir) + except Exception: + if not _IS_WINDOWS: + """ + Windows can't delete the loaded modules, because the modules binaries are opened. + TODO: discuss if have better solution to handle this issue. + """ + log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir) + raise + finally: + clear_inductor_caches() + + +def argsort(seq) -> List[int]: + # preserve original order for equal strides + getter = seq.__getitem__ + a_r = range(len(seq)) + return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413 + + +@functools.lru_cache(8) +def get_dtype_size(dtype): + return torch.empty((), dtype=dtype).element_size() + + +class LineContext(NamedTuple): + context: Any + + +class IndentedBuffer: + tabwidth = 4 + + def __init__(self, initial_indent=0): + self._lines = [] + self._indent = initial_indent + + def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]: + buf = StringIO() + p = 1 + linemap = [] + for line in self._lines: + if isinstance(line, DeferredLineBase): + line = line() + if line is None: + continue + elif isinstance(line, LineContext): + linemap.append((p, line.context)) + continue + assert isinstance(line, str) + buf.write(line) + buf.write("\n") + p += 1 + line.count("\n") + return buf.getvalue(), linemap + + def getvalue(self) -> str: + v, _ = self.getvaluewithlinemap() + return v + + def getrawvalue(self) -> str: + buf = StringIO() + for line in self._lines: + if isinstance(line, DeferredLineBase): + line = line() + if line is None: + continue + elif isinstance(line, LineContext): + continue + assert isinstance(line, str) + # backslash implies line continuation + if line.endswith("\\"): + buf.write(line[:-1]) + else: + buf.write(line) + buf.write("\n") + return buf.getvalue() + + def clear(self): + self._lines.clear() + + def __bool__(self): + return bool(self._lines) + + def prefix(self): + return " " * (self._indent * self.tabwidth) + + def newline(self): + self.writeline("\n") + + def writeline(self, line): + if isinstance(line, LineContext): + self._lines.append(line) + elif isinstance(line, DeferredLineBase): + self._lines.append(line.with_prefix(self.prefix())) + elif line.strip(): + self._lines.append(f"{self.prefix()}{line}") + else: + self._lines.append("") + + def writelines(self, lines): + for line in lines: + self.writeline(line) + + def indent(self, offset=1): + @contextlib.contextmanager + def ctx(): + self._indent += offset + try: + yield + finally: + self._indent -= offset + + return ctx() + + def do_indent(self, offset=1): + self._indent += offset + + def do_unindent(self, offset=1): + self._indent -= offset + + def splice(self, other_code, strip=False): + if isinstance(other_code, IndentedBuffer): + dedent = float("inf") + for line in other_code._lines: + if not isinstance(line, LineContext) and line: + dedent = min(dedent, len(line) - len(line.lstrip())) + if math.isinf(dedent): + dedent = 0 + for line in other_code._lines: + if isinstance(line, LineContext): + self._lines.append(line) + else: + IndentedBuffer.writeline(self, line[int(dedent) :]) + else: + other_code = textwrap.dedent(other_code) + if strip: + other_code = other_code.lstrip() + if not other_code: + return + other_code = other_code.rstrip() + for line in other_code.split("\n"): + self.writeline(line) + + def map(self, func: Callable[[Any], Any]) -> IndentedBuffer: + res = IndentedBuffer(initial_indent=self._indent) + res._lines = [func(line) for line in self._lines] + return res + + def __repr__(self): + return f"{type(self)}({self.getvalue()})" + + def __add__(self, other): + assert self._indent == other._indent + res = IndentedBuffer(initial_indent=self._indent) + res.writelines(self._lines) + res.writelines(other._lines) + return res + + +class FakeIndentedBuffer(IndentedBuffer): + def __init__(self) -> None: + super().__init__() + + def __getattribute__(self, name): + if name == "__class__": # Allow access to the class attribute + return object.__getattribute__(self, name) + raise RuntimeError( + f"Tried to call self.{name} on FakeIndentedBuffer. This buffer" + "is currently used on TritonTemplateKernel to prevent actual" + "writes to the body without explicitly specifying the body with" + "`TritonTemplateKernel.set_subgraph_body(name)`" + ) + + +@contextlib.contextmanager +def restore_stdout_stderr(initial_stdout, initial_stderr): + try: + yield + finally: + sys.stdout = initial_stdout + sys.stderr = initial_stderr + + +class DeferredLineBase: + """A line that can be 'unwritten' at a later time""" + + def __init__(self, line): + if not line.strip(): + line = "" + self.line = line + + def __call__(self) -> Optional[str]: + """Returns either self.line or None to indicate the line has been 'unwritten'""" + raise NotImplementedError + + def _new_line(self, line: str) -> DeferredLineBase: + """Returns a new deferred line with the same condition""" + raise NotImplementedError + + def with_prefix(self, prefix): + return self._new_line(f"{prefix}{self.line}") + + def lstrip(self): + return self._new_line(self.line.lstrip()) + + def __getitem__(self, index): + return self._new_line(self.line[index]) + + def __bool__(self): + return bool(self.line) + + def __len__(self): + return len(self.line) + + +@functools.lru_cache(None) +def is_big_gpu(index) -> bool: + min_sms = 68 # 3080 + avail_sms = torch.cuda.get_device_properties(index).multi_processor_count + if avail_sms < min_sms: + log.warning( + "Not enough SMs to use max_autotune_gemm mode", + extra={"min_sms": min_sms, "avail_sms": avail_sms}, + ) + return False + return True + + +def use_max_autotune() -> bool: + return config.max_autotune or config.max_autotune_gemm + + +def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool: + return ( + use_max_autotune() + and layout.device.type == "cuda" + and layout.dtype in allowed_layout_dtypes + and is_big_gpu(layout.device.index or 0) + ) + + +def _use_autotune_backend(backend: str) -> bool: + return backend.upper() in [ + x.strip() for x in config.max_autotune_gemm_backends.upper().split(",") + ] + + +def _use_conv_autotune_backend(backend: str) -> bool: + return backend.upper() in [ + x.strip() for x in config.max_autotune_conv_backends.upper().split(",") + ] + + +def use_triton_template(layout, *, enable_int32=False, enable_float8=False): + from .codegen.common import BackendFeature, has_backend_feature + + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32] + if enable_int32: + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] + if enable_float8: + layout_dtypes.extend([torch.float8_e4m3fn, torch.float8_e5m2]) + return ( + _use_template_for_cuda(layout, layout_dtypes) + and _use_autotune_backend("TRITON") + and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES) + ) + + +def use_cutlass_template(layout, m, n, k): + from .virtualized import V + + gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) + if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size: + return False + from .codegen.cuda.cutlass_utils import try_import_cutlass + + # Do not use cutlass template on ROCm + if torch.version.hip: + return False + + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] + res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend( + "CUTLASS" + ) + + if res: + if not try_import_cutlass(): + log.warning( + "Failed to import CUTLASS lib. Please check whether " + "_inductor.config.cuda.cutlass_dir is set correctly. " + "Skipping CUTLASS backend for now." + ) + return False + return res + + +@functools.lru_cache(None) +def _rocm_native_device_arch_name(device): + return torch.cuda.get_device_properties(device).gcnArchName + + +@functools.lru_cache(None) +def try_import_ck_lib(): + try: + import ck4inductor # type: ignore[import] + from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import] + gen_ops_library, + gen_ops_preselected, + ) + from ck4inductor.universal_gemm.op import ( # type: ignore[import] + CKGemmOperation, + ) + + package_dirname = os.path.dirname(ck4inductor.__file__) + except ImportError: + + def gen_ops_library(): + return [] + + def gen_ops_preselected(): + return [] + + class CKGemmOperation: # type: ignore[no-redef] + pass + + package_dirname = None + return package_dirname, gen_ops_library, gen_ops_preselected, CKGemmOperation + + +def use_ck_template(layout, m, n, k): + # config knobs check 1 + if not use_max_autotune(): + return False + # config knobs check 2 + if not _use_autotune_backend("CK"): + return False + # platform check + if not torch.version.hip: + return False + # tensors must be on GPU + if not layout.device.type == "cuda": + return False + # hardware check + # if config arch list is not specified, get the native arch from the device properties + native_arch = _rocm_native_device_arch_name(layout.device) + requested_archs = {k.split(":")[0]: k for k in config.rocm.arch} or { + native_arch.split(":")[0]: native_arch + } + requested_supported_archs = [ + requested_archs[k] + for k in requested_archs.keys() & config.rocm.ck_supported_arch + ] + if not requested_supported_archs: + return False + # supported input dtypes + if layout.dtype not in [torch.float16, torch.bfloat16]: + return False + # TBD: investigate if we need to disable backend based on number of available CUs similar to `is_big_gpu` + # check if shape is static and gemm size is not 0 + from .virtualized import V + + gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) + if gemm_size <= 0: + return False + # TBD: investigate if backend needs to be disabled for small gemms similar to CUTLASS + + ck_package_dirname, _, _, _ = try_import_ck_lib() + + if not ck_package_dirname: + log.warning("Please pip install Composable Kernel package") + return False + + if not config.rocm.ck_dir: + log.warning("Please set TORCHINDUCTOR_CK_DIR env variable") + return False + + if ck_package_dirname != config.rocm.ck_dir: + log.warning("Invalid path to CK library") + return False + + return True + + +def _use_template_for_cpu(layout): + return use_max_autotune() and layout.device.type == "cpu" + + +def use_cpp_packed_gemm_template(layout, mat1, mat2, mat2_transposed=False): + from . import ir + from .codegen.cpp_micro_gemm import create_micro_gemm + from .codegen.cpp_utils import get_gemm_template_output_and_compute_dtype + from .kernel.mm_common import mm_args + + if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"): + return False + + if not config.cpp.weight_prepack: + return False + + int8_gemm = mat1.get_dtype() == torch.uint8 + layout_dtypes = [torch.float32, torch.bfloat16, torch.half, torch.uint8] + m, n, k, layout, mat1, mat2 = mm_args( + mat1, + mat2, + out_dtype=layout.dtype if int8_gemm else None, + mat2_transposed=mat2_transposed, + ) + + # TODO(jgong5): support dynamic shapes for n or k + if has_free_symbols((n, k)): + return False + if isinstance(mat2, ir.BaseView): + mat2 = mat2.unwrap_view() + + output_dtype, _ = get_gemm_template_output_and_compute_dtype(mat1.get_dtype()) + micro_gemm = create_micro_gemm( + "micro_gemm", + m, + n, + k, + input_dtype=mat1.get_dtype(), + input2_dtype=mat2.get_dtype(), + output_dtype=output_dtype, + num_threads=parallel_num_threads(), + ) + + def is_last_dim_stride1(x): + x.freeze_layout() + return x.get_stride()[-1] == 1 + + return ( + layout.dtype in layout_dtypes + and micro_gemm is not None + and is_last_dim_stride1(mat1) # TODO(jgong5): support transposed input + and isinstance(mat2, ir.StorageBox) + and mat2.is_module_buffer() + ) + + +def use_aten_gemm_kernels(): + return not use_max_autotune() or _use_autotune_backend("ATEN") + + +class DebugDirManager: + counter = itertools.count(0) + prev_debug_name: str + + def __init__(self) -> None: + self.id = next(DebugDirManager.counter) + + def __enter__(self): + self.prev_debug_name = torch._dynamo.config.debug_dir_root + self.new_name = f"{self.prev_debug_name}_tmp_{self.id}" + torch._dynamo.config.debug_dir_root = self.new_name + + def __exit__(self, *args): + shutil.rmtree(self.new_name) + torch._dynamo.config.debug_dir_root = self.prev_debug_name + + +def run_and_get_code(fn, *args, **kwargs): + from .graph import GraphLowering + + source_codes: List[str] = [] + + def save_output_code(code: str): + source_codes.append(code) + + with mock.patch.object(GraphLowering, "save_output_code", save_output_code): + torch._dynamo.reset() + result = fn(*args, **kwargs) + return result, source_codes + + +def run_fw_bw_and_get_code(fn): + def run_with_backward(): + result = fn() + result.sum().backward() + return result + + return run_and_get_code(run_with_backward) + + +def get_code(fn, *args, **kwargs): + """Get the inductor-generated code, but skip any actual compilation or running.""" + from .graph import GraphLowering + + source_codes: List[str] = [] + + def save_output_code(code: str): + source_codes.append(code) + + def patched_compile_to_module(self: GraphLowering): + class DummyModule: + """This is empty to replace the generated triton module""" + + def __init__(self) -> None: + pass + + def call(self, *args, **kwargs): + # Don't do anything when called + pass + + code, _ = ( + self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() + ) + # Skip all the actual compiling. + nonlocal save_output_code + save_output_code(code) + + return DummyModule() + + with mock.patch.object( + GraphLowering, "compile_to_module", patched_compile_to_module + ), mock.patch.object(GraphLowering, "save_output_code", save_output_code): + torch._dynamo.reset() + # Note the return here is None + _ = fn(*args, **kwargs) + + return source_codes + + +def get_triton_code(fn, *args, **kwargs): + source_codes = get_code(fn, *args, **kwargs) + # Can have two outputs if backwards was eagerly compiled + assert ( + 1 <= len(source_codes) <= 2 + ), f"expected one or two code outputs got {len(source_codes)}" + return source_codes[0] + + +def run_and_get_triton_code(fn, *args, **kwargs): + _, source_codes = run_and_get_code(fn, *args, **kwargs) + # Can have two outputs if backwards was eagerly compiled + assert ( + 1 <= len(source_codes) <= 2 + ), f"expected one or two code outputs got {len(source_codes)}" + return source_codes[0] + + +def run_and_get_graph_lowering(fn, *args, **kwargs): + from torch._inductor.codecache import CompiledFxGraph + from torch._inductor.graph import GraphLowering + + real_init = CompiledFxGraph.__init__ + graph_lowerings = [] + + def fake_init(*args, **kwargs): + real_init(*args, **kwargs) + graph = args[2] + assert isinstance(graph, GraphLowering) + graph_lowerings.append(graph) + + with mock.patch.object(CompiledFxGraph, "__init__", fake_init): + result = fn(*args, **kwargs) + + return result, graph_lowerings + + +@contextlib.contextmanager +def override_lowering(aten_op, override_fn): + """ + Override the lowering of aten_op with override_fn. + The first argument of override_fn is the original lowering fn. + """ + from torch._inductor import lowering + + orig_fn = lowering.lowerings[aten_op] + try: + lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn) + yield + finally: + lowering.lowerings[aten_op] = orig_fn + + +def add_scheduler_init_hook(pre_fn, post_fn=None): + """ + Add hook functions to be called at the beginning and end of Scheduler.__init__. + Used for unit tests. + """ + from torch._inductor.scheduler import Scheduler + + orig_fn = Scheduler.__init__ + + def wrapper(scheduler, nodes): + pre_fn(scheduler, nodes) + out = orig_fn(scheduler, nodes) + if post_fn: + post_fn(scheduler, nodes) + return out + + return unittest.mock.patch.object(Scheduler, "__init__", wrapper) + + +def developer_warning(msg): + """ + Warnings that will be actionable for PyTorch developers, but not + end users. Allows us to easily disable them in stable releases but + keep them on for nightly builds. + """ + if config.developer_warnings: + log.warning(msg) + else: + log.info(msg) + + +def get_benchmark_name(): + """ + An experimental API used only when config.benchmark_kernel is true. + + The benchmark name is only available at codegen time. So we can not + directly call it in benchmark_all_kernels which is run after codegen. + + The function assumes the argument after --only is the benchmark name. + It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc + scripts, this function may return None. + + There are 2 flavors of --only argument we need handle: + 1. --only model_name + 2. --only=model_name + """ + try: + idx = sys.argv.index("--only") + if ( + idx + 1 < len(sys.argv) + and len(sys.argv[idx + 1]) > 0 + and sys.argv[idx + 1][0] != "-" + ): + return sys.argv[idx + 1] + except ValueError: + pass + + for arg in sys.argv: + if arg.startswith("--only="): + return arg[len("--only=") :] + + +def is_ones(items): + return all(x == 1 for x in items) + + +def is_zeros(items): + return all(x == 0 for x in items) + + +def is_cpu_device(inputs): + return all( + item.device == torch.device("cpu") + for item in inputs + if isinstance(item, torch.Tensor) + ) + + +def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype: + assert isinstance( + val, sympy.Expr + ), "only support sympy.Expr as input to get_sympy_Expr_dtype" + if val.is_integer: # type: ignore[attr-defined] + return torch.int64 + else: + return torch.float64 + + +@contextlib.contextmanager +def maybe_profile(should_profile, *args, **kwargs): + if should_profile: + with torch.profiler.profile(*args, **kwargs) as p: + yield p + else: + yield + + +def parallel_num_threads(): + threads = config.cpp.threads + if threads < 1: + threads = torch.get_num_threads() + return threads + + +@functools.lru_cache(None) +def get_device_tflops(dtype): + from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops + + assert dtype in (torch.float16, torch.bfloat16, torch.float32) + + if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"): + # Triton API change in https://github.com/openai/triton/pull/2293 + from torch._utils_internal import max_clock_rate + + sm_clock = max_clock_rate() + if dtype in (torch.float16, torch.bfloat16): + return get_max_tensorcore_tflops(dtype, sm_clock) + + if torch.backends.cuda.matmul.allow_tf32: + return get_max_tensorcore_tflops(torch.float32, sm_clock) + else: + return get_max_simd_tflops(torch.float32, sm_clock) + else: + if dtype in (torch.float16, torch.bfloat16): + return get_max_tensorcore_tflops(dtype) + + if torch.backends.cuda.matmul.allow_tf32: + return get_max_tensorcore_tflops(torch.float32) + else: + return get_max_simd_tflops(torch.float32) + + +@functools.lru_cache(None) +def get_gpu_dram_gbps(): + from triton.testing import get_dram_gbps + + return get_dram_gbps() + + +def get_gpu_shared_memory(): + from triton.runtime import driver + + return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0) + + +def is_welford_reduction(reduction_type): + return reduction_type.startswith("welford") + + +def reduction_num_outputs(reduction_type): + return 3 if is_welford_reduction(reduction_type) else 1 + + +def is_linux() -> bool: + return platform.system() == "Linux" + + +def is_windows(): + return sys.platform == "win32" + + +def has_free_symbols(itr: Iterable[Any]): + return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr) + + +def is_dynamic(*args): + from . import ir + + for t in args: + if isinstance(t, ir.TensorBox): + if has_free_symbols(t.data.get_size()) or ( + hasattr(t.data, "get_stride") and has_free_symbols(t.data.get_stride()) + ): + return True + elif isinstance(t, (ir.StorageBox, ir.BaseView, ir.ComputedBuffer)): + assert hasattr(t, "get_size") and hasattr(t, "get_stride") + if has_free_symbols(t.get_size()) or has_free_symbols(t.get_stride()): + return True + elif not isinstance(t, ir.IRNode): + continue + else: + raise TypeError(f"unexpected type for is_dynamic {type(t)}") + + return False + + +# Placeholder strings used in triton codegen. +class Placeholder(enum.Enum): + # The placeholder for the actual name of a triton kernel. + # e.g. for "def triton_" it would be "triton_" + KERNEL_NAME = "KERNEL_NAME" + + # The descriptive name of the triton kernel; when unique_kernel_names = False, this + # placeholder will be replaced with a string with more information. + DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME" + + +def pass_execution_and_save(func, gm, inp, msg): + from .pattern_matcher import stable_topological_sort + + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + delete=False, + ) as f: + before_io = io.StringIO() + after_io = io.StringIO() + ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp) + print(f"Before:\n{gm.graph}", file=f) + print(gm.graph, file=before_io) + start_time = datetime.now() + with GraphTransformObserver(gm, msg, config.trace.log_url_for_graph_xform): + func(gm.graph) + time_elapsed = datetime.now() - start_time + # recompile graph + stable_topological_sort(gm.graph) + gm.graph.lint() + gm.recompile() + + print(f"After:\n{gm.graph}", file=f) + print(gm.graph, file=after_io) + t = before_io.getvalue() == after_io.getvalue() + log.info( + "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s", + msg, + f.name, + t, + time_elapsed, + ) + + +def is_collective(node, op=None): + from . import ir + + return type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op) + + +def is_wait(node): + from . import ir + + return type(node) == ir._WaitKernel + + +def contains_collective(snode): + from torch._inductor.scheduler import BaseSchedulerNode, GroupedSchedulerNode + + assert isinstance(snode, BaseSchedulerNode) + if isinstance(snode, GroupedSchedulerNode): + return any(contains_collective(x) for x in snode.snodes) + else: + return is_collective(snode.node) + + +def contains_wait(snode): + from torch._inductor.scheduler import BaseSchedulerNode, GroupedSchedulerNode + + assert isinstance(snode, BaseSchedulerNode) + if isinstance(snode, GroupedSchedulerNode): + return any(contains_wait(x) for x in snode.snodes) + else: + return is_wait(snode.node) + + +def is_fallback_op(node, op): + from . import ir + + if isinstance(op, torch._ops.OpOverload): + op = {op} + return isinstance(node, ir.FallbackKernel) and node.op_overload in op + + +def buf_name_to_fused_snode(buf_name, name_to_buf, name_to_fused_node): + return name_to_fused_node[name_to_buf[buf_name].defining_op.get_name()] + + +def find_recursive_deps_of_node( + snode, collected_node_set, name_to_buf, name_to_fused_node, criteria_cb=None +): + if criteria_cb and criteria_cb(snode): + return + collected_node_set.add(snode) + for dep in snode.unmet_dependencies: + defining_op_for_dep = buf_name_to_fused_snode( + dep.name, name_to_buf, name_to_fused_node + ) + if defining_op_for_dep in collected_node_set: + continue + find_recursive_deps_of_node( + defining_op_for_dep, + collected_node_set, + name_to_buf, + name_to_fused_node, + criteria_cb=criteria_cb, + ) + + +def find_recursive_users_of_node( + snode, collected_node_set, name_to_buf, name_to_fused_node, criteria_cb=None +): + if criteria_cb and criteria_cb(snode): + return + collected_node_set.add(snode) + for o in snode.get_outputs(): + for user in o.users: + assert user.node is not None + if user.node.get_name() == "OUTPUT": + continue + if user.node.get_name() not in name_to_fused_node: + continue + user_op = name_to_fused_node[user.node.get_name()] + if user_op in collected_node_set: + continue + find_recursive_users_of_node( + user_op, + collected_node_set, + name_to_buf, + name_to_fused_node, + criteria_cb=criteria_cb, + ) + + +def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int): + "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)" + num_rng_seed_offset_inputs = ( + 2 if torch._functorch.config.functionalize_rng_ops else 0 + ) + # AOT won't lift any parameters if we're inlining NN Modules + # however desugaring subclasses will still add arguments + # resulted in extra fixed inputs https://github.com/pytorch/pytorch/issues/130502 + if ( + torch._dynamo.config.inline_inbuilt_nn_modules + and not torch._dynamo.utils.is_parameter_freezing() + ): + return 0 + + return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs + + +def count_tangents(fx_g: torch.fx.GraphModule): + """ + Infers which inputs are static for a backwards graph + """ + + def is_saved_tensor(x): + return ( + "tangents" not in x.name + and "bwd_seed" not in x.name + and "bwd_base_offset" not in x.name + ) + + arg_count = 0 + static_arg_idxs = [] + for n in fx_g.graph.nodes: + if n.op == "placeholder": + if is_saved_tensor(n): + static_arg_idxs.append(arg_count) + arg_count += 1 + + assert static_arg_idxs == list(range(len(static_arg_idxs))) + return len(static_arg_idxs) + + +@dataclasses.dataclass +class BoxedBool: + value: bool + + def __bool__(self): + return self.value + + @staticmethod + def disable(obj): + if isinstance(obj, BoxedBool): + obj.value = False + return obj + return False + + +@contextlib.contextmanager +def collect_defined_kernels(kernel_list): + from .codegen.wrapper import WrapperCodeGen + + orig_define_kernel = WrapperCodeGen.define_kernel + + def new_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs): + nonlocal kernel_list + kernel_list.append(kernel_code) + return orig_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs) + + with unittest.mock.patch.object(WrapperCodeGen, "define_kernel", new_define_kernel): + yield + + +def get_cloned_parameter_buffer_name(name: str): + return name + "__original__" + + +def is_gpu(device: str): + assert isinstance(device, str) or device is None, device + return device in ["cuda", "xpu"] + + +def device_need_guard(device: str): + assert isinstance(device, str) + return is_gpu(device) + + +def needs_fallback_due_to_atomic_add_limitations(dtype): + # tl.atomic_add does NOT support the following types + return dtype in {torch.int64, torch.bool, torch.bfloat16} + + +def use_scatter_fallback( + op_overload: torch._ops.OpOverload, + reduction_type, + self_dtype, + src_dtype, + src_device_type, + src_is_tensor, +): + if ( + op_overload.overloadpacket + in (torch.ops.aten.scatter_reduce_, torch.ops.aten.scatter_reduce) + and reduction_type is None + ): + return False + + reduce_ty = ( + "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum" + ) + + return ( + reduction_type not in {None, reduce_ty} + or ( + src_is_tensor + and is_gpu(src_device_type) + and needs_fallback_due_to_atomic_add_limitations(src_dtype) + ) + or ( + op_overload.overloadpacket == torch.ops.aten.scatter_reduce_ + and reduction_type == "sum" + and src_is_tensor + and src_device_type == "cpu" + and config.cpp.fallback_scatter_reduce_sum + and (config.cpp.dynamic_threads or parallel_num_threads() != 1) + ) + or (reduction_type == reduce_ty and self_dtype in {torch.bool, torch.int64}) + or torch.are_deterministic_algorithms_enabled() + ) + + +def dump_node_schedule(node_schedule): + """ + An API that can be used in pdb to dump a node_schedule. + Right mainly dump the read/write dependencies but can add more as needed. + """ + from torch._inductor.codegen.simd import DisableReduction, EnableReduction + from torch._inductor.scheduler import SchedulerNode + + print(f"Node schedule with {len(node_schedule)} nodes") + for idx, node in enumerate(node_schedule): + print(f" {idx:3}:") + if node is EnableReduction: + print("enable reduction") + elif node is DisableReduction: + print("disable reduction") + elif isinstance(node, SchedulerNode): + is_red = node.is_reduction() + print(f"{'red' if is_red else 'pw'} scheduler node") + if is_red: + assert node.node is not None + print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined] + print("ReadDep:") + for dep in node.read_writes.reads: + print(dep) + print("WriteDep:") + for dep in node.read_writes.writes: + print(dep) + else: + raise RuntimeError(f"Unrecognized node type: {type(node)}") + + +def tensor_is_aligned(tensor: torch.Tensor): + # See Note: [Input Alignment handling in Inductor] + # Right now, we don't try to guard on the alignment of the storage offset. + # When this comment was written, non-symbolic storage_offsets are not guarded on + # but symbolic storage_offsets are. For consistency, we suppress guard creation + # upon performing this check: that ensures that we don't add recompiles when we + # add this logic. + from torch.fx.experimental.symbolic_shapes import statically_known_true + + return statically_known_true( + (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % GPU_ALIGN_BYTES == 0 + ) + + +def should_assume_input_aligned(example_input: torch.Tensor): + # See Note: [Input Alignment handling in Inductor] + + # right now, we only care about alignment for cuda tensors. + if not is_gpu(example_input.device.type): + return False + return config.assume_aligned_inputs or tensor_is_aligned(example_input) + + +def maybe_get_suppress_shape_guards_ctx(): + # Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards() + # If it's not available, return a nullcontext. + + # If we're dealing with cudagraphs, we might not have a tracing_context + tracing_context = torch._guards.TracingContext.try_get() + if not tracing_context: + return contextlib.nullcontext() + + # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode + shape_env = tracing_context.fake_mode.shape_env + if not shape_env: + return contextlib.nullcontext() + + return shape_env.suppress_guards() + + +def run_and_get_cpp_code(fn, *args, **kwargs): + # We use the patch context manager instead of using it as a decorator. + # In this way, we can ensure that the attribute is patched and unpatched correctly + # even if this run_and_get_cpp_code function is called multiple times. + with unittest.mock.patch.object(config, "debug", True): + torch._dynamo.reset() + import io + import logging + + log_capture_string = io.StringIO() + ch = logging.StreamHandler(log_capture_string) + from torch._inductor.codecache import output_code_log + + output_code_log.addHandler(ch) + prev_level = output_code_log.level + output_code_log.setLevel(logging.DEBUG) + result = fn(*args, **kwargs) + s = log_capture_string.getvalue() + output_code_log.setLevel(prev_level) + output_code_log.removeHandler(ch) + return result, s + + +def shape_env_from_inputs(inputs: List[torch.Tensor]): + shape_env = None + fake_mode = detect_fake_mode(inputs) + + # TODO(voz): It would be nice to enable this assert, but there are lots of tests that + # pass in real inputs for now. + # if len(inputs) > 0: + # assert fake_mode is not None, breakpoint() + + if fake_mode is not None: + return fake_mode.shape_env + + # When there are no tensor inputs, get shape_env from the first SymInt. + for input in inputs: + if isinstance(input, torch.SymInt): + return input.node.shape_env + + # TODO(voz): Should we always have one anyway? + return None + + +def align_inputs_from_check_idxs( + model: Callable[[List[InputType]], Any], + inputs_to_check: Sequence[int], +) -> Callable[[List[InputType]], Any]: + if len(inputs_to_check) == 0: + return model + + def run(new_inputs: List[InputType]): + copy_misaligned_inputs(new_inputs, inputs_to_check) + return model(new_inputs) + + return run + + +def clone_preserve_strides(x: torch.Tensor): + needed_size = ( + sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 + ) + buffer = torch.as_strided(x, (needed_size,), (1,)).clone() + return torch.as_strided(buffer, x.size(), x.stride()) + + +def copy_misaligned_inputs( + new_inputs: List[InputType], check_inputs_idxs: Sequence[int] +) -> None: + for i in check_inputs_idxs: + _inp = new_inputs[i] + assert isinstance(_inp, torch.Tensor) + if _inp.data_ptr() % ALIGNMENT: + new_inputs[i] = clone_preserve_strides(_inp) + + +def remove_unaligned_input_idxs( + inputs: List[InputType], + static_input_idxs: Sequence[int], +): + """ + We require all inputs to be aligned, so introduce a copy for any + that aren't. + """ + aligned_static_input_idxs = [] + for idx in static_input_idxs: + input = inputs[idx] + if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0: + aligned_static_input_idxs.append(idx) + if len(aligned_static_input_idxs) != len(static_input_idxs): + return aligned_static_input_idxs + return static_input_idxs + + +def set_tracing_context_output_strides(example_inputs, compiled_graph): + # Return the output strides to the caller via TracingContext + context = torch._guards.TracingContext.try_get() + if context is not None and context.output_strides is not None: + assert len(context.output_strides) == 0 + shape_env = shape_env_from_inputs(example_inputs) + for exprs in compiled_graph.output_strides: + if exprs is None: + context.output_strides.append(None) + else: + context.output_strides.append( + tuple( + ( + shape_env.evaluate_symexpr(e) + if shape_env is not None + else int(e) + ) + for e in exprs + ) + )