diff --git a/.gitattributes b/.gitattributes index 4d62a9c21308ef7af9dc6dbc7963be05388c9bb8..79fae07f586a29ba2802c1f8fe8eebc37337f771 100644 --- a/.gitattributes +++ b/.gitattributes @@ -129,3 +129,11 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/torch/nn/__pycache__/functional.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/scheduler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/codecache.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/symbolic_shapes.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 b/.venv/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 new file mode 100644 index 0000000000000000000000000000000000000000..22e94ba55aff546be82878d928d976bfe4d1b3e3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:78df2f31f6db8142ec546a1e5a31cb066f7892d12d2f665b448f8069a08ef807 +size 251616632 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/codecache.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/codecache.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27e7f80e18d0c34cfa6c4b13aed98fce7e66a633 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/codecache.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1ba20726a513f57e01fc1fbf9c3744defdeda5d64e6e3a00d7d3911f4f598d2 +size 164293 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b360fda268932cc9f2d7c0e65e7241655f516cc3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:495965e46b513011b3387880294e810069bb3299277002dd35d6e15e1a3d6508 +size 118734 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3986445829b133672a729892e6f87d53d8cb707f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/lowering.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c863587cdf0f8eef657d2fa0f0ebf9ddddc19a24d5670869719203bc7d877e48 +size 337621 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ade93a58f260924a8a8631655c117653b3a490a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd3ce0e8ac0de613f90615aaf063bff822e142ca75c5993718647f82d9d0add5 +size 109858 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d0cd85c643bef156e59b51e835e7f71f75e8be9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a247c49f4e32c680bff1ed7aa611b6eae0c91d995c632fbc3fa35605649b638b +size 109445 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__init__.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc9aa91a94948dafe9edb19fc9dc2360f50c5bf1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/codegen_device_driver.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/codegen_device_driver.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17b3816624846a2de77d823a1854f4c31dd79dc1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/codegen_device_driver.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79b927b81da7fc6867d346e7b5f85363854ea3f6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..615ff7336945643e8e59c4fd208b12cde5049a62 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43d6724126c4dc2e5daa8ce1ea6544023f9b88bc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5ad67bab40bdcf8dc642920904a8c78d53d0a8b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec5850e561e77562a9c59adad20883936f10e20e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44af9d168874402c9d78a768f58505961a21b8b0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b2308f5d3221845307cb0660bb89d0995c744b6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d32bd0ffb919dd1c666438577d9108f8665068d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a6012be684b8ff75d268124512fbee7f83f0974 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f26808e54638ab846fca73b6b1e4b6307e1cd49 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d93048784e81cb8ad37f68dcb71fe6b229f7cec Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05d39b20ef62245ae1e281efc4fca5b7ad7f3436 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5df3b274d253e58e3be12a8b46dc78f3803aabbb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0273aa9aa8df0e57a27e6c8577646620c3159430 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp @@ -0,0 +1,87 @@ +// NOTE: Like interface.cpp, this file will be copied into AOTInductor +// generated output. This file is intended to keep implementation +// details separate from the implementation of the AOTI public +// interface. Note also that #includes should go into interface.cpp +// for simplicity of maintenance. + +namespace torch { +namespace aot_inductor { +template +void convert_output_to_handle( + const ArrayRefTensor& output, + AtenTensorHandle& handle) { + handle = output.expensiveCopyToTensor(); +} + +template +void convert_outputs_to_handles_helper( + const std::tuple...>& outputs, + AtenTensorHandle* output_handles, + std::index_sequence) { + (convert_output_to_handle(std::get(outputs), output_handles[Is]), ...); +} +template +void convert_outputs_to_handles( + const std::tuple...>& outputs, + AtenTensorHandle* output_handles) { + convert_outputs_to_handles_helper( + outputs, output_handles, std::make_index_sequence()); +} + +template +void convert_handle_to_arrayref_tensor( + AtenTensorHandle handle, + ArrayRefTensor& input) { + void* data_ptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle, &data_ptr)); + int64_t dim; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(handle, &dim)); + int64_t numel; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(handle, &numel)); + int64_t* sizes; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle, &sizes)); + int64_t* strides; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle, &strides)); + int32_t dtype; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(handle, &dtype)); + int32_t device_type; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(handle, &device_type)); + int32_t device_index; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_index(handle, &device_index)); + + input = ArrayRefTensor( + MiniArrayRef(reinterpret_cast(data_ptr), numel), + MiniArrayRef(sizes, dim), + MiniArrayRef(strides, dim), + device_type, + device_index); +} + +template +void convert_handles_to_inputs_helper( + AtenTensorHandle* input_handles, + std::tuple...>& inputs, + std::index_sequence) { + (convert_handle_to_arrayref_tensor(input_handles[Is], std::get(inputs)), + ...); +} + +template +void convert_handles_to_inputs( + AtenTensorHandle* input_handles, + std::tuple...>& inputs) { + convert_handles_to_inputs_helper( + input_handles, inputs, std::make_index_sequence()); +} + +template +void assert_numel(const ArrayRefTensor& tensor, uint64_t numel) { + if (tensor.numel() != numel) { + std::stringstream err; + err << "incorrect numel for input tensor. expected " << numel << ", got " << tensor.numel(); + throw std::runtime_error(err.str()); + } +} +} // namespace aot_inductor +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py new file mode 100644 index 0000000000000000000000000000000000000000..f80d1f71e380d08c20e6358a17addc8d009d4efe --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py @@ -0,0 +1,2167 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +import re +from enum import auto, Enum +from itertools import chain +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +import sympy +from sympy.printing.printer import Printer + +import torch +import torch.fx +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT +from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges + +from .. import config, metrics +from ..utils import ( + DeferredLineBase, + generate_assert, + IndentedBuffer, + sympy_dot, + sympy_subs, + unique, +) +from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V + + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") + + +def data_type_logger(msg): + if schedule_log.isEnabledFor(logging.DEBUG): + schedule_log.debug("Data type propagation: %s", msg) + + +@dataclasses.dataclass +class WorkspaceArg: + """A temporary buffer used for a single kernel, then discarded. + + Not registered as a traditional buffer since there are no users, + so it would be dead code eliminated. + """ + + nbytes: sympy.Expr + zero_fill: bool + + +@dataclasses.dataclass +class TensorArg: + name: str + buffer: str + dtype: torch.dtype + offset: sympy.Expr = sympy.Integer(0) # c++ only + alias_of: Optional[str] = None # halide only + + +@dataclasses.dataclass +class SizeArg: + name: str + expr: sympy.Expr + + @property + def alias_of(self): + return None + + +@dataclasses.dataclass +class DeviceCodegen: + scheduling: Any + wrapper_codegen: type + cpp_wrapper_codegen: type = type(None) + + +KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg] + +device_codegens: Dict[str, DeviceCodegen] = {} + + +class DeviceOpOverrides: + def import_get_raw_stream_as(self, name): + raise NotImplementedError + + def set_device(self, device_idx): + raise NotImplementedError + + def synchronize(self): + raise NotImplementedError + + def device_guard(self, device_idx): + raise NotImplementedError + + +device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {} + + +# The code generated by Inductor consists of two main parts: kernel code and wrapper code. +# For any new backend looking to integrate with Inductor, customization of these two main +# parts are necessary to generate its specific code. +# +# Kernel code generation is determined by different Scheduling. Consequently, a new +# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently, +# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively. +# +# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code +# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen, +# and override specific member functions to create backend-specific Python wrapper code. +# +# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part +# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces +# provide flexibility to the backend. A backend can choose to implement these classes from scratch, +# or reuse them by extending and overriding as necessary. And Inductor provides the registration API, +# register_backend_for_device, to equip a new backend at runtime. +# +# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces. +# This backend can be used as a reference: +# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 +def register_backend_for_device( + device: str, + device_scheduling: Any, + device_wrapper_codegen: type, + device_cpp_wrapper_codegen: type = type(None), +): + device_codegens[device] = DeviceCodegen( + device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen + ) + + +class BackendFeature(Enum): + FOREACH = auto() + BUCKETIZE = auto() + INPLACE_BUFFERS = auto() + MASKED_SCATTER_WITH_INDEX = auto() + SCAN = auto() + SORT = auto() + TUPLE_REDUCTION = auto() + PREFER_STORE_LOOP_ORDER = auto() + TRITON_TEMPLATES = auto() + REDUCE_TO_SINGLE_ELEMENT = auto() + + +def get_backend_features(device: Union[torch.device, str]): + init_backend_registration() + if isinstance(device, torch.device): + device_type = device.type + else: + assert isinstance(device, str) + device_type = device + device = torch.device(device_type) + scheduling = get_scheduling_for_device(device_type) + return scheduling(None).get_backend_features(device) + + +def has_backend_feature(device, feature): + """See also V.graph.has_feature""" + assert isinstance(feature, BackendFeature) + return feature in get_backend_features(device) + + +def get_scheduling_for_device(device: str): + return device_codegens[device].scheduling if device in device_codegens else None + + +def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False): + if device in device_codegens: + wrapper_codegen_obj: DeviceCodegen = device_codegens[device] + return ( + wrapper_codegen_obj.cpp_wrapper_codegen + if cpp_wrapper + else wrapper_codegen_obj.wrapper_codegen + ) + else: + return None + + +@functools.lru_cache(None) +def init_backend_registration(): + from .cpp import CppScheduling + from .cpp_wrapper_cpu import CppWrapperCpu + from .cpp_wrapper_cuda import CppWrapperCuda + from .cuda_combined_scheduling import CUDACombinedScheduling + from .halide import HalideScheduling + from .triton import TritonScheduling + from .wrapper import WrapperCodeGen + + if get_scheduling_for_device("cpu") is None: + cpu_backends = {"cpp": CppScheduling, "halide": HalideScheduling} + register_backend_for_device( + "cpu", + lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs), + WrapperCodeGen, + CppWrapperCpu, + ) + + if get_scheduling_for_device("cuda") is None: + # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation + cuda_backends = {"triton": CUDACombinedScheduling, "halide": HalideScheduling} + register_backend_for_device( + "cuda", + lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs), + WrapperCodeGen, + CppWrapperCuda, + ) + + if get_scheduling_for_device("xpu") is None: + register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen) + + private_backend = torch._C._get_privateuse1_backend_name() + if ( + private_backend != "privateuseone" + and get_scheduling_for_device(private_backend) is None + ): + from torch.utils.backend_registration import _get_custom_mod_func + + try: + device_scheduling = _get_custom_mod_func("Scheduling") + wrapper_codegen = _get_custom_mod_func("WrapperCodeGen") + cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodeGen") + if device_scheduling and wrapper_codegen and cpp_wrapper_codegen: + register_backend_for_device( + private_backend, + device_scheduling, + wrapper_codegen, + cpp_wrapper_codegen, + ) + except RuntimeError: + pass + + +def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes): + from ..ir import FlexibleLayout + + # added contiguous index prevents reordering + return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] + + +def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides): + device_op_overrides_dict[device] = device_op_overrides + + +def get_device_op_overrides(device: str): + assert isinstance(device, str) + + if not device_op_overrides_dict.keys(): + from .cuda import device_op_overrides # noqa: F401 + from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401 + + if device in device_op_overrides_dict.keys(): + return device_op_overrides_dict[device] + + +@functools.lru_cache(None) +def boolean_ops(): + return ( + "isinf", + "isnan", + "logical_not", + "signbit", + "le", + "lt", + "ge", + "gt", + "eq", + "ne", + ) + + +DTYPE_TO_COMPUTATION_DTYPE = { + torch.bfloat16: torch.float, + torch.float16: torch.float, + **{ + dtype: dtype + for dtype in [ + torch.bool, + torch.float32, + torch.float64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + ] + }, +} + + +def deduce_output_dtype_by_name( + op_name: str, + *args, + **kwargs, +) -> Optional[torch.dtype]: + """ + Given op name and a list of input dtypes, deduce the output dtype + """ + if op_name in boolean_ops(): + return torch.bool + elif op_name in ( + "to_dtype", + "index_expr", + ): + return kwargs["dtype"] if "dtype" in kwargs else args[-1] + elif op_name in ( + "rand", + "randn", + ): + return torch.float + elif op_name in ( + "get_index", + "randint64", + "load_seed", + ): + return torch.int64 + elif op_name == "reduction": + return kwargs["dtype"] if "dtype" in kwargs else args[1] + elif op_name == "constant": + dtype = kwargs["dtype"] if "dtype" in kwargs else args[-1] + return DTYPE_TO_COMPUTATION_DTYPE[dtype] # type: ignore[index] + elif op_name in ( + "load", + "store", + "store_reduction", + ): + buf_name = args[1] + return V.graph.get_dtype(buf_name) # type: ignore[arg-type] + elif op_name == "to_dtype_bitcast": + return kwargs["dtype"] if "dtype" in kwargs else args[-2] + return None + + +class DataTypePropagation: + def __init__(self, body) -> None: + self.body = body + self.graphs: Dict[Union[Callable[..., Any], str], Any] = { + "root": body.root_block.graph + } + for k, v in body.subblocks.items(): + self.graphs[k] = v.graph + + def deduce_node_dtype_by_inputs(self, node: torch.fx.Node): + inputs = node.all_input_nodes + input_nodes = [ + n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" + ] + if len(input_nodes) == 0: + return None + + all_input_nodes_propagated = all( + OptimizationContext.key in n.meta + and n.meta[OptimizationContext.key].dtype is not None + for n in input_nodes + ) + if not all_input_nodes_propagated: + return None + + return functools.reduce( + torch.promote_types, + [n.meta[OptimizationContext.key].dtype for n in input_nodes], + ) + + def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node): + sub_graph = self.graphs[node.target] + dtype = self.propagate_graph(sub_graph) + assert dtype + return dtype + + def deduce_node_dtype(self, node: torch.fx.Node): + if node.op == "placeholder": + return None + + if node.target == "output" and len(node.args) != 1: + # we can infer output node if it only have 1 arg + return None + + if node.target == operator.getitem: + return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type] + + assert isinstance(node.target, str) + + if node.target.startswith("masked_subblock"): + return self.deduce_node_dtype_by_subgraph(node) + + if ( + output_dtype := deduce_output_dtype_by_name( + node.target, + *node.args, + **node.kwargs, + ) + ) is not None: + return output_dtype + + return self.deduce_node_dtype_by_inputs(node) + + def propagate_graph(self, graph: torch.fx.Graph): + assert graph.nodes + graph_dtype = None + # For masked_subblock, we use output's dtype to represent + # the dtype of this subgraph. For other cases, graph_dtype + # might be None + for node in graph.nodes: + if OptimizationContext.key in node.meta: + opt_ctx = node.meta[OptimizationContext.key] + else: + opt_ctx = OptimizationContext() + + opt_ctx.dtype = self.deduce_node_dtype(node) + node.meta[OptimizationContext.key] = opt_ctx + if node.target == "output": + graph_dtype = opt_ctx.dtype + return graph_dtype + + def propagate(self): + self.propagate_graph(self.graphs["root"]) + + @classmethod + def propagate_loopbody(cls, body): + return cls(body).propagate() + + @classmethod + def propagate_scheduler_node(cls, node): + from ..loop_body import LoopBody + from ..scheduler import SchedulerNode + + assert isinstance(node, SchedulerNode) + assert isinstance(node._body, LoopBody) + DataTypePropagation.propagate_loopbody(node._body) + + +# This printer contains rules that are supposed to be generic for both C/C++ and +# Python +class ExprPrinter(Printer): + @staticmethod + def paren(string): + def all_in_parens(string): + if string[0] != "(" or len(string) < 2: + return False + count = 1 + for i, char in enumerate(string[1:]): + if char == "(": + count += 1 + elif char == ")": + count -= 1 + if count == 0 and i != len(string) - 2: + return False + assert count == 0 + return True + + if ( + isinstance(string, CSEVariable) + or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE) + or re.match(r"^\([^)]*\)$", string, re.IGNORECASE) + or string == "" + ): + return string + # don't put extra parens for strings that are already wrapped in parens + if all_in_parens(string): + return string + return f"({string})" + + def _print_Relational(self, expr): + return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) + + def _print_Mul(self, expr): + return "*".join(map(self.paren, map(self._print, expr.args))) + + def _print_Add(self, expr): + return " + ".join(map(self.paren, map(self._print, expr.args))) + + # NB: this is OK to put here, because Mod is only defined for positive + # numbers, and so across C/Python its behavior is consistent + def _print_Mod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_FloatTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + def _print_CleanDiv(self, expr): + return self._print_FloorDiv(expr) + + def _print_Identity(self, expr): + return self._print(expr.args[0]) + + def _print_GreaterThan(self, expr): + # GreaterThan: >= + # StrictlyGreaterThan: > + # Go figure... + return " >= ".join(map(self.paren, map(self._print, expr.args))) + + # NB: The C implementation is injected into codegen at + # torch/_inductor/codegen/wrapper.py + def _print_align(self, expr): + assert len(expr.args) == 1 + return f"align({self._print(expr.args[0])})" + + # This must be implemented because sympy will collect x * x into Pow(x, 2), without + # any explicit intervention. We print it just like x * x, notably, we + # never generate sympy.Pow with floats. + # + # NB: this pow by natural, you should never have used builtin sympy.pow + # for FloatPow, and a symbolic exponent should be PowByNatural. These + # means exp is guaranteed to be integer. + def _print_Pow(self, expr): + base, exp = expr.args + base = self._print(base) + assert exp == int(exp), exp + exp = int(exp) + assert exp >= 0 + if exp > 0: + return "*".join([self.paren(base)] * exp) + else: # exp == 0 + return "1" + + # Explicit NotImplemented functions are to prevent default sympy printing + # behavior, which will just barf out ToFloat(...) to your IR. The error + # message is better here because it tells you which printer class it needs + # to go in. + + def _print_ToFloat(self, expr): + raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") + + def _print_Infinity(self, expr): + raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") + + def _print_NegativeInfinity(self, expr): + raise NotImplementedError( + f"_print_NegativeInfinity not implemented for {type(self)}" + ) + + def _print_FloorDiv(self, expr): + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + + def _print_PythonMod(self, expr): + raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") + + def _print_IntTrueDiv(self, expr): + raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") + + def _print_PowByNatural(self, expr): + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatPow(self, expr): + raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") + + def _print_TruncToInt(self, expr): + raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") + + def _print_RoundToInt(self, expr): + raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") + + def _print_RoundDecimal(self, expr): + raise NotImplementedError( + f"_print_RoundDecimal not implemented for {type(self)}" + ) + + # NB: Some float operations are INTENTIONALLY not implemented for + # printers. You can implement them as a quick unblock, but it is better + # to ask yourself why we haven't done this computation in the Tensor + # universe instead + + def _print_TruncToFloat(self, expr): + raise NotImplementedError( + f"_print_TruncToFloat not implemented for {type(self)}" + ) + + def doprint(self, expr, *, simplify: bool = True): + # TODO: why are people passing strings to the printer here :think: + if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): + expr = V.graph.sizevars.simplify(expr) + return super().doprint(expr) + + +class PythonPrinter(ExprPrinter): + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"float({self._print(expr.args[0])})" + + def _print_ModularIndexing(self, expr): + x, div, mod = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + mod = self.paren(self.doprint(mod)) + if div != "1": + x = f"({x} // {div})" + return f"{x} % {mod}" + + def _print_Infinity(self, expr): + return "math.inf" + + def _print_NegativeInfinity(self, expr): + return "-math.inf" + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_FloorDiv(self, expr): + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + return f"({x} // {div})" + + # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python + # does a special algorithm + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + def _helper_sqrt(self, expr): + return f"math.sqrt({self._print(expr)})" + + def _print_OpaqueUnaryFn_sqrt(self, expr): + return self._helper_sqrt(expr.args[0]) + + def _print_FloatPow(self, expr): + base, exp = expr.args + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" + + # TODO: Not sure this works with Triton, even when base/exp are integral + def _print_PowByNatural(self, expr): + base, exp = expr.args + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" + + def _print_floor(self, expr): + assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_FloorToInt(self, expr): + assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_TruncToInt(self, expr): + assert len(expr.args) == 1 + # This also could have been int(), they'll do the same thing for float + return f"math.trunc({self._print(expr.args[0])})" + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return f"abs({self._print(expr.args[0])})" + + # NB: It's expected that we've made explicit any promotion in the sympy + # expression, so it doesn't matter that Python max/min doesn't perform + # promotion + def _print_Max(self, expr): + assert len(expr.args) >= 2 + return f"max({', '.join(map(self._print, expr.args))})" + + def _print_Min(self, expr): + assert len(expr.args) >= 2 + return f"min({', '.join(map(self._print, expr.args))})" + + def _print_OpaqueUnaryFn_cos(self, expr): + assert len(expr.args) == 1 + return f"math.cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr): + assert len(expr.args) == 1 + return f"math.cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr): + assert len(expr.args) == 1 + return f"math.acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr): + assert len(expr.args) == 1 + return f"math.sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr): + assert len(expr.args) == 1 + return f"math.sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr): + assert len(expr.args) == 1 + return f"math.asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr): + assert len(expr.args) == 1 + return f"math.tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr): + assert len(expr.args) == 1 + return f"math.tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr): + assert len(expr.args) == 1 + return f"math.atan({self._print(expr.args[0])})" + + def _print_RoundToInt(self, expr): + assert len(expr.args) == 1 + return f"round({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr): + assert len(expr.args) == 2 + number, ndigits = expr.args + assert isinstance(ndigits, sympy.Integer) + return f"round({self._print(number)}, {ndigits})" + + +class OpOverrides: + def __init__(self, parent): + super().__init__() + self._parent = parent + + def __getattr__(self, item): + return getattr(self._parent, item) + + @staticmethod + def identity(value): + # used to trigger cse + return value + + @staticmethod + def constant(value, dtype): + return repr(value) + + @staticmethod + def reciprocal(x): + return ops.truediv(ops.constant(1, torch.int32), x) + + @staticmethod + def square(x): + return ops.mul(x, x) + + @staticmethod + def erfc(x): + return ops.sub(ops.constant(1, torch.float32), ops.erf(x)) + + @staticmethod + def erfcx(x): + return ops.mul(ops.exp(ops.square(x)), ops.erfc(x)) + + @staticmethod + def expm1(x): + return ops.sub(ops.exp(x), ops.constant(1, torch.float32)) + + @staticmethod + def log10(x): + return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32)) + + @staticmethod + def log2(x): + return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32)) + + @staticmethod + def exp2(x): + return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32))) + + @staticmethod + def log1p(x): + return ops.log(ops.add(x, ops.constant(1, torch.int32))) + + @staticmethod + def sigmoid(x): + one = ops.constant(1, torch.int32) + return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x)))) + + @staticmethod + def libdevice_sigmoid(x): + one = ops.constant(1, torch.int32) + return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x)))) + + @staticmethod + def relu(x): + return ops.maximum(x, ops.constant(0, torch.int32)) + + @staticmethod + def libdevice_abs(x): + return ops.abs(x) + + @staticmethod + def libdevice_sqrt(x): + return ops.sqrt(x) + + @staticmethod + def libdevice_cos(x): + return ops.cos(x) + + @staticmethod + def libdevice_sin(x): + return ops.sin(x) + + @staticmethod + def libdevice_log(x): + return ops.log(x) + + @staticmethod + def libdevice_exp(x): + return ops.exp(x) + + @staticmethod + def bitwise_not(x): + return f"~{ExprPrinter.paren(x)}" + + @staticmethod + def logical_not(a): + return f"{ExprPrinter.paren(a)} == 0" + + @staticmethod + def bitwise_and(x, y): + return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_or(x, y): + return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_xor(x, y): + return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_left_shift(x, y): + return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_right_shift(x, y): + return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}" + + @staticmethod + def remainder(a, b): + r = ops.mod(a, b) + cond = ops.and_( + ops.ne(r, ops.constant(0, torch.int32)), + ops.ne(ops.signbit(r), ops.signbit(b)), + ) + return ops.where(cond, ops.add(r, b), r) + + @staticmethod + def trunc_to_int(a, dtype): + return ops.to_dtype(ops.trunc(a), dtype) + + @staticmethod + def floor_to_int(a, dtype): + return ops.to_dtype(ops.floor(a), dtype) + + @staticmethod + def ceil_to_int(a, dtype): + return ops.to_dtype(ops.ceil(a), dtype) + + @staticmethod + def round_to_int(a, dtype): + return ops.to_dtype(ops.round(a), dtype) + + @staticmethod + def int_truediv(a, b): + # TODO: this is wrong + # TODO: an easy bandaid is to generate runtime asserts that it's + # <= 2**53, which is when this equation is correct + return ops.truediv(a, b) + + @staticmethod + def load_seed(name, offset): + return ops.load(name, sympy.Integer(offset)) + + @classmethod + def _initialize_pointwise_overrides(cls, target): + assert target in {"triton", "cpp", "cppvec"}, target + + for funcname, data in pointwise_overrides_data.items(): + impl = getattr(data, target) + if impl is None: + continue + setattr(cls, funcname, staticmethod(impl)) + + +@dataclasses.dataclass +class OverridesData: + name: str + cpp: Callable[..., str] + # None when not impl in libdevice/triton + triton: Optional[Callable[..., str]] = None + # None when not impl in aten/.../vec + cppvec: Optional[Callable[..., str]] = None + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = ( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +# NB: if you add a new special function, don't forget to update +# torch._inductor.ops_handler too +pointwise_overrides_data: Dict[str, OverridesData] = dict( + airy_ai=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"airy_ai_forward({x})", + name="special_airy_ai", + ), + bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_j0_forward({x})", + triton=lambda x: f"libdevice.j0({x})", + name="special_bessel_j0", + ), + bessel_j1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_j1_forward({x})", + triton=lambda x: f"libdevice.j1({x})", + name="special_bessel_j1", + ), + bessel_y0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_y0_forward({x})", + triton=lambda x: f"libdevice.y0({x})", + name="special_bessel_y0", + ), + bessel_y1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_y1_forward({x})", + triton=lambda x: f"libdevice.y1({x})", + name="special_bessel_y1", + ), + digamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_digamma({x})", + cppvec=lambda x: f"{x}.digamma()", + name="digamma", + ), + # no cpp nor triton implementation for entr, it is defined as decomposition + # erf, erfc + erfcx=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_erfcx({x})", + triton=lambda x: f"libdevice.erfcx({x})", + name="special_erfcx", + ), + fma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})", + cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})", + triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})", + name="fma", + ), + # erfinv, exp2, expit, gammaln + igamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igamma({x}, {y})", + name="igamma", + ), + igammac=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igammac({x}, {y})", + name="igammac", + ), + gammainc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igamma({x}, {y})", + name="special_gammainc", + ), + gammaincc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igammac({x}, {y})", + name="special_gammaincc", + ), + i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i0({x})", + triton=lambda x: f"libdevice.cyl_bessel_i0({x})", + cppvec=lambda x: f"{x}.i0()", + name="i0", + ), + i0e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i0e({x})", + cppvec=lambda x: f"{x}.i0e()", + name="special_i0e", + ), + i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i1({x})", + triton=lambda x: f"libdevice.cyl_bessel_i1({x})", + name="special_i1", + ), + i1e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i1e({x})", + name="special_i1e", + ), + log_ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_log_ndtr({x})", + name="special_log_ndtr", + ), + # logit + modified_bessel_i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_i0_forward({x})", + triton=lambda x: f"libdevice.cyl_bessel_i0({x})", + name="special_modified_bessel_i0", + ), + modified_bessel_i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_i1_forward({x})", + triton=lambda x: f"libdevice.cyl_bessel_i1({x})", + name="special_modified_bessel_i1", + ), + modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_k0_forward({x})", + name="special_modified_bessel_k0", + ), + modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_k1_forward({x})", + name="special_modified_bessel_k1", + ), + # multigamma + ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_ndtr({x})", + name="special_ndtr", + ), + ndtri=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_ndtri({x})", + name="special_ndtri", + ), + polygamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_polygamma({y}, {x})", + name="polygamma", + ), + # psi - alias to digamma + # round + scaled_modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})", + name="special_scaled_modified_bessel_k0", + ), + scaled_modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})", + name="special_scaled_modified_bessel_k1", + ), + # sinc + spherical_bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"spherical_bessel_j0_forward({x})", + name="special_spherical_bessel_j0", + ), + zeta=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"zeta({x}, {y})", + name="special_zeta", + ), + chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})", + name="special_chebyshev_polynomial_t", + ), + chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})", + name="special_chebyshev_polynomial_u", + ), + chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})", + name="special_chebyshev_polynomial_v", + ), + chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})", + name="special_chebyshev_polynomial_w", + ), + legendre_polynomial_p=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})", + name="special_legendre_polynomial_p", + ), + shifted_chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_t", + ), + shifted_chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_u", + ), + shifted_chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_v", + ), + shifted_chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_w", + ), + hermite_polynomial_h=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})", + name="special_hermite_polynomial_h", + ), + hermite_polynomial_he=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})", + name="special_hermite_polynomial_he", + ), + laguerre_polynomial_l=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})", + name="special_laguerre_polynomial_l", + ), +) + + +# Use mypy to check protocol implemented correctly +def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]: + return h + + +class DeferredLine(DeferredLineBase): + """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" + + def __init__(self, name, line): + super().__init__(line) + self.name = name + assert not isinstance(line, DeferredLineBase) + + def __call__(self): + if all( + self.name not in x + for x in ( + V.graph.removed_buffers, + V.kernel.removed_buffers, + V.graph.inplaced_to_remove, + V.kernel.inplaced_to_remove, + ) + ): + return self.line + return None + + def _new_line(self, line): + return DeferredLine(self.name, line) + + +class BracesBuffer(IndentedBuffer): + def indent(self, offset=1): + @contextlib.contextmanager + def ctx(): + for _ in range(offset): + self.writeline("{") + self._indent += 1 + for _ in range(-offset): + self._indent -= 1 + self.writeline("}") + yield + for _ in range(-offset): + self.writeline("{") + self._indent += 1 + for _ in range(offset): + self._indent -= 1 + self.writeline("}") + + return ctx() + + +class InplacedBuffer(NamedTuple): + inner_name: str + other_names: List[str] + + +class KernelArgs: + @staticmethod + def _lookup(prefix, odict, name): + assert isinstance(name, (str, sympy.Symbol)) + if name not in odict: + odict[name] = f"{prefix}{len(odict)}" + return odict[name] + + def __init__(self, sizevars=None): + self.input_buffers = {} + self.output_buffers = {} + self.inplace_buffers = {} + self.sizevars = sizevars or {} + self.workspace_arg = None + + def __repr__(self): + return "KernelArgs({})".format( + ", ".join( + map( + repr, + [ + self.input_buffers, + self.output_buffers, + self.inplace_buffers, + self.sizevars, + ], + ) + ) + ) + + def _buffer_is_marked_removed(self, name): + return isinstance(name, str) and name.startswith("REMOVED") + + def input(self, name): + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.output_buffers: + return self.output_buffers[name] + if name in self.inplace_buffers: + return self.inplace_buffers[name].inner_name + if name.startswith("seed"): + return self._lookup("seed", self.input_buffers, name) + return self._lookup("in_ptr", self.input_buffers, name) + + def output(self, name): + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.inplace_buffers: + return self.inplace_buffers[name].inner_name + return self._lookup("out_ptr", self.output_buffers, name) + + def make_inplace(self, input_name, output_name): + assert output_name not in self.inplace_buffers + if input_name in self.inplace_buffers: + buf = self.inplace_buffers[input_name] + buf.other_names.append(output_name) + self.inplace_buffers[output_name] = buf + else: + buf = InplacedBuffer( + f"in_out_ptr{len(unique(self.inplace_buffers.values()))}", + [input_name, output_name], + ) + self.inplace_buffers[input_name] = buf + self.inplace_buffers[output_name] = buf + + def workspace(self, nbytes: sympy.Expr, zero_fill: bool): + if self.workspace_arg is None: + self.workspace_arg = WorkspaceArg(nbytes, zero_fill) + return "ws_ptr", 0 + + offset = self.workspace_arg.nbytes + zero_fill = zero_fill or self.workspace_arg.zero_fill + self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill) + return "ws_ptr", offset + + def seed_offset(self, name, value): + if value in self.sizevars: + return self.sizevars[value] + if name in self.sizevars.values(): + name = ( + f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}" + ) + self.sizevars[value] = name + return name + + def size(self, name): + if str(name) == "seed": + self.sizevars["seed"] = "seed" + return "seed" + return self._lookup("ks", self.sizevars, name) + + def call_names(self): + return chain( + self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() + ) + + def wrap_ptr_arg(self, buf, dtype): + return buf + + def wrap_size_arg(self, size): + return str(size) + + def cpp_argdefs(self): + from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE + + call_args = [] + arg_defs = [] + arg_types = [] + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + outer = inplaced.other_names[-1] + inner = inplaced.inner_name + dtype = V.graph.get_dtype(outer) + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"{cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.input_buffers.items(): + if outer in self.inplace_buffers: + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"const {cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"const {cpp_dtype}*") + for outer, inner in self.output_buffers.items(): + if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"{cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.sizevars.items(): + arg_defs.append(f"const {INDEX_TYPE} {inner}") + call_args.append(self.wrap_size_arg(outer)) + arg_types.append(f"const {INDEX_TYPE}") + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + assert self.workspace_arg is None, "Workspace not supported on CPU " + return arg_defs, call_args, arg_types + + def python_argdefs(self): + arg_defs: List[str] = [] + call_args: List[str] = [] + arg_types: List[torch.dtype] = [] + precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = [] + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + arg_defs.append(inplaced.inner_name) + call_args.append(inplaced.other_names[-1]) + arg_types.append(V.graph.get_dtype(inplaced.other_names[-1])) + precompile_args.append( + TensorArg( + name=inplaced.inner_name, + buffer=inplaced.other_names[-1], + dtype=V.graph.get_dtype(inplaced.other_names[-1]), + ) + ) + for outer, inner in chain( + self.input_buffers.items(), self.output_buffers.items() + ): + if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): + continue + arg_defs.append(inner) + call_args.append(outer) + arg_types.append(V.graph.get_dtype(outer)) + precompile_args.append( + TensorArg( + name=inner, + buffer=outer, + dtype=V.graph.get_dtype(outer), + ) + ) + for outer, inner in self.sizevars.items(): + arg_defs.append(inner) + call_args.append(outer) + arg_types.append(type(outer)) # type: ignore[arg-type] + precompile_args.append(SizeArg(inner, outer)) + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + if self.workspace_arg is not None: + arg_defs.append("ws_ptr") + call_args.append("workspace") + precompile_args.append(self.workspace_arg) + return arg_defs, call_args, precompile_args, arg_types + + def aliases(self): + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + for other in inplaced.other_names: + if ( + other in V.graph.inplaced_to_remove + or other in V.kernel.inplaced_to_remove + ): + continue + if other in self.input_buffers: + yield self.input_buffers[other], inplaced.inner_name + if other in self.output_buffers: + yield self.output_buffers[other], inplaced.inner_name + + def is_removed(self, name): + def _is_removed(name, buffers): + return name not in buffers or self._buffer_is_marked_removed(buffers[name]) + + return _is_removed(name, self.output_buffers) and _is_removed( + name, self.inplace_buffers + ) + + # Includes inplace buffers, excludes removed buffers. Essentially, + # after you do a call into this kernel, which buffers actually contain + # updated data? Modeled off of python_argdefs. + def live_output_buffers(self): + live_outs = OrderedSet() # type: ignore[var-annotated] + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + live_outs.add(inplaced.other_names[-1]) + for outer, inner in self.output_buffers.items(): + if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): + continue + live_outs.add(outer) + return live_outs + + +class CSEVariable: + """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. + To do so, the backends can simply overload `Kernel.create_cse_var` + The "CSEVariable.update_on_args" method gives you a hook for annotations + See example of TritonCSEVariable in triton.py + """ + + def __init__(self, name, bounds: ValueRanges[Any]): + assert isinstance(bounds, ValueRanges) + self.name = name + self.bounds = bounds + self.use_count = 1 # track how many tims this expression is used + + def __str__(self): + return self.name + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other) -> bool: + return type(other) == type(self) and other.name == self.name + + def update_on_args(self, name, args, kwargs): + pass + + def __repr__(self): + return f"{self.__class__.__name__}({self.name!r})" + + +class CppWrapperKernelArgs(KernelArgs): + def wrap_ptr_arg(self, buf, dtype): + from .cpp_utils import DTYPE_TO_CPP + + if config.abi_compatible: + # In the abi_compatible model, we just return the buf here. + # We will form correct call args later in wrapper.generate_kernel_all. + return buf + else: + return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())" + + def wrap_size_arg(self, size): + return f"{size}" + + +class CSE: + """Common subexpression elimination""" + + def __init__( + self, + prefix="", + suffix="", + name_prefix="tmp", + iter_buffers=None, + store_cache=None, + reduction_cache=None, + varname_map=None, + ): + self.prefix = prefix + self.suffix = suffix + self.cache = {} + self.name_prefix = name_prefix + self.store_cache = store_cache or {} + self.reduction_cache = reduction_cache or {} + self.iter_buffer_ids = iter_buffers or itertools.count() + self.invalidated_stores = OrderedSet() # type: ignore[var-annotated] + self.varname_map = varname_map or {} + + def invalidate(self, keep_vars: OrderedSet[str]): + for name, tmp in list(self.store_cache.items()): + if tmp not in keep_vars: + del self.store_cache[name] + self.invalidated_stores.add(name) + self.cache = {k: v for k, v in self.cache.items() if v in keep_vars} + + def clone(self): + # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional + return CSE( + prefix=self.prefix, + suffix=self.suffix, + name_prefix=self.name_prefix, + iter_buffers=self.iter_buffer_ids, + store_cache=self.store_cache, + varname_map=self.varname_map, + ) + + def generate( + self, + buffer: IndentedBuffer, + expr: Union[str, CSEVariable, OpsValue, IndentedBuffer], + *, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + write=True, + assignment=True, + ) -> CSEVariable: + if isinstance(expr, OpsValue): + expr = expr.value + + assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr) + assert write or assignment + if isinstance(expr, CSEVariable): + # If the expressions were always created with all the information, we could + # assert expr.bounds == bounds, but sometimes the expression is created + # with the loose ValueRanges.unknown(), so we need to tighten the bounds + expr.bounds = expr.bounds.tighten(bounds) + expr.use_count += 1 + return expr + cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr + var = self.cache.get(cache_key, None) + if not var: + var = self.newvar(bounds) + self.cache[cache_key] = var + if write: + if V.kernel.current_node: + V.kernel.current_node.codegen_originating_info( + buffer, only_once=True + ) + if isinstance(expr, IndentedBuffer): + if assignment: + buffer.writeline(f"{self.prefix}{var} =") + buffer.splice(expr) + buffer.writeline(self.suffix) + else: + if assignment: + line = f"{self.prefix}{var} = {expr}{self.suffix}" + else: + line = f"{expr}{self.suffix}" + buffer.writeline(line) + else: + var.bounds = var.bounds.tighten(bounds) + var.use_count += 1 + + return var + + def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable: + var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" + var = V.kernel.create_cse_var(var_name, bounds) + self.varname_map[var_name] = var + return var + + +class CodeGen: + def __init__(self) -> None: + super().__init__() + self.exit_stack = contextlib.ExitStack() + + def __enter__(self): + self.exit_stack.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + +class ScopedDict: + def __init__(self, original_dict): + self.original_dict = original_dict + self.new_items = {} + + def __getitem__(self, key): + if key in self.new_items: + return self.new_items[key] + return self.original_dict[key] + + def __setitem__(self, key, value): + self.new_items[key] = value + + def __contains__(self, key): + return key in self.new_items or key in self.original_dict + + def get(self, key, default=None): + if key in self.new_items: + return self.new_items[key] + return self.original_dict.get(key, default) + + +class Kernel(CodeGen): + newvar_prefix = "" + suffix = "" + overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None + # TODO: these look dead, but with all the getattr it's hard to tell... + load_format: None = None + store_format: None = None + + def __init__(self, args=None, increase_kernel_count=True): + super().__init__() + if increase_kernel_count: + metrics.generated_kernel_count += 1 + self.args = args or KernelArgs() + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + + self.num_load = 0 + self.num_reduction = 0 + + self.cse: CSE = CSE(self.newvar_prefix, self.suffix) + self.must_keep_buffers = OrderedSet() # type: ignore[var-annotated] + self.store_buffer_names = OrderedSet() # type: ignore[var-annotated] + self._load_mask = None + self._load_other = None + # OrderedSet in set_current_node + self.current_node = None + self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None + + self.removed_buffers = OrderedSet() # type: ignore[var-annotated] + self.inplaced_to_remove = OrderedSet() # type: ignore[var-annotated] + + # key: the buffer to write + # value: the buffer to read and whose memory can be reused for + # the buffer specified by key + self.inplace_update_buffers = {} + # Set minimum number of elements processed per thread. + self.min_elem_per_thread = 1 + self.kernel_name = None + + @contextlib.contextmanager + def set_current_node(self, node): + prior = self.current_node + self.current_node = node + self.node_to_bounds = node._body.bounds().get_bounds() + try: + yield + finally: + self.current_node = prior + + @contextlib.contextmanager + def swap_buffers(self, lb, cb=None, sb=None): + def scope_cse(cse): + new_cse = cse.clone() + new_cse.cache = ScopedDict(cse.cache) + new_cse.reduction_cache = ScopedDict(cse.reduction_cache) + new_cse.store_cache = ScopedDict(cse.store_cache) + return new_cse + + if cb is None: + cb = lb + loads = self.loads + compute = self.compute + stores = self.stores + cse = self.cse + self.loads = lb + self.compute = cb + self.stores = sb + self.cse = scope_cse(cse) + try: + yield + finally: + self.loads = loads + self.compute = compute + self.stores = stores + self.cse = cse + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: + raise NotImplementedError + + def indirect_load(self, name: str, index: sympy.Expr): + """A load the depends on an index we have read""" + prior = self.loads + try: + # put the load in the compute section as it might have deps + self.loads = self.compute + return self.load(name, index) + finally: + self.loads = prior + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + raise NotImplementedError + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + raise NotImplementedError + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + raise NotImplementedError + + def scan( + self, + dtypes: Tuple[torch.dtype, ...], + combine_fn: Callable[ + [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...] + ], + values: Tuple[CSEVariable, ...], + ) -> Tuple[CSEVariable, ...]: + raise NotImplementedError + + def sort( + self, + dtypes: Tuple[torch.dtype, ...], + values: Tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> Tuple[CSEVariable, ...]: + raise NotImplementedError + + def var_ranges(self): + raise NotImplementedError + + def bucketize( + self, + values: CSEVariable, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> CSEVariable: + """ + See [Note: Inductor bucketize op] + """ + raise NotImplementedError + + @property + def assert_function(self) -> str: + raise NotImplementedError + + def indirect_assert( + self, + var: Union[CSEVariable, str], + lower: Optional[str], + upper: Optional[str], + mask: Optional[Union[CSEVariable, str]] = None, + ) -> str: + if isinstance(var, CSEVariable): + var = str(var) + assert isinstance(var, str) + assert lower is None or isinstance(lower, str) + assert upper is None or isinstance(upper, str) + if lower and upper: + # The conditions need to be in parens because of Python's operator precedence. + # It'd be less error-prone to use and/or/not, which is suported by triton + cond = f"({lower} <= {var}) & ({var} < {upper})" + cond_print = f"{lower} <= {var} < {upper}" + elif lower: + cond = f"{lower} <= {var}" + cond_print = cond + else: + assert upper + cond = f"{var} < {upper}" + cond_print = cond + + if mask: + cond = f"({cond}) | ~({mask})" + + return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ): + raise NotImplementedError + + def index_to_str(self, index: sympy.Expr) -> str: + raise NotImplementedError + + def __enter__(self): + # TODO: hoist this to top level + class CSEProxy: + self.name = "CSEProxy" + vr_analysis = ValueRangeAnalysis() + + @staticmethod + def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] + def inner(*args, **kwargs): + bounds = CSEProxy._bound_variable(name, *args, **kwargs) + + value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] + + def do_cse(v): + csevar = V.kernel.cse.generate( + V.kernel.compute, v, bounds=bounds + ) + csevar.update_on_args(name, args, kwargs) + return csevar + + return pytree.tree_map(do_cse, value) + + return inner + + @staticmethod + def _bound_variable(name, *args, **kwargs): + """ + If the variable comes from an FX node, we forward the bound we have already computed + Else, if the variable when codegen'ing another op, we try to compute its bounds + """ + from ..select_algorithm import TritonTemplateKernel + + if isinstance(V.kernel, TritonTemplateKernel): + return ValueRanges.unknown() + + fx_node = V.interpreter.current_node + if fx_node.target == name and self.node_to_bounds is not None: + assert isinstance(self.node_to_bounds, dict) + return self.node_to_bounds.get(fx_node, ValueRanges.unknown()) + elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): + # These create lots of inner strings. We would need to compute the bounds at the ops + # We will also likely not get much from computing VRs on these nodes + if any( + s in fx_node.target + for s in ("set_indirect", "reduction", "scan") + ): + return ValueRanges.unknown() + + # We assume that the inputs come from `ops.` and are not strings. If you want to generate + # intermediary strings, wrap them in CSE variables with properly initialised bounds. + + # If there is no FX bound but we know how to compute one we do so + assert not kwargs + + def arg_to_bound(x): + if isinstance(x, CSEVariable): + return x.bounds + elif isinstance(x, sympy.Expr): + return bound_sympy(x) + else: + return x + + arg_bounds = list(map(arg_to_bound, args)) + return getattr(CSEProxy.vr_analysis, name)(*arg_bounds) + else: + return ValueRanges.unknown() + + @staticmethod + def indirect_indexing( + var: CSEVariable, + size: Union[sympy.Expr, int], + check: bool = True, + wrap_neg=True, + ): + if isinstance(size, int): + size = sympy.Integer(size) + assert isinstance(size, sympy.Expr), size + # Skip CSE since this doesn't return an expression + + if var.bounds.lower < 0: # type: ignore[operator] + if wrap_neg: + stm = ops.add(var, ops.index_expr(size, torch.long)) + # Mixed negative and non-negative + if var.bounds.upper >= 0: # type: ignore[operator] + lt = ops.lt(var, 0) + stm = ops.where(lt, stm, var) + else: + stm = var + + # Propagate bounds as we know how to compute them properly + new_bounds = ValueRanges.unknown() + if var.bounds != ValueRanges.unknown() and isinstance( + size, sympy.Number + ): + # Take the negative part of the bound and add size to it + # Then take union of that and the positive part + # This is a tighter bound than that of a generic ops.where, as we have info on the cond + neg_bounds = var.bounds & ValueRanges(-int_oo, -1) + new_bounds = ValueRanges( + neg_bounds.lower + size, neg_bounds.upper + size + ) + # We don't have a good way of representing the empty range + if var.bounds.upper >= 0: # type: ignore[operator] + pos = var.bounds & ValueRanges(0, int_oo) + new_bounds = new_bounds | pos + + var = self.cse.generate(self.compute, stm, bounds=new_bounds) + + sympy_var = parent_handler.indirect_indexing(var, size, check) + if generate_assert(check): + assert_lower = not (var.bounds.lower >= 0) + # value ranges cannot x < s when x and s are symbols + assert_upper = not isinstance(size, sympy.Number) or not ( + var.bounds.upper < size + ) + self.check_bounds(sympy_var, size, assert_lower, assert_upper) + return sympy_var + + @staticmethod + def check_bounds( + expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ): + return self.check_bounds(expr, size, lower, upper) + + @staticmethod + def load(name: str, index: sympy.Expr) -> CSEVariable: + if name in self.cse.invalidated_stores: + # A load from an invalidated store requires us to + # keep the actual buffer around + V.kernel.must_keep_buffers.add(name) + if free_symbol_is_type(index, SymT.TMP): + return self.indirect_load(name, index) + store_cache = self.cse.store_cache + if name in store_cache: + return store_cache[name] + out = self.load(name, index) + # count load that is not in the store_cache, and also not in the + # cse cache. + if out.use_count == 1: + self.num_load += 1 + return out + + @staticmethod + def _update_store_cache(name: str, value: CSEVariable): + self.cse.store_cache[name] = value + if self.current_node and name in V.graph.name_to_buffer: + buf = self.current_node.get_output(name) + for other_name in buf.get_mutations(): + self.cse.store_cache[other_name] = value + + @staticmethod + def store( + name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + self.store_buffer_names.add(name) + if mode is None: + CSEProxy._update_store_cache(name, value) + if name not in V.graph.removed_buffers: + return self.store(name, index, value, mode=mode) + else: + return None # type: ignore[return-value] + + @staticmethod + def store_reduction(name: str, index: sympy.Expr, value: CSEVariable): + self.store_buffer_names.add(name) + CSEProxy._update_store_cache(name, value) + + if name not in V.graph.removed_buffers: + return self.store_reduction(name, index, value) + + @staticmethod + def reduction( + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + self.num_reduction += 1 + return self.reduction(dtype, src_dtype, reduction_type, value) + + @staticmethod + def scan( + dtypes: Tuple[torch.dtype, ...], + combine_fn: Callable[ + [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], + Tuple[CSEVariable, ...], + ], + values: Tuple[CSEVariable, ...], + ) -> Tuple[CSEVariable, ...]: + return self.scan(dtypes, combine_fn, values) + + @staticmethod + def sort( + dtypes: Tuple[torch.dtype, ...], + values: Tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> Tuple[CSEVariable, ...]: + return self.sort(dtypes, values, stable, descending) + + @staticmethod + def bucketize( + values: CSEVariable, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> CSEVariable: + """ + [Note: Inductor bucketize op] + + Given values (tensor) and offsets_name (reference to the name of a 1D + tensor), calculate the bucket that each value belongs to. + + e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True + return = [ 0, 1, 1, 1, 1, 3, 3, 4]. + + When right == False, bucket i refers to range (offsets[i], offsets[i+1]]. + When right == True, bucket i refers to range [offsets[i], offsets[i+1]). + + Offsets must be non-decreasing or the result is undefined. + """ + return self.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + + # Use mypy to check protocol implemented correctly + def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: + return h + + super().__enter__() + assert self.overrides + parent_handler = self.overrides(V.get_ops_handler()) + self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Note that V.graph.scheduler can be None when codegening triton template + kernels. + """ + if V.graph.scheduler: + V.graph.scheduler.remove_kernel_local_buffers() + super().__exit__(exc_type, exc_val, exc_tb) + + def rename_indexing(self, index) -> sympy.Expr: + # adds the necessary kernel args for index expressions + # and renames variables in index expressions to kernel arg names + if isinstance(index, (list, tuple)): + return [self.rename_indexing(x) for x in index] # type: ignore[return-value] + index = V.graph.sizevars.simplify(index) + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) + replacements = { + x: self.args.size(x) + for x in sorted_symbols + if symbol_is_type( + x, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + ), + ) + } + return sympy_subs(index, replacements) + + def create_cse_var(self, *args, **kwargs): + return CSEVariable(*args, **kwargs) + + +@dataclasses.dataclass +class OptimizationContext: + key: ClassVar[str] = "opt_ctx" + + dtype: Optional[torch.dtype] = None + ops_name: str = "" + + +@functools.lru_cache(None) +def jinja2_env(): + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +class KernelTemplate: + """ + Base class for defining kernel templates. + + Children classes: TritonTemplate, CUDATemplate + """ + + @staticmethod + def indent_except_first(source: str, num_indents: int, indents_spacing=4): + lines = source.splitlines(True) + if len(lines) > 1: + lines[1:] = [ + (" " * indents_spacing * num_indents) + line for line in lines[1:] + ] + return "".join(lines) + + @staticmethod + def _template_from_string(source): + env = jinja2_env() + if env is not None: + env.filters["indent_except_first"] = KernelTemplate.indent_except_first + from jinja2 import TemplateSyntaxError + + class DetailedTemplateSyntaxError(TemplateSyntaxError): + def __init__(self, original_error): + super().__init__( + original_error.message, + original_error.lineno, + original_error.name, + original_error.filename, + ) + self.original_error = original_error + + def __str__(self): + error_info = f"Error in template at line {self.lineno}\n" + error_info += f"Error message: {self.message}\n" + if hasattr(self.original_error, "source"): + lines = self.original_error.source.split("\n") + error_info += "Context:\n" + start = max(0, self.lineno - 2) + end = min(len(lines), self.lineno + 2) + for i in range(start, end): + if i == self.lineno - 1: + error_info += f"{i+1}: --> {lines[i]}\n" + if hasattr(self.original_error, "column"): + error_info += ( + " " + + " " * (self.original_error.column - 1) + + "^\n" + ) + else: + error_info += f"{i+1}: {lines[i]}\n" + return error_info + + try: + return env.from_string(source) + except TemplateSyntaxError as e: + raise DetailedTemplateSyntaxError(e) from e + + return None + + @staticmethod + def _fake_get_dtype(fake_out): + _get_dtype_real = V.graph.get_dtype + + def get_dtype(name): + if name == fake_out.get_name(): + return fake_out.get_dtype() + return _get_dtype_real(name) + + return get_dtype + + def __init__(self, name: str): + self.name = name + + def maybe_append_choice(self, choices, **kwargs): + """ + Maybe generates a new ChoiceCaller and appends it into existing choices. + + choices: A list of ChoiceCallers. + kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. + """ + + try: + choices.append(self.generate(**kwargs)) + except NotImplementedError as e: + pass + + def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller": + """ + Generates a ChoiceCaller instance from the given arguments. + """ + + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d54762d9e8fd5e9914d4706ba88e4d8818ab36 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py @@ -0,0 +1,4978 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import functools +import itertools +import math +import re +import sys +import warnings +from copy import copy, deepcopy +from enum import Enum +from typing import cast, Dict, List, Optional, Sequence, Set, Tuple, Union + +import sympy + +import torch +import torch.fx +from torch._inductor import dependencies +from torch._prims_common import is_float_dtype, is_integer_dtype +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT + +from ..._dynamo.utils import counters +from .. import codecache, config, cpp_builder, cpu_vec_isa, ir, metrics +from ..loop_body import LoopBody +from ..scheduler import ( + BaseSchedulerNode, + BaseScheduling, + ForeachKernelSchedulerNode, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) +from ..utils import ( + cache_on_self, + get_bounds_index_expr, + get_fused_kernel_name, + has_free_symbols, + is_welford_reduction, + parallel_num_threads, + Placeholder, + sympy_index_symbol, + sympy_index_symbol_with_prefix, + sympy_product, + sympy_subs, +) +from ..virtualized import NullKernelHandler, ops, OpsValue, V +from .common import ( + BackendFeature, + BracesBuffer, + CppWrapperKernelArgs, + CSE, + CSEVariable, + DataTypePropagation, + DeferredLine, + DTYPE_TO_COMPUTATION_DTYPE, + IndentedBuffer, + Kernel, + KernelArgs, + OpOverrides, + OptimizationContext, +) +from .cpp_utils import ( + _get_dtype_from_loopbodies, + _get_loop_body, + cexpr, + cexpr_index, + codegen_rand, + CppCSEVariable, + DTYPE_TO_CPP, + INDEX_TYPE, + LocalBufferContext, + promote_args, + unify_mask_base_type, + value_to_cpp, +) + + +_IS_WINDOWS = sys.platform == "win32" + + +def get_export_declaration(): + return "__declspec(dllexport)" if _IS_WINDOWS else "" + + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") + +NATIVE_OMP_RTYPES = {"+", "*", "^", "||", "min", "max"} +RTYPE_TO_CPP = { + "sum": "+", + "prod": "*", + "xor_sum": "^", + "min": "min", + "max": "max", + "argmin": "argmin", + "argmax": "argmax", + "any": "||", + "welford_reduce": "welford", + "welford_combine": "welford", +} +VECTORIZABLE_RTYPES = { + "max", + "min", + "sum", + "prod", + "xor_sum", + "welford_reduce", + "welford_combine", + "argmin", + "argmax", + "any", +} + +PYTHON_TO_CPP = { + "Tensor": "at::Tensor", + "int": "long", + "float": "double", + "bool": "bool", + "str": "std::string", + "ScalarType": "c10::ScalarType", + "MemoryFormat": "at::MemoryFormat", + "Layout": "at::Layout", + "Device": "at::Device", + "number": "at::Scalar", +} + +CONTAINER_PYTHON_TO_CPP = { + "List": "std::vector", + "Optional": "std::optional", +} + +DTYPE_LOWP_FP = [ + torch.bfloat16, + torch.float16, +] + +VECTORIZABLE_DTYPES: List[torch.dtype] = [ + torch.float64, + torch.float, + torch.bfloat16, + torch.float16, + torch.bool, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, +] + +MASKED_VECTORIZABLE_DTYPES: List[torch.dtype] = [ + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, +] + + +def reduction_init(reduction_type, dtype): + if dtype in DTYPE_LOWP_FP: + # Since load promotes all half-precision inputs to float, the initial + # constant for reduction must be promoted as well + dtype = torch.float32 + if reduction_type in ("xor_sum", "sum", "any"): + return 0 + if reduction_type == "prod": + return 1 + if reduction_type in ("max", "argmax", "min", "argmin"): + cdtype = DTYPE_TO_CPP[dtype] + min_var = ( + f"-std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::min()" + ) + max_var = ( + f"std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::max()" + ) + init_var = min_var if reduction_type in ("max", "argmax") else max_var + return ( + init_var + if reduction_type in ("max", "min") + else f"IndexValue<{cdtype}>{{0, {init_var}}}" + ) + if is_welford_reduction(reduction_type): + return f"Welford<{DTYPE_TO_CPP[dtype]}>()" + raise AssertionError(reduction_type) + + +def reduction_acc_type(reduction_type, dtype): + scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]] + if is_welford_reduction(reduction_type): + return f"Welford<{scalar_type}>" + if reduction_type in {"argmin", "argmax"}: + return f"IndexValue<{scalar_type}>" + return scalar_type + + +def reduction_combine( + reduction_type, + var, + next_value, + index: Optional[sympy.Symbol] = None, + src_dtype=None, +): + is_bool = src_dtype == torch.bool + if reduction_type == "sum": + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" + if reduction_type == "prod": + return f"{var} * {next_value}" + if reduction_type == "xor_sum": + return f"{var} ^ {next_value}" + if reduction_type == "any": + return f"{var} || {next_value}" + if reduction_type in ("min", "max"): + return f"{reduction_type}_propagate_nan({var}, {next_value})" + if reduction_type == "welford_reduce": + return f"welford_combine({var}, {next_value})" + if reduction_type == "welford_combine": + if isinstance(next_value, tuple): + mean, m2, weight = next_value + else: + mean, m2, weight = reduction_project(reduction_type, next_value) + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + if reduction_type in ("argmin", "argmax"): + if index is not None: + return f"{reduction_type}_combine({var}, {next_value}, {index})" + else: + return f"{reduction_type}_combine({var}, {next_value})" + raise AssertionError(reduction_type) + + +def reduction_project(reduction_type, acc): + if is_welford_reduction(reduction_type): + return f"{acc}.mean", f"{acc}.m2", f"{acc}.weight" + elif reduction_type in {"argmin", "argmax"}: + return f"{acc}.index" + return acc + + +@functools.lru_cache +def stride_at(index: sympy.Expr, var: sympy.Symbol): + if not index.has(var): + # see test_torchinductor_dynamic_shapes.py::test_full_boolean_dynamic_shapes_cpu + # which has tmp0 = ops.index_expr(s0 >= 1024, torch.bool) and fails below calculation. + # in this case, there is no dependencies between index and var. + return sympy.Integer(0) + replacement = {var: var + 1} + new_index = sympy_subs(index, replacement) # type: ignore[arg-type] + return sympy.simplify(new_index - index) + + +@functools.lru_cache +def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int): + """ + Simplifies the index expression within the range of a vectorized loop. + Given a vectorized loop variable `var` in the range of a loop with `vec_length`, + this function transforms the `index` into an equivalent form. It handles + simplifications for cases where `var` can be expressed as `vec_length * a + b`, + where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences + of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations. + + NOTE: + The simplified index expression is intended for analysis purposes only, not + for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables + which are not dependent on the loop variable `var` in the vectorized range. Check + https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details. + + Examples: + 1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then + `FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable + when `div` is divisible by 16. + 2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free + variable when `mod` is divisible by 16. + """ + + div_freevar_id = 0 + mod_freevar_id = 0 + + def visit_indexing_div(divisor): + nonlocal div_freevar_id + result = FloorDiv(var, divisor) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_div_c{div_freevar_id}") + div_freevar_id += 1 + return result + + def visit_modular_indexing(divisor, modulus): + nonlocal mod_freevar_id + result = ModularIndexing(var, divisor, modulus) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length: + result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + return result + + original_index = index + + div = sympy.Wild("divisor", integer=True) + if index.has(FloorDiv): + index = index.replace(FloorDiv(var, div), visit_indexing_div) + + mod = sympy.Wild("modulus", integer=True) + if index.has(ModularIndexing): + index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) + + index = sympy.simplify(index) + if index != original_index: + return simplify_index_in_vec_range(index, var, vec_length) + + return index + + +@functools.lru_cache +def stride_at_vec_range( + index: sympy.Expr, var: sympy.Symbol, vec_length: Optional[int] = None +): + if vec_length: + index = simplify_index_in_vec_range(index, var, vec_length) + return stride_at(index, var) + + +class OuterLoopFusedSchedulerNode(FusedSchedulerNode): + @classmethod + def fuse( # type: ignore[override] + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode, outer_loop_fusion_depth + ): + assert node1.scheduler is node2.scheduler + assert all( + type(node) + in ( + OuterLoopFusedSchedulerNode, + SchedulerNode, + FusedSchedulerNode, + ) + for node in (node1, node2) + ) + if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): + return cls( + node1.scheduler, + ( + list(node1.get_outer_nodes()) + if type(node1) is OuterLoopFusedSchedulerNode + else [ + node1, + ] + ) + + ( + list(node2.get_outer_nodes()) + if type(node2) is OuterLoopFusedSchedulerNode + else [ + node2, + ] + ), + outer_loop_fusion_depth, + ) + else: + return cls(node1.scheduler, [node1, node2], outer_loop_fusion_depth) # type: ignore[list-item] + + def __init__( + self, + scheduler: "Scheduler", + outer_fused_nodes: List[Union[FusedSchedulerNode, SchedulerNode]], + outer_loop_fusion_depth, + ): + self.outer_fused_nodes: List[ + Union[FusedSchedulerNode, SchedulerNode] + ] = outer_fused_nodes + self.outer_loop_fusion_depth = outer_loop_fusion_depth + flatten_snodes = [] + for _node in self.outer_fused_nodes: + assert isinstance(_node, (SchedulerNode, FusedSchedulerNode)) + flatten_snodes.extend(list(_node.get_nodes())) + super().__init__(scheduler, flatten_snodes) # type: ignore[arg-type] + + def get_outer_nodes(self): + return self.outer_fused_nodes + + def check_outer_fusion_loop_level_attr( + self, cpp_kernel_proxy_list, outer_loop_fusion_depth + ): + # This function ensures that the same tiling split is applied at each loop level within the outer loop fusion depth. + # In the fusion stage, we only examine nodes with same vars and reduce. + # However, for nodes with same vars and reduce, the loops may still have different tile splits. + # For example (test_expr_vec_non_contiguous in test_cpu_repro.py): + # * buf0 tiling along the 2nd loop level, buf1 tiling along the 3rd loop level. + # If the check failed, we should fall back to standard loop codegen. + def _inner( + left_loop_level: LoopLevel, + right_loop_level: LoopLevel, + loop_fusion_depth: int, + ) -> bool: + # Check if same loop level attr + outer_loops_attr_compare_list = [ + "var", + "size", + "offset", + "steps", + ] + if not ( + all( + getattr(left_loop_level, attr_compare) + == getattr(right_loop_level, attr_compare) + for attr_compare in outer_loops_attr_compare_list + ) + ): + return False + + assert loop_fusion_depth >= 1 + if (loop_fusion_depth := loop_fusion_depth - 1) > 0: + # If the next loop level is expected to undergo outer loop fusion, + # there should be no kernel present at the current loop level. + assert ( + left_loop_level.kernel is None and right_loop_level.kernel is None + ) + # Check next loop level attr + if any( + # Assume no main/tail loop split at any outer loop fusion depth + # Given no clear performance benefit for this complex case + len(loop_level.inner) != 1 + for loop_level in [left_loop_level, right_loop_level] + ) or not _inner( + left_loop_level.inner[0], + right_loop_level.inner[0], + loop_fusion_depth, + ): + return False + + return True + + for idx in range(len(cpp_kernel_proxy_list) - 1): + left_loop_nest = cpp_kernel_proxy_list[idx].loop_nest + right_loop_nest = cpp_kernel_proxy_list[idx + 1].loop_nest + if any( + # Assume no main/tail loop split at any outer loop fusion depth + len(loop_nest.root) != 1 + for loop_nest in [left_loop_nest, right_loop_nest] + ) or not _inner( + left_loop_nest.root[0], right_loop_nest.root[0], outer_loop_fusion_depth + ): + return False + + return True + + def merge_outer_fusion_kernels( + self, + cpp_kernel_proxy_list, + ): + loop_nest_list: List[LoopNestWithSplit] = [ + kernel.loop_nest for kernel in cpp_kernel_proxy_list + ] + kernel_group = cpp_kernel_proxy_list[0].kernel_group + + def _merge_outer_fusion_loop_levels( + loop_level_nested_list: List[List["LoopLevel"]], + outer_loop_fusion_depth, + ): + assert outer_loop_fusion_depth >= 1 + # Assume no main/tail loop split at any outer loop fusion depth + assert all( + len(loop_level_list) == 1 for loop_level_list in loop_level_nested_list + ) + if (outer_loop_fusion_depth := outer_loop_fusion_depth - 1) >= 1: + # Further merge the next loop level + next_loop_level_nested_list = [ + loop_level_list[0].inner + for loop_level_list in loop_level_nested_list + ] + _merge_outer_fusion_loop_levels( + next_loop_level_nested_list, + outer_loop_fusion_depth, + ) + else: + outer_loop_fused_kernel = OuterLoopFusedKernel(kernel_group) + loop_level_of_first_kernel = loop_level_nested_list[0][0] + for kernel_idx in range(len(loop_level_nested_list)): + outer_loop_fused_kernel.inner.append( + deepcopy(loop_level_nested_list[kernel_idx][0]), + ) + loop_level_of_first_kernel.inner = [] + loop_level_of_first_kernel.kernel = outer_loop_fused_kernel + + # Merge the List[LoopNestWithSplit] from cpp_kernel_proxy_list + # into cpp_kernel_proxy_list[0].loop_nest + _merge_outer_fusion_loop_levels( + [_loop_nest.root for _loop_nest in loop_nest_list], # type: ignore[misc] + self.outer_loop_fusion_depth, + ) + return cpp_kernel_proxy_list[0] + + +class RecordOptimizationContext: + def __init__(self, func_name: str = ""): + self.func_name = func_name + self.current_node: Optional[torch.fx.Node] = None + self.opt_ctx: Optional[OptimizationContext] = None + + def __enter__(self): + assert V.interpreter + assert V.interpreter.current_node + + self.current_node = V.interpreter.current_node + assert self.current_node is not None + if OptimizationContext.key in self.current_node.meta: + self.opt_ctx = self.current_node.meta[OptimizationContext.key] + else: + self.opt_ctx = OptimizationContext() + assert self.opt_ctx is not None + self.opt_ctx.ops_name = self.func_name + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.current_node + assert self.opt_ctx + self.current_node.meta[OptimizationContext.key] = self.opt_ctx + + def get_opt_ctx(self): + return self.opt_ctx + + def get_fx_node(self): + assert self.current_node + return self.current_node + + +class CppOverrides(OpOverrides): + """Map element-wise ops to C++""" + + @staticmethod + def add(a, b): + return f"decltype({a})({a} + {b})" + + @staticmethod + def sub(a, b): + return f"decltype({a})({a} - {b})" + + @staticmethod + def mul(a, b): + return f"decltype({a})({a} * {b})" + + @staticmethod + def to_dtype(x, dtype, src_dtype=None, use_compute_types=True): + assert isinstance(x, CppCSEVariable) + if src_dtype is None: + src_dtype = x.dtype + expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype) + csevar = V.kernel.cse.generate(V.kernel.compute, expr) + csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype}) + if dtype in [torch.bfloat16, torch.float16] and src_dtype == torch.float: + """ + https://github.com/pytorch/pytorch/issues/115260 + For FusedSchedulerNode[node1, node2], the node2 loads what node1 stores and the buffer is + in low-precision floating point data type. When the output of node1 also serves as the output of the + kernel, the result of nodes would be different from the case when output of node1 is not the output + of the kernel (where we don't need to insert `to_dtype` for legalization). To address the problem, on + storing the lowp node1 output, we also add the inverse dtype conversion to high precision data type + to the cse cache. + + Example (pseudo code): + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = load(buf) + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + Without cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + With cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + # also add `to_dtype(node1_input_lowp, dtype=torch.float)` -> `node1_output` to cse cache + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = node1_output # hit cse cache + """ + V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype) + return csevar + + @staticmethod + def to_dtype_bitcast(x, dtype, src_dtype): + assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP" + if src_dtype in (torch.float16, torch.bfloat16): + # c10::bit_cast requires the source and target have the bitwidth. + # Because the input tensor's dtype could be promoted, e.g. from float16 to + # float, we have to cast the tensor to its original source dtype before + # invoking bit_cast. We also need to convert the bit-casted tensor + # back to float to make sure we keep using higher precision values + # for the rest of the computation. + cast_x = f"c10::convert<{DTYPE_TO_CPP[src_dtype]}>({x})" + cast_x = f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({cast_x})" + return f"c10::convert<{DTYPE_TO_CPP[torch.float32]}>({cast_x})" + else: + return f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({x})" + + @staticmethod + def abs(x): + return f"std::abs({x})" + + @staticmethod + def sin(x): + return f"std::sin({x})" + + @staticmethod + def cos(x): + return f"std::cos({x})" + + @staticmethod + def neg(x): + return f"decltype({x})(-{x})" + + @staticmethod + def exp(x): + # return f"Sleef_expf_u10({x})" + return f"std::exp({x})" + + @staticmethod + def exp2(x): + return f"std::exp2({x})" + + @staticmethod + def expm1(x): + return f"std::expm1({x})" + + @staticmethod + def erf(x): + return f"std::erf({x})" + + @staticmethod + def erfc(x): + return f"std::erfc({x})" + + @staticmethod + def erfinv(x): + return f"calc_erfinv({x})" + + @staticmethod + def sqrt(x): + return f"std::sqrt({x})" + + @staticmethod + def rsqrt(x): + return f"1 / std::sqrt({x})" + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::log1p({x})" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def tan(x): + return f"std::tan({x})" + + @staticmethod + def tanh(x): + return f"std::tanh({x})" + + @staticmethod + def signbit(x): + """ + On windows std::signbit only support float type. + Ref: https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/signbit?view=msvc-170 + """ + return ( + f"std::signbit(static_cast({x}))" + if _IS_WINDOWS + else f"std::signbit({x})" + ) + + @staticmethod + def pow(a, b): + return f"std::pow({a}, {b})" + + @staticmethod + def log(x): + return f"std::log({x})" + + @staticmethod + def round(x): + return f"std::nearbyint({x})" + + @staticmethod + def floor(x): + return f"std::floor({x})" + + @staticmethod + def floordiv(a, b): + # a and b are integer type + quot = f"{a} / {b}" + rem = f"{a} % {b}" + return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})" + + @staticmethod + def ceil(x): + return f"std::ceil({x})" + + @staticmethod + def trunc(x): + return f"std::trunc({x})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + return f"{a} / {b}" + + @staticmethod + def fmod(a, b): + return f"std::fmod({a}, {b})" + + @staticmethod + def isinf(x): + return f"std::isinf({x})" + + @staticmethod + def isnan(x): + return f"std::isnan({x})" + + @staticmethod + def lgamma(x): + return f"std::lgamma({x})" + + @staticmethod + def acos(x): + return f"std::acos({x})" + + @staticmethod + def acosh(x): + return f"std::acosh({x})" + + @staticmethod + def cosh(x): + return f"std::cosh({x})" + + @staticmethod + def sinh(x): + return f"std::sinh({x})" + + @staticmethod + def asin(x): + return f"std::asin({x})" + + @staticmethod + def asinh(x): + return f"std::asinh({x})" + + @staticmethod + def atan2(x, y): + return f"std::atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"std::atan({x})" + + @staticmethod + def atanh(x): + return f"std::atanh({x})" + + @staticmethod + def copysign(x, y): + return f"std::copysign({x}, {y})" + + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys): + return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys) + + code = BracesBuffer() + exponent = V.kernel.cse.newvar() + mantissa = V.kernel.cse.newvar() + code.writeline(f"int32_t {exponent};") + code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.cache[cache_key] = cse_var + return mantissa, exponent + + @staticmethod + def hypot(x, y): + return f"std::hypot({x}, {y})" + + @staticmethod + def log10(x): + return f"std::log10({x})" + + @staticmethod + def log2(x): + return f"std::log2({x})" + + @staticmethod + def nextafter(x, y): + return f"std::nextafter({x}, {y})" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::max({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def minimum(a, b): + return f"min_propagate_nan({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"max_propagate_nan({a}, {b})" + + @staticmethod + def where(a, b, c): + return f"{a} ? {b} : {c}" + + @staticmethod + def mod(a, b): + return f"mod({a}, {b})" + + @staticmethod + def constant(val, dtype): + if dtype in DTYPE_LOWP_FP: + # Since load promotes all half-precision inputs to float, constants + # must be promoted as well + dtype = torch.float32 + return value_to_cpp(val, DTYPE_TO_CPP[dtype]) + + @staticmethod + def index_expr(expr, dtype): + idx_str = cexpr(V.kernel.rename_indexing(expr)) + var = V.kernel.cse.generate( + V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr) + ) + return ops.to_dtype(var, dtype) + + @staticmethod + def masked(mask, body, other): + code = BracesBuffer() + + # Write masked operation into a lambda + body_var = V.kernel.cse.newvar() + code.writeline(f"auto {body_var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + # Use the lambda's return type as the type of other + other_code = value_to_cpp(other, f"decltype({body_var}())") + return f"{mask} ? {body_var}() : {other_code}" + + @staticmethod + def logical_and(a, b): + return f"{a} && {b}" + + @staticmethod + def logical_not(a): + return f"!{a}" + + @staticmethod + def logical_or(a, b): + return f"{a} || {b}" + + @staticmethod + def logical_xor(a, b): + return f"{a} != {b}" + + @staticmethod + def bitwise_and(a, b): + return f"decltype({a})({a} & {b})" + + @staticmethod + def bitwise_not(a): + return f"decltype({a})(~{a})" + + @staticmethod + def bitwise_or(a, b): + return f"decltype({a})({a} | {b})" + + @staticmethod + def bitwise_xor(a, b): + return f"decltype({a})({a} ^ {b})" + + @staticmethod + def bitwise_left_shift(a, b): + return f"decltype({a})({a} << {b})" + + @staticmethod + def bitwise_right_shift(a, b): + return f"decltype({a})({a} >> {b})" + + @staticmethod + def rand(seed: sympy.Expr, offset: sympy.Expr): + return f"normalized_rand_cpu({seed}, {offset})" + + @staticmethod + def randn(seed: sympy.Expr, offset: sympy.Expr): + return f"randn_cpu({seed}, {offset})" + + @staticmethod + def randint64(seed: sympy.Expr, offset: sympy.Expr, low, high): + return f"randint64_cpu({seed}, {offset}, {low}, {high})" + + @staticmethod + def sigmoid(x): + return f"decltype({x})(1) / (decltype({x})(1) + std::exp(-{x}))" + + @staticmethod + def sign(x): + code = BracesBuffer() + scalar_zero = f"decltype({x})(0)" + scalar_one = f"decltype({x})(1)" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {x} > 0 ? {scalar_one} : {scalar_zero};") + code.writeline(f"auto right = {x} < 0 ? {scalar_one} : {scalar_zero};") + code.writeline("return left - right;") + code.writeline("()") + return code + + +CppOverrides._initialize_pointwise_overrides("cpp") + + +class CppVecOverrides(CppOverrides): + """Map element-wise ops to aten vectorization C++""" + + def __new__(cls, *args, **kargs): + self = super().__new__(cls) + + def wrap(func): + # `CppVecKernel` generates both scalar ops and vector ops according to + # whether the inputs are scalars or vectors while all ops in `CppVecOverrides` + # (except for some ops explained below) assume the inputs are vectors. We wrap the ops in + # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to + # `CppOverrides` when all inputs are scalars. + # + # Notes on ops handled separately in their own functions: + # `ops.masked`: + # needs recursive handling of masked body. + # `ops.index_expr`: + # needs to further analyze the dependency of the index expression on + # the tiling itervar. + def wrapper(*args, **kwargs): + scalars = [ + arg + for arg in args + if isinstance(arg, (int, sympy.Expr)) + or (isinstance(arg, CppCSEVariable) and not arg.is_vec) + ] + vectors = [ + arg + for arg in args + if isinstance(arg, CppCSEVariable) and arg.is_vec + ] + new_args = list(args) + if scalars and vectors: + new_args = [] + for arg in args: + if isinstance(arg, (int, sympy.Expr)): + if isinstance(arg, sympy.Expr) and not arg.is_number: + arg = ops.index_expr(arg, torch.int64) + else: + arg = ops.constant(arg, torch.int64) + arg = arg.value if isinstance(arg, OpsValue) else arg + new_args.append(arg) + + # DType Promotion + if vectors: + # We have saw several data type mismatch issues related with index_expr in + # the lowering phase of torch.int8. torch.int32, torch.int64. + # 1. int32 and int64 in test_torchinductor.py::test_max_pool2d_with_indices_backward3_cpu + # 2. int8 and int32 in test_torchinductor.py::test_max_pool2d5_cpu + # 3. int32 and fp32 in test_torchinductor_dynamic_shapes.py::test_avg_pool2d8_dynamic_shapes_cpu + if len(new_args) == 2: + new_args = promote_args(new_args) + elif func == CppVecOverrides.where: + new_args[1:] = promote_args(new_args[1:]) + + # Broadcast scalar args to vector + if scalars and vectors: + assert isinstance(V.kernel, CppVecKernel) + new_args = [ + V.kernel.broadcast(new_arg) + if ( + isinstance(new_arg, CppCSEVariable) + and not new_arg.is_vec + and func + not in [ + CppVecOverrides.rand, + CppVecOverrides.randn, + CppVecOverrides.randint64, + ] + ) + else new_arg + for new_arg in new_args + ] + + if vectors: + return func(*new_args, **kwargs) + else: + # fallback to scalar ops + scalar_ops = super(CppVecOverrides, self) + scalar_func = getattr( + scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__) # type: ignore[attr-defined] + ) + assert scalar_func is not None + return scalar_func(*args, **kwargs) + + return wrapper + + for name, method in vars(CppVecOverrides).items(): + if getattr(method, "__class__", None) == staticmethod and name not in [ + "masked", + "index_expr", + ]: + setattr(self, name, wrap(method.__func__)) + + return self + + @staticmethod + def add(a, b): + return f"{a} + {b}" + + @staticmethod + def sub(a, b): + return f"{a} - {b}" + + @staticmethod + def mul(a, b): + return f"{a} * {b}" + + @staticmethod + def truediv(a, b): + return f"{a} / {b}" + + @staticmethod + def abs(x): + return f"{x}.abs()" + + @staticmethod + def sin(x): + return f"{x}.sin()" + + @staticmethod + def cos(x): + return f"{x}.cos()" + + @staticmethod + def exp(x): + return f"{x}.exp()" + + @staticmethod + def exp2(x): + return f"{x}.exp2()" + + @staticmethod + def expm1(x): + # decompose for a better performance + vec_one = f"decltype({x})(1)" + return f"{x}.exp() - {vec_one}" + + @staticmethod + def erf(x): + return f"{x}.erf()" + + @staticmethod + def erfc(x): + return f"{x}.erfc()" + + @staticmethod + def erfinv(x): + return f"{x}.erfinv()" + + @staticmethod + def sqrt(x): + return f"{x}.sqrt()" + + @staticmethod + def eq(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} == {y})" + + @staticmethod + def ne(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + if x.dtype == torch.bool: + assert y.dtype == torch.bool + x_cast, y_cast = unify_mask_base_type(V.kernel.compute, (x, y)) + return f"{x_cast} != {y_cast}" + else: + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} != {y})" + + @staticmethod + def lt(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} < {y})" + + @staticmethod + def gt(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} > {y})" + + @staticmethod + def le(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} <= {y})" + + @staticmethod + def ge(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} >= {y})" + + @staticmethod + def and_(x, y): + return f"{x} & {y}" + + @staticmethod + def rsqrt(x): + return f"{x}.rsqrt()" + + @staticmethod + def pow(a, b): + return f"{a}.pow({b})" + + @staticmethod + def log(x): + return f"{x}.log()" + + @staticmethod + def round(x): + return f"{x}.round()" + + @staticmethod + def floor(x): + return f"{x}.floor()" + + @staticmethod + def ceil(x): + return f"{x}.ceil()" + + @staticmethod + def trunc(x): + return f"{x}.trunc()" + + @staticmethod + def fmod(a, b): + return f"{a}.fmod({b})" + + @staticmethod + def lgamma(x): + return f"{x}.lgamma()" + + @staticmethod + def logical_and(a, b): + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"~{a}" + + @staticmethod + def logical_or(a, b): + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def bitwise_and(a, b): + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def load_seed(name, offset): + assert isinstance(V.kernel, CppVecKernel) + return f"{V.kernel.load(name, offset)}" + + @staticmethod + def rand(seed, offset): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = ( + f"result[offset_idx] = normalized_rand_cpu({seed}, offset[offset_idx]);" + ) + return codegen_rand(offset, code, rand_function) + + @staticmethod + def randn(seed, offset): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = f"result[offset_idx] = randn_cpu({seed}, offset[offset_idx]);" + return codegen_rand(offset, code, rand_function) + + @staticmethod + def randint64(seed, offset, low, high): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = f"result[offset_idx] = randint64_cpu({seed}, offset[offset_idx], {low}, {high});" + return codegen_rand(offset, code, rand_function, torch.int64) + + @staticmethod + def remainder(a, b): + assert ( + a.dtype == b.dtype + ), "remainder vec implementation expect the same inputs' dtype." + return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}" + + @staticmethod + def tan(a): + return f"{a}.tan()" + + @staticmethod + def tanh(a): + vec_one = f"decltype({a})(1)" + vec_two = f"decltype({a})(2)" + vec_minus_two = f"decltype({a})(-2)" + return f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}" + + @staticmethod + def reciprocal(a): + return f"{a}.reciprocal()" + + @staticmethod + def atan(x): + return f"{x}.atan()" + + @staticmethod + def acos(x): + return f"{x}.acos()" + + @staticmethod + def asin(x): + return f"{x}.asin()" + + @staticmethod + def cosh(x): + return f"{x}.cosh()" + + @staticmethod + def sinh(x): + return f"{x}.sinh()" + + @staticmethod + def log10(x): + return f"{x}.log10()" + + @staticmethod + def log2(x): + return f"{x}.log2()" + + @staticmethod + def nextafter(x, y): + return f"{x}.nextafter({y})" + + @staticmethod + def copysign(a, b): + return f"{a}.copysign({b})" + + @staticmethod + def atan2(a, b): + return f"{a}.atan2({b})" + + @staticmethod + def hypot(a, b): + return f"{a}.hypot({b})" + + @staticmethod + def atanh(x): + # For real x, atanh(x) = 1/2 * log((1+x)/(1-x)) + vec_one = f"decltype({x})(1)" + vec_one_half = f"decltype({x})(0.5)" + return f"{vec_one_half} * (({vec_one} + {x})/({vec_one} - {x})).log()" + + @staticmethod + def asinh(x): + # For real x, asinh(x) = log(x + sqrt(1 + x**2)) + vec_one = f"decltype({x})(1)" + return f"({x} + ({vec_one} + {x}*{x}).sqrt()).log()" + + @staticmethod + def acosh(x): + return f"{x}.acosh()" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"at::vec::clamp_min({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + # TODO: this seems to be dead + @staticmethod + def sigmoid(x): + return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())" + + @staticmethod + def neg(x): + return f"{x}.neg()" + + @staticmethod + def floordiv(a, b): + if is_float_dtype(a.dtype): + assert ( + a.dtype == b.dtype + ), "div_floor_floating_vec implementation expect the same inputs' dtype." + return f"div_floor_floating_vec({a}, {b})" + else: + assert all(is_integer_dtype(item.dtype) for item in [a, b]) + # a and b are integer type + _t = f"decltype({a})" + if V.kernel._get_raw_num_vectors(b.dtype) < 1: + # Doing blend to set the remaining bits of b to non-zero + b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})" + quot = f"{a} / {b}" + has_rem = f"({a} % {b} != {_t}(0))" + is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))" + return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + if V.kernel._get_raw_num_vectors(b.dtype) < 1: + # Doing blend to set the remaining bits of b to non-zero + _t = f"decltype({b})" + b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})" + return f"{a} / {b}" + + @staticmethod + def minimum(a, b): + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} & {b_cast}" + else: + return f"at::vec::minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} | {b_cast}" + else: + return f"at::vec::maximum({a}, {b})" + + @staticmethod + def square(a): + return f"{a} * {a}" + + @staticmethod + def where(a, b, c): + assert isinstance(V.kernel, CppVecKernel) + if b.dtype == torch.bool: + assert c.dtype == torch.bool + blendv_a, blendv_b, blendv_c = unify_mask_base_type( + V.kernel.compute, (a, b, c) + ) + return f"decltype({blendv_b})::blendv({blendv_c}, {blendv_b}, {blendv_a})" + else: + return f"decltype({b})::blendv({c}, {b}, {V.kernel._get_mask_cast(a, b.dtype)})" + + @staticmethod + def sign(x): + code = BracesBuffer() + vec_zero = f"decltype({x})(0)" + vec_one = f"decltype({x})(1)" + blendv_l = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})" + blendv_r = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {blendv_l};") + code.writeline(f"auto right = {blendv_r};") + code.writeline("return left - right;") + code.writeline("()") + return code + + @staticmethod + def to_dtype(x, dtype, src_dtype=None, use_compute_dtypes=True): + assert dtype in [ + torch.bool, + torch.float64, + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + ], f"{__name__} does not support {dtype}" + assert isinstance(x, CppCSEVariable) + src_dtype = x.dtype + expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype) + csevar = V.kernel.cse.generate(V.kernel.compute, expr) + csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype}) + if dtype in [torch.bfloat16, torch.float16] and src_dtype == torch.float: + V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype) + return csevar + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"{x}.log1p()" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def masked(mask, body, other): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + var = V.kernel.cse.newvar() + with V.kernel.masked(mask) as new_mask: + code.writeline(f"auto {var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + dtype = result.dtype + body_code = f"{var}()" + body_code_vec = ( + body_code + if result.is_vec + else f"{V.kernel._get_vec_type(dtype)}({body_code})" + ) + other_code = value_to_cpp(other, DTYPE_TO_CPP[dtype]) + # loading bool as VecMask + other_code_vec = ( + f"{V.kernel._get_mask_type()}::from({other_code})" + if dtype == torch.bool + else f"{V.kernel._get_vec_type(dtype)}({other_code})" + ) + assert isinstance(new_mask, CppCSEVariable), new_mask + if new_mask.is_vec: + code = BracesBuffer() + code.writeline("[&]") + with V.kernel.swap_buffers(code), code.indent(): + code.writeline(f"if ({new_mask}.all_zero())") + with code.indent(): + code.writeline(f"return {other_code_vec};") + code.writeline("else") + with code.indent(): + # Create cse variable to reuse kernel.overrides.where + body_vec_var = V.kernel.cse.generate( + V.kernel.compute, + body_code_vec, + ) + other_vec_var = V.kernel.cse.generate( + V.kernel.compute, + other_code_vec, + ) + assert isinstance(body_vec_var, CppCSEVariable), body_vec_var + assert isinstance(other_vec_var, CppCSEVariable), other_vec_var + body_vec_var.dtype = dtype + other_vec_var.dtype = dtype + code.writeline( + f"return {V.kernel.overrides.where(new_mask, body_vec_var, other_vec_var)};" + ) + code.writeline("()") + csevar = V.kernel.cse.generate( + V.kernel.compute, + code, + ) + elif result.is_vec: + csevar = V.kernel.cse.generate( + V.kernel.compute, f"{mask} ? {body_code_vec} : {other_code_vec}" + ) + else: + csevar = V.kernel.cse.generate( + V.kernel.compute, f"{mask} ? {body_code} : {other_code}" + ) + # `result` is explicitly added to the args for correct propagation + # of relevant itervars and vectorization status. + csevar.update_on_args("masked", (mask, body, other, result), {}) + return csevar + + @staticmethod + def index_expr(expr, dtype): + assert isinstance(V.kernel, CppVecKernel) + index = V.kernel.rename_indexing(expr) + tiling_var = V.kernel.itervars[V.kernel.tiling_idx] + stride = V.kernel._try_get_const_stride(index, tiling_var) + if stride == 0: + return CppOverrides.index_expr(expr, dtype) + elif stride is not None: + idx = V.kernel.cse.generate( + V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr) + ) + value = ops.to_dtype(idx, dtype) + if isinstance(value, OpsValue): + value = value.value + csevar = V.kernel.arange(value, stride) + else: + csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment] + None, index, dtype, V.kernel.compute + ) + csevar.update_on_args("index_expr", (expr, dtype), {}) + return csevar + + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys): + return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys) + + cdtype = DTYPE_TO_CPP[x.dtype] + size = V.kernel.tail_size if V.kernel.tail_size else V.kernel.tiling_factor + code = BracesBuffer() + exponent = V.kernel.cse.newvar() + mantissa = V.kernel.cse.newvar() + exponent.update_on_args("frexp", (x,), kwargs={}) + mantissa.update_on_args("frexp", (x,), kwargs={}) + n_vec = V.kernel._get_num_vectors(x.dtype) + mantissa_t = ( + f"at::vec::Vectorized<{cdtype}>" + if n_vec == 1 + else f"at::vec::VectorizedN<{cdtype}, {n_vec}>" + ) + code.writeline( + f"at::vec::Vectorized {exponent};" + if n_vec == 1 + else f"at::vec::VectorizedN {exponent};" + ) + code.writeline(f"{mantissa_t} {mantissa};") + code.writeline("[&]()") + with code.indent(): + code.writeline( + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf;" + ) + code.writeline(f"{x}.store(tmpbuf.data(), {cexpr_index(size)});") + code.writeline( + f"__at_align__ std::array tmpbuf_exponent;" + ) + code.writeline( + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf_mantissa;" + ) + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") + with code.indent(): + code.writeline( + "tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]);" + ) + code.writeline( + f"{exponent} = at::vec::Vectorized::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" + if n_vec == 1 + else f"{exponent} = at::vec::VectorizedN::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" + ) + code.writeline( + f"{mantissa} = {mantissa_t}::loadu(tmpbuf_mantissa.data(), {cexpr_index(size)});" + ) + code.writeline("();") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.cache[cache_key] = cse_var + return mantissa, exponent + + @classmethod + def scalarize(cls, scalar_func): + def inner(*args, **kwargs): + assert not kwargs + kernel = V.kernel + assert isinstance(kernel, CppVecKernel) + code = BracesBuffer() + code.writeline("[&]()") + vec_dtype = args[0].dtype + n_vec = kernel._get_num_vectors(vec_dtype) + size = kernel.tail_size if kernel.tail_size else kernel.tiling_factor + scalar_args = [] + cdtype = DTYPE_TO_CPP[vec_dtype] + output_mask = scalar_func.__name__ in ( + "isinf", + "isnan", + "signbit", + ) + octype = "bool" if output_mask else cdtype + octype = ( + DTYPE_TO_CPP[args[-2]] + if (scalar_func.__name__ == "to_dtype_bitcast") + else octype + ) + with code.indent(): + for argidx, arg in enumerate(args): + if isinstance(arg, CppCSEVariable): + assert arg.is_vec + assert arg.dtype == vec_dtype + code.writeline( + f"__at_align__ std::array<{cdtype}, {kernel.tiling_factor}> tmpbuf{argidx};" + ) + code.writeline( + f"{arg}.store(tmpbuf{argidx}.data(), {cexpr_index(size)});" + ) + scalar_args.append(f"tmpbuf{argidx}[i]") + else: + scalar_args.append(arg) + code.writeline( + f"__at_align__ std::array<{octype}, {kernel.tiling_factor}> tmpbuf_out;" + ) + res = scalar_func(*scalar_args) + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") + with code.indent(): + code.writeline(f"tmpbuf_out[i] = {res};") + if output_mask: + assert not kernel.tail_size + load_args = "tmpbuf_out.data()" + load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from" + else: + load_args = f"tmpbuf_out.data(), {cexpr_index(size)}" + if n_vec == 1: + load_fn = f"at::vec::Vectorized<{octype}>::loadu" + else: + load_fn = f" at::vec::VectorizedN<{octype}, {n_vec}>::loadu" + code.writeline(f"return {load_fn}({load_args});") + code.writeline("()") + return code + + return inner + + @classmethod + def _initialize_scalarize(cls): + for name, method in vars(CppOverrides).items(): + if getattr(method, "__class__", None) == staticmethod and name not in vars( + CppVecOverrides + ): + func = cls.scalarize(method.__func__) + func.__name__ = name + setattr(cls, name, staticmethod(func)) + + +CppVecOverrides._initialize_pointwise_overrides("cppvec") +CppVecOverrides._initialize_scalarize() + + +class CppTile2DOverrides(CppVecOverrides): + @staticmethod + def index_expr(expr, dtype): + assert isinstance(V.kernel, CppTile2DKernel) + expr = V.kernel.transform_indexing(expr) + return CppVecOverrides.index_expr(expr, dtype) + + +class CppKernel(Kernel): + overrides = CppOverrides # type: ignore[assignment] + sexpr = cexpr + newvar_prefix = "auto " + suffix = ";" + + def __init__(self, args, num_threads): + super().__init__(args) + self.call_ranges: Optional[Tuple[sympy.Expr, ...]] = None + self.ranges: List[sympy.Expr] = [] + self.itervars: List[sympy.Symbol] = [] + self.reduction_depth = None + self.reduction_prefix = IndentedBuffer() + self.reduction_suffix = IndentedBuffer() + self.parallel_reduction_prefix = IndentedBuffer() + self.parallel_reduction_suffix = IndentedBuffer() + self.local_reduction_init = IndentedBuffer() + self.local_reduction_stores = IndentedBuffer() + self.is_reduction = False + self.non_parallel_reduction_prefix = IndentedBuffer() + self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") + self.weight_recps_cse = CSE( + self.newvar_prefix, self.suffix, name_prefix="wrecps" + ) + self.preloads = IndentedBuffer() + self.poststores = IndentedBuffer() + self.num_threads = num_threads # num_threads the kernel specialized for + self.reduction_omp_dec: Dict[Tuple[str, str], str] = {} + + def _gen_parallel_reduction_buffers( + self, + acc, + acc_type, + reduction_type, + dtype, + reduction_combine_fn=reduction_combine, + reduction_init_fn=reduction_init, + welford_weight_reciprocal_vec_fn=None, + ): + if config.cpp.dynamic_threads and not self.parallel_reduction_prefix: + self.parallel_reduction_prefix.writeline( + "int max_threads = omp_get_max_threads();" + ) + acc_local = f"{acc}_local" + num_threads = ( + "max_threads" if config.cpp.dynamic_threads else parallel_num_threads() + ) + acc_per_thread_var_name = f"{acc}_arr" + acc_per_thread = f"{acc_per_thread_var_name}[{num_threads}]" + """ + MSVC don't support dynamic array(VLA). Please use std::unique_ptr to instead of it. + Ref: https://stackoverflow.com/questions/56555406/creating-dynamic-sized-array-using-msvc-c-compiler + MSVC is the only one compiler, which not support VLA. And MSVC can't get good inductor performance. + So, we can use unique_ptr make it works on MSVC. + For other compilers, we continue to use VLA to get best performence. + """ + acc_per_thread_unique_ptr_decl = f"auto {acc_per_thread_var_name} = std::make_unique<{acc_type}[]>({num_threads})" + acc_per_thread_vla_decl = f"{acc_per_thread_var_name}[{num_threads}]" + acc_local_in_array = acc_per_thread.replace(f"[{num_threads}]", "[tid]") + self.local_reduction_init.writeline( + f"{acc_type} {acc_local} = {reduction_init_fn(reduction_type, dtype)};" + ) + self.parallel_reduction_prefix.writeline( + f"{acc_per_thread_unique_ptr_decl};" + if cpp_builder.is_msvc_cl() + else f"{acc_type} {acc_per_thread_vla_decl};" + ) + self.parallel_reduction_prefix.writelines( + [ + f"for (int tid = 0; tid < {num_threads}; tid++)", + "{", + f" {acc_local_in_array} = {reduction_init_fn(reduction_type, dtype)};", + "}", + ], + ) + self.local_reduction_stores.writelines( + [ + f"{acc_local_in_array} = {acc_local};", + ] + ) + self.parallel_reduction_suffix.writelines( + [ + f"for (int tid = 0; tid < {num_threads}; tid++)", + "{", + f" {acc} = {reduction_combine_fn(reduction_type, acc, acc_local_in_array, src_dtype=dtype)};", + "}", + ], + ) + + def get_reduction_var_pattern(self, line: str): + return re.search("tmp_acc[0-9]+", line) + + def update_stores_with_parallel_reduction(self): + for i, line in enumerate(self.stores._lines): + if isinstance(line, str): + m = self.get_reduction_var_pattern(line) + if m: + var_name = m.group(0) + self.stores._lines[i] = line.replace(var_name, f"{var_name}_local") + + @contextlib.contextmanager + def masked(self, mask): + """Context manager to add an additional mask to loads and stores.""" + prior = self._load_mask + if prior: + mask = ops.and_(mask, prior) + if isinstance(mask, OpsValue): + mask = mask.value + assert isinstance(mask, CppCSEVariable) + # see NOTE [dtype of CppCSEVariable] + # mask's dtype should be bool + mask.dtype = torch.bool + + self._load_mask = mask + try: + yield mask + finally: + self._load_mask = prior + + def scale_index_with_offset( + self, index: sympy.Expr, scale=1, itervar_idx=-1, offset=0 + ): + var = self.itervars[itervar_idx] + replacement = {var: var * scale + offset} + new_index = sympy_subs(index, replacement) + return new_index + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in cpp code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the cpp kernel. + """ + return cexpr(self.rename_indexing(index)) + + def index_indirect_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + """ + Check if an index has free symbol CppCSEVariable that depends on `itervar`. + """ + return any( + self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined] + for s in index.free_symbols + if s.name in self.cse.varname_map # type: ignore[attr-defined] + and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined] + ) + + def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + return itervar in index.free_symbols or self.index_indirect_depends_on( + index, itervar + ) + + def var_ranges(self): + return dict(zip(self.itervars, self.ranges)) + + def check_bounds( + self, + expr: sympy.Expr, + size: sympy.Expr, + lower: bool, + upper: bool, + ): + if not (lower or upper): + return + + indirect = free_symbol_is_type(expr, SymT.TMP) + if indirect: + # indexing in compute + csevar = ops.index_expr(expr, torch.int64).value + buffer = V.kernel.compute + else: + # indexing in loads + prior_compute = V.kernel.compute + try: + V.kernel.compute = self.loads + csevar = ops.index_expr(expr, torch.int64).value + finally: + V.kernel.compute = prior_compute + buffer = self.loads + + size_str = V.kernel.sexpr(self.rename_indexing(size)) if upper else None + + line = self.indirect_assert( + csevar, "0" if lower else None, size_str, self._load_mask + ) + self.cse.generate(buffer, line, assignment=False) + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + line = f"{var}[{cexpr_index(index)}]" + csevar = self.cse.generate(self.loads, line) + csevar.update_on_args("load", (self, name, index), {}) + return csevar + + def store(self, name, index, value, mode=None): + assert "buf" in name + var = self.args.output(name) + index = self.rename_indexing(index) + if mode is None: + line = f"{var}[{cexpr_index(index)}] = {value};" + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + line = f"{var}[{cexpr_index(index)}] += {value};" + else: + dtype = V.graph.get_dtype(name) + # mirroring static_cast(...) in load: + value = f"static_cast<{DTYPE_TO_CPP[dtype]}>({value})" + line = f"atomic_add(&{var}[{cexpr_index(index)}], {value});" + else: + raise NotImplementedError(f"store mode={mode}") + self.stores.writeline(DeferredLine(name, line)) + + def reduction(self, dtype, src_dtype, reduction_type, value): + argmax_or_argmin = reduction_type in {"argmax", "argmin"} + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + self.is_reduction = True + init_dtype = src_dtype if argmax_or_argmin else dtype + acc_type = reduction_acc_type(reduction_type, init_dtype) + self.reduction_prefix.writeline( + f"{acc_type} {acc} = {reduction_init(reduction_type, init_dtype)};" + ) + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + self.stores.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, value, index)};" + ) + self._gen_parallel_reduction_buffers(acc, acc_type, reduction_type, init_dtype) + result = reduction_project(reduction_type, acc) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + self.reduction_suffix.writeline( + DeferredLine(name, f"{var}[{cexpr_index(index)}] = {value};") + ) + + def set_ranges(self, lengths, reduction_lengths): + if self.call_ranges: + assert self.call_ranges == tuple(lengths) + tuple( + reduction_lengths + ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" + assert self.reduction_depth == len(lengths) + else: + self.call_ranges = tuple(lengths) + tuple(reduction_lengths) + self.ranges = [self.rename_indexing(x) for x in self.call_ranges] + self.itervars = [ + sympy_index_symbol_with_prefix(SymT.XBLOCK, n) + for n in range(len(self.ranges)) + ] + self.reduction_depth = len(lengths) + return ( + self.itervars[: self.reduction_depth], + self.itervars[self.reduction_depth :], + ) + + def size_hint(self): + return V.graph.sizevars.size_hint( + sympy_product(self.call_ranges), fallback=8192 + ) + + def codegen_loops_impl(self, loop_nest, code, worksharing): + threads = parallel_num_threads() + assert self.call_ranges is not None + kernels = loop_nest.get_kernels() + has_outer_loop_kernel = any( + isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels + ) + if has_outer_loop_kernel: + assert len(kernels) == 1 + assert isinstance(kernels[0], OuterLoopFusedKernel) + par_depth = kernels[0].decide_parallel_depth( + loop_nest.max_parallel_depth(), threads + ) + else: + par_depth = self.decide_parallel_depth( + loop_nest.max_parallel_depth(), threads + ) + + with contextlib.ExitStack() as stack: + if par_depth: + if loop_nest.is_reduction_only(): + # need to close the worksharing scope to define reduction vars outside it + worksharing.close() + else: + worksharing.parallel(threads) + loop_nest.mark_parallel(par_depth) + elif threads > 1: + if worksharing.single(): + stack.enter_context(code.indent()) + + def gen_loop_kernel(loop: LoopLevel): + def is_parallel_reduction(loop): + root = loop.get_root() + return root.is_reduction and root.parallel + + kernels = loop.get_kernels() + assert len(kernels) == 1 + if not isinstance( + kernels[0], OuterLoopFusedKernel + ) and is_parallel_reduction(loop): + kernels[0].update_stores_with_parallel_reduction() + gen_kernel(kernels[0]) + + def gen_kernel(kernel): + if isinstance(kernel, OuterLoopFusedKernel): + for loop in kernel.inner: + if loop.inner: + gen_loops(loop.inner, loop.is_reduction) + else: + with contextlib.ExitStack() as stack: + # If there is any kernel existing at the final outer loop fusion level, + # the kernel code should be placed within its respective indent to prevent + # the duplication of variable definitions. + stack.enter_context(code.indent()) + gen_loop_kernel(loop) + else: + with contextlib.ExitStack() as stack: + assert kernel + if hasattr(kernel, "codegen_inner_loops"): + code.splice(kernel.preloads) + kernel.codegen_inner_loops(code) + stack.enter_context(code.indent()) + code.splice(kernel.loads) + code.splice(kernel.compute) + code.splice(kernel.stores) + if hasattr(kernel, "codegen_inner_loops"): + code.splice(kernel.poststores) + + def get_reduction_code_buffer(loops, buffer="prefix"): + assert buffer in ("prefix", "suffix", "local") + for loop in loops: + for kernel in loop.get_kernels(): + if buffer == "local": + return ( + kernel.local_reduction_init, + kernel.local_reduction_stores, + ) + elif buffer == "suffix": + suffix = kernel.reduction_suffix + if loop.parallel: + suffix = kernel.parallel_reduction_suffix + suffix + return suffix + else: + prefix = kernel.reduction_prefix + if loop.parallel: + prefix = prefix + kernel.parallel_reduction_prefix + else: + prefix = prefix + kernel.non_parallel_reduction_prefix + return prefix + + def gen_loops(loops: List[LoopLevel], in_reduction=False): + with contextlib.ExitStack() as stack_outer: + local_reduction_init = local_reduction_stores = None + if loops: + loop = loops[0] + if loop.is_reduction and not in_reduction: + reduction_prefix = get_reduction_code_buffer(loops) + if reduction_prefix: + stack_outer.enter_context(code.indent()) + code.splice(reduction_prefix) + if loop_nest.is_reduction_only() and loop.parallel: + ( + local_reduction_init, + local_reduction_stores, + ) = get_reduction_code_buffer(loops, "local") + worksharing.parallel(threads) + if local_reduction_init: + assert local_reduction_stores + code.splice(local_reduction_init) + + for loop in loops: + gen_loop(loop) + + if loops: + loop = loops[0] + if loop_nest.is_reduction_only() and loop.parallel: + if local_reduction_stores: + code.splice(local_reduction_stores) + worksharing.close() + if loop.is_reduction and not in_reduction: + code.splice(get_reduction_code_buffer(loops, "suffix")) + + def gen_loop(loop: LoopLevel): + with contextlib.ExitStack() as stack: + loop_lines = loop.lines() + if loop_lines is None: + return + code.writelines(loop_lines) + stack.enter_context(code.indent()) + # generate inner loops or loop body + if loop.inner: + gen_loops(loop.inner, loop.is_reduction) + else: + gen_loop_kernel(loop) + + stack.enter_context(code.indent()) + if loop_nest.root: + if ( + has_outer_loop_kernel + and isinstance(V.local_buffer_context, LocalBufferContext) + and V.local_buffer_context.local_buffers + ): + # Allocate local buffer + local_buffers = V.local_buffer_context.local_buffers + for local_buffer in local_buffers.values(): + # For dynamic size, rename s to ks + local_buf_size = sympy_product( + [ + self.rename_indexing(size_val) + for size_val in local_buffer.get_layout().size + ] + ) + local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype] + allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})" + local_buffer_name = local_buffer.get_name() + code.splice( + f"std::unique_ptr<{local_buf_dtype} []> buf_{local_buffer_name} = {allocate};" + ) + code.splice( + f"{local_buf_dtype}* {local_buffer_name} = buf_{local_buffer_name}.get();" + ) + gen_loops(loop_nest.root) + else: + gen_kernel(loop_nest.kernel) + + def codegen_loops(self, code, worksharing): + loop_nest = LoopNestWithSplit.build(self) + self.codegen_loops_impl(loop_nest, code, worksharing) + + @property + def assert_function(self) -> str: + if V.graph.aot_mode: + # TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models + # compared with JIT Inductor which uses TORCH_CHECK + return "AOTI_TORCH_CHECK" + else: + return "TORCH_CHECK" + + def decide_parallel_depth(self, max_parallel_depth, threads): + assert self.call_ranges is not None + ranges = self.call_ranges[:max_parallel_depth] + seq = self.size_hint() + par = 1 + depth = 0 + for expr in ranges: + hint = V.graph.sizevars.size_hint(expr, fallback=8192) + if par >= 2 * threads or par == threads: + break + if seq // threads < config.cpp.min_chunk_size: + # not enough work + break + depth += 1 + par *= hint + seq /= hint + # if we assume thread number is dynamic, make sure we + # have at least one parallel scope and let OMP runtime + # to manage the serial vs. parallel. + if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0: + depth = 1 + return depth + + @contextlib.contextmanager + def write_to_suffix(self): + prior = (self.loads, self.compute, self.stores, self.cse) + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + self.cse = self.cse.clone() + yield + self.reduction_suffix.splice(self.loads) + self.reduction_suffix.splice(self.compute) + self.reduction_suffix.splice(self.stores) + (self.loads, self.compute, self.stores, self.cse) = prior + + def create_cse_var(self, *args, **kwargs): + return CppCSEVariable(*args, **kwargs) + + def get_to_dtype_expr(self, src, dtype, src_dtype): + return f"c10::convert<{DTYPE_TO_CPP[dtype]}>({src})" + + def cache_dtype_convert(self, dst, dst_dtype, src, src_dtype): + expr = self.get_to_dtype_expr(src, dst_dtype, src_dtype) + self.cse.cache[expr] = dst + + +class CppVecKernel(CppKernel): + overrides = CppVecOverrides # type: ignore[assignment] + + def __init__( + self, + args, + num_threads, + tiling_factor, + tiling_idx, + tail_size=None, + ): + super().__init__(args, num_threads) + self.vec_isa = cpu_vec_isa.pick_vec_isa() + assert self.vec_isa + assert tiling_factor > 0, "Expect pass in Non-Zero tiling_factor explicitly" + self.tiling_factor = tiling_factor + self.tiling_idx = tiling_idx + self.tail_size = tail_size + self.num_elems = tail_size if tail_size else tiling_factor + + def _try_get_const_stride(self, index: sympy.Expr, itervar: sympy.Symbol): + if self.index_indirect_depends_on(index, itervar): + return None + for indirect_var in ( + self.cse.varname_map[s.name] # type: ignore[attr-defined] + for s in index.free_symbols + if symbol_is_type(s, SymT.TMP) + ): + assert isinstance(indirect_var, CppCSEVariable) + if indirect_var.is_vec: + return None + stride = stride_at_vec_range(index, itervar, self.tiling_factor) + return stride if stride.is_number else None + + def _get_num_vectors(self, dtype: torch.dtype) -> int: + num_vectors = math.ceil( + self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() + ) + assert num_vectors >= 1 + return num_vectors + + def _get_raw_num_vectors(self, dtype: torch.dtype) -> float: + # This utility function is used to check if the vector lanes has been + # fully utilized. For example, uint8 will only use 1/4 of the vector lanes. + return self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() + + def _get_vec_type(self, dtype: torch.dtype) -> str: + num_vectors = self._get_num_vectors(dtype) + if num_vectors == 1: + return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>" + else: + return f"at::vec::VectorizedN<{DTYPE_TO_CPP[dtype]},{num_vectors}>" + + def _get_mask_type(self, dtype: torch.dtype = torch.float) -> str: + if dtype == torch.bool: + return "" + num_vectors = self._get_num_vectors(dtype) + return f"at::vec::VecMask<{DTYPE_TO_CPP[dtype]},{num_vectors}>" + + def _get_mask_cast(self, mask: CppCSEVariable, dtype: torch.dtype) -> str: + assert mask.dtype == torch.bool, repr(mask) + num_vectors = self._get_num_vectors(dtype) + return f"{mask}.template cast<{DTYPE_TO_CPP[dtype]},{num_vectors}>()" + + def get_reduction_var_pattern(self, line: str): + return re.search("tmp_acc[0-9]+_vec", line) + + def _get_vec_load_line( + self, + var: str, + index: sympy.Expr, + dtype: torch.dtype, + load_mask: Optional[CppCSEVariable] = None, + ): + """ + Get a load line str that loads a vector from `var` at `index` of type `dtype`. + If `load_mask` is not None, we do a masked load accordingly. + Notes on the `dtype`: + 1. We always load `self.tiling_factor` number of elements regardless of the `dtype`. + It means we load half of the vector lanes for 16-bit data types and quarter of the + vector lanes for 8-bit data types. + 2. `torch.bool` and `torch.uint8` could mean masks and we load them as float mask vectors. + """ + cpp_type = DTYPE_TO_CPP[dtype] + num_vectors = self._get_num_vectors(dtype) + load_mask_str = None + if load_mask: + if not load_mask.is_vec: + # TODO: avoid hard-code torch.float + load_mask_str = f"{self._get_mask_type(torch.float)}::from({load_mask})" + else: + load_mask_str = f"{self._get_mask_cast(load_mask, torch.float)}" + loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var + if dtype == torch.bool: + # TODO: should we consider load mask here? + line = f"{self._get_mask_type()}::from({loadbuf})" + else: + line = ( + f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})" + if load_mask_str + else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {cexpr_index(self.num_elems)})" + ) + return line + + def _load_or_store_non_contiguous( + self, + var: Optional[str], + index: sympy.Expr, + dtype: torch.dtype, + buffer: Optional[IndentedBuffer] = None, + store_value: Optional[Union[str, CppCSEVariable]] = None, + accu_store: bool = False, + ) -> Optional[CppCSEVariable]: + """ + Load or store a vector in a non-contiguous way. The vector is initialized from an array that is + filled in an inner loop over the tiling factor. + :param var: buffer to load from or store to, i.e. `var[transformed(index)]`. If None, we load the index + as index expression, i.e. `transformed(index)`. + :param index: index into the `var` or the index expression by its own if `var` is None. + The `index` could contain indirect indexing or the tiling itervar. When used in + the inner loop, the index is transformed as follows: + 1. the index is linearized along the tiling dim. + 2. the indirect indexing vector variables are transformed into arrays over the tiling dim. + :param dtype: data type of `var` or `index` if `var` is None. + :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`. + :param store_value: the value to store. If None, we load the vector. + :param accu_store: whether accumulate the store_value to store_ptr. If True, a store_value should be provided + :return: a CppCSEVariable that represents the loaded vector or None if it is a store. + """ + assert not store_value or var is not None, "store var must be provided" + if accu_store: + assert store_value + if buffer is None: + buffer = self.loads + + def get_result_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.num_elems * (4 // dtype.itemsize) + else: + return self.num_elems + + def get_tiling_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.tiling_factor * (4 // dtype.itemsize) + else: + return self.tiling_factor + + def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: + assert vec_var.is_vec + code = BracesBuffer() + code.writeline("[&]") + with code.indent(): + vec_dtype = vec_var.dtype + assert vec_dtype is not None + if vec_dtype == torch.bool: + vec_dtype = torch.float + result_size = get_result_size(vec_dtype) + tiling_size = get_tiling_size(vec_dtype) + code.writeline( + f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {tiling_size}> tmpbuf;" + ) + line = f"{vec_var}.store(tmpbuf.data(), {cexpr_index(result_size)});" + code.writeline(line) + code.writeline("return tmpbuf;") + code.writeline("()") + csevar = self.cse.generate(buffer, code) + assert isinstance(csevar, CppCSEVariable) + return csevar + + code = BracesBuffer() + code.writeline("[&]") + with code.indent(): + result_size = get_result_size(dtype) + tiling_size = get_tiling_size(dtype) + result_declare = ( + f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {tiling_size}> tmpbuf;" + ) + code.writeline(result_declare) + if store_value: + code.writeline( + f"{store_value}.store(tmpbuf.data(), {cexpr_index(result_size)});" + ) + itervar_inner = sympy_index_symbol( + f"{self.itervars[self.tiling_idx]}_inner" + ) + replacements = {} + for indirect_var in ( + self.cse.varname_map[s.name] # type: ignore[attr-defined] + for s in index.free_symbols + if symbol_is_type(s, SymT.TMP) + ): + assert isinstance(indirect_var, CppCSEVariable) + if indirect_var.is_vec: + array_var = vec_to_array(indirect_var) + replacements[indirect_var] = f"{array_var}[{itervar_inner}]" + index = self.scale_index_with_offset( + index, itervar_idx=self.tiling_idx, offset=itervar_inner + ) + load_mask = None + if self._load_mask is not None: + assert not store_value, "unexpected store with load mask" + assert isinstance(self._load_mask, CppCSEVariable), self._load_mask + if self._load_mask.is_vec: + load_mask = f"{self._load_mask}.is_masked({itervar_inner})" + else: + load_mask = f"{self._load_mask} != 0" + if cpp_builder.is_gcc(): + code.writeline(f"#pragma GCC unroll {self.tiling_factor}") + else: + code.writeline(f"#pragma unroll {self.tiling_factor}") + code.writeline( + f"for (long {itervar_inner} = 0; " + + f"{itervar_inner} < {cexpr_index(self.num_elems)}; " + + f"{itervar_inner}++)" + ) + with code.indent(), contextlib.ExitStack() as stack: + index_c = cexpr_index(index) + for indirect_var in replacements: + index_c = re.sub( + r"\b" + f"{indirect_var}" + r"\b", + replacements[indirect_var], + index_c, + ) + rhs = f"{var}[{index_c}]" if var is not None else f"{index_c}" + if load_mask: + code.writeline(f"if ({load_mask})") + stack.enter_context(code.indent()) + if store_value: + conjunction = "+=" if accu_store else "=" + code.writeline(f"{rhs} {conjunction} tmpbuf[{itervar_inner}];") + else: + code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};") + if not store_value: + load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type] + code.writeline(f"return {load_line};") + code.writeline("()") + if store_value: + code.writeline(";") + buffer.splice(code) + return None + else: + csevar = self.cse.generate(buffer, code) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + dtype = V.graph.get_dtype(name) + tiling_var = self.itervars[self.tiling_idx] + stride = self._try_get_const_stride(index, tiling_var) + if stride == 0: + # load scalar and lazily broadcast it on demand + return super().load(name, index) + elif stride == 1: + # load contiguously + line = self._get_vec_load_line(var, index, dtype, self._load_mask) + csevar = self.cse.generate(self.loads, line) # type: ignore[assignment] + else: + csevar = self._load_or_store_non_contiguous(var, index, dtype) # type: ignore[assignment] + assert isinstance(csevar, CppCSEVariable) + csevar.update_on_args("load", (self, name, index), {}) + csevar.is_vec = True + return csevar + + def _get_store_line( + self, + value: Union[str, CppCSEVariable], + var: str, + index: sympy.Expr, + dtype: torch.dtype, + accu_store: bool = False, + ): + """ + Get a store line buffer that stores `value` into `var` at `index` of `dtype`. It handles + both contiguous and non-contiguous store cases. + :param value: Vectorized type templaterized on `dtype`. + :param var: buffer to store into. + :index: index into the `var`. + """ + # when value's type is str (e.g., welford reduction), caller should make sure + # it is a vector + assert isinstance(value, str) or ( + isinstance(value, CppCSEVariable) and value.is_vec + ), value + tiling_var = self.itervars[self.tiling_idx] + var_expr = f"{var} + {cexpr_index(index)}" + stride = self._try_get_const_stride(index, tiling_var) + code = IndentedBuffer() + if stride == 1: + if dtype == torch.float and self.tail_size is None: + code.writeline(f"{value}.store({var_expr});") + else: + code.writeline( + f"{value}.store({var_expr}, {cexpr_index(self.num_elems)});" + ) + else: + self._load_or_store_non_contiguous( + var, index, dtype, buffer=code, store_value=value, accu_store=accu_store + ) + return code + + def store(self, name, index, value, mode=None): + assert "buf" in name + assert isinstance(value, CppCSEVariable), value + if not value.is_vec: + # this happens when we store a scalar into a vectorized buffer like "fill" + value = self.broadcast(value) + var = self.args.output(name) + index = self.rename_indexing(index) + dtype = V.graph.get_dtype(name) + if mode is None: + code = self._get_store_line(value, var, index, dtype) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + code = self._get_store_line( + f"{value}", + var, + index, + dtype, + accu_store=True, + ) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + else: + n_src = self._get_num_vectors(dtype) + n_idx = self._get_num_vectors(torch.int64) + cdtype = DTYPE_TO_CPP[dtype] + index = ops.index_expr(index, torch.int64).value + assert index.is_vec + line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});" + self.stores.writeline(DeferredLine(name, line)) + else: + raise NotImplementedError(f"store mode={mode}") + + def reduction(self, dtype, src_dtype, reduction_type, value): + assert reduction_type in VECTORIZABLE_RTYPES + argmax_or_argmin = reduction_type in {"argmax", "argmin"} + horizontal_reduction = self.tiling_idx >= self.reduction_depth + init_dtype = src_dtype if argmax_or_argmin else dtype + assert isinstance(value, CppCSEVariable), value + + if not value.is_vec: + value = self.broadcast(value) + + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + vec_ns = "at::vec" + vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>" + acc_type = reduction_acc_type(reduction_type, init_dtype) + acc_type_vec = self.reduction_acc_type_vec(reduction_type, init_dtype) + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + acc_vec = f"{acc}_vec" + self.is_reduction = True + self.reduction_prefix.writeline( + f"{acc_type} {acc} = {reduction_init(reduction_type, init_dtype)};" + ) + self.reduction_prefix.writeline( + f"{acc_type_vec} {acc_vec} = {self.reduction_init_vec(reduction_type, init_dtype)};" + ) + if reduction_type == "welford_reduce": + # save the reciprocal of weights for welford reduce + assert self.reduction_depth is not None + # use masked acc_vec for tail vec kernel + self.reduction_prefix.writeline( + f"{acc_type_vec} masked_{acc_vec} = {self.reduction_init_vec(reduction_type, dtype)};" + ) + reduction_size = functools.reduce( + lambda x, y: x * y, self.ranges[self.reduction_depth :] + ) + reduction_factor = ( + self.tiling_factor if self.tiling_idx >= self.reduction_depth else 1 + ) + self.weight_recp_vec_range = FloorDiv(reduction_size, reduction_factor) + if self.weight_recp_vec_range not in self.weight_recps_cse.reduction_cache: + self.weight_recps_val = self.weight_recps_cse.generate( + self.compute, f"reduction {self.weight_recp_vec_range}", write=False + ) + self.weight_recps_cse.reduction_cache[ + self.weight_recp_vec_range + ] = self.weight_recps_val + self.non_parallel_reduction_prefix.writeline( + self.welford_weight_reciprocal_vec(dtype) + ) + # generate weight_recps for parallel reduction + num_threads = ( + "max_threads" + if config.cpp.dynamic_threads + else parallel_num_threads() + ) + self.local_reduction_init.writeline( + self.welford_weight_reciprocal_vec(dtype, num_threads) + ) + else: + self.weight_recps_val = self.weight_recps_cse.reduction_cache[ + self.weight_recp_vec_range + ] + # use masked acc_vec for tail vec kernel + acc_vec_ = f"masked_{acc_vec}" if self.tail_size else acc_vec + self.stores.writeline( + f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, True)};" + ) + else: + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + combine = self.reduction_combine_vec( + reduction_type, + acc_vec, + value, + index=index, + horizontal_reduction=horizontal_reduction, + src_dtype=src_dtype, + ) + self.stores.writeline(f"{acc_vec} = {combine};") + self._gen_parallel_reduction_buffers( + acc, + acc_type, + reduction_type, + init_dtype, + ) + self._gen_parallel_reduction_buffers( + acc_vec, + acc_type_vec, + reduction_type, + init_dtype, + reduction_combine_fn=self.reduction_combine_vec, + reduction_init_fn=self.reduction_init_vec, + ) + if reduction_type == "welford_reduce": + # use masked acc_vec for tail vec kernel + self._gen_parallel_reduction_buffers( + f"masked_{acc_vec}", + acc_type_vec, + reduction_type, + dtype, + reduction_combine_fn=self.reduction_combine_vec, + reduction_init_fn=self.reduction_init_vec, + ) + tmpvar: Union[str, CSEVariable] + is_bool = dtype == torch.bool + if horizontal_reduction: + # Horizontal reduction + if is_welford_reduction(reduction_type): + assert self._get_num_vectors(dtype) in [ + 1, + 2, + ], "Welford reduction does not support VectorizedN (N>2)" + next_value = f"welford_vec_reduce_all({acc_vec})" + masked_next_value = f"welford_vec_reduce_all(masked_{acc_vec})" + self.reduction_suffix.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, masked_next_value)};" + ) + elif argmax_or_argmin: + next_value = f"{reduction_type}_vec_reduce_all({acc_vec})" + elif is_bool: + if reduction_type in ( + "any", + "sum", + "max", + ): + next_value = f"!{acc_vec}.all_zero()" + else: + assert reduction_type == "min" + next_value = f"{acc_vec}.all_masked()" + else: + reduce_all_body = ( + "{ return " + + self.reduction_combine_vec(reduction_type, "x", "y") + + "; }" + ) + is_bool = dtype == torch.bool + # we are using at::vec::VecMask for bool + vec_dtype = torch.float if is_bool else dtype + vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[vec_dtype]}>" + vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[vec_dtype]}, {self._get_num_vectors(vec_dtype)}>" + next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})" + + self.reduction_suffix.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, next_value, src_dtype=src_dtype)};" + ) + tmpvar = acc + else: + tmpvar = acc_vec + if is_welford_reduction(reduction_type): + masked_tmpvar = f"masked_{tmpvar}" + self.reduction_suffix.writeline( + f"{tmpvar} = {reduction_combine(reduction_type, tmpvar, masked_tmpvar)};" + ) + + result = reduction_project(reduction_type, tmpvar) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + out_dtype = V.graph.get_dtype(name) + dtype = ( + (out_dtype if out_dtype == torch.double else torch.float) + if out_dtype.is_floating_point + else torch.int64 + ) + out_num_vectors = V.kernel._get_num_vectors(out_dtype) + src_num_vectors = V.kernel._get_num_vectors(dtype) + code = IndentedBuffer() + if self.tiling_idx >= self.reduction_depth: + # Horizontal reduction + code.writeline( + f"{var}[{cexpr_index(index)}] = static_cast<{DTYPE_TO_CPP[out_dtype]}>({value});" + ) + else: + # Vertical reduction + if out_dtype != dtype: + converted_value = f"{DTYPE_TO_CPP[out_dtype]}_{value}" + if out_dtype == torch.bool: + convert = f"{value}.template cast()" + else: + if src_num_vectors == out_num_vectors == 1: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value})" + ) + else: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}," + f"{out_num_vectors},{DTYPE_TO_CPP[dtype]},{src_num_vectors}>({value})" + ) + code.writeline(f"auto {converted_value} = {convert};") + value = converted_value + code.splice(self._get_store_line(value, var, index, out_dtype)) + self.reduction_suffix.splice(code.map(lambda x: DeferredLine(name, x))) + + def broadcast(self, scalar_var: CppCSEVariable) -> CppCSEVariable: + assert not scalar_var.is_vec + if scalar_var.dtype == torch.bool: + vec_var = self.cse.generate( + self.compute, f"{self._get_mask_type()}::from({scalar_var.name})" + ) + else: + assert scalar_var.dtype is not None + vec_var = self.cse.generate( + self.compute, + f"{self._get_vec_type(scalar_var.dtype)}({scalar_var.name})", + ) + assert isinstance(vec_var, CppCSEVariable) + vec_var.dtype = scalar_var.dtype + vec_var.dependent_itervars = scalar_var.dependent_itervars + vec_var.is_vec = True + return vec_var + + def arange(self, index: CppCSEVariable, stride: sympy.Symbol) -> CppCSEVariable: + assert not index.is_vec + assert index.dtype is not None + csevar = self.cse.generate( + self.compute, + f"{self._get_vec_type(index.dtype)}::arange({index}, {stride})", + ) + assert isinstance(csevar, CppCSEVariable) + csevar.dtype = index.dtype + csevar.is_vec = True + return csevar + + def reduction_init_vec(self, reduction_type, dtype): + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>()" + + if reduction_type in {"argmin", "argmax"}: + cdtype = DTYPE_TO_CPP[scalar_type] + acc_type = self.reduction_acc_type_vec(reduction_type, dtype) + if reduction_type == "argmin": + val = ( + f"std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::max()" + ) + else: + val = ( + f"-std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::min()" + ) + return f"{acc_type}({val})" + + if reduction_type == "any": + return f"{self._get_mask_type()}::from(0)" + + scalar_init = reduction_init(reduction_type, dtype) + vec_init = f"{vec_type}({scalar_init})" + if dtype == torch.bool: + assert reduction_type in ("min", "max", "sum") + return f"{self._get_mask_type()}::from({scalar_init})" + return vec_init + + def reduction_acc_type_vec(self, reduction_type, dtype): + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>" + if reduction_type in {"argmin", "argmax"}: + n_src = self._get_num_vectors(scalar_type) + n_idx = self._get_num_vectors(torch.int64) + return f"IndexValueVec<{DTYPE_TO_CPP[scalar_type]}, {n_src}, {n_idx}>" + if dtype == torch.bool: + assert reduction_type in ("min", "max", "any", "sum") + return f"{self._get_mask_type()}" + return vec_type + + def welford_weight_reciprocal_vec(self, dtype, num_threads=None): + vec_num_range_thread = ( + CeilDiv(self.weight_recp_vec_range, num_threads) + if num_threads + else self.weight_recp_vec_range + ) + vec_num_range_thread_expr = cexpr_index(vec_num_range_thread) + return ( + f"static WeightRecp<{self._get_vec_type(dtype)}> {self.weight_recps_val}" + f"(" + f"{vec_num_range_thread_expr}" + f");" + ) + + def reduction_combine_vec( + self, + reduction_type, + var, + next_value, + use_weight_recps=False, + index: Optional[sympy.Symbol] = None, + horizontal_reduction: Optional[bool] = None, + src_dtype: Optional[torch.dtype] = torch.float32, + ): + is_bool = src_dtype == torch.bool + if reduction_type == "max": + if self.tail_size: + return f"max_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return ( + f"{var} | {next_value}" + if is_bool + else f"at::vec::maximum({var}, {next_value})" + ) + elif reduction_type == "min": + if self.tail_size: + return f"min_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return ( + f"{var} & {next_value}" + if is_bool + else f"at::vec::minimum({var}, {next_value})" + ) + elif reduction_type == "sum": + if self.tail_size: + return f"sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" + elif reduction_type == "prod": + if self.tail_size: + return f"prod_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"{var} * {next_value}" + elif reduction_type == "xor_sum": + if self.tail_size: + return f"xor_sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"{var} ^ {next_value}" + elif reduction_type == "welford_reduce": + if use_weight_recps: + if self.tail_size: + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{self.weight_recps_val})" + else: + return f"welford_combine({var}, {next_value}, &{self.weight_recps_val})" + else: + if self.tail_size: + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"welford_combine({var}, {next_value})" + elif reduction_type == "welford_combine": + if isinstance(next_value, tuple): + # When reading a value from Inductor IR we have a tuple of variable names + mean, m2, weight = next_value + else: + # When combining intermediate accumulators we have a Welford struct + mean, m2, weight = reduction_project(reduction_type, next_value) + if self.tail_size: + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}}, {cexpr_index(self.tail_size)})" + else: + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + elif reduction_type in ("argmin", "argmax"): + assert src_dtype is not None + cdtype = DTYPE_TO_CPP[src_dtype] + n_src = self._get_num_vectors(src_dtype) + n_idx = self._get_num_vectors(torch.int64) + t_extra = "" + arg_extra = "" + if index is not None: + assert horizontal_reduction is not None + t_extra = f", {str(horizontal_reduction).lower()}" + arg_extra = f", {index}" + if self.tail_size: + return ( + f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>" + f"({var}, {next_value}{arg_extra}, {cexpr_index(self.tail_size)})" + ) + else: + return f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>({var}, {next_value}{arg_extra})" + elif reduction_type == "any": + return f"{var} | {next_value}" + else: + raise NotImplementedError + + def indirect_assert(self, var, lower, upper, mask=None): + assert isinstance(var, CppCSEVariable) + assert var.dtype is not None + if not var.is_vec: + if isinstance(mask, CppCSEVariable) and mask.is_vec: + mask = f"({mask}).all_masked()" + return super().indirect_assert(var, lower, upper, mask) + lower_scalar = lower + upper_scalar = upper + if lower: + lower = f"{self._get_vec_type(var.dtype)}({lower})" + if upper: + upper = f"{self._get_vec_type(var.dtype)}({upper})" + if lower and upper: + cond = f"({lower} <= {var}) & ({var} < {upper})" + cond_print = f"{lower_scalar} <= {var} < {upper_scalar}" + elif lower: + cond = f"{lower} <= {var}" + cond_print = f"{lower_scalar} <= {var}" + else: + assert upper + cond = f"{var} < {upper}" + cond_print = f"{var} < {upper_scalar}" + cond = f"{self._get_mask_type(var.dtype)}({cond})" + if mask: + if not mask.is_vec: + mask = f"{self._get_mask_type(var.dtype)}({mask})" + # We need not check when the mask is False + cond = f"({cond}) | ~({mask})" + if self.tail_size: + cond = ( + f"{self._get_mask_type(var.dtype)}::set({self._get_mask_type(var.dtype)}::from(1)" + f", ({cond}), {cexpr_index(self.tail_size)})" + ) + cond = f"({cond}).all_masked()" + return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' + + def get_to_dtype_expr(self, src, dtype, src_dtype): + assert isinstance(src, CppCSEVariable) + if not src.is_vec: + return super().get_to_dtype_expr(src, dtype, src_dtype) + src_cpp_type = DTYPE_TO_CPP[src_dtype] + src_num_vectors = self._get_num_vectors(src_dtype) + dst_cpp_type = DTYPE_TO_CPP[dtype] + dst_num_vectors = self._get_num_vectors(dtype) + expr = f"({src})" + if src_dtype != torch.bool and dtype == torch.bool: + expr = f"{self._get_mask_type(src_dtype)}::from<{src_cpp_type},{src_num_vectors}>({src})" + elif src_dtype == torch.bool and dtype != torch.bool: + expr = f"{src}.to<{dst_cpp_type},{dst_num_vectors}>()" + elif src_dtype != dtype: + if src_num_vectors == dst_num_vectors == 1: + expr = f"at::vec::convert<{dst_cpp_type}>({src})" + else: + expr = f"at::vec::convert<{dst_cpp_type},{dst_num_vectors},{src_cpp_type},{src_num_vectors}>({src})" + return expr + + +class CppTile2DKernel(CppVecKernel): + """ + A vector kernel that handles the 2d tiles with the tile size defined in `tiling_factor` on + the inner-most loop level and one of the outer loop level (`outer_tiling_idx`). When the data + tile is accessed in a contiguous way from the outer loop axis, a transposition is applied on the + tile to make the access contiguous from the inner-most loop axis. Then, the same vectorization + logic from its parent `CppVecKernel` is leveraged for load/store/compute. The transposed tile load + and store are generated into kernel.preloads and kernel.poststores buffers. + + The loop structure looks like below: + for ... + for i_outer ... + for ... + for inner_most ... + // generated by CppTile2DKernel + float tmp0[16*16]; at::vec::transpose_mxn<...>(tmp0, in_ptr0 + ..., ...); // into kernel.preloads + float tmp1[16*16]; // into kernel.preloads + for i_inner ... { // the kernel inner loop + vectorized loads/compute/stores (e.g., load tmp0, store tmp1) // into kernel.loads/compute/stores + } + at::vec::transpose_mxn(out_ptr0 + ..., tmp1, ...) // into kernel.poststores + for inner_most ... (tail) + // generated by CppVecKernel + ... + for i_outer ... (tail) + for ... + for ... + // generated by CppKernel + ... + """ + + overrides = CppTile2DOverrides # type: ignore[assignment] + + def __init__( + self, + args, + num_threads, + tiling_factor, + tiling_indices, + inner_tail_size=None, + outer_tail_size=None, + ): + super().__init__( + args, + num_threads, + tiling_factor, + tiling_indices[1], + inner_tail_size, + ) + self.tiling_indices = tiling_indices + self.inner_tail_size = inner_tail_size + self.outer_tail_size = outer_tail_size + self.inner_num_elems = inner_tail_size if inner_tail_size else tiling_factor + self.outer_num_elems = outer_tail_size if outer_tail_size else tiling_factor + self.inner_is_tiling_idx = True + + def inner_itervar(self): + return sympy_index_symbol(f"{self.itervars[self.outer_idx]}_inner") + + def need_vec_transpose(self, index): + outer_var = self.itervars[self.outer_idx] + inner_var = self.itervars[self.tiling_idx] + outer_stride = stride_at_vec_range(index, outer_var, self.tiling_factor) + inner_stride = stride_at_vec_range(index, inner_var, self.tiling_factor) + return ( + self._load_mask is None # TODO: support transposition with mask + and outer_stride == 1 + and index.has(inner_var) + and not inner_stride.has(inner_var) + and not inner_stride.has(outer_var) + ) + + def gen_transposed_tile_load_store(self, name, var, index, is_store): + # transposed tile load/store outside the kernel inner loop + dtype = V.graph.get_dtype(name) + factor = self.tiling_factor + src = f"{var} + {cexpr_index(index)}" + dst = "__place_holder__" + ld_src = f"{cexpr_index(stride_at_vec_range(index, self.itervars[self.tiling_idx], self.tiling_factor))}" + ld_dst = f"{cexpr_index(self.num_elems)}" + if is_store: + src, dst = dst, src + ld_src, ld_dst = ld_dst, ld_src + + need_define = True + if self.inner_is_tiling_idx ^ is_store: + M, N = self.inner_num_elems, self.outer_num_elems + else: + M, N = ( + self.outer_num_elems, + self.inner_num_elems, + ) + if (isinstance(M, sympy.Expr) and not M.is_number) or ( + isinstance(N, sympy.Expr) and not N.is_number + ): + load_or_store = ( + f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]}>" + f"({src}, {ld_src}, {dst}, {ld_dst}, {cexpr_index(M)}, {cexpr_index(N)});" + ) + else: + load_or_store = ( + f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)}>" + f"({src}, {ld_src}, {dst}, {ld_dst});" + ) + if is_store: + tile_var = self.cse.newvar() + elif load_or_store not in self.cse.cache: + tile_var = self.cse.generate(self.preloads, load_or_store, write=False) + else: + need_define = False + tile_var = self.cse.cache[load_or_store] + + if need_define: + define_line = f"alignas({factor}) {DTYPE_TO_CPP[dtype]} {tile_var}[{factor}*{factor}];" + self.preloads.writeline(define_line) + + load_or_store = load_or_store.replace("__place_holder__", str(tile_var)) + if is_store: + self.poststores.writeline(DeferredLine(name, load_or_store)) + else: + self.preloads.writeline(load_or_store) + + return tile_var + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + + inner = self.inner_itervar() + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=False + ) + # vector load inside the kernel inner loop + loadbuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}" + dtype = V.graph.get_dtype(name) + line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type] + csevar = self.cse.generate(self.loads, line) + csevar.update_on_args("load", (self, name, index), {}) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + else: + new_index = self.transform_indexing(index) + return super().load(name, new_index) + + def store(self, name, index, value, mode=None): + assert "buf" in name + var = self.args.output(name) + + inner = self.inner_itervar() + index = self.rename_indexing(index) + assert mode is None + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=True + ) + # vector store inside the kernel inner loop + storebuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}" + if self.tail_size or V.graph.get_dtype(name) in DTYPE_LOWP_FP + [ + torch.uint8, + torch.int8, + ]: + line = f"{value}.store({storebuf}, {cexpr_index(self.num_elems)});" + else: + line = f"{value}.store({storebuf});" + self.stores.writeline(DeferredLine(name, line)) + else: + new_index = self.transform_indexing(index) + super().store(name, new_index, value, mode) + + def codegen_inner_loops(self, code): + inner = self.inner_itervar() + if self.inner_is_tiling_idx: + code.writeline( + f"for (long {inner} = 0; {inner} < {cexpr_index(self.outer_num_elems)}; {inner}++)" + ) + else: + code.writeline( + f"for (long {inner} = 0; {inner} < {cexpr_index(self.inner_num_elems)}; {inner}++)" + ) + + def set_ranges(self, group, reduction_group): + vars = super().set_ranges(group, reduction_group) + # do vertical reduction as the tail loop + self.outer_idx, self.tiling_idx = ( + self.tiling_indices + if self.tiling_indices[1] < self.reduction_depth + else reversed(self.tiling_indices) + ) + if self.tiling_idx == self.tiling_indices[0]: + self.tail_size = self.outer_tail_size + self.num_elems = self.outer_num_elems + self.inner_is_tiling_idx = False + else: + self.tail_size = self.inner_tail_size + self.num_elems = self.inner_num_elems + self.inner_is_tiling_idx = True + return vars + + def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: + return self.scale_index_with_offset( + index, + itervar_idx=self.outer_idx, + offset=self.inner_itervar(), + ) + + +def get_loop_body_lowp_fp(_body: LoopBody) -> Tuple[Optional[torch.dtype], bool]: + """ + Returns the low precision data type (torch.float16/torch.bfloat16) contained in the nodes + and if all the nodes can codegen with this data type without converting to float. + Otherwise returns None and True. + """ + sub_blocks = [_body.root_block] + list(_body.subblocks.values()) + + _lowp_fp_type: Optional[torch.dtype] = None + _use_fp32 = False + for sub_block in sub_blocks: + for _node in sub_block.graph.nodes: + if _node.op == "placeholder" or _node.target in ( + "get_index", + "index_expr", + ): + continue + + # Fast path if all operations can support bf16/fp16 without converting to fp32 + if _node.target not in [ + "load", + "store", + "abs", + "neg", + "output", + ]: + _use_fp32 = True + + if hasattr(_node, "meta") and _node.meta: + assert OptimizationContext.key in _node.meta + opt_ctx: OptimizationContext = _node.meta[OptimizationContext.key] + if not opt_ctx.dtype or opt_ctx.dtype not in DTYPE_LOWP_FP: + _use_fp32 = True + elif _lowp_fp_type is not None: + if _lowp_fp_type != opt_ctx.dtype: + warnings.warn("bf16 and fp16 are mixed in the scheduler node.") + else: + _lowp_fp_type = opt_ctx.dtype + else: + _use_fp32 = True + + return _lowp_fp_type, _use_fp32 + + +class TilingSelect: + """ + Implement the heuristic to select the tiling factors and tiling indices. + In the future, we can implement advanced heuristic in a subclass. + """ + + def __init__(self): + super().__init__() + + def select_tiling( + self, + fn_list, + var_sizes_list, + ) -> Tuple[List[int], List[int]]: + # TODO(jgong5): support alternative tiling factors and data types + loop_bodies = _get_loop_body(fn_list) + all_dtypes = _get_dtype_from_loopbodies(loop_bodies) + assert all_dtypes + if any(dtype not in VECTORIZABLE_DTYPES for dtype in all_dtypes): + return [], [] + dtype = torch.float + _lowp_fp_dtype = get_loop_body_lowp_fp(loop_bodies[0])[0] + if _lowp_fp_dtype and all( + (get_loop_body_lowp_fp(loop_body)[0] == _lowp_fp_dtype) + for loop_body in loop_bodies[1:] + ): + dtype = _lowp_fp_dtype + + tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) + tiling_indices = self._select_tiling_indices( + fn_list, var_sizes_list, tiling_factor + ) + + if tiling_indices: + group, reduction_group = max( + var_sizes_list, key=lambda sizes: len(sizes[1]) + ) + call_ranges = tuple(group) + tuple(reduction_group) + + if config.cpp.enable_tiling_heuristics: + + def _try_get_stride( + index, + itervars, + tiling_factor, + tiling_indices, + ): + itervar = itervars[tiling_indices[0]] + stride = stride_at_vec_range(index, itervar, tiling_factor) + return stride if stride.is_number else None + + def _update_negative_op_count( + node_name, non_contig_indexing_op_counter + ): + if node_name not in non_contig_indexing_op_counter: + non_contig_indexing_op_counter[node_name] = 1 + else: + non_contig_indexing_op_counter[node_name] += 1 + + def _is_valid_indices( + itervars, + tiling_indices, + ): + return ( + len(tiling_indices) == 1 + and len(itervars) > 0 + and ( + tiling_indices[0] + if tiling_indices[0] >= 0 + else tiling_indices[0] + len(itervars) + ) + < len(itervars) + ) + + itervars = [ + sympy_index_symbol_with_prefix(SymT.XBLOCK, n) + for n in range(len(call_ranges)) + ] + reduction_depth = len(group) + vars, reduction_vars = ( + itervars[:reduction_depth], + itervars[reduction_depth:], + ) + op_counter: Dict[str, int] = {} + # ops may cause overhead with vectorization, like non-contiguous + # index_expr, load, store + non_contig_indexing_op_counter: Dict[str, int] = {} + for _body in loop_bodies: + sub_blocks = [_body.root_block] + list(_body.subblocks.values()) + for sub_block in sub_blocks: + for _node in sub_block.graph.nodes: + if _node.target in ["index_expr", "load", "store"]: + # get the index and replace prefix from z to x + arg_idx = 1 if _node.target == "index_expr" else 2 + index = sub_block.body.indexing_from_args( + (vars, reduction_vars) + )[_node.args[arg_idx].args[0]] + if _is_valid_indices(itervars, tiling_indices): + stride = _try_get_stride( + index, itervars, tiling_factor, tiling_indices + ) + if ( + stride is None + if _node.target == "index_expr" + else stride not in [0, 1] + ): + _update_negative_op_count( + _node.target, non_contig_indexing_op_counter + ) + if isinstance(_node.target, str) and not ( + _node.target.startswith("masked_subblock") + or _node.target + in ["ops", "output", "constant", "get_index"] + ): + if _node.target not in op_counter: + op_counter[_node.target] = 1 + else: + op_counter[_node.target] += 1 + + op_num = sum(op_counter.values()) + non_contig_indexing_op_num = sum( + non_contig_indexing_op_counter.values() + ) + threshold = 0.08 + if op_num > 0 and non_contig_indexing_op_num / op_num >= threshold: + # Too many non-contiguous load/store/index_expr which hurts the + # vectorization performance. Disable vectorization when exceeding + # the threshold. + return [], [] + + if ( + not reduction_group + and group + and len(tiling_indices) == 1 + and not has_free_symbols( + [ + group[tiling_indices[0]], + ] + ) + and group[tiling_indices[0]] < tiling_factor / 2 + ): + # For case of Multi Thread AMP Static shape of pyhpc_isoneutral_mixing, + # the inner loop range doesn't have enough elements to do vectorization + # explicitly and found that `#pragma GCC ivdep` has better performance than + # `#pragma omp simd simdlen(8)`. Disable vectorization for this case. + # Leslie: maybe we can always disable vectorization when loop range is less + # than tiling factor and enable `#pragma omp simd simdlen(8)` for scalar kernel + # when needed. + return [], [] + + if dtype in DTYPE_LOWP_FP: + # For lower precision data type, if the call_range is not long enough, + # use tiling_factor // 2 for better performance + factor_lowp = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) + for tiling_indice in tiling_indices: + if tiling_indice < 0: + tiling_indice = tiling_indice + len(call_ranges) + if tiling_indice < 0 or tiling_indice >= len(call_ranges): + continue + if has_free_symbols(call_ranges): + call_range = V.graph.sizevars.size_hint( + call_ranges[tiling_indice], fallback=0 + ) + if call_range < factor_lowp: + V.graph.sizevars.guard_lt(call_range, factor_lowp) + tiling_factor = factor_lowp // 2 + break + elif call_ranges[tiling_indice] < factor_lowp: + tiling_factor = factor_lowp // 2 + break + + if len(tiling_indices) == 1: + return [tiling_factor], tiling_indices + if len(tiling_indices) == 2: + return [tiling_factor, tiling_factor], tiling_indices + return [], [] + + def _select_tiling_indices( + self, + fn_list, + var_sizes_list, + tiling_factor, + ): + all_index = [] + for fn, var_sizes in zip(fn_list, var_sizes_list): + rw = dependencies.extract_read_writes(fn, *var_sizes) + all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)] + contig_vars = set() + contig_vars_list = [] + non_contig_stride_const = set() + non_contig_stride_other = set() + for index in all_index: + for var in index.free_symbols: + if not re.search(r"^d\d+$", var.name): + continue + stride = stride_at_vec_range(index, var, tiling_factor) + if stride == 0: + continue + elif stride == 1: + contig_vars.add(int(var.name[1:])) + contig_vars_list.append(int(var.name[1:])) + elif all(symbol_is_type(s, SymT.SIZE) for s in stride.free_symbols): + non_contig_stride_const.add(int(var.name[1:])) + else: + non_contig_stride_other.add(int(var.name[1:])) + contig_only = contig_vars - non_contig_stride_const - non_contig_stride_other + group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) + num_itervars = len(group) + len(reduction_group) + if len(contig_vars) == 0: + # no contiguous vars + return [num_itervars - 1] + if contig_only: + return sorted(contig_only)[-1:] + contig_and_const_stride = ( + contig_vars & non_contig_stride_const + ) - non_contig_stride_other + contig_vars_sorted = sorted(contig_vars) + if ( + len(contig_vars_sorted) == 2 + and contig_vars_sorted[-1] in contig_and_const_stride + and contig_vars_sorted[-1] == num_itervars - 1 + ): + return contig_vars_sorted + return sorted(contig_vars_sorted, key=contig_vars_list.count)[-1:] + + +class CppKernelProxy(CppKernel): + def __init__(self, kernel_group): + super().__init__(kernel_group.args, kernel_group.ws.num_threads) + self.kernel_group = kernel_group + self.loop_nest = None + self.call_ranges = None + self.picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa() + + def data_type_propagation(self, nodes): + for _node in nodes: + assert isinstance(_node, SchedulerNode) + DataTypePropagation.propagate_scheduler_node(_node) + + # Check if all the nodes of a given fx graph can support BF16/FP16 + def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode): + if not isinstance(scheduler_node._body, LoopBody): + return True + # Propagate the dtype to check if all the fx node is bf16/fp16 + DataTypePropagation.propagate_scheduler_node(scheduler_node) + return ( + get_loop_body_lowp_fp(scheduler_node._body)[0] is not None + and not get_loop_body_lowp_fp(scheduler_node._body)[1] + ) + + def legalize_lowp_fp_dtype_loopbody(self, loop_body: LoopBody): + def add_to_dtype(sub_graph: torch.fx.Graph): + def is_lowp_fp_load(node: torch.fx.Node): + if node.target not in ["load"]: + return False + assert len(node.args) == 3 + load_dtype = V.graph.get_dtype(node.args[1]) # type: ignore[arg-type] + return load_dtype in DTYPE_LOWP_FP + + def is_lowp_fp_store(node: torch.fx.Node): + if node.target != "store": + return False + _, store_var, _, _, _ = node.args + store_dtype = V.graph.get_dtype(store_var) # type: ignore[arg-type] + return store_dtype in DTYPE_LOWP_FP + + sub_graph_nodes = list(sub_graph.nodes) + to_lowp_fp_legalized_nodes = [] + for _node in sub_graph_nodes: + if is_lowp_fp_load(_node): + # No need to promote to float if all users are direct stores + if all(user.target == "store" for user in _node.users): + continue + ops = _node.args[0] + with sub_graph.inserting_after(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, _node, torch.float) + ) + to_type_node_args = to_type_node.args + _node.replace_all_uses_with(to_type_node) + to_type_node.args = to_type_node_args + metrics.cpp_to_dtype_count += 1 + elif is_lowp_fp_store(_node): + ops, name, _, value_var, _ = _node.args + # No need to promote to float if it is a user of a load which are all directly stored + if value_var.target == "load" and all( + user.target == "store" for user in value_var.users + ): + continue + dtype = V.graph.get_dtype(name) + with sub_graph.inserting_before(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, value_var, dtype) + ) + _node.replace_input_with(value_var, to_type_node) + metrics.cpp_to_dtype_count += 1 + elif _node.target == "reduction": + ( + ops, + dtype, + src_dtype, + reduction_type, + value, + ) = _node.args + if src_dtype in DTYPE_LOWP_FP: + # Since we always convert the load/store value to float if the tensor is bfloat16/float16. + # Therefore, the reduction should never work with bfloat16/float16 value. Hence, we update + # the bfloat16/float16 reduction by + # 1) updating the src_dtype to float + # and 2) updating the dtype to float if it is bfloat16/float16. + assert dtype in [ + torch.float, + torch.bfloat16, + torch.float16, + torch.int64, + ] + _node.args = ( + ops, + torch.float if dtype in DTYPE_LOWP_FP else dtype, + torch.float, + reduction_type, + value, + ) + elif _node.target == "to_dtype" and _node.args[-1] in DTYPE_LOWP_FP: + (ops, x, _) = _node.args + # The legalization always loads the BF16/FP16 tensor as FP32 for computation + # and converts back to BF16/FP16 after the computation. + # Hence, there should be no computation w/ BF16/FP16. + # Therefore, we update the to_dtype by replacing the bf16/fp16 dtype with fp32. + # Save the legalized to_dtype node for the elimination(eliminate_to_dtype step): + # 1) Eliminate the redundant to_dtype node if we have a pattern as follows: + # graph(): + # %lowp_fp_legalized = call_method[target=to_dtype](args = (%ops, %input, torch.float)) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %lowp_fp_legalized, torch.bfloat16/float16)) + # Regarding the first to_dtype, it is redundant because + # the second to_type also converts to the torch.bfloat16/torch.float16. + # Hence, we remove the first to_type. + to_lowp_fp_legalized_nodes.append(_node) + _node.args = (ops, x, torch.float) + else: + pass + + def eliminate_to_dtype(sub_graph: torch.fx.Graph): + def _eliminate_duplicate_to_node(sub_graph: torch.fx.Graph): + # Eliminate the redundant to_dtype node. Let's consider a pattern as follows: + # graph(): + # %to_dtype1 = call_method[target=to_dtype](args = (%ops, %input, torch.float), kwargs = {}) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %to_dtype1, torch.float), kwargs = {}) + # Regarding the first to_dtype, it is redundant because the second to_type also converts to the + # torch.float. Hence, we remove the first to_type + def _used_by_to(to_node: torch.fx.Node): + return all(usr.target == "to_dtype" for usr in to_node.users) + + all_to_nodes = [ + node for node in sub_graph.nodes if node.target == "to_dtype" + ] + all_to_nodes_and_users = [ + {node: node.users} for node in all_to_nodes if _used_by_to(node) + ] + for node_users in all_to_nodes_and_users: + for node, users in node_users.items(): + if node in sub_graph.nodes and ( + all(usr.args[-1] == node.args[-1] for usr in users) + or ( + node in to_lowp_fp_legalized_nodes + and all( + usr.args[-1] in DTYPE_LOWP_FP for usr in users + ) + ) + ): + val_node = node.all_input_nodes[-1] + node.replace_all_uses_with(val_node) + sub_graph.erase_node(node) + + # For debug mode, the graph of LoopBody will attach a new GraphModule as + # owning_module for debugging while the release mode will not. The lint will + # check whether the graph has owning_module to decide if it needs to check + # call_module. LoopBody might contain get_index as a module call. But it + # is just a function. Hence, it cannot pass the lint check for debug mode. + # We bypass the check if the owning_module is None. Eventually, we should call + # get_index via call_function but not call_module. + if sub_graph.owning_module is None: + sub_graph.lint() + + _eliminate_duplicate_to_node(sub_graph) + + eliminate_to_dtype(sub_graph) + + sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values()) + for sub_block in sub_blocks: + add_to_dtype(sub_block.graph) + + def legalize_lowp_fp_dtype(self, nodes): + if all( + isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node) + for _node in nodes + ): + # Mark the load node to load bf16/fp16 + for _node in nodes: + sub_blocks = [_node._body.root_block] + list( + _node._body.subblocks.values() + ) + for sub_block in sub_blocks: + for fx_node in sub_block.graph.nodes: + if fx_node.target in ["load", "store"]: + assert fx_node.meta + assert OptimizationContext.key in fx_node.meta + opt_ctx: OptimizationContext = fx_node.meta[ + OptimizationContext.key + ] + assert opt_ctx.dtype in DTYPE_LOWP_FP + + # Bypass the legalization as the kernel can run with bf16/fp16 directly + return + + for _node in nodes: + assert isinstance(_node, SchedulerNode) + assert isinstance(_node._body, LoopBody) + body: LoopBody = _node._body + if not body.is_memory_copy(): + self.legalize_lowp_fp_dtype_loopbody(body) + + def codegen_functions(self, fn_list, var_sizes_list): + assert len(fn_list) == len(var_sizes_list) + kernel_group = self.kernel_group + group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) + + self.set_ranges(group, reduction_group) + + def codegen_kernel(cls, *args): + with kernel_group.new_kernel(cls, *args) as kernel: + # Ugly hack to maintain the metrics kernel count since + # we only count in CppKernelProxy, not those contained in it + metrics.generated_kernel_count -= 1 + + run(kernel) + return kernel + + def run(kernel): + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + in_suffix = False + for fn, var_sizes in zip(fn_list, var_sizes_list): + if var_sizes in [ + (group, reduction_group), + (tuple(itertools.chain(group, reduction_group)), ()), + ]: + assert not in_suffix + fn(vars, reduction_vars) + else: + in_suffix = True + assert var_sizes == ( + group, + (), + ), f"unexpected group: {var_sizes} != {group}, {reduction_group}" + # we can fuse in some extra pointwise into the suffix + with kernel.write_to_suffix(): + fn(vars, ()) + + scalar_kernel = codegen_kernel(CppKernel) + V.graph.removed_buffers |= scalar_kernel.removed_buffers + V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove + self.loop_nest = LoopNestWithSplit.build(scalar_kernel) + + if not self.picked_vec_isa: + return + + if not self.itervars: + # not a loop + return + + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. + # But the generated scalar kernel has updated these global contexts. Hence, the other kernels + # should not do this again to avoid context conflict. By now, we only control the + # config.inplace_buffers. In the future, we could maintain more contexts. + with torch._inductor.config.patch(inplace_buffers=False): + tiling_select = TilingSelect() + tiling_factors, tiling_indices = tiling_select.select_tiling( + fn_list, var_sizes_list + ) + assert len(tiling_factors) == len(tiling_indices) + # This should be removed after full support for vectorization is implemented. + could_masked_vec = True + all_dtypes = _get_dtype_from_loopbodies(_get_loop_body(fn_list)) + if any(dtype not in MASKED_VECTORIZABLE_DTYPES for dtype in all_dtypes): + # can be removed after masked vectorizable dtype are same with vectorizable dtype + could_masked_vec = False + + if len(tiling_indices) == 1: + vec_kernel = codegen_kernel( + CppVecKernel, tiling_factors[0], tiling_indices[0] + ) + metrics.generated_cpp_vec_kernel_count += 1 + main_loop, tail_loop = self.loop_nest.split_with_tiling( + tiling_indices[0], factor=tiling_factors[0] + ) + main_loop.set_kernel(vec_kernel) + main_loop.simd_vec = True + if config.cpp.enable_loop_tail_vec and could_masked_vec: + tail_loop.steps = tail_loop.size - tail_loop.offset + masked_vec_kernel = codegen_kernel( + CppVecKernel, + tiling_factors[0], + tiling_indices[0], + tail_loop.steps, + ) + tail_loop.set_kernel(masked_vec_kernel) + tail_loop.simd_vec = True + else: + tail_loop.set_kernel(scalar_kernel) + tail_loop.simd_omp = True + # We chop the loop into two cubes by the nelements - main loop and tail loop. + # Regarding the main loop, it is straightforward that it could be vectorized with + # nelements. But for the tail loop, it still could be vectorized. For example, + # if the nelements is 8(256bits), then the tail loop still could be vectorized + # as 4(128bits). + tail_loop.simd_nelements = tiling_factors[0] // 2 + elif len(tiling_indices) == 2: + assert ( + tiling_indices[1] == len(self.itervars) - 1 + and tiling_factors[0] == tiling_factors[1] + ) + + metrics.generated_cpp_vec_kernel_count += 2 + outer_main_loop, outer_tail_loop = self.loop_nest.split_with_tiling( + tiling_indices[0], factor=tiling_factors[0] + ) + ( + inner_main_loop, + inner_tail_loop, + ) = outer_main_loop.split_with_tiling( + tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0] + ) + tile2d_kernel = codegen_kernel( + CppTile2DKernel, tiling_factors[0], tiling_indices + ) + inner_main_loop.set_kernel(tile2d_kernel) + + if config.cpp.enable_loop_tail_vec and could_masked_vec: + ( + inner_main_loop_of_outer_tail_loop, + inner_tail_loop_of_outer_tail_loop, + ) = outer_tail_loop.split_with_tiling( + tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0] + ) + + for tail_loop in ( + inner_tail_loop, + outer_tail_loop, + inner_tail_loop_of_outer_tail_loop, + ): + tail_loop.steps = tail_loop.size - tail_loop.offset + + for tail_loop, inner_tail_size, outer_tail_size in ( + (inner_tail_loop, inner_tail_loop.steps, None), + ( + inner_main_loop_of_outer_tail_loop, + None, + outer_tail_loop.steps, + ), + ( + inner_tail_loop_of_outer_tail_loop, + inner_tail_loop_of_outer_tail_loop.steps, + outer_tail_loop.steps, + ), + ): + masked_tile2d_kernel = codegen_kernel( + CppTile2DKernel, + tiling_factors[0], + tiling_indices, + inner_tail_size, + outer_tail_size, + ) + tail_loop.set_kernel(masked_tile2d_kernel) + else: + vec_kernel = codegen_kernel( + CppVecKernel, tiling_factors[0], tiling_indices[0] + ) + inner_tail_loop.set_kernel(vec_kernel) + + outer_tail_loop.set_kernel(scalar_kernel) + + def codegen_loop_bodies(self, loop_bodies, var_sizes_list): + for body in loop_bodies: + self.legalize_lowp_fp_dtype_loopbody(body) + DataTypePropagation.propagate_loopbody(body) + self.codegen_functions(loop_bodies, var_sizes_list) + + def codegen_nodes(self, nodes: List[SchedulerNode]): + # Legalize BF16 node by adding to_dtype explicitly + self.legalize_lowp_fp_dtype(nodes) + self.data_type_propagation(nodes) + assert len(nodes) >= 1 + + def fn(node, *index_vars): + node.decide_inplace_update() + node.mark_run() + if isinstance(V.kernel, NullKernelHandler): + return node._body(*index_vars) + else: + return node.codegen(index_vars) + + fn_list = [functools.partial(fn, node) for node in nodes] + + if ( + isinstance(V.local_buffer_context, LocalBufferContext) + and V.local_buffer_context.local_buffers + ): + + def wrap_fn(fn): + wrapped_fn = V.local_buffer_context.localize_function( + fn, + ) + wrapped_fn.original_fn = fn + return wrapped_fn + + fn_list = [wrap_fn(fn) for fn in fn_list] + + var_sizes_list = [node.group[1] for node in nodes] + self.codegen_functions(fn_list, var_sizes_list) + + def codegen_loops(self, code, worksharing): + self.codegen_loops_impl(self.loop_nest, code, worksharing) + + +class OuterLoopFusedKernel(CppKernel): + def __init__(self, kernel_group): + super().__init__(kernel_group.args, kernel_group.ws.num_threads) + self.inner: List[LoopLevel] = [] + + def decide_parallel_depth(self, max_parallel_depth, threads) -> int: + kernels_parallel_depth = [] + nested_kernels: List[List[CppKernel]] = [ + loop.get_kernels() for loop in self.inner + ] + for kernels in nested_kernels: + # For any ScalarKernel, VecKernel, or Tile2DKernel, + # they should all have the same call_ranges + call_ranges = kernels[0].call_ranges + assert call_ranges is not None + assert all(kernel.call_ranges == call_ranges for kernel in kernels) + kernels_parallel_depth.append( + kernels[0].decide_parallel_depth(len(call_ranges), threads) + ) + return min( + max_parallel_depth, + max(kernels_parallel_depth), + ) + + +class ReasonFusedNodes(Enum): + SAME_VARS_REDUCE = "same_vars_reduce" + COMPATIBLE_REDUCTION = "compatible_reduction" + COMPATIBLE_RANGES_NO_REDUCTION = "compatible_ranges_no_reduction" + + +class CppScheduling(BaseScheduling): + # ctypes limits the number of args to 1024, refer to: + # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237 + # We set a conservative threshold here. + MAX_FUSED_KERNEL_ARGS_NUM = 500 + backend_features = dict.fromkeys( + [ + BackendFeature.INPLACE_BUFFERS, + BackendFeature.REDUCE_TO_SINGLE_ELEMENT, + ] + ) + + @classmethod + def get_backend_features(cls, device: torch.device): + return cls.backend_features + + def __init__(self, scheduler): + super().__init__() + self.scheduler = scheduler + if scheduler: + self.reset_kernel_group() + self._ready_to_flush = False + + def _set_flush_status(self, status: bool): + self._ready_to_flush = status + + def group_fn(self, sizes): + return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) + + def reset_kernel_group(self): + from .cpp_wrapper_cpu import CppWrapperCpu + + self.kernel_group: Union[CppWrapperKernelGroup, KernelGroup] + if isinstance(V.graph.wrapper_code, CppWrapperCpu): + self.kernel_group = CppWrapperKernelGroup() + else: + self.kernel_group = KernelGroup() + + def fuse(self, node1, node2): + if node1.is_foreach() or node2.is_foreach(): + return ForeachKernelSchedulerNode.fuse(node1, node2) + elif node1.is_template(): + assert not node2.is_template() + return FusedSchedulerNode.fuse(node1, node2) + else: + if ( + self._why_fuse_nodes(node1, node2) + == ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + ): + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) + assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) + + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + assert reduce1 == () and reduce2 == (), (reduce1, reduce2) + + def get_indexing_ranges_exprs(node): + if isinstance(node, FusedSchedulerNode): + assert len(node.snodes) > 0, node.snodes + var_ranges = None + indexing_exprs = set() + for snode in node.snodes: + v, exprs = get_indexing_ranges_exprs(snode) + if var_ranges is None: + var_ranges = v + assert var_ranges == v, (var_ranges, v, node.snodes) + indexing_exprs.update(exprs) + return var_ranges, list(indexing_exprs) + else: + assert isinstance(node, SchedulerNode) + comp_buffer = node.node + assert isinstance(comp_buffer, ir.ComputedBuffer) + _, body, _ = comp_buffer.get_default_sizes_body() + return body.var_ranges, list(body.indexing_exprs.values()) + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + assert isinstance(node_to_recomp, SchedulerNode) + + ref_node = node2 if len(vars1) < len(vars2) else node1 + + extra_indexing_constraints = get_indexing_ranges_exprs(ref_node) + + node_to_recomp.recompute_size_and_body( + extra_indexing_constraints=extra_indexing_constraints + ) + + _, (vars1, _) = node1.group + _, (vars2, _) = node2.group + assert vars1 == vars2, (vars1, vars2) + return FusedSchedulerNode.fuse(node1, node2) + elif self.can_fuse_vertical_outer_loop(node1, node2): + return OuterLoopFusedSchedulerNode.fuse( + node1, node2, self._get_outer_loop_fusion_depth(node1, node2) + ) + else: + return FusedSchedulerNode.fuse(node1, node2) + + def _why_fuse_nodes(self, node1, node2) -> Optional[ReasonFusedNodes]: + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + if vars1 == vars2 and reduce1 == reduce2: + return ReasonFusedNodes.SAME_VARS_REDUCE + if reduce1 == () and vars1 == vars2 + reduce2: + return ReasonFusedNodes.COMPATIBLE_REDUCTION + if self._can_fuse_nodes_with_compatible_ranges(node1, node2): + return ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + # TODO(jansel): allow fusion pointwise (vars1, ()) suffix? + return None + + def _can_fuse_nodes_with_compatible_ranges(self, node1, node2): + # Here we try to fuse SchedulerNode/FusedSchedulerNode with compatible ranges + # e.g. (s0, s1, s2) and (s0 * s1 * s2) + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + c1 = reduce1 == () and reduce2 == () + c2 = math.prod(vars1) == math.prod(vars2) + c3 = len(vars1) == 1 or len(vars2) == 1 + if not (c1 and c2 and c3): + return False + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + ref_node = node2 if len(vars1) < len(vars2) else node1 + + # We can not recompute sizes and body for nodes other than SchedulerNode + # TODO: we can extend fusion support with compatible ranges for FusedSchedulerNode + if isinstance(node_to_recomp, FusedSchedulerNode): + return False + + # It may happen that node1 and node2 compatible number of elements + # but different original ranges, for example: + # {d0: s0, d1: s1, d2: s2} vs {d0: s0*s1*s2} + # See https://github.com/pytorch/pytorch/pull/120077/files#r1500427848 for more details + # TODO: we can fix if it allows us to CSE at least one of the variables + + assert isinstance(node_to_recomp, SchedulerNode) + if isinstance(node_to_recomp.node, ir.TemplateBuffer): + return False + assert isinstance(node_to_recomp.node, ir.ComputedBuffer) + # node.data.get_size() is a cheaper version of node.get_read_writes().var_ranges + # but without variable name + ranges2 = node_to_recomp.node.data.get_size() + ranges1 = None + if isinstance(ref_node, FusedSchedulerNode): + ranges_set = set() + for snode in ref_node.snodes: + if isinstance(snode.node, ir.TemplateBuffer): + break + assert isinstance(snode.node, ir.ComputedBuffer) + ranges_set.add(tuple(snode.node.data.get_size())) + + if len(ranges_set) != 1: + return False + + ranges1 = list(next(iter(ranges_set))) + else: + assert isinstance(ref_node, SchedulerNode) + assert isinstance(ref_node.node, ir.ComputedBuffer) + ranges1 = ref_node.node.data.get_size() + + if ranges1 != ranges2: + return False + + return True + + def _can_fuse_horizontal_impl(self, node1, node2): + assert isinstance(node1, (FusedSchedulerNode, SchedulerNode)) + assert isinstance(node2, (FusedSchedulerNode, SchedulerNode)) + if any( + isinstance(node, OuterLoopFusedSchedulerNode) for node in (node1, node2) + ): + return False + return self._why_fuse_nodes(node1, node2) is not None + + def can_fuse_horizontal(self, node1, node2): + if node1.is_template() or node2.is_template(): + return False + if ( + len(node1.get_nodes()) + len(node2.get_nodes()) + > config.cpp.max_horizontal_fusion_size + ): + return False + + return self._can_fuse_horizontal_impl(node1, node2) + + def _get_outer_loop_fusion_depth(self, node1, node2): + DISABLE_OUTER_LOOP_FUSION = 0 + if not all( + type(node) + in (OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode) + for node in (node1, node2) + ): + return DISABLE_OUTER_LOOP_FUSION + + _node1 = ( + node1.get_outer_nodes()[-1] + if isinstance(node1, OuterLoopFusedSchedulerNode) + else node1 + ) + assert isinstance(_node1, (FusedSchedulerNode, SchedulerNode)) + _node2 = ( + node2.get_outer_nodes()[0] + if isinstance(node2, OuterLoopFusedSchedulerNode) + else node2 + ) + assert isinstance(_node2, (FusedSchedulerNode, SchedulerNode)) + + _, (vars1, reduce1) = _node1.group + _, (vars2, reduce2) = _node2.group + if vars1 == () and vars2 == () and reduce1 != () and reduce2 != (): + # Reduction only + return DISABLE_OUTER_LOOP_FUSION + if all(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): + return ( + node1.outer_loop_fusion_depth + if node1.outer_loop_fusion_depth == node2.outer_loop_fusion_depth + else DISABLE_OUTER_LOOP_FUSION + ) + outer_loop_fusion_depth = min(len(vars1), len(vars2)) + if ( + outer_loop_fusion_depth >= 1 + and vars1[:outer_loop_fusion_depth] == vars2[:outer_loop_fusion_depth] + ): + if any( + type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2) + ): + _compare_node = ( + node1 if type(node1) is OuterLoopFusedSchedulerNode else node2 + ) + if _compare_node.outer_loop_fusion_depth == outer_loop_fusion_depth: + # Same outer loop fusion depth as prev nodes in OuterLoopFusedSchedulerNode + return outer_loop_fusion_depth + else: + return DISABLE_OUTER_LOOP_FUSION + else: + # First 2 nodes to generate OuterLoopFusedSchedulerNode + return outer_loop_fusion_depth + return DISABLE_OUTER_LOOP_FUSION + + def can_fuse_vertical_outer_loop(self, node1, node2): + return ( + not node1.is_template() + and not node2.is_template() + and node1.get_operation_names() & node2.ancestors + and not ( + self._can_fuse_horizontal_impl(node1, node2) + and not node1.is_reduction() + ) + and self._get_outer_loop_fusion_depth(node1, node2) >= 1 + ) + + def get_fusion_pair_priority(self, node1, node2): + if self.can_fuse_vertical_outer_loop(node1, node2): + # Outer loop fusion with lower priority + return 1 + else: + return 0 + + def can_fuse_vertical(self, node1, node2): + if node2.is_template(): + # TODO(jgong5): support pre-op fusion with template + return False + if node1.is_template(): + return not node2.is_reduction() + return ( + self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() + ) or self.can_fuse_vertical_outer_loop(node1, node2) + + def try_loop_split(self, nodes: List[SchedulerNode]): + """ + Apply loop split optimization. + When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop + to avoid non-contiguous loads, subject to the following conditions: + 1. No reduction and no mudular index for all nodes. + 2. Only one node's one indexing_exprs contains a division, according to this indexing_exprs, + we can get the dimension that needs to be split, and the split dimension is contiguous + in all other indexing_exprs. + + For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs: + {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2}, + we will split z2 -> 30*z2 + z3, then the node's var_ranges will be changed to + {z0: 2, z1: 9216, z2: 32, z3: 30} and indexing_exprs will be changed to + {'index0': 8847360*z0 + 960*z1 + 30*z2 + z3, 'index1': 32*z0 + z2, 'index2': 30*z2 + z3}. + """ + + # No reduction and no mudular + if any( + len(node.group[1][1]) != 0 + or any( + expr.has(ModularIndexing) for expr in node._body.indexing_exprs.values() + ) + for node in nodes + ): + return nodes + + split_var = None + split_number = None + divide_index_name = None + num_div = 0 + match_div = False + matched_node = None + + for node in nodes: + assert isinstance(node.node, ir.ComputedBuffer) + _, original_body, _ = node.node.get_default_sizes_body() + for name, expr in original_body.indexing_exprs.items(): + num_div += expr.count(FloorDiv) + if num_div > 1: + return nodes + if expr.count(FloorDiv) == 1: + div_expr = expr.find(FloorDiv).pop() + split_var = div_expr.args[0] + split_number = div_expr.args[1] + divide_index_name = name + if ( + isinstance(split_number, sympy.core.numbers.Integer) + and isinstance(split_var, sympy.core.symbol.Symbol) + and split_var in original_body.iter_vars + and divide_index_name is not None + and all( + stride_at_vec_range(expr, split_var) == 1 + for name, expr in original_body.indexing_exprs.items() + if name != divide_index_name + ) + ): + match_div = True + matched_node = node + + # Only one node contains a division, and the split dimension is contiguous in all other indexing_exprs. + if not match_div: + return nodes + + extra_indexing_constraints = None + + def loop_split(sizes, body, vars): + index_size, reduce_size = sizes + index_vars, reduce_vars = vars + split_idx = index_vars.index(split_var) + new_index_size = index_size.copy() + new_index_size[split_idx] = index_size[split_idx] // split_number + new_index_size.insert(split_idx + 1, split_number) + (new_index_vars, _), var_ranges = dependencies.index_vars_no_squeeze( + new_index_size, reduce_size, prefix="y" + ) + iter_vars = new_index_vars.copy() + divisor_var = iter_vars.pop(split_idx + 1) + iter_vars[split_idx] = split_number * iter_vars[split_idx] + divisor_var + body = ir.LoopBody( + body, [iter_vars, reduce_vars], var_ranges, new_index_vars, reduce_vars + ) + nonlocal extra_indexing_constraints + if not extra_indexing_constraints: + extra_indexing_constraints = ( + body.var_ranges, + list(body.indexing_exprs.values()), + ) + return ( + (new_index_size, reduce_size), + body, + (new_index_vars, reduce_vars), + ) + + # Here decide the final loop order + for node in nodes: + if node == matched_node: + node.recompute_size_and_body(recompute_sizes_body_func=loop_split) + for node in nodes: + if node != matched_node: + node.recompute_size_and_body( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=loop_split, + ) + + return nodes + + def codegen_outer_loop_node( + self, + node: OuterLoopFusedSchedulerNode, + ): + """ + Generate the code for the outer loop fused scheduler node. + 1. Codegen with fused outer loop: depends on the analysis of + the outer loop fused scheduler node, with or without the local buffer. + 2. If failed, fallback to standard codegen. + """ + kernel_group = self.kernel_group + generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count + cpp_kernel_proxy_list: List[CppKernelProxy] = [] + nodes_list: List[List[SchedulerNode]] = [] + assert isinstance(node, OuterLoopFusedSchedulerNode) + + def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode): + """ + Codegen code with fused outer loop and local Buffer. + """ + assert isinstance(node, OuterLoopFusedSchedulerNode) + cpp_kernel_proxy_list.clear() + nodes_list.clear() + + def get_call_ranges(node: BaseSchedulerNode): + assert isinstance(node, (SchedulerNode, FusedSchedulerNode)) + nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + call_ranges = tuple(group) + tuple(reduction_group) + return call_ranges + + local_buffers: List[ir.Buffer] = [] + # Map local buffer name to a list of global buffers + local_to_global_buffers: Dict[str, List[ir.Buffer]] = {} + if all( + len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1 + for _node in node.get_outer_nodes() + ): + # Ref to the typical case of local buffer + # in https://github.com/pytorch/pytorch/blob/ + # 1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 + # where the buffer is with size of last dim and contiguous. + # Only support this typical case at first. + visited_scheduler_nodes: Set[str] = set() + for scheduler_node in node.get_nodes(): + # all users inside same OuterLoopFusedSchedulerNode + assert isinstance(scheduler_node, SchedulerNode) + visited_scheduler_nodes.add(scheduler_node.get_name()) + if ( + scheduler_node.is_reduction() + or len(scheduler_node.get_outputs()) != 1 + ): + continue + + scheduler_buffer = scheduler_node.get_outputs()[0] + if all( + user.node in node.get_nodes() for user in scheduler_buffer.users + ): + global_buffer = scheduler_buffer.node + assert isinstance(global_buffer, ir.ComputedBuffer) + global_buffer_layout = global_buffer.get_layout() + size_offset = node.outer_loop_fusion_depth - len( + get_call_ranges(scheduler_node) + ) + + def is_all_write_read_contiguous(): + contiguous_index_expr = 0 + stride = 1 + for var, range in reversed( + scheduler_node._body.var_ranges.items() + ): + contiguous_index_expr += stride * var + stride *= range + write_index_expr = scheduler_node._body.get_write_expr( + scheduler_buffer.get_name() + ) + + def is_contiguous_index(x): + return x == contiguous_index_expr + + return is_contiguous_index(write_index_expr) and all( + isinstance(user.node, SchedulerNode) + and is_contiguous_index( + user.node._body.get_read_expr( + scheduler_buffer.get_name() + ), + ) + for user in scheduler_buffer.users + ) + + if not ( + global_buffer_layout.is_contiguous() + and is_all_write_read_contiguous() + ): + continue + # Local Buffer is a view of global buffer + local_buffer_layout = ir.FixedLayout( + global_buffer_layout.device, + global_buffer_layout.dtype, + global_buffer_layout.size[size_offset:], + global_buffer_layout.stride[size_offset:], + ) + + def try_share_local_buffer(local_buffer_layout, local_buffers): + for local_buf in local_buffers: + if local_buffer_layout == local_buf.layout and all( + all( + user.node.get_name() in visited_scheduler_nodes + for user in V.graph.scheduler.name_to_buf[ + global_buffer.name + ].users + ) + for global_buffer in local_to_global_buffers[ + local_buf.name + ] + if global_buffer.name is not None + ): + return local_buf + return None + + local_buf_prefix = "local_buffer_data" + # Share existing local buffer + local_buffer_used = try_share_local_buffer( + local_buffer_layout, local_buffers + ) + if not local_buffer_used: + # Create new local buffer + local_buffer_used = ir.Buffer( + f"{local_buf_prefix}_{len(local_buffers)}", + local_buffer_layout, + ) + local_buffers.append(local_buffer_used) + local_to_global_buffers[local_buffer_used.name] = [] + local_to_global_buffers[local_buffer_used.name].append( + global_buffer, + ) + + with LocalBufferContext(kernel_group.args) as scope: + if len(local_buffers) > 0: + for local_buffer in local_buffers: + assert local_buffer.name is not None + scope.add_local_buffer( + local_buffer, local_to_global_buffers[local_buffer.name] + ) + for _node in node.get_outer_nodes(): + assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) + cpp_kernel_proxy = CppKernelProxy(kernel_group) + cpp_kernel_proxy.codegen_nodes(_node.get_nodes()) # type: ignore[arg-type] + cpp_kernel_proxy_list.append(cpp_kernel_proxy) + nodes_list.append(_node.get_nodes()) # type: ignore[arg-type] + + if not node.check_outer_fusion_loop_level_attr( + cpp_kernel_proxy_list, node.outer_loop_fusion_depth + ): + return False + metrics.cpp_outer_loop_fused_inner_counts.append( + metrics.CppOuterLoopFusedCount( + len(cpp_kernel_proxy_list), + local_buffer_number=len(scope.local_buffers), + ) + ) + outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels( + cpp_kernel_proxy_list, + ) + kernel_group.finalize_kernel( + outer_fusion_cpp_kernel_proxy, + [_node for _nodes in nodes_list for _node in _nodes], + ) + + return True + + if not try_outer_loop_fusion_with_local_buf(node): + # Reset generated_cpp_vec_kernel_count to codegen again + metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count + cpp_kernel_proxy_list.clear() + nodes_list.clear() + # Similar as comment in + # https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272 + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. + with torch._inductor.config.patch(inplace_buffers=False): + for _node in node.get_outer_nodes(): + assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) + _nodes: List[SchedulerNode] = _node.get_nodes() # type: ignore[assignment] + cpp_kernel_proxy = CppKernelProxy(kernel_group) + cpp_kernel_proxy.codegen_nodes(_nodes) + kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes) + + def codegen_node( + self, + node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode], + ): + """ + Turn an set of pre-fused nodes into a C++ kernel. + """ + kernel_group = self.kernel_group + + if isinstance(node, OuterLoopFusedSchedulerNode): + self.codegen_outer_loop_node(node) + else: + nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + nodes = self.try_loop_split(nodes) + cpp_kernel_proxy = CppKernelProxy(kernel_group) + cpp_kernel_proxy.codegen_nodes(nodes) + kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) + + args_num = self._get_scheduled_num_args() + if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: + self._set_flush_status(True) + + def is_cpp_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, ir.CppTemplateBuffer + ) + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CPP template, possibly with fused epilogues + """ + counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cpp_template( + template_node + ), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" + template_node = cast(SchedulerNode, template_node) + _, (_, rnumel) = template_node.group + assert rnumel == () + ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node) + epilogue_ir_nodes: List[Optional[ir.Operation]] = [ + n.node for n in epilogue_nodes + ] + assert all( + isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes + ), "Epilogue nodes must all be instances of ir.ComputedBuffer" + + def template_buffer_has_other_users( + template_buffer, outputs_by_name, epilogue_nodes + ): + assert template_buffer.get_name() in outputs_by_name + users = outputs_by_name[template_buffer.get_name()].users + return not all( + isinstance(user.node, BaseSchedulerNode) + and user.node.node in epilogue_nodes + for user in users + ) + + flag_template_buffer_has_other_users = template_buffer_has_other_users( + ctb, template_node.outputs_by_name, epilogue_ir_nodes + ) + kernel, render = ctb.make_kernel_render( + ctb, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_ir_nodes, + ) + with kernel: + for node in [template_node, *epilogue_nodes]: + node.mark_run() # type: ignore[attr-defined] + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule, kernel.args) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.scheduler.free_buffers() + + def _get_scheduled_num_args(self): + return self.kernel_group.get_num_args() + + def ready_to_flush(self): + return self._ready_to_flush + + def codegen_sync(self): + pass + + def define_kernel(self, src_code, nodes, kernel_args=None): + wrapper = V.graph.wrapper_code + fused_name = ( + get_fused_kernel_name(nodes, config.cpp.descriptive_names) + if config.cpp.descriptive_names + else "" + ) + kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) + kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name) + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + src_code = src_code.replace("#pragma CMT", "//") + + compile_wrapper = IndentedBuffer() + args = self.kernel_group.args if kernel_args is None else kernel_args + _, _, arg_types = args.cpp_argdefs() + if not V.graph.cpp_wrapper: + compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''") + compile_wrapper.splice(src_code, strip=True) + if not V.graph.cpp_wrapper: + compile_wrapper.writeline("''')") + wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), cuda=False) + return kernel_name + + def flush(self): + src_code = self.kernel_group.codegen_group() + if src_code: + kernel_name = self.define_kernel( + src_code, self.kernel_group.scheduled_nodes + ) + self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name) + self.reset_kernel_group() + self._set_flush_status(False) + + +class KernelGroup: + def __init__(self): + super().__init__() + self.args = KernelArgs() + self.loops_code = BracesBuffer() + self.ws = WorkSharing(self.loops_code) + self.stack = contextlib.ExitStack() + self.stack.enter_context(self.ws) + self.scheduled_nodes = [] + + def new_kernel(self, cls, *args): + return cls(self.args, parallel_num_threads(), *args) + + def finalize_kernel(self, new_kernel, nodes): + self.scheduled_nodes += nodes + code = self.loops_code + ws = self.ws + new_kernel.codegen_loops(code, ws) + + def get_num_args(self): + arg_defs, call_args, arg_types = self.args.cpp_argdefs() + args_num = len(arg_defs) + return args_num + + def codegen_group(self, name=None) -> str: + self.stack.close() + if not self.scheduled_nodes: + return "" + code = BracesBuffer() + # 1. Include header files + # TODO: support kernel profile on other platforms + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if enable_kernel_profile: + code.writelines(["#include "]) + code.writeline(codecache.cpp_prefix()) + + # 2. Function definition + kernel_decl_name = str(Placeholder.KERNEL_NAME) if name is None else name + kernel_name = str(Placeholder.DESCRIPTIVE_NAME) if name is None else name + arg_defs, _, _ = self.args.cpp_argdefs() + arg_defs = ",\n".ljust(25).join(arg_defs) + func_export_decl = get_export_declaration() + code.writeline( + f'extern "C" {func_export_decl} void {kernel_decl_name}({arg_defs})' + ) + + # 3. Function body + with code.indent(): + if enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + code.writelines( + [ + f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef({{}}));' + ] + ) + for old, new in self.args.aliases(): + code.writeline(f"auto {old} = {new};") + code.splice(self.loops_code) + return code.getvalue() + + def call_kernel(self, wrapper, kernel_name): + _, call_args, arg_types = self.args.cpp_argdefs() + wrapper.generate_kernel_call( + kernel_name, call_args, cuda=False, arg_types=arg_types + ) + + +class CppWrapperKernelGroup(KernelGroup): + def __init__(self): + super().__init__() + self.args = CppWrapperKernelArgs() + + +class WorkSharing: + def __init__(self, code): + self.code = code + self.in_parallel = False + self.num_threads = None + self.stack = contextlib.ExitStack() + + def parallel(self, threads): + if self.in_parallel and threads != self.num_threads: + # wrong number of threads + self.close() + if not self.in_parallel: + self.num_threads = threads + self.in_parallel = True + if config.cpp.dynamic_threads: + self.code.writeline("#pragma omp parallel") + else: + self.code.writeline(f"#pragma omp parallel num_threads({threads})") + self.stack.enter_context(self.code.indent()) + self.code.writeline( + "int tid = omp_get_thread_num();", + ) + + def single(self): + if self.in_parallel: + self.code.writeline("#pragma omp single") + return self.in_parallel + + def close(self): + self.stack.close() + self.in_parallel = False + + def __enter__(self): + self.stack.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stack.__exit__(exc_type, exc_val, exc_tb) + + +@dataclasses.dataclass +class LoopLevel: + var: Optional[sympy.Expr] = None + size: Optional[sympy.Expr] = None + offset: sympy.Expr = sympy.Integer(0) + steps: sympy.Expr = sympy.Integer(1) + parallel: int = 0 + simd_omp: bool = False + simd_vec: bool = False + collapsed: bool = False + is_reduction: bool = False + parent: Optional["LoopLevel"] = None + # the next inner level of the loop, empty if it is inner-most + # contains >1 LoopLevel if the inner level of loop is split + inner: List["LoopLevel"] = dataclasses.field(default_factory=list) + # kernel assigned to this loop level, only valid when it is a leaf + kernel: Optional[CppKernel] = None + + def __post_init__(self): + # Regarding the C++/OpenMP backend, `cpu_vec_isa.pick_vec_isa()` to check + # vectorization ISA is a time-consuming and one-shot operation. It leads + # to taking a longer time to import `codegen.cpp` package because the + # `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while + # the decorator will invoke `cpu_vec_isa.pick_vec_isa()` to initialize the + # `simd_nelements` of the `LoopLevel`. It might introduce additional compilation + # overhead to the Triton backend. Therefore, we moved the `simd_nelements` to + # `__post_init__` + picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa() + self.simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0 + + def get_kernels(self) -> List[CppKernel]: + """Get all kernel objects under this loop level""" + if self.kernel: + return [self.kernel] + kernels = [] + for loop in self.inner: + kernels += loop.get_kernels() + return kernels + + def get_root(self): + """Get all kernel objects under this loop level""" + root = self + while root.parent: + root = root.parent + return root + + def set_kernel(self, kernel: CppKernel): + """ + Set the kernel under this loop level. No split is allowed under + this loop level. + """ + if not self.inner: + self.kernel = kernel + loop: Optional[LoopLevel] = self + assert loop is not None + return + assert len(self.inner) == 1 + self.inner[0].set_kernel(kernel) + + def get_loops_at(self, depth) -> List["LoopLevel"]: + if depth == 0: + return [self] + else: + loops = [] + for loop in self.inner: + loops += loop.get_loops_at(depth - 1) + return loops + + def split_with_tiling(self, depth, factor): + def clone_inner(): + inner = [] + if self.inner: + for loop in self.inner: + inner.append(loop.clone()) + return inner + + def do_split_with_tiling(): + sympy_factor = sympy.Integer(factor) + + offset = FloorDiv(self.size, sympy_factor) * sympy_factor + main_loop = LoopLevel(self.var, offset) + main_loop.steps = sympy_factor + main_loop.parallel = self.parallel + main_loop.collapsed = False + main_loop.is_reduction = self.is_reduction + main_loop.inner = clone_inner() + if main_loop.inner: + for loop in main_loop.inner: + loop.parent = main_loop + + tail_loop = LoopLevel(self.var, self.size) + tail_loop.offset = offset + tail_loop.parallel = self.parallel + tail_loop.collapsed = False + tail_loop.is_reduction = self.is_reduction + tail_loop.inner = clone_inner() + if tail_loop.inner: + for loop in tail_loop.inner: + loop.parent = tail_loop + + return main_loop, tail_loop + + if depth == 0: + main_loop, tail_loop = do_split_with_tiling() + parent = self.parent + if parent: + parent.inner = [main_loop, tail_loop] + main_loop.parent = parent + tail_loop.parent = parent + return main_loop, tail_loop + else: + assert len(self.inner) == 1 + return self.inner[0].split_with_tiling(depth - 1, factor) + + def clone(self): + loop = copy(self) + loop.inner = [] + if self.inner: + for inner_loop in self.inner: + inner_loop_clone = inner_loop.clone() + inner_loop_clone.parent = loop + loop.inner.append(inner_loop_clone) + loop.kernel = deepcopy(self.kernel) + return loop + + def lines(self): + offset_expr = cexpr_index(self.offset) + size_expr = cexpr_index(self.size) + if config.cpp.no_redundant_loops and offset_expr == size_expr: + return None + simd = ( + f"simd simdlen({self.simd_nelements}) " + if self.simd_omp and self.simd_nelements > 1 + else "" + ) + if self.parallel: + # TODO(jansel): look into chunk size and other schedules + line1 = "#pragma omp for" + if self.parallel > 1: + line1 += f" collapse({self.parallel})" + if self.simd_omp: + line1 = line1.replace(" for ", f" for {simd}") + elif self.simd_vec: + line1 = "" + elif self.simd_omp: + line1 = f"#pragma omp {simd}" + elif not self.is_reduction and cpp_builder.is_gcc(): + line1 = "#pragma GCC ivdep" + else: + line1 = "" + offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}" + size_str = f"{self.var}<{size_expr}" + if self.steps.is_number: + steps_str = f"{self.var}+={cexpr_index(self.steps)}" + else: + # If the step size is 0, change it to 1 because a step size of 0 + # will cause floating point exception (core dump) during parallelization. + steps_str = ( + f"{self.var}+=({cexpr_index(self.steps)} == 0 ? " + f"1 : {cexpr_index(self.steps)})" + ) + line2 = f"for({offset_str}; {size_str}; {steps_str})" + if self.collapsed or not line1: + return [line2] + return [line1, line2] + + +@dataclasses.dataclass +class LoopNestWithSplit: + """ + A loop-nest like structure but with some loop level split along + the loop range into the main tiling loop and the tail. It is built + with the `build` method as a loop nest and then split with + `split_with_tiling` at some depth. + + A typical case is for vectorization where we typically split at the inner-most + loop level. A more complicated case is 2D tiling where we split at + both inner-most and outer levels. + """ + + root: Optional[List[LoopLevel]] = None + kernel: Optional[CppKernel] = None + + @staticmethod + def build(kernel: CppKernel): + """Build a LoopNest with the given `kernel` as the leaf""" + itervars = kernel.itervars + ranges = kernel.ranges + reduction_depth = kernel.reduction_depth + assert reduction_depth is not None + + root: List[LoopLevel] = [] + levels: List[LoopLevel] = root + loop: Optional[LoopLevel] = None + for loop_idx, (var, size) in enumerate(zip(itervars, ranges)): + loop = LoopLevel(var, size, parent=loop) + if loop_idx >= reduction_depth: + loop.is_reduction = kernel.is_reduction + levels.append(loop) + levels = loop.inner + loop_nest = LoopNestWithSplit(root) + if loop: + loop.kernel = kernel + else: + loop_nest.kernel = kernel + return loop_nest + + def __bool__(self): + return bool(self.root) + + def get_loops_at(self, depth) -> List[LoopLevel]: + """Get all the loop levels at the given `depth` (most outer loop has depth 0)""" + loops: List[LoopLevel] = [] + assert self.root is not None + for loop in self.root: + loops += loop.get_loops_at(depth) + return loops + + @cache_on_self + def max_parallel_depth(self): + """ + Maximal allowed depth for parallelism: + 1) Levels without splitting and + 2) All reduction or non-reduction levels + When the loop is split at the top level, the max depth is 1. + """ + max_depth = 0 + assert self.root is not None + loops = self.root + if len(loops) > 1: + return 1 + is_reduction = loops[0].is_reduction if loops else False + while len(loops) == 1 and loops[0].is_reduction == is_reduction: + max_depth += 1 + loops = loops[0].inner + return max_depth + + def is_reduction_only(self): + """ + Whether all the loops are for reduction. Reduction loops + are always the inner most ones. + """ + return ( + self.root is not None and len(self.root) > 0 and self.root[0].is_reduction + ) + + def mark_parallel(self, par_depth): + assert ( + par_depth <= self.max_parallel_depth() + ), "Parallel depth cannot exceed the maximal allowed parallel depth" + assert self.root is not None + loops = self.root + for loop in loops: + loop.parallel = par_depth + for i in range(1, par_depth): + loops = loops[0].inner + loops[0].collapsed = True + + def split_with_tiling(self, depth, factor): + """ + Split the loop into main and tail loops at given `depth` so that the range + of the main loop has range `floor_div(range, factor) * factor` and + the tail loop handles the remainder. The main loop is tiled + according to the `factor`. + """ + loops = self.get_loops_at(depth) + assert len(loops) == 1 + split_loops = loops[0].split_with_tiling(0, factor) + if depth == 0: + self.root = split_loops + return split_loops + + def get_kernels(self) -> List[CppKernel]: + """Get all kernel objects under this loop nest""" + if self.kernel: + return [self.kernel] + kernels: List[CppKernel] = [] + assert self.root is not None + for loop in self.root: + kernels += loop.get_kernels() + return kernels diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_gemm_template.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a2b1c41576d0d68435dcab31e11b94cf436c15 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_gemm_template.py @@ -0,0 +1,1043 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +import math +from functools import lru_cache +from typing import Any, Callable, cast, List, Optional, Set, Union +from unittest.mock import patch + +import torch +import torch.utils + +from ..._dynamo.utils import counters +from .. import config, ir, lowering as L +from ..kernel.mm_common import mm_args +from ..select_algorithm import DataProcessorTemplateWrapper +from ..utils import cache_on_self, has_free_symbols, parallel_num_threads +from ..virtualized import ops, V +from .cpp import get_export_declaration +from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType +from .cpp_template import CppTemplate +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import ( + create_epilogue_with_attr, + DTYPE_TO_CPP, + GemmBlocking, + get_gemm_template_output_and_compute_dtype, +) + + +log = logging.getLogger(__name__) + +GEMM_TEMPLATE = r""" +{{template.header().getvalue()}} + +{{micro_gemm.codegen_define(kernel)}} + +{%- if x_scale is not none %} + {%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %} +{%- else %} + {%- set kernel_args = {"X": X, "W": W, "inp": inp} %} +{%- endif %} + +extern "C" {{export_declaration}} +{{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}} +{ + {{kernel.maybe_codegen_profile()}} + constexpr int64_t num_threads = {{num_threads}}; + constexpr int64_t N = {{N}}; + constexpr int64_t K = {{K}}; + constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}}; + constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}}; + constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}}; + constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; + constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; + +{%- if is_dynamic_M %} + const int64_t M = {{kernel.size(GemmOut, 0)}}; + const int64_t Mr_blocks = (M + Mr - 1) / Mr; + {%- if num_threads > 1 %} + int64_t Mt_blocks, Nt_blocks, Kt_blocks; + mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks); + {%- else %} + const auto Mt_blocks = Mr_blocks; + const auto Nt_blocks = Nr_blocks; + const auto Kt_blocks = Kr_blocks; + {%- endif %} + int64_t Mc_blocks, Nc_blocks, Kc_blocks; + uint32_t L1_cache_size = {{L1_cache_size}}; + uint32_t L2_cache_size = {{L2_cache_size}}; + mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>( + num_threads, + M, + N, + K, + Mr, + Nr, + Kr, + Mt_blocks, + Nt_blocks, + Kt_blocks, + Mc_blocks, + Nc_blocks, + Kc_blocks, + L1_cache_size, + L2_cache_size + ); + const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; + const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; + const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; +{%- else %} + constexpr int64_t M = {{kernel.size(GemmOut, 0)}}; + constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; + constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}}; + constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}}; + constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}}; + constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}}; + constexpr int64_t Nc_blocks = {{template.cache_blocking().block_n}}; + constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}}; + constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; + constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; + constexpr int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; +{%- endif %} + + // make sure all partitions are assigned + {{kernel.assert_function}}( + Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks, + "Not all partitions are assigned." + ); + +{%- if maybe_k_slicing %} + std::unique_ptr[]> local_buf_ptrs; + if (num_k_slices > 1) { + local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]); + } +{%- endif %} + +{%- if num_threads > 1 %} + #pragma omp parallel num_threads({{num_threads}}) + { + const int tid = omp_get_thread_num(); + int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end; + mm_get_thread_blocks( + tid, Mr_blocks, Nr_blocks, Kr_blocks, Mt_blocks, Nt_blocks, Kt_blocks, + m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end); + {%- if maybe_k_slicing %} + const int64_t k_group_id = tid / num_k_slices; + const int64_t k_slice_id = tid % num_k_slices; + {%- endif %} +{%- else %} + { + const int tid = 0; + const int64_t m_block_start = 0; + const int64_t m_block_end = Mr_blocks; + const int64_t n_block_start = 0; + const int64_t n_block_end = Nr_blocks; + const int64_t k_block_start = 0; + const int64_t k_block_end = Kr_blocks; +{%- endif %} + {{ micro_gemm.codegen_init(kernel) }} +{%- if use_local_acc %} + {%- set acc_buf_name = "local_acc_buf" %} + {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }} +{%- endif %} + for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { + const int64_t m_start = mc * Mr; + const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); + const int64_t m_size = m_end - m_start; + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + const int64_t n_start = nc * Nr; + const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); + const int64_t n_size = n_end - n_start; + // NB: assume we pad N, nc_block_end won't exceed padded N here. + const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); +{%- if use_local_acc %} + {%- set acc = kernel.local_buffers[acc_buf_name] %} + {{ kernel.reinit_buffer_if_null(acc_buf_name) }} +{%- else %} + {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %} +{%- endif %} + for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { + int64_t k_start = kc * Kr; + int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); +{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} + for (int64_t nci = nc; nci < nc_block_end; nci++) { +{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %} +{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %} +{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} + if (kc == k_block_start) { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=False)|indent(28, false) }} + } else { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=True)|indent(28, false) }} + } + } + } +{%- if maybe_k_slicing %} + if (num_k_slices > 1) { + const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; + local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }}); + } else +{%- endif %} + { +{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %} +{%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %} + {{ kernel.store_output( + tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers + )|indent(20, false) + }} + } + } + } +{%- if maybe_k_slicing %} + if (num_k_slices > 1) { + #pragma omp barrier + for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { + // We slice M-dim and each thread in the k-slicing group works on a slice + const int64_t m_start_unsliced = mc * Mr; + const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); + const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced; + const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices; + const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced); + const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced); + const int64_t m_size = m_end - m_start; + const int64_t m_offset = m_start - m_start_unsliced; + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + const int64_t n_start = nc * Nr; + const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); + const int64_t n_size = n_end - n_start; + const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; + auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get(); + for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) { + auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get(); + for (int64_t m = m_offset; m < m_offset + m_size; m++) { + #pragma omp simd + for (int64_t n = 0; n < n_size; n++) { + {{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n]; + } + } + } + {%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %} + {{ kernel.store_output( + tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers + )|indent(20, false) + }} + } + } + } +{%- endif %} + {{ micro_gemm.codegen_finalize(kernel) }} + } +} +""" + + +def get_padded_n(n, block_n): + return (n + block_n - 1) // block_n * block_n + + +class CppPackedGemmTemplate(CppTemplate): + def __init__( + self, + input_nodes, + layout: ir.Layout, + num_threads: int, + register_blocking: GemmBlocking, + beta=1, + alpha=1, + has_bias=False, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + ) -> None: + assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8] + super().__init__( + "packed_gemm", + input_nodes, + layout, + num_threads, + epilogue_creator=epilogue_creator, + ) + self.beta = beta + self.alpha = alpha + self.has_bias = has_bias + self.register_blocking = register_blocking + m, n = layout.size + _, k = input_nodes[0].get_size() + self.m, self.n, self.k = m, n, k + self.padded_n = get_padded_n(n, self.register_blocking.block_n) + self.is_dynamic_M = has_free_symbols((m,)) + + @cache_on_self + def thread_blocking(self) -> GemmBlocking: + """ + NOTE [Thread blocking in Cpp GEMM] + We use simple heuristics to decide the thread blocking: + 1. Make sure all threads are occupied as much as possible. + 2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse. + 3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing. + TODO(jgong5): allow tuning various blocking options + """ + + @lru_cache(maxsize=100) + def get_factors(number): + factors = [] + for i in range(int(number**0.5), 0, -1): + if number % i == 0: + factors.append(number // i) + factors.append(i) + return factors + + def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks): + thread_block_k = math.ceil(k_blocks / k_factor) + thread_block_n = math.ceil(n_blocks / n_factor) + thread_block_m = math.ceil(m_blocks / m_factor) + return GemmBlocking(thread_block_m, thread_block_n, thread_block_k) + + assert ( + not self.is_dynamic_M + ), "Unable to determine thread blocking for dynamic M." + register_blocking = self.register_blocking + m_blocks = math.ceil(self.m / register_blocking.block_m) + n_blocks = math.ceil(self.n / register_blocking.block_n) + k_blocks = math.ceil(self.k / register_blocking.block_k) + factors = get_factors(self.num_threads) + assert len(factors) > 0 + + if config.cpp.gemm_thread_factors is not None: + factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")] + assert len(factors) == 3 + assert math.prod(factors) == self.num_threads + return get_blocking( + factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks + ) + + # we favor square-sized thread blocks for good data reuse + def get_better_blocking(blocking, best_blocking): + if best_blocking is None: + best_blocking = blocking + else: + block_m_size = blocking.block_m * register_blocking.block_m + block_n_size = blocking.block_n * register_blocking.block_n + best_block_m_size = best_blocking.block_m * register_blocking.block_m + best_block_n_size = best_blocking.block_n * register_blocking.block_n + if blocking.block_k > best_blocking.block_k: + best_blocking = blocking + elif ( + blocking.block_k == best_blocking.block_k + and block_m_size + block_n_size + < best_block_m_size + best_block_n_size + ): + best_blocking = blocking + return best_blocking + + best_blocking = None + # check if we can have a thread-blocking to occupy all threads without k-slicing + for n_factor in factors: + m_factor = self.num_threads // n_factor + if n_blocks >= n_factor and m_blocks >= m_factor: + blocking = get_blocking( + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + if best_blocking is None: + for k_factor in factors: + if k_blocks >= k_factor and ( + config.cpp.gemm_max_k_slices == 0 + or k_factor <= config.cpp.gemm_max_k_slices + ): + n_factors = get_factors(self.num_threads // k_factor) + for n_factor in n_factors: + m_factor = (self.num_threads // k_factor) // n_factor + if n_blocks >= n_factor and m_blocks >= m_factor: + blocking = get_blocking( + m_factor, + n_factor, + k_factor, + m_blocks, + n_blocks, + k_blocks, + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + if best_blocking is None: + for n_factor in factors: + m_factor = self.num_threads // n_factor + if n_blocks >= n_factor or m_blocks >= m_factor: + blocking = get_blocking( + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + assert best_blocking is not None + return best_blocking + + @cache_on_self + def cache_blocking(self) -> GemmBlocking: + def get_cache_blocking(register_blocking, thread_blocking): + Mr = register_blocking.block_m + Nr = register_blocking.block_n + Kr = register_blocking.block_k + + Mt_blocks = thread_blocking.block_m + Nt_blocks = thread_blocking.block_n + Kt_blocks = thread_blocking.block_k + + if config.cpp.gemm_cache_blocking is not None: + blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")] + assert len(blockings) == 3 + Mc_blocks, Nc_blocks, Kc_blocks = blockings + return ( + min(Mc_blocks, Mt_blocks), + min(Nc_blocks, Nt_blocks), + min(Kc_blocks, Kt_blocks), + ) + + # The ratios below are empirically determined to decide + # the effective sizes of L1 and L2. + # TODO: tune the factor here + L1_limit_factor = 0.8 + L2_limit_factor = 0.5 + + L1_cache_size = ( + torch._C._cpu._L1d_cache_size() + ) # per core cache size in Bytes + assert ( + L1_cache_size > 0 + ), f"Expect L1_cache_size > 0 but got {L1_cache_size}" + L1 = L1_cache_size * L1_limit_factor + + L2_cache_size = ( + torch._C._cpu._L2_cache_size() + ) # per core cache size in Bytes + assert ( + L2_cache_size > 0 + ), f"Expect L2_cache_size > 0 but got {L2_cache_size}" + L2 = L2_cache_size * L2_limit_factor + + def get_num_byte(dtype): + return torch.tensor([], dtype=dtype).element_size() + + num_byte_A = get_num_byte(self.input_nodes[0].get_dtype()) + num_byte_B = get_num_byte(self.input_nodes[1].get_dtype()) + + # NOTE [CPP GEMM Cache Blocking Algorithm] + # Our overall strategy is to + # 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc. + # Here, B is Kc x Nr where Nr is a single register block. We use L1 size to + # decide Kc. We want to make Mc large enough to better reuse B. + # 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A + # along N, where we have two sub-strategies (see notes below) to decide Mc and Nc. + + # Step 1: Decide Kc assuming B block is L1-reside. + size_cache_B = Kr * Kt_blocks * Nr * num_byte_B + Kc_blocks = Kt_blocks + if size_cache_B > L1: + Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B)) + + # Step 2: Decide Mc assuming A block is L2-reside. + min_Mc_ratio = 2 # TODO(jgong5): something to tune? + min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr) + assert min_Mc_blocks >= 1 + Kt_bytes = Kt_blocks * Kr * num_byte_A + if min_Mc_blocks * Mr * Kt_bytes < L2: + # Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt + # when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks) + # to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside + # in L1. + Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes))) + Nc_blocks = 1 + else: + # Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse + # A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2. + Mc_blocks = Mt_blocks + Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks) + Nc_bytes = Nc_blocks * Nr * 4 # assume C or acc is float32/int32 + Kc_bytes = Kc_blocks * Kr * num_byte_A + if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2: + # The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2, + # assuming Mc == Nc for good data reuse. + M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8 + if M_max < Mc_blocks * Mr: + Mc_blocks = math.floor(M_max / Mr) + Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks) + + return Mc_blocks, Nc_blocks, Kc_blocks + + assert ( + not self.is_dynamic_M + ), "Unable to determine cache blocking for dynamic M." + register_blocking = self.register_blocking + thread_blocking = self.thread_blocking() + + return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking)) + + def log_blockings(self): + log.debug(f"Register blocking: {self.register_blocking}") # noqa: G004 + if self.is_dynamic_M: + # thread and cache blockings are determined at runtime for dynamic shapes + return + log.debug(f"Cache blocking: {self.cache_blocking()}") # noqa: G004 + thread_blocking = self.thread_blocking() + log.debug(f"Thread blocking: {thread_blocking}") # noqa: G004 + + def get_occupancy(): + m_blocks = math.ceil(self.m / self.register_blocking.block_m) + n_blocks = math.ceil(self.n / self.register_blocking.block_n) + k_blocks = math.ceil(self.k / self.register_blocking.block_k) + m = math.ceil(m_blocks / thread_blocking.block_m) + n = math.ceil(n_blocks / thread_blocking.block_n) + k = math.ceil(k_blocks / thread_blocking.block_k) + return (m, n, k) + + log.debug( + f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}" # noqa: G004 + ) + + def maybe_k_slicing(self): + if self.num_threads == 1: + return False + if self.is_dynamic_M: + # TODO(jgong5): perhaps use size hint to decide? + return True + register_blocking = self.register_blocking + k_blocks = math.ceil(self.k / register_blocking.block_k) + thread_blocking = self.thread_blocking() + return k_blocks > thread_blocking.block_k + + @staticmethod + def add_choices( + choices, + layout, + input_nodes, + beta=1, + alpha=1, + has_bias=False, + trans_w=False, + input_indices=None, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + ): + if input_indices is None: + input_indices = list(range(len(input_nodes))) + + def reorder_and_filter(inputs, layout_or_out): + if has_bias: + assert len(input_indices) >= 3 + # Assume the input order is [inp, x, w] and we reorder it to [x, w, inp] + inp_idx = input_indices[0] + x_idx = input_indices[1] + w_idx = input_indices[2] + return [ + inputs[x_idx], + inputs[w_idx], + inputs[inp_idx], + *[inputs[idx] for idx in input_indices[3:]], + ], layout_or_out + else: + assert len(input_indices) >= 2 + return [inputs[idx] for idx in input_indices], layout_or_out + + def maybe_to_dense(inputs, layout_or_out): + new_inputs = list(inputs) + if isinstance(inputs[1], torch.Tensor): + W = inputs[1] + new_inputs[1] = W.to_dense() if W.is_mkldnn else W + return new_inputs, layout_or_out + + def normalize_shapes(inputs, layout_or_out): + if not trans_w: + return inputs, layout_or_out + new_inputs = list(inputs) + X = inputs[0] + W = inputs[1] + B = inputs[2] if has_bias else None + if isinstance(W, ir.IRNode): + if trans_w: + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + W = L.permute(W, [1, 0]) + else: + if trans_w: + assert isinstance(W, torch.Tensor) + W = W.transpose(0, 1) + if B is not None: + if isinstance(B, ir.IRNode): + if not isinstance(B, ir.TensorBox): + B = ir.TensorBox(B) + B = L.expand(B, (X.get_size()[0], B.get_size()[-1])) + else: + assert isinstance(B, torch.Tensor) + B = B.expand(X.shape[0], B.shape[-1]) + new_inputs[1] = W + if B is not None: + new_inputs[2] = B + return new_inputs, layout_or_out + + # TODO(jgong5): decide proper number of threads per problem size + num_threads = parallel_num_threads() + new_inputs, _ = normalize_shapes( + *maybe_to_dense(*reorder_and_filter(input_nodes, layout)) + ) + m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1]) + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + new_inputs[0].get_dtype() + ) + micro_gemm = create_micro_gemm( + "micro_gemm", + m, + n, + k, + input_dtype=new_inputs[0].get_dtype(), + input2_dtype=new_inputs[1].get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=alpha, + num_threads=num_threads, + ) + assert micro_gemm is not None + _, block_n, _ = micro_gemm.register_blocking + padded_n = get_padded_n(n, block_n) + + def pack_weight(inputs, layout_or_out): + W = inputs[1] + new_inputs = list(inputs) + blocked_w: Union[ir.IRNode, torch.Tensor] = W + if isinstance(W, ir.IRNode): + new_size = [padded_n // block_n, k, block_n] + blocked_w = ir.Buffer( + W.get_name(), # Borrow the registered buffer name + ir.FixedLayout( + W.get_device(), + W.get_dtype(), + new_size, + ir.FlexibleLayout.contiguous_strides(new_size), + 0, + ), + ) + else: + blocked_w = ( + torch.nn.functional.pad(W, (0, padded_n - n)) + .reshape(k, padded_n // block_n, block_n) + .transpose(0, 1) + .contiguous() + ) + if micro_gemm.get_b_layout() != LayoutType.NORMAL: + layout_str = ( + "VNNI4" + if micro_gemm.get_b_layout() == LayoutType.VNNI4 + else "VNNI2" + ) + assert micro_gemm.get_b_layout() in [ + LayoutType.VNNI2, + LayoutType.VNNI4, + ], f"We only support {layout_str} for now" + vnni_size = ( + 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2 + ) + assert ( + k % vnni_size == 0 + ), f"k should be divisible by vnni_size for {layout_str} layout" + blocked_w = ( + blocked_w.view( + padded_n // block_n, k // vnni_size, vnni_size, block_n + ) + .transpose(-1, -2) + .contiguous() + .view(padded_n // block_n, k, block_n) + ) + # normalize stride to be "contiguous_strides" per size + # this avoids the problems in L.view during template codegen + new_stride = [1] + for sz in reversed(blocked_w.shape[1:]): + new_stride.insert(0, new_stride[0] * sz) + blocked_w = blocked_w.as_strided(blocked_w.shape, new_stride) + new_inputs[1] = blocked_w + + def _is_int8_gemm(inputs): + return ( + isinstance(inputs[0], ir.IRNode) + and inputs[0].get_dtype() == torch.uint8 + ) or ( + isinstance(inputs[0], torch.Tensor) + and inputs[0].dtype == torch.uint8 + ) + + if _is_int8_gemm(new_inputs): + BCompensate = None + if isinstance(W, ir.IRNode): + BCompensate = V.graph.add_tensor_constant( + V.graph.constants[W.get_name() + "_BMatrixCompens"], + W.get_name() + "_BMatrixCompens", + ) + else: + BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment] + new_inputs.append(BCompensate) + return new_inputs, layout_or_out + + def preprocessor(inputs, layout): + return pack_weight( + *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout))) + ) + + def postprocessor(output): + if isinstance(output, ir.TensorBox): + # prepack the weight as input to the template buffer + template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) + assert isinstance(template_buffer, ir.CppTemplateBuffer) + new_input_nodes, _ = reorder_and_filter(input_nodes, layout) + + W_node = new_input_nodes[1] + assert W_node.get_name() in V.graph.constants + W = V.graph.constants[W_node.get_name()] + new_input_nodes[1] = W + new_input_nodes, _ = pack_weight( + *normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) + ) + + # By using the new packed weight for the GEMM template, we can prune the + # old weight if it has no other users. This saves memory but makes the FX graph + # non-retraceable. To support retracing, we can add a repack node to the + # FX graph. For example: + # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template + W_tensor_users = 0 + for node in reversed(V.graph.graph.nodes): + # Case may happen when the wgt tensor is used by more than 1 get_attr node + # https://github.com/pytorch/pytorch/issues/134998 + if node.op == "get_attr" and hasattr( + V.graph.module, node.name + ): # wgt might already be deleted + comp_tensor = getattr(V.graph.module, node.name) + if ( + W.is_mkldnn == comp_tensor.is_mkldnn + and W.dtype == comp_tensor.dtype + and W.device == comp_tensor.device + and ( + ( + not W.is_mkldnn + and ( + W.untyped_storage().data_ptr() + == comp_tensor.untyped_storage().data_ptr() + ) + ) + or ( + W.is_mkldnn + and ( + torch.ops.mkldnn.data_ptr(W) + == torch.ops.mkldnn.data_ptr(comp_tensor) + ) + ) + ) + ): + W_tensor_users += 1 + + for node in reversed(V.graph.graph.nodes): + # The wgt tensor has been used by only 1 get_attr node + # The get_attr node has only 1 user fx node + if ( + node.name == W_node.get_name() + and len(node.users) == 1 + and W_tensor_users == 1 + ): + del V.graph.constants[node.name] + delattr(V.graph.module, node.name) + delattr(V.graph.graph.owning_module, node.name) + + W_packed = new_input_nodes[1] + W_packed_constant = V.graph.add_tensor_constant(W_packed) + template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input( + W_packed_constant + ) + return output + + template = DataProcessorTemplateWrapper( + CppPackedGemmTemplate, + preprocessor, + postprocessor, + input_nodes=input_nodes, + layout=layout, + num_threads=num_threads, + register_blocking=micro_gemm.register_blocking, + beta=beta, + alpha=alpha, + has_bias=has_bias, + epilogue_creator=epilogue_creator, + ) + template.maybe_append_choice(choices) + return template + + def render( # type: ignore[override,return] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[List[ir.IRNode]] = None, + **kwargs, + ) -> str: + assert len(self.input_nodes) >= 2 + + int8_gemm = self.input_nodes[0].get_dtype() == torch.uint8 + x_scale = None + x_zp = None + w_scale = None + w_zp = None + if int8_gemm: + X, W = self.input_nodes[0], self.input_nodes[1] + bias_idx = 2 if self.has_bias else 1 + inp = self.input_nodes[bias_idx] if self.has_bias else None + x_scale = self.input_nodes[bias_idx + 1] + x_zp = self.input_nodes[bias_idx + 2] + w_scale = self.input_nodes[bias_idx + 3] + w_zp = self.input_nodes[bias_idx + 4] + Y = self.output_node + else: + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + inp = self.input_nodes[2] if self.has_bias else None + + template_buffer_has_other_users = None + + if template_buffer_node is not None: + # Use the updated prepacked weight buffer + W = template_buffer_node.inputs[1] + Y = template_buffer_node + + assert flag_template_buffer_has_other_users is not None + template_buffer_has_other_users = flag_template_buffer_has_other_users + + template_buffer = Y + gemm_output_buffer = template_buffer + + epilogues: List[ir.IRNode] = [] + reindexers: List[Optional[Callable[[List[Any]], List[Any]]]] = [] + epilogue_creators: List[Callable[[ir.Buffer], ir.Pointwise]] = [] + fake_buffers: List[ir.Buffer] = [] + Y_aliases: Set[str] = set() + + use_local_acc = ( + self.layout.dtype != torch.float + or template_buffer_has_other_users + or int8_gemm + or self.padded_n != self.n + or self.maybe_k_slicing() + ) + + # TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template, + # but we'd better move it here to align with fp. + if inp is not None and self.beta != 0 and not int8_gemm: + # add an epilogue for bias add + def _bias_add_epilogue(buf): + return create_epilogue_with_attr( + buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype + ) + + epilogue_creators.append(_bias_add_epilogue) + + if self.epilogue_creator is not None: + epilogue_creators.append(self.epilogue_creator) + + # When the GEMM output buffer is localized but it has users other than the epilogue nodes, + # we need to copy the value in the GEMM output local buffer to a global buffer. + def need_copy_from_local_to_global_buffer_epilogue( + use_local_acc, template_buffer_has_other_users, epilogue_creators + ): + # The GEMM output buffer is a global buffer, thus copy is not needed. + if not use_local_acc: + return False + + # The possible value of template_buffer_has_other_users is (None, False, True) + # It is None when generating the gemm template during autotune and it will have value during scheduler codegen. + # extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases: + # 1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune) + # 2. template_buffer_has_other_users is False, which means it's safe to keep the value in the + # GEMM output buffer in local buffer only (no users outside of the epilogues will use its value). + if not template_buffer_has_other_users: + return False + + # When bias is not None or self.epilogue_creator is not None, + # there will be epilogue_creators after the GEMM. + # The GEMM output buffer is localized while + # the output buffer of the epilogue_creators is a global buffer. + if epilogue_creators: + return False + + return True + + if need_copy_from_local_to_global_buffer_epilogue( + use_local_acc, template_buffer_has_other_users, epilogue_creators + ): + + def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer): + dtype = self.layout.dtype + input_loader = input_buffer.make_loader() + + def copy_inner(index): + input = input_loader(index) + result = ops.to_dtype(input, dtype) + return result + + return ir.Pointwise( + device=input_buffer.get_device(), + dtype=self.layout.dtype, + inner_fn=copy_inner, + ranges=input_buffer.get_size(), + ) + + epilogue_creators.append(copy_from_local_to_global_buffer_epilogue) + + # NOTE [How CPP GEMM template epilogues are organized] + # gemm_output_buffer + # --> zero or more in-template epilogues (created by `epilogue_creators`) --> + # template_buffer + # --> zero or more out-of-template epilogues (`epilogue_nodes`) --> + # Y + if epilogue_creators: + gemm_output_name = "buf_GemmOut" + gemm_output_buffer = ir.Buffer(gemm_output_name, template_buffer.layout) + current_input_buffer = gemm_output_buffer + for i, creator in enumerate(epilogue_creators): + if i == len(epilogue_creators) - 1: + buffer_name = template_buffer.get_name() + else: + buffer_name = f"buf_GemmOut_epilogue_{i}" + epilogues.append( + ir.ComputedBuffer( + name=buffer_name, + layout=template_buffer.layout, + data=creator(current_input_buffer), + ) + ) + fake_buffers.append(current_input_buffer) + Y_aliases.add(current_input_buffer.get_name()) + reindexers.append(None) + if i < len(epilogue_creators) - 1: + current_input_buffer = ir.Buffer( + buffer_name, template_buffer.layout + ) + + Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y + + if epilogue_nodes: + epilogues.extend(epilogue_nodes) + assert Y.get_numel() == epilogues[-1].get_numel() + Y = cast(ir.Buffer, epilogues[-1]) + + if not template_buffer_has_other_users: + Y_aliases.add(template_buffer.get_name()) + + if ( + Y.get_size() == template_buffer.get_size() + and Y.get_stride() == template_buffer.get_stride() + ): + reindexers.extend([None] * len(epilogue_nodes)) + Y_2d = Y + else: + + def get_reindexer(epilogue_node): + # From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example: + # template_buffer: + # size (324, 512), stride (512, 1) + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + stride_order = list( + ir.get_stride_order( + V.graph.sizevars.size_hints(epilogue_node.get_stride()) + ) + ) + fill_order = ir.stride_order2fill_order(stride_order) + reversed_fill_order = list(reversed(fill_order)) + size_with_stride_ordered_decreasingly = [ + epilogue_node.get_size()[i] for i in reversed_fill_order + ] + reshape_reindex = ir.View.dynamic_reshape_indexer( + size_with_stride_ordered_decreasingly, + template_buffer.get_size(), + ) + + # From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example: + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + # epilogue_node: + # size (1, 18, 18, 512), stride (165888, 1, 9216, 512) + from_stride_ordered_decreasingly_to_epilogue_node_order = [ + (len(stride_order) - 1) - stride_order[i] + for i in range(len(stride_order)) + ] + stride_reindex = ir.same_reorder( + from_stride_ordered_decreasingly_to_epilogue_node_order + ) + + reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex) + return reindexer + + reindexers.extend([get_reindexer(epilogue_node) for epilogue_node in epilogue_nodes]) # type: ignore[list-item] + if isinstance(Y, ir.BaseView): + storage = ir.StorageBox(Y.unwrap_view()) + else: + assert isinstance(Y, ir.Buffer) + storage = ir.StorageBox(Y) + Y_2d = ir.ReinterpretView(storage, template_buffer.get_layout()) + + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + X.get_dtype() + ) + micro_gemm = create_micro_gemm( + f"{kernel.kernel_name}_micro_gemm", + self.m, + self.n, + self.k, + input_dtype=X.get_dtype(), + input2_dtype=W.get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=self.alpha, + num_threads=self.num_threads, + ) + assert micro_gemm is not None + assert self.register_blocking == micro_gemm.register_blocking + self.log_blockings() + if isinstance(micro_gemm, CppMicroGemmAMX): + counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1 + + L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes + assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}" + + L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes + assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}" + + options = dict( + X=X, + W=W, + inp=inp, + Y=Y, + N=self.n, + K=self.k, + PADDED_N=self.padded_n, + GemmOut=gemm_output_buffer, + aliases={alias: Y.get_name() for alias in Y_aliases}, + beta=self.beta, + alpha=self.alpha, + num_threads=self.num_threads, + micro_gemm=micro_gemm, + is_dynamic_M=self.is_dynamic_M, + template=self, + kernel=kernel, + export_declaration=get_export_declaration(), + epilogue_nodes=epilogues, + reindexers=reindexers, + Y_2d=Y_2d, + use_local_acc=use_local_acc, + maybe_k_slicing=self.maybe_k_slicing(), + x_scale=x_scale, + x_zp=x_zp, + w_scale=w_scale, + w_zp=w_zp, + acc_buf_dtype=torch.int32 if int8_gemm else torch.float, + DTYPE_TO_CPP=DTYPE_TO_CPP, + L1_cache_size=L1_cache_size, + L2_cache_size=L2_cache_size, + config=config, + ) + with contextlib.ExitStack() as stack: + for buf in fake_buffers: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) + ) + return self._template_from_string(GEMM_TEMPLATE).render(**options) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..cb26c48fce53c66f05e57c7a36d68063e0f3240b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py @@ -0,0 +1,850 @@ +# mypy: allow-untyped-defs +import dataclasses +import sys +from enum import Enum +from typing import Callable, Dict, List, Optional, Type + +import sympy + +import torch + +from .. import ir +from ..cpu_vec_isa import pick_vec_isa, VecAMX, VecAVX2, VecAVX512, VecISA +from ..utils import IndentedBuffer, parallel_num_threads +from ..virtualized import V +from .common import KernelTemplate +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp + + +class LayoutType(Enum): + NORMAL = 0 + VNNI2 = 1 + VNNI4 = 2 + + +_IS_WINDOWS = sys.platform == "win32" + + +def get_restrict_keyword() -> str: + if _IS_WINDOWS: + # https://learn.microsoft.com/en-us/cpp/cpp/extension-restrict?view=msvc-170 + return "__restrict" + else: + return "__restrict__" + + +class CppMicroGemm: + """ + A class that codegens a kernel that computes small-sized matrix multiplication. + + A micro GEMM kernel is responsible for register blocking, instruction selection, + and other CPU architecture-specific optimizations. + + The subclasses need to override `codegen_define` to define the kernel function + that is called by the code generated by `codegen_call`. + """ + + # TODO(jgong5): support constant shapes and lds as template args. + DECLARE_KERNEL = r""" +template +inline void {{kernel_name}}( +{%- if kernel_extra_args_declare %} + {{kernel_extra_args_declare}} +{%- endif %} + const {{input_t}}* {{restrict_keyword}} A, + const {{input2_t}}* {{restrict_keyword}} B, + {{output_t}}* {{restrict_keyword}} C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc +) +""" + + def __init__( + self, + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha=1, + ) -> None: + self.name = name + self.input_dtype = input_dtype + assert input2_dtype is not None + self.input2_dtype = input2_dtype + self.output_dtype = output_dtype + self.compute_dtype = compute_dtype + self.register_blocking = register_blocking + self.alpha = alpha + + def get_common_options(self): + if self.input_dtype == torch.uint8: + assert self.compute_dtype == torch.int32 + assert self.output_dtype == torch.int32 + assert self.input2_dtype == torch.int8 + return { + "torch": torch, + "kernel_name": self.name, + "input_dtype": self.input_dtype, + "input2_dtype": self.input2_dtype, + "output_dtype": self.output_dtype, + "compute_dtype": self.compute_dtype, + "input_t": DTYPE_TO_CPP[self.input_dtype], + "input2_t": DTYPE_TO_CPP[self.input2_dtype], + "output_t": DTYPE_TO_CPP[self.output_dtype], + "compute_t": DTYPE_TO_CPP[self.compute_dtype], + "alpha": self.alpha, + "kernel_extra_args_declare": self.get_kernel_extra_args_declare(), + "int8_gemm": self.input_dtype == torch.uint8, + "vnni_size": 4 if self.input_dtype == torch.uint8 else 2, + "restrict_keyword": get_restrict_keyword(), + } + + def get_kernel_declaration(self): + options = self.get_common_options() + return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options) + + def get_kernel_extra_args_declare(self) -> str: + return "" + + def get_kernel_extra_args(self) -> str: + return "" + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + raise NotImplementedError + + def codegen_call( + self, + kernel: CppTemplateKernel, + A: ir.Buffer, + B: ir.Buffer, + C: ir.Buffer, + accum: bool, + ) -> str: + """ + Generate the code for calling the templated kernel that computes + `C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise. + """ + A_ptr = f"&({kernel.index(A, [0, 0])})" + B_ptr = f"&({kernel.index(B, [0, 0])})" + C_ptr = f"&({kernel.index(C, [0, 0])})" + M = kernel.size(C, 0) + N = kernel.size(C, 1) + K = kernel.size(A, 1) + lda = kernel.stride(A, 0) + ldb = kernel.stride(B, 0) + ldc = kernel.stride(C, 0) + res = IndentedBuffer() + res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(") + with res.indent(): + extra_args = self.get_kernel_extra_args() + if extra_args: + res.writeline(extra_args) + res.writeline(f"{A_ptr},") + res.writeline(f"{B_ptr},") + res.writeline(f"{C_ptr},") + res.writeline(f"{M},") + res.writeline(f"{N},") + res.writeline(f"{K},") + res.writeline(f"{lda},") + res.writeline(f"{ldb},") + res.writeline(f"{ldc}") + res.writeline(");") + return res.getvalue() + + def codegen_init( + self, + kernel: CppTemplateKernel, + ) -> str: + return "" + + def codegen_finalize( + self, + kernel: CppTemplateKernel, + ) -> str: + return "" + + def get_b_layout(self) -> LayoutType: + return LayoutType.NORMAL + + +@dataclasses.dataclass +class CppMicroGemmConfig: + input_dtype: torch.dtype + input2_dtype: torch.dtype + output_dtype: torch.dtype + compute_dtype: torch.dtype + vec_isa_cls: Type[VecISA] + register_blocking: GemmBlocking + extra_check: Optional[Callable[..., bool]] = None + + +micro_gemm_configs: Dict[Type[CppMicroGemm], List[CppMicroGemmConfig]] = {} + + +def register_micro_gemm(*configs): + def inner(cls): + assert ( + cls not in micro_gemm_configs + ), f"Duplicate micro_gemm registration for {cls}" + assert len(configs) > 0, f"No micro_gemm configs provided for {cls}" + micro_gemm_configs[cls] = list(configs) + return cls + + return inner + + +def generate_gemm_config( + vec_isa_cls, + register_blockings, + input_dtype=torch.float, + input2_dtype=None, + output_dtype=None, + compute_dtype=None, + extra_check=None, +): + if output_dtype is None: + output_dtype = input_dtype + if compute_dtype is None: + compute_dtype = output_dtype + if input2_dtype is None: + input2_dtype = input_dtype + return [ + CppMicroGemmConfig( + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + vec_isa_cls, + GemmBlocking(*blocking), + extra_check, + ) + for blocking in register_blockings + ] + + +class CppMicroGemmRef(CppMicroGemm): + """ + A reference implementation of the CppMicroGemm class with naive C++ code. + It is used for correctness debugging. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + {{compute_t}} result = accum ? C[m * ldc + n] : 0; + for (int64_t k = 0; k < K; ++k) { + result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}}; + } + C[m * ldc + n] = result; + } + } +} +""" + + def __init__( + self, name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha + ) -> None: + super().__init__( + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + GemmBlocking(1, 1, 1), + alpha, + ) + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + **self.get_common_options(), + } + return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options) + + +@register_micro_gemm( + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.half, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.half, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + ), +) +class CppMicroGemmFP32Vec(CppMicroGemm): + """ + This class generates the code for micro gemm using fp32 vec instructions for compute. + It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output. + The output of the microkernel is in FP32, but it would be converted to BF16/FP16 in the template, + if the desired output is BF16/FP16. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); + // TODO(jgong5): loop unroll for M and N + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + for (int64_t n = 0; n < N; n += {{block_n}}) { + if (block_m == {{block_m}}) { + {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>( + A + m * lda, + B + n, + C + m * ldc + n, + K, + lda, + ldb, + ldc + ); + } else { + switch (block_m) { +{%- for b in range(block_m - 1, 0, -1) %} + case {{b}}: + {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>( + A + m * lda, + B + n, + C + m * ldc + n, + K, + lda, + ldb, + ldc + ); + break; +{%- endfor %} + default: + {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m); + } + } + } + } +} +""" + + TEMPLATE_KERNEL = r""" +template +inline void {{kernel_name}}_kernel( + const {{input_t}}* {{restrict_keyword}} A, + const {{input2_t}}* {{restrict_keyword}} B, + {{output_t}}* {{restrict_keyword}} C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc +) { + using Vectorized = at::vec::Vectorized<{{compute_t}}>; + using VectorizedIn = at::vec::Vectorized<{{input_t}}>; + constexpr auto VLEN = Vectorized::size(); + constexpr auto ROWS = BLOCK_M; + constexpr auto COLS = BLOCK_N / VLEN; + + Vectorized va; + at::vec::VectorizedN<{{compute_t}}, COLS> vb; + at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc; + + auto loadc = [&](auto i) { + if constexpr (accum) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); + } else { + vc[i] = Vectorized(0.0f); + } + }; + c10::ForcedUnroll{}(loadc); + + auto compute = [&, COLS](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { +{%- if alpha != 1 %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}}); +{%- else %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k])); +{%- endif %} + } + + if constexpr (row == 0) { +{%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN); + vb[col] = at::vec::convert<{{compute_t}}>(b); +{%- elif input2_dtype == torch.int8 %} + // Convert VLEN int8 elements to int32, and then fp32 + auto b32 = at::vec::convert_to_int32(B + k * ldb + col * VLEN); + vb[col] = at::vec::convert(b32); +{%- else %} + vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN); +{%- endif %} + } + + constexpr int idx = row * COLS + col; + vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); + }; + + for (int k = 0; k < K; ++k) { + c10::ForcedUnroll{}(compute, k); + } + + // store to C + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i].store(C + row * ldc + col * VLEN); + }; + c10::ForcedUnroll{}(storec); +} +""" + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + "kernel": kernel, + "block_m": self.register_blocking.block_m, + "block_n": self.register_blocking.block_n, + "block_k": self.register_blocking.block_k, + "restrict_keyword": get_restrict_keyword(), + **self.get_common_options(), + } + result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + options + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + +# extra check for CppMicroGemmAMX +def check_amx_extra(config, m, n, k, alpha, num_threads): + vnni_size = 4 if config.input_dtype == torch.uint8 else 2 + return k % vnni_size == 0 and alpha == 1 + + +@register_micro_gemm( + *generate_gemm_config( + VecAMX, + [(32, 32, 32), (48, 16, 32), (16, 48, 32)], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_amx_extra, + ), + *generate_gemm_config( + VecAMX, + [(32, 32, 32), (48, 16, 32), (16, 48, 32)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, + extra_check=check_amx_extra, + ), + *generate_gemm_config( + VecAMX, + [(32, 32, 64), (48, 16, 64)], + input_dtype=torch.uint8, + input2_dtype=torch.int8, + output_dtype=torch.int32, + compute_dtype=torch.int32, + extra_check=check_amx_extra, + ), +) +class CppMicroGemmAMX(CppMicroGemm): + """ + This class generates the code for micro gemm using Advanced Matrix eXtention (AMX) + instructions available in 4th generation Intel Xeon for compute. + It supports input types of torch.bfloat16 with fp32 output. + TODO(jgong5): support int8 data type. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); + // TODO(jgong5): loop unroll for M and N + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + int64_t m_tail = m; + for (int64_t n = 0; n < N; n += {{block_n}}) { +{%- for num_rows in range(block_m, 0, -16) %} + {%- if num_rows != block_m %} + else + {%- endif %} + if (block_m >= {{num_rows}}) { + {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + amx_state, + A + m * lda, + B + n, + C + m * ldc + n, + K, + lda, + ldb, + ldc, + 16 + ); + block_m -= {{num_rows}}; + m_tail += {{num_rows}}; + } +{%- endfor %} + if (block_m > 0) { + {{kernel_name}}_amx_kernel_16_{{num_columns}}( + amx_state, + A + m_tail * lda, + B + n, + C + m_tail * ldc + n, + K, + lda, + ldb, + ldc, + block_m + ); + } + } + } +} +""" + + TEMPLATE_KERNEL = r""" +template +inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + AMXState& amx_state, + const {{input_t}}* {{restrict_keyword}} A, + const {{input2_t}}* {{restrict_keyword}} B, + {{output_t}}* {{restrict_keyword}} C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + uint8_t tilecfg_rows +) { + // TODO(jgong5): add prefetch hint for A, B, C + auto loadconfig = [](const amx_tilecfg& cfg) { + _tile_loadconfig(&cfg); + }; + const auto last_k_offset = K / {{block_k}} * {{block_k}}; + const auto tail_k_size = K - last_k_offset; + if C10_LIKELY (last_k_offset > 0) { + amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig); + } else { + amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig); + } + auto load_c = [&]() { +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}})); + {%- endfor %} +{%- endfor %} + }; + auto zero_c = [&]() { +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_zero({{tile_idx}}); + {%- endfor %} +{%- endfor %} + }; + + if constexpr (accum) { + load_c(); + } else { + zero_c(); + } + +{%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %} + // create a buffer for tiles of B. + alignas(64) {{input_t}} bf16_weights_buf[512]; + + int num_b_rows = (last_k_offset > 0) ? 16 : (tail_k_size * sizeof({{input_t}})) / 4; + int b_tile_ptr_stride = ldb * {{vnni_size}}; + + auto load_B_row = [&]({{input2_t}}* src, {{input_t}}* dst) { + {{kernel.unroll_pragma(2)}} + for (int i = 0; i < 2; i++) { + // int8 -> int32 -> fp32 -> bf16 + auto b32 = at::vec::convert_to_int32(src + i * 16); + auto b_bf16 = at::vec::convert<{{input_t}}>(b32); + b_bf16.store(dst + i * 16); + } + }; + + auto load_B_in_buf = [&]({{input2_t}}* B_ptr) { + {{kernel.unroll_pragma(8)}} + for (int i = 0; i < num_b_rows; i++) { + load_B_row( + B_ptr + i * b_tile_ptr_stride, + bf16_weights_buf + i * 32 + ); + } + }; +{%- endif %} + + auto compute = [&](int k) { +{%- set tile_offset_a = num_rows // 16 * num_columns %} +{%- set tile_offset_b = tile_offset_a + num_rows // 16 %} +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx_a = tile_offset_a + tile_row %} + {%- set tile_idx_b = tile_offset_b + tile_col %} + {%- set tile_idx_c = tile_row * num_columns + tile_col %} + {%- if tile_col == 0 %} + _tile_stream_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}})); + {%- endif %} + {%- if tile_row == 0 %} + {%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %} + load_B_in_buf(const_cast<{{input2_t}}*>(B) + k * ldb + {{tile_col * 16 * vnni_size}}); + _tile_loadd({{tile_idx_b}}, bf16_weights_buf, 64); + {%- else %} + _tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}})); + {%- endif %} + {%- endif %} + {%- if int8_gemm %} + _tile_dpbusd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); + {%- else %} + _tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); + {%- endif %} + {%- endfor %} +{%- endfor %} + }; + + {{kernel.unroll_pragma(4)}} + for (int k = 0; k < last_k_offset; k += {{block_k}}) { + compute(k); + } + + auto store_c = [&]() { + // store to C +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}})); + {%- endfor %} +{%- endfor %} + }; + + // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead + if C10_UNLIKELY (tail_k_size > 0) { + if C10_LIKELY (last_k_offset > 0) { + store_c(); + amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig); + load_c(); + } + compute(last_k_offset); + } + + store_c(); +} +""" + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + block_m, block_n, block_k = self.register_blocking + assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX" + assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX" + if self.input_dtype == torch.uint8: + assert block_k == 64, "Only support block_k = 64 for AMX INT8" + else: + assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16" + num_columns = block_n // 16 + options = { + "declare_kernel": self.get_kernel_declaration(), + "kernel": kernel, + "block_m": block_m, + "block_n": block_n, + "block_k": block_k, + "num_columns": num_columns, + "restrict_keyword": get_restrict_keyword(), + **self.get_common_options(), + } + result = "" + for num_rows in range(block_m, 0, -16): + amx_kernel_options = {**options, "num_rows": num_rows} + result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + amx_kernel_options + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + def codegen_init( + self, + kernel: CppTemplateKernel, + ) -> str: + return "AMXState amx_state;" + + def codegen_finalize( + self, + kernel: CppTemplateKernel, + ) -> str: + return "amx_state.release([]() { _tile_release(); });" + + def get_kernel_extra_args_declare(self) -> str: + return "AMXState& amx_state," + + def get_kernel_extra_args(self) -> str: + return "amx_state," + + def get_b_layout(self): + if self.input_dtype == torch.uint8: + return LayoutType.VNNI4 + else: + return LayoutType.VNNI2 + + +def create_micro_gemm( + name, + m, + n, + k, + input_dtype, + input2_dtype, + output_dtype=None, + compute_dtype=None, + alpha=1, + num_threads=-1, + use_ref=True, +) -> Optional[CppMicroGemm]: + def create_from_config(cls, config: CppMicroGemmConfig): + return cls( + name, + config.input_dtype, + config.input2_dtype, + config.output_dtype, + config.compute_dtype, + config.register_blocking, + alpha, + ) + + assert isinstance(n, int) or n.is_number, n + assert isinstance(k, int) or k.is_number, k + m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m + assert isinstance(m, int), m + if output_dtype is None: + output_dtype = input_dtype + if compute_dtype is None: + compute_dtype = output_dtype + if num_threads < 0: + num_threads = parallel_num_threads() + vec_isa = pick_vec_isa() + matched_configs = [] + for cls, configs in micro_gemm_configs.items(): + for config in configs: + if not issubclass(vec_isa.__class__, config.vec_isa_cls): + continue + if ( + config.input_dtype == input_dtype + and config.compute_dtype == compute_dtype + and config.input2_dtype == input2_dtype + and config.output_dtype == output_dtype + # The output_dtype here is the output dtype of the micro-kernel. + # In some cases, the actual output dtype of the op for which the micro-kernel + # is being created would be same as that of the activation, but the micro-kernels + # compute output in Float/int32, which is converted in the GEMM template. This is + # subject to change in the future. + ): + if config.extra_check is not None and not config.extra_check( + config, m, n, k, alpha, num_threads + ): + continue + block_m, block_n, block_k = config.register_blocking + if ( + config.vec_isa_cls == VecAMX + and m < block_m + and input_dtype == torch.bfloat16 + and input2_dtype == torch.int8 + ): + # For int8 WoQ GEMM, AMX micro-kernel may not perform well if m < block_m + continue + # Criteria on the ranking of configurations + # 1. ISA: AMX > VEC + # 2. Dividable by block sizes (block_m, block_n, block_k) + # 3. Number of mxn blocks is large enough to occupy all the threads + # 4. Register blocks are larger + isa_score = 0 + if config.vec_isa_cls == VecAMX: + isa_score += 1 + dividable_score = 0 + if m % block_m == 0: + dividable_score += 1 + if n % block_n == 0: + dividable_score += 1 + if k % block_k == 0: + dividable_score += 1 + occupancy_score = 0 + n_blocks = (n + block_n - 1) // block_n + total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m) + if n_blocks >= num_threads: + occupancy_score += 1 + if total_mxn_blocks >= num_threads: + occupancy_score += 1 + register_bytes = ( + block_m * block_n * config.compute_dtype.itemsize + + (block_m * block_k + block_k * block_n) + * config.input_dtype.itemsize + ) + matched_configs.append( + ( + (isa_score, dividable_score, occupancy_score, register_bytes), + cls, + config, + ) + ) + if len(matched_configs) == 0: + if use_ref: + return CppMicroGemmRef( + name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha + ) + else: + return None + # TODO(jgong5): allow autotuning on choices of configs + return create_from_config(*max(matched_configs, key=lambda x: x[0])[1:]) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template.py new file mode 100644 index 0000000000000000000000000000000000000000..2bce16b2ad1538624be9c4e32893e32a412e7f46 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template.py @@ -0,0 +1,128 @@ +# mypy: allow-untyped-defs +import ctypes +import functools +import itertools +import logging +import sys +from typing import Callable, List, Optional +from unittest.mock import patch + +import sympy + +from .. import codecache, config, ir +from ..autotune_process import CppBenchmarkRequest, TensorMeta +from ..utils import IndentedBuffer, Placeholder, unique +from ..virtualized import V +from .common import KernelTemplate +from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel + + +log = logging.getLogger(__name__) + + +class CppTemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes, + layout: ir.Layout, + num_threads: int, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + ) -> None: + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: ir.Buffer = ir.Buffer("buf_out", layout) + self.layout = layout + self.num_threads = num_threads + self.epilogue_creator = epilogue_creator + + def generate(self, **kwargs): + kernel_name = f"cpp_{self.name}" + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(self.output_node) + ), patch.object(ir.FlexibleLayout, "allow_indexing", True), CppTemplateKernel( + kernel_name=kernel_name, num_threads=self.num_threads + ) as kernel: + code = kernel.render(self, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + log.debug("Generated Code:\n%s", code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(), + kernel.args.python_argdefs(), + ) + + expected_args = list( + unique(input_node.get_name() for input_node in self.input_nodes) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, call_args[len(expected_args) :]) + ) + # Cast the size hint from int to ctypes.c_ulonglong explicitly + # since in cpp kernel, we bind it to C long + extra_args = tuple(ctypes.c_ulonglong(x) for x in extra_args) + + kernel_hash_name = f"cpp_{self.name}_{next(self.index_counter)}" + + # Create the BenchmarkRequest for CPP + bmreq = CppBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: ir.CppTemplateBuffer, + flag_template_buffer_has_other_users: bool, + epilogue_nodes: Optional[List[ir.IRNode]] = None, + ): + kernel = CppTemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), num_threads=self.num_threads + ) + render = functools.partial( + kernel.render, + self, + template_buffer_node=template_node, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_nodes, + **kwargs, + ) + return kernel, render + + return CppTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.writeline(codecache.cpp_prefix()) + res.splice( + """ + #include "c10/util/Unroll.h" + """ + ) + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if enable_kernel_profile: + res.writelines(["#include "]) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template_kernel.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..0333720bbdc689c8748c518ed71d2782f8c75602 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_template_kernel.py @@ -0,0 +1,384 @@ +# mypy: allow-untyped-defs +import itertools +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import sympy +from sympy.parsing.sympy_parser import parse_expr + +import torch +from torch.utils._sympy.symbol import SymT + +from .. import config, cpp_builder, ir, lowering as L +from ..autotune_process import CppBenchmarkRequest +from ..loop_body import LoopBody +from ..select_algorithm import PartialRender +from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix +from ..virtualized import V +from .common import CppWrapperKernelArgs +from .cpp import CppKernel, CppKernelProxy, KernelGroup +from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext +from .cpp_wrapper_cpu import CppWrapperCpu + + +def parse_expr_with_index_symbols(expr): + if isinstance(expr, sympy.Expr): + return expr + elif isinstance(expr, (list, tuple)): + return [parse_expr_with_index_symbols(e) for e in expr] + else: + expr = parse_expr(str(expr)) + int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols} + return expr.subs(int_symbols) + + +def wrap_with_tensorbox(node) -> ir.TensorBox: + return ( + ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node) + ) + + +class CppTemplateKernel(CppKernel): + def __init__(self, kernel_name, num_threads): + super().__init__(None, num_threads) + self.kernel_name = kernel_name + self.render_hooks = {} + self.local_buffers = {} + if isinstance(V.graph.wrapper_code, CppWrapperCpu): + self.args = CppWrapperKernelArgs() + + def render(self, template, **kwargs): + return PartialRender( + template.render(kernel=self, **kwargs), self.render_hooks + ).finalize_all() + + def def_kernel( + self, + inputs: Dict[str, ir.Buffer], + outputs: Dict[str, ir.Buffer], + aliases: Optional[Dict[str, str]] = None, + ) -> str: + for name, inp in inputs.items(): + if inp is not None: + self.args.input_buffers[inp.get_name()] = name + for name, out in outputs.items(): + self.args.output_buffers[out.get_name()] = name + if aliases is not None: + for alias, orig in aliases.items(): + if orig in self.args.input_buffers: + self.args.input_buffers[alias] = self.args.input_buffers[orig] + if orig in self.args.output_buffers: + self.args.output_buffers[alias] = self.args.output_buffers[orig] + + unique_sizevars = { + s + for input in inputs.values() + if input is not None + for sym in itertools.chain(input.get_size(), input.get_stride()) + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + } + unique_sizevars |= { + s + for output in outputs.values() + for sym in itertools.chain(output.get_size(), output.get_stride()) + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + } + sizevars = sorted(unique_sizevars, key=str) + for sizevar in sizevars: + self.args.sizevars[sizevar] = f"k{sizevar}" + + def hook(): + # remove all aliases before generate function definition + if aliases is not None: + for alias in aliases: + if alias in self.args.input_buffers: + self.args.input_buffers[alias] = "REMOVED" + if alias in self.args.output_buffers: + self.args.output_buffers[alias] = "REMOVED" + cpp_argdefs, _, _ = self.args.cpp_argdefs() + return f"void {self.kernel_name}({', '.join(cpp_argdefs)})" + + placeholder = "" + assert placeholder not in self.render_hooks + self.render_hooks[placeholder] = hook + return placeholder + + def call_kernel(self, name: str, node: ir.CppTemplateBuffer): + wrapper = V.graph.wrapper_code + _, call_args, arg_types = self.args.cpp_argdefs() + wrapper.generate_kernel_call(name, call_args, cuda=False, arg_types=arg_types) + + def dtype(self, node: ir.Buffer) -> str: + return DTYPE_TO_CPP[node.get_dtype()] + + def acc_dtype(self, node: ir.Buffer) -> str: + if node.get_dtype() in [torch.float32, torch.bfloat16, torch.half]: + return "float" + else: + raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") + + def size(self, node: ir.Buffer, dim: int) -> str: + return cexpr_index(self.rename_indexing(node.get_size()[dim])) + + def stride(self, node: ir.Buffer, dim: int) -> str: + return cexpr_index(self.rename_indexing(node.get_stride()[dim])) + + def index(self, node: ir.Buffer, indices: List[Any]) -> str: + indexer = node.layout.as_fixed().make_indexer() + index = indexer(parse_expr_with_index_symbols(indices)) + index = self.rename_indexing(index) + outer_name = node.get_name() + inner_name = ( + outer_name + if outer_name in self.local_buffers + else self.args.input(node.get_name()) + ) + return f"{inner_name}[{cexpr_index(index)}]" + + def slice_nd(self, node, ranges: List[Tuple[Any, Any]]) -> ir.ReinterpretView: + """ + Slice the given node with a list of ranges (start and end) corresponding to its dims. + The dim is not sliced if the corresponding range is empty. + """ + assert len(ranges) == len(node.get_size()), f"{ranges=}, {node=}" + sliced = wrap_with_tensorbox(node) + for dim, _range in enumerate(ranges): + if len(_range) == 0: + continue + assert len(_range) == 2 + start, end = parse_expr_with_index_symbols(_range) + sliced = L.slice_(sliced, dim, start, end, clamp=False) + assert isinstance(sliced.data, ir.ReinterpretView), sliced.data + return sliced.data + + def view(self, node, sizes: List[Any]) -> ir.View: + node = wrap_with_tensorbox(node) + sizes = parse_expr_with_index_symbols(sizes) + return L.view(node, sizes).data + + def permute(self, node, dims): + node = wrap_with_tensorbox(node) + permuted = L.permute(node, dims).data + assert isinstance(permuted, ir.ReinterpretView) + return permuted + + def maybe_codegen_profile(self) -> str: + if config.cpp.enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef({{}}));' + else: + return "" + + def unroll_pragma(self, unroll): + if cpp_builder.is_gcc(): + return f"#pragma GCC unroll {unroll}" + else: + return f"#pragma unroll {unroll}" + + def define_buffer(self, name, sizes: List[Any], dtype=torch.float) -> str: + """Define kernel local buffer""" + sizes = parse_expr_with_index_symbols(sizes) + buf = ir.Buffer(name, ir.FixedLayout(torch.device("cpu"), dtype, sizes)) + self.local_buffers[name] = buf + ctype = f"{DTYPE_TO_CPP[dtype]}" + numel = f"{cexpr_index(buf.get_numel())}" + return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();" + + def reinit_buffer_if_null(self, name): + """Reinit the previously defined local buffer if it is null""" + assert name in self.local_buffers + buf = self.local_buffers[name] + ctype = f"{DTYPE_TO_CPP[buf.layout.dtype]}" + numel = f"{cexpr_index(buf.get_numel())}" + return f"if (_{name} == nullptr) {{ _{name} = std::make_unique<{ctype}[]>({numel}); {name} = _{name}.get(); }}" + + def release_buffer(self, name): + """Codegen the code to release the ownership of a local buffer to others""" + assert name in self.local_buffers + return f"_{name}.release()" + + def store_pointwise_nodes( + self, + dst: ir.Buffer, + nodes: List[ir.IRNode], + offsets: Optional[List[sympy.Expr]] = None, + reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None, + ) -> str: + var_sizes = (tuple(dst.get_size()), ()) + var_ranges = { + sympy_index_symbol_with_prefix(SymT.INDEX, i): sz + for i, sz in enumerate(var_sizes[0]) + } + if not offsets: + offsets = [sympy.Integer(0)] * len(var_sizes[0]) + if not reindexers: + reindexers = [None] * len(nodes) + assert len(offsets) == len(var_sizes[0]) + output_index = dst.get_layout().make_indexer()(var_ranges.keys()) + kernel_group = KernelGroup() + kernel_group.args = self.args + cpp_kernel_proxy = CppKernelProxy(kernel_group) + bodies = [] + var_sizes_list = [] + for i, node in enumerate(nodes): + output_name = node.get_name() if i < len(nodes) - 1 else dst.get_name() + node = node.data if isinstance(node, ir.ComputedBuffer) else node + assert isinstance(node, ir.Pointwise), node + + def fn(*args): + assert len(args) == 2 + assert len(args[0]) == len(var_sizes[0]) + assert len(args[1]) == 0 + new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type] + if reindexers[i] is not None: + new_args = reindexers[i](new_args) # type: ignore[misc] + V.ops.store( + output_name, + output_index, + node.make_loader()(new_args).value, + ) + + body = LoopBody( + fn, + (list(var_ranges.keys()), ()), + var_ranges, + list(var_ranges.keys()), + tuple(), + ) + bodies.append(body) + var_sizes_list.append(var_sizes) + + cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) + kernel_group.finalize_kernel(cpp_kernel_proxy, []) + return kernel_group.loops_code.getvalue() + + def store_output( + self, + dst: ir.Buffer, + src: ir.Buffer, + orig_src: Optional[ir.Buffer] = None, + epilogue_nodes: Optional[List[ir.IRNode]] = None, + offsets: Optional[List[Any]] = None, + reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None, + ): + """ + Store the `src` buffer to the `dst` buffer. The size of `src` and `dst` should match. + If `epilogue_nodes` is provided, the `src` buffer is firstly computed with the epilogues + before stored to `dst`. The `epilogues_nodes` are all pointwise. + + Notes: + 1. `src` and `dst` buffer could be the same buffer in which case we are doing in-place compute + and stores. In case `epilogue_nodes` are not provided, we do nothing. + 2. The `epilogue_nodes`, if exist, have computations on `src` before storing to `dst` but since + they come form the original Inductor IR, they might need to be adjusted before working with + `src` and `dst` as outlined below: + a) `src` or `dst` buffer could be a sub-slice of the ranges the `epilogue_nodes`work on. + In this case, the `offsets` could be provided to adjust the indices passed to + `epilogue_nodes` during codegen and the data ranges are also configured according to + the sizes of `src` and `dst`. + b) `dst` might be indexed in a different way as the `epilogue_nodes`, hence a `reindexer` is + needed on the indices to `epilogue_nodes` to match the indexing of `dst`. + c) If `src` is local, we need to add a local buffer for it and localize the `orig_src` buffer + in `epilogue_nodes` with `src`. + """ + assert dst.get_size() == src.get_size(), f"{dst=}, {src=}" + if offsets: + offsets = parse_expr_with_index_symbols(offsets) + if epilogue_nodes: + with LocalBufferContext(self.args) as scope: + assert orig_src is not None + if orig_src.get_name() != src.get_name(): + scope.add_local_buffer( + src, + [ + orig_src, + ], + ) + epilogue_nodes = scope.localize_nodes(epilogue_nodes) + return self.store_pointwise_nodes( + dst, epilogue_nodes, offsets, reindexers # type: ignore[arg-type] + ) + else: + if dst.get_name() != src.get_name(): + # src is local + copy = L.copy(dst, src).data.data + with LocalBufferContext(self.args) as scope: + scope.add_local_buffer(src) + return self.store_pointwise_nodes(dst, [copy]) + else: + assert dst.layout == src.layout, f"{dst=}, {src=}" + return "" + + +class CppTemplateCaller(ir.ChoiceCaller): + """ + CppTemplateCaller + + This class represents a caller for CPP template kernels. It is a subclass of ir.ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CppBenchmarkRequest): The benchmark request for the caller. + template_buffer (ir.CppTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: List[ir.Buffer], + layout: ir.Layout, + make_kernel_render: Callable[ + [ + ir.CppTemplateBuffer, + bool, + Optional[List[ir.IRNode]], + ], + str, + ], + bmreq: CppBenchmarkRequest, + template: "CppTemplate", # type: ignore[name-defined] # noqa: F821 + info_kwargs: Optional[ + Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]] + ] = None, + ): + super().__init__(name, input_nodes, layout) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + return self.bmreq.benchmark(*args, output_tensor=out) + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict( + self, + ) -> Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]]: + return {"backend": "CPP", "op_type": "unknown"} + + def output_node(self) -> ir.TensorBox: + return ir.TensorBox.create( + ir.CppTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + template=self.template, + choice=self, + ) + ) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_utils.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69bb3637f724936da1255c777fb85c96500f226 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_utils.py @@ -0,0 +1,916 @@ +# mypy: allow-untyped-defs +import contextlib +import copy +import functools +import math +import sys +from collections import namedtuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from unittest.mock import patch + +import sympy + +import torch +from torch._prims_common import is_integer_dtype +from torch.utils._sympy.symbol import symbol_is_type, SymT +from torch.utils._sympy.value_ranges import ValueRanges + +from .. import ir +from ..loop_body import LoopBody +from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs +from ..virtualized import ops, OpsValue, V +from .common import ( + CSEVariable, + deduce_output_dtype_by_name, + ExprPrinter, + Kernel, + KernelArgs, + OptimizationContext, +) + + +DTYPE_TO_CPP = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "half", + torch.int64: "int64_t", + torch.int32: "int32_t", + torch.int16: "int16_t", + torch.int8: "int8_t", + torch.uint64: "uint64_t", + torch.uint32: "uint32_t", + torch.uint16: "uint16_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "bfloat16", + torch.complex64: "c10::complex", + torch.float8_e4m3fn: "float8_e4m3fn", + torch.float8_e5m2: "float8_e5m2", +} + +DTYPE_TO_ATEN = { + torch.float32: "at::kFloat", + torch.float64: "at::kDouble", + torch.float16: "at::kHalf", + torch.int64: "at::kLong", + torch.int32: "at::kInt", + torch.int16: "at::kShort", + torch.int8: "at::kChar", + torch.uint64: "at::kUInt64", + torch.uint32: "at::kUInt32", + torch.uint16: "at::kUInt16", + torch.uint8: "at::kByte", + torch.uint32: "at::kUInt32", + torch.uint64: "at::kUInt64", + torch.bool: "at::kBool", + torch.bfloat16: "at::kBFloat16", + torch.complex32: "at::kComplexHalf", + torch.complex64: "at::kComplexFloat", + torch.complex128: "at::kComplexDouble", + torch.float8_e4m3fn: "at::kFloat8_e4m3fn", + torch.float8_e5m2: "at::kFloat8_e5m2", + torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz", +} + +DEVICE_TO_ATEN = { + "cpu": "at::kCPU", + "cuda": "at::kCUDA", +} + +LAYOUT_TO_ATEN = { + torch.strided: "at::kStrided", + torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined] +} + +_IS_WINDOWS = sys.platform == "win32" + +INDEX_TYPE = "int64_t" + +GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"]) + + +def get_promote_dtype(args): + return ( + functools.reduce( + torch.promote_types, # type: ignore[arg-type] + [n.dtype for n in args if isinstance(n, CppCSEVariable)], + ) + if all(n.dtype is not None for n in args if isinstance(n, CppCSEVariable)) + else None # not enough info to calculate the promote dtype + ) + + +def promote_args(new_args): + def promote_arg(arg, promote_type): + if ( + isinstance(arg, CppCSEVariable) + and arg.dtype + and promote_type + and arg.dtype != promote_type + ): + arg = ops.to_dtype(arg, promote_type) + arg = arg.value if isinstance(arg, OpsValue) else arg + arg.dtype = promote_type + return arg + + promote_type = get_promote_dtype(new_args) + promote_fn = functools.partial( + promote_arg, + promote_type=promote_type, + ) + if ( + all( + new_arg.dtype is not None + for new_arg in new_args + if isinstance(new_arg, CppCSEVariable) + ) + and promote_type + ): + new_args = list(map(promote_fn, new_args)) + return new_args + + +def get_opt_ctx(node: torch.fx.Node) -> OptimizationContext: + return node.meta.get(OptimizationContext.key, None) + + +def get_current_node_opt_ctx() -> OptimizationContext: + assert V.interpreter.current_node + return get_opt_ctx(V.interpreter.current_node) + + +def deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs): + if ( + output_dtype := deduce_output_dtype_by_name( + name, + *args, + **kwargs, + ) + ) is not None: + return output_dtype + elif name == "masked": + # Leslie: perhaps we can also deduce the masked dtype by + # inputs' CppCseVariable like other. Let's check it if any + # unexpected failures. + assert ( + hasattr(V.interpreter, "current_node") + and V.interpreter.current_node.target.startswith("masked_subblock") + and get_current_node_opt_ctx() is not None + ) + return get_current_node_opt_ctx().dtype + else: + # deduce output dtype by inputs' dtype + assert all( + arg.dtype is not None for arg in args if isinstance(arg, CppCSEVariable) + ) + return functools.reduce( + torch.promote_types, # type: ignore[arg-type] + [arg.dtype for arg in args if isinstance(arg, CppCSEVariable)], + ) + + +class CppCSEVariable(CSEVariable): + def __init__(self, name, bounds: ValueRanges[Any]) -> None: + super().__init__(name, bounds) + self.is_vec = False + self.dtype: Optional[torch.dtype] = None + self.dependent_itervars: Set[sympy.Symbol] = set() + + def __repr__(self) -> str: + return ( + f"CppCSEVariable(name: {self.name}, bounds: {self.bounds}, is_vec: {self.is_vec}, dtype: {self.dtype}, " + f"dependent_itervars: {self.dependent_itervars})" + ) + + def update_on_args(self, name, args, kwargs): + if name == "load": + # args[2] is index + self._set_dependent_itervars(args[2]) + else: + # propagate relevant itervars and is_vec from args + self.dependent_itervars.update( + *[ + arg.dependent_itervars + for arg in args + if isinstance(arg, CppCSEVariable) + ] + ) + if name == "index_expr": + self._set_dependent_itervars(args[0]) + if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)): + self.is_vec = True + # NOTE [Deduce dtype of CppCSEVariable at runtime] + self.dtype = deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs) + assert self.dtype is not None + + def _set_dependent_itervars(self, index: sympy.Expr): + """ + Set the relevant itervars for this variable based on the `index` expression. + This includes the itervars directly used in the `index` as well as relevant itervars + of other cse variables used in the `index`. + """ + for s in index.free_symbols: + if s in V.kernel.itervars: + self.dependent_itervars.add(s) # type: ignore[arg-type] + elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined] + self.dependent_itervars.update( + V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined] + ) + + def depends_on(self, itervar: sympy.Symbol): + return itervar in self.dependent_itervars + + +class CppPrinter(ExprPrinter): + def _print_Integer(self, expr): + return ( + f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L" + ) + + def _print_Where(self, expr): + c = self.paren(self.doprint(expr.args[0])) + p = self.paren(self.doprint(expr.args[1])) + q = self.paren(self.doprint(expr.args[2])) + return f"{c} ? {p} : {q}" + + def _print_ModularIndexing(self, expr): + x, div, mod = expr.args + x = self.paren(self.doprint(x)) + if div != 1: + div = self.paren(self.doprint(div)) + if expr.is_integer: + x = f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" + else: + x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + mod = self.paren(self.doprint(mod)) + return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})" + + def _print_FloorDiv(self, expr): + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + if expr.is_integer: + return f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" + return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + + def _print_floor(self, expr): + assert len(expr.args) == 1 + r = f"std::floor({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_FloorToInt(self, expr): + assert len(expr.args) == 1 + r = f"std::floor({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_TruncToInt(self, expr): + assert len(expr.args) == 1 + r = f"std::trunc({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" + + def _print_TruncToFloat(self, expr): + assert len(expr.args) == 1 + return f"std::trunc({self._print(expr.args[0])})" + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"static_cast({self._print(expr.args[0])})" + + # TODO: This is wrong if one of the inputs is negative. This is hard to + # tickle though, as the inputs are typically positive (and if we can prove + # they are positive, we will have used Mod instead, for which this codegen + # is right). + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_CMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + # TODO: This is only accurate up to 2**53 + return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" + + # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT + # use std::pow, that operates on floats + def _print_PowByNatural(self, expr): + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + def _print_FloatPow(self, expr): + base, exp = expr.args + return f"std::pow({self._print(base)}, {self._print(exp)})" + + def _print_Pow(self, expr): + # Uses float constants to perform FP div + base, exp = expr.args + base = self._print(base) + + if exp == 0.5 or exp == -0.5: + return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})" + if exp.is_integer: + exp = int(exp) + if exp > 0: + r = "*".join([self.paren(base)] * exp) + elif exp < 0: + r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp))) + else: # exp == 0 + r = "1.0" + + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + else: + # TODO: float vs double + return f"std::pow({base}, {float(exp)})" + + def _print_Rational(self, expr): + # Uses float constants to perform FP div + if expr.q == 1: + r = f"{expr.p}" + else: + r = f"{expr.p}.0/{expr.q}.0" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Min(self, expr): + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::min({il})" + + def _print_Max(self, expr): + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::max({il})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return f"std::abs({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cos(self, expr): + assert len(expr.args) == 1 + return f"std::cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr): + assert len(expr.args) == 1 + return f"std::cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr): + assert len(expr.args) == 1 + return f"std::acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr): + assert len(expr.args) == 1 + return f"std::sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr): + assert len(expr.args) == 1 + return f"std::sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr): + assert len(expr.args) == 1 + return f"std::asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr): + assert len(expr.args) == 1 + return f"std::tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr): + assert len(expr.args) == 1 + return f"std::tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr): + assert len(expr.args) == 1 + return f"std::atan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sqrt(self, expr): + return f"std::sqrt({self._print(expr.args[0])})" + + def _print_RoundToInt(self, expr): + assert len(expr.args) == 1 + # TODO: dispatch to llrint depending on index type + return f"std::lrint({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr): + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + return f"static_cast(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})" + + def _print_BooleanTrue(self, expr): + return "true" + + def _print_BooleanFalse(self, expr): + return "false" + + +# A function to print, useful for printing sympy symbols. +cexpr = CppPrinter().doprint + + +def cexpr_index(index): + return f"static_cast<{INDEX_TYPE}>({cexpr(index)})" + + +def value_to_cpp(value, cpp_type): + if value == float("-inf"): + return f"-std::numeric_limits<{cpp_type}>::infinity()" + elif value == float("inf"): + return f"std::numeric_limits<{cpp_type}>::infinity()" + elif isinstance(value, bool): + return f"static_cast<{cpp_type}>({str(value).lower()})" + elif math.isnan(value): + return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" + else: + return f"static_cast<{cpp_type}>({repr(value)})" + + +def rewrite_index_for_function( + localize_buffer_handler: "LocalizeBufferHandler", + index: sympy.Expr, + global_buf_name: str, +): + # Local buffer at the inner dimensions + snode = V.graph.scheduler.name_to_buf[global_buf_name].defining_op + local_buf = localize_buffer_handler.global_to_local[global_buf_name] + scheduler_nodes = snode.get_nodes() + _, (group, reduction_group) = max( + scheduler_nodes, key=lambda x: int(x.is_reduction()) + ).group + call_ranges = tuple(group) + tuple(reduction_group) + indices_to_keep = [ + f"x{len(call_ranges) - (idx + 1)}" + for idx in range(len(local_buf.get_layout().size)) + ] + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined] + replacements = {} + for x in sorted_symbols: + if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined] + # Only keep index used by local buffer + replacements[x] = sympy.core.numbers.Zero() + index = sympy_subs(index, replacements) # type: ignore[arg-type] + return index + + +def rewrite_index_for_nodes( + localize_buffer_handler: "LocalizeBufferHandler", + index: sympy.Expr, + global_buf_name: str, +): + used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)} + index_vars = [] + local_buf = localize_buffer_handler.global_to_local[global_buf_name] + for i in range(len(local_buf.get_size())): + var = sympy_index_symbol_with_prefix(SymT.INDEX, i) + index_vars.append(var if var in used_vars else 0) + index = local_buf.layout.make_indexer()(index_vars) + return index + + +class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined] + def __init__( + self, + inner, + global_to_local: Dict[str, ir.Buffer], + rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr], + ) -> None: + super().__init__(inner) + self.global_to_local = global_to_local + self.rewrite_index = rewrite_index + + def localize(self, name: str, index: sympy.Expr): + if self.global_to_local and name in self.global_to_local: + assert self.rewrite_index is not None + index = self.rewrite_index(self, index, name) + name = self.global_to_local[name].get_name() + return name, index + + def load(self, name: str, index: sympy.Expr): + return self._inner.load(*self.localize(name, index)) + + def store(self, name, index, value, mode=None): + local_buffer_name, local_buffer_index = self.localize(name, index) + res = self._inner.store(local_buffer_name, local_buffer_index, value, mode) + if ( + self.global_to_local + and name in self.global_to_local + and isinstance(V.kernel, Kernel) + ): + # Remove name of local buffer from Kernel.store_buffer_names + # local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store. + V.kernel.store_buffer_names.discard(local_buffer_name) + return res + + def store_reduction(self, name, index, value): + return self._inner.store_reduction(*self.localize(name, index), value) + + +class LocalBufferContext: + """ + This class creates a context that helps to generate code involving Inductor IR with + function local buffers. These buffers are constructed during the codegen process and + are used to store intermediate results such as local accumulators. We do not want to + add them to `V.graph` since they are not global and we do not want to add them as + function arguments either. So we patch the codegen processes under this scope to support + these buffers without exposure to the outside world. + """ + + def __init__(self, kernel_args: KernelArgs) -> None: + self.kernel_args = kernel_args + self.exit_stack = contextlib.ExitStack() + # map local buffer name to local buffer + self.local_buffers: Dict[str, ir.Buffer] = {} + # map global buffer name to global buffer + self.global_buffers: Dict[str, ir.Buffer] = {} + # map global buffer name to local buffer + self.global_to_local: Dict[str, ir.Buffer] = {} + + def __enter__(self): + self.exit_stack.__enter__() + original_get_dtype = V.graph.get_dtype + + def get_dtype(name): + if name in self.local_buffers: + return self.local_buffers[name].get_dtype() + return original_get_dtype(name) + + self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype)) + + original_input = self.kernel_args.input + + def input(name): + if name in self.local_buffers: + return name + return original_input(name) + + self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input)) + + original_output = self.kernel_args.output + + def output(name): + if name in self.local_buffers: + return name + return original_output(name) + + self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output)) + + # Set current LocalBufferContext into V + self.exit_stack.enter_context(V.set_local_buffer_context(self)) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.local_buffers.clear() + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + def add_local_buffer( + self, local_buffer: ir.Buffer, global_buffers: Optional[List[ir.Buffer]] = None + ): + assert local_buffer.get_name() not in self.local_buffers + self.local_buffers[local_buffer.get_name()] = local_buffer + if global_buffers: + for global_buffer in global_buffers: + global_buffer_name = global_buffer.get_name() + assert ( + global_buffer_name not in self.global_buffers + and global_buffer_name not in self.global_to_local + ) + self.global_buffers[global_buffer_name] = global_buffer + self.global_to_local[global_buffer_name] = local_buffer + V.graph.removed_buffers.add(global_buffer_name) + + def localize_function( + self, + fn: Callable[..., Any], + rewrite_index: Callable[ + ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr + ] = rewrite_index_for_function, + ): + def inner(*args, **kwargs): + with V.set_ops_handler( + LocalizeBufferHandler( + V.get_ops_handler(), + global_to_local=self.global_to_local, + rewrite_index=rewrite_index, + ) + ): + return fn(*args, **kwargs) + + return inner + + def localize_nodes( + self, + nodes: List[ir.IRNode], + rewrite_index: Callable[ + ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr + ] = rewrite_index_for_nodes, + ) -> List[ir.IRNode]: + """ + Given `local_buf` and `global_buf` registered in current `LocalBufferContext` + though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf` + for the given `nodes` and returns a new list of IR nodes that work on `local_buf` + instead of `global_buf`, i.e., all the loads and stores are redirected to + `local_buf`. This helps the fused loops to work on smaller-sized local buffers + for better data locality. + + The the data access of `local_buf` is assumed to be contiguous with the + same order as the `global_buf`. + """ + assert len(nodes) > 0 + + def wrap_inner_fn_for_node(node: ir.IRNode): + loops = node.data if isinstance(node, ir.ComputedBuffer) else node + assert isinstance(loops, ir.Loops) + new_loops = copy.copy(loops) + if isinstance(node, ir.ComputedBuffer): + new_node = ir.ComputedBuffer( + node.get_name(), node.get_layout(), new_loops + ) + else: + new_node = new_loops # type: ignore[assignment] + + new_loops.inner_fn = self.localize_function( + new_loops.inner_fn, + rewrite_index, + ) + return new_node + + return [wrap_inner_fn_for_node(node) for node in nodes] + + +def unify_mask_base_type( + buffer: IndentedBuffer, + vars: Tuple[CSEVariable, ...], + dtype=torch.float, +): + """ + Given list of cse variables, + Cast each to new mask base dtype and return casted cse variable. + """ + new_vars = ( + V.kernel.cse.generate( + buffer, + f"{V.kernel._get_mask_cast(var, dtype)}", + ) + for var in vars + ) + return new_vars + + +def codegen_rand(offset, code, rand_function, dst_dtype=torch.float32): + assert is_integer_dtype(offset.dtype) + code.writeline("[&]()") + with code.indent(): + code.writeline( + f"{DTYPE_TO_CPP[offset.dtype]} offset[{V.kernel.tiling_factor}];" + ) + code.writeline(f"{DTYPE_TO_CPP[dst_dtype]} result[{V.kernel.tiling_factor}];") + code.writeline(f"{offset}.store(offset);") + code.writeline( + f"for( {DTYPE_TO_CPP[offset.dtype]} offset_idx = 0; offset_idx < {V.kernel.tiling_factor}; offset_idx++ )" + ) + with code.indent(): + code.writeline(rand_function) + num_vectors = V.kernel._get_num_vectors(dtype=dst_dtype) + if num_vectors == 1: + code.writeline( + f"return at::vec::Vectorized<{DTYPE_TO_CPP[dst_dtype]}>::loadu(result);" + ) + else: + code.writeline( + f"return at::vec::VectorizedN<{DTYPE_TO_CPP[dst_dtype]}, {num_vectors}>::loadu(result);" + ) + code.writeline("()") + return code + + +def get_gemm_template_output_and_compute_dtype(input_dtype): + if input_dtype == torch.uint8: + return (torch.int32, torch.int32) + else: + return (torch.float32, torch.float32) + + +def create_epilogue_with_attr(input_buffer, attr, **kwargs): + input_loader = input_buffer.make_loader() + dtype = input_buffer.get_dtype() + if attr == "relu": + + def inner_fn(index): + input = input_loader(index) + zero = ops.constant(0, dtype) + return ops.maximum(input, zero) + + elif attr == "gelu": + assert "algorithm" in kwargs + if kwargs["algorithm"] == "none": + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + half = ops.constant(0.5, torch.float) + one = ops.constant(1.0, torch.float) + const = ops.constant(0.7071067811865476, torch.float) + result = input * half * (ops.erf(input * const) + one) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + else: + assert kwargs["algorithm"] == "tanh" + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + half = ops.constant(0.5, torch.float) + one = ops.constant(1.0, torch.float) + const1 = ops.constant(0.7978845608028654, torch.float) + const2 = ops.constant(0.044715, torch.float) + result = ( + half + * input + * ( + one + + ops.tanh(const1 * (input + const2 * input * input * input)) + ) + ) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr == "swish": + + def inner_fn(index): + input = input_loader(index) + result = input * ops.sigmoid(input) + return result + + elif attr == "sigmoid": + + def inner_fn(index): + return ops.sigmoid(input_loader(index)) + + elif attr == "tanh": + + def inner_fn(index): + return ops.tanh(input_loader(index)) + + elif attr == "hardswish" or attr == "hardsigmoid": + + def hardsigmoid_float(input): + zero = ops.constant(0, torch.float) + six = ops.constant(6, torch.float) + three = ops.constant(3, torch.float) + one_over_six = ops.constant(0.16666666666666666, torch.float) + max = ops.maximum(input + three, zero) + min = ops.minimum(max, six) + return min * one_over_six + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + result = hardsigmoid_float(input) + if attr == "hardswish": + result = input * result + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr == "leaky_relu": + assert "scalars" in kwargs + assert len(kwargs["scalars"]) == 1 + negative_slope = kwargs["scalars"][0] + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + zero = ops.constant(0, torch.float) + result = ops.where( + input > zero, input, input * ops.constant(negative_slope, torch.float) + ) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr == "hardtanh": + assert "scalars" in kwargs + assert len(kwargs["scalars"]) == 2 + min_value = kwargs["scalars"][0] + max_value = kwargs["scalars"][1] + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + result = ops.minimum( + ops.maximum(input, ops.constant(min_value, torch.float)), + ops.constant(max_value, torch.float), + ) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr in ["add", "sub", "mul"]: + assert "other" in kwargs + other = kwargs["other"] + num_input_dims = len(input_buffer.get_size()) + num_other_dims = len(other.get_size()) + dims_diff = num_input_dims - num_other_dims + other_loader = other.make_loader() + + def inner_fn(index): + op = getattr(ops, attr) + if dims_diff != 0: + return op(input_loader(index), other_loader(index[dims_diff:])) + else: + return op(input_loader(index), other_loader(index)) + + elif attr == "bias_add": + assert "other" in kwargs + assert "beta" in kwargs + assert "dtype" in kwargs + beta = kwargs["beta"] + other = kwargs["other"] + dtype = kwargs["dtype"] + bias_loader = other.make_loader() + + def inner_fn(index): + bias = bias_loader(index) + input = input_loader(index) + if beta != 1: + result = ops.constant(beta, torch.float) * bias + input + else: + result = bias + input + return result + + else: + raise ValueError(f"Unsupported epilogue attribute: {attr}") + return ir.Pointwise( + device=input_buffer.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input_buffer.get_size(), + ) + + +def _get_loop_body(fn_list): + if all(isinstance(fn, LoopBody) for fn in fn_list): + loop_bodies = fn_list + else: + if hasattr(fn_list[0], "original_fn"): + # For the case of local buffer, we wrap the fn with localize_function + assert all(hasattr(fn, "original_fn") for fn in fn_list) + assert all( + isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list + ) + loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list] + else: + assert all(isinstance(fn, functools.partial) for fn in fn_list) + assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list) + loop_bodies = [fn.args[0]._body for fn in fn_list] + assert loop_bodies is not None + return loop_bodies + + +def _get_dtype_from_loopbodies(loop_bodies): + dtypes = set() + for loop_body in loop_bodies: + graphs = [loop_body.root_block.graph] + [ + body.graph for body in list(loop_body.subblocks.values()) + ] + for graph in graphs: + for node in graph.nodes: + if node.op != "call_method": + continue + dtypes.add(node.meta[OptimizationContext.key].dtype) + return dtypes diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..290787b99b9930c88fb9bd6d8316c5fcf60ea9d4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -0,0 +1,2595 @@ +# mypy: allow-untyped-defs +import functools +import math +import os +import sys +from itertools import count +from typing import Dict, List, Optional, Tuple + +import sympy +from sympy import Expr + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._ops +from torch._inductor.codegen.debug_utils import IntermediateValueDebuggingLevel +from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes + +from .. import config, ir +from ..utils import _align, ALIGN_BYTES, cache_on_self, sympy_product +from ..virtualized import V +from .aoti_hipify_utils import maybe_hipify_code_wrapper +from .common import IndentedBuffer +from .cpp_utils import ( + cexpr, + DEVICE_TO_ATEN, + DTYPE_TO_ATEN, + DTYPE_TO_CPP, + LAYOUT_TO_ATEN, +) +from .wrapper import EnterSubgraphLine, ExitSubgraphLine, WrapperCodeGen + + +class CppWrapperCpu(WrapperCodeGen): + """ + Generates cpp wrapper for running on CPU and calls cpp kernels + """ + + def __init__(self): + if not hasattr(self, "device"): + self.device = "cpu" + super().__init__() + self.declare = "auto " + self.declare_maybe_reference = "decltype(auto) " + self.ending = ";" + self.open_bracket = "{" + self.closed_bracket = "}" + self.comment = "//" + self.namespace = "at::" + self.none_str = "nullptr" if config.abi_compatible else "at::Tensor()" + self.extern_call_ops = set() + self.size = "sizes()" + self.stride = "strides()" + self.cuda = False + self.supports_intermediate_hooks = False + self.outputs_need_copy = set() + self.kernel_callsite_id = count() + self.var_array_id = ( + count() + ) # for different types of local array variable declarations + self.declared_var_array_vars = set() + self.int_array_id = count() # for int array local variable declarations + self.declared_int_array_vars = set() + self.tmp_tensor_id = count() # for tmp tensor local variable declarations + self.arg_var_id = count() + self.used_cached_devices = set() + self.used_cached_dtypes = set() + self.used_cached_layouts = set() + self.cached_output_id = count() + self.scalar_to_tensor_id = count() + self.custom_op_wrapper_loaded = False + self.expr_printer = cexpr + + def generate_kernel_call( + self, + kernel_name: str, + call_args, + grid=None, + device_index=None, + cuda=True, + triton=True, + arg_types=None, + raw_args=None, + grid_fn: str = "grid", + triton_meta=None, + autotune_configs=None, + grid_extra_kwargs="", + ): + """ + Generates kernel call code. + + cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ + if cuda: + return super().generate_kernel_call( + kernel_name, + call_args, + grid, + device_index, + cuda, + triton, + arg_types, + raw_args, + grid_fn, + triton_meta, + autotune_configs, + grid_extra_kwargs, + ) + else: + if config.abi_compatible: + assert arg_types is not None and len(call_args) == len( + arg_types + ), "Mismatch call_args and arg_types in generate_kernel_call" + new_args = [] + for idx, arg in enumerate(call_args): + if "*" in arg_types[idx]: + var_name = f"var_{next(self.arg_var_id)}" + self.writeline( + f"auto* {var_name} = get_data_ptr_wrapper({arg});" + ) + new_args.append(f"({arg_types[idx]})({var_name})") + else: + # arg is a scalar + new_args.append(arg) + self.writeline(self.wrap_kernel_call(kernel_name, new_args)) + else: + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + + def write_constant(self, name, hashed): + # include a hash so our code cache gives different constants different files + self.header.writeline(f"// {name} {hashed}") + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + if V.graph.aot_mode: + for header_cpp_file in ("interface.cpp", "implementation.cpp"): + with open( + os.path.join( + os.path.dirname(__file__), "aoti_runtime", header_cpp_file + ) + ) as f: + self.header.splice(f.read()) + else: + self.header.splice( + """ + import torch + from torch._inductor.codecache import CppWrapperCodeCache + + cpp_wrapper_src = ( + ''' + """ + ) + + if config.abi_compatible: + self.header.splice( + f"#include " + ) + self.header.splice( + """ + #include + #include + #include + """ + ) + if V.graph.aot_mode: + self.header.splice( + """ + #include + """ + ) + else: + self.header.splice( + """ + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #define reinterpret_tensor torch::inductor::_reinterpret_tensor + #define alloc_from_pool torch::inductor::_alloc_from_pool + """ + ) + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if config.profiler_mark_wrapper_call or enable_kernel_profile: + self.header.splice("#include ") + + self.header.splice("typedef at::Half half;") + self.header.splice("typedef at::BFloat16 bfloat16;") + self.header.splice("#include ") + + if not V.graph.aot_mode: + self.header.splice( + """ + #include + + namespace py = pybind11; + using namespace torch::aot_inductor; + + class RAIIPyObject { + public: + RAIIPyObject() : obj_(nullptr) {} + RAIIPyObject(PyObject* obj) : obj_(obj) {} + ~RAIIPyObject() { + Py_XDECREF(obj_); + } + RAIIPyObject& operator=(const RAIIPyObject& other) { + if (this != &other) { + Py_XDECREF(obj_); + obj_ = other.obj_; + Py_XINCREF(obj_); + } + return *this; + } + operator PyObject*() { + return obj_; + } + PyObject* get() { + return obj_; + } + private: + PyObject* obj_; + }; + """ + ) + + # Round up to the nearest multiple of ALIGN_BYTES + # ALIGN_BYTES must be a power of 2 + self.header.splice( + f""" + [[maybe_unused]] static int64_t align(int64_t nbytes) {{ + return (nbytes + {ALIGN_BYTES} - 1) & -{ALIGN_BYTES}; + }} + """ + ) + + @functools.lru_cache(None) # noqa: B019 + def include_extra_header(self, header: str): + # This is needed for cpp to python dtype conversion + self.header.splice(f"#include <{header}>") + + def mark_output_type(self): + # mark output type to unwrap tensor back to python scalar + from ..ir import ShapeAsConstantBuffer + + output_is_tensor = {} + for idx, x in enumerate(V.graph.graph_outputs): + if isinstance(x, ShapeAsConstantBuffer): + output_is_tensor[idx] = False + else: + output_is_tensor[idx] = True + + self.output_is_tensor = output_is_tensor + + def write_prefix(self): + if V.graph.is_const_graph: + # We do not write prefix for constant graph, it will be written by main module. + return + + if V.graph.aot_mode: + self.prefix.writeline("namespace torch {") + self.prefix.writeline("namespace aot_inductor {") + + def write_input_output_info( + self, + info_kind: str, + idx: int, + name: str, + ): + self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""") + + @staticmethod + def get_input_cpp_type(input): + assert config.use_minimal_arrayref_interface + + if isinstance(input, sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype(input) + assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}" + return DTYPE_TO_CPP[dtype] + return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>" + + def generate_input_output_runtime_checks(self): + # In debug_compile mode, we generate checks to ensure the dtype/shape/stride of each + # real input/output tensor match ones provided at compile time via sample + # input/output. + def gen_check(handle_kind, idx, name, tensor): + self.prefix.writeline(f"auto {name} = {handle_kind}[{idx}];") + self.codegen_tensor_dtype_var_decl(self.prefix, name) + expected_dtype_name = DTYPE_TO_ATEN[tensor.dtype] + dtype_str = str(tensor.dtype).split(".")[-1] + self.prefix.splice( + f""" + int32_t {name}_expected_dtype = aoti_torch_dtype_{dtype_str}(); + if ({name}_expected_dtype != {name}_dtype) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched dtype, " + << "expected: " << {name}_expected_dtype << "({expected_dtype_name}), " + << "but got: " << {name}_dtype << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + self.codegen_input_size_var_decl(self.prefix, name) + for dim_idx, d in enumerate(tensor.get_size()): + if isinstance(d, (int, sympy.Integer)): + self.prefix.splice( + f""" + if ({d} != {name}_size[{dim_idx}]) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched dim value at {dim_idx}, " + << "expected: {d}, " << "but got: " << {name}_size[{dim_idx}] + << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + else: + from torch.utils._sympy.value_ranges import bound_sympy + + sym_range = bound_sympy(d, V.graph.sizevars.shape_env.var_to_range) + if not math.isinf(sym_range.lower): + self.prefix.splice( + f""" + if ({name}_size[{dim_idx}] < {sym_range.lower}) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: dim value is too small at {dim_idx}, " + << "expected it to be >= {sym_range.lower}, " << "but got: " + << {name}_size[{dim_idx}] << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + if not math.isinf(sym_range.upper): + self.prefix.splice( + f""" + if ({name}_size[{dim_idx}] > {sym_range.upper}) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: dim value is too large at {dim_idx}, " + << "expected to be <= {sym_range.upper}, " << "but got: " + << {name}_size[{dim_idx}] << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + + self.codegen_input_stride_var_decl(self.prefix, name) + for stride_idx, s in enumerate(tensor.get_stride()): + if not isinstance(s, (int, sympy.Integer)): + continue + self.prefix.splice( + f""" + if ({s} != {name}_stride[{stride_idx}]) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched stride value at {stride_idx}, " + << "expected: {s}, " << "but got: " << {name}_stride[{stride_idx}] + << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + + # force noinline to avoid any potential compilation slowdown due to aggressive + # inline done by the host compiler + self.prefix.splice( + """ + AOTI_NOINLINE static void __check_inputs_outputs( + AtenTensorHandle* input_handles, + AtenTensorHandle* output_handles) { + """ + ) + with self.prefix.indent(): + for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()): + gen_check("input_handles", idx, name, tensor) + self.prefix.writeline("}") + + def write_wrapper_decl(self): + inputs_len = len(V.graph.graph_inputs.keys()) + if V.graph.aot_mode: + if config.use_minimal_arrayref_interface and not V.graph.is_const_graph: + input_cpp_types = ", ".join( + f"{CppWrapperCpu.get_input_cpp_type(x)}" + for x in V.graph.graph_inputs.values() + ) + output_arrayref_types = ", ".join( + f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>" + for x in V.graph.graph_outputs + ) + + self.prefix.splice( + f""" + using AOTInductorModelInputs = std::tuple<{input_cpp_types}>; + using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>; + """ + ) + + if V.graph.const_module: + self.header.splice(V.graph.const_module.wrapper_code.header) + self.prefix.splice(V.graph.const_code) + + if V.graph.is_const_graph: + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + else: + if not config.aot_inductor.use_runtime_constant_folding: + # If we do not split the constant graph, we'll just create + # an empty implementation when wrapping the main module. + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) {} + + """ + ) + + run_impl_proto = """ + void AOTInductorModel::run_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + # Since we are removing non-abi-compatible mode, let's generate + # runtime checks only for abi_compatible mode to avoid extra branches. + if config.aot_inductor.debug_compile and config.abi_compatible: + self.generate_input_output_runtime_checks() + run_impl_proto += """ + __check_inputs_outputs(input_handles, output_handles); + """ + if config.use_minimal_arrayref_interface: + self.prefix.splice( + """ + template <> + AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface< + AOTInductorModelInputs, AOTInductorModelOutputs>( + const AOTInductorModelInputs& inputs, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + self.suffix.splice(run_impl_proto) + self.suffix.splice( + """ + AOTInductorModelInputs inputs; + convert_handles_to_inputs(input_handles, inputs); + auto outputs = run_impl_minimal_arrayref_interface( + inputs, stream, proxy_executor); + // NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this + // interface to perform well for a DSO using the minimal arrayref interface, all we need + // to do is provide ThreadLocalCachedTensor for each one! + convert_outputs_to_handles(outputs, output_handles); + } + """ + ) + + self.suffix.splice( + """ + extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface( + AOTInductorModelHandle model_handle, + const AOTInductorModelInputs& inputs, + AOTInductorModelOutputs& outputs) { + auto model = reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + outputs = model->run_impl_minimal_arrayref_interface( + inputs, + (torch::aot_inductor::DeviceStreamType)nullptr, + nullptr); + }) + } + """ + ) + else: + self.prefix.splice(run_impl_proto) + else: + # cpp entry function for JIT with cpp wrapper + self.prefix.splice( + """ + void inductor_entry_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed) + ) { + """ + ) + with self.prefix.indent(): + # assign inputs and outputs in both cases so the later codegen can be simplified + if not config.use_minimal_arrayref_interface: + if not V.graph.is_const_graph: + if V.graph.aot_mode: + num_args = len(V.graph.graph_inputs) + else: + # Weights are promoted in the JIT mode + num_args = len(V.graph.graph_inputs) + len(V.graph.constants) + # release GIL to support multiple instances inference (in different threads of the same process) + self.prefix.splice("py::gil_scoped_release release;") + + if config.abi_compatible: + self.prefix.splice( + f""" + auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args}); + """ + ) + else: + # This looks dumb, but can avoid creating two versions of code in the AOTInductor runtime. + self.prefix.splice( + f""" + auto inputs = alloc_tensors_by_stealing_from_handles(input_handles, {num_args}); + """ + ) + + if inputs_len != 0: + for idx, input_key in enumerate(V.graph.graph_inputs.keys()): + if config.use_minimal_arrayref_interface: + self.prefix.writeline( + f"auto {input_key} = std::get<{idx}>(inputs);" + ) + continue + # unwrap input tensor back to scalar + if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype( + V.graph.graph_inputs[input_key] # type: ignore[arg-type] + ) + assert ( + dtype is not None + ), "Fails to get the dtype of the sympy.Expr" + cpp_dtype = DTYPE_TO_CPP[dtype] + if config.abi_compatible: + self.codegen_tensor_item( + dtype, f"inputs[{idx}]", input_key, self.prefix + ) + else: + self.prefix.writeline( + f"{cpp_dtype} {input_key} = inputs[{idx}].item<{cpp_dtype}>();" + ) + else: + self.prefix.writeline( + f"auto {input_key} = std::move(inputs[{idx}]);" + ) + + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + for idx, constants_key in enumerate(V.graph.constants.keys()): + if V.graph.aot_mode: + # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. + # Don't call std::move here because it will cause constants_ to lose the ownership. + if config.abi_compatible: + self.prefix.writeline( + f"""auto {constants_key} = constants_->at({idx});""" + ) + else: + self.prefix.writeline( + f"auto {constants_key} = *tensor_handle_to_tensor_pointer(" + + f"""constants_->at({idx}));""" + ) + else: + # Append constants as inputs to the graph + constants_idx = inputs_len + idx + if config.abi_compatible: + self.prefix.writeline( + f"auto {constants_key} = std::move(inputs[{constants_idx}]);" + ) + else: + self.prefix.writeline( + f"auto {constants_key} = inputs[{constants_idx}];" + ) + + self.codegen_inputs(self.prefix, V.graph.graph_inputs) + + if V.graph.aot_mode: + if not V.graph.is_const_graph: + if config.use_minimal_arrayref_interface: + # TODO: input shape checking for regular tensor interface as well? + self.codegen_input_numel_asserts() + else: + self.prefix.writeline("inputs.clear();") + self.prefix.writeline( + "auto& kernels = static_cast(*this->kernels_.get());" + ) + + def codegen_input_numel_asserts(self): + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(buf.get_size()) == 0: + continue + numel = buf.get_numel() + self.prefix.writeline(f"assert_numel({name}, {numel});") + + def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name): + if config.abi_compatible: + code.writeline(f"int32_t {name}_dtype;") + code.writeline( + "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype" + f"({name}, &{name}_dtype));" + ) + else: + # Note that we don't have a corresponding class method from + # the WrapperCodeGen since this method is used for asserting AOTI + # cpp wrapper code. + code.writeline(f"auto {name}_dtype = {name}.dtype();") + + def codegen_input_size_var_decl(self, code: IndentedBuffer, name): + if config.abi_compatible: + code.writeline(f"int64_t* {name}_size;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes({name}, &{name}_size));" + ) + else: + super().codegen_input_size_var_decl(code, name) + + def codegen_input_stride_var_decl(self, code: IndentedBuffer, name): + if config.abi_compatible: + code.writeline(f"int64_t* {name}_stride;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides({name}, &{name}_stride));" + ) + else: + super().codegen_input_stride_var_decl(code, name) + + def codegen_model_kernels(self): + self.prefix.writeline("namespace {") + self.prefix.writeline( + "class AOTInductorModelKernels : public AOTInductorModelKernelsBase {" + ) + self.prefix.writeline(" public:") + declare_kernel = set(self.src_to_kernel.values()) + declare_kernel.update( + entry[0] for entry in self.user_defined_kernel_cache.values() + ) + if V.graph.const_module: + declare_kernel.update( + V.graph.const_module.wrapper_code.src_to_kernel.values() + ) + for kernel in sorted(declare_kernel): + self.prefix.writeline( + maybe_hipify_code_wrapper(f" CUfunction {kernel}{{nullptr}};") + ) + self.prefix.writeline("};") + self.prefix.writeline("} // namespace") + + def codegen_model_constructor(self): + """ + // Generated code example + AOTInductorModel::AOTInductorModel() + : AOTInductorModelBase(4, 1) { + inputs_info_[0].name = "input0"; + inputs_info_[0].dtype = "torch.float16"; + ... + constants_info_[0].name = "L__self___weight"; + constants_info_[0].dtype = at::kFloat; + constants_info_[0].offset = 0; + constants_info_[0].data_size = 8192; + constants_info_[0].shape = {64, 32}; + constants_info_[0].stride = {32, 1}; + ... + outputs_info_[0].name = "output0"; + outputs_info_[0].dtype = "torch.float16"; + } + """ + + num_inputs = len(V.graph.graph_inputs) + num_outputs = len(V.graph.graph_outputs) + num_constants = len(V.graph.constants) + self.prefix.splice( + f""" + AOTInductorModel::AOTInductorModel(std::shared_ptr constants_map, + std::shared_ptr> constants_array, + const std::string& device_str, + std::optional cubin_dir) + : AOTInductorModelBase({num_inputs}, {num_outputs}, {num_constants}, device_str, cubin_dir) {{ + """ + ) + + with self.prefix.indent(): + for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()): + assert not isinstance( + inp, sympy.Expr + ), f"input {name=} cannot be symbolic" + self.write_input_output_info("inputs_info_", idx, name) + + all_cuda = all( + V.graph.get_original_value_of_constant(name).is_cuda + for name in V.graph.constants.keys() + if name not in V.graph.folded_constants + ) + for idx, name in enumerate(V.graph.constants.keys()): + tensor = V.graph.get_original_value_of_constant(name) + assert isinstance(tensor, torch.Tensor) + self.prefix.writeline(f"""constants_info_[{idx}].name = "{name}";""") + self.prefix.writeline( + f"constants_info_[{idx}].dtype = static_cast({self.codegen_dtype(tensor.dtype)});" + ) + self.prefix.writeline( + f"constants_info_[{idx}].offset = {tensor.storage_offset()};" + ) + + # If constants to serialize contain cpu tensors, we always align data_size it to 64. + # When loading the constants, the valid data will depends on the size + # not the data_size so there won't be correctness issue. + data_size = ( + torch.ops.mkldnn._nbytes(tensor) + if tensor.is_mkldnn + else tensor.untyped_storage().nbytes() + ) + self.prefix.writeline( + f"constants_info_[{idx}].data_size = {data_size if all_cuda else _align(data_size)};" + ) + + from_folded = "true" if name in V.graph.folded_constants else "false" + self.prefix.writeline( + f"constants_info_[{idx}].from_folded = {from_folded};" + ) + + size_str = ", ".join([str(s) for s in tensor.size()]) + self.prefix.writeline(f"constants_info_[{idx}].shape = {{{size_str}}};") + + stride_str = ", ".join([str(s) for s in tensor.stride()]) + self.prefix.writeline( + f"constants_info_[{idx}].stride = {{{stride_str}}};" + ) + self.prefix.writeline( + f"constants_info_[{idx}].layout = static_cast({self.codegen_layout(tensor.layout)});" + ) + + if tensor.is_mkldnn: + opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md( + tensor + ) + assert ( + opaque_metadata_tensor.dim() == 1 + ), "Expect opaque_metadata_tensor to be 1-D" + + opaque_metadata_list = opaque_metadata_tensor.tolist() + opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list) + self.prefix.writeline( + f"constants_info_[{idx}].opaque_metadata = {opaque_metadata_str};" + ) + if name in V.graph.dynamo_flat_name_to_original_fqn: + original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get( + name, name + ) + elif name in V.graph.allocated_constant_name: + original_fqn = V.graph.allocated_constant_name[name] + else: + raise AssertionError("original_fqn must be set for constant") + self.prefix.writeline( + f"""constants_info_[{idx}].original_fqn = "{original_fqn}";""" + ) + self.prefix.writeline("update_constants_map(std::move(constants_map));") + self.prefix.writeline("update_constants_array(std::move(constants_array));") + + def escape_string(x): + return ( + x.replace("\\", "\\\\") + .replace('"', '\\"') + .replace("\n", "\\n") + .replace("\t", "\\t") + ) + + self.prefix.writeline( + f'in_spec_ = "{escape_string(config.aot_inductor.serialized_in_spec)}";' + ) + self.prefix.writeline( + f'out_spec_ = "{escape_string(config.aot_inductor.serialized_out_spec)}";' + ) + + for idx, output in enumerate(V.graph.graph_outputs): + assert not isinstance( + output, sympy.Expr + ), f"output {name=} cannot be symbolic" + name = f"output{idx}" + self.write_input_output_info("outputs_info_", idx, name) + + self.prefix.writeline( + "this->kernels_ = std::make_unique();" + ) + + self.prefix.writeline("}") + + def codegen_const_run_driver(self): + """ + // Generated code example + std::unordered_map AOTInductorModel::const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization + ) { + std::unordered_map folded_constants_map; + std::vector output_handles; + // build up output_handles over here. + _const_run_impl(output_handles, stream, proxy_executor); + // build up folded_constants_map + return folded_constants_map; + } + """ + + self.prefix.splice( + """ + std::unordered_map AOTInductorModel::const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization + ) { + """ + ) + if not config.aot_inductor.use_runtime_constant_folding: + self.prefix.splice( + """ + if (!initialization) { + std::cerr << "[WARNING] Calling constant_folding in model, but compiled with config: " + << "aot_inductor.use_runtime_constant_folding=False\\n"; + } + return {}; + } + """ + ) + return + + with self.prefix.indent(): + # This is a mapping to the index of constant folding graph's output + const_index_mapping: List[Optional[Tuple[int, str]]] = [None] * len( + V.graph.const_output_index + ) + for idx, (name, _) in enumerate(V.graph.constants.items()): + if name in V.graph.const_output_index: + const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload] + assert ( + None not in const_index_mapping + ), "Not all constant gets mapped for constant folding graph." + + self.prefix.writeline( + f""" + std::unordered_map folded_constants_map; + folded_constants_map.reserve({len(const_index_mapping)}); + std::vector output_handles({len(const_index_mapping)}); + """ + ) + + self.prefix.splice( + """ + // The below assignment of output_handles to constants is not used directly. + // It's only used to memo the correspondence of handle and constants. + """ + ) + + for output_idx, (const_idx, _) in enumerate(const_index_mapping): # type: ignore[misc] + self.prefix.writeline( + f"output_handles[{output_idx}] = constants_->at({const_idx});" + ) + + self.prefix.writeline( + "_const_run_impl(output_handles, stream, proxy_executor);" + ) + + for output_idx, (_, const_name) in enumerate(const_index_mapping): # type: ignore[misc] + self.prefix.writeline( + f'folded_constants_map["{const_name}"] = output_handles[{output_idx}];' + ) + self.prefix.writeline("return folded_constants_map;") + + self.prefix.writeline("}") + + def generate(self, is_inference): + if V.graph.aot_mode and not V.graph.is_const_graph: + self.codegen_model_kernels() + self.codegen_model_constructor() + self.codegen_const_run_driver() + self.write_wrapper_decl() + return super().generate(is_inference) + + def finalize_prefix(self): + cached_dtypes_buffer = IndentedBuffer() + if config.abi_compatible: + for dtype in self.used_cached_dtypes: + cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});") + for device in self.used_cached_devices: + cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});") + for layout in self.used_cached_layouts: + cached_dtypes_buffer.writeline(f"CACHE_TORCH_LAYOUT({layout});") + cached_dtypes_buffer.splice(self.prefix) + self.prefix = cached_dtypes_buffer + + def define_kernel( + self, name: str, kernel: str, metadata: Optional[str] = None, cuda=False + ): + self.header.splice(f"\n{kernel}\n") + + def codegen_scalar_to_tensor(self, output: str): + name = f"scalar_to_tensor_{next(self.scalar_to_tensor_id)}" + self.wrapper_call.writeline( + f"RAIIAtenTensorHandle {name} = scalar_to_tensor_handle({output});" + ) + return name + + def codegen_tensor_item( + self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None + ): + assert ( + config.abi_compatible + ), "codegen_tensor_item is only used for the ABI-compatible mode" + dtype_str = str(dtype).split(".")[-1] + writer = indented_buffer or self + + if dtype == torch.float16 or dtype == torch.bfloat16: + scalar_tmp = f"{scalar}_tmp" + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};") + + # need convert_arrayref_tensor_to_tensor for ArrayRefTensors + tensor = f"convert_arrayref_tensor_to_tensor({tensor})" + + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));" + ) + writer.writeline(f"float {scalar} = float({scalar_tmp});") + else: + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") + + # need convert_arrayref_tensor_to_tensor for ArrayRefTensors + tensor = f"convert_arrayref_tensor_to_tensor({tensor})" + + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" + ) + + @cache_on_self + def get_output_refs(self): + return [ + f"torch::tensor({x.codegen_reference(self.wrapper_call)})" + if isinstance(x, ir.ShapeAsConstantBuffer) and not config.abi_compatible + else x.codegen_reference(self.wrapper_call) + for x in V.graph.graph_outputs + ] + + def generate_return(self, output_refs: List[str]): + cst_names = V.graph.constants.keys() + arr_iface = ( + not V.graph.is_const_graph and config.use_minimal_arrayref_interface + ) # For brevity. + + def use_thread_local_cached_output_tensor(idx, output): + cached_output_name = f"cached_output_{next(self.cached_output_id)}" + cache_type = "Array" if arr_iface else "Tensor" + self.wrapper_call.writeline( + f"thread_local ThreadLocalCachedOutput{cache_type}> " + f"{cached_output_name}({output});" + ) + if arr_iface: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + output_entry = f"std::get<{idx}>(output_arrayref_tensors)" + element_type = f"std::decay_t" + self.wrapper_call.writeline( + f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), " + f"output_handles[{idx}]));" + ) + + if arr_iface: + self.wrapper_call.writeline( + "AOTInductorModelOutputs output_arrayref_tensors;" + ) + + output2idx: Dict[str, int] = {} + for idx, output in enumerate(output_refs): + if output == self.none_str: + continue + + is_constant_buffer = output in cst_names + output_buffer = V.graph.graph_outputs[idx] + if isinstance(output_buffer, ir.BaseView): + output_storage = output_buffer.unwrap_view() + if isinstance(output_storage.data, ir.ConstantBuffer): + is_constant_buffer = True + + if config.abi_compatible: + if isinstance(output_buffer, ir.ShapeAsConstantBuffer): + # Need to wrap scalar into tensor as the main function returns a vector of tensors + output_tensor = self.codegen_scalar_to_tensor(output) + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output_tensor}.release();" + ) + continue + + output_is_tensor_handle_expr = ( + f"std::is_same_v," + "RAIIAtenTensorHandle> || " + f"std::is_same_v," + "AtenTensorHandle> || " + f"std::is_same_v," + "ConstantHandle>" + ) + self.wrapper_call.writeline( + f"if constexpr ({output_is_tensor_handle_expr}) {{" + ) + with self.wrapper_call.indent(): + if arr_iface: + cached_output_name = ( + f"cached_output_{next(self.cached_output_id)}" + ) + output_value_type = f"std::decay_t(output_arrayref_tensors).data()[0])>" + self.wrapper_call.writeline( + f"thread_local RAIIAtenTensorHandle {cached_output_name};" + ) + if is_constant_buffer: + # NOTE(return_constant): In some rare cases where we return + # a constant, we have to return a copy of this constant, + # because (1) constants are not owned by the Model instance + # (2) constants remain the same cross inference runs, + # assuming they are not updated at runtime Basically, we + # cannot release or transfer the ownership of any original + # constant to the user. + self.wrapper_call.writeline( + f"AtenTensorHandle {cached_output_name}_tmp;" + ) + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &{cached_output_name}_tmp);" + ) + self.wrapper_call.writeline( + f"{cached_output_name} = {cached_output_name}_tmp;" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name} = {output}.release();" + ) + self.wrapper_call.writeline( + f"convert_handle_to_arrayref_tensor({cached_output_name}, " + f"std::get<{idx}>(output_arrayref_tensors));" + ) + else: + if is_constant_buffer: + # See NOTE(return_constant) above. + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &output_handles[{idx}]);" + ) + else: + if output in output2idx: + src_idx = output2idx[output] + self.wrapper_call.writeline( + f"output_handles[{idx}] = output_handles[{src_idx}];" + ) + else: + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output}.release();" + ) + self.wrapper_call.writeline("} else {") + with self.wrapper_call.indent(): + use_thread_local_cached_output_tensor(idx, output) + self.wrapper_call.writeline("}") + + else: + assert ( + not arr_iface + ), "minimal ArrayRef interface is only supported in ABI-compatible mode" + if is_constant_buffer: + output_expr = f"{output}.clone()" + # See NOTE(return_constant) above. + else: + output_expr = output + self.wrapper_call.writeline( + f"output_handles[{idx}] = reinterpret_cast(" + + f"new at::Tensor({output_expr}));" + ) + + if output not in output2idx: + output2idx[output] = idx + if arr_iface: + self.wrapper_call.writeline("return output_arrayref_tensors;") + + def generate_before_suffix(self, result): + if not V.graph.is_const_graph: + if V.graph.aot_mode: + result.writeline("} // AOTInductorModel::run_impl") + else: + result.writeline("} // inductor_entry_impl") + + def generate_end(self, result): + if V.graph.aot_mode: + if V.graph.is_const_graph: + result.writeline("} // AOTInductorModel::_const_run_impl") + else: + result.writeline("} // namespace aot_inductor") + result.writeline("} // namespace torch") + return + + # cpp entry function for JIT with cpp wrapper + result.writeline("'''\n)") + result.splice( + f""" + inductor_entry = CppWrapperCodeCache.load_pybinding( + ["std::vector"], cpp_wrapper_src, {self.cuda}, {len(V.graph.graph_outputs)}) + """ + ) + + wrapper_body = "input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]" + if V.graph.constants: + # Append constants to the input args for cpp wrapper. + # Python wrapper directly gets the value inside the wrapper call + # as a global variable passed when calling exec(code, mod.__dict__, mod.__dict__). + # For cpp wrapper, we need to pass this python value to the inductor_entry_impl function explicitly. + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + constants_str = f"[{', '.join(V.graph.constants.keys())}]" + wrapper_body += f""" + constants_tensor = {constants_str} + input_tensors.extend(constants_tensor) + """ + # Convert vector of at::Tensor to vector of AtenTensorHandle. + # If we pass at::Tensor, the compilation will be too slow. + wrapper_body += """ + input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors) + """ + # Release the inputs for memory reuse. + wrapper_body += """ + args.clear() + """ + + # unwrap output tensor back to python scalar + if all(x for x in self.output_is_tensor.values()): + # If no ShapeAsConstantBuffer in the output, directly return the output as tensors + outputs_str = "output_tensors" + else: + outputs = [ + f"output_tensors[{i}]" + if self.output_is_tensor[i] + else f"output_tensors[{i}].item()" + for i in range(len(V.graph.graph_outputs)) + ] + outputs_str = f"[{', '.join(outputs)}]" + wrapper_body += f""" + output_handles = f(input_handles) + output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles) + return {outputs_str} + """ + + # Wrap the func to support setting result._boxed_call = True + result.splice( + f""" + def _wrap_func(f): + def g(args): + {wrapper_body} + return g + + call = _wrap_func(inductor_entry) + """ + ) + + def get_c_shim_func_name(self, kernel): + if not config.abi_compatible or kernel.startswith("aoti_torch_"): + return kernel + + assert "::" in kernel, "Cpp kernel name: " + kernel + " does not contain '::'" + kernel_tokens = kernel.split("::") + kernel_suffix = kernel_tokens[-1] + if kernel_suffix == "call": + kernel_suffix = kernel_tokens[-2] + + shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}" + return shim_fn + + def generate_c_shim_extern_kernel_call(self, kernel, args): + # In the abi_compatible mode, we call fallback aten ops through a C shim layer + # Setting self.allow_stack_allocation to False because the exchange between + # ArrayRefTensor and at::Tensor is still fragile. + self.allow_stack_allocation = False + + wrapped_args = [] + + args_to_print_or_save = None + debug_printer_manager = V.graph.wrapper_code.debug_printer + if ( + debug_printer_manager.debug_printer_level + != IntermediateValueDebuggingLevel.OFF + ): + args_to_print_or_save = [] + + for x in args: + pieces = x.split(", ") + for piece in pieces: + # We only really *need* convert_arrayref_tensor_to_tensor for + # ArrayRefTensors. The code flowing into here uses `0` for nullptr, + # which convert_arrayref_tensor_to_tensor would blindly coerce to int, + # so just avoid wrapping integers. + # Name matching is to find tensor is hacky, but fixing all the + # ArrayRefTensor issues is not a priority for now. + if isinstance(piece, str) and piece.startswith( + ("buf", "arg", "wrap_with_raii_handle_if_needed") + ): + # TODO: The current way to find a 'tensor' type arg is hacky also as mentioned above + # Find a more reliable way to detect tensor kernel args for extern kernel calls + if ( + debug_printer_manager.debug_printer_level + != IntermediateValueDebuggingLevel.OFF + ): + if piece.startswith(("buf", "arg")): + args_to_print_or_save.append(piece) + piece = f"convert_arrayref_tensor_to_tensor({piece})" + wrapped_args.append(piece) + + debug_printer_manager.set_printer_args( + args_to_print_or_save, kernel, None, None + ) + with debug_printer_manager: + shim_fn = self.get_c_shim_func_name(kernel) + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(wrapped_args)}));" + ) + + def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args): + # registered output buffer name + name = extern_kernel.name + output_handle_name = f"{name}_handle" + self.writeline(f"AtenTensorHandle {output_handle_name};") + output_arg = f"&{output_handle_name}" + self.generate_c_shim_extern_kernel_call( + extern_kernel.get_kernel_name(), args + [output_arg] + ) + self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});") + + def generate_extern_kernel_alloc(self, extern_kernel, args): + if config.abi_compatible: + self.generate_c_shim_extern_kernel_alloc(extern_kernel, args) + else: + super().generate_extern_kernel_alloc(extern_kernel, args) + + def generate_c_shim_fallback_kernel(self, fallback_kernel, args): + output_args = [] + output_raii_handles = [] + output_name_base = fallback_kernel.get_name() + for idx, output in enumerate(fallback_kernel.outputs): + if isinstance(output, ir.MultiOutput): + # TODO: handle integer output (e.g., as in attention) + name = f"{output.get_name()}" + output_handle_name = f"{name}_handle" + if output.indices: + assert ( + output.indices[0][1] == idx + ), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}" + self.writeline(f"AtenTensorHandle {output_handle_name};") + output_args.append(f"&{output_handle_name}") + output_raii_handles.append( + f"RAIIAtenTensorHandle {name}({output_handle_name});" + ) + elif isinstance(output, int): + output_name = f"{output_name_base}_{idx}" + self.writeline(f"int64_t {output_name} = {output};") + output_args.append(f"&{output_name}") + elif isinstance(output, sympy.Symbol): + output_name = f"{output_name_base}_{idx}" + self.writeline(f"auto {output_name} = {output};") + output_args.append(f"&{output_name}") + elif output is None: + output_args.append("nullptr") + else: + raise NotImplementedError(f"unsupported type of {output=}") + args = args + output_args + self.generate_c_shim_extern_kernel_call(fallback_kernel.cpp_kernel_name, args) + for raii_handle in output_raii_handles: + self.writeline(raii_handle) + + def generate_fallback_kernel(self, fallback_kernel, args): + if config.abi_compatible: + self.generate_c_shim_fallback_kernel(fallback_kernel, args) + else: + super().generate_fallback_kernel(fallback_kernel, args) + + def generate_extern_kernel_out( + self, kernel: str, out: str, out_view: Optional[str], args: List[str] + ): + if out_view: + out_name = f"{out}_as_strided" + self.writeline(f"auto {out_name} = {out_view};") + args.insert(0, out_name) + else: + args.insert(0, out) + + if config.abi_compatible: + self.generate_c_shim_extern_kernel_call(kernel, args) + else: + # TODO: add debug printing info for non-abi compatible mode extern kernel call + self.writeline(self.wrap_kernel_call(kernel, args)) + + def generate_scatter_fallback( + self, + output, + inputs, + cpp_kernel_name, + python_kernel_name, + src_is_tensor, + reduce, + kwargs, + ): + # No stack allocation when there is a fallback op + self.allow_stack_allocation = False + + if config.abi_compatible: + # call the ABI shim function instead of the ATen one + cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name) + # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py + cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" + inputs_wrapped = [ + f"convert_arrayref_tensor_to_tensor({x})" + if isinstance(x, str) + else str(x) + for x in inputs + ] + line = f"{cpp_kernel_name}(convert_arrayref_tensor_to_tensor({output}), {','.join(inputs_wrapped)}" + else: + line = f"{cpp_kernel_name}({','.join(map(str, inputs))}" + + if python_kernel_name.startswith("aten.scatter_reduce"): + line += f", {','.join(kwargs)}" + else: + if src_is_tensor: + if reduce: + line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" + else: + assert ( + reduce is None + ), "Expect reduce to be None for aten.scatter_ with scalar src" + line += ");" + self.writeline(line) + + def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + # No stack allocation when there is a fallback op + self.allow_stack_allocation = False + + # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version + if config.abi_compatible: + # See the comment in codegen_reinterpret_view about why having something like + # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding + # tensor prematurely deallocated, thus this std::vector().data() trick here. + indices_str = ( + "std::vector{" + + ( + ", ".join( + [f"convert_arrayref_tensor_to_tensor({ind})" for ind in indices] + ) + ) + + "}.data()" + ) + args = [ + f"convert_arrayref_tensor_to_tensor({x})", + indices_str, + str(len(indices)), + f"convert_arrayref_tensor_to_tensor({values})", + accumulate, + ] + args.insert( + 0, f"convert_arrayref_tensor_to_tensor({x})" + ) # set x as the output tensor, this fallback mutates x. + else: + indices_str = ( + f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}" + ) + args = [x, indices_str, values, accumulate] + args.insert(0, x) # set x as the output tensor, this fallback mutates + + self.writeline(self.wrap_kernel_call(kernel, args)) + + def add_benchmark_harness(self, output): + if V.graph.aot_mode: + return + super().add_benchmark_harness(output) + + def codegen_sizevar(self, x: Expr) -> str: + return self.expr_printer(V.graph.sizevars.simplify(x)) + + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: + if config.abi_compatible: + # in the abi_compatible mode, outputs are returned via arguments + return name + else: + return f"std::get<{index}>({basename})" + + def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: + parts = list(map(self.codegen_sizevar, shape)) + if len(parts) == 0: + return "{}" + if len(parts) == 1: + return f"{{{parts[0]}, }}" + return f"{{{', '.join(parts)}}}" + + def codegen_dynamic_scalar(self, node): + (data,) = (t.codegen_reference() for t in node.inputs) + if config.abi_compatible: + self.codegen_tensor_item( + node.inputs[0].get_dtype(), data, f"{node.sym}_raw" + ) + else: + convert_type = DTYPE_TO_ATEN[node.inputs[0].get_dtype()].replace( + "at::k", "to" + ) + self.writeline(f"auto {node.sym}_raw = {data}.item().{convert_type}();") + + if len(node.keypath) == 0: + self.writeline(f"auto {node.sym} = {node.sym}_raw;") + elif len(node.keypath == 1) and isinstance(node.keypath[0], ConvertIntKey): + self.writeline(f"int64_t {node.sym} = {node.sym}_raw ? 1 : 0;") + elif len(node.keypath == 1) and isinstance(node.keypath[0], DivideByKey): + # TODO: assert divisibility here + self.writeline( + f"int64_t {node.sym} = {node.sym}_raw / {node.keypath[0].divisor};" + ) + else: + raise AssertionError(f"unrecognized keypath {node.keypath}") + + # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again + self.unbacked_symbol_decls.add(str(node.sym)) + + def can_stack_allocate_buffer(self, buffer): + return ( + self.allow_stack_allocation + and buffer.get_device().type == "cpu" + and self.can_prove_buffer_has_static_shape(buffer) + and ir.is_contiguous_strides_for_shape( + buffer.get_stride(), buffer.get_size() + ) + ) + + def make_buffer_free(self, buffer): + return ( + "" + if isinstance(buffer.get_layout(), ir.MultiOutputLayout) + or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers) + or ( + config.use_minimal_arrayref_interface + and V.graph.aot_mode + and buffer.get_name() in V.graph.graph_inputs + ) + else f"{buffer.get_name()}.reset();" + ) + + def make_free_by_names(self, names_to_del: List[str]): + return " ".join(f"{name}.reset();" for name in names_to_del) + + def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): + if config.abi_compatible: + return f"auto {new_name} = std::move({old_name}); // reuse" + else: + return super().codegen_exact_buffer_reuse(old_name, new_name, del_line) + + def generate_profiler_mark_wrapper_call(self, stack): + self.wrapper_call.writeline( + 'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef());' + ) + + def write_triton_header_once(self): + pass + + def generate_start_graph(self): + pass + + def generate_end_graph(self): + pass + + def generate_inf_and_nan_checker(self, nodes): + for buf in nodes.get_names(): + # TODO: Add buf name directly into check_inf_and_nan. + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan({buf}));" + ) + + def codegen_device(self, device): + if config.abi_compatible: + self.used_cached_devices.add(device.type) + return f"cached_torch_device_type_{device.type}, {device.index if device.index else 0}" + else: + return ( + f"c10::Device({DEVICE_TO_ATEN[device.type]}, {device.index})" + if device.index is not None + else f"{DEVICE_TO_ATEN[device.type]}" + ) + + def codegen_dtype(self, dtype): + if config.abi_compatible: + dtype_str = str(dtype).split(".")[-1] + self.used_cached_dtypes.add(dtype_str) + return f"cached_torch_dtype_{dtype_str}" + else: + return DTYPE_TO_ATEN[dtype] + + def codegen_layout(self, layout): + if config.abi_compatible: + layout_str = str(layout).split(".")[-1] + self.used_cached_layouts.add(layout_str) + return f"cached_torch_layout_{layout_str}" + else: + return LAYOUT_TO_ATEN[layout] + + @functools.lru_cache(None) # noqa: B019 + def codegen_int_array_var( + self, + int_array: str, + writer=None, + known_statically=False, + graph=None, # for per-graph caching + ): + # This is used for size/stride declaration + # Because the memory planning is done in two passes (see the implementation + # of self.generate), the writeline behavior is different in the two passes. + # As a result, the emitted int array declarations may appear in a later + # position of the generated code, so the second pass codegen should not + # reuse int array declarations generated in the first pass + if writer is None: + # The first pass codegen uses `self` as the writer + writer = self + + var = f"int_array_{next(self.int_array_id)}" + ctype = "int64_t" + if var not in self.declared_int_array_vars: + self.declared_int_array_vars.add(var) + if known_statically: + writer.writeline(f"static constexpr {ctype} {var}[] = {int_array};") + else: + writer.writeline(f"const {ctype} {var}[] = {int_array};") + return var + + def make_buffer_allocation(self, buffer): + return self.make_allocation( + buffer.get_name(), + buffer.get_device(), + buffer.get_dtype(), + buffer.get_size(), + buffer.get_stride(), + buffer if self.can_stack_allocate_buffer(buffer) else None, + ) + + def make_allocation( + self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None + ): + orig_stride = stride + device_str = self.codegen_device(device) + dtype_code = self.codegen_dtype(dtype) + size = self.codegen_shape_tuple(shape) + stride = self.codegen_shape_tuple(orig_stride) + if config.abi_compatible: + size_array_var = self.codegen_int_array_var( + size, + self.wrapper_call, + known_statically=self.is_statically_known_list_of_ints(shape), + graph=self.get_codegened_graph(), + ) + stride_array_var = self.codegen_int_array_var( + stride, + self.wrapper_call, + known_statically=self.is_statically_known_list_of_ints(orig_stride), + graph=self.get_codegened_graph(), + ) + device_type, device_id = device_str.split(",") + device_idx = "this->device_idx_" if V.graph.aot_mode else device_id + if buffer_if_can_stack_allocate is not None: + self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate + cpp_type = DTYPE_TO_CPP[dtype] + numel = buffer_if_can_stack_allocate.get_numel() + # Note: we don't zero storage because empty_strided doesn't zero either. + self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];") + args = [ + f"{name}_storage", + size_array_var, + stride_array_var, + device_type, + device_idx, + ] + return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});" + + args = [ + str(len(shape)), + size_array_var, + stride_array_var, + dtype_code, + device_type, + device_idx, + f"&{name}_handle", + ] + + self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" + ) + + return f"RAIIAtenTensorHandle {name}({name}_handle);" + + if V.graph.aot_mode and device_str.startswith("c10::Device("): + tensor_device = f"{device_str.split(',')[0]}, this->device_idx_)" + else: + tensor_device = device_str + + if device.type == "cpu": + return f"at::Tensor {name} = at::detail::empty_strided_cpu({size}, {stride}, {dtype_code});" + if device.type == "cuda": + return ( + f"at::Tensor {name} = at::detail::empty_strided_cuda(" + f"{size}, {stride}, {dtype_code}, c10::DeviceType::CUDA);" + ) + return ( + f"{self.declare}{name} = {self.namespace}empty_strided(" + f"{size}, {stride}, at::TensorOptions({tensor_device}).dtype({dtype_code})){self.ending}" + ) + + def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + if config.abi_compatible: + size = self.codegen_shape_tuple(shape) + stride = self.codegen_shape_tuple(stride) + tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" + args = [ + name, + self.expr_printer(offset), # bytes not numel + self.codegen_dtype(dtype), + str(len(shape)), + self.codegen_int_array_var( + size, self.wrapper_call, graph=self.get_codegened_graph() + ), + self.codegen_int_array_var( + stride, self.wrapper_call, graph=self.get_codegened_graph() + ), + f"&{tmp_name}", + ] + self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));" + ) + return f"RAIIAtenTensorHandle({tmp_name})" + + return "alloc_from_pool({})".format( + ", ".join( + [ + name, + self.expr_printer(offset), # bytes not numel + self.codegen_dtype(dtype), + self.codegen_shape_tuple(shape), + self.codegen_shape_tuple(stride), + ] + ) + ) + + def codegen_reinterpret_view( + self, data, size_list, stride_list, offset, writer, dtype=None + ) -> str: + dim = str(len(size_list)) + original_offset = offset + size = self.codegen_shape_tuple(size_list) + stride = self.codegen_shape_tuple(stride_list) + offset = self.codegen_sizevar(offset) + call_strs = [] + if config.abi_compatible: + final_tmp_name = None + final_tmp_name_is_RAIIAtenTensorHandle = False + + def create_reinterpret_call(): + tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" + args = [ + f"{data.get_name()}", + dim, + self.codegen_int_array_var( + size, + writer, + known_statically=self.is_statically_known_list_of_ints( + size_list + ), + graph=self.get_codegened_graph(), + ), + self.codegen_int_array_var( + stride, + writer, + known_statically=self.is_statically_known_list_of_ints( + stride_list + ), + graph=self.get_codegened_graph(), + ), + offset, + ] + call_str = ( + f"auto {tmp_name} = reinterpret_tensor_wrapper({', '.join(args)});" + ) + return tmp_name, call_str + + def create_dtypeview_call(reinterpret_call): + tmp_AtenTensorHandle = ( + f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" + ) + call_strs = [f"AtenTensorHandle {tmp_AtenTensorHandle};"] + dtype_name = str(dtype).split(".")[-1] + device_name = "cuda" if data.layout.device.type == "cuda" else "cpu" + get_dtype_function = f"aoti_torch_dtype_{dtype_name}" + dtypeview_function = f"aoti_torch_{device_name}_view_dtype" + call_strs.append( + f"AOTI_TORCH_ERROR_CODE_CHECK({dtypeview_function}" + f"({reinterpret_call}, {get_dtype_function}(), &{tmp_AtenTensorHandle}));" + ) + tmp_RAIIAtenTensorHandle = ( + f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}_handle" + ) + call_strs.append( + f"RAIIAtenTensorHandle {tmp_RAIIAtenTensorHandle}({tmp_AtenTensorHandle});" + ) + return tmp_RAIIAtenTensorHandle, call_strs + + if ( + size_list == data.layout.size + and stride_list == data.layout.stride + and original_offset == data.layout.offset + ): + # pure dtypeview + if dtype is not None and dtype != data.dtype: + tmp_output_name, tmp_call_strs = create_dtypeview_call( + data.get_name() + ) + call_strs.extend(tmp_call_strs) + final_tmp_name = tmp_output_name + final_tmp_name_is_RAIIAtenTensorHandle = True + else: + return f"{data.get_name()}" + else: + # firstly create reinterpretview + final_tmp_name, reinterpret_call = create_reinterpret_call() + call_strs.append(reinterpret_call) + + if dtype is not None and dtype != data.dtype: + # wrap it with dtypeview + final_tmp_name, tmp_call_strs = create_dtypeview_call( + reinterpret_call + ) + call_strs.extend(tmp_call_strs) + # Because the memory planning is done in two passes (see the implementation + # of self.generate), the writeline behavior is different in the two passes. + if writer is None: + writer = self + writer.writelines(call_strs) + if ( + self.can_stack_allocate_buffer(data) + and self.is_statically_known_list_of_ints(size_list) + and self.is_statically_known_list_of_ints(stride_list) + and ir.is_contiguous_strides_for_shape(stride_list, size_list) + ): + return final_tmp_name + + # NB, the return handle here represents a temporary tensor, which will be automatically + # released. + # Here's a sample usage in the cpp wrapper code: + # ``` + # aoti_torch_addmm_out( + # buf1, + # arg1_1, + # RAIIAtenTensorHandle(tmp_tensor_handle_0), + # buf0, + # 1L, + # 1L)); + # ``` + # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out. + # This could be problematic when it's used in a different pattern, for example: + # ```` + # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}; + # aoti_torch_proxy_executor_call_function(..., tensor_args); + # ```` + # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter + # kernel call. + # + # This is solved by updating the proxy_executor invocation to + # ``` + # aoti_torch_proxy_executor_call_function(..., + # std::vector{ + # RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6 + # }.data() + # ); + # ``` + if not final_tmp_name_is_RAIIAtenTensorHandle: + return f"wrap_with_raii_handle_if_needed({final_tmp_name})" + else: + return final_tmp_name + else: + args = [data.get_name(), size, stride, offset] + return f"reinterpret_tensor({', '.join(args)})" + + def codegen_device_copy(self, src, dst): + if config.abi_compatible: + # aoti_torch_tensor_copy_ takes AtenTensorHandle as input, + # while stack-allocation results in ArrayRefTensor + # so disable stack allocation here + self.allow_stack_allocation = False + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_tensor_copy_(expensive_copy_to_tensor_if_needed({src}), {dst}));" + ) + else: + self.writeline(f"{dst}.copy_({src});") + + def codegen_multi_output(self, name, value): + # in the abi_compatible mode, outputs are retrieved by passing + # output pointers, so we skip its codegen here. + if not config.abi_compatible: + super().codegen_multi_output(name, value) + + def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): + for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs): + if config.abi_compatible: + # in ABI-compatible mode, we copy the underlying at::Tensor of the conditional + # input (outer_input) into another at::Tensor to be used as a subgraph input + # (inner_input) in the nested scope. we can't std::move here, as the codegened + # outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we + # can't necessarily std::move it back to the origin (x). + self.writeline(f"AtenTensorHandle {inner_input}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));" + ) + self.writeline( + f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);" + ) + else: + self.writeline( + f"{self.declare}{inner_input} = {outer_input}{self.ending}" + ) + + def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): + for inner_output, outer_output in zip( + subgraph.graph.graph_outputs, outer_outputs + ): + src = inner_output.codegen_reference() + if config.abi_compatible: + # in ABI-compatible mode, we need to std::move subgraph output (inner_output) + # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy + # constructor is deleted. + src = f"std::move({src})" + # in case the outer_output carried a value + # before (e.g., in the while_loop codegen) + self.writeline(f"{outer_output}.reset();") + self.writeline(f"{outer_output} = {src}{self.ending}") + + def codegen_conditional(self, conditional): + name = conditional.get_name() + outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands] + if config.abi_compatible: + outer_outputs = [] + for out in conditional.outputs: + # in ABI-compatible mode, ir.MultiOutput is not codegened, + # hence pre-declare output variables directly and separately + self.writeline(f"RAIIAtenTensorHandle {out.get_name()};") + outer_outputs.append(out.get_name()) + + if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): + # in ABI-compatible mode, we need to use the ABI shim function + # to extract a C++ bool from the unrelying scalar bool Tensor + predicate = f"{conditional.predicate.get_name()}_scalar" + self.codegen_tensor_item( + torch.bool, + conditional.predicate.codegen_reference(), + predicate, + ) + else: + # the predicate is not a Tensor: SymBool or Python bool + predicate = conditional.predicate.codegen_reference() + else: + # in non-ABI-compatible mode, we can codegen the conditional outputs + # as array of at::Tensor instances, as the ir.MultiOutput is codegened + outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] + self.writeline(f"at::Tensor {name}[{len(conditional.outputs)}];") + predicate = f"{conditional.predicate.codegen_reference()}" + if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): + # move the Tensor predicate to host + predicate = f"{predicate}.item()" + + self.writeline(f"if ({predicate}) {{") + self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) + self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("} else {") + self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) + self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("}") + + def codegen_while_loop(self, while_loop): + name = while_loop.get_name() + outer_carried_inputs = [ + buf.codegen_reference() for buf in while_loop.carried_inputs + ] + outer_additional_inputs = [ + buf.codegen_reference() for buf in while_loop.additional_inputs + ] + cond_result_name = f"{name}_cond_result" + + if config.abi_compatible: + self.writeline(f"RAIIAtenTensorHandle {cond_result_name};") + + cond_outer_inputs = [] + for inp, out in zip(outer_carried_inputs, while_loop.outputs): + # in ABI-compatible mode, the carried inputs are codegened + # as buffers outside the while loop and set to the initial + # values. at the end of each while_loop iteration, they + # will be assined the carried values. + out_name = out.get_name() + self.writeline(f"AtenTensorHandle {out_name}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({inp}, &{out_name}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {out_name}({out_name}_handle);") + cond_outer_inputs.append(out_name) + + # additional inputs will be assinged within the while_loop + # iteration directly from the corresponding outer graph buffers + cond_outer_inputs.extend(outer_additional_inputs) + else: + self.writeline(f"at::Tensor {cond_result_name};") + self.writeline(f"at::Tensor {name}[{len(outer_carried_inputs)}];") + for i, inp in enumerate(outer_carried_inputs): + # set the initial state before the loop + self.writeline(f"{name}[{i}] = {inp};") + + cond_outer_inputs = [ + *[f"{name}[{i}]" for i in range(len(outer_carried_inputs))], + *outer_additional_inputs, + ] + + cond_outer_outputs = [cond_result_name] + body_outer_inputs = list(cond_outer_inputs) + body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)] + + self.writeline("while (1) {") + self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph)) + self.codegen_subgraph( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + + if config.abi_compatible: + cond_result = f"{cond_result_name}_scalar" + self.codegen_tensor_item(torch.bool, cond_result_name, cond_result) + else: + cond_result = f"{cond_result_name}.item()" + self.writeline(f"if (!{cond_result}) break;") + + self.writeline(ExitSubgraphLine(self)) + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + self.codegen_subgraph( + while_loop.body_subgraph, body_outer_inputs, body_outer_outputs + ) + self.writeline(ExitSubgraphLine(self)) + self.writeline("}") + + def generate_extern_kernel_args_decl_if_needed( + self, op_overload, raw_args, output_args + ): + arg_types = [x.real_type for x in op_overload._schema.arguments] + return_types = [x.type for x in op_overload._schema.returns] + + new_tensor_args = [] + new_int_args = [] + + def fill_args(arg, arg_type): + static_arg_types = ( + torch.FloatType, + torch.BoolType, + torch.StringType, + torch.Type, + torch.DeviceObjType, + ) + inductor_tensor_buffers = ( + ir.Buffer, + ir.ReinterpretView, + ) + + if isinstance(arg_type, torch.TensorType): + assert isinstance(arg, inductor_tensor_buffers), f"got {type(arg)}" + new_tensor_args.append(f"{arg.codegen_reference()}") + elif isinstance(arg_type, torch.IntType): + # int + new_int_args.append(str(arg)) + elif isinstance(arg_type, torch.SymIntType): + # SymInt + expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg + new_int_args.append(self.expr_printer(expr)) + elif isinstance(arg_type, torch.NumberType): + # Scalar of type int + assert isinstance(arg, (int, float, bool)) + # Only treat int Scalar as dynamic + if isinstance(arg, int): + new_int_args.append(str(arg)) + elif isinstance(arg_type, torch.ListType): + assert isinstance(arg, (list, tuple)) + + # List[Tensor] + if isinstance(arg_type.getElementType(), torch.TensorType): + new_tensor_args.extend([f"{a.codegen_reference()}" for a in arg]) + # List[Optional[Tensor]] + elif isinstance( + arg_type.getElementType(), torch.OptionalType + ) and isinstance( + arg_type.getElementType().getElementType(), torch.TensorType + ): + new_tensor_args.extend( + [f"{a.codegen_reference()}" for a in arg if a is not None] + ) + # List[int] + elif isinstance(arg_type.getElementType(), torch.IntType): + new_int_args.extend([str(a) for a in arg]) + # List[SymInt] + elif isinstance(arg_type.getElementType(), torch.SymIntType): + expressions = [ + a.node.expr if isinstance(a, torch.SymInt) else a for a in arg + ] + new_int_args.extend( + [self.expr_printer(expr) for expr in expressions] + ) + # List[Scalar] + elif isinstance(arg_type.getElementType(), torch.NumberType): + # Only treat int Scalar as dynamic + is_int_type = [isinstance(a, int) for a in arg] + if any(is_int_type): + assert all( + is_int_type + ), "AOTInductor only supports int scalars of the same type" + new_int_args.extend([str(a) for a in arg]) + else: + assert isinstance( + arg_type.getElementType(), static_arg_types # type: ignore[arg-type] + ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + else: + assert isinstance( + arg_type, static_arg_types # type: ignore[arg-type] + ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + + for arg, arg_type in zip(raw_args, arg_types): + if arg is not None: + if isinstance(arg_type, torch.OptionalType): + fill_args(arg, arg_type.getElementType()) + else: + fill_args(arg, arg_type) + + def fill_output_arg(arg, return_type): + if isinstance(return_type, torch.TensorType): + self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);") + new_tensor_args.append(f"{arg}") + elif isinstance(return_type, torch.SymIntType): + raise NotImplementedError("NYI support for return type: SymInt") + elif isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.SymIntType + ): + raise NotImplementedError("NYI support for return type: List[SymInt]") + else: + raise AssertionError(f"Unsupported return type found: {return_type}") + + # TODO: Only support tensor(s) returns for now, SymInt is not implemented yet + for return_type in return_types: + if isinstance(return_type, (torch.TensorType)): + pass + elif isinstance(return_type, torch.OptionalType): + assert isinstance(return_type.getElementType(), torch.TensorType) + elif isinstance(return_type, torch.ListType): + assert isinstance(return_type.getElementType(), torch.TensorType) + else: + raise NotImplementedError( + f"return type {return_type} is not yet supported." + ) + + for output_arg in output_args: + assert output_arg is not None, "Optional return types are not yet supported" + if isinstance(output_arg, (list, tuple)): + for out in output_arg: + fill_output_arg(out, torch.TensorType.get()) + else: + fill_output_arg(output_arg, torch.TensorType.get()) + + return new_tensor_args, new_int_args + + def generate_extern_kernel_alloc_and_find_schema_if_needed( + self, + buf_name: str, + python_kernel_name: str, + cpp_kernel_name: str, + codegen_args: List[str], + cpp_op_schema: str, + cpp_kernel_key: str, + cpp_kernel_overload_name: str = "", + op_overload: Optional[torch._ops.OpOverload] = None, + raw_args=None, + outputs=None, + ): + # No stack allocation when there is a fallback op + self.allow_stack_allocation = False + + def extract_output_name(out): + if out is None: + # Because out is not a MultiOutput, we assume the kernel returns a single output + return [buf_name] + elif isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): + return out.get_name() + elif isinstance(out, (list, tuple)): + return type(out)(extract_output_name(o) for o in out) + else: + raise AssertionError(f"Unexpected output: {type(out)}") + + # output_args has the same pytree structure as outputs + output_args = None + if config.abi_compatible: + output_args = extract_output_name(outputs) + if isinstance(output_args, str): + output_args = [output_args] + + if V.graph.aot_mode and config.abi_compatible: + assert op_overload is not None + assert raw_args is not None + assert outputs is not None + + return self.generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor( + cpp_kernel_key, + op_overload, + raw_args, + output_args, + ) + else: + return self.generate_extern_kernel_alloc_and_find_schema_if_needed_jit( + buf_name, + python_kernel_name, + cpp_kernel_name, + codegen_args, + cpp_op_schema, + cpp_kernel_key, + cpp_kernel_overload_name, + op_overload, + raw_args, + output_args, + ) + + def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope): + scoped_lines = IndentedBuffer() + for declaration in declarations_before_scope: + scoped_lines.writeline(declaration) + + scoped_lines.writeline("{") + with scoped_lines.indent(): + scoped_lines.writeline("py::gil_scoped_acquire acquire;") + scoped_lines.writelines(lines_in_scope.split("\n")) + scoped_lines.writelines("}") + return scoped_lines._lines + + def load_custom_op_wrapper(self): + # TODO: need to support control flow + if self.custom_op_wrapper_loaded: + return + + lines = """ +RAIIPyObject codecache_module(PyImport_ImportModule("torch._inductor.codecache")); +if (codecache_module.get() == NULL) { + throw std::runtime_error("Failed to load torch._inductor.codecache"); +} +custom_op_wrapper = PyObject_GetAttrString(codecache_module, "custom_op_wrapper"); +if (custom_op_wrapper.get() == NULL) { + throw std::runtime_error("Failed to load torch._inductor.codecache.custom_op_wrapper"); +}""" + + declarations_before_scope = ["RAIIPyObject custom_op_wrapper;"] + scope_gil_acquire = self.generate_scoped_gil_acquire( + declarations_before_scope, lines + ) + self.writelines(scope_gil_acquire) + + self.custom_op_wrapper_loaded = True + + def generate_py_arg(self, py_args_var, idx, raw_arg, arg_type): + def generate_py_arg_inner(lines, raw_arg, arg_type): + if raw_arg is None: + # Py_None is a singleton, so we have to explicitly incref it here + lines.append("Py_INCREF(Py_None);\n") + return "Py_None" + elif isinstance(arg_type, torch.TensorType): + # Store AtenTensorHandle as void* + base_handle = raw_arg.codegen_reference() + ( + tmp_raii_handle_var, + tmp_raii_handle_var_decl, + ) = self.create_tmp_raii_handle_var(base_handle) + if tmp_raii_handle_var: + lines.append(tmp_raii_handle_var_decl) + base_handle = tmp_raii_handle_var + return f"PyCapsule_New(reinterpret_cast({base_handle}.get()), NULL, NULL)" + elif isinstance(arg_type, torch.OptionalType): + return generate_py_arg_inner(lines, raw_arg, arg_type.getElementType()) + elif isinstance(arg_type, torch.IntType): + # int + return f"PyLong_FromLongLong({raw_arg})" + elif isinstance(arg_type, torch.SymIntType): + # SymInt + expr = ( + raw_arg.node.expr if isinstance(raw_arg, torch.SymInt) else raw_arg + ) + return f"PyLong_FromLongLong({self.expr_printer(expr)})" + elif isinstance(arg_type, torch.FloatType): + return f"PyFloat_FromDouble({raw_arg})" + elif isinstance(arg_type, torch.BoolType): + return f"PyBool_FromLong({1 if raw_arg else 0})" + elif isinstance(arg_type, torch.StringType): + return f'PyUnicode_FromString("{raw_arg}")' + elif isinstance(arg_type, torch.NumberType): + # Union[bool, int, float, complex] + # torch/_prims_common/__init__.py + if isinstance(raw_arg, int): + return f"PyLong_FromLongLong({raw_arg})" + elif isinstance(raw_arg, float): + return f"PyFloat_FromDouble({raw_arg})" + elif isinstance(raw_arg, bool): + return f"PyBool_FromLong({1 if raw_arg else 0})" + elif isinstance(raw_arg, complex): + return f"PyComplex_FromDoubles({raw_arg.real, raw_arg.imag})" + elif isinstance(raw_arg, torch.SymInt): + expr = raw_arg.node.expr + return f"PyLong_FromLongLong({self.expr_printer(expr)})" + else: + raise NotImplementedError( + f"arg type {arg_type} with raw_arg {raw_arg}, {type(raw_arg)} is not yet supported by custom_op_wrapper" + ) + elif isinstance(raw_arg, torch.dtype): + # dtype + self.include_extra_header("torch/csrc/DynamicTypes.h") + return f"Py_NewRef(torch::getTHPDtype(static_cast({self.codegen_dtype(raw_arg)})))" + else: + raise NotImplementedError( + f"arg type {arg_type} is not yet supported by custom_op_wrapper" + ) + + lines = [] + if isinstance(arg_type, torch.ListType): + assert isinstance(raw_arg, (list, tuple)), str(raw_arg) + " is not a list" + lines.append( + f"PyObject* {py_args_var}_{idx} = PyList_New({len(raw_arg)});\n" + ) + for i, elem in enumerate(raw_arg): + lines.append( + f"PyList_SetItem({py_args_var}_{idx}, {i}, {generate_py_arg_inner(lines, elem, arg_type.getElementType())});\n" + ) + lines.append( + f"PyTuple_SetItem({py_args_var}, {idx}, {py_args_var}_{idx});\n" + ) + else: + lines.append( + f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(lines, raw_arg, arg_type)});\n" + ) + return "".join(lines) + + def generate_extern_kernel_alloc_and_find_schema_if_needed_jit( + self, + buf_name: str, + python_kernel_name: str, + cpp_kernel_name: str, + codegen_args: List[str], + cpp_op_schema: str, + cpp_kernel_key: str, + cpp_kernel_overload_name: str = "", + op_overload: Optional[torch._ops.OpOverload] = None, + raw_args=None, + output_args: Optional[List[str]] = None, + ): + if not config.abi_compatible: + # Will update this to use an OSS version ProxyExecutor + if cpp_kernel_key not in self.extern_call_ops: + self.writeline( + f"static auto op_{cpp_kernel_key} = c10::Dispatcher::singleton()" + ) + self.writeline( + f'\t.findSchemaOrThrow("{cpp_kernel_name}", "{cpp_kernel_overload_name}")' + ) + self.writeline(f"\t.typed<{cpp_op_schema}>();") + self.extern_call_ops.add(cpp_kernel_key) + + self.writeline( + f"auto {buf_name} = op_{cpp_kernel_key}.call({', '.join(codegen_args)});" + ) + else: + # In the JIT mode, because of the ABI-compatible requirement, we can't directly call + # c10::Dispatcher to find the custom op and call it. Instead, we go back to Python + # to invoke this custom op. + self.load_custom_op_wrapper() + + assert output_args is not None, "output_args should not be None" + num_args = len(raw_args) + py_args_var = f"py_args_{next(self.arg_var_id)}" + # First arg is always the python op name + lines = f""" +RAIIPyObject {py_args_var}(PyTuple_New({num_args+1})); +if ({py_args_var}.get() == NULL) {{ + throw std::runtime_error("PyTuple_New {py_args_var} failed"); +}} +PyTuple_SetItem({py_args_var}, 0, PyUnicode_FromString("{python_kernel_name}")); +""" + + assert op_overload is not None, "op_overload should not be None" + + for idx, (raw_arg, schema_arg) in enumerate( + zip(raw_args, op_overload._schema.arguments) + ): + lines += self.generate_py_arg( + py_args_var, idx + 1, raw_arg, schema_arg.real_type + ) + + lines += f""" +// Call the custom op in Python +RAIIPyObject py_{buf_name}(PyObject_CallObject(custom_op_wrapper, {py_args_var})); +if (py_{buf_name}.get() == NULL) {{ + throw std::runtime_error("PyObject_CallObject {python_kernel_name} failed"); +}}""" + + if len(output_args) == 1: + # result is a single tensor + lines += f""" +{output_args[0]} = reinterpret_cast(PyCapsule_GetPointer(py_{buf_name}.get(), NULL));""" + else: + # result is a tuple of tensors + for idx, output_arg in enumerate(output_args): + lines += f""" +{output_arg} = + reinterpret_cast(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));""" + + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for idx, output_arg in enumerate(output_args) + ] + scope_gil_acquire = self.generate_scoped_gil_acquire( + declarations_before_scope, lines + ) + self.writelines(scope_gil_acquire) + + def generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor( + self, + cpp_kernel_key, + op_overload, + raw_args, # contains both args and flatten kwargs + output_args: Optional[List[str]] = None, + ): + ( + tensor_call_args, + int_call_args, + ) = self.generate_extern_kernel_args_decl_if_needed( + op_overload, raw_args, output_args + ) + + tensor_call_args_str = ", ".join(tensor_call_args) + int_call_args_str = ", ".join(int_call_args) + + extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1 + + self.writeline( + f"aoti_torch_proxy_executor_call_function(proxy_executor, " + f"{extern_kernel_node_index}, " + f"{len(int_call_args)}, " + f"std::vector{{{int_call_args_str}}}.data(), " + f"{len(tensor_call_args)}, " + f"std::vector{{{tensor_call_args_str}}}.data());" + ) + + self.extern_call_ops.add(cpp_kernel_key) + + def generate_reset_kernel_saved_flags(self): + pass + + def generate_save_uncompiled_kernels(self): + pass + + def c_type_for_prim_type(self, val, type_) -> str: + assert ( + config.abi_compatible + ), "c_type_for_prim_type is only used in ABI compatible mode" + if isinstance(type_, torch.OptionalType): + return f"{self.c_type_for_prim_type(val, type_.getElementType())}*" + elif isinstance(type_, torch.TensorType): + return "AtenTensorHandle" + elif isinstance(type_, (torch.IntType, torch.SymIntType)): + return "int64_t" + elif isinstance( + type_, (torch.BoolType, torch.SymBoolType, torch.EnumType) + ) or repr(type_) in ("ScalarType", "Layout"): + return "int32_t" + elif isinstance(type_, torch.FloatType): + return "double" + elif isinstance(type_, torch.NumberType): + if isinstance(val, bool): + return "int32_t" + elif isinstance(val, int): + return "int64_t" + elif isinstance(val, float): + return "double" + elif val is None: + # This could happen when val is an optional value + return "double" + else: + raise AssertionError( + f"Unexpected type in c_type_for_prim_type: {type_=}" + ) + elif isinstance(type_, torch.StringType): + return "const char*" + else: + raise AssertionError(f"Unexpected type in c_type_for_prim_type: {type_=}") + + def val_to_arg_str_for_prim_type(self, val, type_) -> str: + # TODO: not using type_ as the first step of refactoring. Will update this later. + if isinstance(val, bool): + if config.abi_compatible: + return "1" if val else "0" + else: + return "true" if val else "false" + elif isinstance(val, int): + # uint64_t is long on Linux, but long long on MacOS and Windows + return f"{val}LL" if sys.platform in ["darwin", "win32"] else f"{val}L" + elif isinstance(val, str): + return f'"{val}"' + elif isinstance( + val, (ir.Buffer, ir.ReinterpretView, ir.StorageBox, ir.TensorBox) + ): + return val.codegen_reference() + elif isinstance(val, torch.device): + return self.codegen_device(val) + elif isinstance(val, torch.dtype): + return self.codegen_dtype(val) + elif isinstance(val, float) and val in [float("inf"), float("-inf")]: + if val == float("inf"): + return "std::numeric_limits::infinity()" + else: + return "-std::numeric_limits::infinity()" + elif isinstance(val, (list, tuple)): + # FIXME: This happens because type_ is not always properly set to torch.ListType + return f"{{{', '.join(self.val_to_arg_str(x, None) for x in val)}}}" + elif isinstance(val, SymTypes): + return self.expr_printer(val.node.expr) + elif isinstance(val, sympy.Expr): + return self.expr_printer(val) + else: + return repr(val) + + def val_to_arg_str(self, val, type_=None) -> str: + if val is None: + # None needs special care. It either represent nullopt or an empty tensor + if config.abi_compatible: + if type_ is None or isinstance(type_, torch.OptionalType): + if type_ is not None and isinstance( + type_.getElementType(), + ( + torch.ListType, + torch.TupleType, + torch.DeviceObjType, + ), + ): + return "0, 0" + else: + return "0" # nullptr is not available in C + elif isinstance(type_, torch.TensorType): + # create an empty tensor, the equivalent of at::Tensor() + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"AtenTensorHandle {var_name}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));" + ) + self.writeline( + f"RAIIAtenTensorHandle {var_name}({var_name}_handle);" + ) + return var_name + else: + raise AssertionError("Can not map None to a known data type") + else: + return "std::nullopt" + + if isinstance(type_, torch.OptionalType): + element_type = type_.getElementType() + if config.abi_compatible: + if not isinstance(element_type, torch.TensorType): + var_name = f"var_{next(self.arg_var_id)}" + if isinstance( + element_type, + (torch.ListType, torch.TupleType, torch.DeviceObjType), + ): + # type_ is something like Optional[List] or Optional[Device] + arg_str = self.val_to_arg_str(val, element_type) + # For datatypes with auxiliary info, we need to hoist out the extra arguments. + # NOTE: This only works if there is one additional argument, though it can easily be generalized. + main_value, aux = arg_str.rsplit(", ") + self.writeline(f"auto {var_name} = {main_value};") + return f"&{var_name}, {aux}" + else: + self.writeline( + f"{self.c_type_for_prim_type(val, element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};" + ) + return f"&{var_name}" + else: + # type_ is Optional[Tensor] + # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim + base_handle = self.val_to_arg_str(val, element_type) + if config.use_minimal_arrayref_interface: + base_handle = ( + f"convert_arrayref_tensor_to_tensor({base_handle})" + ) + ( + tmp_raii_handle_var, + tmp_raii_handle_var_decl, + ) = self.create_tmp_raii_handle_var(base_handle) + if tmp_raii_handle_var: + self.writeline(tmp_raii_handle_var_decl) + base_handle = tmp_raii_handle_var + var_name = f"var_{next(self.arg_var_id)}" + self.writeline( + f"AtenTensorHandle {var_name} = {base_handle}.get();" + ) + return f"&{var_name}" + else: + return self.val_to_arg_str(val, element_type) + + elif isinstance(type_, torch.ListType): + assert isinstance( + val, (list, tuple) + ), f"{val} does not match with arg type {type_}" + element_type = type_.getElementType() + if config.abi_compatible: + var_name = f"var_array_{next(self.var_array_id)}" + if len(val) == 0: + # Zero-size array is not supported in the C or C++ standard, so + # we declare a null pointer for it. + self.writeline( + f"const {self.c_type_for_prim_type(None, element_type)}* {var_name} = nullptr;" + ) + else: + result = f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" + self.writeline( + f"const {self.c_type_for_prim_type(val[0], element_type)} {var_name}[] = {result};" + ) + # Need to pass the array length because we can't use std::vector + return f"{var_name}, {len(val)}" + else: + return f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" + + return self.val_to_arg_str_for_prim_type(val, type_) + + def create_tmp_raii_handle_var(self, base_handle): + if base_handle.startswith( + ( + "convert_arrayref_tensor_to_tensor", + "wrap_with_raii_handle_if_needed", + ) + ): + # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to + # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call. + tmp_var_name = f"var_{next(self.arg_var_id)}" + return ( + tmp_var_name, + f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};\n", + ) + else: + return "", "" diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..b82a707961df9e940c5849734db64d2dc6784ad1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -0,0 +1,432 @@ +# mypy: allow-untyped-defs +import functools +import os +from itertools import chain, count +from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union + +import sympy + +from torch import dtype as torch_dtype +from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name +from torch._inductor.runtime.triton_heuristics import grid as default_grid + +from .. import config +from ..codecache import CudaKernelParamCache +from ..utils import DeferredLineBase +from ..virtualized import V +from .aoti_hipify_utils import maybe_hipify_code_wrapper +from .codegen_device_driver import cuda_kernel_driver, cuda_kernel_header +from .cpp_utils import cexpr, DTYPE_TO_CPP +from .cpp_wrapper_cpu import CppWrapperCpu +from .wrapper import SymbolicCallArg + + +if TYPE_CHECKING: + from ..graph import GraphLowering + + +class DeferredCudaKernelLine(DeferredLineBase): + """ + When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels + to be tuned and stored as cubin files, so use a deferred line to backfill those information + """ + + def __init__( + self, + kernel_name: str, + line_template: str, + keys: Tuple[str, ...], + ): + super().__init__(line_template) + assert not isinstance(line_template, DeferredLineBase) + self.kernel_name = kernel_name + self.line_template = line_template + self.keys = keys + + def __call__(self): + params = CudaKernelParamCache.get(self.kernel_name) + assert ( + params is not None + ), f"{self.kernel_name} not found in CudaKernelParamCache" + for key in self.keys: + assert ( + key in params + ), f"{key} not found in CudaKernelParamCache[{self.kernel_name}]" + if key == get_cpp_wrapper_cubin_path_name(): + assert os.path.exists(params[key]), f"{params[key]} does not exist" + + return self.line_template % tuple(params[key] for key in self.keys) + + def _new_line(self, line): + return DeferredCudaKernelLine(self.kernel_name, line, self.keys) + + +class DeferredCudaDefaultGrid: + """ + A container for the default grid, which may be used by DeferredCudaGridLine + """ + + def __init__( + self, + kernel_name: str, + grid, + grid_callable: Optional[Callable[..., Any]] = None, + **grid_extra_kwargs, + ): + self.kernel_name = kernel_name + self.grid = grid + self.grid_callable = grid_callable + self.grid_extra_kwargs = grid_extra_kwargs + + def _process_grid(self, grid: Union[List[Any], Tuple[Any, ...]]): + if isinstance(grid, (list, tuple)): + return [self._process_grid(e) for e in grid] + else: + return grid.inner_expr if isinstance(grid, SymbolicCallArg) else grid + + def __call__(self): + grid = self.grid + assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list" + grid = self._process_grid(grid) + grid_callable = self.grid_callable or default_grid + if not self.grid_extra_kwargs: + grid_fn = grid_callable(*grid) + else: + grid_fn = grid_callable(*grid, **self.grid_extra_kwargs) + + params = CudaKernelParamCache.get(self.kernel_name) + assert ( + params is not None + ), f"{self.kernel_name} not found in CudaKernelParamCache" + block_cfg = { + "XBLOCK": params["x_block"], + "YBLOCK": params["y_block"], + "ZBLOCK": params["z_block"], + } + return grid_fn(block_cfg) + + +class DeferredCudaGridLine(DeferredLineBase): + """ + When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels + to be tuned and stored as cubin files, so use a deferred line to backfill those information + """ + + def __init__( + self, + kernel_name: str, + grid_var: str, + grid, + autotune_configs, + ): + super().__init__("") + self.kernel_name = kernel_name + self.grid_var = grid_var + self.grid = grid + self.autotune_configs = autotune_configs + + def __call__(self): + params = CudaKernelParamCache.get(self.kernel_name) + assert ( + params is not None + ), f"{self.kernel_name} not found in CudaKernelParamCache" + + if self.autotune_configs is not None: + # This indicates the Triton kernel is a user-defined one. + grid = None + if len(self.grid) == 1: + grid = self.grid[0] + else: + for i, c in enumerate(self.autotune_configs): + if all(arg == params["meta"][key] for key, arg in c.kwargs.items()): + grid = self.grid[i] + break + assert grid is not None + elif isinstance(self.grid, DeferredCudaDefaultGrid): + grid = self.grid() + else: + grid = self.grid + + assert len(grid) != 0, "Grid can't be empty" + grid_args_str = ", ".join( + [cexpr(V.graph.sizevars.simplify(item)) for item in grid] + ) + return f" Grid {self.grid_var} = Grid({grid_args_str});" + + def _new_line(self, line): + return DeferredCudaGridLine( + self.kernel_name, self.grid_var, self.grid, self.autotune_configs + ) + + +class CppWrapperCuda(CppWrapperCpu): + """ + Generates cpp wrapper for running on GPU and calls CUDA kernels + """ + + def __init__(self) -> None: + self.device = "cuda" + super().__init__() + self.grid_id = count() + self.cuda = True + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + super().write_header() + + self.header.splice("#include ") + if config.abi_compatible: + self.header.splice( + "#include " + ) + else: + self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_header())) + self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_driver())) + + def write_get_raw_stream(self, index, graph=None): + name = f"stream{index}" + self.writeline(maybe_hipify_code_wrapper(f"cudaStream_t {name};")) + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream({index}, (void**)&{name}));" + ) + return name + + def define_kernel( + self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True + ): + if not cuda: + return super().define_kernel(name, kernel, metadata, cuda) + + def generate(self, is_inference): + self.prefix.writeline("\n") + if not V.graph.aot_mode: + for kernel in chain( + sorted(self.src_to_kernel.values()), + sorted([entry[0] for entry in self.user_defined_kernel_cache.values()]), + ): + self.prefix.writeline( + maybe_hipify_code_wrapper(f"static CUfunction {kernel} = nullptr;") + ) + self.prefix.writeline("\n") + return super().generate(is_inference) + + def generate_user_defined_triton_kernel( + self, + kernel_name: str, + raw_args: List[Any], + grid: List[Any], + configs, + triton_meta, + constexprs, + ): + # in C++ wrapper, we don't pass constexpr args, as they don't + # get added as parameters to the PTX code compiled from the + # user-defined Triton kernel (only non-constexpr args do) + raw_args = [ + raw_arg for i, raw_arg in enumerate(raw_args) if i not in constexprs + ] + args = [self.val_to_arg_str(v) for v in raw_args] + arg_types = [ + arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg) + for arg in raw_args + ] + self.generate_kernel_call( + kernel_name, + args, + arg_types=arg_types, + raw_args=raw_args, + grid=grid, + cuda=True, + triton=True, + triton_meta=triton_meta, + autotune_configs=configs, + ) + + @functools.lru_cache(None) # noqa: B019 + def generate_load_kernel_once( + self, + kernel_name: str, + graph: "GraphLowering", # for per-graph caching + ): + keys = (get_cpp_wrapper_cubin_path_name(), "mangled_name", "shared_mem") + kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name + self.writeline(f"if ({kernel_var_name} == nullptr) {{") + self.writeline( + DeferredCudaKernelLine( + kernel_name, + """ """ + + kernel_var_name + + """ = loadKernel("%s", "%s", %s, this->cubin_dir_);""" + if V.graph.aot_mode + else """ """ + + kernel_var_name + + """ = loadKernel("%s", "%s", %s);""", + keys, + ) + ) + self.writeline("}") + return kernel_var_name + + def generate_args_decl(self, call_args, arg_types): + new_args = [] + for arg, arg_type in zip(call_args, arg_types): + var_name = f"var_{next(self.arg_var_id)}" + if isinstance(arg_type, torch_dtype): + if arg.endswith(".item()"): + # Need to declare a scalar in this case + ctype = DTYPE_TO_CPP[arg_type] + arg = arg[:-7] + if config.abi_compatible: + self.codegen_tensor_item( + arg_type, + arg, + var_name, + ) + else: + from torch import bfloat16, float16 + + if arg_type in (float16, bfloat16): + var_name_tmp = f"{var_name}_tmp" + self.writeline( + f"{ctype} {var_name_tmp} = {arg}.item<{ctype}>();" + ) + self.writeline(f"float {var_name} = float({var_name_tmp});") + else: + self.writeline( + f"{ctype} {var_name} = {arg}.item<{ctype}>();" + ) + else: + if config.abi_compatible: + self.writeline( + maybe_hipify_code_wrapper(f"CUdeviceptr {var_name};") + ) + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast(&{var_name})));" + ) + else: + self.writeline( + maybe_hipify_code_wrapper( + f"CUdeviceptr {var_name} = reinterpret_cast({arg}.data_ptr());" + ) + ) + elif arg_type in (sympy.Integer, int): + self.writeline(f"int {var_name} = {self.expr_printer(arg)};") + elif arg_type in (sympy.Float, float): + self.writeline(f"float {var_name} = {self.expr_printer(arg)};") + else: + self.writeline(f"auto {var_name} = {self.expr_printer(arg)};") + new_args.append(f"&{var_name}") + + return ", ".join(new_args) + + def generate_default_grid( + self, + kernel_name: str, + grid: List[Any], + cuda: bool = True, + grid_callable: Optional[Callable[..., Any]] = None, + **grid_extra_kwargs, + ): + """ + Generate grid configs for launching a CUDA kernel using the grid + function from triton_heuristics. Because its computation needs + to read kernel config after autotune, it is done in a deferred way + using DeferredCudaDefaultGrid. + """ + if not cuda: + return grid + return DeferredCudaDefaultGrid( + kernel_name, grid, grid_callable, **grid_extra_kwargs + ) + + def generate_kernel_call( + self, + kernel_name: str, + call_args, + grid=None, + device_index=None, + cuda=True, + triton=True, + arg_types=None, + raw_args=None, + grid_fn: str = "grid", + triton_meta=None, + autotune_configs=None, + grid_extra_kwargs="", + ): + assert arg_types is not None and len(call_args) == len( + arg_types + ), "call_args and arg_types do not match" + + if not cuda: + # Even in CppWrapperCuda, we may see cpp kernels + return super().generate_kernel_call( + kernel_name, + call_args, + grid, + device_index, + cuda, + triton, + arg_types, + raw_args, + grid_fn, + triton_meta, + autotune_configs, + grid_extra_kwargs, + ) + + device_index, call_args = self.prepare_triton_kernel_call( + device_index, call_args + ) + kernel_var_name = self.generate_load_kernel_once(kernel_name, V.graph) + + # args with value 1 are added into equal_to_1 and constants + # in triton_meta (in the Python codegen) which makes them + # inlined in the PTX and compiled CUBIN + if ( + triton_meta is not None + and "configs" in triton_meta + and triton_meta["configs"] + ): + equal_to_1 = triton_meta["configs"][0].equal_to_1 + call_args = [arg for i, arg in enumerate(call_args) if i not in equal_to_1] + arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1] + + call_args_str = self.generate_args_decl(call_args, arg_types) + kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}" + self.writeline(f"void* {kernel_args_var}[] = {{{call_args_str}}};") + stream = ( + "stream" + if V.graph.aot_mode + else self.write_get_raw_stream(device_index, V.graph) + ) + + grid_var = f"{kernel_name}_grid_{next(self.grid_id)}" + self.writeline( + DeferredCudaGridLine(kernel_name, grid_var, grid, autotune_configs) + ) + + kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name + # add debug printer code for all triton kernel related calls + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None) + with debug_printer_manager: + self.writeline(f"if ({grid_var}.is_non_zero()) {{") + self.writeline( + DeferredCudaKernelLine( + kernel_name, + r" launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format( + kernel_var_name, + f"{grid_var}.grid_x", + f"{grid_var}.grid_y", + f"{grid_var}.grid_z", + kernel_args_var, + stream, + ), + ("num_warps", "shared_mem"), + ), + ) + self.writeline("}") diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__init__.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a53f24b1759967a4e48eec424e15a7e0c86952d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62f97185b48d4288694dbbf7fb23929be85170f1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0357e793efb0ea73a6c844a17e8e313bbd584ef4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..905ab4ffad8ec5d5e0648b805b4861312909f7e6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94a81dc567dd5cd0eb7098ae022b2ad05c5aee28 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ec77ae167b0c2dfaaafa1c6547046c93fdb3599 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..884e49860ecf5be1c79eb7fe2161f29b84f264f2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd8140b33da534e2070aea04587ca8a7eea5f078 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66ed46da78ef7e72abb8842060021eea15a09992 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..871588eefca9a6725c2e18ad2a4783404509b767 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -0,0 +1,114 @@ +# mypy: allow-untyped-defs +import logging +from typing import cast, Sequence + +from ...._dynamo.utils import counters +from ... import config +from ...codecache import code_hash, get_path +from ...ir import CUDATemplateBuffer +from ...scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import IndentedBuffer + + +log = logging.getLogger(__name__) + + +class CUDACPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for CUDA C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by CUDACombinedScheduling. + + It handles fusion decisions and CUDA C++ specific template code generation. + """ + + def __init__(self, scheduler: Scheduler) -> None: + super().__init__() + self.scheduler = scheduler + + @classmethod + def get_backend_features(cls, device): + return {} + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + @staticmethod + def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, CUDATemplateBuffer + ) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_name = "_".join(["cuda", fused_name, wrapper.next_kernel_suffix()]) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace("KERNEL_NAME", kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.cuda(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''', 'so')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CUDA template, possibly with fused epilogues + """ + counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cuda_cpp_template( + template_node + ), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer" + template_node = cast(SchedulerNode, template_node) + _, (numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node) + kernel, render = ctb.make_kernel_render(ctb) + with kernel: + template_node.mark_run() + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node] + kernel_name = self.define_kernel(src_code, node_schedule) + + # debug printing values of intermediate tensors + _, call_args, arg_signatures, _ = kernel.args.python_argdefs() + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, kernel_name, arg_signatures, kernel + ) + with debug_printer_manager: + kernel.call_kernel(kernel_name, ctb) + + V.graph.removed_buffers |= kernel.removed_buffers + self.scheduler.free_buffers() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py new file mode 100644 index 0000000000000000000000000000000000000000..fa27231426002c8bc4dd81e29d180d25779d147a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py @@ -0,0 +1,46 @@ +import functools +import logging +from typing import Optional + +import torch + +from ... import config + + +log = logging.getLogger(__name__) + + +def get_cuda_arch() -> Optional[str]: + try: + cuda_arch = config.cuda.arch + if cuda_arch is None: + # Get Compute Capability of the first Visible device + major, minor = torch.cuda.get_device_capability(0) + return str(major * 10 + minor) + return str(cuda_arch) + except Exception as e: + log.error("Error getting cuda arch: %s", e) + return None + + +def get_cuda_version() -> Optional[str]: + try: + cuda_version = config.cuda.version + if cuda_version is None: + cuda_version = torch.version.cuda + return cuda_version + except Exception as e: + log.error("Error getting cuda version: %s", e) + return None + + +@functools.lru_cache(None) +def nvcc_exist(nvcc_path: str = "nvcc") -> bool: + if nvcc_path is None: + return False + import subprocess + + res = subprocess.call( + ["which", nvcc_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + return res == 0 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..d6472a48f1e08c572956ae7d6de0a6a52f1cd747 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -0,0 +1,397 @@ +# mypy: allow-untyped-defs +import logging +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union + +from ...autotune_process import CUDABenchmarkRequest +from ...ir import ( + Buffer, + ChoiceCaller, + CUDATemplateBuffer, + IRNode, + Layout, + PrimitiveInfoType, + TensorBox, +) +from ...utils import sympy_product +from ...virtualized import V +from ..common import IndentedBuffer, Kernel, OpOverrides +from ..cpp_utils import CppPrinter, DTYPE_TO_CPP + + +if TYPE_CHECKING: + from torch._inductor.codegen.cuda.cuda_template import CUDATemplate + +log = logging.getLogger(__name__) + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +class CUDAKernel(Kernel): + """ + Baseclass for CUDA / Cutlass based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + +class CUDATemplateKernel(CUDAKernel): + """ + Template kernels defined by CUDA / Cutlass in C++. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream" + + def __init__(self, kernel_name) -> None: + """ + Initializes a new instance of the CUDATemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + # Mapping from arg name to IRNode. + self.named_nodes: Dict[str, IRNode] = {} + + def arg_name(self, node: IRNode) -> Optional[str]: + """ + Returns arg name of a given input or output node. + """ + if node is None: + return None + return {**self.args.input_buffers, **self.args.output_buffers}.get( + node.get_name(), None + ) + + def check_not_null(self, node: IRNode) -> str: + """ + Generates code to check that a node is not null. + """ + + if node is None: + return "" + + size_str = self.size(node, 0, -1) + name_str = self.arg_name(node) + if name_str is None: + return "" + + res = IndentedBuffer(initial_indent=2) + res.tabwidth = 1 + res.splice( + f""" + {{ + if (!{name_str}) {{ + int64_t {name_str}_size = {size_str}; + if ({name_str}_size > 0) {{ + throw std::runtime_error("input {name_str} is null but size is not 0!"); + }} + }} + }} + """ + ) + return res.getvalue() + + def def_kernel( + self, + inputs: List[IRNode], + outputs: List[IRNode], + names_str: str = "", + input_reorder: Optional[List[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + """ + + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + arg_defs, *_ = self.args.cpp_argdefs() + return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})" + + def call_kernel( + self, + name: str, + node: "CUDATemplateBuffer", # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.WrapperCodeGen + + name: Name of kernel function. + node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + _, call_args, _, arg_types = self.args.python_argdefs() + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + else: + call_args[i] = f"c_void_p({call_args[i]}.data_ptr())" + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + call_args.append("None") + + if node.get_workspace_size() > 0: + wrapper.generate_workspace_allocation( + node.get_workspace_size(), V.graph.scheduler.current_device, False + ) + call_args.append("c_void_p(workspace.data_ptr())") + else: + call_args.append("None") + + wrapper.generate_kernel_call( + name, + call_args, + cuda=True, + triton=False, + arg_types=arg_types, + ) + if node.get_workspace_size() > 0: + wrapper.writeline(wrapper.make_free_by_names(["workspace"])) + + def dtype(self, node: IRNode) -> Optional[str]: + """ + Generates code which represents dtype of a given node. + """ + + if node is None: + return "void" + return DTYPE_TO_CPP.get(node.get_layout().dtype) + + def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]: + # Helper method, called into from CUTLASSGemmTemplate + if node is None: + return default_dtype + from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate + + return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype] + + def max_valid_index(self, node: IRNode, default=-1): + # Helper method, called into from CUTLASSGemmTemplate + if node is None: + return default + max_valid_offset = 0 + for i in range(len(node.get_size())): + max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i] + return max_valid_offset + + def offset(self, node: IRNode) -> str: + """ + Generates code which represents offset of a given node. + """ + + if node is None: + return "0" + return str(node.get_layout().offset) + + def ptr(self, node: IRNode) -> str: + """ + Generates code which represents pointer of a given node. + """ + + if node is None: + return "nullptr" + arg_name = self.arg_name(node) + if arg_name is None: + return "nullptr" + offset = self.offset(node) + return arg_name if offset == "0" else f"{arg_name} + {offset}" + + def size( + self, + node: IRNode, + start_index: int, + end_index: Optional[int] = None, + default_value: int = 0, + ) -> str: + """ + Hook called from template code to get the size of an arg. + Generates code which represents size of a given node in [start_index, end_index). + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + start_index = _normalize_idx(start_index, len(node.get_size())) + if end_index is None: + end_index = start_index + end_index = _normalize_idx(end_index, len(node.get_size())) + + sizes = node.get_size()[start_index : end_index + 1] + if len(sizes) == 0: + return str(default_value) + + val = sympy_product(sizes) + return cexpr(self.rename_indexing(val)) + + def stride(self, node: IRNode, index: int, default_value: int = 0) -> str: + """ + Hook called from template code to get the stride of an arg. + Generates code which represents stride of a given node at index. + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + index = _normalize_idx(index, len(node.get_size())) + if index < 0: + return str(default_value) + + stride = node.get_stride()[index] + return cexpr(self.rename_indexing(stride)) + + def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str: + """ + Hook called from template code to get the row or column stride of an arg. + This is required by some CUTLASS 2.X APIs. + If the node is in row_major, it returns stride[-2]. + If the node is in column_major, it returns stride[-1]. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None or len(node.get_stride()) < 2: + return str(default_value) + + stride0 = node.get_stride()[-1] + stride1 = node.get_stride()[-2] + if stride0 == 1: + return cexpr(self.rename_indexing(stride1)) + elif stride1 == 1: + return cexpr(self.rename_indexing(stride0)) + else: + raise RuntimeError( + f"At least 1 stride should be 1. Strides: {node.get_stride()=}" + ) + + +class CUDATemplateCaller(ChoiceCaller): + """ + CUDATemplateCaller + + This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CUDABenchmarkRequest): The benchmark request for the caller. + template_buffer (CUDATemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: List[Buffer], + layout: Layout, + make_kernel_render: Callable[[CUDATemplateBuffer, Optional[List[IRNode]]], str], + bmreq: CUDABenchmarkRequest, + template: "CUDATemplate", # type: ignore[name-defined] + info_kwargs: Optional[Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]], # type: ignore[type-arg] + ) -> None: + super().__init__(name, input_nodes, layout) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + return self.bmreq.benchmark( + *args, output_tensor=out + ) # @TODO: Hack for ensuring that Cutlass Kernel is preferred + + def __str__(self) -> str: + return f"CUDATemplateCaller(source_file={self.bmreq.source_file})" + + def call_name(self) -> str: + return f"cuda_template_kernels.{self.name}" + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + if self.info_kwargs is not None and "op" in self.info_kwargs: + op: Any = self.info_kwargs["op"] + return { + "backend": "CUDA", + "op_type": type(op).__name__, + "op_conf_name": str(op.configuration_name()), + "op_arch": str(op.arch), + "tile_shape": str(op.tile_description.tile_shape), + "epilogue_schedule": str(op.epilogue_schedule), + "kernel_schedule": str(op.kernel_schedule), + "element_accumulator": str(op.accumulator_type()), + "op_name": str(op.procedural_name()), + "instruction_shape": str( + op.tile_description.math_instruction.instruction_shape + ), + } + else: + return {"backend": "CUDA", "op_type": "unknown"} + + def output_node(self) -> TensorBox: + self.bmreq.update_workspace_size() + return TensorBox.create( + CUDATemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + template=self.template, + ) + ) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py new file mode 100644 index 0000000000000000000000000000000000000000..1f5e59b3b8cc21dc8b9ff63887cd96834e1f28bb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py @@ -0,0 +1,258 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from typing import List, Optional +from unittest.mock import patch + +import sympy + +import torch + +from ...autotune_process import CUDABenchmarkRequest, TensorMeta +from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import KernelTemplate +from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel + + +log = logging.getLogger(__name__) + + +class CUDATemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes: List[Buffer], + layout: Layout, + input_reorder: Optional[List[int]] = None, + ) -> None: + """ + + Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + + Args: + name (str): The name of the CUDATemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. + + """ + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer("buf_out", layout) + self.input_reorder = input_reorder + self.layout = layout + + def generate( # type: ignore[override] + self, + **kwargs, + ) -> CUDATemplateCaller: + """ + Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller + may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A CUDATemplateCaller object representing the generated CUDA template caller. + """ + kernel_name = f"cuda_{self.name}" + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(self.output_node) + ), CUDATemplateKernel( + kernel_name=kernel_name, + ) as kernel: + code = self.render(kernel=kernel, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + log.debug("Generated Code:\n%s", code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, call_args[len(expected_args) :]) + ) + + kernel_hash_name = f"cuda_{self.name}_{next(self.index_counter)}" + + # create the BenchmarkRequest + bmreq = CUDABenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: CUDATemplateBuffer, + epilogue_nodes: Optional[List[IRNode]] = None, + ): + kernel = CUDATemplateKernel( + kernel_name="KERNEL_NAME", + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate + ) + return kernel, render + + return CUDATemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + kwargs, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + using bfloat16 = nv_bfloat16; + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + +class CUTLASSTemplate(CUDATemplate): + """ + CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the + CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels. + """ + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + #include "cute/tensor.hpp" + #include "cutlass/cutlass.h" + #include "cutlass/numeric_types.h" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/host_tensor.h" + #include "cutlass/util/reference/host/tensor_fill.h" + #include "cutlass/util/reference/device/tensor_fill.h" + #include "cutlass/util/device_memory.h" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + using namespace cute; + #define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + throw std::runtime_error(msg); \\ + } \\ + } + + // Used as pass-through functor in EVT just for type casting / rounding + template + struct identity_op { + CUTLASS_HOST_DEVICE + T operator()(T val) const { return val; } + }; + + """ + ) + return res + + def cute_int(self, int_str: str, var_name: str) -> str: + res = "" + if int_str in {"1", "1L"}: + res = "cute::Int<1>{}" + else: + res = int_str + + return f"{res} /* {var_name} */" + + _DTYPE_TO_CUTLASS = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "cutlass::half_t", + torch.int32: "int32_t", + torch.int16: "int16_t", + torch.int8: "int8_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "cutlass::bfloat16_t", + } + + _DTYPE_TO_CUTLASS_SPARSE_META = { + torch.int32: "uint32_t", + torch.int16: "uint16_t", + } + + def cutlass_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})" + + def cutlass_sparse_meta_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return ( + f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})" + ) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..a41fa62b5a7b9fa28500362d7c4db7af22f6d68c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py @@ -0,0 +1,361 @@ +# mypy: allow-untyped-defs +from typing import Dict, List +from unittest.mock import patch + +import sympy + +import torch._inductor.virtualized as virtualized +from torch._inductor.ir import ComputedBuffer, FlexibleLayout, IRNode, Pointwise +from torch._inductor.utils import IndentedBuffer, sympy_str + + +# Used as a magic string to indicate an unsupported sympy expression +# became part of generated C++ code. +_MAGIC_SYMPY_ERROR_STRING = "[!sympy: unsupported expr!]" + + +def _arg_str(a): + if isinstance(a, sympy.Expr): + # If this return value containing the _MAGIC_SYMPY_ERROR_STRING + # is used as part of the final generated C++ code, + # a CUTLASSEVTOpNotImplementedError is raised to indicate that + # the op could not be converted to a valid EVT expression. + return f"{_MAGIC_SYMPY_ERROR_STRING}('{sympy_str(a)}')" + return str(a) + + +class CUTLASSEVTOpNotImplementedError(NotImplementedError): + pass + + +class CutlassEVTEpilogueTypeFormatter: + """ + Codegen class, which provides an entry point to generate + Cutlass "Epilogue Visitor Tree" (EVT) functor declarations. + + See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder + for more about EVTs and how they are declared and used to generate. + + Notes: + * Used by CUTLASSGemmTemplate. + * This class should not be instantiated by users, it is intended to be used + by calling CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(...) + which instantiates this class as an ops handler for virtualized.V.ops.[op-name] + * Extend this with more _op_ nodes to add support for new pointwise operations. + + + """ + + def __init__(self, accumulator_node_name, evt_type_name): + """ + + Initialize an instance of CutlassEVTEpilogueTypeFormatter. + + Parameters: + - accumulator_node_name (str): The name of the output Buffer for the GEMM operation in the original (unfused) + IR graph. + - evt_type_name (str): The output name of the EVT type we are generating. + + """ + self.accumulator_node_name = accumulator_node_name + self.output = IndentedBuffer(0) + self.var_counter = 0 + self.evt_type_name = evt_type_name + self.aliases = {} + + @staticmethod + def ir_to_evt_string( + template_output_node_name: str, + evt_type_name: str, + epilogue_nodes: List[IRNode], + ): + """ + Formats IR nodes into a string representation compatible with Cutlass EVT format. + + Args: + template_output_node_name (str): The name of the template output node. + evt_type_name (str): The name of the EVT type. + epilogue_nodes (List[IRNode]): A list of IR nodes representing the epilogue nodes. As of now, these must be + ComputedBuffer nodes wrapping Pointwise nodes. + + Returns: + A string representation of the IR nodes formatted according to the Cutlass EVT format. + """ + formatter = CutlassEVTEpilogueTypeFormatter( + template_output_node_name, evt_type_name + ) + + with virtualized.V.set_ops_handler(formatter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + for node in epilogue_nodes: + if isinstance(node, ComputedBuffer): + pnode = node.data + else: + raise RuntimeError( + "Epilogue nodes must be Pointwise nodes, wrapped in a named ComputedBuffer" + ) + assert isinstance(pnode, Pointwise) + index = pnode._index(pnode.ranges) + result = pnode.inner_fn(index) + # each epilogue node results in a single "using" statement and may refer to the previous steps by name + formatter.aliases[node.name] = result + res = formatter.getvalue(result) # type: ignore[possibly-undefined] + if _MAGIC_SYMPY_ERROR_STRING in res: + raise CUTLASSEVTOpNotImplementedError( + "sympy / indexing expressions not yet supported in EVT fusion" + ) + else: + return res + + def __getattr__(self, name): + """ + Resolve V.ops. calls, after this instance has been installed as V.ops handler. + """ + + def inner(*args, **kwargs): + fargs = [_arg_str(a) for a in args] + fkwargs = {key: _arg_str(a) for key, a in kwargs.items()} + fn = getattr(self, f"_op_{name}") + line = fn(*fargs, **fkwargs) + self.var_counter += 1 + varname = f"EVT_expr_{self.var_counter}" + # replace line with a new variable name + self.output.writeline(f"using {varname} = {line};") + return varname + + if name.startswith("_"): + raise CUTLASSEVTOpNotImplementedError(name) + if hasattr(self, f"_op_{name}"): + return inner + else: + raise CUTLASSEVTOpNotImplementedError(name) + + def _op_load(self, name, index_expr): + # Load an input to an operation. Might be the output of the matmul, the result + # of a previous epilogue node, a constant or (TODO) an auxiliary input. + if name == self.accumulator_node_name: + return f"cutlass::epilogue::fusion::Sm90AccFetch /* :={name} (matmul output in accumulator) */" + elif name in self.aliases: + return self.aliases[name] + else: + # return f"cutlass::epilogue::fusion::Sm90SrcFetch /* :={name} */" + raise CUTLASSEVTOpNotImplementedError( + f"Operand {name} not found. Auxiliary inputs not supported yet." + ) + + def _op_constant(self, value, dtype): + # Load a constant + if str(dtype) in ("torch.float16", "torch.float32"): + return f"cutlass::epilogue::fusion::Sm90ScalarBroadcast /* value={value}, dtype={dtype} */" + else: + raise CUTLASSEVTOpNotImplementedError( + f"Unsupported dtype for constant: {dtype}" + ) + + def _cutlass_binary_functional_op(self, op, a, b): + # Perform a named operation on two inputs + # see https://github.com/NVIDIA/cutlass/blob/6407bcdf0a24097b7b016ee105937693c62f9923/include/cutlass/functional.h for ops + return f"cutlass::epilogue::fusion::Sm90EVT,{a},{b}>" # noqa: B950 + + def _convert_to_output_dtype(self, a): + # Convert the final output to the dtype of the output buffer + return f"cutlass::epilogue::fusion::Sm90EVT,{a}>" # noqa: B950 + + def _op_to_dtype(self, a, *args, **kwargs): + # no-op in our case, since we convert to the output dtype at the end and convert everything to the accumulator + # dtype. + # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible + # throughout the fusion chain. + return a # noqa: B950 + + def _op_mul(self, a, b): + return self._cutlass_binary_functional_op("multiplies", a, b) + + def _op_div(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_truediv(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_ge(self, a, b): + return self._cutlass_binary_functional_op("greater_equal", a, b) + + def _op_add(self, a, b): + return self._cutlass_binary_functional_op("plus", a, b) + + def _op_sub(self, a, b): + return self._cutlass_binary_functional_op("minus", a, b) + + def _op_minimum(self, a, b): + return self._cutlass_binary_functional_op("minimum", a, b) + + def _op_maximum(self, a, b): + return self._cutlass_binary_functional_op("maximum", a, b) + + def _op_relu(self, a): + const_zero = self._op_constant(0.0, "torch.float32") + return f"cutlass::epilogue::fusion::Sm90EVT,{a}, {const_zero}>" # noqa: B950 + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise CUTLASSEVTOpNotImplementedError + + # Add more ops here... + def getvalue(self, result) -> str: + # Return final result + dtype_converted_expr = self._convert_to_output_dtype( + f"EVT_expr_{self.var_counter}" + ) + self.output.writeline(f"using {self.evt_type_name} = {dtype_converted_expr};") + return self.output.getvalue() + + +class CutlassEVTEpilogueArgumentFormatter: + """ + Codegen class, which provides an entry point to generate + Cutlass "Epilogue Visitor Tree" (EVT) Argument initializers + + See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder + for more about EVTs and how they are declared and used to generate. + + Notes: + * Used by CUTLASSGemmTemplate. + * This class should not be instantiated by users, it is intended to be used + by calling CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(...) + which instantiates this class as an ops handler for virtualized.V.ops.[op-name] + * Extend this with more _op_ nodes to add support for new pointwise operations. + + + """ + + def __init__(self, accumulator_node_name: str): + """ + + Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly. + Use the CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string static method. + + Args: + accumulator_node_name (str): The name of the accumulator node which should contain + the Matmul result before fusion according to the IR graph. + """ + self.accumulator_node_name: str = accumulator_node_name # + self.output: IndentedBuffer = IndentedBuffer(0) # The output buffer for codegen + self.var_counter: int = ( + 0 # used to generate variable names, incremented for each new variable + ) + self.aliases: Dict[str, str] = {} # Aliases for subexpression functors + + @staticmethod + def ir_to_evt_argument_string( + template_output_node_name: str, + epilogue_nodes: List[IRNode], + ) -> str: + formatter = CutlassEVTEpilogueArgumentFormatter( + template_output_node_name, + ) + + with virtualized.V.set_ops_handler(formatter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + for node in epilogue_nodes: + assert isinstance(node, ComputedBuffer) + pnode = node.data + assert isinstance(pnode, Pointwise) + index = pnode._index(pnode.ranges) + result = pnode.inner_fn(index) + # each epilogue node results in a single "using" statement and may refer to the previous steps by name + if node.name is not None: + formatter.aliases[node.name] = result + + res: str = formatter.getvalue(result) # type: ignore[possibly-undefined] + if _MAGIC_SYMPY_ERROR_STRING in res: + raise CUTLASSEVTOpNotImplementedError( + "sympy / indexing expressions not yet supported in EVT fusion" + ) + else: + return res + + def __getattr__(self, name): + def inner(*args, **kwargs): + fargs = [_arg_str(a) for a in args] + fkwargs = {key: _arg_str(a) for key, a in kwargs.items()} + fn = getattr(self, f"_op_{name}") + line = fn(*fargs, **fkwargs) + return line + + if name.startswith("_"): + raise CUTLASSEVTOpNotImplementedError(name) + + if hasattr(self, f"_op_{name}"): + return inner + else: + raise CUTLASSEVTOpNotImplementedError(name) + + def _op_load(self, name, index_expr): + if name == self.accumulator_node_name: + return "{}" + elif name in self.aliases: + return self.aliases[name] + else: + raise CUTLASSEVTOpNotImplementedError( + f"Operand {name} not found. Auxiliary inputs not supported yet." + ) + + def _op_constant(self, value, dtype): + if str(dtype) in ("torch.float16", "torch.float32"): + return "{ static_cast(" + str(value) + ") }" + else: + raise CUTLASSEVTOpNotImplementedError( + f"Unsupported dtype for constant: {dtype}" + ) + + def _cutlass_binary_functional_op(self, op, a, b): + return f"{{ /*{op}: */ {a}, {b} }}" + + def _op_mul(self, a, b): + return self._cutlass_binary_functional_op("multiplies", a, b) + + def _op_div(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_truediv(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_ge(self, a, b): + return self._cutlass_binary_functional_op("greater_equal", a, b) + + def _op_add(self, a, b): + return self._cutlass_binary_functional_op("plus", a, b) + + def _op_sub(self, a, b): + return self._cutlass_binary_functional_op("minus", a, b) + + def _op_minimum(self, a, b): + return self._cutlass_binary_functional_op("minimum", a, b) + + def _op_maximum(self, a, b): + return self._cutlass_binary_functional_op("maximum", a, b) + + def _op_relu(self, a): + const_zero = self._op_constant(0.0, "torch.float32") + return "{" + str(a) + ", " + const_zero + "}" + + def _op_to_dtype(self, a, dtype, src_dtype=None): + # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible + # throughout the fusion chain. + assert dtype in ( + "torch.float32", + "torch.float16", + ), f"Unsupported dtype: {dtype}" + assert src_dtype in ( + None, + "torch.float32", + "torch.float16", + ), f"Unsupported source dtype: {src_dtype}" + return a + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise CUTLASSEVTOpNotImplementedError + + def getvalue(self, result) -> str: + return "{" + str(result) + "}" diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f168a2dd175579689377a5b7e7501435cf6629f2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f237d2cba7902f30826c4367a467fdab3d26f74e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..c95e5a29fa1bdccc473b4b853b6e4d03305a05d7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -0,0 +1,188 @@ +# mypy: allow-untyped-defs +from ..cutlass_utils import try_import_cutlass + + +if try_import_cutlass(): + import enum + + from cutlass_library.gemm_operation import * # noqa: F401, F403 + from cutlass_library.library import * # noqa: F401, F403 + + # copied / modified from original at + # https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658 + # to support EVT similar to + # https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L315C69-L315C69 # noqa: B950 + class EmitGemmUniversal3xInstanceWithEVT: + """Responsible for emitting a CUTLASS 3.x template definition""" + + def __init__(self, operation_suffix="") -> None: + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/gemm/gemm.h", + "cutlass/numeric_types.h", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/epilogue/collective/collective_builder.hpp", + ] + self.builtin_epilogue_functor_template = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + > + """ + self.gemm_template = """ + using EpilogueScheduleType = ${epilogue_schedule}; + static_assert(cute::is_same_v || + cute::is_same_v, + "Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue"); + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementAcc = ${element_accumulator}; + using ElementD = ${element_d}; + ${epilogue_functor}; + using ${operation_name}_epilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + EpilogueScheduleType, + ${operation_name}_epilogue_functor + >::CollectiveOp; + + using ${operation_name}_mainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape, + ${stages}, + ${kernel_schedule} + >::CollectiveOp; + + // Gemm operator ${operation_name} + using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + ${operation_name}_mainloop, + ${operation_name}_epilogue, + ${tile_scheduler}>; + + // Define named type + struct ${operation_name} : + public ${operation_name}_base { }; + + """ + + # + def instance_template(self): + return """ + ${compile_guard_start} + using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; + manifest.append( + new ${gemm_kind}("${operation_name}")); + ${compile_guard_end} + """ + + # + def emit(self, operation): + tile_shape = operation.tile_description.tile_shape + warp_count = operation.tile_description.warp_count + # stage count set to zero indicates builder automatic stage selection + if operation.tile_description.stages > 0: + stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>" + else: + stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout" # noqa: B950 + warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)] + + ( + instance_layout_A, + instance_layout_B, + instance_layout_C, + instance_layout_D, + ) = ( + operation.A.layout, + operation.B.layout, + operation.C.layout, + operation.D.layout, + ) + + # 3.0 profiler integration only supports trivial epilogues for now + epilogue_vector_length = 1 + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + values = { + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), # type: ignore[name-defined] + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], # type: ignore[name-defined] + } + epilogue_functor = SubstituteTemplate( # type: ignore[name-defined] + self.builtin_epilogue_functor_template, values + ) + + elif callable(operation.epilogue_functor): + epilogue_functor = operation.epilogue_functor( + operation.procedural_name() + "_epilogue_functor" + ) + else: + epilogue_functor = str(operation.epilogue_functor) + # + + values = { + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "element_a": DataTypeTag[operation.A.element], # type: ignore[name-defined] + "layout_a": LayoutTag[instance_layout_A], # type: ignore[name-defined] + "element_b": DataTypeTag[operation.B.element], # type: ignore[name-defined] + "layout_b": LayoutTag[instance_layout_B], # type: ignore[name-defined] + "element_c": DataTypeTag[operation.C.element], # type: ignore[name-defined] + "layout_c": LayoutTag[instance_layout_C], # type: ignore[name-defined] + "element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined] + "layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined] + "element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined] + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], # type: ignore[name-defined] # noqa: B950 + "arch": "cutlass::arch::Sm%d" % operation.arch, + "tile_shape_m": str(operation.tile_description.tile_shape[0]), + "tile_shape_n": str(operation.tile_description.tile_shape[1]), + "tile_shape_k": str(operation.tile_description.tile_shape[2]), + "cluster_m": str(operation.tile_description.cluster_shape[0]), + "cluster_n": str(operation.tile_description.cluster_shape[1]), + "cluster_k": str(operation.tile_description.cluster_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined] + "epilogue_schedule": str(EpilogueScheduleTag[operation.epilogue_schedule]), # type: ignore[name-defined] + "epilogue_functor": epilogue_functor, + "stages": stage_count_string, + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "align_c": str(operation.C.alignment), + "align_d": str(operation.C.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], # type: ignore[name-defined] + "transform_b": ComplexTransformTag[operation.B.complex_transform], # type: ignore[name-defined] + "math_operation": MathOperationTag[ # type: ignore[name-defined] + operation.tile_description.math_instruction.math_operation + ], + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), # type: ignore[name-defined] + "tile_scheduler": str(TileSchedulerTag[operation.tile_scheduler]), # type: ignore[name-defined] + } + + return SubstituteTemplate(self.gemm_template, values) # type: ignore[name-defined] diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a149bd66a21bdbdaacebc7028114d5c67b2be9b3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -0,0 +1,363 @@ +# mypy: allow-untyped-defs +import functools +import logging +import os +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Optional + +import sympy + +import torch + +from ... import config +from ...ir import Layout +from ...runtime.runtime_utils import cache_dir +from .cuda_env import get_cuda_arch, get_cuda_version + + +log = logging.getLogger(__name__) + + +def _rename_cutlass_import(content: str, cutlass_modules: List[str]) -> str: + for cutlass_module in cutlass_modules: + content = content.replace( + f"from {cutlass_module} import ", + f"from cutlass_library.{cutlass_module} import ", + ) + return content + + +def _gen_cutlass_file( + file_name: str, cutlass_modules: List[str], src_dir: str, dst_dir: str +) -> None: + orig_full_path = os.path.abspath(os.path.join(src_dir, file_name)) + text = "" + with open(orig_full_path) as f: + text = f.read() + text = _rename_cutlass_import(text, cutlass_modules) + dst_full_path = os.path.abspath( + os.path.join( + dst_dir, + file_name, + ) + ) + with open(dst_full_path, "w") as f: + f.write(text) + + +@functools.lru_cache(None) +def try_import_cutlass() -> bool: + if config.is_fbcode(): + return True + + # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path. + # This is a temporary hack to avoid CUTLASS module naming conflicts. + # TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues. + + cutlass_py_full_path = os.path.abspath( + os.path.join(config.cuda.cutlass_dir, "python/cutlass_library") + ) + tmp_cutlass_py_full_path = os.path.abspath( + os.path.join(cache_dir(), "torch_cutlass_library") + ) + dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library") + + if os.path.isdir(cutlass_py_full_path): + if tmp_cutlass_py_full_path not in sys.path: + if os.path.exists(dst_link): + assert os.path.islink( + dst_link + ), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." + assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath( + cutlass_py_full_path + ), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}" + else: + os.makedirs(tmp_cutlass_py_full_path, exist_ok=True) + os.symlink(cutlass_py_full_path, dst_link) + sys.path.append(tmp_cutlass_py_full_path) + try: + import cutlass_library.generator # noqa: F401 + import cutlass_library.library # noqa: F401 + import cutlass_library.manifest # noqa: F401 + + return True + + except ImportError as e: + log.debug( + "Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.", + str(e), + ) + else: + log.debug( + "Failed to import CUTLASS packages: CUTLASS repo does not exist: %s", + cutlass_py_full_path, + ) + return False + + +def _normalize_cuda_arch(arch: str) -> str: + if int(arch) >= 90: + return "90" + elif int(arch) >= 80: + return "80" + elif int(arch) >= 75: + return "75" + elif int(arch) >= 70: + return "70" + else: + raise NotImplementedError(f"Unsupported cuda arch: {arch}") + + +@dataclass +class CUTLASSArgs: + """ + CUTLASS args used to initialize a CUTLASS Manifest. + """ + + architectures: Optional[str] = None + cuda_version: Optional[str] = None + + operations = "all" + build_dir = "" + curr_build_dir = "" + generator_target = "" + kernels = "all" + ignore_kernels = "" + # TODO: these three look dead? + kernel_filter_file: None = None + selected_kernel_list: None = None + interface_dir: None = None + filter_by_cc = True + disable_full_archs_compilation = False + + def __post_init__(self): + if self.architectures is None or self.cuda_version is None: + raise RuntimeError( + f"{self.architectures=} or {self.cuda_version=} is None!" + ) + self.architectures = _normalize_cuda_arch(self.architectures) + + +@functools.lru_cache(None) +def _gen_ops_cached(arch, version) -> List[Any]: + # Note: Cache needs to be specific for cuda architecture and version + + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library.generator as cutlass_generator + import cutlass_library.manifest as cutlass_manifest + + if arch is None or version is None: + log.error( + "Cannot detect cuda arch %s or cuda version %s. " + "Will discard all cutlass ops. " + "Please consider setting _inductor.cuda.arch and _inductor.cuda.version configs.", + arch, + version, + ) + return [] + arch = _normalize_cuda_arch(arch) + args = CUTLASSArgs(architectures=arch, cuda_version=version) + manifest = cutlass_manifest.Manifest(args) + + if arch == "90": + cutlass_generator.GenerateSM90(manifest, args.cuda_version) + cutlass_generator.GenerateSM80(manifest, args.cuda_version) + else: + try: + func = getattr(cutlass_generator, "GenerateSM" + arch) + func(manifest, args.cuda_version) + except AttributeError as e: + raise NotImplementedError( + "Arch " + arch + " is not supported by current cutlass lib." + ) from e + return manifest.operations + + +def gen_ops() -> List[Any]: + """ + Generates all supported CUTLASS operations. + """ + arch = get_cuda_arch() + version = get_cuda_version() + return _gen_ops_cached(arch, version) + + +def torch_dtype_to_cutlass_type( + torch_dtype: torch.dtype, +) -> "cutlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821 + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library # type: ignore[import] + + if torch_dtype == torch.float: + return cutlass_library.library.DataType.f32 + elif torch_dtype == torch.half: + return cutlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return cutlass_library.library.DataType.bf16 + else: + raise NotImplementedError(f"Unsupported data type: {torch_dtype=}") + + +def dtype_match( + torch_dtype: Optional[torch.dtype], + cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821 +) -> bool: + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library + + if torch_dtype == torch.float: + return ( + cutlass_dtype == cutlass_library.library.DataType.f32 + or cutlass_dtype == cutlass_library.library.DataType.tf32 + ) + elif torch_dtype == torch.half: + return cutlass_dtype == cutlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return cutlass_dtype == cutlass_library.library.DataType.bf16 + elif torch_dtype == torch.int8: + return cutlass_dtype == cutlass_library.library.DataType.s8 + elif torch_dtype == torch.uint8: + return cutlass_dtype == cutlass_library.library.DataType.u8 + elif torch_dtype == torch.int32: + return cutlass_dtype == cutlass_library.library.DataType.s32 + else: + return False + + +def get_accumulator_dtype( + input_torch_dtypes: List[torch.dtype], +) -> Optional[torch.dtype]: + """ + Given a pair of input torch dtypes, returns the inferred accumulator torch dtype. + """ + + if len(input_torch_dtypes) != 2: + return None + + torch_dtype = None + if input_torch_dtypes[0] == input_torch_dtypes[1]: + torch_dtype = input_torch_dtypes[0] + else: + size0 = torch.tensor([], dtype=input_torch_dtypes[0]).element_size() + size1 = torch.tensor([], dtype=input_torch_dtypes[1]).element_size() + if size0 > size1: + dtype0, dtype1 = input_torch_dtypes + else: + dtype1, dtype0 = input_torch_dtypes + if dtype0 in [torch.half, torch.bfloat16] and dtype1 in [ + torch.int8, + torch.uint8, + ]: + torch_dtype = dtype0 + + if torch_dtype == torch.half: + if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction: + return torch_dtype + else: + return torch.float + if torch_dtype in {torch.bfloat16, torch.float}: + return torch.float + if torch_dtype == torch.int8: + return torch.int32 + raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes=}") + + +def get_alignments(torch_dtype: torch.dtype) -> List[int]: + """ + Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype. + CUTLASS gemm / conv SM80 APIs support 16 bytes max alignment, and 2 bytes min alignment. + """ + + if torch_dtype in (torch.half, torch.bfloat16): + return [8, 4, 2, 1] + elif torch_dtype == torch.float: + return [4, 2, 1] + elif torch_dtype in (torch.uint8, torch.int8): + return [16, 8, 4, 2] + elif torch_dtype == torch.int32: + return [4, 2, 1] + else: + raise NotImplementedError(f"unsupported {torch_dtype=} for alignments") + + +def get_max_alignment(inductor_layout: Layout) -> int: + """ + Returns the max alignment (in terms of number of elements) for a given Inductor Layout. + """ + + dtype = inductor_layout.dtype + size = inductor_layout.size + offset = inductor_layout.offset + + def is_static_int(number): + return isinstance(number, (int, sympy.Integer)) + + try: + contiguous_dim = inductor_layout.stride.index(1) + except ValueError: + # No dim with stride 1 found, return 1 + return 1 + if ( + is_static_int(size[contiguous_dim]) + and is_static_int(offset) + and all(is_static_int(s) for s in inductor_layout.stride) + ): + alignments = get_alignments(dtype) + for alignment in alignments: + if ( + int(size[contiguous_dim]) % alignment != 0 + or int(offset) % alignment != 0 + ): + continue + if all( + (dim == contiguous_dim) + or (inductor_layout.stride[dim] % alignment == 0) + for dim in range(len(size)) + ): + return alignment + return 1 + + +class CUDACompileSourceCapturingContext: + # Helper class for Benchmarking and Testing CUTLASS Kernels in isolation. + # Can be used to capture the sourcecode passed to CUDACodeCache.compile + + def __init__(self): + self.sources = [] + self._compile_patch = None + + def __enter__(self, *args, **kwargs): + import unittest.mock as mock + + import torch._inductor.codecache + + _compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile + + def my_compile(source_code, dst_file_ext): + self.sources.append(source_code) + return _compile_method_orig(source_code, dst_file_ext) + + self._compile_patch = mock.patch( + "torch._inductor.codecache.CUDACodeCache.compile", my_compile + ) + return self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr] + + def __exit__(self, *args, **kwargs): + return self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr] + + +def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path): + # returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run + # Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled. + from torch._inductor.codecache import cuda_compile_command + + extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"] + compile_command = cuda_compile_command( + [str(srcpath)], str(exepath), "exe", extra_args=extra_args + ) + return compile_command diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff99b871c82d6ce701a89c65bdd94707463accd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -0,0 +1,19 @@ +# mypy: allow-untyped-defs +from ..common import DeviceOpOverrides, register_device_op_overrides + + +class CUDADeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name): + return f"from torch._C import _cuda_getCurrentRawStream as {name}" + + def set_device(self, device_idx): + return f"torch.cuda.set_device({device_idx})" + + def synchronize(self): + return "torch.cuda.synchronize()" + + def device_guard(self, device_idx): + return f"torch.cuda._DeviceGuard({device_idx})" + + +register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/gemm_template.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..7a999b45b789df49b9dd4423f4e0ec0570f27432 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/gemm_template.py @@ -0,0 +1,1564 @@ +# mypy: allow-untyped-defs +import copy +import enum +import logging +import re +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple, Union + +from ... import ir +from ...config import cuda as inductor_cuda_config +from ...ir import ( + Buffer, + ChoiceCaller, + CUDATemplateBuffer, + FixedLayout, + IRNode, + Layout, + ReinterpretView, +) +from ..common import IndentedBuffer +from . import cutlass_utils +from .cuda_kernel import CUDATemplateKernel +from .cuda_template import CUTLASSTemplate + + +log = logging.getLogger(__name__) + +# Jinja template for GEMM Kernel, used by the CUTLASSGemm3xTemplate class below. +GEMM_TEMPLATE_CUTLASS_3X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +{{instance_definition}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +{{kernel_call_signature}} { + try { + int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}}; + int64_t M = {{kernel.size(X, -2)}}; + int64_t K = {{kernel.size(X, -1)}}; + int64_t N = {{kernel.size(W, -1)}}; + using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + static cutlass::KernelHardwareInfo hw_info; + if (hw_info.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); + } + {{instance_type}}::Arguments arguments; + {{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw, + X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}} + {{instance_type}} gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers +#ifndef CUTLASS_BACKEND_DISABLE_CHECKS + {{kernel.check_not_null(X)}} + {{kernel.check_not_null(W)}} + {{kernel.check_not_null(Bias)}} + {{kernel.check_not_null(Y)}} + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } +#endif +#ifdef CUTLASS_DEBUG_TRACE_LEVEL +#if CUTLASS_DEBUG_TRACE_LEVEL == 1 + { + // Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 + // we don't need a print statement, it's happening inside the function. + gemm_op.maximum_active_blocks(); + } +#endif +#endif + { + auto status = gemm_op.initialize(arguments, workspace, stream); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op(stream); + CUTLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} +""" + +# Jinja template for Cutlass 3.x GEMM Kernel arguments, used by the CUTLASSGemmTemplate class below. +GEMM_ARGS_CUTLASS_3X = r""" + // Initialize GemmUniversal3xInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast({{M}}), + static_cast({{N}}), + static_cast(K), + static_cast(B) + }, // ProblemShape problem_shape + { + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // ElementA const* ptr_A + { + {{template.cute_int(kernel.stride(X, -2), "stride_x0")}}, + {{template.cute_int(kernel.stride(X, -1), "stride_x1")}}, + {{template.cute_int(kernel.stride(X, -3), "batch_stride_x")}} + }, // StrideA dA + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B + { + {{template.cute_int(kernel.stride(W, -1), "stride_w1")}}, + {{template.cute_int(kernel.stride(W, -2), "stride_w0")}}, + {{template.cute_int(kernel.stride(W, -3), "batch_stride_w")}} + }, // StrideB dB + }, // MainloopArguments mainloop + {{epilogue_arguments}}, + hw_info + }; +""" + +# Jinja template for Cutlass 3.x GEMM Kernel arguments if epilogue fusion is applied, +# used by the CUTLASSGemmTemplate class below. +GEMM_ARGS_CUTLASS_3X_EPILOGUE = r""" + // see https://tinyurl.com/4rk89z48 + { + {{epilogue_args}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // ElementC const* ptr_C + { + {{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}}, + {{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}}, + {{template.cute_int(kernel.stride(Bias, -3), "batch_stride_bias")}} + }, // StrideC dC + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D + { + {{template.cute_int(kernel.stride(Y, -2), "stride_y0")}}, + {{template.cute_int(kernel.stride(Y, -1), "stride_y1")}}, + {{template.cute_int(kernel.stride(Y, -3), "batch_stride_y")}} + }, // StrideD dD + }, // EpilogueArguments epilogue +""" + +# Jinja template for GEMM Kernel, used by the CUTLASS2xGemmTemplate class below. +GEMM_TEMPLATE_CUTLASS_2X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +{{instance_definition}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +{{kernel_call_signature}} { + try { + int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}}; + int64_t M = {{kernel.size(X, -2)}}; + int64_t K = {{kernel.size(W, -2)}}; + int64_t N = {{kernel.size(W, -1)}}; + using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + static cutlass::KernelHardwareInfo hw_info; + if (hw_info.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); + } + {{instance_type}}::Arguments arguments; + {{template.render_gemm_arguments(instance_type, argument_template, epilogue_template, should_swap_xw, + X, W, Bias, Meta, Y, alpha, beta, kernel, epilogue_args)}} + {{instance_type}} gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + + // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers +#ifndef CUTLASS_BACKEND_DISABLE_CHECKS + {{kernel.check_not_null(X)}} + {{kernel.check_not_null(W)}} + {{kernel.check_not_null(Bias)}} + {{kernel.check_not_null(Meta)}} + {{kernel.check_not_null(Y)}} + + + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } +#endif +#ifdef CUTLASS_DEBUG_TRACE_LEVEL +#if CUTLASS_DEBUG_TRACE_LEVEL == 1 + { + // Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 + // we don't need a print statement, it's happening inside the function. + gemm_op.maximum_active_blocks(); + } +#endif +#endif + + { + auto status = gemm_op.initialize(arguments, workspace, stream); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op(stream); + CUTLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} +""" + +# Jinja template for Cutlass 2.x GEMM Kernel arguments, used by the CUTLASS2xGemmTemplate class below. +GEMM_ARGS_CUTLASS_2X = r""" + int64_t batch_stride_x = {{kernel.stride(X, -3)}}; + int64_t row_stride_x = {{kernel.row_or_column_stride(X)}}; + int64_t batch_stride_w = {{kernel.stride(W, -3)}}; + int64_t row_stride_w = {{kernel.row_or_column_stride(W)}}; + int64_t batch_stride_bias = {{kernel.stride(Bias, -3)}}; + int64_t row_stride_bias = {{kernel.row_or_column_stride(Bias)}}; + int64_t batch_stride_y = {{kernel.stride(Y, -3)}}; + int64_t row_stride_y = {{kernel.row_or_column_stride(Y)}}; + // Initialize GemmUniversalInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast(M), + static_cast(N), + static_cast(K) + }, // GemmCoord problem_size + {{split_k if split_k > 1 else 'B'}}, // int batch_count + {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // void const * ptr_A + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // void const * ptr_B + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // void const * ptr_C + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // void * ptr_D + batch_stride_x, // int64_t batch_stride_A + batch_stride_w, // int64_t batch_stride_B + batch_stride_bias, // int64_t batch_stride_C + batch_stride_y, // int64_t batch_stride_D + row_stride_x, // typename LayoutA::Stride::LongIndex lda + row_stride_w, // typename LayoutB::Stride::LongIndex ldb + row_stride_bias, // typename LayoutC::Stride::LongIndex ldc + row_stride_y, // typename LayoutC::Stride::LongIndex ldd + }; +""" + +GEMM_ARGS_SPARSE_CUTLASS_2X = r""" + using TensorRefA = cutlass::TensorRef<{{instance_type}}::ElementA, + {{instance_type}}::LayoutA>; + using TensorRefB = cutlass::TensorRef<{{instance_type}}::ElementB, + {{instance_type}}::LayoutB>; + using TensorRefC = cutlass::TensorRef<{{instance_type}}::ElementC, + {{instance_type}}::LayoutC>; + using TensorRefE = cutlass::TensorRef<{{instance_type}}::ElementE, + {{instance_type}}::LayoutE>; + // Note that "X" and "W" names may be misleading here. Namely, for + // sparse GEMM, the first argument is always sparse, while typically + // weight matrix, implied by name "W" will be sparse in + // applications. Thus, just remember that here: "X" refers to first + // argument, that is sparse, and "W" to second, that is dense. + TensorRefA X_ref({{template.cutlass_type_cast(X, kernel.ptr(X))}}, {{kernel.row_or_column_stride(X)}}); + TensorRefB W_ref({{template.cutlass_type_cast(W, kernel.ptr(W))}}, {{kernel.row_or_column_stride(W)}}); + TensorRefC Y_ref({{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, {{kernel.row_or_column_stride(Y)}}); + TensorRefE Meta_ref({{template.cutlass_sparse_meta_type_cast(Meta, kernel.ptr(Meta))}}, + TensorRefE::Layout::packed({ {{kernel.size(Meta, 0)}}, {{kernel.size(Meta, 1)}} })); + // Initialize GemmSparse arguments. + arguments = { + { + static_cast({{M}}), + static_cast({{N}}), + static_cast(K), + }, // GemmCoord problem_size + X_ref, // TensorRef ref_A + W_ref, // TensorRef ref_B + Y_ref, // TensorRef ref_C + Y_ref, // TensorRef ref_D + Meta_ref, // TensorRef ref_E + {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue, + }; +""" + +# Additional includes which are neccessary if the standalone test / debug runner is generated as wel +GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES = r""" +#ifdef GENERATE_STANDALONE_RUNNER +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include +#endif +""" + +# Jinja template for the standalone runner that may be generated as part of the code. +GEMM_STANDALONE_RUNNER_TEMPLATE = r""" +#ifdef GENERATE_STANDALONE_RUNNER +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed, float max=1.0, float min=-1.0) { + if (block.size()<=0) return false; + Element scope_max(static_cast(max)), scope_min(static_cast(min)); + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +extern "C" int run_standalone(uint64_t seed, int repetitions) { + std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl; + size_t workspace_size = 0; + size_t* workspace_size_ptr = &workspace_size; + + using ElementA = {{kernel.cutlass_dtype(X)}}; + using ElementB = {{kernel.cutlass_dtype(W)}}; + using ElementC = {{kernel.cutlass_dtype(Bias, default_dtype='uint8_t')}}; // may not be void + using ElementD = {{kernel.cutlass_dtype(Y)}}; + + cutlass::DeviceAllocation X_data({{kernel.max_valid_index(X)+1}}); + initialize_block(X_data, seed++); + cutlass::DeviceAllocation W_data({{kernel.max_valid_index(W)+1}}); + initialize_block(W_data, seed++); + cutlass::DeviceAllocation Bias_data({{kernel.max_valid_index(Bias)+1}}); + initialize_block(Bias_data, seed++); + cutlass::DeviceAllocation Y_data({{kernel.max_valid_index(Y)+1}}); + + cutlass::DeviceAllocation workspace_data; + // Call once with workspace_size_ptr set to get workspace size + + std::cout << "Calling once to get workspace size" << std::endl; + {{test_call_statement}}; + // Allocate workspace if neccessary + if (workspace_size > 0) { + workspace_data.reset(workspace_size); + std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl; + } + std::cout << "Calling Kernel as {{test_call_statement}};" << std::endl; + workspace_size_ptr = nullptr; + for (int i=0; i None: + """ + Args: + input_nodes (List[Buffer]): List of input nodes of the GEMM kernel. + layout (Layout): Layout type of the resulting output node. + alpha (float): The scaling factor for the product of the inputs in the GEMM operation. + beta (float): The scaling factor applied to the output matrix. + input_reorder (Optional[List[int]]): Specifies the reordering of the input nodes. If not provided, + no reordering is performed. Defaults to None. + """ + super().__init__("cutlass_gemm", input_nodes, layout, input_reorder) + self.alpha = alpha + self.beta = beta + assert len(input_nodes) == 2 or len(input_nodes) == 3 + assert self._are_inputs_layout_compatible( + [node.get_layout() for node in input_nodes] + ) + + @staticmethod + @abstractmethod + def add_cutlass_gemm_choices( + choices: List[ChoiceCaller], + layout: ir.Layout, + input_nodes: List[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[List[int]] = None, + **extra_kwargs, + ) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def _get_supported_ops() -> "List[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + raise NotImplementedError + + @staticmethod + @abstractmethod + def _has_tma_epilogue(self) -> bool: + raise NotImplementedError + + @abstractmethod + def _get_template(self) -> str: + raise NotImplementedError + + @abstractmethod + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> Tuple[str, Optional[str]]: + raise NotImplementedError + + @abstractmethod + def _are_inputs_layout_compatible(self, layouts: List[Layout]) -> bool: + raise NotImplementedError + + @abstractmethod + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _define_gemm_instance( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> Tuple[str, str]: + raise NotImplementedError + + @abstractmethod + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> Tuple[Optional[Buffer], List[Optional[Buffer]], List[str]]: + raise NotImplementedError + + def _add_cutlass_gemm_choices( + self, + choices: List[ChoiceCaller], + layout: ir.Layout, + input_nodes: List[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[List[int]] = None, + **extra_kwargs, + ) -> None: + """ + Adds Cutlass GEMM configurations choices to the auto-tuning list. + + This function mutates the passed list of choices by appending the choices for Cutlass GEMM configs to it. + + Args: + choices (list): The list to which choices are appended. + layout (ir.Layout): The layout configuration. + input_nodes (list): The list of input nodes. + alpha (float,int): Scaling factor, defaults to 1. + beta (float,int): Offset, defaults to 0. + input_reorder (list, optional): Order of the inputs, defaults to None. + **extra_kwargs: Additional keyword arguments. + + """ + + ops = self.gen_ops() + for op in ops: + self.maybe_append_choice( + choices, + op=op, + ) + if len(ops) == 0: + input_layouts = [node.get_layout() for node in input_nodes] + input_strides = [node.get_stride() for node in input_nodes] + output_layout = layout + warning_msg = f"No suitable Cutlass GEMM configs found, fallbacks used ( {len(ops)=}, {output_layout=}, {input_layouts=}, {input_strides=} )" # noqa: B950 + log.warning(warning_msg) + log.debug( + "Added %d Cutlass gemm configs.", + len(ops), + ) + + def header(self) -> IndentedBuffer: + """ + Returns a buffer containing CUDA C++ code for the header section of the CUTLASS GEMM template. + This section primarily includes the necessary header files. + + Returns: + IndentedBuffer: An instance of IndentedBuffer that contains the generated CUDA C++ header code. + """ + res = super().header() + res.splice( + """ + #include "cutlass/gemm/gemm.h" + #include "cutlass/gemm/device/gemm_universal.h" + #include "cutlass/gemm/device/gemm_universal_adapter.h" + #include "cutlass/gemm/kernel/gemm_universal.hpp" + #include "cutlass/gemm/device/gemm_sparse.h" + #include "cutlass/gemm/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/default_epilogue.hpp" + #include "cutlass/epilogue/thread/linear_combination.h" + #include "cutlass/epilogue/thread/activation.h" + #include "cutlass/gemm/dispatch_policy.hpp" + #include "cutlass/gemm/kernel/tile_scheduler.hpp" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/distribution.h" + #include "cutlass/util/packed_stride.hpp" + #include "cutlass/util/tensor_view_io.h" + """ + ) + if inductor_cuda_config.generate_test_runner: + res.splice(GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES) + return res + + @staticmethod + def cutlass_layout(torch_layout: ir.Layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821 + """ + Converts an ir.Layout instance into the corresponding cutlass_library.LayoutType enum value + (RowMajor, ColumnMajor, or None if no matching value is found ). + + Args: + torch_layout (ir.Layout): The layout that needs to be looked up. + + Returns: + cutlass_lib.LayoutType: The converted layout corresponding to the `torch_layout` or None if no matching + value is found. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if torch_layout.stride[-1] == 1: + return cutlass_lib.LayoutType.RowMajor + elif torch_layout.stride[-2] == 1: + return cutlass_lib.LayoutType.ColumnMajor + else: + return None + + @staticmethod + def flip_cutlass_layout( + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_lib.LayoutType": # type: ignore[name-defined] # noqa: F821 + """Helper method: Flips a given cutlass layout (cutlass_lib.LayoutType) from RowMajor + to ColumnMajor or vice versa""" + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if cutlass_layout == cutlass_lib.LayoutType.RowMajor: + return cutlass_lib.LayoutType.ColumnMajor + else: + return cutlass_lib.LayoutType.RowMajor + + @staticmethod + def layout_match( + torch_layout: ir.Layout, + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + """Helper Method: Determines whether a given torch layout matches a given Cutlass layout""" + return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout + + @staticmethod + def set_alignment(torch_layout, op_element) -> bool: + """ + Helper method to update the alignment of a given CUTLASS GEMM op operand's element. + + This method modifies the alignment of the given Cutlass GEMM op operand's element to match the + layout of the corresponding ir.Buffer node. + + Args: + torch_layout: The layout of the corresponding ir.Buffer node. + op_element: The Cutlass GEMM op operand's element whose alignment is to be updated. + + Returns: + bool: True if the alignment was successfully updated, False otherwise. + """ + alignment = cutlass_utils.get_max_alignment(torch_layout) + cuda_arch = cutlass_utils.get_cuda_arch() + if cuda_arch and int(cuda_arch) >= 90 and alignment < op_element.alignment: + return False + else: + op_element.alignment = alignment + return True + + @staticmethod + def should_swap_XW( + bias: IRNode, + ) -> bool: + """ + Helper method to determine whether we should do an explicit transpose by switching the order of the + matmul operands. This might be neccessary when we can't otherwise arrive at the right memory + layout for the given Bias operand. + + Note: This method is a workaround for CUDA Errors that seemingly non-deterministically + occurred in practice in some CUTLASS GEMM Kernels with Linear epilogues that have a bias term. + it might make sense to check on newer Cutlass releases whether it makes sense to keep + returning True in certain cases or whether it becomes unneccessary. + """ + # If bias is row major, swap all M and N dimensions + if ( + bias is not None + and len(bias.get_stride()) >= 2 + and bias.get_stride()[-1] in (0, 1) + ): + log.debug("GEMM Layout swapped X and W -> explicit transpose") + return True + return False + + @staticmethod + def swap_XW( + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + """ + Swap operands X and W (aka operans A and B) of the GEMM operation. This + requires transposing the operands, which is done by swapping the strides. + Note that we don't change the apparent external layout, just the operand layout. + this is intentional. + """ + new_op = copy.deepcopy(op) + new_op.A.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.A.layout) + new_op.B.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.B.layout) + new_op.A, new_op.B = new_op.B, new_op.A + new_op.C.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.C.layout) + new_op.D.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.D.layout) + return new_op + + def fix_op_layout( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + X: Buffer, + W: Buffer, + Bias: Optional[Buffer], + Y: Union[Buffer, ReinterpretView], + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + # This is a workaround to deal with cases where the input layouts have changed + # between autotuning and rendering. This happens if the inputs layout + # are FlexibleLayout instances. In this case, we need to update the + # op's input layouts. It is a hack, because now the op + # we benchmarked is not the same as the op we render, + # but there is no simple way to fix this in the autotuner, since that would + # potentially disable other optimizations. + a_layout = X.get_layout() + b_layout = W.get_layout() + c_layout = Bias.get_layout() if Bias is not None else None + + d_layout = copy.deepcopy(Y.get_layout()) + match_list = [ + CUTLASSGemmTemplate.layout_match(buf.get_layout(), op_layout) + for buf, op_layout in zip( + (X, W, Bias, Y), + (op.A.layout, op.B.layout, op.C.layout, op.D.layout), + ) + if buf is not None + ] + all_match = all(match_list) + if all_match: + return op + log.warning( + f"Cutlass GEMM Layout change: Input and/or output layouts have changed between autotuning/retuning and call to render on {self}. Applying workaround. This can lead to suboptimal performance. Match List: {match_list}" # noqa: G004, B950 + ) + new_op = copy.deepcopy(op) + + if a_layout is not None: + new_op.A.layout = CUTLASSGemmTemplate.cutlass_layout(a_layout) + if b_layout is not None: + new_op.B.layout = CUTLASSGemmTemplate.cutlass_layout(b_layout) + if c_layout is not None: + new_op.C.layout = CUTLASSGemmTemplate.cutlass_layout(c_layout) + new_op.C.element = cutlass_utils.torch_dtype_to_cutlass_type(c_layout.dtype) + if d_layout is not None: + new_op.D.layout = CUTLASSGemmTemplate.cutlass_layout(d_layout) + return new_op + + def filter_op( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + """ + Helper method: + + Determines whether a given Cutlass GEMM op definition is suitable for the current + input / output of the operation that this template is supposed to implement. + + Takes memory layout, dtype and support for EVT operations into account, + and filters potentially problematic ops. + + Returns None if the op is not suitable, otherwise returns the op to be used, which might + have been mutated. + """ + + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + # Skip simt kernels + if ( + op.tile_description.math_instruction.opcode_class + == cutlass_lib.OpcodeClass.Simt + ): + return None + + if op.gemm_kind not in self._get_supported_ops(): + return None + + X = self.input_nodes[0] + W = self.input_nodes[1] + + # Filter ops according to the shape match. + if not self._shape_match(op): + return None + + # Filter ops by dtypes. + accumulator_torch_dtype = cutlass_utils.get_accumulator_dtype( + [X.get_dtype(), W.get_dtype()], + ) + if not ( + cutlass_utils.dtype_match(X.get_dtype(), op.A.element) + and cutlass_utils.dtype_match(W.get_dtype(), op.B.element) + and cutlass_utils.dtype_match( + self.output_node.get_layout().dtype, op.C.element + ) + and cutlass_utils.dtype_match( + accumulator_torch_dtype, op.accumulator_type() + ) + ): + return None + + # Filter ops by input layouts. + if not ( + self.layout_match(X.get_layout(), op.A.layout) + and self.layout_match(W.get_layout(), op.B.layout) + ): + return None + + # Filter ops by alignment. + if not self._alignment_match(op): + return None + + # Update op. + op = copy.deepcopy(op) + + # Set output layout. + op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout()) + + # Filter ops by alignments and set alignments. + if not ( + self.set_alignment(X.get_layout(), op.A) + and self.set_alignment(W.get_layout(), op.B) + and self.set_alignment(self.output_node.get_layout(), op.D) + ): + return None + + # Set epilogue. + # TODO: update epilogue functor according to epilogues. + op.element_epilogue = op.accumulator_type() + if inductor_cuda_config.cutlass_op_allowlist_regex is not None: + if not re.search( + inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name() + ): + return None + if inductor_cuda_config.cutlass_op_denylist_regex is not None: + if re.search( + inductor_cuda_config.cutlass_op_denylist_regex, op.configuration_name() + ): + return None + + # Set bias layout and alignment. + if not self._set_bias_layout_and_alignment(op): + return None + + return op + + def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + """ + Creates a list of Cutlass GemmOperation instances that match the operation this template is designed to represent. + The matching is carried out with respect to the input and output specifications of the operation. + + No function arguments. + + Returns: + List[cutlass_gemm_op.GemmOperation]: A list of GemmOperation instances that are compatible with the + operation requirements of this template. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm] + res: Dict[str, cutlass_gemm_op.GemmOperation] = {} + for op_dict in ops.values(): + for op_list in op_dict.values(): + for op in op_list: + assert isinstance(op, cutlass_gemm_op.GemmOperation) + filter_res = self.filter_op(op) + if ( + filter_res is not None + and res.get(filter_res.configuration_name(), None) is None + ): + res[filter_res.configuration_name()] = filter_res + log.debug("Got cutlass configs: total number of ops: %d, ", len(res)) + return list(res.values())[: inductor_cuda_config.cutlass_max_profiling_configs] + + def gemm_mode(self) -> str: + """ + Returns a Cutlass GEMM mode string for the current operation, dependent on whether this op implements + a batched GEMM or a simple GEMM without batch dimension. + + Returns: + str: A string indicating the Cutlass GEMM mode. If the output node has more than two dimensions, + "cutlass::gemm::GemmUniversalMode::kBatched" is returned, otherwise + "cutlass::gemm::GemmUniversalMode::kGemm" is returned. + """ + sizes = self.output_node.get_size() + if len(sizes) > 2: + return "cutlass::gemm::GemmUniversalMode::kBatched" + else: + return "cutlass::gemm::GemmUniversalMode::kGemm" + + def render( # type: ignore[override] + self, + kernel: CUDATemplateKernel, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + template_buffer_node: Optional[CUDATemplateBuffer] = None, + **kwargs, + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + Renders the Cutlass based CUDA C++ code for the GEMM Kernel that this template is designed to implement, + including potentially fused epilogues. + + Args: + kernel (CUDATemplateKernel): The kernel to be rendered. + op (cutlass_gemm_op.GemmOperation, optional): A GEMM operation that is required to be compatible with the + input and output definitions as well as a possible epilogue. Defaults to None. + **kwargs: Additional keyword arguments. Currently unused. + + Returns: + str: Cutlass based CUDA C++ code fragment as a string, to be used by the current + CUDATemplateKernel or autotuning code. + + Note: + All inputs and their corresponding buffer addresses and names take precedence over previously + passed inputs to the template at construction time. However, they should be layout compatible. + """ + + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + assert isinstance( + op, cutlass_gemm_op.GemmOperation + ), "op argument is required and has to be an instance of GemmOperation" + + assert len(self.input_nodes) >= 2 and self.output_node is not None + X, W = self.input_nodes[0], self.input_nodes[1] + assert isinstance(X.layout, FixedLayout), "X.layout is not fixed" + assert isinstance(W.layout, FixedLayout), "W.layout is not fixed" + Y = self.output_node + if template_buffer_node is not None: + Y = template_buffer_node + + Bias, extra_inputs, extra_names = self._get_extra_inputs_and_names(op) + + # Define Kernel call signature + # Important: This step also populates Kernel name to node mapping data structures, + # which are required further below ( for example by CutlassEVTEpilogueArgumentFormatter and + # the template renderer ) + inputs = [X, W, Bias, *extra_inputs] + names = ["X", "W", "Bias", *extra_names] + ["Y"] + names_str = ",".join(names) + if self.input_reorder is not None: + input_reorder = self.input_reorder + else: + input_reorder = None + kernel_call_signature = kernel.def_kernel( + inputs=inputs, outputs=[Y], names_str=names_str, input_reorder=input_reorder # type: ignore[arg-type] + ) + test_call_statement = self.test_call_statement(kernel, inputs, names_str) + # The layouts might have changed between autotuning and this call if they were FlexibleLayout + # we need to adapt, which might lead to suboptimal performance. + op = self.fix_op_layout(op, X, W, Bias, Y) + + # to make op mutable without affecting others + op = copy.deepcopy(op) + if Bias is not None: + assert Bias.get_layout().dtype == X.get_layout().dtype + # This might have been set to void during filtering, when the assumption was still that there's no C + # operand + op.C.element = op.A.element + + argument_template, epilogue_template = self._get_template_args(op) + should_swap_xw: bool = False + epilogue_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}" + if Bias is not None and self._has_tma_epilogue(op): + if ( + op.epilogue_schedule + != cutlass_lib.EpilogueScheduleType.EpilogueTransposed + and self.should_swap_XW(Bias) + ): + # TMA epilogue requires bias vector in column major to get best perf. + op = self.swap_XW(op) + should_swap_xw = True + + instance_definition, instance_type = self._define_gemm_instance(op) + + options = dict( + alpha=self.alpha, + beta=self.beta, + X=X, + W=W, + Y=Y, + kernel_call_signature=kernel_call_signature, + Bias=Bias, + epilogue_template=epilogue_template, + argument_template=argument_template, + should_swap_xw=should_swap_xw, + template=self, + kernel=kernel, + instance_definition=instance_definition, + instance_type=instance_type, + input_reorder=self.input_reorder, + epilogue_args=epilogue_args, + test_call_statement=test_call_statement, + ) + options.update(dict(zip(extra_names, extra_inputs))) + res = self._template_from_string(self._get_template()).render(**options) + if inductor_cuda_config.generate_test_runner: + test_runner_code = self._template_from_string( + GEMM_STANDALONE_RUNNER_TEMPLATE + ).render(**options) + res += "\n\n" + test_runner_code + return res + + def test_call_statement( + self, + kernel, + input_nodes, + names_str: str = "", + ) -> str: + """ + Helper method to render the Cutlass CUDA C++ code required for calling the GEMM operation in the standalone + test runner that might also be generated along with the rest of the code, if the corresponding config is + enabled. + + Returns a C++ statement that calls the GEMM operation with the correct arguments. + """ + _, __, arg_types = kernel.args.cpp_argdefs() + arg_names = [name.strip() for name in names_str.strip().split(",")] + if input_nodes[2] is None: + del arg_names[2] + arguments = [ + f"(({arg_type}){arg_name}_data.get())" + for arg_type, arg_name in zip(arg_types, arg_names) + ] + return f"{kernel.kernel_name}({', '.join(arguments)}, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" + + +class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate): + def __init__( + self, + input_nodes: List[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[List[int]] = None, + ): + super().__init__(input_nodes, layout, alpha, beta, input_reorder) + + @staticmethod + def add_cutlass_gemm_choices( + choices: List[ChoiceCaller], + layout: ir.Layout, + input_nodes: List[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[List[int]] = None, + **extra_kwargs, + ) -> None: + template = CUTLASS3xGemmTemplate( + input_nodes, layout, alpha, beta, input_reorder + ) + template._add_cutlass_gemm_choices( + choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs + ) + + @staticmethod + def _get_supported_ops() -> "List[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + import cutlass_library.library as cutlass_lib + + return [cutlass_lib.GemmKind.Universal3x] + + def _get_template(self) -> str: + return GEMM_TEMPLATE_CUTLASS_3X + + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> Tuple[str, Optional[str]]: + return (GEMM_ARGS_CUTLASS_3X, GEMM_ARGS_CUTLASS_3X_EPILOGUE) + + @staticmethod + def _has_tma_epilogue( # noqa: F821 # type: ignore[arg-type,name-defined] + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined,arg-type] # noqa: F821 + ) -> bool: # type: ignore[name-defined] + """Helper method: Determine whether a given Cutlass GEMM op has a TMA Epilogue""" + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + result = False + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + epilogue_schedule_str = str(op.epilogue_schedule).split(".")[-1] + result = epilogue_schedule_str.lower().startswith("tma") + return result + + def _are_inputs_layout_compatible(self, layouts: List[Layout]) -> bool: + """ + Evaluates whether input layouts are compatible for General Matrix Multiply (GEMM). + + This function checks compatibility of A, B, and possibly C operand layouts for + a General Matrix Multiply (GEMM) operation, expressed as 'alpha * matmul(A, B) + beta * C'. + It verifies requirements such as matching data types, minimum rank, and suitability + for broadcasting, as defined by PyTorch operations like `torch.matmul`, `torch.aten.mm`, + `addmm`, `bmm`, `baddbmm`, etc. + + Args: + layouts (List[Layout]): List containing 2 or 3 Layout objects representing + the input matrices A, B, and possibly C. + + Returns: + bool: True if layouts are GEMM compatible, otherwise False. + """ + assert len(layouts) == 2 or len(layouts) == 3 + # Check if A and B are compatible + A_layout, B_layout = layouts[:2] + if len(A_layout.size) < 1: + return False + if len(B_layout.size) < 1: + return False + A_size = [int(i) for i in A_layout.size] + B_size = [int(i) for i in B_layout.size] + if len(A_size) < 2: + A_size.insert(0, 1) + if len(B_size) < 2: + A_size.insert(1, 1) + # Are batch dims broadcastable? + while len(A_size) < len(B_size): + A_size.insert(0, 1) + while len(B_size) < len(A_size): + B_size.insert(0, 1) + K = max(A_size[-1], B_size[-2]) + M = A_size[-2] + N = B_size[-1] + if K != A_size[-1] and A_size[-1] != 1: + return False + if K != B_size[-2] and B_size[-1] != 1: + return False + # check batch dim broadcastable + for i in range(len(A_size) - 2): + if A_size[i] != B_size[i] and A_size[i] != 1 and B_size[i] != 1: + return False + if len(layouts) == 3: + C_layout = layouts[2] + C_size = [int(i) for i in C_layout.size] + while len(C_size) < len(A_size): + C_size.insert(0, 1) + # check batch dims + for i in range(len(A_size) - 2): + bd = max(A_size[i], B_size[i]) + if bd != C_size[i] and C_size[i] != 1: + return False + if len(C_size) > len(A_size): + # This may happen if the last elements of C are contiguous and + # their multiplied size equals the last dim size of B + if M != C_size[len(A_size) - 2] and C_size[len(A_size) - 2] != 1: + return False + remaining_size = 1 + for i in range(len(A_size) - 1, len(C_size)): + remaining_size *= C_size[i] + if N != remaining_size and remaining_size != 1: + return False + return True + assert len(C_size) == len(A_size) + if M != C_size[-2] and C_size[-2] != 1: + return False + if N != C_size[-1] and C_size[-1] != 1: + return False + return True + + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + X, W = self.input_nodes[0], self.input_nodes[1] + return X.layout.size[1] == W.layout.size[0] + + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + return True + + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + if len(self.input_nodes) >= 3 and self.input_nodes[2] is not None: + Bias = self.input_nodes[2] + bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout()) + if op.gemm_kind != cutlass_lib.GemmKind.Universal3x: + if bias_layout != op.D.layout: + # For cutlass2, bias and output layout must match + return False + else: + op.C.layout = bias_layout + if not self.set_alignment(Bias.get_layout(), op.C): + return False + else: + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + op.C.element = cutlass_lib.DataType.void + else: + op.C.layout = op.D.layout + return True + + def _define_gemm_instance( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> Tuple[str, str]: + """Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance. + + This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply + forms a core part of a number of scientific applications, so this efficient and adaptable implementation is + crucial. + + Args: + op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering. + + Returns: + Tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ + code (render) and the second part is the string that specifies the operation type. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance() + if not hasattr(op, "epilogue_functor") or not isinstance( + op.epilogue_functor, enum.Enum + ): + op = copy.deepcopy(op) + op.epilogue_functor = cutlass_lib.EpilogueFunctor.LinearCombination + op_def = emitter.emit(op) + pattern = re.compile(r"\s*struct\s(.*?)\s:") + decl = [line for line in op_def.split("\n") if "struct " in line][-1] + + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid Gemm config: \n" + op_def) + op_type = match.groups()[0] + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + op_def += f"\n using {op_type}_device_type = cutlass::gemm::device::GemmUniversalAdapter<{op_type}>;\n" + op_type = f"{op_type}_device_type" + return op_def, op_type + + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> Tuple[Optional[Buffer], List[Optional[Buffer]], List[str]]: + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + inputs: List[Optional[Buffer]] = [] + names: List[str] = [] + return (Bias, inputs, names) + + def render_gemm_arguments( + self, + argument_template: str, + epilogue_template: str, + should_swap_xw: bool, + X: IRNode, + W: IRNode, + Bias: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: CUDATemplateKernel, + epilogue_args, + ) -> str: + """ + Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation. + + Args: + argument_template (str): Template for the GEMM operation arguments. + epilogue_template (str): Template for the epilogue arguments. + should_swap_xw (bool): Determines whether X, W operands should be swapped. If True, applies an explicit + transpose operation to X and W. + X (IRNode): The X input tensor. + W (IRNode): The W input tensor. + Bias (IRNode): The bias tensor. + Y (IRNode): The output tensor. + alpha (float): Scaling factor for the product of the inputs. + beta (float): Scaling factor for the output tensor. + kernel (CUDATemplateKernel): CUDA Template kernel for the operation. + epilogue_args (any): Additional arguments for the epilogue state. + + Returns: + str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation. + + Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y + tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped + before the function call. + """ + options = dict( + alpha=alpha, + beta=beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + template=self, + kernel=kernel, + M="M", + N="N", + epilogue_args=epilogue_args, + ) + assert epilogue_template is not None + + if should_swap_xw: + # Swap + def clone_with_transposed_stride(node: IRNode) -> IRNode: + old_layout = node.get_layout() + new_stride = list(old_layout.stride) + new_stride[-2], new_stride[-1] = new_stride[-1], new_stride[-2] + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + list(old_layout.size), + new_stride, + old_layout.offset, + ) + return Buffer(node.get_name(), new_layout) + + new_X = clone_with_transposed_stride(X) + new_W = clone_with_transposed_stride(W) + new_Bias = clone_with_transposed_stride(Bias) + new_Y = clone_with_transposed_stride(Y) + options["X"], options["W"], options["Bias"], options["Y"] = ( + new_W, + new_X, + new_Bias, + new_Y, + ) + options["M"], options["N"] = "N", "M" + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, **options + ) + + return arguments + + +class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate): + def __init__( + self, + input_nodes: List[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[List[int]] = None, + ): + super().__init__(input_nodes, layout, alpha, beta, input_reorder) + + @staticmethod + def add_cutlass_gemm_choices( + choices: List[ChoiceCaller], + layout: ir.Layout, + input_nodes: List[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[List[int]] = None, + **extra_kwargs, + ) -> None: + template = CUTLASS2xGemmTemplate( + input_nodes, layout, alpha, beta, input_reorder + ) + template._add_cutlass_gemm_choices( + choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs + ) + + @staticmethod + def _get_supported_ops() -> "List[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + import cutlass_library.library as cutlass_lib + + return [cutlass_lib.GemmKind.Universal, cutlass_lib.GemmKind.Sparse] + + @staticmethod + def _has_tma_epilogue(self) -> bool: + return False + + def _get_template(self) -> str: + return GEMM_TEMPLATE_CUTLASS_2X + + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> Tuple[str, Optional[str]]: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + return (GEMM_ARGS_SPARSE_CUTLASS_2X, None) + + return (GEMM_ARGS_CUTLASS_2X, None) + + def _are_inputs_layout_compatible(self, layouts: List[Layout]) -> bool: + """ + Evaluates whether input layouts are compatible for set of operations supported by this class. + + Args: + layouts (List[Layout]): List containing Layout objects representing + the input matrices. + + Returns: + bool: True if layouts are GEMM compatible, otherwise False. + """ + assert len(layouts) == 2 or len(layouts) == 3 + # Check if A and B are compatible + A_layout, B_layout = layouts[:2] + if len(A_layout.size) != 2: + return False + if len(A_layout.size) != 2: + return False + A_size = [int(i) for i in A_layout.size] + B_size = [int(i) for i in B_layout.size] + K = max(A_size[-1], B_size[-2]) + M = A_size[-2] + N = B_size[-1] + if K != A_size[-1] and K != 2 * A_size[-2]: + return False + if K != B_size[-2]: + return False + return True + + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + X, W = self.input_nodes[0], self.input_nodes[1] + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + return X.layout.size[1] * 2 == W.layout.size[0] + + return X.layout.size[1] == W.layout.size[0] + + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind != cutlass_lib.GemmKind.Sparse: + return True + + # SparseGemm in CUTLASS has specific alignment check that for + # small k could make some of the choices throw kMisalignedOperand + # CUTLASS error when run, see: + # https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/include/cutlass/gemm/kernel/sparse_gemm.h#L198-L200 # noqa: B950 + # So, let's skip these choices if that would be the case. + X = self.input_nodes[0] + return (X.layout.size[1] * 2) % op.tile_description.tile_shape[2] == 0 + + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + op.C.layout = op.D.layout + return True + + if len(self.input_nodes) >= 3 and self.input_nodes[2] is not None: + Bias = self.input_nodes[2] + bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout()) + if bias_layout != op.D.layout: + # For cutlass2, bias and output layout must match + return False + if not self.set_alignment(Bias.get_layout(), op.C): + return False + else: + op.C.layout = op.D.layout + return True + + def _define_gemm_instance( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> Tuple[str, str]: + """Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance. + + This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply + forms a core part of a number of scientific applications, so this efficient and adaptable implementation is + crucial. + + Args: + op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering. + + Returns: + Tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ + code (render) and the second part is the string that specifies the operation type. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + emitter = cutlass_gemm_op.EmitSparseGemmInstance() + else: + emitter = cutlass_gemm_op.EmitGemmInstance() + op_def = emitter.emit(op) + op_def = op_def.replace( + "cutlass::gemm::device::Gemm", "cutlass::gemm::device::GemmUniversal" + ) + if op.gemm_kind != cutlass_lib.GemmKind.Sparse: + op_def = op_def.replace("false,", "") + pattern = re.compile(r"\s*using\s(.*?)\s=") + decl = op_def.split("\n")[2] + + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid Gemm config: \n" + op_def) + op_type = match.groups()[0] + return op_def, op_type + + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> Tuple[Optional[Buffer], List[Optional[Buffer]], List[str]]: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + Bias = None + Meta = self.input_nodes[2] + else: + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + Meta = None + inputs = [Meta] + names = ["Meta"] + return (Bias, inputs, names) + + def render_gemm_arguments( + self, + instance_type: str, + argument_template: str, + epilogue_template: str, + should_swap_xw: bool, + X: IRNode, + W: IRNode, + Bias: IRNode, + Meta: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: CUDATemplateKernel, + epilogue_args, + ) -> str: + """ + Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation. + + Args: + instance_type (str): GEMM instance type. + argument_template (str): Template for the GEMM operation arguments. + epilogue_template (str): Template for the epilogue arguments. + should_swap_xw (bool): Determines whether X, W operands should be swapped. If True, applies an explicit + transpose operation to X and W. + X (IRNode): The X input tensor. + W (IRNode): The W input tensor. + Bias (IRNode): The bias tensor. + Meta (IRNode): The meta tensor. + Y (IRNode): The output tensor. + alpha (float): Scaling factor for the product of the inputs. + beta (float): Scaling factor for the output tensor. + kernel (CUDATemplateKernel): CUDA Template kernel for the operation. + epilogue_args (any): Additional arguments for the epilogue state. + + Returns: + str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation. + + Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y + tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped + before the function call. + """ + options = dict( + instance_type=instance_type, + alpha=alpha, + beta=beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + Meta=Meta, + template=self, + kernel=kernel, + M="M", + N="N", + epilogue_args=epilogue_args, + ) + + if epilogue_template is None: + arguments = self._template_from_string(argument_template).render( + split_k=1, **options + ) + return arguments + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, **options + ) + + return arguments diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..909710453796563636d55194024a69001a7038bb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -0,0 +1,100 @@ +# mypy: allow-untyped-defs +from typing import Sequence, Union + +from ..scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) +from .cuda.cuda_cpp_scheduling import CUDACPPScheduling +from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling +from .triton import TritonScheduling + + +class CUDACombinedScheduling(BaseScheduling): + """ + Scheduler for CUDA Kernels, which delegates calls as appropriate + to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices + and use a unified-wrapper for codegen. + + If Scheduling code needs to be specialized for the case of mixed Triton / CUDA C++ code, + this would also be the place to do it. + """ + + def __init__(self, scheduler: Scheduler) -> None: + super().__init__() + self._scheduler = scheduler + self._triton_scheduling = TritonScheduling(scheduler) + self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) + self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) + + def get_backend_features(self, device): + return self._triton_scheduling.get_backend_features(device) + + def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: + if self._cuda_cpp_scheduling.is_cuda_cpp_template(node): + return self._cuda_cpp_scheduling + if self._rocm_cpp_scheduling.is_rocm_cpp_template(node): + return self._rocm_cpp_scheduling + return self._triton_scheduling + + def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2): + return True + return self._triton_scheduling.can_fuse_vertical(node1, node2) + + def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + for node in (node1, node2): + if self._cuda_cpp_scheduling.is_cuda_cpp_template(node): + return self._cuda_cpp_scheduling.can_fuse_horizontal( + node1, node2 + ) # always False at the moment + return self._triton_scheduling.can_fuse_horizontal(node1, node2) + + def group_fn(self, sizes): + return self._triton_scheduling.group_fn(sizes) + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + ): + if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node): + assert epilogue_nodes is None or len(epilogue_nodes) == 0 + return self._cuda_cpp_scheduling.codegen_template( + template_node, epilogue_nodes + ) + elif self._rocm_cpp_scheduling.is_rocm_cpp_template(template_node): + assert epilogue_nodes is None or len(epilogue_nodes) == 0 + return self._rocm_cpp_scheduling.codegen_template( + template_node, epilogue_nodes + ) + else: + return self._triton_scheduling.codegen_template( + template_node, epilogue_nodes + ) + + def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]): + return self._triton_scheduling.codegen_node(node) + + def codegen_sync(self): + return self._triton_scheduling.codegen_sync() + + def flush(self): + return self._triton_scheduling.flush() + + def codegen_combo_kernel(self, *args, **kwargs): + return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs) + + def benchmark_fused_nodes(self, nodes): + return self._triton_scheduling.benchmark_fused_nodes(nodes) + + def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): + return self._triton_scheduling.generate_kernel_code_from_nodes( + nodes, benchmark_kernel + ) + + def benchmark_combo_kernel(self, node_list): + return self._triton_scheduling.benchmark_combo_kernel(node_list) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/debug_utils.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3dcd8eea2cf5ab5c7f688891297a05c812de99 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/debug_utils.py @@ -0,0 +1,175 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +from enum import Enum +from typing import List, Optional + +from torch import dtype as torch_dtype + +from .. import config +from ..virtualized import V +from .multi_kernel import MultiKernel + + +log = logging.getLogger(__name__) + + +# AOTI debug printing related configs +class IntermediateValueDebuggingLevel(Enum): + # OFF: No intermediate tensor value debug info will be printed or saved. + OFF = "0" + # LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed. + SAVE_ONLY = "1" + # LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed. + PRINT_ONLY = "2" + + +class DebugPrinterManager: + def __init__( + self, + debug_printer_level, + args_to_print_or_save: Optional[List[str]] = None, + kernel_name: str = "", + kernel=None, + arg_signatures: Optional[List[type]] = None, + ): + self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level) + if args_to_print_or_save is None: + args_to_print_or_save = [] + self.args_to_print_or_save = args_to_print_or_save + self.kernel_name = kernel_name + self.arg_signatures: Optional[List[type]] = None + self.kernel = kernel + self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names() + + def __enter__(self): + self._perform_debug_print_or_save_helper( + self.args_to_print_or_save, + self.kernel_name, + before_launch=True, + arg_signatures=self.arg_signatures, + ) + + def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures): + self._perform_debug_print_or_save_helper( + args_to_print_or_save, + kernel_name, + before_launch=False, + arg_signatures=arg_signatures, + ) + + def _perform_debug_print_or_save_helper( + self, + args_to_print_or_save, + kernel_name, + before_launch, + arg_signatures: Optional[List[type]] = None, + ): + if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF: + return + if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY: + # by default save all the tensor values before launch + self.codegen_intermediate_tensor_value_save( + self.args_to_print_or_save, + self.kernel_name, + before_launch, + arg_signatures=self.arg_signatures, + ) + if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY: + # by default print all the tensor values before launch + self.codegen_intermediate_tensor_value_print( + self.args_to_print_or_save, + self.kernel_name, + before_launch, + arg_signatures=self.arg_signatures, + ) + + @functools.lru_cache # noqa: B019 + def _get_debug_filtered_kernel_names(self) -> List[str]: + if config.aot_inductor.filtered_kernel_names is None: + return [] + return [ + x.strip() + for x in config.aot_inductor.filtered_kernel_names.lower().split(",") + ] + + def set_printer_args( + self, + args_to_print_or_save: List[str], + kernel_name: str, + arg_signatures: Optional[List[type]], + kernel, + ): + # Note: MultiKernel debug printing is not supported for now + if isinstance(kernel, MultiKernel): + log.info( + "MultiKernel type is not supported in AOTI debug printer tool yet." + ) + self.debug_printer_level = IntermediateValueDebuggingLevel.OFF + self.args_to_print_or_save = args_to_print_or_save + self.kernel_name = kernel_name + self.arg_signatures = arg_signatures + self.kernel = kernel + + def codegen_intermediate_tensor_value_save( + self, + args_to_save, + kernel_name, + before_launch=True, + arg_signatures: Optional[List[type]] = None, + ) -> None: + for i, arg in enumerate(args_to_save): + if arg_signatures is not None and not isinstance( + arg_signatures[i], torch_dtype + ): + # infer from the arg data type (has torch.dtype) to see if it is a tensor type + continue + launch_prefix = "before_launch" if before_launch else "after_launch" + if V.graph.cpp_wrapper: + if config.abi_compatible: + V.graph.wrapper_code.writeline( + f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");' + ) + else: + # TODO: add non-abi compatible mode debug printing info + pass + else: + # currently, not cpp wrapper codegen mode not supported. + pass + + def codegen_intermediate_tensor_value_print( + self, + args_to_print, + kernel_name, + before_launch=True, + arg_signatures: Optional[List[type]] = None, + ) -> None: + for i, arg in enumerate(args_to_print): + if arg_signatures is not None and not isinstance( + arg_signatures[i], torch_dtype + ): + # infer from the arg data type (has torch.dtype) to see if it is a tensor type + continue + if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY: + # when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY, + # check if filtered kernel name list is provided + if ( + len(self.filtered_kernel_names_to_print) > 0 + and kernel_name not in self.filtered_kernel_names_to_print + ): + continue + + launch_prefix = "before_launch" if before_launch else "after_launch" + if V.graph.cpp_wrapper: + if config.abi_compatible: + V.graph.wrapper_code.writeline( + f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");' + ) + else: + # TODO: add non-abi compatible mode debug printing info + pass + else: + line = f"print('{launch_prefix} - {kernel_name} - {arg}', {arg})" + V.graph.wrapper_code.writeline(line) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/halide.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/halide.py new file mode 100644 index 0000000000000000000000000000000000000000..20968a57a4444f48548d48f27684484150086b58 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/halide.py @@ -0,0 +1,1699 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import functools +import itertools +import logging +import re +from collections import defaultdict +from math import inf +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) + +import sympy + +import torch +import torch._logging + +from ..._prims_common import is_integer_dtype +from ...utils._sympy.functions import FloorDiv, ModularIndexing +from ...utils._sympy.symbol import symbol_is_type, SymT +from ...utils._sympy.value_ranges import ValueRanges +from .. import config, ir +from ..codecache import HalideCodeCache +from ..ir import get_reduction_combine_fn +from ..metrics import is_metric_table_enabled, log_kernel_metadata +from ..ops_handler import AddParenHandler, MockHandler +from ..runtime.hints import HalideInputSpec, HalideMeta, ReductionHint +from ..utils import ( + get_bounds_index_expr, + get_kernel_metadata, + parallel_num_threads, + sympy_index_symbol, + sympy_subs, +) +from ..virtualized import _ops as ops, OpsHandler, V +from .common import ( + BackendFeature, + CSEVariable, + DeferredLine, + IndentedBuffer, + OpOverrides, + PythonPrinter, + SizeArg, + TensorArg, +) +from .cpp import DTYPE_TO_CPP +from .cpp_utils import cexpr +from .simd import constant_repr, SIMDKernel, SIMDScheduling + + +if TYPE_CHECKING: + from torch.utils._ordered_set import OrderedSet + + from ..ops_handler import ReductionType, StoreMode + +log = logging.getLogger(__name__) + + +def halide_constant(val): + if isinstance(val, int) and not (-2147483648 <= val <= 2147483647): + info = torch.iinfo(torch.int64) + if val == info.min: + return "hl.Int(64).min()" + if val == info.max: + return "hl.Int(64).max()" + return f"hl.i64({val!r})" + if isinstance(val, float): + return f"hl.f64({constant_repr(val)})" + return repr(val) + + +class Unsupported(RuntimeError): + def __init__(self, thing) -> None: + super().__init__(f"halide backend does not support: {thing}") + + +class HalidePrinter(PythonPrinter): + @staticmethod + def cast_index(expr): + return f"hl.cast({V.kernel.index_dtype}, {expr})" + + @staticmethod + def cast_float(expr): + return f"hl.cast(hl.Float(32), {expr})" + + def _print_Float(self, expr): + return f"hl.f32({expr})" + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"hl.f32({self._print(expr.args[0])})" + + def _print_floor(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.floor({self._print(expr.args[0])})") + + def _print_Trunc(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.trunc({self._print(expr.args[0])})") + + _print_TruncToInt = _print_Trunc + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.ceil({self._print(expr.args[0])})") + + def _helper_sqrt(self, expr): + return f"hl.sqrt({self.cast_float(self._print(expr))})" + + def _print_Where(self, expr): + c = self.doprint(expr.args[0]) + p = self.doprint(expr.args[1]) + q = self.doprint(expr.args[2]) + return f"hl.select({c}, {p}, {q})" + + def _print_Min(self, expr): + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + a = self._print(sympy.Min(*expr.args[:mid])) + b = self._print(sympy.Min(*expr.args[mid:])) + return f"hl.min({a}, {b})" + + def _print_Max(self, expr): + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + a = self._print(sympy.Max(*expr.args[:mid])) + b = self._print(sympy.Max(*expr.args[mid:])) + + return f"hl.max({a}, {b})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.abs({self._print(expr.args[0])})") + + def _print_OpaqueUnaryFn_cos(self, expr): + assert len(expr.args) == 1 + return f"hl.cos(({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr): + assert len(expr.args) == 1 + return f"hl.cosh(({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr): + assert len(expr.args) == 1 + return f"hl.acos(({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr): + assert len(expr.args) == 1 + return f"hl.sin(({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr): + assert len(expr.args) == 1 + return f"hl.sinh(({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr): + assert len(expr.args) == 1 + return f"hl.asin(({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr): + assert len(expr.args) == 1 + return f"hl.tan(({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr): + assert len(expr.args) == 1 + return f"hl.tanh(({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr): + assert len(expr.args) == 1 + return f"hl.atan(({self._print(expr.args[0])})" + + def _print_FloorDiv(self, expr): + if expr.is_integer: + return super()._print_FloorDiv(expr) + + x, div = expr.args + x = self.cast_float(self.paren(self.doprint(x))) + div = self.cast_float(self.paren(self.doprint(div))) + return self.cast_index(f"hl.floor({x} / {div})") + + def _print_Round(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.round({self._print(expr.args[0])})") + + _print_RoundToInt = _print_Round + + def _print_IntTrueDiv(self, expr): + a, b = expr.args + # force a cast to float + return f"({a}) / ({b}+hl.f32(0))" + + def _print_RoundDecimal(self, expr): + val, n = expr.args + val = self._print(val) + n = int(n) + return f"hl.f32({10.**(-n)!r})*hl.round(({val})*hl.f32({10.**n!r}))" + + +texpr = HalidePrinter().doprint +pexpr = PythonPrinter().doprint + + +_halide_type = { + torch.bool: "hl.Bool()", + torch.bfloat16: "hl.BFloat(16)", + torch.float16: "hl.Float(16)", + torch.float32: "hl.Float(32)", + torch.float64: "hl.Float(64)", + torch.int8: "hl.Int(8)", + torch.int16: "hl.Int(16)", + torch.int32: "hl.Int(32)", + torch.int64: "hl.Int(64)", + torch.uint8: "hl.UInt(8)", + torch.uint16: "hl.UInt(16)", + torch.uint32: "hl.UInt(32)", + torch.uint64: "hl.UInt(64)", +} + + +def halide_type(dtype): + return _halide_type[dtype] + + +def halide_acc_type(dtype): + if is_integer_dtype(dtype) and dtype.is_signed and dtype != torch.int64: + dtype = torch.int32 + if dtype in (torch.float16, torch.bfloat16): + dtype = torch.float32 + return halide_type(dtype) + + +class HalideOverrides(OpOverrides): + @staticmethod + def to_dtype( + x, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types=True, + ): + if dtype == torch.bool: + return f"({x} != 0)" + return f"hl.cast({halide_type(dtype)}, {x})" + + @staticmethod + def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype): + if src_dtype in (torch.float16, torch.bfloat16): + x = f"hl.cast({halide_type(src_dtype)}, {x})" # body compute is upcast to fp32 + line = f"hl.reinterpret({halide_type(dtype)}, {x})" + if dtype in (torch.float16, torch.bfloat16): + line = f"hl.cast(hl.Float(32), {line})" + return line + + @classmethod + def constant(cls, value, dtype): + return cls.to_dtype(halide_constant(value), dtype) + + @staticmethod + def abs(x): + return f"hl.abs({x})" + + @staticmethod + def exp(x): + if not hasattr(x, "name"): + return f"hl.exp({x})" + return f"hl.fast_exp(hl.cast(hl.Float(32), {x})) if {x.name}.type().bits() <= 32 else hl.exp({x})" + + @staticmethod + def libdevice_exp(x): + return f"hl.exp({x})" # higher precision that ops.exp + + @staticmethod + def sqrt(x): + return f"hl.sqrt({x})" + + @staticmethod + def minimum(a, b): + # return f"hl.min({a}, {b})" <== handles nan wrong + if not hasattr(a, "name"): + return f"hl.min({a}, {b})" + b = f"hl.cast({a.name}.type(), {b})" + return f"hl.select(({a}<{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.min({a}, {b})" + + @staticmethod + def maximum(a, b): + # return f"hl.max({a}, {b})" <== handles nan wrong + if not hasattr(a, "name"): + return f"hl.max({a}, {b})" + b = f"hl.cast({a.name}.type(), {b})" + return f"hl.select(({a}>{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.max({a}, {b})" + + @staticmethod + def where(a, b, c): + if hasattr(b, "name"): + c = f"hl.cast({b.name}.type(), {c})" + return f"hl.select({a}, {b}, {c})" + + @staticmethod + def cos(x): + return f"hl.cos({x})" + + @staticmethod + def sin(x): + return f"hl.sin({x})" + + @staticmethod + def lgamma(x): + raise Unsupported("lgamma") + + @staticmethod + def erf(x): + return f"hl.erf({x})" + + @staticmethod + def cosh(x): + return f"hl.cosh({x})" + + @staticmethod + def sinh(x): + return f"hl.sinh({x})" + + @staticmethod + def acos(x): + return f"hl.acos({x})" + + @staticmethod + def acosh(x): + return f"hl.acosh({x})" + + @staticmethod + def asin(x): + return f"hl.asin({x})" + + @staticmethod + def asinh(x): + return f"hl.asinh({x})" + + @staticmethod + def atan2(x, y): + return f"hl.atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"hl.atan({x})" + + @staticmethod + def atanh(x): + return f"hl.atanh({x})" + + @staticmethod + def copysign(x, y): + raise Unsupported("copysign") + + @staticmethod + def erfinv(x): + raise Unsupported("erfinv") + + @staticmethod + def hypot(x, y): + return f"hl.hypot({x}, {y})" + + @staticmethod + def nextafter(x, y): + raise Unsupported("nextafter") + + @staticmethod + def logical_and(a, b): + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"{a} == 0" + + @staticmethod + def logical_or(a, b): + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + return f"({a} ^ {b})" + + @staticmethod + def bitwise_and(a, b): + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def rand(seed, offset): + return f"halide_helpers.rand({seed}, {offset})" + + @staticmethod + def randn(seed, offset): + return f"halide_helpers.randn({seed}, {offset})" + + @staticmethod + def randint64(seed, offset, low, high): + return f"halide_helpers.randint64({seed}, {offset}, {low}, {high})" + + @staticmethod + def load_seed(name, offset): + return f"{ops.load(name, 0)} + {V.kernel.args.seed_offset('load_seed_offset', offset)}" + + @staticmethod + def rsqrt(x): + # return f"hl.fast_inverse_sqrt({x})" <== accuracy issues + return f"1./hl.sqrt({x})" + + @staticmethod + def tan(x): + return f"hl.tan({x})" + + @staticmethod + def tanh(x): + return f"hl.tanh({x})" + + @staticmethod + def signbit(x): + return f"(hl.reinterpret(hl.UInt(32), hl.cast(hl.Float(32), {x})) >> 31) != 0" + + @staticmethod + def fmod(a, b): + # TODO(jansel): find a better way to do this, builtin % has wrong sign + return f"{a} - hl.trunc({a}/{b})*{b}" + + @staticmethod + def pow(a, b): + return f"hl.pow({a}, {b})" # hl.fast_pow fails accuracy + + @staticmethod + def log(x): + return f"hl.log({x})" # hl.fast_log fails accuracy + + @staticmethod + def isinf(x): + # workaround https://github.com/halide/Halide/issues/8309 + return f"hl.is_inf(hl.cast(hl.Float(32), {x}))" + + @staticmethod + def isnan(x): + # workaround https://github.com/halide/Halide/issues/8309 + return f"hl.is_nan(hl.cast(hl.Float(32), {x}))" + + @staticmethod + def round(x): + return f"hl.round({x})" + + @staticmethod + def floor(x): + return f"hl.floor({x})" + + @staticmethod + def int_truediv(a, b): + return f"({a}) / ({b} + hl.f32(0))" + + @staticmethod + def floordiv(a, b): + # TODO(jansel): find a better ways to do this, the select-based trick from triton.py didn't work + return ( + f"hl.floor(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})" + ) + + @classmethod + def sign(cls, x): + left = ops.to_dtype(ops.lt("0", x), torch.int8) + right = ops.to_dtype(ops.lt(x, "0"), torch.int8) + sub = ops.sub(left, right) + return f"hl.cast({x.name}.type(), {sub})" + + @staticmethod + def trunc(x): + return f"hl.trunc({x})" + + @staticmethod + def truncdiv(a, b): + # this causes crashes with floating point exception, see test_div_zero_dim_cpu + # return f"hl.div_round_to_zero({a}, {b})" + return ( + f"hl.trunc(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})" + ) + + @staticmethod + def ceil(x): + return f"hl.ceil({x})" + + @staticmethod + def relu(x): + return f"hl.max({x}, 0)" + + @classmethod + def index_expr(cls, expr, dtype): + index = V.kernel.prepare_indexing(expr) + var = V.kernel.genfunc( + V.kernel.index_to_str(index), + V.kernel.used_dims_from_index(index), + bounds=get_bounds_index_expr(expr), + ) + if dtype not in {torch.int32, torch.int64}: + return ops.to_dtype(var, dtype) + return var + + @classmethod + def indirect_indexing(cls, index_var, size, check=True, wrap_neg=True): + # TODO(jansel): Halide only supports 32-bit indexing, we should error on overflow + index_var = ops.to_dtype(index_var, torch.int32) + index_var = ops.halide_clamp(index_var, size, check) + index_var.indirect_indexing_size = size + return sympy_index_symbol(str(index_var)) + + @classmethod + def halide_clamp(cls, value, size, check): + end = V.kernel.kexpr(V.kernel.rename_indexing(size) - 1) + if not isinstance(size, (int, sympy.Integer)): + end = f"hl.cast({value.name}.type(), {end})" + # Skip unsafe_promise_clamped to workaround: https://github.com/halide/Halide/issues/8261#issuecomment-2148835692 + # return f"hl.unsafe_promise_clamped({value}, 0, {end})" + return f"hl.clamp({value}, 0, {end})" + + @staticmethod + def masked(mask, body, other): + with V.kernel.mask_loads(mask, other) as new_mask: + result = body() + + if result.bounds.is_bool: + other = bool(other) + + # Take dtype from result to prevent accidental promotion + other = V.kernel.genfunc( + f"hl.cast({result.name}.type(), {halide_constant(other)})", + [], + bounds=ValueRanges.wrap(other), + ) + # TODO(jansel): look into removing the where in the same places triton does + return ops.where(new_mask, result, other) + + +# Use mypy to check protocol implemented correctly +def _typecheck_HalideOverrides(h: HalideOverrides) -> OpsHandler[str]: + return h + + +class HalideCSEVariable(CSEVariable): + undefined_re = re.compile(r"\b(tmp\d+)\[\?\]") + + def __init__(self, name, bounds: ValueRanges[Any]) -> None: + super().__init__(name, bounds) + self.used_dims: Optional[List[sympy.Symbol]] = None + + def update_on_args(self, name, args, kwargs): + used = set(self.used_dims or ()) + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, HalideCSEVariable): + assert arg.used_dims is not None, (name, arg, args) + used.update(arg.used_dims) + self.used_dims = V.kernel.sort_used_dims(used) + + def index_str(self, dims): + if len(dims) == 0: + return f"{self.name}[()]" + # Reversed since Halide is column major + return f"{self.name}[{', '.join(map(str, dims))}]" + + def __str__(self) -> str: + if self.used_dims is None: + # This will get recomputed and replaced in codegen_kernel() + return f"{self.name}[?]" + return self.index_str(self.used_dims) + + def subs_str(self, replacements): + assert self.used_dims is not None and all( + isinstance(x, sympy.Expr) for x in self.used_dims + ) + return self.index_str([replacements.get(n, n) for n in self.used_dims]) + + +@dataclasses.dataclass +class DimensionInfo: + expr: Optional[sympy.Expr] + size: sympy.Expr + stride: sympy.Expr + + def __init__(self, expr, size, stride) -> None: + super().__init__() + if V.graph.sizevars.statically_known_lt(stride, 0): + stride = -stride + expr = -expr + self.expr = expr + self.size = size + self.stride = stride + + def index_str(self, replacements=None, zero_vars=False): + assert self.expr is not None + expr = self.expr + if zero_vars and expr == 0: + return "hl.Var()" + if replacements: + replacements = {**replacements} + for sym in expr.free_symbols: + if symbol_is_type(sym, SymT.TMP): + assert isinstance(sym, sympy.Symbol) + var = V.kernel.lookup_cse_var(sym.name) + assert isinstance(var, HalideCSEVariable) + replacements[sym] = sympy_index_symbol(var.subs_str(replacements)) + expr = sympy_subs(expr, replacements) + return V.kernel.index_to_str(expr) + + +def eq(left, right): + if V.graph.sizevars.statically_known_equals(left, right): + return True + try: + a = V.graph.sizevars.size_hint(left) + b = V.graph.sizevars.size_hint(right) + except TypeError: # unbacked symints + return False + if a == b: + V.graph.sizevars.guard_equals(left, right) + return a == b + + +def lt(left, right): + if V.graph.sizevars.statically_known_lt(left, right): + return True + try: + a = V.graph.sizevars.size_hint(left) + b = V.graph.sizevars.size_hint(right) + except TypeError: # unbacked symints + gcd = sympy.gcd(left, right) + if gcd == left: + return left != right + return False + if a < b: + V.graph.sizevars.guard_lt(left, right) + return a < b + + +class HalideKernel(SIMDKernel): + overrides = HalideOverrides # type: ignore[assignment] + kexpr: Callable[[sympy.Expr], str] = texpr + + def __init__( + self, + *groups, + index_dtype: str, + mutations: Optional[OrderedSet[str]] = None, + pid_cache=None, + reduction_hint=ReductionHint.DEFAULT, + override_persistent_reduction=None, + ) -> None: + super().__init__( + *groups, + index_dtype=index_dtype, + mutations=mutations, + reduction_hint=reduction_hint, + pid_cache=pid_cache, + override_persistent_reduction=override_persistent_reduction, + ) + # For halide, we just write directly to the body + self.compute = self.body + self.loads = self.body + self.stores = self.body + self.indexing_code_dom = IndentedBuffer() + self.needs_dom_indexing = self.inside_reduction + self.has_reduction = self.inside_reduction + self.buffer_dimensions: Dict[str, List[DimensionInfo]] = {} + self.buffer_offsets: Dict[str, sympy.Expr] = {} + # {h0: size1, h1: size2, ...} + self.halide_vars: Dict[sympy.Symbol, sympy.Expr] = {} + # {x0: h0, x1: h1+10*h2, ...} + self.index_replacements: Dict[sympy.Expr, sympy.Expr] = {} + # {h1: hr1, ...} + self.reduction_renames: Dict[sympy.Symbol, sympy.Symbol] = {} + # {"i": {h0: hi0}, "o": ...} + self.dom_renames: Dict[str, Dict[sympy.Symbol, sympy.Symbol]] = {} + # {"in_ptr0": ["in_ptr0_view0"], ...} + self.buffer_aliases: Dict[str, List[str]] = defaultdict(list) + self.has_indirect_indexing = False + + def create_cse_var(self, name, bounds=None): + self.body.writeline(f"{name} = hl.Func({name!r})") + return HalideCSEVariable(name, bounds) + + def finalize_indexing(self, indices: Sequence[sympy.Expr]): + """ + Hook called right before codegen with every index that will be + used in the fused kernel. + + This populates self.halide_vars/index_replacements/reduction_renames which is an alternate indexing + scheme that avoids using divide and modulus. Instead of xindex/yindex/rindex + we base indexing on a larger number of vars whose product combines to those. + + This function populates self.halide_vars, self.index_replacements, and self.reduction_renames + """ + assert not ( + self.index_replacements or self.halide_vars or self.reduction_renames + ) + size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type] + indices = dict.fromkeys(map(super().prepare_indexing, indices)) + all_used_symbols = set() + sym_to_node = { + n.symbol(): n + for n in itertools.chain.from_iterable( + [tree.nodes.values() for tree in self.range_trees] + ) + } + + def simplify(expr): + return sympy.simplify( + V.graph.sizevars.remove_precomputed_replacements(expr) + ) + + def visit_modular_indexing(base, divisor, modulus): + if base in sym_to_node: + node = sym_to_node[base] + all_used_symbols.add( + node.root.lookup( + node.divisor * divisor, + V.graph.sizevars.evaluate_min( + modulus, FloorDiv(node.length, divisor) + ), + ).symbol() + ) + + def visit_floor_div(base, divisor): + if base in sym_to_node: + node = sym_to_node[base] + all_used_symbols.add( + node.root.lookup( + node.divisor * divisor, + FloorDiv(node.length, divisor), + ).symbol() + ) + + # first figure out all_used_symbols to do dead symbol elimination + for index in indices: + if index.has(ModularIndexing): + index.replace( + ModularIndexing( + sympy.Wild("base"), + sympy.Wild("divisor"), + sympy.Wild("modulus"), + ), + visit_modular_indexing, + ) + if index.has(FloorDiv): + index.replace( + FloorDiv( + sympy.Wild("base"), + sympy.Wild("divisor"), + ), + visit_floor_div, + ) + all_used_symbols.update(super().prepare_indexing(index).free_symbols) + + self.has_indirect_indexing = any( + symbol_is_type(sym, SymT.INDIRECT) for sym in all_used_symbols + ) + + had_fallback = False + for tree in reversed(self.range_trees): + nodes = [n for n in tree.nodes.values() if n.symbol() in all_used_symbols] + nodes.sort(key=lambda n: size_hint(n.divisor)) + if not nodes: + nodes.append(tree.lookup(1, tree.numel)) + handled_count = 0 + divisor = sympy.Integer(1) + added_sym_size = [] + # decide on a minimal set of symbols and put them in self.halide_vars + while handled_count < len(nodes) and not eq(tree.numel, divisor): + sizes_to_add = [ + simplify(n.length) for n in nodes if eq(n.divisor, divisor) + ] + handled_count += len(sizes_to_add) + assert sizes_to_add, nodes + end = divisor * functools.reduce( + V.graph.sizevars.evaluate_max, sizes_to_add + ) + sizes_to_add.extend( + [ + simplify(n.divisor / divisor) + for n in nodes + if lt(divisor, n.divisor) and lt(n.divisor, end) + ] + ) + while sizes_to_add: + next_size = functools.reduce(sympy.gcd, sizes_to_add) + if eq(next_size, 1): + # sizes share no common factors, e.g [2, 21, 42, 441, 889056] + # TODO(jansel): we should just prevent fusion in cases that hit this + next_size = simplify(tree.numel / divisor) + assert not eq(next_size, 1) + sizes_to_add = [] + handled_count = len(nodes) + had_fallback = True + sym = sympy_index_symbol(f"h{len(self.halide_vars)}") + if tree.prefix == "r": + self.reduction_renames[sym] = sympy_index_symbol( + f"hr{len(self.halide_vars)}" + ) + self.halide_vars[sym] = next_size + added_sym_size.append((sym, next_size)) + divisor *= next_size + new_sizes = [n.length for n in nodes if eq(n.divisor, divisor)] + handled_count += len(new_sizes) + prior_len = len(sizes_to_add) + sizes_to_add = [ + sympy.simplify(s / next_size) + for s in sizes_to_add + if not eq(s, next_size) + ] + assert len(sizes_to_add) < prior_len or prior_len == 0 + sizes_to_add.extend(new_sizes) + + # create a mapping to the new set of symbols in self.index_replacements + for node in nodes: + try: + idx = 0 + divisor = 1 + while not eq(node.divisor, divisor): + sym, size = added_sym_size[idx] + idx += 1 + divisor *= size + length = 1 + expr = sympy.Integer(0) + while not eq(node.length, length): + sym, size = added_sym_size[idx] + idx += 1 + expr += length * sym + length *= size + self.index_replacements[node.symbol()] = expr + except IndexError: + assert had_fallback + full_index = sympy.Integer(0) + stride = sympy.Integer(1) + for sym, size in added_sym_size: + full_index += stride * sym + stride *= size + self.index_replacements[ + node.symbol() + ] = V.graph.sizevars.simplify_with_ranges( + ModularIndexing(full_index, node.divisor, node.length), + self.halide_vars, # type: ignore[arg-type] + ) + + # codegen the variable definitions + for sym in self.halide_vars: + self.indexing_code.writeline(f"{sym} = hl.Var({sym.name!r})") + if self.reduction_renames: + self.codegen_rdom( + "rdom", + {rv: self.halide_vars[v] for v, rv in self.reduction_renames.items()}, + ) + + def setup_dom_indexing(self): + """RDom based indexing uses explicit iteration ranges for Func updates""" + prefix = "i" if self.inside_reduction else "o" + if prefix in self.dom_renames: + return self.dom_renames[prefix] + + renames = {} + for var in self.halide_vars.keys(): + if not self.inside_reduction and var in self.reduction_renames: + continue + m = re.match(r"^h(\d+)$", var.name) + assert m + renames[var] = sympy_index_symbol(f"h{prefix}{m.group(1)}") + + self.codegen_rdom( + f"{prefix}dom", {rv: self.halide_vars[v] for v, rv in renames.items()} + ) + + self.dom_renames[prefix] = renames + return renames + + def codegen_rdom(self, name, vars): + rsizes = [ + f"hl.Range(0, {self.kexpr(self.rename_indexing(size))})" + for size in vars.values() + ] + self.indexing_code.writeline(f"{name} = hl.RDom([{', '.join(rsizes)}])") + for i, rsym in enumerate(vars.keys()): + self.indexing_code.writeline(f"{rsym} = {name}[{i}]") + + def prepare_indexing( + self, + index: sympy.Expr, + ): + index = super().prepare_indexing(index) + index = sympy_subs(index, self.index_replacements) + return V.graph.sizevars.simplify_with_ranges(index, self.halide_vars) # type: ignore[arg-type] + + def sym_size(self, sym): + """The size of an index symbol""" + if symbol_is_type(sym, SymT.TMP): + return self.lookup_cse_var(sym.name).indirect_indexing_size + return self.halide_vars[sym] + + def indexing_to_dimensions(self, var: str, index: sympy.Expr, is_store: bool): + """Convert address-based indexing into dimensions using self.halide_vars""" + symbols = [] + for sym in sorted(index.free_symbols, key=lambda x: x.name): # type: ignore[attr-defined] + if symbol_is_type(sym, (SymT.HALIDE, SymT.TMP)): + symbols.append(sym) + else: + assert symbol_is_type( + sym, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + ), + ), sym + + # group the expression by variables used + offset = sympy.Integer(0) + split_expr = {s: sympy.Integer(0) for s in symbols} + split_failed: List[Tuple[List[sympy.Symbol], sympy.Expr]] = [] + index = sympy.expand(self.rename_indexing(index)) + for part in index.args if isinstance(index, sympy.Add) else [index]: + part_vars = [v for v in part.free_symbols if v in split_expr] + if len(part_vars) == 0: + offset += part + elif len(part_vars) == 1: + split_expr[part_vars[0]] += part + else: + new_split_failed = [] + for i in range(len(split_failed)): + assert split_failed[i] is not None + other_vars, other_part = split_failed[i] + if set(other_vars) & set(part_vars): + part_vars.extend([v for v in other_vars if v not in part_vars]) + part += other_part + else: + new_split_failed.append((other_vars, other_part)) + split_failed = [*new_split_failed, (part_vars, part)] + + def expr_to_dimension(expr, syms): + expr = sympy.factor(expr) + if len(syms) == 1: + stride_wild = sympy.Wild("wild", exclude=symbols) + m = expr.match(stride_wild * syms[0]) + if m: + return DimensionInfo( + syms[0], self.sym_size(syms[0]), m[stride_wild] + ) + assert not is_store, expr + length = sympy.simplify( + sympy_subs(expr, {sym: self.sym_size(sym) - 1 for sym in syms}) + 1 + ) + stride = sympy.Integer(1) + if isinstance(expr, sympy.Mul): + for term in expr.args: + if isinstance(term, sympy.Integer): + stride *= term + expr = sympy.simplify(expr / term) + length = sympy.simplify(sympy.ceiling(length / term)) + return DimensionInfo(expr, length, stride) + + # try to turn each group into a strided access + dims = [] + for syms, expr in split_failed: + for v in syms: + expr += split_expr.pop(v) + dims.append(expr_to_dimension(expr, syms)) + for sym, expr in split_expr.items(): + dims.append(expr_to_dimension(expr, [sym])) + dims.sort(key=lambda d: V.graph.sizevars.size_hint(d.stride, fallback=inf)) # type: ignore[arg-type] + + if not dims: # scalar load/store + if self.has_indirect_indexing: + # workaround https://github.com/halide/Halide/issues/8338 + dims.append(DimensionInfo(sympy.Integer(0), 1, 1)) + elif not V.graph.sizevars.statically_known_equals(dims[0].stride, 1): + # Halide assumes dimension 0 is stride == 1, so add a dummy dimension + dims.insert( + 0, DimensionInfo(sympy.Integer(0), 1 if is_store else dims[0].stride, 1) + ) + + if dims and not is_store: + if var in self.buffer_offsets and V.graph.sizevars.statically_known_geq( + offset, self.buffer_offsets[var] + ): + # reuse the existing offset to avoid needing an input alias + self.apply_offset_to_dimension(dims, offset - self.buffer_offsets[var]) + offset = self.buffer_offsets[var] + elif V.graph.sizevars.statically_known_gt( + offset, 0 + ): # TODO(jansel): negative offsets + # roll the offset into the dimensions for cleaner indexing + self.apply_offset_to_dimension(dims, offset) + offset = 0 + + orig_var = var + for i in itertools.count(): + if self.install_dims(var, dims, offset, is_store): + return var, dims + assert not is_store + var = f"{orig_var}_view{i}" + if var not in self.buffer_aliases[orig_var]: + self.buffer_aliases[orig_var].append(var) + + def install_dims(self, var, dims, offset, is_store): + """Try to set self.buffer_dimensions[var], return True on success""" + if var not in self.buffer_dimensions: + self.buffer_dimensions[var] = dims + self.buffer_offsets[var] = offset + return True + if self.buffer_offsets[var] != offset or len( + self.buffer_dimensions[var] + ) != len(dims): + return False + if is_store: + return self.buffer_dimensions[var] == dims + for old, new in zip(self.buffer_dimensions[var], dims): + if old.stride != new.stride: + return False + if old.size != new.size or old.expr != new.expr: + old.size = V.graph.sizevars.evaluate_max(old.size, new.size) + old.expr = None + return True + + def apply_offset_to_dimension(self, dims, offset): + if offset == 0: + return + for i in reversed(range(len(dims))): + if dims[i].stride == 1 or V.graph.sizevars.statically_known_geq( + offset, dims[i].stride + ): + part = FloorDiv(offset, dims[i].stride) + offset -= part * dims[i].stride + dims[i].expr += part + assert offset == 0 + + def used_dims_from_index(self, index: sympy.Expr): + """Detect which range trees are used to populate HalideCSEVariable.used_dims""" + used_dims = set() + for sym in index.free_symbols: + assert isinstance(sym, sympy.Symbol) + if symbol_is_type(sym, SymT.TMP): + # indirect indexing + cse_var = self.lookup_cse_var(sym.name) + assert ( + isinstance(cse_var, HalideCSEVariable) + and cse_var.used_dims is not None + ) + used_dims.update(cse_var.used_dims) + elif symbol_is_type(sym, SymT.HALIDE): + used_dims.add(sym) + elif symbol_is_type( + sym, (SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, SymT.INDEX) + ): + pass + else: + raise NotImplementedError(f"unhandled symbol {sym}") + return self.sort_used_dims(used_dims) + + def sort_used_dims(self, used_dims): + assert all(isinstance(x, sympy.Expr) for x in used_dims) + ordered = [ + sym + for sym in itertools.chain( + self.halide_vars, self.reduction_renames.values() + ) + if sym in used_dims + ] + assert len(ordered) == len(used_dims) + return ordered + + def make_index_str(self, dims, replacements=None, zero_vars=False): + index_str = ", ".join(d.index_str(replacements, zero_vars) for d in dims) + if len(dims) == 0: + index_str = "()" + elif len(dims) == 1: + # workaround for https://github.com/halide/Halide/issues/8299 + index_str = f"{index_str}," + return index_str + + def load(self, name: str, index: sympy.Expr): + """Codegen a load from an InputBuffer""" + var = self.args.input(name) + index = self.prepare_indexing(index) + var, dims = self.indexing_to_dimensions(var, index, False) + line = f"{var}[{self.make_index_str(dims)}]" + dtype = V.graph.get_dtype(name) + if dtype in (torch.float16, torch.bfloat16): + dtype = torch.float32 + line = f"hl.cast(hl.Float(32), {line})" + + if self._load_mask: + assert ( + isinstance(self._load_mask, HalideCSEVariable) + and self._load_mask.used_dims is not None + ) + used_dims = {*self.used_dims_from_index(index), *self._load_mask.used_dims} + result = self.newfunc(self.sort_used_dims(used_dims)) + if result.used_dims: + self.body.writeline(f"{result.name}_mask = hl.RDom([hl.Range(0, 1)])") + self.body.writeline(f"{result.name}_mask.where({self._load_mask})") + other = self.kexpr(self._load_other or 0) # type: ignore[arg-type] + self.body.writeline( + f"{result} = hl.cast({halide_type(dtype)}, {other})" + ) + self.body.writeline( + f"{result} = {line} + hl.cast({halide_type(dtype)}, {result.name}_mask)" + ) + else: + # scalar case + self.body.writeline( + f"{result} = hl.select({self._load_mask}, {line}, hl.cast({halide_type(dtype)}, 0))" + ) + return result + else: + return self.genfunc(line, self.used_dims_from_index(index)) + + def lookup_cse_var(self, name: str): + return self.cse.varname_map[re.sub(r"\[.*", "", name)] + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + """Codegen a store to an OutputBuffer""" + assert isinstance(value, HalideCSEVariable) + var = self.args.output(name) + index = self.prepare_indexing(index) + var, dims = self.indexing_to_dimensions(var, index, True) + if self.is_indirect_indexing(index) or mode is not None: + replacements = self.setup_dom_indexing() + index_str = self.make_index_str(dims, replacements) + value_str = value.subs_str(replacements) + undef_dims = (", ".join(["hl.Var()"] * len(dims))) or "()" + self.body.writeline( + DeferredLine(name, f"{var}[{undef_dims}] = hl.undef({var}.type())") + ) + else: + index_str = self.make_index_str(dims, zero_vars=True) + value_str = str(value) + + dtype = V.graph.get_dtype(name) + if mode is None: + line = f"{var}[{index_str}] = hl.cast({halide_type(dtype)}, {value_str})" + elif mode == "atomic_add": + line = f"{var}[{index_str}] += hl.cast({halide_type(dtype)}, {value_str})" + else: + raise NotImplementedError(f"store mode={mode}") + self.body.writeline(DeferredLine(name, line)) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + """Codegen a reduction operation""" + assert self.inside_reduction + assert not self._load_mask + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + if isinstance(value, tuple): + assert reduction_type == "welford_combine" + self.cse.reduction_cache[ + cache_key + ] = result_tuple = self.welford_combine_impl(*value) + return result_tuple + + assert isinstance(value, HalideCSEVariable) and value.used_dims is not None + reduction_vars = {*self.reduction_renames} + result_var = self.newfunc( + [v for v in value.used_dims if v not in reduction_vars] + ) + if reduction_vars - {*value.used_dims}: + value = self.genfunc( + f"{value}", self.sort_used_dims({*value.used_dims, *reduction_vars}) + ) + value_str = value.subs_str(self.reduction_renames) + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + acc_type = halide_acc_type(dtype) + + if reduction_type in ("argmax", "argmin"): + index = f"{result_var.name}_{reduction_type}" + self.body.writeline(f"{index} = hl.{reduction_type}(rdom, {value_str})") + # turn the N-D argmax index into a 1-D one + parts = [] + stride = 1 + for i, sym in enumerate(self.reduction_renames): + parts.append(f"{index}[{i}]") + if stride != 1: + parts[-1] += f"*{stride}" + stride *= self.halide_vars[sym] + self.body.writeline(f"{result_var} = {' + '.join(parts)}") + elif reduction_type == "welford_reduce": + # TODO(jansel): implement welford_reduce without fallback + result_var = self.welford_reduce_fallback(dtype, value) + else: + combine_fn = get_reduction_combine_fn(reduction_type, acc_type) + with V.set_ops_handler(AddParenHandler(HalideOverrides(MockHandler()))): + combine_str = combine_fn(result_var, value_str) # type: ignore[arg-type] + default_str = f"hl.cast({acc_type}, {halide_constant(default)})" + self.body.writeline(f"{result_var} = {default_str}") + self.body.writeline(f"{result_var} = {combine_str}") + + self.cse.reduction_cache[cache_key] = result_var + return result_var + + def welford_combine_impl(self, mean, m2, weight): + assert isinstance(mean, HalideCSEVariable) and mean.used_dims is not None + assert isinstance(m2, HalideCSEVariable) and m2.used_dims is not None + assert isinstance(weight, HalideCSEVariable) and weight.used_dims is not None + used_dims = {*mean.used_dims, *m2.used_dims, *weight.used_dims} or { + *self.halide_vars + } + used_dims -= {*self.reduction_renames} + result_var = self.newfunc(self.sort_used_dims(used_dims)) + default = [f"hl.cast({x.name}.type(), 0)" for x in (mean, m2, weight)] + pfx = result_var.name + self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(default)}])") + self.body.writeline(f"{pfx}_mean_1 = {result_var}[0]") + self.body.writeline(f"{pfx}_m2_1 = {result_var}[1]") + self.body.writeline(f"{pfx}_weight_1 = {result_var}[2]") + self.body.writeline(f"{pfx}_mean_2 = {mean.subs_str(self.reduction_renames)}") + self.body.writeline(f"{pfx}_m2_2 = {m2.subs_str(self.reduction_renames)}") + self.body.writeline( + f"{pfx}_weight_2 = {weight.subs_str(self.reduction_renames)}" + ) + self.body.writeline(f"{pfx}_delta = {pfx}_mean_2 - {pfx}_mean_1") + self.body.writeline(f"{pfx}_new_weight = {pfx}_weight_1 + {pfx}_weight_2") + self.body.writeline( + f"{pfx}_w2_over_w = hl.select({pfx}_new_weight == 0.0, 0.0, {pfx}_weight_2 / {pfx}_new_weight)" + ) + update = [ + f"{pfx}_mean_1 + {pfx}_delta * {pfx}_w2_over_w", + f"{pfx}_m2_1 + {pfx}_m2_2 + {pfx}_delta * {pfx}_delta * {pfx}_weight_1 * {pfx}_w2_over_w", + f"{pfx}_new_weight", + ] + self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(update)}])") + + unpacked = [] + for i in range(3): + unpacked.append(self.newfunc(result_var.used_dims)) + self.body.writeline(f"{unpacked[-1]} = {result_var}[{i}]") + return tuple(unpacked) + + def scan( + self, + dtypes: Tuple[torch.dtype, ...], + combine_fn: Callable[ + [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...] + ], + values_orig: Tuple[CSEVariable, ...], + ) -> Tuple[CSEVariable, ...]: + assert self.inside_reduction + assert len(dtypes) == len(values_orig) + values: List[HalideCSEVariable] = [] + all_used_dims = set() + for value in values_orig: + assert isinstance(value, HalideCSEVariable) and value.used_dims is not None + if set(value.used_dims) & set(self.reduction_renames): + values.append(value) + else: + values.append( + self.genfunc( + f"{value}", [*value.used_dims, [*self.reduction_renames][:1]] + ) + ) + all_used_dims.update(value.used_dims) + result_var = self.newfunc(self.sort_used_dims(all_used_dims)) + assert result_var.used_dims and set(result_var.used_dims) & set( + self.reduction_renames + ) + initial = [ + f"hl.cast({halide_acc_type(dtype)}, {value})" + for dtype, value in zip(dtypes, values) + ] + + length = self.kexpr(self.rename_indexing(self.range_trees[-1].numel)) + scan_dom = f"{result_var.name}_rdom" + scan = f"{scan_dom}.x" + self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])") + + assert ( + len(self.reduction_renames) == 1 + ), "multi-dimensional scan not implemented" + (scan_var,) = [*self.reduction_renames] # type: ignore[misc] + scan_renames_cur = {scan_var: sympy_index_symbol(scan)} + scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1} + + if len(values) == 1: + + def maybe_tuple(x): + return x[0] + + read_left = [result_var.subs_str(scan_renames_pri)] + read_right = [result_var.subs_str(scan_renames_cur)] + else: + + def maybe_tuple(x): + return f"hl.Tuple([{', '.join(x)}])" + + read_left = [ + result_var.subs_str(scan_renames_pri) + f"[{i}]" + for i in range(len(values)) + ] + read_right = [ + result_var.subs_str(scan_renames_cur) + f"[{i}]" + for i in range(len(values)) + ] + + self.body.writeline(f"{result_var} = {maybe_tuple(initial)}") + + # Disable CSE for update fn + with V.set_ops_handler(AddParenHandler(HalideOverrides(MockHandler()))): + combine_str = combine_fn(read_left, read_right) # type: ignore[arg-type] + self.body.writeline( + f"{result_var.subs_str(scan_renames_cur)} = {maybe_tuple(combine_str)}" + ) + + if len(values) == 1: + return (result_var,) + + unpack_vars = [self.newfunc(self.sort_used_dims(all_used_dims)) for _ in values] + for i, v in enumerate(unpack_vars): + self.body.writeline(f"{v} = {result_var}[{i}]") + return tuple(unpack_vars) + + def genfunc( + self, line, used_dims, *, bounds=ValueRanges.unknown() + ) -> HalideCSEVariable: + var = self.cse.generate(self.body, line, bounds=bounds) + assert isinstance(var, HalideCSEVariable) + var.used_dims = used_dims + return var + + def newfunc(self, used_dims) -> HalideCSEVariable: + var = self.cse.newvar() + assert isinstance(var, HalideCSEVariable) + var.used_dims = used_dims + return var + + def halide_buffer_numel(self, name: str): + """ + We map all tensors to 1D buffers in Halide since Halide has trouble representing some strides that PyTorch + supports. If there are gaps in the underlying layout the numel we pass to Halide includes the gaps while + PyTorch's numel excludes them. + """ + return V.graph.get_buffer(name).get_layout().storage_size() + + def halide_argdefs(self): + """ + Halide requires scalar inputs before outputs, so need to reorder args. + """ + + def arg_order(arg_tuple): + call_str, arg = arg_tuple + if isinstance(arg, SizeArg): + return 1 # this would normally be at the end, move it to middle + elif "out_ptr" in arg.name: + return 2 + else: + assert "in_ptr" in arg.name + return 0 + + result = [] + _, a, b, _ = self.args.python_argdefs() + for call_str, arg in sorted(zip(a, b), key=arg_order): + result.append((call_str, arg)) + if isinstance(arg, TensorArg): + assert arg.offset == 0 and arg.alias_of is None + for alias in self.buffer_aliases.get(arg.name, ()): + result.append( + ( + None, + TensorArg( + alias, + arg.buffer, + arg.dtype, + arg.offset, + alias_of=arg.name, + ), + ) + ) + return result + + def halide_kernel_meta(self) -> HalideMeta: + """Compute metadata required by codecache.py""" + argtypes = [] + for _, arg in self.halide_argdefs(): + if isinstance(arg, SizeArg): + shape = None + stride = None + offset = None + dtype = "long" + else: + shape = [ + cexpr(self.rename_indexing(x.size)) + for x in self.buffer_dimensions[arg.name] + ] + stride = [ + cexpr(self.rename_indexing(x.stride)) + for x in self.buffer_dimensions[arg.name] + ] + assert len(shape) == len(stride) + offset = cexpr(self.buffer_offsets[arg.name]) + dtype = f"{DTYPE_TO_CPP[arg.dtype]}*" + argtypes.append( + HalideInputSpec( + dtype, + arg.name, + shape=shape, + stride=stride, + offset=offset, + alias_of=arg.alias_of, + ) + ) + + current_device = V.graph.scheduler.get_current_device_or_throw() + if current_device.type == "cpu": + target = [config.halide.cpu_target] + schduler = config.halide.scheduler_cpu + scheduler_flags = { + "parallelism": parallel_num_threads(), + } + cuda_device = None + else: + assert current_device.type == "cuda", "only cpu/cuda supported" + assert current_device.index <= 0, "only default device supported" + target = [config.halide.gpu_target] + schduler = config.halide.scheduler_cuda + capability = torch.cuda.get_device_properties(current_device) + if "cuda_capability" not in target[0]: + for major, minor in [(8, 6), (8, 0), (7, 5), (7, 0), (6, 1)]: + if capability.major >= major and capability.minor >= minor: + target.append(f"cuda_capability_{major}{minor}") + break + target.append("user_context") + scheduler_flags = { + "parallelism": capability.multi_processor_count, + # TODO(jansel): explore other flags, see: + # grep parser.parse ~/Halide/src/autoschedulers/anderson2021/AutoSchedule.cpp + } + cuda_device = max(0, current_device.index) + + # strict_float is requires for correctness + target.append("strict_float") + + # without this we will initialize cuda once per kernel and hit errors + target.append("no_runtime") + + if not config.halide.asserts: + target.append("no_asserts") + + if config.halide.debug: + target.append("debug") + + if "64" in self.index_dtype: + # TODO(jansel): it is unclear if this does anything, since input sizes are still int32 + target.append("large_buffers") + + return HalideMeta( + argtypes, + target="-".join(target), + scheduler=schduler, + scheduler_flags=scheduler_flags, + cuda_device=cuda_device, + ) + + def codegen_kernel(self, name=None): + """Called at the end to generate a final kernel string""" + if self.args.inplace_buffers: + raise Unsupported("inplace_buffers") + meta = self.halide_kernel_meta() # ensure needed args are added early + code = IndentedBuffer() + code.splice( + """ + import halide as hl + from torch._inductor.runtime import halide_helpers + from math import inf, nan + + @hl.generator(name="kernel") + class Kernel: + """, + strip=True, + ) + code.do_indent() + for _, arg in self.halide_argdefs(): + if isinstance(arg, SizeArg): + code.writeline(f"{arg.name} = hl.InputScalar({self.index_dtype})") + else: + assert arg.buffer, arg + argcls = "hl.OutputBuffer" if "out" in arg.name else "hl.InputBuffer" + argtype = halide_type(arg.dtype) + ndim = len(self.buffer_dimensions[arg.name]) + code.writeline(f"{arg.name} = {argcls}({argtype}, {ndim})") + code.splice( + """ + def generate(g): + """ + ) + code.do_indent() + for _, arg in self.halide_argdefs(): + code.writeline(f"{arg.name} = g.{arg.name}") + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.indexing_code) + + def update_index(m): + var = self.cse.varname_map[m.group(1)] + assert var.used_dims is not None, var + return str(var) + + for line in self.body._lines: + if isinstance(line, str): + # fill in missing indices + line = HalideCSEVariable.undefined_re.sub(update_index, line) + code.writeline(line) + code.writeline("") + code.writeline("assert g.using_autoscheduler()") + + for _, arg in self.halide_argdefs(): + # fallback=1 below because halide requires buffers to be at least as large as the estimates + # This causes crashes if our estimate is greater than the vector length + # https://github.com/halide/Halide/issues/3103 + if isinstance(arg, SizeArg): + hint = V.graph.sizevars.size_hint(arg.expr, fallback=1) + code.writeline(f"{arg.name}.set_estimate({hint})") + else: + dims = self.buffer_dimensions[arg.name] + range_hints = [] + for i, dim in enumerate(dims): + hint = self._autoscheduler_workarounds( + V.graph.sizevars.size_hint(dim.size, fallback=1), dims + ) + range_hints.append(f"hl.Range(0, {hint})") + if "out" not in arg.name: + code.writeline(f"{arg.name}.dim({i}).set_min(0)") + try: + code.writeline( + f"{arg.name}.dim({i}).set_stride({int(dim.stride)})" + ) + except TypeError: + pass # not integer + try: + code.writeline( + f"{arg.name}.dim({i}).set_extent({int(dim.size)})" + ) + except TypeError: + pass # not integer + code.writeline(f"{arg.name}.set_estimates([{', '.join(range_hints)}])") + + code.do_unindent(2) + code.splice( + """ + if __name__ == "__main__": + hl.main() + """.rstrip(), + ) + if meta.scheduler: + code.splice( + f""" + else: + hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r}) + target = hl.Target({meta.target!r}) + autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r}) + with hl.GeneratorContext(target, autoscheduler): + gen = Kernel() + pipeline = gen._build_pipeline() + # gen.compile_to_callable() does not run the autoscheduler + pipeline.apply_autoscheduler(target, autoscheduler) + kernel = pipeline.compile_to_callable([ + gen._get_input_parameter(a.name)._to_argument() + for a in gen._get_arginfos() + if a.dir == hl.ArgInfoDirection.Input + ], target) + """, + strip=True, + ) + else: + code.splice( + f""" + else: + with hl.GeneratorContext(hl.Target({meta.target!r})): + kernel = Kernel().compile_to_callable() + """, + strip=True, + ) + return code.getvalue() + + @staticmethod + def _autoscheduler_workarounds(n, dims): + if ( + len(dims) == 1 + and config.halide.scheduler_cuda == "Anderson2021" + and V.graph.scheduler.get_current_device_or_throw().type == "cuda" + ): + # workaround https://github.com/halide/Halide/issues/8246 + n = max(2, n) + return n + + def call_kernel(self, name: str, node=None): + """Codegen a call to this kernel""" + wrapper = V.graph.wrapper_code + call_args = [f"{n}" for n, arg in self.halide_argdefs() if arg.alias_of is None] + current_device = V.graph.scheduler.get_current_device_or_throw() + if current_device.type == "cuda": + stream_name = wrapper.write_get_raw_stream(current_device.index, V.graph) + call_args.append(stream_name) + wrapper.generate_kernel_call( + name, + call_args, + cuda=False, # grid/stream is handled internally in halide + ) + + def generate_assert(self, check): + return False # TODO(jansel): support asserts + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ): + pass # TODO(jansel): support asserts + + +class HalideScheduling(SIMDScheduling): + int32_type = "hl.Int(32)" + # TODO(jansel): Halide doesn't actually support 64 bit indexing... + int64_type = "hl.Int(64)" + kernel_type = HalideKernel # type: ignore[arg-type] + + @classmethod + def get_backend_features(cls, device: torch.device): + result = dict.fromkeys( + [ + BackendFeature.TUPLE_REDUCTION, + BackendFeature.PREFER_STORE_LOOP_ORDER, + BackendFeature.REDUCE_TO_SINGLE_ELEMENT, + ] + ) + if config.halide.scan_kernels: + result[BackendFeature.SCAN] = None + return result + + def define_kernel(self, src_code, node_schedule, kernel): + """Codegen kernel definition to go in output wrapper code""" + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + kernel_name = f"halide_kernel_{wrapper.next_kernel_suffix()}" + wrapper.src_to_kernel[src_code] = kernel_name + wrapper.add_import_once( + "from torch._inductor.runtime.hints import HalideMeta, HalideInputSpec" + ) + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline( + f"async_compile.halide({kernel.halide_kernel_meta()!r}, '''" + ) + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''')") + + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment = f"{origins}\n{detailed_origins}" + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + if is_metric_table_enabled("kernel_metadata"): + log_kernel_metadata(kernel_name, "", src_code) + + return kernel_name diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/memory_planning.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/memory_planning.py new file mode 100644 index 0000000000000000000000000000000000000000..60360597ec1cb1cc6d82d138d6807aeda1d42ac3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/memory_planning.py @@ -0,0 +1,770 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import dataclasses +import itertools +import pprint +from typing import Any, Dict, Iterable, List, Optional, Protocol + +import sympy + +import torch + +from .. import config, ir +from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer +from ..virtualized import V +from .wrapper import ( + AllocateLine, + FreeIfNotReusedLine, + MemoryPlanningLine, + NullLine, + ReuseLine, +) + + +@dataclasses.dataclass +class LiveRange: + """ + A range where a given tensor is live. Begin and end are both counters + representing points in the program of grouped memory operations. + Begin is inclusive, end is exclusive. + + Invariant: begin <= end + """ + + begin: float # int | +/-inf + end: float # int | +/-inf + + def contains(self, other: LiveRange): + """Is other entirely within self""" + return self.begin <= other.begin and other.end <= self.end + + def join(self, other: LiveRange): + """Combine two ranges using a union operation""" + return LiveRange(min(self.begin, other.begin), max(self.end, other.end)) + + def __len__(self): + return self.end - self.begin + + +class LiveRanges: + """ + A collection of LiveRange regions, allowing for non-contiguous + live regions. + + Invariant: LiveRanges.ranges is in sorted order and non-overlapping + """ + + def __init__(self, ranges: Iterable[LiveRange]): + ranges = [*sorted(ranges, key=lambda x: x.begin)] + self.ranges = ranges[:1] + for r in ranges[1:]: + assert self.ranges[-1].begin <= r.begin + if self.ranges[-1].end >= r.begin: + self.ranges[-1] = LiveRange.join(self.ranges[-1], r) + else: + self.ranges.append(r) + + def overlaps(self, other: LiveRanges): + """Check if any pair of ranges in self and other overlap""" + left = collections.deque(self.ranges) + right = collections.deque(other.ranges) + while left and right: + if left[0].begin > right[0].begin: + left, right = right, left + assert left[0].begin <= right[0].begin + if left[0].end > right[0].begin: + return True + left.popleft() + return False + + @property + def begin(self): + return self.ranges[0].begin + + @property + def end(self): + return self.ranges[-1].end + + def __repr__(self): + return f"{self.__class__.__name__}([{', '.join(map(repr, self.ranges))}])" + + +class AllocationTreeNode: + """ + Abstract base class for nodes in allocation pool. + """ + + def allocate(self, block: Allocation, is_last: bool) -> bool: + """ + Try to assign block to a memory location in this bool. Return True if + an assignment was made. + """ + return False + + def get_live_ranges(self) -> LiveRanges: + """Aggregate LiveRanges for all objects below this in tree""" + raise NotImplementedError + + def get_size_hint(self) -> int: + """Number of bytes used for example inputs""" + raise NotImplementedError + + def get_symbolic_size(self) -> sympy.Expr: + """Number of bytes needed at runtime""" + raise NotImplementedError + + def finalize(self, pool, offset) -> AllocationTreeNode: + """Called after all allocations have been made""" + return self + + def is_empty(self): + return False + + +@dataclasses.dataclass +class Allocation(AllocationTreeNode): + """ + Represents memory allocated to a given node in the allocation pool. + """ + + node: ir.Buffer + live_range: LiveRange + size_hint: int + symbolic_size: sympy.Expr + allocated: bool = False + pool: Optional[AllocationPool] = None + offset: Optional[sympy.Expr] = None + + @property + def device(self): + return self.node.get_device() + + def get_live_ranges(self): + return LiveRanges([self.live_range]) + + def get_size_hint(self): + return self.size_hint + + def get_symbolic_size(self): + return self.symbolic_size + + def mark_allocated(self): + assert not self.allocated + self.allocated = True + + def finalize(self, pool, offset): + assert self.pool is None and self.offset is None + self.pool = pool + self.offset = offset + return self + + def codegen_alloc_from_pool(self, wrapper): + assert self.pool + node = self.node + shape = tuple(node.get_size()) + stride = tuple(node.get_stride()) + return wrapper.codegen_alloc_from_pool( + self.pool.name, self.offset, node.get_dtype(), shape, stride + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"node={self.node.get_name()}, " + f"live_range={self.live_range}, " + f"size_hint={self.size_hint}, " + f"symbolic_size={self.symbolic_size}, " + f"pool={self.pool.name if self.pool else None}, " + f"offset={self.offset})" + ) + + +@dataclasses.dataclass +class Empty(AllocationTreeNode): + """ + Placeholder to represent empty space in the allocation pool. + Only exists to get the size_hint correct in parent nodes. + """ + + size_hint: int + + def get_live_ranges(self): + return LiveRanges([]) + + def get_size_hint(self): + return self.size_hint + + def get_symbolic_size(self): + return 0 + + def is_empty(self): + return True + + +class MemorySplitProtocol(Protocol): + get_live_ranges: CachedMethod[[], LiveRanges] + get_size_hint: CachedMethod[[], int] + get_symbolic_size: CachedMethod[[], sympy.Expr] + + def _allocate(self, block: Allocation, is_last: bool) -> bool: + ... + + +class ClearCacheOnAllocateMixin(MemorySplitProtocol): + """ + Helper to assist in caching get_live_ranges, get_size_hint, and + get_symbolic_size. + """ + + def allocate(self, block: Allocation, is_last: bool): + is_allocated = self._allocate(block, is_last) + if is_allocated: + self.clear_cache() + return is_allocated + + def clear_cache(self): + self.get_live_ranges.clear_cache(self) + self.get_size_hint.clear_cache(self) + self.get_symbolic_size.clear_cache(self) + + +@dataclasses.dataclass +class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): + """ + Contains a list of allocations not overlapping in LiveRanges. + + Invariant: no pair (a,b) in self.allocations will have: + a.get_live_ranges().overlaps(b.get_live_ranges()) + """ + + allocations: List[AllocationTreeNode] + + def _allocate(self, block: Allocation, is_last: bool): + slot_size = self.get_size_hint() + block_size = block.get_size_hint() + if not is_last and block_size > slot_size: + return False # doesn't fit + + block_live = block.get_live_ranges() + overlapping = [ + s for s in self.allocations if s.get_live_ranges().overlaps(block_live) + ] + if len(overlapping) > 1: + # TODO(jansel): we could try harder here by merging overlapping in space + return False + elif len(overlapping) == 1: + return overlapping[0].allocate(block, is_last) + else: + block.mark_allocated() + + if len(self.allocations) == 1 and isinstance(self.allocations[-1], Empty): + self.allocations.pop() + + if slot_size == block_size: + # perfect fit + self.allocations.append(block) + elif slot_size > block_size: + self.allocations.append( + SpatialSplit.create(block, slot_size - block_size) + ) + else: # grow this allocation + assert is_last + self.allocations = [ + *( + SpatialSplit.create(a, block_size - slot_size) + for a in self.allocations + ), + block, + ] + return True + + @cache_on_self + def get_live_ranges(self) -> LiveRanges: + return LiveRanges( + itertools.chain.from_iterable( + x.get_live_ranges().ranges for x in self.allocations + ) + ) + + @cache_on_self + def get_size_hint(self) -> int: + if not self.allocations: + return 0 + return max(x.get_size_hint() for x in self.allocations) + + @cache_on_self + def get_symbolic_size(self) -> sympy.Expr: + if not self.allocations: + return 0 # type: ignore[return-value] + return sympy.Max(*[x.get_symbolic_size() for x in self.allocations]) + + def is_empty(self): + return len(self.allocations) == 1 and self.allocations[0].is_empty() + + def finalize(self, pool, offset): + self.allocations = [block.finalize(pool, offset) for block in self.allocations] + self.clear_cache() + if len(self.allocations) == 1: + return self.allocations[0] + return self + + +@dataclasses.dataclass +class SpatialSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): + """ + Contains two allocations, left and right, that do not overlap in space. + Right will be allocated immediately after left in memory. + """ + + left: TemporalSplit + right: TemporalSplit + + @staticmethod + def create(left, extra_space): + assert isinstance(left, AllocationTreeNode) + assert isinstance(extra_space, int) and extra_space >= 1 + return SpatialSplit(TemporalSplit([left]), TemporalSplit([Empty(extra_space)])) + + def _allocate(self, block: Allocation, is_last: bool): + return self.left.allocate(block, False) or self.right.allocate(block, is_last) + + @cache_on_self + def get_live_ranges(self): + return LiveRanges( + itertools.chain( + self.left.get_live_ranges().ranges, self.right.get_live_ranges().ranges + ) + ) + + @cache_on_self + def get_size_hint(self) -> int: + return _align(self.left.get_size_hint()) + self.right.get_size_hint() + + @cache_on_self + def get_symbolic_size(self) -> sympy.Expr: + return align(self.left.get_symbolic_size()) + self.right.get_symbolic_size() + + def finalize(self, pool, offset): + self.left = self.left.finalize(pool, offset) + self.right = self.right.finalize( + pool, offset + align(self.left.get_symbolic_size()) + ) + self.clear_cache() + if self.right.is_empty(): + return self.left + return self + + +@dataclasses.dataclass +class AllocationPool: + """ + Represents a pool of allocations that will be generated by a single + call to torch.empty. + """ + + device: torch.device + root: TemporalSplit + can_expand: bool = True + restrict_live_range: Optional[LiveRange] = None + name: Optional[str] = None + names_to_del: List[str] = dataclasses.field(default_factory=list) + creation_cache: Dict[str, str] = dataclasses.field(default_factory=dict) + + def allocate(self, block: Allocation, is_last: bool): + if self.restrict_live_range and not self.restrict_live_range.contains( + block.live_range + ): + return False + + is_last = self.can_expand and is_last + if self.root.allocate(block, is_last): + return True + + if is_last: + return self.allocate_at_end(block) + + return False + + def allocate_at_end(self, block): + block.mark_allocated() + self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))]) + return True + + def finalize(self, name): + assert not self.name + self.name = name + self.names_to_del.append(name) + self.root.finalize(self, 0) + + def codegen_create(self, wrapper, code: IndentedBuffer): + assert self.name + nbytes = self.root.get_symbolic_size() + for block in self.root.allocations: + if isinstance(block, Allocation) and nbytes == block.get_symbolic_size(): + # optimization: fuse first allocation and pool creation + node = block.node + code.writeline( + wrapper.make_allocation( + self.name, + device=self.device, + dtype=node.get_dtype(), + shape=tuple(node.get_size()), + stride=tuple(node.get_stride()), + ) + ) + self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name + return + else: + code.writeline( + wrapper.make_allocation( + self.name, + device=self.device, + dtype=torch.uint8, + shape=(nbytes,), + stride=(1,), + ) + ) + + def codegen_destroy(self, wrapper, code: IndentedBuffer): + code.writeline(wrapper.make_free_by_names(self.names_to_del)) + + def __eq__(self, other): + return self is other + + def __hash__(self): + return id(self) + + +@dataclasses.dataclass +class AllocationPools: + """ + Collection of many AllocationPool objects grouped by device. + """ + + device_to_pools: Dict[torch.device, List[AllocationPool]] = dataclasses.field( + default_factory=dict + ) + + def get_pools(self, block): + if block.device not in self.device_to_pools: + self.device_to_pools[block.device] = [] + return self.device_to_pools[block.device] + + def allocate(self, block: Allocation): + pools = self.get_pools(block) + + for pool in pools: + if pool.allocate(block, is_last=pool is pools[-1]): + return + + # everything is full, make a new pool + pools.append( + AllocationPool( + block.device, + TemporalSplit([block]), + can_expand=config.memory_pool != "none", + ) + ) + block.mark_allocated() + + def allocate_output(self, block: Allocation): + """Outputs get different pools so memory gets freed properly""" + pools = self.get_pools(block) + if pools and config.memory_pool in ("outputs", "combined"): + pools[-1].allocate_at_end(block) + else: + # create a new pool + block.mark_allocated() + pools.append( + AllocationPool( + block.device, + TemporalSplit([block]), + can_expand=config.memory_pool == "combined", + ) + ) + + def finalize(self): + """Called at the end of allocation process""" + for i, pool in enumerate( + itertools.chain.from_iterable(self.device_to_pools.values()) + ): + pool.finalize(f"pool{i}") + + def pprint(self): + for pool in itertools.chain.from_iterable(self.device_to_pools.values()): + print() + print(pool.name) + print(pool.root.get_live_ranges()) + pprint.pprint(pool.root) + + +class BufferGroup: + """ + Due to inplace reuse an allocated buffer can have many names. + This tracks these collections of buffers sharing underlying memory. + """ + + def __init__(self, node: ir.Buffer): + self.node = node + self.names = [node.get_name()] + self.is_output = False + self.allocation: Optional[Allocation] = None + self.live_range = LiveRange(float("inf"), -float("inf")) + + def update_usage(self, timestep: int): + """Expand self.live_range to include timestep""" + self.live_range = LiveRange( + min(timestep, self.live_range.begin), + max(timestep, self.live_range.end), + ) + + def sym_nbytes(self): + return self.node.get_layout().storage_size() * self.node.get_dtype().itemsize + + def make_allocation(self): + assert not self.allocation, "multiple allocations" + assert isinstance(self.live_range.begin, int), "live ranges not computed" + nbytes = self.sym_nbytes() + # For now, fallback value will be used if we encounter an unbacked SymInt. The longer-term plan is to have + # size_hint() use better heuristics for unbackeds, at which point the fallback value will be ignored. + size_hint = V.graph.sizevars.size_hint(nbytes, fallback=64) + self.allocation = Allocation( + self.node, + self.live_range, + size_hint=size_hint, + symbolic_size=nbytes, + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.names!r}, is_output={self.is_output}, " + f"live_range={self.live_range}" + ) + + +@dataclasses.dataclass +class PoolMemoryPlanningLine(MemoryPlanningLine): + """Abstract base class for {Alloc,Dealloc}FromPoolLine""" + + group: BufferGroup + timestep: Optional[int] = None + + @property + def node(self): + return self.group.node + + +@dataclasses.dataclass +class AllocFromPoolLine(PoolMemoryPlanningLine): + """Similar to AllocationLine, but takes memory from a pool""" + + is_first_pool_usage: bool = False + + def codegen(self, code: IndentedBuffer): + allocation = self.group.allocation + assert allocation and allocation.pool + pool = allocation.pool + name = self.node.get_name() + + if self.is_first_pool_usage: + pool.codegen_create(self.wrapper, code) + + pool.names_to_del.extend(self.group.names) + alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper) + if alloc_from_pool in pool.creation_cache: + code.writeline( + self.wrapper.make_tensor_alias( + name, pool.creation_cache[alloc_from_pool], "alloc" + ) + ) + else: + pool.creation_cache[alloc_from_pool] = name + code.writeline( + f"{self.wrapper.declare}{name} = {alloc_from_pool}{self.wrapper.ending}" + ) + + +@dataclasses.dataclass +class DeallocFromPoolLine(PoolMemoryPlanningLine): + """Similar to FreeIfNotReusedLine, but takes memory from a pool""" + + is_last_pool_usage: bool = False + + def codegen(self, code: IndentedBuffer): + if self.is_last_pool_usage: + assert self.group.allocation and self.group.allocation.pool + self.group.allocation.pool.codegen_destroy(self.wrapper, code) + + +@dataclasses.dataclass +class MemoryPlanner: + """ + Coordination object to run memory planning passes during wrapper + codegen. + """ + + wrapper: Any + pools: AllocationPools = dataclasses.field(default_factory=AllocationPools) + buffer_groups: Optional[List[BufferGroup]] = None + + def plan(self, lines: List[Any]) -> List[Any]: + """Call all the memory planning passes in sequence""" + lines = [*lines] + self.drop_removed_buffers(lines) + self.convert_to_pool_lines(lines) + self.compute_live_ranges(lines) + self.allocate_groups() + self.mark_first_last_usage(lines) + return lines + + def drop_removed_buffers(self, lines): + """ + Replace any memory planning lines in V.graph.removed_buffers with NullLine + """ + # drop any removed buffers + for i, line in enumerate(lines): + if isinstance(line, (AllocateLine, FreeIfNotReusedLine, ReuseLine)): + if line.node.get_name() in V.graph.removed_buffers: + lines[i] = NullLine(self.wrapper) + + def compute_buffer_groups(self, lines): + """ + Populates self.buffer_groups with BufferGroup objects that join + allocations with common storage (due to inplace reuse) into a + single object. + """ + name_to_group = {} + for line in lines: + if isinstance(line, AllocateLine): + name = line.node.get_name() + assert name not in name_to_group + name_to_group[name] = BufferGroup(line.node) + elif isinstance(line, ReuseLine): + old_name = line.node.get_name() + new_name = line.reused_as.get_name() + assert new_name not in name_to_group + # TODO(jansel): we should support reusing buffers created via ExternKernelAlloc + if old_name in name_to_group: + name_to_group[old_name].names.append(new_name) + name_to_group[new_name] = name_to_group[old_name] + + outputs = set(V.graph.get_output_names()) + unique_groups = [*{id(g): g for g in name_to_group.values()}.values()] + for group in unique_groups: + group.is_output = any(x in outputs for x in group.names) + + assert self.buffer_groups is None + self.buffer_groups = unique_groups + return name_to_group + + def convert_to_pool_lines(self, lines): + """ + Convert AllocateLine/FreeIfNotReusedLine/ReuseLine into their + pool-based counterparts. + """ + name_to_group = self.compute_buffer_groups(lines) + for i, line in enumerate(lines): + if isinstance(line, AllocateLine): + if line.node.get_name() in name_to_group: + lines[i] = AllocFromPoolLine( + self.wrapper, name_to_group[line.node.get_name()] + ) + elif isinstance(line, FreeIfNotReusedLine): + assert not line.is_reused + if line.node.get_name() in name_to_group: + lines[i] = DeallocFromPoolLine( + self.wrapper, name_to_group[line.node.get_name()] + ) + elif isinstance(line, ReuseLine): + if line.node.get_name() in name_to_group: + line.delete_old = False + + def compute_live_ranges(self, lines): + """Populate every BufferGroup.live_ranges field based on first/last usage""" + timestep = 0 + worklist = collections.deque(lines) + while worklist: + if isinstance(worklist[0], MemoryPlanningLine): + timestep += 1 + while worklist and isinstance(worklist[0], MemoryPlanningLine): + line = worklist.popleft() + if isinstance(line, PoolMemoryPlanningLine): + line.group.update_usage(timestep) + line.timestep = timestep + else: + worklist.popleft() + + timestep += 1 + assert self.buffer_groups is not None + for group in self.buffer_groups: + if group.is_output: + group.update_usage(timestep) + + def allocate_groups(self): + """ + Assign every allocation to a specific location in a specific AllocationPool. + """ + assert config.memory_pool in ("none", "intermediates", "outputs", "combined") + assert self.buffer_groups is not None + + for group in self.buffer_groups: + group.make_allocation() + + outputs: List[Allocation] = [] + intermediates: List[Allocation] = [] + for group in self.buffer_groups: + assert group.allocation + if group.is_output and config.memory_pool != "combined": + outputs.append(group.allocation) + else: + intermediates.append(group.allocation) + + for block in sorted( + outputs, + key=lambda x: ( + x.size_hint, + -len(x.live_range), + ), + ): + self.pools.allocate_output(block) + + for block in sorted( + intermediates, + key=lambda x: ( + -x.size_hint, + -len(x.live_range), + ), + ): + self.pools.allocate(block) + + self.pools.finalize() + + def mark_first_last_usage(self, lines): + """ + Populate the AllocFromPoolLine.is_first_pool_usage and + DeallocFromPoolLine.is_last_pool_usage fields so that pools + are created/destroyed. + """ + seen = set() + for line in lines: + if isinstance(line, AllocFromPoolLine): + assert line.group.allocation + pool = line.group.allocation.pool + assert pool is not None + if pool not in seen: + line.is_first_pool_usage = True + seen.add(pool) + + seen = set() + for line in reversed(lines): + if isinstance(line, DeallocFromPoolLine): + assert line.group.allocation + pool = line.group.allocation.pool + assert pool is not None + if pool not in seen: + line.is_last_pool_usage = ( + pool.root.get_live_ranges().end <= line.timestep + ) + seen.add(pool) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/multi_kernel.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/multi_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..020c38dcc77e4ef9d185cb1fd900dab6c768d36a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/multi_kernel.py @@ -0,0 +1,385 @@ +# mypy: allow-untyped-defs +import logging +import os +import pathlib +from typing import Any, List + +from torch._inductor.metrics import get_metric_table, is_metric_table_enabled +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..codecache import get_path, TritonFuture +from ..runtime.benchmarking import benchmarker +from ..utils import cache_on_self, IndentedBuffer +from ..virtualized import V +from .common import TensorArg + + +log = logging.getLogger(__name__) + + +def get_kernel_argdefs(kernel): + arg_defs, _, _, _ = kernel.args.python_argdefs() + return arg_defs + + +def _get_all_args(args_list, arg_types_list=None): + all_args = max(args_list, key=len)[:] + arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None + for args in args_list: + assert set(args).issubset(set(all_args)), f"{args} v.s. {all_args}" + + return all_args, arg_types + + +def get_all_kernel_argdefs(kernels): + """ + The logic here must match with `get_all_call_args`, except no need to get arg_types here + """ + argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels] + + return _get_all_args(argdefs_list)[0] + + +def get_all_call_args(call_args_list, arg_types_list): + """ + Passed in the call_args for each subkernel and return the call_args for the + combined multi-kernel. + + Note an algorithm as follows does not always work: + ``` + all_call_args: Dict[ + Any, None + ] = {} # use a dict rather than set to maintain insertion order + for call_args in call_args_list: + all_call_args.update({arg: None for arg in call_args}) + + all_call_args = list(all_call_args.keys()) + ``` + It will fail if any kernel has the same argument passed in multiple times. + Check test_pass_same_arg_multi_times in test_multi_kernel.py + + Instead, we pick the longest call args and assert that other call args are + a subset of it. + """ + return _get_all_args(call_args_list, arg_types_list) + + +def get_numel_argdefs(kernel): + numel_argdefs = [] + for tree in kernel.range_trees: + if tree.prefix != "r" or kernel.inside_reduction: + numel_argdefs.append(f"{tree.prefix}numel") + + return numel_argdefs + + +class MultiKernelState: + """ + Maintain state of multi-kernel compilation so we don't define duplicated + multi-kernel for the same set of sub-kernels. + + V.graph.wrapper_code has a reference to MultiKernelState instance. + """ + + def __init__(self): + self.subkernel_to_kernel_name = {} + + def define_kernel(self, kernels): + """ + Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}". + This has some minor issue. + + E.g. for persistent reduction https://gist.github.com/shunting314/39e7c00ff8bb2055942ed5a3255d61ca , + there are 2 flavors of non-persistent reduction: + https://gist.github.com/shunting314/056d43d35907e87efb883970b35c17d4 + and + https://gist.github.com/shunting314/02ee753b65c513c54e695626afe682bd + + The only different is cache eviction policy. + + We should name the multi-kernel differently in these 2 cases. + """ + kernel_names = tuple(k.kernel_name for k in kernels) + if kernel_names in self.subkernel_to_kernel_name: + return self.subkernel_to_kernel_name[kernel_names] + + # name the multi kernel based on the first kernel + multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}" + self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name + + if V.graph.cpp_wrapper: + # we should not generate any python code for multi-kernel during + # the second pass of cpp-wrapper. + return multi_kernel_name + + buf = IndentedBuffer() + buf.writeline( + f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, [" + ) + with buf.indent(): + for name in kernel_names: + buf.writeline(f"{name},") + buf.writeline("])") + + wrapper = V.graph.wrapper_code + wrapper.header.splice(buf) + if config.triton.autotune_at_compile_time: + wrapper.kernel_autotune_defs.splice(buf) + + return multi_kernel_name + + +class MultiKernel: + """ + This class maintains the compile time state for multi kernels. + + Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2. + The generated definition for the multi-kernel will looks like: + ``` + multi_kernel_kernel1 = MultiKernelCall([kernel1, kernel2], multi_kernel_definition_code) + ``` + + Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39 + """ + + def __init__(self, kernels): + assert len(kernels) >= 2 + + self.kernels = kernels + self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel( + kernels + ) + + # need this since some code in inductor check if the kernel object has an args + # attribute to decide if it's a non-null kernel. + self.args = object() + + def call_kernel(self, kernel_name): + """ + Collect the union of arguments from all subkernels as the arguments + for the multi-kernel. + """ + assert kernel_name == self.kernel_name + V.graph.wrapper_code.write_triton_header_once() + _, call_args, _, arg_types = self.kernels[0].args.python_argdefs() + for kernel in self.kernels[1:]: + _, other_call_args, _, other_arg_types = kernel.args.python_argdefs() + assert call_args == other_call_args + assert arg_types == other_arg_types + + grid: List[Any] = [] + + if V.graph.cpp_wrapper: + # for the second pass of cpp-wrapper codegen, we should call + # the fast kernel directly + picked_kernel = MultiKernelCall.lookup_choice(kernel_name) + kernel_name = self.kernels[picked_kernel].kernel_name + + # numels for all subkernels should be the same. Use kernels[0] here + self.kernels[0].add_numel_to_call_args_and_grid( + kernel_name, call_args, arg_types, grid + ) + + grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid) + V.graph.wrapper_code.generate_kernel_call( + kernel_name, + call_args, + grid, + arg_types=arg_types, + ) + + def codegen_nan_check(self): + wrapper = V.graph.wrapper_code + seen = set() + for k in self.kernels: + _, call_args, precompile_args, _ = k.args.python_argdefs() + for arg, precompile_arg in zip(call_args, precompile_args): + if arg in seen: + continue + seen.add(arg) + if isinstance(precompile_arg, TensorArg): + line = f"assert not {arg}.isnan().any().item()" + wrapper.writeline(line) + line = f"assert not {arg}.isinf().any().item()" + wrapper.writeline(line) + + @property + def removed_buffers(self): + return OrderedSet.intersection(*[k.removed_buffers for k in self.kernels]) + + @property + def inplaced_to_remove(self): + return OrderedSet.intersection(*[k.inplaced_to_remove for k in self.kernels]) + + @property + @cache_on_self + def inplace_update_buffers(self): + """ + Make sure all kernels have the same inplace update mappings. + """ + for k in self.kernels[1:]: + assert k.inplace_update_buffers == self.kernels[0].inplace_update_buffers + return self.kernels[0].inplace_update_buffers + + def warn_mix_layout(self, kernel_name: str): + pass + + +class MultiKernelCall: + """ + This class is called at run time to actually run the kernel + """ + + def __init__(self, multi_kernel_name, kernels): + assert len(kernels) >= 2 + self._kernels = kernels + self.multi_kernel_name = multi_kernel_name + + self.disable_cache = os.environ.get( + "TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE" + ) == "1" or is_metric_table_enabled("persistent_red_perf") + + self.picked_kernel = None + if config.triton.multi_kernel > 1: + # manually force a subkernel to ease perf testing + picked_by_config = config.triton.multi_kernel - 2 + assert picked_by_config < len(self._kernels) + self.picked_kernel = picked_by_config + elif not self.disable_cache: + self.load_cache() + + self._recorded = False + + def cache_file_path(self): + _, _, path = get_path(self.kernels[0].fn.cache_key, "picked_kernel") + return pathlib.Path(path) + + def load_cache(self): + assert self.picked_kernel is None + path = self.cache_file_path() + if path.exists(): + with path.open() as fd: + self.picked_kernel = int(fd.read()) + assert self.picked_kernel >= 0 and self.picked_kernel < len( + self._kernels + ) + log.debug( + "Load picked kernel %d from cache file %s", self.picked_kernel, path + ) + + def store_cache(self): + assert self.picked_kernel is not None + path = self.cache_file_path() + path.parent.mkdir(parents=True, exist_ok=True) + + with path.open("w") as fd: + fd.write(str(self.picked_kernel)) + log.debug("Store picked kernel %d to cache file %s", self.picked_kernel, path) + + @property + def kernels(self): + """ + Read results from future. + + This should be called after parallel compilation is done. + In case you call this before compilation is done, + it may slow down the parallel compilation. + """ + for i, kernel in enumerate(self._kernels): + if isinstance(kernel, TritonFuture): + self._kernels[i] = kernel.result() + + return self._kernels + + def benchmark_sub_kernels(self, *args, **kwargs): + """ + Benchmark all the sub kernels and return the execution time + (in milliseconds) for each of time. + + Unit test may mock this method to force a specific kernel to + be picked. + """ + + def wrap_fn(kernel): + def inner(): + args_clone, kwargs_clone = kernel.clone_args(*args, **kwargs) + return kernel.run(*args_clone, **kwargs_clone) + + return inner + + return [ + benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40, fast_flush=True) + for kernel in self.kernels + ] + + # record_choice and lookup_choice are helper functions for cpp-wrapper + # codegen. The first pass use record_choice to keep the choice and + # the second pass do lookup by calling lookup_choice. + # + # An alternative that reused the multi-kernel cache does not work well + # since during codegen of the second pass, it's very hard to know the + # path for the cache file. Also reading the cache file need do some IO + # which can be slower. + @staticmethod + def record_choice(multi_kernel_name, choice): + """ + Record the multi-kernel choice for cpp-wrapper first pass codegen + for the second pass. + + We should do nothing if this function is not called during codegen. + """ + from torch._inductor.graph import GraphLowering + + if not isinstance(V.graph, GraphLowering): + return + + if not V.graph.record_multi_kernel_choice: + return + + V.graph.multi_kernel_to_choice[multi_kernel_name] = choice + + @staticmethod + def lookup_choice(multi_kernel_name): + # this should always been done during cpp-wrapper codegen + assert V.graph.record_multi_kernel_choice + # there should be no miss + return V.graph.multi_kernel_to_choice[multi_kernel_name] + + def run(self, *args, **kwargs): + if self.picked_kernel is None: + timings = self.benchmark_sub_kernels(*args, **kwargs) + self.picked_kernel = timings.index(min(timings)) + k0 = self.kernels[0] + log.debug( + "pick %dth sub-kernel in %s. Size hints %s. Reduction hint %s. Timings %s", + self.picked_kernel, + [k.inductor_meta.get("kernel_name") for k in self.kernels], + k0.size_hints, + k0.inductor_meta.get("reduction_hint"), + timings, + ) + + def get_kernel_path(k): + return k.fn.fn.__code__.co_filename + + get_metric_table("persistent_red_perf").add_row( + lambda: { + "kernel1_name": get_kernel_path(self.kernels[0]), + "kernel2_name": get_kernel_path(self.kernels[1]), + "kernel1_latency": timings[0], + "kernel2_latency": timings[1], + "size_hints": k0.size_hints, + "reduction_hint": k0.inductor_meta.get("reduction_hint"), + "speedup": timings[1] / timings[0], + } + ) + + if not self.disable_cache: + self.store_cache() + + if not self._recorded: + self._recorded = True + self.record_choice(self.multi_kernel_name, self.picked_kernel) + self.run = self.kernels[self.picked_kernel].run # type: ignore[method-assign] + self.run(*args, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__init__.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b6be5a7d79fba874768246ffe246b3d9f44037f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_template.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04efd11fb2f04c9e4a1772808650a78e34348d11 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_template.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_universal_gemm_template.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_universal_gemm_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd3da7ad6d915e0075501e6fa61b616edf906050 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_universal_gemm_template.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/compile_command.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/compile_command.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a78a86fecf631c625872de6c81b3e35a2be8d99 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/compile_command.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_benchmark_request.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_benchmark_request.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21ddd64d5016d573de2c428e01b92ac925f9918b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_benchmark_request.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_cpp_scheduling.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_cpp_scheduling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8b99cd53859b18eed3a7351122790cbae4aa7ad Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_cpp_scheduling.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_kernel.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_kernel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66314387ee03b6b5a5205ce5949141b8eca4fec6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_kernel.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b7a8dddb6b794f0133c1a5cbaacf24c19915fc9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template_buffer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template_buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e43d105a72e6309bdb17eab4d71d3b9236717509 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template_buffer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/ck_template.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/ck_template.py new file mode 100644 index 0000000000000000000000000000000000000000..bb1ce40fc1372aa00ddeaad6be618f4e6afba3a7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/ck_template.py @@ -0,0 +1,93 @@ +import torch +from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate +from torch._inductor.ir import IRNode +from torch._inductor.utils import IndentedBuffer + + +class CKTemplate(ROCmTemplate): + """ + Base class for generating CK templates, has common, i.e. non-gemm-specific, code generation logic + """ + + _TORCH_DTYPE_TO_CK = { + torch.float32: "F32", + torch.float64: "F64", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int32: "I32", + torch.int8: "I8", + torch.float8_e4m3fnuz: "F8", + torch.float8_e5m2fnuz: "BF8", + } + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // HIP headers + + #include + + // CK headers + + #ifdef DEBUG_LOG + #define DEBUG_LOG_TMP DEBUG_LOG + #undef DEBUG_LOG + #else + #define DEBUG_LOG_TMP 0 + #endif + #include "ck/ck.hpp" + #undef DEBUG_LOG + #define DEBUG_LOG DEBUG_LOG_TMP + + #include "ck/utility/data_type.hpp" + #include "ck/library/utility/check_err.hpp" + #include "ck/library/utility/device_memory.hpp" + #include "ck/library/utility/fill.hpp" + #include "ck/library/utility/host_tensor.hpp" + #include "ck/library/utility/host_tensor_generator.hpp" + #include "ck/library/utility/literals.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK globals + + template + using S = ck::Sequence; + + template + using Tuple = ck::Tuple; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using Bilinear = ck::tensor_operation::element_wise::Bilinear; + + // see "composable_kernel/include/ck/utility/data_type.hpp" + using F8 = ck::f8_t; + using BF8 = ck::bf8_t; + using F16 = ck::half_t; + using F32 = float; + // using F64 = double; + using BF16 = ck::bhalf_t; + // using I32 = int32_t; + // using I8 = int8_t; + // using I4 = ck::int4_t; + + #if DEBUG_LOG + static constexpr auto kDEBUG_LOG = 1; + #else + static constexpr auto kDEBUG_LOG = 0; + #endif + """ + ) + return res + + def torch_type_to_ck(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._TORCH_DTYPE_TO_CK.get(node.get_dtype())}*)({ptr})" diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..d247103a9d401937c76360ce9d744e7033cd393c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py @@ -0,0 +1,426 @@ +# mypy: allow-untyped-defs, disable-error-code="attr-defined, valid-type" +import copy +import logging +import random +from typing import List, Optional + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.codegen.rocm.ck_template import CKTemplate +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.ir import Buffer, Layout + +from ...utils import IndentedBuffer, try_import_ck_lib + + +_, gen_ops_library, gen_ops_preselected, CKGemmOperation = try_import_ck_lib() + + +log = logging.getLogger(__name__) + + +def is_static_int(number): + return isinstance(number, (int, sympy.Integer)) + + +def torch_layout_to_ck_layout(torch_layout): + if torch_layout.stride[-1] == 1: + return "Row" + elif torch_layout.stride[-2] == 1: + return "Col" + else: + return None + + +class CKGemmTemplate(CKTemplate): + # the JINJA template for rendering CK Universal GEMMs + gemm_template = r"""{{version_comment}} + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + {{kernel_definition}} { + auto gemm = {{instance_type}} {}; + auto invoker = gemm.MakeInvoker(); + + auto argument = gemm.MakeArgument( + reinterpret_cast(X), + reinterpret_cast(W), + std::array{ {{'Bias' if has_bias else ''}} }, + reinterpret_cast<{{c_element_dtype}}*>(Y), + M, + N, + K, + LDA, + LDB, + std::array{ {{'LDD' if has_bias else ''}} }, + LDC, + 1, // kBatch + PassThrough {}, // a_elementwise_op + PassThrough {}, // b_elementwise_op + {{epilogue}} // c_elementwise_op + ); + if (!gemm.IsSupportedArgument(argument)) { + // we do our best to statically avoid this case in `filter_op` + std::cerr << "invalid argument for gemm instance " << gemm.GetTypeString() << std::endl; + argument.Print(); + return -23; + } + if (workspace_size) { + *workspace_size = gemm.GetWorkSpaceSize(&argument); + return 0; + } + // run the kernel + float elapsed_time = invoker.Run(argument, StreamConfig{stream, /* time kernel */ false, /* log level */ kDEBUG_LOG}); + return 0; + } // kernel definition + } // extern C + """ + + def __init__( + self, + input_nodes: List[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[List[int]] = None, + ) -> None: + super().__init__( + "ck_gemm_template", + input_nodes=input_nodes, + layout=layout, + input_reorder=input_reorder, + ) + self.alpha = alpha + self.beta = beta + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK GEMM header(s) + + #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK GEMM globals + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using BlockGemmPipelineScheduler = ck::BlockGemmPipelineScheduler; + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + using BlockGemmPipelineVersion = ck::BlockGemmPipelineVersion; + """ + ) + return res + + def filter_op(self, op: "CKGemmOperation"): + """ + Determines whether a given op definition is suitable for the current + input / output of the operation that this template implements. + + Filter is based on inputs' dtype, layout and statically inferred size. + + Returns None if the op is not suitable, otherwise returns the op to be used. + """ + metas = [T.get_layout() for T in [*self.input_nodes, self.output_node]] + X_meta = metas[0] + W_meta = metas[1] + Y_meta = metas[-1] + # disable the instance if dtypes don't match + if op.a_element_dtype != self._TORCH_DTYPE_TO_CK[X_meta.dtype]: + return None + if op.b_element_dtype != self._TORCH_DTYPE_TO_CK[W_meta.dtype]: + return None + if op.c_element_dtype != self._TORCH_DTYPE_TO_CK[Y_meta.dtype]: + return None + # disable the instance if layouts don't match + if op.a_layout != torch_layout_to_ck_layout(X_meta): + return None + if op.b_layout != torch_layout_to_ck_layout(W_meta): + return None + if op.c_layout != torch_layout_to_ck_layout(Y_meta): + return None + # try to avoid launching the instance with invalid problem size + # see GridwiseGemm_xdl_cshuffle_v3::CheckValidity + + M = X_meta.size[-2] + K = X_meta.size[-1] + N = W_meta.size[-1] + + if is_static_int(M): + if not any( + m_padding in op.gemm_specialization + for m_padding in ["MPadding", "MNPadding", "MKPadding", "MNKPadding"] + ): + if M % op.m_per_block != 0: + return None + if is_static_int(N): + if not any( + n_padding in op.gemm_specialization + for n_padding in ["NPadding", "MNPadding", "NKPadding", "MNKPadding"] + ): + if N % op.n_per_block != 0: + return None + if is_static_int(K): + if not any( + k_padding in op.gemm_specialization + for k_padding in ["KPadding", "MKPadding", "NKPadding", "MNKPadding"] + ): + if K % op.k_per_block != 0: + return None + + a_contig_size = ( + K if op.a_layout == "Row" else M if op.a_layout == "Col" else None + ) + if ( + is_static_int(a_contig_size) + and a_contig_size % op.a_block_transfer_src_scalar_per_vector != 0 + ): + return None + b_contig_size = ( + N if op.b_layout == "Row" else K if op.b_layout == "Col" else None + ) + if ( + is_static_int(b_contig_size) + and b_contig_size % op.b_block_transfer_src_scalar_per_vector != 0 + ): + return None + c_contig_size = ( + N if op.c_layout == "Row" else M if op.c_layout == "Col" else None + ) + if ( + is_static_int(c_contig_size) + and c_contig_size + % op.c_shuffle_block_transfer_scalar_per_vector_n_per_block + != 0 + ): + return None + + # TBD disable instances with invalid number of pipeline prefetch stages + # It will avoid compiling a small percentage of unrunnable instances which fail the gemm argument check + + return op + + def emit_ck_instance(self, op: "CKGemmOperation"): + # The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance + template_definition = r""" + // Gemm operator {{operation_name}} + using Operation_{{operation_name}} = + ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< + {{template_params}}>; + +""" + # The Jinja template for generating a C++ type alias *usage* for a Universal GEMM instance + template_type = r""" + Operation_{{operation_name}} +""" + template_params = [] + for field_name, field_value in op.dict_items(): + if isinstance(field_value, tuple): + tuple_elements = ", ".join(map(str, iter(field_value))) + if "ds" in field_name: # element type and layout for bias + arg = f"/* {field_name} */ Tuple<{tuple_elements}>" + else: # tile shape + arg = f"/* {field_name} */ S<{tuple_elements}>" + template_params.append(arg) + else: + if field_value is not None: + template_params.append(f"/* {field_name} */ {field_value}") + return self._template_from_string(template_definition).render( + operation_name=op.name(), + template_params=(",\n" + 12 * " ").join(template_params), + ), self._template_from_string(template_type).render(operation_name=op.name()) + + def render(self, kernel: ROCmTemplateKernel, op: "CKGemmOperation", **kwargs) -> str: # type: ignore[override] + """ + The primary entry point for the code rendering process used in this template. + """ + epilogue_nodes = kwargs.get("epilogue_nodes", None) + assert epilogue_nodes is None or 0 == len(epilogue_nodes) + template_buffer_node = kwargs.get("template_buffer_node", None) + if template_buffer_node is not None: + self.output_node = template_buffer_node + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = self.input_nodes[2] if 3 == len(self.input_nodes) else None + + op = copy.deepcopy(op) + + # This parameter is converted into tuple because of change + # from DeviceGemm_Xdl_CShuffleV3 to DeviceGemmMultiD_Xdl_CShuffle_V3. + # The first tuple element corresponds to matmul result... + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, + ) + + if Bias is not None: + op.ds_layouts = (torch_layout_to_ck_layout(Bias.get_layout()),) + op.ds_element_dtypes = ((self._TORCH_DTYPE_TO_CK[Bias.get_layout().dtype]),) + op.c_elementwise_op = "Bilinear" + # c_shuffle_dtype is also used for adding bias to matmul result + # before converting down to the result dtype + op.c_shuffle_dtype = op.acc_dtype + # this parameter needs to be set accordingly to bias stride for correct accumulation + if op.ds_layouts[0] == "Row": + # bias has (N, ) shape + bias_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + else: + # bias has (M, 1) shape + bias_shuffle_block_transfer_scalar_per_vector_n_per_block = (1,) + # ...and the second tuple element corresponds to the bias + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block += ( + bias_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + + instance_definition, instance_type = self.emit_ck_instance(op) + + version_comment = rf"""/** +* Generated code for CK inductor backend +* See {type(self).__module__}.{type(self).__qualname__} +* +* Template instance {op} +* +* {torch.__version__=} +* {torch.version.git_version=} +*/ +""" + + return self._template_from_string(self.gemm_template).render( + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + kernel_definition=kernel.def_kernel( + inputs=[X, W, Bias], # type: ignore[list-item] + outputs=[Y], + names_str="X, W, Bias, Y", + input_reorder=self.input_reorder, + size_args=[ + f"ck::index_t {arg}" + for arg in ["M", "N", "K", "LDA", "LDB", "LDC", "LDD"] + ], + ), + instance_type=instance_type, + a_element_dtype=op.a_element_dtype, + b_element_dtype=op.b_element_dtype, + c_element_dtype=op.c_element_dtype, + bias_element_dtype=op.ds_element_dtypes[0] if Bias is not None else "", + alpha=self.alpha, + beta=self.beta, + epilogue=f"Bilinear {{ {self.alpha}, {self.beta} }}" + if Bias is not None + else "PassThrough {}", + has_bias=Bias is not None, + version_comment=version_comment, + ) + + def _is_rcr_f16(self): + X_meta, W_meta, Y_meta = ( + T.get_layout() for T in [*self.input_nodes, self.output_node] + ) + X_dtype, W_dtype, Y_dtype = ( + self._TORCH_DTYPE_TO_CK[m.dtype] for m in (X_meta, W_meta, Y_meta) + ) + X_layout, W_layout, Y_layout = ( + torch_layout_to_ck_layout(m) for m in (X_meta, W_meta, Y_meta) + ) + + return ( + X_dtype == "F16" + and W_dtype == "F16" + and Y_dtype == "F16" + and X_layout == "Row" + and W_layout == "Col" + and Y_layout == "Row" + ) + + def gen_ops(self): + """ + Creates a list of `CKGemmOperation` instances that match the GEMM operation this template represents. + The instances are guaranteed to have the correct layout, dtype and dimension padding for the GEMM input arguments. + + An instance may invalidate the GEMM configuration at runtime. + Such instances will be assigned +inf runtime by the autotune process. + """ + unfiltered_instances = ( + gen_ops_preselected() + if config.rocm.use_preselected_instances and self._is_rcr_f16() + else gen_ops_library() + ) + filtered_instances = list( + filter(lambda op: self.filter_op(op), unfiltered_instances) + ) + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.n_max_profiling_configs), + ) + if config.rocm.n_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after filter: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + @staticmethod + def add_ck_gemm_choices( + choices, + layout, + input_nodes, + alpha=1, + beta=0, + input_reorder=None, + ): + """ + Add Composable Kernel Universal GEMM instance choices to the auto-tuning list. + """ + template = CKGemmTemplate( + input_nodes, + layout, + alpha=alpha, + beta=beta, + input_reorder=input_reorder, + ) + ops = template.gen_ops() + for op in ops: + template.maybe_append_choice( + choices, + op=op, + ) + + def size_args(self): + X = self.input_nodes[0] + W = self.input_nodes[1] + Bias = self.input_nodes[2] if len(self.input_nodes) > 2 else None + Y = self.output_node + + M = X.get_size()[0] + K = X.get_size()[1] + N = W.get_size()[1] + LDA = X.get_stride()[0 if X.get_stride()[1] == 1 else 1] + LDB = W.get_stride()[0 if W.get_stride()[1] == 1 else 1] + LDC = Y.get_stride()[0 if Y.get_stride()[1] == 1 else 1] + LDD = ( + 0 + if Bias is None + else Bias.get_stride()[0 if Bias.get_stride()[1] == 1 else 1] + ) + + return M, N, K, LDA, LDB, LDC, LDD diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/compile_command.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/compile_command.py new file mode 100644 index 0000000000000000000000000000000000000000..dddb0c56d27f52c3074caff383bd5dd6676cf46a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/compile_command.py @@ -0,0 +1,127 @@ +# mypy: allow-untyped-defs +import logging +import os +from typing import List, Optional + +from torch._inductor import config +from torch._inductor.utils import is_linux + + +log = logging.getLogger(__name__) + + +def _rocm_include_paths() -> List[str]: + from torch.utils import cpp_extension + + rocm_include = ( + os.path.join(config.rocm.rocm_home, "include") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("include") + ) + if not config.rocm.ck_dir: + log.warning("Unspecified Composable Kernel include dir") + ck_include = os.path.join( + config.rocm.ck_dir or cpp_extension._join_rocm_home("composable_kernel"), + "include", + ) + return [os.path.realpath(rocm_include), os.path.realpath(ck_include)] + + +def _rocm_lib_options() -> List[str]: + from torch.utils import cpp_extension + + rocm_lib_dir = ( + os.path.join(config.rocm.rocm_home, "lib") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("lib") + ) + hip_lib_dir = ( + os.path.join(config.rocm.rocm_home, "hip", "lib") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("hip", "lib") + ) + + return [ + f"-L{os.path.realpath(rocm_lib_dir)}", + f"-L{os.path.realpath(hip_lib_dir)}", + "-lamdhip64", + ] + + +def _rocm_compiler_options() -> List[str]: + arch_list = config.rocm.arch or ["native"] + gpu_arch_flags = [f"--offload-arch={arch}" for arch in arch_list] + opts = [ + config.rocm.compile_opt_level, + "-x", + "hip", + "-std=c++17", + *gpu_arch_flags, + "-fno-gpu-rdc", + "-fPIC", + "-mllvm", + "-amdgpu-early-inline-all=true", + "-mllvm", + "-amdgpu-function-calls=false", + "-mllvm", + "-enable-post-misched=0", + ] + if config.rocm.is_debug: + opts += ["-DDEBUG_LOG=1", "-g"] + if config.rocm.save_temps: + opts += ["--save-temps=obj"] + if config.rocm.print_kernel_resource_usage: + opts += ["-Rpass-analysis=kernel-resource-usage"] + if config.rocm.flush_denormals: + opts += ["-fgpu-flush-denormals-to-zero"] + if config.rocm.use_fast_math: + opts += ["-ffast-math"] + return opts + + +def rocm_compiler() -> Optional[str]: + if is_linux(): + if config.rocm.rocm_home: + return os.path.realpath( + os.path.join(config.rocm.rocm_home, "llvm", "bin", "clang") + ) + try: + from torch.utils import cpp_extension + + return os.path.realpath( + cpp_extension._join_rocm_home("llvm", "bin", "clang") + ) + except OSError: + # neither config.rocm.rocm_home nor env variable ROCM_HOME are set + return "clang" + return None + + +def rocm_compile_command( + src_files: List[str], + dst_file: str, + dst_file_ext: str, + extra_args: Optional[List[str]] = None, +) -> str: + include_paths = _rocm_include_paths() + lib_options = _rocm_lib_options() + compiler_options = _rocm_compiler_options() + compiler = rocm_compiler() + options = ( + compiler_options + + (extra_args if extra_args else []) + + ["-I" + path for path in include_paths] + + lib_options + ) + src_file = " ".join(src_files) + res = "" + if dst_file_ext == "o": + res = f"{compiler} {' '.join(options)} -c -o {dst_file} {src_file}" + elif dst_file_ext == "so": + options.append("-shared") + res = f"{compiler} {' '.join(options)} -o {dst_file} {src_file}" + elif dst_file_ext == "exe": + res = f"{compiler} {' '.join(options)} -o {dst_file} {src_file}" + else: + raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") + return res diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py new file mode 100644 index 0000000000000000000000000000000000000000..a70f45b7033d6d9778875696ccb4140d60e09307 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py @@ -0,0 +1,133 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +from ctypes import byref, c_int, c_size_t, c_void_p +from typing import Any, Callable, Iterable, List, Optional, Union + +import torch +from torch._inductor.autotune_process import GPUDeviceBenchmarkRequest, TensorMeta +from torch._inductor.codecache import DLLWrapper, ROCmCodeCache + + +log = logging.getLogger(__name__) + + +class ROCmBenchmarkRequest(GPUDeviceBenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put CUDA Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.workspace_size: int = 0 + self.workspace: Optional[torch.Tensor] = None + self.DLL: Optional[DLLWrapper] = None + self._workspace_size_updated = False + self.hash_key: str = "" + self.source_file: str = "" + self.hash_key, self.source_file = ROCmCodeCache.write(self.source_code, "so") + + def precompile(self): + # Prepopulate code cache + # may happen in separate Threadpool + log.debug("Precompiling %s", self) + ROCmCodeCache.compile(self.source_code, "so") + log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + self.ensure_dll_loaded() + self.update_workspace_size() + args = [ + c_void_p(tensor.data_ptr()) + for tensor in list(input_tensors) + [output_tensor] + ] + size_args = [c_int(arg) for arg in self.extra_args] + log.debug( + "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + run_method = getattr(self.DLL, self.kernel_name) + workspace_ptr = c_void_p(0) + if self.workspace_size > 0: + self.workspace = torch.zeros( + (self.workspace_size + 7) // 8, + dtype=torch.float64, + device=output_tensor.device, + ) + workspace_ptr = c_void_p(self.workspace.data_ptr()) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *size_args, + None, # null workspace size ptr + workspace_ptr, # set workspace ptr, + stream_ptr, + ) + + def update_workspace_size(self) -> None: + if self._workspace_size_updated: + return + self.ensure_dll_loaded() + unique_input_count = len({meta.name for meta in self.input_tensor_meta}) + args = [c_void_p(None) for _ in range(unique_input_count + 1)] + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + + run_method = getattr(self.DLL, self.kernel_name) + # Retrieve workspace_size and initialize workspace. + c_workspace_size = c_size_t() + size_args = [c_int(arg) for arg in self.extra_args] + run_method( + *args, # input ptrs and output ptrs + *size_args, + byref( + c_workspace_size + ), # set workspace size ptr to retrieve workspace size + None, # null workspace ptr + stream_ptr, + ) + torch.cuda.synchronize() # shake out any CUDA errors + self.workspace_size = c_workspace_size.value + log.debug( + "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950 + self.workspace_size, + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + self._workspace_size_updated = True + + def ensure_dll_loaded(self): + if self.DLL is None: + self.DLL, self.hash_key, self.source_file = ROCmCodeCache.load( + self.source_code, "so" + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + self.workspace = None + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..e02c17edd04d25543c9a5c0a1b1414d6cd419cd2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py @@ -0,0 +1,99 @@ +# mypy: allow-untyped-defs +import logging +from typing import cast, Sequence + +from ... import config +from ...codecache import code_hash, get_path +from ...scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import IndentedBuffer +from .rocm_template_buffer import ROCmTemplateBuffer + + +log = logging.getLogger(__name__) + + +class ROCmCPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for ROCm C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by CUDACombinedScheduling. + + It handles fusion decisions and ROCm C++ specific template code generation. + """ + + def __init__(self, scheduler: Scheduler) -> None: + super().__init__() + self.scheduler = scheduler + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + @staticmethod + def is_rocm_cpp_template(node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, ROCmTemplateBuffer + ) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_name = "_".join(["rocm", fused_name, wrapper.next_kernel_suffix()]) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace("KERNEL_NAME", kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.rocm(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''', 'so')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a ROCm template, possibly with fused epilogues + """ + assert self.is_rocm_cpp_template( + template_node + ), "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer" + template_node = cast(SchedulerNode, template_node) + _, (numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: ROCmTemplateBuffer = cast(ROCmTemplateBuffer, template_node.node) + kernel, render = ctb.make_kernel_render(ctb) + with kernel: + template_node.mark_run() + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node] + kernel_name = self.define_kernel(src_code, node_schedule) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.scheduler.free_buffers() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..9029dbe644a5ba641b9808aaed84b8896d8e4b2b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py @@ -0,0 +1,238 @@ +# mypy: allow-untyped-defs +import logging +from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union + +from ...ir import Buffer, ChoiceCaller, IRNode, Layout, PrimitiveInfoType, TensorBox +from ...virtualized import V +from ..common import Kernel, OpOverrides +from ..cpp_utils import CppPrinter +from .rocm_benchmark_request import ROCmBenchmarkRequest +from .rocm_template_buffer import ROCmTemplateBuffer + + +if TYPE_CHECKING: + from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate + +log = logging.getLogger(__name__) + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +class ROCmKernel(Kernel): + """ + Baseclass for ROCm based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + +class ROCmTemplateKernel(ROCmKernel): + """ + Template kernels defined by ROCm in C++. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, hipStream_t stream" + + def __init__(self, kernel_name) -> None: + """ + Initializes a new instance of the ROCmTemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + # Mapping from arg name to IRNode. + self.named_nodes: Dict[str, IRNode] = {} + + def arg_name(self, node: IRNode) -> Optional[str]: + """ + Returns arg name of a given input or output node. + """ + if node is None: + return None + return {**self.args.input_buffers, **self.args.output_buffers}.get( + node.get_name(), None + ) + + def def_kernel( + self, + inputs: List[IRNode], + outputs: List[IRNode], + size_args: List[str], + names_str: str = "", + input_reorder: Optional[List[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + """ + + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + arg_defs, *_ = self.args.cpp_argdefs() + + return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {', '.join(size_args)}, {self._EXTRA_CPP_ARGS})" + + def call_kernel( + self, + name: str, + node: "ROCmTemplateBuffer", # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.WrapperCodeGen + + name: Name of kernel function. + node: The ROCmTemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + _, call_args, _, arg_types = self.args.python_argdefs() + kernel_args = [] + for arg in call_args: + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + if V.graph.is_unspec_arg(arg): + arg = arg + ".item()" + else: + arg = f"c_void_p({arg}.data_ptr())" + kernel_args.append(arg) + + # add size args + kernel_args.extend( + [ + f"c_int({V.graph.sizevars.simplify(sarg)})" + for sarg in node.template.size_args() + ] + ) + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + kernel_args.append("None") + + if node.get_workspace_size() > 0: + wrapper.generate_workspace_allocation( + node.get_workspace_size(), V.graph.scheduler.current_device, False + ) + kernel_args.append("c_void_p(workspace.data_ptr())") + else: + kernel_args.append("None") + + current_device = V.graph.scheduler.get_current_device_or_throw() + wrapper.generate_kernel_call( + name, + kernel_args, + device_index=current_device.index, + cuda=True, + triton=False, + arg_types=arg_types, + ) + if node.get_workspace_size() > 0: + wrapper.writeline(wrapper.make_free_by_names(["workspace"])) + + +class ROCmTemplateCaller(ChoiceCaller): + """ + ROCmTemplateCaller + + This class represents a caller for ROCm template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (ROCmBenchmarkRequest): The benchmark request for the caller. + template_buffer (ROCmTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: List[Buffer], + layout: Layout, + make_kernel_render: Callable[[ROCmTemplateBuffer, Optional[List[IRNode]]], str], + bmreq: ROCmBenchmarkRequest, + template: "ROCmTemplate", # type: ignore[name-defined] + info_kwargs: Optional[Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]], # type: ignore[type-arg] + ) -> None: + super().__init__(name, input_nodes, layout) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + return self.bmreq.benchmark(*args, output_tensor=out) + + def __str__(self) -> str: + return f"ROCmTemplateCaller(source_file={self.bmreq.source_file}, {self.info_dict()})" + + def call_name(self) -> str: + return f"rocm_template_kernels.{self.name}" + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "ROCm", + "name": self.name, + **dict(self.info_kwargs["op"].dict_items()), # type: ignore[union-attr, index] + } + + def output_node(self) -> TensorBox: + self.bmreq.update_workspace_size() + return TensorBox.create( + ROCmTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + template=self.template, + ) + ) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_template.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..bd6957c17702cb64097ea8634c1f5101206a6abd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_template.py @@ -0,0 +1,168 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from typing import List, Optional +from unittest.mock import patch + +from ...autotune_process import TensorMeta +from ...ir import Buffer, IRNode, Layout +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import KernelTemplate +from .rocm_benchmark_request import ROCmBenchmarkRequest +from .rocm_kernel import ROCmTemplateCaller, ROCmTemplateKernel +from .rocm_template_buffer import ROCmTemplateBuffer + + +log = logging.getLogger(__name__) + + +class ROCmTemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes: List[Buffer], + layout: Layout, + input_reorder: Optional[List[int]] = None, + ) -> None: + """ + + Baseclass for ROCm C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + + Args: + name (str): The name of the ROCmTemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. + + """ + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer("buf_out", layout) + self.input_reorder = input_reorder + self.layout = layout + + def generate( # type: ignore[override] + self, + **kwargs, + ) -> ROCmTemplateCaller: + """ + Generates the ROCm template caller object for the given GEMM template and operation. This ROCmTemplateCaller + may be used to call and benchmark the generated ROCm kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A ROCmTemplateCaller object representing the generated ROCm template caller. + """ + kernel_name = f"rocm_{self.name}" + kernel_hash_name = f"rocm_{self.name}_{next(self.index_counter)}" + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(self.output_node) + ), ROCmTemplateKernel( + kernel_name=kernel_name, + ) as kernel: + code = self.render(kernel=kernel, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + log.debug("Autotune key: %s, Generated Code:\n%s", kernel_hash_name, code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + + size_args = ( + self.size_args() if hasattr(self, "size_args") else () + ) # subclass should define def size_args() + size_args_ints = [ + V.graph.sizevars.size_hint(arg) for arg in size_args + ] # resolve to ints for benchmarking + bmreq = ROCmBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=size_args_ints, + source_code=code, + ) + + def make_kernel_render( + template_node: ROCmTemplateBuffer, + epilogue_nodes: Optional[List[IRNode]] = None, + ): + kernel = ROCmTemplateKernel( + kernel_name="KERNEL_NAME", + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate + ) + return kernel, render + + return ROCmTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + kwargs, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + using bfloat16 = hip_bfloat16; + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..105a6224c005bda6a0a65924c3d92aa63d457aa3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py @@ -0,0 +1,20 @@ +# mypy: allow-untyped-defs +from ...ir import TemplateBuffer + + +class ROCmTemplateBuffer(TemplateBuffer): + def __init__( + self, + layout, + inputs, + make_kernel_render, + workspace_size: int, + template: "ROCmTemplate", # type: ignore[name-defined] # noqa: F821 + ) -> None: + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + self.template = template + + def get_workspace_size(self): + return self.workspace_size if self.workspace_size is not None else 0 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/simd.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/simd.py new file mode 100644 index 0000000000000000000000000000000000000000..59c126b993e3d4ddf95de04b9bc6de6eaf9b4d25 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/simd.py @@ -0,0 +1,1869 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +from typing import ( + Any, + Callable, + Counter, + DefaultDict, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + Union, +) + +import sympy + +import torch +import torch._logging +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT + +from ..._dynamo.utils import counters +from .. import config, ir, scheduler +from ..codecache import code_hash +from ..dependencies import Dep, MemoryDep, StarDep, WeakDep +from ..ir import IRNode, TritonTemplateBuffer +from ..optimize_indexing import indexing_dtype_strength_reduction +from ..runtime.hints import ReductionHint +from ..runtime.runtime_utils import green_text, yellow_text +from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse +from ..utils import ( + get_dtype_size, + IndentedBuffer, + Placeholder, + sympy_index_symbol, + sympy_product, + sympy_subs, + unique, +) +from ..virtualized import ops, OpsWrapper, V +from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter +from .multi_kernel import MultiKernel + + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") + + +pexpr = PythonPrinter().doprint + + +@dataclasses.dataclass +class IterationRanges: + """ + Each range tree represents multiple sets of iteration indexing + in a single tiled dimension in the output kernel. + + If you have two loops ranges one (4, 3, 2) and another (4, 6), + then the range tree will be: + 4 (i0) + 3 (i1) 6 (i3) + 2 (i2) + Where i0 is shared between both loops, but then the split into + different indexing vars. All loop ranges must iterate over + the same number of elements. + """ + + def __init__( + self, + name: str, + var_list: List[sympy.Symbol], + var_ranges: Dict[sympy.Symbol, sympy.Expr], + numel: sympy.Expr, + prefix: str, + *, + kernel: SIMDKernel, + divisor=sympy.Integer(1), + length=sympy.Integer(1), + root: IterationRangesRoot, + ) -> None: + super().__init__() + self.name = name + self.var_list = var_list + self.var_ranges = var_ranges + self.numel = numel + self.prefix = prefix + self.divisor = divisor + self.length = length + self.kernel = kernel + self.root = root + + def symbol(self): + return sympy_index_symbol(self.name) + + +class IterationRangesRoot(IterationRanges): + def __init__( + self, + name: str, + numel: sympy.Expr, + # TODO: this is probably SymTy.INDEX and SymTy.RINDEX + prefix: str, + index: int, + kernel: SIMDKernel, + pid_cache=None, + *, + is_loop: bool, + tensor_dim: Optional[int], + grid_dim: Optional[int], + has_zdim: bool, + ) -> None: + if pid_cache is None: + pid_cache = {} + super().__init__( + name=name, + var_list=[], + var_ranges={}, + numel=numel, + prefix=prefix, + kernel=kernel, + root=self, + ) + self.index = index + # Store all the nodes in one flat list + self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {} + # This is for re-ordering program ID in triton mm template + # pid_cache["tl.program_id(0)"] = pid_m + self.pid_cache: Dict[str, str] = pid_cache + + # True if the dimension is implemented as a single program looping over + # the full dimension (currently only used for non-persistent reduction) + assert not is_loop or (prefix == "r" and grid_dim is None) + self.is_loop = is_loop + # Index of corresponding dimension on triton tensors + self.tensor_dim = tensor_dim + # Index of corresponding dimension in the triton grid + self.grid_dim = grid_dim + self.has_zdim = has_zdim + + def __repr__(self) -> str: + return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)" + + def cache_clear(self): + for node in self.nodes.values(): + node.cache_clear() + + def index_sym(self): + return sympy_index_symbol(f"{self.prefix}index") + + def lookup(self, divisor, length): + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(self.index_sym(), divisor) + else: + expr = ModularIndexing(self.index_sym(), divisor, length) + + if expr not in self.nodes: + node = IterationRangesEntry( + f"{self.prefix}{next(V.kernel.iter_vars_count)}", + divisor, + length, + expr, + self, + ) + V.kernel.range_tree_nodes[node.symbol()] = node + self.var_list.append(node.symbol()) + self.var_ranges[node.symbol()] = length + self.nodes[expr] = node + return self.nodes[expr] + + def construct_entries(self, lengths: List[sympy.Expr]): + divisor = sympy.Integer(1) + itervars = [] + for length in reversed(lengths): + itervars.append(self.lookup(divisor, length)) + divisor = divisor * length + return list(reversed(itervars)) + + def construct(self, lengths: List[sympy.Expr]): + return [e.symbol() for e in self.construct_entries(lengths)] + + def vars_and_sizes(self, index: sympy.Expr): + """Figure out vars from this tree used in index""" + nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols] + nodes = [n for n in nodes if n and n.prefix == self.prefix] + nodes.sort( + key=lambda x: V.graph.sizevars.size_hint( + x.divisor, fallback=config.unbacked_symint_fallback + ) + ) + divisor = sympy.Integer(1) + index_vars = [] + sizes = [] + + def add(node): + nonlocal divisor + index_vars.append(node.symbol()) + sizes.append(node.length) + divisor = divisor * node.length + + for node in nodes: + if not V.graph.sizevars.statically_known_equals(node.divisor, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(node.divisor, divisor))) + divisor = node.divisor + add(node) + if not V.graph.sizevars.statically_known_equals(self.numel, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(self.numel, divisor))) + + return list(reversed(index_vars)), list(reversed(sizes)) + + +class IterationRangesEntry(IterationRanges): + def __init__( + self, + name: str, + divisor: sympy.Expr, + length: sympy.Expr, + expr: sympy.Expr, + parent: IterationRanges, + ) -> None: + super().__init__( + name=name, + numel=parent.numel / length, + var_list=parent.var_list, + var_ranges=parent.var_ranges, + prefix=parent.prefix, + divisor=divisor, + length=length, + kernel=parent.kernel, + root=parent.root, + ) + self.parent = parent + self.codegen = functools.lru_cache(None)(self._codegen) + self.expr = expr + + def __repr__(self) -> str: + return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})" + + def set_name(self, name): + self.codegen = lambda: name # type: ignore[assignment] + self.codegen.cache_clear = lambda: None # type: ignore[method-assign] + self.name = name + + def cache_clear(self): + self.codegen.cache_clear() + + def _codegen(self): + V.kernel.codegen_iteration_ranges_entry(self) + return self.name + + def precomputed_args(self): + # for dynamic shapes, find parts of indexing expressions that have to be precomputed + precomputed_args: List[sympy.Expr] = [] + if isinstance(self.expr, sympy.Symbol): + return precomputed_args + assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr) + for arg in self.expr.args[1:]: + if not isinstance(arg, (sympy.Integer, sympy.Symbol)): + symbols = arg.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, SymT.SIZE) for s in symbols + ): + precomputed_args.append(arg) + return precomputed_args + + def __hash__(self): + return hash(self.name) + + def __eq__(self, other): + return self.name == other.name + + +def constant_repr(value): + if value == float("inf"): + return 'float("inf")' + elif value == float("-inf"): + return 'float("-inf")' + elif math.isnan(value): + return 'float("nan")' + return repr(value) + + +class SIMDKernel(Kernel): + """ + Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests. + """ + + sexpr = pexpr + kexpr: Callable[[sympy.Expr], str] + allow_block_ptr = False + + def __init__( + self, + *groups, + index_dtype: str, + mutations: Optional[OrderedSet[str]] = None, + pid_cache=None, + reduction_hint=ReductionHint.DEFAULT, + override_persistent_reduction=None, + ) -> None: + if pid_cache is None: + pid_cache = {} + super().__init__() + self.body = IndentedBuffer() + self.indexing_code = IndentedBuffer() + self.numels = [V.graph.sizevars.simplify(s) for s in groups] + self.mutations: OrderedSet[str] = ( + mutations if mutations is not None else OrderedSet() + ) + self.range_trees: List[IterationRangesRoot] = [] + self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {} + self.iter_vars_count = itertools.count() + self.inside_reduction = self.numels[-1] != 1 + self.reduction_hint = reduction_hint + self.index_dtype: str = index_dtype + self.last_usage: OrderedSet[str] = OrderedSet() + self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list) + self.persistent_reduction: bool = ( + override_persistent_reduction + if override_persistent_reduction is not None + else self.should_use_persistent_reduction() + ) + self.no_x_dim = self.want_no_x_dim() + self.code_hash: Union[str, None] = None + + # define this in a closure to make cache local to object + @functools.lru_cache(None) + def simplify_indexing(index: sympy.Expr): + index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) + for tree in self.range_trees: + index = self.combine_contiguous_dims(index, tree) + + return self.combine_modular_indexing_pairs(index) + + self.simplify_indexing = simplify_indexing + self.initialize_range_tree(pid_cache) + + def want_no_x_dim(self): + return False + + def initialize_range_tree(self, pid_cache): + no_r_dim = not self.inside_reduction or self.numels[-1] == 1 + + prefixes = "zyxr" + active_prefixes = prefixes[-len(self.numels) :] + + grid_dims = "xyz" + if self.no_x_dim: + tensor_dims = "r" + elif no_r_dim: + tensor_dims = "xyz" + else: + tensor_dims = "xyzr" + + tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes) + + for i, prefix in enumerate(active_prefixes): + is_reduction = prefix == "r" + tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None + grid_dim = None if is_reduction else grid_dims.find(prefix) + index = i if grid_dim is None else grid_dim + self.range_trees.append( + IterationRangesRoot( + f"{prefix}index", + self.numels[i], + prefix, + index, + self, + pid_cache=pid_cache, + is_loop=is_reduction and not self.persistent_reduction, + tensor_dim=tensor_dim, + grid_dim=grid_dim, + has_zdim="z" in active_prefixes, + ) + ) + + def finalize_indexing(self, indices: Sequence[sympy.Expr]): + """ + Hook called right before codegen with every index that will be + used in the fused kernel. + """ + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + prior = self.inside_reduction + self.inside_reduction = False + try: + return self.store(name, index, value) + finally: + self.inside_reduction = prior + + def should_use_persistent_reduction(self) -> bool: + return False # defined in subclass + + def var_ranges(self): + return dict( + itertools.chain.from_iterable( + tree.var_ranges.items() for tree in self.range_trees + ) + ) + + def triton_tensor_ndim(self): + return sum(int(tree.tensor_dim is not None) for tree in self.range_trees) + + def indexing_size_str(self, i): + sizes = ["None"] * self.triton_tensor_ndim() + sizes[i] = ":" + return f"[{', '.join(sizes)}]" + + def dense_size_list(self) -> List[str]: + sizes = ["1"] * self.triton_tensor_ndim() + for tree in self.range_trees: + if tree.tensor_dim is None: + continue + + if tree.prefix != "r" or self.inside_reduction: + sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" + return sizes + + def dense_size_str(self): + sizes = self.dense_size_list() + return f"[{', '.join(sizes)}]" + + def combine_modular_indexing_pairs(self, index): + if not isinstance(index, ModularIndexing): + return index + x = index.args[0] + if (tree_node := self.range_tree_nodes.get(x)) is None: + return index + new_index = sympy_subs(index, {x: tree_node.expr}) + new_index = V.graph.sizevars.combine_modular_indexing_pairs(new_index) + # the index now contains xindex/etc, which is nonstandard, fix it up + return sympy_subs( + new_index, + { + tree_node.root.index_sym(): tree_node.root.lookup( + sympy.Integer(1), tree_node.root.numel + ).symbol() + }, + ) + + def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): + if expand_res := V.graph.sizevars.expand_floor_div(index): + new_index, denominator = expand_res # type: ignore[misc] + return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator) + else: + return self._combine_contiguous_dims(index, tree) + + def _combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): + """ + More aggressive simplification to merge contiguous dims + """ + if isinstance(index, (sympy.Integer, sympy.Symbol)): + return index + index_vars, sizes = tree.vars_and_sizes(index) + if len(sizes) <= 1: + return index + new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( + index_vars, sizes, index_prevent_reordering([index], index_vars, sizes) + ) + if new_sizes == sizes: + return index + new_index_vars = tree.construct(new_sizes) + new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars)))) + return new_index + + def set_last_usage(self, nodes): + if not self.inside_reduction or self.persistent_reduction: + return + self.last_usage = OrderedSet( + itertools.chain.from_iterable( + n.last_usage for n in nodes if n is not EnableReduction + ) + ) + + def disable_reduction(self): + should_flush = self.range_trees[-1].is_loop + + @contextlib.contextmanager + def ctx(): + if self.numels[-1] == 1: + assert not self.inside_reduction + yield + return + if should_flush: + # calling codegen_body() will flush all the pending buffers + # and write out a reduction loop + self.codegen_body() + self.inside_reduction = False + try: + yield + if should_flush: + # flush out any code before opening the next loop + self.codegen_body() + finally: + self.inside_reduction = True + + return ctx() + + def set_ranges(self, *lengths): + assert len(lengths) == len(self.range_trees) + return [ + ranges.construct(length) + for length, ranges in zip(lengths, self.range_trees) + ] + + @staticmethod + def _split_iteration_ranges( + groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]] + ): + sv = V.graph.sizevars + new_ranges: List[List[sympy.Expr]] = [[] for _ in groups] + remaining = [sv.simplify(g) for g in groups] + var_count = itertools.count() + + def add_range(i, expr): + expr = sv.simplify(expr) + if not sv.statically_known_multiple_of(remaining[i], expr): + raise CantSplit + # guard on the last item out + remaining[i] = FloorDiv(remaining[i], expr) + new_ranges[i].append(expr) + return next(var_count) + + def make_combined(size, idx1, idx2): + def getter(flat_vars): + return size * flat_vars[idx1] + flat_vars[idx2] + + return getter + + return_getters_groups = [] + current_group = 0 + for length_group in lengths: + return_getters = [] + for size in length_group: + if sv.statically_known_equals(size, 1): # type: ignore[arg-type] + return_getters.append(lambda _: sympy.Integer(0)) + continue + + while current_group < len(remaining) and sv.statically_known_equals( + remaining[current_group], 1 # type: ignore[arg-type] + ): + # scroll to next group with remaining elements + current_group += 1 + + if current_group + 1 < len(remaining) and sv.statically_known_gt( + size, remaining[current_group] + ): + # need to break size in two + if not sv.statically_known_multiple_of( + size, remaining[current_group] + ): + raise CantSplit + size1 = remaining[current_group] + size2 = FloorDiv(size, remaining[current_group]) + return_getters.append( + make_combined( + size2, + add_range(current_group, size1), + add_range(current_group + 1, size2), + ) + ) + else: + return_getters.append( + operator.itemgetter(add_range(current_group, size)) + ) + return_getters_groups.append(return_getters) + + assert all( + V.graph.sizevars.size_hint(s) == 1 for s in remaining + ), f"failed to set ranges {remaining} {lengths}" + + return new_ranges, return_getters_groups + + @classmethod + def is_compatible( + cls, groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]] + ): + try: + cls._split_iteration_ranges(groups, lengths) + return True + except CantSplit: + return False + + def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]): + """ + We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1). + + To do this we need to split up the iteration space of i0 into something like: + for i1 in s0: + for i2 in s1: + i0 = i1*s1 + i2 + .... + + This function matches and resplits lengths to the groups of + this kernel to enable tiled + non-tiled fusions. + """ + groups = [rt.numel for rt in self.range_trees] + if not self.inside_reduction: + groups[-1] = sympy.Integer(1) + + if len(lengths) == len(self.range_trees) and all( + V.graph.sizevars.simplify(sympy_product(x) - g) == 0 + for x, g in zip(lengths, groups) + ): + return self.set_ranges(*lengths) + + new_ranges, return_getters_groups = self._split_iteration_ranges( + groups, lengths + ) + itervars = list(itertools.chain.from_iterable(self.set_ranges(*new_ranges))) + return [[fn(itervars) for fn in fns] for fns in return_getters_groups] + + def is_indirect_indexing(self, index: sympy.Expr): + # tmpX means indirect indexing + return free_symbol_is_type(index, SymT.TMP) + + def is_broadcasted(self, index: sympy.Expr): + # Note. This may not be correct when there is indirect indexing + if self.is_indirect_indexing(index): + return False + + index_numels = [1] * len(self.numels) + for symbol in index.free_symbols: + if symbol not in self.range_tree_nodes: + # Non-iterated variables, e.g. strides + continue + entry = self.range_tree_nodes[symbol] # type: ignore[index] + assert isinstance(entry.parent, IterationRangesRoot) + index_numels[entry.parent.index] *= entry.length + + # If the index variables only iterate over a subset of the kernel + # numels, then it must be broadcasted. + simplify = V.graph.sizevars.simplify + return any( + simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type] + for idx_range, iter_range in zip(index_numels, self.numels) + ) + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in output code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the generated kernel. + + Index expressions often need to be passed in as arguments to the triton kernel. + Rename_indexing and codegen_indexing keep track of the needed indices and add + new parameters to the function signature. + """ + if isinstance(index, list): + return f"[{', '.join(map(self.index_to_str, index))}]" + return self.kexpr(self.rename_indexing(index)) # type: ignore[call-arg] + + def prepare_indexing( + self, + index: sympy.Expr, + ): + index = self.simplify_indexing(index) + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + # last resort, if no range vars are in the expr, hoist it + # TODO instead of trying to blindly find complicated exprs, we should hoist the + # inputs/outputs sizes and strides, but at the time indexing is generated + # kernel inputs and outputs are not set yet, we'd need a deeper refactor + # to do it this way + + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE)) + for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + simp_index = self.simplify_indexing(index) + + # Now that we are done simplifying we can unwrap Identity so that downstream handling + # for its contained expression will work. previously, tl.full wrapping of sympy.Integer + # would not occur + simp_index = ( + simp_index if not isinstance(simp_index, Identity) else simp_index.args[0] + ) + + return self.codegen_indexing(simp_index) + + def active_range_trees(self, reorder=False): + trees = [ + t for t in self.range_trees if t.prefix != "r" or self.inside_reduction + ] + if reorder and len(trees) > 1: + count = sum(t.prefix in "xyz" for t in trees) + assert "".join(t.prefix for t in trees[:count]) == "zyx"[-count:], [ + t.prefix for t in trees[:count] + ] + trees[:count] = reversed(trees[:count]) + return trees + + def codegen_indexing(self, expr: sympy.Expr): + expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) + for sym in sorted(expr.free_symbols, key=str): + if sym in self.range_tree_nodes: + # if indexing expression is complicated, we precompute it on the host side + # and send the result as a kernel argument + replacements = {} + for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] + replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) + if len(replacements) > 0: + self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] + self.range_tree_nodes[sym].expr, replacements # type: ignore[index] + ) + self.range_tree_nodes[sym].codegen() # type: ignore[index] + return expr + + def codegen_nan_check(self) -> None: + raise NotImplementedError("NYI: codegen_nan_check") + + def call_kernel(self, name: str, node: Optional[IRNode] = None) -> None: + raise NotImplementedError("NYI: call_kernel") + + @contextlib.contextmanager + def mask_loads(self, mask, value): + """Context manager to add an additional mask to tl.load/store""" + prior = self._load_mask + prior_val = self._load_other + if prior: + mask = ops.logical_and(mask, prior) + + mask = OpsWrapper._unwrap(mask) + self._load_mask = mask + self._load_other = value + try: + # TODO(jansel): do we need a reshape here? + yield mask + finally: + self._load_mask = prior + self._load_other = prior_val + + def get_strides_of_load(self, index: sympy.Expr): + """ + This gets the stride of the index for each of the tiling variables + (technically, it does it at index 0) + + For example, if + xindex = x0 + 512*x1 + 1024*r0 + x0 = (xindex//512) + x1 = (xindex % 512) + r0 = rindex // 1024 + + this function would return + {xindex: 512, rindex: 1024} + """ + index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()} + index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type] + strides = {} + for range_tree in self.range_trees: + s = sympy_index_symbol(range_tree.name) + strides[s] = sympy_subs(index_in_tile_vars, {s: 1}) - sympy_subs( + index_in_tile_vars, {s: 0} + ) + return strides + + @staticmethod + def _map_tuple_or_scalar(fn, value): + if isinstance(value, tuple): + return tuple(map(fn, value)) + return fn(value) + + def estimate_kernel_num_bytes(self): + """ + Try the best to estimate the total size (in bytes) of the + kernel's inputs and outputs, which is used for estimating the memory + throughput of this kernel. This information is used for checking how + far we are from the peak memory bandwidth. It's important that + we want to avoid overestimating the sizes of the inputs and outputs, + because it can wrongfully give us a very large memory traffic value, + which may be even larger than the theoretical bandwidth and thus + become very misleading. This is particularly problematic for cases + where we slice some inputs. In those cases, we should only count + the size of the "slices" instead of the original inputs, because + only the slices contribute to the real memory traffic. + """ + nbytes = [] + ninplace_args = len(unique(self.args.inplace_buffers.values())) + _, call_args, _, _ = self.args.python_argdefs() + + # For pointwise and reduction kernels, this is the upper-bound numels + # for the output buffer. + # FIXME: This is not exactly right for cases like below: + # def foo(tensor0, tensor1): + # x0 = narrow(tensor0) + # return cat(x0, tensor1) + # For this example, we will end up overestimate the size for the + # slice s0. Potentially, we could have precise inputs information + # if we maintained the original inputs of the Pointwise kernel created + # for the "cat". However, I think it might be a bit overwhelming that + # we add such complexity only for handling some particular cases for + # benchmarking. + out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels)) + for i, arg in enumerate(call_args): + # "buf" may be narrowed. In this case, the number of memory accesses + # should be estimated based on the reinterpreted layout. + # On the other hand, buf may be broadcasted. In this case, + # counting the size of the underline storage would give us + # a better estimation in terms of memory accesses. + if arg not in self.buf_accesses: + nbytes.append(0) + continue + arg_numel = V.graph.get_numel(arg) + buf_size = V.graph.sizevars.size_hint(arg_numel) + if buf_size > out_numel: + # This arg points to a buf that has been sliced. + # We need to count each individual slice to have + # a better estimation. + indices: OrderedSet[Any] = OrderedSet() + no_index_dep_count = 0 + for dep in self.buf_accesses[arg]: + if isinstance(dep, (StarDep, WeakDep)): + indices.add(f"no_index_dep_{no_index_dep_count}") + no_index_dep_count += 1 + else: + indices.add(dep.index) + numel = len(indices) * out_numel + else: + numel = buf_size + dtype = V.graph.get_dtype(arg) + dtype_size = get_dtype_size(dtype) + nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(nbytes) + + def warn_mix_layout(self, kernel_name): + """ + Print message if the kernel have mixed layout inputs. + Only care about 4D tensor for now. + """ + if ( + len(self.args.input_buffers) == 1 + and len(self.args.output_buffers) == 1 + and len(self.args.inplace_buffers) == 0 + ): + # even if input buffer and output buffer have different layout, + # this can be a layout conversion kernel. No need to warn for + # the mix layouts. + return + + argdefs, call_args, signature, _ = self.args.python_argdefs() + uniform_stride_order = None + for arg_name in call_args: + buf = V.graph.try_get_buffer(arg_name) + if buf and len(buf.layout.size) == 4: + # ignore the tensor if only 1 dimension is non-zero + if len([x for x in buf.layout.size if x == 1]) == 3: + continue + stride_order = ir.get_stride_order(buf.layout.stride) + if uniform_stride_order is None: + uniform_stride_order = stride_order + elif uniform_stride_order != stride_order: + msg = yellow_text( + f"Expected stride order {uniform_stride_order}, but found stride order" + + f" {stride_order} for kernel {kernel_name}" + ) + log.warning(msg) + + stride_order_list = [ + ir.get_stride_order(V.graph.get_buffer(name).layout.stride) + if V.graph.try_get_buffer(name) + else None + for name in call_args + ] + size_list = [ + V.graph.get_buffer(name).layout.size + if V.graph.try_get_buffer(name) + else None + for name in call_args + ] + source_list = [ + "GraphInput" + if name in V.graph.graph_inputs + else "IntermediateBuffer" + if name in V.graph.name_to_buffer + else None + for name in call_args + ] + + msg = yellow_text( + f" param names {argdefs}\n buf names {call_args}\n strides {stride_order_list}" + + f"\n sizes {size_list}\n sources {source_list}\n" + ) + log.warning(msg) + return + msg = green_text( + f"All the inputs for the triton kernel {kernel_name} have uniform layout" + ) + log.warning(msg) + + def welford_reduce_fallback(self, dtype, value): + sum_ = ops.reduction(dtype, dtype, "sum", value) + self.inside_reduction = False + rnumel = ops.index_expr(self.numels[-1], dtype) + mean = ops.truediv(sum_, rnumel) + + self.inside_reduction = True + dx = ops.sub(value, mean) + dx2 = ops.mul(dx, dx) + m2 = ops.reduction(dtype, dtype, "sum", dx2) + return OpsWrapper._unwrap((mean, m2, rnumel)) + + def codegen_kernel(self): + raise NotImplementedError + + def codegen_body(self): + pass + + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): + pass + + +class SIMDScheduling(BaseScheduling): + kernel_type = SIMDKernel # override in subclass + int32_type = "torch.int32" + int64_type = "torch.int64" + + def __init__(self, scheduler) -> None: + super().__init__() + self.scheduler = scheduler + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + def can_fuse(self, node1, node2): + """ + Hook called by Scheduler to determine if the Triton backend + can fuse node1 and node2. These nodes might already be + FusedSchedulerNodes. + """ + if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance( + node2, scheduler.ForeachKernelSchedulerNode + ): + return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2) + + _, (numel1, rnumel1) = node1.group + _, (numel2, rnumel2) = node2.group + why = WhyNoFuse(node1, node2) + + if node1.is_split_scan() and not node2.is_split_scan(): + if node2.is_reduction(): + why("Split scan cannot fuse with reductions") + elif node2.is_split_scan() and not node1.is_split_scan(): + if node1.is_reduction(): + why("Split scan cannot fuse with reductions") + + if node1.is_reduction() and node2.is_reduction(): + reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2 + if not reduction_can_fuse: + why( + "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return reduction_can_fuse + + if not node1.is_reduction() and not node2.is_reduction(): + if not (numel1 == numel2 and rnumel1 == rnumel2): + why( + "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return False + + if node1.is_template(): + # Only allow fusion for TritonTemplates for now. + # Fusion for CUDATemplates are not supported. + is_triton_template = isinstance(node1.node, TritonTemplateBuffer) + if not is_triton_template: + why("node1 is not TritonTemplateBuffer") + return is_triton_template + + # check for a bad combined tiling + tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1) + tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1) + tiling3 = self.select_tiling( + node1.get_nodes() + node2.get_nodes(), numel1, rnumel1 + ) + if config.triton.tiling_prevents_pointwise_fusion: + cond = True + if len(tiling1) > 2: + if len(tiling2) > 2: + cond = tiling1 == tiling2 == tiling3 + else: + cond = tiling1 == tiling3 + elif len(tiling2) > 2: + cond = tiling2 == tiling3 + if not cond: + why( + "tiling mismatch (%s, %s, %s)", + tiling1, + tiling2, + tiling3, + ) + return False + + return True + + if not node1.is_reduction() and node2.is_reduction(): + assert rnumel1 == 1 and rnumel2 != 1 + if numel1 == numel2 * rnumel2: + if not all( + SIMDKernel.is_compatible((numel2, rnumel2), n.get_ranges()) + for n in node1.get_nodes() + ): + why("nodes numel/rnumel incompatibility") + return False + if ( + config.triton.tiling_prevents_reduction_fusion + and not node1.is_template() + ): + is_reduction_tiling_valid = self.select_tiling( + node1.get_nodes(), numel1 + ) in ( + (numel1, 1), + (numel2, rnumel2, 1), + ) + if not is_reduction_tiling_valid: + why("invalid tiling for reduction") + return is_reduction_tiling_valid + return True + + if numel1 != numel2: + why("nodes numel incompatibility") + return numel1 == numel2 + + assert node1.is_reduction() and not node2.is_reduction() + # swap args to hit the case above + return self.can_fuse_horizontal(node2, node1) + + can_fuse_vertical = can_fuse + can_fuse_horizontal = can_fuse + + def generate_node_schedule(self, nodes, numel, rnumel): + node_schedule: List[Any] = [] + done: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet() + # Writes with a reduced shape, meaning they are only present once the + # reduction loop has ended + not_ready_yet_nodes: OrderedSet[str] = OrderedSet() + + def fits_in_main_body(n): + _, (node_numel, node_rnumel) = n.group + return (node_numel == numel and node_rnumel == rnumel) or ( + node_numel == numel * rnumel and node_rnumel == 1 + ) + + def fits_outside_reduction(n): + _, (node_numel, node_rnumel) = n.group + return node_numel == numel and node_rnumel == 1 and rnumel != 1 + + def schedule_node_in_loop(n): + done.add(n) + node_schedule.append(n) + # A scan is modelled as a reduction in the scheduler but has a + # full sized output that can be used inside the loop body + if ( + n.is_reduction() + and isinstance(n, scheduler.SchedulerNode) + and isinstance(n.node, ir.ComputedBuffer) + and not isinstance(n.node.data, ir.Scan) + ): + not_ready_yet_nodes.add(n.get_name()) + + @contextlib.contextmanager + def end_current_reduction_loop(): + if node_schedule and node_schedule[-1] is EnableReduction: + node_schedule.pop() + else: + node_schedule.append(DisableReduction) + yield + node_schedule.append(EnableReduction) + not_ready_yet_nodes.clear() + + def requires_closing_previous_reduction(node, node_schedule): + if rnumel == 1: + return False + if not not_ready_yet_nodes & node.ancestors: + return False + assert node_schedule and not isinstance( + node_schedule[-1], (EnableReduction, DisableReduction) + ) + return bool(not_ready_yet_nodes) + + for index, node in enumerate(nodes): + if node in done: + continue + done.add(node) + + if fits_in_main_body(node): + if requires_closing_previous_reduction(node, node_schedule): + with end_current_reduction_loop(): + pass # need to start a new reduction loop + + schedule_node_in_loop(node) + elif fits_outside_reduction(node): + with end_current_reduction_loop(): + node_schedule.append(node) + else: + raise NotImplementedError( + f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}" + ) + + return node_schedule + + def codegen_node( + self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode] + ): + """ + Given a set of pre-fused nodes, generate a Triton kernel. + """ + + nodes: List[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] + + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + buf_accesses = collections.defaultdict(list) + for node in nodes: + for access in node.read_writes.reads | node.read_writes.writes: + buf_accesses[access.name].append(access) + + schedule_log.debug("Schedule:\n %s", node_schedule) + + return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel) + + @staticmethod + def reduction_hint(node): + assert node.is_reduction() + if all( + dep.is_contiguous() + for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes) + ): + return ReductionHint.INNER + else: + return node.node.data.reduction_hint + + @staticmethod + def can_use_32bit_indexing( + numel: sympy.Expr, buffers: Iterable[Union[ir.Buffer, ir.TensorBox]] + ) -> bool: + int_max = torch.iinfo(torch.int32).max + size_hint = V.graph.sizevars.size_hint + has_hint = V.graph.sizevars.shape_env.has_hint + + def within_32bit(e): + # Allow for unhinted e as long as we can still statically prove + # (e.g., via ValueRanges) that it is still in bounds + if V.graph.sizevars.is_expr_static_and_true(e <= int_max): + return True + # Otherwise, the hint MUST exist and be in range + return has_hint(e) and size_hint(e) <= int_max + + if not within_32bit(numel): + return False + + # Any use of a MultiOutputLayout will create a buffer with a + # Layout whose sizes are accounted for + buf_sizes = [ + buf.get_layout().storage_size() + for buf in buffers + if not isinstance(buf.get_layout(), ir.MultiOutputLayout) + ] + + if not all(within_32bit(size) for size in buf_sizes): + return False + + # Only install guards for 32-bit indexing as there is no correctness + # issue with using 64-bit for everything + V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type] + for size in buf_sizes: + V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type] + return True + + @classmethod + def select_index_dtype(cls, node_schedule, numel, reduction_numel): + # Gather all used buffer names + buffer_names: OrderedSet[str] = OrderedSet() + for node in node_schedule: + if not isinstance(node, scheduler.BaseSchedulerNode): + continue + + buffer_names.update(node.get_buffer_names()) + buffer_names.update(node.used_buffer_names()) + + # Get buffers objects + + def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]: + buf = V.graph.get_buffer(name) + if buf is None: + raise RuntimeError(f"Failed to find buffer matching name {name}") + return buf + + buffers = [V.graph.get_buffer(name) for name in buffer_names] + + # In theory we can separately check xnumel and rnumel are <= int_max + # but some indexers do use the full linear index so we need to be + # conservative here. + total_numel = numel * reduction_numel + + if SIMDScheduling.can_use_32bit_indexing(total_numel, buffers): + return cls.int32_type + return cls.int64_type + + def has_non_contiguous_pw_in_reduction_kernel(self, node_schedule, numel, rnumel): + pointwise_nodes = list( + filter( + lambda n: n not in (EnableReduction, DisableReduction) + and not n.is_reduction() + and n.group[1][0] == numel * rnumel, + node_schedule, + ) + ) + for node in pointwise_nodes: + # An index can be an integer when loading a random seed. + if not all( + not isinstance(dep, MemoryDep) + or dep.is_contiguous() + or isinstance(dep.index, (sympy.Integer, int)) + or dep.stride1_for_last_dim() + for dep in itertools.chain( + node.read_writes.reads, node.read_writes.writes + ) + ): + return True + return False + + def get_kernel_args(self, node_schedule, numel, reduction_numel): + reductions = list( + filter( + lambda n: n not in (EnableReduction, DisableReduction) + and n.is_reduction(), + node_schedule, + ) + ) + if len(reductions) > 0: + hints = [self.reduction_hint(n) for n in reductions] + if hints.count(hints[0]) == len(hints): + reduction_hint_val = hints[0] + else: + reduction_hint_val = ReductionHint.DEFAULT + + if ( + reduction_hint_val == ReductionHint.INNER + and self.has_non_contiguous_pw_in_reduction_kernel( + node_schedule, numel, reduction_numel + ) + ): + reduction_hint_val = ReductionHint.DEFAULT + else: + reduction_hint_val = ReductionHint.DEFAULT + + mutations: OrderedSet[str] = OrderedSet() + for node in node_schedule: + if node in (DisableReduction, EnableReduction): + continue + + for buf in node.get_outputs(): + mutations.update(buf.get_mutations()) + + index_dtype = self.select_index_dtype(node_schedule, numel, reduction_numel) + + return reduction_hint_val, mutations, index_dtype + + def codegen_node_schedule( + self, node_schedule, buf_accesses, numel, reduction_numel + ): + from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel + + tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel) + ( + reduction_hint_val, + mutations, + index_dtype, + ) = self.get_kernel_args(node_schedule, numel, reduction_numel) + + is_split_scan = any( + isinstance(node, BaseSchedulerNode) and node.is_split_scan() + for node in node_schedule + ) + kernel_type: type = self.kernel_type + if is_split_scan and issubclass(TritonSplitScanKernel, kernel_type): + kernel_type = TritonSplitScanKernel + + kernel_args = tiled_groups + kernel_kwargs = dict( + reduction_hint=reduction_hint_val, + mutations=mutations, + index_dtype=index_dtype, + ) + + def _node_has_sort(node): + if node in (EnableReduction, DisableReduction): + return False + + sort_nodes = node._body.root_block.graph.find_nodes( + op="call_method", target="sort" + ) + return bool(sort_nodes) + + # ops.sort only works with persistent reduction, and is not bandwidth bound anyway + # so taking the hit of non-coalesced loads is okay + has_sort = any(_node_has_sort(node) for node in node_schedule) + if has_sort: + kernel_kwargs["override_persistent_reduction"] = True + + kernel = kernel_type( + *kernel_args, + **kernel_kwargs, + ) + kernel.buf_accesses = buf_accesses + + kernel2: Optional[SIMDKernel] = None + if kernel.persistent_reduction and config.triton.multi_kernel and not has_sort: + kernel2 = self.kernel_type( + *kernel_args, + **kernel_kwargs, + override_persistent_reduction=False, + ) + self.codegen_node_schedule_with_kernel(node_schedule, kernel2) + with V.set_kernel_handler(kernel2): + src_code2 = kernel2.codegen_kernel() + kernel_name2 = self.define_kernel(src_code2, node_schedule, kernel) + kernel2.kernel_name = kernel_name2 + kernel2.code_hash = code_hash(src_code2) + + # Keep buffers needed by the non-persistent reduction so both + # kernels have the same arguments + kernel.must_keep_buffers = set(kernel2.must_keep_buffers) + + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + + with V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + log.debug("Generating kernel code with kernel_name: %s", kernel_name) + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + + final_kernel = MultiKernel([kernel, kernel2]) if kernel2 is not None else kernel + + with V.set_kernel_handler(final_kernel): + for node in node_schedule: + if node not in (EnableReduction, DisableReduction): + node.mark_run() + + self.codegen_comment(node_schedule) + final_kernel.call_kernel(final_kernel.kernel_name) + + if config.nan_asserts: + final_kernel.codegen_nan_check() + if config.warn_mix_layout: + final_kernel.warn_mix_layout(kernel_name) + + V.graph.removed_buffers |= final_kernel.removed_buffers + V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove + + if ( + V.graph.wrapper_code.supports_intermediate_hooks + and config.generate_intermediate_hooks + ): + # Not every node in the schedule will actually be live on output; + # we can't check dead buffers. + live_outs = kernel.args.live_output_buffers() + for node in node_schedule: + if not isinstance(node, scheduler.BaseSchedulerNode): + continue + name = node.get_name() + if name not in live_outs: + continue + assert node.node is not None + origin_node = node.node.get_origin_node() + if origin_node is not None: + counters["inductor"]["intermediate_hooks"] += 1 + V.graph.wrapper_code.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {name})" + ) + + self.scheduler.free_buffers() + + def codegen_node_schedule_with_kernel(self, node_schedule, kernel): + def current_reduction_nodes(nodes): + return itertools.takewhile(lambda n: n is not DisableReduction, nodes) + + with kernel: + stack = contextlib.ExitStack() + kernel.set_last_usage(current_reduction_nodes(node_schedule)) + all_indexing = {} + + # First pass to collect indexing and decide inplace updates + for node in node_schedule: + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + else: + node.decide_inplace_update() + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + all_indexing.update( + dict.fromkeys( + node._body.indexing_from_args(index_vars).values() + ) + ) + + kernel.finalize_indexing(all_indexing.keys()) + + # Second pass to do codegen + for i, node in enumerate(node_schedule): + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + kernel.set_last_usage(current_reduction_nodes(node_schedule[i:])) + else: + # TODO - use split ranges ? + indexing_dtype_strength_reduction(node._body) + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + node.codegen(index_vars) + + def codegen_template( + self, template_node, epilogue_nodes, only_gen_src_code=False + ) -> Optional[str]: + """ + Codegen a triton template + + If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper + """ + _, (numel, rnumel) = template_node.group + assert rnumel == 1 + kernel, render = template_node.node.make_kernel_render(template_node.node) + with kernel: + if not only_gen_src_code: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + partial_code = render() + with kernel.set_subgraph_body(""): + for node in epilogue_nodes: + node.codegen(kernel.split_and_set_ranges(node.get_ranges())) + + if not isinstance(partial_code, str): + partial_code.finalize_hook("") + partial_code.finalize_hook("", strict=False) + # finalize must be called after adding epilogue above + with V.set_kernel_handler(kernel): + # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion. + with kernel.set_subgraph_body(""): + if isinstance(partial_code, str): + src_code = partial_code + else: + partial_code.finalize_hook("") + src_code = partial_code.code + node_schedule = [template_node, *epilogue_nodes] + + if config.benchmark_kernel: + num_gb = kernel.estimate_kernel_num_bytes() / 1e9 + grid_args = V.graph.sizevars.size_hints(kernel.call_sizes) + assert kernel.meta is not None, "meta is None" + grid = kernel.grid_fn(*grid_args, kernel.meta) + src_code = ( + f"{kernel.imports_for_benchmark_kernel()}\n" + f"{src_code}\n" + f"{kernel.codegen_kernel_benchmark(num_gb, grid).getvalue()}" + ) + + if only_gen_src_code: + return src_code + + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + + self.codegen_comment(node_schedule) + kernel.call_kernel(kernel_name, template_node.node) + + V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove + self.scheduler.free_buffers() + return None + + def codegen_sync(self): + V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize()) + + def generate_combo_kernel_code( + self, + subkernel_nodes: List[BaseSchedulerNode], + custom_part_algorithm: bool, + enable_autotune: bool, + mixed_sizes: bool, + only_gen_src_code: bool = False, + ) -> List[Tuple[str, Any, Any]]: + from .triton_combo_kernel import ComboKernel + + fused_node_lists = [node.get_nodes() for node in subkernel_nodes] + subkernel_map, node_schedule_map = {}, {} + for pn, nodes in zip(subkernel_nodes, fused_node_lists): + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + tiled_groups = self.select_tiling(node_schedule, numel, rnumel) + node_schedule_map[pn] = node_schedule, tiled_groups, numel, rnumel + ( + reduction_hint_val, + mutations, + index_dtype, + ) = self.get_kernel_args(node_schedule, numel, rnumel) + subkernel_map[pn] = ComboKernel.create_triton_kernel( + *tiled_groups, + reduction_hint=reduction_hint_val, + mutations=mutations, + index_dtype=index_dtype, + optimize_mask=not mixed_sizes, + ) + + partitions = ComboKernel.horizontal_partition( + nodes=subkernel_nodes, + triton_scheduling=self, + custom_algorithm=custom_part_algorithm, + kernel_map=subkernel_map, + node_info_map=node_schedule_map, + ) + log.debug( + "ComboKernels: %d nodes partitioned into %s groups", + len(subkernel_nodes), + [len(p) for p in partitions], + ) + kernel_code_list = [] + for node_group in partitions: + fused_node_lists = [node.get_nodes() for node in node_group] + kernel = ComboKernel( + enable_autotune=enable_autotune, + mixed_sizes=mixed_sizes, + ) + + for pn, nodes in zip(node_group, fused_node_lists): + if only_gen_src_code: + # empty last_usage. May cause more aggressive 'evict_last'. Should be fine. + for n in nodes: + n.last_usage = OrderedSet() + self.codegen_node_schedule_with_kernel( + node_schedule_map[pn][0], + kernel.create_sub_kernel(subkernel_map[pn]), + ) + subkernel = subkernel_map[pn] + node_schedule = node_schedule_map[pn][0] + if not only_gen_src_code: + with V.set_kernel_handler(subkernel): # type: ignore[call-arg] + for node in node_schedule: + if node not in (EnableReduction, DisableReduction): + node.mark_run() + V.graph.removed_buffers |= subkernel.removed_buffers + V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove + + src_code = kernel.codegen_kernel() + kernel_code_list.append((src_code, kernel, node_group)) + return kernel_code_list + + def codegen_combo_kernel(self, combo_kernel_node): + subkernel_nodes = combo_kernel_node.get_subkernel_nodes() + custom_part_algorithm = combo_kernel_node.use_custom_partition_algo + enable_autotune = combo_kernel_node.enable_autotune + mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or ( + config.combo_kernel_allow_mixed_sizes == 1 and custom_part_algorithm + ) + + kernel_code_list = self.generate_combo_kernel_code( + subkernel_nodes, custom_part_algorithm, enable_autotune, mixed_sizes + ) + + for src_code, kernel, _ in kernel_code_list: + kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel) + self.codegen_comment([combo_kernel_node]) + log.debug("ComboKernels: generated kernel %s.", kernel_name) + kernel.call_kernel(V.graph.wrapper_code, kernel_name) + + self.scheduler.free_buffers() + + @staticmethod + @functools.lru_cache(32) + def candidate_tilings(node): + ranges, reduction_ranges = node.get_ranges() + if len(ranges) <= 1: + return () + + rw = node.pointwise_read_writes() + assert len(rw.range_vars) == len(ranges) + + # isinstance(dep, MemoryDep): this filters out StarDeps. StarDeps refer to reads + # that need to access the entire tensor; they don't contribute read indexing + # information (and practically, they don't have dep.index so they can't be used + # for stride_hints below + dep_sources = [rw.reads, rw.writes] + assert all( + isinstance(dep, (MemoryDep, StarDep)) + for dep in itertools.chain.from_iterable(dep_sources) + ) + deps = [ + dep + for dep in itertools.chain.from_iterable(dep_sources) + if dep.name not in V.graph.removed_buffers and isinstance(dep, MemoryDep) + ] + write_names = {dep.name for dep in rw.writes} + + tilings: List[CandidateTiling] = [] + + for dep in deps: + strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars) + assert len(strides) == len(ranges) + try: + split = strides.index(1) + 1 + if split == len(ranges): + continue + if all(s == 0 for s in strides[split:]): + # if this is a broadcasted tensor and all dimensions after split are broadcast, + # this is not a real split + continue + + except ValueError: + continue + tiled_groups = ( + V.graph.sizevars.simplify(sympy_product(ranges[:split])), + V.graph.sizevars.simplify(sympy_product(ranges[split:])), + ) + # score by number of elements + score = V.graph.sizevars.size_hint( + sympy_product( + size for size, stride in zip(ranges, strides) if stride != 0 + ) + ) + if dep.name in write_names: + # ngimel said contiguous writes is more important than reads + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[0]): + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[1]): + score *= 2 + + if ( + V.graph.sizevars.size_hint( + score - sympy_product(itertools.chain(ranges, reduction_ranges)) + ) + >= 0 + ): + tilings.append(CandidateTiling(tiled_groups, score, dep.name)) + return tilings + + @classmethod + def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): + """ + Heuristics to decide how to tile kernels. + Currently, we tile based on stride-1 dimensions. + + Returns: + `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel` + + """ + if reduction_numel != 1 or config.triton.max_tiles <= 1: + # TODO(jansel): should we tile reductions? + # do perf hint here if stride-1 dim is not being reduced + if perf_hint_log.level <= logging.WARNING: + for node in EnableReduction.filter(node_schedule): + if len(cls.candidate_tilings(node)) > 0: + perf_hint_log.info("reduction over non-contiguous dims") + break + return (numel, reduction_numel) + + seen_names: OrderedSet[str] = OrderedSet() + candidate_tiles: Counter[Any] = collections.Counter() + for node in EnableReduction.filter(node_schedule): + for tiling in cls.candidate_tilings(node): + if tiling.name in seen_names: + continue + seen_names.add(tiling.name) + candidate_tiles[tiling.tiling] += tiling.score + + ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()] + + if config.triton.max_tiles >= 3: + # Consider adding a third dimension of tiling, but only + # when a1 is a multiple of b1; otherwise, you have a lot + # of stragglers which is annoying to generate code for. + # + # NB: More than three max tiles is not enabled by default. + + # Add one 3D tiling choice + for i in range(1, len(ranked_tilings)): + a0, a1 = ranked_tilings[0] + b0, b1 = ranked_tilings[i] + if V.graph.sizevars.size_hint(a1 - b1) == 0: + continue + if V.graph.sizevars.size_hint(a1 - b1) < 0: + # swap so a0 is bigger + a0, a1 = ranked_tilings[i] + b0, b1 = ranked_tilings[0] + assert V.graph.sizevars.size_hint(a1 - b1) > 0 + if V.graph.sizevars.statically_known_multiple_of(a1, b1): + tiling = (a0, FloorDiv(a1, b1), b1) + ranked_tilings = [tiling] + ranked_tilings + break # only 1 choice for now + + if len(ranked_tilings) > 1: + perf_hint_log.info("possibly bad tiling: %s", ranked_tilings) + + # Optionally, prefer tiling into as many dimensions as possible. + if config.triton.prefer_nd_tiling: + # Get candidate tilings from the node ranges. + node_ranges = [ + node.get_ranges()[0] + for node in EnableReduction.filter(node_schedule) + if isinstance(node, scheduler.SchedulerNode) + ] + new_tilings: OrderedSet[Tuple[sympy.Expr]] = OrderedSet() + for node_range in node_ranges: + # Collapse leading dims, to fit in the maximum dimensionality. + num_leading_dims = max(0, len(node_range) - config.triton.max_tiles) + first_trailing_dim = num_leading_dims + 1 + collapsed_leading_dim = sympy_product(node_range[:first_trailing_dim]) + tiling = [collapsed_leading_dim] + list(node_range[first_trailing_dim:]) + new_tilings.add(tuple(tiling)) + + # Rank tilings by the number of dimensions. E.g., prefer 2D to 1D. + # Since this is a stable sort, ties are broken by schedule order. + ranked_new_tilings = sorted(new_tilings, key=len, reverse=True) + ranked_tilings = ranked_new_tilings + ranked_tilings + + for tiled_groups in ranked_tilings: + new_groups = (*tiled_groups, reduction_numel) + if all( + SIMDKernel.is_compatible(new_groups, node.get_ranges()) + for node in node_schedule + if isinstance(node, scheduler.SchedulerNode) + ): + return new_groups + + return (numel, reduction_numel) + + def flush(self): + pass + + def ready_to_flush(self) -> bool: + return False + + def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): + @dataclasses.dataclass + class LastUsageHolder: + n: Any + last_usage: Any + + def __del__(self) -> None: + self.n.last_usage = self.last_usage + + last_usage_holders = [LastUsageHolder(n, n.last_usage) for n in nodes] + + # empty last_usage. May cause more aggressive 'evict_last'. Should be fine. + for n in nodes: + n.last_usage = OrderedSet() + + if not nodes[0].is_template(): + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + + tiled_groups = self.select_tiling(node_schedule, numel, rnumel) + reduction_hint_val, mutations, index_dtype = self.get_kernel_args( + node_schedule, numel, rnumel + ) + + kernel = self.kernel_type( + *tiled_groups, + reduction_hint=reduction_hint_val, + mutations=mutations, + index_dtype=index_dtype, + ) + + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + with config.patch( + "benchmark_kernel", benchmark_kernel + ), V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + else: + template_node = nodes[0] + epilogue_nodes = nodes[1:] + + with config.patch("benchmark_kernel", benchmark_kernel): + src_code = self.codegen_template( + template_node, epilogue_nodes, only_gen_src_code=True + ) + + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + return src_code + + def codegen_comment(self, node_schedule): + pass + + def define_kernel(self, src_code, node_schedule, kernel): + raise NotImplementedError + + +@dataclasses.dataclass +class CandidateTiling: + tiling: Tuple[sympy.Expr, sympy.Expr] + score: int # higher is better + name: Optional[str] = None + + @staticmethod + def is_good_size(s): + """Somewhat arbitrary heuristic used to boost scores for some sizes""" + s = V.graph.sizevars.size_hint(s) + return s >= 32 and (s % 32 == 0) + + +class DisableReduction: + """ + Marker to invoke `kernel.disable_reduction()`. This closes a + reduction loop and allows for pointwise ops to occur on the output + of a reduction. + """ + + +class EnableReduction: + """ + Marker to end a DisableReduction block. + """ + + @staticmethod + def filter(node_schedule): + """ + Get the nodes from node_schedule skipping those in a + DisableReduction block. + """ + disabled = False + for node in node_schedule: + if node in (EnableReduction, DisableReduction): + # Don't tile stuff outside the main reduction loop + disabled = node is DisableReduction + elif disabled: + pass + else: + yield node + + +class CantSplit(Exception): + pass diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..2596f3d59793e90fd7a90dc0bd81ec5bc2426552 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton.py @@ -0,0 +1,3225 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import functools +import itertools +import logging +import os +import textwrap +from functools import lru_cache +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) + +import sympy + +import torch +import torch._logging +from torch._dynamo.utils import preserve_rng_state +from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties +from torch._prims_common import is_integer_dtype +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._triton import has_triton_package + +from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT +from ...utils._sympy.value_ranges import ValueRanges +from .. import config, ir +from ..codecache import code_hash, get_path, PyCodeCache +from ..metrics import is_metric_table_enabled, log_kernel_metadata +from ..runtime.benchmarking import benchmarker +from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK +from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2 +from ..utils import ( + cache_on_self, + get_bounds_index_expr, + get_fused_kernel_name, + get_kernel_metadata, + is_welford_reduction, + Placeholder, + sympy_dot, + sympy_subs, +) +from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V +from ..wrapper_benchmark import get_kernel_category_by_source_code +from .common import ( + BackendFeature, + CSE, + CSEVariable, + DeferredLine, + IndentedBuffer, + OpOverrides, + PythonPrinter, + SizeArg, + TensorArg, + WorkspaceArg, +) +from .simd import ( + constant_repr, + IterationRangesEntry, + IterationRangesRoot, + pexpr, + SIMDKernel, + SIMDScheduling, +) +from .triton_utils import ( + config_of, + should_unwrap_unspec_arg, + signature_of, + signature_to_meta, +) + + +if TYPE_CHECKING: + from ..ir import IRNode + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") + + +@lru_cache(None) +def gen_attr_descriptor_import(): + """ + import AttrsDescriptor if the triton version is new enough to have this + class defined. + """ + if not has_triton_package(): + return "" + + import triton.compiler.compiler + + if hasattr(triton.compiler.compiler, "AttrsDescriptor"): + return "from triton.compiler.compiler import AttrsDescriptor" + else: + return "" + + +@lru_cache(None) +def gen_common_triton_imports(): + imports = IndentedBuffer() + imports.splice( + """ + import triton + import triton.language as tl + """ + ) + if attr_desc := gen_attr_descriptor_import(): + imports.writeline(attr_desc) + + imports.splice( + """ + from torch._inductor.runtime import triton_helpers, triton_heuristics + from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math + from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties + """ + ) + return imports.getvalue() + + +block_offsets = { + symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True) + for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX] +} + +block_sizes = { + symt: sympy.Symbol(f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True) + for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX] +} + + +@dataclasses.dataclass +class IndexingOptions: + index_str: str + mask_vars: OrderedSet[str] + mask_str: str + expand_str: Optional[str] + _has_rindex: bool + index: sympy.Expr + + def has_mask(self): + return bool(self.mask_vars) + + def has_indirect(self): + return free_symbol_is_type(self.index, SymT.TMP) + + def has_rindex(self): + return self._has_rindex + + def has_tmpmask(self): + return "tmp" in self.mask_str + + def has_rmask(self): + return "rmask" in self.mask_str + + +@dataclasses.dataclass +class BlockPtrOptions: + params: BlockParameters + constant_offset: sympy.Expr + order: List[int] + mask_vars: OrderedSet[str] + reshape_suffix: List[str] + + @property + def shape(self) -> List[sympy.Expr]: + return self.params.shape + + @property + def block_shape(self) -> List[sympy.Expr]: + return self.params.block_shape + + @property + def strides(self) -> List[sympy.Expr]: + return self.params.strides + + @property + def offsets(self) -> List[sympy.Expr]: + return self.params.offsets + + @staticmethod + def create( + *, + params: BlockParameters, + constant_offset: sympy.Expr, + range_trees: List[IterationRangesEntry], + mask_vars: OrderedSet[str], + ) -> BlockPtrOptions: + """Helper to create a BlockPtrOptions instance""" + reshape_suffix = [f"{t.prefix.upper()}BLOCK" for t in range_trees] + + # Only drop broadcast dims if the output has the same + # rank as the block. Otherwise, we will get shape errors. + drop_broadcasts = len(reshape_suffix) == len(params.strides) + + broadcasting_dim = [s == 0 for s in params.strides] + for i, is_broadcasting in enumerate(broadcasting_dim): + if is_broadcasting and drop_broadcasts: + # drop any stride==0 dimensions for performance + reshape_suffix[i] = "1" + + if V.kernel.no_x_dim: + assert range_trees[0].prefix == "x" + reshape_suffix.pop(0) + + if ( + not V.kernel.inside_reduction + and len(params.strides) == len(V.kernel.numels) - 1 + and V.kernel.numels[-1] != 1 + ): + # Need to expand rank by 1 to match rank when self.inside_reduction=True + reshape_suffix.append("1") + + def filter(it): + """Removes any broadcasting dims from a given sequence""" + assert len(it) == len(broadcasting_dim) + return [ + item + for item, is_broadcasting in zip(it, broadcasting_dim) + if not is_broadcasting or not drop_broadcasts + ] + + # Drop broadcasting dimensions from the input. + params = BlockParameters( + **{key: filter(val) for key, val in dataclasses.asdict(params).items()} + ) + + def lookup_size(exprs: Iterable[sympy.Expr]) -> List[sympy.Expr]: + return [V.graph.sizevars.lookup_precomputed_size(expr) for expr in exprs] + + # Look up precomputed sizes + params.shape = lookup_size(params.shape) + params.strides = lookup_size(params.strides) + + return BlockPtrOptions( + params=params, + constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset), + order=list(reversed(range(len(params.shape)))), + mask_vars=mask_vars, + reshape_suffix=reshape_suffix, + ) + + def replace_roffset(self, expr: sympy.Expr, replacement: sympy.Expr) -> sympy.Expr: + """ + Replaces instances of roffset with the new expression. + """ + roffset = block_offsets[SymT.RINDEX] + return sympy_subs(expr, {roffset: replacement}) + + def format(self, name: str, roffset=True) -> str: + """ + Codegen a call to tl.make_block_ptr() + + Args: + name: variable name for pointer + roffset: should roffset be included in offsets=..., for use with tl.advance() + + Returns: + "tl.make_block_ptr(...)" + """ + f = V.kernel.index_to_str + offsets = [*self.offsets] + if not roffset: + offsets = [ + self.replace_roffset(offset, sympy.Integer(0)) for offset in offsets + ] + args = [ + f"{name} + ({f(self.constant_offset)})" + if self.constant_offset != 0 + else name, + f"shape={f(self.shape)}", + f"strides={f(self.strides)}", + f"block_shape={f(self.block_shape)}", + f"order={f(self.order)}", + f"offsets={f(offsets)}", + ] + return f"tl.make_block_ptr({', '.join(args)})" + + @cache_on_self + def boundary_check(self) -> List[int]: + """List of indices to pass to tl.load(boundary_check=...)""" + sizevars = V.graph.sizevars + + # Substitute maximum block sizes in shape expressions. + # This works in multiple_of checks because block sizes are powers of 2. + block_to_max: Dict[sympy.Expr, Any] = { + block_size: TRITON_MAX_BLOCK[prefix_str[symt].upper()] + for symt, block_size in block_sizes.items() + } + + return [ + idx + for idx in range(len(self.shape)) + if ( + not sizevars.statically_known_equals( + self.strides[idx], sympy.Integer(0) + ) + and not sizevars.statically_known_multiple_of( + self.shape[idx], self.block_shape[idx] + ) + and not sizevars.statically_known_multiple_of( + self.shape[idx], sympy_subs(self.block_shape[idx], block_to_max) + ) + and not ( + V.kernel.no_x_dim + and self.block_shape[idx] == block_sizes[SymT.XBLOCK] + ) + ) + ] + + def advance_roffset(self): + """ + Codegen string to pass to tl.advance(name, ...). + + Advance is the difference between offsets in each loop iteration. + To compute it, we replace roffset with multiples of RBLOCK. + Since we expect roffset to vary in range(0, rnumel, RBLOCK), the first + iteration has roffset=0, while the second has roffset=RBLOCK. + """ + rblock = block_sizes[SymT.RINDEX] + advance = [ + ( + self.replace_roffset(offset, rblock) + - self.replace_roffset(offset, sympy.Integer(0)) + ) + for offset in self.offsets + ] + return V.kernel.index_to_str(advance) + + def has_indirect(self): + return False # block_ptr can't do indirect indexing + + def has_rindex(self) -> bool: + return any(free_symbol_is_type(expr, SymT.RINDEX) for expr in self.block_shape) + + def has_rmask(self): + return self.has_rindex() + + def has_tmpmask(self): + return False # block_ptr can't do indirect indexing + + def has_mask(self): + return bool(self.boundary_check()) + + +def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]): + """Workaround https://github.com/openai/triton/issues/2836""" + assert isinstance(old_shape, list) and isinstance(new_shape, list) + if old_shape == new_shape: + return value + if [s for s in new_shape if s != "1"] != old_shape: + return f"tl.reshape({value}, [{', '.join(new_shape)}])" + # rewrite to [:, None] syntax, which is less buggy + idx = 0 + expand = [] + for size in new_shape: + if idx < len(old_shape) and size == old_shape[idx]: + expand.append(":") + idx += 1 + else: + assert size == "1" + expand.append("None") + assert idx == len(old_shape) + return f"{value}[{', '.join(expand)}]" + + +# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a +# number of operators which Triton "implements", but in a way that is +# inconsistent with Python semantics (and consistent with C semantics). We +# must override all of these, or it is potential silent correctness problem +class TritonPrinter(PythonPrinter): + def _print_TruncToInt(self, expr): + assert len(expr.args) == 1 + return ( + f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)" + + def _print_PythonMod(self, expr): + quot, div = expr.args + quot_s = self._print(quot) + div_s = self._print(div) + if quot.is_nonnegative and div.is_nonnegative: + return f"{self.paren(quot_s)} % {self.paren(div_s)}" + return f"triton_helpers.remainder_integer({quot_s}, {div_s})" + + def _print_FloorDiv(self, expr): + assert expr.is_integer + quot, div = expr.args + quot_s = self._print(quot) + div_s = self._print(div) + if quot.is_nonnegative and div.is_nonnegative: + return f"({self.paren(quot_s)} // {self.paren(div_s)})" + return f"triton_helpers.div_floor_integer({quot_s}, {div_s})" + + # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher + # precision algorithm, which we would need to replicate here + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + # NB: sympy.floor/ceiling produce integers, so we have to do the + # conversion to index dtype + def _print_floor(self, expr): + assert len(expr.args) == 1 + return ( + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_FloorToInt(self, expr): + assert len(expr.args) == 1 + return ( + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + + def _helper_sqrt(self, expr): + return f"libdevice.sqrt({self._print(expr)}.to(tl.float32))" + + def _print_FloatPow(self, expr): + return ( + f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})" + ) + + _print_PowByNatural = _print_FloatPow + + def _print_Where(self, expr): + c = self.doprint(expr.args[0]) + p = self.doprint(expr.args[1]) + q = self.doprint(expr.args[2]) + return f"tl.where({c}, {p}, {q})" + + def _print_min_max_helper(self, expr: sympy.Expr, cmp: str) -> str: + """ + Helper for max/min code genereration. + cmp: > or < + """ + nargs = len(expr.args) + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + cls = type(expr) + a = self._print(cls(*expr.args[:mid])) + b = self._print(cls(*expr.args[mid:])) + + # Use a macro so we can propagate constexprs. + # https://github.com/triton-lang/triton/issues/3815 + a, b = tuple(f"({x})" for x in (a, b)) + assert cmp in (">", "<"), f"Unexpected comparator: '{cmp}'" + return f"({a} * ({a} {cmp}= {b}) + {b} * ({b} {cmp} {a}))" + + def _print_Min(self, expr): + return self._print_min_max_helper(expr, "<") + + def _print_Max(self, expr): + return self._print_min_max_helper(expr, ">") + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return f"tl_math.abs({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cos(self, expr): + assert len(expr.args) == 1 + return f"libdevice.cos(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_cosh(self, expr): + assert len(expr.args) == 1 + return f"libdevice.cosh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_acos(self, expr): + assert len(expr.args) == 1 + return f"libdevice.acos(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_sin(self, expr): + assert len(expr.args) == 1 + return f"libdevice.sin(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_sinh(self, expr): + assert len(expr.args) == 1 + return f"libdevice.sinh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_asin(self, expr): + assert len(expr.args) == 1 + return f"libdevice.asin(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_tan(self, expr): + assert len(expr.args) == 1 + return f"libdevice.tan(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_tanh(self, expr): + assert len(expr.args) == 1 + return f"libdevice.tanh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_atan(self, expr): + assert len(expr.args) == 1 + return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_RoundToInt(self, expr): + assert len(expr.args) == 1 + return f"libdevice.llrint({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr): + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + return f"libdevice.nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits}" + + +texpr = TritonPrinter().doprint + + +def triton_compute_type(dtype): + triton_type_name = str(dtype).split(".")[-1] + if triton_type_name == "bool": + triton_type_name = "int1" + elif ( + triton_type_name in ("float16", "bfloat16") + and config.triton.codegen_upcast_to_fp32 + ): + # float16 math is done in float32 inside the kernel + triton_type_name = "float32" + elif triton_type_name == "float8_e4m3fn": + triton_type_name = "float8e4nv" + elif triton_type_name == "float8_e5m2": + triton_type_name = "float8e5" + elif triton_type_name == "float8_e4m3fnuz": + triton_type_name = "float8e4b8" + elif triton_type_name == "float8_e5m2fnuz": + triton_type_name = "float8e5b16" + return f"tl.{triton_type_name}" + + +def _get_primitive_bitwidth(dtype): + if hasattr(dtype, "is_floating_point"): + if dtype.is_floating_point: + # triton_compute_type changes the bitwidth + if ( + dtype in [torch.bfloat16, torch.float16] + and config.triton.codegen_upcast_to_fp32 + ): + return 32 + return torch.finfo(dtype).bits + else: + return torch.iinfo(dtype).bits + else: + return -1 + + +def triton_store_type(dtype): + triton_type_name = str(dtype).split(".")[-1] + if triton_type_name == "bool": + triton_type_name = "int8" + elif triton_type_name == "float8_e4m3fn": + triton_type_name = "float8e4nv" + elif triton_type_name == "float8_e5m2": + triton_type_name = "float8e5" + return f"tl.{triton_type_name}" + + +def triton_acc_type(dtype): + if is_integer_dtype(dtype) and dtype.is_signed: + nbits = 64 if dtype == torch.int64 else 32 + return f"tl.int{nbits}" + return triton_compute_type(dtype) + + +class TritonCSEVariable(CSEVariable): + def __init__(self, name, bounds: ValueRanges[Any]) -> None: + super().__init__(name, bounds) + # We'll use this to track which masks the variable needs when used for indirect indexing + self.mask_vars: OrderedSet[str] = OrderedSet() + + def update_on_args(self, name, args, kwargs): + for arg in args: + if isinstance(arg, TritonCSEVariable): + self.mask_vars.update(arg.mask_vars) + elif isinstance(arg, sympy.Symbol) and arg.name[0] in "xyr": + # most of the time index vars don't need masks associated with them + # however, when index vars are used to compute indices for indirect reads + # those reads should subsequently be masked, + self.mask_vars.update({f"{arg.name[0]}mask"}) + + +class TritonOverrides(OpOverrides): + """Map element-wise ops to Triton""" + + @staticmethod + def to_dtype( + x, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types=True, + ): + def _get_min_elements_per_thread( + src_dtype: torch.dtype, dst_dtype: torch.dtype + ) -> int: + if src_dtype == dst_dtype: + # No data type conversion is needed. No requirements on min_elem_per_thread. + return 0 + + # fp8 data type conversions has min_elem_per_thread requirements. + # Refer to Triton implementations here: + # https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. + fp8_dtypes = ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ) + # Triton doesn't support type conversions between fp8_e4m3 and fp8_e5m2. + assert not ( + src_dtype in fp8_dtypes + and dst_dtype in fp8_dtypes + and src_dtype != dst_dtype + ), "Conversions between float8_e5m2 and float8_e4m3fn is not supported!" + if src_dtype == torch.float8_e5m2 or dst_dtype == torch.float8_e5m2: + return 4 + if src_dtype == torch.float8_e4m3fn or dst_dtype == torch.float8_e4m3fn: + return 2 + # No requirements on min_elem_per_thread. + return 0 + + if src_dtype is not None: + # Both dtype and src_dtype are set. This is used by torch to(dtype=dtype). + # It takes the maximum min_elem_per_thread if there are multiple fp8 conversions + # in the same kernel. + V.kernel.min_elem_per_thread = max( + _get_min_elements_per_thread(src_dtype, dtype), + V.kernel.min_elem_per_thread, + ) + + if dtype == torch.bool: + return f"({x} != 0)" + elif dtype == torch.uint8: + # to work around llvm uint conversion semantics + # that produces 0's for negative values + return f"{x}.to(tl.int8).to(tl.uint8)" + + if use_compute_types: + out_dtype = triton_compute_type(dtype) + else: + out_dtype = triton_store_type(dtype) + + return f"{x}.to({out_dtype})" + + @staticmethod + def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype): + triton_dtype = triton_compute_type(dtype) + # We may promote float16 or bfloat16 to float32 and cause the + # bitwidth of dtype to be different from the input tensor (i.e. float32). + # In such as case, we will have to convert the input tensor to + # its src_type, perform bitcast, and then convert the bit-casted + # tensor back to float to ensure we use values with the right precision. + if ( + src_dtype in (torch.float16, torch.bfloat16) + and config.triton.codegen_upcast_to_fp32 + ): + triton_src_dtype = str(src_dtype).split(".")[-1] + cast_x = f"{x}.to(tl.{triton_src_dtype})" + if dtype in (torch.float16, torch.bfloat16): + triton_type_name = str(dtype).split(".")[-1] + triton_dtype = f"tl.{triton_type_name}" + cast_x = f"{cast_x}.to({triton_dtype}, bitcast=True)" + return f"{cast_x}.to(tl.float32)" + else: + src_dtype_bitwidth = _get_primitive_bitwidth(src_dtype) + target_dtype_bitwidth = _get_primitive_bitwidth(dtype) + bitcast = "True" if src_dtype_bitwidth == target_dtype_bitwidth else "False" + return f"{x}.to({triton_dtype}, bitcast={bitcast})" + + @staticmethod + def _shaped_constant(value, dtype, shape): + type_ = torch._prims_common.dtype_to_type(dtype) + triton_val = constant_repr(type_(value)) + triton_type = triton_compute_type(dtype) + + if triton_type == "tl.float32": + # Float constants are always f32 in triton + return triton_val + + # NOTE: We use a tensor here in order to get the expected type. + # Otherwise, e.g. float64 constants would be trunctated to float32. + return f"tl.full({shape}, {triton_val}, {triton_type})" + + @classmethod + def constant(cls, value, dtype): + return cls._shaped_constant(value, dtype, shape=[]) + + @staticmethod + def abs(x): + return f"tl_math.abs({x})" + + @staticmethod + def libdevice_abs(x): + return f"libdevice.abs({x})" + + @staticmethod + def exp(x): + return f"tl_math.exp({x})" + + @staticmethod + def libdevice_exp(x): + return f"libdevice.exp({x})" + + @staticmethod + def exp2(x): + return f"libdevice.exp2({x})" + + @staticmethod + def expm1(x): + return f"libdevice.expm1({x})" + + @staticmethod + def sqrt(x): + return f"libdevice.sqrt({x})" + + @staticmethod + def libdevice_sqrt(x): + return f"libdevice.sqrt({x})" + + @staticmethod + def relu(x): + bug = config.triton.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + # NB: this only triggers runtime error as long as input + # is not all zero + return f'triton_helpers.device_assert_then({x} == 0, "injected assert fail", {x})' + elif bug == "accuracy": + return f"{x} + 1" + elif bug is None: + return ops.maximum(ops.constant(0, torch.int32), x) + else: + raise AssertionError( + f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def minimum(a, b): + return f"triton_helpers.minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"triton_helpers.maximum({a}, {b})" + + @staticmethod + def where(a, b, c): + return f"tl.where({a}, {b}, {c})" + + @staticmethod + def inline_asm_elementwise( + *inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1 + ): + triton_type = triton_compute_type(dtype) + input_refs = ", ".join([str(i) for i in inputs]) + if constraints is None: + constraints = ", ".join(["=r"] + ["r" for _ in inputs]) + return f"tl.inline_asm_elementwise('{asm}', '{constraints}', [{input_refs}], dtype={triton_type}, is_pure={is_pure}, pack={pack})" # noqa: B950 + + @staticmethod + def cos(x): + return f"tl_math.cos({x})" + + @staticmethod + def libdevice_cos(x): + return f"libdevice.cos({x})" + + @staticmethod + def sin(x): + return f"tl_math.sin({x})" + + @staticmethod + def libdevice_sin(x): + return f"libdevice.sin({x})" + + @classmethod + def index_expr(cls, expr, dtype): + raise NotImplementedError("ops.index_expr not implemented outside a kernel") + + @staticmethod + def masked(mask, body, other): + raise NotImplementedError("ops.masked not implemented outside a kernel") + + @staticmethod + def lgamma(x): + return f"libdevice.lgamma({x})" + + @staticmethod + def erf(x): + return f"libdevice.erf({x})" + + @staticmethod + def cosh(x): + return f"libdevice.cosh({x})" + + @staticmethod + def sinh(x): + return f"libdevice.sinh({x})" + + @staticmethod + def acos(x): + return f"libdevice.acos({x})" + + @staticmethod + def acosh(x): + return f"libdevice.acosh({x})" + + @staticmethod + def asin(x): + return f"libdevice.asin({x})" + + @staticmethod + def asinh(x): + return f"libdevice.asinh({x})" + + @staticmethod + def atan2(x, y): + return f"libdevice.atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"libdevice.atan({x})" + + @staticmethod + def atanh(x): + return f"libdevice.atanh({x})" + + @staticmethod + def copysign(x, y): + return f"libdevice.copysign({x}, {y})" + + @staticmethod + def erfc(x): + return f"libdevice.erfc({x})" + + @staticmethod + def erfinv(x): + return f"libdevice.erfinv({x})" + + @staticmethod + def hypot(x, y): + return f"libdevice.hypot({x}, {y})" + + @staticmethod + def log10(x): + return f"libdevice.log10({x})" + + @staticmethod + def log2(x): + return f"libdevice.log2({x})" + + @staticmethod + def nextafter(x, y): + return f"libdevice.nextafter({x}, {y})" + + @staticmethod + def logical_and(a, b): + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"{a} == 0" + + @staticmethod + def logical_or(a, b): + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + return f"({a} ^ {b})" + + @staticmethod + def bitwise_and(a, b): + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def rand(seed, offset): + offset = f"({offset}).to(tl.uint32)" + return f"tl.rand({seed}, {offset})" + + @staticmethod + def randn(seed, offset): + offset = f"({offset}).to(tl.uint32)" + return f"tl.randn({seed}, {offset})" + + @staticmethod + def randint64(seed, offset, low, high): + offset = f"({offset}).to(tl.uint32)" + return f"triton_helpers.randint64({seed}, {offset}, {low}, {high})" + + @staticmethod + def load_seed(name, offset): + raise NotImplementedError("ops.load_seed not implemented outside a kernel") + + @staticmethod + def rsqrt(x): + return f"libdevice.rsqrt({x})" + + @staticmethod + def log1p(x): + return f"libdevice.log1p({x})" + + @staticmethod + def tan(x): + return f"libdevice.tan({x})" + + @staticmethod + def tanh(x): + return f"libdevice.tanh({x})" + + @staticmethod + def sigmoid(x): + return f"tl.sigmoid({x})" + + @staticmethod + def signbit(x): + # XX: This is wrong for the value -0.0 in floating point + return f"libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0" + + @staticmethod + def fmod(a, b): + return f"libdevice.fmod({a}, {b})" + + @staticmethod + def pow(a, b): + return f"libdevice.pow({a}, {b})" + + @staticmethod + def log(x): + return f"tl_math.log({x})" + + @staticmethod + def libdevice_log(x): + return f"libdevice.log({x})" + + @staticmethod + def isinf(x): + return f"libdevice.isinf({x}).to(tl.int1)" + + @staticmethod + def isnan(x): + return f"libdevice.isnan({x}).to(tl.int1)" + + @staticmethod + def round(x): + return f"libdevice.nearbyint({x})" + + @staticmethod + def floor(x): + return f"libdevice.floor({x})" + + @staticmethod + def floordiv(a, b): + # See the comment in lowering.div_mode. a and b are integer type. + # Similar to div_floor_kernel_cuda in pytorch core. + # Notice that // in triton behaves as truncdiv instead of floordiv + quot = f"{a} // {b}" + rem = f"{a} % {b}" + return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})" + + @staticmethod + def sign(x): + z = ops.constant(0, torch.int32) + left = ops.to_dtype((ops.lt(z, x)), torch.int8) + right = ops.to_dtype((ops.lt(x, z)), torch.int8) + sub = ops.sub(left, right) + return f"{sub}.to({x}.dtype)" + + @staticmethod + def trunc(x): + return f"libdevice.trunc({x})" + + @staticmethod + def truncdiv(a, b): + # See the comment in lowering.div_mode. a and b are integer type. + # Notice that // in triton behaves as truncdiv instead of floordiv + return f"{a} // {b}" + + @staticmethod + def ceil(x): + return f"libdevice.ceil({x})" + + +TritonOverrides._initialize_pointwise_overrides("triton") + + +# Use mypy to check protocol implemented correctly +def _typecheck_TritonOverrides(h: TritonOverrides) -> OpsHandler[str]: + return h + + +class TritonKernelOverrides(TritonOverrides): + """Map element-wise ops to Triton within a TritonKernel + + Unlike TritonOverrides, these assume the code is going to be inserted into + the body of the main triton kernel and so it may use indexing and mask + variables which are assumed to already be defined in the current scope. + """ + + @classmethod + def constant(cls, value, dtype): + # NOTE: Cannot use shape=[] as it's not supported by triton-rocm + # We could use shape=[1] instead but starting with the correct + # ndim avoids extra `tt.expand_dim` ops appearing in the triton IR. + ndim = V.kernel.triton_tensor_ndim() + shape = [1] * ndim + return cls._shaped_constant(value, dtype, shape=shape) + + @classmethod + def index_expr(cls, expr, dtype): + indexing = V.kernel.indexing(expr, block_ptr=False) + assert isinstance(indexing, IndexingOptions) + var = V.kernel.cse.generate( + V.kernel.compute, indexing.index_str, bounds=get_bounds_index_expr(expr) + ) + + if dtype not in (torch.int32, torch.int64): + var = V.kernel.cse.generate(V.kernel.compute, cls.to_dtype(var, dtype)) + var.mask_vars = indexing.mask_vars + return var + + @staticmethod + def masked(mask, body, other): + if mask is not None and torch.version.hip is not None: + mask = V.kernel.cse.generate( + V.kernel.compute, + f"{mask}.to(tl.int1)", + ) + + nodes = body.graph.find_nodes(op="output") + assert nodes, "graph for body does not contain an output" + + need_where = False + for node in nodes: + for arg in node.args: + if arg.target != "load" or should_unwrap_unspec_arg(arg.args[0]): + need_where = True + + value = None if need_where else other + with V.kernel.mask_loads(mask, value=value) as new_mask: + result = body() + + if need_where: + # Remove once CSEVariables track the dtype + if result.bounds.is_bool: + other = bool(other) + # Take dtype from result to prevent accidental promotion + other = V.kernel.cse.generate( + V.kernel.compute, + f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)", + bounds=ValueRanges.wrap(other), + ) + ret = ops.where(new_mask, result, other) + else: + ret = result + + ret.mask_vars.discard(new_mask) + return ret + + @staticmethod + def load_seed(name, offset): + var = V.kernel.args.input(name) + return ( + f"tl.load({var} + {V.kernel.args.seed_offset('load_seed_offset', offset)})" + ) + + @staticmethod + def frexp(x): + cache_key = f"frexp({x})" + if cache_key in V.kernel.cse.cache: + return V.kernel.cse.cache[cache_key] + + mantissa = V.kernel.cse.newvar() + exponent = V.kernel.cse.newvar() + V.kernel.compute.writeline( + f"{mantissa}, {exponent} = triton_helpers.frexp({x})" + ) + V.kernel.cse.cache[cache_key] = (mantissa, exponent) + return (mantissa, exponent) + + +# Use mypy to check protocol implemented correctly +def _typecheck_TritonKernelOverrides(h: TritonKernelOverrides) -> OpsHandler[str]: + return h + + +class HelperFunctions: + """An ordered set of helper functions.""" + + _templates_seen: Dict[str, str] # Template code to function name + finalized_helpers: List[str] + + def __init__(self) -> None: + self._templates_seen = {} + self.finalized_helpers = [] + + def add(self, template_code: str, *, base_name="_triton_helper_fn") -> str: + """This accepts a function definition with the function name + left as a format specifier e.g. + + @triton.jit + def {name}(arg0, arg1): + return arg0 + arg1 + + We add the templated code to the function set and return the name + assigned to that function. + + """ + existing_name = self._templates_seen.get(template_code) + if existing_name is not None: + # Don't duplicate existing helpers + return existing_name + + name = f"{base_name}{len(self.finalized_helpers)}" + self._templates_seen[template_code] = name + self.finalized_helpers.append(template_code.format(name=name)) + return name + + def __iter__(self): + return iter(self.finalized_helpers) + + def __getitem__(self, idx): + return self.finalized_helpers[idx] + + +@dataclasses.dataclass +class BlockParameters: + """ + Class representing ND block dimensions, for block pointer analysis. + """ + + shape: List[sympy.Expr] = dataclasses.field(default_factory=list) + block_shape: List[sympy.Expr] = dataclasses.field(default_factory=list) + strides: List[sympy.Expr] = dataclasses.field(default_factory=list) + offsets: List[sympy.Expr] = dataclasses.field(default_factory=list) + + def __add__(self, other: BlockParameters) -> BlockParameters: + """ + Concatenates block parameters. + """ + cls = type(self) + a, b = tuple(dataclasses.asdict(x) for x in (self, other)) + return cls(**{key: a[key] + b[key] for key in a}) + + +class TritonKernel(SIMDKernel): + overrides = TritonKernelOverrides # type: ignore[assignment] + helper_functions: HelperFunctions + kexpr: Callable[[sympy.Expr], str] = texpr + allow_block_ptr = True + + def __init__( + self, + *groups, + index_dtype: str, + mutations: Optional[OrderedSet[str]] = None, + pid_cache=None, + reduction_hint=ReductionHint.DEFAULT, + min_elem_per_thread=0, + override_persistent_reduction=None, + optimize_mask=True, + ) -> None: + self.optimize_mask: bool = optimize_mask + super().__init__( + *groups, + index_dtype=index_dtype, + mutations=mutations, + reduction_hint=reduction_hint, + pid_cache=pid_cache, + override_persistent_reduction=override_persistent_reduction, + ) + self.suffix: IndentedBuffer = IndentedBuffer() # type: ignore[assignment] + self.outside_loop_vars: OrderedSet[Any] = OrderedSet() + self.min_elem_per_thread = min_elem_per_thread + self.block_ptr_id = itertools.count() + self.helper_functions = HelperFunctions() + + # A set of autotuning hints to pass as part of triton_meta + self.autotune_hints: OrderedSet[AutotuneHint] = OrderedSet() + self.triton_meta: Optional[Dict[str, object]] = None + + self.codegen_range_tree() + + def _get_symt(self, tree: IterationRangesEntry) -> SymT: + prefix_to_symt = {prefix: symt for symt, prefix in prefix_str.items()} + return prefix_to_symt[tree.prefix] + + def _get_block_size(self, tree: IterationRangesEntry) -> sympy.Symbol: + return block_sizes[self._get_symt(tree)] + + def _get_block_offset(self, tree: IterationRangesEntry) -> sympy.Symbol: + return block_offsets[self._get_symt(tree)] + + def _max_block_size(self, tree: IterationRangesEntry) -> int: + return TRITON_MAX_BLOCK[tree.prefix.upper()] + + def codegen_range_tree(self): + for tree in self.range_trees: + # reduction indexing goes inside a loop + if not tree.is_loop: + self.iteration_ranges_codegen_header(tree, self.body) + if self.inside_reduction and self.range_trees[-1].is_loop: + # workaround for this issue: + # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7 + self.body.writeline( + f"rbase = {self.iteration_ranges_ranges_code(self.range_trees[-1])}" + ) + + def need_numel_args(self): + r""" + Indicate whether we need provide numel as arguments for the generated + kernel calls in the benchmark. + + Should be true for pointwise/reduction kernels but false for triton + matmul kernels. + """ + return True + + def should_use_persistent_reduction(self) -> bool: + """ + Heuristic to set self.persistent_reduction and add guards + if needed. + """ + if not (self.inside_reduction and config.triton.persistent_reductions): + return False + threshold = { + ReductionHint.INNER: 1024, + }.get(self.reduction_hint, 64) + + # If multi_kernel is enabled, we do more aggressive persistent reduction. + # This may result in some persistent reductions slower than the + # corresponding non-persistent reductions. MultiKernel will do benchmarking + # to pick the faster one. + if config.triton.multi_kernel: + threshold *= 16 + last_numel = self.numels[-1] + return V.graph.sizevars.statically_known_leq(last_numel, threshold) # type: ignore[arg-types] + + def want_no_x_dim(self): + return ( + self.reduction_hint == ReductionHint.INNER + and self.persistent_reduction + and len(self.numels) == 2 + and V.graph.sizevars.statically_known_geq(self.numels[-1], 256) # type: ignore[arg-types] + ) + + @property + def assert_function(self) -> str: + return "tl.device_assert" + + def indexing( + self, + index: sympy.Expr, + *, + copy_shape=None, + dense_indexing=False, + override_mask=None, + block_ptr=False, + ): + """ + Compute the index and mask to pass to tl.load() or tl.store() + """ + index = self.prepare_indexing(index) + index_vars = index.free_symbols + has_rindex = False + + mask_vars: OrderedSet[str] = OrderedSet() + for var in index_vars: + assert isinstance(var, sympy.Symbol) + has_rindex = has_rindex or symbol_is_type(var, SymT.RINDEX) + if override_mask: + pass + elif symbol_is_type(var, SymT.TMP): + # indirect indexing + cse_var = self.cse.varname_map[var.name] + mask_vars.update(cse_var.mask_vars) + elif symbol_is_type( + var, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + SymT.INDEX, + SymT.FLOAT, + SymT.UNBACKED_FLOAT, + ), + ): + pass + else: + # var is one of xN, yN or rN + assert symbol_is_type( + var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK) + ), var.name + mask_vars.add(f"{var.name[0]}mask") + + need_dense = ( + config.triton.dense_indexing + or dense_indexing + or self._load_mask is not None + ) and index != 0 + + have_dense = True + have_loop_vars = False + dense_mask_vars: OrderedSet[str] = OrderedSet() + + for tree in self.active_range_trees(): + if index_vars.intersection(tree.var_list): + have_loop_vars = True + else: + have_dense = False + dense_mask_vars.add(f"{tree.prefix}mask") + + if ( + block_ptr + and self.allow_block_ptr + and config.triton.use_block_ptr + and not override_mask + and not self._load_mask + and len(mask_vars - dense_mask_vars) == 0 + and not self.is_indirect_indexing(index) + and have_loop_vars + # workaround https://github.com/openai/triton/issues/2821 + and self.index_dtype == "tl.int32" + ): + + def match_strided_block( + index: sympy.Expr, range_tree: IterationRangesEntry + ) -> Optional[BlockParameters]: + """ + Matches expressions of the form: + idx = s * xindex + + This implies stride (s,), and shape (XBLOCK,). + """ + symbol = range_tree.symbol() + stride = sympy.Wild("stride", exclude=[symbol]) + m = index.match(symbol * stride) + if m is None: + return None + + return BlockParameters( + shape=[range_tree.numel], + block_shape=[self._get_block_size(range_tree)], + strides=[m[stride]], + offsets=[self._get_block_offset(range_tree)], + ) + + def match_mod_div_block( + index: sympy.Expr, range_tree: IterationRangesEntry + ) -> Optional[BlockParameters]: + """ + Matches higher-dimensional blocks coming from FloorDiv and ModularIndexing. + + Example expression to match: + sN * ((rindex//(d1 * ... * d(N-1)))) + + s1 * ModularIndexing(rindex, 1, d1) + + ... + + s(N-1) * ModularIndexing(rindex, d1 * ... * d(N-2), d(N-1)) + + This iterates over a block of shape (dN, ..., d1) and stride + (sN, ..., s1). (d1,...,d(N-1)) and (s1,...,sN) are + wildcards that we match. + + Note that dN does not appear in the expression, but we solve for it + using range tree numels and the other dims. + """ + # Bound the possible number of dims. We use the following heuristics: + # - At least one dim for each range tree node. + # - At least one dim for every FloorDiv or ModularIndexing op. + # - At least 2 dims to pattern match. + num_dims = max( + 2, + len(self.range_tree_nodes), + (index.count(FloorDiv) + index.count(ModularIndexing)), + ) + + # Pattern match to find the strides and offset. + index_var = range_tree.symbol() + wild = functools.partial(sympy.Wild, exclude=[index_var]) + dims: List[sympy.Expr] = [ + wild(f"dim_mod{idx}") for idx in range(num_dims) + ] + strides: List[sympy.Expr] = [ + wild(f"stride_mod{idx}") for idx in range(num_dims) + ] + + def get_slice_numels(dims: List[Any]) -> List[Any]: + """ + Compute the cumulative size of each dimension's slice. + This proceeds from the last dim up to the second. + """ + numels = [sympy.Integer(1)] + for dim in dims[:0:-1]: + numel = dim * numels[0] + numels.insert(0, numel) + return numels + + # The first dimension's index is computed by division. + # The remaining are computed by modulo. + slice_numels = get_slice_numels(dims[:num_dims]) + block_index_exprs = [FloorDiv(index_var, slice_numels[0])] + [ + ModularIndexing(index_var, numel, dim) + for dim, numel in zip(dims[1:], slice_numels[1:]) + ] + + # Calculate a linear index from block indices. + match_expr = sympy_dot(strides, block_index_exprs) + + # Pattern match. + match = index.match(match_expr) + if match is None: + return None + + # Provide default values for unmatched dims and strides. + for dim in dims[1:]: + if dim not in match: + match[dim] = sympy.Integer(1) + for stride in strides[1:]: + if stride not in match: + match[stride] = sympy.Integer(0) + + sizevars = V.graph.sizevars + + def get_match(expr: sympy.Expr) -> sympy.Expr: + return sizevars.lookup_precomputed_size(match[expr]) + + # Replace wildcards with matched expressions. + dims = [dims[0]] + [get_match(dim) for dim in dims[1:]] + strides = [get_match(stride) for stride in strides] + slice_numels = get_slice_numels(dims) + block_index_exprs = [ + sympy_subs(expr, match) for expr in block_index_exprs + ] + + # The leading dimension is not directly matched in our expression. + # We solve for it by dividing the range tree numel by the product of + # all other dimensions. We quit if they are not known to be divisible. + assert ( + dims[0] not in match + ), "Expected not to match the leading dimension!" + if not sizevars.statically_known_multiple_of( + range_tree.numel, slice_numels[0] + ): + return None + dims[0] = range_tree.numel / slice_numels[0] + + # Check for applicable iteration range sizes. + # When mapping a 1D block into an ND one, we need to know that + # the number of elements is not changed. This means the slice numels of + # the ND iteration range must evenly divide the length of the 1D block. + # There are two cases where we can guarantee this: + # 1. Numels are powers of 2. If numel == 2 ** n, and we know XBLOCK == 2 ** m, + # with n and m integers, then either numel is a multiple of XBLOCK, or numel + # is less than XBLOCK. (If numel is less than XBLOCK, we round up to 1 below.) + # 2. Numels are multiples of the maximum possible block size. + max_block = self._max_block_size(range_tree) + if any( + not sizevars.statically_known_multiple_of(numel, max_block) + and not sizevars.statically_known_power_of_2(numel) + for numel in slice_numels + ): + return None + + def identity(expr: sympy.Expr) -> sympy.Expr: + return expr + + # Compute the ND block shape from the linear block size. + # Use CielDiv to round leading dimensions up to 1. + # Non-leading dimensions are clamped to the size of the iteration range, + # while the leading dimension can exceed this to accomodate a larger + # block size. + linear_block_size = self._get_block_size(range_tree) + block_shape: List[sympy.Expr] = [ + CeilDiv(linear_block_size, slice_numels[0]) + ] + [ + sympy.Min(CeilDiv(linear_block_size, numel), dim) + for numel, dim in zip(slice_numels[1:], dims[1:]) + ] + + # Compute block offsets from {xyzr}offset and the matched expressions. + block_offsets: List[sympy.Expr] = [ + sympy_subs(expr, {index_var: self._get_block_offset(range_tree)}) + for expr in block_index_exprs + ] + + return BlockParameters( + shape=dims, + block_shape=block_shape, + strides=strides, + offsets=block_offsets, + ) + + def match_block_pointer_subexpr( + expr: sympy.Expr, range_tree: IterationRangesEntry + ) -> Optional[BlockParameters]: + """ + Match a block indexing subexpression involving a single range tree. + """ + for match_func in ( + match_strided_block, + match_mod_div_block, + ): + match = match_func(expr, range_tree) + if match is not None: + return match + + return None + + def match_block_pointer() -> Optional[BlockPtrOptions]: + index_relative_to_xyr_index = sympy_subs( + index, {v: t.expr for v, t in self.range_tree_nodes.items()} + ) + range_trees = self.active_range_trees(reorder=True) + + # Match each range tree separately. + range_symbols = {tree.symbol() for tree in range_trees} + index_terms = sympy.Add.make_args(index_relative_to_xyr_index) + block_params = BlockParameters() + for tree in range_trees: + # Partition the index into subexpressions pertaining to each range tree. + # For example xindex * 5 + rindex * 3 is partitioned to + # (xindex * 5, rindex * 3). + symbol = tree.symbol() + subexpr = sympy.Integer(0) + sum( + expr for expr in index_terms if symbol in expr.free_symbols + ) + + # Reject mixed terms, e.g. xindex * rindex. + # NB: the zero expression is allowed, for broadcasting. + if len(range_symbols.intersection(subexpr.free_symbols)) > 1: + return None + + # Match the subexpression for this range tree. + params = match_block_pointer_subexpr(subexpr, tree) + if params is None: + return None + block_params += params + + # Collect leftover terms as a constant offset. + offset = sum( + expr + for expr in index_terms + if not range_symbols.intersection(expr.free_symbols) + ) + + # Form the block pointer. + self.filter_masks(mask_vars) + return BlockPtrOptions.create( + params=block_params, + constant_offset=offset, + range_trees=range_trees, + mask_vars=mask_vars, + ) + + # Return a block pointer, if indexing matches the pattern. + options = match_block_pointer() + if options is not None: + return options + + expand_str = None + index_str = self.index_to_str(index) + if isinstance(index, sympy.Integer): + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" + return IndexingOptions( + index_str, OrderedSet(), "None", expand_str, has_rindex, index + ) + + if need_dense and not have_dense: + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + index_str = f"tl.broadcast_to({index_str}, {expand_str})" + mask_vars = dense_mask_vars + elif not have_loop_vars and copy_shape: + index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)" + mask_vars = dense_mask_vars + + if override_mask: + mask_vars = OrderedSet([override_mask]) + + if self._load_mask: + mask_vars.add(self._load_mask) + + self.filter_masks(mask_vars) + + mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None" + return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex, index) # type: ignore[arg-type] + + def codegen_block_ptr( + self, name: str, var: str, indexing: BlockPtrOptions, other="" + ) -> Tuple[str, Optional[DeferredLine], str]: + advance_block_ptr = None + check = indexing.boundary_check() + if not check: + # workaround https://github.com/openai/triton/issues/2813 + other = "" + elif other: + assert other == ", other=0.0" + other = f", boundary_check={check!r}, padding_option='zero'" + else: + other = f", boundary_check={check!r}" + if ( + self.inside_reduction + and self.range_trees[-1].is_loop + and indexing.has_rindex() + ): + block_ptr = f"block_ptr{next(self.block_ptr_id)}" + self.body.writeline( + DeferredLine( + name, f"{block_ptr} = {indexing.format(var, roffset=False)}" + ) + ) + advance_block_ptr = DeferredLine( + name, + f"{block_ptr} = tl.advance({block_ptr}, {indexing.advance_roffset()})", + ) + else: + block_ptr = indexing.format(var) + return block_ptr, advance_block_ptr, other + + def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""): + # broadcasting is not implicit for block_ptrs + value = ( + f"tl.broadcast_to({value}, {self.index_to_str(indexing.reshape_suffix)})" + ) + # drop any extra size=1 dimensions + block_shape = [V.kernel.index_to_str(expr) for expr in indexing.block_shape] + value = triton_reshape(value, indexing.reshape_suffix, block_shape) + # workaround https://github.com/openai/triton/issues/2814 + value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})" + return f"tl.store({block_ptr}, {value}{other})" + + def check_bounds( + self, + expr: sympy.Expr, + size: sympy.Expr, + lower: bool, + upper: bool, + ): + if not (lower or upper): + return + + assert isinstance(expr, sympy.Expr) + indexing = self.indexing(expr, block_ptr=False) + assert isinstance(indexing, IndexingOptions) + + index_str = indexing.index_str + mask_str = indexing.mask_str if indexing.has_mask() else None + size_str = texpr(self.rename_indexing(size)) if upper else None + + # expr is already wrapped + line = self.indirect_assert( + index_str, "0" if lower else None, size_str, mask_str + ) + + indirect = self.is_indirect_indexing(expr) or any( + isinstance(m, TritonCSEVariable) for m in indexing.mask_vars + ) + buffer = self.get_load_buffer(indexing) + self.cse.generate(buffer, line, assignment=False) + + def get_load_buffer(self, indexing): + if indexing.has_indirect() or indexing.has_tmpmask(): + # Masked loads must come after the mask is computed + return self.compute + elif ( + self.inside_reduction + and self.range_trees[-1].is_loop + and not indexing.has_rindex() + ): + # can lift a common load outside of reduction loop + # One exception is when this is an indirect_load. + return self.body + else: + return self.loads + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + indirect_indexing = self.is_indirect_indexing(index) + original_index = index + indexing = self.indexing(index, block_ptr=True) + has_rindex = indexing.has_rindex() + has_tmpmask = indexing.has_tmpmask() + + # Keep the variable in cache if were going to reuse it. Equiv., if any of the following hold + # 1) We are doing broadcasting + # 2) It is a non-coalesced load. The intuition is that if it's + # non-coalesced, we will likely load each element multiple times in + # practice. + # 3) It will be used later and it won't be CSE'd. Equiv., if all the following hold + # 3.1) We are in a reduction loop + # 3.2) Its not its last use + # 3.3) This load will not be lifted to the body + # + is_coalesced = any( + i == 1 for i in self.get_strides_of_load(original_index).values() + ) + if self.is_broadcasted(original_index): + ep = ", eviction_policy='evict_last'" + elif not is_coalesced: + ep = ", eviction_policy='evict_last'" + elif self.inside_reduction and self.range_trees[-1].is_loop: + if name in self.args.inplace_buffers: + names: OrderedSet[str] = OrderedSet( + self.args.inplace_buffers[name].other_names + ) + else: + names = OrderedSet([name]) + last_use = len(names & self.last_usage) > 0 + evict_last = not last_use and (has_rindex or indirect_indexing) + if evict_last: + ep = ", eviction_policy='evict_last'" + else: + ep = ", eviction_policy='evict_first'" + else: + ep = "" + + if (has_tmpmask or has_rindex) and indexing.has_mask(): + if self._load_other: + other = f", other={constant_repr(self._load_other)}" + else: + other = ", other=0.0" + else: + other = "" + + advance_block_ptr = None + append_broadcast = None + if should_unwrap_unspec_arg(name): + line = var + else: + if isinstance(indexing, BlockPtrOptions): + block_ptr, advance_block_ptr, other = self.codegen_block_ptr( + name, var, indexing, other + ) + line = f"tl.load({block_ptr}{other}{ep})" + # add needed size=1 dimensions + block_shape = [str(dim) for dim in indexing.block_shape] + line = triton_reshape(line, block_shape, indexing.reshape_suffix) + elif isinstance(original_index, sympy.Integer): + line = f"tl.load({var} + ({original_index}))" + append_broadcast = indexing.expand_str + else: + line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other})" + + dtype = V.graph.get_dtype(name) + if ( + dtype in (torch.float16, torch.bfloat16) + and config.triton.codegen_upcast_to_fp32 + ): + line += ".to(tl.float32)" + if dtype == torch.bool and torch.version.hip is None: + # Workaround for https://github.com/openai/triton/issues/2151 + # tl.load returns int8 when loading from pointer to int1 + # NOTE: Currently causes hangs on bool UTs for ROCm + line += ".to(tl.int1)" + + load_buffer = self.get_load_buffer(indexing) + result_var = self.cse.generate(load_buffer, line) + assert isinstance(result_var, TritonCSEVariable) + result_var.mask_vars = indexing.mask_vars # type: ignore[assignment] + + if append_broadcast: + line = f"tl.broadcast_to({result_var}, {append_broadcast})" + result_var = self.cse.generate(load_buffer, line) + + if advance_block_ptr: + load_buffer.writeline(advance_block_ptr) + + if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex): + self.outside_loop_vars.add(result_var) + + return result_var + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + var = self.args.output(name) + original_index = index + indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None) + + # Guard against write-after-read corruption in triton. + # See # https://github.com/openai/triton/issues/1615 + # This triton bug means that a load which is broadcasted over multiple + # warps may see the result of a store that happens later in the triton + # program. The workaround is to add a barrier before storing, which + # enforces that all warps have already read the data. + is_inplace = name in self.args.inplace_buffers + is_broadcasted = self.is_broadcasted(original_index) + if is_inplace and is_broadcasted: + self.stores.writeline(DeferredLine(name, "tl.debug_barrier()")) + + advance_block_ptr = None + if isinstance(indexing, BlockPtrOptions): + block_ptr, advance_block_ptr, other = self.codegen_block_ptr( + name, var, indexing + ) + # block_ptr stores don't do implicit casting + line = self.codegen_block_ptr_store_line( + name, indexing, block_ptr, value, other + ) + elif mode is None: + line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})" + elif mode == "atomic_add": + line = f"tl.atomic_add({var} + ({indexing.index_str}), {value}, {indexing.mask_str}, sem='relaxed')" + else: + raise NotImplementedError(f"store mode={mode}") + self.stores.writeline(DeferredLine(name, line)) + if advance_block_ptr: + self.stores.writeline(advance_block_ptr) + + if not self.inside_reduction: + self.outside_loop_vars.add(value) + + def bucketize( + self, + values: CSEVariable, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> CSEVariable: + """ + See [Note: Inductor bucketize op] + """ + + # Triton performance for bucketize_binary_search is much better when the number + # of threads equals the number of elements. + # If we're trying to use a bucketize kernel, we should make sure that an + # autotuning config with num_elements_per_warp=32 exists. + self.autotune_hints.add(AutotuneHint.ELEMENTS_PER_WARP_32) + + offsets_ptr = self.args.input(offsets_name) + block_size = self.dense_size_str() + offsets_size_str = self.index_to_str(offsets_size) + + if indexing_dtype == torch.int32: + triton_dtype = "tl.int32" + elif indexing_dtype == torch.int64: + triton_dtype = "tl.int64" + else: + raise NotImplementedError( + "Bucketize only supports indexing with int32 and int64" + ) + + result = self.cse.generate( + self.compute, + f"triton_helpers.bucketize_binary_search({values}, {offsets_ptr}, {triton_dtype}, {right}, {offsets_size_str}, {block_size})", # noqa: B950 line too long + ) + + return result + + def reduction_resize(self, value): + ndims = self.triton_tensor_ndim() + if ndims == 1: + return f"triton_helpers.promote_to_tensor({value})" + + sizes = [":"] * ndims + sizes[-1] = "None" + return f"{value}[{', '.join(sizes)}]" + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + assert self.inside_reduction + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + masks = sorted(masks) + if self._load_mask: + masks.append(self._load_mask) + reduction_range_prefix = self.range_trees[-1].prefix + + # Say we have + # tmp0 = ops.constant(1, torch.int64) + # tmp1 = ops.reduction(torch.int64, torch.int64, "sum", tmp0) + # tmp0 in the triton code is either a scalar, or single-element tensor + # so if we emit tl.sum directly, it will only give 1 instead of RBLOCK * 1 + # To avoid this, we broadcast to the expected shape first. + dense_size_str = self.dense_size_str() + value = self._map_tuple_or_scalar( + lambda v: self.cse.generate( + self.compute, f"tl.broadcast_to({v}, {dense_size_str})" + ), + value, + ) + + dim: int + root_op: str + + def final_reduction(value): + use_helper = reduction_type in {"any", "max", "min", "prod"} + module = "triton_helpers" if use_helper else "tl" + if reduction_type in {"max", "min"}: + return self.reduction_resize( + f"{module}.{reduction_type}2({value}, {dim})" + ) + return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})") + + def final_argreduce(buffer, result_var, value, index): + buffer.splice( + f"""\ + _, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) + {result_var} = {self.reduction_resize(f'{result_var}_tmp')} + """ + ) + + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + dim = self.triton_tensor_ndim() - 1 + acc_type = triton_acc_type(src_dtype) + result_var: Any = self.cse.newvar() + result_var.mask_vars = OrderedSet(var for var in masks if var[0] != "r") + cond = " & ".join(masks) + + def where_cond(tval, fval): + if not cond: + return tval + return TritonKernelOverrides.where(cond, tval, fval) + + if self.persistent_reduction: + default = ir.Reduction.default_value(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(constant_repr, default) + + def _mask_value(value, default): + return self.cse.generate(self.compute, where_cond(value, default)) + + if isinstance(value, tuple): + masked_value = [_mask_value(v, d) for v, d in zip(value, default)] + else: + masked_value = _mask_value(value, default) + + if reduction_type in {"argmax", "argmin"}: + accumulator_index = str( + self.cse.generate( + self.compute, + f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)", + ) + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + final_argreduce( + self.compute, result_var, masked_value, accumulator_index + ) + elif reduction_type == "welford_reduce": + # For persistent reductions, don't bother with + # welford's algorithm since it uses more registers, and + # taking two reductions doesn't increase memory usage. + result_var = self.welford_reduce_fallback(dtype, value) + elif reduction_type == "welford_combine": + mean, m2, weight = masked_value + welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})" + mean, m2, weight = (self.cse.newvar() for _ in range(3)) + self.compute.writeline(f"{mean}, {m2}, {weight} = {welford}") + + result_var = tuple( + self.cse.generate(self.compute, self.reduction_resize(var_name)) + for var_name in (mean, m2, weight) + ) + else: + result_var = self.cse.generate( + self.compute, final_reduction(masked_value) + ) + else: + accumulator = f"_{result_var}" + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(constant_repr, default) + if not isinstance(default, tuple): + self.body.writeline( + f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})" + ) + + if reduction_type in {"argmax", "argmin"}: + accumulator_index = f"_{result_var}_index" + long_max = torch.iinfo(torch.int64).max + self.body.writeline( + f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)" + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( + {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index + ) + {accumulator} = {where_cond(f'{accumulator}_next', accumulator)} + {accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)} + """ + ) + final_argreduce(self.suffix, result_var, accumulator, accumulator_index) + elif is_welford_reduction(reduction_type): + accumulator = f"{result_var}_mean" + accumulator_m2 = f"{result_var}_m2" + accumulator_weight = f"{result_var}_weight" + self.body.writeline( + f"{accumulator} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + self.body.writeline( + f"{accumulator_m2} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + self.body.writeline( + f"{accumulator_weight} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + + if reduction_type == "welford_combine": + mean, m2, weight = value + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_combine( + {accumulator}, {accumulator_m2}, {accumulator_weight}, + {mean}, {m2}, {weight} + ) + """ + ) + else: + assert reduction_type == "welford_reduce" + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_reduce( + {value}, {accumulator}, {accumulator_m2}, {accumulator_weight}, roffset == 0 + ) + """ + ) + + self.compute.splice( + f"""\ + {accumulator} = {where_cond(f'{accumulator}_next', accumulator)} + {accumulator_m2} = {where_cond(f'{accumulator_m2}_next', accumulator_m2)} + {accumulator_weight} = {where_cond(f'{accumulator_weight}_next', accumulator_weight)} + """ + ) + + result_mean = result_var + result_m2 = self.cse.newvar() + result_weight = self.cse.newvar() + self.suffix.splice( + f"""\ + {result_mean}_tmp, {result_m2}_tmp, {result_weight}_tmp = triton_helpers.welford( + {accumulator}, {accumulator_m2}, {accumulator_weight}, {dim} + ) + {result_mean} = {self.reduction_resize(f'{result_mean}_tmp')} + {result_m2} = {self.reduction_resize(f'{result_m2}_tmp')} + {result_weight} = {self.reduction_resize(f'{result_weight}_tmp')} + """ + ) + result_var = result_mean, result_m2, result_weight + else: + combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype) + updated = combine_fn(accumulator, value) + self.compute.writeline( + f"{accumulator} = {where_cond(updated, accumulator)}" + ) + + if src_dtype == torch.bool: + # This is only really used for aten.any. It changes the + # final reduction of a non-persistent reduction from + # tmp5 = triton_helpers.max(_tmp5, 1)[:, None] + # to + # tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1) + # which is needed because tl.reduce doesn't support tl.int1 + accumulator = f"{accumulator}.to(tl.int8)" + result_type = triton_compute_type(dtype) + self.suffix.writeline( + f"{result_var} = {final_reduction(accumulator)}.to({result_type})" + ) + else: + self.suffix.writeline( + f"{result_var} = {final_reduction(accumulator)}" + ) + + self.cse.reduction_cache[cache_key] = result_var + + if isinstance(result_var, tuple): + assert all(isinstance(x, TritonCSEVariable) for x in result_var) + self.outside_loop_vars |= OrderedSet(result_var) + else: + assert isinstance(result_var, TritonCSEVariable) + self.outside_loop_vars.add(result_var) + + return result_var + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + assert self.inside_reduction + self.inside_reduction = False + indexing = self.indexing(index, block_ptr=True) + self.inside_reduction = True + var = self.args.output(name) + + if isinstance(indexing, BlockPtrOptions): + self.suffix.writeline( + DeferredLine( + name, + self.codegen_block_ptr_store_line( + name, + indexing, + indexing.format(var), + value, + f", boundary_check={indexing.boundary_check()!r}", + ), + ) + ) + else: + assert isinstance(indexing, IndexingOptions) + self.suffix.writeline( + DeferredLine( + name, + f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})", + ) + ) + + def _lift_helper(self, fn, num_args) -> str: + # Lift IR function for scan operations into a triton function + # in the global namespace + helper = IndentedBuffer() + helper.writeline("@triton.jit") + args = [tuple(f"arg{i}_{n}" for n in range(num_args)) for i in range(2)] + signature = ", ".join(itertools.chain.from_iterable(args)) + helper.writeline(f"def {{name}}({signature}):") + + cse = CSE(prefix="", suffix="") + overrides = TritonOverrides(V.MockHandler()) + + # Build a name that changes depending on fn to workaround a triton bug + # where the combine_fn to reduce and scan is not hashed, and so different + # scan ops may collide in the triton cache. + # This is fixed with the latest triton pin, but not the triton-rocm pin. + helper_name = "_triton_helper_fn" + + class CSEProxy: + def __getattr__(self, name: str) -> Callable[..., CSEVariable]: + def inner(*args, **kwargs): + nonlocal helper_name + helper_name += f"_{name}" + return cse.generate( + helper, + getattr(overrides, name)(*args, **kwargs), + ) + + return inner + + with helper.indent(), V.set_ops_handler(CSEProxy()): + outputs = fn(*args) + outputs = ", ".join(str(output) for output in outputs) + helper.writeline(f"return {outputs}") + + return self.helper_functions.add(helper.getvalue(), base_name=helper_name) + + def scan( + self, + dtypes: Tuple[torch.dtype, ...], + combine_fn: Callable[ + [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...] + ], + values: Tuple[CSEVariable, ...], + ) -> Tuple[CSEVariable, ...]: + assert self.inside_reduction + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + masks = sorted(masks) + assert not self._load_mask, "ops.scan not supported inside ops.masked" + reduction_range_prefix = self.range_trees[-1].prefix + + broadcasted_values = [] + accumulators = [] + + cse_compute = functools.partial(self.cse.generate, self.compute) + combine_helper_fn = self._lift_helper(combine_fn, len(values)) + dim = self.triton_tensor_ndim() - 1 + + for value, dtype in zip(values, dtypes): + acc_type = triton_acc_type(dtype) + cond = " & ".join(masks) + + value_dtype = self.cse.generate( + self.compute, + f"{value}.to({triton_compute_type(dtype)})", + ) + value = self.cse.generate( + self.compute, + f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})", + ) + broadcasted_values.append(value) + + acc_type = triton_acc_type(dtype) + cond = " & ".join(masks) + + if not self.persistent_reduction: + accumulator = self.cse.newvar() + reduced_size = self.dense_size_list() + reduced_size[-1] = "1" + reduced_size = f"[{', '.join(reduced_size)}]" + + default = "float('nan')" if dtype.is_floating_point else "-1" + self.body.writeline( + f"{accumulator} = tl.full({reduced_size}, {default}, {acc_type})" + ) + + accumulators.append(accumulator) + + def csv(values): + return " ".join(f"{value}," for value in values) + + def cse_multiple(line, n, masks): + cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] + if all(cache_key in self.cse.cache for cache_key in cache_keys): + return [self.cse.cache[cache_key] for cache_key in cache_keys] + result_vars = [self.cse.newvar() for _ in range(n)] + self.compute.writeline( + f"{csv(result_vars)} = {line}", + ) + for result_var, cache_key in zip(result_vars, cache_keys): + if masks: + result_var.mask_vars = masks # type: ignore[attr-defined] + self.cse.cache[cache_key] = result_var + return tuple(result_vars) + + partial_scan_vars = cse_multiple( + f"tl.associative_scan(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})", + len(values), + masks, + ) + + if not self.persistent_reduction: + # tl.reduce doesn't work for non-commutative operators, so instead + # of repeating the scan op as a reduction, we use sum to select the + # last scan value + partial_reduce_vars = [ + cse_compute( + f"triton_helpers.select_one(({partial_scan_var}), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)" + ) + for partial_scan_var in partial_scan_vars + ] + accs_next = combine_fn(tuple(accumulators), tuple(partial_reduce_vars)) + full_scan_vars = combine_fn(tuple(accumulators), partial_scan_vars) + result_vars = [ + cse_compute(f"tl.where(roffset > 0, {full_scan}, {partial_scan})") + for full_scan, partial_scan in zip(full_scan_vars, partial_scan_vars) + ] + for acc_next, accumulator, partial_reduce in zip( + accs_next, accumulators, partial_reduce_vars + ): + self.compute.writeline( + f"{accumulator} = tl.where(roffset > 0, {acc_next}, {partial_reduce})" + ) + else: + result_vars = partial_scan_vars + + for result_var in result_vars: + result_var.mask_vars = masks # type: ignore[attr-defined] + + return tuple(result_vars) + + def sort( + self, + dtypes: Tuple[torch.dtype, ...], + values: Tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> Tuple[CSEVariable, ...]: + assert self.inside_reduction + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + masks = sorted(masks) + assert not self._load_mask, "ops.sort not supported inside ops.masked" + assert ( + self.persistent_reduction + ), "ops.sort is only supported in persistent reductions" + reduction_range_prefix = self.range_trees[-1].prefix + + cse_compute = functools.partial(self.cse.generate, self.compute) + dim = self.triton_tensor_ndim() - 1 + + broadcasted_values = [ + cse_compute(f"tl.broadcast_to({value}, {self.dense_size_str()})") + for value in values + ] + + def csv(values): + return " ".join(f"{value}," for value in values) + + def cse_multiple(line, n, masks): + cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] + if all(cache_key in self.cse.cache for cache_key in cache_keys): + return [self.cse.cache[cache_key] for cache_key in cache_keys] + result_vars = [self.cse.newvar() for _ in range(n)] + self.compute.writeline( + f"{csv(result_vars)} = {line}", + ) + for result_var, cache_key in zip(result_vars, cache_keys): + if masks: + result_var.mask_vars = masks # type: ignore[attr-defined] + self.cse.cache[cache_key] = result_var + return tuple(result_vars) + + assert self.range_trees[-1].prefix == "r" + rnumel = "None" if self._has_constant_mask(self.range_trees[-1]) else "rnumel" + + if len(values) == 2: + line = ( + f"triton_helpers.sort_with_index({broadcasted_values[0]}, {broadcasted_values[1]}," + f" {rnumel}, {dim}, stable={stable}, descending={descending})" + ) + result_vars = cse_multiple(line, len(values), masks) + else: + raise AssertionError("Unhandled sort") + + for result_var, input_var in zip(result_vars, values): + result_var.mask_vars = masks # type: ignore[attr-defined] + result_var.bounds = input_var.bounds + + return tuple(result_vars) + + def codegen_body(self): + """ + Concat output code from index_code, loads, compute, stores, + suffix into self.body. + + For pointwise kernels, this is called just once at the end. + + For reduction kernels, this generates a loop over the reduction + axis. + """ + if not ( + self.indexing_code + or self.loads + or self.stores + or self.compute + or self.suffix + ): + return + + if self.inside_reduction and self.range_trees[-1].is_loop: + self.body.writeline("for roffset in range(0, rnumel, RBLOCK):") + with self.body.indent(): + # last range tree is always reduction + self.iteration_ranges_codegen_header(self.range_trees[-1], self.body) + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + + # invalidate any caches that came from inside the reduction loop + self.cse.invalidate(self.outside_loop_vars) + self.range_trees[-1].cache_clear() + else: + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + self.body.splice(self.suffix) + self.indexing_code.clear() + self.loads.clear() + self.compute.clear() + self.stores.clear() + self.suffix.clear() + + def codegen_kernel_benchmark(self, num_gb, grid=None): + result = IndentedBuffer() + argdefs, call_args, signature, _ = self.args.python_argdefs() + + result.writelines(["", "", "def get_args():"]) + with result.indent(): + name_cnt = itertools.count() + var_names = [] + for arg_name, arg_sig in zip(call_args, signature): + var_name = f"arg_{next(name_cnt)}" + buf = V.graph.try_get_buffer(arg_name) + if buf: + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long + ) + elif arg_name in V.graph.constants: + # note that random seed is put in V.graph.constants + const_tensor = V.graph.constants[arg_name] + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long + ) + elif isinstance(arg_sig, SizeArg): + symval_hint = V.graph.sizevars.size_hint(arg_sig.expr) + + # Force the seed_offset to be 0 so calls to the same kernel + # using different seed offset will have the same benchmark harness. + # We can dedup kernel definitions in this case. + if "seed_offset" in arg_sig.name: + symval_hint = 0 + result.writeline(f"{var_name} = {symval_hint}") + elif isinstance(arg_sig, WorkspaceArg): + device = V.graph.scheduler.get_current_device_or_throw() + nbytes = V.graph.sizevars.size_hint(arg_sig.nbytes) + result.writeline( + f"{var_name} = torch.zeros({nbytes}, device='{device}', dtype=torch.uint8)" + ) + else: + raise KeyError( + f"Don't find the buffer or const tensor for {arg_name}" + ) + var_names.append(var_name) + result.writeline(f"return {', '.join(var_names)},") + + result.writelines(["\n", "\n", "def call(args):"]) + if grid is None: + grid = [] + extra_args = [] + extra_args_str = None + for tree in self.active_range_trees(): + expr = pexpr(V.graph.sizevars.size_hint(tree.numel)) + extra_args.append(expr) + if tree.prefix != "r": + grid.append(expr) + if self.need_numel_args(): + extra_args_str = ", ".join(map(str, extra_args)) + ", " + else: + extra_args_str = "" + grid_arg = f"{extra_args_str}grid=grid({', '.join(grid)})" + else: + grid_arg = f"grid={grid}" + current_device = V.graph.scheduler.get_current_device_or_throw() + index = current_device.index + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + stream_name = f"stream{index}" + result.writeline(f"{stream_name} = get_raw_stream({index})") + result.writeline( + f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})" + ) + + # benchmark all configs + result.writelines(["\n", "\n", "def benchmark_all_configs(args):"]) + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + result.writeline( + f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})" + ) + + result.writelines(["\n", "\n", "if __name__ == '__main__':"]) + with result.indent(): + result.writeline( + "from torch._inductor.runtime.benchmarking import benchmarker" + ) + result.writeline("") + + result.writeline("args = get_args()") + result.writeline( + "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40, fast_flush=True)" + ) + result.writeline(f"num_gb = {num_gb}") + result.writeline("gb_per_s = num_gb / (ms / 1e3)") + result.writeline( + 'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")' + ) + + return result + + def imports_for_benchmark_kernel(self): + return textwrap.dedent( + """ + from torch._dynamo.testing import rand_strided + {} + import torch + from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid + """.format( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + ) + + def _get_heuristic(self): + if self.persistent_reduction: + assert self.inside_reduction + return "persistent_reduction" + elif self.inside_reduction: + return "reduction" + return "pointwise" + + @staticmethod + def inductor_meta_common(): + inductor_meta = { + "backend_hash": torch.utils._triton.triton_hash_with_backend(), + "are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(), + "assert_indirect_indexing": config.assert_indirect_indexing, + "autotune_local_cache": config.autotune_local_cache, + "autotune_pointwise": config.triton.autotune_pointwise, + "autotune_remote_cache": config.autotune_remote_cache, + "force_disable_caches": config.force_disable_caches, + "dynamic_scale_rblock": config.dynamic_scale_rblock, + "max_autotune": config.max_autotune, + "max_autotune_pointwise": config.max_autotune_pointwise, + "min_split_scan_rblock": config.triton.min_split_scan_rblock, + "spill_threshold": config.triton.spill_threshold, + "store_cubin": config.triton.store_cubin, + } + if torch.version.hip is not None: + inductor_meta["is_hip"] = True + if config.is_fbcode(): + inductor_meta["is_fbcode"] = True + if config.profile_bandwidth: + inductor_meta["profile_bandwidth"] = config.profile_bandwidth + inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex + inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output + inductor_meta[ + "profile_bandwidth_with_do_bench_using_profiling" + ] = config.profile_bandwidth_with_do_bench_using_profiling + if config.coordinate_descent_tuning: + inductor_meta[ + "coordinate_descent_tuning" + ] = config.coordinate_descent_tuning + inductor_meta[ + "coordinate_descent_search_radius" + ] = config.coordinate_descent_search_radius + inductor_meta[ + "coordinate_descent_check_all_directions" + ] = config.coordinate_descent_check_all_directions + return inductor_meta + + def codegen_kernel(self, name=None): + code = IndentedBuffer() + + size_hints = [] + for numel in self.numels: + numel_hint = V.graph.sizevars.symbolic_hint(numel) + if not isinstance(numel_hint, (int, sympy.Integer)): + # This default heuristic hint was picked carefully: it is + # large, to ensure that we don't shrink the block size (since + # if you don't have many elements, it'd be wasteful to pick a + # large block size). Since we don't know how many elements we + # might have, we should be OK with some inefficiency to make + # sure we handle the large case well. 8192 is the largest + # block size we support, so we pick that. + # + # If we have a better hint for unbacked SymInts (e.g., because + # a user told us, or we are tracking upper bounds) we could + # use that here. + size_hint = 8192 + else: + size_hint = next_power_of_2(int(numel_hint)) + size_hints.append(size_hint) + + if not self.inside_reduction: + size_hints.pop() + + heuristics = self._get_heuristic() + + if name is None: + code.splice(gen_common_triton_imports()) + + if config.benchmark_kernel: + code.splice(self.imports_for_benchmark_kernel()) + + argdefs, _, signature, _ = self.args.python_argdefs() + # maps actual expression to SizeArg if it is in sizevars replacements + for i, arg in enumerate(signature): + if isinstance(arg, SizeArg): + # mypy is unhappy about the sympy.Expr + # type for the key of the dict below + symbol = cast(sympy.Symbol, arg.expr) + if symbol in V.graph.sizevars.inv_precomputed_replacements: + signature[i] = SizeArg( + arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol] + ) + + mutated_args: OrderedSet[str] = OrderedSet() + for mutation in self.mutations: + if mutation in self.args.input_buffers: + mutated_args.add(self.args.input_buffers[mutation]) + if ( + mutation in self.args.inplace_buffers + and mutation not in V.graph.removed_buffers + and mutation not in self.removed_buffers + ): + mutated_args.add(self.args.inplace_buffers[mutation].inner_name) + if mutation in self.args.output_buffers: + mutated_args.add(self.args.output_buffers[mutation]) + + # workspace arguments are mutated, but are not marked as mutations in self.mutations + # because their buffers are added during codegen, and aren't tracked during + # lowering/scheduling. So we add them as mutated_args explicitly below. + # + # In the logic below, we only mark the workspaces a mutated if they are marked with + # zero_fill: that's because, if we don't expect the buffer to be pre-filled with + # zeros, then, although we still mutate the data, we don't care about those + # mutations because we don't make any assumptions about the contents of the + # workspace buffer. + for argname, arg in zip(argdefs, signature): + if isinstance(arg, WorkspaceArg) and arg.zero_fill: + mutated_args.add(argname) + + mutated_args = sorted(mutated_args) + + triton_meta_signature = signature_to_meta( + signature, size_dtype=self.index_dtype + ) + triton_meta = { + "signature": triton_meta_signature, + "device": DeviceProperties.create( + V.graph.scheduler.get_current_device_or_throw() + ), + "constants": {}, + } + + inductor_meta = { + "autotune_hints": set(self.autotune_hints), + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "mutated_arg_names": mutated_args, + "no_x_dim": self.no_x_dim, + "num_load": self.num_load, + "num_reduction": self.num_reduction, + **self.inductor_meta_common(), + } + + num_gb = None + if config.benchmark_kernel or config.profile_bandwidth: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + + for tree in self.active_range_trees(): + sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) + signature.append(sizearg) + triton_meta_signature[len(argdefs)] = signature_of( + sizearg, size_dtype=self.index_dtype + ) + argdefs.append(f"{tree.prefix}numel") + # constexpr version causes issues, see + # https://github.com/pytorch/torchdynamo/pull/1362 + # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint( + # tree.numel + # ) + # argdefs.append(f"{tree.prefix}numel: tl.constexpr") + triton_meta["configs"] = [config_of(signature)] + + # Triton compiler includes equal_to_1 args into constants even + # when they are not constexpr. otherwise there may be a segfault + # during launching the Inductor-compiled Triton kernel. + # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 + # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 + for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] + triton_meta["constants"][arg_num] = 1 # type: ignore[index] + + self.triton_meta = triton_meta + + for tree in self.range_trees: + if tree.prefix == "r" and self.persistent_reduction: + # RBLOCK for persistent_reduction is defined in codegen_static_numels + continue + if tree.tensor_dim is None: + continue + argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr") + + self.codegen_body() + + for helper in self.helper_functions: + code.writeline("") + code.splice(helper) + + if self.inside_reduction: + reduction_hint = self.reduction_hint + heuristics_line = f""" + @triton_heuristics.{heuristics}( + size_hints={size_hints!r}, + reduction_hint={reduction_hint}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + if len(signature) == 4: # input, output and 2 args + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @triton_heuristics.{heuristics}( + size_hints={size_hints!r}, {tile_hint} + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + min_elem_per_thread={self.min_elem_per_thread} + ) + @triton.jit + """ + code.splice(heuristics_line) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):" + ) + with code.indent(): + self.codegen_static_numels(code) + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.body) + + if config.benchmark_kernel: + code.splice(self.codegen_kernel_benchmark(num_gb)) + + return code.getvalue() + + def _get_persistent_RBLOCK(self, rnumel): + rnumel = V.graph.sizevars.simplify(rnumel) + if isinstance(rnumel, (sympy.Integer, int)): + val = int(rnumel) + val = next_power_of_2(val) + else: + val = 128 + while not V.graph.sizevars.statically_known_leq(rnumel, val): + assert val <= 16 * 1024, f"Failed to find static RBLOCK for {rnumel}" + val *= 2 + return val + + def codegen_static_numels(self, code): + """ + We get a small speedup from hard coding numels if they are static. + + This code stomps on the passed-in values by writing an constant to the top of the kernel. + + In a kernel like: + def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): + + We would add + xnumel = 4096 + rnumel = 768 + + After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes + a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream + knows that its a static numel, as that you just plop a constant into the kernel. + """ + for tree in self.range_trees: + if tree.prefix != "r" or self.inside_reduction: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if isinstance(simplified_tree_numel, (sympy.Integer, int)): + code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") + + if tree.prefix == "r" and self.persistent_reduction: + val = self._get_persistent_RBLOCK(tree.numel) + code.writeline(f"RBLOCK: tl.constexpr = {val}") + + if tree.prefix == "x" and self.no_x_dim: + code.writeline("XBLOCK: tl.constexpr = 1") + + def _get_grid_fn(self): + return "grid" + + def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid): + # TODO(jansel): if there are constants, we shouldn't bother passing them as args + for tree in self.range_trees: + if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)): + expr = tree.numel + else: + expr = V.graph.wrapper_code.generate_numel_expr(name, tree) + + if tree.prefix != "r" or self.inside_reduction: + call_args.append(expr) + arg_types.append(type(expr)) + if tree.grid_dim is not None: + grid.append(expr) + + def call_kernel(self, name: str, node: Optional[IRNode] = None): + wrapper = V.graph.wrapper_code + wrapper.write_triton_header_once() + _, call_args, _, arg_types = self.args.python_argdefs() + grid: List[Any] = [] + self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid) + current_device = V.graph.scheduler.get_current_device_or_throw() + + if self.args.workspace_arg is not None: + ws = self.args.workspace_arg + wrapper.generate_workspace_allocation( + ws.nbytes, current_device, ws.zero_fill + ) + + grid = wrapper.generate_default_grid(name, grid) + wrapper.generate_kernel_call( + name, + call_args, + grid, + current_device.index, + cuda=True, + triton=True, + arg_types=arg_types, + grid_fn=self._get_grid_fn(), + triton_meta=self.triton_meta, + ) + + if self.args.workspace_arg is not None: + wrapper.writeline(wrapper.make_free_by_names(["workspace"])) + + def codegen_nan_check(self): + wrapper = V.graph.wrapper_code + _, call_args, arg_signatures, _ = self.args.python_argdefs() + for arg, arg_signature in zip(call_args, arg_signatures): + if isinstance(arg_signature, TensorArg): + if V.graph.cpp_wrapper: + if config.abi_compatible: + wrapper.writeline( + f'AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan("{arg}", {arg}));' + ) + else: + wrapper.writeline(f'assert_inf_and_nan("{arg}", {arg});') + else: + line = f"assert not {arg}.isnan().any().item()" + wrapper.writeline(line) + line = f"assert not {arg}.isinf().any().item()" + wrapper.writeline(line) + + def create_cse_var(self, *args, **kwargs): + return TritonCSEVariable(*args, **kwargs) + + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): + line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}" + if entry.root.is_loop: + self.indexing_code.writeline(line) + else: + # lift non-reduction stores outside loop + self.body.writeline(line) + + def iteration_ranges_ranges_code(self, entry): + assert entry.tensor_dim is not None + size = self.indexing_size_str(entry.tensor_dim) + index_dtype = self.index_dtype + convert = f".to({index_dtype})" if index_dtype != "tl.int32" else "" + return f"tl.arange(0, {entry.prefix.upper()}BLOCK){size}{convert}" + + def iteration_ranges_scalar_code(self, entry, value): + index_dtype = self.index_dtype + ndim = self.triton_tensor_ndim() + size = [1] * ndim + return f"tl.full({size}, {value}, {index_dtype})" + + def iteration_ranges_get_pid(self, entry): + assert entry.grid_dim is not None + key = f"tl.program_id({entry.grid_dim})" + # y_grid has a limit, so express it in terms of y and z in case of overflow. + # z grid is only exercised when max_tiles == 3 (off by default). + if ( + entry.grid_dim == 1 + and not entry.has_zdim + and not V.graph.sizevars.statically_known_leq(entry.numel, get_max_y_grid()) + ): + # For ynumel larger than max_ygrid, we need to use zdim. + # For each z dimension, there are tl.num_programs(1) yblocks which is passed by grad(x,y,z). + # So, we need to add tl.program_id(z) * tl.num_programs(y) *YBLOCK to get the correct yoffset. + key = f"({key} + tl.program_id({entry.grid_dim + 1}) * tl.num_programs({entry.grid_dim}))" + pid = entry.pid_cache.get(key, key) + if self.index_dtype != "tl.int32": + return f"{pid}.to({self.index_dtype})" + return pid + + def _has_constant_mask(self, tree: IterationRangesRoot): + if not self.optimize_mask: + return False + if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type] + return True + # Masks are superfluous if numel is a multiple of BLOCK + # (We use the fact that BLOCK is required by triton to be a power of 2) + if tree.prefix == "r" and self.persistent_reduction: + max_block = self._get_persistent_RBLOCK(tree.numel) + elif tree.prefix == "x" and self.no_x_dim: + max_block = 1 + else: + if tree.prefix.upper() not in TRITON_MAX_BLOCK: + return False + max_block = TRITON_MAX_BLOCK[tree.prefix.upper()] + + # Optional optimization: if block divides numel exactly, we will + # never need to do a masked load to handle stragglers at the end. + # It's faster to avoid masking at all. But it is sound to always + # mask. + return V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block) + + def filter_masks(self, mask_vars): + for tree in self.range_trees: + if self._has_constant_mask(tree): + mask_vars.discard(f"{tree.prefix}mask") + + def iteration_ranges_codegen_header(self, entry, code): + x = entry.prefix + if entry.is_loop: + code.writeline(f"{entry.name} = {x}offset + {x}base") + elif entry.grid_dim is None: + # no need to "{x}offset = " + code.writeline(f"{entry.name} = {self.iteration_ranges_ranges_code(entry)}") + code.writeline(f"{x}offset = 0") + else: + if entry.tensor_dim is not None: + line = f"{x}offset + {self.iteration_ranges_ranges_code(entry)}" + else: + line = self.iteration_ranges_scalar_code(entry, f"{x}offset") + code.writelines( + [ + f"{x}offset = {self.iteration_ranges_get_pid(entry)} * {x.upper()}BLOCK", + f"{entry.name} = {line}", + ] + ) + + if self._has_constant_mask(entry): + sizes = self.dense_size_str() + code.writeline(f"{x}mask = tl.full({sizes}, True, tl.int1)") + else: + code.writeline(f"{x}mask = {entry.name} < {x}numel") + + +class TritonScheduling(SIMDScheduling): + int32_type = "tl.int32" + int64_type = "tl.int64" + kernel_type = TritonKernel + backend_features = dict.fromkeys( # dict for deterministic order + [ + BackendFeature.FOREACH, + BackendFeature.BUCKETIZE, + BackendFeature.INPLACE_BUFFERS, + BackendFeature.MASKED_SCATTER_WITH_INDEX, + BackendFeature.SCAN, + BackendFeature.TRITON_TEMPLATES, + ] + ) + if torch.version.hip is None: + backend_features.update( + dict.fromkeys( + [ + # TODO: Move this above when ROCm triton adds support for multiple inputs + BackendFeature.TUPLE_REDUCTION, + BackendFeature.SORT, + ] + ) + ) + + @classmethod + def get_backend_features(cls, device: torch.device): + return cls.backend_features + + def codegen_comment(self, node_schedule): + wrapper = V.graph.wrapper_code + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + if origins: + wrapper.writeline(origins) + + if config.debug_fusion: + from torch._inductor.scheduler import ( + BaseSchedulerNode, + ForeachKernelSchedulerNode, + ) + + if not any( + isinstance(n, ForeachKernelSchedulerNode) for n in node_schedule + ): + # We probably should look what are the nodes inside a foreach + # schedule node + node_names = [ + n.get_name() + for n in node_schedule + if isinstance(n, BaseSchedulerNode) + ] + wrapper.writeline( + f"{wrapper.comment} Fused node name list: {', '.join(node_names)}" + ) + + def define_kernel(self, src_code, node_schedule, kernel): + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_category = get_kernel_category_by_source_code(src_code)[:3] + kernel_name = "_".join( + ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] + ) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" + + # DESCRIPTIVE_NAME is used for profiling purposes; it shows the full kernel name + # even when unique_kernel_names is turned off. Meanwhile, KERNEL_NAME is sometimes set + # to "triton_" to maximize caching opportunities (when unique_kernel_names = False). + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), subs_name) + + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + src_code = src_code.replace("#pragma CMT", "#") + + basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''") + compile_wrapper.splice(src_code, strip=True) + current_device = V.graph.scheduler.get_current_device_or_throw() + compile_wrapper.writeline(f"''', device_str='{current_device.type}')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + + # log kernel metadata for offline analysis. + # E.g. one can find all unaligned inner reduction and check if + # padding helps with the perf kernel by kernel. + if is_metric_table_enabled("kernel_metadata"): + log_kernel_metadata(kernel_name, kernel_path, src_code) + + return kernel_name + + def benchmark_fused_nodes(self, nodes): + with preserve_rng_state(): + src_code = self.generate_kernel_code_from_nodes( + nodes, benchmark_kernel=True + ) + mod = PyCodeCache.load(src_code) + + def cache_file_path(): + assert mod.__file__ is not None + return os.path.splitext(mod.__file__)[0] + ".kernel_perf" + + def load_cache(): + path = cache_file_path() + if os.path.exists(path): + with open(path) as fd: + return float(fd.read()) + return None + + def store_cache(): + path = cache_file_path() + with open(path, "w") as fd: + fd.write(str(ms)) + + log.debug( + "kernel src code for %s written to: %s", + {n.get_name() for n in nodes}, + mod.__file__, + ) + ms = load_cache() + if ms is not None: + return ms, mod.__file__ + + args = mod.get_args() + call = mod.call + wrapped_jit_function = mod.triton_ + + # call once to trigger the compilation + try: + call(wrapped_jit_function.clone_args(*args)[0]) + except Exception as e: + log.debug( + "Exception (%s) in compiling fused nodes %s", + e, + {n.get_name() for n in nodes}, + ) + ms = float("inf") + store_cache() + return ms, mod.__file__ + + launchers = wrapped_jit_function.launchers + assert len(launchers) == 1 + if launchers[0].n_spills > 0: + # skip benchmarking the kernel if there are register spills + ms = float("inf") + else: + # We have to clone the inplace updated arguments to avoid earlier calls + # generating out of range indices for later calls. + ms = benchmarker.benchmark_gpu( + lambda: call(wrapped_jit_function.clone_args(*args)[0]) + ) + + # overhead of cloning args gives bias for fusing the kernel + # in the case of mutating/in-placeable second fusion + # TODO - would be better as a hook in triton do_bench that reset + # the input values between benchmarking + ms = ms - benchmarker.benchmark_gpu( + lambda: wrapped_jit_function.clone_args(*args) + ) + + log.debug( + "The fused kernel for %s took %.3f ms to run", + {n.get_name() for n in nodes}, + ms, + ) + store_cache() + return ms, mod.__file__ + + def benchmark_combo_kernel(self, node_list): + def cache_file_path(): + assert mod.__file__ is not None + return os.path.splitext(mod.__file__)[0] + ".kernel_perf" + + def load_cache(): + path = cache_file_path() + if os.path.exists(path): + with open(path) as fd: + return tuple(float(e) for e in fd.read().split()) + return (None, None) + + def store_cache(): + path = cache_file_path() + with open(path, "w") as fd: + fd.write(str(ms) + " " + str(ms_clone)) + + total_ms, file_list = 0, [] + total_clone_ms = 0 + removed_buffers_orig = V.graph.removed_buffers + V.graph.removed_buffers = OrderedSet(removed_buffers_orig) + inplaced_to_remove_orig = V.graph.inplaced_to_remove + V.graph.inplaced_to_remove = OrderedSet(inplaced_to_remove_orig) + enable_autotune = config.combo_kernels_autotune > 0 + mixed_sizes = config.combo_kernel_allow_mixed_sizes > 0 + kernel_code_list = self.generate_combo_kernel_code( + subkernel_nodes=node_list, + custom_part_algorithm=True, + enable_autotune=enable_autotune, + mixed_sizes=mixed_sizes, + only_gen_src_code=True, + ) + + for src_code, _, node_group in kernel_code_list: + fused_node_lists = [node.get_nodes() for node in node_group] + names = [n.get_name() for nodes in fused_node_lists for n in nodes] + + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + mod = PyCodeCache.load(src_code) + + log.debug( + "kernel src code for %s written to: %s", + names, + mod.__file__, + ) + ms, ms_clone = load_cache() + if ms is not None: + total_ms += ms + total_clone_ms += ms_clone + file_list.append(mod.__file__) + continue + + args = mod.get_args() + call = mod.call + wrapped_jit_function = mod.triton_ + + # call once to trigger the compilation + call(wrapped_jit_function.clone_args(*args)[0]) + + launchers = wrapped_jit_function.launchers + assert len(launchers) == 1 + if launchers[0].n_spills > 0: + # skip benchmarking the kernel if there are register spills + ms = ms_clone = float("inf") + else: + # We have to clone the inplace updated arguments to avoid earlier calls + # generating out of range indices for later calls. + ms = benchmarker.benchmark_gpu( + lambda: call(wrapped_jit_function.clone_args(*args)[0]) + ) + ms_clone = benchmarker.benchmark_gpu( + lambda: wrapped_jit_function.clone_args(*args)[0] + ) + + log.debug( + "The fused kernel for %s took %.3f ms to run, %.3f ms to clone inputs", + {n.get_name() for n in node_group}, + ms, + ms_clone, + ) + store_cache() + total_ms += ms + total_clone_ms += ms_clone + file_list.append(mod.__file__) + V.graph.removed_buffers = removed_buffers_orig + V.graph.inplaced_to_remove = inplaced_to_remove_orig + return total_ms, total_clone_ms, file_list diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_combo_kernel.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_combo_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..abfd1dc76c698e044cc1e417b23ad47253328829 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_combo_kernel.py @@ -0,0 +1,1123 @@ +import itertools +import logging +import textwrap +from collections import defaultdict +from dataclasses import dataclass +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + Union, +) + +from sympy import Integer, Symbol + +from torch.utils._ordered_set import OrderedSet + +from .. import config, metrics +from ..runtime.hints import DeviceProperties, ReductionHint +from ..runtime.runtime_utils import next_power_of_2 +from ..runtime.triton_heuristics import grid_combo_kernels +from ..scheduler import BaseSchedulerNode +from ..utils import Placeholder +from ..virtualized import V +from .common import ( + DeferredLine, + IndentedBuffer, + Kernel, + PythonPrinter, + SizeArg, + WorkspaceArg, +) +from .simd import SIMDScheduling +from .triton import gen_common_triton_imports, TritonKernel +from .triton_utils import config_of, signature_to_meta + + +log = logging.getLogger(__name__) +pexpr = PythonPrinter().doprint +LARGE_NUMELS = 512e5 +BLOCK_UTILIZATION = 0.8 + + +def _default_custom_combo_kernel_horizontal_partition( + nodes: List[BaseSchedulerNode], + triton_scheduling: SIMDScheduling, + kernel_map: Dict[BaseSchedulerNode, TritonKernel], + node_info_map: Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]], +) -> List[List[BaseSchedulerNode]]: + """Horizontally partition the given list of nodes into a list of list of nodes where each sublist + represents a partion. Nodes in different partitions are implemented in different combo kernels. + Nodes in the same partition are likely to be implemented + in the same combo kernel, but subject to subsequent restrictions like CUDA limits for number of args. + + Input arguments: + nodes: a list of fused scheduler nodes to partition. + triton_scheduling: TritonScheduling instance. + kernel_map: a map from node to its kernel. + node_info_map: a map from node to (node_schedule, tiled_groups, numel, rnumel). + Output: + a list of list of nodes with each sublist representing a partition. + + The default algorithm is to partition nodes based on the following rules: + 1) nodes with the same number of block dimensions are grouped together. + 2) large pointwise nodes (numels greater than LARGE_NUMELS) are separated from other nodes. + 3) large reduce nodes are separated from other nodes. + """ + + assert len(nodes) >= 1 + + # first partition nodes based on number of block dimensions + tilings = [node_info_map[n][1] for n in nodes] + + max_dims = max(len(t) for t in tilings) + nodes_per_ndim = [] + for i in range(2, max_dims + 1): + group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i] + reduction = [ + n + for n in group_per_dim + if kernel_map[n].inside_reduction + and not (kernel_map[n].persistent_reduction and kernel_map[n].no_x_dim) + ] + not_reduction = [n for n in group_per_dim if n not in reduction] + # rnumel > 2048 usually has long execution time + # BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes + long_reduction = [ + n for n in reduction if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 + ] + short_reduction = [n for n in reduction if n not in long_reduction] + if long_reduction: + log.warning( + "ComboKernels: %d long reduction nodes are separated", + len(long_reduction), + ) + large_pointwise = [ + n + for n in not_reduction + if not kernel_map[n].inside_reduction + and len(kernel_map[n].numels) == 2 + and V.graph.sizevars.size_hint(kernel_map[n].numels[0]) > LARGE_NUMELS + ] + if large_pointwise: + # TODO benchmark the performance when large pointwise nodes combining with others + log.warning( + "ComboKernels: %d large pointwise nodes are separated", + len(large_pointwise), + ) + not_reduction = [n for n in not_reduction if n not in large_pointwise] + for node in large_pointwise: + nodes_per_ndim.append([node]) + + for g in (not_reduction, short_reduction, long_reduction): + if g: + nodes_per_ndim.append(g) + + assert sum(len(p) for p in nodes_per_ndim) == len(nodes) + return nodes_per_ndim + + +_custom_combo_kernel_horizontal_partition_algorithm: Callable[ + [ + List[BaseSchedulerNode], + SIMDScheduling, + Dict[BaseSchedulerNode, TritonKernel], + Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]], + ], + List[List[BaseSchedulerNode]], +] = _default_custom_combo_kernel_horizontal_partition + + +def set_custom_combo_kernel_horizontal_partition( + algorithm: Callable[ + [ + List[BaseSchedulerNode], + SIMDScheduling, + Dict[BaseSchedulerNode, TritonKernel], + Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]], + ], + List[List[BaseSchedulerNode]], + ] +) -> None: + """Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions + are implemented in different combo kernels. Nodes in the same partition are likely to be implemented + in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args. + + The algorithm should take a list of nodes and return a list of list of nodes. + + The default algorithm is to partition nodes based on number of block dimensions. + """ + global _custom_combo_kernel_horizontal_partition_algorithm + _custom_combo_kernel_horizontal_partition_algorithm = algorithm + + +@dataclass +class PartitionState: + partitions: List[List[BaseSchedulerNode]] + cur_partition: List[BaseSchedulerNode] + cur_count: int + + def finalize(self) -> None: + if self.cur_partition: + self.partitions.append(self.cur_partition) + + +class ComboKernel(Kernel): + MAX_NUM_ARGS = 250 # number where I would no longer get triton errors + + @staticmethod + def _update_partition( + partition_state: PartitionState, + node_rw_count: int, + node_info: BaseSchedulerNode, + ) -> None: + if partition_state.cur_count + node_rw_count > ComboKernel.MAX_NUM_ARGS: + partition_state.partitions.append(partition_state.cur_partition) + partition_state.cur_partition = [node_info] + partition_state.cur_count = node_rw_count + else: + partition_state.cur_count += node_rw_count + partition_state.cur_partition.append(node_info) + + @staticmethod + def _base_horizontal_partition( + subkernel_nodes: List[BaseSchedulerNode], + triton_scheduling: SIMDScheduling, + node_info_map: Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]], + custom_algorithm: bool, + ) -> List[List[BaseSchedulerNode]]: + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel) + for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args + (read/writes) and to have the same 2D or 1D blocking strategy.""" + # TODO support combination of kernels with different block dimensions + assert len(subkernel_nodes) >= 1 + mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or ( + config.combo_kernel_allow_mixed_sizes == 1 and custom_algorithm + ) + + ndim_to_partition_state: Dict[int, PartitionState] = defaultdict( + lambda: PartitionState([], [], 0) + ) + yelem_to_partition_state: Dict[int, PartitionState] = defaultdict( + lambda: PartitionState([], [], 0) + ) + + for node in subkernel_nodes: + node_schedule, tiled_groups, numel, rnumel = node_info_map[node] + node_info = node + + read_writes = node.read_writes + read_write_count = len(read_writes.reads) + len(read_writes.writes) + + ndim = len(tiled_groups) + assert ndim >= 2, f"Combokernel not support tile {tiled_groups}" + if not mixed_sizes and ndim == 3: + y_elem = tiled_groups[0] + partition_state = yelem_to_partition_state[y_elem] + ComboKernel._update_partition( + partition_state, read_write_count, node_info + ) + else: + assert mixed_sizes or ndim <= 3, f"No mixed sizes: tile {tiled_groups}" + partition_state = ndim_to_partition_state[ndim] + ComboKernel._update_partition( + partition_state, read_write_count, node_info + ) + + all_partitions = [] + for partition_state in ndim_to_partition_state.values(): + partition_state.finalize() + all_partitions.extend(partition_state.partitions) + for partition_state in yelem_to_partition_state.values(): + partition_state.finalize() + all_partitions.extend(partition_state.partitions) + + return all_partitions + + @staticmethod + def horizontal_partition( + nodes: List[BaseSchedulerNode], + triton_scheduling: SIMDScheduling, + kernel_map: Dict[BaseSchedulerNode, TritonKernel], + node_info_map: Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]], + custom_algorithm: bool = False, + ) -> List[List[BaseSchedulerNode]]: + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnum) + for each subkernel node where each sublist forms a ComboKernel. It horizontally partitions nodes into + sublists in the following way: + 1) call _custom_combo_kernel_horizontal_partition_algorithm() if custom_algorithm is True + 2) then, call _base_horizontal_partition() to partition nodes into sublists, each sublist is + guaranteed to not exceed CUDA limits for number of args (read/writes) and to have the same + 2D or 1D blocking strategy. + """ + if custom_algorithm: + raw_partitions = _custom_combo_kernel_horizontal_partition_algorithm( + nodes, triton_scheduling, kernel_map, node_info_map + ) + else: + raw_partitions = [nodes] + + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel) + for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args + (read/writes) and to have the same 2D or 1D blocking strategy.""" + all_partitions = [] + for raw_partition in raw_partitions: + all_partitions.extend( + ComboKernel._base_horizontal_partition( + raw_partition, triton_scheduling, node_info_map, custom_algorithm + ) + ) + return all_partitions + + class SequentialDispatch: + """ + The dispatcher which dispatches the subkernels in a sequential manner: + the blocks are first dispatched to the 1st subkernel (until it is filled), + then to the 2nd subkernel, and so on. + The class defines the methods specific to the dispatch algorithm. + Methods: + codegen_pid_range(...): codegen the pid range for each subkernel. + grid(...): codegen the grid size for launching the combo kernel. + """ + + @classmethod + def codegen_pid_range( + cls, kernel: "ComboKernel", num: int, code: IndentedBuffer + ) -> None: + if num == 0: + cls._calculate_xblocks(kernel, code) + code.splice(f"if pid < num_xblocks_{num}:") + with code.indent(): + code.splice("pid_offset = pid") + else: + code.splice(f"elif pid < num_xblocks_{num}:") + with code.indent(): + code.splice(f"pid_offset = pid - num_xblocks_{num-1}") + + @classmethod + def _calculate_xblocks( + cls, kernel: "ComboKernel", code: IndentedBuffer + ) -> None: + x_numels_list = kernel.x_numels_list + for i in range(len(x_numels_list)): + xnumels, no_x_dim = ( + (x_numels_list[i], False) + if isinstance(x_numels_list[i], str) + and cast(str, x_numels_list[i])[0] != "-" + or ( + isinstance(x_numels_list[i], int) + and cast(int, x_numels_list[i]) > 0 + ) + else (kernel.min_x_blocks_list[i], True) + ) + xblock_str = ( + f"tl.cdiv({xnumels}, XBLOCK)" if not no_x_dim else f"{xnumels}" + ) + if i == 0: + code.splice(f"num_xblocks_{i} = {xblock_str}") + else: + code.splice(f"num_xblocks_{i} = num_xblocks_{i-1} + {xblock_str}") + + @classmethod + def grid( + cls, + sub_kernel_numels: List[List[int]], + x_blocks_list: List[Union[str, int]], + dynamic_shape: bool, + ) -> Tuple[Any, ...]: + xnumel = list(x_blocks_list) + ynumel: Any = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels] + znumel: Any = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels] + + if dynamic_shape: + ynumel = None if None in ynumel else ynumel + znumel = None if None in znumel else znumel + else: + # TODO: improve 1d/2d mixed cases + ynumel = ( + None + if any(e is None for e in cast(List[Any], ynumel)) + else max(cast(Iterable[int], ynumel)) + ) + znumel = ( + None + if any(e is None for e in cast(List[Any], znumel)) + else max(cast(Iterable[int], znumel)) + ) + + numels = ( + (xnumel,) + if not ynumel + else (ynumel, xnumel) + if not znumel + else (znumel, ynumel, xnumel) + ) + return numels + + class RoundRobinDispatch: + """ + The dispatcher which dispatches the subkernels in a round robin manner: + the blocks are interleavedly dispatched to each subkernel to execute them + in parallel. + The class defines the methods specific to the dispatch algorithm. + Methods: + codegen_pid_range(...): codegen the pid range for each subkernel. + grid(...): codegen the grid size for launching the combo kernel. + """ + + @classmethod + def codegen_pid_range( + cls, kernel: "ComboKernel", num: int, code: IndentedBuffer + ) -> None: + num_kernels = len(kernel.sub_kernels) + if num == 0: + cond = "if" + else: + cond = "elif" + code.splice(f"{cond} pid % {num_kernels} == {num}:") + with code.indent(): + code.splice(f"pid_offset = pid // {num_kernels}") + + @classmethod + def grid( + cls, + sub_kernel_numels: List[List[int]], + x_blocks_list: List[Union[str, int]], + dynamic_shape: bool, + ) -> Tuple[Any, ...]: + xnumel = x_blocks_list + # set no_x_dim xnumels to 0 + xnumel_x_dim = [max(e, 0) for e in xnumel] + ynumel = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels] + znumel = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels] + + # TODO: support 1d/2d mixed cases + xnumel = ( + None + if any(e is None for e in xnumel) + else xnumel + if dynamic_shape + else max(xnumel_x_dim) # type: ignore[type-var, arg-type] + ) + ynumel = ( + None + if any(e is None for e in ynumel) + else ynumel + if dynamic_shape + else max(ynumel) # type: ignore[type-var, arg-type] + ) + znumel = ( + None + if any(e is None for e in znumel) + else znumel + if dynamic_shape + else max(znumel) # type: ignore[type-var, arg-type] + ) + + numels = ( + (xnumel,) + if not ynumel + else (ynumel, xnumel) + if not znumel + else (znumel, ynumel, xnumel) + ) + return numels + + def __init__( + self, enable_autotune: bool = False, mixed_sizes: bool = False + ) -> None: + super().__init__() + self.sub_kernels: List[TritonKernel] = [] + self.iter_vars_count = itertools.count() + self.grids: List[List[int]] = [] + self.min_x_blocks_list: List[Union[int, str]] = [] + self.x_numels_list: List[Union[int, str]] = [] + self.enable_autotune = enable_autotune + self.mixed_sizes = mixed_sizes + self.dispatch_class: Optional[ + Union[ + Type[ComboKernel.SequentialDispatch], + Type[ComboKernel.RoundRobinDispatch], + ] + ] = None + self.block_args: List[str] = [] + # there following are used when autotuning is disabled + self.block_size_1d = 1024 # Try tuning this value + self.block_size_2d = 32 + self.num_warps = 8 + self.block_size_reduce = 256 + self.dynamic_shape_args: List[str] = [] + + def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel: + sub_kernel = triton_kernel + metrics.generated_kernel_count -= 1 + sub_kernel.args = self.args + sub_kernel.iter_vars_count = self.iter_vars_count + sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids + self.sub_kernels.append(sub_kernel) + return sub_kernel + + @staticmethod + def create_triton_kernel( + *groups: Any, + index_dtype: str, + mutations: OrderedSet[str], + reduction_hint: ReductionHint, + optimize_mask: bool, + ) -> TritonKernel: + """ + Only allow optimize_mask=True when 1) sequential dispatch is used, + 2) numels except x dimension are the same for each sub kernel. + """ + return TritonKernel( + *groups, + index_dtype=index_dtype, + mutations=mutations, + pid_cache={"tl.program_id(0)": "pid_offset"}, + reduction_hint=reduction_hint, + optimize_mask=optimize_mask, + ) + + def codegen_static_numels_sub_kernel( + self, code: IndentedBuffer, sub_kernel: TritonKernel, num: int + ) -> List[str]: + """ + We get a small speedup from hard coding numels if they are static. + + This code stomps on the passed-in values by writing an constant to the top of the kernel. + + In a kernel like: + def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): + + We would add + xnumel = 4096 + rnumel = 768 + + After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes + a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream + knows that its a static numel, as that you just plop a constant into the kernel. + """ + grid = [] + uniquify_block_sizes = [] + for tree in sub_kernel.range_trees: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if isinstance(simplified_tree_numel, (Integer, int)): + code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") + else: + assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args + uniquify_block_sizes.append(f"{tree.prefix}numel") + + if tree.prefix != "r": + if isinstance(simplified_tree_numel, (Integer, int)): + grid.append(int(simplified_tree_numel)) + else: + grid.append(f"{tree.prefix}numel_{num}") + + if tree.prefix == "r" and sub_kernel.persistent_reduction: + if isinstance(simplified_tree_numel, (Integer, int)): + val = int(simplified_tree_numel) + else: + raise RuntimeError( + "Dynamic shape on reduction dimension is not supported" + ) + val = next_power_of_2(val) + code.writeline(f"RBLOCK_{num}: tl.constexpr = {val}") + uniquify_block_sizes.append("RBLOCK") + + if tree.prefix == "x" and sub_kernel.no_x_dim: + code.writeline(f"XBLOCK_{num}: tl.constexpr = 1") + uniquify_block_sizes.append("XBLOCK") + self.grids.append(grid) + return uniquify_block_sizes + + def min_x_blocks_sub_kernel(self, sub_kernel: TritonKernel, num: int) -> None: + """ + Kernels with no_x_dim being true has no tunable XBLOCK. They have a fixed number of X blocks. + Grid calculation needs to make sure that they are assigned with enough number of blocks. + """ + min_x_blocks: Union[int, str] = 0 + x_numels: Union[int, str] = 0 + for tree in sub_kernel.range_trees: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if tree.prefix == "x": + if isinstance(simplified_tree_numel, (Integer, int)): + x_numels = int(simplified_tree_numel) + else: + x_numels = f"{tree.prefix}numel_{num}" + if sub_kernel.no_x_dim: + min_x_blocks = x_numels + x_numels = ( + -min_x_blocks + if isinstance(x_numels, int) + else "-" + cast(str, x_numels) + ) + else: + if isinstance(simplified_tree_numel, (Integer, int)): + x_numels = int(simplified_tree_numel) + else: + x_numels = f"{tree.prefix}numel_{num}" + self.min_x_blocks_list.append(min_x_blocks) + self.x_numels_list.append(x_numels) + + def select_heuristics(self, sub_kernel: TritonKernel) -> Tuple[str, List[int]]: + size_hints = [ + next_power_of_2(V.graph.sizevars.size_hint(numel)) + for numel in sub_kernel.numels + ] + if sub_kernel.persistent_reduction: + assert sub_kernel.inside_reduction + heuristics = "persistent_reduction" + elif sub_kernel.inside_reduction: + heuristics = "reduction" + else: + size_hints.pop() + heuristics = "pointwise" + return heuristics, size_hints + + def select_combo_heuristics( + self, heuristics_list: List[str], size_hints_list: List[List[int]] + ) -> Tuple[str, List[int], TritonKernel]: + if not self.enable_autotune: + return "foreach", size_hints_list[0], self.sub_kernels[0] + if "reduction" in heuristics_list: + i, _ = max( + enumerate(size_hints_list), + key=lambda x: x[1][0] if heuristics_list[x[0]] == "reduction" else 0, + ) + return heuristics_list[i], size_hints_list[i], self.sub_kernels[i] + elif "pointwise" in heuristics_list: + i, _ = max( + enumerate(size_hints_list), + key=lambda x: x[1][0] if heuristics_list[x[0]] == "pointwise" else 0, + ) + # modify size_hint to avoid oom check fail (may be a false alarm) + num_pointwise = len([e for e in heuristics_list if e == "pointwise"]) + num_reduction = len([e for e in heuristics_list if e == "reduction"]) + num_persistent_reduction = len( + [e for e in heuristics_list if e == "persistent_reduction"] + ) + assert ( + num_reduction == 0 + ), "combining pointwise and reduction are not supported yet." + heuristics = ( + "pointwise_with_reduction" + if num_persistent_reduction > 0 + else "pointwise" + ) + if len(heuristics_list) - num_pointwise >= 4: + size_hints = size_hints_list[i] + size_hints[0] = min(128, size_hints[0]) + return heuristics, size_hints_list[i], self.sub_kernels[i] + else: + return heuristics_list[0], size_hints_list[0], self.sub_kernels[0] + + def get_mutated_args_sub_kernels(self) -> List[str]: + mutated_args = set() + for sub_kernel in self.sub_kernels: + for mutation in sub_kernel.mutations: + if mutation in sub_kernel.args.input_buffers: + mutated_args.add(sub_kernel.args.input_buffers[mutation]) + if ( + mutation in sub_kernel.args.inplace_buffers + and mutation not in V.graph.removed_buffers + and mutation not in sub_kernel.removed_buffers + ): + mutated_args.add( + sub_kernel.args.inplace_buffers[mutation].inner_name + ) + if mutation in sub_kernel.args.output_buffers: + mutated_args.add(sub_kernel.args.output_buffers[mutation]) + return sorted(mutated_args) + + def select_dispatch_strategy(self) -> None: + if self.dispatch_class is not None: + return + # mixed_sizes is used for optimize_mask, so it only allows sequential dispatch + # Not mixed sizes on y dim technically is ok to use round robin as wells. + if not self.mixed_sizes or any(isinstance(e, str) for e in self.x_numels_list): + # str in min_x_blocks_list means a dynamic shape + self.dispatch_class = ComboKernel.SequentialDispatch + return + # A negative x_blocks_list element means the kernel is not tunable, + # i.e., no_x_dim = True + x_numels_list = [abs(cast(int, e)) for e in self.x_numels_list] + total = max(x_numels_list) * len(x_numels_list) + needed = sum(x_numels_list) + if needed / total > BLOCK_UTILIZATION: + # Introduced overhead (masked blocks) is less than 20% + self.dispatch_class = ComboKernel.RoundRobinDispatch + else: + self.dispatch_class = ComboKernel.SequentialDispatch + + def jit_line( + self, + heuristics: str, + size_hints: List[int], + selected_kernel: TritonKernel, + pointwise_with_reduce: bool = False, + signature: Optional[List[Any]] = None, + ) -> str: + can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels) + size_dtype = "tl.int32" if can_use_32bit else "tl.int64" + if signature is None: + _, _, signature, _ = self.args.python_argdefs() + for i, sub in enumerate(self.sub_kernels): + self.min_x_blocks_sub_kernel(sub, i) + self.select_dispatch_strategy() + triton_meta = { + "signature": signature_to_meta(signature, size_dtype=size_dtype), + "device": DeviceProperties.create( + V.graph.scheduler.get_current_device_or_throw() + ), + "constants": {}, + } + triton_meta["configs"] = [config_of(signature)] + mutated_args = self.get_mutated_args_sub_kernels() + inductor_meta = { + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "mutated_arg_names": mutated_args, + **TritonKernel.inductor_meta_common(), + } + + sub_kernel = selected_kernel + if heuristics == "foreach": + heuristics_line = f""" + @triton_heuristics.foreach( + num_warps={self.num_warps}, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + ) + @triton.jit + """ + elif sub_kernel.inside_reduction: + reduction_hint = sub_kernel.reduction_hint + heuristics_line = f""" + @triton_heuristics.{heuristics}( + size_hints={size_hints!r}, + reduction_hint={reduction_hint}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @triton_heuristics.{heuristics}( + size_hints={size_hints!r}, {tile_hint} + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + + return heuristics_line + + def codegen_blocks(self, code: IndentedBuffer) -> None: + for block in self.block_args: + assert block in [ + "XBLOCK", + "YBLOCK", + "RBLOCK", + ], f"{block} is not supported without autotuning" + if "YBLOCK" in self.block_args: + code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}") + code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}") + else: + code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}") + if "RBLOCK" in self.block_args: + code.splice(f"RBLOCK: tl.constexpr = {self.block_size_reduce}") + + def add_blockd_to_args(self, argdefs: List[str]) -> List[str]: + block_args = {} + block_names = {} + for num, sub_kernel in enumerate(self.sub_kernels): + # TODO: we assume all sub_kernels have the same block size + for tree in sub_kernel.range_trees: + if tree.prefix == "r" and ( + not sub_kernel.inside_reduction or sub_kernel.persistent_reduction + ): + continue + if tree.prefix == "x" and sub_kernel.no_x_dim: + continue + block_args[f"{tree.prefix.upper()}BLOCK : tl.constexpr"] = tree.prefix + block_names[f"{tree.prefix.upper()}BLOCK"] = tree.prefix + if self.enable_autotune: + argdefs.extend(block_args) + self.block_args = list(block_names.keys()) + return argdefs + + def add_numel_to_args(self, argdefs: List[str], signature: List[Any]) -> List[str]: + for num, sub_kernel in enumerate(self.sub_kernels): + for tree in sub_kernel.active_range_trees(): + if not isinstance(tree.numel, (Integer, int)): + # only if it is a dynamic shape + sizearg = SizeArg(f"{tree.prefix}numel_{num}", tree.numel) + signature.append(sizearg) + argdefs.append(f"{tree.prefix}numel_{num}") + self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}") + return argdefs + + def add_numel_to_call_args_and_grid( + self, name: str, call_args: List[Any], arg_types: List[Any], grid: List[Any] + ) -> None: + for num, sub_kernel in enumerate(self.sub_kernels): + for i, tree in enumerate(sub_kernel.range_trees): + numel_name = f"{tree.prefix}numel_{num}" + if numel_name not in self.dynamic_shape_args: + continue + if isinstance(tree.numel, (Integer, Symbol)): + expr = tree.numel + else: + expr = V.graph.wrapper_code.generate_numel_expr( + name, tree, suffix=str(num) + ) + if tree.prefix != "r": + assert isinstance( + grid[i][num], str + ), f"Grid {grid[i][num]} should be a dynamic shape." + numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else "" + assert ( + grid[i][num] == numel_sign + numel_name + ), f"numel args mismatch: {grid[i][num]} vs {numel_name}" + grid[i][num] = -expr if numel_sign == "-" else expr + + if tree.prefix != "r" or sub_kernel.inside_reduction: + call_args.append(expr) + arg_types.append(type(expr)) + + def add_numel_to_call_args_and_grid_benchmark( + self, extra_args: List[Any], grid: Union[List[Any], Tuple[Any, ...]] + ) -> None: + for num, sub_kernel in enumerate(self.sub_kernels): + for i, tree in enumerate(sub_kernel.range_trees): + numel_name = f"{tree.prefix}numel_{num}" + if numel_name not in self.dynamic_shape_args: + continue + expr = V.graph.sizevars.size_hint(tree.numel) + if tree.prefix != "r": + assert isinstance( + grid[i][num], str + ), f"Grid {grid[i][num]} should be a dynamic shape." + numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else "" + assert ( + grid[i][num] == numel_sign + numel_name + ), f"grid mismatch: {grid[i][num]} vs {numel_name}" + grid[i][num] = -expr if numel_sign == "-" else expr + if tree.prefix != "r" or sub_kernel.inside_reduction: + extra_args.append(expr) + + def codegen_kernel(self, name: Optional[str] = None) -> str: + # TODO: is it correct to use the first sub kernel's heuristics? + heuristics_list, size_hints_list = [], [] + for subkernel in self.sub_kernels: + h, s = self.select_heuristics(subkernel) + heuristics_list.append(h) + size_hints_list.append(s) + heuristics, size_hints, selected_kernel = self.select_combo_heuristics( + heuristics_list, size_hints_list + ) + pointwise_with_reduction, heuristics = ( + (True, "pointwise") + if heuristics == "pointwise_with_reduction" + else (False, heuristics) + ) + code = IndentedBuffer() + + code.splice(gen_common_triton_imports()) + if config.benchmark_combo_kernel: + code.splice(self.imports_for_benchmark_kernel()) + + argdefs, _, signature, _ = self.args.python_argdefs() + argdefs = self.add_numel_to_args(argdefs, signature) + argdefs = self.add_blockd_to_args(argdefs) + code.splice( + self.jit_line( + heuristics, + size_hints, + selected_kernel, + pointwise_with_reduce=pointwise_with_reduction, + signature=signature, + ) + ) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):" + ) + + with code.indent(): + code.splice("pid = tl.program_id(0)") + if not self.enable_autotune: + self.codegen_blocks(code) + + for num, sub_kernel in enumerate(self.sub_kernels): + assert self.dispatch_class is not None + self.dispatch_class.codegen_pid_range(self, num, code) + with code.indent(): + uniquify = self.codegen_static_numels_sub_kernel( + code, sub_kernel, num + ) + sub_kernel.codegen_body() + uniquified_body = self.uniquify_block_sizes( + sub_kernel.body, num, uniquify + ) + code.splice(uniquified_body) + + code.splice("else:") + with code.indent(): + code.splice("pass") + + if config.benchmark_combo_kernel: + code.splice(self.codegen_kernel_benchmark(num_gb=0)) + + return code.getvalue() + + def codegen_kernel_benchmark( + self, num_gb: float, grid: Optional[List[Any]] = None + ) -> IndentedBuffer: + result = IndentedBuffer() + argdefs, call_args, signature, _ = self.args.python_argdefs() + + result.writelines(["", "", "def get_args():"]) + with result.indent(): + name_cnt = itertools.count() + var_names = [] + for arg_name, arg_sig in zip(call_args, signature): + var_name = f"arg_{next(name_cnt)}" + buf = V.graph.try_get_buffer(arg_name) + if buf: + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long + ) + elif arg_name in V.graph.constants: + # note that random seed is put in V.graph.constants + const_tensor = V.graph.constants[arg_name] + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long + ) + elif isinstance(arg_sig, SizeArg): + symval_hint = V.graph.sizevars.size_hint(arg_sig.expr) + + # Force the seed_offset to be 0 so calls to the same kernel + # using different seed offset will have the same benchmark harness. + # We can dedup kernel definitions in this case. + if "seed_offset" in arg_sig.name: + symval_hint = 0 + result.writeline(f"{var_name} = {symval_hint}") + elif isinstance(arg_sig, WorkspaceArg): + device = V.graph.scheduler.get_current_device_or_throw() + nbytes = V.graph.sizevars.size_hint(arg_sig.nbytes) + result.writeline( + f"{var_name} = torch.zeros({nbytes}, device='{device}', dtype=torch.uint8)" + ) + else: + raise KeyError( + f"Don't find the buffer or const tensor for {arg_name}" + ) + var_names.append(var_name) + result.writeline(f"return {', '.join(var_names)},") + + result.writelines(["\n", "\n", "def call(args):"]) + if grid is None: + assert self.dispatch_class is not None + dynamic_shape = self.dynamic_shape_args != [] + grid_tuple = self.dispatch_class.grid( + self.grids, self.x_numels_list, dynamic_shape + ) + extra_args_str = "" + extra_args: List[Any] = [] + if dynamic_shape: + self.add_numel_to_call_args_and_grid_benchmark(extra_args, grid_tuple) + # convert nested list to list of str + grid_tuple = tuple( + "[" + ", ".join(pexpr(item) for item in e) + ",]" + for e in grid_tuple + ) + extra_args_str = ", ".join(map(str, extra_args)) + ", " + min_blocks = None + else: + min_blocks = max(self.min_x_blocks_list) * len(self.sub_kernels) + grid_str = ", ".join(pexpr(item) for item in grid_tuple) + grid_extra_kwargs = ( + f"num_kernels={len(self.sub_kernels)}, " + f"min_blocks={min_blocks}, " + f"is_sequential={self.dispatch_class is self.SequentialDispatch}" + ) + grid_str = f"{grid_str}, {grid_extra_kwargs}" + grid_arg = f"{extra_args_str}grid=grid_combo_kernels({grid_str})" + else: + grid_arg = f"grid={grid}" + index = V.graph.scheduler.get_current_device_or_throw().index + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + stream_name = f"stream{index}" + result.writeline(f"{stream_name} = get_raw_stream({index})") + result.writeline( + f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})" + ) + + # benchmark all configs + result.writelines(["\n", "\n", "def benchmark_all_configs(args):"]) + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + result.writeline( + f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})" + ) + + result.writelines(["\n", "\n", "if __name__ == '__main__':"]) + with result.indent(): + result.writeline( + "from torch._inductor.runtime.benchmarking import benchmarker" + ) + result.writeline("") + + result.writeline("args = get_args()") + result.writeline( + "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40, fast_flush=True)" + ) + result.writeline(f"num_gb = {num_gb}") + result.writeline("gb_per_s = num_gb / (ms / 1e3)") + result.writeline( + 'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")' + ) + + return result + + def imports_for_benchmark_kernel(self) -> str: + return textwrap.dedent( + """ + from torch._dynamo.testing import rand_strided + {} + import torch + from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels + """.format( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + ) + + def uniquify_block_sizes( + self, code: IndentedBuffer, num_kernel: int, uniquify: List[str] + ) -> IndentedBuffer: + if not uniquify: + return code + modified = IndentedBuffer(initial_indent=code._indent) + for line in code._lines: + if isinstance(line, str) and (blocks := [e for e in uniquify if e in line]): + modified_line = line + for block in blocks: + modified_line = modified_line.replace( + block, f"{block}_{num_kernel}" + ) + modified.writeline(modified_line) + elif isinstance(line, DeferredLine) and ( + blocks := [e for e in uniquify if e in line.line] + ): + modified_line = line.line + for block in blocks: + modified_line = modified_line.replace( + block, f"{block}_{num_kernel}" + ) + new_line = DeferredLine(line.name, modified_line) + modified.writeline(new_line) + else: + modified.writeline(line) + return modified + + def call_kernel(self, code: IndentedBuffer, name: str) -> None: + _, call_args, _, arg_types = self.args.python_argdefs() + + wrapper = V.graph.wrapper_code + assert self.dispatch_class is not None + dynamic_shape = self.dynamic_shape_args != [] + grid = list( + self.dispatch_class.grid(self.grids, self.x_numels_list, dynamic_shape) + ) + num_kernels = len(self.sub_kernels) + min_blocks = ( + max(self.min_x_blocks_list) * num_kernels if not dynamic_shape else None + ) + is_sequential = self.dispatch_class is self.SequentialDispatch + if dynamic_shape: + self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid) + # convert nested list to list of str + # grid = tuple("["+", ".join(pexpr(item) for item in e)+",]" for e in grid) + if not self.enable_autotune and not dynamic_shape: + launch_grid = self.grid_no_autotune( + grid, num_kernels, cast(int, min_blocks), is_sequential + ) + V.graph.wrapper_code.generate_kernel_call( + name, + call_args, + grid=launch_grid, + arg_types=arg_types, + grid_fn="", + ) + return + # autotuning is enabled + grid = wrapper.generate_default_grid( + name, + list(grid), + grid_callable=grid_combo_kernels, + num_kernels=num_kernels, + min_blocks=min_blocks, + is_sequential=is_sequential, + default_meta=None if self.enable_autotune else self.get_default_meta(), + ) + wrapper.generate_kernel_call( + name, + call_args, + grid, + V.graph.scheduler.get_current_device_or_throw().index, + cuda=True, + triton=True, + arg_types=arg_types, + grid_fn="grid_combo_kernels", + grid_extra_kwargs=( + f"num_kernels={num_kernels}, " + f"min_blocks={min_blocks}, " + f"is_sequential={is_sequential}, " + f"default_meta={None if self.enable_autotune else self.get_default_meta()}" + ), + ) + + def grid_no_autotune( + self, + grid: Union[Tuple[Any], List[Any]], + num_kernels: int, + min_blocks: int, + is_sequential: bool, + ) -> List[int]: + meta = self.get_default_meta() + grid_func = grid_combo_kernels( + *grid, + num_kernels=num_kernels, + min_blocks=min_blocks, + is_sequential=is_sequential, + ) + return grid_func(meta) + + def get_default_meta(self) -> Dict[str, int]: + if "YBLOCK" in self.block_args: + meta = {"XBLOCK": self.block_size_2d, "YBLOCK": self.block_size_2d} + else: + meta = {"XBLOCK": self.block_size_1d} + return meta diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_utils.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c8dbd516f13595449e5c7803b6f4a78174f88d52 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_utils.py @@ -0,0 +1,170 @@ +# mypy: allow-untyped-defs +from typing import Any, Dict, List, Optional + +import sympy + +import torch + +from .. import config +from ..runtime.hints import instance_descriptor +from ..utils import _type_of +from ..virtualized import V +from .common import KernelArgType, SizeArg, TensorArg, WorkspaceArg + + +def should_unwrap_unspec_arg(name: str): + if V.graph.is_unspec_arg(name): + # Unwrap on all devices except CPU + if V.graph.scheduler.get_current_device_or_throw().type != "cpu": + return True + # Only unwrap on CPU if the input is not used as an output + if name not in V.graph.mutated_buffers: + return True + return False + + +def signature_of(arg: KernelArgType, *, size_dtype: str) -> str: + if isinstance(arg, TensorArg): + # TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes. + # Related PR: https://github.com/openai/triton/pull/2279/ + if arg.dtype == torch.float8_e4m3fn: + tye = "*fp8e4nv" + elif arg.dtype == torch.float8_e5m2: + tye = "*fp8e5" + elif arg.dtype == torch.float8_e4m3fnuz: + tye = "*fp8e4b8" + elif arg.dtype == torch.float8_e5m2fnuz: + tye = "*fp8e5b16" + else: + tye = _type_of(arg.dtype) + if should_unwrap_unspec_arg(arg.buffer): + # had unwrapped 0d tensor as scalar + new_tye = tye.lstrip("*") + if new_tye in ["fp16", "bf16"]: + return "fp32" + else: + return new_tye + else: + return tye + if isinstance(arg, SizeArg): + if arg.expr is None: + # From triton/runtime/jit.py + # `None` is nullptr. Implicitly convert to *i8. + return "*i8" + elif isinstance(arg.expr, (float, sympy.Float)): + return "fp32" + if size_dtype == "tl.int32": + return "i32" + elif size_dtype == "tl.int64": + return "i64" + else: + raise NotImplementedError(f"unhandled size_dtype {size_dtype}") + if isinstance(arg, WorkspaceArg): + return "*i8" + raise NotImplementedError(f"unhandled {type(arg)}: {arg}") + + +def signature_to_meta( + signature: List[KernelArgType], + *, + size_dtype: str, + indices: Optional[List[int]] = None, +) -> Dict[int, str]: + if indices is None: + indices = list(range(len(signature))) + return { + i: signature_of(arg, size_dtype=size_dtype) + for i, arg in zip(indices, signature) + } + + +def is_unaligned_buffer(arg: TensorArg): + buf_name = arg.buffer + if buf_name in V.graph.graph_inputs: + # See Note: [Input Alignment handling in Inductor] + return buf_name not in V.graph.aligned_inputs + + if buf_name in V.graph.constants: + # all constants are assumed to be aligned + return False + + if V.graph.scheduler: + layout = V.graph.scheduler.get_buffer_layout(buf_name) + else: + buffer = V.graph.try_get_buffer(buf_name) + # output arg + if not buffer: + assert buf_name == V.kernel.output_node.name + layout = V.kernel.output_node.layout + else: + layout = buffer.get_layout() + + if isinstance(layout, torch._inductor.ir.NonOwningLayout): + return not layout.maybe_guard_aligned() + else: + return False + + +def config_of( + args: List[KernelArgType], + *, + indices: Optional[List[int]] = None, +) -> Any: + if indices is None: + indices = list(range(len(args))) + + def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool: + """ + Roughly follow triton code here: + https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222 + """ + if isinstance(x, TensorArg): + if include_tensor: + offset_aligned = V.graph.sizevars.statically_known_multiple_of( + x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type] + ) + return offset_aligned and not is_unaligned_buffer(x) + else: + return False + if isinstance(x, SizeArg): + # TODO(voz): These are kinda redundant, if we can solve out statically_known_multiple_of with + # _maybe_evaluate_static... + if x.name.startswith("load_seed_offset"): + return False + if x.expr is None: + return False + if isinstance(x.expr, float): + return False + return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type] + if isinstance(x, WorkspaceArg): + return V.graph.sizevars.statically_known_multiple_of(x.nbytes, alignment) # type: ignore[arg-type] + raise NotImplementedError(f"unhandled {type(x)}: {x}") + + if config.triton.divisible_by_16: + divisible_by_16 = tuple( + i + for i, arg in zip(indices, args) + if is_aligned(arg, alignment=16, include_tensor=True) + ) + else: + divisible_by_16 = () + divisible_by_8 = tuple( + i + for i, arg in zip(indices, args) + if is_aligned(arg, alignment=8, include_tensor=False) + ) + + equal_to_1 = tuple( + i + for i, arg in zip(indices, args) + if isinstance(arg, SizeArg) + and isinstance(arg.expr, (int, sympy.Integer)) + and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type] + ) + # ids_of_folded_args is set from equal_to_1 + # and None args by the Triton compiler + ids_of_folded_args = tuple(equal_to_1) + + return instance_descriptor( + divisible_by_16, equal_to_1, ids_of_folded_args, divisible_by_8 + ) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a27869212efbc096ad46b9a5400c78b1a6f234e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:93ef693fcec0e185f444c8e720b2ea60fa476878b07a02b41f61ca88faf8afbd +size 100608 diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/symbolic_shapes.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/symbolic_shapes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..439dd43de252b39b9199da1545ec074e858eb3c8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/symbolic_shapes.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c39af3e78c6b66cd27dd9fce142d74333e949b90af2040c3bd715b9c9647916 +size 271836