diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/async_compile.py b/.venv/lib/python3.11/site-packages/torch/_inductor/async_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..0794ceb3eed5719b5037de6ceb1c89c30a52e1cd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/async_compile.py @@ -0,0 +1,297 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +import multiprocessing +import os +import sys +from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor +from concurrent.futures.process import BrokenProcessPool +from functools import partial +from time import time +from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING + +import torch +from torch._dynamo.device_interface import get_registered_device_interfaces +from torch._inductor import config +from torch._inductor.codecache import ( + CodeCacheFuture, + CppCodeCache, + CppPythonBindingsCodeCache, + CUDACodeCache, + HalideCodeCache, + LambdaFuture, + ROCmCodeCache, + TritonCodeCache, + TritonFuture, +) +from torch._inductor.compile_worker.subproc_pool import ( + _warm_process_pool, + AnyPool, + SubprocPool, +) +from torch._inductor.compile_worker.watchdog import _async_compile_initializer +from torch._inductor.runtime.compile_tasks import ( + _set_triton_ptxas_path, + _worker_compile_triton, +) +from torch.hub import _Faketqdm, tqdm +from torch.utils._triton import has_triton_package + + +if TYPE_CHECKING: + from torch._inductor.runtime.hints import HalideMeta + +# timing metrics for time spent in the compilation +_cumulative_compile_time = 0.0 +_t0: Optional[float] = None + +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + + +def pre_fork_setup(): + """ + Setup that must be done prior to forking with a process pool. + """ + # ensure properties have been calculated before processes + # are forked + caching_device_properties() + + # Computing the triton key can be slow. If we call it before fork, + # it will be cached for the forked subprocesses. + try: + from triton.compiler.compiler import triton_key + + triton_key() + except ImportError: + # Triton might not be installed or might be an old version. + pass + + +def caching_device_properties(): + for _, device_interface in get_registered_device_interfaces(): + if device_interface.is_available(): + device_interface.Worker.get_device_properties() + + +def _compile_start() -> None: + global _t0 + if _t0 is None: + _t0 = time() + + +def _compile_end() -> None: + global _cumulative_compile_time, _t0 + if _t0 is not None: + t1 = time() + _cumulative_compile_time += t1 - _t0 + _t0 = None + # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) + + +_IS_WINDOWS = sys.platform == "win32" + +log = logging.getLogger(__name__) + + +# Used to keep track of all process pools invoked so far. +_pool_set: Set[AnyPool] = set() + + +def shutdown_compile_workers() -> None: + """Shut down all outstanding compile-worker pools.""" + for pool in _pool_set: + pool.shutdown() + after_fork() + + +def after_fork(): + """Reset pools to initial state without shutting them down""" + _pool_set.clear() + AsyncCompile.process_pool.cache_clear() + + +try: + os.register_at_fork(after_in_child=after_fork) +except AttributeError: + pass # register_at_fork does not exists on windows + + +class AsyncCompile: + def __init__(self) -> None: + pass + + @staticmethod + @functools.lru_cache(1) + def pool() -> ThreadPoolExecutor: + assert config.compile_threads > 1 + return ThreadPoolExecutor(config.compile_threads) + + @staticmethod + def _get_ready(): + """No-op function to help mark when the subprocess pool is ready.""" + return "ready" + + @staticmethod + @functools.lru_cache(1) + def process_pool() -> AnyPool: + assert config.compile_threads > 1 + pool: AnyPool + if config.worker_start_method == "subprocess": + # Wrapper around ProcessPoolExecutor forks in a new process we control + pool = SubprocPool(config.compile_threads) + else: + pre_fork_setup() + ctx = multiprocessing.get_context(config.worker_start_method) + pool = ProcessPoolExecutor( + config.compile_threads, + mp_context=ctx, + initializer=partial(_async_compile_initializer, os.getpid()), + ) + # when this pool is created in a subprocess object, the normal exit handler + # doesn't run, and we need to register our own handler. + # exitpriority has to be high, because another one of the finalizers will + # kill the worker thread that sends the shutdown message to the workers... + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + + # Set an attribute we can check to see if the pool is ready. + pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[union-attr] + _pool_set.add(pool) + return pool + + @classmethod + def warm_pool(cls) -> None: + if config.compile_threads <= 1: + return + _compile_start() + _warm_process_pool(cls.process_pool(), config.compile_threads) + _compile_end() + + @classmethod + def submit(cls, task: Callable[..., Any]) -> Any: + if config.compile_threads <= 1: + return task() + return cls.pool().submit(task) + + def _use_process_pool(self): + return ( + config.compile_threads > 1 + and self.process_pool().ready_future.done() # type: ignore[union-attr] + ) + + def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): + kernel_code_log.info("Triton Kernel:\n%s", source_code) + _compile_start() + _set_triton_ptxas_path() + + kernel = TritonCodeCache.load(kernel_name, source_code) + if self._use_process_pool(): + # We want to support changing these env vars after (and while) the + # process pool is running, so pass them to the subprocess to reset. + env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] + extra_env = {v: os.environ[v] for v in env_vars if v in os.environ} + return TritonFuture( + kernel, + self.process_pool().submit( + _worker_compile_triton, + kernel._reload_in_subproc, + extra_env, + ), + ) + else: + kernel.precompile() + return kernel + + def multi_kernel(self, *args, **kwargs) -> Any: + from torch._inductor.codegen.multi_kernel import MultiKernelCall + + # no need to call this in parallel since the sub-kernels are already parallel tasks + return MultiKernelCall(*args, **kwargs) + + def cpp(self, source_code: str): + kernel_code_log.info("CPP Kernel:\n%s", source_code) + if config.compile_threads <= 1: + return CppCodeCache.load(source_code).kernel + else: + get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) + return LambdaFuture(lambda: get_result().kernel) + + def cpp_pybinding(self, argtypes: List[str], source_code: str): + kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) + if config.compile_threads <= 1: + return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) + else: + get_result = CppPythonBindingsCodeCache.load_pybinding_async( + argtypes, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def cuda(self, source_code, dst_file_ext): + kernel_code_log.info("CUDA Kernel:\n%s", source_code) + + def task(): + return CUDACodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + def rocm(self, source_code, dst_file_ext): + kernel_code_log.info("ROCm Kernel:\n%s", source_code) + + def task(): + return ROCmCodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + def halide(self, meta: HalideMeta, source_code: str): + kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code) + if config.compile_threads <= 1: + return HalideCodeCache.generate_halide(meta, source_code) + else: + get_result = HalideCodeCache.generate_halide_async( + meta, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def wait(self, scope: Dict[str, Any]) -> None: + num_kernels = len( + [ + value + for key, value in scope.items() + if isinstance(value, (Future, CodeCacheFuture)) + ] + ) + pbar = tqdm( + total=num_kernels, + desc="Inductor Compilation", + disable=config.disable_progress, + delay=0, + ) + if config.compile_threads > 1: + for key, result in scope.items(): + if config.verbose_progress and not isinstance(pbar, _Faketqdm): + pbar.set_postfix_str(key) + if isinstance(result, (Future, CodeCacheFuture)): + try: + scope[key] = result.result() + except BrokenProcessPool as e: + raise RuntimeError( + "A compilation subprocess exited unexpectedly. This " + "is likely due to a crash. To facilitate debugging, " + "you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 " + "to cause compilation to occur in the main process." + ) from e + pbar.update(1) + + _compile_end() + + +if ( + os.environ.get("TORCH_TNT_IN_USE", "0") == "1" + or os.environ.get("TORCH_WARM_POOL", "1") != "1" + # The subprocess pool is only used for the Triton backend + or not has_triton_package() +): + pass +else: + AsyncCompile.warm_pool() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__init__.py b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1476c23ad5cecd2d655cefb36904076f3f4557e3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1896f1ed41b8635a0ef13c40e23b4289e9b539d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6c11e8dff6ec0fd1eaa597d74d9bddfeb6d44b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py @@ -0,0 +1,296 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MMRankingA100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + + def get_name(self) -> str: + return 'mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]: + if context.get_value('arith_intensity') <= 52.6245059967041: + if context.get_value('n') <= 34.0: + if context.get_value('n') <= 18.0: + if context.get_value('k*n') <= 312.0: + return [(0.093, 12), (0.081, 16), (0.081, 148), (0.070, 10), (0.070, 17), (0.070, 149), (0.070, 151), (0.070, 150), (0.070, 14), (0.058, 11), (0.058, 15), (0.058, 13), (0.058, 122), (0.047, 121), (0.035, 123), (0.012, 92)] + else: + if context.get_value('k') <= 40.0: + return [(0.083, 42), (0.083, 46), (0.083, 44), (0.083, 40), (0.083, 128), (0.067, 45), (0.067, 43), (0.067, 41), (0.067, 169), (0.067, 171), (0.067, 168), (0.067, 129), (0.067, 170), (0.033, 103), (0.017, 121)] + else: + return [(0.112, 137), (0.104, 136), (0.101, 0), (0.081, 1), (0.073, 135), (0.069, 67), (0.066, 187), (0.058, 41), (0.050, 71), (0.046, 68), (0.046, 70), (0.031, 44), (0.027, 43), (0.027, 170), (0.019, 189), (0.019, 188), (0.015, 169), (0.015, 171), (0.012, 115), (0.012, 168), (0.012, 69), (0.004, 103)] + else: + if context.get_value('mat1_stride_0') <= 20.0: + return [(0.069, 0), (0.059, 157), (0.059, 22), (0.059, 153), (0.059, 155), (0.059, 25), (0.059, 23), (0.059, 19), (0.044, 21), (0.044, 18), (0.044, 152), (0.044, 158), (0.044, 154), (0.044, 156), (0.044, 20), (0.044, 124), (0.044, 24), (0.030, 125), (0.029, 126), (0.015, 97), (0.015, 95), (0.015, 96), (0.010, 2), (0.010, 75)] + else: + if context.get_value('k') <= 68.0: + return [(0.087, 72), (0.087, 74), (0.087, 73), (0.086, 76), (0.077, 75), (0.067, 192), (0.058, 190), (0.048, 47), (0.048, 193), (0.048, 49), (0.048, 51), (0.048, 191), (0.038, 53), (0.019, 133), (0.019, 50), (0.019, 175), (0.019, 172), (0.019, 48), (0.019, 174), (0.010, 173), (0.010, 177), (0.010, 52), (0.010, 54), (0.010, 178), (0.010, 176)] + else: + return [(0.154, 52), (0.154, 72), (0.102, 75), (0.087, 49), (0.087, 73), (0.086, 51), (0.057, 176), (0.045, 2), (0.038, 191), (0.038, 178), (0.038, 190), (0.029, 173), (0.029, 76), (0.026, 138), (0.013, 139), (0.013, 140), (0.003, 0)] + else: + if context.get_value('k') <= 35.0: + if context.get_value('k') <= 18.0: + if context.get_value('m*n') <= 19505152.0: + return [(0.151, 159), (0.140, 160), (0.129, 164), (0.055, 127), (0.051, 29), (0.044, 161), (0.044, 147), (0.040, 146), (0.040, 31), (0.037, 145), (0.026, 28), (0.022, 90), (0.022, 93), (0.022, 94), (0.022, 100), (0.022, 125), (0.022, 158), (0.022, 157), (0.011, 87), (0.011, 88), (0.011, 89), (0.011, 91), (0.011, 95), (0.011, 96), (0.011, 98), (0.011, 99)] + else: + return [(0.069, 7), (0.069, 5), (0.067, 147), (0.066, 8), (0.061, 145), (0.058, 146), (0.052, 124), (0.049, 29), (0.049, 159), (0.046, 31), (0.043, 157), (0.041, 9), (0.041, 4), (0.040, 6), (0.035, 164), (0.035, 160), (0.026, 158), (0.017, 125), (0.017, 28), (0.017, 32), (0.017, 162), (0.017, 27), (0.017, 30), (0.017, 161), (0.009, 33), (0.009, 26), (0.009, 163), (0.006, 0)] + else: + if context.get_value('n') <= 68.0: + return [(0.101, 182), (0.101, 59), (0.088, 57), (0.076, 184), (0.076, 61), (0.076, 179), (0.076, 62), (0.076, 58), (0.063, 180), (0.063, 60), (0.051, 56), (0.050, 181), (0.025, 130), (0.025, 177), (0.025, 183), (0.013, 178), (0.013, 55)] + else: + return [(0.089, 180), (0.079, 60), (0.066, 35), (0.066, 181), (0.066, 38), (0.066, 58), (0.066, 179), (0.066, 57), (0.062, 184), (0.053, 37), (0.044, 166), (0.040, 55), (0.040, 39), (0.040, 36), (0.040, 165), (0.040, 167), (0.027, 177), (0.027, 34), (0.022, 159)] + else: + if context.get_value('m*n') <= 309760.0: + return [(0.298, 0), (0.097, 140), (0.080, 83), (0.072, 86), (0.044, 84), (0.036, 178), (0.036, 117), (0.036, 82), (0.032, 120), (0.032, 85), (0.028, 119), (0.024, 130), (0.024, 109), (0.020, 108), (0.020, 118), (0.012, 104), (0.012, 116), (0.012, 141), (0.012, 144), (0.008, 105), (0.008, 106), (0.008, 111), (0.008, 114), (0.008, 107), (0.008, 132), (0.004, 101), (0.004, 102), (0.004, 110), (0.004, 112), (0.004, 113), (0.004, 131)] + else: + if context.get_value('n') <= 72.0: + return [(0.227, 77), (0.118, 78), (0.102, 194), (0.086, 80), (0.059, 57), (0.054, 81), (0.049, 196), (0.048, 197), (0.048, 59), (0.043, 79), (0.032, 195), (0.027, 180), (0.022, 3), (0.021, 141), (0.016, 60), (0.016, 142), (0.011, 183), (0.011, 0), (0.011, 144)] + else: + return [(0.140, 186), (0.132, 185), (0.109, 63), (0.085, 65), (0.078, 37), (0.077, 35), (0.062, 197), (0.047, 194), (0.046, 165), (0.046, 57), (0.039, 78), (0.039, 79), (0.039, 66), (0.039, 64), (0.016, 195), (0.008, 159)] + else: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 815360.0: + if context.get_value('k') <= 1184.0: + return [(0.218, 140), (0.205, 0), (0.154, 144), (0.115, 141), (0.051, 185), (0.051, 104), (0.039, 78), (0.038, 116), (0.026, 165), (0.026, 130), (0.026, 178), (0.013, 57), (0.013, 195), (0.013, 167), (0.013, 186)] + else: + return [(0.901, 0), (0.030, 144), (0.030, 134), (0.016, 3), (0.006, 78), (0.006, 77), (0.002, 57), (0.002, 194), (0.002, 59), (0.002, 60), (0.002, 143)] + else: + if context.get_value('arith_intensity') <= 187.23922729492188: + if context.get_value('mat1_stride_0') <= 198.0: + return [(0.273, 63), (0.158, 37), (0.152, 35), (0.127, 57), (0.097, 165), (0.053, 185), (0.031, 0), (0.028, 64), (0.014, 60), (0.014, 78), (0.009, 55), (0.008, 134), (0.005, 34), (0.005, 167), (0.005, 179), (0.005, 65), (0.005, 66), (0.005, 186), (0.005, 194), (0.002, 166)] + else: + return [(0.296, 63), (0.235, 0), (0.132, 64), (0.074, 37), (0.069, 78), (0.051, 185), (0.051, 35), (0.030, 57), (0.020, 77), (0.016, 194), (0.008, 66), (0.007, 65), (0.003, 3), (0.003, 165), (0.003, 141), (0.001, 134), (0.001, 166)] + else: + return [(0.405, 0), (0.246, 37), (0.177, 63), (0.145, 35), (0.005, 185), (0.005, 65), (0.005, 64), (0.004, 57), (0.003, 66), (0.002, 165), (0.001, 78), (0.001, 55)] + else: + return [(0.357, 0), (0.112, 165), (0.101, 57), (0.094, 179), (0.086, 64), (0.074, 167), (0.067, 60), (0.064, 159), (0.033, 35), (0.007, 195), (0.002, 180), (0.001, 34), (0.001, 166), (0.001, 78)] diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b31638b816d6e806c2ca5ca13805e52da89a41 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py @@ -0,0 +1,321 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MMRankingH100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 232448 + and str(metadata.device_capa) == "(9, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + + def get_name(self) -> str: + return 'mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]: + if context.get_value('arith_intensity') <= 29.89772129058838: + if context.get_value('n') <= 34.0: + if context.get_value('n') <= 18.0: + if context.get_value('k*n') <= 432.0: + if context.get_value('arith_intensity') <= 7.8700292110443115: + return [(0.098, 128), (0.098, 129), (0.098, 127), (0.073, 14), (0.073, 16), (0.073, 12), (0.073, 154), (0.073, 156), (0.073, 157), (0.073, 155), (0.049, 10), (0.049, 94), (0.049, 95), (0.048, 96)] + else: + return [(0.091, 154), (0.073, 10), (0.073, 15), (0.073, 13), (0.073, 11), (0.073, 17), (0.073, 16), (0.073, 14), (0.073, 12), (0.055, 127), (0.054, 157), (0.054, 156), (0.054, 155), (0.036, 129), (0.036, 128), (0.018, 41), (0.018, 43)] + else: + if context.get_value('k') <= 40.0: + return [(0.070, 39), (0.069, 45), (0.069, 41), (0.069, 43), (0.069, 111), (0.069, 112), (0.056, 38), (0.056, 40), (0.056, 42), (0.056, 44), (0.056, 174), (0.056, 173), (0.056, 175), (0.056, 134), (0.056, 172), (0.056, 135), (0.014, 154), (0.014, 127)] + else: + return [(0.147, 144), (0.119, 143), (0.087, 142), (0.083, 0), (0.073, 191), (0.059, 69), (0.050, 67), (0.046, 70), (0.041, 1), (0.036, 174), (0.032, 43), (0.032, 123), (0.028, 40), (0.027, 42), (0.027, 173), (0.023, 175), (0.018, 66), (0.014, 192), (0.014, 193), (0.014, 139), (0.014, 68), (0.014, 127)] + else: + if context.get_value('mat1_stride_0') <= 40.0: + if context.get_value('mat1_stride_0') <= 20.0: + return [(0.109, 23), (0.109, 21), (0.109, 20), (0.088, 0), (0.087, 131), (0.066, 18), (0.065, 130), (0.065, 132), (0.065, 159), (0.065, 160), (0.065, 161), (0.065, 158), (0.022, 22), (0.022, 19)] + else: + return [(0.065, 46), (0.064, 52), (0.064, 50), (0.064, 48), (0.064, 51), (0.064, 49), (0.064, 47), (0.064, 53), (0.064, 181), (0.064, 177), (0.064, 179), (0.064, 176), (0.038, 130), (0.038, 136), (0.026, 182), (0.026, 178), (0.026, 180), (0.026, 137), (0.025, 158), (0.013, 114), (0.013, 113)] + else: + if context.get_value('mat1_stride_0') <= 68.0: + return [(0.138, 140), (0.125, 195), (0.100, 71), (0.100, 74), (0.100, 196), (0.100, 194), (0.100, 197), (0.075, 75), (0.062, 72), (0.062, 73), (0.012, 180), (0.012, 51), (0.012, 182)] + else: + return [(0.124, 180), (0.124, 182), (0.114, 75), (0.103, 74), (0.093, 51), (0.093, 71), (0.072, 72), (0.062, 194), (0.052, 145), (0.052, 195), (0.021, 48), (0.021, 50), (0.021, 47), (0.020, 124), (0.010, 147), (0.010, 146), (0.010, 46)] + else: + if context.get_value('k') <= 18.0: + if context.get_value('m*k') <= 528.0: + return [(0.097, 88), (0.087, 92), (0.077, 90), (0.058, 105), (0.058, 103), (0.058, 104), (0.058, 99), (0.058, 100), (0.058, 106), (0.058, 93), (0.057, 91), (0.057, 97), (0.057, 98), (0.057, 101), (0.048, 102), (0.029, 87), (0.029, 89)] + else: + if context.get_value('n') <= 80.0: + return [(0.057, 161), (0.057, 130), (0.057, 24), (0.056, 164), (0.056, 163), (0.056, 166), (0.056, 168), (0.056, 30), (0.056, 28), (0.056, 26), (0.056, 25), (0.056, 27), (0.056, 29), (0.056, 31), (0.042, 131), (0.028, 99), (0.028, 101), (0.028, 100), (0.028, 167), (0.028, 165), (0.028, 133)] + else: + return [(0.110, 164), (0.108, 163), (0.106, 168), (0.069, 161), (0.066, 151), (0.060, 152), (0.055, 165), (0.050, 27), (0.050, 29), (0.048, 131), (0.043, 153), (0.037, 133), (0.037, 130), (0.028, 8), (0.028, 5), (0.027, 7), (0.026, 26), (0.016, 162), (0.012, 9), (0.007, 4), (0.005, 100), (0.005, 6), (0.005, 24)] + else: + if context.get_value('k') <= 36.0: + if context.get_value('n') <= 68.0: + return [(0.097, 184), (0.097, 56), (0.086, 186), (0.086, 183), (0.086, 188), (0.086, 58), (0.086, 60), (0.065, 54), (0.043, 187), (0.043, 185), (0.043, 57), (0.043, 61), (0.032, 55), (0.032, 130), (0.032, 59), (0.011, 181), (0.011, 163), (0.011, 136), (0.011, 138)] + else: + return [(0.117, 184), (0.117, 170), (0.117, 169), (0.107, 183), (0.106, 188), (0.075, 181), (0.064, 130), (0.064, 56), (0.053, 171), (0.032, 57), (0.032, 59), (0.032, 185), (0.011, 163), (0.011, 32), (0.011, 37), (0.011, 34), (0.011, 33), (0.011, 35), (0.011, 36), (0.011, 54)] + else: + if context.get_value('mat2_stride_0') <= 384.0: + return [(0.244, 0), (0.061, 76), (0.061, 79), (0.030, 3), (0.030, 183), (0.030, 189), (0.030, 187), (0.030, 64), (0.030, 190), (0.030, 62), (0.030, 198), (0.030, 201), (0.030, 77), (0.030, 200), (0.030, 80), (0.030, 199), (0.030, 78), (0.030, 184), (0.020, 86), (0.020, 84), (0.020, 120), (0.020, 81), (0.020, 121), (0.020, 85), (0.020, 122), (0.010, 83), (0.010, 118), (0.010, 119), (0.010, 82)] + else: + return [(0.274, 83), (0.171, 86), (0.152, 0), (0.071, 85), (0.061, 125), (0.050, 84), (0.020, 109), (0.020, 117), (0.020, 81), (0.020, 118), (0.020, 121), (0.020, 108), (0.020, 115), (0.020, 116), (0.010, 110), (0.010, 120), (0.010, 103), (0.010, 107), (0.010, 119), (0.010, 122)] + else: + if context.get_value('arith_intensity') <= 56.995582580566406: + if context.get_value('n') <= 68.0: + if context.get_value('k*n') <= 4448.0: + if context.get_value('m*n') <= 29626368.0: + return [(0.107, 198), (0.107, 200), (0.107, 201), (0.107, 199), (0.106, 76), (0.106, 79), (0.064, 197), (0.063, 56), (0.043, 184), (0.043, 187), (0.042, 80), (0.042, 77), (0.042, 183), (0.021, 78)] + else: + return [(0.073, 201), (0.073, 198), (0.073, 200), (0.073, 199), (0.073, 197), (0.073, 56), (0.073, 58), (0.073, 79), (0.073, 76), (0.072, 59), (0.072, 78), (0.072, 77), (0.072, 80), (0.018, 184), (0.018, 55), (0.018, 54)] + else: + if context.get_value('k') <= 348.0: + return [(0.206, 76), (0.183, 77), (0.169, 198), (0.160, 199), (0.053, 59), (0.046, 56), (0.038, 3), (0.030, 148), (0.030, 58), (0.030, 187), (0.023, 184), (0.015, 0), (0.008, 55), (0.008, 54)] + else: + return [(0.146, 198), (0.145, 199), (0.145, 148), (0.126, 0), (0.084, 76), (0.084, 77), (0.042, 80), (0.042, 79), (0.021, 149), (0.021, 150), (0.021, 3), (0.014, 46), (0.014, 74), (0.014, 75), (0.014, 124), (0.014, 194), (0.014, 195), (0.007, 145), (0.007, 146), (0.007, 2), (0.007, 72), (0.007, 147), (0.007, 71)] + else: + if context.get_value('m') <= 3264.0: + return [(0.247, 147), (0.115, 197), (0.066, 199), (0.066, 201), (0.066, 198), (0.049, 0), (0.049, 169), (0.049, 171), (0.033, 140), (0.033, 125), (0.033, 114), (0.016, 126), (0.016, 183), (0.016, 184), (0.016, 185), (0.016, 182), (0.016, 188), (0.016, 78), (0.016, 148), (0.016, 138), (0.016, 77), (0.016, 56), (0.016, 59)] + else: + if context.get_value('k') <= 62.5: + return [(0.226, 190), (0.226, 189), (0.122, 62), (0.122, 64), (0.055, 77), (0.055, 78), (0.037, 198), (0.036, 201), (0.036, 33), (0.024, 163), (0.018, 56), (0.018, 35), (0.018, 169), (0.006, 171)] + else: + return [(0.162, 35), (0.118, 33), (0.096, 189), (0.096, 190), (0.088, 169), (0.074, 62), (0.073, 56), (0.066, 171), (0.051, 198), (0.051, 201), (0.044, 59), (0.037, 64), (0.029, 63), (0.007, 0), (0.007, 77)] + else: + if context.get_value('m*n') <= 1097728.0: + return [(0.403, 0), (0.179, 141), (0.134, 150), (0.086, 147), (0.051, 148), (0.048, 3), (0.024, 189), (0.020, 199), (0.017, 64), (0.010, 65), (0.010, 77), (0.007, 114), (0.003, 138), (0.003, 59), (0.003, 182)] + else: + if context.get_value('m*n') <= 3244032.0: + return [(0.295, 189), (0.176, 64), (0.157, 65), (0.090, 0), (0.069, 62), (0.059, 63), (0.046, 77), (0.039, 169), (0.023, 199), (0.020, 35), (0.013, 33), (0.010, 171), (0.003, 141)] + else: + if context.get_value('n') <= 136.0: + return [(0.197, 189), (0.197, 63), (0.161, 77), (0.157, 62), (0.061, 33), (0.044, 65), (0.039, 35), (0.039, 64), (0.030, 169), (0.026, 0), (0.017, 199), (0.017, 148), (0.009, 56), (0.004, 3)] + else: + return [(0.460, 0), (0.145, 62), (0.138, 63), (0.081, 35), (0.047, 33), (0.043, 189), (0.023, 64), (0.018, 77), (0.013, 169), (0.009, 65), (0.009, 56), (0.005, 32), (0.005, 59), (0.002, 183), (0.002, 163)] diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py new file mode 100644 index 0000000000000000000000000000000000000000..6590e9c5344c288782b4ca92557effc5830d3a5b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py @@ -0,0 +1,150 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MixedMMA100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_fallback_mixed_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + + def get_name(self) -> str: + return 'mixed_mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]: + if str(context.get_value('1LEQmLEQ16')) != 'True': + if context.get_value('m') <= 32.5: + if context.get_value('n') <= 6976.0: + if context.get_value('n') <= 3520.0: + if context.get_value('m*n') <= 37632.0: + return None + else: + return [(1.000, 13)] + else: + if context.get_value('m*k') <= 452352.0: + return [(0.590, 13), (0.256, 8), (0.103, 7), (0.051, 11)] + else: + return [(0.778, 8), (0.222, 13)] + else: + if context.get_value('k*n') <= 102776832.0: + if context.get_value('n') <= 14656.0: + return [(1.000, 11)] + else: + return [(0.889, 11), (0.111, 13)] + else: + return [(1.000, 11)] + else: + if context.get_value('m*n') <= 446464.0: + if context.get_value('m*n') <= 223424.0: + if context.get_value('mat1_stride_0') <= 3968.0: + return None + else: + return None + else: + if context.get_value('m*n') <= 346112.0: + return [(0.960, 16), (0.040, 7)] + else: + return [(0.750, 16), (0.136, 14), (0.114, 7)] + else: + if str(context.get_value('33LEQmLEQ64')) != 'True': + if context.get_value('n') <= 6976.0: + return [(1.000, 14)] + else: + return [(0.753, 2), (0.222, 1), (0.015, 7), (0.007, 16), (0.004, 12)] + else: + if context.get_value('n') <= 13888.0: + return [(0.710, 14), (0.275, 21), (0.014, 12)] + else: + return [(0.374, 19), (0.339, 20), (0.106, 21), (0.101, 16), (0.066, 17), (0.009, 14), (0.004, 18)] + else: + if context.get_value('n') <= 3520.0: + if context.get_value('arith_intensity') <= 3.994754433631897: + if str(context.get_value('mat2_dtype')) != 'torch.uint8': + if context.get_value('m*k') <= 18944.0: + return [(0.577, 5), (0.423, 6)] + else: + return [(0.988, 5), (0.012, 6)] + else: + if context.get_value('arith_intensity') <= 2.9899919033050537: + return None + else: + return None + else: + if context.get_value('arith_intensity') <= 7.956453561782837: + if context.get_value('k*n') <= 9244032.0: + return [(0.822, 5), (0.178, 6)] + else: + return [(0.977, 5), (0.023, 0)] + else: + if context.get_value('m*k') <= 978944.0: + return [(1.000, 5)] + else: + return [(0.971, 5), (0.029, 0)] + else: + if context.get_value('n') <= 13632.0: + if context.get_value('n') <= 6976.0: + return [(1.000, 6)] + else: + if context.get_value('k') <= 3968.0: + return [(0.617, 3), (0.111, 5), (0.099, 7), (0.086, 9), (0.062, 6), (0.025, 8)] + else: + return [(0.779, 8), (0.119, 5), (0.053, 7), (0.035, 6), (0.013, 3)] + else: + if context.get_value('k*n') <= 39518208.0: + return [(0.385, 4), (0.327, 3), (0.192, 6), (0.038, 7), (0.038, 10), (0.019, 5)] + else: + if context.get_value('n') <= 20800.0: + return [(0.821, 6), (0.121, 7), (0.029, 4), (0.014, 5), (0.007, 3), (0.007, 8)] + else: + return [(0.530, 7), (0.386, 6), (0.046, 8), (0.021, 3), (0.015, 4), (0.002, 5)] diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py new file mode 100644 index 0000000000000000000000000000000000000000..07a343fff72eee7749fd62742a73bfc03736569d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py @@ -0,0 +1,149 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MixedMMH100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 232448 + and str(metadata.device_capa) == "(9, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_fallback_mixed_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + + def get_name(self) -> str: + return 'mixed_mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]: + if context.get_value('arith_intensity') <= 15.988086223602295: + if context.get_value('n') <= 25280.0: + if context.get_value('n') <= 1344.0: + if context.get_value('mat1_stride_0') <= 7808.0: + return [(0.581, 7), (0.419, 6)] + else: + if context.get_value('m*n') <= 7680.0: + return [(0.875, 0), (0.125, 6)] + else: + return [(0.833, 0), (0.167, 7)] + else: + if context.get_value('n') <= 8512.0: + if str(context.get_value('mat2_dtype')) != 'torch.int8': + return [(0.763, 6), (0.237, 7)] + else: + return [(0.725, 7), (0.275, 6)] + else: + if str(context.get_value('mat1_dtype')) != 'torch.bfloat16': + return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)] + else: + return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)] + else: + if context.get_value('n') <= 42254.0: + if context.get_value('n') <= 33856.0: + if context.get_value('k*n') <= 68157440.0: + return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)] + else: + return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)] + else: + return [(0.659, 5), (0.341, 6)] + else: + if context.get_value('k*n') <= 326052992.0: + if context.get_value('n') <= 55232.0: + return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)] + else: + return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)] + else: + if context.get_value('n') <= 57024.0: + return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)] + else: + return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)] + else: + if context.get_value('m*n') <= 543936.0: + if str(context.get_value('17LEQmLEQ32')) != 'True': + if context.get_value('m*n') <= 262272.0: + if context.get_value('n') <= 1592.5: + return [(0.860, 0), (0.140, 9)] + else: + return None + else: + if context.get_value('m*k') <= 1294336.0: + return [(0.833, 17), (0.150, 18), (0.017, 15)] + else: + return [(0.917, 17), (0.083, 8)] + else: + if context.get_value('n') <= 12416.0: + if context.get_value('m*n') <= 43008.0: + return None + else: + return [(0.853, 14), (0.147, 9)] + else: + return [(0.625, 12), (0.375, 14)] + else: + if context.get_value('m') <= 32.5: + if context.get_value('mat2_stride_1') <= 6656.0: + if context.get_value('n') <= 69184.0: + return [(0.611, 12), (0.361, 14), (0.028, 13)] + else: + return [(1.000, 12)] + else: + if context.get_value('mat2_stride_1') <= 20864.0: + return [(1.000, 12)] + else: + return [(0.958, 12), (0.042, 9)] + else: + if context.get_value('m*n') <= 1085440.0: + if context.get_value('n') <= 9152.0: + return [(1.000, 18)] + else: + return [(0.780, 18), (0.160, 16), (0.060, 20)] + else: + if context.get_value('m') <= 67.0: + return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)] + else: + return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)] diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py new file mode 100644 index 0000000000000000000000000000000000000000..b61f8a9dd1e99056864a9dddc663b090f6971214 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py @@ -0,0 +1,109 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/ +from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicRegression, +) + + +class PadMMA100(LearnedHeuristicRegression): + + def __init__(self) -> None: + pass + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_feedback(self, context: AHContext, choice: Choice) -> float: + context.context_dict[CHOICE_COL] = choice + return self.predict(context) + + def get_confidence_threshold(self) -> float: + return 1.7025303314066 + + def get_name(self) -> str: + return 'pad_mm' + + def predict(self, context: AHContext) -> float: + if str(context.get_value('choice')) != 'pad': + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 4171264.0: + if context.get_value('m*k') <= 3999308.0: + return 1.8751469764071178 + else: + if str(context.get_value('n_multiple_32')) != 'True': + return 0.9117231355626345 + else: + return 1.1607689608873861 + else: + if str(context.get_value('n_multiple_2')) != 'True': + if str(context.get_value('using_tf32')) != 'True': + return 0.7430382200435992 + else: + return 0.8531269794448678 + else: + if str(context.get_value('k_multiple_2')) != 'True': + return 0.7577181972719917 + else: + return 0.8977349440424219 + else: + if context.get_value('m*n') <= 1299712.0: + return 1.1669723418995592 + else: + if context.get_value('mat2_stride_1') <= 45217.5: + if context.get_value('m*n') <= 55884158.0: + return 1.0262769936909601 + else: + return 1.0022677428470845 + else: + if context.get_value('m') <= 18478.0: + return 1.1127066261894312 + else: + return 1.0337740659894263 + else: + if str(context.get_value('mat1_dtype')) != 'torch.float32': + if str(context.get_value('n_multiple_2')) != 'False': + if str(context.get_value('k_multiple_2')) != 'True': + if context.get_value('mat1_stride_0') <= 561.0: + return 1.2900382135142956 + else: + return 1.5761737616057887 + else: + if context.get_value('num_dims_needs_padding') <= 1.5: + return 1.0472263310239422 + else: + return 1.1727673465762514 + else: + if context.get_value('k') <= 28238.5: + if context.get_value('k/(m*n)') <= 0.00026227018679492176: + return 1.6770542505397175 + else: + return 1.3974785435105923 + else: + if str(context.get_value('mat1_dtype')) != 'torch.bfloat16': + return 1.3952699800111992 + else: + return 1.5759286511628336 + else: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 14119424.0: + return 0.8875772670422478 + else: + if str(context.get_value('mat2_innermost_needs_padding')) != 'True': + return 1.1467728924377265 + else: + return 1.215842963532998 + else: + if context.get_value('arith_intensity') <= 396.8774871826172: + return 0.89940161869551 + else: + if context.get_value('mat2_stride_1') <= 45217.5: + return 0.9964328169353532 + else: + return 0.9493479238294826 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic.py b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic.py new file mode 100644 index 0000000000000000000000000000000000000000..c2dc04cc20f9d38dfaf962b70095984e8b1cf49f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic.py @@ -0,0 +1,315 @@ +import json +import os +from functools import partial +from typing import Any, Callable, Dict, List, Optional + +import torch +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + AHOperation, + Choice, + CHOICE_COL, + Feedback, + FEEDBACK_COL, + get_metadata_str_from_log, +) +from torch._inductor.autoheuristic.learned_heuristic_controller import ( + LearnedHeuristicController, +) +from torch._inductor.ir import ChoiceCaller +from torch._inductor.runtime.runtime_utils import cache_dir +from torch._inductor.utils import get_gpu_shared_memory + + +class LocalFeedback: + """ + To be able to collect data for a choice, a function providing feedback given a choice has to be provided. + LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice + (see pad_mm.py, where the autotuning happens locally, for an example). + """ + + def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None: + self.feedback_fn = feedback_fn + + def __call__(self, choice: Choice) -> Feedback: + return self.feedback_fn(choice) + + +class InconsistentMetadata(Exception): + """ + Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does + not match the metadata it would store if the file didn't exist. + """ + + +class AutoHeuristic: + """ + AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and + generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train + a heuristic (see torchgen/autoheuristic/). + """ + + collected_feedback: Dict[Choice, Feedback] + + def __init__( + self, + fallback: Callable[[], Choice], + choices: List[Choice], + feedback: Optional[LocalFeedback], + context: AHContext, + name: str, + augment_context: Optional[List[AHOperation]] = None, + precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, + ) -> None: + """ + Initializes an instance of the AutoHeuristic class. + + Args: + fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or + AutoHeuristic is in data collection mode. + choices: A list of possible choices the heuristic can make. + feedback: An instance of LocalFeedback that provides feedback for a given choice. + context: Context to store with each choice and feedback. + name: A string that identifies the heuristic. + augment_context: An optional list of AHOperation instances that augment the context. + precondition: A callable that returns a boolean indicating whether AutoHeuristic should run. + """ + self.fallback = fallback + self.choices = choices + self.feedback = feedback + self.context = context + self.name = name + self.collected_feedback = {} + self.augment_context = augment_context + self.metadata = AHMetadata( + get_gpu_shared_memory(), + torch.cuda.get_device_capability(), + self.choices, + self.name, + ) + self.precondition = precondition + + if not self.satisfies_precondition(): + return + + if torch._inductor.config.autoheuristic_log_path == "DEFAULT": + self.log_path = self.get_default_log_path() + else: + self.log_path = torch._inductor.config.autoheuristic_log_path + + if torch._inductor.config.collect_autoheuristic(self.name): + if self.feedback is not None: + for choice in self.choices: + feedback_val = self.feedback(choice) + self.save_data(choice, feedback_val) + + def satisfies_precondition(self) -> bool: + return self.precondition is None or self.precondition( + self.metadata, self.context + ) + + def get_choice(self) -> Choice: + """ + Returns the chosen option based on the value of autoheuristic_use. + If self.name is one of the comma separated strings in autoheuristic_use, + it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option. + """ + + if not self.satisfies_precondition(): + return self.fallback() + + if torch._inductor.config.use_autoheuristic(self.name): + if self.augment_context is not None: + self.context.apply_operations(self.augment_context) + controller = LearnedHeuristicController( + self.metadata, + self.context, + ) + decision = controller.get_decision() + if decision not in self.choices: + # TODO(AlnisM): We might want to allow this in the future + return self.fallback() + if decision is not None: + return decision + return self.fallback() + + def get_top_k_choices( + self, top_k: int, always_included: Optional[List[str]] = None + ) -> Optional[List[Choice]]: + if not self.satisfies_precondition(): + return None + if torch._inductor.config.use_autoheuristic(self.name): + if self.augment_context is not None: + self.context.apply_operations(self.augment_context) + controller = LearnedHeuristicController( + self.metadata, + self.context, + ) + choices = controller.get_decisions_ranked(top_k) + if choices is None: + return None + if always_included is not None: + for choice in always_included: + if choice not in choices: + choices.append(choice) + return choices + return None + + def get_collected_feedback(self, choice: Choice) -> Any: + return self.collected_feedback.get(choice, None) + + @staticmethod + def get_device_identifier() -> str: + # a heuristic might work well for one GPU, but not for another + # we store the collected data per GPU model and learn a heuristic per GPU model + + # TODO(AlnisM): just using the device name for now, but the same GPU model can have different names + device_name = torch.cuda.get_device_name().replace(" ", "_") + return device_name + + def get_default_log_path(self) -> str: + device_name = self.get_device_identifier() + path = f"{cache_dir()}/autoheuristic/{device_name}/" + os.makedirs(path, exist_ok=True) + path += f"{self.name}.txt" + return path + + def serialize_metadata(self) -> str: + metadata_dict = self.metadata.to_dict() + ( + num_features, + cat_features, + ) = self.context.get_numerical_and_categorical_features() + metadata_dict["numerical_features"] = num_features + metadata_dict["categorical_features"] = cat_features + return json.dumps(metadata_dict) + + def save_data(self, choice: Choice, feedback_val: Feedback) -> None: + self.collected_feedback[choice] = feedback_val + log_path = self.log_path + + lines = [] + log_exists = os.path.exists(log_path) + if log_exists: + # if log already exists, make sure it is consistent + metadata = self.serialize_metadata() + existing_metadata = get_metadata_str_from_log(self.log_path) + if existing_metadata != metadata: + raise InconsistentMetadata( + "Given metadata does not match existing metadata" + ) + else: + lines.append(self.serialize_metadata()) + feature_header = self.context.get_feature_names_csv() + header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL + lines.append(header) + + line = "" + feature_values = self.context.get_feature_values_csv() + line += feature_values + "," + choice + "," + str(feedback_val) + lines.append(line) + + with open(log_path, "a") as f: + f.write("\n".join(lines) + "\n") + + +class AutoHeuristicSelectAlgorithm(AutoHeuristic): + """ + AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic + when one wants to use AutoHeuristic for kernel choice selection. + """ + + def __init__( + self, + fallback: Callable[[], Optional[ChoiceCaller]], + choices: List[ChoiceCaller], + input_nodes: List[Any], + context: AHContext, + name: str, + augment_context: Optional[List[AHOperation]] = None, + precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, + ) -> None: + """ + The arguments choices, input_nodes and name have to match the ones used in the call to + autotune_select_algorithm(), e.g. if the following call is made + autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes + have to be used here. + """ + self.input_nodes = input_nodes + self.choicestr2choice: Dict[str, ChoiceCaller] = {} + for choice in choices: + self.choicestr2choice[choice.autoheuristic_id()] = choice + choices_str = list(self.choicestr2choice.keys()) + + def fallback_str() -> str: + fallback_choice = fallback() + if fallback_choice is None: + # TODO: Find a nicer way to handle this + return "unsure" + return fallback_choice.autoheuristic_id() + + super().__init__( + fallback_str, + choices_str, + None, + context, + name, + augment_context, + precondition, + ) + + if ( + torch._inductor.config.collect_autoheuristic(self.name) + and self.satisfies_precondition() + ): + self.register_global_feedback(input_nodes, choices) + + def register_global_feedback( + self, input_nodes: List[Any], choices: List[ChoiceCaller] + ) -> None: + """ + Registers a callback in select_algorithm, which is called with the timing of each choice. + """ + + from torch._inductor.select_algorithm import ( + add_feedback_saver, + create_inputs_key, + create_precompile_key, + ) + + def store_global_feedback( + ah_inputs_key: str, + ah_precompile_key: str, + timings: Dict[ChoiceCaller, float], + name: str, + input_nodes: List[Any], + choices: List[ChoiceCaller], + ) -> None: + current_inputs_key = create_inputs_key(input_nodes) + if current_inputs_key != ah_inputs_key: + return + current_precompile_key = create_precompile_key( + name, current_inputs_key, choices + ) + if current_precompile_key != ah_precompile_key: + return + for choice, time in timings.items(): + self.save_data(choice.autoheuristic_id(), time) + + inputs_key = create_inputs_key(input_nodes) + precompile_key = create_precompile_key(self.name, inputs_key, choices) + feedback_saver = partial(store_global_feedback, inputs_key, precompile_key) + add_feedback_saver(feedback_saver) + + def get_choice_caller(self) -> Optional[ChoiceCaller]: + choice = self.get_choice() + return self.choicestr2choice.get(choice, None) + + def get_top_k_choices_caller( + self, top_k: int, always_included: Optional[List[str]] = None + ) -> Optional[List[ChoiceCaller]]: + choices = self.get_top_k_choices(top_k, always_included) + if choices is None: + return None + return [self.choicestr2choice[choice] for choice in choices] diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4db054aa23f1d8b5153758fa8127d58c880764c9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py @@ -0,0 +1,339 @@ +import functools +from typing import Any, Callable, Dict, List, Tuple + +import torch + + +Feedback = float +Choice = str +Value = Any + +CHOICE_COL = "choice" +FEEDBACK_COL = "feedback" + + +class AHFeature: + """ + The context, that AutoHeuristic stores, is a list of features. AutoHeuristic needs to know whether a feature is + categorical (i.e., not a continuous variable) to learn a machine learning model. + """ + + def __init__(self, name: str, value: Value, is_categorical: bool = False) -> None: + self.name = name + self.value = value + self.is_categorical = is_categorical + + +class AHOperation: + """ + AHOperation can be used to augment the data collected by AutoHeuristic. + One might for example store features like m, k, n, but also want to use + features like m*n, or k*n, to learn a heuristic. Instead of storing features + that can be created from the collected data, one can use AHOperation to + create new features from the collected data. + """ + + def __init__( + self, name: str, func: Callable[[Any], Value], is_categorical: bool = False + ) -> None: + self.name = name + self.func = func + self.is_categorical = is_categorical + + def apply_operation(self, data: Any) -> None: + data[self.name] = self.func(data) + + +class AHContext: + """ + This class is used to specify which information AutoHeuristic should store. For each choice, AutoHeursitic will + store the context and the collected feedback. The context could be something like the shape of a tensor, i.e., + information that will help to learn a heuristic. + """ + + features: List[AHFeature] + context_dict: Dict[str, Value] + + def __init__(self) -> None: + self.features = [] + self.context_dict = {} + + def add_feature( + self, name: str, value: Value, is_categorical: bool = False + ) -> None: + self.features.append(AHFeature(name, value, is_categorical=is_categorical)) + self.context_dict[name] = value + + def get_numerical_and_categorical_features(self) -> Tuple[List[str], List[str]]: + numerical_features = [] + categorical_features = [] + for feature in self.features: + if feature.is_categorical: + categorical_features.append(feature.name) + else: + numerical_features.append(feature.name) + + return numerical_features, categorical_features + + def get_feature_names_csv(self) -> str: + return ",".join(feature.name for feature in self.features) + + def get_feature_values_csv(self) -> str: + return ",".join(str(feature.value) for feature in self.features) + + def get_value(self, name: str) -> Value: + return self.context_dict[name] + + def apply_operations(self, operations: List[AHOperation]) -> None: + for op in operations: + op.apply_operation(self.context_dict) + + +class AHMetadata: + def __init__( + self, + shared_memory: Any, + device_capa: Tuple[int, int], + choices: List[Choice], + name: str, + ) -> None: + # use amount of shared_memory and device_capability to identify GPU + # TODO(AlnisM): there might be a better way to do this + self.shared_memory = shared_memory + self.device_capa = device_capa + self.choices = choices + self.name = name + + def to_dict(self) -> Dict[str, Value]: + return { + "shared_memory": self.shared_memory, + "device_capa": self.device_capa, + "name": self.name, + } + + +def get_metadata_str_from_log(log_path: str) -> str: + with open(log_path, newline="") as file: + json_string = file.readline().strip() + return json_string + + +def check_minsize(context: AHContext, minsize: int) -> bool: + return ( + context.get_value("m") >= minsize + and context.get_value("k") >= minsize + and context.get_value("n") >= minsize + ) + + +def pad_mm_precondition(metadata: AHMetadata, context: AHContext) -> bool: + if metadata.shared_memory == 166912 and metadata.device_capa == (8, 0): + # A100 precondition + return check_minsize(context, 512) + elif metadata.shared_memory == 232448 and metadata.device_capa == (9, 0): + # H100 precondition + return check_minsize(context, 768) + return True + + +def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool: + m = context.get_value("m") + k = context.get_value("k") + n = context.get_value("n") + if m > 128 or k < 1024 or n < 1024: + return False + mat1_iscontig = context.get_value("mat1_iscontig") + mat2_iscontig = context.get_value("mat2_iscontig") + return mat1_iscontig and not mat2_iscontig + + +def get_mult_dims_ops() -> List[AHOperation]: + m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"]) + m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"]) + k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"]) + return [m_times_k_op, m_times_n_op, k_times_n_op] + + +def get_arith_intensity(data: Any) -> float: + m = data["m"] + k = data["k"] + n = data["n"] + if m == 0 or k == 0 or n == 0: + return 0.0 + return m * k * n / (m * k + k * n + m * n) + + +def pad_mm_operations() -> List[AHOperation]: + mult_dims_ops = get_mult_dims_ops() + k_div_m_times_n_op = AHOperation( + "k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"]) + ) + + def bfloat_perf_hit(data: Any) -> bool: + m = data["m"] + k = data["k"] + n = data["n"] + is_bfloat = str(data["mat1_dtype"]) == "torch.bfloat16" + return k > (m * 1024) and k > (n * 1024) and is_bfloat + + bfloat_perf_hit_op = AHOperation( + "bfloat_perf_hit", bfloat_perf_hit, is_categorical=True + ) + + arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity) + dims_need_padding_ops = get_dims_need_padding_ops() + dims_multiple_ops = get_dims_multiple_ops() + is_contig_ops = get_is_contig_ops() + + ah_operations = mult_dims_ops + [ + k_div_m_times_n_op, + bfloat_perf_hit_op, + arith_intensity_op, + ] + ah_operations.extend(dims_need_padding_ops) + ah_operations.extend(dims_multiple_ops) + ah_operations.extend(is_contig_ops) + return ah_operations + + +def between_op(data: Any, dim: str, lower: int, upper: int) -> bool: + return data[dim] >= lower and data[dim] <= upper + + +def between_ops() -> List[AHOperation]: + dims = ["m", "k", "n"] + limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)] + ah_operations = [] + for dim in dims: + for lower, upper in limits: + between_op_fn = functools.partial( + between_op, dim=dim, lower=lower, upper=upper + ) + # using 'LEQ' instead of '<=' because '<=' cannot be exported to dot + between_op_name = f"{lower}LEQ{dim}LEQ{upper}" + ah_operations.append( + AHOperation(between_op_name, between_op_fn, is_categorical=True) + ) + return ah_operations + + +def pow2_op(data: Any, dim: str, exponent: int) -> bool: + return data[dim] == 2**exponent + + +def mm_operations() -> List[AHOperation]: + mult_dims_ops = get_mult_dims_ops() + arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity) + return mult_dims_ops + [arith_intensity_op] + + +def mixed_mm_operations() -> List[AHOperation]: + return mm_operations() + between_ops() + + +def is_multiple(data: Any, dim: str, mult: int) -> bool: + return data[dim] % mult == 0 + + +def get_dims_multiple_ops() -> List[AHOperation]: + multiples = [2, 4, 8, 16, 32] + dims = ["m", "k", "n"] + dims_multiple_ops = [] + for dim in dims: + for mult in multiples: + is_multiple_fn = functools.partial(is_multiple, dim=dim, mult=mult) + dims_multiple_op = AHOperation( + f"{dim}_multiple_{mult}", is_multiple_fn, is_categorical=True + ) + dims_multiple_ops.append(dims_multiple_op) + return dims_multiple_ops + + +def get_dims_need_padding_ops() -> List[AHOperation]: + def mat1_innermost_needs_padding_fn(data: Any) -> bool: + mat1_stride_0 = data["mat1_stride_0"] + mat1_stride_1 = data["mat1_stride_1"] + m_padded_length = data["m_padded_length"] + k_padded_length = data["k_padded_length"] + mat1_innermost_needs_padding = False + if mat1_stride_0 == 1 and m_padded_length != 0: + mat1_innermost_needs_padding = True + if mat1_stride_1 == 1 and k_padded_length != 0: + mat1_innermost_needs_padding = True + return mat1_innermost_needs_padding + + mat1_innermost_op = AHOperation( + "mat1_innermost_needs_padding", + mat1_innermost_needs_padding_fn, + is_categorical=True, + ) + + def mat2_innermost_needs_padding_fn(data: Any) -> bool: + mat2_stride_0 = data["mat2_stride_0"] + mat2_stride_1 = data["mat2_stride_1"] + k_padded_length = data["k_padded_length"] + n_padded_length = data["n_padded_length"] + mat2_innermost_needs_padding = False + if mat2_stride_0 == 1 and k_padded_length != 0: + mat2_innermost_needs_padding = True + if mat2_stride_1 == 1 and n_padded_length != 0: + mat2_innermost_needs_padding = True + return mat2_innermost_needs_padding + + mat2_innermost_op = AHOperation( + "mat2_innermost_needs_padding", + mat2_innermost_needs_padding_fn, + is_categorical=True, + ) + + def num_dims_needs_padding_fn(data: Any) -> int: + m_padded_length = data["m_padded_length"] + k_padded_length = data["k_padded_length"] + n_padded_length = data["n_padded_length"] + num_dims_needs_padding = 0 + if m_padded_length != 0: + num_dims_needs_padding += 1 + if k_padded_length != 0: + num_dims_needs_padding += 1 + if n_padded_length != 0: + num_dims_needs_padding += 1 + return num_dims_needs_padding + + num_dims_op = AHOperation("num_dims_needs_padding", num_dims_needs_padding_fn) + return [mat1_innermost_op, mat2_innermost_op, num_dims_op] + + +def get_is_contig_ops() -> List[AHOperation]: + def mat1_is_contig_fn(data: Any) -> bool: + stride_0 = data["mat1_stride_0"] + stride_1 = data["mat1_stride_1"] + k = data["k"] + return stride_0 == k and stride_1 == 1 + + mat1_is_contig_op = AHOperation( + "mat1_iscontig", mat1_is_contig_fn, is_categorical=True + ) + + def mat2_is_contig_fn(data: Any) -> bool: + stride_0 = data["mat2_stride_0"] + stride_1 = data["mat2_stride_1"] + n = data["n"] + return stride_0 == n and stride_1 == 1 + + mat2_is_contig_op = AHOperation( + "mat2_iscontig", mat2_is_contig_fn, is_categorical=True + ) + + return [mat1_is_contig_op, mat2_is_contig_op] + + +def context_add_strides(context: AHContext, name: str, stride: Tuple[int, ...]) -> None: + for i, s in enumerate(stride): + context.add_feature(f"{name}_stride_{i}", s) + + +def context_add_using_tf32(context: AHContext, dtype: torch.dtype) -> None: + using_tf32 = "not_float_32" + if dtype == torch.float32: + using_tf32 = torch.backends.cuda.matmul.allow_tf32 + context.add_feature("using_tf32", using_tf32, is_categorical=True) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..23663c148df192b85178aa62d685ac0b332f9256 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py @@ -0,0 +1,119 @@ +import importlib +import inspect +import pkgutil +from collections import defaultdict +from typing import Any, Dict, List, Optional + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic + + +def find_and_instantiate_subclasses( + package_name: str, base_class: Any +) -> List[LearnedHeuristic]: + instances = [] + + package = importlib.import_module(package_name) + for _, module_name, _ in pkgutil.walk_packages( + package.__path__, package.__name__ + "." + ): + try: + module_basename = module_name.split(".")[-1] + if not module_basename.startswith("_"): + # learned heuristics start with an underscore + continue + module = importlib.import_module(module_name) + + # look for classes that are subclasses of base_class + for name, obj in inspect.getmembers(module): + if ( + inspect.isclass(obj) + and issubclass(obj, base_class) + and obj != base_class + ): + instance = obj() + instances.append(instance) + except Exception as e: + print(f"Error processing module {module_name}: {e}") + + return instances + + +class LearnedHeuristicController: + """ + Class that finds and instantiates all learned heuristics. It also provides + a way to get the decision of a learned heuristic. + """ + + existing_heuristics: Dict[str, List[LearnedHeuristic]] = defaultdict(list) + """ + A dictionary that stores all the learned heuristics for each optimization. + The key is the optimization name, and the value is a list of LearnedHeuristic objects. + """ + + heuristics_initialized: bool = False + """ + A flag that indicates whether the learned heuristics have been initialized. + Set to true when the get_decision() function is called for the first time. + """ + + def __init__( + self, + metadata: AHMetadata, + context: AHContext, + ) -> None: + self.metadata = metadata + self.context = context + + def get_heuristics(self, name: str) -> List[LearnedHeuristic]: + """ + Returns a list of learned heuristics for the given optimization name. + """ + + if not LearnedHeuristicController.heuristics_initialized: + # learned heuristics are generated into the following package + learned_heuristics_package = "torch._inductor.autoheuristic.artifacts" + + # learned heuristics have to be of type LearnedHeuristic + base_class = LearnedHeuristic + found_heuristics = find_and_instantiate_subclasses( + learned_heuristics_package, base_class + ) + + for learned_heuristic in found_heuristics: + opt_name = learned_heuristic.get_name() + LearnedHeuristicController.existing_heuristics[opt_name].append( + learned_heuristic + ) + LearnedHeuristicController.heuristics_initialized = True + + return LearnedHeuristicController.existing_heuristics[name] + + def get_decision(self) -> Optional[Choice]: + """ + Returns the decision made by the learned heuristic or None if no heuristic was found or the heuristic is unsure + which choice to make. + """ + + heuristics = self.get_heuristics(self.metadata.name) + for heuristic in heuristics: + if heuristic.check_precondition(self.metadata, self.context): + return heuristic.get_decision(self.context, self.metadata.choices) + return None + + def get_decisions_ranked(self, top_k: int) -> Optional[List[Choice]]: + heuristics = self.get_heuristics(self.metadata.name) + for heuristic in heuristics: + if heuristic.check_precondition(self.metadata, self.context): + choices = heuristic.get_decisions_ranked(self.context) + if choices is None: + return None + avail_choices = [ + choice for choice in choices if choice in self.metadata.choices + ] + return avail_choices[:top_k] + return None diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa7671f0223e1bea0b2e34fbdb50053c0db6371 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py @@ -0,0 +1,92 @@ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) + + +class LearnedHeuristic: + """ + LearnedHeuristic is a base class for all learned heuristics. + """ + + def __init__(self) -> None: + pass + + def check_precondition( + self, + metadata: AHMetadata, + context: AHContext, + ) -> bool: + return True + + def get_decision( + self, context: AHContext, choices: List[Choice] + ) -> Optional[Choice]: + return None + + def get_confidence_threshold(self) -> float: + return 1.0 + + def get_name(self) -> str: + return "" + + def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]: + return None + + +class LearnedHeuristicRegression(LearnedHeuristic): + def __init__(self) -> None: + super().__init__() + + def get_feedback(self, context: AHContext, choice: Choice) -> float: + return 1.0 + + def get_decision( + self, context: AHContext, choices: List[Choice] + ) -> Optional[Choice]: + choice2feedback = {} + for choice in choices: + predicted_feedback = self.get_feedback(context, choice) + choice2feedback[choice] = predicted_feedback + sorted_choices_feedback = sorted(choice2feedback.items(), key=lambda t: t[1]) + highest_feedback = sorted_choices_feedback[-1][1] + second_highest_feedback = sorted_choices_feedback[-2][1] + if highest_feedback / second_highest_feedback > self.get_confidence_threshold(): + return sorted_choices_feedback[-1][0] + # We are not sure which choice is the best one + return None + + +class LearnedHeuristicDecision(LearnedHeuristic): + def __init__(self) -> None: + super().__init__() + + def get_choice(self, idx: int) -> Optional[str]: + return None + + def get_decision( + self, context: AHContext, choices: List[Choice] + ) -> Optional[Choice]: + best_choices = self.get_best_choices(context) + if not best_choices: + return None + (best_choice_proba, best_choice_idx) = best_choices[0] + if best_choice_proba <= self.get_confidence_threshold(): + return None + return self.get_choice(best_choice_idx) + + def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]: + feedback_idx_list = self.get_best_choices(context) + if feedback_idx_list is None: + return None + choices = [ + self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list + ] + choices = [choice for choice in choices if choice is not None] + return choices + + def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]: + return [] diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/autotune_process.py b/.venv/lib/python3.11/site-packages/torch/_inductor/autotune_process.py new file mode 100644 index 0000000000000000000000000000000000000000..eea4f8d6573d82c5d629382893920c8256953c6a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/autotune_process.py @@ -0,0 +1,876 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +import ctypes +import dataclasses +import functools +import logging +import os +import queue +import time +import warnings +from concurrent.futures import ThreadPoolExecutor +from ctypes import byref, c_size_t, c_void_p, CDLL +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + TYPE_CHECKING, + Union, +) + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch import multiprocessing +from torch._dynamo.testing import rand_strided +from torch._inductor import ir +from torch._inductor.codecache import ( + CppCodeCache, + CUDACodeCache, + DLLWrapper, + get_hash, + PyCodeCache, +) + + +if TYPE_CHECKING: + from multiprocessing.process import BaseProcess + from multiprocessing.queues import Queue + from types import ModuleType + + from torch._inductor.select_algorithm import TritonTemplateCaller + +from . import config +from .runtime.benchmarking import benchmarker +from .virtualized import V + + +CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" +EXIT_HANDLER_REGISTERED = False + +log = logging.getLogger(__name__) + + +# Used to synchronize between parent and child processes +class Ping: + pass + + +class Pong: + pass + + +class NonzeroWorkspaceNotSupportedError(Exception): + pass + + +@contextlib.contextmanager +def set_cuda_visible_device(device: Optional[int]): + """ + Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the + specified single device. If device is None, don't manipulate the environment. + """ + if device is None: + yield + return + + current = os.environ.get(CUDA_VISIBLE_DEVICES) + os.environ[CUDA_VISIBLE_DEVICES] = str(device) + try: + yield + finally: + if current is None: + del os.environ[CUDA_VISIBLE_DEVICES] + else: + os.environ[CUDA_VISIBLE_DEVICES] = current + + +@dataclasses.dataclass +class TuningProcess: + """ + Abstraction for launching a helper process to benchmark kernels. Spawns + the parent process and uses multiprocessing queues to send benchmark + requests and return results. + """ + + device: Optional[int] = None + process: Optional[BaseProcess] = None + request_queue: Optional[Queue[Any]] = None + response_queue: Optional[Queue[Any]] = None + + @staticmethod + def process_main( + request_queue: Queue[Any], + response_queue: Queue[Any], + ) -> None: + """ + Entry point for the child process. + """ + log.debug( + "Entering TuningProcess child. Visible devices = %s", + os.environ.get(CUDA_VISIBLE_DEVICES), + ) + try: + TuningProcess.workloop(request_queue, response_queue) + except Exception as ex: + log.exception("Exception in TuningProcess") + + @staticmethod + def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None: + """ + Work loop for the benchmarking subprocess. + """ + while True: + obj = request_queue.get() + + if obj is None: + break # None is a sentinel for the child to terminate + elif isinstance(obj, Ping): + response_queue.put(Pong()) + elif isinstance(obj, BenchmarkRequest): + response_queue.put(obj.benchmark()) + else: + raise RuntimeError(f"Invalid request type {type(obj)}") + + def valid(self) -> bool: + """ + True if the sub-process has been initialized. + """ + return ( + self.process is not None + and self.request_queue is not None + and self.response_queue is not None + ) + + def clear(self) -> None: + """ + Reset to an uninitialized state. + """ + self.process = self.request_queue = self.response_queue = None + + def initialize(self) -> None: + """ + Create child process, request/response queues, and do the warm up. + Set the environment to make only the provided GPU device visible + to the process. + """ + if self.valid(): + return + + # cuda runtime does not work with "fork", use "spawn" to start processes. + ctx = multiprocessing.get_context("spawn") + self.request_queue = ctx.Queue() + self.response_queue = ctx.Queue() + + self.process = ctx.Process( + target=self.process_main, + args=( + self.request_queue, + self.response_queue, + ), + ) + assert self.process is not None + with set_cuda_visible_device(self.device): + self.process.start() + + def put(self, obj: Any) -> None: + """ + Push a work item to the child process. + """ + # In case of a prior crash, ensure the subprocess is running + self.initialize() + assert self.request_queue is not None + self.request_queue.put(obj) + + def get( + self, result_timeout=120.0, graceful_timeout=3.0, terminate_timeout=1.0 + ) -> Any: + """ + Get a response from the child process. Raises queue.Empty on timeout + or if the process dies. + + This method is (so far) only used by TuningProcessPool, where torch._inductor.config entries are being used + to populate the timeouts: + + Arguments: + + @param result_timeout: Timeout in seconds, defaults to 120.0 or to + config.max_autotune_subproc_result_timeout_seconds when called by TuningProcessPool + @param graceful_timeout: Timeout in seconds to allow graceful shutdown (SIGTERM is sent after this time). + Defaults to 3.0 or to config.max_autotune_subproc_graceful_timeout_seconds + @param terminate_timeout: Timeout in seconds after SIGTERM, until we send SIGKILL if the process + remains alive. Defaults to 1.0 or to + config.max_autotune_subproc_terminate_timeout_seconds. + Returns: + A response from the child process (Any type) + """ + assert self.process is not None + assert self.response_queue is not None + while True: + try: + remaining_timeout = result_timeout + res = None + while remaining_timeout is not None and remaining_timeout >= 1.0: + remaining_timeout -= 0.5 + try: + res = self.response_queue.get(timeout=0.5) + break + except queue.Empty: + if not self.process.is_alive(): + raise # is being caught a few lines below + if res is None: + res = self.response_queue.get(timeout=remaining_timeout) + return res + except queue.Empty: + status = self.process.exitcode + if status is None: + self.kill( + graceful_timeout=graceful_timeout, + terminate_timeout=terminate_timeout, + ) + else: + # child process crashed + self.clear() + raise + + def terminate(self) -> None: + """ + Signal the child process to terminate. + """ + if self.valid(): + assert self.process is not None + assert self.request_queue is not None + self.request_queue.put(None) + + def wait(self) -> None: + """ + Wait for the child process to exit. + """ + if self.process is not None: + self.process.join() + self.clear() + + def kill(self, graceful_timeout=5.0, terminate_timeout=1.0) -> None: + # Tries to kill the process, using a graceful_timeout in which the process + # is allowed to exit gracefully. If the process is still alive, + # it will be terminated. If that is not sufficient to end it + # within terminate_timeout seconds, it will be killed. + if self.process is not None: + self.terminate() + self.process.join(timeout=graceful_timeout) + if self.process.is_alive(): + log.warning( + "Sending SIGTERM to process with PID %d", + self.process.pid, + ) + self.process.terminate() + self.process.join(timeout=terminate_timeout) + if self.process.is_alive(): + log.error( + "Sending SIGKILL to process with PID %d", + self.process.pid, + ) + self.process.kill() # This should definitely end the process + self.clear() + + +@dataclasses.dataclass +class TuningProcessPool: + """ + Maintains a pool of TuningProcesses to benchmark kernels in parallel + across devices. By default, we create one TuningProcess per device and + set the sub-process environment to make only that device visible. + """ + + processes: Optional[queue.Queue[TuningProcess]] = None + executor: Optional[ThreadPoolExecutor] = None + + def initialize(self) -> None: + """ + Start the child processes. + """ + assert (self.processes is None) == (self.executor is None) + if self.processes is not None: + return + + devices = self.get_device_list() + log.debug("Sub-process autotune device list: %s", devices) + + # Launch the child processes and push a msg to "warm up" + self.processes = queue.Queue() + for device in devices: + p = TuningProcess(device=device) + p.initialize() + p.put(Ping()) + self.processes.put(p) + + # Wait for the initialization to finish + for p in self.processes.queue: + assert isinstance(p.get(result_timeout=None), Pong) + + # Use a thread pool to manage distributing work to the subprocesses. + # Threads block on an available process, so it makes sense to match + # the number of threads with the number of devices. + self.executor = ThreadPoolExecutor(max_workers=len(devices)) + + # Register the exit handler for the parent process so it will terminate + # the child processes. + global EXIT_HANDLER_REGISTERED + if not EXIT_HANDLER_REGISTERED: + EXIT_HANDLER_REGISTERED = True + import atexit + + atexit.register(self.terminate) + + def get_device_list(self) -> Sequence[Optional[int]]: + """ + Gather the list of devices to be used in the pool. + """ + if not config.autotune_multi_device: + # Don't use multiple devices + return [None] + + count = torch.cuda.device_count() + + # If the user specified the visible devices in the env, use those. + if CUDA_VISIBLE_DEVICES in os.environ: + devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")] + assert len(devices) <= count + return devices + + return list(range(count)) + + def terminate(self) -> None: + """ + Signal all child processes to terminate. + """ + if self.executor is not None: + self.executor.shutdown() + self.executor = None + + if self.processes is not None: + for p in self.processes.queue: + p.terminate() + for p in self.processes.queue: + p.wait() + self.processes = None + + def target(self, choice: TritonTemplateCaller) -> float: + """ + Entry point for the thread-pool helper threads: Wait for an open TuningProcess, + remove it from the queue, execute the benchmark in that subprocess, and return + the TuningProcess to the queue. + """ + assert choice.bmreq is not None + assert self.processes is not None + + process = self.processes.get() + process.put(choice.bmreq) + try: + return process.get( + config.max_autotune_subproc_result_timeout_seconds, + config.max_autotune_subproc_graceful_timeout_seconds, + config.max_autotune_subproc_terminate_timeout_seconds, + ) + except queue.Empty: + warnings.warn( + f"Failed to benchmark choice '{choice}'. It will be ignored. " + "Please debug the root cause in case the choice can bring perf gains." + ) + # set to INF so this choice will be ignored + return float("inf") + finally: + self.processes.put(process) + + def benchmark( + self, + choices: List[TritonTemplateCaller], + ) -> Dict[TritonTemplateCaller, float]: + """ + Benchmark each choice in a separate process. + """ + assert self.processes is not None, "Tuning process pool is not initialized" + assert self.executor is not None + + results = {} + + # Use a ThreadExecutorPool to spread the work across the subprocesses and + # to grab subprocesses as soon as they're free. + for choice, result in zip(choices, self.executor.map(self.target, choices)): + results[choice] = result + + return results + + +tuning_pool = TuningProcessPool() + + +LayoutOrBuffer = Union[ir.Layout, ir.Buffer] + + +@dataclasses.dataclass +class TensorMeta: + device: torch.device + dtype: torch.dtype + sizes: torch._prims_common.ShapeType + strides: torch._prims_common.StrideType + offset: int + name: Optional[str] = None + + @classmethod + def from_irnodes( + cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]] + ) -> Union[TensorMeta, List[TensorMeta]]: + if isinstance(irnodes, Sequence): + result: List[Any] = [cls.from_irnodes(x) for x in irnodes] + assert all(isinstance(x, TensorMeta) for x in result) + return result + + node = irnodes + if isinstance(node, ir.Layout): + node = ir.Buffer("fake", node) + + dtype = node.get_dtype() + assert dtype is not None + + return TensorMeta( + device=node.get_device(), + dtype=dtype, + sizes=V.graph.sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + strides=V.graph.sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + offset=V.graph.sizevars.size_hint( + node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + name=node.get_name(), + ) + + def to_tensor(self) -> torch.Tensor: + return rand_strided( + self.sizes, + self.strides, + device=self.device, + dtype=self.dtype, + extra_size=self.offset, + ) + + +@dataclasses.dataclass +class BenchmarkRequest: + """ + Only handle triton template benchmark for now. The extern kernel benchmark + can be done inside the same process since they usually don't cause crash. + + Important: Instances of this class and subclasses have to be serializable + across process boundaries. Do not put CUDA Tensors in here! + """ + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + ) -> None: + # the kernel name defined in the module + self.kernel_name = kernel_name + + if isinstance(input_tensor_meta, TensorMeta): + input_tensor_meta = [input_tensor_meta] + self.input_tensor_meta = input_tensor_meta + + if isinstance(output_tensor_meta, (tuple, list)): + assert len(output_tensor_meta) == 1 + output_tensor_meta = output_tensor_meta[0] + self.output_tensor_meta = output_tensor_meta + + self.extra_args = extra_args + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + raise NotImplementedError + + def cleanup_run_fn(self) -> None: + pass + + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + raise NotImplementedError + + def benchmark( + self, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + debug = log.isEnabledFor(logging.DEBUG) + if debug: + start_ts = time.time() + + # create args and out tensor + if output_tensor is None: + assert len(input_tensors) == 0 + input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta) + output_tensor = self.output_tensor_meta.to_tensor() + + if debug: + create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + start_ts = time.time() + try: + fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor) + except NonzeroWorkspaceNotSupportedError: + # Skipping all ops with nonzero workspace requirements + log.info("Skipping op due to nonzero workspace requirement") + return float("inf") + + if debug: + load_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + start_ts = time.time() + + out = self.do_bench(fn, *input_tensors, output_tensor) + + if debug: + bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + log.debug( + "InChildProcess %s: load %f, create tensor %f, bench %f", + str(self), + load_elapse, # type: ignore[possibly-undefined] + create_tensor_elapse, # type: ignore[possibly-undefined] + bench_elapse, + ) + self.cleanup_run_fn() + return out + + +class TestBenchmarkRequest(BenchmarkRequest): + """ + Supports unit testing. Defined in this file so that the TuningProcess + sub-process knows how to unpickle these objects. + """ + + def __init__(self, value: Optional[float] = None) -> None: + self.value = value + + def benchmark( + self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None + ) -> float: + if self.value is None: + raise Exception("Failed to run") # noqa: TRY002 + return self.value + + +class GPUDeviceBenchmarkRequest(BenchmarkRequest): + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + device_idx_set = { + tensor.device.index + for tensor in [*input_tensors, output_tensor] + if isinstance(tensor, torch.Tensor) + and tensor.is_cuda + and tensor.device.index is not None + } + assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}" + if len(device_idx_set) == 1: + device_idx = next(iter(device_idx_set)) + else: + device_idx = torch.cuda.current_device() + + with torch.cuda.device(device_idx): + out = benchmarker.benchmark_gpu(fn) + torch.cuda.synchronize() # shake out any CUDA errors + + return out + + +class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put CUDA Tensors in here! + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + module_path: str, # the path of the module defining the triton kernel + module_cache_key: str, + grid: List[int], + num_stages: int, + num_warps: int, + matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction. + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.module_path = module_path + self.module_cache_key = module_cache_key + self.grid = grid + self.num_stages = num_stages + self.num_warps = num_warps + self.matrix_instr_nonkdim = matrix_instr_nonkdim + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) + log.debug( + "benchmark module key: %s, path: %s", + self.module_cache_key, + self.module_path, + ) + + run_method = getattr(mod, self.kernel_name).run + extra_args = list(self.extra_args) + + # Newer version of triton add warmup argument to JITFunction.run. + # This code handles backward-compatibility. + warmup_arg = {} + import inspect + + if "warmup" in inspect.signature(run_method).parameters: + warmup_arg["warmup"] = False + + from torch._C import _cuda_getCurrentRawStream as get_raw_stream + + if torch.version.hip and self.matrix_instr_nonkdim != 0: + return functools.partial( + run_method, + *input_tensors, + output_tensor, + *self.extra_args, + grid=self.grid, + **warmup_arg, + stream=get_raw_stream(self.output_tensor_meta.device.index), + ) + else: + return functools.partial( + run_method, + *input_tensors, + output_tensor, + *self.extra_args, + grid=self.grid, + **warmup_arg, + stream=get_raw_stream(self.output_tensor_meta.device.index), + ) + + def precompile(self): + mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) + getattr(mod, self.kernel_name).precompile() + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}" + + +class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put CUDA Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.workspace_size: int = 0 + self.workspace: Optional[torch.Tensor] = None + self.DLL: Optional[DLLWrapper] = None + self._workspace_size_updated = False + self.hash_key: str = "" + self.source_file: str = "" + self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so") + + def precompile(self): + # Prepopulate CUDACodeCache + # may happen in separate Threadpool + log.debug("Precompiling %s", self) + CUDACodeCache.compile(self.source_code, "so") + log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + self.ensure_dll_loaded() + self.update_workspace_size() + args = [ + c_void_p(tensor.data_ptr()) + for tensor in list(input_tensors) + [output_tensor] + ] + log.debug( + "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + run_method = getattr(self.DLL, self.kernel_name) + workspace_ptr = c_void_p(0) + if self.workspace_size > 0: + self.workspace = torch.zeros( + (self.workspace_size + 7) // 8, + dtype=torch.float64, + device=output_tensor.device, + ) + workspace_ptr = c_void_p(self.workspace.data_ptr()) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *self.extra_args, + None, # null workspace size ptr + workspace_ptr, # set workspace ptr, + stream_ptr, + ) + + def update_workspace_size(self) -> None: + if self._workspace_size_updated: + return + self.ensure_dll_loaded() + unique_input_count = len({meta.name for meta in self.input_tensor_meta}) + args = [c_void_p(None) for _ in range(unique_input_count + 1)] + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + + run_method = getattr(self.DLL, self.kernel_name) + # Retrieve workspace_size and initialize workspace. + c_workspace_size = c_size_t() + run_method( + *args, # input ptrs and output ptrs + *self.extra_args, + byref( + c_workspace_size + ), # set workspace size ptr to retrieve workspace size + None, # null workspace ptr + stream_ptr, + ) + torch.cuda.synchronize() # shake out any CUDA errors + self.workspace_size = c_workspace_size.value + log.debug( + "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950 + self.workspace_size, + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + self._workspace_size_updated = True + + def ensure_dll_loaded(self): + if self.DLL is None: + self.DLL, self.hash_key, self.source_file = CUDACodeCache.load( + self.source_code, "so" + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + self.workspace = None + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" + + +class CPUDeviceBenchmarkRequest(BenchmarkRequest): + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + return benchmarker.benchmark_cpu(fn) + + +class CppBenchmarkRequest(CPUDeviceBenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.hash_key = get_hash(source_code) + self.DLL: Optional[Union[CDLL, ModuleType]] = None + + def precompile(self): + # Prepopulate CppCodeCache + # may happen in separate Threadpool + log.debug("Precompiling %s", self) + CppCodeCache.load(self.source_code, cuda=False) + log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + # TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf + self.DLL = CppCodeCache.load(self.source_code, cuda=False) + args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]] + log.debug( + "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.DLL, + args, + self.extra_args, + ) + run_method = getattr(self.DLL, self.kernel_name) + # Assume only size with type ctypes.c_ulonglong in extra_args + assert all(isinstance(arg, ctypes.c_ulonglong) for arg in self.extra_args) + run_method.argtypes = [ctypes.c_ulonglong] * ( + len(args) + len(list(self.extra_args)) + ) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *self.extra_args, + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + """ + Check close attr due to it crash on Windows. + """ + if hasattr(self.DLL, "close"): + self.DLL.close() + + def __str__(self) -> str: + return f"{self.kernel_name=}" + + +def benchmark_in_sub_process( + choices: List[TritonTemplateCaller], +) -> Dict[TritonTemplateCaller, float]: + """ + Do benchmarking in a subprocess and return the perf number (latency). + """ + return tuning_pool.benchmark(choices) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/codecache.py b/.venv/lib/python3.11/site-packages/torch/_inductor/codecache.py new file mode 100644 index 0000000000000000000000000000000000000000..59cc47ac06c94e492bed109e60c28bb6b7e61ba0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/codecache.py @@ -0,0 +1,3353 @@ +from __future__ import annotations + +import base64 +import copyreg +import dataclasses +import functools +import hashlib +import importlib +import io +import json +import logging +import os +import pickle +import pkgutil +import re +import shlex +import shutil +import struct +import subprocess +import sys +import sysconfig +import tempfile +import textwrap +import threading +import warnings +from bisect import bisect_right +from copy import copy +from ctypes import c_void_p, CDLL, cdll +from datetime import timedelta +from functools import partial +from pathlib import Path +from time import time, time_ns +from types import ModuleType +from typing import ( + Any, + Callable, + cast, + Counter, + Dict, + Generator, + List, + NoReturn, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import TypeAlias + +import torch +import torch.distributed as dist +from torch import SymInt, Tensor +from torch._dynamo.utils import counters, dynamo_timed, get_chromium_event_logger +from torch._inductor import config, exc, metrics +from torch._inductor.codegen.cuda import cuda_env +from torch._inductor.codegen.rocm.compile_command import ( + rocm_compile_command, + rocm_compiler, +) +from torch._utils_internal import log_cache_bypass + +from .utils import _align + + +T = TypeVar("T") + + +if TYPE_CHECKING: + from collections.abc import KeysView + + from .remote_cache import JsonDataTy, RemoteCache + + +""" +codecache.py, cpp_builder.py and cpu_vec_isa.py import rule: +https://github.com/pytorch/pytorch/issues/124245#issuecomment-2197778902 +""" +from torch._inductor.cpp_builder import ( + _set_gpu_runtime_env, + _transform_cuda_paths, + CppBuilder, + CppOptions, + CppTorchCudaOptions, + get_compiler_version_info, + get_cpp_compiler, + get_name_and_dir_from_output_file_path, + normalize_path_separator, +) +from torch._inductor.cpu_vec_isa import pick_vec_isa +from torch._inductor.cudagraph_utils import ( + BoxedDeviceIndex, + CudagraphCachedInfo, + log_cudagraph_skip_and_bump_counter, +) +from torch._inductor.runtime.compile_tasks import ( + _module_to_triton_kernel, + _reload_python_module, + _reload_python_module_in_subproc, +) +from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir +from torch._inductor.utils import ( + ALIGN_BYTES, + align_inputs_from_check_idxs, + BoxedBool, + clear_on_fresh_inductor_cache, + is_linux, + is_windows, + set_tracing_context_output_strides, +) +from torch._logging import trace_structured +from torch._subclasses.fake_tensor import ( + extract_tensor_metadata, + FakeTensor, + TensorMetadata, +) +from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv + + +if TYPE_CHECKING: + from concurrent.futures import Future + + from torch._inductor.graph import GraphLowering + from torch._inductor.ir import ChoiceCaller + from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta + + +_HERE = os.path.abspath(__file__) +_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) +_LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld") + +_IS_WINDOWS = sys.platform == "win32" + +if config.is_fbcode(): + from triton.fb import build_paths + from triton.fb.build import _run_build_command + + from torch._inductor.fb.utils import ( + log_global_cache_errors, + log_global_cache_stats, + log_global_cache_vals, + use_global_cache, + ) +else: + + def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + pass + + def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + pass + + def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + pass + + def use_global_cache() -> bool: # type: ignore[misc] + return False + + +output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") + +LOCK_TIMEOUT = 600 + +_IS_WINDOWS = sys.platform == "win32" + + +log = logging.getLogger(__name__) + + +def cpp_wrapper_cache_dir(name: str) -> str: + cu_str = ( + "cpu" + if torch.version.cuda is None + else f'cu{torch.version.cuda.replace(".", "")}' + ) + python_version = f"py{sys.version_info.major}{sys.version_info.minor}" + build_folder = f"{python_version}_{cu_str}" + + cpp_wrapper_dir = os.path.join(cache_dir(), build_folder) + cpp_wrapper_build_directory = os.path.join(cpp_wrapper_dir, name) + os.makedirs(cpp_wrapper_build_directory, exist_ok=True) + return cpp_wrapper_build_directory + + +def get_cpp_wrapper_cubin_path_name() -> str: + return "cubin_path" if torch.version.hip is None else "hsaco_path" + + +class CacheBase: + @staticmethod + @functools.lru_cache(None) + def get_system() -> Dict[str, Any]: + try: + from triton.compiler.compiler import triton_key + + # Use triton_key instead of triton.__version__ as the version + # is not updated with each code change + triton_version = triton_key() + except ModuleNotFoundError: + triton_version = None + + try: + system: Dict[str, Any] = { + "device": {"name": None}, + "version": { + "triton": triton_version, + }, + } + device_properties = torch.cuda.get_device_properties( + torch.cuda.current_device() + ) + if torch.version.cuda is not None: + system["device"]["name"] = device_properties.name + system["version"]["cuda"] = torch.version.cuda + else: + system["device"]["name"] = device_properties.gcnArchName + system["version"]["hip"] = torch.version.hip + except (AssertionError, RuntimeError): + # If cuda is not installed, none of the above config is relevant. + system = {} + + system["hash"] = hashlib.sha256( + json.dumps(system, sort_keys=True).encode("utf-8") + ).hexdigest() + + return system + + @staticmethod + @clear_on_fresh_inductor_cache + @functools.lru_cache(None) + def get_local_cache_path() -> Path: + return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"])) + + @staticmethod + @functools.lru_cache(None) + def get_global_cache_path() -> Optional[Path]: + return ( + Path(os.path.join(config.global_cache_dir, CacheBase.get_system()["hash"])) + if config.global_cache_dir is not None + else None + ) + + def __init__(self) -> None: + self.system = CacheBase.get_system() + + def get_local_cache(self) -> Dict[str, Any]: + local_cache_path = self.get_local_cache_path() + if not local_cache_path.is_file(): + return {} + with open(local_cache_path) as local_cache_fp: + local_cache = json.load(local_cache_fp) + return local_cache["cache"] + + def update_local_cache(self, local_cache: Dict[str, Any]) -> None: + local_cache_path = self.get_local_cache_path() + write_atomic( + str(local_cache_path), + json.dumps({"system": self.system, "cache": local_cache}, indent=4), + make_dirs=True, + ) + + +class LocalCache(CacheBase): + def lookup(self, *keys: str) -> Optional[Dict[str, Any]]: + cache = self.get_local_cache() + + sub_cache = cache + for key in keys: + if key in cache: + sub_cache = cache[key] + else: + return None + + return sub_cache + + def set_value(self, *keys: str, value: Any) -> None: + cache = self.get_local_cache() + + sub_cache = cache + for key in keys[0:-1]: + sub_cache.setdefault(key, {}) + sub_cache = sub_cache[key] + sub_cache[keys[-1]] = value + + self.update_local_cache(cache) + + +class PersistentCache(CacheBase): + @functools.lru_cache(None) # noqa: B019 + def get_global_cache(self) -> Dict[str, Any]: + global_cache_path = self.get_global_cache_path() + if global_cache_path is None or not global_cache_path.is_file(): + return {} + with open(global_cache_path) as global_cache_fp: + global_cache = json.load(global_cache_fp) + return global_cache["cache"] + + def lookup( + self, + choices: List[ChoiceCaller], + op: str, + inputs: str, + benchmark: Optional[Callable[[Any], Dict[ChoiceCaller, float]]], + ) -> Dict[ChoiceCaller, float]: + """ + Check to see if we have benchmarked the given choice callers. For each + choice caller: + + 1. Check global_cache[op][inputs][choice][precision], return benchmark if cached. + 2. Check local_cache[op][inputs][choice][precision], return benchmark if cached. + 3. If benchmark is not None: + a. `max_autotune_gemm=True`: benchmark the choice, update + local_cache[op][inputs][choice], and return the benchmark. + b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing. + """ + precision = torch.get_float32_matmul_precision() + + log_stats = partial(log_global_cache_stats, self.system, op, inputs, precision) + log_vals = partial(log_global_cache_vals, self.system, op, inputs, precision) + log_errors = partial( + log_global_cache_errors, self.system, op, inputs, precision + ) + timings = {} + + def check_cache(cache: Dict[str, Any], callback: Any = None) -> bool: + """Check if `cache` contains data for all the choices""" + hit = True + for choice in choices: + choice_hash = choice.hash_key() + if choice_hash in cache.get(op, {}).get(inputs, {}).get(precision, {}): + # cache hit + timings[choice] = cache[op][inputs][precision][choice_hash] + else: + # cache miss + hit = False + break + if callback: + callback(cached=hit) + return hit + + if config.max_autotune or config.max_autotune_gemm: + local_cache = self.get_local_cache() if config.autotune_local_cache else {} + # check local cache first since it is data specific to the current machine + if ( + not check_cache(local_cache) + and not ( + use_global_cache() + and check_cache(self.get_global_cache(), callback=log_stats) + ) + and benchmark is not None + ): + try: + # re-benchmark everything to try to get consistent numbers from the same machine + timings = benchmark(choices) + assert all(choice in timings for choice in choices) + local_cache.setdefault(op, {}) + local_cache[op].setdefault(inputs, {}).setdefault(precision, {}) + for choice, timing in timings.items(): + local_cache[op][inputs][precision][choice.hash_key()] = timing + except RuntimeError as e: + # catch and log autotuning failures + log_errors(e) + raise e + + self.update_local_cache(local_cache) + + timings_to_log = { + choice.hash_key(): timings[choice] for choice in choices + } + log_vals(timings_to_log) + elif use_global_cache(): + # only check global cache, not local one + check_cache(self.get_global_cache(), callback=log_stats) + # may have a partial cache hit, where not everything is benchmarked + + return timings + + +def get_lock_dir() -> str: + lock_dir = os.path.join(cache_dir(), "locks") + if not os.path.exists(lock_dir): + os.makedirs(lock_dir, exist_ok=True) + return lock_dir + + +def sha256_hash(data: bytes) -> str: + # [:51] to strip off the "Q====" suffix common to every hash value. + return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower() + + +def code_hash(code: Union[str, bytes], extra: str = "") -> str: + hashing_str = code if isinstance(code, bytes) else code.encode("utf-8") + if extra != "": + hashing_str = hashing_str + b"||" + extra.encode("utf-8") + return "c" + sha256_hash(hashing_str) + + +def get_path( + basename: str, extension: str, specified_dir: str = "" +) -> Tuple[str, str, str]: + if specified_dir: + if os.path.isabs(specified_dir): + subdir = specified_dir + else: + subdir = os.path.join(cache_dir(), specified_dir) + else: + subdir = os.path.join(cache_dir(), basename[1:3]) + path = os.path.join(subdir, f"{basename}.{extension}") + return basename, subdir, path + + +def get_hash( + content: Union[str, bytes], extra: str = "", hash_type: str = "code" +) -> str: + if hash_type == "code": + return code_hash(content, extra) + if hash_type in ["cubin", "hsaco", "spv"]: + return code_hash(repr(content)) + raise AssertionError(f"Unknown hash type {hash_type}") + + +def write( + content: Union[str, bytes], + extension: str, + extra: str = "", + hash_type: str = "code", + specified_dir: str = "", +) -> Tuple[str, str]: + # use striped content to compute hash so we don't end up with different + # hashes just because the content begins/ends with different number of + # spaces. + key: str = get_hash(content.strip(), extra, hash_type) + basename, subdir, path = get_path(key, extension, specified_dir) + encode_utf_8: bool = hash_type == "code" + if not os.path.exists(path): + write_atomic(path, content, make_dirs=True) + return basename, path + + +def write_text(text: str) -> str: + """ + Write the `text` to a file and return the path computed based on the hash. + """ + return write(text, "txt")[1] + + +def write_atomic( + path_: str, + content: Union[str, bytes], + make_dirs: bool = False, + encode_utf_8: bool = False, +) -> None: + # Write into temporary file first to avoid conflicts between threads + # Avoid using a named temporary file, as those have restricted permissions + assert isinstance( + content, (str, bytes) + ), "Only strings and byte arrays can be saved in the cache" + path = Path(path_) + if make_dirs: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" + write_mode = "w" if isinstance(content, str) else "wb" + with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: + f.write(content) + tmp_path.rename(path) + + +@dataclasses.dataclass +class TensorMetadataAndValues: + """ + TensorMetadata plus the elements as a list of raw values. + Used for hashing inlined constants. + """ + + tensor_metadata: TensorMetadata + values: List[Any] + + +def _ident(x: T) -> T: + return x + + +def extract_tensor_metadata_for_cache_key( + device_map: Dict[torch.device, torch.device], t: Tensor +) -> TensorMetadata: + """ + Extracts the tensor metadata and removes fields of the TensorMetadata + that are not needed for caching + """ + meta = extract_tensor_metadata(t) + if not hasattr(t, "_is_inductor_static"): + meta = dataclasses.replace(meta, storage_offset=0, storage_bytes=None) + + # The pickle implementation avoids serializing the same object more than once. + # That behavior means the byte stream we create to hash will vary if, for example, + # we see two tensor objects with the same device, but the torch.device object is + # actually the same object vs. merely equivalent. We want to produce the same hash + # value in either situation, so we memoize the device objects and always reference + # the same object for a given device. It's possible other metadata fields deserve + # the same treatment, but so far we've only observed this issue with the device. + if meta.device not in device_map: + device_map[meta.device] = meta.device + meta = dataclasses.replace(meta, device=device_map[meta.device]) + + return meta + + +def _reduce_fake_tensor( + device_map: Dict[torch.device, torch.device], t: Tensor +) -> Tuple[Callable[[T], T], Tuple[TensorMetadata]]: + """ + See FxGraphCachePickler. Custom reducer to pickle FakeTensors. + """ + metadata = extract_tensor_metadata_for_cache_key(device_map, t) + return (_ident, (metadata,)) + + +def _reduce_tensor( + device_map: Dict[torch.device, torch.device], t: Tensor +) -> Tuple[Callable[[T], T], Tuple[TensorMetadataAndValues]]: + """ + See FxGraphCachePickler. Custom reducer to pickle Tensors. + If we see tensors, we know they're constants stored as attributes on + the GraphModule. Include the values in the key calculation. Small + tensors will be inlined, so we can't serve the same cache entry for + different values anyway. Large constants are treated as parameters, + so we could conceivably reuse a cache entry. To do that, however, + PyCodeCache would need more complexity to create a new module from its + cache, but with the right constants attached as attributes. + """ + if t.is_mkldnn: + # TODO: These tensors don't currently pickle, so we can't cache a + # compiled graph containing them. Just fail now. If mkldnn tensors + # get pickling support, we can remove this. + raise BypassFxGraphCache("mkldnn tensors unpickleable.") + + # Very large tensors could be expensive to copy to cpu and hash. Let's + # at least report if we find slowness. + start = time() + values = t.tolist() + elapsed = time() - start + if elapsed > 1.0: + warnings.warn( + f"FX graph cache handling of a large constant took {elapsed:.1}s. Please file an issue." + ) + + metadata = extract_tensor_metadata_for_cache_key(device_map, t) + return (_ident, (TensorMetadataAndValues(metadata, values),)) + + +def _reduce_symint(s: SymInt) -> Tuple[Callable[[T], T], Tuple[str]]: + """ + See FxGraphCachePickler. Custom reducer to pickle SymInts. + """ + # For hashing purposes, we only care about the name of the symbol and + # not the backed value. We evaluate guards stored with a cached graph + # to ensure a cached entity with SymInt args is safe to reuse. + return (_ident, (str(s),)) + + +def _reduce_unsupported(s: Any) -> NoReturn: + """ + See FxGraphCachePickler. Custom reducer to handle any objects that we don't + support and therefore raise to bypass caching. + """ + raise BypassFxGraphCache("Reduce unsupported.") + + +class FxGraphCachePickler(pickle.Pickler): + """ + Custom pickler to customize the pickling of some objects (Tensors), only for the + purpose of computing a hash for keying into the FxGraphCache. Tensors contain + objects that don't pickle and/or vary between runs, and we want to capture the + data that allow us to compute a stable, but safe hash. + """ + + # See extract_tensor_metadata_for_cache_key. Whenever we extract metadata during + # pickling, we make sure devices always reference the same torch.device object. + _device_map: Dict[torch.device, torch.device] = {} + + dispatch_table = copyreg.dispatch_table.copy() + dispatch_table[FakeTensor] = functools.partial(_reduce_fake_tensor, _device_map) + dispatch_table[torch.Tensor] = functools.partial(_reduce_tensor, _device_map) + dispatch_table[torch.SymInt] = _reduce_symint + dispatch_table[ + torch.fx.experimental._backward_state.BackwardState + ] = _reduce_unsupported + + @classmethod + def dumps(cls, obj: Any) -> bytes: + """ + Pickle an object using the FxGraphCachePickler. + """ + with io.BytesIO() as stream: + pickler = cls(stream) + # TODO: pickler.fast is technically deprecated. Will this work on new python versions? + pickler.fast = True # Run with pickler.fast so it doesn't intern strings, making the hash result more predictable + try: + pickler.dump(obj) + except (TypeError, AttributeError) as e: + # Some configs options are callables, e.g., post_grad_custom_pre_pass, + # and may not pickle. + log.warning("Can't pickle", exc_info=True) + raise BypassFxGraphCache("Config options may be unpickleable.") from e + return stream.getvalue() + + @classmethod + def get_hash(cls, obj: Any) -> str: + """ + Serialize an object using the FxGraphCachePickler and return a hash + of the pickled object. + """ + serialized_data = cls.dumps(obj) + return sha256_hash(serialized_data) + + @classmethod + def debug_lines(cls, inp: FxGraphHashDetails) -> List[str]: + """ + Get a printable string describing in more detail all the attributes + comprising an object. Useful for debugging when one graph hashes + to a different value than another. + """ + + def get_str(obj: Any) -> str: + if isinstance(obj, torch.Tensor): + return str(extract_tensor_metadata_for_cache_key(cls._device_map, obj)) + elif isinstance(obj, bytes): + return "" + elif type(obj) in cls.dispatch_table: + # Run the reducer on the object + return str(cls.dispatch_table[type(obj)](obj)[1]) + else: + return str(obj) + + lines = [] + for attr, obj in vars(inp).items(): + if isinstance(obj, list): + for ii in range(len(obj)): + h = cls.get_hash(obj[ii]) + lines.append(f"[{h}] {attr}[{ii}]: {get_str(obj[ii])}") + elif isinstance(obj, dict): + for k, v in obj.items(): + h = cls.get_hash(v) + lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}") + else: + h = cls.get_hash(obj) + lines.append(f"[{h}] {attr}: {get_str(obj)}") + return lines + + +def build_code_hash( + roots: List[str] | None, prefix: str, hasher: hashlib._Hash +) -> None: + for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name): + spec = lib.module_finder.find_spec(lib.name, None) + assert spec is not None + module = spec.origin + assert module is not None + with open(module, "rb") as f: + hasher.update(spec.name.encode("utf-8")) + hasher.update(f.read()) + if lib.ispkg: + # need to also hash submodules + build_code_hash(spec.submodule_search_locations, f"{spec.name}.", hasher) + + +@functools.lru_cache(None) +def torch_key() -> bytes: + """ + Compute a key that contains relevant information about torch source files + """ + if not config.is_fbcode(): + + def get_code_hash(root: str) -> bytes: + # This function isn't meant to be used outside of torch_key, just a + # helper for clarity. Instead, use torch_key() directly when you need + # a hash representing the state of the source code. + extra_files = ( + "codegen/aoti_runtime/interface.cpp", + "codegen/aoti_runtime/implementation.cpp", + "codegen/cpp_prefix.h", + "script.ld", + ) + inductor_root = os.path.dirname(__file__) + extra_files = [os.path.join(inductor_root, x) for x in extra_files] + hasher = hashlib.sha256() + hasher.update(torch.__version__.encode("utf-8")) + build_code_hash([root], "", hasher) + for path in extra_files: + if os.path.exists(path): + with open(path, "rb") as f: + hasher.update(f.read()) + return hasher.digest() + + return get_code_hash(_TORCH_PATH) + + from libfb.py import parutil + + return parutil.get_file_contents("torch/src_hash.txt").rstrip().encode("ascii") + + +def get_inductor_root() -> str: + return os.path.dirname(__file__) + + +@dataclasses.dataclass +class OrderedSetHolder: + """ + See FxGraphHashDetails. Holds a sorted list to support stable hashing + of set kwargs. + """ + + items: List[Any] + + +class BypassFxGraphCache(Exception): + """ + Exception to indicate that the FxGraphCache should be bypassed. + """ + + +class FxGraphHashDetails: + """ + Object to capture all the details for a compiled FX graph relevant to computing + a safe and stable cache key. + """ + + # Excluded kwargs param that are not stable between runs + EXCLUDED_KWARGS = ["graph_id"] + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], + ) -> None: + self.gm = gm + self.example_inputs = example_inputs + + # Order kwargs so hashing is stable to changes in kwarg order. + self.fx_kwargs = {} + for k in sorted(fx_kwargs): + if k not in self.EXCLUDED_KWARGS: + if type(fx_kwargs[k]) is set: + # Special case to handle set params. Python sets can't be + # ordered, so sort the elements and store them in a proxy. + self.fx_kwargs[k] = OrderedSetHolder(sorted(fx_kwargs[k])) + else: + self.fx_kwargs[k] = fx_kwargs[k] + + # Alignment checks + self.inputs_to_check = inputs_to_check + + # 'Deterministic algorithms' can affect codegen via lowering to cuda kernels. + self.deterministic_algorithms_settings = ( + torch.are_deterministic_algorithms_enabled(), + torch.is_deterministic_algorithms_warn_only_enabled(), + torch.utils.deterministic.fill_uninitialized_memory, # type: ignore[attr-defined] + ) + + # Global settings affecting matmul codegen. + self.cuda_matmul_settings = ( + torch.backends.cuda.matmul.allow_tf32, + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction, + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction, + ) + + # Also hash on various system info (including the triton compiler version). + self.torch_version = torch_key() + self.system_info = CacheBase.get_system() + self.inductor_config = config.save_config_portable() + + def debug_lines(self) -> List[str]: + """ + Get a printable string describing in more detail all the attributes + comprising this object. Useful for debugging when one graph hashes + to a different value than another. + """ + return FxGraphCachePickler.debug_lines(self) + + +def compiled_fx_graph_hash( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], +) -> Tuple[str, List[str]]: + """ + Generate a unique hash of the FX graph for caching. + """ + details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check) + # The prefix distinguishes among the other kinds of objects we + # cache in this module. + key = "f" + FxGraphCachePickler.get_hash(details) + debug_lines = details.debug_lines() + debug_str = "\n".join(debug_lines) + log.debug(f"FX graph cache hash details for key {key}:\n{debug_str}") # noqa: G004 + return key, debug_lines + + +def cudagraph_post_compile( + example_inputs: List[Any], + compiled_graph: CompiledFxGraph, + cudagraphs: BoxedBool, +) -> None: + """ + Checks for any reasons not to run cudagraphs and then + runs it on compiled_graph. + Mutates the `compiled_graph.current_callable` and `cudagraphs` + """ + assert compiled_graph.current_callable is not None + assert compiled_graph.cudagraph_info is not None + cached_info = compiled_graph.cudagraph_info + cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons + inputs_to_check = compiled_graph.inputs_to_check + boxed_forward_device_index = compiled_graph.boxed_forward_device_index + is_inference = compiled_graph.fx_kwargs["is_inference"] + is_backward = compiled_graph.fx_kwargs["is_backward"] + + if not cudagraph_fail_reasons: + fx_kwargs = compiled_graph.fx_kwargs + static_input_idxs = fx_kwargs["static_input_idxs"] + + placeholders = cached_info.placeholders + stack_traces = cached_info.stack_traces + if not config.triton.cudagraph_trees: + # Force specialize all inputs so that CUDA graphs will work + for t in example_inputs: + if isinstance(t, torch.SymInt): + int(t) # guard + + if ( + boxed_forward_device_index is not None + and not is_inference + and not is_backward + ): + boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs))) + + from .compile_fx import cudagraphify + + current_callable = compiled_graph.current_callable + assert current_callable is not None + compiled_graph.current_callable = cudagraphify( + current_callable, + static_input_idxs=static_input_idxs, + device_index=next(iter(compiled_graph.device_idxs)), + stack_traces=stack_traces, + is_backward=is_backward, + is_inference=is_inference, + constants=tuple(compiled_graph.constants.values()), + placeholders=placeholders, + mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs), + ) + + else: + BoxedBool.disable(cudagraphs) + + # See [Backward Generation Handling] + # if cudagraph'd the forward and set the device, we need to let the cudagraph manager + # know we are we running the backward even if we will not run it in cudagraphs + if is_backward and config.triton.cudagraph_trees: + assert boxed_forward_device_index is not None + assert boxed_forward_device_index.value is not None + compiled_graph_callable = compiled_graph.current_callable + + manager = torch._inductor.cudagraph_trees.get_manager( + boxed_forward_device_index.value, create_if_none_exists=False + ) + # should already exist from forward + assert manager is not None + + def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]: + manager.set_to_running_backward() # type: ignore[union-attr] + return compiled_graph_callable(new_inputs) + + compiled_graph.current_callable = compiled_artifact + + if "cuda" in compiled_graph.device_types: + # prefer better disable_cudagraphs_reason bc stack trace + # TODO: migrate all disable reasons to stack trace, refactor + if compiled_graph.disabled_cudagraphs_reason: + log_cudagraph_skip_and_bump_counter( + compiled_graph.disabled_cudagraphs_reason + ) + else: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {cudagraph_fail_reasons}" + ) + + +def maybe_realign_inputs( + ran_cudagraphs: BoxedBool, + compiled_graph: CompiledFxGraph, + inputs_to_check: Sequence[int], +) -> None: + """ + Realigns input strides from inputs_to_check if + we didn't end up running cudagraphs. Mutates + `compiled_graph.current_callable` if cudagraphs + was run. Otherwise, does nothing. + """ + if not ran_cudagraphs: + assert compiled_graph.current_callable is not None + new_callable = align_inputs_from_check_idxs( + compiled_graph.current_callable, inputs_to_check + ) + if new_callable is not compiled_graph.current_callable: + compiled_graph.current_callable = new_callable + + +def add_ephemeral_timeout_increase_for_distributed(time_saved_ns: int) -> int: + """ + Ephemerally increases the NCCL timeout when compiling for a distributed job + Returns amount of seconds increased + """ + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return 0 + + increased_timeout_sec = int(time_saved_ns // 1e9) # convert to seconds + + if config.is_fbcode(): + fudge_factor = torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:ephemeral_timeout_fudge_factor_percentage" + ) + log.info( + "Ephemeral NCCL timeout increase fudge factor %d and original increase value %d", + fudge_factor, + increased_timeout_sec, + ) + increased_timeout_sec += int(increased_timeout_sec * fudge_factor / 100) + + log.info("Increasing NCCL timeout by %d", increased_timeout_sec) + dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs( + timedelta(seconds=increased_timeout_sec) + ) + return increased_timeout_sec + + +class FxGraphCache: + """ + Supports caching and reusing compiled Fx graphs. + + The overall strategy is as follows: + - This cache stores entries on disk. When saving an entry, we can't + serialize callables (that could be C++, Triton, etc.), so we serialize + their own disk cache location. We then recreate the compiled artifact + after fetching from disk. + - For indexing the cache, we gather the fields relevant to identifying an + FxGraph (the graph module, graph inputs, system settings etc.) into an + FxGraphCacheDetails object, pickle it, and compute a hash for the key. + See FxGraphCachePickler. + - Among the metadata we store, we also include a guards expression that's + appropriate for validating any symbols for Tensor arguments that have + symbolic bounds. On cache lookup then, we evaluate those guards in the + current context to validate that a cached entry can be served. + - A given graph could have multiple compiled versions, corresponding to + different sets of guards. Therefore, we store cache entries in the form: + // + - On lookup, we compute the key from the graph details, iterate over all + leaf files in the corresponding subdirectory, deserialize the entry, and + evaluate its guards expression. If the evaluation succeeds, we have a + cache hit. If it fails, we compile the graph and store a new entry. + - Finally, on a cache hit, we need to make sure any guards that would + have been created during compilation are added to the current context. + """ + + # TODO(masnesral): Investigate whether it's beneficial to store compiled graphs + # in an in-memory cache after loading from disk. + @staticmethod + def _get_tmp_dir() -> str: + """ + Get the toplevel temporary directory for storing compiled graphs. + """ + return os.path.join(cache_dir(), "fxgraph") + + @staticmethod + def _get_tmp_dir_for_key(key: str) -> str: + """ + Return the disk location for a given cache key. + """ + return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key) + + @staticmethod + def _filter_backed_symints(inputs: List[Any]) -> List[torch.SymInt]: + """ + Get the backed SymInt objects from the input list. Note that we can never + have guards that depend on unbacked symint. + """ + return [s for s in inputs if isinstance(s, torch.SymInt) and has_hint(s)] + + @staticmethod + def _get_shape_env() -> Optional[ShapeEnv]: + """ + Helper to get the shape env from the tracing context. + """ + ctx = torch._guards.TracingContext.try_get() + if not ctx: + return None + return ctx.fake_mode.shape_env + + @staticmethod + def _lookup_graph( + key: str, + example_inputs: List[torch.Tensor], + local: bool, + remote_cache: Optional[RemoteCache[JsonDataTy]], + ) -> Optional[CompiledFxGraph]: + """ + Lookup a compiled graph in the cache by key. On a hit, return the + deserialized CompiledFxGraph object. On a miss, return None. + """ + shape_env = FxGraphCache._get_shape_env() + assert shape_env is not None + + symints = FxGraphCache._filter_backed_symints(example_inputs) + hints = [hint_int(s) for s in symints] + + def iterate_over_candidates() -> Generator[CompiledFxGraph, None, None]: + if local: + subdir = FxGraphCache._get_tmp_dir_for_key(key) + if os.path.exists(subdir): + for path in sorted(os.listdir(subdir)): + try: + with open(os.path.join(subdir, path), "rb") as f: + yield pickle.load(f) + except Exception: + log.warning( + "fx graph cache unable to load compiled graph", + exc_info=True, + ) + + if remote_cache: + try: + if (cache_data := remote_cache.get(key)) is not None: + assert isinstance(cache_data, dict) + data = cache_data["data"] + assert isinstance(data, (str, bytes)) + content = base64.b64decode(data) + yield pickle.loads(content) + except Exception: + log.warning( + "fx graph cache unable to load compiled graph", exc_info=True + ) + + # Iterate over any entries in the subdir for this key and evaluate + # their guards to determine whether there's a hit. + graph = None + + for candidate in iterate_over_candidates(): + if not candidate.guards_expr: + # No guards to evaluate, so this is a hit. + graph = candidate + break + + # Evaluate the guard expression in the current context. + # If there's not a cache hit, we don't want the evaluation to + # affect the current env, e.g., cause the creation of new guards, + # so we evaluate with the hints instead of the symbols. + hit = bool( + shape_env.evaluate_guards_expression(candidate.guards_expr, hints) + ) + log.debug( + "fx graph cache key %s evaluating guards [%s] with values %s => hit=%s", + key, + candidate.guards_expr, + hints, + hit, + ) + if hit: + graph = candidate + break + + if graph is None: + return None + + # See _save_graph(); we don't store the callable in the cache entry so + # recreate it here from the PyCodeCache disk cache. + artifact_path = get_path(graph.cache_key, "py")[2] + code = graph.source_code + if not os.path.exists(artifact_path): + counters["inductor"]["fxgraph_lookup_write_file"] += 1 + Path(os.path.dirname(artifact_path)).mkdir(parents=True, exist_ok=True) + cpp_pp = cpp_prefix_path() + if os.path.basename(cpp_pp) in code: + if cpp_pp in code: + # Great the name is correct + pass + else: + # Old dir name is included, replace it + pattern = rf'#include\s*"[^"]+{os.path.basename(cpp_pp)}"' + code = re.sub(pattern, f'#include "{cpp_pp}"', code) + + write_atomic(artifact_path, code, make_dirs=True) + + try: + graph.current_callable = PyCodeCache.load_by_key_path( + graph.cache_key, + artifact_path, + graph.cache_linemap, + graph.constants, + ).call + except OSError: + # Not expected, but in case the PyCodeCache entry is removed from + # underneath us, treat it as a cache miss and recompile. + log.error("Failed to load cached artifact: %s", artifact_path) + return None + + # Now re-evaluate with the symints to add any guards to the current env. + if graph.guards_expr: + check = bool( + shape_env.evaluate_guards_expression(graph.guards_expr, symints) + ) + assert check is True + log.debug( + "fx graph cache key %s post-load guards: %s", key, shape_env.guards + ) + + # Increment the cached metrics/counters by the amounts recorded when the FX + # graph was compiled for this cache entry. Pretending these counters + # were incremented normally is useful for testing with the cache enabled. + metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas) + counters["inductor"] += graph.counter_deltas + + from .graph import GraphLowering + + GraphLowering.save_output_code(code) + output_code_log.debug("Output code written to: %s", artifact_path) + output_code_log.debug("Output code: \n%s", code) + # On cache hit, use artifact path as filename + trace_structured( + "inductor_output_code", + lambda: {"filename": artifact_path}, + payload_fn=lambda: code, + ) + return graph + + @staticmethod + def post_compile( + compiled_graph: CompiledFxGraph, + example_inputs: List[torch.Tensor], + cudagraphs: BoxedBool, + ) -> CompiledFxGraph: + """ + Run a set of post processing steps after loading from the cache. These involve: + - Setting the tracing context output strides + - Running cudagraphs if enabled + - Realigning inputs + + This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph. + The results of this function are *not* saved in the cache itself. + """ + set_tracing_context_output_strides(example_inputs, compiled_graph) + + if cudagraphs: + # It's possible that cudagraphs is enabled, but was disabled + # during a previous compilation we're loading from the cache. + # If so, we need to disable it on this new process too. + if compiled_graph.disabled_cudagraphs_reason: + if "cuda" in compiled_graph.device_types: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}" + ) + else: + counters["inductor"]["cudagraph_skips"] += 1 + BoxedBool.disable(cudagraphs) + else: + cudagraph_post_compile( + example_inputs, + compiled_graph, + cudagraphs, + ) + inputs_to_check = compiled_graph.inputs_to_check + # cudagraphs could have been disabled from the earlier conditions + # so we still need to realign inputs if that happens + maybe_realign_inputs( + cudagraphs, + compiled_graph, + inputs_to_check, + ) + + return compiled_graph + + @staticmethod + def _save_graph( + key: str, + compiled_graph: CompiledFxGraph, + example_inputs: List[torch.Tensor], + local: bool, + remote_cache: Optional[RemoteCache[JsonDataTy]], + ) -> None: + """ + Store a serialized CompiledFxGraph on disk. + """ + disk_compiled_graph = copy(compiled_graph) + # We can't really serialize callables that may be C++/Triton/etc., + # so we serialize their PyCodeCache disk cache location instead. + # TODO: This could be better if we're ever able to serialize compiled + # models to disk. + disk_compiled_graph.current_callable = None + + # Before serializing, compute the guard expression that will be used to + # ensure that a CompiledFxGraph is valid when loaded from the cache. It's + # sufficient to consider only the SymInt args to the fx graph since the + # Tensor shapes are already captured in the hash for the cache key. Any + # Tensor arg with a symbolic shape will have a SymInt arg for the graph. + shape_env = FxGraphCache._get_shape_env() + assert shape_env is not None + symints = FxGraphCache._filter_backed_symints(example_inputs) + guards = shape_env.get_pruned_guards(symints) + disk_compiled_graph.guards_expr = shape_env.produce_guards_expression( + placeholders=symints, guards=guards + ) + + try: + content = pickle.dumps(disk_compiled_graph) + except Exception: + log.warning( + "fx graph cache unable to serialize compiled graph", exc_info=True + ) + counters["inductor"]["fxgraph_cache_pickle_error"] += 1 + return + + try: + if local: + subdir = FxGraphCache._get_tmp_dir_for_key(key) + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + + # Use a hash of the serialized CompiledFxGraph to get a unique file + # name. The specific name doesn't matter since a lookup involves + # iterating over all entries in the parent subdir. + path = os.path.join(subdir, sha256_hash(content)) + write_atomic(path, content, make_dirs=True) + + if remote_cache: + time_taken_ms = int((disk_compiled_graph._time_taken_ns or 0) // 1e6) + cache_data: JsonDataTy = { + "data": base64.b64encode(content).decode("ascii"), + "time_taken_ms": time_taken_ms, + } + remote_cache.put(key, cache_data) + except Exception: + log.warning("fx graph unable to write to cache", exc_info=True) + counters["inductor"]["fxgraph_cache_write_error"] += 1 + + @staticmethod + def _check_can_cache(gm: torch.fx.GraphModule) -> None: + """ + Check some conditions that would preclude caching and raise BypassFxGraphCache + to bypass in case caching is not possible. + """ + # Freezing can embed constants that wouldn't be static across runs. + if config.freezing or config.aot_inductor.use_runtime_constant_folding: + raise BypassFxGraphCache( + "Freezing may introduce constants that aren't static across runs." + ) + + # The treatment of guards in the caching implementation requires that + # we have a shape env. + if FxGraphCache._get_shape_env() is None: + log.debug("fx graph cache no shape env") + raise BypassFxGraphCache("No shape env.") + + # HigherOrderOperators should be handled on a case-by-case basis. + # Currently, we just skip caching if we have any. + # We also skip if there are any torchbind objects. + for node in gm.graph.nodes: + if isinstance(node.target, torch._ops.HigherOrderOperator): + raise BypassFxGraphCache("Can't cache HigherOrderOperators.") + if node.op == "getattr" and isinstance( + getattr(gm, node.target), torch._C.ScriptObject + ): + raise BypassFxGraphCache("Can't cache torchbind objects.") + + @staticmethod + def load( # type: ignore[no-untyped-def] + compile_fx_fn: Callable[..., Any], + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], + local: bool, + remote: bool, + ): + """ + Load a compiled graph from the cache. If a cached entry does not exist, + compile the graph and save it to the cache. + """ + assert local or remote, "at least one of them needs to be enabled" + compiled_graph = None + cache_state = None + cache_event_time = None + cache_info: Dict[str, Any] = {} + try: + FxGraphCache._check_can_cache(gm) + key, debug_lines = compiled_fx_graph_hash( + gm, example_inputs, fx_kwargs, inputs_to_check + ) + cache_info["key"] = key + cache_info["components"] = debug_lines + + remote_cache: Optional[RemoteCache[JsonDataTy]] = None + if remote: + cache_id = "fx-graph-v1" + try: + if config.is_fbcode(): + from torch._inductor.fb.remote_cache import FbRemoteFxGraphCache + + remote_cache = FbRemoteFxGraphCache(cache_id) + else: + from torch._inductor.remote_cache import RemoteFxGraphCache + + remote_cache = RemoteFxGraphCache(cache_id) + except ModuleNotFoundError as e: + # No need for a stack trace on this error + remote_cache = None + log.warning("Unable to create a remote cache: %s", e) + except Exception: + remote_cache = None + log.warning("Unable to create a remote cache", exc_info=True) + + compiled_graph = FxGraphCache._lookup_graph( + key, example_inputs, local, remote_cache + ) + + if compiled_graph is None: + log.debug("fx graph cache miss for key %s", key) + counters["inductor"]["fxgraph_cache_miss"] += 1 + cache_state = "miss" + start_time = time_ns() + cache_event_time = start_time + compiled_graph = compile_fx_fn( + gm, example_inputs, inputs_to_check, fx_kwargs + ) + compiled_graph._time_taken_ns = time_ns() - start_time + cache_info["time_taken_ns"] = compiled_graph._time_taken_ns + FxGraphCache._save_graph( + key, + compiled_graph, + example_inputs, + local, + remote_cache, + ) + else: + log.debug("fx graph cache hit for key %s", key) + counters["inductor"]["fxgraph_cache_hit"] += 1 + cache_state = "hit" + cache_event_time = time_ns() + if (time_saved_ns := compiled_graph._time_taken_ns) is not None: + cache_info["time_saved_ns"] = time_saved_ns + if ( + ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( + time_saved_ns + ) + ) != 0: + cache_info["ephemeral_timeout_increase"] = ephemeral_increase + compiled_graph._fx_graph_cache_key = key + except BypassFxGraphCache as e: + counters["inductor"]["fxgraph_cache_bypass"] += 1 + cache_state = "bypass" + log.info("Bypassing FX Graph Cache because '%s'", e) + cache_info["cache_bypass_reason"] = str(e) + if remote: + log_cache_bypass("bypass_fx_graph", str(e)) + cache_event_time = time_ns() + + if not compiled_graph: + compiled_graph = compile_fx_fn( + gm, example_inputs, inputs_to_check, fx_kwargs + ) + assert compiled_graph is not None + cache_info["cache_state"] = cache_state + chromium_log = get_chromium_event_logger() + chromium_log.log_instant_event( + f"fx_graph_cache_{cache_state}", cache_event_time, metadata=cache_info + ) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_cache_hash", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + # Use the passed in cudagraphs so that we mutate the BoxedBool correctly + FxGraphCache.post_compile( + compiled_graph, example_inputs, fx_kwargs["cudagraphs"] + ) + return compiled_graph + + @staticmethod + def clear() -> None: + """ + Clear out the on-disk cache. + """ + try: + shutil.rmtree(FxGraphCache._get_tmp_dir()) + except FileNotFoundError: + pass + + +_StrideExprStr: TypeAlias = str + + +@dataclasses.dataclass +class CompiledFxGraph: + """ + Class holding a compiled FX graph. This is the object serialized on disk + to support FxGraph caching. + """ + + current_callable: Optional[Callable[..., Any]] + cache_key: str + source_code: str = dataclasses.field(repr=False) # Do not display source_code + cache_linemap: Optional[List[Tuple[int, str]]] + device_types: Set[str] + device_idxs: Set[int] + mutated_inputs: Set[str] + mutated_input_idxs: Set[int] + constants: Dict[str, torch.Tensor] + torchbind_constants: Dict[str, torch._C.ScriptObject] + output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]] + disabled_cudagraphs_reason: Optional[str] + metrics_deltas: metrics.CachedMetricsDeltas + counter_deltas: Counter[str] + # This is a string representation of an expression we serialize + # with the object so the guards can be evaluated in a different + # context in order to verify the validity of serving a cached + # fx graph. The expression must be generated by: + # ShapeEnv.produce_guards_expression() + guards_expr: Optional[str] + + cudagraph_info: Optional[CudagraphCachedInfo] + fx_kwargs: Dict[str, Any] + inputs_to_check: Sequence[int] + boxed_forward_device_index: Optional[BoxedDeviceIndex] + + _time_taken_ns: Optional[int] = None + _boxed_call: Optional[bool] = None + _fx_graph_cache_key: Optional[str] = None + + def __init__( + self, + current_callable: Optional[Callable[..., Any]], + graph: GraphLowering, + output_strides: List[Optional[Tuple[_StrideExprStr, ...]]], + disabled_cudagraphs_reason: Optional[str], + metrics_deltas: metrics.CachedMetricsDeltas, + counter_deltas: Counter[str], + ) -> None: + self.current_callable = current_callable + self.cache_key = graph.cache_key + if graph.cache_path: + with open(graph.cache_path) as f: + self.source_code = f.read() + self.cache_linemap = graph.cache_linemap + # TODO - ordered set + self.device_types = set(graph.device_types) + self.device_idxs = set(graph.device_idxs) + self.mutated_inputs = set(graph.mutated_inputs) + self.mutated_input_idxs = set(graph.mutated_input_idxs) + self.constants = graph.constants + self.torchbind_constants = graph.torchbind_constants + self.output_strides = output_strides + self.disabled_cudagraphs_reason = disabled_cudagraphs_reason + self.metrics_deltas = metrics_deltas + self.counter_deltas = counter_deltas + self.guards_expr = None + self.cudagraph_info = None + self.fx_kwargs = {} + self.inputs_to_check = () + self.boxed_forward_device_index = None + + def __call__(self, inputs: List[Any]) -> Any: + assert self.current_callable is not None + return self.current_callable(inputs) + + +def run_command_and_check(cmd_: str) -> None: + cmd = shlex.split(cmd_) + try: + subprocess.check_call(cmd) + except subprocess.CalledProcessError as e: + raise exc.CppCompileError(cmd, e.output) from e + + +@functools.lru_cache(None) +def split_aot_inductor_output_path(path: str) -> Tuple[str, str]: + """Returns the path where the AOT Inductor compiled kernels are stored.""" + if path.endswith(".so"): + return os.path.split(path) + elif path.endswith(".pt2"): + return os.path.split(path) + else: + return path, "" + + +@clear_on_fresh_inductor_cache +class CudaKernelParamCache: + cache: Dict[str, Dict[str, str]] = {} + cache_clear = staticmethod(cache.clear) + + @classmethod + def set(cls, key: str, params: Dict[str, str], cubin: str, bin_type: str) -> None: + _, path = write( + cubin, + bin_type, + hash_type=bin_type, + specified_dir=split_aot_inductor_output_path( + config.aot_inductor.output_path + )[0], + ) + + params[get_cpp_wrapper_cubin_path_name()] = path + + cls.cache[key] = params + + @classmethod + def get(cls, key: str) -> Optional[Dict[str, str]]: + return cls.cache.get(key, None) + + @classmethod + def get_keys(cls) -> KeysView[str]: + return cls.cache.keys() + + +class AotCodeCompiler: + @classmethod + def compile( + cls, + graph: GraphLowering, + source_code: str, + serialized_extern_kernel_nodes: Optional[str], + cuda: bool, + ) -> str: + if sys.platform == "win32": + raise RuntimeError("AotCodeCompiler not yet supported for inductor") + + _set_gpu_runtime_env() # cpp_extension consults the env + + picked_vec_isa = pick_vec_isa() + vec_isa_cmd_gen = CppBuilder( + name="o", + sources="i", + BuildOption=CppTorchCudaOptions( + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, + ), + ) + # write function will calc source_code hash, the same source code with different + # ISA level should be generate different hash. + # So we need get a command_line which contains isa related parameter as a part of hash key. + # And then pass the command_line to below write function as extra parameter to + # guarantee the source code hash contains ISA difference. + cpp_command = repr(vec_isa_cmd_gen.get_command_line()) + + fbcode_aot_cpu_re = False + use_absolute_path = False + if config.is_fbcode(): + ld_command = build_paths.ld() + if not cuda and graph.aot_mode: # Meta internal AOTInductor CPU + objcopy_command = build_paths.objcopy_fallback() + fbcode_aot_cpu_re = True + use_absolute_path = True + else: + objcopy_command = build_paths.objcopy() + else: + ld_command = "ld" + objcopy_command = "objcopy" + + ( + specified_output_path, + specified_so_name, + ) = split_aot_inductor_output_path(config.aot_inductor.output_path) + key, input_path = write( + source_code, + "cpp", + extra=cpp_command, + specified_dir=specified_output_path, + ) + output_code_log.info("Output code written to: %s", input_path) + trace_structured( + "graph_dump", + lambda: { + "name": "inductor_aot_code", + "type": "cpp", + "filename": input_path, + }, + payload_fn=lambda: source_code, + ) + + # We use a file lock below to protect FS operations. The lock file + # is scoped to the 'key', so make sure the consts_path is protected + # by the same lock: + consts_specified_dir = os.path.join(os.path.split(input_path)[0], key) + + def _compile_consts_linux(consts: bytes) -> str: + _, consts_path = write( + consts, + "bin", + specified_dir=consts_specified_dir, + ) + + consts_o = os.path.splitext(consts_path)[0] + ".o" + if fbcode_aot_cpu_re: + cmd = f"{ld_command} -r -b binary -o {os.path.basename(consts_o)} {os.path.basename(consts_path)}" + compile_file(consts_path, consts_o, cmd.split()) + os.chmod(consts_o, 0o644) + else: + cmd = f"{ld_command} -r -b binary -o {consts_o} {consts_path}" + run_command_and_check(cmd) + log.debug("aot constant binary command: %s", cmd) + + if graph.mutated_buffers & set(graph.constants.keys()): + # .data section is between .text and .bss. When the size of .data is large, + # during the linking, the relocation of .text against .bss may overflow. + # Rename it to .ldata so that it won't be in between the .text and .bss section + if len(consts) > 2_000_000_000: + raise ValueError( + "Models with buffer mutation included doesn't support constants greater than 2GB!" + ) + rename_data = " .data=.ldata" + else: + # if no buffer mutation is needed, we could instead set the data region + # as read-only (i.e. .lrodata) which could accomodate larger size of data + # to be linked. + rename_data = " .data=.lrodata,alloc,load,readonly,data,contents" + + assert ( + ALIGN_BYTES & (ALIGN_BYTES - 1) + ) == 0 and ALIGN_BYTES >= 64, "must be power of 2 and >= 64" + cmd = ( + f"{objcopy_command} --rename-section" + f"{rename_data}" + f" --set-section-alignment .data={ALIGN_BYTES}" # following the gAlignment of CPU in c10/core/alignment.h + f" {consts_o} {consts_o}" + ) + log.debug("aot constant rename section command: %s", cmd) + run_command_and_check(cmd) + + cmd = f"rm {consts_path}" + log.debug("aot constant bin removal command: %s", cmd) + run_command_and_check(cmd) + + if fbcode_aot_cpu_re: + body = re.sub(r"[\W]", "_", os.path.basename(consts_path)) + else: + body = re.sub(r"[\W]", "_", consts_path) + + symbol_list = [] + symbol_list.append( + f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_start {consts_o}" + ) + symbol_list.append( + f"{objcopy_command} --redefine-sym _binary_{body}_size=_binary_constants_bin_size {consts_o}" + ) + symbol_list.append( + f"{objcopy_command} --redefine-sym _binary_{body}_end=_binary_constants_bin_end {consts_o}" + ) + log.debug("aot constant binary redefine symbol: %s", " ".join(symbol_list)) + for cmd in symbol_list: + run_command_and_check(cmd) + return consts_o + + def _compile_consts_darwin(consts: bytes) -> str: + if config.aot_inductor.debug_dump_consts_bin: + _, _binary_constants_path = write( + consts, + "bin", + specified_dir=consts_specified_dir, + ) + log.debug("binary constants path: %s", _binary_constants_path) + + is_large_consts = len(consts) > 1024 + consts_asm = "\t.section\t__DATA,__data\n" + consts_asm += "\t.globl\t__binary_constants_bin_start\n" + consts_asm += "__binary_constants_bin_start:\n" + if not is_large_consts: + for c in consts: + consts_asm += f"\t.byte {c}\n" + # Add one element even if constants are empty + # Otherwise assembler will not put them in data section + if not consts: + consts_asm += "\t.space 1\n" + else: + consts_asm += "\t.quad 0x1234567899abcdef\n" + consts_asm += f"\t.space {len(consts) - 8}\n" + consts_asm += ".globl\t__binary_constants_bin_end\n" + consts_asm += "__binary_constants_bin_end:\n" + _, consts_path = write( + consts_asm, + "S", + specified_dir=consts_specified_dir, + ) + consts_o = os.path.splitext(consts_path)[0] + ".o" + cmd = f"{get_cpp_compiler()} -c -o {consts_o} {consts_path}" + run_command_and_check(cmd) + if is_large_consts: + with open(consts_o, "r+b") as f: + f.seek(0) + hdr = f.read(1024) + # Search for magic number and write the actual data over it + start_idx = hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12") + assert start_idx != -1 + f.seek(start_idx) + pos = 0 + while pos < len(consts): + rc = f.write(consts[pos:]) + pos += rc + return consts_o + + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + # Currently, this only support serializing extern nodes in fbcode + # Eventually, we should also have a serializer for OSS. + if serialized_extern_kernel_nodes: + output_json = os.path.splitext(input_path)[0] + ".json" + with open(output_json, "w") as f: + f.write(serialized_extern_kernel_nodes) + + output_so = ( + config.aot_inductor.output_path + if specified_so_name + else os.path.splitext(input_path)[0] + ".so" + ) + + output_o = os.path.splitext(input_path)[0] + ".o" + + all_cuda = all( + graph.get_original_value_of_constant(name).is_cuda + for name in graph.constants.keys() + if name not in graph.folded_constants + ) + + def get_nbytes_of_tensor(tensor: torch.Tensor, all_cuda: bool) -> int: + n_bytes = ( + torch.ops.mkldnn._nbytes(tensor) + if tensor.is_mkldnn + else tensor.untyped_storage().nbytes() + ) + return n_bytes if all_cuda else _align(n_bytes) + + consts_size = sum( + get_nbytes_of_tensor(tensor, all_cuda) + for (name, tensor) in graph.constants.items() + if name not in graph.folded_constants + ) + # TODO: Fix mmap weights with cuda + use_mmap_weights = not config.is_fbcode() and consts_size > 2_000_000_000 + if config.aot_inductor.force_mmap_weights: + use_mmap_weights = True + + if config.aot_inductor.package: + ( + object_output_name, + object_output_dir, + ) = get_name_and_dir_from_output_file_path(input_path) + object_build_options = CppTorchCudaOptions( + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, + compile_only=True, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + ) + object_builder = CppBuilder( + name=object_output_name, + sources=input_path, + output_dir=object_output_dir, + BuildOption=object_build_options, + ) + compile_cmd = object_builder.get_command_line() + output_o = object_builder.get_target_file_path() + + compile_flags = os.path.splitext(input_path)[0] + "_compile_flags.json" + object_build_options.save_flags_to_file(compile_flags) + + else: + ( + object_output_name, + object_output_dir, + ) = get_name_and_dir_from_output_file_path(input_path) + object_build_options = CppTorchCudaOptions( + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, + compile_only=True, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + ) + object_builder = CppBuilder( + name=object_output_name, + sources=input_path, + output_dir=object_output_dir, + BuildOption=object_build_options, + ) + compile_cmd = object_builder.get_command_line() + output_o = object_builder.get_target_file_path() + + log.debug("aot compilation command: %s", compile_cmd) + if fbcode_aot_cpu_re: + output_o = os.path.splitext(input_path)[0] + ".o" + compile_file(input_path, output_o, compile_cmd.split()) + os.chmod(output_o, 0o644) + else: + run_command_and_check(compile_cmd) + + def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes: + def _pad_to_alignment(raw_bytes: bytes) -> bytes: + padded_bytes = raw_bytes.ljust( + (len(raw_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES * ALIGN_BYTES, + b"\x00", + ) + return padded_bytes + + # This serializes the tensor's untyped_storage to bytes by accessing + # the raw data of the underlying structure. + import ctypes + + if t.numel() == 0: + return b"" + + if t.is_mkldnn: + data_ptr = torch.ops.mkldnn.data_ptr(t) + nbytes = torch.ops.mkldnn._nbytes(t) + else: + t_cpu = t.untyped_storage().cpu() + data_ptr = t_cpu.data_ptr() + nbytes = t_cpu.nbytes() + + raw_array = ctypes.cast( + data_ptr, + ctypes.POINTER(ctypes.c_ubyte * nbytes), + ) + raw_bytes = bytes(raw_array.contents) + return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes) + + serialized_weights = b"".join( + _to_bytes(graph.get_original_value_of_constant(name), all_cuda) + for name in graph.constants.keys() + if name not in graph.folded_constants + ) + if not use_mmap_weights: + aot_constants = serialized_weights + magic_number = 0 + else: + magic_number = cast( + int, torch.randint(0, torch.iinfo(torch.int64).max, (1,)).item() + ) + aot_constants = struct.pack("qq", consts_size + 8, magic_number) + + consts_o = { + "linux": _compile_consts_linux, + "darwin": _compile_consts_darwin, + }[sys.platform](aot_constants) + + if config.aot_inductor.package: + output_name, output_dir = get_name_and_dir_from_output_file_path( + output_so + ) + so_build_options = CppTorchCudaOptions( + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, + use_absolute_path=use_absolute_path, + ) + so_builder = CppBuilder( + name=output_name, + sources=[output_o, consts_o], + output_dir=output_dir, + BuildOption=so_build_options, + ) + link_cmd = so_builder.get_command_line() + output_so = so_builder.get_target_file_path() + + linker_flags = os.path.splitext(input_path)[0] + "_linker_flags.json" + so_build_options.save_flags_to_file(linker_flags) + + from torch._inductor.package import package_aoti + + if use_mmap_weights: + weight_file = ( + os.path.splitext(input_path)[0] + "_serialized_weights.bin" + ) + with open(weight_file, "wb") as f_weights: + f_weights.write(serialized_weights) + f_weights.write(struct.pack("q", magic_number)) + + archive_path = package_aoti(os.path.split(input_path)[0]) + return archive_path + else: + output_name, output_dir = get_name_and_dir_from_output_file_path( + output_so + ) + so_build_options = CppTorchCudaOptions( + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, + use_absolute_path=use_absolute_path, + ) + so_builder = CppBuilder( + name=output_name, + sources=[output_o, consts_o], + output_dir=output_dir, + BuildOption=so_build_options, + ) + link_cmd = so_builder.get_command_line() + output_so = so_builder.get_target_file_path() + + log.debug("aot linkage command: %s", link_cmd) + if fbcode_aot_cpu_re: + output_so = ( + config.aot_inductor.output_path + if specified_so_name + else os.path.splitext(input_path)[0] + ".so" + ) + compile_file([output_o, consts_o], output_so, link_cmd.split()) + os.chmod(output_so, 0o755) + else: + run_command_and_check(link_cmd) + + if use_mmap_weights: + import resource + + page_size_ = resource.getpagesize() + page_size = max(16384, page_size_) + + with open(output_so, "a+b") as f_so: + so_size = f_so.tell() + # Page align the weights + f_so.write(b" " * (page_size - so_size % page_size)) + f_so.write(serialized_weights) + f_so.write(struct.pack("q", magic_number)) + + # Append cmds to the end of codegen-ed wrapper file + with open(input_path, "a") as f: + f.write("\n") + f.write(f"// Compile cmd\n// {compile_cmd}\n") + f.write(f"// Link cmd\n// {link_cmd}\n") + + return output_so + + +# Putting this fn in cpp.py (unfortunately) causes a deadlock, which is why it's in codecache.py. +# Why? importing from cpp.py invokes codecache.pick_vec_isa(), which takes out a lock. +# Cycle goes: +# - CppCodeCache.load() +# - pick_vec_isa() +# - valid_vec_isa_list() +# - VecISA.__bool__() <-- takes out a lock +# - compile_file() <-- imports cpp_prefix_path from cpp, which causes us to try to take out the same lock. +@clear_on_fresh_inductor_cache +@functools.lru_cache +def cpp_prefix_path() -> str: + path = Path(__file__).parent / "codegen/cpp_prefix.h" + with path.open() as f: + content = f.read() + _, filename = write( + content, + "h", + ) + return normalize_path_separator(filename) + + +def cpp_prefix() -> str: + filename = cpp_prefix_path() + if config.is_fbcode(): + # We need relative paths, since we bundle up + # everything that we compile into a folder for remote compilation. + return f'#include "{os.path.basename(filename)}"' + else: + return f'#include "{filename}"' + + +# Given a path to an input cpp file and an output path, +# Attempts to compile the file, storing the output in "output_path" +def compile_file( + input_path: Union[str, List[str]], output_path: str, cmd: List[str] +) -> None: + with dynamo_timed("compile_file"): + return _compile_file(input_path, output_path, cmd) + + +def _compile_file( + input_path: Union[str, List[str]], output_path: str, cmd: List[str] +) -> None: + input_paths = [input_path] if isinstance(input_path, str) else input_path + input_files = [ + os.path.basename(ip) if config.is_fbcode() else ip for ip in input_paths + ] + try: + if config.is_fbcode(): + # Need to copy our header into the same folder as the sourcecode. + header_path = cpp_prefix_path() + header_name = os.path.basename(header_path) + output_name = os.path.basename(output_path) + # When we build remotely, we need to make sure to carefully copy any files + # that are required during the compilation process into our build directly. + # This is where all of the ATen/c10/Torch includes come from. + torch_includes_path = os.path.join(_TORCH_PATH, "include") + with tempfile.TemporaryDirectory() as tmp_dir: + # Copy everything to tmp compilation folder + shutil.copy(header_path, os.path.join(tmp_dir, header_name)) + shutil.copy(_LINKER_SCRIPT, os.path.join(tmp_dir, "script.ld")) + for p, f in zip(input_paths, input_files): + shutil.copy(p, os.path.join(tmp_dir, f)) + dest_include_path = os.path.join(tmp_dir, "include") + shutil.copytree(torch_includes_path, dest_include_path) + # Run the build + output_file_path = _run_build_command(cmd, tmp_dir, output_name) + # Copy output from the build + if os.path.exists(output_path): + os.remove(output_path) + shutil.copy(output_file_path, output_path) + else: + subprocess.check_output(cmd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + output = e.output.decode("utf-8") + openmp_problem = "'omp.h' file not found" in output or "libomp" in output + if openmp_problem and sys.platform == "darwin": + instruction = ( + "\n\nOpenMP support not found. Please try one of the following solutions:\n" + "(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ " + "that has builtin OpenMP support;\n" + "(2) install OpenMP via conda: `conda install llvm-openmp`;\n" + "(3) install libomp via brew: `brew install libomp`;\n" + "(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path" + " with `include/omp.h` under it." + ) + output += instruction + raise exc.CppCompileError(cmd, output) from e + + +_libgomp: Optional[CDLL] = None + + +def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p]: + # This function will be called from generated cpp wrapper code in the JIT mode. + # Because tensors will be passed in as AtenTensorHandle, we need to explicitly convert them. + def convert_arg(arg: Any) -> Any: + if str(type(arg)) == "": + # No easy way to do isinstance check on PyCapsule + return torch._C._aoti.alloc_tensor_by_stealing_from_void_ptr(arg) + elif isinstance(arg, (list, tuple)): + return type(arg)(convert_arg(a) for a in arg) + else: + return arg + + converted_args = [convert_arg(arg) for arg in args] + + assert op.startswith("torch.ops."), ( + op + " can not be called through custom_op_wrapper" + ) + func = None + for i, s in enumerate(op.split(".")): + if i == 0: + func = importlib.import_module(s) + func = getattr(func, s) + + assert callable(func), op + " can not be loaded through custom_op_wrapper" + result = func(*converted_args) + if isinstance(result, (list, tuple)): + for r in result: + assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors" + return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type] + else: + assert isinstance(result, torch.Tensor), op + " returns a non-tensor" + return torch._C._aoti.unsafe_alloc_void_ptr_from_tensor(result) + + +@clear_on_fresh_inductor_cache +class CppCodeCache: + cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} + cache_clear = staticmethod(cache.clear) + cpp_compile_command_flags: Dict[str, Any] = {} + + @staticmethod + def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]: + return cdll.LoadLibrary(path) + + @classmethod + def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]: + try: + result = cls._load_library_inner(path, key) + result.key = key # type: ignore[union-attr] + return result + except (ImportError, OSError) as e: + if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"): + # hacky workaround for fbcode/buck + global _libgomp + _libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1") + result = cls._load_library_inner(path, key) + result.key = key # type: ignore[union-attr] + return result + if "failed to map segment from shared object" in str(e): + raise OSError( + f"{e}. The most common reason this may occur is if the {tempfile.gettempdir()} folder " + "is mounted with noexec (e.g., by default Docker mounts tmp file systems " + f"as noexec). Please remount {tempfile.gettempdir()} with exec enabled, or set another " + "temporary directory with TORCHINDUCTOR_CACHE_DIR environment variable." + ) from e + raise + + @classmethod + def load_async( + cls, + source_code: str, + cuda: bool = False, + submit_fn: Any = None, + extra_flags: Sequence[str] = (), + ) -> Any: + compile_command = { + **cls.cpp_compile_command_flags, + "cuda": cuda, + "vec_isa": pick_vec_isa(), + "extra_flags": extra_flags, + } + + _set_gpu_runtime_env() # cpp_extension consults the env + + command_gen = CppBuilder( + name="o", sources="i", BuildOption=CppTorchCudaOptions(**compile_command) + ) + # write function will calc source_code hash, the same source code with different + # ISA level should be generate different hash. + # So we need get a command_line which contains isa related parameter as a part of hash key. + # And then pass the command_line to below write function as extra parameter to + # guarantee the source code hash contains ISA difference. + vec_isa_cmd = repr(command_gen.get_command_line()) + key, input_path = write(source_code, "cpp", extra=vec_isa_cmd) + + if key not in cls.cache: + from filelock import FileLock + + lock_path = os.path.join(get_lock_dir(), key + ".lock") + output_name, output_dir = get_name_and_dir_from_output_file_path(input_path) + """ + If `fb_code` env, it need to be dispatched to original `compile_file` function. + So, we still need to prepare parameters for the function: `input_path` and `fb_output_path`. + """ + fb_output_path = input_path[:-3] + "so" + future: Optional[Future[Any]] = None + lib = None + + cpp_build_option = CppTorchCudaOptions(**compile_command) + cpp_builder = CppBuilder( + name=output_name, + sources=input_path, + output_dir=output_dir, + BuildOption=cpp_build_option, + ) + + worker_fn = functools.partial( + _worker_compile_cpp, + lock_path, + cpp_builder, + input_path, + fb_output_path, + ) + + binary_path = normalize_path_separator( + fb_output_path + if config.is_fbcode() + else cpp_builder.get_target_file_path() + ) + + def load_fn() -> Any: + nonlocal lib + if lib is None: + if future is not None: + future.result() + result = worker_fn() + assert result is None + lib = cls._load_library(binary_path, key) + assert lib is not None + return lib + + if submit_fn is not None: + with FileLock(lock_path, timeout=LOCK_TIMEOUT): + if not os.path.exists(binary_path): + future = submit_fn(worker_fn) + + cls.cache[key] = load_fn + + return cls.cache[key] + + @classmethod + def load(cls, source_code: str, cuda: bool = False) -> Any: + return cls.load_async(source_code, cuda)() + + +def _worker_compile_cpp( + lock_path: str, + cpp_builder: CppBuilder, + fb_input_path: str, + fb_output_path: str, +) -> None: + from filelock import FileLock + + with FileLock(lock_path, timeout=LOCK_TIMEOUT): + binary_path = ( + fb_output_path if config.is_fbcode() else cpp_builder.get_target_file_path() + ) + if not os.path.exists(binary_path): + if config.is_fbcode(): + compile_file( + fb_input_path, + fb_output_path, + shlex.split(cpp_builder.get_command_line()), + ) + else: + cpp_builder.build() + + +# Customized Python binding for cpp kernels +@clear_on_fresh_inductor_cache +class CppPythonBindingsCodeCache(CppCodeCache): + cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} + cache_clear = staticmethod(cache.clear) + cpp_compile_command_flags = { + # kernels have no dependency on libtorch + "include_pytorch": False, + "shared": True, + } + entry_function = "kernel" + call_entry_function = "kernel(%s);Py_RETURN_NONE;" + extra_parse_arg = "" + suffix_template = textwrap.dedent( + """ + // Python bindings to call %s(): + #define PY_SSIZE_T_CLEAN + #include + #include + #include + + #ifndef _MSC_VER + #if __cplusplus < 202002L + // C++20 (earlier) code + // https://en.cppreference.com/w/cpp/language/attributes/likely + #define likely(x) __builtin_expect(!!(x), 1) + #define unlikely(x) __builtin_expect(!!(x), 0) + #endif + #else + #define likely(x) (x) + #define unlikely(x) (x) + #endif + + // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow. + // We manually link it below to workaround issues with fbcode build. + static void* (*_torchinductor_pyobject_tensor_data_ptr)(PyObject* obj); + + template static inline T parse_arg(PyObject* args, size_t n) { + static_assert(std::is_pointer::value, "arg type must be pointer or long"); + return static_cast(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n))); + } + template <> inline int64_t parse_arg(PyObject* args, size_t n) { + auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n)); + if(unlikely(result == -1 && PyErr_Occurred())) + throw std::runtime_error("expected int arg"); + return result; + } + template <> inline uintptr_t parse_arg(PyObject* args, size_t n) { + auto result = PyLong_AsVoidPtr(PyTuple_GET_ITEM(args, n)); + if(unlikely(result == reinterpret_cast(-1) && PyErr_Occurred())) + throw std::runtime_error("expected int arg"); + return reinterpret_cast(result); + } + + %s + + static PyObject* %s_py(PyObject* self, PyObject* args) { + try { + if(unlikely(!PyTuple_CheckExact(args))) + throw std::runtime_error("tuple args required"); + if(unlikely(PyTuple_GET_SIZE(args) != %s)) + throw std::runtime_error("requires %s args"); + %s + } catch(std::exception const& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + } catch(...) { + PyErr_SetString(PyExc_RuntimeError, "unhandled error"); + return nullptr; + } + } + + static PyMethodDef py_methods[] = { + {"%s", %s_py, METH_VARARGS, ""}, + {NULL, NULL, 0, NULL}}; + + static struct PyModuleDef py_module = + {PyModuleDef_HEAD_INIT, "%s", NULL, -1, py_methods}; + + PyMODINIT_FUNC PyInit_%s(void) { + const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"); + if(!str_addr) { + PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set"); + return nullptr; + } + std::istringstream iss(str_addr); + uintptr_t addr = 0; + iss >> addr; + _torchinductor_pyobject_tensor_data_ptr = + reinterpret_cast(addr); + return PyModule_Create(&py_module); + } + """ + ) + + @classmethod + def _load_library_inner(cls, path: str, key: str) -> ModuleType: + os.environ["_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"] = str( + torch._C._dynamo.guards._torchinductor_pyobject_tensor_data_ptr # type: ignore[attr-defined] + ) + module_name = f"{key}.{cls.entry_function}" + try: + return sys.modules[module_name] + except KeyError: + pass + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore[union-attr] + return module + + @classmethod + def load_pybinding_async( + cls, + argtypes: List[str], + source_code: str, + cuda: bool = False, + num_outputs: int = -1, + submit_fn: Any = None, + extra_flags: Sequence[str] = (), + ) -> Any: + """ + Wrap a C++ function in fast Python bindings. + + Args: + argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"] + source_code: C++ source code containing a ENTRY_FUNCTION() function + + Returns: + A python version of ENTRY_FUNCTION() + """ + parseargs = ", ".join( + f"parse_arg<{argtype.replace('const ', '')}>(args, {n})" + for n, argtype in enumerate(argtypes) + ) + suffix = cls.suffix_template % ( + cls.entry_function, + cls.extra_parse_arg % num_outputs if cls.extra_parse_arg else "", + cls.entry_function, + len(argtypes), + len(argtypes), + cls.call_entry_function % parseargs, + cls.entry_function, + cls.entry_function, + cls.entry_function, + cls.entry_function, + ) + get_result = cls.load_async( + source_code + suffix, cuda, submit_fn=submit_fn, extra_flags=extra_flags + ) + result = None + + def future() -> Any: + nonlocal result + if result is None: + result = get_result() + assert isinstance(result, ModuleType) + return getattr(result, cls.entry_function) + + return future + + @classmethod + def load_pybinding(cls, *args: Any, **kwargs: Any) -> Any: + return cls.load_pybinding_async(*args, **kwargs)() + + +@clear_on_fresh_inductor_cache +class CppWrapperCodeCache(CppPythonBindingsCodeCache): + cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} + cache_clear = staticmethod(cache.clear) + cpp_compile_command_flags = { + "include_pytorch": True, + "shared": True, + } + entry_function = "inductor_entry_cpp" + call_entry_function = "return inductor_entry_cpp(%s);" + extra_parse_arg = textwrap.dedent( + """ + #include + + static inline std::vector unpack_tensor_handle_list(PyObject* pyvec) { + std::vector result; + size_t result_len = PyList_GET_SIZE(pyvec); + result.reserve(result_len); + for (size_t i = 0; i < result_len; i++) { + // AtenTensorHandle is essentially a pointer + void* elem = PyCapsule_GetPointer(PyList_GET_ITEM(pyvec, i), NULL); + result.push_back(reinterpret_cast(elem)); + } + return result; + } + + static inline PyObject* pack_tensor_handle_list(const std::vector& cppvec) { + size_t result_len = cppvec.size(); + PyObject* result = PyList_New(static_cast(result_len)); + for (size_t i = 0; i < result_len; i++) { + PyObject *elem = + cppvec[i] == nullptr + ? Py_None + // Store AtenTensorHandle as PyCapsulate + : PyCapsule_New(reinterpret_cast(cppvec[i]), NULL, NULL); + PyList_SET_ITEM(result, i, elem); + } + return result; + } + + template <> inline std::vector parse_arg>(PyObject* args, size_t n) { + return unpack_tensor_handle_list(PyTuple_GET_ITEM(args, n)); + } + + PyObject* inductor_entry_cpp(std::vector&& input_handles) { + // For outputs, we only allocate a vector to hold returned tensor handles, + // not allocating the actual output tensor storage here + std::vector output_handles(%s); + try { + inductor_entry_impl(input_handles.data(), output_handles.data()); + return pack_tensor_handle_list(output_handles); + } catch(std::exception const& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return {}; + } catch(...) { + PyErr_SetString(PyExc_RuntimeError, "unhandled error"); + return {}; + } + } + """ + ) + + +@clear_on_fresh_inductor_cache +class HalideCodeCache(CppPythonBindingsCodeCache): + cache: Dict[str, Callable[[], Union[ModuleType, CDLL]]] = {} + cache_clear = staticmethod(cache.clear) + _standalone_runtime_path: Optional[str] = None + prefix = textwrap.dedent( + """ + #include "{halideruntime_h}" + #include "{headerfile}" + #include + #include + + namespace c10 {{ + inline long div_floor_integer(long a, long b) {{ + if ((a<0) != (b<0)) {{ + const auto quot = a / b; + const auto rem = a % b; + return rem ? quot - 1 : quot; + }} + return a / b; + }} + }} + """ + ) + glue_template_cpp = prefix + textwrap.dedent( + """ + void kernel({argdefs}) {{ + {buffers} + int err = halide_kernel({buffer_names}); + if(err != 0) throw std::runtime_error("halide_kernel failed"); + }} + """ + ) + glue_template_cuda = prefix + textwrap.dedent( + """ + #include + static const halide_device_interface_t* cuda_interface = halide_cuda_device_interface(); + + void kernel({argdefs}, uintptr_t stream) {{ + {buffers} + int err = halide_kernel(reinterpret_cast(stream), {buffer_names}); + if(err != 0) throw std::runtime_error("halide_kernel failed"); + }} + """ + ) + standalone_runtime_cuda_init = textwrap.dedent( + """ + #include "{}" + #include + + static int acquire_context(void* user_context, + void** cuda_context_out, + bool create) {{ + return cuCtxGetCurrent(reinterpret_cast(cuda_context_out)); + }} + + static int release_context(void* user_context) {{ + return 0; + }} + + static int get_stream(void* user_context, + void* cuda_context, + void** stream_out) {{ + *stream_out = user_context; + return 0; + }} + + static int register_halide_hooks() {{ + halide_set_cuda_acquire_context(&acquire_context); + halide_set_cuda_release_context(&release_context); + halide_set_cuda_get_stream(&get_stream); + return 0; + }} + + int inductor_register_halide_hooks_result = register_halide_hooks(); + """ + ) + + @classmethod + def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> List[str]: + assert arg.shape is not None + assert arg.stride is not None and len(arg.shape) == len(arg.stride) + assert arg.offset is not None + data_ptr = f"{arg.alias_of or arg.name} + {arg.offset}" + if cuda: + device = f"reinterpret_cast({data_ptr})" + device_interface = "cuda_interface" + host = "nullptr" + flags = "halide_buffer_flag_device_dirty" + else: + device = "0" + device_interface = "nullptr" + host = f"reinterpret_cast({data_ptr})" + flags = "halide_buffer_flag_host_dirty" + + dims = [] + for size, stride in zip(arg.shape, arg.stride): + dims.append(f"halide_dimension_t(0, {size}, {stride})") + + return [ + f"halide_buffer_t {name};", + f"halide_dimension_t {name}_dims[] = {{{', '.join(dims)}}};", + f"{name}.device = {device};", + f"{name}.device_interface = {device_interface};", + f"{name}.host = {host};", + f"{name}.flags = {flags};", + f"{name}.type = {arg.halide_type()};", + f"{name}.dimensions = {len(dims)};", + f"{name}.dim = {name}_dims;", + f"{name}.padding = nullptr;", + ] + + @classmethod + def _codegen_glue(cls, meta: HalideMeta, headerfile: object) -> str: + is_cuda = meta.is_cuda() + assert is_cuda is ("user_context" in meta.target) + assert "no_runtime" in meta.target + buffers = [] + buffer_names = [] + for i, arg in enumerate(meta.argtypes): + if arg.is_buffer(): + buffer_names.append(f"&hl_buf_{i}") + buffers.extend(cls._codegen_buffer(f"hl_buf_{i}", arg, is_cuda)) + else: + assert "*" not in arg.ctype + buffer_names.append(arg.name) + buffers = "\n".join([f" {line}" for line in buffers]).lstrip() + + glue_template = cls.glue_template_cuda if is_cuda else cls.glue_template_cpp + glue_code = glue_template.format( + halideruntime_h=cls.find_header( + "HalideRuntimeCuda.h" if is_cuda else "HalideRuntime.h" + ), + headerfile=headerfile, + argdefs=", ".join( + f"{a.bindings_type()} {a.name}" + for a in meta.argtypes + if a.alias_of is None + ), + buffers=buffers, + buffer_names=", ".join(buffer_names), + ) + return glue_code + + @classmethod + @functools.lru_cache(None) + def config_hash(cls) -> str: + command_gen = CppBuilder( + name="O", + sources="I", + BuildOption=CppOptions(), + ) + command_line = command_gen.get_command_line() + return sha256_hash( + "\n".join( + [ + cls.glue_template_cpp, + cls.glue_template_cuda, + cls.standalone_runtime_cuda_init, + command_line, + ] + ).encode("utf-8") + ) + + @staticmethod + def _search_for_file(suffix: str, errmsg: str) -> str: + spec = importlib.machinery.PathFinder.find_spec("halide") + if spec is None or not spec.submodule_search_locations: + raise RuntimeError("halide python bindings not installed") + try: + search = spec.submodule_search_locations[0] + for file in os.listdir(search): + if file.endswith(".so"): + try: + out = subprocess.check_output( + ["ldd", os.path.join(search, file)] + ) + except subprocess.SubprocessError: + continue + m = re.search(r"(/.*)/libHalide.so", out.decode("utf-8")) + if m: + path = os.path.join(os.path.abspath(m.group(1)), suffix) + if os.path.exists(path): + return os.path.abspath(path) + except Exception as e: + raise RuntimeError(errmsg) from e + raise RuntimeError(errmsg) + + @staticmethod + @functools.lru_cache(None) + def find_libautoschedule(name: str) -> str: + sofile = f"libautoschedule_{name.lower()}.so" + if "HALIDE_LIB" in os.environ: + path = os.path.join(os.environ["HALIDE_LIB"], sofile) + if os.path.exists(path): + return path + errmsg = ( + f"Can't find {sofile}, set env HALIDE_LIB to the directory containing it" + ) + return HalideCodeCache._search_for_file(sofile, errmsg) + + @staticmethod + @functools.lru_cache(None) + def find_header(name: str) -> str: + if "HALIDE_INCLUDE" in os.environ: + path = os.path.join(os.environ["HALIDE_INCLUDE"], name) + if os.path.exists(path): + return path + if "HALIDE_LIB" in os.environ: + path = os.path.abspath( + os.path.join(os.environ["HALIDE_LIB"], f"../include/{name}") + ) + if os.path.exists(path): + return path + errmsg = ( + f"Can't find {name}, set env HALIDE_INCLUDE to the directory containing it" + ) + return HalideCodeCache._search_for_file(f"../include/{name}", errmsg) + + @classmethod + def generate_halide_async( + cls, meta: HalideMeta, source_code: str, submit_fn: Any = None + ) -> Callable[[], Any]: + dirpath = Path( + get_path( + code_hash( + source_code, + extra=repr((cls.config_hash(), meta)), + ), + "halide", + )[2] + ) + os.makedirs(dirpath, exist_ok=True) + wait_for_compile = None + genfile = str(dirpath / "generate_kernel.py") + libfile = str(dirpath / "halide_kernel.a") + headerfile = str(dirpath / "halide_kernel.h") + donefile = str(dirpath / "done") + lockfile = str(dirpath / "lock") + need_compile = not os.path.exists(donefile) + jobs = [] + if need_compile: + write_atomic(genfile, source_code) + cmd = [ + sys.executable, + genfile, + "-g", + "kernel", + "-o", + f"{dirpath}", + "-f", + "halide_kernel", + "-e", + "static_library,h,schedule", + ] + if meta.scheduler: + cmd.extend(["-p", cls.find_libautoschedule(meta.scheduler)]) + cmd.extend(meta.args()) + jobs.append(functools.partial(subprocess.check_call, cmd)) + + binding_types = [ + arg.bindings_type() for arg in meta.argtypes if arg.alias_of is None + ] + if meta.is_cuda(): + binding_types.append("uintptr_t") # stream + bindings_future = cls.load_pybinding_async( + binding_types, + cls._codegen_glue(meta, headerfile), + extra_flags=(libfile, cls.build_standalone_runtime()), + submit_fn=jobs.append if need_compile else None, + cuda=meta.is_cuda(), + ) + + if need_compile: + jobs.append(functools.partial(touch, donefile)) + task = functools.partial(_worker_task_halide, lockfile, jobs) + if submit_fn: + wait_for_compile = submit_fn(task).result + else: + task() + + def load() -> Callable[[], Any]: + if wait_for_compile: + wait_for_compile() + return bindings_future() + + return load + + @classmethod + def generate_halide(cls, *args: Any, **kwargs: Any) -> Callable[[], Any]: + return cls.generate_halide_async(*args, **kwargs)() + + @classmethod + def build_standalone_runtime(cls) -> str: + if cls._standalone_runtime_path and os.path.exists( + cls._standalone_runtime_path + ): + return cls._standalone_runtime_path + is_cuda = torch.cuda.is_available() + libname = "libStandaloneHalideRuntime.so" + target = "host-cuda" if is_cuda else "host" + if cls._standalone_runtime_path: + assert not os.path.exists(cls._standalone_runtime_path) + # We hit this case in unittests when we run with fresh_inductor_cache() + # Generating a fresh runtime over and over causes errors because we initialize + # cuda hundreds of times in the same process and run out of file descriptors. + # Workaround by jail breaking the current fresh_inductor_cache(). + base = default_cache_dir() + else: + base = cache_dir() + dirpath = Path(base) / f"halide-runtime-{target}-{cls.config_hash()}" + os.makedirs(dirpath, exist_ok=True) + donefile = str(dirpath / "done") + lockfile = str(dirpath / "lock") + hookfile = str(dirpath / "hooks.cpp") + afile = str(dirpath / "standalone_halide_runtime.a") + sofile = str(dirpath / libname) + if not os.path.exists(donefile): + import filelock + import halide as hl # type: ignore[import-untyped,import-not-found] + + with filelock.FileLock(lockfile, LOCK_TIMEOUT): + if not os.path.exists(donefile): + with open(hookfile, "w") as f: + if is_cuda: + f.write( + cls.standalone_runtime_cuda_init.format( + cls.find_header("HalideRuntimeCuda.h") + ) + ) + hl.compile_standalone_runtime(afile, hl.Target(target)) + + name, output_dir = get_name_and_dir_from_output_file_path(sofile) + halide_cmd_gen = CppBuilder( + name=name, + sources=[hookfile, afile], + output_dir=output_dir, + BuildOption=CppTorchCudaOptions( + cuda=is_cuda, + ), + ) + + subprocess.check_call( + shlex.split(halide_cmd_gen.get_command_line()) + ) + touch(donefile) + assert os.path.exists(sofile) + cls._standalone_runtime_path = sofile + return sofile + + +def _worker_task_halide(lockfile: str, jobs: List[partial[Any]]) -> None: + from filelock import FileLock + + try: + with FileLock(lockfile, LOCK_TIMEOUT): + for job in jobs: + job() + except subprocess.SubprocessError as e: + if os.environ.get("HALIDE_REPRO") == "1": + python, script, *cmd = getattr(e, "cmd", ("", "", "")) + if os.path.basename(python).startswith("python"): + code = open(script).read() + main = " hl.main()" + assert code.count(main) == 1 + + class Out: + def __repr__(self) -> str: + return "out" + + cmd[cmd.index("-o") + 1] = Out() # type: ignore[call-overload] + repl = textwrap.indent( + textwrap.dedent( + f"""\ + import sys, tempfile + with tempfile.TemporaryDirectory() as out: + sys.argv = {["repro.py", *cmd]!r} + hl.main() + """ + ), + " ", + ) + code = code.replace(main, repl) + with open("repro.py", "w") as fd: + fd.write(code.lstrip()) + raise RuntimeError(f"wrote repro.py: {e}") from e + raise + + +def touch(filename: str): # type: ignore[no-untyped-def] + open(filename, "a").close() + + +@clear_on_fresh_inductor_cache +class PyCodeCache: + cache: Dict[str, ModuleType] = {} + linemaps: Dict[str, List[Tuple[Any, ...]]] = {} + cache_clear = staticmethod(cache.clear) + + @classmethod + def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]: + return write(source_code, "py", extra=extra) + + @classmethod + def load( + cls, + source_code: str, + extra: str = "", + linemap: Optional[List[Tuple[int, str]]] = None, + attrs: Optional[Dict[str, Any]] = None, + ) -> ModuleType: + key, path = write(source_code, "py", extra=extra) + return cls.load_by_key_path(key, path, linemap, attrs) + + @classmethod + def load_by_key_path( + cls, + key: str, + path: str, + linemap: Optional[List[Tuple[int, str]]] = None, + attrs: Optional[Dict[str, Any]] = None, + ) -> ModuleType: + if linemap is None: + linemap = [] + if key not in cls.cache: + mod = _reload_python_module(key, path) + + # another thread might set this first + cls.cache.setdefault(key, mod) + # unzip into separate lines/nodes lists + cls.linemaps[path] = list(zip(*linemap)) + + if attrs is not None: + for k, v in attrs.items(): + setattr(mod, k, v) + + if not (linemap or attrs): + mod._reload_in_subproc = functools.partial( # type: ignore[attr-defined] + _reload_python_module_in_subproc, key, path + ) + + return cls.cache[key] + + @classmethod + @functools.lru_cache(None) + def stack_frames_for_code( + cls, path: str, lineno: int + ) -> Optional[List[Dict[str, Any]]]: + if path not in cls.linemaps: + return None + # [(starting_line, ), ...] + lines, nodes = cls.linemaps[path] + p = bisect_right(lines, lineno) + if p == 0: + return None + entry = nodes[p - 1] + if not entry: + return None + + def parse_stack_trace(stack_trace: str) -> List[Dict[str, Any]]: + # ideally fx stores stack traces as data rather than a string + # but this is not along a performance critical path + regex = r'File "(.+)", line (\d+), in (.+)\n' + matches = re.findall(regex, stack_trace) + return [ + {"filename": f, "line": int(l), "name": n} + for f, l, n in reversed(matches) + ] + + return parse_stack_trace(entry) + + +class TritonCodeCache: + @classmethod + def load(cls, kernel_name: str, source_code: str) -> ModuleType: + return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name) + + +def _cuda_compiler() -> Optional[str]: + if cuda_env.nvcc_exist(config.cuda.cuda_cxx): + return config.cuda.cuda_cxx + if config.is_fbcode(): + return os.path.join(build_paths.cuda(), "bin", "nvcc") + if cuda_env.nvcc_exist(os.getenv("CUDACXX")): + return os.getenv("CUDACXX", "") + if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")): + return os.path.realpath(os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc")) + return "nvcc" + + +def _cutlass_include_paths() -> List[str]: + if config.is_fbcode(): + from libfb.py import parutil + + cutlass_path = parutil.get_dir_path("cutlass-3-headers") + else: + cutlass_path = config.cuda.cutlass_dir + return [ + # Use realpath to get canonical absolute paths, in order not to mess up cache keys + os.path.realpath(os.path.join(cutlass_path, "include")), + os.path.realpath(os.path.join(cutlass_path, "tools/library/include")), + os.path.realpath(os.path.join(cutlass_path, "tools/library/src")), + os.path.realpath(os.path.join(cutlass_path, "tools/util/include")), + ] + + +def _cuda_lib_options() -> List[str]: + _set_gpu_runtime_env() # cpp_extension consults the env + from torch.utils import cpp_extension + + lpaths = cpp_extension.library_paths(cuda=True) + [ + sysconfig.get_config_var("LIBDIR") + ] + extra_ldflags: List[str] = [] + if is_linux(): + _transform_cuda_paths(lpaths) + for path in lpaths: + # -rpath ensures the DLL can find its dependencies when loaded, even + # if the library path is non-standard. + extra_ldflags.extend([f"-L{path}", "-Xlinker", f"-rpath={path}"]) + extra_ldflags.append("-lcuda") + extra_ldflags.append("-lcudart") + else: + raise NotImplementedError( + "Unsupported env, failed to find cuda libs! Currently only Linux is supported." + ) + return extra_ldflags + + +def _nvcc_host_compiler_options() -> List[str]: + return [ + "-fPIC", + "-fno-strict-aliasing", + "-fvisibility=hidden", + "-Wconversion", + ] + + +def _nvcc_compiler_options() -> List[str]: + arch = cuda_env.get_cuda_arch() + if arch == "90": + # Required by cutlass compilation. + arch = "90a" + code = [f"sm_{arch}", f"compute_{arch}"] + if config.cuda.enable_cuda_lto: + code += [f"lto_{arch}"] + options = [ + "-t=0", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-w", + f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", + config.cuda.compile_opt_level, + "-std=c++17", + "--expt-relaxed-constexpr", + "-DNDEBUG", + ] + if config.is_fbcode(): + options.extend(["-ccbin", os.path.dirname(build_paths.gcc())]) + if config.cuda.enable_debug_info: + options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]) + if config.cuda.enable_ptxas_info: + options.extend( + [ + "--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.) + "--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels + "--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels + "--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.) + "--source-in-ptx", + ] + ) # Annotate the ptx file with source information + if config.cuda.use_fast_math: + options.extend( + [ + "--use_fast_math", + "-DCUTLASS_USE_TANH_FOR_SIGMOID=1", + ] + ) + return options + + +def cuda_compile_command( + src_files: List[str], + dst_file: str, + dst_file_ext: str, + extra_args: Optional[List[str]] = None, +) -> str: + if extra_args is None: + extra_args = [] + include_paths = _cutlass_include_paths() + cuda_lib_options = _cuda_lib_options() + nvcc_host_compiler_options = _nvcc_host_compiler_options() + nvcc_compiler_options = _nvcc_compiler_options() + options = ( + nvcc_compiler_options + + extra_args + + [ + f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}" + for opt in nvcc_host_compiler_options + ] + + ["-I" + path for path in include_paths] + + cuda_lib_options + ) + src_file = " ".join(src_files) + res = "" + if dst_file_ext == "o": + res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}" + elif dst_file_ext == "so": + options.append("-shared") + res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" + elif dst_file_ext == "exe": + res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" + else: + raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") + log.debug("CUDA command: %s", res) + return res + + +class DLLWrapper: + """A wrapper for a dynamic library.""" + + def __init__( + self, + lib_path: str, + ) -> None: + self.lib_path = lib_path + self.is_open = False + self.DLL = cdll.LoadLibrary(lib_path) + self.is_open = True + + def close(self) -> None: + if self.is_open: + self._dlclose() + self.is_open = False + + def _dlclose(self) -> None: + f_dlclose = None + + if is_linux(): + syms = CDLL(None) + if not hasattr(syms, "dlclose"): + # Apline Linux + syms = CDLL("libc.so") + + if hasattr(syms, "dlclose"): + f_dlclose = syms.dlclose + elif is_windows(): + import ctypes + + kernel32 = ctypes.CDLL("kernel32", use_last_error=True) + + f_dlclose = kernel32.FreeLibrary + else: + raise NotImplementedError("Unsupported env, failed to do dlclose!") + + if f_dlclose is not None: + if is_linux(): + f_dlclose.argtypes = [c_void_p] + f_dlclose(self.DLL._handle) + elif is_windows(): + import ctypes + from ctypes import wintypes + + f_dlclose.argtypes = [wintypes.HMODULE] + f_dlclose(self.DLL._handle) + else: + log.warning( + "dll unloading function was not found, library may not be unloaded properly!" + ) + + def __getattr__(self, name: str) -> Callable[..., None]: + if not self.is_open: + raise RuntimeError(f"Cannot use closed DLL library: {self.lib_path}") + + method = getattr(self.DLL, name) + + def _wrapped_func(*args: Any) -> None: + err = method(*args) + if err: + raise RuntimeError(f"Error in function: {method.__name__}") + + return _wrapped_func + + def __enter__(self) -> DLLWrapper: # noqa: PYI034 + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + def __del__(self) -> None: + self.close() + + +@clear_on_fresh_inductor_cache +class CUDACodeCache: + @dataclasses.dataclass + class CacheEntry: + input_path: str + output_path: str + + cache: Dict[str, CacheEntry] = {} + cache_clear = staticmethod(cache.clear) + _SOURCE_CODE_SUFFIX = "cu" + + @classmethod + def write(cls, source_code: str, dst_file_ext: str) -> Tuple[str, str]: + """ + Writes source code into a file with dst_file_ext as the file extension. + Returns the hash key of source code, and the path to the file. + """ + + cuda_command = repr( + cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext) + ) + key, input_path = write( + source_code, cls._SOURCE_CODE_SUFFIX, extra=cuda_command + ) + return key, input_path + + @classmethod + def compile( + cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None + ) -> Tuple[str, str, str]: + """ + Compiles CUDA source_code into a file with dst_file_ext extension. + Returns a tuple of dst_file_path, hash_key, source_code_path + """ + key, input_path = cls.write(source_code, dst_file_ext) + if key not in cls.cache: + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext + if not os.path.exists(output_path): + cmd = cuda_compile_command( + [input_path], output_path, dst_file_ext, extra_args + ) + start_time = time() + log.debug("CUDA Compilation: %s", cmd) + cmd_parts = cmd.split(" ") + try: + subprocess.check_output( + cmd_parts, stderr=subprocess.STDOUT, env=os.environ + ) + except subprocess.CalledProcessError as error: + raise exc.CUDACompileError(cmd_parts, error.output) from error + end_time = time() + log_duration_msg = f"CUDA Compilation took {end_time - start_time} seconds. Compile command: {cmd}" + log.info(log_duration_msg) + else: + log.debug( + "CUDA Compilation skipped: %s since output already exists", + input_path, + ) + cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path) + + return (cls.cache[key].output_path, key, input_path) + + @classmethod + def load(cls, source_code: str, dst_file_ext: str) -> Tuple[DLLWrapper, str, str]: + """ + Compiles source code and loads the generated .so file. + Returns a tuple of DLLWrapper, hash_key, source_code_path + """ + + if dst_file_ext != "so": + raise RuntimeError( + f"Only support loading a .so file for now. " + f"Requested file extension: {dst_file_ext}. Source code: {source_code}" + ) + dst_file_path, hash_key, source_code_path = cls.compile( + source_code, dst_file_ext + ) + return (DLLWrapper(dst_file_path), hash_key, source_code_path) + + +@clear_on_fresh_inductor_cache +class ROCmCodeCache: + @dataclasses.dataclass + class CacheEntry: + input_path: str + output_path: str + + cache: Dict[str, CacheEntry] = {} + cache_clear = staticmethod(cache.clear) + _SOURCE_CODE_SUFFIX = "cpp" + _logged_compiler_version = False + + @classmethod + def write(cls, source_code: str, dst_file_ext: str) -> Tuple[str, str]: + """ + Writes source code into a file with dst_file_ext as the file extension. + Returns the hash key of source code, and the path to the file. + """ + + cuda_command = repr( + rocm_compile_command(["dummy_input"], "dummy_output", dst_file_ext) + ) + key, input_path = write( + source_code, cls._SOURCE_CODE_SUFFIX, extra=cuda_command + ) + return key, input_path + + @classmethod + def compile( + cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None + ) -> Tuple[str, str, str]: + """ + Compiles source_code into a file with dst_file_ext extension, + using the compile command specific for the ROCm platform. + Returns a tuple of dst_file_path, hash_key, source_code_path + """ + if not cls._logged_compiler_version: + cls._logged_compiler_version = True + log.debug(get_compiler_version_info(str(rocm_compiler()))) + + key, input_path = cls.write(source_code, dst_file_ext) + if key not in cls.cache: + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext + if not os.path.exists(output_path): + cmd = rocm_compile_command( + [input_path], output_path, dst_file_ext, extra_args + ) + start_time = time() + cmd_parts = cmd.split(" ") + try: + output = subprocess.check_output( + cmd_parts, + stderr=subprocess.STDOUT, + text=True, + env=os.environ, + ) + log.debug("Compilation output: %s", output) + except subprocess.CalledProcessError as error: + raise exc.CUDACompileError(cmd_parts, error.output) from error + end_time = time() + log_duration_msg = f"Compilation took {end_time - start_time} seconds. Compile command: {cmd}" + log.info(log_duration_msg) + else: + log.debug( + "Compilation skipped: %s since output already exists", + input_path, + ) + cls.cache[key] = ROCmCodeCache.CacheEntry(input_path, output_path) + + return (cls.cache[key].output_path, key, input_path) + + @classmethod + def load(cls, source_code: str, dst_file_ext: str) -> Tuple[DLLWrapper, str, str]: + """ + Compiles source code and loads the generated .so file. + Returns a tuple of DLLWrapper, hash_key, source_code_path + """ + + if dst_file_ext != "so": + raise RuntimeError( + f"Only support loading a .so file for now. " + f"Requested file extension: {dst_file_ext}. Source code: {source_code}" + ) + dst_file_path, hash_key, source_code_path = cls.compile( + source_code, dst_file_ext + ) + return (DLLWrapper(dst_file_path), hash_key, source_code_path) + + +class CodeCacheFuture: + def result(self) -> None: + raise NotImplementedError + + +class TritonFuture(CodeCacheFuture): + kernel: ModuleType + + def __init__( + self, + kernel: Any, + future: Optional[Future[Any]], + ) -> None: + self.kernel = kernel + self.future = future + + def result(self) -> ModuleType: # type: ignore[override] + if self.future is not None: + # If the worker failed this will throw an exception. + result = self.future.result() + assert result is None + self.future = None + self.kernel.precompile() + return self.kernel + + +class LambdaFuture(CodeCacheFuture): + def __init__(self, result_fn: Callable[..., Any]) -> None: + self.result_fn = result_fn + + def result(self) -> Callable[..., Any]: # type: ignore[override] + return self.result_fn() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py b/.venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a233a3b9e21eb33d79b6ad60cd3ba4f276ddae --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py @@ -0,0 +1,264 @@ +import functools +import math +from enum import IntEnum + +import sympy + +import torch + +from . import ir +from .utils import get_dtype_size, sympy_product +from .virtualized import V + + +class NCCL_COLL(IntEnum): + ALL_REDUCE = 0 + ALL_GATHER = 1 + REDUCE_SCATTER = 2 + + +class NVIDIA_GPU_TYPE(IntEnum): + VOLTA = 0 + AMPERE = 1 + HOPPER = 2 + + +@functools.lru_cache +def get_gpu_type() -> NVIDIA_GPU_TYPE: + gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or "" + if "V100" in gpu_info: + return NVIDIA_GPU_TYPE.VOLTA + elif "A100" in gpu_info: + return NVIDIA_GPU_TYPE.AMPERE + elif "H100" in gpu_info: + return NVIDIA_GPU_TYPE.HOPPER + else: + # for other gpu types, assume Ampere + return NVIDIA_GPU_TYPE.AMPERE + + +def get_collective_type(node: ir.IRNode) -> NCCL_COLL: + if not isinstance(node, ir._CollectiveKernel): + raise ValueError(f"node is not a collective kernel: {node}") + + kernel_name = node.python_kernel_name + assert kernel_name is not None + if "all_reduce" in kernel_name: + return NCCL_COLL.ALL_REDUCE + elif "all_gather" in kernel_name: + return NCCL_COLL.ALL_GATHER + elif "reduce_scatter" in kernel_name: + return NCCL_COLL.REDUCE_SCATTER + else: + raise ValueError(f"Unsupported collective kernel: {kernel_name}") + + +def get_collective_input_size_bytes(node: ir.IRNode) -> int: + sz_bytes = 0 + for inp in node.inputs: # type: ignore[attr-defined] + numel = sympy_product(inp.layout.size) + if isinstance(numel, sympy.Integer): + # For ease of testing + numel = int(numel) + else: + numel = V.graph.sizevars.size_hint(numel, fallback=0) + sz_bytes += numel * get_dtype_size(inp.layout.dtype) + return sz_bytes + + +def get_collective_group_size(node: ir.IRNode) -> int: + if type(node) == ir._CollectiveKernel: + from torch.distributed.distributed_c10d import _get_group_size_by_name + + return _get_group_size_by_name(node.constant_args[-1]) + else: + raise TypeError(f"Unsupported collective type: {node}") + + +#################################################################################################################### +# The following code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # +#################################################################################################################### + + +class NCCL_HW(IntEnum): + NVLINK = 0 + PCI = 1 + NET = 2 + + +class NCCL_ALGO(IntEnum): + TREE = 0 + RING = 1 + + +class NCCL_PROTO(IntEnum): + # The ordering and enum values here matches original in + # https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28 + # For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990 + LL = 0 # Low-latency + # LL128 = 1 # Low-latency 128-byte + # SIMPLE = 2 + + +# Latencies in us +# len(NCCL_ALGO) x len(NCCL_PROTO) +# NOTE: use array instead of tensor to prevent incompatibility with fake mode +baseLat = [ + # Tree + [ + 6.8, # LL + ], + # Ring + [ + 6.6, # LL + ], +] + +# Latencies in us +# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO) +hwLat = [ + # NVLINK + [ + [0.6], # Tree (LL) + [0.6], # Ring (LL) + ], + # PCI + [ + [1.0], # Tree (LL) + [1.0], # Ring (LL) + ], + # NET + [ + [5.0], # Tree (LL) + [2.7], # Ring (LL) + ], +] + + +# LL128 max BW per channel +llMaxBws = [ + # Volta-N1/Intel-N2/Intel-N4 + [ + 39.0, + 39.0, + 20.4, + ], + # Ampere-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], + # Hopper-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], +] + + +def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: + """ + Returns estimated NCCL collective runtime in nanoseconds (ns). + + The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc. + We aim to estimate the runtime as accurately as possible. + + Assumptions: + - only ring algorithm (NCCL_ALGO_RING) is used + - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used + - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + - collective is one of: allreduce, reducescatter, allgather + """ + tensor_storage_size_bytes = get_collective_input_size_bytes(node) + # Convert bytes to GB + tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024 + + # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus. + # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + num_gpus_per_node = 8 + group_size = get_collective_group_size(node) + nNodes = math.ceil(group_size / num_gpus_per_node) + nRanks = group_size # this is total # of gpus globally that participate in this collective op + + if nRanks <= 1: + return 0 + + # Assumes ring algorithm + nccl_algo = NCCL_ALGO.RING + nccl_proto = NCCL_PROTO.LL + coll = get_collective_type(node) + + # =============== bandwidth computation =============== + # First compute bandwidth in GB/s; then at the end, convert it to GB/ns + + bwIntra = torch._inductor.config.intra_node_bw + bwInter = torch._inductor.config.inter_node_bw + + compCapIndex = get_gpu_type() + index2 = nNodes - 1 if nNodes <= 2 else 2 + # LL: for single node, we look at GPU type; for multi-node, we look at CPU type + index1 = compCapIndex if nNodes == 1 else 0 + llMaxBw = llMaxBws[index1][index2] + + # NOTE: each step of ring algorithm is synchronized, + # and is bottlenecked by the slowest link which is the inter-node interconnect. + # hence when nNodes >= 2, bw is inter-node bandwidth. + # NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc + # have this as `if nNodes <= 2` which seems wrong. Corrected it here. + bw = bwIntra if nNodes == 1 else bwInter + nChannels = 2 # Assume # channels is 2 + busBw = nChannels * bw + + # Various model refinements + busBw = min( + llMaxBw, + busBw + * (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0), + ) + + if coll == NCCL_COLL.ALL_REDUCE: + nsteps = 2 * (nRanks - 1) + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): + nsteps = nRanks - 1 + + # Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time) + ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined] + bandwidth = busBw * ratio + # Convert GB/s to GB/ns + bandwidth_GB_per_ns = bandwidth / 1e9 + + # =============== latency computation =============== + intraHw = NCCL_HW.NVLINK + + if coll == NCCL_COLL.ALL_REDUCE: + if nNodes > 1: + nInterSteps = 2 * nNodes + else: + nInterSteps = 0 + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): + nInterSteps = nNodes - 1 + + # First compute latency in us; then at the end, convert it to ns + latency = baseLat[nccl_algo][nccl_proto] + intraLat = hwLat[intraHw][nccl_algo][nccl_proto] + interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto] + + # Inter-node rings still have to launch nsteps * net overhead. + netOverhead = 0.0 + if nNodes > 1: + netOverhead = 1.0 # getNetOverhead(comm); + intraLat = max(intraLat, netOverhead) + latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined] + # Convert us to ns + latency_ns = latency * 1e3 + + # =============== final result =============== + transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns + return transport_ns + latency_ns + + +################################################################################################################ +# The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # +################################################################################################################ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/comms.py b/.venv/lib/python3.11/site-packages/torch/_inductor/comms.py new file mode 100644 index 0000000000000000000000000000000000000000..dcad1e1bf67c9841937792f06e1eac6cda52ce62 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/comms.py @@ -0,0 +1,640 @@ +# mypy: allow-untyped-defs +# pyre-strict +from __future__ import annotations + +import heapq +import operator +import sys +from collections import defaultdict +from typing import Dict, List, Set, TYPE_CHECKING + +import torch + +from . import config, ir +from .dependencies import WeakDep +from .utils import ( + contains_collective, + contains_wait, + find_recursive_deps_of_node, + find_recursive_users_of_node, + is_collective, + is_fallback_op, + is_wait, +) + + +overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") + +if TYPE_CHECKING: + from .scheduler import BaseSchedulerNode + + +def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: + """ + Greedily schedules waits as late as possible. + """ + return _schedule_for_comm( + snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False + ) + + +def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: + """ + Greedily schedules comms as early as possible. + """ + return _schedule_for_comm( + snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False + ) + + +def reorder_compute_for_overlap( + snodes: List[BaseSchedulerNode], +) -> List[BaseSchedulerNode]: + """ + This achieves the following overall scheduling procedure: + Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes + that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N. + Step 2: If all those compute nodes are sufficient to overlap comm N, we're done. + Otherwise, we now need to look elsewhere to find compute that overlaps with comm N. + We prioritize compute nodes that are needed sooner. + Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1. + Step 4: We schedule comm N + 1. + Repeat this for subsequent comm nodes. + """ + return _schedule_for_comm( + snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True + ) + + +def _schedule_for_comm( + snodes: List[BaseSchedulerNode], + raise_comms: bool, + sink_waits: bool, + reorder_for_overlap: bool, +) -> List[BaseSchedulerNode]: + """ + Schedule `snodes` for various comm optimization objectives. + + Args: + snodes: the nodes to be scheduled. + raise_comms: whether to greedily schedule collectives as early as possible + sink_wait: whether to greedily schedule waits as late as possible + reorder_compute_for_overlap: whether to reorder compute nodes to + optimize for compute/communication overlapping. + + Returns: + The new schedule order. + + Some notes on the synergy between different options: + - `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`. + - When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized. + """ + # We assign each node a tuple of scores (score_0, score_1, score_2), + # decreasing in importance, with a lower value indicating a higher ranking: + # + # - score_0: the lowest comm_idx among the comm nodes that the node blocks. + # If a node doesn't block any comm nodes, its score_0 is set to + # sys.maxsize. This score ensures that comm nodes get scheduled as early as + # possible. + # - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures + # that wait nodes are deferred as late as possible. + # - score_2: the index of the node in the original topological order. This + # score provides stability in case of ties. + # + # When only raise_comms is True, only score_0 and score_2 are considered. + # When only sink_waits is True, only score_1 and score_2 are considered. + # When neither is True, the original order is yielded. + buf_name_to_snode = {} + name_to_fused_node = {} + scores_0, scores_1, scores_2 = {}, {}, {} + for idx, snode in enumerate(snodes): + for buf_name in snode.get_buffer_names(): + buf_name_to_snode[buf_name] = snode + + for op_name in snode.get_operation_names(): + name_to_fused_node[op_name] = snode + name_to_fused_node[snode.get_name()] = snode + + node_name = snode.get_name() + scores_0[node_name] = sys.maxsize + scores_1[node_name] = 0 + scores_2[node_name] = idx + + comm_idx = 0 + for snode in snodes: + if raise_comms and contains_collective(snode): + scores_0[snode.get_name()] = comm_idx + for anc in snode.ancestors: + anc_fused_name = name_to_fused_node[anc].get_name() + scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx) + comm_idx += 1 + elif sink_waits and contains_wait(snode): + scores_1[snode.get_name()] = 1 + + class Runnable: + def __init__(self, snode) -> None: + self.snode = snode + name = next(iter(snode.get_operation_names())) + fused_name = name_to_fused_node[name].get_name() + self.score = ( + scores_0[fused_name], + scores_1[fused_name], + scores_2[fused_name], + ) + + def __lt__(self, other): + return self.score < other.score + + unmet_deps: Dict[BaseSchedulerNode, Set[str]] = { + snode: {dep.name for dep in snode.unmet_dependencies} for snode in snodes + } + + ready: List[Runnable] = [] + buffer_users: Dict[str, Set[BaseSchedulerNode]] = defaultdict(set) + snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes} + + for snode, deps in unmet_deps.items(): + if len(deps) == 0: + heapq.heappush(ready, Runnable(snode)) + for dep in deps: + buffer_users[dep].add(snode) + + scheduled = [] + + def schedule(snode): + """ + Schedules `snode` and put all unblocked nodes onto the ready queue. + """ + scheduled.append(snode) + for buf_name in snode.get_buffer_names(): + for snode in buffer_users[buf_name]: + unmet_deps[snode].remove(buf_name) + if len(unmet_deps[snode]) == 0: + heapq.heappush(ready, Runnable(snode)) + + def get_overlapping_candidate(): + """ + Return the next node in the ready queue that's neither a collective or + a wait. + """ + candidates = [ + x + for x in ready + if not contains_collective(x.snode) and not contains_wait(x.snode) + ] + if len(candidates) == 0: + return None + return min(candidates, key=lambda x: x.score) + + def schedule_collective_for_overlap(snode): + """ + Schedules collective node `snode`, along with one or more compute nodes + to overlap with it. The strategy is described in the comment of + `reorder_compute_for_overlap`. + """ + assert contains_collective(snode) + schedule(snode) + + collective_cost = snode_to_cost[snode] + while ( + collective_cost > 0 + and (candidate := get_overlapping_candidate()) is not None + ): + ready.remove(candidate) + schedule(candidate.snode) + collective_cost -= snode_to_cost[candidate.snode] + heapq.heapify(ready) + + while len(ready): + snode = heapq.heappop(ready).snode + if reorder_for_overlap and contains_collective(snode): + schedule_collective_for_overlap(snode) + else: + schedule(snode) + + for snode, deps in unmet_deps.items(): + assert len(deps) == 0, ( + "Detected unscheduled nodes. " + f"Nodes with unmet dependencies: {unmet_deps}" + ) + return scheduled + + +def decide_global_ordering_of_comms( + nodes: List[BaseSchedulerNode], name_to_buf, name_to_fused_node +) -> List[BaseSchedulerNode]: + """ + Decide global ordering of comms, by just enforcing the ordering that's in the input graph + (might not be the same ordering as the eager mode program). + TODO: Come up with a better approach + """ + # If FSDP2 is used, we apply FSDP-specific passes. + if any( + is_fallback_op( + x.node, + { + torch.ops.fsdp.all_gather_copy_in.default, + torch.ops.fsdp.chunk_cat.default, + }, + ) + for x in nodes + ): + nodes = enforce_comm_ordering_for_fsdp(nodes, name_to_buf, name_to_fused_node) + + comm_nodes = [n for n in nodes if contains_collective(n)] + + for i in range(1, len(comm_nodes)): + # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm + mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) + for buf in comm_nodes[i - 1].get_buffer_names(): + comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf)) + + return nodes + + +def estimate_op_runtime(snode: BaseSchedulerNode) -> float: + """ + Returns estimated op runtime in nanoseconds (ns) + """ + if config.estimate_op_runtime == "default": + runtime = snode.get_estimated_runtime() + else: + assert callable(config.estimate_op_runtime) + runtime = config.estimate_op_runtime(snode) + return runtime + + +def node_summary(snode): + detail = "" + if isinstance(snode.node, ir.ExternKernelOut): + detail = f" ({snode.node.python_kernel_name})" + out_tensor_info = "" + if ( + hasattr(snode.node, "layout") + and hasattr(snode.node.layout, "size") + and hasattr(snode.node.layout, "stride") + ): + out_tensor_info = ( + f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})" + ) + node_name = "" + if hasattr(snode.node, "name"): + node_name = snode.node.name + return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})" + + +def visualize_overlap(order): + total_est_runtime: float = 0.0 + cur_comm_node = None + for snode in order: + if cur_comm_node is None: + if contains_collective(snode): + total_est_runtime += estimate_op_runtime(snode) + cur_comm_node = snode.node + elif is_wait(snode.node): + raise AssertionError( + "Wait is not expected when there is no collective running" + ) + else: # exposed compute op + total_est_runtime += estimate_op_runtime(snode) + overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 + else: # cur_comm_node is not None + if contains_collective(snode): + raise AssertionError( + "Found two collectives running at the same time. " + "`visualize_overlap` needs to be updated to handle this case" + ) + elif is_wait(snode.node): # end of this comm op + overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 + cur_comm_node = None + else: # overlapped compute op + overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004 + overlap_log.debug( + f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004 + ) + + +def reorder_compute_and_comm_for_overlap( + snodes: List[BaseSchedulerNode], +) -> List[BaseSchedulerNode]: + order = snodes + + for p in config.reorder_for_compute_comm_overlap_passes: + if isinstance(p, str) and p in globals(): + p = globals()[p] # it is a builtin pass + if torch.distributed.get_rank() == 0: + overlap_log.debug( + f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004 + ) + try: + visualize_overlap(order) + except Exception as e: + overlap_log.debug(str(e)) + order = p(order) # type: ignore[operator] + if torch.distributed.get_rank() == 0: + overlap_log.debug( + f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004 + ) + try: + visualize_overlap(order) + except Exception as e: + overlap_log.debug(str(e)) + return order + + +def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None: + try: + import torch.distributed._composable.fsdp._fsdp_collectives + + assert torch.distributed.is_available() + # Assert existence of these ops + assert ( + torch.ops._c10d_functional.all_gather_into_tensor + and torch.ops._c10d_functional.all_gather_into_tensor_out + ) + except (ImportError, AttributeError, AssertionError): + return + + from .pattern_matcher import ( + CallFunction, + KeywordArg, + Match, + PatternMatcherPass, + register_graph_pattern, + ) + + """ + all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); + getitem = all_gather_copy_in[0]; + (getitem_1 = all_gather_copy_in[1];) # optional + + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...); + + -> + + all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); + getitem = all_gather_copy_in[0]; + getitem_1 = all_gather_copy_in[1]; + + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1); + """ + + def remove_unused_getitem(g): + # Remove `getitem_X = all_gather_copy_in[1]` which is never used. + node_list = list(g.nodes) + for n in node_list: + if ( + n.target == operator.getitem + and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default + and n.args[1] == 1 + ): + g.erase_node(n) + + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunction( + torch.ops._c10d_functional.all_gather_into_tensor.default, + CallFunction( + operator.getitem, + CallFunction( + torch.ops.fsdp.all_gather_copy_in.default, + KeywordArg("all_gather_inputs"), + KeywordArg("inp_split_sizes"), + KeywordArg("all_gather_input_numel"), + KeywordArg("world_size"), + KeywordArg("rank"), + KeywordArg("dtype"), + KeywordArg("device"), + ), + KeywordArg("item_idx"), + ), + KeywordArg("group_size"), + KeywordArg("group_name"), + ), + pass_dict=graph_pass, + extra_check=lambda match: match.kwargs["item_idx"] == 0, + ) + def reinplace_all_gather(match: Match, *args, **kwargs): + def repl( + *args, + ): + copy_in_args = args[:-2] + group_size = args[-2] + group_name = args[-1] + all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default( + *copy_in_args + ) + getitem = all_gather_copy_in[0] + getitem_1 = all_gather_copy_in[1] + all_gather_into_tensor = ( + torch.ops._c10d_functional.all_gather_into_tensor_out.default( + getitem, group_size, group_name, out=getitem_1 + ) + ) + return all_gather_into_tensor + + match.replace_by_example( + repl, + [ + kwargs["all_gather_inputs"], + kwargs["inp_split_sizes"], + kwargs["all_gather_input_numel"], + kwargs["world_size"], + kwargs["rank"], + kwargs["dtype"], + kwargs["device"], + kwargs["group_size"], + kwargs["group_name"], + ], + ) + + remove_unused_getitem(graph) + graph_pass.apply(graph) # type: ignore[arg-type] + + +def get_op_idx(snode): + assert not isinstance( + snode, + ( + torch._inductor.scheduler.FusedSchedulerNode, + torch._inductor.scheduler.GroupedSchedulerNode, + ), + ) + return int(snode.get_name()[2:]) + + +def enforce_comm_ordering_for_fsdp( + snodes: List[torch._inductor.scheduler.BaseSchedulerNode], + name_to_buf: Dict[str, torch._inductor.scheduler.SchedulerBuffer], + name_to_fused_node: Dict[str, BaseSchedulerNode], +) -> List[torch._inductor.scheduler.BaseSchedulerNode]: + from . import scheduler + + new_order: list[BaseSchedulerNode] = [] + scheduled = set() + ag_exists = False + rs_exists = False + ag_grouped_node_to_wait_grouped_node = {} + rs_grouped_node_to_wait_grouped_node = {} + snode_name_to_final_snode = {} + + def _create_group_node(snodes_to_group): + group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group) + for snode in snodes_to_group: + snode_name_to_final_snode[snode.get_name()] = group_node + snode_name_to_final_snode[group_node.get_name()] = group_node + return group_node + + # Create grouped nodes for specific sets of ops + for snode in snodes: + # Case 1: Handle AllGather + if is_collective( + snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default + ) and any( + is_fallback_op( + name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default + ) + for x in snode.ancestors + ): + ag_exists = True + ag_snode = snode + ag_related_snode_set: set[scheduler.BaseSchedulerNode] = set() + + # Find the "cast + copy_in + getitem + all_gather" code block + find_recursive_deps_of_node( + ag_snode, + ag_related_snode_set, + name_to_buf, + name_to_fused_node, + ) + + # Find the "all_gather + all_gather_wait_tensor + copy_out + set_" code block + allowed_ops = { + torch.ops._c10d_functional.all_gather_into_tensor_out.default, + torch.ops._c10d_functional.wait_tensor.default, + torch.ops.fsdp.split_with_sizes_copy.default, + torch.ops.aten.set_.source_Tensor, + } + find_recursive_users_of_node( + ag_snode, + ag_related_snode_set, + name_to_buf, + name_to_fused_node, + criteria_cb=lambda x: not ( + isinstance(x, scheduler.NopKernelSchedulerNode) + or ( + isinstance(x, scheduler.ExternKernelSchedulerNode) + and x.node.op_overload in allowed_ops # type: ignore[union-attr] + ) + ), + ) + + # sort nodes by original operation order + ag_related_snodes = sorted( + ag_related_snode_set, key=lambda x: get_op_idx(x) + ) + + # In the "reuse layer" case, some ops in the 2nd all-gather code block could also + # depend on ops in the 1st all-gather code block, and we don't want to group them together. + end_idx_of_current_ag_block = len(ag_related_snodes) + copy_out_count = 0 + for i in range(len(ag_related_snodes)): + cur_snode = ag_related_snodes[i] + if is_fallback_op( + cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default + ): + copy_out_count += 1 + if copy_out_count > 1: + end_idx_of_current_ag_block = i + break + + ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block] + + # Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode + wait_node_idx = None + for i in range(len(ag_related_snodes) - 1): + if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel): + wait_node_idx = i + 1 + break + assert wait_node_idx is not None + ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx]) + + # Group "all_gather_wait_tensor + copy_out + set_" into one GroupedSchedulerNode + ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:]) + + ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node + + # Case 2: Handle ReduceScatter + elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default): + rs_exists = True + rs_snode = snode + + # Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block + rs_related_snode_set: set[scheduler.BaseSchedulerNode] = set() + find_recursive_users_of_node( + rs_snode, + rs_related_snode_set, + name_to_buf, + name_to_fused_node, + ) + + # sort nodes by original operation order + rs_related_snodes = sorted( + rs_related_snode_set, key=lambda x: get_op_idx(x) + ) + + # Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode + wait_node_idx = None + for i in range(len(rs_related_snodes) - 1): + if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel): + wait_node_idx = i + 1 + break + assert wait_node_idx is not None + rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx]) + + # Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode + rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:]) + + rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node + + assert len(snode_name_to_final_snode) > 0 + if ag_exists: + assert len(ag_grouped_node_to_wait_grouped_node) > 0 + if rs_exists: + assert len(rs_grouped_node_to_wait_grouped_node) > 0 + + # Build the new node schedule, taking GroupedSchedulerNode into account + for snode in snodes: + if snode.get_name() in snode_name_to_final_snode: + snode = snode_name_to_final_snode[snode.get_name()] + if snode in scheduled: + continue + new_order.append(snode) + scheduled.add(snode) + + # Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run + # before next AllGather's "copy_in then AG" group node + prev_ag_wait = None + for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items(): + if prev_ag_wait is not None: + mutating_buf = next(iter(ag_group_node.get_buffer_names())) + for o in prev_ag_wait.get_outputs(): + ag_group_node.add_fake_dep( + WeakDep(o.get_name(), mutating_buf=mutating_buf) + ) + prev_ag_wait = wait_group_node + + # Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run + # before next ReduceScatter's "copy_in then RS" group node + prev_rs_wait = None + for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items(): + if prev_rs_wait is not None: + mutating_buf = next(iter(rs_group_node.get_buffer_names())) + for o in prev_rs_wait.get_outputs(): + rs_group_node.add_fake_dep( + WeakDep(o.get_name(), mutating_buf=mutating_buf) + ) + prev_rs_wait = wait_group_node + + return new_order # type: ignore[return-value] diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..db4c1e0eddfeceb3b65a7551c9e3615eabebc257 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py @@ -0,0 +1,1629 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import contextlib +import functools +import io +import itertools +import logging +import os +import sys +import time +import warnings +from itertools import count +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from unittest import mock + +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch.fx +import torch.utils._pytree as pytree +from functorch.compile import min_cut_rematerialization_partition +from torch._dynamo import ( + compiled_autograd, + config as dynamo_config, + logging as dynamo_logging, + utils as dynamo_utils, +) +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.repro.after_aot import wrap_compiler_debug +from torch._dynamo.utils import ( + counters, + detect_fake_mode, + flatten_graph_inputs, + lazy_format_graph_code, +) +from torch._functorch import config as functorch_config +from torch._functorch.aot_autograd import aot_export_module, make_boxed_func +from torch._inductor.codecache import ( + _StrideExprStr, + code_hash, + CompiledFxGraph, + FxGraphCache, +) +from torch._inductor.cudagraph_utils import ( + BoxedDeviceIndex, + CudagraphCachedInfo, + get_placeholder_info, + log_cudagraph_skip_and_bump_counter, + PlaceholderInfo, +) +from torch._inductor.debug import save_args_for_compile_fx_inner +from torch._inductor.runtime.runtime_utils import cache_dir +from torch._inductor.utils import ( + BoxedBool, + count_tangents, + fresh_inductor_cache, + InputType, + is_gpu, + should_assume_input_aligned, + tensor_is_aligned, +) +from torch._logging import trace_structured +from torch._ops import OpOverload +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.monitor import _WaitCounter + +from .._dynamo.backends.common import aot_autograd +from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined] +from ..fx.graph import _PyTreeCodeGen +from . import config, metrics +from .debug import DebugContext +from .decomposition import select_decomp_table +from .fx_passes.joint_graph import joint_graph_passes +from .fx_passes.post_grad import post_grad_passes, view_to_reshape +from .fx_passes.pre_grad import pre_grad_passes +from .graph import GraphLowering +from .ir import ExternKernelNode +from .utils import ( + align_inputs_from_check_idxs, + clone_preserve_strides, + copy_misaligned_inputs, + get_cloned_parameter_buffer_name, + has_incompatible_cudagraph_ops, + maybe_get_suppress_shape_guards_ctx, + output_node, + remove_unaligned_input_idxs, + shape_env_from_inputs, +) +from .virtualized import V + + +if config.is_fbcode(): + from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log +else: + # no-op decorator + def time_and_log(attr: str): + return dynamo_utils.identity + + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs") +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + + +# copy_ fails when trying to write to tensors with memory overlap, +# for expanded dimensions (a dimension which used to have size 1 -> ?) +# we can select one element from that dimension and write to it +# to achieve writing to all values of that dimension of the input tensor +def get_expanded_dims(t): + if not isinstance(t, torch.Tensor): + return None + return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1] + + +def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor: + for expanded_dim in expanded_dims: + t = torch.ops.aten.slice(t, expanded_dim, 0, 1) + return t + + +def complex_memory_overlap(t: torch.Tensor) -> bool: + # if torch._debug_has_internal_overlap thinks this tensor potentially has + # memory overlap internally, let's dig deeper to find out whether it's true. + # + # Call squeeze() so that dimension with size 1 does not cause false positive. + t = index_expanded_dims(t, get_expanded_dims(t)).squeeze() + if torch._debug_has_internal_overlap(t) != 0: + strides = t.stride() + sizes = t.shape + indices = list(range(len(strides))) + indices = [x for _, x in sorted(zip(strides, indices))] + for i in range(len(strides)): + prev_stride = 1 if i == 0 else strides[indices[i - 1]] + prev_size = 1 if i == 0 else sizes[indices[i - 1]] + if strides[indices[i]] < prev_stride * prev_size: + return True + return False + + +def get_static_input_idxs(num_fixed): + # If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes + # of cudagraphs. Rather than copying these into cudagraph-owned memory + # like we do for normal inputs on each run, we will re-record a cudagraph if these + # parameter locations change. + context = torch._guards.TracingContext.try_get() + fixed = list(range(num_fixed)) + if not context or not context.fw_metadata: + return fixed + + return fixed + context.fw_metadata.static_input_indices + + +@functools.lru_cache(None) +def _step_logger(): + return dynamo_logging.get_step_logger(log) + + +@functools.lru_cache(None) +def _warn_tf32_disabled(): + if ( + torch.cuda.is_available() + and not torch.backends.cuda.matmul.allow_tf32 + and torch.cuda.get_device_capability() >= (8, 0) + ): + warnings.warn( + "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. " + "Consider setting `torch.set_float32_matmul_precision('high')` for better performance." + ) + + +def _unlift_graph(mod, gm, graph_signature): + from torch.export.unflatten import _assign_attr, _AttrKind + + state_dict = {} + for name, param in mod.named_parameters(remove_duplicate=False): + state_dict[name] = param + _assign_attr( + param, + gm, + name, + attr_kind=_AttrKind.PARAMETER, + ) + for name, buffer in mod.named_buffers(remove_duplicate=False): + state_dict[name] = buffer + _assign_attr( + buffer, + gm, + name, + attr_kind=_AttrKind.BUFFER, + ) + + placeholder_nodes = gm.graph.find_nodes(op="placeholder") + lifted_inputs = [] + + # In AOTI, module parameters and buffers are not lifted as graph inputs. + # As a result, mutation to buffers has side effect which makes their initial + # values different from Eager. So we clone them here as a copy. + # We are not cloning for parameters, although it will be needed if we want to + # support training. + for node in placeholder_nodes: + node_name = node.name + if node_name in graph_signature.inputs_to_parameters: + parameter_name = graph_signature.inputs_to_parameters[node_name] + lifted_inputs.append(parameter_name) + elif node_name in graph_signature.inputs_to_buffers: + buffer_name = graph_signature.inputs_to_buffers[node_name] + lifted_inputs.append(buffer_name) + gm.meta[ + get_cloned_parameter_buffer_name(buffer_name) + ] = clone_preserve_strides(state_dict[buffer_name]) + else: + assert node_name in graph_signature.user_inputs + lifted_inputs.append(None) + + from torch.export._unlift import _unlift + + outputs = list(gm.graph.nodes)[-1].args[0] + mutated_outputs = [] + buffer_mutations = graph_signature.buffers_to_mutate + user_input_mutations = graph_signature.user_inputs_to_mutate + output_tokens = graph_signature.output_tokens + for idx, out in enumerate(outputs): + value = None + + if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): + if out.name in buffer_mutations: + value = buffer_mutations[out.name] + elif out.name in user_input_mutations: + value = user_input_mutations[out.name] + + mutated_outputs.append(value) + + unlifted_gm = _unlift( + gm, + lifted_inputs, + mutated_outputs, + pytree.LeafSpec(), + None, + state_dict, + {}, + ) + return unlifted_gm + + +def _get_subgraph_names(gm): + for node in sorted( + itertools.chain( + gm.graph.find_nodes(op="call_function", target=torch.ops.higher_order.cond), + gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.while_loop + ), + ) + ): + if node.target == torch.ops.higher_order.cond: + true_subgraph_name = node.args[1].name + false_subgraph_name = node.args[2].name + yield true_subgraph_name + yield false_subgraph_name + elif node.target == torch.ops.higher_order.while_loop: + cond_subgraph_name = node.args[0].name + body_subgraph_name = node.args[1].name + yield cond_subgraph_name + yield body_subgraph_name + + +def _recursive_pre_grad_passes(gm, example_inputs): + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + # as we don't have recursive example inputs, passing None here + new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None) + setattr(gm, subgraph_name, new_subgraph) + return pre_grad_passes(gm, example_inputs) + + +def _recursive_joint_graph_passes(gm): + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + _recursive_joint_graph_passes(subgraph) + joint_graph_passes(gm) + + +def _recursive_post_grad_passes(gm, is_inference: bool = False): + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + _recursive_post_grad_passes(subgraph, is_inference) + post_grad_passes(gm, is_inference) + + +def split_const_gm( + gm: torch.fx.GraphModule, + lifted_constants: Optional[Dict[str, Any]] = None, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> Tuple[torch.fx.GraphModule, Dict[str, int]]: + """ + This function takes an GraphModule input "gm". + The gm will be split into 2 components, + 1) const_gm, which consists the subgraph of gm that can be constant folded. + 2) gm (being inplace modified,) which returns the graph after constant folding. + + If an additional "lifted_constants" argument is passed in, we will assume the gm has + been lifted and run the transformation accordingly. + + When a "skip_folding_node_fn" callback is passed, we will skip constant folding on + the nodes for which the callback returns True. + + const_output_index is a mapping of corresponding node name from gm to the + output index of const_gm. + Returns (const_gm, const_output_index) + """ + from torch._inductor.constant_folding import ( + CONST_MODULE_TAG, + META_TAG, + MODULE_TAG, + replace_node_with_constant, + run_and_get_constant_graph, + ) + + const_gm, const_result = run_and_get_constant_graph( + gm, lifted_constants, skip_folding_node_fn + ) + + const_outputs = { + x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0]) + } + + to_erase_node = [] + to_replace_node = [] + const_output_index = {} + for node in gm.graph.nodes: + if node.name in const_outputs: + to_replace_node.append(node) + elif node.meta[META_TAG] == CONST_MODULE_TAG and node.op != "placeholder": + to_erase_node.append(node) + + for node in to_replace_node: + new_const_name = "_FOLDED_CONST_" + node.name + replace_node_with_constant( + gm, + node, + const_result[const_outputs[node.name]], + new_const_name, + ) + const_output_index[new_const_name] = const_outputs[node.name] + for node in to_erase_node[::-1]: + if node.users: + for n in node.users: + assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty." + else: + gm.graph.erase_node(node) + gm.recompile() + + return const_gm, const_output_index + + +def is_tf32_warning_applicable(gm: torch.fx.GraphModule): + aten = torch.ops.aten + tf32_ops = { + aten.mm.default, + aten.addmm.default, + aten.bmm.default, + aten.baddbmm.default, + } + for target in tf32_ops: + for node in gm.graph.find_nodes(op="call_function", target=target): + if ( + isinstance(node.meta.get("val", None), torch.Tensor) + and node.meta["val"].dtype == torch.float32 + and node.meta["val"].device.type == "cuda" + ): + return True + return False + + +def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]): + """ + For CPU backend, enable comprehensive padding causes some unit tests + fail due to changing number of generated kernels. Skip for now. + """ + has_gpu = any( + is_gpu(t.device.type) for t in example_inputs if isinstance(t, torch.Tensor) + ) + + if config.disable_padding_cpu and config.comprehensive_padding and not has_gpu: + perf_hint_log.info("Skip comprehensive padding on CPU") + return config.patch(comprehensive_padding=False) + else: + return contextlib.nullcontext() + + +def fake_tensor_prop( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + force_allow_non_fake_inputs: bool = False, +): + """ + If we can not detect fake mode from the context of inputs, create one. + + The created fake mode will be returned. + """ + fake_mode = detect_fake_mode(example_inputs) + if not fake_mode: + fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs) + else: + ctx = ( + contextlib.nullcontext() + if not force_allow_non_fake_inputs + else mock.patch.object(fake_mode, "allow_non_fake_inputs", True) + ) + with ctx: # type: ignore[attr-defined] + FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs( + *example_inputs + ) + + return fake_mode + + +def should_use_remote_fx_graph_cache(): + if config.fx_graph_remote_cache is not None: + return config.fx_graph_remote_cache + if not config.is_fbcode(): + return False + + if torch._utils_internal.is_fb_unit_test(): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + jk_name = "pytorch/remote_cache:fx_graph_memcache_version" + if torch.version.hip is not None: + jk_name = "pytorch/remote_cache:fx_graph_memcache_version_amd" + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(jk_name) + + +# pass config dict back to user +def get_patched_config_dict(config_patches=None) -> Dict[str, Any]: + with config.patch(config_patches): + return config.get_config_copy() + + +@contextlib.contextmanager +def with_fresh_cache_if_config(): + if config.force_disable_caches: + # Don't delete the cache dir because it has to survive beyond the + # compile_fx call. Let's put the temp dirs under the default cache + # dir so they're easier to locate. + with fresh_inductor_cache(dir=cache_dir(), delete=False): + yield + else: + yield + + +def compile_fx_inner(*args, **kwargs): + # Need with_fresh_cache_if_config for compile_fx_inner even if we already have one for + # compile_fx. The reason is the compilation for backward graph may happen after + # compile_fx return and we may want to use the _LazyGraphModule for compiling + # the backward graph as well. + with contextlib.ExitStack() as stack: + stack.enter_context(torch.utils._python_dispatch._disable_current_modes()) + stack.enter_context(_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)) + stack.enter_context( + dynamo_utils.dynamo_timed( + "compile_fx_inner", phase_name="inductor_compile", fwd_only=False + ) + ) + stack.enter_context(with_fresh_cache_if_config()) + stack.enter_context(DebugContext()) + + return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( + *args, **kwargs + ) + + +@time_and_log(attr="compilation time (in seconds)") +def _compile_fx_inner( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + cudagraphs: Optional[BoxedBool] = None, + static_input_idxs: Optional[List[int]] = None, + is_backward: bool = False, + graph_id: Optional[int] = None, + cpp_wrapper: bool = False, + aot_mode: bool = False, + is_inference: bool = False, + boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, + user_visible_outputs: Optional[Dict[str, None]] = None, + layout_opt: Optional[bool] = None, + extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None, +) -> Union[CompiledFxGraph, str]: + """ + Inductor API that compiles a single graph. + + If you change the argument list for this function, make sure you + also update the call to save_args_for_compile_fx_inner below accordingly. + """ + if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode: + # trigger the real recompilation for _LazyGraphModule before returning + # the forward method. + from torch.fx._lazy_graph_module import _LazyGraphModule + + _LazyGraphModule.force_recompile(gm) + return make_boxed_func(gm.forward) + + if static_input_idxs is None: + static_input_idxs = [] + + static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs) + + assert isinstance( + next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list) + ), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}" + + if config.save_args: + save_args_for_compile_fx_inner( + gm, + example_inputs, + cudagraphs=cudagraphs, + static_input_idxs=static_input_idxs, + is_backward=is_backward, + graph_id=graph_id, + cpp_wrapper=cpp_wrapper, + aot_mode=aot_mode, + is_inference=is_inference, + boxed_forward_device_index=boxed_forward_device_index, + user_visible_outputs=user_visible_outputs, + layout_opt=layout_opt, + ) + + if cudagraphs is None: + cudagraphs = BoxedBool(config.triton.cudagraphs) + + # Inputs to fx_codegen_and_compile + # Anything that affects codegen should go here, so if the signature + # of fx_codegen_and_compile changes, the dict should be updated accordingly + graph_kwargs = { + "cudagraphs": cudagraphs, + "static_input_idxs": static_input_idxs, + "is_backward": is_backward, + "graph_id": graph_id, + "cpp_wrapper": cpp_wrapper, + "aot_mode": aot_mode, + "is_inference": is_inference, + "user_visible_outputs": user_visible_outputs, + "layout_opt": layout_opt, + "extern_node_serializer": extern_node_serializer, + } + + start = time.time() + + fx_graph_remote_cache = should_use_remote_fx_graph_cache() + + inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) # type: ignore[arg-type] + + def codegen_and_compile( + gm, + example_inputs, + inputs_to_check, + fx_kwargs, + ): + """ + This function calls fx_codegen_and_compile and also adds some extra metadata to the resulting + compiled fx graph. The metadata is saved to FXGraphCache. + """ + compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs) + if isinstance(compiled_graph, str): + # We only return a string in aot mode, in which case we don't + # need to do any post-compilation steps: we just return the string, + # which is the filename of the compiled code. + return compiled_graph + cudagraph_info = None + if cudagraphs: + # check cudagraph disabling reasons from inductor lowering + if compiled_graph.disabled_cudagraphs_reason: + if "cuda" in compiled_graph.device_types: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}" + ) + else: + counters["inductor"]["cudagraph_skips"] += 1 + BoxedBool.disable(cudagraphs) + else: + complex_memory_overlap_inputs = any( + complex_memory_overlap(t) + for t in example_inputs + if isinstance(t, torch.Tensor) + ) + + if not config.triton.cudagraph_support_input_mutation: + # Skip supports for cudagraph-managed tensors + from torch._inductor.cudagraph_utils import ( + check_for_mutation_ignore_cuda_graph_managed_tensor, + ) + + has_mutation_str = ( + check_for_mutation_ignore_cuda_graph_managed_tensor( + gm, + compiled_graph, + static_input_idxs, # type:ignore[arg-type] + ) + ) + has_mutation = has_mutation_str is not None + + if has_mutation: + compiled_graph.disabled_cudagraphs_reason = has_mutation_str + else: + # Check mutation later to support cudagraph-managed tensors + has_mutation = None + + cudagraph_tests = [ + (not has_mutation, "mutated inputs"), + (not has_incompatible_cudagraph_ops(gm), "incompatible ops"), + (not complex_memory_overlap_inputs, "complex memory overlap"), + ( + all( + isinstance(t, (torch.Tensor, torch.SymInt)) + for t in example_inputs + ), + "non-Tensor inputs", + ), + ] + output = output_node(gm) + # output args are tuple of first argument + assert len(output.args) == 1 + stack_traces = [ + (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) + for arg in output.args[0] + ] + cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b] + placeholders = tuple(get_placeholder_info(gm.graph)) + cudagraph_info = CudagraphCachedInfo( + placeholders, stack_traces, cudagraph_fail_reasons + ) + + compiled_graph.cudagraph_info = cudagraph_info + compiled_graph.inputs_to_check = inputs_to_check + compiled_graph.fx_kwargs = fx_kwargs + # TODO: should this be part of fx_kwargs + compiled_graph.boxed_forward_device_index = boxed_forward_device_index + return compiled_graph + + with _WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _: + if ( + not config.force_disable_caches + and (config.fx_graph_cache or fx_graph_remote_cache) + and not aot_mode + ): + for i, input in enumerate(example_inputs): + if ( + isinstance(input, torch.Tensor) + and input.device.type == "cuda" + and i in static_input_idxs + ): + input._is_inductor_static = True # type: ignore[attr-defined] + + compiled_graph = FxGraphCache.load( + codegen_and_compile, + gm, + example_inputs, + graph_kwargs, + inputs_to_check, + local=config.fx_graph_cache, + remote=fx_graph_remote_cache, + ) + else: + compiled_graph = codegen_and_compile( + gm, example_inputs, inputs_to_check, graph_kwargs # type: ignore[arg-type] + ) + if aot_mode: + # AOT mode is special because codegen_and_compile returns a string. + # In that case, we don't need to run all post compilation steps, we just need + # to return the string directly. + return compiled_graph + compiled_graph = FxGraphCache.post_compile( + compiled_graph, example_inputs, cudagraphs + ) + + log.debug("FX codegen and compilation took %.3fs", time.time() - start) + + _step_logger()( + logging.INFO, + "torchinductor done compiling " + f"{'BACKWARDS' if is_backward else 'FORWARDS'} " + f"graph {graph_id}", + ) + # aot autograd needs to know to pass in inputs as a list + compiled_graph._boxed_call = True + return compiled_graph + + +def fx_codegen_and_compile( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + cudagraphs: Optional[BoxedBool] = None, + static_input_idxs: Optional[List[int]] = None, + is_backward: bool = False, + graph_id: Optional[int] = None, + cpp_wrapper: bool = False, + aot_mode: bool = False, + is_inference: bool = False, + # Use a dict with None value rather than a set for deterministic + # iteration order just in case. + user_visible_outputs: Optional[Dict[str, None]] = None, + layout_opt: Optional[bool] = None, + extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None, +) -> Union[CompiledFxGraph, str]: + if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None: + import time + + log.warning("Sleeping for %s since sleep_sec_TESTING_ONLY is set", sleep_sec) + time.sleep(sleep_sec) + + with dynamo_utils.preserve_rng_state(): + if is_tf32_warning_applicable(gm): + _warn_tf32_disabled() + + inductor_counters = counters["inductor"].copy() + + # lift the maximum depth of the Python interpreter stack + # to adapt large/deep models + sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000)) + + _step_logger()( + logging.INFO, + "torchinductor compiling " + f"{'BACKWARDS' if is_backward else 'FORWARDS'} " + f"graph {graph_id}", + ) + + def log_graph_runnable(): + fd = io.StringIO() + torch._dynamo.repro.after_aot.save_graph_repro( + fd, gm, example_inputs, "inductor", save_dir=None + ) + return fd.getvalue() + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_runnable", + "encoding": "string", + }, + payload_fn=lambda: log_graph_runnable(), + ) + + V.debug.fx_graph(gm, example_inputs) + # TODO: Should we actually dump this? It should be redundant with the aot + # structured logs... + # trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False)) + + shape_env = shape_env_from_inputs(example_inputs) + + # Convert view to reshape in the graph. This is necessary primarily for + # layout optimization. Do it unconditionally for uniformity. + # + # It's needed because when we do layout optimization, an contiguous tensor + # in eager mode may becomes a channels last tensor. A view op previously + # can be applied to the contiguous tensor may not be able to be applied + # on the channels tensor any more. An error like + # RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + # will be printed. + # + # Replace view op to reshape op in this case. + # As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this. + # + # Also this has to be done before FakeTensorProp below to avoid the failed + # .view() call. + view_to_reshape(gm) + + # It is safe to run FakeTensorProp under no_grad because by the time + # we're in inductor, we assume that AOTAutograd has already "taken care" + # of autograd, so there should be no more autograd-related API's in the + # graph. + with torch.no_grad(): + fake_mode = fake_tensor_prop(gm, example_inputs) + + # pattern matcher passes might not preserve striding information + # on node.meta["val"]. if in the future we rely on these being + # correct we will need to fix. + + with V.set_fake_mode(fake_mode): + # has some issues with memory in training + _recursive_post_grad_passes(gm, is_inference=is_inference) + V.debug.fx_graph_transformed(gm, example_inputs) + post_grad_graphs_log.debug( + "%s", + lazy_format_graph_code( + "AFTER POST GRAD", + gm, + include_stride=True, + include_device=True, + colored=True, + ), + ) + trace_structured( + "inductor_post_grad_graph", + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + if config.is_fbcode(): + log_optimus_to_scuba( + extra_logging={"pt2_configs": str(get_patched_config_dict())} + ) + + with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding( + example_inputs + ): + const_output_index = None + const_graph = None + const_code = None + + if aot_mode and config.aot_inductor.use_runtime_constant_folding: + const_gm, const_output_index = split_const_gm(gm) + + const_graph = GraphLowering( + const_gm, + example_inputs=[], + shape_env=shape_env, + graph_id=graph_id, + cpp_wrapper=cpp_wrapper, + aot_mode=aot_mode, + user_visible_outputs=user_visible_outputs, + extern_node_serializer=extern_node_serializer, + is_inference=is_inference, + is_const_graph=True, + ) + with V.set_graph_handler(const_graph): + assert cpp_wrapper, "AOT mode only supports C++ wrapper" + const_graph.run() + + const_code, _ = const_graph.codegen_with_cpp_wrapper() + + graph = GraphLowering( + gm, + # example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning. + # For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass, + # we currently use fake tensors and defake them later. + example_inputs=example_inputs, + shape_env=shape_env, + graph_id=graph_id, + cpp_wrapper=cpp_wrapper, + aot_mode=aot_mode, + user_visible_outputs=user_visible_outputs, + extern_node_serializer=extern_node_serializer, + is_inference=is_inference, + const_output_index=const_output_index, + const_code=const_code, + const_module=const_graph, + ) + metrics_helper = metrics.CachedMetricsHelper() + with V.set_graph_handler(graph): + graph.run(*example_inputs) + output_strides: List[Optional[Tuple[_StrideExprStr, ...]]] = [] + if graph.graph_outputs is not None: + # We'll put the output strides in the compiled graph so we + # can later return them to the caller via TracingContext + p = SymExprPrinter() + for out in graph.graph_outputs: + if ( + hasattr(out, "layout") + and len(free_unbacked_symbols(out.layout.stride)) == 0 + ): + # Convert to string for eval on the load path + output_strides.append( + tuple(p.doprint(s) for s in out.layout.stride) + ) + else: + output_strides.append(None) + + _check_triton_bf16_support(graph) + compiled_fn = graph.compile_to_fn() + num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() + metrics.num_bytes_accessed += num_bytes + metrics.node_runtimes += node_runtimes + metrics.nodes_num_elem += nodes_num_elem + + if ( + cudagraphs + and config.triton.cudagraph_skip_dynamic_graphs + and not V.graph.disable_cudagraphs_reason + and torch._inductor.utils.any_is_symbolic(*example_inputs) + ): + stack_trace = None + for node in gm.graph.nodes: + meta_val = node.meta.get("val", None) + if ( + node.op == "placeholder" + or not isinstance(meta_val, torch.Tensor) + or not torch._inductor.utils.any_is_symbolic(meta_val) + ): + continue + + if stack_trace := node.meta.get("stack_trace", None): + break + disable = "graph with symbolic shapes inputs and config.triton.cudagraph_skip_dynamic_graphs=True." + if stack_trace: + disable = f"{disable} Found from {stack_trace}\n" + else: + disable = f"{disable}\n" + V.graph.disable_cudagraphs_reason = disable + + if V.aot_compilation is True: + return compiled_fn + + if cudagraphs and not V.graph.disable_cudagraphs_reason: + from torch._inductor.cudagraph_utils import ( + check_lowering_disable_cudagraph, + ) + + V.graph.disable_cudagraphs_reason = ( + check_lowering_disable_cudagraph(V.graph.device_node_mapping) + ) + + compiled_graph = CompiledFxGraph( + compiled_fn, + graph, + output_strides, + V.graph.disable_cudagraphs_reason, + metrics_helper.get_deltas(), + counters["inductor"] - inductor_counters, + ) + + return compiled_graph + + +def get_input_idxs_to_check( + inputs: List[InputType], + static_input_idxs: Sequence[int], +) -> Sequence[int]: + """ + This function runs at compile time, and generates a list of indices for which we + might need to do a copy to preserve alignment requirements. + """ + ids_to_check = [] + + for i, input in enumerate(inputs): + if not isinstance(input, torch.Tensor): + # non-tensors don't need alignment + continue + if not is_gpu(input.device.type): + # right now we only care for gpu tensors + continue + with maybe_get_suppress_shape_guards_ctx(): + # suppress guards so that tensor_is_aligned and should_assume_input_aligned + # do not add guards on input's storage offset + if i in static_input_idxs and tensor_is_aligned(input): + continue + if not should_assume_input_aligned(input): + continue + + # if we get here, then + # (a) our triton code assumes that the input is aligned + # (b) we can't be sure ahead of time that the input will actually be aligned. + # therefore, at runtime, we'll need to check that the input is aligned + # (and if not, clone it to make it aligned.) + ids_to_check.append(i) + + return ids_to_check + + +def cudagraphify( + model: Callable[..., Any], + static_input_idxs: Sequence[int] = (), + *, + device_index: int, + stack_traces: List[Optional[str]], + is_backward: bool, + is_inference: bool, + constants: Tuple[torch.Tensor, ...] = (), + placeholders: Sequence[PlaceholderInfo] = (), + mutated_input_idxs: Tuple[int, ...] = (), +) -> Callable[..., Any]: + from torch._inductor.cudagraph_trees import ( + cudagraphify_impl as new_cudagraphify_impl, + ) + + cudagraphify_fn: Callable[..., Any] + if config.triton.cudagraph_trees: + cudagraphify_fn = functools.partial( + new_cudagraphify_impl, + device_index=device_index, + stack_traces=stack_traces, + is_backward=is_backward, + is_inference=is_inference, + constants=constants, + placeholders=placeholders, + mutated_input_idxs=mutated_input_idxs, + ) + else: + cudagraphify_fn = cudagraphify_impl + + compiled_fn = None + + def run(new_inputs): + nonlocal compiled_fn + if compiled_fn is None: + with dynamo_utils.dynamo_timed( + "cudagraphify" + ), dynamo_utils.preserve_rng_state(): + compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs) + return compiled_fn(new_inputs) + + return run + + +def static_input(x: torch.Tensor) -> torch.Tensor: + """ + Copy and input while preserving strides + """ + return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device) + + +def index_expanded_dims_and_copy_( + dst: torch.Tensor, + src: torch.Tensor, + expanded_dims: List[int], +): + "Index into expanded dimensions of both dst and src then copy_" + dst = index_expanded_dims(dst, expanded_dims) + src = index_expanded_dims(src, expanded_dims) + dst.copy_(src) + + +def cudagraphify_impl( + model: Callable[..., Any], + inputs: List[torch.Tensor], + static_input_idxs: Sequence[int] = (), +): + """ + Assumes inputs[static_input_idxs[i]] are always the same memory address + """ + check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type] + static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type] + copy_misaligned_inputs(inputs, check_input_idxs) # type: ignore[arg-type] + + assert isinstance(inputs, list) + + inps_expanded_dims = [ + get_expanded_dims(x) if idx not in static_input_idxs else [] + for idx, x in enumerate(inputs) + ] + + # allocate static tensor inputs + static_inputs = [ + x + if not isinstance(x, torch.Tensor) + else static_input(x) + if idx not in static_input_idxs + else x.detach() + for idx, x in enumerate(inputs) + ] + + # copy over input values for fresh allocations + for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)): + if isinstance(x, torch.Tensor) and idx not in static_input_idxs: + index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims) + + # warmup + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + # copy static_inputs because it will be cleared in model + with torch.cuda.stream(stream): + model(list(static_inputs)) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + # record + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"): + static_outputs = model(list(static_inputs)) + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + if config.size_asserts: + + def run(new_inputs): + assert len(static_inputs) == len(new_inputs) + for idx, (dst, src, expanded_dims) in enumerate( + zip(static_inputs, new_inputs, inps_expanded_dims) + ): + if not isinstance(dst, torch.Tensor): + pass + elif idx in static_input_idxs: + assert dst.data_ptr() == src.data_ptr() + else: + # TODO - could make one single op of multiple slices + # and avoid dispatch. + # Could also pre-index the `dst` tensors + index_expanded_dims_and_copy_(dst, src, expanded_dims) + new_inputs.clear() + graph.replay() + return static_outputs + + else: + copy_indices = [ + idx for idx in range(len(static_inputs)) if idx not in static_input_idxs + ] + + def run(new_inputs): + for idx in copy_indices: + expanded_dims = inps_expanded_dims[idx] + index_expanded_dims_and_copy_( + static_inputs[idx], new_inputs[idx], expanded_dims + ) + new_inputs.clear() + graph.replay() + return static_outputs + + return align_inputs_from_check_idxs(run, check_input_idxs) + + +def compile_fx_aot( + model_: torch.fx.GraphModule, + example_inputs_: List[torch.Tensor], + inner_compile: Callable[..., Any] = compile_fx_inner, + config_patches: Optional[Dict[str, Any]] = None, +): + config_patches: Dict[str, Any] = ( + {"cpp_wrapper": True} + if config_patches is None + else {**config_patches, "cpp_wrapper": True} + ) + if ( + "aot_inductor.output_path" not in config_patches + and not config.aot_inductor.output_path + ): + config_patches = { + **config_patches, + "aot_inductor.output_path": code_hash(model_.code), + } + + extern_node_serializer = config_patches.pop("extern_node_serializer", None) + with V.set_aot_compilation(True): + compiled_lib_path = compile_fx( + model_, + example_inputs_, + inner_compile=functools.partial( + inner_compile, + aot_mode=True, + extern_node_serializer=extern_node_serializer, + ), + config_patches=config_patches, + ) + assert os.path.exists( + compiled_lib_path + ), f"AOTInductor compiled library does not exist at {compiled_lib_path}" + return compiled_lib_path + + +_graph_counter = count(0) + + +def fw_compiler_freezing( + aot_autograd_model: torch.fx.GraphModule, + aot_example_inputs: List[torch.Tensor], + dynamo_model: torch.fx.GraphModule, + num_example_inputs: int, + inner_compile: Callable[..., Any], + cudagraphs: BoxedBool, + graph_id: int, + forward_device: BoxedDeviceIndex, +): + from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze + + # partition_fn won't be called + _recursive_joint_graph_passes(aot_autograd_model) + + layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True) + if layout_opt: + # make sure meta['val'] is properly setup + fake_tensor_prop(aot_autograd_model, aot_example_inputs, True) + convert_conv_weights_to_channels_last(aot_autograd_model) + + opt_model, preserved_arg_indices = freeze( + dynamo_model, + aot_autograd_model, + aot_example_inputs, # type: ignore[arg-type] + ) + + aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices] + num_fixed = len(preserved_arg_indices) - num_example_inputs + + fake_mode = detect_fake_mode(aot_example_inputs) + + # for freezing, all graph outputs should be user visible + *_, model_outputs_node = opt_model.graph.nodes + model_outputs = model_outputs_node.args[0] + user_visible_outputs = dict.fromkeys( + n.name for n in model_outputs if isinstance(n, torch.fx.Node) + ) + + static_input_idxs = list(range(num_fixed)) + # constant params will be real tensors, not fake + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context is not None: + params_flat = tracing_context.params_flat + assert params_flat is not None + for i in range(len(params_flat)): + if i not in preserved_arg_indices: + params_flat[i] = None + + if tracing_context.fw_metadata: + static_input_idxs += tracing_context.fw_metadata.static_input_indices + + with mock.patch.object(fake_mode, "allow_non_fake_inputs", True): + optimized_function = inner_compile( + opt_model, + aot_example_inputs, + static_input_idxs=static_input_idxs, + cudagraphs=cudagraphs, + graph_id=graph_id, + is_inference=True, + boxed_forward_device_index=forward_device, + layout_opt=layout_opt, + user_visible_outputs=user_visible_outputs, + ) + + # aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper + # that drops constant-ified params + if V.aot_compilation is True: + return optimized_function + + def wrapper(args): + args_new = [args[i] for i in preserved_arg_indices] + args.clear() + return optimized_function(args_new) + + wrapper._boxed_call = True # type: ignore[attr-defined] + + return wrapper + + +def compile_fx( + model_: torch.fx.GraphModule, + example_inputs_: List[torch.Tensor], + inner_compile: Callable[..., Any] = compile_fx_inner, + config_patches: Optional[Dict[str, Any]] = None, + decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None, +): + with _use_lazy_graph_module(dynamo_config.use_lazy_graph_module): + """Main entrypoint to a compile given FX graph""" + if config_patches: + with config.patch(config_patches): + return compile_fx( + model_, + example_inputs_, + # need extra layer of patching as backwards is compiled out of scope + inner_compile=config.patch(config_patches)(inner_compile), + decompositions=decompositions, + ) + + if config.cpp_wrapper: + with config.patch( + { + "cpp_wrapper": False, + # For triton.autotune_at_compile_time, disable by default for + # FBCode, but enabled by default for OSS. + "triton.autotune_at_compile_time": config.triton.autotune_at_compile_time + if config.is_fbcode() + else os.environ.get( + "TORCHINDUCTOR_TRITON_AUTOTUNE_AT_COMPILE_TIME", "1" + ) + == "1", + "triton.autotune_cublasLt": False, + "triton.cudagraphs": False, + "triton.store_cubin": True, + } + ), V.set_real_inputs(example_inputs_): + inputs_ = example_inputs_ + if isinstance(model_, torch.fx.GraphModule): + fake_inputs = [ + node.meta.get("val") + for node in model_.graph.nodes + if node.op == "placeholder" + ] + if all(v is not None for v in fake_inputs): + # Validate devices before switching to fake tensors. + for idx, fi, i in zip(count(), fake_inputs, inputs_): + if fi.device != i.device: + raise ValueError( + f"Device mismatch between fake input and example input at position #{idx}: " + f"{fi.device} vs {i.device}. If the model was exported via torch.export(), " + "make sure torch.export() and torch.aot_compile() run on the same device." + ) + inputs_ = fake_inputs + return compile_fx( + model_, + inputs_, + inner_compile=functools.partial(inner_compile, cpp_wrapper=True), + decompositions=decompositions, + ) + + recursive_compile_fx = functools.partial( + compile_fx, + inner_compile=inner_compile, + decompositions=decompositions, + ) + + if not graph_returns_tuple(model_): + return make_graph_return_tuple( + model_, + example_inputs_, + recursive_compile_fx, + ) + + if isinstance(model_, torch.fx.GraphModule): + if isinstance(model_.graph._codegen, _PyTreeCodeGen): + # this graph is the result of dynamo.export() + return handle_dynamo_export_graph( + model_, + example_inputs_, + recursive_compile_fx, + ) + + model_ = _recursive_pre_grad_passes(model_, example_inputs_) + + if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_): + return flatten_graph_inputs( + model_, + example_inputs_, + recursive_compile_fx, + ) + + assert not config._raise_error_for_testing + num_example_inputs = len(example_inputs_) + cudagraphs = BoxedBool(config.triton.cudagraphs) + forward_device = BoxedDeviceIndex(None) + + graph_id = next(_graph_counter) + + decompositions = ( + decompositions if decompositions is not None else select_decomp_table() + ) + + def fw_compiler_base( + model: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + is_inference: bool, + ): + with dynamo_utils.dynamo_timed("compile_fx..fw_compiler_base"): + return _fw_compiler_base(model, example_inputs, is_inference) + + def _fw_compiler_base( + model: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + is_inference: bool, + ): + if is_inference: + # partition_fn won't be called + _recursive_joint_graph_passes(model) + + fixed = torch._inductor.utils.num_fw_fixed_arguments( + num_example_inputs, len(example_inputs) + ) + + user_visible_outputs = {} + + if config.keep_output_stride: + model_outputs_node = output_node(model) + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + num_model_outputs = len(model_outputs) + + context = torch._guards.TracingContext.try_get() + # See Note [User Outputs in the inductor graph] + if context is not None and context.fw_metadata and not is_inference: + original_output_start_index = ( + context.fw_metadata.num_mutated_inp_runtime_indices + ) + else: + original_output_start_index = 0 + + if isinstance(model_, torch.fx.GraphModule): + *_, orig_model_outputs_node = model_.graph.nodes + assert orig_model_outputs_node.op == "output" + orig_model_outputs, _ = pytree.tree_flatten( + orig_model_outputs_node.args + ) + num_orig_model_outputs = len(orig_model_outputs) + else: + num_orig_model_outputs = num_model_outputs + + assert num_orig_model_outputs <= num_model_outputs + + # Note [User Outputs in the inductor graph] + # We makes the following assumption + # For inference + # len(orig_model_outputs) == len(model_outputs) + # For training + # len(orig_model_outputs) <= len(model_outputs) + # During training, most of the time the model_outputs starts with + # original module's outputs followed by saved activations. + # But this can be not true if the model have inplace updated tensors. + # AOTAutograd will make those tensors being returned before the original + # module's output. + # To make things safe, we'll use original_output_start_index field + # set by AOTAutograd to decide where the original module outputs start. + orig_output_end_idx = ( + original_output_start_index + num_orig_model_outputs + ) + # Sanity chec: we are about to splice out the "user" outputs from the full set + # of "graph" outputs. Make sure we're within bounds. + assert orig_output_end_idx <= num_model_outputs + + user_visible_outputs = dict.fromkeys( + n.name + for n in model_outputs[ + original_output_start_index:orig_output_end_idx + ] + if isinstance(n, torch.fx.Node) + ) + + return inner_compile( + model, + example_inputs, + static_input_idxs=get_static_input_idxs(fixed), + cudagraphs=cudagraphs, + graph_id=graph_id, + is_inference=is_inference, + boxed_forward_device_index=forward_device, + user_visible_outputs=user_visible_outputs, + ) + + fw_compiler = functools.partial(fw_compiler_base, is_inference=False) + + if config.freezing and not torch.is_grad_enabled(): + inference_compiler = functools.partial( + fw_compiler_freezing, + dynamo_model=model_, + num_example_inputs=num_example_inputs, + inner_compile=inner_compile, + cudagraphs=cudagraphs, + graph_id=graph_id, + forward_device=forward_device, + ) + else: + inference_compiler = functools.partial(fw_compiler_base, is_inference=True) + + def partition_fn(graph, joint_inputs, **kwargs): + _recursive_joint_graph_passes(graph) + return min_cut_rematerialization_partition( + graph, joint_inputs, **kwargs, compiler="inductor" + ) + + def bw_compiler( + model: torch.fx.GraphModule, example_inputs: List[torch.Tensor] + ): + with dynamo_utils.dynamo_timed("compile_fx..bw_compiler"): + user_visible_outputs = {} + + if config.bw_outputs_user_visible: + model_outputs_node = output_node(model) + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + user_visible_outputs = dict.fromkeys( + n.name for n in model_outputs if isinstance(n, torch.fx.Node) + ) + fixed = count_tangents(model) + return inner_compile( + model, + example_inputs, + static_input_idxs=list(range(fixed)), + cudagraphs=cudagraphs, + is_backward=True, + graph_id=graph_id, + boxed_forward_device_index=forward_device, + user_visible_outputs=user_visible_outputs, + ) + + # TODO: can add logging before/after the call to create_aot_dispatcher_function + # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func + # once torchdynamo is merged into pytorch + + fake_mode = detect_fake_mode( + example_inputs_ + ) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + tracing_context = ( + torch._guards.TracingContext.try_get() + or torch._guards.TracingContext(fake_mode) + ) + + if V.aot_compilation is True: + with functorch_config.patch(unlift_effect_tokens=True): + gm, graph_signature = aot_export_module( + model_, + example_inputs_, + trace_joint=False, + decompositions=decompositions, + ) + unlifted_gm = _unlift_graph(model_, gm, graph_signature) + if "dynamo_flat_name_to_original_fqn" in model_.meta: + unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[ + "dynamo_flat_name_to_original_fqn" + ] + + # Disable amp as in aot_dispatch_autograd (https://github.com/pytorch/pytorch/pull/86515) + # In inference_compiler (fw_compiler_base), _recursive_joint_graph_passes will call into + # _sfdp_init() to register patterns. + # When fallback_random is set to True, the sdpa patterns will be traced during runtime. + # If amp is turned on, the traced FP32 patterns will have prims.convert_element_type which + # will be the same as the generated FP16 patterns. + disable_amp = torch._C._is_any_autocast_enabled() + context = ( + torch._C._DisableAutocast if disable_amp else contextlib.nullcontext + ) + with V.set_fake_mode(fake_mode), compiled_autograd.disable(), context(): + return inference_compiler(unlifted_gm, example_inputs_) + + with V.set_fake_mode(fake_mode), torch._guards.tracing( + tracing_context + ), compiled_autograd.disable(), functorch_config.patch( + unlift_effect_tokens=True + ): + return aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + inference_compiler=inference_compiler, + decompositions=decompositions, + partition_fn=partition_fn, + keep_inference_input_mutations=True, + cudagraphs=cudagraphs, + )(model_, example_inputs_) + + +def graph_returns_tuple(gm: torch.fx.GraphModule): + """True if a FX graph returns a tuple""" + if not isinstance(gm, torch.fx.GraphModule): + return True # can't check this, assume true + (rv,) = output_node(gm).args + if isinstance(rv, (list, tuple)): + return True + if ( + isinstance(rv, torch.fx.node.Node) + and hasattr(rv.target, "_schema") + and len(rv.target._schema.returns) > 1 + and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns) + ): + # for graphs whose result is one node with multiple outputs + return True + return False + + +def make_graph_return_tuple( + gm: torch.fx.GraphModule, + inputs: List[torch.Tensor], + compile_gm: Callable[..., Any], +): + """ + Mutate gm so it returns a tuple. This is only needed for graphs + not created by torchdynamo that return non-tuples. + """ + node = output_node(gm) + (rv,) = node.args + rv, spec = pytree.tree_flatten(rv) + with gm.graph.inserting_before(node): + gm.graph.output(rv) + gm.graph.erase_node(node) + assert graph_returns_tuple(gm) + + compiled_fn = compile_gm(gm, inputs) + + @functools.wraps(compiled_fn) + def wrapper(*args, **kwargs): + return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec) + + return wrapper + + +def handle_dynamo_export_graph( + gm: torch.fx.GraphModule, + inputs: List[torch.Tensor], + compile_gm: Callable[..., Any], +): + """ + `torch._dynamo.export` embeds pytrees in the FX graph codegen object, + convert that to a normal FX graph so inductor can compile it. + """ + codegen = gm.graph._codegen + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.recompile() + + compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs)) + + @functools.wraps(compiled_fn) + def wrapper(*args): + return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args))) + + return wrapper + + +def _check_triton_bf16_support(graph: GraphLowering) -> None: + def warn_and_skip(device) -> None: + from torch._dynamo.exc import SkipFrame + + device_interface = get_interface_for_device(device.type) + device_props = device_interface.get_device_properties(device) + warnings.warn( + f"{device_props.name} does not support bfloat16 compilation natively, skipping" + ) + raise SkipFrame("BF16 is not supported") + + for inp in graph.graph_inputs.values(): + device = getattr(inp, "get_device", lambda: torch.device("meta"))() + if (not is_gpu(device.type)) or inp.get_dtype() != torch.bfloat16: + continue + # Print warning and skip frame if attempting to compile for bfloat16 + # on device without hardware support for dtype + device_interface = get_interface_for_device(device.type) + if device_interface.is_bf16_supported(including_emulation=False): + return + warn_and_skip(device) + + for out in graph.graph_outputs: + device = getattr(out, "get_device", lambda: torch.device("meta"))() + if (not is_gpu(device.type)) or out.get_dtype() != torch.bfloat16: + continue + # Print warning and skip frame if attempting to compile for bfloat16 + # on device without hardware support for dtype + device_interface = get_interface_for_device(device.type) + if device_interface.is_bf16_supported(including_emulation=False): + return + warn_and_skip(device) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/config.py b/.venv/lib/python3.11/site-packages/torch/_inductor/config.py new file mode 100644 index 0000000000000000000000000000000000000000..80b65fbc31af9173e89c7005c67718caa3f51c9c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/config.py @@ -0,0 +1,1241 @@ +import os # noqa: C101 +import sys +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union + +import torch + + +def is_fbcode() -> bool: + return not hasattr(torch.version, "git_version") + + +def fx_graph_remote_cache_default() -> Optional[bool]: + if os.environ.get("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") == "1": + return True + if os.environ.get("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") == "0": + return False + return None + + +def autotune_remote_cache_default() -> Optional[bool]: + if os.environ.get("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") == "1": + return True + if os.environ.get("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") == "0": + return False + return None + + +# Enable auto_functionalized_v2 (enabled by default) +enable_auto_functionalized_v2 = ( + os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "0") == "1" +) + +# add some debug printouts +debug = False + +# Whether to disable a progress bar for autotuning +disable_progress = True + +# Whether to enable printing the source code for each future +verbose_progress = False + +# use fx aot graph codegen cache +fx_graph_cache = ( + os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE", "0" if is_fbcode() else "1") == "1" +) + +# use remote fx aot graph codegen cache +# False: Disables the cache +# True: Enables the cache +# None: Not set -- Off for OSS, JustKnobs based for internal +fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default() + +# enable autotune local cache +autotune_local_cache = True + +# enable autotune remote cache +# False: Disables the cache +# True: Enables the cache +# None: Not set -- Off for OSS, JustKnobs based for internal +autotune_remote_cache: Optional[bool] = autotune_remote_cache_default() + +# Force disabled all inductor level caching -- This will override any other caching flag +force_disable_caches = os.environ.get("TORCHINDUCTOR_FORCE_DISABLE_CACHES") == "1" + +# sleep in inductor for testing +sleep_sec_TESTING_ONLY: Optional[int] = None + +# The default layout constraint for custom operators. +# This must be the name of one of the layout constraint tags +# (that is, one of {"needs_fixed_stride_order", "flexible_layout"}), +# If the custom op does not have a layout constraint tag already +# then we assume the following applies. +custom_op_default_layout_constraint = "flexible_layout" + +# use cpp wrapper instead of python wrapper +cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" + +# codegen cpp wrapper code in an ABI compatible mode +abi_compatible = ( + os.environ.get("TORCHINDUCTOR_ABI_COMPATIBLE", "1" if is_fbcode() else "0") == "1" +) + +c_shim_version = os.environ.get("TORCHINDUCTOR_C_SHIM_VERSION", "2") + +# dead code elimination +dce = False + +# assume weight tensors are fixed size +static_weight_shapes = True + +# put correctness assertions in generated code +size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1" +nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1" + +# enable loop reordering based on input orders +pick_loop_orders = True + +# reuse a kernel input as the output +inplace_buffers = True + +# reuse a buffer for an unrelated purpose +allow_buffer_reuse = True + +# Enable pooled allocations for non-output tensors +memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1" + +# How to organize memory under memory_planning=True: +# - "none": do not try to pool storage, just reuse +# - "intermediates": all non-outputs share storage, outputs each get unique storage +# - "outputs": two pools, one for intermediates (freed on return) and one for outputs +# - "combined": a single pool for both intermediates and outputs +memory_pool = os.environ.get("TORCHINDUCTOR_MEMORY_POOL", "intermediates") + +# codegen benchmark harness +benchmark_harness = True + +# fuse pointwise into templates +epilogue_fusion = True + +# do epilogue fusions before other fusions +epilogue_fusion_first = False + +# enable pattern match+replace optimizations +pattern_matcher = True + +# set to True to enable the back-to-back GEMM pass +b2b_gemm_pass = False + +# register custom graph optimization pass hook. so far, pre/post passes are +# only applied before/after pattern_matcher in post_grad_passes. +# +# def my_custom_pre_pass(graph: torch.fx.graph.Graph): +# # my custom graph optimization pass +# ... +# +# def my_custom_post_pass(graph: torch.fx.graph.Graph): +# # my custom graph optimization pass +# ... +# +# torch._inductor.config.post_grad_custom_pre_pass = my_custom_pre_pass +# torch._inductor.config.post_grad_custom_post_pass = my_custom_post_pass +post_grad_custom_pre_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None +post_grad_custom_post_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None + +# Registers a custom joint graph pass. +joint_custom_pre_pass: Optional[Callable[[torch.fx.Graph], None]] = None +joint_custom_post_pass: Optional[Callable[[torch.fx.Graph], None]] = None + +# Registers a custom pregrad pass. Note that the pre-grad IR is 1. +# non-functional, 2. non-normalized, and 3. prone to change. Ideally we should +# use post-grad passes. +pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None + +# Registers a custom pass to be run right before fusion in Inductor scheduler. +# WARNING: Inductor scheduler IR is at prototype stage and subject to change, +# hence custom IR passes built on top of it might break in the future. +_pre_fusion_custom_pass: Optional[ + Callable[ + [List["torch._inductor.scheduler.BaseSchedulerNode"]], + List["torch._inductor.scheduler.BaseSchedulerNode"], + ] +] = None + +# Deprecated +split_cat_fx_passes = True + +# Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability. +efficient_conv_bn_eval_fx_passes = False + +# Enable predispatch aten IR for export +is_predispatch = False + +# Deprecated +group_fusion = False + +# Deprecated +batch_fusion = True + +# Pre grad fusion and options in order, set to empty dict to disable fusion. +# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions. +# batch fusion options: +# batch_linear +# batch_linear_lhs +# batch_layernorm +# batch_tanh +# batch_relu +# batch_sigmoid + +# split cat fusion options: +# normalization_pass +# remove_split_with_size_one_pass +# merge_getitem_cat_pass +# merge_stack_tahn_unbind +# merge_splits_pass +# mutate_cat_pass +# split_cat_pass +pre_grad_fusion_options: Dict[str, Dict[str, Any]] = { + "batch_linear": {}, + "batch_linear_lhs": {}, + "batch_layernorm": {}, + "batch_tanh": {}, + "batch_relu": {}, + "batch_sigmoid": {}, +} + +# Post grad fusion and options, set to empty dict to disable fusion. +# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions. +post_grad_fusion_options: Dict[str, Dict[str, Any]] = {} + +# enable reordering pass for improving memory locality +reorder_for_locality = True + +# Scale down RBLOCK for better occupancy +dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1" + +# this forces fusion for int_mm with mul. Needed when you want to avoid realizing the int32 +# but the mul gets fused with other pointwise ops instead. +force_fuse_int_mm_with_mul = False + +# for pattern torch.mm(a, b.to(dtype)) with cuda tensors, +# enable torch._inductor.kernel.mm.tuned_mixed_mm fused kernel. +# Autotune will compare perf with normal cast->then->mm option +use_mixed_mm = True + +# enable runtime numeric check for pre/post grad fx passes +# floating point provides limited accuracy (about 7 decimal digits for single precision +# floating point numbers,about 16 decimal digits for double precision floating point numbers) +# according to PyTorch documentation. +# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations +fx_passes_numeric_check: Dict[str, Any] = { + "pre_grad": False, + "precision": 1e-4, + "num_iterations": 1, + "requires_optimizer": True, +} + +# mixed_mm_choice can be used to control the behaviour for pattern torch.mm(a, b.to(dtype)) with cuda tensors. +# The fallback aten implementation is normal cast->then->mm option. +# If mixed_mm_choice is "default": this flag will be ignored. +# If mixed_mm_choice is "triton": +# - Always use torch._inductor.kernel.mm.tuned_mixed_mm's fused kernel. +# - Autotune will not compare with fallback. +# If mixed_mm_choice is "aten": always use the fallback aten implementation. +# If mixed_mm_choice is "heuristic": +# - Enables the heuristic. +# - If the heuristic decides to add a config, it will add the config as the first choice. +# - If autotune is disabled, this config will always be chosen. +# - If autotune is enabled, it will also compare with fallback aten implementation and fused kernel. +# The use_mixed_mm flag will be ignored if mixed_mm_choice != "default". +mixed_mm_choice = "heuristic" + +# enable reordering pass for increasing overlap between compute and communication +reorder_for_compute_comm_overlap = False + +# passes (in execution order) for increasing overlap between compute and communication +# for built-in passes, use string name; for user-defined passes, pass in the function handle +# WARNING: Inductor scheduler IR is at prototype stage and subject to change, +# hence custom IR passes built on top of it might break in the future. +reorder_for_compute_comm_overlap_passes = [ + "reorder_compute_for_overlap", + "sink_waits", + "raise_comms", +] + +# runtime estimation function for ops +# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle +estimate_op_runtime = "default" + +# unit: GB/s, uni-directional P2P bandwidth per card +# default value is NVLink +intra_node_bw = 300 + +# unit: GB/s, uni-directional P2P bandwidth per node +# default value is InfiniBand +inter_node_bw = 25 + +# enable slow autotuning passes to select algorithms +max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1" + +# enable slow autotuning passes to select pointwise/reductions algorithms +max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1" + +# enable slow autotuning passes to select gemm algorithms +max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1" + +# force cublas and triton to use the same precision; cublas supports TF32 for matmul operations +# when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations +# for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure +# that triton does not use TF32 wherever cublas would not use TF32 +force_same_precision = ( + True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1" +) + +# Specify candidate backends for gemm autotune. +# Possible choices are combinations of: ATen, Triton, CUTLASS, CK, CPP. +# ATen: default Pytorch ATen kernels. +# Triton: Triton templates defined in torch inductor (AMD and NVidia GPUs). +# CUTLASS: Cutlass templates and kernels (NVidia GPUs only). +# CK: Composable Kernel templates and kernels (AMD Instinct GPUs only). +# CPP: CPP templates and kernels for CPU. +max_autotune_gemm_backends = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP" +).upper() + +# As above, specify candidate backends for conv autotune. +# NB: in some cases for 1x1 convs we emit as matmul, +# which will use the backends of `max_autotune_gemm_backends` +max_autotune_conv_backends = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_CONV_BACKENDS", "ATEN,TRITON" +).upper() + + +# Specify the size of the search space for GEMM autotuning. +# DEFAULT - balance between compile time overhead and performance +# EXHAUSTIVE - maximize performance +max_autotune_gemm_search_space = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT" +).upper() + +# Whether we fall back to ATen or hard error when no matches are found during autotuning +autotune_fallback_to_aten = ( + os.environ.get("TORCHINDUCTOR_AUTOTUNE_FALLBACK_TO_ATEN", "1") == "1" +) + +# the value used as a fallback for the unbacked SymInts +# that can appear in the input shapes (e.g., in autotuning) +unbacked_symint_fallback = 8192 + +# DEPRECATED, DO NOT USE +search_autotune_cache = False + +save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1" + +# We will disable creating subprocess for autotuning if this is False +autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1" + +# The following three timeouts are applicable if autotune_in_subproc is True: + +# Max time that a a valid benchmark result may take during autotuning +max_autotune_subproc_result_timeout_seconds = 60.0 +# Additional time we allow subprocesses to terminate gracefully after the timeout until we send a SIGTERM +max_autotune_subproc_graceful_timeout_seconds = 1.0 +# Additional time that we grant after a SIGTERM until we do a hard SIGKILL of subprocesses +max_autotune_subproc_terminate_timeout_seconds = 2.0 + +# If autotuning in subprocess, whether to use multiple devices +autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1" + +coordinate_descent_tuning = ( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1" +) +coordinate_descent_check_all_directions = ( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1" +) +coordinate_descent_search_radius = int( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1") +) + +# AutoHeuristic is a framework that allows one to collect data from autotuning, use the data to learn a heuristic, and +# generate the learned heursitic to code which is shipped with the compiler +# Specify a list of comma separated optimizations to collect data for +autoheuristic_collect = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "") +# Specify a list of comma separated optimizations to use learned heuristics for +autoheuristic_use = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "mixed_mm") + + +def run_autoheuristic(name: str) -> bool: + return collect_autoheuristic(name) or use_autoheuristic(name) + + +def collect_autoheuristic(name: str) -> bool: + return name in torch._inductor.config.autoheuristic_collect.split(",") + + +def use_autoheuristic(name: str) -> bool: + return name in torch._inductor.config.autoheuristic_use.split(",") + + +# If set to "DEFAULT", this will use the default log path specified in autoheuristic.py. +# If set to another path, autoheuristic will instead log results to the given path. +autoheuristic_log_path = os.environ.get( + "TORCHINDUCTOR_AUTOHEURISTIC_LOG_PATH", "DEFAULT" +) + +# Disabled by default on ROCm, opt-in if model utilises NHWC convolutions +layout_opt_default = "1" if not torch.version.hip else "0" +layout_optimization = ( + os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", layout_opt_default) == "1" +) + +force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1" + + +# Whether to keep the output strides the same as eager after layout optimization. +keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1" + +# Enabling this will let compiler print warning messages if a generated triton +# kernel has inputs with mixed layouts. This is helpful for perf debugging +# since kernel with mixed layout inputs may run much slower then one whose inputs +# have uniform layouts. +warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1" + +# control store vs recompute heuristic +# For fanouts, rematerialization can lead to exponential blowup. So, have +# smaller threshold +realize_reads_threshold = 4 +realize_opcount_threshold = 30 + +# Threshold to prevent excessive accumulation of ops in one buffer during lowering +realize_acc_reads_threshold = 8 + +# fallback to eager for random/dropout, this is slow but useful for debugging +fallback_random = False + +# automatically create fallbacks when encountering an unhandled op +implicit_fallbacks = True + +# fuse even in cases without common reads +aggressive_fusion = False + +# For each fused kernel in the wrapper, comment with the nodes that get fused. +# Useful for debugging fusion. +debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1" +benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1" +enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "") +loop_ordering_after_fusion = ( + os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1" +) + +# For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel +benchmark_epilogue_fusion = ( + os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1" +) + +# Take how many of the top triton kernels to benchmark epilogue +max_epilogue_benchmarked_choices = 1 + +# how many nodes to allow into a single fusion +max_fusion_size = 64 + +# max number of inputs to generate cat as a pointwise op with masked laods +max_pointwise_cat_inputs = 8 + +# replace small reductions with pointwise, disable with `= 1` +unroll_reductions_threshold = 8 + +# Add extra comments to output code (causes compile cache misses) +comment_origin = False + +# Convert 1x1 convs into matmuls +conv_1x1_as_mm = False + +# Enable split reductions for better utilization when the dimension +# being reduced over is large (by splitting it) +split_reductions = True + +benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1" + +# Enable constant and index_expr folding +constant_and_index_propagation = True + +# we always add constants into graph.constants without +# performing any constant-inlining optimization +always_keep_tensor_constants = False + +# assert that indirect indexing does not read / write out of bounds +assert_indirect_indexing = True + +# compute CSE bounds on variables that do not appear in the FX graph +compute_all_bounds = False + +# enable the combo kernel that combines data-independent kernels (additional +# to foreach kernels) into a single one (Experimental) +combo_kernels = False +# benchmark combo kernels and only allow ones with perf gains +benchmark_combo_kernel = False +# combo_kernel autotuning options: 0 - disable, 1 - enable except for foreach, +# 2 - enable for all +combo_kernels_autotune = 1 +# Enable masking for combining kernels of mixed sizes: 0 - disable, 1 - enable +# for all except for foreach, 2 - enable for all +combo_kernel_allow_mixed_sizes = 1 +# Enable dynamic shapes for foreach kernels +combo_kernel_foreach_dynamic_shapes = False + +# constant folding on the joint graph +joint_graph_constant_folding = True + +# Enable indirect_indexing asserts for decompositions and lowerings +debug_index_asserts = False + +# Mode to emulate pytorch eager numerics for lower precision (fp16, bf16) +# Pytorch eager computes bf16/fp16 by upcasting inputs to fp32 and downcasting after +# For multiple, fused pointwise nodes, inductor will elide the intermediary upcasts and downcasts +# Typically this should be closer to fp64 ref numerics. However, it can be useful for debugging +# to emulate the eager numerics. +emulate_precision_casts = False + +# warnings intended for PyTorch developers, disable for point releases +is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__ +developer_warnings = is_fbcode() or is_nightly_or_source + +# This pattern matches a special usage of scatter +# 1. It's applied to a constant tensor +# 2. The index tensor has size 1 in the scatter dimension +# Such pattern generates a sparse matrix when the const tensor is all-zero. +# We can lower this pattern to a pointwise kernel for more fusion opportunities +# and saving memory footprint. +optimize_scatter_upon_const_tensor = ( + os.environ.get("TORCHINDUCTOR_OPTIMIZE_SCATTER_UPON_CONST_TENSOR", "1") == "1" +) + + +# The multiprocessing start method to use for inductor workers in the codecache. +# Can be "subprocess" or "fork". +def decide_worker_start_method() -> str: + start_method = os.environ.get( + "TORCHINDUCTOR_WORKER_START", "fork" if is_fbcode() else "subprocess" + ) + assert start_method in ( + "subprocess", + "fork", + ), f"Invalid start method: {start_method}" + return start_method + + +worker_start_method = decide_worker_start_method() + +# Flags to turn on all_reduce fusion. These 2 flags should be automaticaly turned +# on by DDP and should not be set by the users. +_fuse_ddp_communication = False +_fuse_ddp_bucket_size = 25 + +# Flag to control which fusion passes to apply. Functions in the list will +# be applied in order. There are two different different fusion passes +# --"fuse_ddp_with_concat_op" and "fuse_ddp_with_coalesced_op". The default +# one is "fuse_ddp_with_concat_op". Users can also change this to a customized +# fusion function. +# +# The fusion currently does not support multiple DDP with different PG or +# data type. This feature will be added in the future PRs. +# +# "schedule_comm_wait" is used to delay the wait ops to maximize comm/comp +# overlapping. At this moment, this pass performs better than +# reorder_for_compute_comm_overlap_passes but we will add the logic of +# "schedule_comm_wait" in the future and remove the one here. +_fuse_ddp_communication_passes: List[Union[Callable[..., None], str]] = [ + "fuse_ddp_with_concat_op", + "schedule_comm_wait", +] + +_micro_pipeline_tp: bool = False + + +def decide_compile_threads() -> int: + """ + Here are the precedence to decide compile_threads + 1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by + setting this to 1 to make pdb happy. + 2. Set to 1 if it's win32 platform + 3. decide by the number of CPU cores + """ + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + elif sys.platform == "win32": + return 1 + elif is_fbcode(): + return 1 + else: + cpu_count = ( + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else os.cpu_count() + ) + assert cpu_count + return min(32, cpu_count) + + +compile_threads = decide_compile_threads() + +# gemm autotuning global cache dir +if is_fbcode(): + try: + from libfb.py import parutil + + if __package__: + global_cache_dir = parutil.get_dir_path( + os.path.join(__package__.replace(".", os.sep), "fb/cache") + ) + else: + global_cache_dir = parutil.get_dir_path("fb/cache") + except (ValueError, ModuleNotFoundError): + global_cache_dir = None + +else: + global_cache_dir = None + +# If kernel is fused, the name is generated from the origin node op names +# for larger kernels limit this +kernel_name_max_ops = 10 + +# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs +shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1" + +# Control if we will do padding for pointwise/reductions +comprehensive_padding = ( + os.environ.get("TORCHINDUCTOR_COMPREHENSIVE_PADDING", "1") == "1" +) +pad_channels_last = False + +# Disable comprehensive padding on the CPU +disable_padding_cpu = True + +# The width of comprehensive padding, in bytes. +# CUDA max memory transaction size is 128 bytes for a warp. +padding_alignment_bytes = 128 + +# Threshold on the minimum stride that will be padded. +# +# Don't align a too small stride since that causes too much memory increase. +# Pad too small stride may also cause perf loss. We may result in many tiny data blocks +# with gaps in between. That causes less coalesced GPU memory access! +# +# Initially we pick 320 as the threshold since for alignement=16, +# that results in at most 5% memory cost. +# +# But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. +# Let's say an inner reduction has a row size 513. Inductor will generate +# persistent reduction code. +# If we do padding, the strides are not contiguous any more. Inductor +# uses a much smaller threshold for persistent reduction in this case and +# generates potentially worse non-persistent reduction code. +# +# This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x. +# (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms) +padding_stride_threshold = 1024 + +# Enable padding outputs, even if they would not be padded in eager mode. +# By default, we use the same strides as eager mode. +pad_outputs = False + +# Whether to treat output of the backward graph as user visible. +# For user visible outputs, inductor will make sure the stride matches with eager. +bw_outputs_user_visible = True + +# Whether to always use shape padding if it is enabled and possible +force_shape_pad: bool = False + +# Fx-based linear/matmul/bmm + permute/transpose vertical fusion +permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1" + +# Mark the wrapper call in PyTorch profiler +profiler_mark_wrapper_call = False + +# Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for +# every intermediate for which we can correlate it with an intermediate +# from the original FX graph +generate_intermediate_hooks = False + +# Populate traceback field on IRNode; good for debugging why origin_node is +# not populated, or finding out where an IRNode was constructed +debug_ir_traceback = False + +# used for debugging to make sure config is properly set +_raise_error_for_testing = False + +_profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "") +profile_bandwidth = _profile_var != "" +profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var +# Specify a file where we print out the profiling results. +# None means we do not dump results to a file. +profile_bandwidth_output = os.environ.get("TORCHINDUCTOR_PROFILE_OUTPUT", None) +# Switch to do_bench_using_profiling to exclude the CPU overheads +profile_bandwidth_with_do_bench_using_profiling = ( + os.environ.get("TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING") == "1" +) + + +# TODO: remove later +disable_cpp_codegen = False + + +# Freezing will attempt to inline weights as constants in optimization +# and run constant folding and other optimizations on them. After freezing, weights +# can no longer be updated. +freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1" + +# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead +# of potentially keeping multiple copies of weights. +freezing_discard_parameters: bool = False + +# Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests +# should be run with this flag both on and off to make sure we have coverage. +allow_stack_allocation: bool = ( + os.environ.get("TORCHINDUCTOR_STACK_ALLOCATION", "1" if is_fbcode() else "0") == "1" +) + +# Enables an alternate DSO interface (the "minimal ArrayRef interface") intended +# to maximize performance for use cases that it can accommodate at the expense of +# generality. In brief: +# - inputs and outputs are ArrayRefTensor (note that strides are required, but the +# tensor must be contiguous) +# - constant handling is unchanged because it is not a per-inference-iteration bottleneck +# +# When the DSO is generated in this mode, the usual interface will also be supported, +# but performance for that interface may be degraded. +use_minimal_arrayref_interface: bool = False + +# decompose some memory bound matmul/bmm to mul +decompose_mem_bound_mm: bool = False + +# assume_aligned_inputs means that we assume that inputs will be aligned; we generate +# code using this assumption, and clone tensors before use if they aren't aligned. +# In the common case, most inputs will be aligned. +assume_aligned_inputs: bool = False + +# For the user-written Triton kernels compiled with the model, ignore the unsupported +# arguments passed to the @triton.autotune in the user's code; this is unsafe, as +# ignoring the unsupported args may lead to unexpected autotuning behavior: don't +# set unless you know what you're doing. +unsafe_ignore_unsupported_triton_autotune_args: bool = False + +# When True, we will check in scheduler.py _codegen that there are no "loops" +# in the call stack; that is to say, the same frame multiple times. This +# ensures that a cProfile trace to this frame will be a straight line without +# any cycles. +check_stack_no_cycles_TESTING_ONLY: bool = False + + +# config specific to codegen/cpp.py +class cpp: + # set to torch.get_num_threads() + threads = -1 + + # Do not generate loops when the condition doesn't hold, like: + # for(long i0=4096; i0<4096; i0+=1) + no_redundant_loops = ( + os.environ.get("TORCHINDUCTOR_CPP_NO_REDUNDANT_LOOPS", "1") == "1" + ) + + # Assume number of threads is dynamic, don't specialize thread number. + # Kernels don't recompile on thread number changes with this flag on. + # For single-threaded workload, turning it on would incur a slight + # performance degradation. + dynamic_threads = os.environ.get("TORCHINDUCTOR_CPP_DYNAMIC_THREADS", "0") == "1" + + simdlen: Optional[int] = None + min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096")) + cxx = ( + None, # download gcc12 from conda-forge if conda is installed + # "g++-12", + # "g++-11", + # "g++-10", + # "clang++", + os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"), + # "g++.par", + ) + # Allow kernel performance profiling via PyTorch profiler + enable_kernel_profile = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_KERNEL_PROFILE", "0") == "1" + ) + + # enable weight prepacking to get a better performance; may lead to large memory footprint + weight_prepack = os.environ.get("TORCHINDUCTOR_CPP_WEIGHT_PREPACK", "1") == "1" + + # Inject a bug into our relu implementation; useful for testing our repro + # extraction and minification functionality. + # Valid values: "compile_error", "runtime_error", "accuracy" + inject_relu_bug_TESTING_ONLY: Optional[str] = None + inject_log1p_bug_TESTING_ONLY: Optional[str] = None + + # If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise, + # force usage as specified, without testing. + vec_isa_ok: Optional[bool] = None + + # similar to config.triton.descriptive_names + descriptive_names = "original_aten" + + # how many nodes to allow into a single horizontal fusion + max_horizontal_fusion_size = int( + os.environ.get("TORCHINDUCTOR_CPP_MAX_HORIZONTAL_FUSION_SIZE", "16") + ) + + # Make scatter_reduce fallback when reduce is sum to avoid performance regression + # using atomic_add. + fallback_scatter_reduce_sum = ( + os.environ.get("TORCHINDUCTOR_CPP_FALLBACK_SCATTER_REDUCE_SUM", "1") == "1" + ) + + # Use funsafe-math-optimizations when compiling + enable_unsafe_math_opt_flag = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_UNSAFE_MATH_OPT_FLAG", "0") == "1" + ) + + # Use ffp-contract when compiling + enable_floating_point_contract_flag = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_FLOATING_POINT_CONTRACT_FLAG", "0") + == "1" + ) + + # Disable the tiling select heuristic + enable_tiling_heuristics = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1" + ) + + # Maximal allowed number of slices on K-dim for a GEMM kernel. This controls + # the maximal parallelism of K-slicing. Since K-slicing requires extra thread + # synchronization and buffers, the maximal number of slices is limited to + # mitigate the sync overhead and memory usage. + # When set to 0, the number of slices is unlimited. + gemm_max_k_slices = int(os.environ.get("TORCHINDUCTOR_CPP_GEMM_MAX_K_SLICES", "1")) + + # For perf tuning and debugging purpose, configure the pre-defined cache blocking for + # MxNxK dims respectively. The blockings are separated by comma and the unit is + # the number of register blocks. + # For example, "4,1,10" means 4 register blocks on M, 1 on N and 10 on K respectively. + gemm_cache_blocking = os.environ.get("TORCHINDUCTOR_CPP_GEMM_CACHE_BLOCKING", None) + + # For perf tuning and debugging purpose, configure the pre-defined thread blocking factors for + # MxNxK dims respectively. The factors are separated by comma and their product + # should be the same as the total number of threads. + # For example, if the total number of threads is 56, "7,4,2" means the work is + # decomposed into 7x4x2 thread blocks along MxNxK of a GEMM. + gemm_thread_factors = os.environ.get("TORCHINDUCTOR_CPP_GEMM_THREAD_FACTORS", None) + + # Whether to enable masked vectorization for the tail_loop. + enable_loop_tail_vec = True + + +# config specific to codegen/triton.py +class triton: + # Use cudagraphs on output code + cudagraphs = os.environ.get("TORCHINDUCTOR_CUDAGRAPHS") == "1" + + # Use cudagraph trees for memory pooling if `cudagraphs` is True + cudagraph_trees = True + + # Should we skip cudagraphing graphs with dynamic shape inputs + # If False, we will re-record a graph for each unique set of shape inputs + cudagraph_skip_dynamic_graphs = False + + # assertions not on the fast path, steady state + slow_path_cudagraph_asserts = True + + # TODO - need to debug why this prevents cleanup + cudagraph_trees_history_recording = False + + # Enable cudagraph support for mutated inputs from prior cudagraph pool + cudagraph_support_input_mutation = False if is_fbcode() else True + + # Maximal number of allowed cudagraph re-record for a function and + # a cudagraph node due to static input tensor address changes or + # cudagraph managed tensor data pointer changed. + # i.e., allow num_recording <= cudagraph_unexpected_rerecord_limit + # note: we are conservative here and choose a large limit. + cudagraph_unexpected_rerecord_limit = 128 + + # Warn loudly when the number of cudagraphs due to dynamic shape + # exceeds this limit + cudagraph_dynamic_shape_warn_limit: Optional[int] = 50 + + # synchronize after cudagraph invocation + force_cudagraph_sync = False + + # always run cudagraphs in the eager warmup stage + # instead of recording and executing cudagraphs + force_cudagraphs_warmup = False + + # assertions on the fast path + fast_path_cudagraph_asserts = False + + # skip warmup for cudagraph trees + skip_cudagraph_warmup = False + + # Synchronize before and after every compiled graph. + debug_sync_graph = False + + # Synchronize after every kernel launch, to help pinpoint bugs + debug_sync_kernel = False + + # Always load full blocks (rather than broadcasting inside the block) + dense_indexing = False + + # limit tiling dimensions + max_tiles = 2 + + # Prefer higher dimensional tilings. This simplifies indexing expressions, making + # it easier to identify block pointers. + prefer_nd_tiling: bool = False + + # use triton.autotune for pointwise ops with complex layouts + # this should only be disabled for debugging/testing + autotune_pointwise = True + + # max autotune gemm with cublasLt + autotune_cublasLt = True + + # Tune the generated Triton kernels at compile time instead of first time they run + autotune_at_compile_time = False + + # should we stop a fusion to allow better tiling? + tiling_prevents_pointwise_fusion = True + tiling_prevents_reduction_fusion = True + + # should we give different names to kernels + # Note: This is orthogonal to descriptive_names - this is deciding whether + # our triton kernel names should all be `triton_` (to maximize caching) or + # whether they should be unique. + unique_kernel_names = os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1" + + # should we put op names in kernel names + # False: No special names (just triton__1, triton__2, etc.) + # "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.) + # "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions) + # "inductor_node": Maps to the node name in the FX graph passed to Inductor + descriptive_names = "original_aten" + + # use alternate codegen for smaller reductions + persistent_reductions = ( + os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1" + ) + + # 0/False: disable + # 1/True: enable, use tuning to pick between different subkernels + # 2: enable, force using persistent reduction (for debugging) + # 3: enable, force using non-persistent reduction (for debugging) + multi_kernel = int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0")) + + # hint to Triton when arguments are divisible by 16 + divisible_by_16 = True + + # Minimum RBLOCK to be used for a TritonSplitScanKernel + # NOTE: This also indirectly controls the size of workspace buffer required + min_split_scan_rblock = 256 + + # Store the generated cubin files for cpp wrapper code to load + store_cubin = False + + # the max number of spills we allow for the configs we benchmark. + # Setting this to 0 means we skip a config if it spills even a single + # register. + # Setting it to a larger value allows a config spilling a small amount + # of registers being benchmarked. + # + # NOTE: triton will always report >0 register spills for kernels using sin/cos. + # (check this issue https://github.com/openai/triton/issues/1756 ) + # So far we see a fixed 8 spilled registers for kernels using sin/cos. + # Raise the threshold to 16 to be safe. + # We should revisit this once we understand more of the source of register spills. + spill_threshold: int = 16 + + # Generate code containing the newer tl.make_block_ptr() API for loads/store + use_block_ptr = False + + # Inject a bug into our relu implementation; useful for testing our repro + # extraction and minification functionality. + # Valid values: "compile_error", "runtime_error", "accuracy" + inject_relu_bug_TESTING_ONLY: Optional[str] = None + + # Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental) + codegen_upcast_to_fp32 = True + + +class aot_inductor: + # AOTInductor output path + # If an absolute path is specified, the generated lib files will be stored under the directory; + # If a relative path is specified, it will be used as a subdirectory under the default caching path; + # If not specified, a temp directory will be created under the default caching path. + # If the specified path contains something like "model.so", the sub-string will be used + # to name the generated library. + output_path = "" + + debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1" + + debug_dump_consts_bin: bool = ( + os.environ.get("AOT_INDUCTOR_DEBUG_DUMP_CONSTS_BIN", "0") == "1" + ) + + # option for debug printing/saving for intermediate tensor values for aot inductor + # 0: disable debug dumping + # 1: enable saving intermediate tensor values + # 2: enable printing intermediate tensor values + debug_intermediate_value_printer = os.environ.get( + "AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0" + ) + + # filtered nodes to be printed for debug values. Specify this option when debug_intermediate_value_printer is set to 2 + filtered_kernel_names = os.environ.get( + "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", None + ) + + # Serialized tree spec for flattening inputs + serialized_in_spec = "" + + # Serialized tree spec for flattening outputs + serialized_out_spec = "" + + # flag to decide whether to create a submodule for constant graph. + use_runtime_constant_folding: bool = False + + # flag to force weight to be appened to the shared library and mmaped by the runtime + # rather than embedded into the data section. Needed to support 1B+ parameter models + force_mmap_weights: bool = False + + package: bool = False + + +class cuda: + # CUDA arch to use for CUDA template kernel compilation. + # e.g. "70", "75", "80", "90", etc. + # When arch is None, Inductor uses torch.cuda.get_device_capability(0). + arch: Optional[str] = None + + # CUDA version to use for CUDA template kernel compilation. + # e.g. "11.4", "12.1", etc. + # When version is None, Inductor uses torch.version.cuda. + version: Optional[str] = None + + # Optimization level for the host compiler. + compile_opt_level = "-O1" + + # Whether to enable device LTO (link-time-optimization). + enable_cuda_lto = False + + # Whether to keep intermediate files dring compilation. + enable_ptxas_info = False + + # Whether to enable debug info, e.g. line number, cutlass debug info. + enable_debug_info = False + + # Whether to use fast math. + use_fast_math = False + + # Path to the CUTLASS repo root directory. + # The default path only works under PyTorch local development environment. + cutlass_dir = os.environ.get( + "TORCHINDUCTOR_CUTLASS_DIR", + os.path.abspath( + os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/") + ), + ) + + # Configures the maximum number of CUTLASS configs to profile in max_autotune. + # By default it's None, so that all CUTLASS configs are tuned. + # This is mainly used to reduce test time in CI. + cutlass_max_profiling_configs: Optional[int] = None + + # Path to CUDA NVCC. + # NVCC search order: + # 1) cuda_cxx set in this config + # 2) CUDACXX environment variable + # 3) CUDA_HOME environment variable + # 4) default system search PATH. + cuda_cxx: Optional[str] = None + + # Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops. + cutlass_backend_min_gemm_size: int = 1 + + # enable generation of inline standalone runner in CUDA CPP generated code + # which allows to compile the generated code into a standalone executable. + generate_test_runner: bool = ( + os.environ.get("INDUCTOR_CUDA_BACKEND_GENERATE_TEST_RUNNER_CODE", "1") == "1" + ) + + # Keep only Cutlass op configs which contain this regular expression pattern + # Set this to "warpspecialized_cooperative_epi_tma" to enable only SM90 TMA Cutlass Kernels for large GEMMs + cutlass_op_allowlist_regex: Optional[str] = None + + # Note: Names of Cutlass ops names can be obtained by calling + # op.configuration_name() on a Cutlass op instance, for example those + # returned from cutlass_utils.gen_ops() or the op argument passed to + # CUTLASSGemmTemplate.render(...) + + # Filter Cutlass configs which contain this regular expression pattern + # Set this to "pingpong" to avoid numerical issues + # caused by the op ordering of the "pingpong" memory access + # pattern used by some Cutlass Kernels. + cutlass_op_denylist_regex: Optional[str] = "pingpong" + + +class rocm: + # Offload arch list for device code compilation, e.g. ["gfx941", "gfx942"]. + # If empty, the `native` arch is used + arch: List[str] = [] + + # Enable the CK backend for CDNA2 and CDNA3 only (for now) + # Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors + ck_supported_arch: List[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"] + + # Optimization level, use to balance compilation speed and runtime performance + compile_opt_level = "-O2" + + # Flag to keep debug information in compiled objects + is_debug = False + + # Flag to keep intermediate files (assembly listings, preprocessed sources, etc.) + save_temps = False + + # Flag to add `-ffast-math`` to compile flags + use_fast_math = True + + # Flag to add `-fgpu-flush-denormals-to-zero` to compile flags + flush_denormals = True + + # Flag to print register and LDS usage during compilation + print_kernel_resource_usage = False + + # Path to ROCm installation, if None, use env variable ROCM_HOME + rocm_home: Optional[str] = None + + # Path to Composable Kernel library. + # Install with `pip install git+https://github.com/rocm/composable_kernel@develop`. + ck_dir = os.environ.get("TORCHINDUCTOR_CK_DIR") + + # Number of op instance choices to trade off between runtime perf and compilation time + n_max_profiling_configs: Optional[int] = None + + # Flag to use a short list of CK instances which perform well across a variety of shapes. + # Currently RCR and F16 only + use_preselected_instances: bool = False + + +# Backend to use for CPU codegen either "cpp" or "halide" (experimental) +cpu_backend = "cpp" + +# Backend to use for CUDA codegen either "triton" or "halide" (experimental) +cuda_backend = "triton" + + +class halide: + # Base halide target to use for CPU devices + cpu_target = "host" + + # Base halide target to use for CUDA devices + gpu_target = "host-cuda" + + # Halide autoscheduler to use, choices are: + # "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only) + scheduler_cuda = "Anderson2021" + scheduler_cpu = "Adams2019" + + # Controls `no_asserts` flag passed to Halide target (warning: can false positive) + asserts = False + + # Controls `debug` flag passed to Halide target + debug = False + + # Enable (or fallback on) scan kernels such as cumsum + # Halide autoschedulers struggle with these kernels + scan_kernels = False + + +# create a directory containing lots of debug information +class trace: + # master switch for all debugging flags below + enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + # Save debug information to a temporary directory + # If not specified, a temp directory will be created by system + debug_dir: Optional[str] = None + + # Save python logger call >=logging.DEBUG + debug_log = False + + # Save python logger call >=logging.INFO + info_log = False + + # Save input FX graph (post decomps, pre optimization) + fx_graph = True + + # Save FX graph after transformations + fx_graph_transformed = True + + # Save TorchInductor IR before fusion pass + ir_pre_fusion = True + + # Save TorchInductor IR after fusion pass + ir_post_fusion = True + + # Copy generated code to trace dir + output_code = True + + # SVG figure showing post-fusion graph + graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1" + + # SVG figure showing fx with fusion + draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1" + + # We draw our fx graphs with the "record" shape attribute by default. + # Sometimes, when the graph is very complex, we may hit dot errors like below: + # "flat edge between adjacent nodes one of which has a record shape - + # replace records with HTML-like labels" + # and thus fail to generate a graph. So, let's give the user an option + # to specify the shape attribute for the dot graph. For example, passing + # INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like lables + # to workaround the above failure. + dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None) + + # If not None, this is the URL that saves the SVG files of the input/output + # graph of each pass that changed the graph + # The nodes that are being transformed in each pass will be colored in yellow + # URL only supports local directory for now + log_url_for_graph_xform = os.environ.get("INDUCTOR_LOG_URL_FOR_GRAPH_XFORM", None) + + # Store cProfile (see snakeviz to view) + compile_profile = False + + # Upload the .tar.gz file + # Needs to be overriden based on specific environment needs + upload_tar: Optional[Callable[[str], None]] = None + + log_autotuning_results: bool = False + + +_save_config_ignore = [ + # workaround: "Can't pickle " + "trace.upload_tar", + "post_grad_custom_post_pass", + "post_grad_custom_pre_pass", + "joint_custom_pre_pass", + "joint_custom_post_pass", + "pre_grad_custom_pass", +] + +_cache_config_ignore_prefix = [ + # trace functions are not relevant to config caching + "trace", + # uses absolute path + "cuda.cutlass_dir", + # not relevant + "compile_threads", +] + +if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + +from torch.utils._config_module import install_config_module + + +# adds patch, save_config, etc +install_config_module(sys.modules[__name__]) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py b/.venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..72f34d32475f39c94a30308455edbcfcb8e03d08 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py @@ -0,0 +1,348 @@ +import collections +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.utils._pytree as pytree + + +aten = torch.ops.aten + +# We would like to split modules into two subgraphs for runtime weight updates to work correctly. +# The use case and more information could be found at: +# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing +META_TAG = "MODULE_TYPE" +MODULE_TAG = "_MAIN_MODULE" +CONST_MODULE_TAG = "_CONST_MODULE" + + +def replace_node_with_constant( + gm: torch.fx.GraphModule, + node: torch.fx.Node, + constant: torch.Tensor, + name: Optional[str] = None, +) -> None: + g = gm.graph + + if name: + qualname = name + else: + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 # type: ignore[assignment] + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 + + gm._frozen_param_count = i + 1 + + with g.inserting_before(node): + new_input_node = g.create_node("get_attr", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + + # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_buffer(qualname, constant) + setattr(gm, qualname, constant) + + +def is_const_source( + node: torch.fx.Node, lifted_constants: Optional[Dict[str, Any]] +) -> bool: + return node.op == "get_attr" or ( + node.op == "placeholder" + and lifted_constants is not None + and node.name in lifted_constants + ) + + +class ConstantFolder(torch.fx.Interpreter): + def __init__( + self, + gm: torch.fx.GraphModule, + skip_constructors: bool = False, + lifted_constants: Optional[Dict[str, torch.Tensor]] = None, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, + ) -> None: + super().__init__(gm) + self.node_replacements: Dict[torch.fx.Node, Any] = {} + self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter() + self.unknown_value = object() + self.skip_constructors: bool = skip_constructors + + # overwrite this to deallocate env values if their only remaining use + # is the output + self.user_to_last_uses = self.node_to_last_non_output_use() + self.lifted_constants = lifted_constants + + def _support_dynamic_shape(self) -> bool: + # ConstantFolder not support dynamic shape now + return False + + def _deduce_value(self, node: torch.fx.Node) -> Any: + return super().run_node(node) + + def is_impure(self, node: torch.fx.node.Node) -> bool: + if ( + node.target == torch.ops.prims.convert_element_type.default + and is_const_source(node.args[0], self.lifted_constants) # type: ignore[arg-type] + and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] + and node.args[1] == torch.bfloat16 + ): + # For int8_weight -> dq -> bf16_weight + return True + if node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + ]: + # For the pattern fp32_weight -> q -> dq + # We only folding fp32_weight -> q + # int8_weight and leave dq in graph to be fused + return True + return False + + def node_to_last_non_output_use(self) -> Dict[torch.fx.Node, List[torch.fx.Node]]: + last_non_output_use = collections.defaultdict(list) + seen_uses = set() + output_node = next(iter(reversed(self.module.graph.nodes))) + + for node in reversed(self.module.graph.nodes): + if node.target == "output": + continue + + def add_use(inp: torch.fx.Node) -> None: + if inp in seen_uses: + return + + seen_uses.add(inp) + last_non_output_use[node].append(inp) + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs)) + + # if this node is only used in output, we want to gc it right away + if len(node.users) == 1 and output_node in node.users: + last_non_output_use[node].append(node) + + return last_non_output_use + + def run_node(self, node: torch.fx.Node) -> Any: + if node.target == "output": + # because we remove nodes from env on last non output use, + # re-define them now or we'll get error in interpreter + def set_env(arg: torch.fx.Node) -> None: + self.env[arg] = self.unknown_value + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, set_env, node.args) + return super().run_node(node) + + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + # We need to do this weird thing because in cases where flattened_inputs + # contains a ScriptObject, equality checking results in a type error if + # the types are different. + if any( + type(self.unknown_value) == type(input_) and self.unknown_value == input_ + for input_ in flattened_inputs + ): + return self.unknown_value + + # TODO - fix errors with this + if ( + node.op == "call_function" + and node.target == aten._efficientzerotensor.default + ): + return self.unknown_value + + # TODO - constant folding triton kernel returns the inputs -- fix this + if ( + node.op == "call_function" + and node.name == "triton_kernel_wrapper_functional_proxy" + ): + return self.unknown_value + + # skip constructors, since inductor generates optimal code for them already + # and turning into tensor would result in an additional global memory read + # TODO - more complicated strategy + if ( + self.skip_constructors + and not is_const_source(node, self.lifted_constants) + and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) + ): + return self.unknown_value + + # All mutations should either be removed or on inputs which we did not make constant + if ( + isinstance(node.target, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return self.unknown_value + + out = self._deduce_value(node) + if out == self.unknown_value: + return self.unknown_value + + if not is_const_source(node, self.lifted_constants) and isinstance( + out, torch.Tensor + ): + if out.device.type == "meta": + return out + + if not self.insertable_tensor_check(out): + return out + + if self.is_impure(node): + return self.unknown_value + + self.add_node_replacement(node, out) + + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + + for n in flattened_node_inps: + if not isinstance(n, torch.fx.Node): + continue + + self.replaced_uses[n] += 1 + + for to_delete in self.user_to_last_uses.get(node, []): + if self.replaced_uses[to_delete] == len(to_delete.users): + self.node_replacements.pop(to_delete, None) + + return out + + def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor + + def run(self) -> Any: # type: ignore[override] + env: Dict[torch.fx.Node, Any] = {} + self.insert_placerholder_values(env) + return super().run(initial_env=env) + + def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None: + for n in self.module.graph.find_nodes(op="placeholder"): + if self.lifted_constants is not None and n.name in self.lifted_constants: + env[n] = self.lifted_constants[n.name] + else: + env[n] = self.unknown_value # type: ignore[assignment] + + +def constant_fold( + gm: torch.fx.GraphModule, + constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> None: + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + replace_node_with_constant(gm, node, constant) + + erased_params = [] + for node in gm.graph.find_nodes(op="get_attr"): + if len(node.users) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) + + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def constant_graph_tag( + gm: torch.fx.GraphModule, + lifted_constants: Optional[Dict[str, Any]], + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]], +) -> None: + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder( + gm, skip_constructors=True, lifted_constants=lifted_constants + ) + cf.run() + + for node in gm.graph.nodes: + if skip_folding_node_fn is not None and skip_folding_node_fn(node): + node.meta[META_TAG] = MODULE_TAG + continue + if ( + is_const_source(node, lifted_constants) + or node in cf.node_replacements + or node in cf.replaced_uses + ): + node.meta[META_TAG] = CONST_MODULE_TAG + else: + node.meta[META_TAG] = MODULE_TAG + + +def run_and_get_constant_graph( + gm: torch.fx.GraphModule, + lifted_constants: Optional[Dict[str, Any]], + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]], +) -> Tuple[torch.fx.GraphModule, Tuple[torch.Tensor, ...]]: + """ + Construct a GraphModule which corresponds to the part which could be + constant folded in provided gm. + """ + + constant_graph_tag(gm, lifted_constants, skip_folding_node_fn) + + def untag(node: torch.fx.Node) -> bool: + used_to_fold = False + for u in node.users: + if u.meta[META_TAG] == CONST_MODULE_TAG: + used_to_fold = True + break + if not used_to_fold: + node.meta[META_TAG] = MODULE_TAG + return used_to_fold + + const_args = [] + if lifted_constants is not None: + placeholders = list(gm.graph.find_nodes(op="placeholder")) + for node in placeholders: + if node.meta[META_TAG] == MODULE_TAG: + continue + if untag(node): + const_args.append(lifted_constants[node.name]) + + # We rewrite the tags, if it's a constant being directly consumed, without + # any folding opportunity, we keep it in main gm. + for node in gm.graph.find_nodes(op="get_attr"): + untag(node) + + new_graph = torch.fx.Graph() + + node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + output_nodes = [] + for node in gm.graph.nodes: + if node.meta[META_TAG] == MODULE_TAG: + continue + + new_node = new_graph.node_copy(node, lambda x: node_remapping[x]) + node_remapping[node] = new_node + + for user in node.users: + if user.meta[META_TAG] == MODULE_TAG: + output_nodes.append(new_node) + break + + new_graph.output(tuple(output_nodes)) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + + const_result = new_gm(*const_args) + return new_gm, const_result diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/cpu_vec_isa.py b/.venv/lib/python3.11/site-packages/torch/_inductor/cpu_vec_isa.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4838e5f168550656e4e0024d8fb2003ccf06a9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/cpu_vec_isa.py @@ -0,0 +1,373 @@ +# mypy: allow-untyped-defs +import dataclasses +import functools +import os +import platform +import re +import subprocess +import sys +from typing import Any, Callable, Dict, List + +import torch +from torch._inductor import config + + +_IS_WINDOWS = sys.platform == "win32" + + +def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str: + # ISA dry compile will cost about 1 sec time each startup time. + # Please check the issue: https://github.com/pytorch/pytorch/issues/100378 + # Actually, dry compile is checking compile capability for ISA. + # We just record the compiler version, isa options and pytorch version info, + # and generated them to output binary hash path. + # It would optimize and skip compile existing binary. + from torch._inductor.cpp_builder import get_compiler_version_info, get_cpp_compiler + + compiler_info = get_compiler_version_info(get_cpp_compiler()) + torch_version = torch.__version__ + fingerprint = f"{compiler_info}={isa_flags}={torch_version}" + return fingerprint + + +class VecISA: + _bit_width: int + _macro: List[str] + _arch_flags: str + _dtype_nelements: Dict[torch.dtype, int] + + # Note [Checking for Vectorized Support in Inductor] + # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions + # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions + # like exp, pow, sin, cos and etc. + # But PyTorch and TorchInductor might use different compilers to build code. If + # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so + # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass + # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest + # gcc/g++ compiler by default while it could support the AVX512 compilation. + # Therefore, there would be a conflict sleef version between PyTorch and + # TorchInductor. Hence, we dry-compile the following code to check whether current + # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM + # also needs the logic + # In fbcode however, we are using the same compiler for pytorch and for inductor codegen, + # making the runtime check unnecessary. + _avx_code = """ +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) +#include +#include +#endif + +alignas(64) float in_out_ptr0[16] = {0.0}; + +extern "C" void __avx_chk_kernel() { + auto tmp0 = at::vec::Vectorized(1); + auto tmp1 = tmp0.exp(); + tmp1.store(in_out_ptr0); +} +""" # noqa: B950 + + _avx_py_load = """ +import torch +from ctypes import cdll +cdll.LoadLibrary("__lib_path__") +""" + + def bit_width(self) -> int: + return self._bit_width + + def nelements(self, dtype: torch.dtype = torch.float) -> int: + return self._dtype_nelements[dtype] + + def build_macro(self) -> List[str]: + return self._macro + + def build_arch_flags(self) -> str: + return self._arch_flags + + def __hash__(self) -> int: + return hash(str(self)) + + def check_build(self, code: str) -> bool: + from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT, write + from torch._inductor.cpp_builder import ( + CppBuilder, + CppTorchOptions, + normalize_path_separator, + ) + + key, input_path = write( + code, + "cpp", + extra=_get_isa_dry_compile_fingerprint(self._arch_flags), + ) + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_dir = os.path.dirname(input_path) + buid_options = CppTorchOptions(vec_isa=self, warning_all=False) + x86_isa_help_builder = CppBuilder( + key, + [input_path], + buid_options, + output_dir, + ) + try: + # Check if the output file exist, and compile when not. + output_path = normalize_path_separator( + x86_isa_help_builder.get_target_file_path() + ) + if not os.path.isfile(output_path): + status, target_file = x86_isa_help_builder.build() + + # Check build result + subprocess.check_call( + [ + sys.executable, + "-c", + VecISA._avx_py_load.replace("__lib_path__", output_path), + ], + cwd=output_dir, + stderr=subprocess.DEVNULL, + env={**os.environ, "PYTHONPATH": ":".join(sys.path)}, + ) + except Exception as e: + return False + + return True + + @functools.lru_cache(None) # noqa: B019 + def __bool__(self) -> bool: + if config.cpp.vec_isa_ok is not None: + return config.cpp.vec_isa_ok + + if config.is_fbcode(): + return True + + return self.check_build(VecISA._avx_code) + + +@dataclasses.dataclass +class VecNEON(VecISA): + _bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h + _macro = ["CPU_CAPABILITY_NEON"] + if sys.platform == "darwin" and platform.processor() == "arm": + _macro.append("AT_BUILD_ARM_VEC256_WITH_SLEEF") + _arch_flags = "" # Unused + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "asimd" # detects the presence of advanced SIMD on armv8-a kernels + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +@dataclasses.dataclass +class VecAVX512(VecISA): + _bit_width = 512 + _macro = ["CPU_CAPABILITY_AVX512"] + _arch_flags = ( + "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" + if not _IS_WINDOWS + else "/arch:AVX512" + ) # TODO: use cflags + _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32} + + def __str__(self) -> str: + return "avx512" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +@dataclasses.dataclass +class VecAMX(VecAVX512): + _arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8" + + def __str__(self) -> str: + return super().__str__() + " amx_tile" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + _amx_code = """ +#include +#include + +struct amx_tilecfg { + uint8_t palette_id; + uint8_t start_row; + uint8_t reserved_0[14]; + uint16_t colsb[16]; + uint8_t rows[16]; +}; + +extern "C" void __amx_chk_kernel() { + amx_tilecfg cfg = {0}; + _tile_loadconfig(&cfg); + _tile_zero(0); + _tile_dpbf16ps(0, 1, 2); + _tile_dpbusd(0, 1, 2); +} +""" + + @functools.lru_cache(None) # noqa: B019 + def __bool__(self) -> bool: + if super().__bool__(): + if config.is_fbcode(): + return False + if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx(): + return True + return False + + +@dataclasses.dataclass +class VecAVX2(VecISA): + _bit_width = 256 + _macro = ["CPU_CAPABILITY_AVX2"] + _arch_flags = ( + "-mavx2 -mfma -mf16c" if not _IS_WINDOWS else "/arch:AVX2" + ) # TODO: use cflags + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "avx2" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +@dataclasses.dataclass +class VecZVECTOR(VecISA): + _bit_width = 256 + _macro = [ + "CPU_CAPABILITY_ZVECTOR", + "CPU_CAPABILITY=ZVECTOR", + "HAVE_ZVECTOR_CPU_DEFINITION", + ] + _arch_flags = "-mvx -mzvector" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "zvector" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +@dataclasses.dataclass +class VecVSX(VecISA): + _bit_width = 256 # VSX simd supports 128 bit_width, but aten is emulating it as 256 + _macro = ["CPU_CAPABILITY_VSX"] + _arch_flags = "-mvsx" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "vsx" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +class InvalidVecISA(VecISA): + _bit_width = 0 + _macro = [""] + _arch_flags = "" + _dtype_nelements = {} + + def __str__(self) -> str: + return "INVALID_VEC_ISA" + + def __bool__(self) -> bool: # type: ignore[override] + return False + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +def x86_isa_checker() -> List[str]: + supported_isa: List[str] = [] + + def _check_and_append_supported_isa( + dest: List[str], isa_supported: bool, isa_name: str + ) -> None: + if isa_supported: + dest.append(isa_name) + + Arch = platform.machine() + """ + Arch value is x86_64 on Linux, and the value is AMD64 on Windows. + """ + if Arch != "x86_64" and Arch != "AMD64": + return supported_isa + + avx2 = torch.cpu._is_avx2_supported() + avx512 = torch.cpu._is_avx512_supported() + amx_tile = torch.cpu._is_amx_tile_supported() + + _check_and_append_supported_isa(supported_isa, avx2, "avx2") + _check_and_append_supported_isa(supported_isa, avx512, "avx512") + _check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile") + + return supported_isa + + +invalid_vec_isa = InvalidVecISA() +supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()] + + +# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content +# might have too much redundant content that is useless for ISA check. Hence, +# we only cache some key isa information. +@functools.lru_cache(None) +def valid_vec_isa_list() -> List[VecISA]: + isa_list: List[VecISA] = [] + if sys.platform == "darwin" and platform.processor() == "arm": + isa_list.append(VecNEON()) + + if sys.platform not in ["linux", "win32"]: + return isa_list + + arch = platform.machine() + if arch == "s390x": + with open("/proc/cpuinfo") as _cpu_info: + while True: + line = _cpu_info.readline() + if not line: + break + # process line + featuresmatch = re.match(r"^features\s*:\s*(.*)$", line) + if featuresmatch: + for group in featuresmatch.groups(): + if re.search(r"[\^ ]+vxe[\$ ]+", group): + isa_list.append(VecZVECTOR()) + break + elif arch == "ppc64le": + isa_list.append(VecVSX()) + elif arch == "aarch64": + isa_list.append(VecNEON()) + elif arch in ["x86_64", "AMD64"]: + """ + arch value is x86_64 on Linux, and the value is AMD64 on Windows. + """ + _cpu_supported_x86_isa = x86_isa_checker() + for isa in supported_vec_isa_list: + if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa: + isa_list.append(isa) + + return isa_list + + +def pick_vec_isa() -> VecISA: + if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]): + return VecAVX2() + + _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list() + if not _valid_vec_isa_list: + return invalid_vec_isa + + # If the simdlen is None, it indicates determine the vectorization length automatically + if config.cpp.simdlen is None: + assert _valid_vec_isa_list + return _valid_vec_isa_list[0] + + for isa in _valid_vec_isa_list: + if config.cpp.simdlen == isa.bit_width(): + return isa + + return invalid_vec_isa diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py b/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py new file mode 100644 index 0000000000000000000000000000000000000000..5a33de0e3668998ac6f6d71ab16c696938ab1265 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py @@ -0,0 +1,2441 @@ +""" +CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables, +which share the same memory pool. Sharing a memory pool is an extremely +important optimization when chaining multiple CUDA graphs together, as it +prevents you from needing to copy intermediate tensors from one graph to the +next, and reduces overall memory usage by allowing dead memory from the first +pool to be reused in the second. + +The standard graph/make_graph_callables support sharing memory pool, but +with a lot of caveats. CUDA graph trees remove these restrictions: + +* Previously, if you recorded graphs A, B, you had to replay A, B in that + order. With CUDA graph trees, after replaying A, you can change your + mind and record/replay a different graph B'; we will support efficient + execution of both A, B and A, B', using only max(mem(A, B), mem(A, B')). In + other words: we support arbitrary trees of CUDA graph operations, not just + sequences (this is why this feature is called CUDA graph trees.) + +* Previously, if you executed graph A, some non-CUDA graph code, and then + graph B, after executing graph B, it was not safe to retain any references + to intermediates produced by A. With CUDA graph trees, we track if any +outputs of graph A are still live by the time graph B is run, and make + sure graph B doesn't clobber there memory when reusing the CUDA graphs + pool. You'll get a separate recording of B depending on what tensors + stay live or dead. + +CUDA graph trees are flexible enough to be used in Dynamo across graph breaks, +which is their primary use case. + +The ability to switch from replay to record is fairly nontrivial: remember that +when you replay a CUDA graph, you only replay CUDA operations; no CPU side state +is updated. In particular, the CPU-side book-keeping for the allocator is not +reconstructed. However, to record a new child CUDA graph, we must restore this +book-keeping. This is what checkpoint pool state is used for. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import gc +import itertools +import operator +import sys +import threading +import traceback +import warnings +import weakref +from collections import defaultdict +from enum import auto, Enum +from typing import ( + Any, + Callable, + cast, + ContextManager, + Dict, + Generator, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + TYPE_CHECKING, + TypeVar, + Union, +) + +import torch.fx +from torch import Tensor +from torch._dynamo.mutation_guard import GenerationTracker +from torch._dynamo.utils import counters, preserve_rng_state +from torch._inductor.compile_fx import ( + align_inputs_from_check_idxs, + copy_misaligned_inputs, + get_expanded_dims, + get_input_idxs_to_check, + index_expanded_dims, + remove_unaligned_input_idxs, + static_input, +) +from torch._inductor.cudagraph_utils import ( + check_for_mutation, + CheckInvariantStatus, + FunctionID, + log_cudagraph_skip_and_bump_counter, + log_data_ptr_mismatch, + maybe_warning_due_to_dynamic_shape, + ModelType, + OutputType, + PlaceholderInfo, + WrappedFunction, +) +from torch.multiprocessing.reductions import StorageWeakRef +from torch.storage import UntypedStorage +from torch.utils import _pytree as pytree +from torch.utils.weak import TensorWeakRef + + +if TYPE_CHECKING: + from torch._inductor.utils import InputType + from torch.types import _bool + +StorageWeakRefPointer = int +StorageDataPtr = int +NBytes = int +S = TypeVar("S", bound="StorageWeakRefWrapper") + + +if torch.backends.cuda.is_built(): + from torch._C import ( + _cuda_CUDAAllocator_AllocatorState as AllocatorState, + _set_cached_tensors_enabled as _set_cached_tensors_enabled, + ) +else: + + class AllocatorState: # type: ignore[no-redef] + pass + + def _set_cached_tensors_enabled(enabled: _bool) -> None: + pass + + +log = torch._logging.getArtifactLogger(__name__, "cudagraphs") + + +from . import config + + +@dataclasses.dataclass(frozen=True) +class GraphID: + "Unique counter of a cuda graph recording" + id: int + + +def clear_cublass_cache() -> None: + """ + Cublas keeps a persistent workspace allocation for running matmuls. This poses a problem for + doing warmup within a CUDAGraph private pool because we do not want persistent allocations from + one one run to the next. When we begin a new run of a cudagraphs path (generation), all tensors + from the previous generation are freed. This frees them the memory pool, but not elsewhere. + A tensor in the cublas workspace would continue to be in use the workspace but would also get allocated + in the next run. The memory would be in use in two places. + + To solve this, we clear cublas caches before and after warming up or recording. If a workspace is required + it will be allocated to the cudagraph private pool and accounted for in the allocator for the duration of the + program. There is no overhead to this on replay since cudagraphs removes allocation overhead. + """ + torch._C._cuda_clearCublasWorkspaces() + + +@contextlib.contextmanager +def clear_cublas_manager() -> Generator[None, None, None]: + "Context manager around clearing cublas caches that will clear on enter and exit" + clear_cublass_cache() + try: + yield + finally: + clear_cublass_cache() + + +@contextlib.contextmanager +def disable_conv_cache_emptying() -> Generator[None, None, None]: + prev = torch._C._cuda_get_conv_benchmark_empty_cache() + torch._C._cudnn_set_conv_benchmark_empty_cache(False) + try: + yield + finally: + torch._C._cudnn_set_conv_benchmark_empty_cache(prev) + + +@contextlib.contextmanager +def enable_history_recording() -> Generator[None, None, None]: + "Turns on history recording in the CUDA Caching Allocator" + enabled = torch._C._cuda_isHistoryEnabled() + try: + if not enabled: + torch.cuda.memory._record_memory_history() + yield + finally: + if not enabled: + torch.cuda.memory._record_memory_history(None) + + +def get_history_recording() -> ContextManager[None]: + # TODO - remove, prevents cleanup + if not config.triton.cudagraph_trees_history_recording: + return contextlib.nullcontext() + return enable_history_recording() + + +class TreeManagerContainer: + """ + Manages the lifetime of the tree manager. Like `PrivatePool` in cuda caching allocator, + the tree and its corresponding memory pool should be kept alive as long as any outstanding + graph or tensor which is an output of a graph remains alive. + + There is a single tree manager container per device. + + The lifecycle of a tree_manager is: + - Is constructed, no graph, no fns, no tensors + - Tree manager is fetched, resulting in tree manager being allocated + - We generate a bunch of functions, calling add_strong_reference + - These functions die, calling finalize_reference + - When all the functions die, we finalize_tree_manager. + + TODO: in the future, we would like to do the following once storage weak refs land + - We look for all the live storages and add references to THOSE + - We count as storages die + - All the storages are dead, we deallocate the tree manager + """ + + def __init__(self, device_index: int) -> None: + # This class keeps a strong reference to tree_manager, + # but upon all other strong references to the tree_manager will reset it to None. + # We need a strong reference so that we can still access its attributes upon cleanup. + self.tree_manager: Optional[CUDAGraphTreeManager] = None + + # Number of outstanding references to the current tree manager + self.live_cudagraphify_fns = 0 + + self.device_index = device_index + + # Following two objects are only set in the case that Tensor outputs outlive + # the cudagraphify_fns. Reference to the Graph is needed to keep the private pool from + # deallocation. + self.live_storages_count = 0 + self.graph: Optional[torch.cuda.CUDAGraph] = None + + self.lock = threading.Lock() + + def _finalize_tensor(self) -> None: + with self.lock: + self.live_storages_count -= 1 + if self.live_storages_count == 0: + self.graph = None + + # manager was used again after existing cleanup, + # we shouldnt set it to None + if self.live_cudagraphify_fns == 0: + self.tree_manager = None + + def finalize_cudagraphify_fn(self) -> None: + with self.lock: + self.live_cudagraphify_fns -= 1 + if self.live_cudagraphify_fns == 0: + self._finalize_tree_manager() + + def _finalize_tree_manager(self) -> None: + assert self.lock.locked() + self.tree_manager = None + + # TODO - when issue #91395 is landed, we can set a weakref on + # storages and trigger a deallocation when all outputs of the + # cudagraph are dead. + + # live_storages = list( + # tree_manager.live_cudagraph_pool_storages_in_curr_execution() + # ) + + # # Maintain reference to graph to keep tensors alive + # assert len(tree_manager.roots) > 0, "expected at least one use" + # root = next(tree_manager.get_roots()) + # self.graph = root.graph + # seen_storages = set() + # for stor in live_storages: + # if stor in seen_storages: + # continue + # seen_storages.add(stor) + # self.live_storages_count += 1 + # . weakref.finalize(stor, self._finalize_tensor) + + def add_strong_reference(self, fn: Callable[..., Any]) -> None: + with self.lock: + self.live_cudagraphify_fns += 1 + + weakref.finalize(fn, self.finalize_cudagraphify_fn) + + def get_tree_manager(self) -> CUDAGraphTreeManager: + with self.lock: + if self.tree_manager is None: + self.tree_manager = CUDAGraphTreeManager(self.device_index) + return self.tree_manager + + +local = threading.local() + +# one tree manager per device +local.tree_manager_containers = {} +local.tree_manager_locks = defaultdict(threading.Lock) + + +# only incremented by user call of mark_step_begin +class MarkStepBox: + mark_step_counter = 0 + + +# We need to register this as an object that will be copied over as TLS when new +# threads are created in autograd +torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_containers) +torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks) + + +def mark_step_begin() -> None: + "Indicates that a new iteration of inference or training is about to begin." + + # iterate down to distinguish from GenerationTracking counter + MarkStepBox.mark_step_counter -= 1 + + +def reset_cudagraph_trees() -> None: + "Clear all cudagraph trees" + # see shutdown below for why this is necessary + container_dict = get_obj(local, "tree_manager_containers") + locks_dict = get_obj(local, "tree_manager_locks") + for device, lock in locks_dict.items(): + with lock: + container = container_dict.get(device) + if not container or not container.tree_manager: + continue + + container.tree_manager.shutdown() + + _set_cached_tensors_enabled(False) + container_dict.clear() + + MarkStepBox.mark_step_counter = 0 + + +def get_obj(local: Any, attr_name: str) -> Any: + if hasattr(local, attr_name): + return getattr(local, attr_name) + else: + assert torch._C._is_key_in_tls(attr_name) + return torch._C._get_obj_in_tls(attr_name) + + +def get_container(device_index: int) -> TreeManagerContainer: + container_dict = get_obj(local, "tree_manager_containers") + lock = get_obj(local, "tree_manager_locks")[device_index] + + with lock: + if device_index not in container_dict: + container_dict[device_index] = TreeManagerContainer(device_index) + + return container_dict[device_index] + + +def get_manager( + device_index: int, create_if_none_exists: bool = True +) -> Optional[CUDAGraphTreeManager]: + if create_if_none_exists: + return get_container(device_index).get_tree_manager() + return get_container(device_index).tree_manager + + +def cudagraphify_impl( + model: ModelType, + inputs: List[InputType], + static_input_idxs: Sequence[int], + *args: Any, + **kwargs: Any, +) -> ModelType: + fn_cache: Dict[Tuple[int, ...], Callable[..., Any]] = {} + + # Detect int inputs: we need to index on these + int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)] + get_ints: Any = operator.itemgetter(*int_key) if int_key else lambda _: None + + has_warn = False + + del inputs + + def deferred_cudagraphify(inputs: List[InputType]) -> OutputType: + nonlocal has_warn + + int_key = get_ints(inputs) + fn = fn_cache.get(int_key) + if fn is not None: + return fn(inputs) + + if int_key is None: + log.info("recording cudagraph tree for graph without symints") + else: + log.info("recording cudagraph tree for symint key %s", int_key) + + if not has_warn: + has_warn = maybe_warning_due_to_dynamic_shape(fn_cache, int_key) + + # first get indices we need to check to align, then update our static inputs, + # and finally copy + check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) + new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) + copy_misaligned_inputs(inputs, check_input_idxs) + + fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs) + fn = align_inputs_from_check_idxs(fn, inputs_to_check=check_input_idxs) + fn_cache[int_key] = fn + + return out + + return deferred_cudagraphify + + +def cudagraphify( + model: ModelType, + inputs: List[InputType], + static_input_idxs: Sequence[int] = (), + *, + device_index: int, + is_backward: bool, + is_inference: bool, + stack_traces: Optional[StackTraces] = None, + constants: Tuple[torch.Tensor, ...] = (), + placeholders: Tuple[PlaceholderInfo, ...] = (), + mutated_input_idxs: Tuple[int, ...] = (), +) -> Tuple[ModelType, OutputType]: + manager = get_container(device_index).get_tree_manager() + assert not (is_backward and is_inference) + mode = ( + CompilationMode.BACKWARD + if is_backward + else (CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD) + ) + + return manager.add_function( + model, + inputs, + static_input_idxs, + stack_traces, + mode, + constants, + placeholders, + mutated_input_idxs, + ) + + +class StorageWeakRefWrapper: + """ + Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked. + """ + + __slots__ = ["ref", "_data_ptr", "extra_ref_check"] + + storage_ref: Optional[StorageWeakRef] + + def __init__( + self, + inp: Union[Tensor, UntypedStorage], + extra_ref_check: Optional[Callable[[], bool]] = None, + ) -> None: + """ + extra_ref_check is an additional check we need to run to check if the + weak ref has expired. in checking storage use count we assume extra_ref_check + will hold an additional reference to the storage. + """ + if isinstance(inp, Tensor): + stor = inp.untyped_storage() + else: + assert isinstance(inp, UntypedStorage) + stor = inp + self.ref = StorageWeakRef(stor) + self._data_ptr = stor.data_ptr() + self.extra_ref_check = extra_ref_check + + @classmethod + def from_weakref_and_data_ptr( + cls: Type[S], + cdata: Any, + data_ptr: int, + extra_ref_check: Optional[Callable[[], bool]] = None, + ) -> StorageWeakRefWrapper: + instance = cls.__new__(cls) + instance._data_ptr = data_ptr + instance.ref = StorageWeakRef.from_weakref(cdata) + instance.extra_ref_check = extra_ref_check + return instance + + def __call__(self) -> Optional[StorageWeakRefPointer]: + if self.expired(): + return None + + return self.ref.cdata + + def swap_weakref(self, cdata: Any) -> None: + self.ref.__del__() + self.ref.cdata = cdata + + def data_ptr(self) -> int: + "NB: returns the data ptr even if the storage has expired" + return self._data_ptr + + def remove_extra_reference(self) -> None: + self.extra_ref_check = None + + def expired(self) -> bool: + if self.extra_ref_check is not None and not self.extra_ref_check(): + return False + + # if extra_ref_check is not None we expect an additional reference + stor_count = torch._C._storage_Use_Count(self.ref.cdata) + return (stor_count - (self.extra_ref_check is not None)) == 0 + + def __repr__(self) -> str: + if self.ref is None or self.ref.expired(): + return f"StorageWeakRefWrapper to {self.data_ptr()}; dead" + else: + return f"StorageWeakRefWrapper to {self.data_ptr()}; alive" + + +def is_live(weak_ref: Optional[StorageWeakRefWrapper]) -> bool: + return maybe_deref(weak_ref) is not None + + +def maybe_deref( + weak_ref: Optional[StorageWeakRefWrapper], +) -> Optional[Tuple[StorageWeakRefPointer, int]]: + if weak_ref is None: + return None + r = weak_ref() + if r is None: + return None + # NB: r.data_ptr() does not necessarily equal weak_ref.data_ptr() + return r, weak_ref.data_ptr() + + +@contextlib.contextmanager +def _use_cuda_memory_pool_manager( + device: int, mem_pool: Tuple[int, int], stream: torch.cuda.Stream +) -> Generator[None, None, None]: + """ + Context manager to use cuda graph pool for new allocations. If you use this manager + all cudagraph tensors in use should be reflected in the allocator or they will be overwritten. + existing_graph should already have been used in a capture, and the mem_pool must already exist, + because this manager will not preserve a reference to the pool which keeps it alive. + """ + torch.cuda.synchronize() + stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(stream), torch.device(device): + torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool) + try: + yield + finally: + torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool) + torch._C._cuda_releasePool(device, mem_pool) + + torch.cuda.current_stream().wait_stream(stream) + + +def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]: + if not isinstance(t, torch.Tensor): + assert t is None + return None + return StorageWeakRefWrapper(t) + + +# A path index of (depth, offset) indices into a graph that is `depth`` number of nodes from the root +# at graph output offset +PathOutputIndex = Tuple[int, int] + +# For each node in the path, for each output, is the output alive +PathLiveness = List[List[bool]] + +StackTraces = List[Optional[str]] + + +class CUDAWarmupNode: + """ + Simplified Wrapper around A CUDA Model that wraps outputs in storage refs and exposes + apis to get the live storages in the current chain of warmup. + + A CUDAWarmupNode may have either CUDAGraphNode or CUDAWarmupNode as a parent, but may only have + CUDAWarmupNode as children, because we cannot record or execute with tensors which do not have stable + memory addresses. + + CUDAWarmupNode and CUDAGraphNode have a number of differences that make it easier to use separate classes. + - Much of the CUDAGraphNode logic & initialization is based on the tensor properties of first recording. In the + first instance of warmup, these are not finalized yet. + - All Inputs to the RecordedFunction must be copied over to the cuda graph memory pool, this is unnecessary in warmup. + - CUDAWarmup is only used once and so does not need to optimize as much bookkeeping. It is much simpler. + + NB: this class and CUDAGraphNode need to expose `path_live_weakrefs`, `all_outputs_are_dead`, and + `self.outputs_weakrefs`, `stack_traces`, and `tensor_weakrefs` for compatibility. + """ + + def __init__( + self, + wrapped_function: WrappedFunction, + parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]], + cuda_graphs_pool: Tuple[int, int], + existing_cuda_graph: Optional[torch.cuda.CUDAGraph], + device_index: int, + stack_traces: Optional[StackTraces], + stream: torch.cuda.Stream, + already_warm: bool, + id: GraphID, + ) -> None: + self.wrapped_function = wrapped_function + self.parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = parent + self.cuda_graphs_pool = cuda_graphs_pool + self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = [] + self.tensor_weakrefs: List[Optional[TensorWeakRef]] = [] + self.existing_cuda_graph = existing_cuda_graph + self.has_run = False + self.device_index = device_index + self.stack_traces = stack_traces + self.stream = stream + self.already_warm = already_warm + self.id = id + + def run(self, new_inputs: Any) -> OutputType: + assert not self.has_run, "Wrapped function should never be run twice" + + # See: output_is_alias_of_persistent_static_inputs below. We should only be returning freshly created + # storages in path_live_weakrefs. + existing_path_data_ptrs = { + t.data_ptr() for t in self.path_live_weakrefs() if t() + } + + def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]: + non_cudagraph_inps = [] + for t in itertools.chain(new_inputs, self.wrapped_function.constants): + if ( + isinstance(t, torch.Tensor) + and t.untyped_storage().data_ptr() not in existing_path_data_ptrs + ): + non_cudagraph_inps.append(weakref.ref(t.untyped_storage())) + return non_cudagraph_inps + + non_cudagraph_inps_storages = get_non_cudagraph_inps() + + if config.triton.slow_path_cudagraph_asserts and not self.already_warm: + refs = list(self.path_live_weakrefs()) + check_memory_pool(self.device_index, self.cuda_graphs_pool, refs) + + with torch.cuda.device( + self.device_index + ), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager( + self.device_index, self.cuda_graphs_pool, self.stream + ), get_history_recording(): + out = self.wrapped_function.model(new_inputs) + + # We need to know which outputs are allocated within the cudagraph pool + # so that we can deallocate them at the beginning of the next cudagraph step, + # and set their access to error. + # We use a weakref to the inputs storage, in case a block which was previously + # allocated to the general caching allocator pool gets reallocated to a private pool. + + non_cudagraph_inps_storage_ptrs = set() + for storage in non_cudagraph_inps_storages: + s = storage() + if s is not None: + non_cudagraph_inps_storage_ptrs.add(s._cdata) + + assert len(new_inputs) == 0 + + # sdpa returns cpu tensors when not recording cuda graph + def add_ref(o: Any) -> bool: + return ( + isinstance(o, torch.Tensor) + and o.is_cuda + and o.untyped_storage()._cdata not in non_cudagraph_inps_storage_ptrs + and o.untyped_storage().data_ptr() != 0 + ) + + self.outputs_weakrefs.extend( + [map_to_ref(o) if add_ref(o) else None for o in out] + ) + self.tensor_weakrefs.extend( + [TensorWeakRef(o) if add_ref(o) else None for o in out] + ) + + if config.triton.slow_path_cudagraph_asserts and not self.already_warm: + out_refs = list(self.path_live_weakrefs()) + check_memory_pool(self.device_index, self.cuda_graphs_pool, out_refs) + + return out + + @property + def _path_from_root( + self, + ) -> Generator[Union[CUDAGraphNode, CUDAWarmupNode], None, None]: + nodes = [] + node: Union[CUDAGraphNode, CUDAWarmupNode] = self + while node: + nodes.append(node) + node = node.parent # type: ignore[assignment] + + yield from reversed(nodes) + + def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: + "Returns all live storages weakrefs that created by nodes in this path" + for node in self._path_from_root: + for output in node.outputs_weakrefs: + if is_live(output): + yield output # type: ignore[misc] + + def all_outputs_are_dead(self) -> bool: + return not list(self.path_live_weakrefs()) + + def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool: + for storage_weak_ref in self.path_live_weakrefs(): + if t.untyped_storage().data_ptr() == storage_weak_ref.data_ptr(): + return True + return False + + +# Aliases for List that say what the indices denote +InputList = List # input indexes +OutputList = List # output indexes +LevelList = List # levels (distance from root of tree) + + +class OutputAliasInfo: + pass + + +class _UnaliasedStorage(OutputAliasInfo): + "Singleton to mark that the graph output constructs a new alias or is None" + + +UnaliasedStorage = _UnaliasedStorage() + + +class AliasesPriorGraphOutput(OutputAliasInfo): + "Marks that the graph output aliases an output of a prior graph" + __slots__ = ["index"] + + index: PathOutputIndex + + def __init__(self, index: PathOutputIndex) -> None: + assert isinstance(index, tuple) + self.index = index + + +class AliasesNewOutput(OutputAliasInfo): + "Marks that the graph output aliases an index in the new, returned outputs" + + __slots__ = ["index"] + + index: int + + def __init__(self, index: int) -> None: + assert isinstance(index, int) + self.index = index + + +class CUDAGraphNode: + """ + A single recording of a function into a CUDA Graph. Recordings of CUDA Graphs share a single memory pool + and are structured into a tree, where there is a single recording that can precede it (parent) and multiple + subsequent recordings that may follow (children). A node will have no parent if it is the first recording + in a tree; i.e., when it is first recorded, there are no live tensors from a previous recording which + would force a dependency. + + On first recording, all of the live tensors in the current CUDA Graph Node path will be + reflected in the corresponding private pool. On subsequent executions, the caching allocator + is unaffected when the graph is replayed. + + In order to support recording a subsequent cuda graph recording after execution of this graph, + we checkpoint the state of the memory pool so that it may later be resumed. + + WrappedFunction should have already been warmed up prior to invocation. + + See [setCheckpointPoolState] for further explanation, as well as + https://user-images.githubusercontent.com/13564/222815509-374f3400-f83d-4f7d-8fa6-4a092b3250bb.png + """ + + def __init__( + self, + wrapped_function: WrappedFunction, + id: GraphID, + parent: Optional[CUDAGraphNode], + inputs: List[InputType], + cuda_graphs_pool: Tuple[int, int], + device_index: int, + stack_traces: Optional[StackTraces], + stream: torch.cuda.Stream, + ) -> None: + assert isinstance(inputs, (list, tuple)) + + self.wrapped_function = wrapped_function + self.id = id + self.device = device_index + self.stack_traces = stack_traces + self.stream = stream + + # Enable re-record a cudagraph when static tensor address changed. + # if not we should error when it changed. + self.rerecord_if_static_inputs_change = ( + torch._dynamo.config.inline_inbuilt_nn_modules + or torch._inductor.config.triton.cudagraph_support_input_mutation + ) + + # if this is a root parent will be None. use weakref to prevent reference cycle + self._parent = weakref.ref(parent) if parent is not None else None + # reference to the shared memory pool for the entire cuda graphs tree + self.cuda_graphs_pool = cuda_graphs_pool + + # A single wrapped function may be recorded multiple times if memory patterns or + # invariants change from one execution to the next + self.children: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list) + + # StorageWeakRef maintains whether the Storage C++ object remains allocated, + # not whether the corresponding memory has been deallocated. In order + # to use them to track memory deallocations we must maintain a single StorageWeakRef + # for all Storages that reference that memory (even if we are constructing Storages + # that do not have a deallocator function). We maintain one single storage_cache + # as we execute any tree path. When we retrieve a storage from the cache we + # check that it is still alive, and we hash based on observed recording data ptr + # and storage cdata. + + # we preserve a single reference to executed outputs that is then referenced + # in children to avoid children having to chase parent pointers in the hot path + # DO NOT reassign output_weakrefs, only call `clear()` + # Path is a series of nodes from root to the current node + self.outputs_weakrefs: OutputList[Optional[StorageWeakRefWrapper]] = [] + self.path_weakrefs: LevelList[OutputList[Optional[StorageWeakRefWrapper]]] = [ + node.outputs_weakrefs for node in self._path_from_root + ] + self.path_stacktraces: LevelList[Optional[StackTraces]] = [ + node.stack_traces for node in self._path_from_root + ] + self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = [] + + # tensors which are outputs of previous graphs in the tree + self.cudagraph_managed_idxs: List[int] = [ + idx + for idx, t in enumerate(inputs) + if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t) + ] + + self.static_input_idxs: List[int] = list( + set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs) + ) + + self.non_static_input_idx: LevelList[int] = [ + i for i in range(len(inputs)) if i not in self.static_input_idxs + ] + + counters["inductor"]["cudagraph_recorded_non_static_inputs"] += len( + self.non_static_input_idx + ) + + self.non_managed_static_input_idxs: LevelList[int] = [ + i + for i in wrapped_function.static_input_idxs + if i not in self.cudagraph_managed_idxs + ] + + def maybe_get_static_data_ptr( + idx: int, + inputs: List[Union[torch.Tensor, int]], + static_input_idxs: List[int], + ) -> Optional[int]: + inp = inputs[idx] + if isinstance(inp, torch.Tensor) and idx in static_input_idxs: + return inp.data_ptr() + return None + + self.static_input_data_ptrs: InputList[Optional[int]] = [ + maybe_get_static_data_ptr(i, inputs, self.static_input_idxs) + for i in range(len(inputs)) + ] + + # When we checkpoint, and free generations, we will be manually freeing the outputs + # of CUDAGraphNodes. We should not be freeing parameters, not do we need to account for + # their liveness (they are static), so we need to compute which outputs are aliases of + # parameters. Some static inputs are saved tensors from the forward that die in the backward. + # Their locations are static but lifetimes are not. We only include the persistent static + # data ptrs below because the non persistent data ptrs may be outputs of this record and + # fresh allocations. + + # precompute expanded dims to avoid computing in the hot path + self.expanded_dims: List[List[int]] = [ + get_expanded_dims(x) + if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs + else [] + for idx, x in enumerate(inputs) + ] + + # For each node in path, which outputs were observed to be live + # before invoking graph recording, and after graph recording + self.recorded_liveness_before_graph: LevelList[OutputList[bool]] = [] + self.recorded_liveness_after_graph: LevelList[OutputList[bool]] = [] + + # List of Tuples of (depth, output_index) that index into node at depth + # number of nodes from root and output_index of outputs. Will index into + # path_weakrefs. + self.expected_dead_indices_before_graph: List[PathOutputIndex] = [] + self.expected_dead_indices_after_graph: List[PathOutputIndex] = [] + + # all live indices after graph recording + self.live_indices_after_graph: List[PathOutputIndex] = [] + + if self.parent is not None: + previous_liveness = self.parent.recorded_liveness_after_graph + curr_liveness = self._get_liveness(self.path_weakrefs) + + different_indices = self._get_different_indices( + previous_liveness, curr_liveness + ) + + self.recorded_liveness_before_graph = curr_liveness + self.expected_dead_indices_before_graph = different_indices + + recording_inputs = self._allocate_and_copy_recording_inputs(inputs) + # recording inputs will copy over memory, so we can free non recording inputs + inputs.clear() + del inputs + + # graph used for recording model invocation + self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + + # we allocate non-static inputs within the same memory pool as the CUDAGraph + # which we will record the model with. For memory efficiency, it is important + # to reclaim the input memory when the inputs are no longer live. To accomplish this, + # we reconstruct tensors at the correct data pointers of our inputs which are + # non owning and do not prevent deallocation. On subsequent executions, input values + # will be copied over to these tensors. + self.reconstructed_inputs: List[InputType] = [ + self._reconstruct_from_tensor_metadata(self._tensor_metadata(x)) + if isinstance(x, torch.Tensor) + else x + for x in recording_inputs + ] + + # DO THE RECORDING!!! + # We record the CUDA graph in the constructor of CUDAGraphNode, which + # gives you what the CPU side compute of the function would do. We + # don't throw the recording outputs away: their memory is + # correctly accounted for in the CUDAGraphs caching allocator. This + # means on the very FIRST run of the CUDA graph node, we can directly + # do more recording, because we have a valid caching allocator state. + # NB: This relies on run() being called immediately after the + # constructor, otherwise this optimization would not be valid. + + # initialized below in _record + + self.checkpointed_caching_state: Optional[AllocatorState] = None + + # Output Storage Alias information, can be: + # - A new, unaliased storage, or the output is None + # - An alias of an output of a prior graph + # - An alias of an output already created in the reconstructed outputs + # This is None if the output in question is an int + self.output_storage_alias: OutputList[Optional[OutputAliasInfo]] = [] + + # is the output Storage unaliased in subsequent outputs, of all subsequent paths + # if it is, we cached the output tensor and adjust storage liveness tracking to also + # check if the output tensor does not have an additional python reference. + # If a descendent node discovers it has an alias of a prior output, then the output + # will no longer be cached in the ancestor. + # The large majority of tensors are unaliased, and preserving aliased output tensors would add + # significant additional complexity with marginal gains + # The cached tensor outputs are added on the first execution, and cleared whenever we need + # to do subsequent recording + self.unaliased_in_all_paths: OutputList[bool] = [] + self.cached_tensor_outputs: OutputList[Optional[Tensor]] = [] + + # if an output aliases a static, persistent input then the corresponding Tensor will + # be set here. These are different than cached tensors, because they are tensors that + # are aliases of parameters that are always live. + self.static_output_tensors: OutputList[Optional[Tensor]] = [] + + # Cleared after recording + self.recording_outputs: Optional[OutputType] = self._record( + wrapped_function.model, recording_inputs + ) + self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = [] + + # As with inputs, we do not want to keep the outputs permanently alive because that would prevent + # their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata + # needed to reconstruct instead. + assert self.recording_outputs is not None + for out in self.recording_outputs: + if isinstance(out, torch.Tensor): + self.outputs_metadata.append( + self._tensor_metadata(out, ignore_storage_offset=False) + ) + else: + assert isinstance(out, (int, type(None))), type(out) + self.outputs_metadata.append(out) + + self.graph.replay() + + def _copy_inputs_and_remove_from_src( + self, dsts: List[InputType], srcs: List[InputType] + ) -> None: + dst_tensors = [] + src_tensors = [] + for idx in self.non_static_input_idx: + if not isinstance(srcs[idx], torch.Tensor): + continue + expanded_dims = self.expanded_dims[idx] + dst_tensors.append(index_expanded_dims(dsts[idx], expanded_dims)) # type: ignore[arg-type] + src_tensors.append(index_expanded_dims(srcs[idx], expanded_dims)) # type: ignore[arg-type] + srcs[idx] = None # type: ignore[call-overload] + # Fails on empty lists + if dst_tensors: + torch._foreach_copy_(dst_tensors, src_tensors) + + def check_static_inputs_are_stable(self, new_inputs: List[InputType]) -> None: + # avoid checking managed tensor static points since we already checked those in check_invariants + if ( + not self.rerecord_if_static_inputs_change + and not torch._C._tensors_data_ptrs_at_indices_equal( + new_inputs, # type: ignore[arg-type] + self.static_input_data_ptrs, + self.non_managed_static_input_idxs, + ) + ): + # this should error + error_msg = log_data_ptr_mismatch( + self.wrapped_function.placeholders, + new_inputs, + self.static_input_data_ptrs, + self.non_managed_static_input_idxs, + CheckInvariantStatus.StaticInputIdxMismatch, + ) + torch._check(False, lambda: error_msg) + + def run_first_inputs(self, new_inputs: List[InputType]) -> OutputType: + if config.triton.fast_path_cudagraph_asserts: + self.debug_check_invariants_before_invocation() + + # graph is already invoked in the __init__ + # inputs are copied over in _allocate_recording_inputs and subsequently cleared + assert len(new_inputs) == 0 + outputs = self.recording_outputs + self.recording_outputs = None + assert outputs is not None + return outputs + + def run(self, new_inputs: List[InputType]) -> OutputType: + self.check_static_inputs_are_stable(new_inputs) + + self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs) + new_inputs.clear() + + self.run_graph() + + outputs = self.reconstruct_outputs() + + if config.triton.fast_path_cudagraph_asserts: + self.debug_check_invariants_after_invocation() + + if config.triton.force_cudagraph_sync: + torch.cuda.synchronize() + + # Reset this to run the check in the future + self.static_inputs_stable = False + + return outputs + + def reconstruct_outputs(self) -> OutputType: + "Reconstruct output tensors according to their saved metadata and alias information" + + # Cached tensors will not yet be set on the first execution + # They are also cleared in checkpointing, so if we checkpoint this node + # and then execute it again we will need to repopulate cached tensors + if not self.cached_tensor_outputs: + self._initialize_cached_tensors() + + outputs: OutputType = [] + + for i, (storage_info, metadata) in enumerate( + zip(self.output_storage_alias, self.outputs_metadata) + ): + if not isinstance(metadata, dict): # tensor metadata + assert isinstance(metadata, (int, type(None))) + outputs.append(metadata) + continue + + cached_t = self.cached_tensor_outputs[i] + if cached_t is not None: + # this output represents a fresh allocated tensor. + # We return the same TensorImpl from run to run to avoid overhead. + # autograd.Function will reset the Autograd meta of output tensors + # as part of aot_autograd, but _backward_hooks are stored on tensors separately, + # so we need to manually reset hooks. + if cached_t._backward_hooks is not None: + cached_t._backward_hooks = None + + # No need to update weakrefs, already correctly initialized + outputs.append(cached_t) + continue + + static_t = self.static_output_tensors[i] + if static_t is not None: + assert self.outputs_weakrefs[i] is None + outputs.append(static_t) + continue + + storage = self.prepare_alias_info_for_tensor_construction( + storage_info, metadata + ) + + if isinstance(storage, UntypedStorage) or storage is None: + out = self._reconstruct_from_tensor_metadata(metadata, storage) + else: + assert isinstance(storage, int) + out = self._reconstruct_from_tensor_metadata( + metadata, cast(torch.Tensor, outputs[storage]).untyped_storage() + ) + + outputs.append(out) + w = self.outputs_weakrefs[i] + assert w is not None + w.swap_weakref(out.untyped_storage()._weak_ref()) + + return outputs + + def prepare_alias_info_for_tensor_construction( + self, + out_alias_info: Optional[OutputAliasInfo], + metadata: Union[Dict[str, Any], int, None], + ) -> Union[UntypedStorage, None, int]: + if ( + isinstance(metadata, (int, type(None))) + or out_alias_info is UnaliasedStorage + ): + return None + + if isinstance(out_alias_info, AliasesPriorGraphOutput): + depth, existing_output_index = out_alias_info.index + ref = self.path_weakrefs[depth][existing_output_index] + assert ref is not None + return torch.UntypedStorage._new_with_weak_ptr(ref()) + + assert isinstance(out_alias_info, AliasesNewOutput) + return out_alias_info.index + + def prepare_storages_for_construction( + self, + ) -> List[Union[UntypedStorage, None, int]]: + output_storages = [] + for output_storage_alias, metadata in zip( + self.output_storage_alias, self.outputs_metadata + ): + output_storages.append( + self.prepare_alias_info_for_tensor_construction( + output_storage_alias, metadata + ) + ) + + return output_storages + + def run_graph(self) -> None: + assert self.graph is not None + self.graph.replay() + + def all_outputs_are_dead(self) -> bool: + "All outputs of the path from this node to its root are dead" + for depth, output_index in self.live_indices_after_graph: + if is_live(self.path_weakrefs[depth][output_index]): + return False + return True + + def _record(self, model: ModelType, inputs: List[InputType]) -> OutputType: + "Record the model" + + def static_input_iter() -> Generator[torch.Tensor, None, None]: + for i in self.wrapped_function.static_input_idxs: + _inp = inputs[i] + if isinstance( + _inp, torch.Tensor + ) and not self._is_cuda_graph_recorded_tensor(_inp): + yield _inp + + # see: output_is_alias_of_persistent_static_inputs above + static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = { + inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp) + for inp in itertools.chain( + static_input_iter(), self.wrapped_function.constants + ) + } + + if config.triton.slow_path_cudagraph_asserts: + # need to use parent live weakrefs because live_indices isnt set yet + memory = ( + [] if self.parent is None else list(self.parent.path_live_weakrefs()) + ) + memory += [ + StorageWeakRefWrapper(elem) + for i, elem in enumerate(inputs) + if isinstance(elem, torch.Tensor) + and i not in self.wrapped_function.static_input_idxs + and elem.untyped_storage().data_ptr() != 0 + ] + check_memory_pool(self.device, self.cuda_graphs_pool, memory) + + with preserve_rng_state(), torch.cuda.device( + self.device + ), clear_cublas_manager(), torch.cuda.graph( + self.graph, + stream=self.stream, + pool=self.cuda_graphs_pool, + capture_error_mode="thread_local", + ), get_history_recording(): + static_outputs = model(inputs) + + # running model should reclaim memory + assert len(inputs) == 0 + + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + self._add_first_outputs(static_outputs, static_input_persistent_storage_ptrs) + + return static_outputs + + def _add_first_outputs( + self, + outputs: OutputType, + static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper], + ) -> None: + "Add the outputs from the first invocation of the node and set up metadata" + + # getting liveness before we have added the outputs to path, so the length + # of the two lists is equal + prev_liveness = self.recorded_liveness_before_graph + curr_liveness = self._get_liveness(self.path_weakrefs) + + delta = self._get_different_indices(prev_liveness, curr_liveness) + self.expected_dead_indices_after_graph = delta + + assert len(self.outputs_weakrefs) == 0 + # index from data pointer to index in outputs + output_new_storages_index: Dict[StorageDataPtr, int] = {} + + self.unaliased_in_all_paths = [False for _ in range(len(outputs))] + self.static_output_tensors = [None for _ in range(len(outputs))] + + for i, o in enumerate(outputs): + if o is None or not isinstance(o, torch.Tensor): + self.output_storage_alias.append(UnaliasedStorage) + continue + + torch._check( + o.is_cuda or o.untyped_storage().data_ptr() == 0, + lambda: ( + "Expected all cuda outputs in cuda graph recording. Non cuda output " + f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" + ), + ), + + ref = static_input_persistent_storage_ptrs.get( + o.untyped_storage().data_ptr(), None + ) + # also treat empty storages as static outputs because we do not need to manage their lifetime + # and they should not participate in checkpointing + is_empty_storage = o.untyped_storage().data_ptr() == 0 + if (ref and ref() is not None) or is_empty_storage: + self.output_storage_alias.append(None) + self.static_output_tensors[i] = o + continue + + path_ref = self._is_alias_of_live_recorded_tensor(o) + if path_ref is not None: + self._mark_prior_graph_output_as_aliased(path_ref) + self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref)) + continue + + if o.untyped_storage().data_ptr() in output_new_storages_index: + index = output_new_storages_index[o.untyped_storage().data_ptr()] + self.unaliased_in_all_paths[index] = False + self.output_storage_alias.append(AliasesNewOutput(index)) + continue + + output_new_storages_index[o.untyped_storage().data_ptr()] = i + self.output_storage_alias.append(UnaliasedStorage) + self.unaliased_in_all_paths[i] = True + + if self.stack_traces is None: + self.stack_traces = [None for _ in range(len(outputs))] + else: + assert len(self.stack_traces) == len( + outputs + ), "Wrong number of stack traces passed in" + + assert not self.outputs_weakrefs + for out, static_output_tensor in zip(outputs, self.static_output_tensors): + if not isinstance(out, torch.Tensor) or static_output_tensor is not None: + self.outputs_weakrefs.append(None) + self.tensor_weakrefs.append(None) + else: + self.outputs_weakrefs.append(StorageWeakRefWrapper(out)) + self.tensor_weakrefs.append(TensorWeakRef(out)) + + self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs) + self.checkpointed_caching_state = torch._C._cuda_getCheckpointState( + self.device, self.cuda_graphs_pool + ) + + # now, get liveness with outputs added + for depth in range(len(self.path_weakrefs)): + for output_index in range(len(self.path_weakrefs[depth])): + if is_live(self.path_weakrefs[depth][output_index]): + self.live_indices_after_graph.append((depth, output_index)) + + self.debug_check_invariants_after_invocation() + if config.triton.slow_path_cudagraph_asserts: + check_memory_pool( + self.device, self.cuda_graphs_pool, list(self.path_live_weakrefs()) + ) + + def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex) -> None: + "Remove a graph output from the unaliased, cached tensors in an ancestor node" + depth, output_index = index + node = list(self._path_from_root)[depth] + node.unaliased_in_all_paths[output_index] = False + x = self.path_weakrefs[depth][output_index] + assert x is not None + x.remove_extra_reference() + + def _initialize_cached_tensors(self) -> None: + # we should not be clearing output_weakrefs, and they should be set in the first + # record run + assert len(self.outputs_weakrefs) == len(self.outputs_metadata) + + for i, (storage_info, metadata, make_cached) in enumerate( + zip( + self.output_storage_alias, + self.outputs_metadata, + self.unaliased_in_all_paths, + ) + ): + if not make_cached: + self.cached_tensor_outputs.append(None) + continue + + assert storage_info is UnaliasedStorage + assert isinstance(metadata, dict) + s = self.create_storage(metadata) + out = self._reconstruct_from_tensor_metadata(metadata, storage=s) # type: ignore[arg-type] + + # XXX: let autograd know that there will be an additional reference to the tensor + # that can be ignored when deciding whether to do gradient buffer inplacing. + # Otherwise, inplacing could differ between tracing and subsequent execution. + # For some models we tested this led to inputs no longer being in cudagraph pools, + # leading to spurious re-recordings. + # It also tells AMP cache that even though the tensor impls cannot be cached + # in dtype conversions. + + torch._C._add_cached_tensor(out) + + self_ref = weakref.ref(self) + + # one reference in our array, and calling sys.getrefcount bumps the refcount by one + def check_refcount(i: int) -> bool: + self_loc = self_ref() + if self_loc is None: + return False + return self_loc.get_output_refcount(i) == 2 + + check = functools.partial(check_refcount, i=i) + + self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check) + self.cached_tensor_outputs.append(out) + + def get_output_refcount(self, index: int) -> int: + return sys.getrefcount(self.cached_tensor_outputs[index]) + + @property + def parent(self) -> Optional[CUDAGraphNode]: + "unwraps the weakref to _parent" + return self._parent() if self._parent is not None else None + + @property + def _path_to_root(self) -> Generator[CUDAGraphNode, None, None]: + "Returns all nodes in the path starting at self and ending at root" + node = self + while node: + yield node + node = node.parent # type: ignore[assignment] + + @property + def _path_from_root(self) -> Generator[CUDAGraphNode, None, None]: + "Returns all nodes in the path starting at the root and ending at self" + nodes = reversed(list(self._path_to_root)) + yield from nodes + + def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool: + "Is this tensor an output of a node in this path" + for output_refs in self.path_weakrefs: + for storage_weak_ref in output_refs: + if storage_weak_ref is None: + continue + # don't need to check liveness of storage since the cuda graph managed + # memory is never released. + data_ptr = storage_weak_ref.data_ptr() + if t.untyped_storage().data_ptr() == data_ptr: + return True + + return False + + def _is_alias_of_live_recorded_tensor( + self, t: torch.Tensor + ) -> Optional[PathOutputIndex]: + for depth, output_refs in enumerate(self.path_weakrefs): + for output_index, storage_ref in enumerate(output_refs): + if (storage_and_ptr := maybe_deref(storage_ref)) is not None: + storage, ptr = storage_and_ptr + if ptr == t.untyped_storage().data_ptr(): + return (depth, output_index) + + return None + + @staticmethod + def _check_liveness( + indices: List[PathOutputIndex], + output_refs: List[List[Optional[StorageWeakRefWrapper]]], + ) -> bool: + "Check that all of the indices specified are dead references" + for depth, output_index in indices: + w = output_refs[depth][output_index] + assert w is not None + if w() is not None: + return False + return True + + def add_child(self, function_id: FunctionID, node: CUDAGraphNode) -> None: + "Adds node as a a child of self" + self.children[function_id].append(node) + + @staticmethod + def _get_different_indices( + prev: List[List[bool]], curr: List[List[bool]] + ) -> List[PathOutputIndex]: + "Find indices where the two lists differ." + dead_indices = [] + assert len(prev) <= len(curr) + for i, (outputs1, outputs2) in enumerate(zip(prev, curr)): + assert len(outputs1) == len(outputs2) + for j, (output1, output2) in enumerate(zip(outputs1, outputs2)): + if output1 != output2: + dead_indices.append((i, j)) + + return dead_indices + + @staticmethod + def _get_liveness( + weakrefs: List[List[Optional[StorageWeakRefWrapper]]], + ) -> List[List[bool]]: + "Maps weakrefs to true if the reference is alive and false otherwise" + if len(weakrefs) == 0: + return [] + + return [pytree.tree_map(is_live, outputs) for outputs in weakrefs] + + def debug_assert_invariants( + self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex] + ) -> None: + if not config.triton.fast_path_cudagraph_asserts: + return + + for i, node in enumerate(self._path_from_root): + assert self.path_weakrefs[i] is node.outputs_weakrefs + + nodes = list(self._path_from_root) + + live_blocks = get_block_addrs(self.cuda_graphs_pool) + + live_storage_data_ptrs = set() + live_storage_weak_ptrs = set() + + for depth, outputs_liveness in enumerate(expected_liveness): + for output_idx, output_liveness in enumerate(outputs_liveness): + # tensor can die early, but it can't be alive when it should be dead + w = self.path_weakrefs[depth][output_idx] + if (stor_weak_ptr_and_data_ptr := maybe_deref(w)) is not None: + assert output_liveness + stor_weak_ptr, stor_data_ptr = stor_weak_ptr_and_data_ptr + assert (stor_data_ptr in live_storage_data_ptrs) == ( + stor_weak_ptr in live_storage_weak_ptrs + ) + live_storage_data_ptrs.add(stor_data_ptr) + live_storage_weak_ptrs.add(stor_weak_ptr) + + is_persistent_alias = ( + nodes[depth].static_output_tensors[output_idx] is not None + ) + + if is_persistent_alias: + assert stor_data_ptr not in live_blocks + + for depth, output_index in newly_dead: + assert not is_live(self.path_weakrefs[depth][output_index]) + + def debug_check_invariants_before_invocation(self) -> None: + self.debug_assert_invariants( + self.recorded_liveness_before_graph, self.expected_dead_indices_before_graph + ) + + def debug_check_invariants_after_invocation(self) -> None: + self.debug_assert_invariants( + self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph + ) + + def data_ptrs_dead_since_invocation(self) -> List[int]: + """ + Since this node was invoked, return data ptrs of all tensor outputs that have died + in the current executing tree path. + """ + curr_liveness = self._get_liveness(self.path_weakrefs) + _get_different_indices = self._get_different_indices( + self.recorded_liveness_after_graph, curr_liveness + ) + + path = list(self._path_from_root) + ptrs_to_deallocate = [] + for depth, output_index in _get_different_indices: + ptrs_to_deallocate.append( + path[depth].outputs_metadata[output_index]["data_ptr"] # type: ignore[index] + ) + + return ptrs_to_deallocate + + def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: + for i, j in self.live_indices_after_graph: + out = self.path_weakrefs[i][j] + if out is not None and is_live(out): + yield out + + def remove_node_cached_tensors(self) -> None: + for t in self.cached_tensor_outputs: + if t is not None: + torch._C._remove_cached_tensor(t) + self.cached_tensor_outputs.clear() + + for i, unaliased in enumerate(self.unaliased_in_all_paths): + if unaliased: + n = self.outputs_weakrefs[i] + assert n is not None + n.remove_extra_reference() + + def remove_path_cached_tensors(self) -> None: + for node in self._path_from_root: + node.remove_node_cached_tensors() + + def clear_path_state(self) -> None: + "Clear the path state in this current executing node" + # this doesnt actually do anything right now, leaving it as placeholder + + @staticmethod + def _tensor_metadata( + x: torch.Tensor, ignore_storage_offset: bool = True + ) -> Dict[str, Any]: + assert isinstance(x, torch.Tensor) + # We ignore the storage offset for inputs, but not for outputs + # TODO: - should we make the storage resizable ? + return { + "nbytes": x.untyped_storage().nbytes(), + "data_ptr": x.untyped_storage().data_ptr(), + "size": x.shape, + "stride": x.stride(), + "dtype": x.dtype, + "device": x.device, + "storage_offset": x.storage_offset() if not ignore_storage_offset else 0, + } + + def _reconstruct_from_tensor_metadata( + self, metadata: Dict[str, Any], storage: Optional[UntypedStorage] = None + ) -> Tensor: + s = self.create_storage(metadata) if storage is None else storage + return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s) # type: ignore[arg-type] + + def create_storage(self, metadata: Dict[str, Any]) -> torch.types.Storage: + return torch._C._construct_storage_from_data_pointer( + metadata["data_ptr"], metadata["device"], metadata["nbytes"] + ) + + def _allocate_and_copy_recording_inputs( + self, inputs: List[InputType] + ) -> List[Union[torch.Tensor, int]]: + """ + Allocate inputs for non static, non cudagraph managed tensors in the memory pool + and copy over the tensor values. + """ + + torch.cuda.synchronize() + self.stream.wait_stream(torch.cuda.current_stream()) + recording_inputs: List[InputType] = [] + + with warnings.catch_warnings(record=True), torch.cuda.device( + self.device + ), _use_cuda_memory_pool_manager( + self.device, + mem_pool=self.cuda_graphs_pool, + stream=self.stream, + ): + for i, inp in enumerate(inputs): + if not isinstance(inp, torch.Tensor): + assert isinstance(inp, int) + recording_inputs.append(inp) + elif i not in self.static_input_idxs: + # static_input does an allocation! + recording_inputs.append(static_input(inp)) + else: + recording_inputs.append(inp) + + self._copy_inputs_and_remove_from_src(recording_inputs, inputs) + + return recording_inputs + + def check_invariants( + self, inputs: List[InputType] + ) -> Tuple[CheckInvariantStatus, Callable[..., str]]: + """ + Checks if this node can be run. The same pattern of tensor liveness, static inputs, + and tensors managed in the cudagraph private pool must remain stable. + """ + + _logger = functools.partial( + log_data_ptr_mismatch, + self.wrapped_function.placeholders, + inputs, + self.static_input_data_ptrs, + ) + + # previously managed data pointers remain stable + # this is on the hot path so moved to C++. equivalent to: + # return all(t.data_ptr() == data_ptr for (t, data_ptr) in zip(tensors, data_ptrs)) + if not torch._C._tensors_data_ptrs_at_indices_equal( + inputs, # type: ignore[arg-type] + self.static_input_data_ptrs, + self.cudagraph_managed_idxs, + ): + status = CheckInvariantStatus.CudagraphManagedIdxMismatch + _logger = functools.partial( + _logger, + self.cudagraph_managed_idxs, + status, + ) + return status, _logger + + if not self._check_liveness( + self.expected_dead_indices_before_graph, self.path_weakrefs + ): + status = CheckInvariantStatus.ExpectedDeadIndicesBeforeGraphMismatch + return status, lambda: f"{status}" + + # static input data pointers should remain stable + # if we are inlining builtin nn modules we re-record in this case + # if we are not inlining builtin nn modules, we check this in check_static_inputs_are_stable + # and error if they are not stable + if ( + self.rerecord_if_static_inputs_change + and not torch._C._tensors_data_ptrs_at_indices_equal( + inputs, # type: ignore[arg-type] + self.static_input_data_ptrs, + self.static_input_idxs, + ) + ): + status = CheckInvariantStatus.StaticInputIdxMismatch + _logger = functools.partial( + _logger, + self.static_input_idxs, + status, + ) + return status, _logger + + # the cudagraph managed tensors which died upon recording must also die upon + # this invocation. it is too late to check after we've replayed the graph, + # because we would have already written over their memory. + for idx in self.cudagraph_managed_idxs: + inputs[idx] = None # type: ignore[call-overload] + + torch._check( + self._check_liveness( + self.expected_dead_indices_after_graph, self.path_weakrefs + ), + lambda: "TODO: graph recording observed an input tensor deallocate during graph " + " recording that did not occur during replay. Please file an issue.", + ) + return CheckInvariantStatus.SUCCESS, lambda: f"{CheckInvariantStatus.SUCCESS}" + + def num_descendants(self) -> int: + "Total number of descendents of this node" + num_desc = 0 + for children in self.children.values(): + for child in children: + num_desc += 1 + num_desc += child.num_descendants() + return num_desc + + +def get_cudagraph_segments(pool_id: Tuple[int, int]) -> Any: + segments = torch.cuda.memory_snapshot() + return [segment for segment in segments if segment["segment_pool_id"] == pool_id] + + +def get_block_addrs(pool_id: Tuple[int, int], live_only: bool = True) -> List[int]: + blocks = [] + + for segment in get_cudagraph_segments(pool_id): + addr = segment["address"] + for block in segment["blocks"]: + if block["state"] == "active_allocated" or not live_only: + blocks.append(addr) + + addr += block["size"] + + return blocks + + +def format_tb(frames: List[Any]) -> str: + formatted_traceback = [] + + for entry in frames: + formatted_traceback.append( + traceback.FrameSummary(entry["filename"], entry["line"], entry["name"]) + ) + + return "".join(traceback.format_list(formatted_traceback)) + + +def check_memory_pool( + device: int, + pool_id: Tuple[int, int], + live_storages_ptrs: List[StorageWeakRefWrapper], +) -> None: + assert all( + isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs + ) # noqa: C419 + unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} + + # check if there is a divergence first, then do the expensive snapshot call after + # we know it will error + if torch._C._cuda_checkPoolLiveAllocations(device, pool_id, unique_storages): + return + + # at this point we are past the fast-path. we have seen rare cases where a dead tensor is dead, + # but hasn't been gc'd yet, and gives false positive for allocated_not_in_live_storages + gc.collect() + + segments = get_cudagraph_segments(pool_id) + + allocated_not_in_live_storages = {} + + for segment in segments: + addr = segment["address"] + for block in segment["blocks"]: + if block["state"] == "active_allocated": + if addr not in unique_storages: + allocated_not_in_live_storages[addr] = block + else: + unique_storages.remove(addr) + + addr += block["size"] + + torch._check( + len(unique_storages) == 0, + lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}", + ) + + if len(allocated_not_in_live_storages) != 0: + formatted = [] + for dp, block in allocated_not_in_live_storages.items(): + trace = format_tb(block.get("frames", [])) + formatted.append(f"Data Pointer: {dp}, history: \n{trace}") + formatted_s = "\n".join(formatted) + msg = ( + f"These live storage data ptrs are in the cudagraph pool but not " + f"accounted for as an output of cudagraph trees: \n\n{formatted_s}" + ) + raise RuntimeError(msg) + + +class ExecutionState(Enum): + """ + Represents the state of the CUDAGraph Tree. Will be None if there is no live current memory allocated + in the cuda graph pool. Otherwise will reflect the state of the most recently executed node. + """ + + NONE = auto() + WARMUP = auto() + RECORDING = auto() + EXECUTION = auto() + + +class CompilationMode(Enum): + FORWARD = auto() + BACKWARD = auto() + INFERENCE = auto() + + +class CUDAGraphTreeManager: + """ + Groups individual recordings or executions of cuda graphs into a tree of recordings, + and checks required invariants, and manages warmups of graphs. + + When graphs are recorded in the same tree, it enforces subsequent execution + to follow the same order and have the same output tensor livespans. To remove + unnecessary coupling of cuda graphs (and additional imposed invariants), + the tree manager will end a currently recording tree whenever it is valid - when + the memory pool no longer has any live allocations. + + We ignore outputs from a previous generation that correspond to prior model outputs. + Currently this is hardcoded `GenerationTracker.generation` tracked in torch dynamo. + # TODO: make generation increment configurable, warn on overwrite. + + We run graph warmups in the cudagraph memory pool and return the result on the first invocation + of a function. For many models it is important to reclaim activations as you run the backward. + If we were to warm up the model and keep an extra copy of the inputs around to subsequently + use for recording, we would incur a memory penalty. Additionally, if we are part way through training + your model and need to recompile, memory will be allocated to the cuda graph pool, so we run this + warmup run in the cuda graph memory pool. As for recording, warm up needs the state of live tensors + to be accurately reflected so we checkpoint the allocator state if we need to warm up following graph + replay. + """ + + def __init__(self, device_index: int) -> None: + # roots are functions which have no dependencies on an other node. I.e., + # when they are first invoked, none of their inputs are outputs are outputs + # of another node, nor are there any live outputs of another node whose + # liveness would create a dependency. + self.roots: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list) + + # mapping from function id to wrapped function + self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {} + + self.ids_to_stack_traces: Dict[FunctionID, Optional[StackTraces]] = {} + + self.warmed_up_functions: Set[FunctionID] = set() + # if we fail to increment generation, and are stuck warming up, + # only warn on each function once + self.warned_functions: Set[FunctionID] = set() + torch._C._set_cached_tensors_enabled(True) + + # warn only once if a function mutates inputs + self.warned_mutation: Set[FunctionID] = set() + + # NB: cuda caching allocator will remember the stream a segment is allocated to + # and only allocate that segment to the same stream. we need to use a single stream + # for all allocations to the memory pool, otherwise the allocations to separate streams + # will not be reused; separate recordings would have use the same memory pool, but not + # the same memory. + + with torch.cuda.device(device_index): + torch.cuda.synchronize() + self.stream = torch.cuda.Stream() + self.stream.wait_stream(torch.cuda.current_stream()) + + # Keeps Memory Pool Alive + self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle() + + with warnings.catch_warnings(record=True), torch.cuda.graph( + self.graph, + pool=self.cuda_graphs_thread_pool, + stream=self.stream, + capture_error_mode="thread_local", + ): + pass + + self.graph_counter = itertools.count(0) + self.func_counter = itertools.count(0) + + # mapping from graph_id to (function id to mutation type hint) since we are + # specializing on a particular combination of Parent Node -> Function ID. + self.non_cudagraph_managed_mutation_hint: Dict[ + Optional[GraphID], Dict[FunctionID, bool] + ] = defaultdict(dict) + self.warmup_node_counter = itertools.count(start=-1, step=-1) + + # mapping from graph_id to (function id to re-record count). We fall back to + # eager function if a function is re-recorded frequently on a node. + self.num_rerecord: Dict[Optional[GraphID], Dict[FunctionID, int]] = defaultdict( + lambda: defaultdict(lambda: 0) + ) + + # whether we the current node is in a state of warmup, recording, execution. If + # there is no current node the state will be ExecutionState.None. + self.path_state = ExecutionState.NONE + self.device_index = device_index + + # the most recently invoked cudagraph wrapping of a function. Will be None + # when there is no output from a previous recording or execution whose memory + # we need to respect in the cuda caching allocation. If you incremented generation, + # this will also be none, as ignore those allocations. + self.current_node: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = None + + # current generation of cudagraph invocations. when torch.compile is run + # we increment the current generation. are willing to ignore live outputs + # of a previous generation in checking liveness. + self.current_gen: int = -1 + + # number of instances we are in execution and failed to match to an + # existing child + self.debug_fail_counter = 0 + # number of instances we had to checkpoint the function + self.debug_checkpointing_counter = 0 + + self.id_to_mode: Dict[FunctionID, CompilationMode] = {} + + # Note: [Backward Generation Handling] + # We generally perform a sequence of forward executions followed by backward executions. + # If multiple torch.compile wrapped forwards are executed with their backwards pending, + # we should not disregard the outputs from a prior torch.compile since the entire training + # loop hasn't completed. Occasionally, a backward pass corresponding to a forward pass may + # not be executed, so we cannot wait for all pending forward pass backward completions, so + # we cannot wait for all backwards to have been invoked. Instead we wait for a single backward + # invocation. Triggering a backward pass typically doesn't lead to another torch.compile + # invocation, making it less likely for the generation to increase between multiple + # backward calls. The following use case is covered by this approach: + # mod1 = torch.compile(...) + # mod2 = torch.compile(...) + # mod2(mod1(x)).sum().backward() + + self.running_forwards_with_pending_backwards = False + + def run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType: + assert self.graph is not None, "Running CUDAGraph after shutdown" + out = self._run(new_inputs, function_id) + + # The forwards are only pending following invocation, not before + mode = self.id_to_mode[function_id] + if mode == CompilationMode.FORWARD: + self.running_forwards_with_pending_backwards = True + elif mode == CompilationMode.BACKWARD: + self.running_forwards_with_pending_backwards = False + + return out + + def set_to_running_backward(self) -> None: + self.running_forwards_with_pending_backwards = False + + def _get_cuda_graph_recorded_tensor_checker(self) -> Callable[[Tensor], bool]: + return ( + self.current_node._is_cuda_graph_recorded_tensor + if isinstance(self.current_node, (CUDAGraphNode, CUDAWarmupNode)) + else lambda _: False + ) + + def new_warmup_node_id(self) -> GraphID: + return GraphID(next(self.warmup_node_counter)) + + def _update_non_cudagraph_managed_mutation( + self, function_id: FunctionID, inputs: List[InputType] + ) -> None: + node_id = self._get_node_id() + if maybe_mutation_str := check_for_mutation( + self.ids_to_funcs[function_id], + inputs, + self._get_cuda_graph_recorded_tensor_checker(), + ): + self.non_cudagraph_managed_mutation_hint[node_id][function_id] = True + # warn once per function_id + if function_id in self.warned_mutation: + return + self.warned_mutation.add(function_id) + log_cudagraph_skip_and_bump_counter(maybe_mutation_str) + else: + self.non_cudagraph_managed_mutation_hint[node_id][function_id] = False + + def _get_node_id(self) -> Optional[GraphID]: + if self.current_node is None: + return None + elif isinstance(self.current_node, (CUDAGraphNode, CUDAWarmupNode)): + return self.current_node.id + else: + raise RuntimeError(f"Unknown node type {type(self.current_node)}") + + def exceed_rerecord_limit( + self, node_id: Optional[GraphID], function_id: FunctionID + ) -> bool: + if torch._dynamo.config.inline_inbuilt_nn_modules: + return False + + return ( + self.num_rerecord[node_id][function_id] + > torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit + ) + + def _run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType: + # we will try to end the current execution lazily, since + # we dont want to do unnecessary checking of the existing outputs + # on the hot path, but both recording and warmup only happen once + # so we check up front + if self.in_recording: + self.try_end_curr_recording(function_id) + + if self.in_warmup: + self.try_end_curr_warmup(function_id) + + node_id = self._get_node_id() + if function_id not in self.non_cudagraph_managed_mutation_hint[node_id]: + self._update_non_cudagraph_managed_mutation(function_id, new_inputs) + + # Early exit if the function mutates inputs which are neither parameters/buffers nor + # cudagraph recorded tensors. This check should happen after `try_end_curr_recording` + # and `try_end_curr_warmup` which may change self.current_node. + if self.non_cudagraph_managed_mutation_hint[node_id][ + function_id + ] or self.exceed_rerecord_limit(node_id, function_id): + return self.ids_to_funcs[function_id].model(new_inputs) + + # warming up a function and subsequentally recording may use different memory addresses + # because both depend on the state of the caching allocator. if we warm up graph A, + # then warm up graph B and make more allocations, the subsequent recording of A will not + # necessarily use the same addresses as in the warm up. Thus any warm up of a node can only + # be followed by warm up runs. + if ( + ( + not ( + function_id in self.warmed_up_functions + or config.triton.skip_cudagraph_warmup + ) + ) + or self.in_warmup + or config.triton.force_cudagraphs_warmup + ): + # If we are in the middle of executing cuda graphs, then we need to checkpoint memory state. + # Both Recording and Warmup will be reflected in the allocator and dont need changes + if self.path_state == ExecutionState.EXECUTION: + self.apply_checkpoint_execution_state_in_allocator() + + return self.run_eager(new_inputs, function_id) + + assert not isinstance(self.current_node, CUDAWarmupNode) + child_nodes = ( + self.roots if self.current_node is None else self.current_node.children + ) + + if not self.in_recording: + unexpected_rerecord, unexpected_rerecord_reason = False, lambda: "" + for child in child_nodes[function_id]: + # here we are checking memory consistency between recording and execution, + # as well as things like stability of tensor locations, etc + # and other + status, status_logger = child.check_invariants(new_inputs) + if status == CheckInvariantStatus.SUCCESS: + return self.execute_node(child, new_inputs) + + if ( + status == CheckInvariantStatus.StaticInputIdxMismatch + or status == CheckInvariantStatus.CudagraphManagedIdxMismatch + ): + unexpected_rerecord = True + unexpected_rerecord_reason = status_logger + + # now that we know the new function can't be run as a child of the + # current node, if it is a root, try to end the current execution. + # as noted above, we want to do this lazily to avoid having to + # check all existing outputs + if self.current_node is not None and function_id in self.roots: + self.try_end_curr_execution() + + # run again to hit the root matching case which must succeed + if self.current_node is None: + return self.run(new_inputs, function_id) + + if len(self.ids_to_funcs[function_id].mutated_input_idxs) > 0: + self._update_non_cudagraph_managed_mutation(function_id, new_inputs) + if self.non_cudagraph_managed_mutation_hint[self._get_node_id()][ + function_id + ]: + return self.ids_to_funcs[function_id].model(new_inputs) + + # nb: run before checkpointing because checkpointing is slow, and we will + # be using the eager caching allocator pool which does not require live + # accounting of tensors in cudagraph allocator + if unexpected_rerecord: + curr_node_id = self._get_node_id() + self.num_rerecord[curr_node_id][function_id] += 1 + if self.exceed_rerecord_limit(curr_node_id, function_id): + _id = curr_node_id.id if curr_node_id else None + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraph due to function {function_id.id} exceeding max " + f"re-recording limit " + f"(={torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit}) " + f"on cudagraph node {_id} due to {unexpected_rerecord_reason()}." + ) + return self.ids_to_funcs[function_id].model(new_inputs) + + # at this point, we necessarily will do a new recording + self.debug_fail_counter += 1 + + self.try_end_curr_execution() + if self.current_node is not None: + self.apply_checkpoint_execution_state_in_allocator() + + # now, we are in a recording state ! + return self.record_function(new_inputs, function_id) + + def shutdown(self) -> None: + """ + Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn + might reference a backward which invokes a CUDA Graph Node, we have to manually clear them on shutdown + to avoid a reference cycle. + """ + nodes = [] + for roots in self.roots.values(): + nodes.extend(roots) + + while nodes: + node = nodes.pop() + for children in node.children.values(): + nodes.extend(children) + node.remove_node_cached_tensors() + node.graph = None + + self.graph = None + self.roots = None # type: ignore[assignment] + self.current_node = None + + def record_function( + self, new_inputs: List[InputType], function_id: FunctionID + ) -> OutputType: + assert not isinstance(self.current_node, CUDAWarmupNode) + graph_id = self.new_graph_id() + log.debug( + "Recording function %d of graph recording id %d", + function_id.id, + graph_id.id, + ) + torch.cuda.synchronize() + node = CUDAGraphNode( + self.ids_to_funcs[function_id], + graph_id, + self.current_node, + new_inputs, + self.cuda_graphs_thread_pool, + self.device_index, + self.ids_to_stack_traces[function_id], + self.stream, + ) + if self.current_node is None: + self.roots[function_id].append(node) + else: + self.current_node.add_child(function_id, node) + self.current_node = node + self.path_state = ExecutionState.RECORDING + self.update_generation() + torch.cuda.synchronize() + return node.run_first_inputs(new_inputs) + + def execute_node( + self, node: CUDAGraphNode, new_inputs: List[InputType] + ) -> OutputType: + self.current_node = node + self.path_state = ExecutionState.EXECUTION + self.update_generation() + return node.run(new_inputs) + + def run_eager( + self, new_inputs: List[InputType], function_id: FunctionID + ) -> OutputType: + # this is only stored on current node, because when we start a new path, + # we will deallocate it + already_warm = function_id in self.warmed_up_functions + if not already_warm: + log.debug("Running warmup of function %d", function_id.id) + else: + log.debug( + "Running eager of function %d because ancestor needed to warm up", + function_id.id, + ) + self.warmed_up_functions.add(function_id) + node = CUDAWarmupNode( + self.ids_to_funcs[function_id], + self.current_node, + self.cuda_graphs_thread_pool, + self.graph, + self.device_index, + self.ids_to_stack_traces[function_id], + self.stream, + already_warm, + self.new_warmup_node_id(), + ) + self.current_node = node + self.path_state = ExecutionState.WARMUP + self.update_generation() + return node.run(new_inputs) + + def new_graph_id(self) -> GraphID: + return GraphID(next(self.graph_counter)) + + def new_func_id(self) -> FunctionID: + return FunctionID(next(self.func_counter)) + + def add_function( + self, + model: ModelType, + inputs: List[InputType], + static_input_idxs: Sequence[int], + stack_traces: Optional[StackTraces], + mode: CompilationMode, + constants: Tuple[torch.Tensor, ...], + placeholders: Tuple[PlaceholderInfo, ...], + mutated_input_idxs: Tuple[int, ...], + ) -> Tuple[ModelType, OutputType,]: + id = self.new_func_id() + self.ids_to_stack_traces[id] = stack_traces + self.ids_to_funcs[id] = WrappedFunction( + model, + list(static_input_idxs), + id, + tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_cuda), + placeholders, + mutated_input_idxs, + ) + self.id_to_mode[id] = mode + fn = functools.partial(self.run, function_id=id) + + # container needs to set clean up when fn dies + get_container(self.device_index).add_strong_reference(fn) + return fn, fn(inputs) + + @property + def in_recording(self) -> bool: + return self.path_state == ExecutionState.RECORDING + + @property + def in_warmup(self) -> bool: + return self.path_state == ExecutionState.WARMUP + + def get_roots(self) -> Iterator[CUDAGraphNode]: + for nodes in self.roots.values(): + yield from nodes + + @property + def current_node(self) -> Optional[Union[CUDAGraphNode, CUDAWarmupNode]]: + return self._current_node + + @current_node.setter + def current_node( + self, value: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] + ) -> None: + self._current_node = value + if value is None: + self.path_state = ExecutionState.NONE + + def update_generation(self) -> None: + self.current_gen = self.get_curr_generation() + + @staticmethod + def get_curr_generation() -> int: + if MarkStepBox.mark_step_counter != 0: + return MarkStepBox.mark_step_counter + + return GenerationTracker.generation + + @staticmethod + def user_invoked_mark_step() -> bool: + return MarkStepBox.mark_step_counter != 0 + + def can_start_new_generation(self) -> bool: + if not self.in_new_torch_compile_invocation(): + return False + + if self.user_invoked_mark_step(): + return True + + return not self.running_forwards_with_pending_backwards + + def in_new_torch_compile_invocation(self) -> bool: + return self.current_gen != self.get_curr_generation() + + def try_end_curr_recording(self, function_id: FunctionID) -> None: + """ + Check if the current recording can be terminated, either because all outputs of the + previously recorded node are dead or because it was executed in a different + generation. Will set current_node to None and in_recording to False if successful. + """ + assert self.in_recording + assert self.current_node is not None + + # multiple invocations, allow overwriting the previous generation + if self.can_start_new_generation(): + self.dealloc_current_path_weakrefs() + self.clear_current_path_state_and_set_to_none() + return + + if self.current_node.all_outputs_are_dead(): + self.clear_current_path_state_and_set_to_none() + return + + self.check_warn_on_unable_to_start_executing(function_id) + + def try_end_curr_execution(self) -> None: + """ + Check if the current executing node can be terminated, either because all outputs of the + previously executed node are dead or because it was executed in a different generation. + Will set current_node to None if successful. + """ + + assert not self.in_recording + if self.current_node is None: + return + + if self.can_start_new_generation(): + self.clear_current_path_state_and_set_to_none() + return + + if self.current_node.all_outputs_are_dead(): + self.clear_current_path_state_and_set_to_none() + + def try_end_curr_warmup(self, function_id: FunctionID) -> None: + if self.can_start_new_generation(): + self.dealloc_current_path_weakrefs() + self.current_node = None + return + + assert self.current_node is not None + if self.current_node.all_outputs_are_dead(): + self.current_node = None + return + + self.check_warn_on_unable_to_start_executing(function_id) + + def check_warn_on_unable_to_start_executing(self, function_id: FunctionID) -> None: + "Warn if we in a potential loop where we are unable to hit fast path" + if ( + function_id in self.warned_functions + or not self.in_new_torch_compile_invocation() + ): + return + + assert self.current_node is not None + existing_nodes = [ + node + for node in self.current_node._path_from_root + if node.wrapped_function.id == function_id + ] + + if len(existing_nodes) <= 1: + return + + # repeated same pattern + parents = { + n.parent.wrapped_function.id + for n in itertools.chain(existing_nodes, (self.current_node,)) + if n.parent is not None + } + if len(parents) == len(existing_nodes): + return + + self.warned_functions.add(function_id) + warnings.warn( + "Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. " + "Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() " + "before each model invocation" + ) + + def dealloc_current_path_weakrefs(self) -> None: + assert self.current_node is not None + # TODO: we could also allow the these weak refs to continue to be allocated, + # but that adds some complications. + for node in self.current_node._path_from_root: + assert node.stack_traces is not None + assert len(node.tensor_weakrefs) == len(node.stack_traces) + for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces): + ten = None if t is None else t() + if ten is None: + continue + + stack_trace = ( + stack_trace.strip() + if stack_trace + else "[Could not find stack trace]" + ) + msg = ( + "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. " + f"Stack trace: {stack_trace}. " + "To prevent overwriting, clone the tensor outside of torch.compile() " + "or call torch.compiler.cudagraph_mark_step_begin() before each model invocation." + ) + torch._C._set_storage_access_error_msg(ten, msg) + + deleted = set() + for storage_ref in self.current_node.path_live_weakrefs(): + _storage_deref = storage_ref() + if _storage_deref and storage_ref.data_ptr() not in deleted: + deleted.add(storage_ref.data_ptr()) + torch._C._free_And_Remove_DeleterFn(_storage_deref) + + def clear_current_path_state_and_set_to_none(self) -> None: + assert isinstance(self.current_node, CUDAGraphNode) + self.current_node.clear_path_state() + self.current_node = None + + def apply_checkpoint_execution_state_in_allocator(self) -> None: + """ + Checkpoint the current execution state in the caching allocator so that + additional cudagraph recordings can be made respecting existent live storages. + """ + assert isinstance(self.current_node, CUDAGraphNode) + self.debug_checkpointing_counter += 1 + log.debug( + "Checkpointing cuda caching allocator state. Number of checkpoints %d", + self.debug_checkpointing_counter, + ) + + state = self.current_node.checkpointed_caching_state + device = self.current_node.device + assert state is not None and device is not None + + # currently we deallocate on instead of allowing stale recordings + stale_storages: List[int] = [] + + # remove cached tensors, otherwise they would prevent memory from being + # reclaimed in subsequent recordings + self.current_node.remove_path_cached_tensors() + live_storages_wrappers = list(self.current_node.path_live_weakrefs()) + + # path_live_weakrefs guarantees that t() will not be None + live_storages_weak_refs: list[int] = [t() for t in live_storages_wrappers] # type: ignore[misc] + ptrs_to_deallocate = self.current_node.data_ptrs_dead_since_invocation() + torch._C._cuda_setCheckpointPoolState( + device, state, stale_storages, live_storages_weak_refs + ) + + # NB: deduplicate aliased outputs + for ptr in set(ptrs_to_deallocate): + torch._C._cuda_cudaCachingAllocator_raw_delete(ptr) + + # Now the live blocks should be exactly equal to the live storages in private pool + if config.triton.slow_path_cudagraph_asserts: + check_memory_pool( + self.device_index, self.cuda_graphs_thread_pool, live_storages_wrappers + ) + for wrapper in live_storages_wrappers: + storage_ptr = wrapper() + assert storage_ptr is not None + assert torch._C._has_Standard_Deleter(storage_ptr) + assert wrapper.data_ptr() not in ptrs_to_deallocate + + def live_cudagraph_pool_storages_in_curr_execution( + self, + ) -> List[StorageWeakRefPointer]: + if self.current_node is None: + return [] + # explicitly ignoring previous recorded outputs from past path + # path_live_weakrefs() guarantees that t() will not be None + return [t() for t in self.current_node.path_live_weakrefs()] # type: ignore[misc] diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_utils.py b/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4f97e1daf60f623f3baea38f8b60f0466af81385 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_utils.py @@ -0,0 +1,330 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import torch +from torch._dynamo.utils import counters +from torch._inductor.utils import InputType + + +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + + +OutputType = List[Optional[Union[int, torch.Tensor]]] +ModelType = Callable[[List[InputType]], OutputType] + + +@dataclasses.dataclass(frozen=True) +class FunctionID: + "Unique counter of a function wrapped in cudagraphify_impl" + id: int + + +@dataclasses.dataclass(frozen=True) +class PlaceholderInfo: + """ + A serializable version of torch.fx.Node that contains information + pertinent to placeholder stack traces. We use these in logging and error messages + related to cudagraphs, and will cache these results. + """ + + name: str + stack_trace: Optional[str] + # This field is recursive, but never cyclic (since a node never uses itself) + users: List[PlaceholderInfo] + mutating_use_stack_trace: Optional[str] + + +@dataclasses.dataclass(frozen=True) +class WrappedFunction: + """ + Represents a function that you want to record for CUDA graph replay, + with a little more metadata so we can identify if we have an applicable + CUDA graph in our CUDA graph tree for it. + """ + + model: Callable[..., Any] + static_input_idxs: Sequence[int] + id: FunctionID + constants: Tuple[torch.Tensor, ...] + placeholders: Sequence[PlaceholderInfo] + mutated_input_idxs: Sequence[int] + + +def get_mutating_use_stack_trace_from_node( + placeholder_node: torch.fx.Node, +) -> Optional[str]: + # reinplaced uses might have a single, non-copy_ use + if len(placeholder_node.users) == 1: + return next(iter(placeholder_node.users)).meta.get("stack_trace", None) + + for use in placeholder_node.users: + if use.target == torch.ops.aten.copy_.default: + if stack_trace := use.meta.get("stack_trace", None): + return stack_trace + + return None + + +def get_mutating_use_stack_trace(placeholder_info: PlaceholderInfo) -> Optional[str]: + return placeholder_info.mutating_use_stack_trace + + +def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo: + name = placeholder_node.name + stack_trace = placeholder_node.meta.get("stack_trace", None) + users = [] + mutating_use_stack_trace = None + # Only recurse to users once, since we only care about user's stack traces + if placeholder_node.op == "placeholder": + users = [to_placeholder_info(i) for i in placeholder_node.users] + mutating_use_stack_trace = get_mutating_use_stack_trace_from_node( + placeholder_node + ) + + return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace) + + +def get_placeholder_info(graph: torch.fx.Graph) -> List[PlaceholderInfo]: + return [ + to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder" + ] + + +def format_default_skip_message(reason: str) -> str: + return f"skipping cudagraphs due to {reason}" + + +def get_mutation_stack_trace( + placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int] +) -> str: + stack_trace: Optional[str] = "" + + for idx in mutation_indices: + placeholder = placeholders[idx] + if stack_trace := get_mutating_use_stack_trace(placeholder): + break + + msg = format_default_skip_message( + f"mutated inputs ({len(mutation_indices)} instances)" + ) + if stack_trace: + return f"{msg}. Found from : \n {stack_trace}" + + return msg + + +def check_for_mutation( + func: WrappedFunction, + inputs: List[InputType], + is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool], +) -> Optional[str]: + # doesnt work for non-trees because the warmup run would apply mutation twice + if torch._inductor.config.triton.cudagraph_trees: + # checking if mutation is only on parameters/static inputs + mutation_indices: Sequence[int] = [ + idx + for idx in func.mutated_input_idxs + if not ( + idx in func.static_input_idxs + or is_cuda_graph_recorded_tensor(inputs[idx]) # type: ignore[arg-type] + ) + ] + else: + mutation_indices = func.mutated_input_idxs + + static_inputs_log.debug( + "check mutation static input indices: %s", func.static_input_idxs + ) + static_inputs_log.debug("check mutation mutation indices: %s", mutation_indices) + + return ( + get_mutation_stack_trace(func.placeholders, mutation_indices) + if mutation_indices + else None + ) + + +def _get_use_stack_trace(node) -> Optional[str]: + for use in node.users: + if stack_trace := use.meta.get("stack_trace", None): + return stack_trace + return None + + +def check_multiple_devices_or_any_cpu_nodes( + device_node_mapping: Dict[torch.device, torch.fx.Node] +) -> Optional[str]: + if cpu_node := device_node_mapping.get(torch.device("cpu")): + msg = f"cpu device ({cpu_node.name})" + if stack_trace := _get_use_stack_trace(cpu_node): + return format_default_skip_message(f"{msg}. Found from : \n {stack_trace}") + + return format_default_skip_message(msg) + + if ( + len(device_node_mapping) == 1 + and next(iter(device_node_mapping.keys())).type == "cuda" + ): + return None + + keys_repr = (repr(key) for key in device_node_mapping.keys()) + return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}") + + +def check_lowering_disable_cudagraph( + device_node_mapping: Dict[torch.device, torch.fx.Node] +): + return check_multiple_devices_or_any_cpu_nodes(device_node_mapping) + + +def log_cudagraph_skip_and_bump_counter(msg): + perf_hint_log.warning(msg) + counters["inductor"]["cudagraph_skips"] += 1 + + +@dataclasses.dataclass +class BoxedDeviceIndex: + value: Optional[int] + + def set(self, device_idx: Optional[int]): + assert device_idx is None or isinstance(device_idx, int) + self.value = device_idx + + +def check_for_mutation_ignore_cuda_graph_managed_tensor( + gm: torch.fx.GraphModule, compiled_graph, static_input_idxs: Sequence[int] +) -> Optional[str]: + default_msg = format_default_skip_message("mutated inputs") + + # doesnt work for non-trees because the warmup run would apply mutation twice + if torch._inductor.config.triton.cudagraph_trees: + unique_idxs = set(static_input_idxs) + # checking if mutation is only on parameters/static inputs + mutation_indices = [ + idx for idx in compiled_graph.mutated_input_idxs if idx not in unique_idxs + ] + has_mutation = len(mutation_indices) != 0 + if not has_mutation: + return None + placeholders = get_placeholder_info(gm.graph) + return get_mutation_stack_trace(placeholders, mutation_indices) + + else: + has_mutation = len(compiled_graph.mutated_inputs) != 0 + return None if not has_mutation else default_msg + + +def get_placeholder_stack_trace(placeholder: PlaceholderInfo) -> Optional[str]: + """ + Gets the first non-empty stack trace of a placeholder or its users. + """ + if placeholder.stack_trace: + return placeholder.stack_trace + + for user in placeholder.users: + if user.stack_trace: + return user.stack_trace + + return None + + +class CheckInvariantStatus(Enum): + # Check invariant succeeded + SUCCESS = 1 + + # Previously managed data pointers are not stable + CudagraphManagedIdxMismatch = 2 + + # Static tensor input addresses are not stable + StaticInputIdxMismatch = 3 + + # Expected dead indices before graph are live + ExpectedDeadIndicesBeforeGraphMismatch = 4 + + def __str__(self) -> str: + if self.name == "CudagraphManagedIdxMismatch": + return "cudagraph managed tensor data pointer changed" + elif self.name == "StaticInputIdxMismatch": + return "static input data pointer changed" + elif self.name == "ExpectedDeadIndicesBeforeGraphMismatch": + return "expected dead indices before graph are live" + else: + return f"{self.name}: {self.value}" + + +def log_data_ptr_mismatch( + placeholders: Sequence[PlaceholderInfo], + inputs: List[InputType], + recorded_data_ptr: Sequence[Optional[int]], + target_idxs: Sequence[int], + mismatch: CheckInvariantStatus, +) -> str: + """ + Logs the mismatch between input data pointers and recorded data pointers. + This checks only idxs in target_idxs. + """ + assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len( + placeholders + ), "length mismatch between inputs, recorded_data_ptr, and placeholders" + + t_tensors = [inputs[i] for i in target_idxs] + t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs] + error_msg = f"{mismatch}.\n" + for i, (tensor, data_ptr) in enumerate(zip(t_tensors, t_data_ptrs)): + assert isinstance(tensor, torch.Tensor) + index = target_idxs[i] + if tensor.data_ptr() != data_ptr: + placeholder = placeholders[index] + error_msg = ( + f"{error_msg}input name: {placeholder.name}. " + f"data pointer changed from {data_ptr} to {tensor.data_ptr()}. " + f"input stack trace: {get_placeholder_stack_trace(placeholder)}\n" + ) + return error_msg + + +def maybe_warning_due_to_dynamic_shape( + fn_cache: Dict[Tuple[int, ...], Callable[..., Any]], + new_int_key: Any, +) -> bool: + num_cudagraphs = len(fn_cache.keys()) + 1 + + def warn_msg(): + return ( + "CUDAGraph supports dynamic shapes by recording a new graph for each " + "distinct input size. Recording too many CUDAGraphs may lead to " + f"extra overhead. We have observed {num_cudagraphs} distinct sizes. " + "Please consider the following options for better performance: " + "a) padding inputs to a few fixed number of shapes; or b) set " + "torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. " + "Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None " + "to silence this warning." + ) + + if ( + torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit + and num_cudagraphs + > torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit + ): + perf_hint_log.warning(warn_msg()) + return True + + return False + + +@dataclasses.dataclass(frozen=True) +class CudagraphCachedInfo: + """ + Info needed to realign inputs + """ + + placeholders: Sequence[PlaceholderInfo] + stack_traces: List[Optional[str]] + cudagraph_fail_reasons: List[str] diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/debug.py b/.venv/lib/python3.11/site-packages/torch/_inductor/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..868833a425be4b13d76dbca5318498ee3b7e5a46 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/debug.py @@ -0,0 +1,693 @@ +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import os +import os.path +import pickle +import pstats +import shutil +import subprocess +from typing import Any, Callable, Dict, IO, Iterator, List, Optional, Type, Union +from unittest.mock import patch + +import torch +from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled +from torch import fx as fx +from torch._dynamo.repro.after_aot import save_graph_repro +from torch._dynamo.utils import get_debug_dir +from torch.fx.graph_module import GraphModule +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.fx.passes.tools_common import legalize_graph +from torch.utils._pytree import tree_map + +from . import config, ir # noqa: F811, this is needed +from .scheduler import ( + BaseSchedulerNode, + FusedSchedulerNode, + NopKernelSchedulerNode, + OutputNode, + SchedulerNode, +) +from .virtualized import V + + +log = logging.getLogger(__name__) + +SchedulerNodeList = List[Any] +BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"]) +GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"] + + +@functools.lru_cache(None) +def has_dot() -> bool: + try: + subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE) + return True + except subprocess.SubprocessError: + return False + + +def draw_buffers( + nodes: List[BaseSchedulerNode], + print_graph: bool = False, + fname: Optional[str] = None, +) -> None: + """ + Draw a graph in fname.svg. + """ + if not has_dot(): + log.warning("draw_buffers() requires `graphviz` package") + return + + if fname is None: + fname = get_graph_being_compiled() + + graph = create_fx_from_snodes(nodes) + + for node in graph.nodes: + if "fusion_meta" not in node.meta: + continue + group = node.meta["fusion_meta"].group + if isinstance(group, tuple): + if isinstance(group[1], int): + group = (group[1],) + else: + group = group[1] + + # gather meta data + dtype = None + if isinstance(node, ir.ComputedBuffer): + dtype = node.data.dtype + + metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type] + node.meta["tensor_meta"] = metadata + + if print_graph: + print(graph) + + gm = GraphModule({}, graph) + legalize_graph(gm) + gm.graph.lint() + draw_graph( + gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape + ) + + +def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph: + """ + Creates a FX Graph from a list of SchedulerNode objects. + """ + + def get_fake_func(name: str) -> Callable[..., int]: + def func1(*args: Any) -> int: + return 0 + + func1.__name__ = name + return func1 + + FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"]) + + buf_to_fx_node = {} + node_to_fx_node = {} + graph = torch.fx.Graph() + first_node = None + + outputs = [] + group: Any = None + # create call_function node for each Buffer and Kernel + for snode in snodes: + if snode.is_extern(): + node_type = "extern" + group = node_type + elif snode.is_template(): + node_type = "template" + group = node_type + elif isinstance(snode, NopKernelSchedulerNode): + node_type = "nop" + group = node_type + elif isinstance(snode, SchedulerNode): + node_type = "compute" + group = snode.group + elif isinstance(snode, FusedSchedulerNode): + node_type = "fused" + group = snode.group + else: + raise RuntimeError("Unknown node type") + + fused_name = torch._inductor.utils.get_fused_kernel_name( + snode.get_nodes(), "original_aten" + ) + func_name = f"{node_type}: {fused_name}" + node_func = get_fake_func(func_name) + kwargs = {} + if hasattr(snode, "get_device"): + kwargs = {"device": snode.get_device()} + fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type] + + def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool: + if isinstance(snode, FusedSchedulerNode): + return any(in_output(x) for x in snode.snodes) + return any( + isinstance(user.node, OutputNode) + for buf in snode.get_outputs() + for user in buf.users + ) + + if in_output(snode): + outputs.append(fx_node) + name = snode.get_name() + fx_node.name = name + + fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type) + + node_to_fx_node[name] = fx_node + for buf in snode.get_outputs(): + buf_to_fx_node[buf.get_name()] = fx_node + + if first_node is None: + first_node = fx_node + + # create edges between nodes + for snode in snodes: + name = snode.get_name() + deps = snode.read_writes.reads + + fx_node = node_to_fx_node[name] + new_args = [] + for dep in deps: + if dep.name in buf_to_fx_node: + dep_node = buf_to_fx_node[dep.name] + else: + with graph.inserting_before(first_node): + dep_node = graph.placeholder(dep.name) + buf_to_fx_node[dep.name] = dep_node + if dep_node == fx_node: # to avoid cycles + continue + new_args.append(dep_node) + + fx_node.args = tuple(new_args) + + graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) + return graph + + +def update_orig_fx_node_name_to_buf_name( + nodes: Optional[SchedulerNodeList], + node_name_to_buf_name: Dict[str, str], + parent_buf_name: Optional[str] = None, + n_origins: int = 0, +) -> None: + if nodes is None: + return + for node in nodes: + # for FusedSchedulerNode, traverse recursively into get_nodes() + buf_name = node.get_name() + children_nodes = node.get_nodes() + if children_nodes is not None and len(children_nodes) > 1: + update_orig_fx_node_name_to_buf_name( + children_nodes, + node_name_to_buf_name, + buf_name if parent_buf_name is None else parent_buf_name, + ) + continue + else: + assert len(children_nodes) == 1 and children_nodes[0] == node + + ir_node = node.node + if ir_node is None or ir_node.origins is None: + continue + for origin in ir_node.origins: + node_name = origin.name + # when buf1 and buf2 both have origin=node1 + # we draw node1 according to buf1 + if node_name not in node_name_to_buf_name: + node_name_to_buf_name[node_name] = ( + buf_name if parent_buf_name is None else parent_buf_name + ) + + +def get_node_name_to_buf_meta( + node_name_to_buf_name: Dict[str, str] +) -> Dict[str, BufMeta]: + buf_name_to_n_node = {} + for node_name, buf_name in node_name_to_buf_name.items(): + if buf_name not in buf_name_to_n_node: + buf_name_to_n_node[buf_name] = {node_name} + else: + buf_name_to_n_node[buf_name].add(node_name) + + node_name_to_buf_meta = {} + for node_name, buf_name in node_name_to_buf_name.items(): + n_node = len(buf_name_to_n_node[buf_name]) + node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node) + return node_name_to_buf_meta + + +def annotate_orig_fx_with_snodes( + gm: torch.fx.GraphModule, + snodes: SchedulerNodeList, +) -> None: + """ + Creates a FX Graph from a list of SchedulerNode objects. + """ + node_name_to_buf_name: Dict[str, str] = {} + update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name) + if node_name_to_buf_name is None: + return + node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name) + for node in gm.graph.nodes: + if node.name in node_name_to_buf_meta: + node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name) + + +@contextlib.contextmanager +def enable_aot_logging() -> Iterator[None]: + compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + import torch._functorch.aot_autograd + + log = logging.getLogger(torch._functorch.aot_autograd.__name__) + + stack = contextlib.ExitStack() + if not compile_debug: + try: + yield + finally: + stack.close() + return + + # Enable all graphs to be logged to a file by setting the flags to True + # and the log level of the file logger to DEBUG + stack.enter_context(patch("functorch.compile.config.debug_partitioner", True)) + + path = os.path.join(get_debug_dir(), "torchinductor") + os.makedirs(path, exist_ok=True) + + fh = logging.FileHandler( + os.path.join( + path, + f"aot_{get_aot_graph_name()}_debug.log", + ) + ) + fh.setLevel(logging.DEBUG) + fh.setFormatter( + logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") + ) + log.addHandler(fh) + try: + yield + finally: + log.removeHandler(fh) + stack.close() + + +class DebugContext: + _counter = itertools.count() + + @staticmethod + def create_debug_dir(folder_name: str) -> Optional[str]: + debug_dir = config.trace.debug_dir or get_debug_dir() + for n in DebugContext._counter: + dirname = os.path.join( + debug_dir, + "torchinductor", + f"{folder_name}.{n}", + ) + if not os.path.exists(dirname): + os.makedirs(dirname) + return dirname + return None + + def __init__(self) -> None: + self._prof = None + self._path = None + self._stack = contextlib.ExitStack() + + def copy(self, new_path: str) -> None: + if not self._path: + return + assert new_path.endswith(".debug"), new_path + from filelock import FileLock + + try: + with FileLock(f"{new_path}.lock"): + if os.path.exists(new_path): + shutil.rmtree(new_path) + shutil.copytree(self._path, new_path) + except OSError: + log.warning( + "Failed to copy debug files from %s to %s", self._path, new_path + ) + + def fopen( + self, + filename: str, + write_mode: str = "w", + *args: Any, + **kwargs: Any, + ) -> IO[Any]: + assert self._path + return open(os.path.join(self._path, filename), write_mode, *args, **kwargs) + + @contextlib.contextmanager + def fopen_context( + self, + filename: str, + write_mode: str = "w", + *args: Any, + **kwargs: Any, + ) -> Iterator[IO[Any]]: + assert self._path + with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f: + yield f + + def filename(self, suffix: str) -> str: + assert self._path + return os.path.join(self._path, suffix) + + def upload_tar(self) -> None: + if config.trace.upload_tar is not None: + import tarfile + + assert self._path + tar_file = os.path.join( + self._path, f"{os.path.basename(self._path)}.tar.gz" + ) + with tarfile.open(tar_file, "w:gz") as tar: + tar.add(self._path, arcname=os.path.basename(self._path)) + config.trace.upload_tar(tar_file) + + def __enter__(self) -> None: + if config.debug: + log = logging.getLogger("torch._dynamo") + prev_level = log.level + log.setLevel(logging.DEBUG) + + def reset_log_level(level: Any) -> None: + log.setLevel(level) + + self._stack.callback(reset_log_level, prev_level) + + self._stack.enter_context(V.set_debug_handler(self)) + + if not config.trace.enabled: + return + + self._path = self.create_debug_dir(get_aot_graph_name()) # type: ignore[assignment] + + if config.trace.debug_log: + self._setup_log_capture("debug.log", logging.DEBUG) + if config.trace.info_log: + self._setup_log_capture("info.log", logging.INFO) + + def _setup_log_capture( + self, + filename: str, + level: int, + ) -> None: + log = logging.getLogger("torch._inductor") + fd = self._stack.enter_context(self.fopen(filename)) + ch = logging.StreamHandler(fd) + ch.setLevel(level) + ch.setFormatter( + logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") + ) + log.addHandler(ch) + log.setLevel(min(log.level, level)) + self._stack.callback(log.removeHandler, ch) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + if self._prof: + self._prof.disable() + self._save_profile_data() + + if self._path: + self.upload_tar() + log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path) + self._stack.close() + + def _save_profile_data(self) -> None: + assert self._prof + self._prof.dump_stats(self.filename("compile.prof")) + with self.fopen("compile.stats") as fd: + stats = pstats.Stats(self._prof, stream=fd) + stats.strip_dirs() + stats.sort_stats("cumtime") + stats.print_stats(100) + stats.sort_stats("tottime") + stats.print_stats(100) + + def __getattr__(self, name: str) -> Optional[Callable[..., None]]: + if config.trace.enabled and getattr(config.trace, name): + try: + return getattr(DebugFormatter(self), name) + except Exception: + log.warning("Ignoring exception in debug code", exc_info=True) + return None + else: + + def ignored(*args: Any, **kwargs: Any) -> None: + pass + + return ignored + + +class DebugFormatter: + def __init__(self, handler: DebugContext) -> None: + self.fopen = handler.fopen + self.fopen_context = handler.fopen_context + self.filename = handler.filename + self.handler = handler + + def fx_graph( + self, + gm: torch.fx.GraphModule, + inputs: List[torch.Tensor], + ) -> None: + with self.fopen("fx_graph_runnable.py") as fd: + save_graph_repro(fd, gm, inputs, "inductor") + + with self.fopen("fx_graph_readable.py") as fd: + fd.write(gm.print_readable(print_output=False)) + + def fx_graph_transformed( + self, + gm: torch.fx.GraphModule, + inputs: List[torch.Tensor], + ) -> None: + with self.fopen("fx_graph_transformed.py") as fd: + fd.write(gm.print_readable(print_output=False)) + + def ir_pre_fusion(self, nodes: SchedulerNodeList) -> None: + self._write_ir("ir_pre_fusion.txt", nodes) + + def ir_post_fusion(self, nodes: SchedulerNodeList) -> None: + self._write_ir("ir_post_fusion.txt", nodes) + + def _write_ir( + self, + filename: str, + nodes: SchedulerNodeList, + ) -> None: + with self.fopen(filename) as fd: + log.info("Writing debug ir to %s", fd.name) + for node in nodes: + fd.write(node.debug_str()) + fd.write("\n\n\n") + + def graph_diagram(self, nodes: SchedulerNodeList) -> None: + draw_buffers(nodes, fname=self.filename("graph_diagram.svg")) + + def draw_orig_fx_graph( + self, + gm: torch.fx.GraphModule, + nodes: SchedulerNodeList, + ) -> None: + annotate_orig_fx_with_snodes(gm, nodes) + draw_graph( + gm, + fname=self.filename("orig_fx_graph_diagram.svg"), + clear_meta=False, + prog=GRAPHVIZ_COMMAND_SCALABLE, + parse_stack_trace=True, + dot_graph_shape=config.trace.dot_graph_shape, + ) + + def output_code(self, filename: str) -> None: + shutil.copy(filename, self.filename("output_code.py")) + + def log_autotuning_results( + self, + name: str, + input_nodes: List[ir.IRNode], + timings: Dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821 + elapse: float, + precompile_elapse: float, + ) -> None: + import json + + from .ir import FixedLayout + + def build_node_info(node: ir.IRNode) -> Dict[str, str]: + if hasattr(node, "name"): + node_name = node.name + else: + node_name = "" + node_info = { + "name": node_name, + "type": type(node).__name__, + } + try: + layout = node.get_layout() + if isinstance(layout, FixedLayout): + offset = 0 + try: + offset = int(layout.offset) + except Exception: + try: + offset = V.graph.sizevars.size_hint( + layout.offset, fallback=0 + ) + except Exception: + pass + static_layout = FixedLayout( + layout.device, + dtype=layout.dtype, + size=list(V.graph.sizevars.size_hints(layout.size)), + stride=list(V.graph.sizevars.size_hints(layout.stride)), + offset=offset, + ) + node_info["layout"] = str(static_layout) + else: + node_info["layout"] = str(node.get_layout()) + except Exception as e: + pass + try: + node_info["dtype"] = str(node.get_dtype()) + except Exception as e: + pass + try: + node_info["device"] = str(node.get_device()) + except Exception as e: + pass + try: + node_info["stride"] = str( + V.graph.sizevars.size_hints(node.get_stride()) + ) + except Exception as e: + pass + try: + node_info["size"] = str(V.graph.sizevars.size_hints(node.get_size())) + except Exception as e: + pass + try: + node_info["numel"] = str(V.graph.sizevars.size_hint(node.get_numel())) + except Exception as e: + pass + if hasattr(node, "data") and isinstance(node.data, ir.IRNode): + node_info["data"] = build_node_info(node.data) + return node_info + + general_properties = { + "op_name": name, + "cuda_device_name": torch.cuda.get_device_name(), + "cuda_device_count": torch.cuda.device_count(), + "input_nodes": [build_node_info(node) for node in input_nodes], + "autotuning_time": elapse, + "precompile_time": precompile_elapse, + } + with self.fopen_context( + "autotuning_result_json_list.txt", "at", encoding="utf-8" + ) as fd: + for caller, time in timings.items(): + info_dict = dict(caller.info_dict()) + info_dict.update(general_properties) + info_dict["benchmark_result"] = time + json.dump(info_dict, fd) + fd.write("\n") + + +@dataclasses.dataclass +class TensorMetadataHolder: + tensor_metadata: TensorMetadata + device: torch.device + + +save_args_cnt = itertools.count() + + +def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: + """ + This function is used to save arguments for a compile_fx_inner function call + to the file system. Later on one can replay the compile_fx_inner call + with the saved arguments using load_args_and_run_compile_fx_inner. + """ + + folder = "/tmp/inductor_saved_args" + if not os.path.exists(folder): + os.mkdir(folder) + + def handle_tensor(x: Any) -> Any: + """ + Pickle FakeTensor will result in error: + AttributeError: Can't pickle local object 'WeakValueDictionary.__init__..remove' + + Convert all Tensor to metadata. This may also makes pickle faster. + """ + if isinstance(x, torch.Tensor): + return TensorMetadataHolder(_extract_tensor_metadata(x), x.device) + else: + return x + + args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs)) + + fn_name = "compile_fx_inner" + path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl" + with open(path, "wb") as f: + pickle.dump((args_to_save, kwargs_to_save), f) + + if log.isEnabledFor(logging.DEBUG): + message = f""" +Arguments for a compile_fx_inner call is saved to {path}. To replay the call, +run the following: + +from torch._inductor.debug import load_args_and_run_compile_fx_inner +load_args_and_run_compile_fx_inner({path!r}) + """ + # call print rather than log.debug. log.debug will print message + # prefix for each line which makes the code snippet harder to be + # copied. + # Not a big deal since the code is already been guarded by checking + # the log level. + print(message) + + +def load_args_and_run_compile_fx_inner(path: str) -> Any: + from torch._inductor.compile_fx import compile_fx_inner + + with open(path, "rb") as f: + args, kwargs = pickle.load(f) + + def handle_tensor(x: Any) -> Any: + if isinstance(x, TensorMetadataHolder): + return torch._dynamo.testing.rand_strided( + x.tensor_metadata.shape, + x.tensor_metadata.stride, + x.tensor_metadata.dtype, + x.device, + ) + else: + return x + + fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + with fake_mode, config.patch("save_args", False): + args, kwargs = tree_map(handle_tensor, (args, kwargs)) + return compile_fx_inner(*args, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/decomposition.py b/.venv/lib/python3.11/site-packages/torch/_inductor/decomposition.py new file mode 100644 index 0000000000000000000000000000000000000000..0e067395f807121789f2a51fe4919f122f03f170 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/decomposition.py @@ -0,0 +1,980 @@ +# mypy: allow-untyped-decorators +import functools +import logging +import math +import sys +import typing +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch._decomp as decomp +import torch._prims_common as utils +import torch.ao.quantization.fx._decomposed +from torch._decomp import ( + core_aten_decompositions, + get_decompositions, + remove_decompositions, +) +from torch._decomp.decompositions import ( + _grid_sampler_2d as decomp_grid_sampler_2d, + pw_cast_for_opmath, +) +from torch._decomp.decompositions_for_rng import extra_random_decomps +from torch._dynamo.utils import counters +from torch._higher_order_ops.out_dtype import out_dtype +from torch._inductor.utils import pad_listlike +from torch._prims_common import ( + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + type_to_dtype, +) +from torch.fx.experimental.symbolic_shapes import definitely_true, guard_size_oblivious + +from . import config, inductor_prims +from .utils import ( + is_gpu, + needs_fallback_due_to_atomic_add_limitations, + use_scatter_fallback, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims +quantized = torch.ops.quantized +_quantized = torch.ops._quantized +quantized_decomposed = torch.ops.quantized_decomposed + +inductor_decompositions = get_decompositions( + [ + aten._adaptive_avg_pool2d_backward, + aten.addmv, + aten.arange, + aten.bitwise_and_, + aten.bitwise_or_, + aten.clamp_min_, + aten.dist, + aten.empty_like, + aten.flip, + aten.gelu, + aten.hardtanh, + aten.index_select, + aten.lcm, + aten.leaky_relu, + aten.linalg_vector_norm, + aten._log_softmax, + aten.max_pool2d_with_indices_backward, + aten._native_batch_norm_legit, + aten._native_batch_norm_legit_functional, + aten._native_batch_norm_legit_no_training, + aten._batch_norm_with_update, + aten._batch_norm_with_update_functional, + aten._batch_norm_no_update, + aten.batch_norm_backward, + aten.native_batch_norm, + aten.native_group_norm, + aten.native_layer_norm, + aten.nll_loss2d_backward, + aten._softmax, + aten.sin_, + aten.sqrt_, + out_dtype, + aten._to_copy, + aten.tril_indices, + aten.triu_indices, + aten.upsample_bilinear2d.vec, + quantized.linear_dynamic_fp16_unpacked_weight, + _quantized.wrapped_quantized_linear, + ] +) +decompositions = {**core_aten_decompositions(), **inductor_decompositions} + +# Remove unwanted decompositions included via the core ATen decompositions from +# the Inductor decomp table. +decomps_to_exclude = [ + aten._unsafe_index, + aten._unsafe_masked_index, + aten._unsafe_masked_index_put_accumulate, + aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py + aten._softmax_backward_data, + aten.clamp_max, + aten.clamp_min, + aten.glu, # inductor lowers this directly + aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass + aten.slice_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass + aten.split.Tensor, # inductor lowers this directly + aten.squeeze, # inductor lowers this directly + aten.sum, # inductor lowers this directly + aten.unbind, # inductor lowers this directly +] + +remove_decompositions(decompositions, decomps_to_exclude) + + +def register_decomposition( + ops: List[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]] +) -> Callable[..., Any]: + for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined] + if op in decompositions: + log.warning("duplicate decomp: %s", ops) + return decomp.register_decomposition(ops, decompositions) + + +# TODO: for now, inductor doesn't handle asserts +# because the condition is symbol -> tensor in the graph. +@register_decomposition([aten._assert_async.msg]) +def assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None: + return + + +# Following `assert_async_msg_decomp` and implement as non-op. +@register_decomposition([aten._functional_assert_async.msg]) +def functional_assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None: + return + + +@register_decomposition([aten.sym_constrain_range_for_size.default]) +def sym_constrain_range_for_size( + symbol: torch.SymInt, + *, + min: Optional[torch.types.Number] = None, + max: Optional[torch.types.Number] = None, +) -> None: + return + + +@register_decomposition([aten.clamp]) +@pw_cast_for_opmath +def clamp( + x: torch.Tensor, + min: Optional[torch.types.Number] = None, + max: Optional[torch.types.Number] = None, +) -> torch.Tensor: + if min is not None: + x = x.clamp_min(min) + if max is not None: + x = x.clamp_max(max) + return x + + +@register_decomposition([aten.full]) +def full( + size: List[Union[int, torch.SymInt]], + fill_value: torch.types.Number, + **kwargs: Any, +) -> torch.Tensor: + dtype = kwargs.get("dtype") + if dtype is None: + kwargs["dtype"] = type_to_dtype(type(fill_value)) + return torch.full(size, fill_value, **kwargs) + return NotImplemented + + +# Not really sure how to put this into the main library. PrimTorch wants +# empty_permuted to go to the prim, and typically users don't really want +# to decompose to empty_strided (but inductor is OK with it, because we are +# cool with strides and everything goes to empty_strided) +@register_decomposition([aten.empty_permuted.default]) +def empty_permuted( + size: List[Union[int, torch.SymInt]], + physical_layout: List[int], + **kwargs: Any, +) -> torch.Tensor: + perm = [0] * len(size) + for p, l in enumerate(physical_layout): + perm[l] = p + return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm) + + +@register_decomposition([aten.convolution_backward]) +def convolution_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + bias_sizes: List[int], + stride: Union[int, List[int]], + padding: Union[int, List[int]], + dilation: Union[int, List[int]], + transposed: bool, + output_padding: List[int], + groups: int, + output_mask: List[bool], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not output_mask[2] or not is_gpu(grad_output.device.type): + return NotImplemented + grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim()))) + grad_inp, grad_weight, _ = aten.convolution_backward( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + [output_mask[0], output_mask[1], False], + ) + return (grad_inp, grad_weight, grad_bias) + + +@register_decomposition([aten.round.decimals]) +def round_dec(x: torch.Tensor, decimals: int = 0) -> torch.Tensor: + ten_pow_decimals = 10.0**decimals + return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals) + + +@register_decomposition([aten.bmm]) +@pw_cast_for_opmath +def bmm( + self: torch.Tensor, + batch2: torch.Tensor, +) -> torch.Tensor: + if config.coordinate_descent_tuning: + if guard_size_oblivious(self.shape[1] == 1) or guard_size_oblivious( + batch2.shape[2] == 1 + ): + out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2) + return out + if self.device.type == "cpu": + if guard_size_oblivious(self.size(1) == 1) and guard_size_oblivious( + batch2.size(-1) == 1 + ): + counters["inductor"]["decompose_bmm"] += 1 + return torch.sum( + self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True + ).unsqueeze(1) + return NotImplemented + + +@register_decomposition([aten.addmm]) +@pw_cast_for_opmath +def addmm( + self: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, + beta: torch.types.Number = 1, + alpha: torch.types.Number = 1, +) -> torch.Tensor: + if self.device.type == "cpu": + if guard_size_oblivious(mat1.size(0) == 1) and guard_size_oblivious( + mat2.size(-1) == 1 + ): + counters["inductor"]["decompose_addmm"] += 1 + out = torch.sum( + mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True + ).unsqueeze(0) + return alpha * out + beta * self + if ( + guard_size_oblivious(mat1.size(0) == 1) + and definitely_true(mat2.size(0) <= 16) + and definitely_true(mat2.size(1) <= 16) + ): + counters["inductor"]["decompose_addmm"] += 1 + out = (mat1.T * mat2).sum(dim=0, keepdim=True) + return alpha * out + beta * self + return NotImplemented + + +@register_decomposition([aten.mm]) +@pw_cast_for_opmath +def mm( + self: torch.Tensor, + input2: torch.Tensor, +) -> torch.Tensor: + # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning. + # todo: Look into why and fix it (hopefully) + if config.coordinate_descent_tuning: + if guard_size_oblivious(self.shape[0] == 1) or guard_size_oblivious( + input2.shape[1] == 1 + ): + return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1) + if self.device.type == "cpu": + if ( + guard_size_oblivious(self.size(-1) == 1) + and guard_size_oblivious(self.size(0) > 0) + and guard_size_oblivious(input2.size(0) == 1) + and (self.dtype == input2.dtype) + and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32) + ): + counters["inductor"]["decompose_mm"] += 1 + return torch.cat([self[i, :] * input2 for i in range(self.size(0))]) + if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious( + input2.size(-1) == 1 + ): + counters["inductor"]["decompose_mm"] += 1 + return torch.sum( + self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True + ).unsqueeze(0) + return NotImplemented + + +# This pass does two things: +# - Eliminate cat when there is only one tensor input +# - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we +# don't remove ALL empty tensors, only the naughty ones) +@register_decomposition([aten.cat.default]) +def cat( + tensors: List[torch.Tensor], + dim: int = 0, +) -> torch.Tensor: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + def non_empty_tensor(x: torch.Tensor) -> bool: + # For better or worse, this is a valid cat: + # + # torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)]) + # + # We'd like to eliminate naughtiness like this for downstream passes + # like split_cat. The easiest way is to just drop such inputs + # (guarding that they are non-zero). + # + # Is it permissible for this filtering to be size-oblivious? A case + # where this could matter is cat([(2, 2), (u0,)], dim=0); if u0 + # happened to be zero, we would have liked to have filtered it out. + # But actually, the ONLY way this could have passed is if u0 == 0, + # so by the time we get here we have already installed a deferred + # runtime assert forcing u0 to be zero. So if this hasn't happened, + # we know that the unbacked SymInt has appropriate size and there are + # no problems. + if len(x.shape) == 1 and guard_size_oblivious(x.shape[0] == 0): + return False + + if dim < len(x.shape) and guard_size_oblivious(x.shape[dim] == 0): + return False + + return True + + filtered_tensors = list(filter(non_empty_tensor, tensors)) + + if len(filtered_tensors) == 1: + return filtered_tensors[0].clone() + elif 1 < len(filtered_tensors) < len(tensors): + # on the first call, when we remove empty tensors, we redispatch recursively + return aten.cat.default(filtered_tensors, dim) + + # optimization, avoid concat for single, repeated input + if len(filtered_tensors) > 1 and all( + t is filtered_tensors[0] for t in filtered_tensors + ): + inp = filtered_tensors[0] + shape = list(inp.shape) + dim = dim + len(inp.shape) if dim < 0 else dim + shape.insert(dim, len(filtered_tensors)) + return inp.unsqueeze(dim).expand(*shape).flatten(dim, dim + 1).clone() + + # when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed) + return NotImplemented + + +@register_decomposition([aten.angle]) +def angle(x: torch.Tensor) -> torch.Tensor: + if x.is_complex(): + return torch.where( + torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real) + ) + + # when x is real number + # if x >= 0, return 0 + # if x < 0, return pi + # if x is nan, return nan + _, dtype = elementwise_dtypes( + x, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device) + ret = torch.where(x < 0, pi, 0.0) + return torch.where(torch.isnan(x), float("nan"), ret) + + +@register_decomposition([aten.add]) +def add( + x: torch.Tensor, + y: torch.Tensor, + *, + alpha: Optional[torch.types.Number] = None, +) -> torch.Tensor: + # Require both x and y to be complex tensors. + x_is_complex_tensor = torch.is_tensor(x) and x.is_complex() + y_is_complex_tensor = torch.is_tensor(y) and y.is_complex() + if not x_is_complex_tensor or not y_is_complex_tensor: + return NotImplemented + z = y + if alpha is not None: + z = alpha * y + complex_type = torch.promote_types(x.dtype, y.dtype) + + # For complex typed `x`, `x.view(x.real.dtype)` doubles the last dimension and can cause problem + # when broadcasting the add. + def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor: + """Reshape tensor from [*initial_dims, last_dim] to *initial_dims, last_dim/2, 2]""" + # Get the current shape of the tensor + *initial_dims, last_dim = tensor.shape + + # Check if the last dimension is even. We should never reach here since `x.view(x.real.dtype)` + # doubles the last dimension for complex numbers. + if last_dim % 2 != 0: + raise AssertionError( + "The size of the last dimension must be even to reshape it to [..., last_dim/2, 2]" + ) + + # Reshape the tensor + new_shape = (*initial_dims, last_dim // 2, 2) + reshaped_tensor = tensor.view(new_shape) + return reshaped_tensor + + x_reshaped = reshape_tensor_complex(x.view(x.real.dtype)) + z_reshaped = reshape_tensor_complex(z.view(y.real.dtype)) + result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type) + return result + + +@register_decomposition([aten.conj_physical]) +def conj_physical(self: torch.Tensor) -> torch.Tensor: + assert not self.is_complex(), "TODO: implement this" + return self + + +@register_decomposition([aten.lift, aten.detach_]) +def lift(self: torch.Tensor) -> torch.Tensor: + return self + + +@register_decomposition([aten.bernoulli.default]) +def bernoulli( + self: torch.Tensor, + *, + generator: Optional[torch.Generator] = None, +) -> torch.Tensor: + assert generator is None + return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype) + + +@register_decomposition([aten.fmin, prims.fmin]) +def fmin(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.where(torch.isnan(other) | (other > self), self, other) + + +@register_decomposition([aten.fmax, prims.fmax]) +def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.where(torch.isnan(other) | (other < self), self, other) + + +@register_decomposition(aten.amax) +def amax( + self: torch.Tensor, + dim: Optional[int] = None, + keepdim: bool = False, +) -> torch.Tensor: + if self.dtype == torch.bool: + return torch.any(self, dim=dim, keepdim=keepdim) + return NotImplemented + + +@register_decomposition(aten.amin) +def amin( + self: torch.Tensor, + dim: Optional[int] = None, + keepdim: bool = False, +) -> torch.Tensor: + if self.dtype == torch.bool: + return torch.all(self, dim=dim, keepdim=keepdim) + return NotImplemented + + +@register_decomposition([aten.narrow_copy]) +def narrow_copy( + self: torch.Tensor, + dim: int, + start: int, + length: int, +) -> torch.Tensor: + return torch.narrow(self, dim, start, length).clone() + + +@register_decomposition([aten.view_copy.default]) +def view_copy_default( + self: torch.Tensor, + size: List[Union[int, torch.SymInt]], +) -> torch.Tensor: + return aten.view(self, size).clone() + + +@register_decomposition([aten.view_copy.dtype]) +def view_copy_dtype( + self: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + return self.to(dtype).clone() + + +def get_like_layout( + tensor: torch.Tensor, + memory_format: Optional[torch.memory_format] = None, +) -> torch.memory_format: + # TODO: _to_copy tensor to stride permutation + if memory_format is torch.preserve_format or memory_format is None: + return utils.suggest_memory_format(tensor) + else: + return memory_format + + +@register_decomposition(aten.rand_like) +def rand_like( + self: torch.Tensor, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return torch.rand( + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randn_like) +def randn_like( + self: torch.Tensor, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return torch.randn( + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.full_like) +def full_like( + self: torch.Tensor, + fill_value: Union[int, float], + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[torch.device] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> torch.Tensor: + return torch.full( + [*self.size()], + fill_value, + dtype=dtype or self.dtype, + layout=layout or self.layout, + device=device or self.device, + requires_grad=requires_grad, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randint_like.default) +def randint_like( + self: torch.Tensor, + high: int, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return aten.randint.low( + 0, + high, + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randint_like.low_dtype) +def randint_like_low( + self: torch.Tensor, + low: int, + high: int, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return aten.randint.low( + low, + high, + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randint.default) +def randint( + high: int, + size: List[Union[int, torch.SymInt]], + **kwargs: Any, +) -> torch.Tensor: + return aten.randint.low(0, high, size, **kwargs) + + +@register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default) +def linear_dynamic_fp16_unpacked_weight( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight) + return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight( + input, packed_weight, bias, weight.size()[0] + ) + + +@register_decomposition(_quantized.wrapped_quantized_linear.default) +def wrapped_quantized_linear( + input: torch.Tensor, + input_scale: torch.Tensor, + input_zero_point: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + bias: torch.Tensor, + out_scale: torch.Tensor, + out_zero_point: torch.Tensor, + out_channel: int, +) -> torch.Tensor: + packed_weight = torch.ops._quantized._wrapped_linear_prepack( + weight, weight_scale, weight_zero_point, bias + ) + return torch.ops._quantized._wrapped_quantized_linear_prepacked( + input, + input_scale, + input_zero_point, + packed_weight, + out_scale, + out_zero_point, + out_channel, + ) + + +@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack) +def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor: + def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor: + x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3)) + if sys.byteorder == "little": + return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None] + else: + return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None] + + scales = bitcast_u8_to_f32(packed[..., -8:-4]) + offsets = bitcast_u8_to_f32(packed[..., -4:]) + return packed[..., :-8].to(torch.float32) * scales + offsets + + +@register_decomposition([aten.grid_sampler_2d]) +@pw_cast_for_opmath +def grid_sampler_2d( + a: torch.Tensor, + grid: torch.Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, +) -> torch.Tensor: + # We do not expand the grid (_expand_grid=False) on cpu for performance reasons + # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x + # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2) + # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first. + # Thus we apply this hack to not expand the grid for this case. + _expand_grid = not ( + a.device == torch.device("cpu") + and interpolation_mode == 0 + and a.is_contiguous(memory_format=torch.contiguous_format) + ) + + output = decomp_grid_sampler_2d( + a, + grid=grid, + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + _expand_grid=_expand_grid, + ) + return output + + +@register_decomposition(aten._foreach_addcmul.Scalar) +def _foreach_addcmul_scalar( + self: List[torch.Tensor], + left_tensors: List[torch.Tensor], + right_tensors: List[torch.Tensor], + scalar: float = 1, +) -> List[torch.Tensor]: + return aten._foreach_add.List( + self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar + ) + + +@register_decomposition(aten._foreach_addcdiv.Scalar) +def _foreach_addcdiv_scalar( + self: List[torch.Tensor], + left_tensors: List[torch.Tensor], + right_tensors: List[torch.Tensor], + scalar: float = 1, +) -> List[torch.Tensor]: + return aten._foreach_add.List( + self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar + ) + + +@register_decomposition(aten._foreach_lerp.Scalar) +def _foreach_lerp_scalar( + start_tensors: List[torch.Tensor], + end_tensors: List[torch.Tensor], + weight: torch.types.Number, +) -> List[torch.Tensor]: + return aten._foreach_add.List( + start_tensors, + aten._foreach_mul.Scalar( + aten._foreach_sub.List(end_tensors, start_tensors), weight + ), + ) + + +@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd) +@register_decomposition(aten.miopen_batch_norm) +def miopen_batch_norm( + input: torch.Tensor, + weight: torch.Tensor, + bias: typing.Optional[torch.Tensor], + running_mean: typing.Optional[torch.Tensor], + running_var: typing.Optional[torch.Tensor], + training: bool, + exponential_average_factor: float, + epsilon: float, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + a, b, c = aten.native_batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + exponential_average_factor, + epsilon, + ) + + if training: + return (a, b, c) + return ( + a, + weight.new_zeros((0,)), + weight.new_zeros((0,)), + ) + + +@functools.lru_cache(None) +def fast_random_decomps() -> Dict[Any, Callable[..., Any]]: + return {**decompositions, **extra_random_decomps} + + +# TODO(aakhundov): replace this (and the above) Any by more +# specific type and fix all the cascading mypy errors +def select_decomp_table() -> Dict[Any, Callable[..., Any]]: + """decomps can change based on config""" + if config.fallback_random: + return decompositions + return fast_random_decomps() + + +@register_decomposition(aten.masked_scatter) +def masked_scatter( + self: torch.Tensor, + mask: torch.Tensor, + source: torch.Tensor, +) -> torch.Tensor: + from .codegen.common import BackendFeature, has_backend_feature + + if has_backend_feature(self.device, BackendFeature.MASKED_SCATTER_WITH_INDEX): + # This two-step algorithm is the same as eager CUDA, for eager CPU we + # use a 1-shot serial iteration. + self, mask = aten.broadcast_tensors([self, mask]) + source_idx = mask.reshape(-1).cumsum(0) - 1 + self_flat, mask_flat, source_flat = (x.flatten() for x in (self, mask, source)) + result = aten._unsafe_masked_index(source_flat, mask_flat, [source_idx], 0) + return torch.where(mask_flat, result, self_flat).view(self.shape) + return NotImplemented + + +@register_decomposition(quantized_decomposed.choose_qparams.tensor) +def choose_qparams_tensor( + input: torch.Tensor, + quant_min: int, + quant_max: int, + eps: float, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + min_val, max_val = torch.aminmax(input) + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.max(scale, torch.Tensor([eps])) + zero_point = quant_min - torch.round(min_val / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + return scale.to(torch.float64), zero_point.to(torch.int64) + + +@register_decomposition(aten.put) +def put( + self: torch.Tensor, + index: torch.Tensor, + source: torch.Tensor, + accumulate: bool = False, +) -> torch.Tensor: + flattened = self.flatten() + flattened = torch.index_put( + flattened, [index], source.reshape(index.shape), accumulate + ) + return flattened.reshape(self.shape) + + +@register_decomposition(aten.put_) +def put_( + self: torch.Tensor, + index: torch.Tensor, + source: torch.Tensor, + accumulate: bool = False, +) -> torch.Tensor: + out = aten.put(self, index, source, accumulate=accumulate) + return self.copy_(out) + + +@register_decomposition(aten._softmax_backward_data.default) +@pw_cast_for_opmath +def _softmax_backward_data( + grad_output: torch.Tensor, + output: torch.Tensor, + dim: int, + input_dtype: torch.dtype, +) -> torch.Tensor: + new_grad_output = grad_output * output + sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True) + # grad_input = new_grad_output - output * sum_new_grad + grad_input = inductor_prims.fma(-output, sum_new_grad, new_grad_output) + + # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor + # if grad_output.device == torch.device("cpu"): + # return grad_input.contiguous() + + if grad_output.dtype != input_dtype: + grad_input = grad_input.to(input_dtype) + return grad_input.contiguous() + + +@register_decomposition(aten.index_reduce) +def index_reduce( + self: torch.Tensor, + dim: int, + index: torch.Tensor, + src: torch.Tensor, + reduction_type: str, + *, + include_self: bool = True, +) -> torch.Tensor: + if reduction_type == "mean" and not needs_fallback_due_to_atomic_add_limitations( + self.dtype + ): + true_division = self.dtype.is_floating_point or self.dtype.is_complex + ones = torch.ones_like(src) + if include_self: + out = self + counts = torch.ones_like(self).index_add(dim, index, ones) + else: + out = self.index_fill(dim, index, 0) + counts = torch.zeros_like(self).index_add(dim, index, ones) + counts = counts.masked_fill(counts < 1, 1) + out = out.index_add(dim, index, src) + return out / counts if true_division else out // counts + + if use_scatter_fallback( + aten.scatter_reduce_.two, + reduction_type, + self.dtype, + src.dtype, + src.device.type, + True, + ): + return NotImplemented + + repeats = self.shape[dim + 1 :].numel() * self.shape[:dim].numel() + index_shape = (index.numel(), *self.shape[dim + 1 :], *self.shape[:dim]) + perm = (*range(self.ndim - dim, self.ndim), 0, *range(1, self.ndim - dim)) + scatter_index = ( + index.to(torch.int64) + .repeat_interleave(repeats) + .reshape(index_shape) + .permute(perm) + ) + return self.scatter_reduce( + dim, + scatter_index, + src, + reduction_type, + include_self=include_self, + ) + + +@register_decomposition(aten.max_pool2d_with_indices) +def max_pool2d_with_indices( + x: torch.Tensor, + kernel_size: List[int], + stride: Optional[Union[int, List[int]]] = None, + padding: Union[int, List[int]] = 0, + dilation: Union[int, List[int]] = 1, + ceil_mode: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if dilation == 1: + dilation = [1, 1] + + if padding == 0: + padding = [0, 0] + + if not stride: + stride = kernel_size + + kernel_size = pad_listlike(kernel_size, 2) + dilation = pad_listlike(dilation, 2) + padding = pad_listlike(padding, 2) + stride = pad_listlike(stride, 2) + + window_size = kernel_size[0] * kernel_size[1] + # We fallback when using non-default dilation or when the window size is too large + if ( + torch._inductor.lowering.should_fallback_max_pool2d_with_indices( + kernel_size, dilation + ) + or window_size > torch.iinfo(torch.int8).max + ): + return NotImplemented + + vals, offsets = prims._low_memory_max_pool2d_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ) + indices = prims._low_memory_max_pool2d_offsets_to_indices( + offsets, + kernel_size[1], + x.size(-1), + stride, + padding, + ) + return vals, indices diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/dependencies.py b/.venv/lib/python3.11/site-packages/torch/_inductor/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..95a48281977c87002e9ee486ae80e0b58b706d33 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/dependencies.py @@ -0,0 +1,745 @@ +# mypy: allow-untyped-defs +import abc +import dataclasses +import itertools +import logging +import re +import typing +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from unittest.mock import patch + +import sympy + +import torch +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.utils._ordered_set import OrderedSet + +from .codegen.common import index_prevent_reordering +from .utils import ( + get_dtype_size, + reduction_num_outputs, + sympy_index_symbol, + sympy_str, + sympy_subs, + VarRanges, +) +from .virtualized import OpsHandler, ReductionType, V + + +log = logging.getLogger(__name__) +is_indirect = re.compile(r"indirect|tmp").search + + +class Dep(abc.ABC): + name: str + index: sympy.Expr + + @abc.abstractmethod + def rename(self, renames: Dict[str, str]) -> "Dep": + pass + + @abc.abstractmethod + def get_numel(self) -> sympy.Expr: + pass + + @abc.abstractmethod + def numbytes_hint(self): + pass + + @abc.abstractmethod + def has_unbacked_symbols(self) -> bool: + pass + + @abc.abstractmethod + def is_contiguous(self) -> bool: + pass + + def normalize_with_stride_order(self, prefix="t"): + return self + + +@dataclasses.dataclass(frozen=True) +class MemoryDep(Dep): + name: str + index: sympy.Expr + var_names: Tuple[sympy.Symbol, ...] + size: Tuple[sympy.Expr, ...] + mode: Optional[str] = None + + def __repr__(self) -> str: + return f"MemoryDep({self.name!r}, {self.index}, {self.ranges}, {self.mode})" + + @property + def num_vars(self): + return len(self.var_names) + + def decide_loop_order_to_match(self, other): + """ + Can return None if not able to decide loop orders. + """ + assert self.num_vars == other.num_vars + + # ignore broadcast for now since broadcast causes extra 0 strides + # which makes it hard to decide the correct loop orders. + if self.num_vars != len(self.index.free_symbols): + return None + if other.num_vars != len(other.index.free_symbols): + return None + + # bail out if any size is 0 or 1 + # For size == 0, it's an empty tensor, any strides for that dimension + # are equivalent. Skip for simplicity and it may not matter that much. + # + # For size == 1, it cause cause tie for strides of different dimensions. + # Also when we first time create LoopBody in ComputedBuffer.simplify_and_reorder + # we can dependencies.index_vars_squeeze which should already sqeeuze + # the size == 1 dimensions. + if any(s == 0 or s == 1 for s in itertools.chain(self.size, other.size)): + return None + + # Extract strides for both expression + self_strides = V.graph.sizevars.stride_hints(self.index, self.var_names) + other_strides = V.graph.sizevars.stride_hints(other.index, other.var_names) + + # Even if the shape contains no 0/1, some complex index expression may + # still have duplicate stride values. Here is an example: + # https://gist.github.com/shunting314/511a7e1ec88aa2e1a8ec85d8445ab129 + # We don't reorder the loop for these cases for now, but in theory + # we could improve the algorithm to detect the correct loop orders. + if len(set(self_strides)) != len(self_strides) or len( + set(other_strides) + ) != len(other_strides): + log.debug( + "unable to decide loop order. self_dep=%s v.s. other_dep=%s, self_strides=%s v.s. other_strides=%s", + self, + other, + self_strides, + other_strides, + ) + return None + + # May hanppen if self and other are as follows + # MemoryDep('addmm_6', 393216*d0 + 768*d1 + d2, {d0: 16, d1: 512, d2: 768}, None) + # MemoryDep('addmm_6', 98304*d0 + d1 + 768*d2, {d0: 64, d1: 768, d2: 128}, None) + if set(self_strides) != set(other_strides): + return None + + stride_to_index = {s: i for i, s in enumerate(self_strides)} + order = [] + for s in other_strides: + order.append(stride_to_index[s]) + + assert set(order) == set(range(0, self.num_vars)) + return order + + def get_offset(self): + """ + Return the offset by setting every variable to be 0. + """ + return sympy_subs(self.index, dict.fromkeys(self.var_names, 0)) + + def normalize(self) -> "MemoryDep": + """ + Normalize by merging loops. The different to normalize_with_stride_order is, + this method does not reorder loops while normalize_with_stride_order reorder + loops based on stride order. + """ + return MemoryDep( + self.name, + *_RecordLoadStoreInner._normalize(self.index, self.ranges), # type: ignore[arg-type] + self.mode, + ) + + def normalize_with_stride_order(self, prefix="t"): + r""" + Used to decide if two MemoryDep does not equal due to different loop orders. + More specifically, when dep1 and dep2 are not equal, we can normalize + both and check if they are equal after that. If yes, then the mismatch is + caused by different loop orders. + """ + # import here to avoid circular import + from torch._inductor import ir + + strides = V.graph.sizevars.stride_hints(self.index, self.var_names) + + # pick a loop order with stride ordered decreasingly + order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) + stride_reorder = ir.same_reorder(order) + sizes = self.size + var_names = self.var_names + + new_reordered_sizes = stride_reorder(sizes) + new_reordered_var_names = stride_reorder(var_names) + + new_simplified_sizes, reindex, prune = V.graph.sizevars._simplify_loops( + new_reordered_var_names, + new_reordered_sizes, + index_prevent_reordering( + [self.index], new_reordered_var_names, new_reordered_sizes + ), + ) + + # now let's create new symbols with the passed in prefix + var_ranges, add_var = var_builder(prefix) + replacement = dict( + zip( + new_reordered_var_names, + reindex([add_var(x) for x in new_simplified_sizes]), + ) + ) + new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR + + out = MemoryDep(self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())) # type: ignore[arg-type] + return out + + @property + def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]: + """{c0: 128, c1: 512, ...}""" + return dict(zip(self.var_names, self.size)) + + def get_numel(self) -> sympy.Expr: + if self.is_indirect(): + numel = V.graph.get_numel(self.name) + else: + vars: OrderedSet[sympy.Basic] = OrderedSet(self.index.free_symbols) + numel = sympy.Integer(1) + for var, size in zip(self.var_names, self.size): + if var in vars: + numel = numel * size + return numel # type: ignore[return-value] + + def rename(self, renames: Dict[str, str]) -> "MemoryDep": + if self.name in renames: + return MemoryDep( + renames[self.name], + self.index, + var_names=self.var_names, + size=self.size, + mode=self.mode, + ) + return self + + def numbytes_hint(self): + return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( + V.graph.get_dtype(self.name) + ) + + def has_unbacked_symbols(self): + return len(free_unbacked_symbols(self.get_numel())) > 0 + + def is_contiguous(self) -> bool: + return isinstance(self.index, sympy.Symbol) and self.index in self.var_names + + def stride1_for_last_dim(self, result_for_complex_expression=True) -> bool: + """ + Whether the stride for the last dimension is 1. + """ + # python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_masked_scatter_cuda_float16 + # will exercise thru this corner case. + if len(self.var_names) == 0: + return True + + terms = self.index.args if isinstance(self.index, sympy.Add) else [self.index] + + last_sym = self.var_names[-1] + for term in terms: + if term is last_sym: + return True + + # Having a >1 stride for the last dimension is bad for perf + # return False. + if ( + isinstance(term, sympy.Mul) + and len(term.args) == 2 + and term.args[1] is last_sym + and isinstance(term.args[0], (int, sympy.Integer)) + and term.args[0] > 1 + ): + return False + + return result_for_complex_expression + + def is_scalar(self) -> bool: + if isinstance(self.index, sympy.Symbol): + return self.index not in self.var_names and not self.is_indirect() + return isinstance(self.index, (int, sympy.Integer)) + + def is_indirect(self) -> bool: + return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined] + + +@dataclasses.dataclass(frozen=True) +class StarDep(Dep): + name: str + mode: Optional[str] = None + + # depends on the entire buffer + @property + def index(self): + raise NotImplementedError("StarDep does not have an index") + + def get_numel(self) -> sympy.Expr: + return V.graph.get_numel(self.name) # type: ignore[return-value] + + def rename(self, renames: Dict[str, str]) -> "StarDep": + if self.name in renames: + return StarDep(renames[self.name], self.mode) + return self + + def numbytes_hint(self): + return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( + V.graph.get_dtype(self.name) + ) + + def has_unbacked_symbols(self): + return len(free_unbacked_symbols(self.get_numel())) > 0 + + def is_contiguous(self) -> bool: + return False + + def is_scalar(self) -> bool: + return False + + def is_indirect(self) -> bool: + return False + + +# Used for tracking mutation ordering +# if A reads a buffer and B mutates it +# B must be ordered after A +# +# This is useful for a variety of reasons. +# For example, if A's read is never actually used, we can eliminate it. +# Another case is if A's buffer ends up being fused away, we never need to +# materialize that buffer +@dataclasses.dataclass(frozen=True) +class WeakDep(Dep): + # Fake dependency on unused buffer + name: str + # Buffer that is doing the mutation + mutating_buf: str + + @property + def index(self): + raise NotImplementedError("WeakDep does not have an index") + + def get_numel(self) -> sympy.Expr: + return sympy.Integer(1) + + def rename(self, renames: Dict[str, str]) -> "WeakDep": + if self.name in renames: + return WeakDep(renames[self.name], self.mutating_buf) + return self + + def numbytes_hint(self): + return 1 # Purely inserted for ordering, not an actual dep + + def has_unbacked_symbols(self): + return False + + def is_contiguous(self) -> bool: + return False + + +@dataclasses.dataclass(frozen=True) +class IndexExprDep: + index: sympy.Expr # type: ignore[assignment] + var_names: Tuple[sympy.Symbol, ...] + size: Tuple[sympy.Expr, ...] + + +@dataclasses.dataclass +class ReadWrites: + reads: OrderedSet[Dep] + writes: OrderedSet[Dep] + index_exprs: OrderedSet[IndexExprDep] + range_vars: Optional[List[sympy.Expr]] = None + var_ranges: Optional[VarRanges] = None + + def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites": + return ReadWrites( + OrderedSet(dep.rename(renames) for dep in self.reads), + OrderedSet(dep.rename(renames) for dep in self.writes), + self.index_exprs, + self.range_vars, + self.var_ranges, + ) + + def with_read(self, dep: Union[Dep, Set[Dep]]) -> "ReadWrites": + assert isinstance(dep, (WeakDep, StarDep, set)) + if not isinstance(dep, set): + dep = {dep} + return ReadWrites( + OrderedSet.union(self.reads, dep), + self.writes, + self.index_exprs, + self.range_vars, + self.var_ranges, + ) + + def merge(self, other: "ReadWrites"): + reads = OrderedSet.union(self.reads, other.reads) + writes = OrderedSet.union(self.writes, other.writes) + index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs) + return ReadWrites(reads - writes, writes, index_exprs) + + @staticmethod + def merge_list(read_writes: List["ReadWrites"]): + all_writes = OrderedSet.union(*[rw.writes for rw in read_writes]) + all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes + all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes]) + return ReadWrites(all_reads, all_writes, all_index_exprs) + + def remove_reads(self, rem_reads): + return ReadWrites( + self.reads - rem_reads, + self.writes, + self.index_exprs, + self.range_vars, + self.var_ranges, + ) + + def reads_and_writes(self): + return itertools.chain(self.reads, self.writes) + + def buffer_names(self, ignore_integer_index=True): + """ + Integer index is used for load_seed. + """ + names: OrderedSet[str] = OrderedSet() + for dep in self.reads_and_writes(): + if not isinstance(dep, MemoryDep): + continue + if not ignore_integer_index or not isinstance( + dep.index, (int, sympy.Integer) + ): + names.add(dep.name) + return names + + +class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] + def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: + super().__init__() + self._reads: OrderedSet[Dep] = OrderedSet() + self._writes: OrderedSet[MemoryDep] = OrderedSet() + self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet() + self._var_ranges: VarRanges = var_ranges + self._should_normalize: bool = normalize + + @staticmethod + def drop_unused_symbols(index, var_names, sizes): + """ + Reduction has last (reduced) dim in its sizes, but + downstream users won't. Normalize this away. + """ + if not isinstance(index, sympy.Expr): + # index can be an int + return + free_symbols = index.free_symbols + while var_names and var_names[-1] not in free_symbols: + var_names.pop() + sizes.pop() + + @classmethod + def _normalize( + cls, index: sympy.Expr, var_ranges: VarRanges + ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: + # Try to further simplify the indexes even if simplify_loops didn't + # convert it to the simplest form because of the interference from + # different indexing formulas. + index_vars = [*var_ranges.keys()] + sizes = tuple(var_ranges.values()) # type: ignore[assignment] + new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( + index_vars, + sizes, + index_prevent_reordering([index], index_vars, sizes), + ) + + # assign new variables each dimension to deal with numbering mismatches + # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 + new_vars, add_var = var_builder(canonicalization_prefix()) + replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) + index = sympy_subs(sympy.expand(index), replacement) + + new_vars = [*new_vars.keys()] + new_sizes = [*new_sizes] + cls.drop_unused_symbols(index, new_vars, new_sizes) + return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type] + + def canonicalize( + self, index: sympy.Expr + ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: + if not self._should_normalize: + sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()] + var_names = [k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1] + sizes = [v for v in sizes if v != 1] + + self.drop_unused_symbols(index, var_names, sizes) + + return index, tuple(var_names), tuple(sizes) # type: ignore[return-value, arg-type] + var_ranges = { + k: V.graph.sizevars.simplify(v) + for k, v in self._var_ranges.items() + # TODO(jansel): explore this further normalization + # if k in free_symbols + } + return self._normalize(index, var_ranges) + + def load(self, name: str, index: sympy.Expr) -> str: + self._reads.add(MemoryDep(name, *self.canonicalize(index))) + return f"load({name}, {sympy_str(index)})" + + def load_seed(self, name: str, index: int): + assert isinstance(index, int) + return self.load(name, sympy.Integer(index)) + + def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str: + self._writes.add(MemoryDep(name, *self.canonicalize(index), mode=mode)) + return f"store({name}, {sympy_str(index)}, {value}, {mode})" + + def store_reduction(self, name: str, index, value) -> str: + return self.store(name, index, f"store_reduction({value})") + + def index_expr(self, index: sympy.Expr, dtype) -> str: + self._index_exprs.add(IndexExprDep(*self.canonicalize(index))) + return f"index_expr({sympy_str(index)}, {dtype})" + + def bucketize( + self, + values, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ): + self._reads.add(StarDep(offsets_name)) + return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})" + + +class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined] + def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: + parent_handler = _RecordLoadStoreInner( + var_ranges=var_ranges, normalize=normalize + ) + super().__init__(parent_handler=parent_handler) + + +# TODO: check call sites +def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]: + cnt = itertools.count() + var_ranges: VarRanges = {} + + def add_var(length: sympy.Expr) -> sympy.Symbol: + v = sympy_index_symbol(f"{prefix}{next(cnt)}") + var_ranges[v] = length + return v + + return var_ranges, add_var + + +def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str): + var_ranges, add_var = var_builder(prefix) + args: List[List[sympy.Symbol]] = [] + for size in argsizes: + args.append(list(map(add_var, size))) + return args, var_ranges + + +def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"): + from .ir import SqueezeView + + var_ranges, add_var = var_builder(prefix) + args: List[List[sympy.Expr]] = [] + new_sizes: List[List[sympy.Expr]] = [] + for size in argsizes: + new_size, reindex = SqueezeView.squeezer(size) + new_sizes.append(new_size) + args.append(reindex(list(map(add_var, new_size)))) + return args, var_ranges + + +def extract_read_writes( + fn: Callable[..., Any], + *argsizes: Tuple[sympy.Expr, ...], + normalize: bool = False, + prefix: str = "d", + hidden_args=(), +): + args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix) + + from .loop_body import LoopBody, MemoryUsageType + + if isinstance(fn, LoopBody): + # Fast path to avoid tracing when we already have a LoopBody + inner = _RecordLoadStoreInner(var_ranges=var_ranges, normalize=normalize) + name_to_index = fn.indexing_from_args([*args, *hidden_args]) + if fn.indirect_vars: + # mimic the `tmpX` naming tracing gives us + repl = {v: sympy.Symbol(f"tmp{i}") for i, v in enumerate(fn.indirect_vars)} + name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()} + for entry in fn.memory_usage[MemoryUsageType.LOAD]: + inner.load(entry.buffer_name, name_to_index[entry.index_name]) + for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]: + inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) + for entry in fn.memory_usage[MemoryUsageType.STORE]: + inner.store( + entry.buffer_name, name_to_index[entry.index_name], None, entry.mode + ) + for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]: + inner.store_reduction( + entry.buffer_name, name_to_index[entry.index_name], None + ) + for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]: + inner.index_expr(name_to_index[entry.index_name], None) + for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]: + inner.bucketize( + None, entry.buffer_name, name_to_index[entry.index_name], None, None + ) + # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped + else: + # Slow path tracing the function + rw = RecordLoadStore(var_ranges, normalize=normalize) + with V.set_ops_handler(rw): + fn(*args, *hidden_args) + inner = rw.parent_handler + + if normalize: + range_vars = [] # Number of vars could differ due to normalization + else: + range_vars = [*itertools.chain.from_iterable(args)] + + return ReadWrites( + OrderedSet(inner._reads), + OrderedSet(inner._writes), + inner._index_exprs, + range_vars, + var_ranges, + ) + + +def extract_input_node_reduction_ranges( + input_node: "torch._inductor.ir.TensorBox", +) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]: + """ + Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same. + It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes. + In this case, reduction_sizes of the Reduction nodes need to be the same. + Otherwise returns (None, None). + """ + + from .ir import ComputedBuffer, Loops + + if isinstance(input_node.data, ComputedBuffer): + # Input node has already been realized. Return its size and reduction_size. + size = input_node.get_size() + reduction_size = input_node.get_reduction_size() + if len(reduction_size) > 0: + return (size, reduction_size) + else: + return (None, None) + + if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined] + # Other IRNodes do not have reduction_ranges. + return (None, None) + + # There is one issue: what if there are views / permutations between the input node and its dependent realized nodes? + # The current method still uses reduction ranges from the dependent realized node, which is not ideal. + # Is there a way to check whether there are permutations inbetween? + reads = input_node.get_reads() + reduction_size = None + size = None + while reduction_size is None and len(reads) > 0: + seen: OrderedSet[str] = OrderedSet() + new_reads = [] + for read in reads: + if not isinstance(read, MemoryDep): + continue + if read.name in seen: + continue + seen.add(read.name) + buffer = V.graph.try_get_buffer(read.name) + if buffer is None: + continue + op = buffer.get_defining_op() + if op is None: + continue + + if isinstance(op, ComputedBuffer) and len(op.get_reduction_size()) > 0: + if reduction_size is None: + reduction_size = op.get_reduction_size() + size = op.get_size() + elif reduction_size != op.get_reduction_size() or size != op.get_size(): + return (None, None) + else: + new_reads.extend(op.get_reads()) + if reads == new_reads: + return (size, reduction_size) + else: + reads = new_reads + return (size, reduction_size) + + +def canonicalization_prefix(): + return "c" + + +# ops handler which computes all the free unbacked symbols for an IR +class FreeUnbackedSymbolsOpsHandler: + symbols: OrderedSet[sympy.Symbol] + + def __init__(self) -> None: + self.symbols = OrderedSet() + + def __getattr__(self, name: str) -> Callable[..., Any]: + def inner(*args, **kwargs): + for a in itertools.chain(args, kwargs.values()): + if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)): + self.symbols |= free_unbacked_symbols(a) + + return inner + + def indirect_indexing( + self, index_var, size, check=True, wrap_neg=True + ) -> sympy.Symbol: + assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean)) + self.symbols |= free_unbacked_symbols(size) + return sympy_index_symbol(f"({str(index_var)})") + + def frexp(self, x): + return (None,) * 2 + + def scan(self, dtypes, combine_fn, values): + return (None,) * len(values) + + def sort(self, dtypes, values, stable, descending): + return (None,) * len(values) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[None, Tuple[None, ...]], + ) -> Union[None, Tuple[None, ...]]: + num_values = reduction_num_outputs(reduction_type) + return (None,) * num_values if num_values > 1 else None + + +def _typecheck_FreeUnbackedSymbolsOpsHandler( + h: FreeUnbackedSymbolsOpsHandler, +) -> OpsHandler[None]: + return h + + +def extract_free_unbacked_symbols(fn: Callable[..., Any], index, rindex=None): + from .ir import FlexibleLayout + + args = [index, rindex] if rindex is not None else [index] + handler = FreeUnbackedSymbolsOpsHandler() + # NB: I cargo culted the allow_indexing patch here, I don't understand why + # people do this all over + with V.set_ops_handler(handler), patch.object( + FlexibleLayout, "allow_indexing", True + ): + fn(*args) + return handler.symbols diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/exc.py b/.venv/lib/python3.11/site-packages/torch/_inductor/exc.py new file mode 100644 index 0000000000000000000000000000000000000000..728f652032d52affff0bfe216e920a4a97cd51d3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/exc.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import os +import tempfile +import textwrap +from functools import lru_cache + + +if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1": + + @lru_cache(None) + def _record_missing_op(target): + with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd: + fd.write(str(target) + "\n") + +else: + + def _record_missing_op(target): # type: ignore[misc] + pass + + +class OperatorIssue(RuntimeError): + @staticmethod + def operator_str(target, args, kwargs): + lines = [f"target: {target}"] + [ + f"args[{i}]: {arg}" for i, arg in enumerate(args) + ] + if kwargs: + lines.append(f"kwargs: {kwargs}") + return textwrap.indent("\n".join(lines), " ") + + +class MissingOperatorWithoutDecomp(OperatorIssue): + def __init__(self, target, args, kwargs) -> None: + _record_missing_op(target) + super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}") + + +class MissingOperatorWithDecomp(OperatorIssue): + def __init__(self, target, args, kwargs) -> None: + _record_missing_op(target) + super().__init__( + f"missing decomposition\n{self.operator_str(target, args, kwargs)}" + + textwrap.dedent( + f""" + + There is a decomposition available for {target} in + torch._decomp.get_decompositions(). Please add this operator to the + `decompositions` list in torch._inductor.decomposition + """ + ) + ) + + +class LoweringException(OperatorIssue): + def __init__(self, exc: Exception, target, args, kwargs) -> None: + super().__init__( + f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}" + ) + + +class SubgraphLoweringException(RuntimeError): + pass + + +class InvalidCxxCompiler(RuntimeError): + def __init__(self) -> None: + from . import config + + super().__init__( + f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}" + ) + + +class CppWrapperCodeGenError(RuntimeError): + def __init__(self, msg: str) -> None: + super().__init__(f"C++ wrapper codegen error: {msg}") + + +class CppCompileError(RuntimeError): + def __init__(self, cmd: list[str], output: str) -> None: + if isinstance(output, bytes): + output = output.decode("utf-8") + + super().__init__( + textwrap.dedent( + """ + C++ compile error + + Command: + {cmd} + + Output: + {output} + """ + ) + .strip() + .format(cmd=" ".join(cmd), output=output) + ) + + +class CUDACompileError(CppCompileError): + pass diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/extern_node_serializer.py b/.venv/lib/python3.11/site-packages/torch/_inductor/extern_node_serializer.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed505d3e60ee9896aed5953ab114779ecd54dcd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/extern_node_serializer.py @@ -0,0 +1,25 @@ +import json +from typing import List + +from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node +from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder +from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode + + +def serialize_extern_kernel_node( + extern_kernel_node: inductor_ExternKernelNode, +) -> ExternKernelNode: + assert isinstance(extern_kernel_node.node, Node) + return ExternKernelNode( + name=extern_kernel_node.name, + node=extern_kernel_node.node, + ) + + +def extern_node_json_serializer( + extern_kernel_nodes: List[inductor_ExternKernelNode], +) -> str: + serialized_nodes = ExternKernelNodes( + nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes] + ) + return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/freezing.py b/.venv/lib/python3.11/site-packages/torch/_inductor/freezing.py new file mode 100644 index 0000000000000000000000000000000000000000..9b936bf7bb1b1b4ef00195874a8d51fc737d0b16 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/freezing.py @@ -0,0 +1,269 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import itertools +import logging +import weakref +from typing import Any, List, Optional, Tuple + +import torch +import torch.utils._pytree as pytree +from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code +from torch._functorch.aot_autograd import MutationType +from torch._functorch.compile_utils import fx_graph_cse +from torch._inductor.constant_folding import constant_fold, replace_node_with_constant +from torch._inductor.fx_passes.freezing_patterns import freezing_passes +from torch._inductor.fx_passes.post_grad import view_to_reshape + +from . import config + + +aten = torch.ops.aten +prims = torch.ops.prims + +log = logging.getLogger(__name__) + + +def replace_params_with_constants( + gm: torch.fx.GraphModule, + flat_params: list[Any], + fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta, +) -> List[int]: + """ + Replaces the parameters of a PyTorch GraphModule with constants wherever possible. + Returns a list of indices representing the input parameters that were not converted to constants. + """ + params = gm.graph.find_nodes(op="placeholder") + fake_inp_nodes = params[: len(params)] + preserved_arg_indices = [] + aliased_input_args = [ + out_info.base_idx + for out_info in fw_metadata.output_info + if out_info.base_idx is not None + ] + + # TODO (tmanlaibaatar) figure out why this is different + # from mutated_inp_runtime_indices + mutated_inps = [ + i + for i, m in enumerate(fw_metadata.input_info) + if m.mutation_type + in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH) + ] + + for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)): + if i in mutated_inps or i in aliased_input_args: + preserved_arg_indices.append(i) + continue + replace_node_with_constant(gm, node, real_input) + # add on non param inputs + preserved_arg_indices.extend(range(len(flat_params), len(params))) + # is this necessary ? + gm.recompile() + return preserved_arg_indices + + +def freeze( + dynamo_gm: torch.fx.GraphModule, + aot_autograd_gm: torch.fx.GraphModule, + example_inputs: List[torch._subclasses.FakeTensor], +) -> Tuple[torch.fx.GraphModule, List[int]]: + """ + Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation + and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency. + + Assumes that this function is run in dynamo tracing post aot_autograd. + + Args: + dynamo_gm (torch.fx.GraphModule): The Dynamo constructed GraphModule. + aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen. + example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process. + + Returns: + Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices + of the inputs that were preserved (not turned into constants). + """ + # We have convert conv's weight to channels last which may meet error for .view + # when doing fake_tensor_prop. So we need to convert view to reshape first. + # See the details in fx_codegen_and_compile of compile_fx.py. + view_to_reshape(aot_autograd_gm) + + if tracing_context := torch._guards.TracingContext.try_get(): + fw_metadata = tracing_context.fw_metadata + params_flat = tracing_context.params_flat + assert fw_metadata is not None and params_flat is not None + + preserved_arg_indices = replace_params_with_constants( + aot_autograd_gm, params_flat, fw_metadata + ) + else: + inputs = aot_autograd_gm.graph.find_nodes(op="placeholder") + preserved_arg_indices = list(range(len(inputs))) + + # TODO - further restrict cse ? right now needed to dedup aliasing ops + cse_graph = fx_graph_cse(aot_autograd_gm.graph) + aot_autograd_gm.graph = cse_graph + aot_autograd_gm.recompile() + + aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices] + freezing_passes(aot_autograd_gm, aot_example_inputs) + + constant_fold(aot_autograd_gm) + # invalidate nn Modules + if config.freezing_discard_parameters: + invalidate_eager_modules() + discard_traced_gm_params(dynamo_gm) + + log.debug( + "%s", lazy_format_graph_code("FROZEN GRAPH", aot_autograd_gm, colored=True) + ) + + return aot_autograd_gm, preserved_arg_indices + + +class ErasedTensor(torch.Tensor): + @staticmethod + def __new__(cls, elem, name, owning_mod): + return super().__new__(cls, elem.to(device="meta")) + + def __init__(self, elem, name: Optional[str], mod) -> None: + self.erased_name = name + self.owning_mod_ref = weakref.ref(mod) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + erased_tensors = [ + e + for e in pytree.arg_tree_leaves(*args, **kwargs) + if isinstance(e, ErasedTensor) + ] + assert len(erased_tensors) > 0 + e = erased_tensors[0] + + raise RuntimeError( + f"Trying to run Pytorch Eager Module after Dynamo Freezing. " + "The original parameters have been discarded for memory efficiency. " + f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}" + ) + + +def invalidate_eager_modules(): + with torch.utils._python_dispatch._disable_current_modes(): + for ( + mod + ) in torch._guards.TracingContext.get().module_context.nn_modules.values(): + if not isinstance(mod, torch.nn.Module): + continue + + for attr_name, tensor in list( + itertools.chain( + mod.named_parameters(recurse=False), + mod.named_buffers(recurse=False), + ) + ): + with torch._dispatch.python.no_python_dispatcher(): + e_t = ErasedTensor(tensor, attr_name, mod) + if isinstance(tensor, torch.nn.Parameter): + e_t.requires_grad_(True) + e_t._is_param = True # type: ignore[attr-defined] + setattr(mod, attr_name, e_t) + + +def discard_traced_gm_params(mod: torch.fx.GraphModule): + with torch.utils._python_dispatch._disable_current_modes(): + for attr_name, tensor in list( + itertools.chain( + mod.named_parameters(recurse=False), mod.named_buffers(recurse=False) + ) + ): + with torch._dispatch.python.no_python_dispatcher(): + e_t = ErasedTensor(tensor, attr_name, mod) + if isinstance(tensor, torch.nn.Parameter): + e_t.requires_grad_(True) + e_t._is_param = True # type: ignore[attr-defined] + setattr(mod, attr_name, e_t) + + +def enforce_output_layout(gm: torch.fx.GraphModule): + """ + Make sure the output node's layout does not change due to compiler optimizations + by adding aten.as_strided nodes with the expected strides. + + Only used for inference so we can assume all graph outputs are model outputs. + """ + *_, output_node = gm.graph.nodes + out_list = output_node.args[0] + with gm.graph.inserting_before(output_node): + for n in out_list: + if not isinstance( + n.meta["val"], torch.Tensor + ) or not torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]): + continue + + # add a node to enforce eager layout + ft = n.meta["val"] + new_node = gm.graph.call_function( + prims.inductor_force_stride_order.default, (n, ft.stride()) + ) + + # can not call + # n.replace_all_uses_with(new_node) + # since it will replace the usage of n in new_node itself. + output_node.replace_input_with(n, new_node) + + gm.graph.lint() + gm.recompile() + + +def enforce_as_strided_input_layout(gm: torch.fx.GraphModule): + """ + Make sure the as_strided node's input's layout does not change due to compiler + optimizations, because the as_strided strides info depends on input tensor stride info. + """ + + as_strided_ops = [ + torch.ops.aten.as_strided.default, + torch.ops.aten.as_strided_.default, + torch.ops.aten.as_strided_scatter.default, + ] + strided_nodes = [n for n in gm.graph.nodes if n.target in as_strided_ops] + for n in strided_nodes: + with gm.graph.inserting_before(n): + # add a node to enforce eager layout + ft = n.args[0].meta["val"] + new_node = gm.graph.call_function( + prims.inductor_force_stride_order.default, (n.args[0], ft.stride()) + ) + n.replace_input_with(n.args[0], new_node) + + gm.graph.lint() + gm.recompile() + + +def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule): + """ + Convert 4d convolution weight tensor to channels last format. + + This pass is performed before freezing so the added nodes can be constant + folded by freezing. + """ + with dynamo_timed("convert_conv_weights_to_channels_last"): + convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default] + for conv in convs: + weight_node = conv.args[1] + if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[ + "val" + ].is_contiguous(memory_format=torch.channels_last): + # not a 4d tensor or already channels last, skip + continue + + with gm.graph.inserting_before(conv): + new_node = gm.graph.call_function( + aten.clone.default, + (weight_node,), + {"memory_format": torch.channels_last}, + ) + conv.replace_input_with(weight_node, new_node) + + enforce_as_strided_input_layout(gm) + enforce_output_layout(gm) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/fx_utils.py b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f1791004996204a0e2cb201c4a706f2e99408086 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/fx_utils.py @@ -0,0 +1,251 @@ +# mypy: allow-untyped-defs +import operator +from collections import defaultdict +from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type + +import sympy + +import torch +import torch.fx +from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + rebind_unbacked, + statically_known_true, + sym_eq, +) +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_map + +from .virtualized import V + + +# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched. +# Works for length 2 patterns with 1 module and 1 function/method. +def matches_module_function_pattern( + pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]], + node: torch.fx.node.Node, + modules: Dict[str, torch.nn.modules.Module], +) -> bool: + if len(node.args) == 0: + return False + if not isinstance(node.args[0], torch.fx.Node) or not isinstance( + node, torch.fx.Node + ): + return False + # the first node is call_module + if node.args[0].op != "call_module": + return False + if not isinstance(node.args[0].target, str): + return False + if node.args[0].target not in modules: + return False + if type(modules[node.args[0].target]) is not pattern[0]: + return False + # the second node is call_function or call_method + if node.op != "call_function" and node.op != "call_method": + return False + if node.target != pattern[1]: + return False + # make sure node.args[0] output is only used by current node. + if len(node.args[0].users) > 1: + return False + return True + + +class FakeTensorUpdater: + """ + The main idea here is that it's difficult to maintain accurate fake + tensors (our primary form of metadata) for each node in our graph as we + transform it. + + The most reliable way to obtain this information is by rerunning + faketensor propagation. However, in general, faketensor propagation is + fairly expensive. So, instead we'd like to only rerun faketensor + propagation on nodes that have changed. + + In order to detect which nodes have changed, we first hash its node, + target, and argument lists (which are immutable in FX). + + Then, whenever we call incremental_update, we check which FX nodes have a + new hash, and recompute the faketensor metadata for that node. Then, we + continue to recursively compute the faketensors for all users until the + fake tensors stop changing. + """ + + def __init__(self, graph: torch.fx.Graph) -> None: + self.processed_hashes = set() + self.graph = graph + + for node in self.graph.nodes: + self.processed_hashes.add(self.hash_node(node)) + + def hash_node(self, node: torch.fx.Node): + # todo(chilli): Not a great hash function + return (node, node.target, id(node.args), id(node.kwargs)) + + def incremental_update(self): + processed = set() + existing_storages: DefaultDict[Optional[int], int] = defaultdict(int) + for node in self.graph.nodes: + existing_storages[get_node_storage(node)] += 1 + + def is_intlist_same(new, old): + return statically_known_true(sym_eq(new, old)) + + def is_fake_tensor_same(new, old): + if type(new) != type(old): + return False + if isinstance(new, (list, tuple)): + if len(new) != len(old): + return False + return all( + is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old) + ) + if new is None: + return old is None + if not isinstance(new, torch.Tensor): + assert isinstance( + new, (torch.SymInt, torch.SymBool, torch.SymFloat) + ), f"Unknown type {type(new)} in {self.graph}" + return ( + new.node.shape_env._maybe_evaluate_static( + sympy.Eq(new.node.expr, old.node.expr) + ) + == sympy.true + ) + if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout: + return False + if new.layout == torch.strided and ( + not is_intlist_same(new.stride(), old.stride()) + or not statically_known_true( + new.storage_offset() == old.storage_offset() + ) + ): + return False + + if new.device != old.device: + return False + + if get_storage(new) == get_storage(old): + return True + + # This is the case where it returns a completely fresh storage that's used nowhere else. + if ( + existing_storages[get_storage(old)] == 1 + and get_storage(new) not in existing_storages + ): + return True + return False + + def should_process_node(node): + # node.target for nodes returning true from this function + # are called under fake mode and does not work for inductor + # lowerings. We check if the node.target is an aten operator + # or operator.getitem which is used when returning multiple + # tensors from an op. + return node.op == "call_function" and ( + isinstance(node.target, torch._ops.OpOverload) + or node.target == operator.getitem + ) + + to_process = set() + for node in self.graph.nodes: + if ( + self.hash_node(node) in self.processed_hashes + and id(node) not in to_process + ): + continue + + if not should_process_node(node): + continue + + is_valid, args, kwargs = get_fake_args_kwargs(node) + if not is_valid: + continue + with V.fake_mode: + new_fake_tensor = node.target(*args, **kwargs) + if "val" in node.meta and is_fake_tensor_same( + new_fake_tensor, node.meta["val"] + ): + continue + + rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor) + + node.meta["val"] = new_fake_tensor + if (shape_env := V.fake_mode.shape_env) and ( + symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor) + ): + # Refresh the bindings to the new symbols + node.meta["unbacked_bindings"] = symbol_to_path + + existing_storages[get_node_storage(node)] += 1 + + to_process.update([id(user) for user in node.users]) + + self.processed_hashes.add(self.hash_node(node)) + + +def get_storage(t: torch.Tensor) -> int: + return t.untyped_storage()._cdata + + +def get_node_storage(node: torch.fx.Node) -> Optional[int]: + if "val" not in node.meta: + return None + if not isinstance(node.meta["val"], torch.Tensor): + return None + if not torch._C._has_storage(node.meta["val"]): + return None + return get_storage(node.meta["val"]) + + +def get_fake(x): + if isinstance(x, torch.fx.Node): + if "val" not in x.meta: + return x + return x.meta["val"] + return x + + +def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]: + """ + First value returns a boolean if any of the input nodes don't have a faketensor. + """ + args, kwargs = tree_map(get_fake, (x.args, x.kwargs)) + if any( + isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs) + ): + return False, args, kwargs + return True, args, kwargs + + +def is_node_realized(node: torch.fx.Node) -> bool: + """Returns true if a node is always realized when lowered to inductor IR. + + NOTE: This may return some false negatives. e.g. it doesn't + handle buffers realized heuristically during lowering, or + buffers realized indirectly through view ops. + """ + from torch._inductor.lowering import fallbacks, needs_realized_inputs + + def is_buffer(node: torch.fx.Node) -> bool: + if node.op == "call_function" and node.target is operator.getitem: + # For nodes with multiple outputs, we get the fx graph: + # foo = torch.ops.aten.foo(...) + # getitem = foo[0] + # getitem_1 = foo[1] + # where we need to check if foo is a fallback kernel + return is_buffer(node.args[0]) # type: ignore[arg-type] + return node.op in ("placeholder", "output") or node.target in fallbacks + + if is_buffer(node): + return True + + def realizes_inputs(node: torch.fx.Node) -> bool: + return node.op == "output" or node.target in needs_realized_inputs + + if any(realizes_inputs(user) for user in node.users): + return True + + # Otherwise, assume node isn't realized + return False diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/graph.py b/.venv/lib/python3.11/site-packages/torch/_inductor/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..3448ea6eb9620ad0094a7c86aa35244574f1ff2b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/graph.py @@ -0,0 +1,1930 @@ +import functools +import itertools +import logging +import operator +import os +import re +import sys +import time +from collections import defaultdict +from contextlib import contextmanager +from types import ModuleType +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Iterable, + List, + NoReturn, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) + +import sympy +from sympy import Expr + +import torch +import torch._logging +import torch.fx +from torch import device, Tensor +from torch._decomp import get_decompositions +from torch._dynamo.utils import defake, dynamo_timed +from torch._logging import LazyString, trace_structured +from torch._prims_common import make_channels_last_strides_for +from torch._subclasses.fake_tensor import FakeTensor +from torch.fx import GraphModule +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import ( + free_unbacked_symbols, + has_free_symbols, + resolve_unbacked_bindings, + RuntimeAssert, + ShapeEnv, + SymTypes, +) +from torch.fx.graph import Graph +from torch.fx.node import Node +from torch.utils._mode_utils import no_dispatch +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.numbers import int_oo + +from . import config, ir +from .codegen.common import ( + BackendFeature, + DeviceOpOverrides, + get_backend_features, + get_device_op_overrides, + get_wrapper_codegen_for_device, + init_backend_registration, +) +from .exc import ( + CppWrapperCodeGenError, + LoweringException, + MissingOperatorWithDecomp, + MissingOperatorWithoutDecomp, +) +from .ir import ( + Constant, + FixedLayout, + get_device_type, + InputBuffer, + Pointwise, + Reduction, + StorageBox, + TensorBox, + TorchBindObject, +) +from .lowering import ( + FALLBACK_ALLOW_LIST, + fallback_handler, + fallback_node_due_to_unsupported_type, + lowerings, + make_fallback, + maybe_layout_constraints, + needs_realized_inputs, + unsupported_output_tensor, +) +from .scheduler import BaseSchedulerNode +from .sizevars import SizeVarAllocator +from .utils import ( + convert_shape_to_inductor, + gather_origins, + get_cloned_parameter_buffer_name, + get_sympy_Expr_dtype, + maybe_get_suppress_shape_guards_ctx, + should_assume_input_aligned, +) +from .virtualized import NullHandler, V + + +if TYPE_CHECKING: + from torch._higher_order_ops.effects import _EffectType + from .codegen.wrapper import WrapperCodeGen + +from torch._inductor.codecache import output_code_log + + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") + +aten = torch.ops.aten + +_post_grad_graph_counter = itertools.count() + +if config.is_fbcode(): + from torch._inductor.fb.utils import log_module_code +else: + + def log_module_code(*args: Any, **kwargs: Any) -> None: + pass + + +def supported_dtype_of_cpp_wrapper(dtype: torch.device, cuda: bool) -> bool: + supported_dtype = { + torch.float32, + torch.float64, + torch.int64, + torch.int32, + torch.int16, + torch.int8, + torch.uint8, + torch.bool, + torch.bfloat16, + torch.complex32, + torch.complex64, + torch.complex128, + torch.float16, + } + if cuda: + supported_dtype.add(torch.float8_e4m3fn) + supported_dtype.add(torch.float8_e5m2) + supported_dtype.add(torch.float8_e4m3fnuz) + supported_dtype.add(torch.float8_e5m2fnuz) + + return dtype in supported_dtype + + +def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]: + assert isinstance( + constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) + ), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer" + if isinstance(constant_buffer, sympy.core.numbers.Integer): + return torch.int64 + + if isinstance(constant_buffer, sympy.Expr): + return get_sympy_Expr_dtype(constant_buffer) + + if constant_buffer.is_integer: + return torch.int64 + elif constant_buffer.is_float: + return torch.float32 + else: + return None + + +def is_magic_method(op: Any) -> bool: + magic_ops = {method_to_operator(m) for m in magic_methods} + return op in magic_ops + + +def getattr_recursive( + obj: GraphModule, target: str +) -> Union[Tensor, torch._C.ScriptObject, GraphModule]: + target_atoms = target.split(".") + attr_itr = obj + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def mark_nodes_dislike_padding( + g: Graph, user_visible_outputs: Optional[Dict[str, None]] +) -> None: + """ + Nodes like convolution/convolution_backward want its input to be dense. + If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction. + + The pass finds nodes that dislike padding. These are nodes that can be reached + from a convolution/convolution_backward in the backward direction without + going thru a reduction. + """ + if not config.comprehensive_padding: + return + ops_dislike_padding = { + aten.convolution, + aten.convolution_backward, + } + # what's a better way to collect the reduction ops? + ops_like_padding = { + aten.var_mean, + aten.sum, + aten.mean, + aten.prod, + aten.any, + aten.amin, + aten.amax, + aten.min, + aten.max, + aten.argmin, + aten.argmax, + aten.scatter_reduce, + } + + def _get_overload_packet( + node: torch.fx.Node, + ) -> Optional[torch._ops.OpOverloadPacket]: + return ( + node.target._overloadpacket + if node.op == "call_function" + # hasattr on OpOverloadPacket is slow, do isinstance first + and isinstance(node.target, torch._ops.OpOverload) + and hasattr(node.target, "_overloadpacket") + else None + ) + + for cur in reversed(g.nodes): + op = _get_overload_packet(cur) + if not op: + continue + if op in ops_dislike_padding: + cur.meta["dislike_padding"] = True + + if cur.meta.get("dislike_padding", False): + # propagate + for prior in cur.all_input_nodes: + prior_op = _get_overload_packet(prior) + if not prior_op: + continue + if prior_op not in ops_like_padding: + prior.meta["dislike_padding"] = True + # We only want to mark output nodes. So, move it after the above prior nodes process. + if ( + not config.pad_outputs + and user_visible_outputs + and cur.name in user_visible_outputs + ): + cur.meta["dislike_padding"] = True + + +class GraphLowering(torch.fx.Interpreter): + graph_outputs: List[ir.IRNode] + + def symbolic_sizes_strides( + self, ex: torch.Tensor + ) -> Tuple[Union[List[int], List[Expr]], Union[List[int], List[Expr]]]: + """ + Support dynamic shapes and dynamic strides by assigning variables + to each dimension. We duck-shape tensors, so if two tensors + have the same size they get assigned the same symbolic variable. + """ + if self.reuse_shape_env: + return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor( + ex.stride() + ) + else: + from torch._dynamo.source import ConstantSource + + # TODO: this should not be needed once #93059 lands + # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816 + # TODO: make a dedicated UnknownSource for this? + # NB: This is using the legacy default behavior from + # create_symbolic_sizes_strides_storage_offset but we hope we can + # just delete this entirely + source = ConstantSource( + f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}" + ) + ( + size, + stride, + _, + ) = self._shape_env.create_symbolic_sizes_strides_storage_offset( + ex, + source, + ) + + size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size] + stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride] + return size, stride + + def static_sizes_strides( + self, ex: torch.Tensor + ) -> Tuple[List[sympy.Expr], List[sympy.Expr]]: + """ + Primarily used to weights + """ + size = [sympy.Integer(i) for i in ex.size()] + stride = [sympy.Integer(i) for i in ex.stride()] + return size, stride + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs: Optional[List[torch.Tensor]] = None, + shape_env: Optional[ShapeEnv] = None, + graph_id: Optional[int] = None, + cpp_wrapper: bool = False, + aot_mode: bool = False, + user_visible_outputs: Optional[Dict[str, None]] = None, + layout_opt: Optional[bool] = None, + extern_node_serializer: Optional[ + Callable[[List[ir.ExternKernelNode]], Any] + ] = None, + is_inference: bool = False, + is_const_graph: bool = False, + const_output_index: Optional[Dict[str, int]] = None, + const_code: Optional[str] = None, + const_module: Optional["GraphLowering"] = None, + name: Optional[str] = None, + ) -> None: + super().__init__(gm) + self.example_inputs = example_inputs + self.layout_opt = ( + layout_opt + if layout_opt is not None + else self.decide_layout_opt(gm, is_inference=is_inference) + ) + self.num_channels_last_conv = 0 + self.is_inference = is_inference + self.is_const_graph = is_const_graph + self.const_code = const_code + self.const_module = const_module + + self.extra_traceback = False # we do our own error wrapping + if shape_env is None: + shape_env = ShapeEnv() + self.reuse_shape_env = False + else: + self._shape_env = shape_env + self.reuse_shape_env = True + self._shape_env = shape_env + # We are going to start code generating runtime asserts, so make sure + # you don't start adding new ones in the lowering process + shape_env.freeze_runtime_asserts() + # We're going to mutate ras_by_symbol as we finish generating them + self.ras_by_symbol: Dict[ + sympy.Symbol, List[RuntimeAssert] + ] = shape_env.deferred_runtime_asserts.copy() + self.bound_unbacked_symbols: OrderedSet[sympy.Symbol] = OrderedSet() + self.sizevars = SizeVarAllocator(shape_env) + self.graph_input_names: List[str] = [] + self.graph_inputs: Dict[str, TensorBox] = {} + self.graph_inputs_original: Dict[str, InputBuffer] = {} + self.zero_dim_cpu_tensor_list: OrderedSet[str] = OrderedSet() + self.device_types: OrderedSet[str] = ( + const_module.device_types if const_module else OrderedSet() + ) + self.device_idxs: OrderedSet[int] = ( + const_module.device_idxs if const_module else OrderedSet() + ) + self.cuda = False + self.buffers: List[ir.Buffer] = [] + self.operations: List[ir.Operation] = [] + self.const_output_index: Dict[str, int] = ( + const_output_index if const_output_index else {} + ) + self.folded_constants: OrderedSet[str] = ( + OrderedSet(const_output_index.keys()) + if const_output_index + else OrderedSet() + ) + self.constants: Dict[str, torch.Tensor] = ( + const_module.constants if const_module else {} + ) + self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {} + self.constant_reprs: Dict[str, str] = {} + self.removed_operations: OrderedSet[str] = OrderedSet() + self.removed_buffers: OrderedSet[str] = OrderedSet() + self.removed_inplace_buffers: OrderedSet[str] = OrderedSet() + self.mutated_buffers: OrderedSet[str] = OrderedSet() + self.never_reuse_buffers: OrderedSet[str] = OrderedSet() + self.inplaced_to_remove: OrderedSet[str] = OrderedSet() + self.device_ops: DeviceOpOverrides = None # type: ignore[assignment] + self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment] + # See `ProxyExecutor Design Note` in ir.py for more details + self.extern_kernel_nodes: List[ir.ExternKernelNode] = [] + + from torch._inductor.extern_node_serializer import extern_node_json_serializer + + self.extern_node_serializer: Callable[[List[ir.ExternKernelNode]], Any] = ( + extern_node_serializer + if config.is_fbcode() and extern_node_serializer + else extern_node_json_serializer + ) + + self.current_node: torch.fx.Node = None # type: ignore[assignment] + self.lists: Dict[str, List[str]] = {} + self.mutated_inputs: OrderedSet[str] = OrderedSet() + self.mutated_input_idxs: List[int] = [] + self.name_to_buffer: Dict[str, ir.Buffer] = {} + self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list) + self.name_to_op: Dict[str, ir.Operation] = {} + self.creation_time = time.time() + self.name = name # type: ignore[assignment] + self.cpp_wrapper = cpp_wrapper + + # record multi_kernel choice for cpp_wrapper so the second pass knows + # which sub-kernel is picked. Copy cpp_wrapper to another variable + # since cpp_wrapper flag is OrderedSet to false for the first pass of codegen. + self.record_multi_kernel_choice = cpp_wrapper + self.multi_kernel_to_choice: Dict[str, int] = {} + + self.aot_mode = aot_mode + self.graph_id = graph_id + self.post_grad_graph_id = next(_post_grad_graph_counter) + self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment] + self.nodes_prefer_channels_last = ( + self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet() + ) + self._warned_fallback = {"aten.convolution_backward"} + self.user_visible_outputs = ( + user_visible_outputs if user_visible_outputs is not None else {} + ) + mark_nodes_dislike_padding(gm.graph, user_visible_outputs) + self.cache_key: str = "" # This is the cache key for the compiled artifact + self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored + self.cache_linemap: List[ + Tuple[int, str] + ] = ( + [] + ) # This is the linemap used by the profiler to mark custom compiled kernels getting run + # Used if lowering encounters cases where cudagraphs are not supported + self.disable_cudagraphs_reason: Optional[str] = None + + # only keeping one node per device for stack trace purposes + self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {} + self.orig_gm: torch.fx.GraphModule = gm.__copy__() + self.dynamo_flat_name_to_original_fqn = self.module.meta.get( + "dynamo_flat_name_to_original_fqn", {} + ) + self.allocated_constant_name: Dict[str, str] = ( + const_module.allocated_constant_name if const_module is not None else {} + ) + init_backend_registration() + self.get_backend_features = functools.lru_cache(None)(get_backend_features) + + self.effectful_ops: Dict[_EffectType, ir.Buffer] = {} + self.aligned_inputs: OrderedSet[str] = OrderedSet() + self.no_fuse_buffer_names: OrderedSet[str] = OrderedSet() + + # Below field is related to printing debug intermediate tensor values info for debugging + self.all_codegen_kernel_names: OrderedSet[str] = OrderedSet() + + def has_feature( + self, device: Union[torch._inductor.ir.IRNode, device], feature: BackendFeature + ) -> bool: + assert isinstance(feature, BackendFeature), feature + return feature in self.get_backend_features(get_device_type(device)) + + @staticmethod + def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool: + """ + Decide if we should enable layout optimization for this graph based on + heuristics. + """ + if not config.layout_optimization: + return False + + if config.force_layout_optimization: + return True + + conv_nodes = [ + n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default + ] + nconv = len(conv_nodes) + + if nconv == 0: + return False + + # For cpu backend and mkldnn enabled, we always use channels_last for better performance. + if ( + torch.backends.mkldnn.enabled + and torch.backends.mkldnn.is_available() + and all( + n.args[idx].meta["val"].device == torch.device("cpu") + for n in conv_nodes + for idx in [0, 1] + ) + ): + return True + + # Following models are skipped due to this: + # jx_nest_base + # volo_d1_224 + if len(list(gm.graph.nodes)) >= 300 * nconv: + log.debug("Skipped layout opt because only a few conv") + return False + + if any( + has_free_symbols(n.args[idx].meta["val"]) + for n in conv_nodes + for idx in [0, 1] + ): + log.debug( + "See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670" + ) + return False + + def is_grouped(n: Any) -> bool: + meta_val = n.args[1].meta["val"] # type: ignore[union-attr, operator] + assert isinstance(meta_val, torch.Tensor) + return n.args[-1] > 1 and meta_val.size(1) > 1 # type: ignore[union-attr, operator] + + def is_in_out_channel(n: torch.fx.Node) -> bool: + return ( + n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) # type: ignore[union-attr, operator] + and n.args[1].meta["val"].size(2) > 1 # type: ignore[union-attr, operator] + ) + + def is_small_channel(n: torch.fx.Node) -> bool: + return ( + n.args[1].meta["val"].size(0) <= 64 # type: ignore[union-attr, operator] + and n.args[1].meta["val"].size(1) <= 64 # type: ignore[union-attr, operator] + ) + + # only grouped convolutions benchmarked as slower in conv samples for inference only + if is_inference: + from torch.utils.flop_counter import FlopCounterMode + + flop_counts: Dict[str, float] = defaultdict(float) + for node in conv_nodes: + success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs( + node + ) + + if success: + with FlopCounterMode(display=False) as flop_counter_mode: + with V.fake_mode: + node.target(*args, **kwargs) + + counted_flops = flop_counter_mode.get_total_flops() + if is_grouped(node): + node_type = "grouped" + elif is_small_channel(node): + node_type = "small" + elif is_in_out_channel(node): + node_type = "in_out" + else: + node_type = "default" + + flop_counts[node_type] += counted_flops + else: + log.debug("Conv inputs meta not found") + + # average benchmarked channels last speedup / slowdown, < 1 is speedup. + # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/ + # To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb + GROUPED_MULTIPLIER = 1.358 + DEFAULT_MULTIPLIER = 0.823 + IN_OUT_MULTIPLIER = 0.725 + SMALL_MULTIPLIER = 0.783 + + total_flops = sum(flop_counts.values()) + # TODO - get different values per hardware + weighted_flops = ( + flop_counts["grouped"] * GROUPED_MULTIPLIER + + flop_counts["small"] * SMALL_MULTIPLIER + + flop_counts["in_out"] * IN_OUT_MULTIPLIER + + flop_counts["default"] * DEFAULT_MULTIPLIER + ) + do_layout_opt = weighted_flops <= total_flops + if not do_layout_opt: + log.debug( + "Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d", + total_flops, + weighted_flops, + ) + return do_layout_opt + + # Channels last layout can dramatically hurt grouped conv perf. E.g. + # Conv with arguments like + # {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3], + # "stride": [2, 2], "padding": [1, 1], "groups": 2} + # slows down 31x using channels last.. + + # But a lot of timm models use depthwise separable convolution which will + # result in grouped convolution with in-channel size == 1. + # For those grouped convolution, channels last still helps a lot. + # E.g. + # Conv with arguments + # {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3], + # "stride": [2, 2], "padding": [1, 1], "groups": 58} + # get 1.86x speedup with channels last layout. + # + # The following heuristics skip using channels-last if the model contains + # grouped convolution with in-channels > 1. + if any(map(is_grouped, conv_nodes)): + log.debug( + "Skip layout opt because found grouped convolution with >1 in_channels!" + ) + return False + + # For some models that contain convolution with larger in-channel than out-channel, applying + # channels last hurts performance. + # Following models are skipped due to this: + # - pytorch_unet + # - phlippe_densenet (slightly worse) + # - Background_Matting (1.22x -> 0.821x) + # - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x) + if any(map(is_in_out_channel, conv_nodes)): + log.debug( + "Skip layout opt because some convolutions have smaller out_channel" + ) + return False + + # Following models are skipped due to this: + # - functorch_maml_omniglot + if all(map(is_small_channel, conv_nodes)): + log.debug("Skip layout opt because all convolution channels are too small") + return False + + return True + + def qualify_name(self, name: str) -> str: + """Prepend the given name with the graph name if any.""" + if self.name is not None: + return f"{self.name}_{name}" + return name + + def make_subgraph( + self, + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + subgraph_name: str, + ) -> "GraphLowering": + """ + Make a subgraph of the current graph with all inherited + parts, except the graph module (`gm`) and `example_inputs`. + The subgraphs are lowered separately, but intended to be + inlined in the parent graph's codegening. Hence the need + for maintaining the same `shape_env` and other properties. + The subgraph name is qualified by the parent graph's name. + """ + return GraphLowering( + gm=gm, + example_inputs=example_inputs, + shape_env=self._shape_env, + cpp_wrapper=self.cpp_wrapper, + aot_mode=self.aot_mode, + extern_node_serializer=self.extern_node_serializer, + is_inference=self.is_inference, + name=self.qualify_name(subgraph_name), + ) + + def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]: + """ + The rule to decide if an node prefer channels last is simple. + 1. if it's input/output of a convolution + 2. if one of its user prefers channels last + + We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs; + Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers + channels last. + + Consider the scenario: conv -> batch-norm -> relu -> conv + Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies: + 1. the output of batch-norm should be channels last initially since its input is a conv's output. + Forcing the batch-norm's output to be contiguous results in the first copy + 2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output. + We need convert it to channels last layout which results in the second copy. + With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies + can be saved. + """ + output_set: OrderedSet[Node] = OrderedSet() + for n in reversed(self.module.graph.nodes): + if n.target == torch.ops.aten.convolution.default: + output_set.add(n) + continue + + for user in n.users: + if user in output_set: + output_set.add(n) + break + + # need a second pass to add downstream nodes of those channel last nodes to the sets. + # This pass is especially needed to avoid mix-layout kernel inputs in backward pass. + # + # Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned + # from the fwd graph. Without this second pass, we will force relu's output to be contiguous. + # Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last + # tensors and passed to a kernel. + # + # This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x. + # It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x . + # This also helps the following models: + # - res2net101_26w_4s + # - res2net50_14w_8s + # - sebotnet33ts_256 + for n in self.module.graph.nodes: + if n in output_set: + output_set.update(n.users) + + return output_set + + def warn_fallback(self, name: str) -> None: + if name not in self._warned_fallback: + self._warned_fallback.add(name) + perf_hint_log.info("Using FallbackKernel: %s", name) + + def add_device_info(self, device: torch.device) -> None: + self.device_types.add(device.type) + if device.index is not None: + self.device_idxs.add(device.index) + if V.graph.current_node and device not in self.device_node_mapping: + self.device_node_mapping[device] = V.graph.current_node + + @property + def fake_mode(self) -> torch._subclasses.fake_tensor.FakeTensorMode: + return V.fake_mode + + def try_get_buffer( + self, buffer_name: str + ) -> Optional[Union[ir.TensorBox, ir.Buffer]]: + if buffer_name in self.name_to_buffer: + return self.name_to_buffer[buffer_name] + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name] + if buffer_name in self.constants: + data = V.graph.constants[buffer_name] + return ir.ConstantBuffer( + buffer_name, + ir.FixedLayout( + data.device, data.dtype, *V.graph.static_sizes_strides(data) + ), + ) + + return None + + def get_buffer(self, buffer_name: str) -> Union[ir.TensorBox, ir.Buffer]: + buf = self.try_get_buffer(buffer_name) + if buf is not None: + return buf + raise RuntimeError(f"Failed to find buffer matching name {buffer_name}") + + def get_dtype(self, buffer_name: str) -> torch.dtype: + if buffer_name in self.constants: + return self.constants[buffer_name].dtype + if buffer_name in self.name_to_buffer: + return self.name_to_buffer[buffer_name].get_dtype() + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name].get_dtype() + m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name) + if m: + return self.get_dtype(m.group(1)) + raise KeyError(f"could not find {buffer_name}") + + def get_numel(self, buffer_name: str) -> Union[int, Expr]: + from .ir import MultiOutputLayout + + if buffer_name in self.constants: + return self.constants[buffer_name].numel() + if buffer_name in self.name_to_buffer: + buf = self.name_to_buffer[buffer_name] + if isinstance(getattr(buf, "layout", None), MultiOutputLayout): + return 1 + return buf.get_numel() + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name].get_numel() + raise KeyError(f"could not find {buffer_name}") + + def run(self, *args: Any) -> Any: # type: ignore[override] + with dynamo_timed("GraphLowering.run"): + return super().run(*args) + + def register_operation(self, op: ir.Operation) -> str: + assert op.operation_name is None, f"Operation registered twice: {op}" + assert isinstance(op, ir.Operation) + name = self.qualify_name(f"op{len(self.operations)}") + self.operations.append(op) + self.name_to_op[name] = op + op.operation_name = name + return name + + def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str: + name = self.qualify_name(f"buf{len(self.buffers)}") + self.buffers.append(buffer) + self.name_to_buffer[name] = buffer + if ( + # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144 + not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements()) + and buffer.get_device() is not None + ): + self.add_device_info(buffer.get_device()) + + if set_name: + buffer.name = name + return name + + def register_operation_list(self, operation_names: List[str]) -> str: + name = self.qualify_name("list_" + "_".join(operation_names)) + self.lists[name] = operation_names + return name + + def register_users_of( + self, node_output: Union[Iterable[ir.IRNode], ir.IRNode] + ) -> None: + def register(value: Union[Iterable[ir.IRNode], ir.IRNode]) -> None: + if isinstance(value, (list, tuple)): + for x in value: + register(x) + if isinstance(value, ir.TensorBox): + for read_name in value.get_read_names(): + self.name_to_users[read_name].append(value) + + register(node_output) + + def mark_buffer_mutated(self, name: str) -> None: + """ + When a buffer is mutated we need to make sure all the reads to + the old version are realized before the mutation happens. + """ + assert isinstance(name, str) + self.mutated_buffers.add(name) + + if name not in self.name_to_users: + return + + for user in self.name_to_users[name]: + user.realize() + + def get_original_value_of_constant(self, name: str) -> torch.Tensor: + """ + In AOTI, module buffers may have been mutated during the tracing and compilation. + Thus we need to read from previously stored original buffers, to make sure the + generated model.so uses correct initial values. + """ + assert name in self.allocated_constant_name and name in self.constants, ( + "Can not find the original value for " + name + ) + orig_name = get_cloned_parameter_buffer_name(self.allocated_constant_name[name]) + return ( + self.module.meta[orig_name] + if orig_name in self.module.meta + else self.constants[name] + ) + + def allocate_non_dup_const_name( + self, name: Optional[str], data: Union[Tensor] + ) -> str: + orig_name = name + if not config.aot_inductor.use_runtime_constant_folding: + for constant_name, value in self.constants.items(): + if ( + not data.is_mkldnn + and data.size() == value.size() + and data.stride() == value.stride() + and data.dtype == value.dtype + and data.device == value.device + and data.untyped_storage().data_ptr() + == value.untyped_storage().data_ptr() + and data.storage_offset() == value.storage_offset() + ): + return constant_name + + if name is None: + name = f"constant{len(self.constants)}" + assert name is not None + if name[0].isdigit(): + name = f"constant_{name}" + name = self.qualify_name(name) + # We may generate a var name for each constant in the codegen. + # Let's only keep sane characters. + prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name) + name = prefix + cnt = 0 + while name in self.constants: + name = f"{prefix}_{cnt}" + cnt += 1 + self.constants[name] = data + self.constant_reprs[name] = ( + f"{data.device!r} {data.dtype!r} " + f"{tuple(data.size())!r} {tuple(data.stride())!r} " + f"{hash(data):x}" + ) + self.allocated_constant_name[name] = orig_name # type: ignore[assignment] + return name + + def add_tensor_constant( + self, data: Tensor, name: Optional[str] = None + ) -> TensorBox: + new_name = self.allocate_non_dup_const_name(name, data) + return TensorBox.create( + ir.ConstantBuffer( + new_name, + FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)), + ) + ) + + def constant_name(self, name: str, device_override: Optional[torch.device]) -> str: + """ + We AOT copy constants to the devices they are needed on. + If device_override doesn't match the constant's device, then + copy it and return a different name. + """ + if self.constants[name].device == device_override or device_override is None: + return name + with torch.utils._python_dispatch._disable_current_modes(): + # caller might have OrderedSet fake tensor mode which will create a fake tensor + # when calling .to, so unset modes here + return self.allocate_non_dup_const_name( + f"{name}_{device_override.type}{device_override.index or 0}", + self.constants[name].to(device_override), + ) + + def placeholder( + self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override] + ) -> Union[Expr, TensorBox, None]: + example = super().placeholder(target, args, kwargs) # type: ignore[arg-type] + self.graph_input_names.append(target) + if isinstance(example, SymTypes): + expr = example.node.expr + self.graph_inputs[target] = expr + return expr + elif isinstance(example, (int, bool, float)): + expr = sympy.sympify(example) + self.graph_inputs[target] = expr + return expr + elif example is None: + return None + if isinstance(example, BackwardState): + # Ignored arg, must be unused + # Alternately we could filter this out in AotAutograd + return None + assert isinstance(example, torch.Tensor), example + # todo(chilli): We can remove the last check once we turn buffers into + # static shape tensors. That's a hack to workaround Inductor believing + # the buffer should be static but us passing in a fake tensor with + # symbolic shapes. + if not example._has_symbolic_sizes_strides: + # the first N inputs are weights + sizes, strides = self.static_sizes_strides(example) + else: + sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment] + # TODO(jansel): handle input aliasing + target = self.qualify_name(target) + tensor = TensorBox.create( + InputBuffer( + target, + FixedLayout(example.device, example.dtype, sizes, strides), + ) + ) + self.graph_inputs[target] = tensor + self.graph_inputs_original[target] = tensor.data.data + if self.current_node.users: # cudagraphs should work with an unused CPU input + self.add_device_info(example.device) + + # Note: [Input Alignment handling in Inductor] + # Alignment matters for generating efficient code. Some operations, + # e.g. vectorized loads, can only be performed on aligned inputs. + # + # But if we codegen assuming aligned inputs and then get unaligned + # inputs at runtime, then we are forced to clone - which is bad for + # both perf and memory usage. + # + # One option would be to guard on storage_offset%ALIGNMENT, and then + # codegen based on this. But storage_offset guards turned out to be + # expensive and cause recompiles; Instead, we're generating code + # based on the alignment of the example input without guarding. + with maybe_get_suppress_shape_guards_ctx(): + if should_assume_input_aligned(example): + self.aligned_inputs.add(target) + return tensor + + def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg, override] + if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): + return super().call_function(target, args, kwargs) + + # hasattr on OpOverloadPacket is slow, check isinstance first + if not isinstance(target, torch._ops.OpOverloadPacket) and hasattr( + target, "_inductor_lowering_function" + ): + # passthrough lowerings from .pattern_matcher + return target(*args, **kwargs) + + if target not in lowerings: + assert isinstance( + target, torch._ops.OpOverload + ), f"{target} is not an OpOverload" + base_name = target.name().split(".")[0] + if base_name in FALLBACK_ALLOW_LIST: + make_fallback(target) + elif config.implicit_fallbacks: + error = ( + MissingOperatorWithDecomp + if get_decompositions([target]) + else MissingOperatorWithoutDecomp + ) + log.info( + "Creating implicit fallback for:\n%s", + error.operator_str(target, args, kwargs), + ) + make_fallback(target) + + elif get_decompositions([target]): + # There isn't a good way to dynamically patch this in + # since AOT Autograd already ran. The error message tells + # the user how to fix it. + raise MissingOperatorWithDecomp(target, args, kwargs) + else: + raise MissingOperatorWithoutDecomp(target, args, kwargs) + + try: + log.debug(" via %s", lowerings[target]) # type: ignore[index] + out = lowerings[target](*args, **kwargs) # type: ignore[index] + return out + except Exception as e: + raise LoweringException(e, target, args, kwargs).with_traceback( + e.__traceback__ + ) from None + + @staticmethod + def can_inline_constant(t: torch.Tensor) -> bool: + """ + True if this is a small constant attr that will be inlined. + """ + return len(t.shape) == 1 and t.shape[0] <= 8 + + def get_attr( + self, target: str, args: Tuple[()], kwargs: Dict[str, object] # type: ignore[override] + ) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]: + # this is a constant + value = getattr_recursive(self.module, target) # type: ignore[arg-type] + + if isinstance(value, torch.fx.GraphModule): + return ir.Subgraph(name=target, graph_module=value) + + if isinstance(value, torch._C.ScriptObject): + self.torchbind_constants[target] = value + self.constant_reprs[target] = "" + return TorchBindObject(target, value) + + assert isinstance(value, torch.Tensor) + if ( + config.aot_inductor.use_runtime_constant_folding + or config.always_keep_tensor_constants + or unsupported_output_tensor(value) + ): + return self.add_tensor_constant(value, target) + + with no_dispatch(): + if value.shape == (): + return Constant(value.item(), value.dtype, value.device) + if self.can_inline_constant(value): + log.debug("Inlining constant: %s ", str(target)) + # tensor lowering has constant inlining logic + from .lowering import tensor + + return tensor(value.tolist(), dtype=value.dtype, device=value.device) + + return self.add_tensor_constant(value, target) + + def call_module(self, target: Any, args: Any, kwargs: Any) -> NoReturn: + raise AssertionError + + def call_method(self, target: Any, args: Any, kwargs: Any) -> NoReturn: + raise AssertionError + + def output( + self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override] + ) -> None: + result = super().output(target, args, kwargs) # type: ignore[arg-type] + if not isinstance(result, (tuple, list)): + # nested subgraphs can have singleton outputs + result = (result,) + assert isinstance(result, (tuple, list)), type(result) + assert all( + isinstance( + x, + ( + TensorBox, + ir.Constant, + type(None), + ir.ConstantBuffer, + sympy.Expr, + sympy.logic.boolalg.Boolean, + int, + ir.EffectfulKernel, + ), + ) + for x in result + ), result + + fx_node_args = V.graph.current_node.args[0] # type: ignore[arg-type] + if not isinstance(fx_node_args, (tuple, list)): + # nested subgraphs can have singleton outputs + fx_node_args = (fx_node_args,) + result = [ir.ExternKernel.realize_input(x) for x in result] + result_correct_strides = [] + + assert len(fx_node_args) == len(result) + for r, fx_node in zip(result, fx_node_args): + if not isinstance(r, (ir.TensorBox, ir.BaseView)): + result_correct_strides.append(r) + else: + # AOT Autograd tries to detect stride divergence of inductor from output metadata. + # Here, we try to avoid spurious divergence by matching insignificant strides such as + result_correct_strides.append( + self.try_match_insignificant_strides( + r, fx_node.meta["val"].stride() + ) + ) + + self.graph_outputs = result_correct_strides + value: ir.IRNode + for name, value in self.graph_inputs.items(): + assert isinstance( + value, (TensorBox, sympy.Expr) + ), f"Unsupported inductor graph input type: {type(value)}" + if not isinstance(value, TensorBox): + continue + value.realize() + assert isinstance(value, TensorBox) + value = value.data + assert isinstance(value, ir.StorageBox) + value_storage_box = value + value = value.data + if not isinstance(value, InputBuffer) or value.get_name() != name: + # one of our inputs was mutated, need to turn that into a copy + ir.MutationLayoutSHOULDREMOVE.realize_into( + value, self.graph_inputs_original[name] + ) + # replace output with mutated input + try: + ind = self.graph_outputs.index(value_storage_box) + self.graph_outputs[ind] = self.graph_inputs_original[name] + except ValueError: + pass + + self.finalize() + log.debug( + "Force channels last inputs for %d conv for the current graph with id %d", + self.num_channels_last_conv, + self.graph_id if self.graph_id is not None else -1, + ) + + def finalize(self) -> None: + for buf in self.buffers: + buf.decide_layout() + + @contextmanager + def set_current_node(self, node: torch.fx.Node): # type: ignore[no-untyped-def] + old = self.current_node + try: + self.current_node = node + yield + finally: + self.current_node = old + + def try_match_insignificant_strides( + self, + tensor: Union[ir.TensorBox, ir.BaseView], + meta_strides_inp: Tuple[Union[int, torch.SymInt], ...], + ) -> Union[ir.TensorBox, ir.BaseView]: + """ + Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant + dimensions - size 0 or 1 - will be updated. + + If there are real stride differences (NHWC vs NCHW) then the input will be returned. + """ + + # should have already been realized + assert torch._inductor.ir.is_storage_and_layout(tensor) + + meta_strides = [ + s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_strides_inp + ] + + if all( + self.sizevars.statically_known_equals(s1, s2) + for s1, s2 in zip(meta_strides, tensor.get_stride()) + ): + return tensor # type: ignore[arg-type] + + def significant_strides_equal( + shape: Sequence[Union[Expr, int]], + meta_strides: Sequence[Union[Expr, int]], + tensor_strides: Sequence[Union[Expr, int]], + ) -> bool: + for dim, s1, s2 in zip(shape, meta_strides, tensor_strides): + if self.sizevars.statically_known_leq(dim, 1): # type: ignore[arg-type] + continue + + if not self.sizevars.statically_known_equals(s1, s2): + return False + + return True + + if not significant_strides_equal( + tensor.get_size(), meta_strides, tensor.get_stride() + ): + return tensor + + storage, old_layout = torch._inductor.ir.as_storage_and_layout(tensor) + new_stride = list(old_layout.stride) + for i, s in enumerate(tensor.get_size()): + if self.sizevars.statically_known_leq(s, 1): # type: ignore[arg-type] + new_stride[i] = meta_strides[i] + + new_layout = torch._inductor.ir.FixedLayout( + old_layout.device, + old_layout.dtype, + old_layout.size, + new_stride, + old_layout.offset, + ) + return ir.TensorBox(torch._inductor.ir.ReinterpretView(storage, new_layout)) + + def propagate_mutation( + self, + fx_node: torch.fx.Node, + old_args: Tuple[Any], + old_kwargs: Dict[str, Any], + new_args: Tuple[Any], + new_kwargs: Dict[str, Any], + ) -> None: + """Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs. + + Assumes we may have cloned old_args/old_kwargs into new_args/new_kwargs + and then called fx_node(*new_args, **new_kwargs). + + If fx_node mutates any of new_args/new_kwargs, and they are different from + old_args/old_kwargs, then we need to update the original tensor. + """ + assert isinstance(fx_node.target, torch._ops.OpOverload) + assert len(old_args) == len(new_args) + assert len(old_kwargs) == len(new_kwargs) + + def maybe_propagate( + schema_arg: torch._C.Argument, old_arg: ir.IRNode, new_arg: ir.IRNode + ) -> None: + if old_arg is new_arg: + return + if schema_arg.alias_info is not None and schema_arg.alias_info.is_write: + # The lowering for copy_ is smart enough to "replace" old_arg with + # new_arg in all future uses so a copy_ kernel never gets emitted. + self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {}) + + schema = fx_node.target._schema + for idx, (old_arg, new_arg) in enumerate(zip(old_args, new_args)): + schema_arg = schema.arguments[idx] + maybe_propagate(schema_arg, old_arg, new_arg) + + schema_kwargs = {arg.name: arg for arg in schema.arguments} + + for key in old_kwargs.keys(): + old_arg = old_kwargs[key] + new_arg = new_kwargs[key] + schema_arg = schema_kwargs[key] + maybe_propagate(schema_arg, old_arg, new_arg) + + def run_node(self, n: torch.fx.Node) -> object: + def debug(msg: str) -> None: + log.debug("lowering %s %s", LazyString(n.format_node), msg) + + buffer_watermark = len(self.buffers) + operation_watermark = len(self.operations) + + origins = {n} + is_call_function = n.op == "call_function" + if is_call_function: + args, kwargs = self.fetch_args_kwargs_from_env(n) + origins |= gather_origins(args, kwargs) + with ir.IRNode.current_origins(origins), self.set_current_node( # type: ignore[arg-type] + n + ), V.set_current_node( + n + ): + if ( + n.op == "call_function" + and n.target is not operator.getitem + and fallback_node_due_to_unsupported_type(n) + ): + debug("fallback_handler") + result = fallback_handler(n.target, add_to_fallback_set=False)( + *args, **kwargs # type: ignore[possibly-undefined] + ) + elif n.op == "call_function" and ( + layout_constraints := maybe_layout_constraints(n.target) # type: ignore[arg-type] + ): + debug("layout_constraints") + old_args = args # type: ignore[possibly-undefined] + old_kwargs = kwargs # type: ignore[possibly-undefined] + args, kwargs = layout_constraints(n, *args, **kwargs) # type: ignore[index] + result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type] + # layout_constraints are allowed to make new copies of the inputs. + # if they do, and if the target is mutable, then we need to + # write the new values back into the original inputs. + self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined] + elif is_magic_method(n.target): + # TODO: this is sus, it probably should be handled in the + # lowerings themselves similarly to sym_size/sym-stride + # https://github.com/pytorch/pytorch/issues/127789 + debug("is_magic_method") + if isinstance( + n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool) + ): + result = n.meta["val"].node.expr + else: + result = super().run_node(n) + else: + debug("") + result = super().run_node(n) + + # require the same stride order for dense outputs, + # 1. user-land view() will not throw because inductor + # output different strides than eager + # long term the solution is to make view() always succeed + # with infallible strides. + # 2: as_strided ops, we need make sure its input has same size/stride with + # eager model to align with eager behavior. + as_strided_ops = [ + torch.ops.aten.as_strided.default, + torch.ops.aten.as_strided_.default, + torch.ops.aten.as_strided_scatter.default, + torch.ops.aten.resize.default, + torch.ops.aten.resize_as.default, + ] + is_output = any(user.op == "output" for user in n.users) + is_input_for_as_strided = any( + user.target in as_strided_ops for user in n.users + ) + + if n.meta.get("inductor_realize_to_strides", False) and isinstance( + result, TensorBox + ): + result.realize() + strides = n.meta["val"].stride() + sym_strides = torch._inductor.utils.any_is_symbolic(*strides) + if ( + not hasattr(result, "get_stride") + or result.get_stride() != strides + and not sym_strides + ): + stride_order = ir.get_stride_order(strides) + result = ir.ExternKernel.require_stride_order(result, stride_order) + if ( + is_output + and isinstance(result, TensorBox) + and isinstance(result.data, ir.BaseView) + ): + # Realize so that outputs are correctly aliased + result.realize() + + if (is_output or is_input_for_as_strided) and isinstance( + n.meta["val"], torch.Tensor + ): + strides = n.meta["val"].stride() + if len(strides): + allow_padding = ( + config.pad_outputs or n.name not in self.user_visible_outputs + ) and not is_input_for_as_strided + dense = torch._prims_common.is_non_overlapping_and_dense( + n.meta["val"] + ) + unbacked_symbols_in_strides = ( + len(free_unbacked_symbols(strides)) > 0 + ) + if ( + not unbacked_symbols_in_strides + and dense + and len(result.get_size()) == 4 + and n in self.nodes_prefer_channels_last + and n.name not in self.user_visible_outputs + and not is_input_for_as_strided + ): + strides = ir.FlexibleLayout.stride_ordered_for_memory_format( + result.get_size(), torch.channels_last + ) + if not unbacked_symbols_in_strides and len(strides): + # To avoid converting possible view ops to a copy kernel, we use the previous + # require_exact_strides to handle views. But ultimately it's better to require + # the right strides at the tensor definition. + if n.meta["val"]._is_view() or isinstance( + result.data, ir.BaseView + ): + result = ir.ExternKernel.require_stride_order( + result, + ir.get_stride_order(strides), + allow_padding=allow_padding, + ) + else: + strides = [ + s.node.expr if isinstance(s, torch.SymInt) else s + for s in strides + ] + result = ir.ExternKernel.require_exact_strides( + result, strides, allow_padding=allow_padding + ) + + # Realize if (1) any user need inputs realized, or (2) there is + # already too many reads and rematerializing can be bad. + num_users = len(OrderedSet(n.users)) + if num_users > 1 and isinstance(result, TensorBox): + for user in n.users: + if user.target in needs_realized_inputs: + result.realize_hint() + # This inclusion is somewhat controversial (from + # discussion between Horace, Natalia, and Elias). + # Currently, it's not very clear why this is helpful. + # The general idea here is that even though a node may + # have FlexibleLayout, we still often *treat* it as if + # it was contiguous. This appears to sometimes result in + # suboptimal behavior. + # + # When we do a better job selecting layout, we should + # revisit this. + need_fixed_layout = [ + torch.ops.aten.convolution_backward.default, + torch.ops.aten.mm.default, + torch.ops.aten._int_mm.default, + ] + need_fixed_channels_last_layout = [] + if not self.layout_opt: + need_fixed_layout.append(torch.ops.aten.convolution.default) + if torch._C._has_mkldnn: + need_fixed_layout += [ + torch.ops.mkldnn._linear_pointwise.default, + torch.ops.mkldnn._linear_pointwise.binary, + torch.ops.aten.mkldnn_rnn_layer.default, + torch.ops.onednn.qlinear_pointwise.default, + torch.ops.onednn.qlinear_pointwise.tensor, + torch.ops.onednn.qlinear_pointwise.binary, + torch.ops.onednn.qlinear_pointwise.binary_tensor, + ] + need_fixed_channels_last_layout += [ + torch.ops.mkldnn._convolution_pointwise.default, + torch.ops.mkldnn._convolution_pointwise.binary, + torch.ops.mkldnn._convolution_pointwise_.binary, + torch.ops.mkldnn._convolution_transpose_pointwise.default, + torch.ops.onednn.qconv2d_pointwise.default, + torch.ops.onednn.qconv2d_pointwise.binary, + ] + if torch._C.has_mkl: + need_fixed_layout += [torch.ops.mkl._mkl_linear.default] + if user.target in need_fixed_layout: + result = ir.ExternKernel.require_stride_order( + result, + ir.get_stride_order(n.meta["val"].stride()), + allow_padding=True, + ) + if ( + user.target in need_fixed_channels_last_layout + and n is user.args[0] + ): + result = ir.ExternKernel.require_stride_order( + result, + ir.get_stride_order( + make_channels_last_strides_for(n.meta["val"].shape) + ), + ) + if user.op == "output": + if isinstance(result.data.data, (Pointwise, Reduction)): + result.realize() + + # TODO(jansel): introduce a store vs inline choice + result.mark_reuse(len(n.users)) + + # Realize if the IRNode already has accumulated lots of reads + if isinstance(result, TensorBox) and result.has_exceeded_max_reads(): + # Prevent excessive accumulation in a computed buffer, when + # there are multiple branches each with small number of memory + # reads, but they converge to a user. + result.realize_hint() + + # Realize if a Pointwise has too much stuff to be inlined. + # As this may cause RecursionError during Inductor's evaluation. + if isinstance(result, TensorBox) and isinstance(result.data, StorageBox): + curr = result.data.data + if isinstance(curr, Pointwise): + # Use inner fn as a rough proxy. Good enough. + if curr.has_large_inner_fn(): + result.realize() + + # This is not complete, but it doesn't have to be: origin_node + # tracking is best effort. The logic here critically relies on direct + # TensorBox -> StorageBox denoting a non-view; we don't bother trying + # to get views to work. Feel free to add any extra cases as needed. + # + # Note: we can't YOLO tree_map over this result, because if there are + # buffers or a view involved, we might not be able to validly assign + # the origin_node here. + if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox): + if isinstance(result.data.data, ir.Loops): + result.data.data.origin_node = n + elif isinstance(result.data.data, ir.Buffer): + result.data.data.origin_node = n + if isinstance(result.data.data, ir.ComputedBuffer) and isinstance( + result.data.data.data, ir.Loops + ): + result.data.data.data.origin_node = n + # Not really multi-output, can straightforwardly recurse in + elif ( + isinstance(result.data.data, ir.MultiOutput) + and not result.data.data.indices + ): + if isinstance(result.data.data.inputs[0], ir.Buffer): + result.data.data.inputs[0].origin_node = n + + self.register_users_of(result) + + new_unbacked_defs: OrderedSet[sympy.Symbol] = OrderedSet() + for buf in self.buffers[buffer_watermark:]: + new_unbacked_defs |= buf.get_unbacked_symbol_defs() + for op in self.operations[operation_watermark:]: + new_unbacked_defs |= op.get_unbacked_symbol_defs() + + def format_new_defs() -> str: + r = [] + for buf in self.buffers[buffer_watermark:]: + r.append( + f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n" + ) + for op in self.operations[operation_watermark:]: + r.append( + f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n" + ) + return "***\n".join(r) + + if n.op != "placeholder": + # Note [Backwards runtime asserts] + # Backwards poses an interesting problem for deferred runtime + # asserts. In the easy case, we may solely close over data + # dependent sized tensors, and there are no binding sites for + # unbacked SymInts. In this case, we can just drop all the + # runtime asserts on the floor: no non-placeholder bindings, no + # problem. + # + # However, it is *possible* for a fresh runtime assert to show up + # between forwards and backwards. Right now, the freezing process + # that happens when we lower forwards means that we will freeze + # runtime asserts, and then the moment the backwards lowering + # process attempts to add a new deferred runtime assert, we will + # fail. Let's say you remove that assert. Now when we get here, + # we need to make sure we actually emit these asserts (because we + # can't emit them in forwards, we already compiled it). So we + # have to do something here. But we don't want to reemit ALL + # deferred runtime asserts, we only want to emit the NEW ones. + # Therefore needing some sort of stratification in the ShapeEnv. + # This is all doable, it just hasn't been done yet. + shape_env = V.graph.sizevars.shape_env + + def make_assert(expr: Expr, msg: str) -> None: + assert_op = ir.AssertScalar(expr, msg) + self.register_buffer(assert_op, set_name=True) + self.register_operation(assert_op) + + for i0 in new_unbacked_defs: + ras = self.ras_by_symbol.pop(i0, []) + # NB: size-like not needed, we won't retrace + vr = shape_env.var_to_range[i0] + if not shape_env._default_unspecified_value_range().issubset(vr): + + def is_convertible(s: Expr) -> bool: + if s in (int_oo, -int_oo): + return False + try: + int(s) + return True + except TypeError: + return False + + if is_convertible(vr.lower): + make_assert(i0 >= vr.lower, f"{i0} >= {vr.lower}") + if is_convertible(vr.upper): + make_assert(i0 <= vr.upper, f"{i0} <= {vr.upper}") + + for ra in ras: + fvs = free_unbacked_symbols(ra.expr) + missing = fvs - self.bound_unbacked_symbols + if missing: + i1 = min(missing, key=str) + self.ras_by_symbol.setdefault(i1, []).append(ra) + else: + make_assert(ra.expr, f"{ra.expr}") + + self.bound_unbacked_symbols |= new_unbacked_defs + + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {}) + ) + # When we do lowering, it is possible we reallocate unbacked SymInts. + # So we need to line up the unbacked SymInts when performing the test + # here + # + # In principle, we could permit lowering to introduce MORE unbacked + # SymInts: as long as all the old unbacked ones are accounted for, + # it's fine for inductor to introduce extra calls to item()/unbacked() + # whatever. This actually happens in practice when an unbacked SymInt + # gets memoized away; naively, when Inductor reprocesses a kernel, it + # doesn't know that the memo still applies, and ends up allocating a + # new symbol. However, this is generally a bad thing: we may still + # end up needing to test equalities on the symbols, and a fresh + # symbol is likely to hit lots of GuardOnDataDependent errors that + # we already know facts for. + renamed_unbacked_bindings = OrderedSet( + V.fake_mode.shape_env.unbacked_renamings.get(s, s) + for s in unbacked_bindings.keys() + ) + assert new_unbacked_defs >= renamed_unbacked_bindings, ( + f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n" + f"fx node is: {n.format_node()}\n" + f"new operations are:\n\n{format_new_defs()}" + ) + + return result + + def validate_can_generate_cpp_wrapper(self) -> None: + if config.disable_cpp_codegen: + raise CppWrapperCodeGenError("C++ codegen is disabled") + + if sys.platform not in ["linux", "darwin", "win32"]: + raise CppWrapperCodeGenError(f"Unsupported platform {sys.platform}") + + for value in self.graph_inputs.values(): + dtype = None + if isinstance(value, TensorBox): + dtype = value.get_dtype() + elif isinstance( + value, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) + ): + dtype = may_get_constant_buffer_dtype(value) + + if not supported_dtype_of_cpp_wrapper(dtype, self.cuda): + raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}") + + def init_wrapper_code(self) -> None: + self.cuda = "cuda" in self.device_types + if self.cpp_wrapper: + self.validate_can_generate_cpp_wrapper() + + device_types = self.device_types.copy() + device_types.discard("cpu") + device_types.discard("meta") + # TODO(Eikan): Only support mixing cpu and other device now. + assert len(device_types) <= 1, "Does not support mixing {}".format( + "+".join(device_types) + ) + only_cpu = len(device_types) == 0 + device_type = "cpu" if only_cpu else device_types.pop() + + self.device_ops = get_device_op_overrides(device_type) + wrapper_code_gen_cls = get_wrapper_codegen_for_device( + device_type, self.cpp_wrapper + ) + assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported" + self.wrapper_code = wrapper_code_gen_cls() + + if self.const_module: + # If we have const module, we could reuse the kernels + # This could avoid duplication and save time on doing recompilation (if Triton.) + self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter + self.wrapper_code.src_to_kernel = ( + self.const_module.wrapper_code.src_to_kernel + ) + + def codegen_with_cpp_wrapper(self) -> Tuple[str, List[Tuple[int, Node]]]: + """ + For CPU, the cpp wrapper codegen is done in one pass. + For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python + wrapper code and run it to generate autotuned kernel binaries in the first pass; and then + generate cpp wrapper code and compile it to a dynamic library in the second pass. + """ + if "cuda" in self.device_types: + # first pass + self.cpp_wrapper = False + # Although triton.store_cubin was OrderedSet in compile_fx, the backward pass didn't pick + # that up. In theory it should work by only setting triton.store_cubin to True here, + # but that will cause a problem when use_runtime_constant_folding is OrderedSet. + with config.patch({"triton.store_cubin": True}): + compiled = self.compile_to_module().call + + if not config.triton.autotune_at_compile_time: + + def materialize( + x: Union[torch.SymInt, torch.SymFloat, torch.Tensor] + ) -> Union[int, float, torch.Tensor]: + if x is None: + return None + elif isinstance(x, (torch.SymInt, torch.SymFloat)): + # Need concrete value to run dynamic shapes and tune the result + return x.node.hint + elif isinstance(x, FakeTensor): + return defake(x) + else: + assert isinstance( + x, torch.Tensor + ), "Unknown type when creating real inputs" + str(type(x)) + return x + + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context is not None and not isinstance( + V.real_inputs, NullHandler + ): + if tracing_context.output_strides: + tracing_context.output_strides.clear() + + params_flat = [ + param + for param in tracing_context.params_flat # type: ignore[union-attr] + if param is not None + ] + real_inputs = [ + materialize(x) + for x in itertools.chain(params_flat, V.real_inputs) + ] + else: + # In the backward pass, V.real_inputs is not OrderedSet. + # Generating random inputs based on self.example_inputs sometimes can be problematic, + # e.g. illegal memory access. A comprehensive fix is to autotune in a separate process. + real_inputs = [ + materialize(x) + for x in ( + self.example_inputs + if isinstance(V.real_inputs, NullHandler) + else V.real_inputs + ) + ] + + if self.mutated_inputs: + from .compile_fx import clone_preserve_strides + + mutated_input_idxs = [ + idx + for idx, name in enumerate(self.graph_inputs) + if name in self.mutated_inputs + and isinstance(real_inputs[idx], torch.Tensor) + ] + for idx in mutated_input_idxs: + # clone mutated Tensor inputs to avoid mutating them in + # the first pass of the CPP wrapper-based compilation, as + # this will lead to a side effect on the example inputs: + # e.g. if torch.compile(f)(x) if called on input-mutating + # f, the inputs x will be mutated twice in the process: + # once here, and again when running the compiled model; + # this will also lead to a numerically incorrect output + mutated_inp = real_inputs[idx] + assert isinstance(mutated_inp, torch.Tensor) + real_inputs[idx] = clone_preserve_strides(mutated_inp) + del mutated_inp + + with torch.utils._python_dispatch._disable_current_modes(): + compiled(real_inputs) + del real_inputs + + # second pass + self.cpp_wrapper = True + self.removed_buffers.clear() + self.removed_operations.clear() + self.inplaced_to_remove.clear() + V.graph.sizevars.precomputed_replacements.clear() + V.graph.sizevars.inv_precomputed_replacements.clear() + with config.patch({"triton.autotune_at_compile_time": False}): + return self.codegen() + else: + # cpu + return self.codegen() + + def codegen(self) -> Tuple[str, List[Tuple[int, Node]]]: + from .scheduler import Scheduler + + self.init_wrapper_code() + + self.scheduler = Scheduler(self.operations) + V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes) + + self.wrapper_code.push_codegened_graph(self) + self.scheduler.codegen() + + log.debug( + "Finished codegen for all nodes. The list of kernel names available: %s", + V.graph.all_codegen_kernel_names, + ) + + result = self.wrapper_code.generate(self.is_inference) + self.wrapper_code.pop_codegened_graph() + return result + + def codegen_subgraph(self, parent_graph: "GraphLowering") -> None: + """ + This is a more compact version of the `codegen()` above + where we codegen this graph as a subgraph of some parent + graph. The parent graph is passed as an argument: the + intention is to inline codegening of the subgraph in + the parent graph's wrapper code (including the generated + kerenls). The wrapper code is not finalized (via `.generate()` + call), as this will be done in the parent graph's `codegen()`. + """ + from .scheduler import Scheduler + + self.wrapper_code = parent_graph.wrapper_code + self.device_ops = parent_graph.device_ops + self.cpp_wrapper = parent_graph.cpp_wrapper + + self.scheduler = Scheduler(self.operations) + self.scheduler.codegen() + + def count_bytes( + self, + ) -> Tuple[ + int, List[Tuple[BaseSchedulerNode, int]], List[Tuple[BaseSchedulerNode, float]] + ]: + total_bytes = 0 + node_counts = [] + node_runtimes = [] + for node in self.scheduler.nodes: + num_bytes = node.get_read_write_buffers_sizes() + total_bytes += num_bytes + node_counts.append((node, num_bytes // 4)) + node_runtimes.append((node, node.get_estimated_runtime())) + + return total_bytes, node_counts, node_runtimes + + @staticmethod + def save_output_code(code: str) -> None: + # No-op to be patched for unit tests + pass + + def compile_to_module(self) -> ModuleType: + with dynamo_timed( + "GraphLowering.compile_to_module", phase_name="code_gen", fwd_only=False + ): + return self._compile_to_module() + + def _compile_to_module(self) -> ModuleType: + from .codecache import PyCodeCache + + code, linemap = ( + self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() + ) + + GraphLowering.save_output_code(code) + output_code_log.debug("Output code: \n%s", code) + try: + linemap = [(line_no, node.stack_trace) for line_no, node in linemap] # type: ignore[misc] + key, path = PyCodeCache.write(code) + except Exception: + trace_structured( + "inductor_output_code", + # Just omit the filename, I still want the code though! + payload_fn=lambda: code, + ) + raise + else: + trace_structured( + "inductor_output_code", + lambda: {"filename": path}, + payload_fn=lambda: code, + ) + + mod = PyCodeCache.load_by_key_path( + key, + path, + linemap=linemap, # type: ignore[arg-type] + attrs={**self.constants, **self.torchbind_constants}, + ) + self.cache_key = key + self.cache_path = path + self.cache_linemap = linemap # type: ignore[assignment] + + # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029 + # TODO. Revisit this once the logging API is more mature + assert mod.__file__ is not None + + log_module_code(mod.__file__) + log.debug("Output code written to: %s", mod.__file__) + output_code_log.info("Output code written to: %s", mod.__file__) + if config.benchmark_kernel: + print(f"Compiled module path: {mod.__file__}", file=sys.stderr) + V.debug.output_code(mod.__file__) + V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug") + return mod + + def compile_to_fn(self) -> Any: + if self.aot_mode: + from .codecache import AotCodeCompiler + + assert self.cpp_wrapper, "AOT mode only supports C++ wrapper" + code, linemap = self.codegen_with_cpp_wrapper() + output_code_log.debug("Output code: \n%s", code) + + serialized_extern_kernel_nodes = None + if self.extern_kernel_nodes: + serialized_extern_kernel_nodes = self.extern_node_serializer( + self.extern_kernel_nodes + ) + output_code_log.debug( + "Serialized Extern Kernel Nodes: \n%s", + serialized_extern_kernel_nodes, + ) + + # Directly return the file path with the compiled code + return AotCodeCompiler.compile( + self, code, serialized_extern_kernel_nodes, cuda=self.cuda + ) + else: + return self.compile_to_module().call + + def get_output_names(self) -> List[str]: + return [ + node.get_name() + for node in self.graph_outputs + if not isinstance(node, ir.NoneAsConstantBuffer) + and not isinstance(node, ir.ShapeAsConstantBuffer) + ] + + def is_unspec_arg(self, name: str) -> bool: + # dynamo wraps unspec variable as 0d CPU tensor, + # need to convert to scalar during codegen (triton only) + return ( + name in self.graph_inputs.keys() + and self.graph_inputs[name].get_numel() == 1 + and self.graph_inputs[name].get_device().type == "cpu" + ) or name in self.zero_dim_cpu_tensor_list diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/hooks.py b/.venv/lib/python3.11/site-packages/torch/_inductor/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..9d8aeecd283185608260f0ffd16a19270f9e96b1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/hooks.py @@ -0,0 +1,30 @@ +# mypy: allow-untyped-defs +import contextlib +from typing import Callable, List, TYPE_CHECKING + + +if TYPE_CHECKING: + import torch + +# Executed in the order they're registered +INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = [] + + +@contextlib.contextmanager +def intermediate_hook(fn): + INTERMEDIATE_HOOKS.append(fn) + try: + yield + finally: + INTERMEDIATE_HOOKS.pop() + + +def run_intermediate_hooks(name, val): + global INTERMEDIATE_HOOKS + hooks = INTERMEDIATE_HOOKS + INTERMEDIATE_HOOKS = [] + try: + for hook in hooks: + hook(name, val) + finally: + INTERMEDIATE_HOOKS = hooks diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/index_propagation.py b/.venv/lib/python3.11/site-packages/torch/_inductor/index_propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..f4384d51b7d4adea0d55804090d684f9b7573b05 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/index_propagation.py @@ -0,0 +1,373 @@ +# mypy: allow-untyped-defs +"""This file implements the IndexPropagation ops handler, which wraps an +underlying handler to add a limited form of constant propagation, as well as +propagation of sympy expressions downstream of ops.index_expr calls. + +For example, say we have the IR: + + tmp0 = ops.index_expr(x, torch.int32) + tmp1 = ops.constant(2, torch.int32) + tmp2 = ops.mul(tmp0, tmp1) + tmp3 = ops.indirect_indexing(tmp2, x_size) + tmp4 = ops.load("buf0", tmp3) + +The underlying handler would just see: + + ops.load("buf0", x * 2) + +This is limited by the set of operators handled in the sympy expression +printers. So simple operations like minimum and maximum cannot be translated to +SymPy expressions yet, despite sympy.Min and sympy.Max existing. + +""" +import itertools +from dataclasses import dataclass +from typing import Any, Callable, Dict, Literal, Optional, overload, Tuple, Union +from typing_extensions import TypeAlias + +import sympy + +import torch +from torch._prims_common import dtype_to_type, is_integer_dtype +from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + +from .sizevars import evaluate_expr +from .utils import generate_assert +from .virtualized import V + + +_ExprType = Union[sympy.Expr, float, int, bool] + + +def _is_constant(val: _ExprType): + if isinstance(val, sympy.Basic): + return val.is_number + return isinstance(val, (int, float, bool)) + + +def upper_bound(val: _ExprType): + return bound_sympy(val).upper if isinstance(val, sympy.Expr) else val + + +@dataclass +class TypedExpr: + """A SymPy expression with associated type""" + + expr: _ExprType + dtype: torch.dtype + + def is_constant(self): + return _is_constant(self.expr) + + def __post_init__(self): + if _is_constant(self.expr): + self.expr = dtype_to_type(self.dtype)(self.expr) + + +class SymPyOps: + """An ops handler where all IR values are SymPy expressions + + When a value cannot be represented as a SymPy expression, the method is + either not defined, or returns NotImplemented + + """ + + @staticmethod + def identity(value: Any) -> Any: + return value + + @staticmethod + def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr: + return TypedExpr(value, dtype) + + @staticmethod + def index_expr(value: Union[sympy.Expr, int], dtype: torch.dtype) -> TypedExpr: + return TypedExpr(value, dtype) + + @staticmethod + def to_dtype( + value: TypedExpr, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = False, + ) -> TypedExpr: + return TypedExpr(value.expr, dtype) + + @staticmethod + def abs(x: TypedExpr) -> TypedExpr: + return TypedExpr(abs(x.expr), x.dtype) # type: ignore[arg-type] + + @staticmethod + def square(x: TypedExpr) -> TypedExpr: + return TypedExpr(x.expr * x.expr, x.dtype) + + @staticmethod + def add(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr + y.expr, result_type) + + @staticmethod + def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr - y.expr, result_type) + + @staticmethod + def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr * y.expr, result_type) + + @staticmethod + def neg(x: TypedExpr) -> TypedExpr: + return TypedExpr(-x.expr, x.dtype) + + @staticmethod + def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + return TypedExpr(FloorDiv(x.expr, y.expr), result_type) + + @staticmethod + def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr) + return TypedExpr(result_expr, result_type) + + @staticmethod + def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + x_expr = sympy.sympify(x.expr) + y_expr = sympy.sympify(y.expr) + # In these cases, remainder in Python == remainder in C++, so this transformation + # is sound + if ( + x_expr.is_nonnegative is not None + and x_expr.is_nonnegative == y_expr.is_positive + ): + result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr) + return TypedExpr(result_expr, result_type) + return NotImplemented + + @staticmethod + def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(sympy.Min(x.expr, y.expr), result_type) + + @staticmethod + def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(sympy.Max(x.expr, y.expr), result_type) + + +@dataclass +class IndexPropVar: + value: Any # Either an IR value, or TypedExpr if is_symbolic is true + is_symbolic: bool = False + + @staticmethod + def new_symbolic(expr: TypedExpr) -> "IndexPropVar": + return IndexPropVar(expr, is_symbolic=True) + + def __post_init__(self): + assert not self.is_symbolic or isinstance( + self.value, TypedExpr + ), "Symbolic IndexPropVar must contain a TypedExpr" + + +IndexPropResult: TypeAlias = Union[IndexPropVar, Tuple["IndexPropResult", ...]] + + +class IndexPropagation: + """Ops wrapper that tries to propagate constant and index_expr values through the computation. + + This aims to maximize the compile time simplification possible, and convert + indirect indexing from arange into normal static indexing. + + """ + + def __init__( + self, + inner: Any, + iter_ranges: Dict[sympy.Symbol, sympy.Expr], + indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr], + ) -> None: + self._inner = inner + self.shape_env = V.graph.sizevars.shape_env + + var_to_range = { + k: ValueRanges(0, upper_bound(v) - 1) for k, v in iter_ranges.items() + } + self.var_to_range = tuple( + itertools.chain(self.shape_env.var_to_range.items(), var_to_range.items()) + ) + # NOTE: this is intentionally kept as a reference so the caller can + # update it in-place + self.indirect_var_ranges = indirect_var_ranges + + axioms = [] + for x, s in iter_ranges.items(): + axioms.append(0 <= x) + axioms.append(x < s) + self.axioms = tuple(axioms) + self.shape_env.get_axioms() + + def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any: + # Construct a new constant/index_expr from the SymPy expression + if _is_constant(expr): + val = dtype_to_type(dtype)(expr) + return self._inner.constant(val, dtype) + return self._inner.index_expr(expr, dtype) + + def unwrap(self, a: Union[Any, IndexPropVar]) -> Any: + if isinstance(a, (list, tuple)): + return tuple(self.unwrap(v) for v in a) + + if not isinstance(a, IndexPropVar): + return a + + # Prefer the sympy representation if possible + if a.is_symbolic: + return self.materialize_expr(a.value.expr, a.value.dtype) + + return a.value + + def wrap(self, a) -> IndexPropResult: + if isinstance(a, (list, tuple)): + return tuple(self.wrap(v) for v in a) + return IndexPropVar(a) + + @overload + def fallback( + self, + name: Literal["indirect_indexing"], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> IndexPropVar: + ... + + @overload + def fallback( + self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> IndexPropResult: + ... + + def fallback( + self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> IndexPropResult: + # Fallback to the wrapped handler + new_args = [self.unwrap(a) for a in args] + new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()} + return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs)) + + def propagate_sympy( + self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> IndexPropResult: + # Build a new SymPy expression from this ops call + def unwrap(a: Union[Any, IndexPropVar]) -> Any: + if not isinstance(a, IndexPropVar): + return a + return a.value + + new_args = [unwrap(a) for a in args] + new_kwargs = {k: unwrap(v) for k, v in kwargs.items()} + new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs) + is_valid_expr = new_expr is not NotImplemented and ( + # Inductor doesn't expect floating point in sympy expressions, but + # allow floating point constants to be propagated + new_expr.is_constant() + or new_expr.expr.is_integer + ) + if not is_valid_expr: + return self.fallback(name, args, kwargs) + return IndexPropVar.new_symbolic(new_expr) + + def __getattr__(self, name: str) -> Callable[..., IndexPropResult]: + def inner(*args: Any, **kwargs: Any) -> IndexPropResult: + if not hasattr(SymPyOps, name): + return self.fallback(name, args, kwargs) + + var_arguments = [ + a + for a in itertools.chain(args, kwargs.values()) + if isinstance(a, IndexPropVar) + ] + if not all(v.is_symbolic for v in var_arguments): + return self.fallback(name, args, kwargs) + + return self.propagate_sympy(name, args, kwargs) + + return inner + + def statically_true(self, e): + """ + Given some iter_ranges, return a function that given an expression, returns whether + it is true or false using value ranges, guard knowledge and runtime_asserts. + + FIXME I think this may not be entirely right, as we may not be able to use all runtime_asserts + If this is an issue, just use guards in `self.axioms`. + + The proper way of handling this would be to have a global shape_env that adds + runtime_asserts as they happen in the code. Then, it shuld be used in SimplifyIndexing + to perform wrap_expr and in CSEProxy.check_bounds to elide upper / lower bounds also + for indirect_indexing + """ + var_to_range = ( + *self.var_to_range, + *( + (k, ValueRanges(0, upper_bound(v) - 1)) + for k, v in self.indirect_var_ranges.items() + ), + ) + return evaluate_expr(self.shape_env, e, self.axioms, var_to_range) + + def indirect_indexing( + self, + index: Union[Any, IndexPropVar], + size: Any, + check: bool = True, + wrap_neg=True, + ) -> Any: + if isinstance(index, IndexPropVar) and index.is_symbolic: + # If we find something we can convert into a direct indexing we do so + # We still need to (perhaps) wrap the expression and add bound checks + # We want to do this "constant folding", as we don't allow to fuse + # kernels into indirect indexing + + expr = sympy.sympify(index.value.expr) + + # TODO Perhaps move this logic to the simplify indexing pass + def wrap_expr(expr): + # Positive, negative, mixed + if self.statically_true(0 <= expr): + return expr + elif self.statically_true(expr < 0): + return expr + size + else: + return Where(expr < 0, expr + size, expr) + + # Sometimes it's easier to prove 0 <= expr than the weaker -size <= expr + can_prove_lower = self.statically_true(0 <= expr) or self.statically_true( + -size <= expr + ) + can_prove_upper = self.statically_true(expr < size) + if wrap_neg: + expr = wrap_expr(expr) + if generate_assert(check): + self.fallback( + "check_bounds", + (expr, size), + dict(lower=not can_prove_lower, upper=not can_prove_upper), + ) + return expr + + indirect_var = self.fallback( + "indirect_indexing", (index, size, check, wrap_neg), {} + ).value + return indirect_var diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py b/.venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py new file mode 100644 index 0000000000000000000000000000000000000000..82a23d3e60cf10eac46fc5f30b2a8d811ebe7360 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py @@ -0,0 +1,179 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import logging +from typing import Optional, Sequence + +import torch +from torch import _prims, Tensor + + +log = logging.getLogger(__name__) + + +def make_prim( + schema: str, + impl_aten, + return_type=_prims.RETURN_TYPE.NEW, + doc: str = "", + tags: Optional[Sequence[torch.Tag]] = None, +): + if isinstance(return_type, tuple): + + def meta(*args, **kwargs): + return tuple(_prims.TensorMeta(o) for o in impl_aten(*args, **kwargs)) + + else: + + def meta(*args, **kwargs): + return _prims.TensorMeta(impl_aten(*args, **kwargs)) + + return _prims._make_prim( + schema=schema, + return_type=return_type, + meta=meta, + impl_aten=impl_aten, + doc=doc, + tags=tags, + ) + + +def eager_force_stride(input_tensor: Tensor, stride) -> Tensor: + if input_tensor.stride() == stride: + return input_tensor + new_tensor = input_tensor.clone().as_strided( + input_tensor.shape, + stride, + ) + new_tensor.copy_(input_tensor) + return new_tensor + + +# Custom prims used for handling randomness +seed = make_prim( + "inductor_seed(Device device) -> Tensor", + lambda device: torch.randint(2**63 - 1, [], device=device), + doc="create a fresh seed (one per call) for use with inductor_rand", + tags=(torch.Tag.nondeterministic_seeded,), +) +seeds = make_prim( + "inductor_seeds(int count, Device device) -> Tensor", + lambda count, device: torch.randint(2**63 - 1, [count], device=device), + doc="Horizontal fusion of many inductor_seed() calls", + tags=(torch.Tag.nondeterministic_seeded,), +) +lookup_seed = make_prim( + # if inductor_lookup_seed changes, update partitioners.py + "inductor_lookup_seed(Tensor seeds, int index) -> Tensor", + lambda seeds, index: seeds[index], + doc="Extract a single seed from the result of inductor_seeds()", +) +random = make_prim( + "inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor", + lambda size, seed, mode: getattr(torch, mode)(size, device=seed.device), + doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused", +) +randint = make_prim( + "inductor_randint(SymInt low, SymInt high, SymInt[] size, Tensor seed) -> Tensor", + lambda low, high, size, seed: torch.randint(low, high, size, device=seed.device), + doc="torch.randint() using backend-specific RNG that can be fused", +) +force_stride_order = make_prim( + "inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor", + eager_force_stride, + doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise", +) +_unsafe_index_put_ = make_prim( + "_unsafe_index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)", + lambda self, indices, values, accumulate=False: torch.ops.aten.index_put_( + self, indices, values, accumulate + ), + doc="Unsafe index_put_ (doesn't issue device asserts)", +) +fma = make_prim( + "fma(Tensor a, Tensor b, Tensor c) -> Tensor", + lambda a, b, c: (a * b) + c, + doc="Fused multiply add: fma(a, b, c) -> (a * b) + c without rounding after the multiplication", +) + + +def _low_memory_max_pool2d_with_offsets_aten( + self, + kernel_size, + stride, + padding, + dilation, + ceil_mode, +): + vals, indices = torch.ops.aten.max_pool2d_with_indices( + self, kernel_size, stride, padding, dilation, ceil_mode + ) + + input_width = self.shape[-1] + kernel_width = kernel_size[1] + + bh_shape = [1] * self.ndim + bh_shape[-2] = -1 + bh = torch.arange(indices.shape[-2], dtype=torch.int64, device=self.device).view( + bh_shape + ) + + bw_shape = [1] * self.ndim + bw_shape[-1] = -1 + bw = torch.arange(indices.shape[-1], dtype=torch.int64, device=self.device).view( + bw_shape + ) + + hbase = bh * stride[0] - padding[0] + wbase = bw * stride[1] - padding[1] + + ih = indices // input_width + iw = indices - (ih * input_width) + + h_inc = ih - hbase + w_inc = iw - wbase + + offsets = h_inc * kernel_width + w_inc + + return vals, offsets.to(torch.int8) + + +def _low_memory_max_pool2d_offsets_to_indices_aten( + offsets, kernel_width, input_width, stride, padding +): + offsets = offsets.to(torch.int64) + h_inc = offsets // kernel_width + w_inc = offsets - (h_inc * kernel_width) + + bh_shape = [1] * offsets.ndim + bh_shape[-2] = -1 + bh = torch.arange(offsets.shape[-2], dtype=torch.int64, device=offsets.device).view( + bh_shape + ) + + bw_shape = [1] * offsets.ndim + bw_shape[-1] = -1 + bw = torch.arange(offsets.shape[-1], dtype=torch.int64, device=offsets.device).view( + bw_shape + ) + + hbase = bh * stride[0] - padding[0] + wbase = bw * stride[1] - padding[1] + + ih = hbase + h_inc + iw = wbase + w_inc + return ih * input_width + iw + + +_low_memory_max_pool2d_with_offsets = make_prim( + "_low_memory_max_pool2d_with_offsets(Tensor self, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950 + _low_memory_max_pool2d_with_offsets_aten, + return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW), + doc="Instead of returning indices, returns indices offsets.", +) + +_low_memory_max_pool2d_offsets_to_indices = make_prim( + "_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding) -> Tensor", # noqa: B950 + _low_memory_max_pool2d_offsets_to_indices_aten, + doc="Convert small int offsets to regular indices.", +) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/ir.py b/.venv/lib/python3.11/site-packages/torch/_inductor/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..0929996fe174523fe0b3aad44e2d20130ea7ba86 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/ir.py @@ -0,0 +1,6953 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import itertools +import logging +import textwrap +import traceback +from contextlib import nullcontext +from functools import partial +from typing import ( + Any, + Callable, + ClassVar, + ContextManager, + Dict, + Iterable, + List, + Literal, + Optional, + overload, + Sequence, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import TypeAlias +from unittest.mock import patch + +import sympy +from sympy import Expr, Integer, Symbol + +import torch._export.serde.schema as export_schema +import torch._logging +import torch.fx +import torch.utils._pytree as pytree +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.utils import identity +from torch._export.serde.serialize import GraphModuleSerializer +from torch._higher_order_ops.auto_functionalize import can_auto_functionalize +from torch._inductor import metrics +from torch._prims_common import ( + compute_required_storage_length, + is_boolean_dtype, + is_float_dtype, + make_channels_last_strides_for, + StrideType, +) +from torch._subclasses.fake_tensor import get_schema_info +from torch.fx.experimental.symbolic_shapes import ( + CallMethodKey, + compute_unbacked_bindings, + DivideByKey, + free_unbacked_symbols, + rebind_unbacked, + resolve_unbacked_bindings, + SymTypes, +) +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import SymT + +from . import config, dependencies +from .codegen.common import BackendFeature, index_prevent_reordering +from .dependencies import ( + extract_free_unbacked_symbols, + extract_input_node_reduction_ranges, + extract_read_writes, + var_builder, +) +from .loop_body import LoopBody +from .ops_handler import OpCounterCSE, OpCountResult +from .runtime.benchmarking import benchmarker +from .runtime.hints import ReductionHint +from .utils import ( + argsort, + cache_on_self, + ceildiv, + convert_shape_to_inductor, + convert_shape_to_symint, + developer_warning, + get_kernel_metadata, + is_dynamic, + is_gpu, + sympy_dot, + sympy_index_symbol, + sympy_index_symbol_with_prefix, + sympy_product, + sympy_subs, +) +from .virtualized import ops, OpsValue, V + + +if TYPE_CHECKING: + from .graph import GraphLowering + +_T = TypeVar("_T") +_U = TypeVar("_U") +_V = TypeVar("_V") + +_IntLike: TypeAlias = Union[int, Expr] + +log = logging.getLogger(__name__) +indent = functools.partial(textwrap.indent, prefix=" ") +aten = torch.ops.aten + +""" [Note: Inductor IR] + +Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each +lowering is registered to a particular aten operator, and expects inputs that +correspond to the aten schema. However, in place of torch Tensor inputs, lowerings +expect Inductor TensorBox inputs. + +TensorBox IR represents torch tensors. Tensors are sometimes single objects owning +storage, and sometimes views of another Tensor's storage. Mutating tensor operations +(such as add_()) affect the underlying storage and any associated views. Other operations +(such as .t_()) update metadata about the current view but don't modify the underlying storage. + +To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer. + +TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor +output from an operation. But just as torch.Tensors take different forms, TensorBox IR can +reference View IR or directly reference StorageBox IRs. + +Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops) +may take an existing TensorBox and point it to a new underlying View IR. + +Tensors that directly own storage are represented as a chain of: +TensorBox -> StorageBox -> Buffer +where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout. + +If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer +(leaving the old buffer unmodified and functionalizing the operation). + +Tensors backed by views add one more indirection to the IR. +TensorBox -> View -> StorageBox -> Buffer +In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox. + +Computation is represented by Operation nodes, with each operation producing 1 +or more output Buffers. In the case of mutations, these will be new Buffers that have the +mutated buffer listed in its get_mutation_names(). + +It is also possible to have an InputBuffer for which there is no corresponding Operation, +e.g. it may be a graph input or compile time constant. + +""" + + +_NodeOrNodes: TypeAlias = Union[ + int, + "TensorBox", + Dict[str, "TensorBox"], + "Symbol", + "IRNode", + Sequence[ + Optional[Union[int, Dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]] + ], +] + + +def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None: + def _check_tensorbox(nodes: Optional[_NodeOrNodes]) -> None: + # Could expand this to check deeper properties + # (e.g. TensorBox points to View or StorageBox) + if nodes is None: + pass + elif isinstance(nodes, (list, tuple)): + for node in nodes: + _check_tensorbox(node) + elif isinstance(nodes, dict): + for node in nodes.values(): + _check_tensorbox(node) + else: + assert isinstance( + nodes, + ( + torch._inductor.ir.ExpandView, + DynamicScalar, + AssertScalar, + TensorBox, + sympy.logic.boolalg.Boolean, + Expr, + int, + EffectfulKernel, + ), + ), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]" + + # Be picky about the accepted data structure (don't use pytree here) + _check_tensorbox(node_or_nodes) + + +def ops_wrapper(name: str) -> Callable[..., OpsValue]: + assert isinstance(name, str) + + def fn(*args: object, **kwargs: object) -> OpsValue: + return getattr(ops, name)(*args, **kwargs) + + return fn + + +def inverse_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]: + inv_order = dict(zip(order, range(len(order)))) + + def reindex(index: Sequence[_T]) -> Sequence[_T]: + assert len(index) == len(inv_order) + return [index[inv_order[i]] for i in range(len(index))] + + return reindex + + +def same_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]: + def reindex(index: Sequence[_T]) -> Sequence[_T]: + assert len(index) == len(order) + return [index[order[i]] for i in range(len(index))] + + return reindex + + +def fuse_reindexing( + reindex1: Callable[[Sequence[_U]], Sequence[_V]], + reindex2: Callable[[Sequence[_T]], Sequence[_U]], +) -> Callable[[Sequence[_T]], Sequence[_V]]: + def reindex(index: Sequence[_T]) -> Sequence[_V]: + return reindex1(reindex2(index)) + + return reindex + + +NHWC_STRIDE_ORDER = [3, 0, 2, 1] +NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1] + + +def stride_order2fill_order( + order: Sequence[Union[int, Integer]] +) -> Sequence[Union[int, Integer]]: + """ + Convert stride order to fill order + For channel last format, + + stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0] + """ + lookup = {pos: idx for idx, pos in enumerate(order)} + fill_order = [lookup[i] for i in range(len(order))] + return fill_order + + +def get_stride_order(seq: Sequence[Union[int, torch.SymInt, Expr]]) -> Sequence[int]: + """ + Convert strides to stride order + """ + sorted_idx: List[int] = argsort(seq) + out = [0 for _ in range(len(seq))] + for i, elem in enumerate(sorted_idx): + out[elem] = i + return out + + +@overload +def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None: + ... + + +@overload +def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: + ... + + +def ir_node_to_tensor( + x: Optional[IRNode], guard_shape: bool = True +) -> Optional[torch.Tensor]: + if x is None: + return None + + shape_fn: Callable[[Union[int, Expr]], Union[int, Expr]] + if not guard_shape: + shape_fn = V.graph.sizevars.size_hint + else: + shape_fn = identity + size = [shape_fn(s) for s in x.get_size()] + stride: StrideType + if is_storage_and_layout(x): + stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc, union-attr] + else: + stride = FlexibleLayout.contiguous_strides(size) # type: ignore[assignment] + dtype = x.get_dtype() + device = x.get_device() + size = convert_shape_to_symint(size) + stride = convert_shape_to_symint(stride) + with V.graph.sizevars.shape_env.suppress_guards(): + t = torch.empty_strided( + size=size, stride=stride, dtype=dtype, device=device + ).zero_() + return t + + +def may_convert_to_optional( + value: Optional[Sequence[_T]], +) -> Optional[Sequence[Optional[_T]]]: + if isinstance(value, list) and not value: + # [None] makes sure the cpp wrapper codegen will generate something like + # {std::nullopt} instead of {} + return [None] + return value + + +def get_device_type(x: object) -> Optional[str]: + if get_device := getattr(x, "get_device", None): + return get_device_type(get_device()) + if isinstance(x, torch.device): + return x.type + return None + + +def is_triton(x: object) -> bool: + dtype = get_device_type(x) + return bool(dtype and is_gpu(dtype)) + + +def is_cpu(x: object) -> bool: + return get_device_type(x) == "cpu" + + +class IRNode: + _current_origins: ClassVar[OrderedSet[Any]] = OrderedSet() + + @staticmethod + @contextlib.contextmanager + def current_origins(origins: OrderedSet[torch.fx.Node]): + old = IRNode._current_origins + IRNode._current_origins = old | origins + try: + yield + finally: + IRNode._current_origins = old + + def __post_init__(self): + self.origins = OrderedSet(self._current_origins) + self.traceback = traceback.format_stack() if config.debug_ir_traceback else None + + def get_read_names(self) -> OrderedSet[str]: + raise NotImplementedError(f"NYI on {type(self)}") + + def get_traceback(self): + return self.traceback + + def get_defining_op(self): + raise NotImplementedError + + def common_repr(self, shorten=True): + origins = f"origins={getattr(self, 'origins', '')}" + if shorten and len(origins) > 64: + # this can get *very* long + origins = f"{origins[:61]}..." + return [origins] + + def str_helper(self, lines, shorten=True, multiline=True): + lines = lines + self.common_repr(shorten) + lines = list(map(str, lines)) + if multiline: + new_lines = indent(",\n".join(lines)) + return f"{type(self).__name__}(\n{new_lines}\n)" + else: + return f"{type(self).__name__}({lines})" + + def get_dtype(self): + return self.dtype + + def get_layout(self): + raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!") + + def get_size(self): + raise NotImplementedError(f"get_size() is not implemented by {type(self)}!") + + @property + def shape(self): + return self.get_size() + + def get_numel(self): + return sympy_product(self.get_size()) + + def is_zero_elements(self): + return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] + + def realize(self): + """ + If the IRNode refers to data which has not been materialized (e.g., + it is a Pointwise/Reduction that could potentially have more + compute fused into it), realize the IRNode into physical memory, + ending the possibility of fusing into it, but allowing, e.g., multiple + users to access the data without having to recompute. + + Check StorageBox.realize for a particularly notable implementation. + + TODO(ezyang): I think, in principle, every IRNode should have an + implementation of this, and most of the time no-op is OK, but you + really do have to audit each IRNode for this, so for now, raise + an error if it's not implemented. Note that some code in graph.py + will catch this thrown error and suppress it with a warning. + """ + raise NotImplementedError(f"realize NYI on {type(self)}") + + def codegen_reference(self, writer=None): + raise NotImplementedError(f"codegen_reference NYI on {type(self)}") + + # The abstract method declarations below serve to convince mypy that all IRNode instances have these functions + # defined, while having no effect at runtime. We cannot create stub implementations here because other parts of + # the code dynamically check for defined attributes. + get_device: Callable[[], torch.device] + dtype: torch.dtype + get_name: Callable[[], str] + get_reads: Callable[[], Any] + num_reads: Callable[[], int] + get_stride: Callable[[], Any] + get_storage_numel: Callable[[], Any] + has_exceeded_max_reads: Callable[[], bool] + make_loader: Callable[[], Callable[[Any], Any]] + make_indexer: Callable[[], Callable[[Any], Any]] + mark_reuse: Callable[[int], None] + realize_hint: Callable[[], None] + get_unbacked_symbol_uses: Callable[[], OrderedSet[sympy.Symbol]] + + +@dataclasses.dataclass +class Operation: + def __post_init__(self): + self.operation_name: Optional[str] = None + + def get_device(self): + raise NotImplementedError + + def get_origin_node(self): + assert hasattr(self, "origin_node") + return self.origin_node + + def get_origins(self): + assert hasattr(self, "origins") + return self.origins + + def get_operation_name(self) -> str: + assert self.operation_name is not None + return self.operation_name + + def is_extern(self): + return False + + def is_no_op(self): + return False + + def get_read_writes(self): + raise NotImplementedError + + def is_user_of(self, name): + return name in self.get_read_names() + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet(dep.name for dep in self.get_reads()) + + def get_reads(self): + return self.get_read_writes().reads + + def get_outputs(self) -> List[Buffer]: + raise NotImplementedError + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + """ + Returns the unbacked symbols which are required to be in scope in + order to successfully perform codegen for this buffer. For example, + a buffer that corresponds to an extern kernel call that takes i0 as + an argument would return {i0} here. This is used to generate necessary + dependencies that ensure we actually bind i0 in codegen before you + try to use it. + + Note that this is NOT transitive; in particular, if this buffer takes + in as input another buffer with dynamic shape (e.g., (i0,)), we will + not report it here, because you will already have a dependency + on that buffer, which will eventually have a dependency on i0 if + necessary. + """ + return OrderedSet() + + def get_workspace_size(self): + """ + Gets extra global memory size needed by this buffer. + Some algorithms (e.g. group gemm) may require extra global memory in the generated code. + """ + return 0 + + +@dataclasses.dataclass +class Loops(IRNode): + device: torch.device + dtype: torch.dtype + inner_fn: Callable[..., Any] + ranges: List[Expr] + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet().union( + *(free_unbacked_symbols(e) for e in self.ranges), + self.inner_fn_free_unbacked_symbols(), + ) + + def __str__(self, names=("ranges",)): + return self.str_helper( + [ + f"'{self.device.type}'", + str(self.dtype), + self.inner_fn_str(), + ] + + [f"{name}={getattr(self, name)}" for name in names] + + [f"origin_node={self.origin_node!r}"] + ) + + def __post_init__(self): + super().__post_init__() + self.origin_node = None + + __repr__ = __str__ + + def get_device(self): + return self.device + + def get_origin_node(self): + return self.origin_node + + def get_size(self): + return self.ranges + + def get_pointwise_size(self): + return self.ranges + + def is_extern(self): + return False + + @classmethod + def create(cls, *args, **kwargs): + origin_node = kwargs.pop("origin_node", None) + tb = kwargs.pop("traceback", None) + r = cls(*args, **kwargs) + r.origin_node = origin_node + r.traceback = ( + tb or traceback.format_stack() if config.debug_ir_traceback else None + ) + return TensorBox.create(r) + + @staticmethod + def _index(ranges, prefix=SymT.INDEX): + return [ + sympy.Integer(0) if s == 1 else sympy_index_symbol_with_prefix(prefix, n) + for n, s in enumerate(ranges) + ] + + @cache_on_self + def inner_fn_opcount(self) -> OpCountResult: + opcounter = OpCounterCSE(V.MockHandler()) + with V.set_ops_handler(opcounter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + self.inner_fn(*self.inner_fn_args()) + return opcounter.getvalue() + + def inner_fn_args(self): + return (self._index(self.ranges),) + + @cache_on_self + def inner_fn_str(self): + return V.KernelFormatterHandler.ir_to_string( + self.inner_fn, *self.inner_fn_args() + ) + + def has_large_inner_fn(self): + return self.inner_fn_opcount().num_ops > config.realize_opcount_threshold + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + return extract_free_unbacked_symbols(self.inner_fn, index) + + def get_reads(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + if self.get_reduction_type(): + return extract_read_writes( + self.make_loader(), + self.get_size(), + self.get_reduction_size(), + ).reads + else: + return extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet(self.inner_fn_opcount().read_buffers) + + def num_reads(self): + return len(self.inner_fn_opcount().read_buffers) + + def get_reduction_size(self): + raise NotImplementedError( + f"get_reduction_size() is not implemented by {type(self)}!" + ) + + def get_reduction_type(self): + raise NotImplementedError( + f"get_reduction_type() is not implemented by {type(self)}!" + ) + + def constant_to_device(self, device): + raise NotImplementedError( + f"constant_to_device() is not implemented by {type(self)}!" + ) + + +def nop_loader_fn(idx: Union[Expr, Sequence[Expr]], *, dtype: torch.dtype) -> OpsValue: + if dtype.is_floating_point: + return ops.constant(float("nan"), dtype) + else: + return ops.constant(0, dtype) + + +class Pointwise(Loops): + def make_loader(self): + # Make zero-element loops into a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.dtype) + + return self.inner_fn + + def get_reduction_size(self): + return [] + + def get_reduction_type(self): + return None + + def store_output(self, output_name, indexer, vars): + loader = self.make_loader() + return ops.store(output_name, indexer(vars), loader(vars)) + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Pointwise(device, self.dtype, loader, self.ranges) + + +@dataclasses.dataclass +class Scatter(Pointwise): + output_indexer: Callable[[List[Expr]], Expr] + scatter_mode: Optional[str] = None + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Scatter( + device, + self.dtype, + loader, + self.ranges, + self.output_indexer, + self.scatter_mode, + ) + + def store_output(self, output_name, indexer, vars): + loader = self.make_loader() + return ops.store( + output_name, + indexer(self.output_indexer(vars)), + loader(vars), + mode=self.scatter_mode, + ) + + +REDUCTION_COMBINE_FN: Dict[str, Callable[..., OpsValue]] = { + "any": ops_wrapper("logical_or"), + "max": ops_wrapper("maximum"), + "min": ops_wrapper("minimum"), + "prod": ops_wrapper("mul"), + "sum": ops_wrapper("add"), + "xor_sum": ops_wrapper("bitwise_xor"), +} + + +def get_reduction_combine_fn( + reduction_type: str, dtype: torch.dtype, arg_break_ties_left: bool = True +) -> Callable[..., object]: + if reduction_type in REDUCTION_COMBINE_FN: + return REDUCTION_COMBINE_FN[reduction_type] + + elif reduction_type in ("argmax", "argmin"): + + def argmax_combine_fn( + a: Tuple[object, object], b: Tuple[object, object] + ) -> Tuple[OpsValue, OpsValue]: + a_value, a_index = a + b_value, b_index = b + + if reduction_type == "argmin": + mask = ops.lt(a_value, b_value) + else: + mask = ops.gt(a_value, b_value) + + equal = ops.eq(a_value, b_value) + if is_float_dtype(dtype): + a_isnan = ops.ne(a_value, a_value) + b_isnan = ops.ne(b_value, b_value) + mask = ops.logical_or(mask, ops.gt(a_isnan, b_isnan)) + equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan)) + + tie = ( + ops.lt(a_index, b_index) + if arg_break_ties_left + else ops.gt(a_index, b_index) + ) + mask = ops.logical_or(mask, ops.logical_and(equal, tie)) + return ( + ops.where(mask, a_value, b_value), + ops.where(mask, a_index, b_index), + ) + + return argmax_combine_fn + + elif reduction_type == "welford_combine": + + def welford_combine_fn( + a: Tuple[OpsValue, OpsValue, OpsValue], + b: Tuple[OpsValue, OpsValue, OpsValue], + ) -> Tuple[OpsValue, OpsValue, OpsValue]: + a_mean, a_m2, a_weight = a + b_mean, b_m2, b_weight = b + + delta = b_mean - a_mean + new_weight = a_weight + b_weight + w2_over_w = b_weight / new_weight + return ( + a_mean + delta * w2_over_w, + a_m2 + b_m2 + delta * delta * a_weight * w2_over_w, + new_weight, + ) + + return welford_combine_fn + + else: + raise NotImplementedError(f"unknown reduction_type={reduction_type}") + + +def significant_strides_equal( + strides1: Sequence[_IntLike], strides2: Sequence[_IntLike], size: Sequence[_IntLike] +) -> bool: + """ + Returns true if the strides are equal, ignoring dimensions of size 1 . + """ + non_1_indices = [ + i + for i, dim in enumerate(size) + if V.graph.sizevars.size_hint(dim, fallback=2) != 1 + ] + strides1 = [V.graph.sizevars.size_hint(strides1[i]) for i in non_1_indices] + strides2 = [V.graph.sizevars.size_hint(strides2[i]) for i in non_1_indices] + return strides1 == strides2 + + +@dataclasses.dataclass +class Reduction(Loops): + reduction_ranges: List[Expr] + reduction_type: str + # self.dtype represents the dst dtype + src_dtype: torch.dtype + reduction_hint: ReductionHint + + def __str__(self) -> str: # type: ignore[override] + return Loops.__str__( # type: ignore[call-arg] + self, names=("ranges", "reduction_ranges", "reduction_type") + ) + + def __repr__(self) -> str: # type: ignore[override] + return self.__str__() + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return super().get_unbacked_symbol_uses() | OrderedSet().union( + *(free_unbacked_symbols(e) for e in self.reduction_ranges) + ) + + def get_reduction_size(self): + return self.reduction_ranges + + def get_reduction_type(self): + return self.reduction_type + + def store_reduction(self, output_name, indexer, vars, reduction_vars): + value = ops.reduction( + self.dtype, + self.src_dtype, + self.reduction_type, + self.inner_fn(vars, reduction_vars), + ) + return ops.store_reduction(output_name, indexer(vars), value) + + def index_length(self): + return len(self.ranges) + len(self.reduction_ranges) + + def inner_fn_args(self): + index = self._index(self.ranges) + rindex = self._index(self.reduction_ranges, SymT.RINDEX) + return (index, rindex) + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + rindex = self._index(self.reduction_ranges, SymT.RINDEX) + return extract_free_unbacked_symbols(self.inner_fn, index, rindex) + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Reduction( + device, + self.dtype, + loader, + self.ranges, + self.reduction_ranges, + self.reduction_type, + self.src_dtype, + ReductionHint.DEFAULT, + ) + + @staticmethod + def num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node: Optional[IRNode] = None, + ): + def _is_static(x): + return isinstance(x, (int, sympy.Integer)) + + reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel) + numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges)) + + should_split = ( + not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT) + and reduction_type + not in ( + "argmax", + "argmin", + ) + and config.split_reductions + # We don't support unbacked symints + and _is_static(reduction_numel_hint) + and _is_static(numel_hint) + ) + if not should_split: + return ReductionHint.DEFAULT, 1 + + device_interface = get_interface_for_device(get_device_type(device)) # type: ignore[arg-type] # next PR + device_properties = device_interface.Worker.get_device_properties(device) + if get_device_type(device) == "xpu": + num_sm = device_properties.gpu_subslice_count + else: + # default is cuda behavior + num_sm = device_properties.multi_processor_count + + min_elements_per_thread = 32 + max_elements_per_thread = 512 + threads_per_sm = 2048 + min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm + max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm + + def inner_reduction_splits(reduction_numel_hint, numel_hint): + # do heuristics that's close to eager mode for split inner reduction + # we leak reduction autotune configs here, and will need to refactor to avoid this later + num_warps = 8 + num_threads = 32 * num_warps + if numel_hint >= 2 * num_sm: # don't split if there are enough outputs + return 1 + if reduction_numel_hint <= 8192: + return 1 + if reduction_numel_hint * numel_hint <= min_elements_per_device: + split_size = min_elements_per_thread + elif reduction_numel_hint * numel_hint < max_elements_per_device: + target_blocks = num_sm * threads_per_sm // (2 * num_threads) + blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint + tmp_split_size = ( + reduction_numel_hint + num_threads * blocks_per_output - 1 + ) // (num_threads * blocks_per_output) + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) + if abs(closest - tmp_split_size) < 30: + # prefer even splits, but never smalle than min_elements_per_thread + split_size = max(closest, min_elements_per_thread) + else: + split_size = tmp_split_size + else: + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) + if abs(closest - max_elements_per_thread) < 50: + # prefer even splits + split_size = closest + else: + split_size = max_elements_per_thread + return (reduction_numel_hint + split_size * num_threads - 1) // ( + split_size * num_threads + ) + + def outer_reduction_splits(reduction_numel_hint, numel_hint): + # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128 + # extend to even smaller number of outputs + num_warps = 8 + num_threads = num_warps * 32 + rvals_per_thread = 4 # comes from heuristics, refactor to not leak here + xvals_per_block = 128 + xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block + if reduction_numel_hint * numel_hint < min_elements_per_device: + split_size = min_elements_per_thread + elif reduction_numel_hint * numel_hint < max_elements_per_device: + target_blocks = num_sm * threads_per_sm // (num_threads) + target_blocks = (target_blocks + xblocks - 1) // xblocks + tmp_split_size = ( + reduction_numel_hint + rvals_per_thread * target_blocks - 1 + ) // (rvals_per_thread * target_blocks) + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) + if abs(tmp_split_size - closest) < 20: + split_size = max(closest, min_elements_per_thread) + else: + split_size = tmp_split_size + else: + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) + if abs(closest - max_elements_per_thread) < 50: + # prefer even splits + split_size = closest + else: + split_size = max_elements_per_thread + + return (reduction_numel_hint + rvals_per_thread * split_size - 1) // ( + rvals_per_thread * split_size + ) + + # easy cases + if numel_hint == 1: + split = inner_reduction_splits(reduction_numel_hint, numel_hint) + if split == 1: + # No need to split. + return ReductionHint.INNER, split + if input_node is not None and isinstance(input_node, TensorBox): + new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( + input_node + ) + if new_ranges is not None and new_reduction_ranges is not None: + extracted_numel_hint = V.graph.sizevars.symbolic_hint( + sympy_product(new_ranges + new_reduction_ranges) + ) + if reduction_numel_hint == extracted_numel_hint: + log.debug( + "Use previous IRNode's range and reduction_ranges instead of split. " + "current ranges: %s, current reduction ranges: %s, current split: %d, " + "new ranges: %s, new reduction ranges: %s", + ranges, + reduction_ranges, + split, + new_ranges, + new_reduction_ranges, + ) + # If the input_node or its dependent nodes are also Reduction nodes, + # use reduction_sizes of this node or its dependent nodes directly. + return ReductionHint.INNER, -1 + return ReductionHint.INNER, split + if ( + reduction_numel_hint <= min_elements_per_thread + or numel_hint >= num_sm * 2 * 32 + ): + return ReductionHint.DEFAULT, 1 + + r = Reduction( + device, + dst_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + src_dtype, + ReductionHint.DEFAULT, + ) + + def get_read_indices(r): + cb = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=r.get_device(), + dtype=r.get_dtype(), + size=r.get_size(), + ), + data=r, + ) + read_writes = cb.get_read_writes() + # try finding the full size producer + # TODO this will fail for something like ((1, N) * (N, 1)).sum() + # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare + range_vars = [ + r + for r in read_writes.range_vars + if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number) + ] + indices = [] + changed = False + for md in sorted(read_writes.reads, key=lambda x: x.name): + if all(r in md.index.free_symbols for r in range_vars): + indices.append(md.index) + if md.name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[md.name] + original_stride = buf.layout.stride + buf.decide_layout() + if buf.layout.stride != original_stride: + changed = True + return indices, changed + + indices, changed = get_read_indices(r) + if changed: + indices, _ = get_read_indices(r) + + if len(indices) == 0: + # TODO determine splits when all inputs are broadcast + return ReductionHint.DEFAULT, 1 + + (_, reduction_vars), ranges = dependencies.index_vars_squeeze( + r.get_size(), r.get_reduction_size() + ) + num_outer = 0 + num_inner = 0 + for i in indices: + i = V.graph.sizevars.simplify_with_ranges(i, ranges) + strides = V.graph.sizevars.stride_hints(i, reduction_vars, ranges.keys()) + outer = all(s > 1 for s in strides) + if outer: + num_outer += 1 + else: + num_inner += 1 + if num_inner > num_outer: + return ReductionHint.INNER, inner_reduction_splits( + reduction_numel_hint, numel_hint + ) + else: + return ReductionHint.OUTER, outer_reduction_splits( + reduction_numel_hint, numel_hint + ) + + @staticmethod + def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type, src_dtype): + """Convert inner_fn from a reduction to an pointwise""" + reduction_ranges = [ + V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges + ] + + combine_fn = get_reduction_combine_fn(reduction_type, src_dtype) + + def fn(index): + return functools.reduce( + combine_fn, + ( + value_fn(index, rindex) + for rindex in itertools.product( + *[range(x) for x in reduction_ranges] + ) + ), + ) + + if reduction_type in ("argmin", "argmax"): + flatten_index = FixedLayout( + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + reduction_ranges, + FlexibleLayout.contiguous_strides(reduction_ranges), + ).make_indexer() + + def value_fn(index, rindex): + rindex = [sympy.expand(i) for i in rindex] + return ( + inner_fn(index, rindex), + ops.index_expr(flatten_index(rindex), torch.int64), + ) + + return lambda index: fn(index)[1] + else: + value_fn = inner_fn + return fn + + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + input_node: Optional[IRNode] = None, + ): + reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + if reduction_numel == 0: + # N.B. This is a hack to generate the literal of the given type + # Ideally, we should be fixing `def constant` in triton.py + # but it breaks due to hardcoded dtypes in other places + def py_cnst(val): + return ( + bool(val) + if dst_dtype == torch.bool + else float(val) + if dst_dtype.is_floating_point + else int(val) + ) + + rtypes_to_inits = { + "sum": py_cnst(0), + "xor_sum": py_cnst(0), + "prod": py_cnst(1), + "any": py_cnst(0), + # "all" is desugared to `!any(!val)` + } + + assert ( + reduction_type in rtypes_to_inits.keys() + ), f"{reduction_type} not supported for zero-dimension tensors!" + + def const_fn(index): + return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) + + return Pointwise.create( + device=device, + dtype=src_dtype, + inner_fn=const_fn, + ranges=list(ranges), + ) + + if reduction_numel == 1: + # this reduction is actually a pointwise op + if reduction_type in ("argmin", "argmax"): + + def fn(index): + return ops.constant(0, dst_dtype) + + else: + + def fn(index): + reduction_index = [sympy.Integer(0) for _ in reduction_ranges] + return inner_fn(index, reduction_index) + + return Pointwise.create(device, dst_dtype, fn, ranges) + + if ( + isinstance(reduction_numel, sympy.Integer) + and V.graph.sizevars.size_hint(reduction_numel) + < config.unroll_reductions_threshold + and sympy_product(ranges) != 1 + ): + return Pointwise.create( + device, + dst_dtype, + cls._unroll_reduction_fn( + inner_fn, reduction_ranges, reduction_type, src_dtype + ), + ranges, + ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = cls.num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node, + ) + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ReductionHint.DEFAULT: + reduction_hint = hint + if split == -1: + assert input_node is not None + new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( + input_node # type: ignore[arg-type] + ) + assert new_ranges is not None + assert new_reduction_ranges is not None + return cls.create_multilayer_existing_ranges( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + elif split > 1: + # triton doesn't support reduce to single element well, so break it up + return cls.create_multilayer( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + ) + + return TensorBox.create( + Reduction( + device, + dst_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + src_dtype, + reduction_hint, + ) + ) + + @staticmethod + def default_accumulator(reduction_type, dtype): + if reduction_type in ("max", "argmax"): + if is_float_dtype(dtype): + return float("-inf") + elif is_boolean_dtype(dtype): + return 0 + else: + return torch.iinfo(dtype).min + if reduction_type in ("min", "argmin"): + if is_float_dtype(dtype): + return float("inf") + elif is_boolean_dtype(dtype): + return 1 + else: + return torch.iinfo(dtype).max + + return { + "sum": 0, + "prod": 1, + "xor_sum": 0, + "any": 0, + "welford_reduce": (0, 0, 0), + "welford_combine": (0, 0, 0), + }[reduction_type] + + @staticmethod + def default_value(reduction_type, dtype): + if reduction_type == "welford_reduce": + return 0 + return Reduction.default_accumulator(reduction_type, dtype) + + @staticmethod + def _multilayer_second_step_hint( + split: int, numel_hint: int, reduction_hint: ReductionHint + ) -> ReductionHint: + if split == -1: + return reduction_hint + if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER: + return ReductionHint.OUTER_TINY + if ( + split <= 1024 + and numel_hint <= 256 + and reduction_hint == ReductionHint.OUTER + ): + return ReductionHint.OUTER_TINY + + return reduction_hint + + @classmethod + def _multilayer_wrap_loader( + cls, + loader, + reduction_ranges, + reduction_numel, + split, + block_size, + default, + ): + reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel]) + need_mask = not V.graph.sizevars.is_expr_static_and_true( + sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] + ) + + def wrapper_fn(index, reduction_index): + (reduction_index,) = reduction_index + *new_index, reduction_block = index + indices = block_size * reduction_block + reduction_index + + def body(): + return loader(new_index, reindex([indices])) + + if need_mask: + mask = ops.lt( + ops.index_expr(indices, torch.int32), + ops.index_expr(reduction_numel, torch.int32), + ) + return ops.masked(mask, body, default) + else: + return body() + + return wrapper_fn + + @classmethod + def _multilayer_wrap_loader_existing_ranges( + cls, + loader, + original_ranges, + original_reduction_ranges, + new_ranges, + new_reduction_ranges, + default, + ): + assert all( + r == 1 for r in original_ranges + ), f"Only enabled for numel_hint == 1, found {original_ranges=}" + reindex = View.dynamic_reshape_indexer( + original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges) + ) + + def wrapper_fn(merged_index, new_reduction_index): + original_idx = merged_index[: len(original_ranges)] + new_index = merged_index[len(original_ranges) :] + return loader( + original_idx, + reindex(tuple(new_index) + tuple(new_reduction_index)), + ) + + return wrapper_fn + + @classmethod + def create_multilayer_helper( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + wrapper_fn: Callable[..., Any], + original_ranges: List[Expr], + original_reduction_ranges: List[Expr], + new_ranges: List[Expr], + new_reduction_ranges: List[Expr], + reduction_type: str, + split: int, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + # triton will automatically compute reductions in fp32 if reducing over fp16/bf16 + # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction + # in fp32 and not reduce precision by breaking up the kernel into multiple layers + intermediate_dtype = ( + dst_dtype + if dst_dtype not in (torch.float16, torch.bfloat16) + else torch.float + ) + intermediate = Reduction.create( + device, + intermediate_dtype, + src_dtype, + wrapper_fn, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + intermediate.realize() + intermediate_loader = intermediate.make_loader() + + def intermediate_fn(index, reduction_index): + return intermediate_loader([*index, *reduction_index]) + + numel_hint = V.graph.sizevars.size_hint(sympy_product(original_ranges)) + reduction_hint = cls._multilayer_second_step_hint( + split, numel_hint, reduction_hint + ) + + assert original_ranges == new_ranges[: len(original_ranges)] + return TensorBox.create( + Reduction( + device, + dst_dtype, + intermediate_fn, + original_ranges, + new_ranges[len(original_ranges) :], + reduction_type, + src_dtype, + reduction_hint, + ) + ) + + @classmethod + def create_multilayer( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + split: int, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + # TODO(jansel): realize the reduction so we can do dynamic indexing + reduction_numel = sympy_product(reduction_ranges) + block_size = FloorDiv(reduction_numel + (split - 1), split) + default = cls.default_value(reduction_type, dst_dtype) + wrapper_fn = cls._multilayer_wrap_loader( + inner_fn, reduction_ranges, reduction_numel, split, block_size, default + ) + + return cls.create_multilayer_helper( + device, + dst_dtype, + src_dtype, + wrapper_fn, + ranges, + reduction_ranges, + [*ranges, split], # type: ignore[list-item] + [block_size], + reduction_type, + split, + reduction_hint, + ) + + @classmethod + def create_multilayer_existing_ranges( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + original_ranges: List[Expr], + original_reduction_ranges: List[Expr], + new_ranges: List[Expr], + new_reduction_ranges: List[Expr], + reduction_type: str, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + default = cls.default_value(reduction_type, dst_dtype) + wrapper_fn = cls._multilayer_wrap_loader_existing_ranges( + inner_fn, + original_ranges, + original_reduction_ranges, + new_ranges, + new_reduction_ranges, + default, + ) + return cls.create_multilayer_helper( + device, + dst_dtype, + src_dtype, + wrapper_fn, + original_ranges, + original_reduction_ranges, + [*original_ranges, *new_ranges], + new_reduction_ranges, + reduction_type, + -1, + reduction_hint, + ) + + +class WelfordReduction(Reduction): + output_index: int + + def __init__( + self, + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + reduction_hint, + output_index, + ): + if len(inner_fns) == 1: + loader = inner_fns[0] + else: + + def loader(idx, reduction_idx): + return tuple(fn(idx, reduction_idx) for fn in inner_fns) + + super().__init__( + device, + dtype, + loader, + ranges, + reduction_ranges, + reduction_type, + dtype, + reduction_hint, + ) + self.output_index = output_index + + def store_reduction(self, output_name, indexer, vars, reduction_vars): + values = ops.reduction( + self.dtype, + self.src_dtype, + self.reduction_type, + self.inner_fn(vars, reduction_vars), + ) + value = values[self.output_index] + return ops.store_reduction(output_name, indexer(vars), value) + + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dtype: torch.dtype, + inner_fns: Sequence[Callable[..., Any]], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + ): + assert reduction_type in ("welford_reduce", "welford_combine") + + reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + def const(val): + def inner_fn(idx): + return ops.constant( + val, + dtype, + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(ranges), + ) + + if reduction_numel == 0: + mean = const(0) + m2 = const(0) + weight = const(0) + return mean, m2, weight + + if reduction_numel == 1: + + def copy(loader): + def inner_fn(idx): + reduction_index = [sympy.Integer(0) for _ in reduction_ranges] + return loader(idx, reduction_index) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(ranges), + ) + + if reduction_type == "welford_reduce": + return copy(inner_fns[0]), const(0), const(1) + else: + return tuple(copy(fn) for fn in inner_fns) + + # TODO: Unrolled reduction + # if ( + # isinstance(reduction_numel, sympy.Integer) + # and V.graph.sizevars.size_hint(reduction_numel) + # < config.unroll_reductions_threshold + # and sympy_product(ranges) != 1 + # ): + # return Pointwise.create( + # device, + # dst_dtype, + # cls._unroll_reduction_fn( + # inner_fn, reduction_ranges, reduction_type, src_dtype + # ), + # ranges, + # ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = Reduction.num_splits( + device, + dtype, + dtype, + inner_fns[0], + ranges, + reduction_ranges, + reduction_type=reduction_type, + reduction_numel=reduction_numel, + ) + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ReductionHint.DEFAULT: + reduction_hint = hint + if split > 1: + # triton doesn't support reduce to single element well, so break it up + return cls.create_multilayer( + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + ) + + results = [ + TensorBox.create( + WelfordReduction( + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + reduction_hint, + output_idx, + ) + ) + for output_idx in range(3) + ] + for t in results: + t.realize() + return results + + @staticmethod + def default_value(reduction_type, dtype): + return (0, 0, 0) + + @classmethod + def create_multilayer( # type: ignore[override] + cls, + device: torch.device, + dtype: torch.dtype, + inner_fns: Sequence[Callable[..., Any]], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + split: int, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + reduction_numel = sympy_product(reduction_ranges) + need_mask = not V.graph.sizevars.is_expr_static_and_true( + sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] + ) + + if need_mask and reduction_type != "welford_combine": + # If we need mask, then "welford_reduce" doesn't work because + # masked inputs shouldn't count towards the welford weight + + def constant(idx, reduction_idx, value): + return ops.constant(value, dtype) + + return cls.create_multilayer( + device=device, + dtype=dtype, + inner_fns=( + inner_fns[0], + partial(constant, value=0), + partial(constant, value=1), + ), + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type="welford_combine", + split=split, + reduction_hint=reduction_hint, + ) + + block_size = FloorDiv(reduction_numel + (split - 1), split) + intermediates = WelfordReduction.create( + device, + dtype, + tuple( + cls._multilayer_wrap_loader( + loader, + reduction_ranges, + reduction_numel, + split, + block_size, + default=0, + ) + for loader in inner_fns + ), + [*ranges, split], # type: ignore[list-item] + [block_size], + reduction_type, + reduction_hint, + ) + for i in intermediates: + i.realize() + + i_loaders = [i.make_loader() for i in intermediates] + + def intermediate_loader_fn(index, reduction_index, loader): + return loader([*index, *reduction_index]) + + numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges)) + reduction_hint = cls._multilayer_second_step_hint( + split, numel_hint, reduction_hint + ) + return WelfordReduction.create( + device, + dtype, + tuple( + partial(intermediate_loader_fn, loader=i.make_loader()) + for i in intermediates + ), + ranges, + [split], # type: ignore[list-item] + # welford_reduce turns one input into three outputs, which are combined with welford_combine + "welford_combine", + reduction_hint, + ) + + +@dataclasses.dataclass +class Scan(Loops): + scan_ranges: List[Expr] + size: List[Expr] + combine_fn: Callable[[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]] + reindex: Callable[[List[Expr], List[Expr]], List[Expr]] + reduction_hint: ReductionHint + output_index: int + # output_index indexes the following tuples + dtypes: Tuple[torch.dtype, ...] + inner_fns: Tuple[Callable[..., Any], ...] + + # HACK we mimick reduction + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we + # need to explicitly represent the closure so we can pull out unbacked + # symbols here + return ( + super().get_unbacked_symbol_uses() + | OrderedSet().union(*(free_unbacked_symbols(e) for e in self.scan_ranges)) + | OrderedSet().union(*(free_unbacked_symbols(e) for e in self.size)) + ) + + def __post_init__(self): + assert len(self.ranges) + len(self.scan_ranges) == len(self.size) + super().__post_init__() + + def store_reduction(self, output_name, indexer, vars, scan_vars): + idx = self.reindex(vars, scan_vars) + values = [inner_fn(idx) for inner_fn in self.inner_fns] + result = ops.scan(self.dtypes, self.combine_fn, values) + return ops.store(output_name, indexer(idx), result[self.output_index]) + + def get_reduction_type(self): + # return self.scan_op + return "custom" + + def get_reduction_size(self): + return self.scan_ranges + + def get_size(self): + return self.size + + def get_pointwise_size(self): + return self.ranges + + def index_length(self): + return len(self.ranges) + len(self.scan_ranges) + + def inner_fn_args(self): + index = self._index(self.ranges) + rindex = self._index(self.scan_ranges, SymT.RINDEX) + idx = self.reindex(index, rindex) + return (idx,) + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + rindex = self._index(self.scan_ranges, SymT.RINDEX) + idx = self.reindex(index, rindex) + return extract_free_unbacked_symbols(self.inner_fn, idx) + + @classmethod + def create( + cls, + device: torch.device, + dtypes: Tuple[torch.dtype, ...], + inner_fns: Tuple[Callable[[List[Expr]], Any], ...], + size: List[Expr], + axis: int, + combine_fn: Callable[[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]], + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + *, + # Whether we have the option to fallback to aten + can_fallback_to_aten: bool = True, + **kwargs, + ) -> List[Optional[TensorBox]]: + pointwise_ranges = [*size[:axis], *size[axis + 1 :]] + scan_ranges = [size[axis]] + + if not V.graph.has_feature(device, BackendFeature.SCAN): + return [None] * len(dtypes) + + if len(dtypes) > 1 and not V.graph.has_feature( + device, BackendFeature.TUPLE_REDUCTION + ): + return [None] * len(dtypes) + + sizevars = V.graph.sizevars + scan_numel = sizevars.simplify(sympy_product(scan_ranges)) + + assert len(dtypes) == len(inner_fns) + + # Scan with a single element is just a copy + if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): # type: ignore[arg-type] + return [ + Pointwise.create( + device=device, + dtype=dtypes[output_index], + inner_fn=inner_fns[output_index], + ranges=size, + ) + for output_index in range(len(dtypes)) + ] + + reduction_hint, num_splits = cls.num_splits( + device=device, + dtype=dtypes[0], + inner_fn=inner_fns[0], + axis=axis, + pointwise_ranges=pointwise_ranges, + scan_ranges=scan_ranges, + combine_fn=combine_fn, + scan_numel=scan_numel, + ) + scan_type = Scan + if num_splits > 1: + supports_split = torch.version.hip is None and len(dtypes) == 1 + if not supports_split: + if can_fallback_to_aten: + # Fallback to ATen + return [None] * len(dtypes) + else: + num_splits = 1 + else: + scan_type = SplitScan + + def reindex(index, scan_index): + assert len(scan_index) == len(scan_ranges) + assert len(index) == len(pointwise_ranges) + return [*index[:axis], *scan_index, *index[axis:]] + + results = [ + TensorBox.create( + scan_type( + device=device, + dtype=dtypes[output_index], + dtypes=dtypes, + inner_fn=inner_fns[output_index], + inner_fns=inner_fns, + size=size, + ranges=pointwise_ranges, + scan_ranges=scan_ranges, + combine_fn=combine_fn, + reindex=reindex, + reduction_hint=reduction_hint, + output_index=output_index, + **kwargs, + ) + ) + for output_index in range(len(dtypes)) + ] + + for result in results: + result.realize() + + return results + + @classmethod + def num_splits( + cls, + device: torch.device, + dtype: torch.dtype, + inner_fn: Callable[[List[Expr]], Any], + axis: int, + pointwise_ranges: List[Expr], + scan_ranges: List[Expr], + combine_fn: Callable[[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]], + scan_numel: Expr, + ): + # TODO: custom splitting heuristic for scan + def wrapper_fn(idx, reduction_idx): + return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]]) + + return Reduction.num_splits( + device=device, + dst_dtype=dtype, + src_dtype=dtype, + inner_fn=wrapper_fn, + ranges=pointwise_ranges, + reduction_ranges=scan_ranges, + reduction_type="sum", + reduction_numel=scan_numel, + ) + + +# This signifies a scan op that should go through TritonSplitScanKernel codegen on CUDA. +@dataclasses.dataclass +class SplitScan(Scan): + pass + + +@dataclasses.dataclass +class Sort(Loops): + # Sorts a tuple of key, value pairs + sort_ranges: List[Expr] + size: List[Expr] + reindex: Callable[[List[Expr], List[Expr]], List[Expr]] + reduction_hint: ReductionHint + output_index: int + # output_index indexes the following tuples + dtypes: Tuple[torch.dtype, ...] + inner_fns: Tuple[Callable[..., Any], ...] + + stable: bool + descending: bool + + # HACK we mimick reduction + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return ( + super().get_unbacked_symbol_uses() + | OrderedSet().union(*(free_unbacked_symbols(e) for e in self.sort_ranges)) + | OrderedSet().union(*(free_unbacked_symbols(e) for e in self.size)) + ) + + def __post_init__(self): + assert len(self.ranges) + len(self.sort_ranges) == len(self.size) + super().__post_init__() + + def store_reduction(self, output_name, indexer, vars, sort_vars): + idx = self.reindex(vars, sort_vars) + values = [inner_fn(idx) for inner_fn in self.inner_fns] + result = ops.sort(self.dtypes, values, self.stable, self.descending) + return ops.store(output_name, indexer(idx), result[self.output_index]) + + def get_reduction_type(self): + return "sort" + + def get_reduction_size(self): + return self.sort_ranges + + def get_size(self): + return self.size + + def get_pointwise_size(self): + return self.ranges + + def index_length(self): + return len(self.ranges) + len(self.sort_ranges) + + def inner_fn_args(self): + index = self._index(self.ranges) + rindex = self._index(self.sort_ranges, SymT.RINDEX) + idx = self.reindex(index, rindex) + return (idx,) + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + rindex = self._index(self.sort_ranges, SymT.RINDEX) + idx = self.reindex(index, rindex) + return extract_free_unbacked_symbols(self.inner_fn, idx) + + @classmethod + def create( + cls, + device: torch.device, + dtypes: Tuple[torch.dtype, ...], + inner_fns: Tuple[Callable[[List[Expr]], Any], ...], + size: List[Expr], + axis: int, + stable: bool, + descending: bool, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + **kwargs, + ) -> List[Optional[TensorBox]]: + pointwise_ranges = [*size[:axis], *size[axis + 1 :]] + sort_ranges = [size[axis]] + + if not V.graph.has_feature(device, BackendFeature.SORT): + return [None] * len(dtypes) + + sizevars = V.graph.sizevars + sort_numel = sizevars.simplify(sympy_product(sort_ranges)) + + # Heuristic, smallest rblock where triton usually outperforms aten.sort + # It also isn't bandwidth bound so fusion is unlikely to help. + max_rblock = 512 + is_persistent_kernel = ( + config.triton.persistent_reductions + and sizevars.is_expr_static_and_true(sympy.Le(sort_numel, max_rblock)) + ) + if not is_persistent_kernel: + # We only support persistent triton kernels + return [None] * len(dtypes) + + assert len(dtypes) == len(inner_fns) + + # Sort with a single element is just a copy + if sizevars.is_expr_static_and_true(sympy.Le(sort_numel, 1)): # type: ignore[arg-type] + return [ + Pointwise.create( + device=device, + dtype=dtypes[output_index], + inner_fn=inner_fns[output_index], + ranges=size, + ) + for output_index in range(len(dtypes)) + ] + + def reindex(index, sort_index): + assert len(sort_index) == len(sort_ranges) + assert len(index) == len(pointwise_ranges) + return [*index[:axis], *sort_index, *index[axis:]] + + results = [ + TensorBox.create( + Sort( + device=device, + dtype=dtypes[output_index], + dtypes=dtypes, + inner_fn=inner_fns[output_index], + inner_fns=inner_fns, + size=size, + ranges=pointwise_ranges, + sort_ranges=sort_ranges, + reindex=reindex, + reduction_hint=reduction_hint, + output_index=output_index, + stable=stable, + descending=descending, + **kwargs, + ) + ) + for output_index in range(len(dtypes)) + ] + + for result in results: + result.realize() + + return results + + +def is_storage_and_layout(x: IRNode) -> bool: + try: + as_storage_and_layout(x, freeze=False) + return True + except NotImplementedError: + return False + + +def is_contiguous_storage_and_layout(x: IRNode) -> bool: + try: + buffer, layout = as_storage_and_layout(x, freeze=False) + # pad the stride here so we will NOT claim an tensor as contiguous + # if a padding is gonna happen. + if layout.should_pad_strides(): + layout.pad_strides() + return layout.is_contiguous() + except NotImplementedError: + return False + + +def as_storage_and_layout( + x: IRNode, + freeze: bool = True, + want_contiguous: bool = False, + stride_order: Optional[Sequence[Union[int, Integer]]] = None, + allow_padding: bool = False, + exact_strides: Optional[Sequence[Union[int, Integer]]] = None, +) -> Tuple[StorageBox, Layout]: + """ + Try to simplify x into a StorageBox and a Layout. + + allow_padding only affect how we apply stride_order. When allow_padding + is True, we have the freedom to add padding when applying the stride_order. + """ + if isinstance(x, TensorBox): + return as_storage_and_layout( + x.data, + freeze=freeze, + want_contiguous=want_contiguous, + stride_order=stride_order, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + if isinstance(x, StorageBox) and isinstance(x.data, Buffer): + if freeze: + if want_contiguous: + x.data.freeze_layout() + assert x.data.layout.is_contiguous() + elif stride_order is not None: + x.data.freeze_layout_with_stride_order( + stride_order, allow_padding=allow_padding + ) + elif exact_strides is not None: + x.data.freeze_layout_with_exact_strides( + exact_strides, allow_padding=allow_padding + ) + else: + x.data.decide_layout() + return x, x.data.layout + if isinstance(x, ReinterpretView): + # making the base of x contiguous or stride_ordered will not necessarily make + # the ReinterpretView either, so don't pass along those arguments + buffer, _ = as_storage_and_layout( + x.data, + freeze=freeze, + ) + return buffer, x.layout + raise NotImplementedError + + +as_contiguous_storage_and_layout = functools.partial( + as_storage_and_layout, want_contiguous=True +) + + +def is_stride_order_storage_and_layout( + x: IRNode, stride_order: Sequence[Union[int, Integer]] +) -> bool: + try: + buffer, layout = as_storage_and_layout(x, freeze=False) + return layout.is_stride_ordered(stride_order) + except NotImplementedError: + return False + + +@dataclasses.dataclass +class BaseView(IRNode): + data: IRNode + + def get_unbacked_symbol_uses(self): + return self.data.get_unbacked_symbol_uses() + + def make_reindexer(self): + raise NotImplementedError(f"make_reindexer NYI on {self}") + + def make_indexer(self): + inner = self.data.make_indexer() + reindex = self.make_reindexer() + + def indexer(idx): + return inner(reindex(idx)) + + return indexer + + def make_loader(self): + inner = self.data.make_loader() + reindex = self.make_reindexer() + + def loader(idx): + return inner(reindex(idx)) + + return loader + + @property + def dtype(self): + return self.data.dtype + + def get_layout(self): + return self.data.get_layout() + + def get_device(self): + return self.data.get_device() + + def get_origin_node(self): + return None + + def get_name(self): + return self.data.get_name() + + def get_pointwise_size(self): + return self.get_size() + + def mark_reuse(self, users): + return self.data.mark_reuse(users) + + def has_exceeded_max_reads(self): + return self.data.has_exceeded_max_reads() + + def realize(self): + return self.data.realize() + + def realize_hint(self): + return self.data.realize_hint() + + def get_storage_numel(self): + return self.data.get_storage_numel() + + def is_extern(self): + return self.data.is_extern() # type: ignore[attr-defined] + + def is_module_buffer(self): + return self.data.is_module_buffer() # type: ignore[attr-defined] + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() + + def get_reads(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + return extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + + def unwrap_view(self): + x: IRNode = self + while isinstance(x, BaseView): + x = x.data + return x + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Pointwise(device, self.get_dtype(), loader, self.get_size()) + + +@dataclasses.dataclass +class ExpandView(BaseView): + size: List[Expr] + + @staticmethod + def _normalize_size(x, new_size): + """Replace `-1` with correct sizes""" + sizevars = V.graph.sizevars + new_size = list(map(sympy.expand, new_size)) + old_size = x.get_size() + old_size = [None] * (len(new_size) - len(old_size)) + list(old_size) + assert len(new_size) == len(old_size) + for i in range(len(new_size)): + if new_size[i] == -1: + assert old_size[i] is not None + new_size[i] = old_size[i] + elif old_size[i] is None or old_size[i] == 1: + pass + else: + # Sanity check: Expect broadcast compatibility + # + # NB: new_size[i] == old_size[i] is expected to already be + # guarded because the meta formula was expected to have taught + # us this equality. + assert ( + sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0 + ), "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}" + return new_size + + @classmethod + def create(cls, x, new_size): + new_size = cls._normalize_size(x, new_size) + + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + skip = len(new_size) - len(old_layout.size) + assert skip >= 0 + new_stride = [sympy.Integer(0)] * skip + for stride, size in zip(old_layout.stride, old_layout.size): + new_stride.append(stride if size != 1 else sympy.Integer(0)) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + list(new_size), + new_stride, + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + return ExpandView(x, new_size) + + def get_size(self): + return self.size + + def make_reindexer(self): + target = self.get_size() + actual = self.data.get_size() + skip = len(target) - len(actual) + + def reindex(index): + index = list(index[skip:]) + assert len(index) == len(actual) + for i in range(len(actual)): + if actual[i] == 1: + # zero out broadcast dimension + index[i] = sympy.Integer(0) + return index + + return reindex + + +@dataclasses.dataclass +class PermuteView(BaseView): + dims: List[Expr] + + @classmethod + def create(cls, x, dims): + dims = cls._map_neg_dims(dims) + assert OrderedSet(dims) == OrderedSet(range(len(dims))) + + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + [old_layout.size[i] for i in dims], + [old_layout.stride[i] for i in dims], + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + return PermuteView(x, dims) + + @classmethod + def _map_neg_dims(cls, dims): + return [dim if dim >= 0 else len(dims) + dim for dim in dims] + + def get_size(self): + assert OrderedSet(self._map_neg_dims(self.dims)) == OrderedSet( + range(len(self.dims)) + ) + size = self.data.get_size() + return [size[i] for i in self.dims] + + def make_reindexer(self): + inv = {j: i for i, j in enumerate(self.dims)} + inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index] + assert OrderedSet(inv) == OrderedSet(range(len(self.dims))) + + def reindex(index): + return [index[i] for i in inv] + + return reindex + + +class SqueezeView(BaseView): + @classmethod + def create(cls, x, *, dim=None): + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_size = [] + new_stride = [] + if dim is not None: + assert isinstance(dim, int), "expected integer dim argument" + assert 0 <= dim and dim < len(old_layout.size) + + for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)): + if dim is None: + if size != 1: + new_size.append(size) + new_stride.append(stride) + else: + if i != dim: + new_size.append(size) + new_stride.append(stride) + else: + assert size == 1, "expected squeezed size to be 1" + + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + if dim is None: + # redirect to a generic view + return View.create(x, [s for s in x.get_size() if s != 1]) + else: + assert x.get_size()[dim] == 1 + return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim]) + + @staticmethod + def squeezer(size: Tuple[sympy.Expr, ...]): + new_size = [s for s in size if s != 1] + not_one = [i for i, s in enumerate(size) if s != 1] + length = len(size) + + def reindex(index: List[sympy.Expr]) -> Tuple[sympy.Expr, ...]: + assert len(index) == len(not_one), f"{index} {not_one}" + new_index = [sympy.Integer(0)] * length + for idx, s in zip(not_one, index): + new_index[idx] = s + return tuple(new_index) + + return new_size, reindex + + def __init__(self, data): + raise AssertionError("use SqueezeView.create()") + + +@dataclasses.dataclass +class GenericView(BaseView): + size: List[Expr] + reindex: Callable[..., Any] + + def make_reindexer(self): + return self.reindex + + def reindex_str(self): + index_old = [ + sympy_index_symbol_with_prefix(SymT.INDEX, n) for n in range(len(self.size)) + ] + index_new = list(self.reindex(index_old)) + return f"lambda {', '.join(map(str, index_old))}: {index_new}" + + def __str__(self) -> str: + return self.str_helper( + [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"] + ) + + __repr__ = __str__ + + @classmethod + def create(cls, x, new_size, reindex): + return cls(x, list(new_size), reindex) + + def get_size(self): + return self.size + + +@dataclasses.dataclass +class View(GenericView): + @staticmethod + def handle_negative_index(idx, size): + idx = sympy.expand(idx) + size = sympy.expand(size) + evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr + if evaluate_expr(sympy.Lt(idx, 0)): + idx = idx + size + return idx + + @classmethod + def create(cls, x, new_size): + assert isinstance(new_size, (tuple, list)) + old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) + + # Skip pointless views + if V.graph.sizevars.statically_known_list_equals(old_size, new_size): + return x + + unbacked_symbols_in_sizes = False + if ( + len(free_unbacked_symbols(old_size)) > 0 + or len(free_unbacked_symbols(new_size)) > 0 + ): + unbacked_symbols_in_sizes = True + + if 0 in new_size: + + def fake_reindex(index): + return tuple([0] * len(old_size)) + + return cls(x, list(new_size), fake_reindex) + # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout + elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes: + if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)): + # realize x; otherwise, the dynamic_reshape_indexer below will fail + # due to the size_hint's inability to process unbacked SymInts + x = ExternKernel.realize_input(x) + + storage, old_layout = as_contiguous_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + FlexibleLayout.contiguous_strides(new_size), + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + reindex = cls.dynamic_reshape_indexer(old_size, new_size) + return cls(x, list(new_size), reindex) + + @staticmethod + def resolve_negative_size(old_size, new_size): + new_size = [V.graph.sizevars.simplify(x) for x in new_size] + old_size = [V.graph.sizevars.simplify(x) for x in old_size] + + new_size = list(new_size) + for i in range(len(new_size)): + if new_size[i] == -1: + new_size[i] = sympy.Integer(1) + new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size)) + break + + V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size)) + return old_size, new_size + + @classmethod + def dynamic_reshape_indexer(cls, old_size, new_size): + try: + reindex = cls._dynamic_reshape_indexer(old_size, new_size) + except (AssertionError, IndexError): + # optimistic algorithm failed, lets do a fallback + flat = [sympy_product(old_size)] + reindex1 = cls._dynamic_reshape_indexer(old_size, flat) + reindex2 = cls._dynamic_reshape_indexer(flat, new_size) + reindex = fuse_reindexing(reindex1, reindex2) + return reindex + + @staticmethod + def _dynamic_reshape_indexer(old_size, new_size): + """ + Perform a reshape entirely by modifying indexing math + """ + size_hint = V.graph.sizevars.size_hint + # TODO: These symbols may not escape, if they don't assert so and + # treat them as temporary + vars = [ + sympy_index_symbol_with_prefix(SymT.VIEW, i) for i in range(len(new_size)) + ] + + stack_new = list(zip(vars, new_size)) + stack_old = list(old_size) + + view_expr = [] + while stack_new and stack_old: + size_old = stack_old.pop() + var, size_new = stack_new.pop() + if size_old == 1: + view_expr.append(sympy.Integer(0)) + stack_new.append((var, size_new)) # re-add + elif size_new == 1: + stack_old.append(size_old) # re-add + elif size_hint(size_new) == size_hint(size_old): + view_expr.append(var) + V.graph.sizevars.guard_equals(size_new, size_old) + elif size_hint(size_new) < size_hint(size_old): + while size_hint(size_new) < size_hint(size_old): + var2, size_new2 = stack_new.pop() + var = var2 * size_new + var + size_new = size_new * size_new2 + view_expr.append(var) + V.graph.sizevars.guard_equals(size_new, size_old) + elif size_hint(size_new) > size_hint(size_old): + divisor = sympy.Integer(1) + modulus = size_old + view_expr.append(ModularIndexing(var, divisor, modulus)) + divisor = divisor * modulus + while size_hint(size_new) > size_hint(size_old): + modulus = stack_old.pop() + view_expr.append(ModularIndexing(var, divisor, modulus)) + divisor = divisor * modulus + size_old = size_old * modulus + V.graph.sizevars.guard_equals(size_new, size_old) + else: + raise AssertionError + + while stack_old: + size_old = stack_old.pop() + V.graph.sizevars.guard_equals(size_old, 1) # type: ignore[arg-type] + view_expr.append(sympy.Integer(0)) + + while stack_new: + var, size_new = stack_new.pop() + V.graph.sizevars.guard_equals(size_new, 1) # type: ignore[arg-type] + + view_expr.reverse() + assert len(view_expr) == len(old_size) + + def reindex(index): + assert len(index) == len(vars), (len(index), len(vars)) + replacements = dict(zip(vars, index)) + return tuple(sympy_subs(x, replacements) for x in view_expr) # type: ignore[arg-type] + + return reindex + + +@dataclasses.dataclass +class ReinterpretView(BaseView): + """Pretend our storage has a different layout""" + + layout: Layout + + def __post_init__(self): + super().__post_init__() + if isinstance(self.data, BaseView): + self.data = self.data.unwrap_view() + + def __str__(self) -> str: + return self.str_helper( + [ + self.data, + self.layout, + ] + ) + + __repr__ = __str__ + + def get_name(self): + return self.data.get_name() + + def get_device(self): + return self.layout.device + + def get_origin_node(self): + return None + + @property + def dtype(self): + return self.layout.dtype + + def get_size(self): + return list(self.layout.size) + + def get_stride(self): + return list(self.layout.stride) + + def make_loader(self): + def loader(index): + indexer = self.layout.make_indexer() + tmp_loader = ops.load(self.get_name(), indexer(index)) + if self.layout.dtype != self.data.dtype: + return ops.to_dtype_bitcast(tmp_loader, self.dtype, self.data.dtype) + else: + return tmp_loader + + return loader + + def make_indexer(self): + return self.layout.make_indexer() + + def get_layout(self): + return self.layout + + def freeze_layout(self): + pass + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return ( + free_unbacked_symbols(self.layout.size) + | free_unbacked_symbols(self.layout.stride) + | free_unbacked_symbols(self.layout.offset) + ) + + def codegen_reference(self, writer=None): + # reinterpret_tensor is similar to as_strided except: + # - offset is added to the existing offset (rather than replacing it) + # - view tracking is disabled similar to unsafe_view + return V.graph.wrapper_code.codegen_reinterpret_view( + self.data, + self.layout.size, + self.layout.stride, + self.layout.offset, + writer, + dtype=self.layout.dtype, + ) + + def num_reads(self): + return 1 + + +@dataclasses.dataclass +class DtypeView(BaseView): + """Pretend our storage has a different type""" + + target_dtype: torch.dtype + + @classmethod + def create(cls, x, new_dtype): + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + new_dtype, + old_layout.size, + old_layout.stride, + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + return DtypeView(x, new_dtype) + + def __str__(self) -> str: + return self.str_helper([self.data, self.target_dtype]) + + __repr__ = __str__ + + @property + def dtype(self): + return self.target_dtype + + def get_size(self): + return self.data.get_size() + + def make_loader(self): + inner = self.data.make_loader() + + def loader(idx): + return ops.to_dtype_bitcast(inner(idx), self.target_dtype, self.data.dtype) + + return loader + + +class SliceView(View): + @classmethod + def normalize_start_end(cls, x, dim, start, end): + """ + Normalize start and end such that both are in the range + [0, x.get_size()[dim]] and start <= end. + """ + sizevars = V.graph.sizevars + dim_size = x.get_size()[dim] + + if any(free_unbacked_symbols(x) for x in (start, end, dim_size)): + + def clamp(x, lower, upper): + return sympy.Min(sympy.Max(x, lower), upper) + + else: + + def clamp(x, lower, upper): + return sizevars.evaluate_min(sizevars.evaluate_max(x, lower), upper) + + def clamp_wrap(val, lower, upper, default): + if val is None: + return default + val = cls.handle_negative_index(val, dim_size) + return clamp(val, lower, upper) + + start = clamp_wrap(start, 0, dim_size, 0) + end = clamp_wrap(end, start, dim_size, dim_size) + return start, end + + @classmethod + def create(cls, x, dim, start, end, step=1, clamp=True): + step = sympy.expand(step) + assert isinstance(step, sympy.Expr) or step > 0 + try: + if start == 0 and end >= 2**63 - 1 and step == 1: + return x + except TypeError: + pass + + sizevars = V.graph.sizevars + new_size = list(x.get_size()) + + # NB: Ordinarily we default to clamping. + # We only don't clamp for split_with_sizes. For split_with_sizes, sizes should be already valid + # failing in this situation is ok, since invalid sizes could trigger silent errors. + if clamp: + start, end = cls.normalize_start_end(x, dim, start, end) + + new_size[dim] = FloorDiv(end - start + (step - 1), step) + + if is_storage_and_layout(x): + # Fast path + storage, old_layout = as_storage_and_layout(x) + new_stride = list(old_layout.stride) + new_stride[dim] = new_stride[dim] * step + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset + old_layout.stride[dim] * start, + ) + return ReinterpretView(storage, new_layout) + + def reindex(index): + assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" + index = list(index) + index[dim] = index[dim] * step + start + return index + + # redirect to a generic view + return SliceView(x, size=new_size, reindex=reindex) + + +class BaseConstant(IRNode): + dtype: torch.dtype + device: torch.device + + def get_size(self): + return () + + def get_device(self): + return self.device + + def get_origin_node(self): + return None + + def mark_reuse(self, users): + pass + + def has_exceeded_max_reads(self): + return False + + def get_reads(self): + return () + + def is_extern(self): + return False + + +@dataclasses.dataclass +class Constant(BaseConstant): + value: Any + dtype: torch.dtype + device: torch.device + + def make_loader(self): + def loader(index): + return ops.constant(self.value, self.dtype) + + return loader + + def realize(self): + pass + + def constant_to_device(self, device): + return Constant(self.value, self.dtype, device) + + +@dataclasses.dataclass +class IndexingConstant(BaseConstant): + index: Any + dtype: torch.dtype + device: torch.device + + def make_loader(self): + def loader(index): + return ops.index_expr(self.index, self.dtype) + + return loader + + def constant_to_device(self, device): + return IndexingConstant(self.index, self.dtype, device) + + +def is_contiguous_strides_for_shape( + stride: Sequence[_IntLike], shape: Sequence[_IntLike] +) -> bool: + return all( + size == 1 or left == right + for left, right, size in zip( + stride, FlexibleLayout.contiguous_strides(shape), shape + ) + ) + + +def get_align_for_dtype(dtype: torch.dtype) -> int: + return config.padding_alignment_bytes // dtype.itemsize + + +@dataclasses.dataclass +class Layout(IRNode): + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: List[Expr], + stride: Optional[Sequence[Union[Expr, int]]], + offset: Expr = Integer(0), + ): + assert stride is None or len(size) == len( + stride + ), f"size={size}, stride={stride}" + self.device = device + self.dtype = dtype + assert all(isinstance(s, (Expr, int)) for s in size) + self.size = size + self._stride = stride + self.offset = offset + + @property + def stride(self): + return self._stride + + def __str__(self) -> str: + offset = "" + if self.offset != 0: + offset = f", offset={self.offset}" + return ( + f"{type(self).__name__}('{self.device.type}', {self.dtype}, " + f"size={self.size}, stride={self.stride}{offset})" + ) + + __repr__ = __str__ + + def is_contiguous(self): + return is_contiguous_strides_for_shape(self.stride, self.size) + + @staticmethod + def is_channels_last_contiguous(shape, strides): + ndim = len(shape) + if ndim not in [4, 5] or shape[1] == 1: + return False + for left, right, size in zip( + strides, make_channels_last_strides_for(shape), shape # type: ignore[arg-type] + ): + if size != 1 and left != right: + return False + return True + + def is_transposed(self): + for left, right, size in zip( + self.stride, + reversed(FlexibleLayout.contiguous_strides(list(reversed(self.size)))), + self.size, + ): + if size != 1 and left != right: + return False + return True + + def is_stride_ordered(self, order): + assert len(self.stride) == len(order) + + # ignore dimensions of size 1, they dont affect layout + non_1_indices = [ + i + for i, dim in enumerate(self.size) + if V.graph.sizevars.size_hint(dim, fallback=2) != 1 + ] + + stride = [self.stride[i] for i in non_1_indices] + order = [order[i] for i in non_1_indices] + + def sorted_indices(arr): + sorted_arr = sorted(arr) + return [sorted_arr.index(element) for element in arr] + + # since we may have removed dimensions, need to re-sort & re-index order + order = sorted_indices(order) + + # reorder the stride given order + stride_ordered = [-1] * len(order) + for i in range(len(order)): + stride_ordered[order[i]] = V.graph.sizevars.size_hint(stride[i]) + # check if it is in ascending order + for i in range(len(order) - 1): + if stride_ordered[i] > stride_ordered[i + 1]: + return False + return True + + def is_channels_last_stride_ordered(self): + # create channels_last order(NCHW, NCDHW, the C is the first order). + order = [0] + list(reversed(range(1, len(self.stride) - 1))) + order = [len(order)] + order + return self.is_stride_ordered(order) + + @staticmethod + def _pad_strides(in_strides, size, dtype): + """ + The padding does not change stride order but makes sure all strides larger + than the threshold are multiple of align. + """ + align = get_align_for_dtype(dtype) + if len(in_strides) == 0: + return in_strides + + if not config.pad_channels_last and Layout.is_channels_last_contiguous( + size, in_strides + ): + return in_strides + + current_fx_node = V.get_current_node() + if hasattr(current_fx_node, "meta") and current_fx_node.meta.get( + "dislike_padding", False + ): + return in_strides + + # get_stride_order does not work with dynamic shape. Also we can not + # statically decide if a padding is needed or how much padding we should + # do for dynamic shape. + # + # Skip padding the strides for dynamic shape for now. + if not all( + isinstance(s, (int, sympy.Integer)) + for s in itertools.chain(in_strides, size) + ): + return in_strides + + stride_order = get_stride_order(in_strides) + fill_order = stride_order2fill_order(stride_order) + + new_strides = [0 for _ in range(len(in_strides))] + # since we pad when the layout is flexible, we can decide the + # smallest stride to be 1. + new_strides[fill_order[0]] = 1 + + padded = False + for rank, idx in enumerate(fill_order[1:], start=1): + prev_idx = fill_order[rank - 1] + stride = new_strides[prev_idx] * size[prev_idx] + + if stride > config.padding_stride_threshold and stride % align != 0: + stride = ceildiv(stride, align) * align + padded = True + new_strides[idx] = stride + + if not padded: + # Consider a tensor with shape [256, 1, 5, 5] + # Avoid strides like [25, 5, 5, 1] being padded to equivalent strides + # [25, 25, 5, 1]. + return in_strides + + metrics.num_comprehensive_padding += 1 + return new_strides + + def pad_strides(self): + assert isinstance(self, FlexibleLayout) + assert self._stride is not None + self._stride = self._pad_strides(self._stride, self.size, self.dtype) + + def should_pad_strides(self): + return config.comprehensive_padding and isinstance(self, FlexibleLayout) + + def as_fixed(self): + if isinstance(self, FixedLayout): + return self + + if self.should_pad_strides(): + self.pad_strides() + return FixedLayout( + self.device, + self.dtype, + self.size, + self.stride, + self.offset, + ) + + def make_indexer(self): + assert ( + FlexibleLayout.allow_indexing + ), f"convert {type(self).__name__} to FixedLayout first" + return self.as_fixed().make_indexer() + + def __eq__(self, other) -> bool: + return ( + self.device == other.device + and self.dtype == other.dtype + and self.size == other.size + and self.stride == other.stride + and self.offset == other.offset + ) + + def storage_size(self) -> sympy.Expr: + return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type, return-value] + + +class FixedLayout(Layout): + """A Tensor layout we cannot change""" + + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: Union[List[Expr], List[int]], + stride: Optional[Sequence[Union[Expr, int]]] = None, + offset: Union[Expr, int] = Integer(0), + ): + if stride is None: + stride = FlexibleLayout.contiguous_strides(size) + super().__init__( + device, + dtype, + size, # type: ignore[arg-type] + stride, + offset, # type: ignore[arg-type] + ) + + def make_indexer(self): + """A closure containing math to read a given element""" + + def indexer(index): + assert len(index) == len(self.stride) + assert len(index) == len(self.size) + result = self.offset + for idx, stride, sz in zip(index, self.stride, self.size): + if sz != 1: + result = result + idx * stride + return result + + return indexer + + +class FlexibleLayout(Layout): + """A Tensor layout we are allowed to change""" + + allow_indexing = False + + # WARNING! This doesn't handle zero size tensors correctly + @staticmethod + def contiguous_strides(sizes): + if len(sizes) == 0: + return [] + reversed_strides = [sympy.Integer(1)] + for size in reversed(sizes[1:]): + reversed_strides.append(size * reversed_strides[-1]) + return list(reversed(reversed_strides)) + + @staticmethod + def fill_ordered(sizes, order): + """ + Create a stride based on the order the dimensions should be filled in. + + In this format, channels last would be: + [1, 3, 2, 0] + """ + assert OrderedSet(range(len(sizes))) == OrderedSet(order), (sizes, order) + next_stride = sympy.Integer(1) + strides = [None] * len(order) + + for i in order: + strides[i] = next_stride + next_stride = next_stride * sizes[i] + return strides + + @staticmethod + def stride_ordered(sizes, order): + """ + Create a stride based on the sorted order of a permuted range. + + In this format, channels last would be: + [3, 0, 2, 1] + """ + assert OrderedSet(range(len(sizes))) == OrderedSet(order) + fill_order = stride_order2fill_order(order) + return FlexibleLayout.fill_ordered(sizes, fill_order) + + @staticmethod + def stride_ordered_for_memory_format(sizes, memory_format): + """ + Create a stride based on a memory format. + + Memory format is translasted into a stride order, + so channels_last is the same as: + FlexibleLayout.stride_ordered(sizes, [3, 0, 2, 1]) + + This interface does not support memory_format `torch.preserve_format` + which should be used to deduce a format from another source + """ + if memory_format == torch.channels_last: + return FlexibleLayout.stride_ordered(sizes, NHWC_STRIDE_ORDER) + elif memory_format == torch.channels_last_3d: + return FlexibleLayout.stride_ordered(sizes, NHWDC_STRIDE_ORDER) + elif memory_format == torch.contiguous_format: + return FlexibleLayout.contiguous_strides(sizes) + else: + log.debug( + "stride_ordered_for_memory_format, unsuppored memory_format: %s", + memory_format, + ) + raise NotImplementedError + + @staticmethod + def same_ordered(sizes, stride): + """ + Create a stride that has the same stride order as given stride + + For example, if given stride is [1000, 1, 100, 10], + the fill order should be [1, 3, 2, 0] + """ + assert len(sizes) == len(stride) + stride = [V.graph.sizevars.size_hint(x) for x in stride] + fill_order = sorted(range(len(stride)), key=stride.__getitem__) + return FlexibleLayout.fill_ordered(sizes, fill_order) + + def as_stride_order(self, order, allow_padding=False): + new_stride = self.stride_ordered(self.size, order) + if self.should_pad_strides() and allow_padding: + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + ) + + def as_exact_strides(self, exact_strides, allow_padding=False): + new_stride = exact_strides + if self.should_pad_strides() and allow_padding: + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + ) + + def as_fill_order(self, order): + new_stride = self.fill_ordered(self.size, order) + if self.should_pad_strides(): + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + ) + + def as_same_order(self, stride): + new_stride = self.same_ordered(self.size, stride) + if self.should_pad_strides(): + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + ) + + def __init__(self, device, dtype, size, stride_order=None): + if stride_order: + strides = FlexibleLayout.fill_ordered(size, stride_order) + else: + strides = FlexibleLayout.contiguous_strides(size) + super().__init__(device, dtype, size, strides) + + +class NonOwningLayout(Layout): + """Is a view into the storage of another tensor""" + + def __init__(self, view: Union[BaseView, TensorBox]): + layout = view.get_layout() + super().__init__( + layout.device, + layout.dtype, + layout.size, + layout.stride, + ) + self.view = view + + def make_indexer(self): + return self.as_fixed().make_indexer() + + def maybe_guard_aligned(self): + offset = self.view.get_layout().offset + if offset == 0: + return True + from .utils import ALIGNMENT + + return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) # type: ignore[arg-type] + + +class NoneLayout(IRNode): + # This is janky, I figured out what fields to populate by just running + # the model I was interested in and adding properties/methods as needed. + # This doesn't inherit from Layout because Layout assumes you have stuff + # like sizes, but I don't really have anything here. + # + # If you have an ir.Node with NoneLayout, you probably need to setup + # dependencies manually in scheduler + + def __init__(self, device): + self.device = device + self.size = [0] + self.stride = [0] + + def storage_size(self): + return 0 + + def as_fixed(self): + return self + + +class MutationLayoutSHOULDREMOVE(Layout): + def __init__(self, target: IRNode): + super().__init__( + target.get_device(), + target.get_dtype(), + target.get_size(), + None, + ) + self.target = target + name = self.get_buffer().get_name() + V.graph.mark_buffer_mutated(name) + + @Layout.stride.getter # type: ignore[attr-defined] + def stride(self): + return self.real_layout().stride + + def storage_size(self) -> sympy.Expr: + return self.real_layout().storage_size() + + def get_buffer(self) -> Buffer: + def unwrap_views(target): + if isinstance(target, MutationLayoutSHOULDREMOVE): + return unwrap_views(target.target) + if isinstance(target, BaseView): + return unwrap_views(target.unwrap_view()) + if isinstance(target, MutableBox): + return unwrap_views(target.data) + return target + + result = unwrap_views(self.target) + assert isinstance( + result, Buffer + ), "MutationLayoutSHOULDREMOVE must refer to a buffer" + return result + + def real_layout(self): + return self.get_buffer().layout + + @classmethod + def realize_into(cls, src, dst, unsafe_alias=False): + dst.realize() + # NOTE: We must realize users of `dst` before we realize `src`, since + # realization order determines scheduling order. Otherwise, src's + # mutation would be scheduled before the existing users of dst! + V.graph.mark_buffer_mutated(dst.get_name()) + + if isinstance(src, TensorBox): + src = src.data + + # We copy the contents of src into dst. In most cases this should + # be fused into a single kernel by the scheduler. + # NOTE: We cannot change src's layout to mutate dst directly as this + # would alias src to dst, which is not correct as further mutations to + # dst would effect users of src. However if there are no more users of + # dst, we can alias src to dst. + src.realize_hint() + + if not unsafe_alias: + src = Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ).data + + src.realize() + assert isinstance(src.data.layout, FlexibleLayout) + src.data.layout = MutationLayoutSHOULDREMOVE(dst) + return src.data + + def as_fixed(self): + return self + + def make_indexer(self): + return self.target.make_indexer() + + +@dataclasses.dataclass +class Buffer(IRNode): + # Name is sometimes None; e.g., ForceInPlace, where there isn't + # a meaningful name + name: Optional[str] + layout: Layout + + # Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly, + # MultiOutput does NOT define this! + + def __post_init__(self): + super().__post_init__() + self.origin_node = None + + def make_indexer(self): + return self.layout.make_indexer() + + def get_name(self) -> str: + assert self.name, self + return self.name + + def get_device(self): + return self.layout.device + + def get_origin_node(self): + return self.origin_node + + def get_defining_op(self) -> Optional[Operation]: + return None + + @property + def dtype(self): + return getattr(self.layout, "dtype", None) + + def get_size(self): + return list(self.layout.size) + + def get_stride(self): + return list(self.layout.stride) + + def get_offset(self): + return self.layout.offset + + def get_layout(self): + return self.layout + + def get_storage_numel(self): + return self.get_numel() + + def is_extern(self): + return False + + def freeze_layout(self): + if not isinstance(self.layout, (MultiOutputLayout, NonOwningLayout)): + self.layout = self.layout.as_fixed() + + def freeze_layout_with_stride_order(self, order, allow_padding=False): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_stride_order(order, allow_padding=allow_padding) + + def freeze_layout_with_fill_order(self, order): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_fill_order(order) + + def freeze_layout_with_same_order(self, stride): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_same_order(stride) + + def freeze_layout_with_exact_strides(self, exact_strides, allow_padding=False): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_exact_strides( + exact_strides, allow_padding=allow_padding + ) + + def is_zero_elements(self): + return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] + + def make_loader(self): + # Loading from a zero-element buffer is a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.get_dtype()) + + def loader(index): + indexer = self.layout.make_indexer() + return ops.load(self.name, indexer(index)) + + return loader + + def codegen_reference(self, writer=None): + return self.get_name() + + def decide_layout(self): + pass + + def get_inputs_that_alias_output(self): + if isinstance(self.layout, NonOwningLayout): + return [self.layout.view.get_name()] + return () + + def get_mutation_names(self): + if isinstance(self.layout, MutationLayoutSHOULDREMOVE): + return [self.layout.target.get_name()] + return () + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet([self.get_name()]) + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def realize(self): + pass + + def should_allocate(self): + # Returns False by default. + return False + + +@dataclasses.dataclass +class OperationBuffer(Buffer, Operation): + # An operation that produces a single output buffer + def get_outputs(self) -> List[Buffer]: + return [self] + + def get_defining_op(self) -> Operation: + return self + + def __post_init__(self): + Buffer.__post_init__(self) + Operation.__post_init__(self) + + +class InputBuffer(Buffer): + def num_reads(self): + return 1 + + +class ConstantBuffer(InputBuffer): + override_device: Optional[torch.device] = None + + def make_loader(self): + def loader(index): + indexer = self.layout.make_indexer() + return ops.load( + V.graph.constant_name(self.get_name(), self.override_device), + indexer(index), + ) + + return loader + + def constant_to_device(self, device): + return ConstantBuffer( + V.graph.constant_name(self.get_name(), device), self.layout + ) + + +class NoneAsConstantBuffer(IRNode): + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def codegen_reference(self, writer=None): + return V.graph.wrapper_code.none_str + + +class ShapeAsConstantBuffer(IRNode): + def __init__(self, shape): + super().__init__() + self._shape = shape + + @property + def shape(self): + return self._shape + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return free_unbacked_symbols(self.shape) + + def codegen_reference(self, writer=None): + return V.graph.wrapper_code.expr_printer(V.graph.sizevars.simplify(self.shape)) + + +@dataclasses.dataclass +class ComputedBuffer(OperationBuffer): + data: Loops + + def get_computed_buffer_name(self): + """ + Returns self.name if it exists, otherwise returns the name of the data node if that exists. + If neither exist, returns None. + """ + if self.name is not None: + return self.name + if hasattr(self.data, "name"): + return self.data.name + return None + + def num_reads(self): + return self.data.num_reads() + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() + + def get_read_writes(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + if self.data.get_reduction_type(): + return extract_read_writes( + self.get_store_function(), + self.data.get_pointwise_size(), + self.data.get_reduction_size(), + ) + else: + return extract_read_writes( + self.get_store_function(), + self.data.get_size(), + ) + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + # Ordinarily, we'd like to just peek at the arguments list, + # but ComputedBuffers have no argument list. + # + # Morally, this logic needs to be synchronized with the + # KernelArgs.size calls, which are responsible for making symbols make + # there way as kernel arguments (and it is precisely passing in one of + # those symbols that establishes a dependency). However, we haven't + # started codegen yet so we can't directly reuse that logic. + # + # For now, I'm just yoloing with the size of the buffer. Not sure if + # it is enough. + # + # One thing you might wonder is if this is enough for a ComputedBuffer + # denoting a reduction over i0. Empirically, it is enough, but for an + # unusual reason: we only need accurate dependencies for item() call, + # but it's impossible to end up with a reduction over i0 from an + # item() call without a regular non-reduction buffer first. + return ( + free_unbacked_symbols(self.get_size()) + | free_unbacked_symbols(self.get_stride()) + | free_unbacked_symbols(self.get_offset()) + | self.data.get_unbacked_symbol_uses() + ) + + def make_loader(self): + # Inline constants and index_expressions + if ( + hasattr(self.data, "make_loader") + and self.name not in V.graph.mutated_buffers + and self.num_reads() == 0 + ): + # can be inlined + return self.data.make_loader() + return super().make_loader() + + def get_store_function(self): + indexer = self.layout.as_fixed().make_indexer() + if isinstance(self.data, (Reduction, Scan, Sort)): + return partial(self.data.store_reduction, self.name, indexer) + else: + assert isinstance(self.data, Pointwise) + return partial(self.data.store_output, self.name, indexer) + + def get_fill_order(self): + """ + If our layout is still flexible, try to determine the stride order based on stride orders of reads. + + TODO(jansel): A better algorithm here would look at downstream consumers of this + value and try to do global graph-level layout optimization. + This is also something just begging to be autotuned. + """ + if isinstance(self.layout, FlexibleLayout): + (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze( + self.data.get_pointwise_size(), self.data.get_reduction_size() + ) + reads = self.get_read_writes().reads + # only consider reads to buffer of same size + # ignore StarDeps because they don't contribute stride information + assert all( + isinstance(r, (dependencies.StarDep, dependencies.MemoryDep)) + for r in reads + ) + reads = [ + sympy_subs( + r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0} + ) + for r in reads + if isinstance(r, dependencies.MemoryDep) + ] + + if reads: + if isinstance(self.data, (Scan, Sort)): + indices = self.data.reindex(index_vars, reduction_vars) + else: + indices = index_vars + stride_lengths = [ + V.graph.sizevars.stride_hints(expr, indices) for expr in reads # type: ignore[arg-type] + ] + from .scheduler import pick_loop_order + + return pick_loop_order(stride_lengths, self.get_size()) + + return None + + def decide_layout(self): + if isinstance(self.layout, FlexibleLayout): + order = self.get_fill_order() + if order: + self.freeze_layout_with_fill_order(order) + else: + self.freeze_layout() + + @cache_on_self + def get_default_sizes_body(self): + args, var_ranges = dependencies.index_vars_squeeze( + self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q" + ) + with patch.object(ConstantBuffer, "override_device", self.get_device()): + body = LoopBody( + self.get_store_function(), + (args if self.get_reduction_type() else args[:1]), + var_ranges, + *args, + ) + index_vars = [] + reduce_vars: List[Any] = [] + index_size = [] + reduce_size = [] + for v, s in var_ranges.items(): + if v in args[0]: + assert not reduce_vars + index_vars.append(v) + index_size.append(s) + else: + assert v in args[1] + reduce_vars.append(v) + reduce_size.append(s) + return (index_size, reduce_size), body, (index_vars, reduce_vars) + + def simplify_and_reorder( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ): + """ + This is a main place where we do loop transformations in a + backend-agnostic way. + + Here we: + 1) Remove any 1 dimensions + 2) Fuse contiguous dimensions together + 3) Reorder dimensions based on stride orders + + Optional argument extra_indexing_constraints can be used to append additional + indexing expressions to existing ones derived from buffer's body. This can be useful + to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...) + on CPU by preventing indexing simplifications and obtaining index/reduce ranges for + the scheduler node compatible with other nodes. + Optional argument recompute_sizes_body_func can be used to recompute sizes and body + on the default body. This can be useful to append additional loop transformations. + """ + ( + (index_size, reduce_size), + body, + (index_vars, reduce_vars), + ) = self.get_default_sizes_body() + + if recompute_sizes_body_func: + ( + (index_size, reduce_size), + body, + (index_vars, reduce_vars), + ) = recompute_sizes_body_func( + (index_size, reduce_size), body, (index_vars, reduce_vars) + ) + + index_formulas = [*body.indexing_exprs.values()] + if extra_indexing_constraints is not None: + assert ( + isinstance(extra_indexing_constraints, tuple) + and len(extra_indexing_constraints) == 2 + ) + extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints + assert isinstance(extra_indexing_ranges, dict) + assert isinstance(extra_indexing_expr, list) + assert all(isinstance(f, Expr) for f in extra_indexing_expr) + + expected_var_ranges = body.var_ranges + assert expected_var_ranges == extra_indexing_ranges, ( + expected_var_ranges, + extra_indexing_ranges, + ) + # remove already existing expressions + extra_indexing_expr = [ + e for e in extra_indexing_expr if e not in index_formulas + ] + index_formulas += extra_indexing_expr + + memory_addrs = [*body.get_write_exprs()] + if not V.graph.has_feature(self, BackendFeature.PREFER_STORE_LOOP_ORDER): + memory_addrs.extend(body.get_read_exprs()) + + def simplify_and_reorder(x_vars, support_vars, sizes, simplify_loops): + sizes, reindex0, reindex1 = self._apply_loop_reordering( + x_vars, support_vars, sizes, memory_addrs + ) + # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1] + x_vars = reindex0(x_vars) + + if simplify_loops: + sizes, reindex2, prune = V.graph.sizevars._simplify_loops( + x_vars, + sizes, + index_prevent_reordering(index_formulas, x_vars, sizes), + ) + reindex = fuse_reindexing(reindex1, reindex2) + else: + reindex = reindex1 + return sizes, reindex, reindex1 + + support_vars = index_vars + reduce_vars + should_merge_loops = ( + self.get_device().type != "cuda" or not config.loop_ordering_after_fusion + ) + iter_ranges, iter_reindex, _ = simplify_and_reorder( + index_vars, + support_vars, + index_size, + should_merge_loops, + ) + + # Like iteration dimensions, we may also want to delay merging reduction dimensions. + # E.g., if we reduce a tensor [M, N, K] for its M and N dimensions followed by a pointwise + # kernel, merging M and N dimension too early makes it hard to decide what loop order + # we should pick for the piontwise kernel so that it is fusible with the reduction. + reduce_ranges, reduce_reindex, _ = simplify_and_reorder( + reduce_vars, support_vars, reduce_size, should_merge_loops + ) + + # retrace the loop body with simplification and reordering applied + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + iter_ranges, + reduce_ranges, + prefix="z", + ) + body = LoopBody( + body, + [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], + var_ranges, + iter_vars, + reduce_vars, + ) + return (iter_ranges, reduce_ranges), body + + @staticmethod + def _apply_loop_reordering( + index_vars, + support_vars, + sizes, + memory_addrs, + priority_idx=None, + ): + """ + Shuffle the order of loops around to hopefully improve performance. + """ + from .scheduler import pick_loop_order + + if priority_idx is None: + priority_idx = [] + + try: + strides = [ + V.graph.sizevars.stride_hints(expr, index_vars, support_vars) + for expr in memory_addrs + ] + assert len(strides) == len(memory_addrs) and len(strides[0]) == len( + index_vars + ) + order = list(reversed(pick_loop_order(strides, sizes, priority_idx))) + except Exception: + if config.debug: + log.warning( + "Did not simplify complex index:\n%s\n%s", + dict(zip(index_vars, sizes)), + memory_addrs, + ) + order = list(range(len(sizes))) + sizes = [sizes[i] for i in order] + return sizes, same_reorder(order), inverse_reorder(order) + + def get_reduction_size(self): + return self.data.get_reduction_size() + + def get_reduction_type(self): + return self.data.get_reduction_type() + + def is_no_op(self): + return self.data.is_zero_elements() + + def should_allocate(self): + return True + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + return self.data.constant_to_device(device) + + +class TemplateBuffer(OperationBuffer): + """ + Represents a Triton (in the future other type) of template operator + that we can fuse an epilogue onto. + """ + + def __init__(self, layout, inputs, make_kernel_render): + super().__init__(name=None, layout=layout) + self.inputs = InputsKernel.unwrap_storage(inputs) + self.make_kernel_render = make_kernel_render + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def get_read_writes(self): + return self.extract_read_writes(normalize=True) + + def extract_read_writes(self, normalize): + name = self.get_name() + indexer = self.layout.make_indexer() + + def dummy(index, rindex): + assert len(rindex) == 0 + return ops.store(name, indexer(index), "fake") + + deps = dependencies.extract_read_writes( + dummy, self.get_size(), (), normalize=normalize + ) + deps.reads = OrderedSet(dependencies.StarDep(x.get_name()) for x in self.inputs) + return deps + + def get_reduction_size(self): + return 1 + + def get_reduction_type(self): + return None + + def is_no_op(self): + return False + + def should_allocate(self): + return True + + def simplify_and_reorder( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ): + return ( + ( + self.get_size(), + (), + ), + None, + ) + + +class TritonTemplateBuffer(TemplateBuffer): + def __init__( + self, + layout, + inputs, + make_kernel_render, + debug_extra=None, + mutated_inputs: Optional[Iterable[IRNode]] = None, + ): + """ + NOTE:[TritonTemplates with multiple outputs] + We want the ability for TritonTemplates to output multiple tensors. Triton + kernels have no notion of outputs and this is done by creating tensors that + are then mutated by the kernel. Currenlty our STORE_OUTPUT codegen doesn't + support creating multinode outputs for triton templates. + We work around this by creating an extra input buffer during the lowering + and we mark them as mutated inputs. + """ + super().__init__(layout, inputs, make_kernel_render) + self.debug_extra = debug_extra + self.mutated_inputs = mutated_inputs + self.outputs: List[Buffer] = [self] + if mutated_inputs is not None: + # Ensure that the mutated inputs are only allowed for certain nodes + allowed_set = ( + torch.ops.higher_order.flex_attention, + torch.ops.higher_order.flex_attention_backward, + ) + current_node = V.graph.current_node.target + assert ( + current_node in allowed_set + ), f"Mutated inputs are only allowed for {allowed_set} but got {current_node}" + device = self.inputs[0].get_device() + self.outputs += [ + MutationOutput(NoneLayout(device), buf, self) for buf in mutated_inputs + ] + + def get_outputs(self) -> List[Buffer]: + return self.outputs + + def __str__(self) -> str: + out = f"TritonTemplateBuffer(layout={self.layout}, {self.debug_extra})" + return out + + +PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]] + + +class ChoiceCaller: + """ + Represents a possible choice used in autotune_process.py. + During autotuning, self.benchmark() is first called to get benchmark result, + and if this choice is selected, self.output_node() is called to get the output_node. + + Children classes: TritonTemplateCaller, CUDATemplateCaller. + """ + + def __init__(self, name, input_nodes, layout): + super().__init__() + self.name = name + self.layout = layout + self.input_nodes = input_nodes + + def benchmark(self, *args, out) -> float: + algo = self.to_callable() + return benchmarker.benchmark(algo, args, {"out": out}) + + def call_name(self) -> str: + raise NotImplementedError + + def to_callable(self): + raise NotImplementedError + + def hash_key(self) -> str: + raise NotImplementedError + + def output_node(self) -> TensorBox: + raise NotImplementedError + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return {} + + def autoheuristic_id(self) -> str: + return "unsupported_choice" + + +class TritonTemplateCallerBase(ChoiceCaller): + def get_make_kernel_render(self) -> Any: + raise NotImplementedError + + +class MultiTemplateBuffer(TritonTemplateBuffer): + """ + Represents a Buffer with multiple backing implementation choices. + + Choices can be TritonTemplates or ExternKernels. During scheduling if there is a potential + epilogue we will benchmark each of the choices with the epilogue to determine an implementation. + Otherwise, the fastest base choice will be chosen. + """ + + def __init__( + self, + layout: Layout, + inputs: List[IRNode], + choice_timings: Callable[[], Dict[ChoiceCaller, float]], + ): + super().__init__(layout=layout, inputs=inputs, make_kernel_render=None) + self._choice_timings_fn = choice_timings + self._choice_timings: Optional[Dict[ChoiceCaller, float]] = None + self.original_inputs = inputs + + @property + def choice_timings(self) -> Dict[ChoiceCaller, float]: + if self._choice_timings is None: + self._choice_timings = self._choice_timings_fn() + return self._choice_timings + + @contextlib.contextmanager + def swap_as_triton_caller(self, caller: TritonTemplateCallerBase): + assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller) + assert self.layout == caller.layout + + render = self.make_kernel_render + self.make_kernel_render = caller.get_make_kernel_render() + try: + yield + finally: + self.make_kernel_render = render + + def finalize_as_triton_caller(self, caller: TritonTemplateCallerBase): + assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller) + assert self.layout.size == caller.layout.size + assert self.layout.stride == caller.layout.stride + self.make_kernel_render = caller.get_make_kernel_render() + + def get_min_choice(self) -> Tuple[ChoiceCaller, float]: + min_choice = min(self.choice_timings, key=self.choice_timings.get) # type: ignore[arg-type] + return (min_choice, self.choice_timings[min_choice]) + + +class CUDATemplateBuffer(TemplateBuffer): + def __init__( + self, + layout, + inputs, + make_kernel_render, + workspace_size: int, + template: CUDATemplate, # type: ignore[name-defined] # noqa: F821 + ): + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + self.template = template + + def get_workspace_size(self): + return self.workspace_size if self.workspace_size is not None else 0 + + +class CppTemplateBuffer(TemplateBuffer): + def __init__(self, layout, inputs, make_kernel_render, template, choice): + super().__init__(layout, inputs, make_kernel_render) + self.template = template + self.choice = choice + + +@dataclasses.dataclass +class InputsKernel(OperationBuffer): + inputs: List[Buffer] + + def get_read_writes(self): + reads: OrderedSet[dependencies.Dep] = OrderedSet() + StarDep = dependencies.StarDep + for input in self.inputs: + if isinstance(input, list): + reads.update(StarDep(x.get_name()) for x in input) + else: + reads.add(StarDep(input.get_name())) + + writes: OrderedSet[dependencies.Dep] = OrderedSet( + StarDep(buf.get_name()) for buf in self.get_outputs() + ) + + return dependencies.ReadWrites( + reads=reads, + writes=writes, + index_exprs=OrderedSet(), + ) + + @classmethod + def unwrap_storage_for_input(cls, x): + if isinstance(x, TensorBox): + x = x.data + if isinstance(x, StorageBox): + x = x.data + if isinstance(x, BaseView) and not isinstance(x, ReinterpretView): + x = ExternKernel.realize_input(x) + if isinstance(x, TensorBox): + # when converting to ReinterpretView fails in the + # realize_input call above, the result will be wrapped + # into TensorBox / StorageBox pair as a result of the + # cls.copy_input call; so we should unwrap recursively + return cls.unwrap_storage_for_input(x) + if isinstance(x, TorchBindObject): + return x + assert isinstance(x, (Buffer, ReinterpretView)), x + return x + + @staticmethod + def unwrap_storage(inputs): + inputs_new = [] + for x in inputs: + if isinstance(x, list): + x = [InputsKernel.unwrap_storage_for_input(i) for i in x] + else: + x = InputsKernel.unwrap_storage_for_input(x) + inputs_new.append(x) + return inputs_new + + def is_extern(self): + return True + + def num_reads(self): + return 1 + + +class NopKernel(InputsKernel): + def is_no_op(self): + return True + + +class ConcatKernel(NopKernel): + """ + There isn't actually a real kernel for concat, we just change the + storage for the upstream data. + """ + + @classmethod + def create(cls, inputs, dim): + device = inputs[0].get_device() + dtype = inputs[0].get_dtype() + new_size = list(inputs[0].get_size()) + offsets_start = [0] + offsets_end = [new_size[dim]] + assert 0 <= dim < len(new_size) + for i in range(1, len(inputs)): + input_size = inputs[i].get_size() + offsets_start.append(new_size[dim]) + assert len(input_size) == len(new_size) + assert inputs[i].get_dtype() == dtype + assert inputs[i].get_device() == device + for j in range(len(new_size)): + if j == dim: + new_size[j] = new_size[j] + input_size[j] + else: + new_size[j] = V.graph.sizevars.guard_equals( + new_size[j], input_size[j] + ) + offsets_end.append(new_size[dim]) + + output_stride = FlexibleLayout.contiguous_strides(new_size) + # If any of the inputs is in CL format, use CL format for the output + for i in range(len(inputs)): + x = inputs[i] + if is_storage_and_layout(x): + layout = x.get_layout() + if isinstance( + layout, FixedLayout + ) and Layout.is_channels_last_contiguous(layout.size, layout.stride): + # use CL stride for the output + output_stride = make_channels_last_strides_for(new_size) + break + any_input_is_storage_and_layout = any(is_storage_and_layout(x) for x in inputs) + fx_node_args = V.graph.current_node.args[0] + assert isinstance(fx_node_args, list) + # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output + if any_input_is_storage_and_layout is False and any( + "val" in arg.meta + and ( + arg.meta["val"].is_contiguous(memory_format=torch.channels_last) + or arg.meta["val"].is_contiguous(memory_format=torch.channels_last_3d) + ) + for arg in fx_node_args + ): + output_stride = make_channels_last_strides_for(new_size) + + concat_kernel = ConcatKernel( + name=None, + layout=FixedLayout( + device=device, + dtype=dtype, + size=new_size, + stride=output_stride, + ), + inputs=[], + ) + kernel = StorageBox(concat_kernel) + op_names = [] + for i in range(len(inputs)): + input_buffer = cls.realize_into( + inputs[i], + SliceView.create( + kernel, dim, offsets_start[i], offsets_end[i], clamp=False + ), + ) + concat_kernel.inputs.append(input_buffer) + + if isinstance(inputs[i].data, BaseView): + input_unwrapped = inputs[i].data.unwrap_view() + else: + input_unwrapped = inputs[i].data + + if ( + input_unwrapped.is_input_buffer() + and is_gpu(inputs[i].get_device().type) + and not is_dynamic(input_buffer) + ): + op_names.append(input_buffer.get_operation_name()) + + if len(op_names) > 1 and V.graph.has_feature(device, BackendFeature.FOREACH): + V.graph.register_operation_list(op_names) + + concat_kernel.name = V.graph.register_buffer(concat_kernel) + concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs) + V.graph.register_operation(concat_kernel) + + return kernel + + @classmethod + def can_realize_into_without_copy(cls, src): + if isinstance(src, TensorBox): + # unwrap a TensorBox + return cls.can_realize_into_without_copy(src.data) + + return isinstance(src.data.layout, FlexibleLayout) and not isinstance( + src.data, ExternKernelAlloc + ) + + @classmethod + def realize_into(cls, src, dst): + # Attempt to turn this into a ReinterpretView rather than assert. + # This has concessions around layout, as as_storage_and_layout + # can cause us to go from flexible to fixed layout. + if not isinstance(dst, ReinterpretView): + if is_storage_and_layout(dst): + storage, layout = as_storage_and_layout(dst) + dst = ReinterpretView(storage, layout) + assert isinstance(dst, ReinterpretView), dst + if isinstance(src, TensorBox): + # unwrap a TensorBox + return cls.realize_into(src.data, dst) + if isinstance(src, StorageBox): + src.realize() + # ExternKernelAlloc has specific requirements for output layout, should create a copy + assert hasattr(src.data, "layout") + if cls.can_realize_into_without_copy(src): + src.data.layout = NonOwningLayout(dst) + return src.data + # introduce a copy + pw = Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ) + return cls.realize_into(pw, dst) + + def should_allocate(self): + return True + + +@dataclasses.dataclass +class ExternKernel(InputsKernel): + constant_args: Tuple[Any, ...] = () + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + output_view: Optional[ReinterpretView] = None + python_kernel_name: Optional[str] = None + cpp_kernel_name: Optional[str] = None + # FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel + # We shouldn't need to do this since the information can be retrieved from op_overload._schema. + ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field( + default_factory=list + ) + op_overload: Optional[ + Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator] + ] = None + arg_properties: Optional[List[Dict[str, Any]]] = None + kwarg_properties: Optional[Dict[str, Dict[str, Any]]] = None + unbacked_bindings: Dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field( + default_factory=dict + ) + mutation_outputs: List[MutationOutput] = dataclasses.field(default_factory=list) + + def __init__( + self, + name, + layout, + inputs, + constant_args=(), + kwargs=None, + output_view=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ): + super().__init__( + name, + layout, + inputs, + ) + self.constant_args = constant_args + self.kwargs = kwargs if kwargs else {} + self.output_view = output_view + self.op_overload = op_overload + self.set_cpp_kernel_name(cpp_kernel_name) + self.set_python_kernel_name(python_kernel_name) + self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel + self.collect_arg_kwarg_properties() + self.unbacked_bindings = {} + self.mutation_outputs = [] + self.fx_node = V.graph.current_node + + def get_outputs(self) -> List[Buffer]: + return [self, *self.mutation_outputs] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def collect_arg_kwarg_properties(self): + # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional + # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen + self.arg_properties = ( + [ + { + "name": x.name, + "type": x.real_type, + "default_value": x.default_value, + } + for x in self.op_overload._schema.arguments + if not x.kwarg_only + ] + if isinstance(self.op_overload, torch._ops.OpOverload) + else [{} for i in range(len(self.inputs))] + ) + self.allarg_properties = ( + { + x.name: {"type": x.real_type, "default_value": x.default_value} + for x in self.op_overload._schema.arguments + } + if isinstance(self.op_overload, torch._ops.OpOverload) + else {} + ) + # FIXME: self.kwargs does not always match kwargs defined in schema, so sometimes + # ordered_kwargs_for_cpp_kernel is explicilty passed in. + if ( + isinstance(self.op_overload, torch._ops.OpOverload) + and not self.ordered_kwargs_for_cpp_kernel + ): + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in self.op_overload._schema.arguments if x.kwarg_only + ] + + def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False): + # Previously, we want to maintain forward-compatibility by skipping + # default args in the serialized artifacts in fbcode. However, + # some of our shim interfaces require default values being OrderedSet. + # Discussed with Sherlock offline and we decided to allow serializing + # default args into the C++ wrapper code for now. We will refine this + # part if we see real FC requirement. More details related to FC + # can be found at: + # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing + assert isinstance(args, (list, tuple)) + if isinstance(args, tuple): + args = list(args) + assert self.arg_properties, "ExternKernel.arg_properties should not be empty" + + n_args = len(args) + n_pos_args = len(self.arg_properties) + # For cpp wrapper, if some positional args are not provided, we need to check + # if they're in the kwargs or use their default value + if n_args < n_pos_args: + log.debug( + "%s has %d unprovided positional arguments. " + "Will check if they are in the keyword arguments or will use default values.", + self.op_overload, + n_pos_args - n_args, + ) + for i in range(n_args, n_pos_args): + arg_name = self.arg_properties[i]["name"] + args.append( + kwargs[arg_name] + if arg_name in kwargs + else self.arg_properties[i]["default_value"] + ) + return args + + def decide_layout(self): + if isinstance(self.layout, FlexibleLayout): + self.apply_constraint() + self.freeze_layout() + + def codegen_comment(self, wrapper): + origin_str, detailed_origin_str = get_kernel_metadata(self, wrapper) + if origin_str: + wrapper.writeline(origin_str) + + def codegen(self, wrapper): + raise NotImplementedError + + def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None): + self.cpp_kernel_name = cpp_kernel_name + self.cpp_kernel_overload_name = None + self.cpp_kernel_key = None + self.cpp_op_schema = None + if not V.graph.cpp_wrapper or not isinstance( + self.op_overload, torch._ops.OpOverload + ): + return + + kernel = self.op_overload + if self.cpp_kernel_name is None: + # Try to construct cpp_kernel_name from op_overload + if kernel.namespace == "aten": + # Calling with the default kernel name can lead to ambiguous behavior like the following example. + # repeat_interleave(const at::Tensor & repeats, c10::optional output_size=std::nullopt) + # repeat_interleave(const at::Tensor & self, int64_t repeats, + # c10::optional dim=std::nullopt, c10::optional output_size=std::nullopt) + opname = ( + kernel.__name__.split(".")[0] + if kernel._overloadname == "default" + else kernel.__name__.replace(".", "_") + ) + self.cpp_kernel_name = f"at::_ops::{opname}::call" + else: + self.cpp_kernel_name = kernel._schema.name + + # Set up info for runtime schema lookup + # TODO: The logics here may be further simplified. + from .codegen.wrapper import get_cpp_op_schema + + self.cpp_kernel_overload_name = kernel._schema.overload_name + self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] + try: + self.cpp_op_schema = get_cpp_op_schema(kernel) + except Exception: + self.cpp_op_schema = "" + + def set_python_kernel_name(self, python_kernel_name: Optional[str]): + self.python_kernel_name = python_kernel_name + if python_kernel_name is not None: + return + + kernel = self.op_overload + if kernel is None: + pass + elif isinstance(kernel, torch._ops.HigherOrderOperator): + self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}" + else: + self.python_kernel_name = ( + f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}" + ) + + def get_kernel_name(self): + return ( + ( + V.graph.wrapper_code.get_c_shim_func_name(self.cpp_kernel_name) # type: ignore[attr-defined] + if config.abi_compatible + else self.cpp_kernel_name + ) + if V.graph.cpp_wrapper + else self.python_kernel_name + ) + + @staticmethod + def copy_input(x): + pw = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=x.get_size(), + origin_node=x.get_origin_node(), + traceback=x.get_traceback(), + ) + pw.realize() + return pw + + @classmethod + def process_kernel( + cls, kernel, *args, **kwargs + ) -> Tuple[ + Any, + List[Any], + List[Any], + Callable[[Any, Any], Any], + Optional[Dict[sympy.Symbol, pytree.KeyPath]], + ]: + binded_args = {"args": args, "kwargs": kwargs} + + args_flat, args_spec = pytree.tree_flatten(binded_args) + + is_arg_tensor = [] + tensor_args = [] + non_tensor_args: List[Any] = [] + for arg in args_flat: + is_arg_tensor.append(isinstance(arg, IRNode)) + if is_arg_tensor[-1]: + tensor_args.append(arg) + else: + if isinstance(arg, sympy.Expr): + arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) + non_tensor_args.append(arg) + + def unflatten_args(new_tensor_args, new_non_tensor_args): + result = [] + it_tensors = iter(new_tensor_args) + it_non_tensors = iter(new_non_tensor_args) + for is_tensor in is_arg_tensor: + if is_tensor: + result.append(next(it_tensors)) + else: + result.append(next(it_non_tensors)) + r = pytree.tree_unflatten(result, args_spec) + return r.get("args", []), r.get("kwargs", {}) + + tensor_args = [cls.realize_input(x) for x in tensor_args] + + # freeze layout otherwise our output stride calculation might + # become incorrect + for x in tensor_args: + if is_storage_and_layout(x): + as_storage_and_layout(x, freeze=True) + + # Rerun fake tensor propagation, because Inductor may have changed the + # strides of inputs and we need to determine accurately what the + # output stride will be. + example_args: List[Union[torch.Tensor, torch._C.ScriptObject]] = [] + + # We need to retain the constant values of fake tensors that we originally + # propagated the graph with, because for some operators running without a + # constant would trigger an error / DataDependentException + for x in tensor_args: + # if x is a view of a constant, we need to realize the view + # (we can't pass the constant into the kernel directly) + if not isinstance(x, BaseView) and x.get_name() in V.graph.constants: + example_args.append(V.graph.constants[x.get_name()]) + elif ( + not isinstance(x, BaseView) + and x.get_name() in V.graph.torchbind_constants + ): + example_args.append(V.graph.torchbind_constants[x.get_name()]) + else: + example_args.append(ir_node_to_tensor(x, guard_shape=True)) + + new_args, new_kwargs = unflatten_args(example_args, non_tensor_args) + example_output = kernel(*new_args, **new_kwargs) + + unbacked_bindings: Optional[Dict[sympy.Symbol, pytree.KeyPath]] = None + if shape_env := V.fake_mode.shape_env: + rebind_unbacked(shape_env, V.current_node, example_output) + unbacked_bindings = compute_unbacked_bindings( + shape_env, example_output, V.current_node.meta.get("val") + ) + + example_out_li = ( + [example_output] + if not isinstance(example_output, (list, tuple)) + else example_output + ) + for t in example_out_li: + if isinstance(t, torch.Tensor) and t.is_sparse: + msg = "sparsity not handled. Please file issue for sparse inference weights." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + return ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) + + @classmethod + def convert_to_reinterpret_view(cls, x): + """ + In order to pass this to an extern kernel we need a + ReinterpretView not a View. This allows us to avoid some + unneeded copies. + """ + assert isinstance(x, BaseView) + if isinstance(x, ReinterpretView): + return x + + # NOTE: Don't use extract_read_writes here as it fails when + # make_loader() inlines the computation + x_unwrap_view = x.unwrap_view() + buf = V.graph.get_buffer(x_unwrap_view.get_name()) + assert buf is not None + x_unwrap_view_fx_node = buf.get_origin_node() + # Prefer channels last format according to how the format is set from eager. + if ( + x_unwrap_view_fx_node is not None + and "val" in x_unwrap_view_fx_node.meta + and isinstance(x_unwrap_view.layout, FlexibleLayout) + and ( + x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last + ) + or x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last_3d + ) + ) + ): + x_unwrap_view.freeze_layout_with_same_order( + make_channels_last_strides_for(x_unwrap_view.get_size()) + ) + else: + x_unwrap_view.freeze_layout() + + index_args, var_ranges = dependencies.index_vars_squeeze( + x.get_size(), prefix="r" + ) + range_vars = index_args[0] + index = x.make_indexer()(range_vars) + + index = V.graph.sizevars.simplify_with_ranges(index, var_ranges) + strides = V.graph.sizevars.stride_vars(index, range_vars) + offset = V.graph.sizevars.offset_var(index, range_vars) + expected = sympy_dot(range_vars, strides) + offset + + if index != expected: + log.debug( + "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s", + strides, + offset, + index, + ) + raise NotImplementedError + + return ReinterpretView( + data=x.data, + layout=FixedLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=x.get_size(), + stride=strides, + offset=offset, + ), + ) + + @classmethod + def realize_input(cls, x): + if x is None: + return NoneAsConstantBuffer() + if isinstance(x, (sympy.Expr, sympy.logic.boolalg.Boolean, int)): + return ShapeAsConstantBuffer(x) + if isinstance(x, Constant): + return V.graph.add_tensor_constant( + torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device()) + ) + if isinstance(x, ConstantBuffer): + return x + if isinstance(x, TensorBox): + return cls.realize_input(x.data) + if isinstance(x, ReinterpretView): + return ReinterpretView(cls.realize_input(x.data), x.get_layout()) + if isinstance(x, BaseView): + x.realize() + if is_storage_and_layout(x.unwrap_view()): + try: + return cls.convert_to_reinterpret_view(x) + except NotImplementedError: + pass + if isinstance(x, StorageBox): + # TODO(jansel): impose layout preference on realized buffer + x.realize() + return x + if isinstance(x, TorchBindObject): + return x + return cls.copy_input(x) + + @classmethod + def require_stride1(cls, x): + if is_storage_and_layout(x): + if len(x.get_stride()) == 0: + return x + for stride in x.get_stride(): + if stride == 1: + return x + return cls.copy_input(x) + + @classmethod + def require_strides( + cls, + x, + order: Optional[Sequence[int]] = None, + exact_strides: Optional[Sequence[_IntLike]] = None, + allow_padding=False, + ): + assert order is not None or exact_strides is not None + if x.get_numel() == 0: # Layout doesn't matter + return x + # require x to have the layout + if is_storage_and_layout(x): + while isinstance(x.get_layout(), NonOwningLayout): + x = x.get_layout().view + if isinstance(x.get_layout(), FlexibleLayout): + if order: + # If the the FlexibleLayout already has the size and stride in the required order, + # freeze it to a FixedLayout by using its current size and stride. + # The behavior of using its current size and stride or the given order can be different + # if the size and stride has ambiguilty, for example for a 4D input where the iC = 1: + # size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last), + # the current size and stride already satisfies this order. + # However by freezing it to the required order, the layout will be changed to: + # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary. + + # fix flexiblelayout to be FixedLayout with stride_order + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=get_stride_order( + V.graph.sizevars.size_hints(x.get_layout().stride) + ) + if is_stride_order_storage_and_layout(x, order) + else order, + allow_padding=allow_padding, + ) + return x + else: + # If the exact_strides is given, freeze the FlexibleLayout to a FixedLayout with the exact_strides. + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=None, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + return x + elif isinstance(x.get_layout(), FixedLayout) and ( + (order and x.get_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, x.get_layout().stride, x.get_size() + ) + ) + ): + return x + elif isinstance(x.get_layout(), MutationLayoutSHOULDREMOVE): + if isinstance(x.get_layout().real_layout(), FlexibleLayout): + raise AssertionError( + "the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout" + ) + elif isinstance(x.get_layout().real_layout(), FixedLayout) and ( + (order and x.get_layout().real_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, + x.get_layout().real_layout().stride, + x.get_size(), + ) + ) + ): + return x + + # TODO - Storage to InputBuffer + if isinstance(x, InputBuffer) and ( + (order and x.get_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, x.get_layout().stride, x.get_size() + ) + ) + ): + return x + if ( + isinstance(x, TensorBox) + and isinstance(x.data, BaseView) + and not isinstance(x.data, ReinterpretView) + and is_storage_and_layout(x.unwrap_view()) + and not isinstance(x.unwrap_view().data, ExternKernelAlloc) + ): + try: + x.data = cls.convert_to_reinterpret_view(x.data) + if order: + return cls.require_stride_order( + x, order, allow_padding=allow_padding + ) + elif exact_strides: + return cls.require_exact_strides( + x, exact_strides, allow_padding=allow_padding + ) + except NotImplementedError: + pass + # Although this is a clone, inductor is good about fusing clones into previous + # operations if they weren't realized and their layouts were flexible. + x = cls.copy_input(x) + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=order, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + if order: + assert is_stride_order_storage_and_layout(x, order) + return x + + @classmethod + def require_exact_strides(cls, x, exact_strides, allow_padding=False): + return cls.require_strides( + x, exact_strides=exact_strides, allow_padding=allow_padding + ) + + @classmethod + def require_stride_order(cls, x, order, allow_padding=False): + return cls.require_strides(x, order=order, allow_padding=allow_padding) + + @classmethod + def require_channels_last(cls, x): + return cls.require_stride_order(x, NHWC_STRIDE_ORDER) + + @classmethod + def require_channels_last_3d(cls, x): + return cls.require_stride_order(x, NHWDC_STRIDE_ORDER) + + @classmethod + def require_contiguous(cls, x): + return cls.require_stride_order(x, list(reversed(range(len(x.get_size()))))) + + def apply_constraint(self): + pass + + def codegen_const_args(self, names: Optional[List[str]] = None): + if V.graph.cpp_wrapper: + result = [] + # Aten ops follow the convention that tensor args are before non-tensor args, + # in which case the following 'len(self.inputs) + i' logic works. But this + # may not be true for other ops, and if that is the case, caller needs to + # pass in a list of const arg names for arg_properties lookup. + name_to_arg_properties = None + if names and self.arg_properties: + assert len(self.constant_args) == len( + names + ), "names passed to codegen_const_args does not match self.constant_args" + name_to_arg_properties = { + arg.get("name"): arg for arg in self.arg_properties + } + + for i, x in enumerate(self.constant_args): + if name_to_arg_properties is not None: + prop = name_to_arg_properties.get(names[i]) # type: ignore[index] + type_ = prop.get("type") if prop else None + else: + idx = len(self.inputs) + i + type_ = ( + self.arg_properties[idx].get("type") + if self.arg_properties and idx < len(self.arg_properties) + else None + ) + result.append( + V.graph.wrapper_code.val_to_arg_str(x, type_) # type: ignore[arg-type] + ) + return result + else: + return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) + + def codegen_args(self): + args = [] + for i, x in enumerate(self.inputs): + if isinstance(x, list): + names = [i.codegen_reference() for i in x] + codegen_reference = f'[{", ".join(names)}]' + args.append(codegen_reference) + else: + if V.graph.cpp_wrapper: + assert self.arg_properties and i < len( + self.arg_properties + ), "Invalid access to ExternKernel.arg_properties" + type_ = self.arg_properties[i].get("type") + args.append( + V.graph.wrapper_code.val_to_arg_str( # type: ignore[arg-type] + x, type_ + ) + ) + else: + args.append(x.codegen_reference()) + args.extend(self.codegen_const_args()) + return args + + def get_kwargs_value(self, arg_name): + if arg_name in self.kwargs: + return self.kwargs.get(arg_name) + if self.allarg_properties and self.allarg_properties.get(arg_name): + return self.allarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr] + else: + raise AssertionError(f"{arg_name} not in self.allarg_properties") + + def codegen_kwargs(self, skip_out=False): + if V.graph.cpp_wrapper: + kwargs = [] + for arg_name in self.ordered_kwargs_for_cpp_kernel: + if skip_out and arg_name == "out": + # ExternKernelOut has its own logic for inserting the out parameter + continue + + v = self.get_kwargs_value(arg_name) + if isinstance(v, sympy.Expr): + kwargs.append(v) + else: + type_ = ( + self.allarg_properties.get(arg_name).get("type") # type: ignore[union-attr] + if self.allarg_properties and arg_name in self.allarg_properties + else None + ) + kwargs.append( + V.graph.wrapper_code.val_to_arg_str( # type: ignore[arg-type] + v, type_ + ) + ) + else: + kwargs = [ + f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" # type: ignore[misc] + for k, v in self.kwargs.items() + ] + return kwargs + + def codegen_size_asserts(self, wrapper): + if config.size_asserts and not V.graph.cpp_wrapper: + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(self.get_size()) == 0: + return + size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size()) + stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride()) + wrapper.writeline( + f"assert_size_stride({self.get_name()}, {size}, {stride})" + ) + + def get_group_stride(self): + """ + get output sizes and strides, for template_codegen + """ + _size = self.get_size() + _stride = self.get_stride() + # iter_ranges = _size of output tensor, reduce_range = [] because no reduction + return [_size, []], _stride + + def canonicalize(self): + """ + Manually get canonicalization of the output index + """ + # manually generate index formula for conv + sizevars = V.graph.sizevars + sizes = self.get_size() + strides = self.get_stride() + strides = [sizevars.size_hint(x) for x in strides] + # TODO: I can't tell if the symbols here are temporary + index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))] + # reorder index vars according to stride + index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) + lookup = {pos: idx for idx, pos in enumerate(index_order)} + order = [lookup[i] for i in range(len(lookup))] + index_vars = [index_vars[i] for i in order] + indexer = self.make_indexer() + index = indexer(index_vars) + + new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( + index_vars, sizes, [index] + ) + + # assign new variables each dimension to deal with numbering mismatches + # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 + _, add_var = var_builder("c") + replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) + + index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type] + return index, tuple(new_sizes) + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + # NB: It's not necessary to check regular inputs as we automatically + # have dependencies on them + r: OrderedSet[sympy.Symbol] = OrderedSet() + for arg in self.constant_args: + r |= maybe_free_unbacked_symbols(arg) + for arg in self.kwargs.values(): + r |= maybe_free_unbacked_symbols(arg) + return r + + def __str__(self) -> str: + kernel_name = getattr(self, "python_kernel_name", None) + lines = [ + f"python_kernel_name={kernel_name!r}", + ] + lines += [ + f"{field.name}={getattr(self, field.name)}" + for field in dataclasses.fields(self) + ] + lines.append(f"origin_node={self.origin_node!r}") + return self.str_helper(lines) + + __repr__ = __str__ + + +@dataclasses.dataclass +class ExternKernelOut(ExternKernel): + def codegen(self, wrapper): + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs(skip_out=True)] + kernel_name = self.get_kernel_name() + if ( + V.graph.cpp_wrapper + and self.cpp_kernel_name == "torch::inductor::_mm_plus_mm" + ): + # For https://github.com/pytorch/pytorch/issues/128474 + kernel_name = ( + "aoti_torch__mm_plus_mm_out" + if config.abi_compatible + else "torch::inductor::_mm_plus_mm_out" + ) + else: + kernel_name = self.get_kernel_name() + wrapper.generate_extern_kernel_out( + kernel_name, + self.codegen_reference(), + self.output_view.codegen_reference() if self.output_view else None, + args, + ) + + def __init__( + self, + layout, + inputs, + constant_args=(), + kwargs=None, + output_view=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ): + super().__init__( + None, + layout, + self.unwrap_storage(inputs), + constant_args, + kwargs or {}, + None, + python_kernel_name, + cpp_kernel_name, + ordered_kwargs_for_cpp_kernel, + op_overload, + ) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def should_allocate(self): + return True + + +class RandomSeeds(ExternKernelOut): + def __init__(self, count: int, device: torch.device): + limits = torch.iinfo(torch.int64) + super().__init__( + layout=FixedLayout( + device=device, + dtype=torch.int64, + size=[count], + ), + inputs=[], + constant_args=[limits.min, limits.max, [count]], + python_kernel_name="aten.randint.low_out", + # FIXME: Ideally we should only use at::_ops::randint_low_out::call here, + # but the signature is different from is at::randint_out. Again, + # we can simplify the code when only keeping an ABI-compatible version. + cpp_kernel_name="at::_ops::randint_low_out::call" + if config.abi_compatible + else "at::randint_out", + op_overload=aten.randint.low_out, + ) + + +class ExternKernelAlloc(ExternKernel): + def codegen(self, wrapper): + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs()] + V.graph.wrapper_code.generate_extern_kernel_alloc(self, args) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def __init__( + self, + layout, + inputs, + constant_args=(), + kwargs=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ): + super().__init__( + None, + layout, + self.unwrap_storage(inputs), + constant_args, + kwargs or {}, + None, + python_kernel_name, + cpp_kernel_name, + ordered_kwargs_for_cpp_kernel, + op_overload, + ) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def should_allocate(self): + return False + + def apply_constraint(self): + raise NotImplementedError + + +class MutationOutput(Buffer): + """ + An output buffer that represents the mutation of a pre-existing buffer + """ + + def __init__(self, layout, mutated_node, mutating_node: Operation): + super().__init__(name=None, layout=layout) + mutated_node_name = mutated_node.get_name() + V.graph.mark_buffer_mutated(mutated_node_name) + self.mutation_names = [mutated_node_name] + self.mutating_node: Operation = mutating_node + self.name = V.graph.register_buffer(self) + + def get_defining_op(self) -> Operation: + return self.mutating_node + + def get_mutation_names(self): + return self.mutation_names + + def should_allocate(self): + return False + + +class UserDefinedTritonKernel(ExternKernel): + def get_kernel_and_configs(self): + from triton.runtime.autotuner import Autotuner + + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + + kernel = kernel_side_table.get_kernel(self.kernel_idx) + configs = [] + if isinstance(kernel, Autotuner): + configs = kernel.configs + kernel = kernel.fn + return kernel, configs + + def codegen(self, wrapper): + kernel, configs = self.get_kernel_and_configs() + + # Definition of kernel + new_name, triton_meta = wrapper.define_user_defined_triton_kernel( + kernel, configs, self.kwargs + ) + raw_args = [ + self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel + ] + + # NOTE: raw_args doesn't include autotuned args. + # But, kernel.constexprs includes indices of autotuned args. + # So, let's recalculate constexpr indices wrt to raw_args. + constexpr_indices = [] + for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel): + if kernel.arg_names.index(kwarg) in kernel.constexprs: + constexpr_indices.append(idx) + + # Call to kernel + self.codegen_comment(wrapper) + wrapper.generate_user_defined_triton_kernel( + new_name, raw_args, self.grid, configs, triton_meta, constexpr_indices + ) + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + # add unbacked symbols used in the grid to the ones used + # in the kwargs (the latter is generated by ExternKernel) + return super().get_unbacked_symbol_uses() | free_unbacked_symbols(self.grid) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__(self, *, kernel_idx, grid, kernel_args): + inputs = [] + kwargs = {} + constant_args = [] + for k, v in kernel_args.items(): + if isinstance(v, TensorBox): + t = InputsKernel.unwrap_storage_for_input(self.realize_input(v)) + inputs.append(t) + kwargs[k] = t + else: + constant_args.append(v) + kwargs[k] = v + + assert len(inputs) != 0 + self.device = inputs[0].get_device() + + super().__init__( + None, + NoneLayout(self.device), # type: ignore[arg-type] + inputs, + tuple(constant_args), + kwargs, + ) + self.kernel_idx = kernel_idx + self.grid = grid + + kernel, configs = self.get_kernel_and_configs() + # If we are autotuning, not all arguments will be passed + self.ordered_kwargs_for_cpp_kernel = [ + arg for arg in kernel.arg_names if arg in kernel_args + ] + + from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors + + autotuned_kwargs = configs[0].kwargs if len(configs) > 0 else {} + self.mutable_args = [ + kernel_args[key] + for key in identify_mutated_tensors( + kernel, {**kernel_args, **autotuned_kwargs} + ) + ] + + self.mutation_outputs = [ + MutationOutput(NoneLayout(self.device), buf, self) + for buf in self.mutable_args + ] + V.graph.register_operation(self) + + def get_outputs(self) -> List[Buffer]: + return list(self.mutation_outputs) + + def get_device(self) -> torch.device: + return self.device + + +class InplaceBernoulliFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper): + (x,) = (t.codegen_reference() for t in self.inputs) + + if V.graph.cpp_wrapper and config.abi_compatible: + # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here, + # which needs to be explicitly generated for cpp wrapper + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}" + ) + else: + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__(self, op_overload, x, *constant_args): + super().__init__( + None, + NoneLayout(x.get_device()), # type: ignore[arg-type] + self.unwrap_storage([x]), + constant_args, + op_overload=op_overload, + ) + V.graph.mark_buffer_mutated(x.get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + if not config.abi_compatible: + # TODO: this should be simplified once we switch to ABI-compatible only + self.cpp_kernel_name = "at::native::bernoulli_" + + +# Used to deal with torch.complex types +class InplaceCopyFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper): + (dst, src, non_blocking) = self.codegen_args() + wrapper.codegen_device_copy(src, dst) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__( + self, + layout, + inputs, + constant_args, + ): + super().__init__( + None, + layout, + inputs, + constant_args, + python_kernel_name="aten.copy_", + cpp_kernel_name=( + "aoti_torch_copy_" if config.abi_compatible else "at::_ops::copy_::call" + ), + ) + V.graph.mark_buffer_mutated(inputs[0].get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @classmethod + def create(cls, dst, src, non_blocking: bool = False): + inputs = [cls.realize_input(t) for t in [dst, src]] + constant_args = (non_blocking,) + result = InplaceCopyFallback( + NoneLayout(dst.get_device()), # type: ignore[arg-type] + inputs, + constant_args, + ) + return result + + +class MutatingFirstArgExternKernel(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper): + argrefs = [ + *(t.codegen_reference() for t in self.inputs), + *map(repr, self.constant_args), + ] + wrapper.writeline( + f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}" + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def has_side_effects(self): + return True + + +class ResizeStorageBytes(MutatingFirstArgExternKernel): + def __init__(self, variable, new_size): + assert isinstance(new_size, int), "TODO: dynamic shapes" + super().__init__( + None, + NoneLayout(variable.get_device()), # type: ignore[arg-type] + self.unwrap_storage([variable]), + constant_args=(new_size,), + ) + V.graph.mark_buffer_mutated(variable.get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + self.python_kernel_name = "inductor_ops.resize_storage_bytes_" + self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_" + V.graph.never_reuse_buffers.add(variable.data.get_name()) + + +class SetSourceTensorKernel(ExternKernelAlloc): + def __init__(self, self_tensor, storage_tensor): + self_tensor.freeze_layout() + super().__init__( + self_tensor.get_layout(), + [self_tensor, storage_tensor], + python_kernel_name="torch.ops.aten.set_.source_Tensor", + op_overload=torch.ops.aten.set_.source_Tensor, + ) + V.graph.never_reuse_buffers.add(self_tensor.data.get_name()) + V.graph.never_reuse_buffers.add(storage_tensor.get_name()) + V.graph.never_reuse_buffers.add(self.get_name()) + device = storage_tensor.get_device() + self.mutation_outputs = [ + MutationOutput(NoneLayout(device), self_tensor, self), + MutationOutput(NoneLayout(device), storage_tensor, self), + ] + + def get_inputs_that_alias_output(self): + return [self.inputs[0].get_name(), self.inputs[1].get_name()] + + +class ScatterFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly. + This class handles both aten.scatter_ and aten.scatter_reduce_. + It also handle the case `src` being a scalar properly. + """ + + def codegen(self, wrapper): + reduce = self.kwargs["reduce"] + if V.graph.cpp_wrapper: + # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum + get_operator_enum = {"add": "sum", "multiply": "prod"} + if reduce in get_operator_enum: + reduce = get_operator_enum[reduce] + + if self.src_is_tensor: + (x, index, src) = (t.codegen_reference() for t in self.inputs) + else: + (x, index) = (t.codegen_reference() for t in self.inputs) + src = self.constant_args[1] + wrapper.generate_scatter_fallback( + x, + [x, self.constant_args[0], index, src], + self.cpp_kernel_name, + self.python_kernel_name, + self.src_is_tensor, + reduce, + self.codegen_kwargs(), + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__( + self, + op_overload, + x, + dim: int, + index, + src, + *, + reduce: Optional[str] = None, + include_self: bool = True, + ): + self.src_is_tensor = isinstance(src, TensorBox) + + constant_args: Tuple[Any, ...] + if self.src_is_tensor: + tensors = [self.realize_input(t) for t in [x, index, src]] + constant_args = (dim,) + else: + tensors = [self.realize_input(t) for t in [x, index]] + constant_args = (dim, src) + + super().__init__( + None, + NoneLayout(x.get_device()), # type: ignore[arg-type] + self.unwrap_storage(tensors), + constant_args, + {"reduce": reduce, "include_self": include_self}, + python_kernel_name=str(op_overload), + ordered_kwargs_for_cpp_kernel=["reduce", "include_self"], + op_overload=op_overload, + ) + V.graph.mark_buffer_mutated(x.get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + +class IndexPutFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation and indices properly + """ + + def codegen(self, wrapper): + (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs) + indices = [] + iter_valid_indices = iter(valid_indices) + for i, _ in enumerate(self.indices): + if self.indices[i] is not None: + indices.append(next(iter_valid_indices)) + else: + indices.append(V.graph.wrapper_code.none_str) + + wrapper.generate_index_put_fallback( + self.get_kernel_name(), x, indices, values, *self.codegen_const_args() + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__(self, op_overload, x, indices, values, accumulate): + self.indices = indices + valid_indices = [i for i in indices if i is not None] + tensors = [self.realize_input(x) for x in [x, values, *valid_indices]] + cpp_kernel_name = ( + "aoti_torch_index_put_out" if config.abi_compatible else "at::index_put_out" + ) + super().__init__( + None, + NoneLayout(x.get_device()), # type: ignore[arg-type] + self.unwrap_storage(tensors), + (accumulate,), + python_kernel_name="aten.index_put_", + cpp_kernel_name=cpp_kernel_name, + op_overload=op_overload, + ) + V.graph.mark_buffer_mutated(self.inputs[0].get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + +class DeviceCopy(ExternKernelOut): + @classmethod + def create(cls, x, device): + if ( + not x.is_extern() + and all(r in V.graph.constants for r in x.get_read_names()) + and not config.aot_inductor.use_runtime_constant_folding + ): + return x.constant_to_device(device) + + V.graph.add_device_info(device) + V.graph.add_device_info(x.get_device()) + + developer_warning("DeviceCopy in input program") + return DeviceCopy( + FlexibleLayout( + device=device, + dtype=x.get_dtype(), + size=x.get_size(), + ), + [cls.realize_input(x)], + ) + + def codegen(self, wrapper): + args = self.codegen_args() + assert len(args) == 1 + if self.output_view: + wrapper.codegen_device_copy(args[0], self.output_view.codegen_reference()) + else: + wrapper.codegen_device_copy(args[0], self.codegen_reference()) + + +class DynamicScalar(ExternKernel): + """ + The result of a call to aten._local_scalar_dense. + """ + + def get_reads(self): + return () + + def should_allocate(self): + return False + + def __init__(self, sym, keypath, data): + data.realize() + super().__init__(None, NoneLayout(torch.device("cpu")), self.unwrap_storage([data])) # type: ignore[arg-type] + self.sym = sym + self.keypath = keypath + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet([self.sym]) + + def codegen(self, wrapper): + wrapper.codegen_dynamic_scalar(self) + + +class AssertScalar(ExternKernel): + """ + The result of a call to aten._assert_scalar + """ + + def get_reads(self): + return () + + def should_allocate(self): + return False + + def __init__(self, scalar, msg): + super().__init__( + # Buffer(name, layotu) + None, + NoneLayout(torch.device("cpu")), # type: ignore[arg-type] + # InputsKernel(inputs) + [], + ) # type: ignore[arg-type] + self.scalar = scalar + self.msg = msg + + def has_side_effects(self): + return True + + def get_unbacked_symbol_uses(self): + return free_unbacked_symbols(self.scalar) + + def codegen(self, wrapper): + if V.graph.cpp_wrapper: + pass + else: + # NB: It is EXTREMELY important not to simplify the scalar under + # assertion here, because simplify is done with respect to + # runtime asserts. So if you have "u0 == 0" in the runtime + # asserts, if you subsequently try to simplify(u0 == 0), you will + # get True (because we've already runtime assert'ed that it's + # true). But we're code generating the actual runtime assert + # here!! + wrapper.writeline( + f"if not {V.graph.wrapper_code.codegen_python_sizevar(self.scalar, simplify=False)}:" + ) + wrapper.writeline(f" raise RuntimeError({repr(self.msg)})") + # No one should ever use this buffer, but for uniformity + # define the variable and assign it None + wrapper.writeline(f"{self.get_name()} = None") + + +@dataclasses.dataclass +class ExternKernelNode: + name: str + node: export_schema.Node + + +has_c_shim = OrderedSet( + [ + aten._embedding_bag.default, + aten._fft_c2c.default, + aten._scaled_dot_product_efficient_attention.default, + aten._scaled_dot_product_flash_attention.default, + aten._scaled_dot_product_cudnn_attention.default, + aten._scaled_mm.default, + aten.addmm.out, + aten.bmm.out, + aten.copy_.default, + aten.mm.out, + aten.repeat_interleave.Tensor, + aten.nonzero.default, + aten.view.dtype, + aten.view_as_real.default, + ] +) + + +class FallbackKernel(ExternKernelAlloc): + def __init__( + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + *, + unbacked_bindings=None, + ): + if ( + kernel == aten.mul.Tensor + and len(tensor_args) == 1 + and len(nontensor_args) == 1 + ): + # When aten.mul.Tensor's second arg is constant, cpp wrapper expects + # to call mul_Scalar. A more proper fix is to do it in decomposition. + # See https://github.com/pytorch/pytorch/issues/123478 + kernel = aten.mul.Scalar + + super().__init__( + layout, + tuple(tensor_args), + tuple(nontensor_args), + op_overload=kernel, + ) + + # We need output buffers for generating kernel arguments in the + # abi-compatible mode, where we retrieve outputs by pass each individual + # output through the abi-compatible interface. + self.outputs: Sequence[Any] = [] + self.use_runtime_dispatch = False + self.unbacked_bindings = unbacked_bindings + + assert isinstance( + kernel, + ( + torch._ops.OpOverload, + torch._ops.HigherOrderOperator, + ), + ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported" + self.op_overload = kernel + self.unflatten_args = unflatten_args + self.kwargs = {} if kwargs is None else kwargs + V.graph.warn_fallback(self.python_kernel_name) # type: ignore[arg-type] + + # args that are aliased + self.alias_names: List[str] = [] + # args that are mutated AND returned from the op + self.mutation_names: List[str] = [] + + if isinstance(self.op_overload, torch._ops.HigherOrderOperator): + # We assume here that HOPs with FallbackKernel are functional. + # This may not always be true! HOPs must individually opt-in to + # FallbackKernel, so please check this if you opt-in. + return + + if "_c10d_functional" in self.op_overload.name(): + # _c10d_functional kernels are lowered into _CollectiveKernel which + # derives from FallbackKernel for the cpp codegen. The kernels + # don't pass the can_auto_functionalize check, but their mutation + # is handled properly by _CollectiveKernel. + return + + schema = self.op_overload._schema + + # NOTE: [FallbackKernel supported operators] + # We only support three types of operators: + # - functional ops + # - view ops + # - inplace aten ops + # - mutating ops that are auto-functionalizable. That is, + # the operator may mutate any number of inputs, but its outputs + # may not alias any of the inputs. + # + # The unsupported cases usually do not show up here (because + # AOTAutograd functionalized them away); the only way for an in-place + # op to show up here is if a lowering or pass introduced it. + if torch._library.utils.mutates_and_returns_first_arg(self.op_overload): + self.mutation_names.append(tensor_args[0].get_name()) + return + + if schema.is_mutable and not can_auto_functionalize(kernel): + raise NotImplementedError( + f"NYI: Can't generate FallbackKernel for {kernel}" + ) + + schema_args = schema.arguments + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + + def handle_aliasing_and_mutation(info, arg): + # Assertions to make sure we didn't mismatch args + if isinstance(info.type, torch.ListType): + assert isinstance(arg, (list, tuple)) + is_optional_tensor = isinstance( + info.type, torch.OptionalType + ) and isinstance(info.type.getElementType(), torch.TensorType) + is_list_tensor = isinstance(info.type, torch.ListType) and isinstance( + info.type.getElementType(), torch.TensorType + ) + if is_optional_tensor or isinstance(info.type, torch.TensorType): + # PyTorch also accepts None and scalar types for args marked as "Tensor". + # We're not going to check all of them here. + assert not isinstance(arg, (tuple, list)) + + if arg is None: + return + if info.alias_info is None: + return + + def add_alias(t): + self.alias_names.append(t.get_name()) + if info.alias_info.is_write: + self.mutation_outputs.append( + MutationOutput(NoneLayout(t.get_device()), t, self) + ) + + if is_list_tensor: + for tensor_arg in arg: + add_alias(tensor_arg) + else: + assert isinstance(info.type, torch.TensorType) or is_optional_tensor + add_alias(arg) + + for info, arg in torch._library.utils.zip_schema(schema, args, kwargs): + handle_aliasing_and_mutation(info, arg) + + def codegen_unbacked_symbol_defs(self, wrapper): + if not hasattr(self, "unbacked_bindings"): + return + + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, self.unbacked_bindings + ) + + if not unbacked_bindings: + return + + for s, keypath in unbacked_bindings.items(): + + def go(expr, keypath): + if keypath == (): + return expr + + if ( + len(keypath) >= 2 + and isinstance(keypath[0], CallMethodKey) + and isinstance(keypath[1], pytree.SequenceKey) + ): + return go( + f"{expr}.{keypath[0].name}({keypath[1].idx})", keypath[2:] + ) + elif isinstance(keypath[0], CallMethodKey): + return go(f"{expr}.{keypath[0].name}()", keypath[1:]) + elif isinstance(keypath[0], pytree.SequenceKey): + return ( + go(f"std::get<{keypath[0].idx}>({expr})", keypath[1:]) + if V.graph.cpp_wrapper + else go(f"{expr}[{keypath[0].idx}]", keypath[1:]) + ) + elif isinstance(keypath[0], DivideByKey): + # TODO: need to assert divisibility + # TODO: this is invalid C++ codegen + return go(f"{expr}.__floordiv__({keypath[0].divisor})", keypath[1:]) + else: + raise AssertionError(f"unrecognized keypath {keypath}") + + def go_outer(): + if V.graph.cpp_wrapper and config.abi_compatible: + # Special handling for the top level buffer access, + # because self.get_name() is actually never bound; the + # individual output arguments are bound by + # generate_c_shim_fallback_kernel + if len(self.outputs) == 1: + return go(self.outputs[0].get_name(), keypath) + else: + assert isinstance(keypath[0], pytree.SequenceKey) + return go(self.outputs[keypath[0].idx].get_name(), keypath[1:]) + else: + return go(self.get_name(), keypath) + + wrapper.writeline( + f"{wrapper.codegen_unbacked_symbol_decl(s)} = {go_outer()}{wrapper.ending}" + ) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + if unbacked_bindings := getattr(self, "unbacked_bindings", None): + return resolve_unbacked_bindings( + V.graph.sizevars.shape_env, unbacked_bindings + ).keys() + else: + return OrderedSet() + + def codegen_args(self): + @dataclasses.dataclass + class Shim: + ref: Any + + def __repr__(self) -> str: + return self.ref + + tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] + args, kwargs = self.unflatten_args(tensor_args, self.constant_args) + if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): + args = self.fill_non_provided_args(args, kwargs) + args = [ + V.graph.wrapper_code.val_to_arg_str(x, param.real_type) + for param, x in zip(self.op_overload._schema.arguments, args) + ] + else: + args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args] + + # let self.codegen_kwargs handle kwargs + self.kwargs.update(kwargs) + return args + + @staticmethod + def find_device(tensor_args, example_output): + if tensor_args: + devices = [arg.get_device() for arg in tensor_args if arg.get_device()] + return devices[0] + if isinstance(example_output, torch.Tensor): + return example_output.device + if isinstance(example_output, (list, tuple)): + device_set = OrderedSet( + FallbackKernel.find_device(None, x) for x in example_output + ) + # Remove None + devices = [device for device in device_set if device] + if len(devices) == 1: + return devices[0] + for device in devices: + if is_gpu(device.type): + return device + return devices[0] + return None + + def has_side_effects(self): + if isinstance(self.op_overload, torch._ops.HigherOrderOperator): + return False + return get_schema_info(self.op_overload).is_mutable() + + def get_inputs_that_alias_output(self): + return self.alias_names + + def get_mutation_names(self): + assert len(self.mutation_names) <= 1 + return self.mutation_names + + # ProxyExecutor Design Note + # We export the ExternFallbackNodes (for custom ops) into a serialized file + # and run it with a host side proxy executor to address the ABI problem + # This is currently only implemented for fbcode. Eventually, we will also make this work for OSS. + # Detailed design doc can be found at + # https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing + def export_extern_kernel_node(self): + assert isinstance(self, FallbackKernel) + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + args = self.fill_non_provided_args(args, kwargs) + ordered_kwargs = [ + kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel + ] + if not V.graph.aot_mode: + # No need to serialize in the cpp wrapper JIT mode + return [*args, *ordered_kwargs] + + serializer = GraphModuleSerializer(None, None) # type: ignore[arg-type] + named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs) # type: ignore[arg-type] + + # serialize_outputs + def handle_single_output(return_type, output): + if isinstance(return_type, torch.TensorType): + # For single Tensor + out = output + if isinstance(output, (list, tuple)): + assert len(output) == 1 + out = output[0] + return export_schema.Argument.create( + as_tensor=export_schema.TensorArgument(name=out.get_name()) + ) + elif isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.TensorType + ): + # For single TensorList + return export_schema.Argument.create( + as_tensors=[ + export_schema.TensorArgument(name=out.get_name()) + for out in output + ] + ) + else: + raise RuntimeError(f"Unsupported return type {type(return_type)}") + + target = self.op_overload + returns = target._schema.returns # type: ignore[union-attr] + if len(returns) == 1: + return_type = returns[0].real_type + output_arguments = [handle_single_output(return_type, self.outputs)] + else: + # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])" + assert isinstance(self.outputs, tuple) + assert len(returns) == len(self.outputs) + output_arguments = [ + handle_single_output(return_schema.real_type, output) + for return_schema, output in zip(returns, self.outputs) + ] + + node = ExternKernelNode( + name=self.get_name(), + node=export_schema.Node( + target=self.op_overload.name(), # type: ignore[union-attr] + inputs=named_arguments, + outputs=output_arguments, + metadata={}, + ), + ) + + V.graph.extern_kernel_nodes.append(node) + + return [*args, *ordered_kwargs] + + def codegen(self, wrapper): + kernel = self.op_overload + if kernel.namespace == "aten": # type: ignore[union-attr] + # Aten Fallback Ops + assert isinstance(kernel, torch._ops.OpOverload) + if V.graph.cpp_wrapper: + from torchgen.aoti.fallback_ops import inductor_fallback_ops + + if config.abi_compatible and str(kernel) not in inductor_fallback_ops: + # C shim v2 is torchgen-ed, which should cover all aten ops. + # If you do hit a missed op, please update fallback_ops.py. + log.warning( + "%s is missing a c-shim implementation, using proxy executor as fallback", + kernel, + ) + self.use_runtime_dispatch = True + elif kernel.namespace == "_quantized": # type: ignore[union-attr] + # Internal Quantized Fallback Ops + assert isinstance(kernel, torch._ops.OpOverload) + if V.graph.cpp_wrapper: + if not config.abi_compatible: + self.use_runtime_dispatch = True + else: + # For non-aten OpOverload, i.e. custom ops + if V.graph.cpp_wrapper: + self.use_runtime_dispatch = True + + if self.use_runtime_dispatch: + self.codegen_comment(wrapper) + + exported_args = None + args = None + if config.abi_compatible: + exported_args = self.export_extern_kernel_node() + else: + args = [*self.codegen_args(), *self.codegen_kwargs()] + + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + self.op_overload, + exported_args, + self.outputs, + ) + else: + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs()] + V.graph.wrapper_code.generate_fallback_kernel(self, args) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + self.codegen_unbacked_symbol_defs(wrapper) + + @staticmethod + def tensor_to_layout(output: torch.Tensor): + return FixedLayout( + output.device, + output.dtype, + convert_shape_to_inductor(output.size()), + convert_shape_to_inductor(output.stride()), + ) + + @classmethod + def create(cls, kernel, *args, **kwargs): + fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,) + context: ContextManager[None] = ( + V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() # type: ignore[assignment] + ) + with context: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, *args, **kwargs) + + device = cls.find_device(tensor_args, example_output) + if example_output is None: + packed = cls( + NoneLayout(device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings=unbacked_bindings, + ) + + else: + assert device, "Not sure where to find device info" + packed = cls( + MultiOutputLayout(device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings=unbacked_bindings, + ) + + def generate_output(output, indices): + if isinstance(output, (list, tuple)): + return type(output)( + generate_output(output[i], indices + [(type(output), i)]) + for i in range(len(output)) + ) + elif isinstance(output, dict): + return { + key: generate_output(val, indices + [(type(output), key)]) + for key, val in output.items() + } + elif isinstance(output, torch.Tensor): + return MultiOutput( + cls.tensor_to_layout(output), + packed, + indices, + ) + elif isinstance(output, int): + return output + elif isinstance(output, torch.SymInt): + return output.node.expr + else: + assert ( + output is None + ), f"FallbackKernel output type {type(output)} is not supported" + return None + + outputs = generate_output(example_output, []) + if isinstance(outputs, (list, tuple, dict)): + packed.outputs = outputs # type: ignore[assignment] + else: + packed.outputs = [outputs] + return outputs + + def apply_constraint(self): + return super().apply_constraint() + + +@dataclasses.dataclass +class ComplexView(FallbackKernel): + """View a complex number as two dtyped numbers or vice versa""" + + def should_allocate(self): + return False + + def get_inputs_that_alias_output(self): + # Signal to codegen that our output buffer isn't safe to reuse + return [self.inputs[0].get_name()] + + def __init__( + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + *, + unbacked_bindings=None, + ): + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + unbacked_bindings=unbacked_bindings, + ) + + +@dataclasses.dataclass +class MultiOutputLayout(IRNode): + device: torch.device + + +class MultiOutput(ExternKernel): + # Given an input MultiOutputLayout buffer, indexes out an actual buffer + # from that result. This doesn't actually produce multiple outputs, + # that's MultiOutputLayout! + def codegen_list_tuple_access(self, basename, indices): + if len(indices) > 0: + itype, i = indices[0] + if issubclass(itype, list): + return self.codegen_list_tuple_access(f"{basename}[{i}]", indices[1:]) + elif issubclass(itype, tuple): + # cpp wrapper code needs to use std::get<> to access a tuple + tuple_access = V.graph.wrapper_code.codegen_tuple_access( + basename, self.get_name(), str(i) + ) + return self.codegen_list_tuple_access(tuple_access, indices[1:]) + elif issubclass(itype, dict): + return self.codegen_list_tuple_access(f"{basename}['{i}']", indices[1:]) + else: + raise AssertionError("non supported index type: ", itype) + else: + return basename + + def codegen(self, wrapper): + wrapper.codegen_multi_output( + self.get_name(), + self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices), + ) + + def __init__(self, layout, input, indices: List[Tuple[Any, ...]]): + super().__init__(None, layout, [input], ()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + self.indices = indices + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return self.inputs[0].get_unbacked_symbol_uses() + + def should_allocate(self): + return False + + def get_inputs_that_alias_output(self): + return [ + inp.get_name() + for inp in self.inputs + if isinstance(inp, FallbackKernel) + and len(inp.get_inputs_that_alias_output()) > 0 + ] + + +@dataclasses.dataclass +class MutableBox(IRNode): + """ + TensorBox / StorageBox allow in-place mutation of Tensors + """ + + data: IRNode + + def __getattr__(self, name): + fn = getattr(self.data, name) + if callable(fn): + return fn + raise AttributeError(f"{type(self.data).__name__}.{name} not callable") + + def realize(self): + return self.data.realize() + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return self.data.get_unbacked_symbol_uses() + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() + + def get_defining_op(self): + return self.data.get_defining_op() + + def codegen_reference(self, writer=None): + return self.data.codegen_reference(writer) + + @property + def layout(self): + return self.data.get_layout() + + def get_layout(self): + return self.layout + + def get_size(self): + return self.data.get_size() + + @property + def dtype(self): + return self.data.dtype + + def __str__(self) -> str: + if isinstance(self.data, MutableBox): + line0 = f"{type(self).__name__}({type(self.data).__name__}(" + endl = "))" + inner = self.data.data + else: + line0 = f"{type(self).__name__}(" + inner = self.data + endl = ")" + + lines = [ + line0, + indent(str(inner)), + endl, + ] + return "\n".join(lines) + + __repr__ = __str__ + + +class TensorBox(MutableBox): + @staticmethod + def create(data): + return TensorBox(StorageBox(data)) + + +class StorageBox(MutableBox): + def is_input_buffer(self): + if isinstance(self.data, (InputBuffer, ReinterpretView)): + return self.data.get_name() in V.graph.graph_inputs + return False + + def is_module_buffer(self): + return ( + isinstance(self.data, (ConstantBuffer)) + and self.data.get_name() in V.graph.constants + ) + + def realize(self): + if isinstance( + self.data, + ( + ComputedBuffer, + InputsKernel, + InputBuffer, + ReinterpretView, + TemplateBuffer, + ), + ): + return self.data.get_name() + assert isinstance(self.data, (Pointwise, Reduction, Scan, Sort)), type( + self.data + ) + origin_node = self.data.get_origin_node() + traceback = self.data.get_traceback() + self.data = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=self.data.get_device(), + dtype=self.data.get_dtype(), + size=self.data.get_size(), + ), + data=self.data, + ) + self.data.name = V.graph.register_buffer(self.data) + V.graph.register_operation(self.data) + self.data.origins = self.origins + self.data.origin_node = origin_node + self.data.traceback = traceback + return self.data.name + + def realize_hint(self): + """ + Called on buffers we expect to be forced to realize later. + """ + if ( + isinstance(self.data, (Pointwise, Reduction)) + and self.data.inner_fn_opcount().nontrivial_read_count > 1 + ): + self.realize() + + def has_exceeded_max_reads(self): + return isinstance(self.data, Pointwise) and ( + self.num_reads() > config.realize_acc_reads_threshold + or self.has_large_inner_fn() + ) + + def should_realize_on_reuse(self, users): + """ + A heuristic to decide if we should realize a tensor + that is used multiple times. + """ + if users > 1 and isinstance(self.data, (Pointwise, Reduction)): + if is_cpu(self.data): + # Heuristic for realizing reused result of heavy ops on cpu + opcount = self.data.inner_fn_opcount() + heavy_ops = ["exp", "sigmoid"] # a list of heavy ops + if any(x in opcount.used_ops for x in heavy_ops): + return True + return ( + self.num_reads() > config.realize_reads_threshold + or self.has_large_inner_fn() + ) + return False + + def mark_reuse(self, users): + if self.should_realize_on_reuse(users): + self.realize() + + def num_reads(self): + return self.data.num_reads() + + +@dataclasses.dataclass +class Subgraph(IRNode): + name: str + graph_module: torch.fx.GraphModule + graph: Optional[GraphLowering] = None + + +def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool: + buffers = [ + buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer + for buffer in buffers + ] + # assuming the same buffer is represented by the same IRNode object + return len(OrderedSet(id(buffer) for buffer in buffers)) < len(buffers) + + +@dataclasses.dataclass +class Conditional(ExternKernel): + predicate: Optional[IRNode] = None + operands: Optional[List[TensorBox]] = None + true_subgraph: Optional[Subgraph] = None + false_subgraph: Optional[Subgraph] = None + outputs: Optional[List[MultiOutput]] = None + + def __init__( + self, + predicate: IRNode, + operands: List[TensorBox], + true_subgraph: Subgraph, + false_subgraph: Subgraph, + layout: MultiOutputLayout, + ): + self.predicate = predicate + self.operands = operands + self.true_subgraph = true_subgraph + self.false_subgraph = false_subgraph + + inputs = [] + if not isinstance(predicate, ShapeAsConstantBuffer): + inputs.append(predicate) + inputs.extend(operands) + + super().__init__( + name=None, + layout=layout, # type: ignore[arg-type] + inputs=inputs, # type: ignore[list-item] + ) + + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @classmethod + def create( + cls, + predicate: TensorBox, + true_fn: Subgraph, + false_fn: Subgraph, + operands: List[TensorBox], + ): + predicate = cls.realize_input(predicate) + operands = [cls.realize_input(x) for x in operands] + + fx_operands = V.graph.current_node.args[-1] + fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] + + for subgraph in (true_fn, false_fn): + if subgraph.graph is None: + # create and lower subgraphs + subgraph.graph = V.graph.make_subgraph( + gm=subgraph.graph_module, + example_inputs=fake_operands, + subgraph_name=subgraph.name, + ) + with V.set_graph_handler(subgraph.graph): + subgraph.graph.run(*fake_operands) + + true_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] + false_outputs = false_fn.graph.graph_outputs # type: ignore[union-attr] + + for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)): + if _has_aliased_buffers(true_outputs): + raise AssertionError( + "Output aliasing is currently not supported in compiled torch.cond. " + f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}" + ) + + # make sure true and false outputs are structurally equivalent + assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs) + for i, (to, fo) in enumerate(zip(true_outputs, false_outputs)): + assert to.get_size() == fo.get_size(), (i, to, fo) + assert to.get_stride() == fo.get_stride(), (i, to, fo) + assert to.get_device() == fo.get_device(), (i, to, fo) + assert to.get_dtype() == fo.get_dtype(), (i, to, fo) + assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo) + + if not isinstance(predicate, ShapeAsConstantBuffer): + # use predicate device for consistent codegen-ing + device = predicate.get_device() + else: + # predicate is not a Tensor: use first operand's device + assert ( + len(operands) > 0 + ), "When predicate is not a Tensor, there must be at least one operand in torch.cond." + device = operands[0].get_device() + + conditional = Conditional( + predicate=predicate, + operands=operands, + true_subgraph=true_fn, + false_subgraph=false_fn, + layout=MultiOutputLayout(device), + ) + + outputs = [ + MultiOutput( + FixedLayout( + device=output.get_device(), + dtype=output.get_dtype(), + size=output.get_size(), + stride=output.get_stride(), + offset=output.get_layout().offset, + ), + conditional, + [(list, i)], + ) + # as the true and false outputs are equivalent, + # we can use either of them here as a "template" + for i, output in enumerate(true_outputs) + ] + + conditional.outputs = outputs + return outputs + + def codegen(self, wrapper): + wrapper.codegen_conditional(self) + + +@dataclasses.dataclass +class WhileLoop(ExternKernel): + carried_inputs: Optional[List[TensorBox]] = None + additional_inputs: Optional[List[TensorBox]] = None + cond_subgraph: Optional[Subgraph] = None + body_subgraph: Optional[Subgraph] = None + outputs: Optional[List[MultiOutput]] = None + + def __init__( + self, + carried_inputs: List[TensorBox], + additional_inputs: List[TensorBox], + cond_subgraph: Subgraph, + body_subgraph: Subgraph, + layout: MultiOutputLayout, + ): + self.carried_inputs = carried_inputs + self.additional_inputs = additional_inputs + self.cond_subgraph = cond_subgraph + self.body_subgraph = body_subgraph + + super().__init__( + name=None, + layout=layout, # type: ignore[arg-type] + inputs=carried_inputs + additional_inputs, # type: ignore[list-item] + ) + + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @classmethod + def create( + cls, + cond_fn: Subgraph, + body_fn: Subgraph, + carried_inputs: List[TensorBox], + additional_inputs: List[TensorBox], + ): + carried_inputs = [cls.realize_input(x) for x in carried_inputs] + additional_inputs = [cls.realize_input(x) for x in additional_inputs] + all_inputs = carried_inputs + additional_inputs + + fx_all_inputs = V.graph.current_node.args[-2] + V.graph.current_node.args[-1] # type: ignore[operator] + fake_all_inputs = [x.meta["val"] for x in fx_all_inputs] # type: ignore[union-attr] + + for subgraph in (cond_fn, body_fn): + if subgraph.graph is None: + # create and lower subgraphs + subgraph.graph = V.graph.make_subgraph( + gm=subgraph.graph_module, + example_inputs=fx_all_inputs, # type: ignore[arg-type] + subgraph_name=subgraph.name, + ) + with V.set_graph_handler(subgraph.graph): + subgraph.graph.run(*fake_all_inputs) + + cond_outputs = cond_fn.graph.graph_outputs # type: ignore[union-attr] + body_outputs = body_fn.graph.graph_outputs # type: ignore[union-attr] + + if _has_aliased_buffers(body_outputs): + raise AssertionError( + "Output aliasing is currently not supported in compiled torch.while_loop. " + f"The outputs of the body_fn subgraph of torch.while_loop are aliased: {body_outputs}" + ) + + # make sure cond_fn returns a boolean scalar Tensor + assert len(cond_outputs) == 1, cond_outputs + assert cond_outputs[0].get_dtype() == torch.bool, cond_outputs + assert len(cond_outputs[0].get_size()) == 0, cond_outputs + + assert ( + len(all_inputs) > 0 + ), "torch.while_loop is assumed to have at least one operand." + + device = all_inputs[0].get_device() + + # make sure carried_inputs and body outputs are structurally equivalent + assert len(carried_inputs) == len(body_outputs), (carried_inputs, body_outputs) + for i, (op, bo) in enumerate(zip(carried_inputs, body_outputs)): + assert op.get_size() == bo.get_size(), (i, op, bo) + assert op.get_stride() == bo.get_stride(), (i, op, bo) + # assume all carried_inputs and outputs are on the same device + # as the MultiOutputLayout below requires single device + assert op.get_device() == bo.get_device() == device, (i, op, bo, device) + assert op.get_dtype() == bo.get_dtype(), (i, op, bo) + assert op.get_layout().offset == bo.get_layout().offset, (i, op, bo) + + while_loop = WhileLoop( + carried_inputs=carried_inputs, + additional_inputs=additional_inputs, + cond_subgraph=cond_fn, + body_subgraph=body_fn, + # asserted above that there is at least one operand + layout=MultiOutputLayout(device), + ) + + outputs = [ + MultiOutput( + FixedLayout( + device=output.get_device(), + dtype=output.get_dtype(), + size=output.get_size(), + stride=output.get_stride(), + offset=output.get_layout().offset, + ), + while_loop, + [(list, i)], + ) + for i, output in enumerate(body_outputs) + ] + + for inp, out in zip(carried_inputs, outputs): + if inp.get_name() in V.graph.graph_inputs: + # if a carried input of the while_loop is a graph input, + # it can be returned as is when the number of iterations + # is zero. due to this, we can't (generally) reuse the + # output buffers corresponding to the graph inputs, as + # the inputs may end up being mutated. + V.graph.never_reuse_buffers.add(out.get_name()) + + while_loop.outputs = outputs + return outputs + + def codegen(self, wrapper): + wrapper.codegen_while_loop(self) + + +class EffectfulKernel(FallbackKernel): + def __init__( + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + *, + unbacked_bindings=None, + ): + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + unbacked_bindings=unbacked_bindings, + ) + + from torch._higher_order_ops.effects import get_effect_key + + effect_type = get_effect_key(kernel, (*nontensor_args, *tensor_args), kwargs) + assert effect_type is not None + self.effect_type = effect_type + self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None) + V.graph.effectful_ops[effect_type] = self + + def get_read_writes(self): + read_writes = super().get_read_writes() + + if self.prev_effect_buffer is not None: + read_writes.reads.add( + dependencies.StarDep(self.prev_effect_buffer.get_name()) + ) + + return read_writes + + def has_side_effects(self): + return True + + +@dataclasses.dataclass +class TorchBindObject(IRNode): + name: str + value: torch._C.ScriptObject + + def get_name(self): + return self.name + + def get_device(self): + return None # is there a device?? + + def codegen_reference(self, writer=None): + return self.name + + +class _CollectiveKernel(FallbackKernel): + def should_allocate(self): + return False + + def has_side_effects(self): + return True + + # This is identical to FallbackKernel.set_cpp_kernel(), minus the + # part that checks against input aliasing and mutation. + def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None): + from .codegen.wrapper import get_cpp_op_schema + + assert ( + type(self.op_overload) is torch._ops.OpOverload + ), "Setting cpp kernel needs a valid op_overload" + kernel = self.op_overload + self.cpp_kernel_name = kernel._schema.name + self.cpp_kernel_overload_name = kernel._schema.overload_name + self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] + + self.cpp_op_schema = get_cpp_op_schema(kernel) + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in kernel._schema.arguments if x.kwarg_only + ] + + # NOTE: [In-Place Collective Safety] + # Between the initiation and completion of an in-place collective, the + # input buffers are subject to both volatile reads and volatile writes. + # They must not be read, written to or reused by another kernel. To ensure + # the constraints, we model collective -> wait_tensor as as two-step + # mutation of the input buffers. + @classmethod + def create_inplace( + cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs + ) -> None: + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + assert not unbacked_bindings, f"{kernel} {unbacked_bindings}" + for tensor_arg in tensor_args: + tensor_arg.realize() + + device = tensor_args[0].get_device() + packed = cls( + NoneLayout(device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + + inps = pytree.tree_leaves(inputs) + packed.mutation_outputs.extend( + [MutationOutput(NoneLayout(device), buf, packed) for buf in inps] + ) + + # For inplace collective ops, the input is guaranteed to be alias of the returned value of op. + packed.alias_names.extend([inp.get_name() for inp in inps]) + if "out" in kwargs: + packed.mutation_outputs.append( + MutationOutput(NoneLayout(device), kwargs["out"], packed) + ) + # For out-variant collective ops, the `out=` arg is guaranteed to be alias of the returned value of op. + packed.alias_names.append(kwargs["out"].get_name()) + + # NOTE: [Out-of-Place Collective Safety] + # Between the initiation and completion of an out-of-place collective: + # + # Input buffers: + # - Are subject to volatile reads + # - Can be read by another kernel + # - Must not be written to or reused by another kernel + # + # Output buffers: + # - Are subject to volatile writes + # - Must not be read, written to or reused by another kernel + # + # To ensure the safety of input buffers without sacrificing read + # availability, we add input buffers as read deps of wait_tensor kernels. + # + # To ensure the safety of output buffers, we model wait_tensor as a + # mutation to the output buffer. Note we also assumes the user program being + # correct and the output buffer is not consumed by kernels other than + # wait_tensor. + # + # TODO(yifu): add a pre-grad pass to validate the correctness of collective + # usage in the user program. + @classmethod + def create_out_of_place( + cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs + ): + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + assert not unbacked_bindings, f"{kernel}, {unbacked_bindings}" + for tensor_arg in tensor_args: + tensor_arg.realize() + + if isinstance(example_output, list): + device = cls.find_device(tensor_args, example_output) + packed = cls( + MultiOutputLayout(device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.outputs = [ + MultiOutput( + cls.tensor_to_layout(tensor), + packed, + [(list, i)], + ) + for i, tensor in enumerate(example_output) + ] + return packed.outputs + else: + packed = cls( + cls.tensor_to_layout(example_output), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.outputs = [packed] + return packed + + +class _WaitKernel(_CollectiveKernel): + def get_volatile_reads(self): + inp = self.inputs[0] + if isinstance(inp, _CollectiveKernel): + # Out-of-place single-output + return [inp.inputs[0]] + elif isinstance(inp, MultiOutput): + # This can be two things: + # 1. Out-of-place multi-output coll + # 2. In-place coll with inputs coming from another MultiOutput + coll = inp.inputs[0] + # Case 1 + if isinstance(coll, _CollectiveKernel): + _, idx = inp.indices[0] + return [coll.inputs[idx]] + # Case 2 + return [] + else: + # In-place requires no additional deps handling for volatile + # reads since the inputs are mutated. + return [] + + @classmethod + def create_wait(cls, kernel, inp: TensorBox) -> None: + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, inp) + assert not unbacked_bindings, f"{kernel} {unbacked_bindings}" + packed = cls( + NoneLayout(inp.get_device()), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.mutation_outputs.append( + MutationOutput(NoneLayout(inp.get_device()), inp, packed) + ) + + def get_read_writes(self): + read_writes = super().get_read_writes() + # See [Out-of-Place Collective Safety]. + volatile_reads = self.get_volatile_reads() + for vr in volatile_reads: + read_writes.reads.add(dependencies.StarDep(vr.get_name())) + return read_writes + + +# NB: recursive structure here reflects val_to_arg_str, avoid +# calling free_unbacked_symbols on "exotic" types that don't get pexpr +# treatment +def maybe_free_unbacked_symbols(s: object) -> OrderedSet[Symbol]: + if isinstance(s, (SymTypes, Expr)): + # This branch should be impossible in return position + return free_unbacked_symbols(s) + elif isinstance(s, (tuple, list)): + r: OrderedSet[sympy.Symbol] = OrderedSet() + for t in s: + r |= maybe_free_unbacked_symbols(t) + return r + elif isinstance(s, torch.Tensor): + # This branch is impossible in constant-args position + return free_unbacked_symbols(s) + else: + return OrderedSet() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/jagged_lowerings.py b/.venv/lib/python3.11/site-packages/torch/_inductor/jagged_lowerings.py new file mode 100644 index 0000000000000000000000000000000000000000..c96c9f4ae2d443779f7514e3077287c327c92123 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/jagged_lowerings.py @@ -0,0 +1,264 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import List, Optional, Tuple, Union + +import sympy + +import torch + +from .ir import Pointwise, TensorBox +from .lowering import fallback_handler, is_integer_type, register_lowering +from .virtualized import ops + + +# pyre-ignore[2,3] +def dense_idx_to_jagged_idx(batch_idx, seq_idx, offsets_loader, jagged_len): + # jagged_len + 1 is used as the upper bound, + # because the last sequence length may be zero + begin_idx = ops.indirect_indexing( + offsets_loader([batch_idx]), + jagged_len + 1, + ) + end_idx = offsets_loader([batch_idx + 1]) + jagged_idx = begin_idx + seq_idx + return jagged_idx, end_idx + + +def get_inverse_offsets( + offsets: TensorBox, + jagged_len: Union[int, sympy.Expr], + realize: bool = True, +) -> TensorBox: + """ + Returns "inverse_offsets" - the inverse of the offsets array. + offsets maps batch index (dense) to jagged index (i.e. offset into jagged tensor). + inverse_offsets maps jagged index to batch index. + + e.g. for offsets [0, 3, 4, 9, 10] this will return + inverse_offsets = [0, 0, 0, 1, 2, 2, 2, 2, 2, 3] + + For the given offsets, the computed inverse_offsets are cached + on the first call and reused in the further calls. + """ + + if hasattr(offsets, "inverse_offsets"): + # inverse_offsets are already computed + # for these offsets: can reuse + return offsets.inverse_offsets + + # ops.bucketize takes offsets.get_name() which doesn't exist on Pointwise + # kernels, i.e. we need to realize it before using. In other words, we need + # offsets to be in global memory so that we can binary search over the + # entire tensor + offsets.realize() + device: torch.device = offsets.get_device() + dtype: torch.dtype = offsets.get_dtype() + + # pyre-ignore[2,3] + def inner_fn(index): + idx = index[0] + bucket = ops.bucketize( + values=ops.index_expr(idx, dtype), + offsets_name=offsets.get_name(), + offsets_size=offsets.get_size()[0], + indexing_dtype=dtype, + right=True, + ) + # ops.bucketize above returns 1-based bucket indices, + # but we need 0-based, hence we subtract 1 from batch + return bucket - 1 + + inverse_offsets = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[jagged_len], + ) + + if realize: + # "freeze" the node so that it doesn't get inlined downstream. + inverse_offsets.realize() + + # cache inverse_offsets for further reuse + offsets.inverse_offsets = inverse_offsets # type: ignore[attr-defined] + + return inverse_offsets + + +def jagged_idx_to_dense_idx( + jagged_idx, # pyre-ignore[2] + inverse_offsets_loader, # pyre-ignore[2] + offsets_loader, # pyre-ignore[2] + batch_size: Union[int, sympy.Expr], + max_seq_len: Union[int, sympy.Expr], + offsets_dtype: torch.dtype, +) -> Tuple[sympy.Expr, sympy.Expr]: + batch_idx = ops.indirect_indexing( + inverse_offsets_loader([jagged_idx]), + batch_size + 1, + ) + batch_start = offsets_loader([batch_idx]) + seq = ops.index_expr(jagged_idx, offsets_dtype) - batch_start + # check=False because there may be sequences longer than max_seq_len + seq_idx = ops.indirect_indexing(seq, max_seq_len, check=False) + return batch_idx, seq_idx + + +def register_jagged_ops(): + # pyre-ignore[56] + @register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default) + def _jagged_to_padded_dense_forward( + jagged_values: TensorBox, + jagged_offsets: List[TensorBox], + max_lengths: List[int], # list of ints/SymInts + padding_value: float = 0.0, + ) -> TensorBox: + device = jagged_values.get_device() + dtype = jagged_values.get_dtype() + + jagged_values_size = jagged_values.get_size() + + # only handle the common case of a single jagged dimension + if ( + len(jagged_offsets) != 1 + or device.type != "cuda" + or device != jagged_offsets[0].get_device() + or len(jagged_values_size) != 2 + or len(jagged_offsets[0].get_size()) != 1 + or len(max_lengths) != len(jagged_offsets) + or not is_integer_type(jagged_offsets[0]) + ): + return fallback_handler( + torch.ops.aten._jagged_to_padded_dense_forward.default, + add_to_fallback_set=False, + )( + jagged_values, + jagged_offsets, + max_lengths, + padding_value, + ) + + offsets: TensorBox = jagged_offsets[0] + offsets_len = offsets.get_size()[0] + offsets_dtype = offsets.get_dtype() + batch_size = offsets_len - 1 + max_seq_len = max_lengths[0] + embedding_len = jagged_values_size[1] + jagged_len = jagged_values_size[0] + + output_size = [batch_size, max_seq_len, embedding_len] + + values_loader = jagged_values.make_loader() + offsets_loader = offsets.make_loader() + + # pyre-ignore[2,3,53] + def inner_fn(index): + # dense tensor size: [B, N, D] + batch_idx, seq_idx, emb_idx = index + jagged_idx, end_idx = dense_idx_to_jagged_idx( + batch_idx=batch_idx, + seq_idx=seq_idx, + offsets_loader=offsets_loader, + jagged_len=jagged_len, + ) + return ops.masked( + ops.lt( + ops.index_expr(jagged_idx, offsets_dtype), + end_idx, + ), + lambda: values_loader([jagged_idx, emb_idx]), + padding_value, + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=output_size, + ) + + def _dense_to_jagged_forward_impl( + fallback_op, # pyre-ignore[2] + dense: TensorBox, + jagged_offsets: List[TensorBox], + jagged_len: Optional[int] = None, + ) -> TensorBox: + device = dense.get_device() + dtype = dense.get_dtype() + + dense_size = dense.get_size() + + # only handle the common case of a single jagged dimension + if ( + len(jagged_offsets) != 1 + or device.type != "cuda" + or device != jagged_offsets[0].get_device() + or len(jagged_offsets[0].get_size()) != 1 + or len(dense_size) != 3 + or jagged_len is None + or not is_integer_type(jagged_offsets[0]) + ): + return fallback_handler(fallback_op, add_to_fallback_set=False)( + dense, + jagged_offsets, + jagged_len, + ) + + offsets: TensorBox = jagged_offsets[0] + offsets_dtype = offsets.get_dtype() + batch_size = dense_size[0] + max_seq_len = dense_size[1] + embedding_len = dense_size[-1] + + output_size = [jagged_len, embedding_len] + + dense_loader = dense.make_loader() + offsets_loader = offsets.make_loader() + + inverse_offsets = get_inverse_offsets( + offsets=offsets, + jagged_len=jagged_len, + ) + inverse_offsets_loader = inverse_offsets.make_loader() + + # pyre-ignore[2,3,53] + def inner_fn(index): + # jagged tensor size: [sum_B(N_B), D] + jagged_idx, emb_idx = index + batch_idx, seq_idx = jagged_idx_to_dense_idx( + jagged_idx=jagged_idx, + offsets_loader=offsets_loader, + inverse_offsets_loader=inverse_offsets_loader, + batch_size=batch_size, + max_seq_len=max_seq_len, + offsets_dtype=offsets_dtype, + ) + return ops.masked( + ops.lt( + ops.index_expr(seq_idx, offsets_dtype), + ops.index_expr(max_seq_len, offsets_dtype), + ), + lambda: dense_loader([batch_idx, seq_idx, emb_idx]), + 0.0, # jagged sequence longer than max_seq_len + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=output_size, + ) + + # pyre-ignore[56] + @register_lowering(torch.ops.aten._padded_dense_to_jagged_forward) + def _dense_to_jagged_forward( + dense: TensorBox, + jagged_offsets: List[TensorBox], + jagged_len: Optional[int] = None, + ) -> TensorBox: + return _dense_to_jagged_forward_impl( + fallback_op=torch.ops.aten._padded_dense_to_jagged_forward.default, + dense=dense, + jagged_offsets=jagged_offsets, + jagged_len=jagged_len, + ) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py b/.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..adf921913edfcdfc902ab70307edae5bf6c46528 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py @@ -0,0 +1,6449 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import itertools +import logging +import math +import operator +import os +import warnings +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from unittest.mock import patch + +import sympy + +import torch +import torch.ao.quantization.fx._decomposed +import torch.fx +import torch.utils._pytree as pytree +from torch._higher_order_ops.associative_scan import associative_scan_op +from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation +from torch._prims_common import ( + canonicalize_dim, + canonicalize_dims, + check, + dtype_to_type, + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + get_computation_dtype, + is_boolean_dtype, + is_float_dtype, + is_integer_dtype, + Number, +) +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.utils._sympy.functions import ( + CeilDiv, + FloorDiv, + Identity, + IntTrueDiv, + ModularIndexing, +) + +from .._dynamo.utils import import_submodule +from . import config, inductor_prims, ir, test_operators # NOQA: F401 +from .decomposition import decompositions, get_decompositions +from .ir import ( + DtypeView, + ExpandView, + IndexingConstant, + is_triton, + ops_wrapper, + PermuteView, + Pointwise, + Reduction, + SqueezeView, + TensorBox, + validate_ir, + View, +) +from .utils import ( + ceildiv, + decode_device, + is_dynamic, + is_gpu, + is_pointwise_use, + needs_fallback_due_to_atomic_add_limitations, + pad_listlike, + sympy_product, + use_scatter_fallback, +) +from .virtualized import ops, V + + +log = logging.getLogger(__name__) +lowerings: Dict[torch._ops.OpOverload, Callable[..., Any]] = {} +# Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints +_maybe_layout_constraints: Dict[ + torch._ops.OpOverload, Optional[Callable[..., Any]] +] = {} +fallbacks: Set[torch._ops.OpOverload] = set() +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims +needs_realized_inputs: Set[torch._ops.OpOverload] = set() +foreach_ops: Set[torch._ops.OpOverload] = set() +inplace_foreach_ops: Set[torch._ops.OpOverload] = set() +inplaceable_foreach_ops: Dict[torch._ops.OpOverload, torch._ops.OpOverload] = {} +quantized_decomposed = torch.ops.quantized_decomposed + + +def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., Any]]: + """Get layout constraints. Returns None if there are no layout constraints.""" + if not isinstance(fn, torch._ops.OpOverload): + # Only OpOverloads have layout constraints. + return None + if fn in _maybe_layout_constraints: + return _maybe_layout_constraints[fn] + # OpOverload with custom lowerings override tag-based layout constraints + if fn in lowerings: + _maybe_layout_constraints[fn] = None + return None + # We lazily register tag-based layout constraints. + + def handle_layout_constraint_tag(tag): + if tag is torch._C.Tag.needs_fixed_stride_order: + _maybe_layout_constraints[fn] = constrain_to_fx_strides + return _maybe_layout_constraints[fn] + elif tag is torch._C.Tag.flexible_layout: + _maybe_layout_constraints[fn] = None + return None + else: + raise AssertionError(f"Unknown layout constraint tag: {tag}") + + tag = get_layout_constraint_tag(fn) + return handle_layout_constraint_tag(tag) + + +def get_layout_constraint_tag(fn): + tags_by_priority = [ + torch._C.Tag.needs_fixed_stride_order, + torch._C.Tag.flexible_layout, + ] + for tag in tags_by_priority: + if tag in fn.tags: + return tag + return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) + + +def assert_nyi(cond, msg): + if not cond: + raise NotImplementedError(f"inductor does not support {msg}") + + +def add_needs_realized_inputs(fn): + if isinstance(fn, (list, tuple, set)): + return [add_needs_realized_inputs(x) for x in fn] + needs_realized_inputs.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + needs_realized_inputs.update( + getattr(fn, overload) for overload in fn.overloads() + ) + + +def add_layout_constraint(fn, constraint): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + _maybe_layout_constraints[getattr(fn, overload)] = constraint + else: + _maybe_layout_constraints[fn] = constraint + + +add_needs_realized_inputs( + [ + aten.as_strided, + aten.as_strided_copy, + aten.avg_pool2d, + aten.avg_pool2d_backward, + aten.bmm, + aten.convolution, + aten.convolution_backward, + aten.max_pool2d_with_indices, + aten.max_pool2d_with_indices_backward, + aten.mm, + aten.upsample_nearest2d, + aten._upsample_nearest_exact2d, + aten._int_mm, + ] +) + +# TODO(jansel): ezyang says we won't need this in the future, try removing it +# based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28 +DTYPE_ID_LOOKUP = { + 0: torch.uint8, + 1: torch.int8, + 2: torch.int16, + 3: torch.int32, + 4: torch.int64, + 5: torch.float16, + 6: torch.float32, + 7: torch.float64, + 8: torch.complex32, + 9: torch.complex64, + 10: torch.complex32, + 11: torch.bool, + 15: torch.bfloat16, + # TODO(jansel): add quantized types? + # _(c10::qint8, QInt8) /* 12 */ + # _(c10::quint8, QUInt8) /* 13 */ + # _(c10::qint32, QInt32) /* 14 */ + # _(c10::quint4x2, QUInt4x2) /* 16 */ + # _(c10::quint2x4, QUInt2x4) /* 17 */ +} + + +def decode_dtype(dtype: int): + if not isinstance(dtype, int): + return dtype + assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP" + dtype = DTYPE_ID_LOOKUP[dtype] + return dtype + + +def is_integer_type(x): + if isinstance(x, TensorBox): + return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + elif isinstance(x, sympy.Expr): + return x.is_integer is True # type: ignore[attr-defined] + else: + return isinstance(x, int) + + +def is_boolean_type(x): + if isinstance(x, TensorBox): + return is_boolean_dtype(x.get_dtype()) + else: + return isinstance(x, bool) + + +def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND): + def construct_input(inp): + if isinstance(inp, (Number, sympy.Basic)): + return inp + else: + assert hasattr(inp, "get_dtype") + dim = len(inp.get_size()) + # construct a tmp tensor to feed into torch.result_type + return torch.zeros([1] * dim, dtype=inp.get_dtype()) + + inps = [construct_input(arg) for arg in args] + _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind) + return dtype + + +def get_overloads(aten_fn): + if not isinstance(aten_fn, (list, tuple)): + aten_fn = [aten_fn] + else: + aten_fn = list(aten_fn) + + for fn in list(aten_fn): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + if other_fn not in lowerings: + aten_fn.append(other_fn) + + return aten_fn + + +def in_namespace(op, namespace): + if isinstance(op, torch._ops.OpOverloadPacket): + return namespace in op._qualified_op_name + elif isinstance(op, torch._ops.OpOverload): + return namespace in op.name() + return False + + +def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool): + indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + if (type_promotion_kind or convert_input_to_bool) and indices: + if convert_input_to_bool: + dtype = torch.bool + else: + # FIXME that's a crude approximation for promoting args + promoting_args = [ + a + for a in args + if isinstance(a, (Number, sympy.Basic)) + or getattr(a, "dtype", None) is not None + ] + dtype = get_promoted_dtype( + *promoting_args, type_promotion_kind=type_promotion_kind + ) + + # sometimes args are an immutable list so we can't mutate them + def promote(arg): + if isinstance(arg, TensorBox): + return to_dtype(arg, dtype) + elif isinstance(arg, ir.Constant): + return ir.Constant(arg.value, dtype, args[indices[0]].get_device()) + else: + return arg + + args = [promote(a) for a in args] + if broadcast and indices: + for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + args[i] = x + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + + return args + + +def _register_foreach_lowering(aten_fn, decomp_fn): + """ + Add a foreach lowering to lowerings dict. + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + assert len(args) <= 2 + out = decomp_fn(*args, **kwargs) + validate_ir(out) + return out + + aten_fns = get_overloads(aten_fn) + foreach_ops.update(aten_fns) + lowerings.update(dict.fromkeys(aten_fns, wrapped)) + return wrapped + + +def _register_lowering( + aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool +): + """ + Add a lowering to lowerings dict + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + args: Union[List[Any], Tuple[Any, ...], Dict[Any, Any]] = list(args) + unpacked = False + # TODO maybe we need to use pytrees here + if len(args) == 1 and isinstance(args[0], (list, tuple)): + unpacked = True + args = args[0] + + # kwargs tensors not supported yet unless it's a fallback op + if not all( + (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn + ): + assert not any(isinstance(x, TensorBox) for x in kwargs.values()) + # explicitly assert for "out=" ops for better error messages + assert not any( + x == "out" for x in kwargs.keys() + ), "out= ops aren't yet supported" + + args = transform_args( + args, broadcast, type_promotion_kind, convert_input_to_bool + ) + + if unpacked: + args = [args] + + out = decomp_fn(*args, **kwargs) + validate_ir(out) + + return out + + aten_fn = get_overloads(aten_fn) + + lowerings.update(dict.fromkeys(aten_fn, wrapped)) + return wrapped + + +def register_lowering( + aten_fn, + broadcast=False, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, +): + """ + Shim to support decorator syntax. + """ + return functools.partial( + _register_lowering, + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + ) + + +def broadcast_symbolic_shapes(a, b): + """ + Broadcasting logic based on symbolic shapes. + + We give the shapes 0 and 1 concrete values, while all other shapes + are symbolic sympy formulas. + """ + output = [] + for x, y in itertools.zip_longest( + reversed(a), reversed(b), fillvalue=sympy.Integer(1) + ): + if y == 1: + output.append(x) + elif x == 1: + output.append(y) + else: + V.graph.sizevars.guard_equals(x, y) + if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols): + output.append(y) # prefer shorter formula + else: + output.append(x) + return tuple(reversed(output)) + + +def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None): + assert ( + override_return_dtype is None or type_promotion_kind is None + ), "only one of override_return_dtype or type_promotion_kind may be given" + + if override_return_dtype is None and type_promotion_kind is None: + type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + + if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs): + return inputs + if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs): + dtype = override_return_dtype or get_promoted_dtype( + *inputs, type_promotion_kind=type_promotion_kind + ) + + def const_func(x): + if isinstance(x, sympy.Basic): + return ir.IndexingConstant(x, dtype, decode_device(None)) + else: + return ir.Constant(x, dtype, decode_device(None)) + + return [const_func(x) for x in inputs] + ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView, ir.Constant))) + out = [] + for x in inputs: + if isinstance(x, (int, float)): + out.append( + ExpandView.create( + ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size()) + ) + ) + elif isinstance(x, sympy.Basic): + out.append( + ExpandView.create( + IndexingConstant(x, ex.get_dtype(), ex.get_device()), + list(ex.get_size()), + ) + ) + else: + out.append(x) + + return out + + +def make_pointwise( + fn, + override_return_dtype=None, + override_device=None, + override_fn_when_input_bool=None, + override_fn_when_gpu_float64=None, + allow_alpha=False, + triton_fallback=None, +): + def inner(*inputs: List[TensorBox], alpha=None): + if triton_fallback is not None and any(map(is_triton, inputs)): + assert not allow_alpha # not implemented + return triton_fallback(*inputs) + + inputs = promote_constants(inputs, override_return_dtype) + if allow_alpha: + if alpha is not None and alpha != 1: + inputs = list(inputs) + inputs[-1] = mul(inputs[-1], alpha) + else: + assert alpha is None + loaders = [x.make_loader() for x in inputs] + ranges = inputs[0].get_size() + dtype = override_return_dtype or inputs[0].get_dtype() + is_gpu_device = is_gpu(decode_device(inputs[0].get_device()).type) + + for other in inputs[1:]: + assert isinstance(other, ir.BaseConstant) or len(ranges) == len( + other.get_size() + ), f"ndim mismatch {fn} {ranges} {other.get_size()}" + + # in tracing, we will annotate pointwise nodes that correspond to the output of + # a pointwise node that would have been run in eager. intermediary pointwise nodes + # during decompositions are not annotated. + emulate_precision_casts = ( + V.graph is not None + and getattr(V.graph, "current_node", None) is not None + and V.graph.current_node.meta is not None + and V.graph.current_node.meta.get("low_precision_pointwise_barrier", False) + and dtype in (torch.bfloat16, torch.float16) + ) + + def inner_fn(index): + assert len(index) == len(ranges), f"wrong ndim {index} {ranges}" + if dtype == torch.bool and override_fn_when_input_bool is not None: + return override_fn_when_input_bool(*[load(index) for load in loaders]) + elif ( + override_fn_when_gpu_float64 + and is_gpu_device + and dtype == torch.float64 + ): + return override_fn_when_gpu_float64(*[load(index) for load in loaders]) + else: + inputs_loaded = [] + for load in loaders: + out = load(index) + if emulate_precision_casts: + downcast = ops.to_dtype(out, dtype, use_compute_types=False) + out = ops.to_dtype(downcast, dtype) + inputs_loaded.append(out) + + out = fn(*inputs_loaded) + if emulate_precision_casts: + # fp16/bf16 kernels are computed in fp32. Casting down to fp16/bf16 here, + # then upcasting again, to emulate casts that eager would do. + downcast = ops.to_dtype(out, dtype, use_compute_types=False) + return ops.to_dtype(downcast, dtype) + return out + + if not override_device: + device = None + for i in inputs: + if is_gpu(i.get_device().type): + device = i.get_device() + break + if not device: + device = inputs[0].get_device() + + device = override_device or device + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + ) + + return inner + + +def make_foreach_pointwise(pw_fn, allow_alpha=False): + def inner(*inputs: List[List[TensorBox]], alpha=1): + # group by device, whether any of the inputs are dynamic, and whether their types match + # (proxy for type promotion) + def group_args(arg_pairs): + out = defaultdict(list) + for i, args in enumerate(arg_pairs): + use_foreach = ( + not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes + ) + device = None + for t in args: + if isinstance(t, TensorBox): + device = t.data.get_device() + break + assert ( + device is not None + ), "foreach op should have at least one tensor arg" + out[(device, use_foreach)].append((i, args)) + return out + + realize_outputs = ( + len(V.graph.current_node.users) == 0 + or V.graph.current_node.target in inplace_foreach_ops + ) + for node in V.graph.current_node.users: + for user in node.users: + if not (user.op == "call_function" and (user.target in foreach_ops)): + realize_outputs = True + + a_list_input = None + for input in inputs: + if isinstance(input, (list, tuple)): + a_list_input = input + break + assert ( + a_list_input is not None + ), "at least one input must be a list to a foreach op" + + # broadcast scalar inputs to match length of list inputs + broadcast_inputs = [] + for input in inputs: + if not isinstance(input, (list, tuple)): + broadcast_inputs.append([input] * len(a_list_input)) + else: + broadcast_inputs.append(input) + + groups = group_args(zip(*broadcast_inputs)) + + outputs = [None] * len(a_list_input) + for (device, use_foreach), group in groups.items(): + operation_list: List[str] = [] + for ( + output_ind, + args, + ) in group: + if allow_alpha: + output = pw_fn(*args, alpha=alpha) + else: + output = pw_fn(*args) + + outputs[output_ind] = output + + if ( + V.graph.has_feature(device, BackendFeature.FOREACH) + and use_foreach + and realize_outputs + ): + output.realize() + operation_list.append(output.get_operation_name()) + + if operation_list: + V.graph.register_operation_list(operation_list) + + assert all(x is not None for x in outputs) + return outputs + + return inner + + +def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False): + src_dtype = x.get_dtype() + if src_dtype == dtype: + return clone(x) if copy else x + + def _to_dtype(x): + return ops.to_dtype(x, dtype, src_dtype=src_dtype) + + return make_pointwise(_to_dtype, override_return_dtype=dtype)(x) + + +@register_lowering(prims.convert_element_type, type_promotion_kind=None) +def _convert_element_type(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + if x.get_size(): + # Decompose since aa aten fallback is more friendly for c++ codegen. + # This decomposition doesn't work for empty tensor, which needs more investigation. + dst = empty_like(x, dtype=dtype) + ir.InplaceCopyFallback.create(dst, x) + return dst + else: + return fallback_handler( + prims.convert_element_type.default, add_to_fallback_set=False + )(x, dtype) + return to_dtype(x, dtype, copy=True) + + +def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False): + x_dtype = x.get_dtype() + if x_dtype == dtype: + return clone(x) if copy else x + + def _get_primitive_bitwidth(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits + else: + return torch.iinfo(dtype).bits + + src_bits = _get_primitive_bitwidth(x_dtype) + dst_bits = _get_primitive_bitwidth(dtype) + if src_bits != dst_bits: + # fallback to aten eager implementation for differing bitwidths + return fallback_handler(aten.view.dtype)(x, dtype) + else: + return TensorBox(DtypeView.create(x, dtype)) + + +@register_lowering(aten.view.dtype, type_promotion_kind=None) +def _view_dtype(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + return TensorBox.create( + ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype) + ) + return to_dtype_bitcast(x, dtype) + + +def to_device(x: TensorBox, device: torch.device, *, copy=False): + device = decode_device(device) + if x.get_device() == device: + return clone(x) if copy else x + return TensorBox.create(ir.DeviceCopy.create(x, device)) + + +@register_lowering(prims.device_put, type_promotion_kind=None) +def _device_put(x: TensorBox, device: torch.device): + return to_device(x, device, copy=True) + + +def register_pointwise( + aten_fn, + name=None, + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, + override_return_dtype=None, + override_fn_when_input_bool=None, + allow_alpha=False, + use_libdevice_for_f64=False, + triton_fallback=None, +): + """A pointwise function that maps ops.{name} to inputs""" + name = name or aten_fn.__name__ + fn = ops_wrapper(name) + if use_libdevice_for_f64: + fn_libdevice = ops_wrapper("libdevice_" + name) + if override_fn_when_input_bool is not None: + override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool) + + fn = make_pointwise( + fn, + override_return_dtype=override_return_dtype, + override_fn_when_input_bool=override_fn_when_input_bool, + override_fn_when_gpu_float64=fn_libdevice if use_libdevice_for_f64 else None, # type: ignore[possibly-undefined] + allow_alpha=allow_alpha, + triton_fallback=triton_fallback, + ) + fn = register_lowering( + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + convert_input_to_bool=convert_input_to_bool, + )(fn) + return fn + + +def register_frexp(): + """A pointwise function that maps ops.frexp to inputs""" + name = "frexp" + frexp = ops_wrapper("frexp") + + def frexp0(*args, **kwargs): + return frexp(*args, **kwargs)[0] # type: ignore[index] # next PR + + def frexp1(*args, **kwargs): + return frexp(*args, **kwargs)[1] # type: ignore[index] # next PR + + pw_fns = [ + make_pointwise(frexp0), + make_pointwise(frexp1, override_return_dtype=torch.int32), + ] + + def fn(*args, **kwargs): + return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs) + + fn = register_lowering( + aten.frexp, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + )(fn) + return fn + + +register_frexp() + + +def register_foreach_pointwise( + aten_fn, + pointwise_lowering_fn, + allow_alpha=False, +): + fn = make_foreach_pointwise(pointwise_lowering_fn, allow_alpha=allow_alpha) + fn = _register_foreach_lowering(aten_fn, fn) + return fn + + +@register_lowering(aten.where, broadcast=False, type_promotion_kind=None) +def where(cond, a, b): + def fn(*args): + return ops.where(*args) + + if isinstance(a, (float, int)): + a = constant_like(a)(b) + if isinstance(b, (float, int)): + b = constant_like(b)(a) + + args = [cond, a, b] + dtype = get_promoted_dtype( + args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + args[i] = x + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + return make_pointwise(fn, override_return_dtype=dtype)( + args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype) + ) + + +@register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None) +def broadcast_tensors(*inputs): + if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): + return broadcast_tensors(*inputs[0]) + target: List[sympy.Expr] = functools.reduce( + broadcast_symbolic_shapes, [x.get_size() for x in inputs], [] + ) + outputs = [] + for x in inputs: + sizes = x.get_size() + if len(sizes) != len(target) or any( + ((a == 1 and b != 1) or (a != 1 and b == 1)) for a, b in zip(sizes, target) + ): + x = expand(x, target) + outputs.append(x) + return outputs + + +@register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of]) +def nop(x): + return x # AOT autograd handles this for us + + +if hasattr(aten, "lift_fresh"): + register_lowering(aten.lift_fresh)(nop) + + +@register_lowering(aten.squeeze, type_promotion_kind=None) +def squeeze(x, dim=None): + assert isinstance(x, TensorBox) + if dim is None: + return TensorBox(SqueezeView.create(x.data)) + + dim = ( + V.graph.sizevars.evaluate_static_shape(dim) + if isinstance(dim, (int, sympy.Expr)) + else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim) + ) + dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload] + dims = set((dim,) if not isinstance(dim, tuple) else dim) + + new_shape = [] + for d, s in enumerate(x.get_size()): + if not (d in dims and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1))): + new_shape.append(s) + + # squeeze does nothing if the size isn't 1 + return view(x, new_shape) if new_shape != x.get_size() else x + + +@register_lowering(aten.squeeze_copy, type_promotion_kind=None) +def squeeze_copy(x, dim=None): + return clone(squeeze(x, dim)) + + +@register_lowering([aten.squeeze_]) +def squeeze_(x, dim=None): + val = squeeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +@register_lowering(aten.isinf) +def isinf(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isinf") + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.isnan) +def isnan(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isnan") + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.ceil) +def ceil(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("ceil") + return make_pointwise(fn)(x) + + +@register_lowering(aten.floor) +def floor(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("floor") + return make_pointwise(fn)(x) + + +@register_lowering(aten.round.default) +def round(x): + if is_integer_type(x): + return clone(x) + else: + fn = ops_wrapper("round") + return make_pointwise(fn)(x) + + +@register_lowering(aten.trunc) +def trunc(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("trunc") + return make_pointwise(fn)(x) + + +@register_lowering(aten.expand, type_promotion_kind=None) +def expand(x, sizes): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + (x,) = promote_constants([x]) + if isinstance(x, ir.BaseConstant): + return ExpandView.create(x, tuple(sizes)) + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + if tuple(x.get_size()) == tuple(sizes): + return x + + if not free_unbacked_symbols(x.get_size()): + x_size_product = V.graph.sizevars.size_hint(sympy_product(x.get_size())) + # TODO: It would be better to realize the input if any of its sizes + # are unbacked, because typically the size will be non-zero. However, + # this cannot be done directly as below as we'll choke on the size_hint + # here + if x_size_product > 0 and not free_unbacked_symbols(sizes): + # maybe realize input before broadcasting it + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(sizes)) // x_size_product + ) + return TensorBox(ExpandView.create(x.data, tuple(sizes))) + + +@register_lowering(prims.broadcast_in_dim, type_promotion_kind=None) +def broadcast_in_dim(a, shape, broadcast_dimensions): + s = list(shape) + for broadcast_dimension in broadcast_dimensions: + s[broadcast_dimension] = -1 + + v = a + for idx, x in enumerate(s): + if x != -1: + v = unsqueeze(v, idx) + + return expand(v, shape) + + +@register_lowering(aten.expand_as, type_promotion_kind=None) +def expand_as(x, y): + return expand(x, y.get_size()) + + +@register_lowering(aten.repeat) +def repeat(x, repeats): + old_size = list(x.get_size()) + if len(repeats) > len(old_size): + old_size = [sympy.Integer(1)] * (len(repeats) - len(old_size)) + old_size + x = view(x, list(old_size)) + assert len(repeats) == len(x.get_size()) + + new_size = list(x.get_size()) + + zero_tensor = False + for i in range(len(repeats)): + if repeats[i] == 0: + zero_tensor = True + new_size[i] = new_size[i] * repeats[i] + + if zero_tensor: + return empty(new_size, dtype=x.get_dtype(), device=x.get_device()) + if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)): + return clone(expand(x, new_size)) + + x_loader: Callable[[Any], Any] + + def inner_fn(index): + assert len(index) == len(repeats) + index = list(index) + for i in range(len(repeats)): + if repeats[i] != 1: + if old_size[i] == 1: + index[i] = sympy.Integer(0) + else: + index[i] = ModularIndexing(index[i], 1, old_size[i]) + return x_loader(index) + + old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size)) + if old_size_product > 0: + # maybe realize the input + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product + ) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(new_size), + ) + + +@register_lowering(aten._unsafe_view, type_promotion_kind=None) +@register_lowering(aten.view, type_promotion_kind=None) +@register_lowering(aten.reshape, type_promotion_kind=None) +def view(x, sizes): + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + return TensorBox(View.create(x.data, sizes)) + + +@register_lowering(aten.permute, type_promotion_kind=None) +def permute(x, dims): + assert isinstance(x, TensorBox) + assert isinstance(dims, (list, tuple)) + return TensorBox(PermuteView.create(x.data, tuple(dims))) + + +@register_lowering(aten.slice, type_promotion_kind=None) +def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True): + assert isinstance(x, TensorBox) + dim = _validate_dim(x, dim, 0) + return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp)) + + +@register_lowering(aten.as_strided, type_promotion_kind=None) +def as_strided(x, size, stride, storage_offset=None): + if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView): + # as_strided ignores views + x = x.data.unwrap_view() + x.realize() + if not ir.is_storage_and_layout(x): + raise NotImplementedError(f"unrealized as_strided({x}, ...)") + storage, old_layout = ir.as_storage_and_layout(x) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + [sympy.expand(s) for s in size], + [sympy.expand(s) for s in stride], + sympy.expand(storage_offset or 0), + ) + return TensorBox(ir.ReinterpretView(storage, new_layout)) + + +@register_lowering(aten.as_strided_, type_promotion_kind=None) +def as_strided_(x, size, stride, storage_offset=None): + assert isinstance(x, TensorBox) + x.data = as_strided(x, size, stride, storage_offset).data + return x + + +@register_lowering(aten.as_strided_copy, type_promotion_kind=None) +def as_strided_copy(x, size, stride, storage_offset=None): + result = as_strided(x, size, stride, storage_offset) + return clone(result) + + +def pointwise_cat(inputs, dim=0): + # (inclusive, exclusive) + inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = [] + prev_end = 0 + for inp in inputs: + inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type] + prev_end = inputs_ranges[-1][-1] # type: ignore[assignment] + + inputs_loaders = [inp.make_loader() for inp in inputs] + + def inner_fn(idx): + idx_dim = ops.index_expr(idx[dim], torch.int64) + + masks = [] + masked_loads = [] + for i in range(len(inputs)): + start = ( + ops.constant(0, torch.int64) + if i == 0 + else ops.index_expr(inputs_ranges[i][0], torch.int64) + ) + end = ops.index_expr(inputs_ranges[i][1], torch.int64) + + start_cond = ops.ge(idx_dim, start) + end_cond = ops.lt(idx_dim, end) + if i == 0: + mask = end_cond + elif i == len(inputs) - 1: + mask = start_cond + else: + mask = ops.and_(start_cond, end_cond) + + masks.append(mask) + idx_load = list(idx) + + # if we're concatting [4], [2] + # when we index the second tensor for 5 we want to index 5 - 4 + # Use Identity to prevent expansion of index * stride to keep expression + # in same int bitwidth as shape + idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0]) + + masked_loads.append( + ops.masked( + mask, + lambda: inputs_loaders[i](idx_load), + 0.0, # this value should be unused + ), + ) + + next_val = masked_loads[-1] + for i in range((len(inputs)) - 2, -1, -1): + next_val = ops.where( + masks[i], + masked_loads[i], + next_val, + ) + return next_val + + new_size = list(inputs[0].get_size()) + new_size[dim] = inputs_ranges[-1][-1] + + return Pointwise.create( + device=inputs[0].get_device(), + dtype=inputs[0].get_dtype(), + inner_fn=inner_fn, + ranges=new_size, + ) + + +@register_lowering(quantized_decomposed.quantize_per_channel, type_promotion_kind=None) +def quantized_decomposed_quantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert ( + input.get_dtype() == torch.float32 + ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + assert axis < len( + input.get_size() + ), f"Expecting axis to be < {len(input.get_size())}" + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.int32: + zero_point = ops.to_dtype(zero_point, torch.int32) + inv_scale = ops.reciprocal(scale) + val = ops.round(input * inv_scale) + zero_point + clamped = ops.maximum(qmin, ops.minimum(qmax, val)) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_channel, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + assert ( + input.get_dtype() == dtype + ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + assert axis < len( + input.get_size() + ), f"Expecting axis to be < {len(input.get_size())}" + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.float32: + zero_point = ops.to_dtype(zero_point, torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale + return val + + return Pointwise.create( + device=input.get_device(), + dtype=torch.float32, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None +) +def quantized_decomposed_quantize_per_tensor_default( + input: TensorBox, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert ( + input.get_dtype() == torch.float32 + ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + + input_loader = input.make_loader() + + def inner_fn(idx, scale, zero_point): + input = input_loader(idx) + inv_scale, zero_point = _create_constants( + 1.0 / scale, zero_point, dtype=torch.float32 + ) + val = ops.round(input * inv_scale) + zero_point + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=functools.partial( + inner_fn, scale=float(scale), zero_point=int(zero_point) + ), + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_tensor_default( + input: TensorBox, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert ( + input.get_dtype() == dtype + ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + + input_loader = input.make_loader() + + def inner_fn(idx, scale, zero_point): + input = input_loader(idx) + scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale + return val + + return Pointwise.create( + device=input.get_device(), + dtype=torch.float32, + inner_fn=functools.partial( + inner_fn, scale=float(scale), zero_point=int(zero_point) + ), + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None +) +def quantized_decomposed_quantize_per_tensor_tensor( + input: TensorBox, + scale: TensorBox, + zero_point: TensorBox, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert ( + input.get_dtype() == torch.float32 + ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + assert len(scale.get_size()) == 0 or ( + len(scale.get_size()) == 1 and scale.get_size()[0] == 1 + ), "expect scale as scalar tensor" + assert len(zero_point.get_size()) == 0 or ( + len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 + ), "expect zero_point as scalar tensor" + + input_loader = input.make_loader() + scale_loader = scale.make_loader() + zero_point_loader = zero_point.make_loader() + + def inner_fn(idx): + input = input_loader(idx) + _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) + _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) + if scale.dtype != torch.float32: + _scale = ops.to_dtype(_scale, torch.float32) + if zero_point.dtype != torch.float32: + _zero_point = ops.to_dtype(_zero_point, torch.float32) + val = ops.round(input * ops.reciprocal(_scale)) + _zero_point + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_tensor_tensor( + input: TensorBox, + scale: TensorBox, + zero_point: TensorBox, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scale.get_size()) == 0 or ( + len(scale.get_size()) == 1 and scale.get_size()[0] == 1 + ), "expect scale as scalar tensor" + assert len(zero_point.get_size()) == 0 or ( + len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 + ), "expect zero_point as scalar tensor" + assert ( + input.get_dtype() == dtype + ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + + input_loader = input.make_loader() + scale_loader = scale.make_loader() + zero_point_loader = zero_point.make_loader() + + def inner_fn(idx): + input = input_loader(idx) + _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) + _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) + if scale.dtype != torch.float32: + _scale = ops.to_dtype(_scale, torch.float32) + if zero_point.dtype != torch.float32: + _zero_point = ops.to_dtype(_zero_point, torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale + return val + + return Pointwise.create( + device=input.get_device(), + dtype=torch.float32, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering(aten.cat) +def cat(inputs, dim=0): + cpu_device = inputs[0].get_device().type == "cpu" + if cpu_device and all( + input.get_dtype() in [torch.int8, torch.uint8] for input in inputs + ): + # TODO Remove this fallback when we support vectorization + # code gen with uint8 data type directly. + for input in inputs: + input.realize() + if all(len(input.get_size()) == 4 for input in inputs): + inputs, _ = require_channels_last(aten.cat, *inputs) + return fallback_handler(aten.cat.default)(inputs, dim) + + if len(inputs) == 1: + return clone(inputs[0]) + + dim = _validate_dim(inputs[0], dim, 0) + dtype = get_promoted_dtype( + *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + inputs = [to_dtype(inp, dtype) for inp in inputs] + + def unwrap_tensor(x: Union[TensorBox, ir.StorageBox]) -> ir.IRNode: + if isinstance(x, TensorBox): + if isinstance(x.data, ir.BaseView): + return x.data.unwrap_view() + else: + return x.data + + if isinstance(x, ir.StorageBox): + return x.data + + return x + + def is_reduction(t): + return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction) + + def can_fuse_reduction(t): + if isinstance(t, (TensorBox, ir.StorageBox)): + return can_fuse_reduction(unwrap_tensor(t)) + return ( + is_reduction(t) + or isinstance(t, ir.Pointwise) + and any( + can_fuse_reduction(V.graph.get_buffer(read)) + for read in t.get_read_names() + ) + ) + + # fusing reducutions into computed concat buffer can cause regressions. + fusable_reduction = any(can_fuse_reduction(t) for t in inputs) + + def should_lower_cat_input(x) -> bool: + # Unrealized inputs will not be storage and layouts, and we dont want to realize + # them in case we want to fuse + if ir.is_storage_and_layout(x): + storage, _ = ir.as_storage_and_layout(x, freeze=False) + return not ir.ConcatKernel.can_realize_into_without_copy(storage) + + if isinstance(x, (TensorBox, ir.StorageBox)): + return should_lower_cat_input(unwrap_tensor(x)) + + if isinstance(x, ir.Pointwise): + return True + + return False + + # TODO: We observed negative performance impact of pointwise_cat optimization on CPU so disabled it. + # We will revisit this later after enabling vectorization on index_expr. + if cpu_device: + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + def op_count(x): + if isinstance(x, (TensorBox, ir.StorageBox)): + return op_count(unwrap_tensor(x)) + + # this will correspond to a direct memory read + if not isinstance(x, ir.Pointwise): + return 0 + + count = x.inner_fn_opcount().num_ops + for read in x.get_read_names(): + count += op_count(V.graph.get_buffer(read)) + + return count + + # as of inputs increase, possibility for register spilling also increases + # past a certain threshold of inputs we only fuse if the if the input kernels + # are simple + # not sure if we want to expose to users via config since logic may change in future + MAX_COMPLEX_POINTWISE_CAT = 8 + MAX_SIMPLE_OP_COUNT = 2 + + def additional_pointwise_ops(op: torch._ops.OpOverload): + return op in (aten.cat.default, aten.constant_pad_nd.default) + + if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or ( + (len(inputs) <= config.max_pointwise_cat_inputs) + and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs) + ): + pointwise_uses = all( + is_pointwise_use(use, additional_pointwise_ops) + for use in V.current_node.users + ) + # fuse in case we will be used in a pointwise node, and there are any inputs we + # we can prevent materialization of. + fuse_pointwise_use = ( + any(should_lower_cat_input(inp) for inp in inputs) and pointwise_uses + ) + + # horizontal fuse in case all inputs will require a copy kernel anyway. + # only horizontally fuse pointwise kernels + horizontal_fuse_cat = all( + should_lower_cat_input(inp) for inp in inputs + ) and not any(can_fuse_reduction(t) for t in inputs) + if fuse_pointwise_use or (horizontal_fuse_cat and not fusable_reduction): + return pointwise_cat(inputs, dim) + + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + +@register_lowering(aten.diagonal, type_promotion_kind=None) +def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + original_shape = input.get_size() + num_dims = len(original_shape) + dim1 = canonicalize_dim(idx=dim1, rank=num_dims) + dim2 = canonicalize_dim(idx=dim2, rank=num_dims) + + check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0)) + if offset_negative: + diag_size = V.graph.sizevars.evaluate_max( + V.graph.sizevars.evaluate_min( + original_shape[dim1] + offset, original_shape[dim2] + ), + 0, # type: ignore[arg-type] + ) + else: + diag_size = V.graph.sizevars.evaluate_max( + V.graph.sizevars.evaluate_min( + original_shape[dim1], original_shape[dim2] - offset + ), + 0, # type: ignore[arg-type] + ) + + base_idx = (0, 0) + if offset_negative: + base_idx = (-offset, 0) + else: + base_idx = (0, offset) + + sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)] + sizes.append(diag_size) + + def reindexer(idx): + diag_idx = idx[-1] + original_idx = [0] * len(original_shape) + cur_dim = 0 + for d in range(num_dims): + if d == dim1: + original_idx[d] = diag_idx + base_idx[0] + elif d == dim2: + original_idx[d] = diag_idx + base_idx[1] + else: + original_idx[d] = idx[cur_dim] + cur_dim += 1 + + assert cur_dim == len(original_shape) - 2 + return original_idx + + return TensorBox(ir.GenericView.create(input, sizes, reindexer)) + + +@register_lowering(aten.diagonal_copy, type_promotion_kind=None) +def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + return clone(diagonal(input, offset, dim1, dim2)) + + +@register_lowering(aten.diagonal_scatter, type_promotion_kind=None) +def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1): + output = clone(input) + target = diagonal(output, offset, dim1, dim2) + mutate_to(target, src) + return output + + +@register_lowering(aten.select, type_promotion_kind=None) +def select(x, dim, idx): + idx = View.handle_negative_index(idx, x.get_size()[dim]) + return squeeze(slice_(x, dim, idx, idx + 1), dim) + + +@register_lowering(aten.split, type_promotion_kind=None) +def split(x, sizes, dim=0, clamp=True): + dim = _validate_dim(x, dim, 0) + if isinstance(sizes, sympy.Expr): + # TODO: We don't have to guard on sizes per se, but the number + # of splits must stay constant + sizes = V.graph.sizevars.evaluate_static_shape(sizes) + if isinstance(sizes, (int, sympy.Integer)): + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + sizes = [sizes] * ((x_size + sizes - 1) // sizes) + result = [] + start = 0 + for size in sizes: + end = start + size + result.append(slice_(x, dim, start, end, clamp=clamp)) + start = end + return result + + +@register_lowering(aten.split_with_sizes, type_promotion_kind=None) +def split_with_sizes(x, sizes, dim=0): + return split(x, sizes, dim, clamp=False) + + +@register_lowering(aten.unbind, type_promotion_kind=None) +def unbind(x, dim=0): + dim = _validate_dim(x, dim, 0) + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + result = [] + for i in range(x_size): + result.append(select(x, dim, i)) + return result + + +@register_lowering(aten.unfold, type_promotion_kind=None) +def unfold(x, dimension, size, step): + sizes = x.get_size() + ndim = len(sizes) + dim = canonicalize_dim(ndim, dimension) + + if ndim == 0: + return slice_(unsqueeze(x, 0), end=size) + + dim_size = sizes[dim] + sizevars = V.graph.sizevars + sizevars.guard_leq(size, dim_size) + sizevars.guard_lt(0, step) # type: ignore[arg-type] + + new_dim_size = FloorDiv(dim_size - size, step) + 1 + if sizevars.size_hint(dim_size) > 0: + x.mark_reuse(sizevars.size_hint(CeilDiv(new_dim_size * size, dim_size))) + + out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size] + + def reindexer(idx): + dim_idx = idx[-1] + idx[dim] * step + return (*idx[:dim], dim_idx, *idx[dim + 1 : -1]) + + return TensorBox(ir.GenericView.create(x, out_size, reindexer)) + + +@register_lowering(aten.unsqueeze, type_promotion_kind=None) +def unsqueeze(x, dim): + dim = _validate_dim(x, dim, 1) + new_shape = list(x.get_size()) + new_shape.insert(dim, sympy.Integer(1)) + return view(x, new_shape) + + +@register_lowering(aten.unsqueeze_, type_promotion_kind=None) +def unsqueeze_(x, dim): + val = unsqueeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +def _validate_dim(x, dim, offset=0): + dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim)) + ndim = len(x.get_size()) + if dim < 0: + dim += ndim + offset + assert 0 <= dim < ndim + offset + return dim + + +@register_lowering(aten.glu) +def glu(x, dim=-1): + dim = _validate_dim(x, dim, 0) + # TODO: don't guard on static shape here + new_len = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) // 2 + a = slice_(x, dim, 0, new_len) + b = slice_(x, dim, new_len, new_len * 2) + return mul(a, sigmoid(b)) + + +def fallback_handler(kernel, add_to_fallback_set=True): + if add_to_fallback_set: + fallbacks.add(kernel) + + def handler(*args, **kwargs): + def wrap_tensors(x): + return TensorBox.create(x) if isinstance(x, ir.IRNode) else x + + return pytree.tree_map( + wrap_tensors, ir.FallbackKernel.create(kernel, *args, **kwargs) + ) + + return handler + + +@functools.lru_cache(None) +def _warn_complex_not_supported(): + warnings.warn( + "Torchinductor does not support code generation for complex operators. Performance may be worse than eager." + ) + + +# There are some types (CPU) which we accept as input but not as +# output. +def unsupported_input_tensor(t: torch.Tensor, parent=None): + "Do not support reading or writing to this tensor" + if t.is_complex(): + # Complex views are supported with IR ComplexView + if parent and parent.target in ( + torch.ops.aten.view.dtype, + torch.ops.prims.convert_element_type.default, + ): + return False + _warn_complex_not_supported() + return True + return False + + +def unsupported_output_tensor(t: torch.Tensor, parent=None): + "Do not support writing tensor but can read from it" + if unsupported_input_tensor(t, parent): + return True + return t.is_cpu and config.disable_cpp_codegen + + +def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True): + # Custom fallback lowering + if node.target is aten.view_as_complex.default: + return False + + # We should be able to remove this special case once `disable_cpp_codegen` is killed. + if node.target is aten.lift_fresh_copy.default: + return False + + def check_skip_condition(node, parent, is_output): + if not isinstance(node, torch.fx.Node): + return False + + if "val" not in node.meta: + return False + + for meta in pytree.tree_leaves(node.meta["val"]): + if not isinstance(meta, torch._subclasses.FakeTensor): + continue + + if is_output: + if unsupported_output_tensor(meta, parent): + return True + else: + if unsupported_input_tensor(meta, parent): + return True + + return False + + # only skip codegen if there is a cpu output, not input + for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs): + if check_skip_condition(arg, node, is_output=False): + return True + + return check_skip_condition(node, node, is_output=True) + + +def make_fallback(op, layout_constraint=None, warn=True): + assert op not in decompositions, f"both a fallback and a decomp for same op: {op}" + if ( + warn + and bool(os.getenv("CI")) + and get_decompositions([op]) + # if fallback_random, we allow not decomposing random + and not ( + config.fallback_random + and op in torch._decomp.decompositions_for_rng.extra_random_decomps + ) + ): + # Note: 'warn' is holdover from when this was a warning, but for ops that previously + # set warn=False we do not want a CI error. + # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not + # likely to be triggered preferentially on one CI config over another. + if torch._dynamo.config.suppress_errors: + torch._dynamo.config.suppress_errors = False + log.warning( + "A make_fallback error occurred in suppress_errors config," + " and suppress_errors is being disabled to surface it." + ) + raise AssertionError( + f"make_fallback({op}): a decomposition exists, we should switch to it." + " To fix this error, either add a decomposition to core_aten_decompositions (preferred)" + " or inductor_decompositions, and delete the corresponding `make_fallback` line." + " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.", + ) + + def register_fallback(op_overload): + add_needs_realized_inputs(op_overload) + if layout_constraint is not None: + add_layout_constraint(op_overload, layout_constraint) + return register_lowering(op_overload, type_promotion_kind=None)( + fallback_handler(op_overload) + ) + + if isinstance(op, torch._ops.OpOverloadPacket): + for ol in op.overloads(): + op_overload = getattr(op, ol) + register_fallback(op_overload) + elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + register_fallback(op) + else: + raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}") + + +def philox_rand_offset(shape): + """ + TorchInductor offset calculation differs from PyTorch eager offset + calculation for random ops (tl.rand vs torch.rand). In future, we should + strive for same impl for tl.rand and torch.rand. + """ + numel = 1 + for s in shape: + numel = numel * s + return tensor(numel, dtype=torch.int64) + + +@register_lowering(torch.ops.rngprims.philox_rand, type_promotion_kind=None) +def philox_rand(size, seed, offset, stride, device, dtype): + # stride arg is optional and will be used in future for distributed random + # ops. Currently, its unused. + random_pos = ir.FixedLayout( + device, + dtype, + size, + ir.FlexibleLayout.contiguous_strides(size), + ).make_indexer() + seed_loader = seed.make_loader() + offset_loader = offset.make_loader() + + def inner_fn(index): + # Both seed and offset in the philox_rand op are tensors. + # torch seed and offsets are of type int64, but tl.rand accepts int32 + seed_index_expr = ops.to_dtype(seed_loader([]), torch.int32) + offset_index_expr = ops.to_dtype(offset_loader([]), torch.int32) + # Get the offset'd position + rand_index_expr = ops.add( + ops.index_expr(random_pos(index), torch.int32), offset_index_expr + ) + result = ops.rand( + seed_index_expr, + rand_index_expr, + ) + return ops.to_dtype(result, dtype) + + random_values_node = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + ) + + offset_node = philox_rand_offset(size) + return random_values_node, offset_node + + +@register_lowering(aten.native_dropout, type_promotion_kind=None) +def native_dropout(x, p, train): + if config.fallback_random: + return pytree.tree_map( + TensorBox.create, + ir.FallbackKernel.create(aten.native_dropout.default, x, p, train), + ) + else: + raise AssertionError("should be handled in replace_random.py") + + +@register_lowering(aten.bernoulli_, type_promotion_kind=None) +def bernoulli_(x, *args): + assert config.fallback_random or x.get_device() == torch.device( + "cpu" + ), "this should be handled in decomps unless config.fallback_random or the device is CPU" + x.realize() + op_overload = ( + aten.bernoulli_.float + if len(args) == 0 or isinstance(args[0], float) + else aten.bernoulli_.Tensor + ) + ir.InplaceBernoulliFallback(op_overload, x, *args) + return x + + +@register_lowering(aten.bernoulli.p, type_promotion_kind=None) +def bernoulli_p(x, *args): + assert config.fallback_random or x.get_device() == torch.device( + "cpu" + ), "this should be handled in decomps unless config.fallback_random or the device is CPU" + return bernoulli_(clone(x), *args) + + +# This shouldn't be called in general +@register_lowering(aten._foobar) +def _foobar(_): + raise AssertionError + + +@functools.lru_cache(1) +def _warn_triton_random(salt): + log.info("using triton random, expect difference from eager") + + +def warn_triton_random(): + # only warn once per graph + _warn_triton_random(V.graph.creation_time) + + +fallback_rand_default = fallback_handler(aten.rand.default) +fallback_rand_generator = fallback_handler(aten.rand.generator) +fallback_randn_default = fallback_handler(aten.randn.default) +fallback_randn_generator = fallback_handler(aten.randn.generator) +make_fallback(aten.randint) + + +@register_lowering(aten.rand) +def rand(*args, **kwargs): + if kwargs.get("generator", None) is not None: + return fallback_rand_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_rand_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(aten.randn) +def randn(*args, **kwargs): + if kwargs.get("generator", None) is not None: + return fallback_randn_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_randn_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(inductor_prims.force_stride_order, type_promotion_kind=None) +def inductor_force_stride_order(input_tensor, stride): + stride_order = ir.get_stride_order(stride) + return ir.ExternKernel.require_stride_order(input_tensor, stride_order) + + +@register_lowering(inductor_prims.seed, type_promotion_kind=None) +def inductor_seed(device: torch.device): + raise AssertionError("should be handled in fuse_seed_creation_pass()") + + +@register_lowering(inductor_prims.seeds, type_promotion_kind=None) +def inductor_seeds(count, device): + warn_triton_random() + return TensorBox.create(ir.RandomSeeds(count, decode_device(device))) + + +@register_lowering(inductor_prims.lookup_seed, type_promotion_kind=None) +def inductor_lookup_seed(seeds, index): + def inner_fn(_): + return ops.load_seed(seeds.get_name(), index) + + return Pointwise.create( + device=seeds.get_device(), + dtype=seeds.get_dtype(), + inner_fn=inner_fn, + ranges=[], + ) + + +@register_lowering(inductor_prims.random, type_promotion_kind=None) +def inductor_random(size: List[int], seed: TensorBox, mode: str, *, offset: int = 0): + assert not config.fallback_random + assert mode in ("rand", "randn") + size = [*size] + dtype = torch.float32 + device = seed.get_device() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return getattr(ops, mode)( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + ) + + result = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + result.realize() + return result + + +@register_lowering(inductor_prims.randint, type_promotion_kind=None) +def inductor_randint( + low: int, high: int, size: List[int], seed: TensorBox, *, offset: int = 0 +): + assert not config.fallback_random + size = [*size] + dtype = torch.int64 + device = seed.get_device() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return ops.randint64( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + ops.index_expr(low, torch.int64), + ops.index_expr(high, torch.int64), + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + + +@register_lowering(aten.bucketize, type_promotion_kind=None) +def bucketize( + input: TensorBox, + boundaries: TensorBox, + *, + out_int32: bool = False, + right: bool = False, +): + assert len(boundaries.get_size()) == 1 + + if not ( + V.graph.has_feature(input, BackendFeature.BUCKETIZE) + and V.graph.has_feature(boundaries, BackendFeature.BUCKETIZE) + ): + return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)( + input, boundaries, out_int32=out_int32, right=right + ) + + # The entire boundaries tensor needs to be used by ops.bucketize, so we + # need to realize it into global memory; or in other words, we can't + # guarantee that boundaries.get_name() (used below) will exist unless + # we call boundaries.realize(). + boundaries.realize() + boundaries_size = boundaries.get_size()[0] + device = input.get_device() + input_loader = input.make_loader() + + index_dtype = torch.int32 if out_int32 else torch.int64 + + def inner_fn(index): + val = input_loader(index) + indices = ops.bucketize( + val, + boundaries.get_name(), + boundaries_size, + index_dtype, + right, + ) + + return indices + + return Pointwise.create( + device=device, + dtype=index_dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +def require_dense(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs) + ) + return args, kwargs + + +def require_contiguous(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs) + ) + return args, kwargs + + +def require_channels_last(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs) + ) + return args, kwargs + + +def constrain_to_fx_strides(fx_node, *args, **kwargs): + def apply_constraint(arg, fx_arg): + if isinstance(arg, ir.IRNode): + stride_order = ir.get_stride_order(fx_arg.meta["val"].stride()) + return ir.ExternKernel.require_stride_order(arg, stride_order) + return arg + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +# TODO(jansel): we should implement decomps or lowerings for these +# https://github.com/pytorch/torchdynamo/issues/327 +FALLBACK_ALLOW_LIST = { + "torchvision::roi_align", +} + + +def sdpa_constraint(fx_node, *args, **kwargs): + # sdpa requires dense last dimension] + + def apply_constraint(arg, fx_arg): + if not isinstance(arg, ir.IRNode): + return arg + + meta_val = fx_arg.meta["val"] + meta_stride = meta_val.stride() + + stride_order = ir.get_stride_order(meta_stride) + if stride_order and stride_order[-1] != 0: + # contiguous stride order + stride_order = list(reversed(range(len(arg.get_size())))) + + if not meta_val.is_cuda: + return ir.ExternKernel.require_stride_order(arg, stride_order) + + # This is the minimum alignment required by SDPA kernels for attention_bias. + # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask + ALIGNMENT = 8 + + assert isinstance(arg, TensorBox) + if len(arg.get_size()) not in (3, 4): + return arg + + def is_aligned_realized_tensor(x): + aligned_strides = all( + (V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0 + for i in range(len(x.get_stride()) - 1) + ) + return ( + V.graph.sizevars.size_hint(x.get_stride()[-1]) + ) == 1 and aligned_strides + + try: + arg.get_stride() + if is_aligned_realized_tensor(arg): + return V.graph.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride + ) + except AttributeError: + pass + + def is_aligned(x): + return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 + + if isinstance(arg.data, ir.BaseView): + if not is_aligned(arg): + if is_aligned(arg.unwrap_view()): + return V.graph.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride + ) + + return ir.ExternKernel.require_stride_order(arg, stride_order) + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +# WIP +make_fallback(aten._adaptive_avg_pool3d) # @isuruf +make_fallback(aten.adaptive_max_pool3d) # @isuruf +make_fallback(aten.fractional_max_pool3d) # @isuruf +make_fallback(aten.max_pool3d_with_indices) # @isuruf (can this one be implemented?) + + +# 1) Easy +make_fallback(aten.uniform, warn=False) +make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py) +make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks +make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl? +make_fallback(aten.searchsorted) # bucketized is implemented (see eager impl) + + +# 1.5) Easy or Impossible +make_fallback(aten._cdist_forward) # p=2 should be feasible +make_fallback(aten._cdist_backward) + +# 2) Medium +make_fallback(aten.max_unpool2d) +make_fallback(aten.max_unpool3d) +make_fallback(aten._trilinear) + + +# 3) Difficult +# Scans +# See the discussion at +# https://dev-discuss.pytorch.org/t/pytorch-sparse-gnn-compiler-rfc/1644/19 +make_fallback(aten.segment_reduce.default) +make_fallback(aten._segment_reduce_backward.default) + +# Histogram (need to implement Histogram IR) +make_fallback(aten.histc) +make_fallback(aten.histogram.bin_ct) +make_fallback(aten._histogramdd_bin_edges.default) +make_fallback(aten._histogramdd_from_bin_cts.default) + +# Need templated kernel +make_fallback(aten.addbmm) +make_fallback(aten._addmm_activation, warn=False) + +# Need templated kernel. Probably impossible to write efficiently +make_fallback(aten.convolution_backward, constrain_to_fx_strides) +make_fallback(aten._cudnn_rnn, require_dense) +make_fallback(aten._cudnn_rnn_backward, require_contiguous) + +# Haven't checked but sound difficult / impossible +make_fallback(aten._embedding_bag, require_contiguous) +make_fallback(aten._embedding_bag_forward_only, require_contiguous) +make_fallback(aten._embedding_bag_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._fused_moving_avg_obs_fq_helper) +make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) + + +# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp +make_fallback(aten.max_pool3d_with_indices_backward) +make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) +make_fallback(aten._adaptive_avg_pool3d_backward) +make_fallback(aten.adaptive_max_pool2d_backward) +make_fallback(aten.adaptive_max_pool3d_backward) +make_fallback(aten.fractional_max_pool2d_backward) +make_fallback(aten.fractional_max_pool3d_backward) +make_fallback(aten.replication_pad1d_backward) +make_fallback(aten.replication_pad2d_backward) +make_fallback(aten.upsample_linear1d_backward) +make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) +make_fallback(aten.upsample_trilinear3d_backward) +make_fallback(aten.grid_sampler_2d_backward, require_dense) +make_fallback(aten._pdist_backward) + + +# 5) Impossible (missing triton/CPU features) + +# Sorting / Sorting-like +make_fallback(aten.sort) +make_fallback(aten.sort.stable) +make_fallback(aten.kthvalue) +make_fallback(aten.topk) +make_fallback(aten.mode) +make_fallback(aten.median) +make_fallback(aten.nanmedian) +make_fallback(aten.randperm) +# see: https://github.com/pytorch/pytorch/pull/121354 +make_fallback(aten.resize_) +make_fallback(aten.resize_as_) + +# Linalg +make_fallback(aten._linalg_det) +make_fallback(aten.linalg_householder_product) +make_fallback(aten.linalg_inv_ex) +make_fallback(aten.linalg_ldl_factor_ex) +make_fallback(aten.linalg_ldl_solve) +make_fallback(aten.linalg_lu) +make_fallback(aten.linalg_lu_factor_ex) +make_fallback(aten.linalg_lu_solve) +make_fallback(aten.linalg_matrix_exp) +make_fallback(aten.linalg_qr) +make_fallback(aten._linalg_slogdet) +make_fallback(aten._linalg_solve_ex) +make_fallback(aten.linalg_solve_triangular) +make_fallback(aten._linalg_svd) +make_fallback(aten.lu_unpack) +make_fallback(aten.ormqr) +make_fallback(aten._linalg_check_errors) +make_fallback(aten.linalg_pinv.atol_rtol_tensor) +make_fallback(aten._linalg_eigh) +make_fallback(aten.triangular_solve) +make_fallback(aten.linalg_cholesky_ex) +make_fallback(aten.cholesky_inverse) +make_fallback(aten.cholesky_solve) +make_fallback(aten.geqrf) +make_fallback(aten._fft_r2c) # needs complex as well + +# Data dependent (are these necessary?) +make_fallback(aten.nonzero.default) + +# Misc +make_fallback(aten.gcd.default, warn=False) +make_fallback(aten._thnn_fused_lstm_cell, require_dense) +make_fallback(torch._prims.rng_prims.run_and_save_rng_state) +make_fallback(torch._prims.rng_prims.run_with_rng_state) + +# Implmented / Half implemented +# Scans. Implemented for CUDA, missing CPU +make_fallback(aten.masked_scatter) +make_fallback(aten.masked_scatter_backward) + +# Complex number support +make_fallback(aten.view_as_complex, require_contiguous) +make_fallback(aten.angle) # needs complex + +# Needs efficentzerotensor +make_fallback(aten._efficientzerotensor) + +# Needs Sparse +make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) +make_fallback(aten.to_sparse) +make_fallback(aten._to_sparse) + +# Needs dimname support +make_fallback(aten.zeros.names) + +# 6) Pattern-matched +make_fallback( + aten._scaled_dot_product_efficient_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_efficient_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_cudnn_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_cudnn_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback(aten._flash_attention_forward.default, sdpa_constraint) +make_fallback(aten._flash_attention_backward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_forward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_backward.default, sdpa_constraint) + +# index_reduce requires fallback when use_scatter_fallback(...) returns True +make_fallback(aten.index_reduce) + + +# Register with type_promotion_kind None. +# For example, fp16.copy_(fp32) should **not** promote the first input's dtype. +@register_lowering(aten.copy, type_promotion_kind=None) +def copy(self, src, non_blocking=False): + x = src + if self.get_device() != src.get_device(): + x = to_device(x, self.get_device()) + if self.get_dtype() != src.get_dtype(): + x = to_dtype(x, self.get_dtype()) + + if self.get_size() != src.get_size(): + out = expand(x, self.get_size()) + return clone(out) + return clone(x) + + +@register_lowering(aten.clone) +def clone(x, *, memory_format=None): + # TODO(jansel): memory format + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=list(x.get_size()), + ) + + +def clone_preserve_reinterpret_view(x): + reinterpret_view_layouts = [] + if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView): + x = x.data # unwrap TensorBox + while isinstance(x, ir.ReinterpretView): + reinterpret_view_layouts.append(x.get_layout()) + x = x.data + x = TensorBox(x) + + x = clone(x) + + if reinterpret_view_layouts: + x = x.data # unwrap TensorBox + for layout in reinterpret_view_layouts[::-1]: + x = ir.ReinterpretView(x, layout) + x = TensorBox(x) + + return x + + +if hasattr(aten, "lift_fresh_copy"): + register_lowering(aten.lift_fresh_copy)(clone) + + +@register_lowering(prims.iota) +def iota( + length, + *, + start, + step, + dtype, + device, + requires_grad, +): + def fn(index): + return ops.index_expr(step * index[0] + start, dtype=dtype) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=fn, + ranges=[length], + ) + + +@register_lowering(aten.select_scatter, type_promotion_kind=None) +def select_scatter(x, src, dim: int, index: int): + assert x.get_dtype() == src.get_dtype() + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)): + index = index + x.get_size()[dim] + V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type] + V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type] + src = expand(unsqueeze(src, dim), x.get_size()) + src_loader = src.make_loader() + + def inner_fn(idx): + return ops.where( + ops.eq( + ops.index_expr(idx[dim], torch.int32), + ops.index_expr(index, torch.int32), + ), + src_loader(idx), + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + ) + + +@register_lowering(aten.slice_scatter, type_promotion_kind=None) +def slice_scatter(x, src, dim=0, start=None, end=None, step=1): + assert x.get_dtype() == src.get_dtype() + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + dim_size = x.get_size()[dim] + + start, end = ir.SliceView.normalize_start_end(x, dim, start, end) + + src_size = list(x.get_size()) + src_size[dim] = FloorDiv(end - start + (step - 1), step) + src = expand(src, src_size) + src_loader = src.make_loader() + + def inner_fn(idx): + if start == 0 and end == dim_size and step == 1: + # selecting every element is the same as just src.clone() + return src_loader(idx) + + idx_dim = ops.index_expr(idx[dim], torch.int64) + src_idx = list(idx) + src_idx[dim] = FloorDiv(idx[dim] - start, step) + + mask = [] + if start != 0: + mask.append( + ops.ge( + idx_dim, + ops.index_expr(sympy.expand(start), torch.int64), + ) + ) + if end != dim_size: + mask.append( + ops.lt( + idx_dim, + ops.index_expr(sympy.expand(end), torch.int64), + ) + ) + if step != 1: + mask.append( + ops.eq( + ops.index_expr( + ModularIndexing(idx[dim] - start, 1, step), torch.int64 + ), + ops.constant(0, torch.int64), + ) + ) + assert mask + mask = functools.reduce(ops.and_, mask) + src_val = ops.masked( + mask, + lambda: src_loader(src_idx), + 0 if is_integer_type(x) else 0.0, + ) + return ops.where( + mask, + src_val, + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + ) + + +def _unwrap(x): + if isinstance(x, (list, tuple)) and len(x) > 0: + return _unwrap(x[0]) + return x + + +@register_lowering([torch.tensor, aten.scalar_tensor]) +def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + if isinstance(_unwrap(data), int): + dtype = dtype or torch.int64 + else: + dtype = dtype or torch.get_default_dtype() + + ranges: List[sympy.Expr] = [] + + if isinstance(data, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(data, dtype) + + elif isinstance(data, (float, int)): + + def inner_fn(index): + return ops.constant(data, dtype) + + elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8: + # inline small tensors + ranges.append(sympy.Integer(len(data))) + + def inner_fn(index): + def binary_search(start, end): + assert start < end + if end - start == 1: + return ops.constant(data[start], dtype) + mid = (end - start) // 2 + start + return ops.where( + ops.lt( + ops.index_expr(index[0], torch.int64), + ops.constant(mid, torch.int64), + ), + binary_search(start, mid), + binary_search(mid, end), + ) + + if len(data) == 0: + return ops.constant(0, dtype) + return binary_search(0, len(data)) + + else: + return V.graph.add_tensor_constant( + torch.tensor(data, dtype=dtype, device=device) + ) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + ) + + +@register_lowering(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if isinstance(data, TensorBox): + if dtype is not None: + data = to_dtype(data, dtype) + if device is not None: + data = to_device(data, device) + return data + return tensor(data, dtype=dtype, device=device) + + +@register_lowering(torch.LongTensor) +def long_tensor(data): + return tensor(data, dtype=torch.int64) + + +@register_lowering(aten._local_scalar_dense) +def _local_scalar_dense(data): + from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings + + # This is interesting! Most lowerings return tensors, so you can just + # return the buffer you allocated and it will get used (or not used, if + # it's dead.) But _local_scalar_dense (aka item) returns an int, + # not a Tensor, so you would have a type mismatch if you return a buffer; + # we are obligated to return a sympy expression instead. However, + # we need to actually codegen the .item() call somehow. We do this + # by registering a faux buffer for the DynamicScalar IR node, which is + # solely responsible for generating this .item(). The buffer is + # not used for anything (notice we discard it); at codegen time, + # the "buffer" just gets assigned None. + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] + ) + assert len(unbacked_bindings) == 1, unbacked_bindings + # NB: Have to be very careful here. V.graph.current_node.meta["val"] + # seemingly also contains a symbol which you want to do binding for, + # but it actually isn't. In particular, if we have later performed + # a deferred runtime assert saying that u0 == s0, you will actually + # see s0 from expr! This is bad because we need to actually generate + # the assert that says u0 == s0, so we need to know where to get u0 + # from (this call). In particular, we must use unbacked_bindings, which + # is guaranteed to have the original, unreplaced symbol in question. + # + # NB2: Another thing we have to be very careful about are symbol bindings + # that require nontrivial refinement, e.g., when you have a binding site + # x: Sym(u0 * 4) = y.item(). Here, the code generation must do a division + # in order to appropriately bind u0. This is communicated via the keypath + # in unbacked_bindings, and we need to hold onto it in order to generate + # code appropriately for this case. + binding_sym, keypath = next(iter(unbacked_bindings.items())) + buffer = ir.DynamicScalar(binding_sym, keypath, data) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + # NB: the replaced expr is OK to use directly downstream, we want + # simplifications in this case! + val = V.graph.current_node.meta["val"] + if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): + return val.node.expr + else: + return sympy.sympify(val) + + +@register_lowering(aten._assert_scalar) +def _assert_scalar(data, msg): + # NB: These will be handled at codegen time + # Not sure if we are guaranteed to be able to serve out truth from the + # deferred_runtime_asserts, TODO: try this assert out + # assert bool(data.scalar), data + return None + + +def _full(fill_value, device, dtype, size): + value = fill_value + if not isinstance(fill_value, (int, float)) and hasattr(value, "value"): + value = value.value + + if isinstance(value, (int, float)): + + def inner_fn(index): + return ops.constant(value, dtype) + + elif isinstance(value, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(value, dtype) + + else: + assert len(value.get_size()) == 0 + value_loader = value.make_loader() + + def inner_fn(index): + return value_loader([]) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + ) + + +@register_lowering(aten.full_like, type_promotion_kind=None) +def full_like(x, fill_value, **kwargs): + return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) + + +def tensor_constructor(fill_value): + # torch.zeros, torch.ones, etc + def inner( + *size, + names=None, + dtype=None, + device=None, + layout=None, + pin_memory=False, + memory_format=None, + ): + assert_nyi(names is None, "named tensors") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + device = decode_device(device) + dtype = dtype or torch.get_default_dtype() + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + # See https://github.com/pytorch/pytorch/issues/118102 + # All sizes at lowering time should be sympy.Symbol, not SymInt! + for s in size: + assert not isinstance(s, torch.SymInt) + size = [sympy.expand(s) for s in size] + return _full(fill_value, device, dtype, size) + + return inner + + +@register_lowering([torch.empty, aten.empty]) +def empty( + *size, + names=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, +): + assert_nyi(names is None, "named tensors") + device = decode_device(device) + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +def create_tensor_like(creation_fn): + """ + Shim to convert X_like(...) into X(...). For example zeros_like() into zeros(). + """ + + def _constant_like( + x, *, dtype=None, device=None, layout=None, pin_memory=False, memory_format=None + ): + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + if dtype is None: + dtype = x.get_dtype() + else: + dtype = decode_dtype(dtype) + device = device or x.get_device() + size = list(x.get_size()) + return creation_fn( + size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory + ) + + return _constant_like + + +def constant_like(fill_value): + return create_tensor_like(tensor_constructor(fill_value)) + + +empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty)) +ones_like = create_tensor_like(tensor_constructor(1)) +zeros_like = create_tensor_like(tensor_constructor(0)) + + +def new_constant(fill_value): + def _new_constant( + x, size, *, dtype=None, layout=None, device=None, pin_memory=None + ): + assert isinstance(size, (list, tuple)) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = decode_dtype(dtype) or x.get_dtype() + device = device or x.get_device() + size = [sympy.Integer(s) for s in size] + return _full(fill_value, device, dtype, size) + + return _new_constant + + +@register_lowering(aten.new_empty) +def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_lowering(aten.empty_strided) +def empty_strided( + size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + assert isinstance(size, (list, tuple)) + assert isinstance(stride, (list, tuple, type(None))) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = decode_dtype(dtype) or torch.get_default_dtype() + device = device or torch.tensor(0.0).device + pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size) + pointwise.realize() + buffer = pointwise.data.data + # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode + buffer.data.ranges = [0] * len(size) + assert isinstance(buffer, ir.ComputedBuffer) + size = [sympy.expand(s) for s in size] + stride = ( + [sympy.expand(s) for s in stride] + if stride + else ir.FlexibleLayout.contiguous_strides(size) + ) + buffer.layout = ir.FixedLayout( + device=device, + dtype=dtype, + size=size, + stride=stride, + ) + return pointwise + + +@register_lowering(aten.new_empty_strided) +def new_empty_strided( + x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, stride, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_lowering(prims.copy_strided.default) +def copy_strided(x, stride): + stride = [V.graph.sizevars.size_hint(s) for s in stride] + stride_order = sorted(range(len(stride)), key=stride.__getitem__) + return ir.ExternKernel.require_stride_order(x, stride_order) + + +@register_lowering([torch.full, aten.full]) +def full(size, fill_value, **kwargs): + assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition" + return tensor_constructor(fill_value)(size, **kwargs) + + +@register_lowering(aten.gather, type_promotion_kind=None) +def gather(x, dim, index, sparse_grad=False): + # sparse_grad doesn't affect forward computation, + # and backward tracing is taken care of by AOT Autograd + assert isinstance(x, TensorBox) + if index.get_numel() == 0: + # Empty index case. Return an empty array with the same shape + return new_empty(x, index.get_size()) + + assert index.get_dtype() == torch.int64 + size = x.get_size() + offset = len(size) == 0 + dim = _validate_dim(x, dim, offset) + + if offset: + x = expand(x, [1]) + size = [1] + + x_loader = x.make_loader() + index_loader = index.make_loader() + + def fn(idx): + idx = list(idx) + gather_idx = ops.indirect_indexing(index_loader(idx), size[dim]) + if len(idx) == 0: + idx = [gather_idx] + else: + idx[dim] = gather_idx + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + ) + + +@register_lowering(aten.embedding, type_promotion_kind=None) +def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + assert not sparse + assert isinstance(weight, TensorBox) + assert isinstance(indices, TensorBox) + assert "int" in str(indices.get_dtype()) + + weight_loader = weight.make_loader() + indices_loader = indices.make_loader() + indices_ndim = len(indices.get_size()) + weight_size = weight.get_size() + new_size = [*indices.get_size(), *weight_size[1:]] + + def fn(idx): + assert len(idx) == len(new_size), f"{idx} != {new_size}" + var_index = indices_loader(idx[:indices_ndim]) + weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [ + *idx[indices_ndim:] + ] + return weight_loader(weight_idx) + + return Pointwise.create( + device=weight.get_device(), + dtype=weight.get_dtype(), + inner_fn=fn, + ranges=new_size, + ) + + +def check_and_broadcast_indices(indices, device): + assert all( + i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8) + for i in indices + if i is not None + ), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}" + if any( + i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None + ): + raise NotImplementedError("Fallback for bool indices") + + valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)] + assert len(valid_idxs) > 0, "requires at least 1 non-None index" + new_indices = [None] * len(indices) + for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])): + # Eager allows indices to be CPU tensor when running on CUDA + # FIXME: Calling to_device(x, device) should work but + # test_advancedindex_mixed_cpu_devices still fails + if x.get_device() != device: + raise NotImplementedError("Fallback when indices is on a different device") + new_indices[i] = x + return new_indices, valid_idxs + + +def index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + x_loader, + check, +): + # Note that behavior of indexing differs when there are non consecutive + # tensors. In this case, the tensor index is pulled to the beginning. + # + # Suppose a = torch.arange(3 * 4 * 5 * 6 * 7).view(3, 4, 5, 6, 7) + # x = torch.tensor[1,2] + # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will + # be pulled to the front. + non_consecutive_tensors = False + for previous, current in zip(tensor_indices, tensor_indices[1:]): + if current - previous != 1: + non_consecutive_tensors = True + + output_size = [x_size[i] for i, val in enumerate(indices) if val is None] + output_size = [*output_size, *x_size[len(output_size) + len(tensor_indices) :]] + + first_tensor_index = tensor_indices[0] + if non_consecutive_tensors: + output_size = tensor_size + output_size + else: + output_size = ( + output_size[:first_tensor_index] + + tensor_size + + output_size[first_tensor_index:] + ) + + def fn(idx): + assert len(idx) == len(output_size) + assert len(indices_loaders) == len(indexed_size) + + rank = len(tensor_size) + new_index = [] + first_tensor_index = tensor_indices[0] + start_offset = 0 if non_consecutive_tensors else first_tensor_index + next_idx = 0 + for i in range(tensor_indices[-1] + 1): + if i == start_offset: + next_idx += rank + if indices[i] is None: + assert next_idx < len(idx) + new_index.append(idx[next_idx]) + next_idx += 1 + else: + loader = indices_loaders[i] + assert loader is not None + size = indexed_size[i] + new_index.append( + ops.indirect_indexing( + loader(idx[start_offset : start_offset + rank]), + size, + check=check, + ) + ) + new_index = [ + *new_index, + *idx[next_idx:], + ] + return new_index if x_loader is None else x_loader(new_index) + + return output_size, fn + + +def index_impl(x, indices, check): + output_size, inner_fn, _ = index_impl_helper(x, indices, check) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=output_size, + ) + + +def index_impl_helper(x, indices, check): + assert isinstance(indices, (list, tuple)) + x_loader = x.make_loader() + indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device()) + assert len(tensor_indices) > 0, "Must have at least one valid idx" + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + # no guards on output size, all the guards are set in broadcast_tensors + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + + x_size = x.get_size() + + indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None] + if check and 0 in indexed_size and 0 not in tensor_size: + raise IndexError("index is out of bounds for dimension with size 0") + + indexed_size = [x_size[i] for i in range(len(indices))] + output_size, index_inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=check, + ) + + def inner_fn(idx): + return x_loader(index_inner_fn(idx)) + + return output_size, inner_fn, index_inner_fn + + +@register_lowering(aten.index, type_promotion_kind=None) +def index(x, indices): + try: + return index_impl(x, indices, check=True) + except NotImplementedError: + # Fallback to ATen for boolean indexing + x.realize() + return fallback_handler(aten.index.Tensor, add_to_fallback_set=False)( + x, indices + ) + + +@register_lowering(aten._unsafe_index, type_promotion_kind=None) +def _unsafe_index(x, indices): + return index_impl(x, indices, check=False) + + +# All the indexing decompositions are written in terms of index, index_put, and index_put_ +# We cannot have this lowering as a decomposition as it introduces +# mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead +# code elimination and common subexpression elimination optimizations, which +# assume graphs to be side-effect free. More details at +# https://github.com/pytorch/torchdynamo/issues/1235 +# and +# https://github.com/pytorch/torchdynamo/issues/1863 +@register_lowering(aten.index_put) +def index_put(x, indices, values, accumulate=False): + return index_put_(clone(x), indices, values, accumulate) + + +@register_lowering(aten._unsafe_index_put) +def _unsafe_index_put(x, indices, values, accumulate=False): + return index_put_impl_(clone(x), indices, values, accumulate, check=False) + + +def index_put_as_masked_fill(self, indices, value, accumulate): + if value.get_device() != self.get_device(): + value = to_device(value, self.get_device()) + if accumulate: + value = add(self, value) + return mutate_to(self, where(indices[0], value, self)) + + +def index_put_fallback(self, indices, values, accumulate): + deterministic = torch.are_deterministic_algorithms_enabled() + if is_triton(values) and (accumulate or deterministic): + msg = ( + "index put with accumulate." + if not deterministic + else "deterministic index put." + ) + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate) + return self + + +@register_lowering(aten.index_put_, type_promotion_kind=None) +def index_put_(self, indices, values, accumulate=False): + return index_put_impl_(self, indices, values, accumulate, check=True) + + +@register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None) +def _unsafe_index_put_(self, indices, values, accumulate=False): + return index_put_impl_(self, indices, values, accumulate, check=False) + + +def index_put_impl_(self, indices, values, accumulate, check): + # Dispatch to masked fill for single boolean index with single value + if ( + values.get_numel() == 1 + and len(indices) == 1 + and indices[0].get_dtype() in {torch.bool, torch.uint8} + ): + mask = indices[0] + for _ in range(len(mask.get_size()), len(self.get_size())): + mask = unsqueeze(mask, -1) + return index_put_as_masked_fill(self, [mask], values, accumulate) + + # Fallback in torch deterministic mode + if torch.are_deterministic_algorithms_enabled(): + return index_put_fallback(self, indices, values, accumulate) + + # Fallback if there is a boolean index + for index in indices: + if index is not None and index.get_dtype() in {torch.bool, torch.uint8}: + return index_put_fallback(self, indices, values, accumulate) + + x_size = self.get_size() + x_ndim = len(x_size) + + if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()): + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + self = index_put_fallback(self, indices, values, accumulate) + if x_ndim == 0: + self = view(self, []) + return self + + values = to_dtype(values, self.get_dtype()) + + try: + # Note that code will only get here when dtype is uint32 + indices, tensor_indices = check_and_broadcast_indices( + indices, self.get_device() + ) + except NotImplementedError: + return index_put_fallback(self, indices, values, accumulate) + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + + assert isinstance(self, TensorBox) + self.realize() + + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=check, + ) + + values = expand(values, expected_vals_size) + # all guards are set above during broadcast_tensors and expand + + scatter = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add" if accumulate else None, + ) + buffer = ir.ComputedBuffer( + None, + ir.MutationLayoutSHOULDREMOVE(self), + scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + if x_ndim == 0: + self = view(self, []) + return self + + +fallback__unsafe_masked_index = fallback_handler( + aten._unsafe_masked_index.default, add_to_fallback_set=False +) + +fallback__unsafe_masked_index_put_accumulate = fallback_handler( + aten._unsafe_masked_index_put_accumulate.default, add_to_fallback_set=False +) + + +@register_lowering(aten._unsafe_masked_index, type_promotion_kind=None) +def _unsafe_masked_index(self, mask, indices, fill): + ranges, _, _unsafe_index_fn = index_impl_helper(self, indices, check=False) + mask_loader = mask.make_loader() + self_loader = self.make_loader() + + def inner_fn(idx): + if mask.dtype != torch.bool: + mask_val = ops.to_dtype(mask_loader(idx), torch.bool) + else: + mask_val = mask_loader(idx) + return ops.masked(mask_val, lambda: self_loader(_unsafe_index_fn(idx)), fill) + + return Pointwise.create( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=inner_fn, + ranges=ranges, + ) + + +@register_lowering(aten._unsafe_masked_index_put_accumulate, type_promotion_kind=None) +def _unsafe_masked_index_put_accumulate(x, mask, indices, values): + masked_value = where(mask, values, 0) + shape = x.get_size() + clamped_indices = [ + clamp(indices[i], -shape[i], shape[i] - 1) if indices[i] else None + for i in range(len(indices)) + ] + # TODO: use a masked store for this. currently only triton + # supports masked stores and cpp backend does not. + return _unsafe_index_put(x, clamped_indices, masked_value, accumulate=True) + + +@make_pointwise +def clamp(a, min, max): + return ops.maximum(min, ops.minimum(max, a)) + + +@register_lowering(aten.as_strided_scatter, type_promotion_kind=None) +def as_strided_scatter(self, src, size, stride, storage_offset=None): + output = clone(self) + output_view = as_strided(output, size, stride, storage_offset) + copy_(output_view, src) + return output + + +@register_lowering(aten.scatter, type_promotion_kind=None) +def scatter(x, dim: int, index, src, **kwargs): + return scatter_(clone(x), dim, index, src, **kwargs) + + +def scatter_fallback( + op_overload: torch._ops.OpOverload, + self, + dim: int, + index, + src, + *, + reduce: Optional[str] = None, + include_self: bool = True, +): + src_is_tensor = isinstance(src, TensorBox) + if use_scatter_fallback( + op_overload, + reduce, + self.get_dtype(), + src.get_dtype() if src_is_tensor else type(src), + src.get_device().type if src_is_tensor else "not impl", + src_is_tensor, + ): + ir.ScatterFallback( + op_overload, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + return self + + return None + + +@register_lowering(aten.scatter_, type_promotion_kind=None) +def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None): + assert reduce in {None, "add", "multiply"} + if reduce is None: + op_overload = getattr(aten.scatter_, V.graph.current_node.target._overloadname) # type: ignore[union-attr] + fallback_result = scatter_fallback( + op_overload, self, dim, index, src, reduce=reduce + ) + if fallback_result is not None: + return fallback_result + + if reduce == "add": + reduce = "sum" + elif reduce == "multiply": + reduce = "prod" + return scatter_reduce_(self, dim, index, src, reduce) + + +@register_lowering(aten.scatter_add, type_promotion_kind=None) +def scatter_add(x, dim: int, index, src): + return scatter_add_(clone(x), dim, index, src) + + +@register_lowering(aten.scatter_add_, type_promotion_kind=None) +def scatter_add_(x, dim: int, index, src): + return scatter_reduce_(x, dim, index, src, "sum") + + +@register_lowering(aten.scatter_reduce, type_promotion_kind=None) +def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs): + return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs) + + +@register_lowering(aten.scatter_reduce_, type_promotion_kind=None) +def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True): + assert reduce in {None, "sum", "prod", "mean", "amax", "amin"} + assert ( + len(aten.scatter_reduce_.overloads()) == 1 + and "two" in aten.scatter_reduce_.overloads() + ), "aten.scatter_reduce_.two is not the unique overload of aten.scatter_reduce_" + + if isinstance(src, Number): + src = full_like(self, src) + + fallback_result = scatter_fallback( + aten.scatter_reduce_.two, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + + if fallback_result: + return fallback_result + + assert isinstance(self, TensorBox) + assert "int" in str(index.get_dtype()) + + ndim = len(self.get_size()) + if ndim == 0: + self = view(self, [1]) + + if isinstance(src, TensorBox) and len(src.get_size()) == 0: + src = view(src, [1]) + + if isinstance(index, TensorBox) and len(index.get_size()) == 0: + index = view(index, [1]) + + if index.get_numel() == 0: + return self + + dim = _validate_dim(self, dim) + + self.realize() + index_loader = index.make_loader() + src_loader = src.make_loader() if isinstance(src, TensorBox) else None + + def output_indexer(idx): + # self is captured from the end of the function, so it may have 0 dim + shape = self.get_size() + ndim = len(shape) + indirect_idx = list(idx) + indirect_idx[dim] = ops.indirect_indexing( + index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False + ) + return indirect_idx + + def fn(idx): + if src_loader: + return src_loader(idx) + else: + # src is a scalar + return ops.constant(src, self.get_dtype()) + + def backend_reduce_str(reduce): + if reduce == "sum": + return "atomic_add" + else: + # TODO: Need to support more reduction type + assert reduce is None + return None + + if not include_self: + # zero out the corresponding elements first + zero_out = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=lambda index: ops.constant(0, self.get_dtype()), + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=None, + ) + buffer = ir.ComputedBuffer( + None, + ir.MutationLayoutSHOULDREMOVE(self), + zero_out, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + scatter = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=backend_reduce_str(reduce), + ) + buffer = ir.ComputedBuffer( + None, + ir.MutationLayoutSHOULDREMOVE(self), + scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + if ndim == 0: + self = view(self, []) + return self + + +def upsample_nearestnd( + x, + output_size, + scales_x: Tuple[Optional[float], ...], + n: int = 2, + exact: bool = False, +): + x.realize_hint() # elements are reused + x_loader = x.make_loader() + i_sizes = x.get_size()[-n:] + batch = x.get_size()[:-n] + i_sizes = [V.graph.sizevars.evaluate_static_shape(i) for i in i_sizes] + + assert len(scales_x) == n + o_sizes = output_size + + inv_scales = [i / o for i, o in zip(i_sizes, o_sizes)] + for i, scale in enumerate(scales_x): + if scale is not None: + inv_scales[i] = 1.0 / scale + + def scale_fn(x, scale, size): + # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5) + # = floor(scale * (output_index + 0.5)) + # Nearest: input_index = floor(scale * output_index) + x = ops.index_expr(x, torch.float32) + if exact: + x = ops.add(x, ops.constant(0.5, torch.float32)) + x = ops.mul(x, ops.constant(scale, torch.float32)) + x = ops.to_dtype(x, torch.int32) + return ops.indirect_indexing(x, size, check=False) + + def fn(idx): + x = idx[-n:] + b = idx[:-n] + return x_loader( + [*b, *[scale_fn(i, s, size) for i, s, size in zip(x, inv_scales, i_sizes)]] + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=[*batch, *o_sizes], + ) + + +@register_lowering(aten.upsample_nearest1d.default) +def upsample_nearest1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1) + + +@register_lowering(aten._upsample_nearest_exact1d.default) +def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True) + + +@register_lowering(aten.upsample_nearest2d.default) +def upsample_nearest2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2) + + +@register_lowering(aten._upsample_nearest_exact2d.default) +def _upsample_nearest_exact2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True) + + +@register_lowering(aten.upsample_nearest3d.default) +def upsample_nearest3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3) + + +@register_lowering(aten._upsample_nearest_exact3d.default) +def _upsample_nearest_exact3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd( + x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True + ) + + +def _create_constants(*args, dtype): + return tuple(ops.constant(a, dtype) for a in args) + + +@register_lowering(prims.rev.default) +def rev(x, dims): + # note - dims pre-canonicalized + x_loader = x.make_loader() + sizes = x.get_size() + + def loader(idx): + idx = list(idx) + assert len(idx) == len(sizes) + for dim in dims: + idx[dim] = (sizes[dim] - 1) - idx[dim] + + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=loader, + ranges=sizes, + ) + + +@register_lowering(aten.constant_pad_nd, type_promotion_kind=None) +def constant_pad_nd(x, padding, fill_value=0): + assert (len(padding) % 2) == 0 + if all(p == 0 for p in padding): + return clone(x) + + sizes = x.get_size() + + bounds = list(reversed(list(zip(padding[::2], padding[1::2])))) + n = len(sizes) - len(bounds) + + # if padding is a complicated expression, hoist it + bounds_precomp: List[Tuple[sympy.Symbol, Any]] = [] + for l, h in bounds: + bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type] + + output_size = list(sizes[:n]) + mask_sizes = [] + for (low, high), size in zip(bounds, sizes[n:]): + mask_sizes.append(size) + output_size.append(sympy.expand(size + low + high)) + assert len(output_size) == len(sizes) + fill_value = dtype_to_type(x.get_dtype())(fill_value) + + def mask(index): + mask = [] + for idx, (low, high), length in zip(index[n:], bounds, mask_sizes): + if low != 0: + mask.append(range_mask_low(idx, 0)) + if high != 0: + mask.append(range_mask_high(idx, length)) + mask = functools.reduce(ops.and_, mask) + return ops.masked(mask, lambda: x_loader(index), fill_value) + + def offset_fn(index): + new_index = list(index[:n]) + for idx, (low, high) in zip(index[n:], bounds_precomp): + new_index.append(idx - low) + assert len(new_index) == len(index) + return mask(new_index) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=offset_fn, + ranges=output_size, + ) + + +def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]): + return ops.ge( + ops.index_expr(i, torch.int64), + ops.index_expr(sympy.Integer(low), torch.int64), + ) + + +def range_mask_high(i: sympy.Expr, high: sympy.Expr): + return ops.lt( + ops.index_expr(i, torch.int64), + ops.index_expr(high, torch.int64), + ) + + +def range_mask(i: sympy.Expr, high: sympy.Expr, low: sympy.Expr): + return ops.and_( + range_mask_low(i, low), + range_mask_high(i, high), + ) + + +def constant_boundary_condition( + x, fill_value, padding=None, pad_fill_value=1.0, dim=None +): + h = x.get_size()[-dim:] + x_loader = x.make_loader() + padding_h = padding or [0] * dim + + def load(index): + prefix = index[:-dim] + ih = index[-dim:] + + mask = functools.reduce( + ops.and_, + [range_mask(ih[i], h[i] + padding_h[i], -padding_h[i]) for i in range(dim)], + ) + return ( + ops.masked( + mask, + lambda: constant_boundary_condition(x, pad_fill_value, dim=dim)( + [*prefix, *ih] + ), + fill_value, + ) + if padding + else ops.masked(mask, lambda: x_loader([*prefix, *ih]), fill_value) + ) + + return load + + +def pooling_size(x, i, kernel_size, stride, padding, ceil_mode): + x_out = FloorDiv( + x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i] + ) + + if ceil_mode: + x_alt = FloorDiv( + x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i] + ) + if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0: + # Sliding windows must start within the input or left padding + x_alt -= 1 # type: ignore[assignment] + V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type] + if V.graph.sizevars.size_hint(x_out - x_alt) == 0: + # ceil mode is actually a no-op, lets guard on that + V.graph.sizevars.guard_equals(x_out, x_alt) + ceil_mode = False + else: + x_out = x_alt + return x_out, ceil_mode + + +def should_fallback_max_pool2d_with_indices(kernel_size, dilation): + kernel_size = pad_listlike(kernel_size, 2) + window_size = kernel_size[0] * kernel_size[1] + return (window_size > 25) or any(d > 1 for d in dilation) + + +def max_pool2d_checks( + x, kernel_size, stride, padding, dilation, *, assert_fallback=None +): + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + + kernel_size = pad_listlike(kernel_size, 2) + stride = pad_listlike(stride, 2) + padding = pad_listlike(padding, 2) + dilation = pad_listlike(dilation, 2) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + assert len(x.get_size()) in (3, 4) + + use_fallback = should_fallback_max_pool2d_with_indices(kernel_size, dilation) + if assert_fallback is not None: + assert use_fallback == assert_fallback + + return kernel_size, stride, padding, dilation, use_fallback + + +@register_lowering(prims._low_memory_max_pool2d_with_offsets, type_promotion_kind=None) +def _low_memory_max_pool2d_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode=False, +): + # assert we are not on a fallback path, the inductor decomp should have guaranteed this + kernel_size, stride, padding, dilation, _ = max_pool2d_checks( + x, kernel_size, stride, padding, dilation, assert_fallback=False + ) + + x.realize_hint() + *batch, h, w = x.get_size() + + h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode) + w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode) + + dtype = x.dtype + min_value = ( + False + if dtype is torch.bool + else (float("-inf") if dtype.is_floating_point else torch.iinfo(dtype).min) + ) + + new_size = list(batch) + [h_out, w_out] + if padding[0] or padding[1] or ceil_mode1 or ceil_mode2: + x_loader = constant_boundary_condition(x, min_value, dim=2) + else: + x_loader = x.make_loader() + + def fn(idx, return_index): + *prefix, bh, bw = idx + maxval = None + maxindex = None + for h_inc, w_inc in itertools.product( + range(kernel_size[0]), range(kernel_size[1]) + ): + ih = bh * stride[0] + h_inc - padding[0] + iw = bw * stride[1] + w_inc - padding[1] + val = x_loader([*prefix, ih, iw]) + if return_index: + index = ops.index_expr(h_inc * kernel_size[1] + w_inc, torch.int8) + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) + if maxval is None: + maxval = val + else: + maxval = ops.maximum(val, maxval) + if return_index: + return maxindex + else: + return maxval + + out = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=functools.partial(fn, return_index=False), + ranges=new_size, + ) + offsets = Pointwise.create( + device=x.get_device(), + dtype=torch.int8, + inner_fn=functools.partial(fn, return_index=True), + ranges=new_size, + ) + return out, offsets + + +@register_lowering( + prims._low_memory_max_pool2d_offsets_to_indices, type_promotion_kind=None +) +def _low_memory_max_pool2d_offsets_to_indices( + offsets, kernel_width, input_width, stride, padding +): + # TODO: Generalize to other max pooling flavors, and arbitrary dim + + offsets_loader = offsets.make_loader() + + def increments_to_index(h_inc, w_inc, bh, bw): + w_in = ops.index_expr(input_width, torch.int64) + hbase = ops.index_expr(bh * stride[0] - padding[0], torch.int64) + wbase = ops.index_expr(bw * stride[1] - padding[1], torch.int64) + ih = hbase + h_inc + iw = wbase + w_inc + return ih * w_in + iw + + def offsets_to_indices(idx): + *prefix, bh, bw = idx + offset = offsets_loader([*prefix, bh, bw]) + kw_const = ops.constant(kernel_width, torch.int32) + h_inc = offset // kw_const + w_inc = offset - (h_inc * kw_const) + return increments_to_index(h_inc, w_inc, bh, bw) + + indices = Pointwise.create( + device=offsets.get_device(), + dtype=torch.int64, + inner_fn=offsets_to_indices, + ranges=offsets.get_size(), + ) + return indices + + +# Fallback selected when we do not decompose to the low-memory path. +make_fallback(aten.max_pool2d_with_indices) + + +fallback_max_pool2d_with_indices_backward = fallback_handler( + aten.max_pool2d_with_indices_backward.default, + add_to_fallback_set=False, +) + + +@register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None) +def max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices +): + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + assert len(x.get_size()) in (3, 4) + + # we will read this many times, so make sure it is computed + grad_output.realize_hint() + try: + gO_stride = grad_output.get_stride() + except AttributeError: + # some classes don't have `get_stride` + # TODO will need a better way of determining if inputs are channels-last + gO_stride = None + if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined] + data = x.data.data # type: ignore[attr-defined] + x_buffer = ir.ComputedBuffer( + name=None, + layout=ir.FlexibleLayout( + device=data.get_device(), + dtype=data.get_dtype(), + size=data.get_size(), + ), + data=data, + ) + x_buffer.decide_layout() + x_stride = x_buffer.get_stride() + else: + try: + x_stride = x.get_stride() + except AttributeError: + x_stride = None + + is_channels_last = (x_stride is not None and x_stride[1] == 1) or ( + gO_stride is not None and gO_stride[1] == 1 + ) + if any(d != 1 for d in dilation): + # dilation NYI + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + *batch, height, width = x.get_size() + *_, pooled_height, pooled_width = grad_output.get_size() + + indices_loader = indices.make_loader() + grad_loader = grad_output.make_loader() + new_size = list(x.get_size()) + + h_window_size = max( + max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1) + for h in range(kernel_size[0] * 2) + ) + w_window_size = max( + max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1) + for w in range(kernel_size[1] * 2) + ) + + window_size = h_window_size * w_window_size + + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + indices_size = indices.get_size() + + def fn(idx): + *prefix, h, w = idx + index_test = ops.index_expr(h * width + w, torch.int32) + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + grad_index = [ + *prefix, + ops.indirect_indexing( + ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))), + indices_size[-2], + check=False, + ), + ops.indirect_indexing( + ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))), + indices_size[-1], + check=False, + ), + ] + + index_actual = indices_loader(grad_index) + grad_part = grad_loader(grad_index) + check = ops.eq(index_actual, index_test) + + if gradient is None: + # don't need mask for 0, 0 + gradient = ops.where( + check, grad_part, ops.constant(0.0, torch.float32) + ) + else: + mask = ops.and_( + ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ), + check, + ) + gradient = ops.where(mask, ops.add(gradient, grad_part), gradient) + assert gradient is not None + return gradient + + out = Pointwise.create( + device=grad_output.get_device(), + dtype=grad_output.get_dtype(), + inner_fn=fn, + ranges=new_size, + ) + if is_channels_last: + return ir.ExternKernel.require_channels_last(out) + else: + return out + + +def pad_adaptive_loader(x, pad_val=0.0): + *_, h, w = x.get_size() + x_loader = x.make_loader() + + def load(prefix, increments, start_indices, end_indices): + ih, iw = increments + h_start_index, w_start_index = start_indices + h_end_index, w_end_index = end_indices + + mask = ops.and_( + ops.lt( + ops.index_expr(h_start_index + ih, torch.int64), + ops.index_expr(h_end_index, torch.int64), + ), + ops.lt( + ops.index_expr(w_start_index + iw, torch.int64), + ops.index_expr(w_end_index, torch.int64), + ), + ) + + return ops.masked( + mask, + lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]), + pad_val, + ) + + return load + + +def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out): + h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) + h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) + + w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) + w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) + + return h_start_index, h_end_index, w_start_index, w_end_index + + +def _adaptive_pooling_fn( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + result = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + if result is None: + result = val + else: + result = pooling_fn(val, result) + return result + + return fn + + +def _adaptive_pooling_fn_with_idx( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + + index = ops.index_expr( + (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 + ) + + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) + + if maxval is None: + maxval = val + else: + maxval = pooling_fn(val, maxval) + + return maxindex + + return fn + + +fallback_adaptive_avg_pool2d = fallback_handler( + aten._adaptive_avg_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten._adaptive_avg_pool2d) +def _adaptive_avg_pool2d(x, output_size): + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) + + h_out, w_out = output_size + + # no-op if the same input and output + if h_in == h_out and w_in == w_out: + return clone(x) + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()) + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [h_in // h_out, w_in // w_out] + return avg_pool2d(x, kernel_size) + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_avg_pool2d(x, output_size) + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.add, + ) + + ones_loader = pad_adaptive_loader(ones_like(x)) + + def fn(idx): + return ops.truediv( + fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader) + ) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + # TODO: should we force these to be realized? + return rv + + +fallback_adaptive_max_pool2d = fallback_handler( + aten.adaptive_max_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.adaptive_max_pool2d) +def adaptive_max_pool2d(x, output_size): + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) + + h_out, w_out = output_size + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()), empty( + o_size, dtype=torch.int64, device=x.get_device() + ) + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [h_in // h_out, w_in // w_out] + if should_fallback_max_pool2d_with_indices(kernel_size, dilation=[1, 1]): + return max_pool2d_with_indices(x, kernel_size) # type: ignore[name-defined] # noqa: F821 + else: + v, offsets = _low_memory_max_pool2d_with_offsets( + x, + kernel_size, + stride=kernel_size, + padding=[0, 0], + dilation=[1, 1], + ceil_mode=False, + ) + indices = _low_memory_max_pool2d_offsets_to_indices( + offsets, kernel_size[1], w_in, kernel_size, padding=[0, 0] + ) + return v, indices + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_max_pool2d(x, output_size) + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + inner_func_max_val = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.maximum, + ) + + inner_func_max_idx = _adaptive_pooling_fn_with_idx( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.maximum, + ) + + def inner_fn_max_val(idx): + return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf"))) + + def inner_fn_max_idx(idx): + return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf"))) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=inner_fn_max_val, + ranges=new_size, + ) + ri = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=inner_fn_max_idx, + ranges=new_size, + ) + return rv, ri + + +fallback_fractional_max_pool2d = fallback_handler( + aten.fractional_max_pool2d.default, add_to_fallback_set=False +) + + +def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim): + out_sz = out_sz[dim] + in_sz = in_sz[dim] + kernel_sz = kernel_sz[dim] + alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1) + samples_loader = samples.make_loader() + + def load(prefix, i): + sample = samples_loader([*prefix, dim]) + i_expr = ops.index_expr(i, samples.get_dtype()) + alpha_expr = ops.index_expr(alpha, samples.get_dtype()) + seq_i = ops.floor((i_expr + sample) * alpha_expr) - ops.floor( + sample * alpha_expr + ) + seq_i = ops.to_dtype(seq_i, torch.int64) + + mask = ops.lt( + i_expr, + ops.index_expr(out_sz - 1, torch.int64), + ) + return ops.where(mask, seq_i, ops.index_expr(in_sz - kernel_sz, torch.int64)) + + return load + + +@register_lowering(aten.fractional_max_pool2d) +def fractional_max_pool2d(x, kernel_size, output_size, random_samples): + x.realize_hint() + *batch, inp_h, inp_w = x.get_size() + kernel_h, kernel_w = kernel_size + h_out, w_out = output_size + + if kernel_h * kernel_w >= 25: + return fallback_fractional_max_pool2d( + x, kernel_size, output_size, random_samples + ) + + gen_offsets_for_dim = functools.partial( + _fractional_pooling_offsets, + samples=random_samples, + in_sz=[inp_h, inp_w], + out_sz=output_size, + kernel_sz=kernel_size, + ) + + h_index_fn = gen_offsets_for_dim(dim=0) + w_index_fn = gen_offsets_for_dim(dim=1) + x_loader = x.make_loader() + + def fn(idx, return_index): + *prefix, bh, bw = idx + + h_start_index = ops.indirect_indexing(h_index_fn(prefix, bh), inp_h) + w_start_index = ops.indirect_indexing(w_index_fn(prefix, bw), inp_w) + + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])): + val = x_loader([*prefix, h_start_index + ih, w_start_index + iw]) + if return_index: + index = ops.index_expr( + (h_start_index + ih) * inp_w + w_start_index + iw, torch.int64 + ) + if maxindex is None: + maxindex = index + else: + maxindex = ops.where( + ops.or_(ops.gt(val, maxval), ops.isnan(val)), index, maxindex + ) + if maxval is None: + maxval = val + else: + maxval = ops.maximum(val, maxval) + if return_index: + return maxindex + else: + return maxval + + new_size = list(batch) + [h_out, w_out] + rv = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=functools.partial(fn, return_index=False), + ranges=new_size, + ) + + ri = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=functools.partial(fn, return_index=True), + ranges=new_size, + ) + return rv, ri + + +@register_lowering(aten.upsample_nearest2d_backward.default) +def upsample_nearest2d_backward( + x, output_size=None, input_size=None, scales_h=None, scales_w=None +): + x.realize_hint() + + *batch, inp_h, inp_w = x.get_size() + inp_h = V.graph.sizevars.evaluate_static_shape(inp_h) + inp_w = V.graph.sizevars.evaluate_static_shape(inp_w) + + *batch, out_h, out_w = input_size + + if inp_h % out_h == 0 and inp_w % out_w == 0: + return avg_pool2d(x, [inp_h // out_h, inp_w // out_w], divisor_override=1) + + h_kernel_max = ceildiv(inp_h, out_h) + w_kernel_max = ceildiv(inp_w, out_w) + + def start_index(index, out_dim, inp_dim): + return CeilDiv(index * inp_dim, sympy.sympify(out_dim)) + + def end_index(index, out_dim, inp_dim): + return start_index((index + 1), out_dim, inp_dim) + + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[inp_h, inp_w], + out_sizes=[out_h, out_w], + pooling_fn=ops.add, + ) + + def fn(idx): + return fn_sum(idx, pad_adaptive_loader(x)) + + rv = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=list(input_size), + ) + + return rv + + +fallback_avg_pool2d = fallback_handler( + aten.avg_pool2d.default, add_to_fallback_set=False +) +fallback_avg_pool3d = fallback_handler( + aten.avg_pool3d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool2d, type_promotion_kind=None) +def avg_pool2d( + x, + kernel_size, + stride=(), + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + return _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim=2, + ) + + +@register_lowering(aten.avg_pool3d, type_promotion_kind=None) +def avg_pool3d( + x, + kernel_size, + stride=(), + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + return _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim=3, + ) + + +def _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim, +): + if not stride: + stride = kernel_size + if not padding: + padding = [0] * dim + kernel_size = pad_listlike(kernel_size, dim) + stride = pad_listlike(stride, dim) + padding = pad_listlike(padding, dim) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == dim + assert len(stride) == dim + assert len(padding) == dim + assert len(x.get_size()) in (dim + 1, dim + 2) + + x.realize_hint() + batch = x.get_size()[:-dim] + h = x.get_size()[-dim:] + + h_out, ceil_modes = zip( + *[ + pooling_size(h[i], i, kernel_size, stride, padding, ceil_mode) + for i in range(dim) + ] + ) + + if any(padding) or any(ceil_modes): + x_loader = constant_boundary_condition(x, 0.0, dim=dim) + had_padding = True + else: + x_loader = x.make_loader() + had_padding = False + + new_size = list(batch) + list(h_out) + dtype = x.get_dtype() + + window_size = functools.reduce(operator.mul, kernel_size) + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + if dim == 2: + fallback = fallback_avg_pool2d + elif dim == 3: + fallback = fallback_avg_pool3d + else: + raise ValueError(f"Unknown dim: {dim}") + + return fallback( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def fn_sum(idx, loader): + prefix = idx[:-dim] + b = idx[-dim:] + total = None + for ih in itertools.product(*[range(kernel_size[i]) for i in range(dim)]): + inp = [b[i] * stride[i] + ih[i] - padding[i] for i in range(dim)] + val = loader([*prefix, *inp]) + if total is None: + total = val + else: + total = ops.add(val, total) + return total + + if not had_padding or divisor_override: + if divisor_override: + scale = 1 / divisor_override + else: + scale = 1.0 / window_size + + def fn(idx): + return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype)) + + else: + + def fn(idx): + prefix = idx[:-dim] + bh = idx[-dim:] + + divide_factors = [] + for i in range(dim): + hstart = bh[i] * stride[i] - padding[i] + hend = sympy.Min(hstart + kernel_size[i], h[i] + padding[i]) + if not count_include_pad: + hstart = sympy.Max(hstart, 0) + hend = sympy.Min(hend, h[i]) + factor = ops.index_expr(hend - hstart, torch.int32) + divide_factors.append(factor) + divide_factor = functools.reduce(ops.mul, divide_factors) + return ops.truediv(fn_sum(idx, x_loader), divide_factor) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + # TODO(jansel): should we force these to be realized? + return rv + + +fallback_avg_pool2d_backward = fallback_handler( + aten.avg_pool2d_backward.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None) +def avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + assert divisor_override is None or divisor_override != 0, "divisor must be not zero" + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0] + + assert isinstance(grad_output, TensorBox) + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(x.get_size()) in (3, 4) + + grad_output.realize_hint() # we will read this many times, so make sure it is computed + + *batch, height, width = x.get_size() + + h_out, ceil_mode1 = pooling_size(height, 0, kernel_size, stride, padding, ceil_mode) + w_out, ceil_mode2 = pooling_size(width, 1, kernel_size, stride, padding, ceil_mode) + + grad_loader = grad_output.make_loader() + + had_padding = padding[0] or padding[1] or ceil_mode1 or ceil_mode2 + + *_, pooled_height, pooled_width = grad_output.get_size() + new_size = list(x.get_size()) + dtype = x.get_dtype() + + h_window_size = max( + max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1) + for h in range(kernel_size[0] * 2) + ) + w_window_size = max( + max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1) + for w in range(kernel_size[1] * 2) + ) + + window_size = h_window_size * w_window_size + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(ph, pw): + """ + This computes the scaling factor that we will divide an element + by when `count_include_pad=False` + """ + stride_h = ops.constant(stride[0], torch.int32) + stride_w = ops.constant(stride[1], torch.int32) + pad_h = ops.constant(padding[0], torch.int32) + pad_w = ops.constant(padding[1], torch.int32) + kernel_h = ops.constant(kernel_size[0], torch.int32) + kernel_w = ops.constant(kernel_size[1], torch.int32) + hstart = ops.sub(ops.mul(ph, stride_h), pad_h) + wstart = ops.sub(ops.mul(pw, stride_w), pad_w) + hend = ops.minimum( + ops.add(hstart, kernel_h), + ops.add(ops.index_expr(height, torch.int32), pad_h), + ) + wend = ops.minimum( + ops.add(wstart, kernel_w), + ops.add(ops.index_expr(width, torch.int32), pad_w), + ) + hstart = ops.maximum(hstart, ops.constant(0, torch.int32)) + wstart = ops.maximum(wstart, ops.constant(0, torch.int32)) + hend = ops.minimum(hend, ops.index_expr(height, torch.int32)) + wend = ops.minimum(wend, ops.index_expr(width, torch.int32)) + divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart)) + return divide_factor + + def fn(idx): + *prefix, h, w = idx + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] + else: + scale = compute_pool_size_without_padding(ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where(mask, part, ops.constant(0.0, torch.float32)) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + +fallback_avg_pool3d_backward = fallback_handler( + aten.avg_pool3d_backward.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None) +def avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + assert divisor_override is None or divisor_override != 0, "divisor must be not zero" + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0, 0] + + assert isinstance(grad_output, TensorBox) + assert isinstance(x, TensorBox) + assert len(kernel_size) == 3 + assert len(stride) == 3 + assert len(padding) == 3 + assert len(x.get_size()) in (4, 5) + + grad_output.realize_hint() + + *batch, depth, height, width = x.get_size() + + d_out, ceil_mode_d = pooling_size(depth, 0, kernel_size, stride, padding, ceil_mode) + h_out, ceil_mode_h = pooling_size( + height, 1, kernel_size, stride, padding, ceil_mode + ) + w_out, ceil_mode_w = pooling_size(width, 2, kernel_size, stride, padding, ceil_mode) + + grad_loader = grad_output.make_loader() + had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w + + *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size() + new_size = list(x.get_size()) + dtype = x.get_dtype() + + d_window_size, h_window_size, w_window_size = ( + max( + max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1) + for d in range(kernel_size[i] * 2) + ) + for i in range(3) + ) + + window_size = d_window_size * h_window_size * w_window_size + if window_size > 125: + # Kernel size too big. Results in hard-to-optimize Triton code. + return fallback_avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(pd, ph, pw): + stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride) + pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding) + kernel_d, kernel_h, kernel_w = ( + ops.constant(k, torch.int32) for k in kernel_size + ) + + dstart, hstart, wstart = ( + ops.sub(ops.mul(p, s), pad) + for p, s, pad in zip( + [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w] + ) + ) + dend, hend, wend = ( + ops.minimum( + ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad) + ) + for start, k, dim, pad in zip( + [dstart, hstart, wstart], + [kernel_d, kernel_h, kernel_w], + [depth, height, width], + [pad_d, pad_h, pad_w], + ) + ) + dstart, hstart, wstart = ( + ops.maximum(start, ops.constant(0, torch.int32)) + for start in [dstart, hstart, wstart] + ) + dend, hend, wend = ( + ops.minimum(end, ops.index_expr(dim, torch.int32)) + for end, dim in zip([dend, hend, wend], [depth, height, width]) + ) + divide_factor = ops.mul( + ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart) + ) + return divide_factor + + def fn(idx): + *prefix, d, h, w = idx + d, h, w = (v + pad for v, pad in zip([d, h, w], padding)) + + pdstart, phstart, pwstart = ( + ops.index_expr(FloorDiv(v - k + s, s), torch.int32) + for v, k, s in zip([d, h, w], kernel_size, stride) + ) + + pdend, phend, pwend = ( + ops.index_expr(FloorDiv(v, s) + 1, torch.int32) + for v, s in zip([d, h, w], stride) + ) + + pdstart, phstart, pwstart = ( + ops.maximum(pstart, ops.constant(0, torch.int32)) + for pstart in [pdstart, phstart, pwstart] + ) + pdend, phend, pwend = ( + ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32)) + for pend, pooled_dim in zip( + [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width] + ) + ) + + gradient = None + # Iterate over the 3D region to accumulate gradients + for pd_ in range(d_window_size): + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + pd, ph, pw = ( + ops.add(pstart, ops.constant(p_, torch.int32)) + for pstart, p_ in zip( + [pdstart, phstart, pwstart], [pd_, ph_, pw_] + ) + ) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] * kernel_size[2] + else: + scale = compute_pool_size_without_padding(pd, ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + pd, ops.sub(pdend, ops.constant(1, torch.int32)) + ), + pooled_depth, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where( + mask, part, ops.constant(0.0, torch.float32) + ) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + +def _validate_reduction_axis(x, axis): + size = x.get_size() + if isinstance(axis, int): + axis = [axis] + elif not axis: + axis = range(len(size)) + if len(size) == 0: + assert tuple(axis) in [(), (0,), (-1,)], f"invalid axis: {axis}" + return [] + axis = list(axis) + for i in range(len(axis)): + if axis[i] < 0: + axis[i] += len(size) if len(size) else 1 + assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0) + assert len(set(axis)) == len(axis), "reduction axis not unique" + return axis + + +def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = set(_validate_reduction_axis(x, axis)) + + kept_sizes = [] + kept_idx = [] + reduced_sizes = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + reduced_sizes.append(size[i]) + else: + kept_idx.append(i) + kept_sizes.append(size[i]) + + def loader(index, reduction_index): + assert len(reduction_index) == len(reduced_idx) + if keepdims: + assert len(index) == len(size) + index = [index[i] for i in kept_idx] + assert len(index) == len(kept_idx) + new_index = [None] * (len(index) + len(reduction_index)) + for idx, var in itertools.chain( + zip(kept_idx, index), zip(reduced_idx, reduction_index) + ): + new_index[idx] = var + return inner_loader(new_index) + + if keepdims: + new_size = list(size) + for i in reduced_idx: + new_size[i] = sympy.Integer(1) + else: + new_size = kept_sizes + + inner_loader = x.make_loader() + return dict( + device=x.get_device(), + dst_dtype=override_return_dtype or x.get_dtype(), + src_dtype=x.get_dtype(), + inner_fn=loader, + ranges=new_size, + reduction_ranges=reduced_sizes, + ) + + +def make_reduction(reduction_type: str, override_return_dtype=None): + def inner(x, axis=None, keepdims=False, *, dtype=None): + kwargs = _make_reduction_inner( + x, + axis=axis, + keepdims=keepdims, + dtype=dtype, + override_return_dtype=override_return_dtype, + ) + result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) + if isinstance( + result.data.data, Reduction + ): # Only realize if reduction isn't unrolled + result.realize() + return result + + return inner + + +def _make_scan_inner(x, *, axis, dtype): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_dim(x, axis) + + return dict( + device=x.get_device(), + dtypes=(x.get_dtype(),), + inner_fns=(x.make_loader(),), + size=x.get_size(), + axis=axis, + ) + + +@register_lowering(aten.mean) +def mean(x, axis=None, keepdim=False, *, dtype=None): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + # compute in higher-precision until end of mean lowering + output_dtype = x.get_dtype() + if output_dtype in (torch.float16, torch.bfloat16): + x = to_dtype(x, torch.float) + sum_result = sum_(x, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + return to_dtype(div(sum_result, denom), output_dtype) + + +def var_mean_sum_(x, axis, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + x_mean = mean(x, axis, keepdim=True) + if return_mean: + x_mean.realize() + + diffs = square(sub(x, x_mean)) + sum_result = sum_(diffs, axis, keepdim) + + denom = sympy_product(size[i] for i in axis) + if correction: + denom = sympy.Max(denom - correction, 0) + denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + x_var = div(sum_result, denom) + if not return_mean: + return (x_var,) + + x_mean = x_mean if keepdim else squeeze(x_mean, axis) + return x_var, x_mean + + +def use_two_step_variance(x, axis, keepdim): + # Instead of unrolling welford, just unroll the simpler two-step var + axis = _validate_reduction_axis(x, axis) + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + + ranges = kwargs["ranges"] + reduction_numel = sympy_product(kwargs["reduction_ranges"]) + return ( + isinstance(reduction_numel, sympy.Integer) + and int(reduction_numel) < config.unroll_reductions_threshold + and sympy_product(ranges) != 1 + ) + + +def var_mean_welford_(x, axis, *, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + loader = kwargs.pop("inner_fn") + kwargs.pop("dst_dtype") + kwargs.pop("src_dtype") + + mean, m2, _ = ir.WelfordReduction.create( + inner_fns=(loader,), + reduction_type="welford_reduce", + dtype=x.get_dtype(), + **kwargs, + ) + m2.realize() + + dtype = x.get_dtype() + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + rnumel = sympy_product(size[i] for i in axis) + + def get_constant_or_index_expr(x, dtype): + if isinstance(x, sympy.Expr) and not x.is_number: + return ops.to_dtype(ops.index_expr(x, torch.int64), dtype) + return ops.constant(x, dtype) + + def scale_fn(data): + c = get_constant_or_index_expr(correction, dtype) + N = get_constant_or_index_expr(rnumel, dtype) + zero = ops.constant(0, dtype) + return data / ops.maximum(zero, N - c) + + var = make_pointwise(scale_fn)(m2) + + if return_mean: + mean.realize() + return var, mean + return (var,) + + +def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): + out_dtype = x.get_dtype() + compute_dtype = get_computation_dtype(out_dtype) + x = to_dtype(x, compute_dtype, copy=False) + kwargs = dict( + x=x, + axis=axis, + correction=correction, + keepdim=keepdim, + return_mean=return_mean, + ) + output = ( + var_mean_sum_(**kwargs) + if use_two_step_variance(x, axis=axis, keepdim=keepdim) + else var_mean_welford_(**kwargs) + ) + output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) + return output[0] if not return_mean else output + + +@register_lowering([aten.var, prims.var]) +def var_(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False + ) + + +@register_lowering(aten.var_mean) +def var_mean(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True + ) + + +def pow_recursive(x, y, dtype): + if y < 0: + return pow_recursive(ops.reciprocal(x), -y, dtype) + if y == 0: + return ops.constant(1, dtype) + if y == 1: + return x + + result = pow_recursive(x, y // 2, dtype) + result = ops.mul(result, result) + if (y % 2) == 1: + result = ops.mul(result, x) + return result + + +@make_pointwise +def pow_native(a, b): + return ops.pow(a, b) + + +fallback_pow_tensor_tensor = fallback_handler( + aten.pow.Tensor_Tensor, add_to_fallback_set=False +) +fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False) +fallback_pow_tensor_scalar = fallback_handler( + aten.pow.Tensor_Scalar, add_to_fallback_set=False +) + + +@register_lowering(aten.pow, broadcast=True) +def pow(a, b): + if isinstance(b, float) and b == int(b): + return pow(a, int(b)) + elif isinstance(b, float) and b == 0.5: + return sqrt(a) + elif isinstance(b, int) and b == 1: + return clone(a) + + # Type promotion ensures all tensor arguments have the same type + dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox)) + is_integer_pow = is_integer_dtype(dtype) + + # Optimize away small fixed powers, or for integers avoid falling back to ATen + embed_exponent = isinstance(b, int) and ( + -32 < b < 32 or (is_integer_pow and b >= 0) + ) + if embed_exponent: + loader = a.make_loader() + + def fn(idx): + return pow_recursive(loader(idx), b, a.get_dtype()) + + return Pointwise.create( + device=a.get_device(), + dtype=a.get_dtype(), + inner_fn=fn, + ranges=a.get_size(), + ) + + if isinstance(a, Number): + if a == 1: + return full_like(b, 1) + if a == 2 and is_float_dtype(b.get_dtype()): + return exp2(b) + + if is_integer_pow: + # ops.pow doesn't work for integers + if isinstance(a, Number): + return fallback_pow_scalar(a, b) + elif isinstance(b, Number): + return fallback_pow_tensor_scalar(a, b) + else: + return fallback_pow_tensor_tensor(a, b) + + return pow_native(a, b) + + +def mutate_to(changed, val, unsafe_alias=False): + if isinstance(changed, TensorBox): + changed_data = changed.data + else: + changed_data = changed + if isinstance(val, TensorBox): + val = val.data + + if not isinstance(val, ir.StorageBox): + # introduce a copy to handle views + val = Pointwise.create( + device=changed.get_device(), + dtype=changed.get_dtype(), + inner_fn=val.make_loader(), + ranges=changed.get_size(), + ).data + assert isinstance(val, ir.StorageBox) + + if isinstance(changed_data, ir.StorageBox) and not ( + changed_data.is_input_buffer() + # In AOTI, module parameters and buffers are not lifted as graph inputs + or changed_data.is_module_buffer() + or isinstance(changed_data.data, ir.NopKernel) + ): + # Fast path, just swing the data pointer + val.realize() + changed_data.data = val.data + return changed + + ir.MutationLayoutSHOULDREMOVE.realize_into( + val, changed_data, unsafe_alias=unsafe_alias + ) + return changed + + +@register_lowering(aten.fill_) +def fill_(x, fill_value): + return mutate_to(x, full_like(x, fill_value)) + + +@register_lowering(aten.copy_, type_promotion_kind=None) +def copy_(dst, src, non_blocking=False): + if dst is src: + # dst.copy_(dst) can happen from the reinplacing pass + return dst + src = to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) + + +@make_pointwise +def floordiv(a, b): + return ops.floordiv(a, b) + + +@make_pointwise +def truncdiv(a, b): + return ops.truncdiv(a, b) + + +@register_lowering(aten.div, broadcast=True) +def div_mode(a, b, rounding_mode=None): + both_integer = is_integer_type(a) and is_integer_type(b) + both_boolean = is_boolean_type(a) and is_boolean_type(b) + + # floordiv and truncdiv need special handling for integer tensors on Triton, + # see the discussion at https://github.com/openai/triton/issues/605 + if rounding_mode == "floor": + assert not both_boolean, "floordiv operands can not be boolean at the same time" + return floordiv(a, b) if both_integer else floor(div(a, b)) + if rounding_mode == "trunc": + assert not both_boolean, "truncdiv operands can not be boolean at the same time" + return truncdiv(a, b) if both_integer else trunc(div(a, b)) + return div(a, b) + + +@register_lowering([aten.mul], broadcast=True) +def mul(a, b): + both_bool = is_boolean_type(a) and is_boolean_type(b) + if both_bool: + return logical_and(a, b) + else: + fn = ops_wrapper(aten.mul.__name__) + return make_pointwise(fn)(a, b) + + +def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]: + """Try convert an arbitrary IR node into an ir.Constant value""" + + # First try unwrapping the IRNode to see if it is already an ir.Constant + # Optional step, but avoids unnecessary inner_fn evaluation. + if isinstance(x, ir.MutableBox): + return get_constant_value(x.data) + if isinstance(x, ir.BaseView): + return get_constant_value(x.unwrap_view()) + if isinstance(x, ir.Constant): + return x + + # If the unwrapped node is not an ir.Constant, try evaluating inner_fn + # to see if the returned value is from an `ops.constant` call + if not isinstance(x, ir.Loops): + return None + + handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device()) + with V.set_ops_handler(handler), patch.object( + ir.FlexibleLayout, "allow_indexing", True + ): + out = x.inner_fn(*x.inner_fn_args()) + + assert isinstance(out, torch._inductor.virtualized.OpsValue) + if isinstance(out.value, ir.Constant): + return out.value + return None + + +# NOTE: prims.div maps to a / b in C, so performs truncation division on +# integer inputs and true division for floating and complex inputs. +@register_lowering([prims.div], broadcast=True) +def div_prim(a, b): + is_integral = all(is_boolean_type(x) or is_integer_type(x) for x in [a, b]) + + if is_integral: + return truncdiv(a, b) + + if (divisor := get_constant_value(b)) is not None: + # Replace divide by constant with multiply by reciprocal + if divisor.value == 0: + reciprocal = math.copysign(float("inf"), divisor.value) + else: + reciprocal = 1.0 / divisor.value + return mul(a, reciprocal) + + def fn(*args): + return ops.truediv(*args) + + return make_pointwise(fn)(a, b) + + +@register_lowering( + [aten.true_divide, aten.div.Tensor], + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def div(a, b): + a, b = promote_constants( + (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + return div_prim(a, b) + + +@register_lowering([aten.fmod, prims.fmod], broadcast=True) +def fmod(a, b): + is_integral = is_boolean_type(a) or is_integer_type(a) + + if is_integral: + + def fn(a, b): + return ops.mod(a, b) + + else: + + def fn(a, b): + return ops.fmod(a, b) + + return make_pointwise(fn)(a, b) + + +@register_lowering(aten.rsqrt) +def rsqrt(x): + dtype = x.get_dtype() + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + x = to_dtype(x, torch.get_default_dtype()) + + def _rsqrt(x): + return ops.rsqrt(x) + + return make_pointwise(_rsqrt)(x) + + +@register_lowering([aten.sum, prims.sum]) +def sum_(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("sum", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + +fallback_cumsum = fallback_handler(aten.cumsum.default) +fallback_cumprod = fallback_handler(aten.cumprod.default) +fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default) +fallback_cummax = fallback_handler(aten.cummax.default) +fallback_cummin = fallback_handler(aten.cummin.default) + + +@register_lowering(aten.cumsum) +def cumsum(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + def combine_fn(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + return (ops.add(a, b),) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn) + if result is None: + return fallback_cumsum(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.cumprod) +def cumprod(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + def combine_fn(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + return (ops.mul(a, b),) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn) + if result is None: + return fallback_cumprod(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.logcumsumexp) +def logcumsumexp(x, dim): + def log_add_exp_helper(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + min_v = ops.minimum(a, b) + max_v = ops.maximum(a, b) + mask = (min_v != max_v) | (~ops.isinf(min_v)) + return (ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a),) + + dtype = x.get_dtype() + if len(x.get_size()) == 0: + assert dim in [0, -1] + return clone(x) + + kwargs = _make_scan_inner(x, axis=dim, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper) + if result is None: + return fallback_logcumsumexp(x, dim=dim) + return result + + +@register_lowering(aten.cummax, type_promotion_kind=None) +def cummax(x, axis=None): + if len(x.get_size()) == 0: + assert axis in [0, -1] + return clone(x), empty_like(x, dtype=torch.int64) + + dtype = x.get_dtype() + combine_fn = ir.get_reduction_combine_fn( + "argmax", dtype=dtype, arg_break_ties_left=False + ) + + min_value = ( + False + if dtype is torch.bool + else ( + torch.finfo(dtype).min + if dtype.is_floating_point + else torch.iinfo(dtype).min + ) + ) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + kwargs["dtypes"] = (dtype, torch.int64) + kwargs["inner_fns"] = (x.make_loader(), lambda _: "rindex") + values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type] # next PR + if values is None: + return fallback_cummax(x, dim=axis) + return values, indices + + +@register_lowering(aten.cummin, type_promotion_kind=None) +def cummin(x, axis=None): + if len(x.get_size()) == 0: + assert axis in [0, -1] + return clone(x), empty_like(x, dtype=torch.int64) + + dtype = x.get_dtype() + combine_fn = ir.get_reduction_combine_fn( + "argmin", dtype=dtype, arg_break_ties_left=False + ) + + max_value = ( + True + if dtype is torch.bool + else ( + torch.finfo(dtype).max + if dtype.is_floating_point + else torch.iinfo(dtype).max + ) + ) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + kwargs["dtypes"] = (dtype, torch.int64) + kwargs["inner_fns"] = (x.make_loader(), lambda _: "rindex") + values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type] # next PR + if values is None: + return fallback_cummin(x, dim=axis) + return values, indices + + +@register_lowering(aten.prod) +def prod(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("prod", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + +@register_lowering(aten.any) +def reduce_any(x, dim=None, keepdim=False): + x = to_dtype(x, torch.bool) + return make_reduction("any")(x, axis=dim, keepdims=keepdim) + + +@register_lowering(aten.max, type_promotion_kind=None) +def reduce_max(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amax(x, axis=dim, keepdims=keepdim), + reduce_argmax(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amax(x, axis=None, keepdims=keepdim) + + +@register_lowering(aten.min, type_promotion_kind=None) +def reduce_min(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amin(x, axis=dim, keepdims=keepdim), + reduce_argmin(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amin(x, axis=None, keepdims=keepdim) + + +register_lowering(prims.xor_sum)(make_reduction("xor_sum")) +reduce_amax = register_lowering(aten.amax)(make_reduction("max")) +reduce_amin = register_lowering(aten.amin)(make_reduction("min")) +reduce_argmax = register_lowering(aten.argmax)( + make_reduction("argmax", override_return_dtype=torch.int64) +) +reduce_argmin = register_lowering(aten.argmin)( + make_reduction("argmin", override_return_dtype=torch.int64) +) + +add = register_pointwise( + aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or" +) + +sort_fallback = fallback_handler(aten.sort.stable, add_to_fallback_set=False) + + +@register_lowering(aten.sort.stable, type_promotion_kind=None) +def sort_stable(x, *, stable=None, dim=-1, descending=False): + if stable is None: + stable = False + + shape = x.get_size() + device = x.get_device() + dim = canonicalize_dim(len(shape), dim) + if len(shape) == 0: + return clone(x), _full(0, device, torch.int64, shape) + + dim_size = shape[dim] if len(shape) else 1 + if not V.graph.sizevars.statically_known_lt(dim_size, torch.iinfo(torch.int16).max): + return sort_fallback(x, stable=stable, dim=dim, descending=descending) + + indices = iota( + dim_size, start=0, step=1, dtype=torch.int16, device=device, requires_grad=False + ) + view_shape = [1] * len(shape) + if len(shape): + view_shape[dim] = dim_size + indices = view(indices, view_shape) + indices = expand(indices, shape) + + values, indices = ir.Sort.create( + device=device, + dtypes=(x.dtype, indices.dtype), + inner_fns=(x.make_loader(), indices.make_loader()), + size=shape, + axis=dim, + stable=stable, + descending=descending, + ) + if values is None: + return sort_fallback(x, stable=stable, dim=dim, descending=descending) + + assert indices is not None + return values, to_dtype(indices, torch.int64) + + +@register_lowering(aten.sort.default, type_promotion_kind=None) +def sort(x, dim=-1, descending=False): + return sort_stable(x, stable=False, dim=dim, descending=descending) + + +def register_pointwise_numeric(op, name=None, triton_fallback=None): + return register_pointwise( + op, + name=name, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + triton_fallback=triton_fallback, + ) + + +def register_pointwise_numeric_ldf64(op): + return register_pointwise( + op, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + use_libdevice_for_f64=True, + ) + + +exp = register_pointwise_numeric_ldf64(aten.exp) +exp2 = register_pointwise_numeric(aten.exp2) +expm1 = register_pointwise_numeric(aten.expm1) +relu = register_pointwise(aten.relu) +sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid) +sqrt = register_pointwise_numeric_ldf64(aten.sqrt) +square = register_pointwise(aten.square) +sub = register_pointwise(aten.sub, allow_alpha=True) +register_pointwise_numeric_ldf64(aten.cos) +register_pointwise_numeric_ldf64(aten.sin) +abs = register_pointwise(aten.abs) +bitwise_and = register_pointwise(aten.bitwise_and) +bitwise_left_shift = register_pointwise(aten.bitwise_left_shift) +bitwise_not = register_pointwise( + aten.bitwise_not, override_fn_when_input_bool="logical_not" +) +bitwise_or = register_pointwise(aten.bitwise_or) +bitwise_right_shift = register_pointwise(aten.bitwise_right_shift) +bitwise_xor = register_pointwise(aten.bitwise_xor) +register_pointwise_numeric(aten.lgamma) +erf = register_pointwise_numeric(aten.erf) +register_lowering( + aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +)(erf) + +register_pointwise_numeric(aten.log1p) +register_pointwise_numeric(aten.tan) +register_pointwise_numeric(aten.tanh) +register_pointwise_numeric_ldf64(aten.log) +logical_and = register_pointwise( + aten.logical_and, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_not = register_pointwise( + aten.logical_not, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_or = register_pointwise( + aten.logical_or, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_xor = register_pointwise( + aten.logical_xor, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +maximum = register_pointwise(aten.maximum) +minimum = register_pointwise(aten.minimum) +register_lowering(aten.clamp_min)(maximum) +register_lowering(aten.clamp_max)(minimum) +neg = register_pointwise(aten.neg) +abs = register_pointwise(aten.abs) +reciprocal = register_pointwise_numeric(aten.reciprocal) +register_pointwise(aten.remainder) +sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity") +register_pointwise(aten.ceil) +register_pointwise(aten.signbit, override_return_dtype=torch.bool) + +register_lowering(aten._neg_view)(neg) + +register_pointwise(aten.le, override_return_dtype=torch.bool) +register_pointwise(aten.lt, override_return_dtype=torch.bool) +register_pointwise(aten.ge, override_return_dtype=torch.bool) +gt = register_pointwise(aten.gt, override_return_dtype=torch.bool) +register_pointwise(aten.eq, override_return_dtype=torch.bool) +register_pointwise(aten.ne, override_return_dtype=torch.bool) + +register_pointwise_numeric(aten.cosh) +register_pointwise_numeric(aten.sinh) +register_pointwise_numeric(aten.acos) +register_pointwise_numeric(aten.acosh) +register_pointwise_numeric(aten.asin) +register_pointwise_numeric(aten.asinh) +register_pointwise_numeric(aten.atan2) +register_pointwise_numeric(aten.atan) +register_pointwise_numeric(aten.atanh) +register_pointwise_numeric(aten.copysign) +register_pointwise_numeric(aten.erfc) +register_pointwise_numeric(aten.erfinv) +register_pointwise_numeric(aten.hypot) +register_pointwise_numeric(aten.log10) +register_pointwise_numeric(aten.log2) +register_pointwise_numeric(aten.nextafter) + +from .codegen.common import BackendFeature, pointwise_overrides_data + + +def _get_pointwise_overrides(ns, name): + data = pointwise_overrides_data[name] + op = getattr(ns, data.name, None) + if op is None: + return + + def make_triton_fallback(op): + if data.triton is None: + return fallback_handler(op) + + if isinstance(op, torch._ops.OpOverloadPacket): + for olname in op.overloads(): + ol = getattr(op, olname) + yield ol, data.type_promotion_kind, make_triton_fallback(ol) + else: + yield op, data.type_promotion_kind, make_triton_fallback(op) + + +for name in pointwise_overrides_data: + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + aten, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + prims, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + +foreach_add_list = register_foreach_pointwise( + aten._foreach_add.List, add, allow_alpha=True +) +foreach_add_scalar = register_foreach_pointwise( + aten._foreach_add.Scalar, add, allow_alpha=True +) +register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True) +foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul) +foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul) +register_foreach_pointwise(aten._foreach_sub.List, sub) +register_foreach_pointwise(aten._foreach_sub.Scalar, sub) +register_foreach_pointwise(aten._foreach_neg.default, neg) +register_foreach_pointwise(aten._foreach_abs.default, abs) +register_foreach_pointwise(aten._foreach_pow.Scalar, pow) +register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow) +foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div) +foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div) +register_foreach_pointwise(aten._foreach_sqrt, sqrt) +register_foreach_pointwise(aten._foreach_maximum.List, maximum) +register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum) +register_foreach_pointwise(aten._foreach_minimum.List, minimum) +register_foreach_pointwise(aten._foreach_minimum.Scalar, minimum) +register_foreach_pointwise(aten._foreach_clamp_min.List, maximum) +register_foreach_pointwise(aten._foreach_clamp_min.Scalar, maximum) +register_foreach_pointwise(aten._foreach_clamp_max.List, minimum) +register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum) +register_foreach_pointwise(aten._foreach_reciprocal, reciprocal) +register_foreach_pointwise(aten._foreach_sign, sign) +register_foreach_pointwise(aten._foreach_copy, copy) + + +# these are only encountered as outputs of the graph +# reinplacing epilogue copies improves compile time +# by removing extra buffers sent to the scheduler. +def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op): + inplaceable_foreach_ops[outplace_aten_op] = aten_op + inplace_foreach_ops.add(aten_op) + + def fn(*args, **kwargs): + results = outplace_op(*args, **kwargs) + mut_results = [] + for arg, result in zip(args[0], results): + mut_results.append(mutate_to(arg, result, unsafe_alias=True)) + + return mut_results + + _register_foreach_lowering(aten_op, fn) + + +register_foreach_inplace( + aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list +) +register_foreach_inplace( + aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar +) +register_foreach_inplace( + aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list +) +register_foreach_inplace( + aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar +) +register_foreach_inplace( + aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list +) +register_foreach_inplace( + aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar +) + + +def register_inplace(aten_op, outplace_op): + @register_lowering(aten_op, type_promotion_kind=None) + def fn(*args, **kwargs): + result = outplace_op(*args, **kwargs) + result = to_dtype(result, args[0].get_dtype()) + return mutate_to(args[0], result) + + return fn + + +register_inplace(aten.add_, add) +register_inplace(aten.bitwise_and_, bitwise_and) +register_inplace(aten.bitwise_left_shift_, bitwise_left_shift) +register_inplace(aten.bitwise_not_, bitwise_not) +register_inplace(aten.bitwise_or_, bitwise_or) +register_inplace(aten.bitwise_right_shift_, bitwise_right_shift) +register_inplace(aten.bitwise_xor_, bitwise_xor) +register_inplace(aten.mul_, mul) +register_inplace(aten.div_.Tensor, div) +register_inplace(aten.div_.Tensor_mode, div_mode) +register_inplace(aten.logical_and_, logical_and) +register_inplace(aten.logical_not_, logical_not) +register_inplace(aten.logical_or_, logical_or) +register_inplace(aten.logical_xor_, logical_xor) +register_inplace(aten.sub_, sub) +register_inplace(aten.relu_, relu) +register_inplace(aten.sigmoid_, sigmoid) + + +register_lowering(aten.__and__)(bitwise_and) +register_lowering(aten.__lshift__)(bitwise_left_shift) +register_lowering(aten.__or__)(bitwise_or) +register_lowering(aten.__rshift__)(bitwise_right_shift) +register_lowering(aten.__xor__)(bitwise_xor) + +register_inplace(aten.__iand__, aten.__and__) +register_inplace(aten.__ilshift__, aten.__lshift__) +register_inplace(aten.__ior__, aten.__or__) +register_inplace(aten.__irshift__, aten.__rshift__) +register_inplace(aten.__ixor__, aten.__xor__) + + +@register_lowering(aten.sym_constrain_range) +def sym_constrain_range(a, min=None, max=None): + return None + + +@register_lowering(aten.sym_size.int) +def sym_size(a, dim): + val = V.graph.current_node.meta["val"] + # Note [Can val be an int?] + # ~~~~~~~~~~~~~~~~~~~~~~~~~ + # In principle, someone could construct an FX graph where + # a call to size/stride has a val that is a plain int (not + # SymInt). However, we will maintain the invariant that + # this is not possible: if you are constructing an FX graph + # where there is a call to size/stride that returns an + # int, but you KNOW that int must always be a constant, + # then you do not need trace that call at all (and just + # constant propagate the integer as is.) + assert isinstance(val, torch.SymInt) + return val.node.expr + + +@register_lowering(aten.sym_stride.int) +def sym_stride(a, dim): + val = V.graph.current_node.meta["val"] + # See Note [Can val be an int?] + assert isinstance(val, torch.SymInt) + return val.node.expr + + +@register_lowering(aten.sym_numel) +def sym_numel(a): + return a.get_numel() + + +for method, func in magic_methods.items(): + register_lowering(method_to_operator(method))(func) + + +@register_lowering(aten._foobar) +def foobar(self, *args, **kwargs): + raise NotImplementedError("Helpful for debugging") + + +@register_lowering(torch.ops._inductor_test.realize) +def _realize(x): + x.realize() + return clone(x) + + +@register_lowering(torch.ops.inductor.resize_storage_bytes_) +def resize_storage_bytes_(variable, new_size): + variable.realize() + ir.ResizeStorageBytes(variable, new_size) + return variable + + +@register_lowering(torch.ops.aten.set_.source_Tensor) +def set__source_tensor(self, source_tensor): + self.realize() + source_tensor.realize() + return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor)) + + +if hasattr(torch.ops.fsdp, "set_"): + + @register_lowering(torch.ops.fsdp.set_.default) + def fsdp_set_(self, source_tensor): + self.realize() + source_tensor.realize() + ir.SetSourceTensorKernel(self, source_tensor) + + +@register_lowering(torch.ops.aten.resize) +def resize(x, size, *, memory_format=None): + assert isinstance(x, TensorBox) + assert isinstance(size, (list, tuple)) + + if memory_format is None: + memory_format = torch.contiguous_format + if memory_format == torch.preserve_format: + raise RuntimeError(f"unsupported memory format: {memory_format}") + + if memory_format == torch.channels_last: + assert len(size) == 4 + if memory_format == torch.channels_last_3d: + assert len(size) == 5 + + old_numel = x.get_numel() + dtype = x.get_dtype() + device = x.get_device() + + if isinstance(x.data, ir.BaseView): + x.data = x.data.unwrap_view() + + if ( + torch.are_deterministic_algorithms_enabled() + and torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined] + ): + if is_float_dtype(dtype): + uninitalized_val = float("nan") + elif is_integer_dtype(dtype): + uninitalized_val = torch.iinfo(dtype).max + else: + uninitalized_val = True + else: + # using zero as that is what empty does + uninitalized_val = 0.0 + + if V.graph.sizevars.statically_known_equals(old_numel, 0): # type: ignore[arg-type] + return full(size, uninitalized_val, dtype=dtype, device=device) + + x_flat = as_strided( + x, + [ + old_numel, + ], + [ + 1, + ], + ) + flat_loader = x_flat.make_loader() + out_stride = ir.FlexibleLayout.stride_ordered_for_memory_format(size, memory_format) + out_indexer = ir.FixedLayout(device, dtype, size, out_stride).make_indexer() + + def inner_fn(idx): + flat_index = out_indexer(idx) + flat_index_expr = ops.index_expr(flat_index, torch.int64) + limit = ops.index_expr(old_numel, torch.int64) + mask = ops.lt(flat_index_expr, limit) + return ops.masked(mask, lambda: flat_loader([flat_index]), uninitalized_val) + + out = Pointwise.create( + device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(size) + ) + return out + + +from torch._higher_order_ops.auto_functionalize import auto_functionalized + + +make_fallback(auto_functionalized) + + +@register_lowering(triton_kernel_wrapper_mutation) +def triton_kernel_wrap_(*, kernel_idx, constant_args_idx, grid, kwargs): + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + + constant_args = kernel_side_table.get_constant_args(constant_args_idx) + ir.UserDefinedTritonKernel( + kernel_idx=kernel_idx, + grid=grid, + kernel_args={**kwargs, **constant_args}, + ) + return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)} + + +@register_lowering(torch.ops.higher_order.cond) +def cond(pred, true_fn, false_fn, operands): + if is_triton(pred) or any(map(is_triton, operands)): + msg = "control flow operator: torch.cond." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + result = ir.Conditional.create(pred, true_fn, false_fn, operands) + return list(map(TensorBox.create, result)) + + +@register_lowering(torch.ops.higher_order.while_loop) +def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): + if any(map(is_triton, carried_inputs + additional_inputs)): + msg = "control flow operator: torch.while_loop." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + result = ir.WhileLoop.create(cond_fn, body_fn, carried_inputs, additional_inputs) + return list(map(TensorBox.create, result)) + + +@register_lowering(associative_scan_op, type_promotion_kind=None) +def associative_scan(combine_fn: ir.Subgraph, input, dim: int): + from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph + + subgraph_inputs = [ + InputDescriptor(dtype=x.get_dtype(), device=x.get_device()) + for x in itertools.chain(input, input) + ] + lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated] + + def wrapped_combine_fn(lhs, rhs): + return lowered_combine_fn( + *pytree.tree_leaves(lhs), + *pytree.tree_leaves(rhs), + ) + + kwargs = _make_scan_inner(input[0], axis=dim, dtype=None) + kwargs["dtypes"] = tuple(x.get_dtype() for x in input) + kwargs["inner_fns"] = tuple(x.make_loader() for x in input) + result = ir.Scan.create( + combine_fn=wrapped_combine_fn, + can_fallback_to_aten=False, + **kwargs, + ) + if result[0] is None: + raise RuntimeError("Unable to generate code for associative_scan op") + return result + + +@register_lowering(torch.ops.prims._sink_tokens.default) +def _sink_tokens(tokens): + return None + + +@register_lowering(torch.ops.higher_order.with_effects) +def with_effects(token, op, *args, **kwargs): + result = ir.EffectfulKernel.create(op, *args, **kwargs) + + from torch._higher_order_ops.effects import get_effect_key + + effect_type = get_effect_key(op, args, kwargs) + assert effect_type is not None + effectful_kernel = V.graph.effectful_ops[effect_type] + + if result is None: + return (effectful_kernel,) + + result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result) + if not isinstance(result, (list, tuple)): + return (effectful_kernel, result) + else: + return (effectful_kernel, *result) + + +try: + import torch.distributed._functional_collectives + + _c10d_functional = torch.ops._c10d_functional + + @register_lowering(_c10d_functional.all_reduce) + def _all_reduce(inp, reduce_op, group_name): + inp = clone(inp) + if config.reorder_for_compute_comm_overlap: + # The horizontal fusion of this clone often severely delays the + # scheduling of the all_reduce_ node. Horizontally fusing this + # clone can almost never out-perform scheduling the all_reduce_ + # earlier. Also in most cases, this clone is eliminated via + # in-place reuse. Therefore, we tell the scheduler to not fuse it. + inp.realize() + V.graph.no_fuse_buffer_names.add(inp.get_name()) + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_.default, inp, reduce_op, group_name + ) + return inp + + @register_lowering(_c10d_functional.all_reduce_) + def _all_reduce_(inp, reduce_op, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_.default, inp, reduce_op, group_name + ) + return inp + + @register_lowering(_c10d_functional.all_reduce_coalesced) + def _all_reduce_coalesced(inputs, reduce_op, group_name): + inputs = [clone(inp) for inp in inputs] + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + @register_lowering(_c10d_functional.all_reduce_coalesced_) + def _all_reduce_coalesced_(inputs, reduce_op, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + @register_lowering(_c10d_functional.all_gather_into_tensor) + def _all_gather_into_tensor(inp, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_gather_into_tensor.default, + inp, + group_size, + group_name, + ) + ) + + @register_lowering(_c10d_functional.all_gather_into_tensor_coalesced) + def _all_gather_into_tensor_coalesced(inputs, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_gather_into_tensor_coalesced.default, + inputs, + group_size, + group_name, + ), + ) + + @register_lowering(_c10d_functional.all_gather_into_tensor_out) + def _all_gather_into_tensor_out(inp, group_size, group_name, *, out): + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_gather_into_tensor_out.default, + inp, + group_size, + group_name, + out=out, + ) + return out + + @register_lowering(_c10d_functional.reduce_scatter_tensor) + def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.reduce_scatter_tensor.default, + inp, + reduce_op, + group_size, + group_name, + ) + ) + + @register_lowering(_c10d_functional.reduce_scatter_tensor_coalesced) + def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.reduce_scatter_tensor_coalesced.default, + inputs, + reduce_op, + group_size, + group_name, + ), + ) + + @register_lowering(_c10d_functional.all_to_all_single) + def _all_to_all_single(inp, output_split_sizes, input_split_sizes, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_to_all_single.default, + inp, + output_split_sizes, + input_split_sizes, + group_name, + ) + ) + + @register_lowering(_c10d_functional.broadcast) + def _broadcast(inp, src, group_name): + inp = clone(inp) + ir._CollectiveKernel.create_inplace( + _c10d_functional.broadcast_.default, inp, src, group_name + ) + return inp + + @register_lowering(_c10d_functional.broadcast_) + def _broadcast_(inp, src, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.broadcast_.default, inp, src, group_name + ) + return inp + + @register_lowering(_c10d_functional.wait_tensor) + def _wait_tensor(inp): + ir._WaitKernel.create_wait(_c10d_functional.wait_tensor.default, inp) + return inp + + @register_lowering(torch.ops._dtensor.shard_dim_alltoall) + def _shard_dim_alltoall(inp, gather_dim, shard_dim, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + torch.ops._dtensor.shard_dim_alltoall.default, + inp, + gather_dim, + shard_dim, + group_name, + ) + ) + +except (AttributeError, ImportError): + log.info( + "Inductor support for distributed collectives depends on building torch.distributed" + ) + +# populate lowerings defined in kernel/* +from . import kernel + + +import_submodule(kernel) + +from . import quantized_lowerings + + +quantized_lowerings.register_quantized_ops() +quantized_lowerings.register_woq_mm_ops() + +from . import mkldnn_lowerings + + +mkldnn_lowerings.register_onednn_fusion_ops() + +from . import jagged_lowerings + + +jagged_lowerings.register_jagged_ops() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/metrics.py b/.venv/lib/python3.11/site-packages/torch/_inductor/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..fe77279800e3da5ae581cb684c340726b9b3827d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/metrics.py @@ -0,0 +1,436 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import csv +import dataclasses +import inspect +import os +import re +from dataclasses import dataclass +from functools import lru_cache +from typing import Dict, List, Set, Tuple, TYPE_CHECKING + +from torch._inductor import config +from torch._inductor.utils import get_benchmark_name + + +# Prevent circular import +if TYPE_CHECKING: + from torch._inductor.scheduler import BaseSchedulerNode + +# counter for tracking how many kernels have been generated +generated_kernel_count = 0 +generated_cpp_vec_kernel_count = 0 +num_bytes_accessed = 0 +nodes_num_elem: List[ + Tuple[ + BaseSchedulerNode, + int, + ] +] = [] +node_runtimes: List[Tuple[BaseSchedulerNode, float]] = [] + +# counters for tracking fusions +ir_nodes_pre_fusion = 0 + +# counters for tracking to_dtype inserted +cpp_to_dtype_count = 0 + + +@dataclasses.dataclass +class CppOuterLoopFusedCount: + inner_kernel_number: int + local_buffer_number: int = 0 + + +# The length counts the number of outer loop fusions. +cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = [] + +num_comprehensive_padding = 0 +num_matches_for_scatter_upon_const_tensor = 0 + +num_loop_reordering = 0 + + +# reset all counters +def reset(): + global generated_kernel_count + global generated_cpp_vec_kernel_count + global num_bytes_accessed, nodes_num_elem + global ir_nodes_pre_fusion + global cpp_to_dtype_count + global cpp_outer_loop_fused_inner_counts + global num_comprehensive_padding + global num_matches_for_scatter_upon_const_tensor + global num_loop_reordering + + generated_kernel_count = 0 + generated_cpp_vec_kernel_count = 0 + num_bytes_accessed = 0 + nodes_num_elem.clear() + node_runtimes.clear() + ir_nodes_pre_fusion = 0 + cpp_to_dtype_count = 0 + cpp_outer_loop_fused_inner_counts.clear() + num_comprehensive_padding = 0 + num_matches_for_scatter_upon_const_tensor = 0 + num_loop_reordering = 0 + + +@dataclass +class CachedMetricsDeltas: + """ + The subset of metrics we want update across cache hits, e.g., the + FxGraphCache. + """ + + generated_kernel_count: int + generated_cpp_vec_kernel_count: int + ir_nodes_pre_fusion: int + cpp_to_dtype_count: int + num_bytes_accessed: int + num_matches_for_scatter_upon_const_tensor: int + + +def get_metric_fields(): + return [field.name for field in dataclasses.fields(CachedMetricsDeltas)] + + +class CachedMetricsHelper: + """ + A helper class to help calculate and apply counter deltas for those + metrics we want to save with cache entries (e.g., FxGraphCache) and + apply on a cache hit. + """ + + def __init__(self) -> None: + self.cached_metrics = {} + for metric in get_metric_fields(): + self.cached_metrics[metric] = globals()[metric] + + def get_deltas(self) -> CachedMetricsDeltas: + delta_metrics = {} + for metric in get_metric_fields(): + delta_metrics[metric] = globals()[metric] - self.cached_metrics[metric] + + return CachedMetricsDeltas(**delta_metrics) + + @staticmethod + def apply_deltas(delta: CachedMetricsDeltas): + for metric in get_metric_fields(): + globals()[metric] += getattr(delta, metric) + + +REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {} + + +@dataclass +class MetricTable: + table_name: str + column_names: List[str] + + num_rows_added: int = 0 + + def add_row(self, row_fn): + if self.table_name not in enabled_metric_tables(): + return + + row_dict = row_fn() + assert len(self.column_names) == len( + row_dict + ), f"{len(self.column_names)} v.s. {len(row_dict)}" + assert set(self.column_names) == set( + row_dict.keys() + ), f"{set(self.column_names)} v.s. {set(row_dict.keys())}" + + row = [ + get_benchmark_name(), + ] + row += [row_dict[column_name] for column_name in self.column_names] + self._write_row(row) + + def output_filename(self): + return f"metric_table_{self.table_name}.csv" + + def write_header(self): + filename = self.output_filename() + with open(filename, "w") as fd: + writer = csv.writer(fd, lineterminator="\n") + writer.writerow(["model_name"] + self.column_names) + + def _write_row(self, row): + filename = self.output_filename() + if self.num_rows_added == 0 and not os.path.exists(filename): + self.write_header() + + self.num_rows_added += 1 + + for idx, orig_val in enumerate(row): + if isinstance(orig_val, float): + new_val = f"{orig_val:.6f}" + elif orig_val is None: + new_val = "" + else: + new_val = orig_val + row[idx] = new_val + + with open(filename, "a") as fd: + writer = csv.writer(fd, lineterminator="\n") + writer.writerow(row) + + @staticmethod + def register_table(name, column_names): + table = MetricTable(name, column_names) + REGISTERED_METRIC_TABLES[name] = table + + +MetricTable.register_table( + "slow_fusion", + [ + "kernel1_path", + "kernel1_latency", + "kernel2_path", + "kernel2_latency", + "fused_kernel_path", + "fused_kernel_latency", + "slow_down_ratio", + ], +) + +# track the fusion statistics for each graph +MetricTable.register_table( + "graph_stats", + [ + "graph_id", + "num_nodes_before_fusion", + "num_nodes_after_fusion", + ], +) + +# track the perf difference between persistent reduction and non-persistent +# reductions +MetricTable.register_table( + "persistent_red_perf", + [ + "kernel1_name", + "kernel2_name", + "kernel1_latency", + "kernel2_latency", + "size_hints", + "reduction_hint", + "speedup", + ], +) + +# Log the fusion failures due to indexing mismatch +MetricTable.register_table( + "fusion_failure_due_to_indexing_mismatch", + [ + "pre_grad_graph_id", + "post_grad_graph_id", + "node1_name", + "node2_name", + "node1_debug_str", + "node2_debug_str", + "common_buffer_names", + "failure_reason", + ], +) + +# Log metadata for pointwise/reduction kernels. E.g., model name, kernel path, numel, rnumel, reduction hint +MetricTable.register_table( + "kernel_metadata", + [ + "kernel_name", + "kernel_path", + "kernel_category", # pointwise/reduction/foreach etc. + "size_hints", + "reduction_hint", + "line_of_code", + "num_load", + "num_store", + "num_for_loop", + "num_atomic_add", + "num_args", + # xyz numel can be different to size_hints since size_hints are rounded + # up to the nearest power of 2. + # Inductor kernel will burn in the xyz numel in kernel code for static + # shape kernels. + # Logging them will be helpful to find unaligned shape for reduction + "xnumel", + "ynumel", + "rnumel", + "kernel_args_num_gb", + ], +) + + +def _parse_kernel_fn_code(kernel_module_code): + """ + The kernel_module_code is the python module that contains kernel function code. + kernel function is the proper triton kernel function annotated with + @triton.jit + """ + from .codecache import PyCodeCache + from .wrapper_benchmark import get_triton_kernel + + mod = PyCodeCache.load(kernel_module_code) + kernel = get_triton_kernel(mod) + # kernel is a CachingAutotune; kernel.fn is the JITFunction; + # kernel.fn.fn is the function being decorate by triton.jit + return inspect.getsource(kernel.fn.fn) + + +def _parse_kernel_line_of_code(proper_kernel_fn_code): + """ + Return the line of code for the kernel excluding the decorators. + """ + return len(proper_kernel_fn_code.splitlines()) + + +def _parse_size_hints(kernel_module_code, kernel_category): + if kernel_category == "foreach": + # foreach kernel does not have size_hints + return None + m = re.search(r"size_hints=(\[[0-9, ]*\]),", kernel_module_code) + assert m, "size_hints missing!" + return m.group(1) + + +def _parse_reduction_hint(kernel_category, kernel_module_code): + if kernel_category not in ("reduction", "persistent_reduction"): + return None + m = re.search(r"reduction_hint=ReductionHint\.(\w*),", kernel_module_code) + assert m, "reduction_hint not found in kernel source code!" + return m.group(1) + + +def _count_pattern(proper_kernel_fn_code, pattern): + return proper_kernel_fn_code.count(pattern) + + +def _count_args(proper_kernel_fn_code): + def_line = proper_kernel_fn_code.splitlines()[0] + assert def_line.startswith("def ") + start_idx = def_line.index("(") + end_idx = def_line.index("):") + decl_csv = def_line[start_idx + 1 : end_idx] + comps = decl_csv.split(",") + return len(comps) + + +def _parse_proper_kernel_fn_code(kernel_fn_code): + """ + Skip decorators. + """ + start_pos = kernel_fn_code.index("def ") + return kernel_fn_code[start_pos:] + + +def _parse_numel(proper_kernel_fn_code, numel_arg_name): + m = re.search(f"{numel_arg_name} = ([\\d]+)", proper_kernel_fn_code) + if m: + return int(m.group(1)) + else: + return None + + +def _parse_kernel_args_num_gb(kernel_fn_code, kernel_category): + """ + inductor meta looks like: + inductor_meta={... 'mutated_arg_names': [], 'no_x_dim': False, 'kernel_num_gb': 2.0}, + """ + m = re.search(r".kernel_num_gb.:\s*([0-9.]+)", kernel_fn_code) + if m: + return float(m.group(1)) + else: + """ + There are a few cases that kernel_num_gdb field can be missing: + 1. the field will be missing if config.benchmark_kernel and + config.profile_bandwidth are false + 2. even if config.benchmark_kernel or config.profile_bandwidth is true. + foreach kernel does not have kernel_num_gb field in the metadata + """ + return None + + +def log_kernel_metadata(kernel_name, kernel_path, kernel_module_code): + """ + An utility to log kernel metadata. We may parse metadata from kernel source code here. + + It's fine to parse the generated kernel code here since the logging is + disabled by default. It would hurt compilation time. + """ + from .wrapper_benchmark import get_kernel_category_by_source_code + + kernel_category = get_kernel_category_by_source_code(kernel_module_code) + reduction_hint = _parse_reduction_hint(kernel_category, kernel_module_code) + size_hints = _parse_size_hints(kernel_module_code, kernel_category) + kernel_fn_code = _parse_kernel_fn_code(kernel_module_code) + + proper_kernel_fn_code = _parse_proper_kernel_fn_code(kernel_fn_code) + + # the line of code excluding the decortors + kernel_line_of_code = _parse_kernel_line_of_code(proper_kernel_fn_code) + + get_metric_table("kernel_metadata").add_row( + lambda: { + "kernel_name": kernel_name, + "kernel_path": kernel_path, + "kernel_category": kernel_category, + "size_hints": size_hints, + "reduction_hint": reduction_hint, + "line_of_code": kernel_line_of_code, + "num_load": _count_pattern(proper_kernel_fn_code, "tl.load"), + "num_store": _count_pattern(proper_kernel_fn_code, "tl.store"), + "num_for_loop": _count_pattern(proper_kernel_fn_code, "for "), + "num_atomic_add": _count_pattern(proper_kernel_fn_code, "tl.atomic_add"), + "num_args": _count_args(proper_kernel_fn_code), + "xnumel": _parse_numel(proper_kernel_fn_code, "xnumel"), + "ynumel": _parse_numel(proper_kernel_fn_code, "ynumel"), + "rnumel": _parse_numel(proper_kernel_fn_code, "rnumel"), + "kernel_args_num_gb": _parse_kernel_args_num_gb( + kernel_fn_code, kernel_category + ), + } + ) + + +def purge_old_log_files(): + """ + Purge the old log file at the beginning when the benchmark script runs. + Should do it in the parent process rather than the child processes running + each individual model. + """ + for name, table in REGISTERED_METRIC_TABLES.items(): + if name in enabled_metric_tables(): + filename = table.output_filename() + if os.path.exists(filename): + os.unlink(filename) + + table.write_header() + + +@lru_cache +def enabled_metric_tables() -> Set[str]: + config_str = config.enabled_metric_tables + + enabled = set() + for name in config_str.split(","): + name = name.strip() + if not name: + continue + assert ( + name in REGISTERED_METRIC_TABLES + ), f"Metric table name {name} is not registered" + enabled.add(name) + return enabled + + +def is_metric_table_enabled(name): + return name in enabled_metric_tables() + + +def get_metric_table(name): + assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined" + return REGISTERED_METRIC_TABLES[name] diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_ir.py b/.venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_ir.py new file mode 100644 index 0000000000000000000000000000000000000000..12634c632b80a9e6c9dd52ea9b8ddf2b190b1943 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_ir.py @@ -0,0 +1,1881 @@ +# mypy: allow-untyped-defs +from typing import Any, List, Optional + +import sympy + +import torch +from torch._prims_common import make_channels_last_strides_for +from torch.utils._ordered_set import OrderedSet + +from .ir import ( + ExternKernelAlloc, + FixedLayout, + FlexibleLayout, + ir_node_to_tensor, + IRNode, + is_contiguous_storage_and_layout, + Layout, + may_convert_to_optional, + MultiOutput, + MultiOutputLayout, + MutationOutput, + NoneLayout, + TensorBox, +) +from .utils import convert_shape_to_inductor, pad_listlike +from .virtualized import V + + +def _prepare_convolution_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding: List[int], + stride: List[int], + dilation: List[int], + groups: int, + transposed: bool = False, + output_padding: Optional[List[int]] = None, +): + """ + This function is a helper function to prepare inputs, layout and constant args + for convolution post-op fusion's create function, including deciding the output + layout (channels first or channels last), realizing inputs and make them etc. The + function only supports the CPU device since conv post-op fusion kernel is only + supported on CPU right now. + """ + + # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size + def _conv_input_size( + output_size, weight_size, padding, output_padding, stride, dilation, groups + ): + assert len(output_size) == len(weight_size), "Expect input dim == weight dim" + dim = len(output_size) + assert dim > 2, "Expect input dim > 2" + + BATCH_DIM = 0 + WEIGHT_INPUT_CHANNELS_DIM = 1 + input_size = [] + input_size.append(output_size[BATCH_DIM]) + input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups) + for d in range(2, dim): + kernel = (weight_size[d] - 1) * dilation[d - 2] + 1 + input_size_d = ( + (output_size[d] - 1) * stride[d - 2] + - (padding[d - 2] * 2) + + kernel + + output_padding[d - 2] + ) + input_size.append(input_size_d) + return list(map(int, input_size)) + + # The size of prepacked_weight is the prepacked weight size of deconv: + # Groups > 1: [g*o, i/g, ...] + # Groups == 1: [o, i, ...] + # Returns original weight size in [i, o, ...] + def _original_deconv_weight_size( + prepacked_weight, + groups, + ): + prepacked_weight_size = prepacked_weight.size() + dim = len(prepacked_weight_size) + assert dim > 2, "Expect weight dim > 2" + if groups > 1: + weight_size = [] + weight_size.append(prepacked_weight_size[1] * groups) + weight_size.append(prepacked_weight_size[0] / groups) + for d in range(2, dim): + weight_size.append(prepacked_weight_size[d]) + else: + weight_size = prepacked_weight.transpose(0, 1).size() + return weight_size + + x.realize() + weight.realize() + if bias is not None: + bias.realize() + with V.graph.fake_mode: + # TODO cleaned up the fake_tensor trace as Linear implementation + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + dims = len(x_fake.size()) - 2 + assert 0 < len(padding) <= dims + assert 0 < len(dilation) <= dims + assert 0 < len(stride) <= dims + padding = pad_listlike(padding, dims) + dilation = pad_listlike(dilation, dims) + stride = pad_listlike(stride, dims) + if output_padding is None: + output_padding = pad_listlike([0], dims) + else: + assert 0 < len(output_padding) <= dims + output_padding = pad_listlike(output_padding, dims) + assert isinstance(groups, (int, sympy.core.numbers.Integer)) + if transposed: + # When transposed, the size of the prepacked oneDNN weight is different + # from the PyTorch weight. We're not able to run aten conv with such + # size. We infer the output size from the input params here: + weight_size = _original_deconv_weight_size(weight_fake, groups) + input_size = x_fake.size() + output_size = _conv_input_size( + input_size, + weight_size, + padding, + output_padding, + stride, + dilation, + groups, + ) + else: + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.convolution( + x_fake, + weight_fake, + bias_fake, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + output_size = output.size() + + req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) + req_stride_order = [len(req_stride_order)] + req_stride_order + + x = cls.require_stride_order(x, req_stride_order) + + # We won't do weight prepack for Conv if dynamic_shapes. + # In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel. + # In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1), + # x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order + # won't change the stride of this tensor since stride for dimensions of size 1 is ignored. While in Conv kernel, + # this tensor is considered as channels first and the output will be in contiguous format. + # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. + dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) + if dynamic_shapes and is_contiguous_storage_and_layout(x): + output_stride = FlexibleLayout.contiguous_strides(output_size) + else: + output_stride = make_channels_last_strides_for(output_size) + + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] + + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + convert_shape_to_inductor(output_size), + convert_shape_to_inductor(output_stride), + ) + constant_args = [padding, stride, dilation, groups] + if transposed: + constant_args.insert(1, output_padding) + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order + + +def _prepare_linear_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", +): + """ + This function is a helper function to prepare inputs, layout and constant args + for linear post-op fusion's create function. The function only supports the CPU device + since linear post-op fusion kernel is only supported on CPU right now. + """ + x.realize() + weight.realize() + if bias is not None: + bias.realize() + + *m, _ = x.get_size() + # The weight has been transposed during the qlinear weight prepack process. + # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/ + # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291 + _, oc = weight.get_size() + output_size = list(m) + [oc] + req_stride_order = list(reversed(range(len(x.get_size())))) + + x = cls.require_stride_order(x, req_stride_order) + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] + + output_stride = FlexibleLayout.contiguous_strides(output_size) + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ) + constant_args: List[Any] = [] + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order + + +class ConvolutionUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_pointwise.default, + ) + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const std::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view attr, + torch::List> scalars, + std::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + op_overload=self.op_overload, + raw_args=[*self.inputs, *self.constant_args], + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + attr, + scalars: Optional[List[Any]], + algorithm, + ): + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + return ConvolutionUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class ConvolutionBinary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + cpp_constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_pointwise.binary, + ) + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& other_t, + const at::Tensor& weight_t, + const std::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm)""" + self.cpp_constant_args = cpp_constant_args + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + self.op_overload, + [*self.inputs, *self.constant_args], + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + return ConvolutionBinary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class ConvolutionBinaryInplace(ExternKernelAlloc): + def __init__( + self, + kernel_layout, + inputs, + constant_args=(), + ) -> None: + # Due to constrain of op.call, other (Tensor&) should be at input[0] + reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] + + super().__init__( + kernel_layout, + reordered_inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_pointwise_.binary, + ) + # TODO: op.call: input[0] should be at::Tensor& + self.cpp_op_schema = """ + at::Tensor&( + at::Tensor& other_t, + const at::Tensor& input_t, + const at::Tensor& weight_t, + const std::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm)""" + + self.mutation_outputs = [ + MutationOutput(NoneLayout(inputs[0].get_device()), inputs[0], self), + MutationOutput(NoneLayout(inputs[1].get_device()), inputs[1], self), + ] + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + self.op_overload, + [*self.inputs, *self.constant_args], + ) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + _, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + packed = ConvolutionBinaryInplace( + kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type] + inputs=inputs, + constant_args=constant_args, + ) + # This op mutates in place which means that the result is not the + # target but rather the input that is being mutated + # init reorders the inputs, so inputs[1] becomes packed.inputs[0] + return packed.inputs[0] + + +class ConvolutionTransposeUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_transpose_pointwise.default, + ) + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const std::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef output_padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view attr, + torch::List> scalars, + std::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + output_padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups_: int, + attr, + scalars: Optional[List[Any]], + algorithm, + ): + transposed = True + ( + inputs, + constant_args, + kernel_layout, + _, + ) = _prepare_convolution_fusion_create( + cls, + x, + weight, + bias, + padding_, + stride_, + dilation_, + groups_, + transposed, + output_padding_, + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + return ConvolutionTransposeUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class QConvPointWisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 5 + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.onednn.qconv2d_pointwise.default, + ) + self.cpp_op_schema = """ + at::Tensor( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + c10::string_view attr, + torch::List> scalars, + std::optional algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation + args = [x.codegen_reference() for x in self.inputs] + const_arg_names = [ + "x_scale", + "x_zero_point", + "stride", + "padding", + "dilation", + "groups", + "output_scale", + "output_zero_point", + "output_dtype", + "attr", + "scalars", + "algorithm", + ] + if not self.has_bias: + const_arg_names.insert(2, "bias") + const_args = list(self.codegen_const_args(const_arg_names)) + + x = args[0] + x_raw = self.inputs[0] + packed_weight = args[1] + packed_weight_raw = self.inputs[1] + bias = args[2] if self.has_bias else const_args[2] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[2] + w_scale, w_zp = args[-2], args[-1] + w_scale_raw, w_zp_raw = self.inputs[-2], self.inputs[-1] + ( + x_scale, + x_zp, + ) = const_args[:2] + ( + x_scale_raw, + x_zp_raw, + ) = self.constant_args[:2] + ( + stride, + padding, + dilation, + groups, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-10:] + ( + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-10:] + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + bias_raw, + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + op_overload=self.op_overload, + raw_args=raw_args, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: float, + x_zero_point: int, + qw: "TensorBox", # qw + w_scale: "TensorBox", + w_zero_point: "TensorBox", + bias: "TensorBox", + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, + output_scale: float, + output_zero_point: int, + output_dtype, + attr, + scalars, + algorithm, + ): + transposed = False + output_padding = None + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( + cls, + qx, + qw, + bias, + padding, + stride, + dilation, + groups, + transposed, + output_padding, + ) + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + + constant_args = ( + [ + x_scale, + x_zero_point, + ] + + constant_args + + [ + output_scale, + output_zero_point, + output_dtype, + attr, + may_convert_to_optional(scalars), + algorithm, + ] + ) + + assert output_dtype is not None + if output_dtype in [torch.float32, torch.bfloat16]: + # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8. + kernel_layout.dtype = output_dtype + + return QConvPointWisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class QConvPointWiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + """ + Needs input/weight/output qparams + if bias is not None + - inputs = [x, w, b, accum, w_scale, w_zp] + - const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, accum, w_scale, w_zp] + - const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale, + accum_zp, o_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 6 + self.idx_for_inplace_sum = 3 if self.has_bias else 2 + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.onednn.qconv2d_pointwise.binary, + ) + self.cpp_op_schema = """ + at::Tensor( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor accum, + double accum_scale, + int64_t accum_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + c10::string_view binary_attr, + std::optional alpha, + std::optional attr, + torch::List> scalars, + std::optional algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation + args = [x.codegen_reference() for x in self.inputs] + const_arg_names = [ + "x_scale", + "x_zero_point", + "accum_scale", + "accum_zero_point", + "stride", + "padding", + "dilation", + "groups", + "output_scale", + "output_zero_point", + "output_dtype", + "binary_attr", + "alpha", + "unary_attr", + "unary_scalars", + "unary_algorithm", + ] + if not self.has_bias: + const_arg_names.insert(4, "bias") + const_args = list(self.codegen_const_args(const_arg_names)) + + x = args[0] + x_raw = self.inputs[0] + packed_weight = args[1] + packed_weight_raw = self.inputs[1] + bias = args[2] if self.has_bias else const_args[4] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[4] + accum, w_scale, w_zp = args[-3], args[-2], args[-1] + accum_raw, w_scale_raw, w_zp_raw = ( + self.inputs[-3], + self.inputs[-2], + self.inputs[-1], + ) + ( + x_scale, + x_zp, + accum_scale, + accum_zp, + ) = const_args[:4] + ( + x_scale_raw, + x_zp_raw, + accum_scale_raw, + accum_zp_raw, + ) = self.constant_args[:4] + ( + stride, + padding, + dilation, + groups, + o_scale, + o_zp, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-12:] + ( + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-12:] + conv_args = ( + x, + x_scale, + x_zp, + accum, + accum_scale, + accum_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_scale, + o_zp, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + accum_raw, + accum_scale_raw, + accum_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + bias_raw, + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + conv_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + op_overload=self.op_overload, + raw_args=raw_args, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_mutation_names(self): + return [self.inputs[self.idx_for_inplace_sum].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale, + x_zero_point, + qaccum: "TensorBox", + accum_scale, + accum_zero_point, + qw: "TensorBox", # packed_weight + w_scale, + w_zero_point, + bias: "TensorBox", + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, + output_scale: "TensorBox", + output_zero_point: "TensorBox", + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + transposed = False + output_padding = None + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, + qx, + qw, + bias, + padding, + stride, + dilation, + groups, + transposed, + output_padding, + ) + + qaccum = cls.require_stride_order(qaccum, req_stride_order) + inputs.append(qaccum) + + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + constant_args = ( + [ + x_scale, + x_zero_point, + accum_scale, + accum_zero_point, + ] + + constant_args + + [ + output_scale, + output_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + ) + + assert ( + binary_attr == "sum" + ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." + + V.graph.mark_buffer_mutated(qaccum.get_name()) + packed = QConvPointWiseBinaryPT2E( + layout=NoneLayout(qaccum.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + + # Return accum since it has been inplace changed. + return packed.inputs[packed.idx_for_inplace_sum] + + +class MKLPackedLinear(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkl._mkl_linear.default, + ) + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& self, + const at::Tensor& mkl_weight_t, + const at::Tensor& origin_weight_t, + const std::optional& bias_opt, + const int64_t prepack_batch_size)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create(cls, x, packed_w, orig_w, B, batch_size): + x = cls.require_stride1(cls.realize_input(x)) + orig_w = cls.require_stride1(cls.realize_input(orig_w)) + *m, _ = x.get_size() + oc, _ = orig_w.get_size() + output_size = list(m) + [oc] + output_stride = FlexibleLayout.contiguous_strides(output_size) + inputs = [x, packed_w, orig_w] + constant_args = [batch_size] + if B is not None: + inputs += [B] + else: + constant_args.insert(0, None) + + return MKLPackedLinear( + layout=FixedLayout( + x.get_device(), x.get_dtype(), output_size, output_stride + ), + inputs=inputs, + constant_args=constant_args, + ) + + +class LinearUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._linear_pointwise.default, + ) + self.cpp_kernel_key = "linear_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const std::optional& bias_opt, + c10::string_view attr, + torch::List> scalars, + std::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create(cls, x, w, B, attr, scalars, algorithm): + x = cls.require_contiguous(cls.realize_input(x)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + inputs = [x, w] + constant_args = [attr, scalars if scalars else [-1], algorithm] + if B is not None: + B = cls.require_contiguous(cls.realize_input(B)) + inputs.append(B) + else: + constant_args.insert(0, None) + + return LinearUnary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + ) + + def apply_constraint(self): + pass + + +class LinearBinary(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._linear_pointwise.binary" + + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._linear_pointwise.binary, + ) + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& other_t, + const at::Tensor& weight_t, + const std::optional& bias_opt, + c10::string_view attr) + """ + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + + @classmethod + def create(cls, x, y, w, B, attr): + x = cls.require_contiguous(cls.realize_input(x)) + y = cls.require_contiguous(cls.realize_input(y)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + + inputs = [x, y, w] + constant_args = [attr] + if B is not None: + B = cls.require_contiguous(cls.realize_input(B)) + inputs.append(B) + else: + constant_args.insert(0, B) + + return LinearBinary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + ) + + def apply_constraint(self): + pass + + +class QLinearPointwisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + x_scale_zp_are_tensors=False, + ) -> None: + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + self.x_scale_zp_are_tensors = x_scale_zp_are_tensors + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default, + ) + x_scale_type_str, x_zp_type_str = ( + ("at::Tensor", "at::Tensor") + if x_scale_zp_are_tensors + else ("double", "int64_t") + ) + self.cpp_op_schema = f""" + at::Tensor( + at::Tensor act, + {x_scale_type_str} act_scale, + {x_zp_type_str} act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + c10::string_view post_op_name, + torch::List> post_op_args, + c10::string_view post_op_algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + x_raw = self.inputs[0] + packed_weight = args[1] + packed_weight_raw = self.inputs[1] + bias = args[2] if self.has_bias else const_args[0] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0] + w_scale, w_zp = args[-2], args[-1] + w_scale_raw, w_zp_raw = self.inputs[-2], self.inputs[-1] + if self.x_scale_zp_are_tensors: + assert len(args) >= 4 + x_scale, x_zp = args[-4], args[-3] + x_scale_raw, x_zp_raw = self.inputs[-4], self.inputs[-3] + ( + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-6:] + ( + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-6:] + else: + assert len(const_args) >= 8 + ( + x_scale, + x_zp, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-8:] + ( + x_scale_raw, + x_zp_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-8:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + bias_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + self.op_overload, + raw_args, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: float, + x_zero_point: int, + qw: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zero_point: "TensorBox", + bias: "TensorBox", + output_scale: float, + output_zero_point: int, + output_dtype, + post_op_name, + post_op_args, + post_op_algorithm, + ): + (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create( + cls, + qx, + qw, + bias, + ) + + if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox): + x_scale.realize() + x_zero_point.realize() + inputs = inputs + [x_scale, x_zero_point] + x_scale_zp_are_tensors = True + else: + assert isinstance(x_scale, float) and isinstance(x_zero_point, int) + constant_args = constant_args + [x_scale, x_zero_point] + x_scale_zp_are_tensors = False + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + post_op_name, + may_convert_to_optional(post_op_args), + post_op_algorithm, + ] + + assert output_dtype is not None + if output_dtype in [torch.float32, torch.bfloat16]: + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + + +class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + x_scale_zp_are_tensors=False, + ) -> None: + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp, x2] + - const_args is: [x_scale, x_zp, o_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp, x2] + - const_args is: [bias, x_scale, x_zp, o_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + self.x_scale_zp_are_tensors = x_scale_zp_are_tensors + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary, + ) + x_scale_type_str, x_zp_type_str = ( + ("at::Tensor", "at::Tensor") + if x_scale_zp_are_tensors + else ("double", "int64_t") + ) + self.cpp_op_schema = f""" + at::Tensor( + at::Tensor act, + {x_scale_type_str} act_scale, + {x_zp_type_str} act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional other, + std::optional bias, + double inv_output_scale, + int64_t output_zero_point, + std::optional output_dtype, + double other_scale, + int64_t other_zero_point, + c10::string_view binary_post_op, + double binary_alpha, + c10::string_view unary_post_op, + torch::List> unary_post_op_args, + c10::string_view unary_post_op_algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + x_raw = self.inputs[0] + packed_weight = args[1] + packed_weight_raw = self.inputs[1] + bias = args[2] if self.has_bias else const_args[0] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0] + w_scale, w_zp, other = args[-3], args[-2], args[-1] + w_scale_raw, w_zp_raw, other_raw = ( + self.inputs[-3], + self.inputs[-2], + self.inputs[-1], + ) + if self.x_scale_zp_are_tensors: + assert len(args) >= 5 + x_scale, x_zp = args[-5], args[-4] + x_scale_raw, x_zp_raw = self.inputs[-5], self.inputs[-4] + ( + o_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-10:] + ( + o_scale_raw, + o_zp_raw, + output_dtype_raw, + other_scale_raw, + other_zp_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-10:] + else: + assert len(const_args) >= 8 + ( + x_scale, + x_zp, + o_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-12:] + ( + x_scale_raw, + x_zp_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + other_scale_raw, + other_zp_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-12:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + other, + bias, + o_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + other_raw, + bias_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + other_scale_raw, + other_zp_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + self.op_overload, + raw_args, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_mutation_names(self): + binary_post_op = self.constant_args[-5] + if binary_post_op == "sum": + return [self.inputs[-1].get_name()] + else: + return [] + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: float, + x_zero_point: int, + qw: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zero_point: "TensorBox", + other: "TensorBox", + bias: "TensorBox", + output_scale: float, + output_zero_point: int, + output_dtype, + other_scale, + other_zp, + binary_post_op, + binary_alpha, + unary_post_op, + unary_post_op_args, + unary_post_op_algorithm, + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_linear_fusion_create( + cls, + qx, + qw, + bias, + ) + + if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox): + x_scale.realize() + x_zero_point.realize() + inputs = inputs + [x_scale, x_zero_point] + x_scale_zp_are_tensors = True + else: + assert isinstance(x_scale, float) and isinstance(x_zero_point, int) + constant_args = constant_args + [x_scale, x_zero_point] + x_scale_zp_are_tensors = False + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + if binary_post_op == "sum": + other = cls.require_stride_order(other, req_stride_order) + inputs.append(other) + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + other_scale, + other_zp, + binary_post_op, + binary_alpha, + unary_post_op, + may_convert_to_optional(unary_post_op_args), + unary_post_op_algorithm, + ] + + if binary_post_op == "sum": + V.graph.mark_buffer_mutated(other.get_name()) + packed = QLinearPointwiseBinaryPT2E( + layout=NoneLayout(other.get_device()), + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + # Return other since it has been inplace changed. + return packed.inputs[-1] + + assert output_dtype is not None + if output_dtype in [torch.float32, torch.bfloat16]: + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwiseBinaryPT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + + +class MkldnnRnnLayer(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.aten.mkldnn_rnn_layer.default, + ) + + @classmethod + def create( + cls, + x: "TensorBox", + w0: "TensorBox", + w1: "TensorBox", + w2: "TensorBox", + w3: "TensorBox", + hx: "TensorBox", + cx: "TensorBox", + reverse: bool, + batch_sizes: List[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + x = cls.require_stride1(cls.realize_input(x)) + # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer. + # Make sure x is contiguous in batch_first case. + x.freeze_layout() + w0 = cls.require_stride1(cls.realize_input(w0)) + w1 = cls.require_stride1(cls.realize_input(w1)) + w2 = cls.require_stride1(cls.realize_input(w2)) + w3 = cls.require_stride1(cls.realize_input(w3)) + hx = cls.require_stride1(cls.realize_input(hx)) + hx.freeze_layout() + cx = cls.require_stride1(cls.realize_input(cx)) + cx.freeze_layout() + + input_size = x.get_size() + assert len(input_size) == 3, "Expect lstm input to be 3D" + # batch_first is handled in the lstm OP. When entering + # rnn_layer here, we'll always have batch_first = False + seq_length, mini_batch, input_size = input_size + output_shape = [seq_length, mini_batch, hidden_size] + + hy_shape = hx.get_size() + cy_shape = cx.get_size() + + res: List[IRNode] = [] + + inputs = [x, w0, w1, w2, w3, hx, cx] + constant_args = [ + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ] + + packed = MkldnnRnnLayer( + MultiOutputLayout(x.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + + def get_strides_of_lstm_output(output_shape, batch_first): + assert len(output_shape) == 3, "Expect output_shape to be 3D" + return FlexibleLayout.contiguous_strides(output_shape) + + output_sizes = [output_shape, hy_shape, cy_shape] + output_strides = [ + get_strides_of_lstm_output(output_shape, batch_first), + FlexibleLayout.contiguous_strides(hy_shape), + FlexibleLayout.contiguous_strides(cy_shape), + ] + output_ir = [ + MultiOutput( + FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ), + packed, + [(tuple, i)], + ) + for i, (output_size, output_stride) in enumerate( + zip(output_sizes, output_strides) + ) + ] + + return output_ir diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_lowerings.py b/.venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_lowerings.py new file mode 100644 index 0000000000000000000000000000000000000000..a9cc0bc8299ebfce4bc67628ac21af27a847a936 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_lowerings.py @@ -0,0 +1,1087 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +from typing import List, Optional + +import torch +import torch.utils._pytree as pytree +from torch._inductor.kernel.mm_common import mm_args + +from . import ir +from .codegen.cpp_gemm_template import CppPackedGemmTemplate +from .codegen.cpp_utils import create_epilogue_with_attr +from .ir import TensorBox +from .lowering import ( + add, + add_needs_realized_inputs, + aten, + permute, + register_lowering, + to_dtype, + view, +) +from .select_algorithm import ( + autotune_select_algorithm, + ChoiceCaller, + ExternKernelChoice, +) +from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune +from .virtualized import ops, V + + +def register_onednn_fusion_ops(): + if torch._C._has_mkldnn: + from . import mkldnn_ir + + aten_mkldnn_linear_unary = ExternKernelChoice( + torch.ops.mkldnn._linear_pointwise, + "mkldnn::_linear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.LinearUnary.create, + ) + aten_mkldnn_linear_binary = ExternKernelChoice( + torch.ops.mkldnn._linear_pointwise.binary, + "mkldnn::_linear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.LinearBinary.create, + ) + aten_mkldnn_qlinear_unary = ExternKernelChoice( + torch.ops.onednn.qlinear_pointwise, + "onednn::qlinear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.QLinearPointwisePT2E.create, + ) + aten_mkldnn_qlinear_binary = ExternKernelChoice( + torch.ops.onednn.qlinear_pointwise.binary, + "onednn::qlinear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.QLinearPointwiseBinaryPT2E.create, + ) + cpu_needs_realized_inputs = [ + torch.ops.mkldnn._convolution_pointwise, + torch.ops.mkldnn._convolution_pointwise_, + torch.ops.mkldnn._convolution_transpose_pointwise, + torch.ops.mkldnn._linear_pointwise, + aten.mkldnn_rnn_layer.default, + torch.ops.onednn.qconv2d_pointwise, + ] + + @register_lowering(torch.ops.mkldnn._convolution_pointwise) + def convolution_unary( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionUnary.create( + x, + weight, + bias, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise.binary) + def convolution_binary( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionBinary.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary) + def convolution_binary_inplace( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionBinaryInplace.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._linear_pointwise) + def linear_unary( + x: TensorBox, + w: TensorBox, + b: TensorBox, + attr, + scalars, + algorithm, + layout=None, + ): + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + if b is not None: + b = ir.ExternKernel.realize_input(b) + choices: List[ChoiceCaller] = [] + if use_max_autotune(): + transposed_w = permute(w, [1, 0]) + *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout) + if use_cpp_packed_gemm_template(layout, x, transposed_w): + + def epilogue_creator(buf): + return create_epilogue_with_attr( + buf, attr, scalars=scalars, algorithm=algorithm + ) + + kwargs = dict( + has_bias=b is not None, + trans_w=True, + epilogue_creator=None if attr == "none" else epilogue_creator, + ) + if b is not None: + kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment] + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, w] if b is None else [x, w, b], + **kwargs, # type: ignore[arg-type] + ) + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict(attr=attr, scalars=scalars, algorithm=algorithm) + if b is None: + kwargs["B"] = None + choices.append( + aten_mkldnn_linear_unary.bind( + [x, w] if b is None else [x, w, b], + layout, + **kwargs, + ) + ) + assert w.get_name() in V.graph.constants + input_gen_fns = { + 1: lambda x: V.graph.constants[x.get_name()], + } + result = autotune_select_algorithm( + "linear_unary", + choices, + [x, w] if b is None else [x, w, b], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2: + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + @register_lowering(torch.ops.mkldnn._linear_pointwise.binary) + def linear_binary( + x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr, layout=None + ): + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + y_size = y.get_size() + if len(y_size) > 2: + y = view(y, [-1, y_size[-1]]) + if b is not None: + b = ir.ExternKernel.realize_input(b) + choices: List[ChoiceCaller] = [] + if use_max_autotune(): + transposed_w = permute(w, [1, 0]) + *_, layout, x, transposed_w, y = mm_args( + x, transposed_w, y, layout=layout + ) + if use_cpp_packed_gemm_template(layout, x, transposed_w): + + def epilogue_creator(buf): + return create_epilogue_with_attr(buf, attr, other=y) + + kwargs = dict( + has_bias=b is not None, + trans_w=True, + epilogue_creator=epilogue_creator, + ) + kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1] + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, y, w] if b is None else [x, y, w, b], + **kwargs, # type: ignore[arg-type] + ) + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict(attr=attr) + if b is None: + kwargs["B"] = None + choices.append( + aten_mkldnn_linear_binary.bind( + [x, y, w] if b is None else [x, y, w, b], + layout, + **kwargs, + ) + ) + assert w.get_name() in V.graph.constants + input_gen_fns = { + 2: lambda x: V.graph.constants[x.get_name()], + } + result = autotune_select_algorithm( + "linear_binary", + choices, + [x, y, w] if b is None else [x, y, w, b], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2: + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + @register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise) + def convolution_transpose_unary( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + output_padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionTransposeUnary.create( + x, + weight, + bias, + padding, + output_padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering(aten.mkldnn_rnn_layer.default) + def mkldnn_rnn_layer( + x: TensorBox, + w0: TensorBox, + w1: TensorBox, + w2: TensorBox, + w3: TensorBox, + hx: TensorBox, + cx: TensorBox, + reverse: bool, + batch_sizes: List[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + return pytree.tree_map( + TensorBox.create, + mkldnn_ir.MkldnnRnnLayer.create( + x, + w0, + w1, + w2, + w3, + hx, + cx, + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ), + ) + + @register_lowering(torch.ops.onednn.qconv2d_pointwise, type_promotion_kind=None) + def qconvolution_unary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + mkldnn_ir.QConvPointWisePT2E.create( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering( + torch.ops.onednn.qconv2d_pointwise.binary, type_promotion_kind=None + ) + def qconvolution_binary( + x: TensorBox, + x_scale, + x_zp, + accum: TensorBox, + accum_scale, + accum_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ): + if ( + binary_attr == "sum" + and output_dtype in [torch.float32, torch.bfloat16] + and accum.get_dtype() in [torch.float32, torch.bfloat16] + and accum.get_dtype() != output_dtype + ): + # For int8-mixed-bf16 quantization and inplace add, + # there is case when accum dtype is float32 but output dtype is bfloat16. + # Since the accum will be inplaced changed with post op sum, + # we will do accum dtype convertion here. + accum = to_dtype(accum, output_dtype) + return TensorBox.create( + mkldnn_ir.QConvPointWiseBinaryPT2E.create( + x, + x_scale, + x_zp, + accum, + accum_scale, + accum_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ) + ) + + @register_lowering(torch.ops.onednn.qlinear_pointwise, type_promotion_kind=None) + def qlinear_unary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + o_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + layout=None, + ): + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + if not isinstance(x_scale, ir.TensorBox): + assert type(x_scale) == float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + else: + x_scale.realize() + if not isinstance(x_zp, ir.TensorBox): + assert type(x_zp) == int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + else: + x_zp.realize() + + # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer + # Refer to https://github.com/pytorch/pytorch/blob + # /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 + w_scale.realize() + w_zp.realize() + if w_zp.get_dtype() != torch.int32 and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ): + # W_zp might be a ConstantBuffer with int64, convert it to int32 + w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) + w_zp = V.graph.add_tensor_constant( + torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() + ) + + bias_dtype = None if bias is None else bias.get_dtype() + + choices: List[ChoiceCaller] = [] + if use_max_autotune(): + *_, layout, x, packed_weight = mm_args( + x, packed_weight, layout=layout, out_dtype=output_dtype + ) + if ( + isinstance( + ir.InputsKernel.unwrap_storage_for_input(x_zp), + ir.ConstantBuffer, + ) + and len(x_zp.get_layout().size) == 0 # Per tensor quant of act + and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ) + and torch.equal( + torch.zeros_like(V.graph.constants[w_zp.get_name()]), + V.graph.constants[w_zp.get_name()], + ) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA + and use_cpp_packed_gemm_template(layout, x, packed_weight) + ): + W_tensor = V.graph.constants[packed_weight.get_name()].to_dense() + weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0) + weight_compens = V.graph.add_tensor_constant( + weight_compens_tensor, + name=packed_weight.get_name() + "_BMatrixCompens", + ) + + def epilogue_creator(input_buffer): + # Epilogue to convert from s32 to f32 for u8s8f32 + assert output_dtype in [ + torch.float32, + torch.bfloat16, + torch.uint8, + ] + input_loader = input_buffer.make_loader() + weight_compens_loader = weight_compens.make_loader() + x_scale_loader = x_scale.make_loader() + w_scale_loader = w_scale.make_loader() + x_zp_loader = x_zp.make_loader() + nonlocal bias + bias_loader = None + if bias is not None: + bias_loader = bias.make_loader() + + def inner_fn(index): + nonlocal bias + input = input_loader(index) + # MicroKernel Output is with int32 + # cvt to FP32 before doing compensation + input = ops.to_dtype(input, torch.float32) + weight_compens_index = (index[-1],) + _x_scale = x_scale_loader(()) + _x_zp = x_zp_loader(()) + _w_scale = w_scale_loader(weight_compens_index) + _weight_compo = weight_compens_loader(weight_compens_index) + # Step 1: Doing compensation to cvt fp32 + temp = ops.mul( + ops.mul( + input, + _x_scale, + ), + _w_scale, + ) + temp = ops.sub( + temp, + ops.mul( + ops.mul( + ops.mul( + _x_scale, + _w_scale, + ), + _x_zp, + ), + _weight_compo, + ), + ) + # Step 2: add Bias if applicable + if bias is not None: + _bias = bias_loader(weight_compens_index) + nonlocal bias_dtype + assert bias_dtype in [torch.float32, torch.bfloat16] + if bias_dtype == torch.bfloat16: + _bias = ops.to_dtype(_bias, torch.float32) + temp = ops.add(temp, _bias) + + return temp + + output_buf = ir.Pointwise( + device=input_buffer.get_device(), + dtype=torch.float32, # Hardcode to FP32 for u8s8f32 + inner_fn=inner_fn, + ranges=input_buffer.get_size(), + ) + + # Step 3: Doing the unary post op fusion + if attr != "none": + output_buf = create_epilogue_with_attr( + output_buf, attr, scalars=scalars, algorithm=algorithm + ) + + # Step 4: Cast output to Target Dtype + if output_dtype == torch.bfloat16: + output_cast_loader = output_buf.make_loader() + + def inner_fn_cast_output_to_bf16(index): + input = output_cast_loader(index) + return ops.to_dtype(input, output_dtype) + + output_buf = ir.Pointwise( + device=output_buf.get_device(), + dtype=output_dtype, + inner_fn=inner_fn_cast_output_to_bf16, + ranges=output_buf.get_size(), + ) + elif output_dtype == torch.uint8: + from .lowering import _create_constants + + requant_input_loader = output_buf.make_loader() + + def inner_fn_requant(index, scale, zero_point): + input = requant_input_loader(index) + inv_scale, zero_point = _create_constants( + 1.0 / scale, zero_point, dtype=torch.float32 + ) + val = ops.round(input * inv_scale) + zero_point + qmin, qmax = _create_constants( + 0, 255, dtype=torch.float32 + ) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, torch.uint8) + + output_buf = ir.Pointwise( + device=output_buf.get_device(), + dtype=output_dtype, + inner_fn=functools.partial( + inner_fn_requant, + scale=float(o_scale), + zero_point=int(o_zero_point), + ), + ranges=output_buf.get_size(), + ) + + return output_buf + + assert x.get_dtype() == torch.uint8 + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias], + has_bias=bias is not None, + epilogue_creator=epilogue_creator, + input_indices=[0, 3, 1, 2, 4, 5] + if bias is None + else [6, 0, 3, 1, 2, 4, 5], + ) + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict( + output_scale=o_scale, + output_zero_point=o_zero_point, + output_dtype=output_dtype, + post_op_name=attr, + post_op_args=scalars, + post_op_algorithm=algorithm, + ) + if bias is None: + kwargs["bias"] = None + choices.append( + aten_mkldnn_qlinear_unary.bind( + (x, x_scale, x_zp, packed_weight, w_scale, w_zp) + if bias is None + else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias), + layout, + **kwargs, + ) + ) + assert packed_weight.get_name() in V.graph.constants + input_gen_fns = { + 3: lambda x: V.graph.constants[x.get_name()], + 4: lambda x: V.graph.constants[x.get_name()], + 5: lambda x: V.graph.constants[x.get_name()], + 6: lambda x: V.graph.constants[x.get_name()], # For bias + } + result = autotune_select_algorithm( + "qlinear_unary", + choices, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2: + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + @register_lowering( + torch.ops.onednn.qlinear_pointwise.binary, type_promotion_kind=None + ) + @register_lowering( + torch.ops.onednn.qlinear_pointwise.binary_tensor, type_promotion_kind=None + ) + def qlinear_binary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + x2: TensorBox, + bias: TensorBox, + o_scale, + o_zero_point, + output_dtype, + x2_scale, + x2_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + layout=None, + ): + x_size = x.get_size() + x2_size = x2.get_size() + assert len(x_size) == len(x2_size) + if len(x_size) > 2 and binary_attr == "add": + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + x2 = view(x2, [-1, x2_size[-1]]) + if not isinstance(x_scale, ir.TensorBox): + assert type(x_scale) == float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + else: + x_scale.realize() + if not isinstance(x_zp, ir.TensorBox): + assert type(x_zp) == int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + else: + x_zp.realize() + + # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer + # Refer to https://github.com/pytorch/pytorch/blob + # /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 + w_scale.realize() + w_zp.realize() + if w_zp.get_dtype() != torch.int32 and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ): + w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) + w_zp = V.graph.add_tensor_constant( + torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() + ) + if binary_attr == "sum": + if output_dtype in [ + torch.float32, + torch.bfloat16, + ] and x2.get_dtype() in [torch.float32, torch.bfloat16]: + if x2.get_dtype() != output_dtype: + # For int8-mixed-bf16 quantization and inplace add, + # there is case when accum dtype is float32 but output dtype is bfloat16. + # Since the accum will be inplaced changed with post op sum, + # we will do accum dtype convertion here. + x2 = to_dtype(x2, output_dtype) + else: + assert ( + x2.get_dtype() == output_dtype + ), "dtype of accum for qlinear post op sum should be the same as output" + x2_dtype = x2.get_dtype() + bias_dtype = bias.get_dtype() if bias is not None else None + choices: List[ChoiceCaller] = [] + if ( + use_max_autotune() and binary_attr == "add" + ): # Support inplace sum fusion + *_, layout, x, packed_weight, x2 = mm_args( + x, packed_weight, x2, layout=layout, out_dtype=output_dtype + ) + if ( + isinstance( + ir.InputsKernel.unwrap_storage_for_input(x_zp), + ir.ConstantBuffer, + ) + and len(x_zp.get_layout().size) == 0 # Per tensor quant of act + and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ) + and torch.equal( + torch.zeros_like(V.graph.constants[w_zp.get_name()]), + V.graph.constants[w_zp.get_name()], + ) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA + and use_cpp_packed_gemm_template(layout, x, packed_weight) + ): + W_tensor = V.graph.constants[packed_weight.get_name()] + W_tensor = W_tensor.to_dense() + weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0) + weight_compens = V.graph.add_tensor_constant( + weight_compens_tensor, + name=packed_weight.get_name() + "_BMatrixCompens", + ) + + def epilogue_creator(input_buffer): + # Epilogue to convert from s32 to f32 for u8s8f32 + assert output_dtype in [ + torch.float32, + torch.bfloat16, + torch.uint8, + ] + + input_loader = input_buffer.make_loader() + x2_loader = x2.make_loader() + weight_compens_loader = weight_compens.make_loader() + x_scale_loader = x_scale.make_loader() + w_scale_loader = w_scale.make_loader() + x_zp_loader = x_zp.make_loader() + nonlocal bias + bias_loader = None + if bias is not None: + bias_loader = bias.make_loader() + + def inner_fn(index): + nonlocal bias + input = input_loader(index) + _x2 = x2_loader(index) + _x_scale = x_scale_loader(()) + _x_zp = x_zp_loader(()) + + # MicroKernel Output is with int32 + # cvt to FP32 before doing compensation + input = ops.to_dtype(input, torch.float32) + weight_compens_index = (index[-1],) + _w_scale = w_scale_loader(weight_compens_index) + _weight_compens = weight_compens_loader( + weight_compens_index + ) + # Step 1: Doing compensation to cvt fp32 + temp = ops.mul( + ops.mul( + input, + _x_scale, + ), + _w_scale, + ) + temp = ops.sub( + temp, + ops.mul( + ops.mul( + ops.mul( + _x_scale, + _w_scale, + ), + _x_zp, + ), + _weight_compens, + ), + ) + + # Step 2: add Bias if applicable + if bias is not None: + _bias = bias_loader(weight_compens_index) + nonlocal bias_dtype + assert bias_dtype in [torch.float32, torch.bfloat16] + if bias_dtype == torch.bfloat16: + _bias = ops.to_dtype(_bias, torch.float32) + temp = ops.add(temp, _bias) + + # Step 3: Binary add + nonlocal x2_dtype + assert x2_dtype in [torch.float32, torch.bfloat16] + if x2_dtype == torch.bfloat16: + _x2 = ops.to_dtype(_x2, torch.float32) + temp = ops.add(temp, _x2) + + return temp + + output_buf = ir.Pointwise( + device=input_buffer.get_device(), + dtype=torch.float32, # Hardcode to FP32 for u8s8f32 + inner_fn=inner_fn, + ranges=input_buffer.get_size(), + ) + + # Step 4: Unary post op if has + if unary_attr != "none": + output_buf = create_epilogue_with_attr( + output_buf, + unary_attr, + scalars=unary_scalars, + algorithm=unary_algorithmm, + ) + + # Step 5: Cast output to Target Dtype + if output_dtype == torch.bfloat16: + output_cast_loader = output_buf.make_loader() + + def inner_fn_cast_output_to_bf16(index): + input = output_cast_loader(index) + return ops.to_dtype(input, output_dtype) + + output_buf = ir.Pointwise( + device=output_buf.get_device(), + dtype=output_dtype, + inner_fn=inner_fn_cast_output_to_bf16, + ranges=output_buf.get_size(), + ) + elif output_dtype == torch.uint8: + from .lowering import _create_constants + + requant_input_loader = output_buf.make_loader() + + def inner_fn_requant(index, scale, zero_point): + input = requant_input_loader(index) + inv_scale, zero_point = _create_constants( + 1.0 / scale, zero_point, dtype=torch.float32 + ) + val = ops.round(input * inv_scale) + zero_point + qmin, qmax = _create_constants( + 0, 255, dtype=torch.float32 + ) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, torch.uint8) + + output_buf = ir.Pointwise( + device=output_buf.get_device(), + dtype=torch.uint8, + inner_fn=functools.partial( + inner_fn_requant, + scale=float(o_scale), + zero_point=int(o_zero_point), + ), + ranges=output_buf.get_size(), + ) + + return output_buf + + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias], + has_bias=bias is not None, + epilogue_creator=epilogue_creator, + # Reorder bias and x2 + input_indices=[0, 3, 1, 2, 4, 5, 6] + if bias is None + else [7, 0, 3, 1, 2, 4, 5, 6], + ) + + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict( + output_scale=o_scale, + output_zero_point=o_zero_point, + output_dtype=output_dtype, + other_scale=x2_scale, + other_zp=x2_zp, + binary_post_op=binary_attr, + binary_alpha=alpha, + unary_post_op=unary_attr, + unary_post_op_args=unary_scalars, + unary_post_op_algorithm=unary_algorithmm, + ) + if bias is None: + kwargs["bias"] = None + choices.append( + aten_mkldnn_qlinear_binary.bind( + (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2) + if bias is None + else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias), + layout, + **kwargs, + ) + ) + assert packed_weight.get_name() in V.graph.constants + input_gen_fns = { + 3: lambda x: V.graph.constants[x.get_name()], + 4: lambda x: V.graph.constants[x.get_name()], + 5: lambda x: V.graph.constants[x.get_name()], + } + if bias is not None: + input_gen_fns[7] = lambda x: V.graph.constants[x.get_name()] # For bias + result = autotune_select_algorithm( + "qlinear_binary", + choices, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2 and binary_attr == "add": + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + if torch._C.has_mkl: + aten_mkl_linear = ExternKernelChoice( + torch.ops.mkl._mkl_linear, + "mkl::_mkl_linear", + has_out_variant=False, + kernel_creator=mkldnn_ir.MKLPackedLinear.create, + ) + cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) + + @register_lowering(torch.ops.mkl._mkl_linear) + def mkl_packed_linear( + x: TensorBox, + packed_w: TensorBox, + orig_w: TensorBox, + b: Optional[TensorBox], + batch_size, + *, + layout=None, + ): + choices: List[ChoiceCaller] = [] + if use_max_autotune(): + transposed_w = permute(orig_w, [1, 0]) + *_, layout, x, transposed_w = mm_args( + x, transposed_w, layout=layout + ) + if use_cpp_packed_gemm_template(layout, x, transposed_w): + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, packed_w, orig_w], + trans_w=True, + input_indices=[0, 2], + ) + + if len(choices) == 0 or use_aten_gemm_kernels(): + choices.append( + aten_mkl_linear.bind( + (x, packed_w, orig_w), layout, B=None, batch_size=batch_size + ) + ) + + assert packed_w.get_name() in V.graph.constants + assert orig_w.get_name() in V.graph.constants + # packed_w is a mkldnn tensor which we can't generate directly + # so we use the weights from the original tensor in autotune. + input_gen_fns = { + 1: lambda x: V.graph.constants[x.get_name()], + 2: lambda x: V.graph.constants[x.get_name()], + } + result: TensorBox = autotune_select_algorithm( + "packed_linear", + choices, + [x, packed_w, orig_w], + layout, + input_gen_fns=input_gen_fns, + ) + if b is not None: + result = add(result, b) + return result + + add_needs_realized_inputs(cpu_needs_realized_inputs) + else: + pass diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/package/__init__.py b/.venv/lib/python3.11/site-packages/torch/_inductor/package/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c088562100cdae1b3808437d845501b639822862 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/package/__init__.py @@ -0,0 +1 @@ +from .package import load_package, package_aoti diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbb216d1f26be7be237b0aa78d289617ca202659 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/package.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/package.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21c2a3ebb059995345c132737847f52053df3820 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/package.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/package/build_package.py b/.venv/lib/python3.11/site-packages/torch/_inductor/package/build_package.py new file mode 100644 index 0000000000000000000000000000000000000000..9205b9ced254275018472108485173eba9479f11 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/package/build_package.py @@ -0,0 +1,15 @@ +build_package_contents = """ +import os +from pathlib import Path + +from torch._inductor.package.package import compile_so + +curr_dir = Path(__file__).parent +aoti_files = [ + os.path.join(root, file) + for root, dirs, files in os.walk(curr_dir) + for file in files +] + +output_so = compile_so(curr_dir, aoti_files, curr_dir) +""" diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/package/package.py b/.venv/lib/python3.11/site-packages/torch/_inductor/package/package.py new file mode 100644 index 0000000000000000000000000000000000000000..d1304293e3cd88bdeaa704c558ffefd3f3ef885c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/package/package.py @@ -0,0 +1,237 @@ +import glob +import json +import os +import shlex +import subprocess +import tempfile +import zipfile +from pathlib import Path +from typing import Callable, List, Optional, Union + +import torch +import torch._inductor +import torch.utils._pytree as pytree +from torch._inductor import config, exc +from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder +from torch.export._tree_utils import reorder_kwargs + +from .build_package import build_package_contents +from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION + + +class PT2ArchiveWriter: + def __init__(self, archive_path: str) -> None: + self.archive_path: str = archive_path + self.archive_file: Optional[zipfile.ZipFile] = None + + def __enter__(self) -> "PT2ArchiveWriter": + assert self.archive_file is None + self.archive_file = zipfile.ZipFile( + self.archive_path, "w", compression=zipfile.ZIP_STORED + ) + self.writestr("version", str(ARCHIVE_VERSION)) + self.writestr("archive_format", "pt2") + return self + + def __exit__(self, *args) -> None: # type: ignore[no-untyped-def] + assert self.archive_file is not None + self.archive_file.close() + self.archive_file = None + return None + + def writestr(self, name: str, data: Union[bytes, str]) -> None: + assert self.archive_file is not None + self.archive_file.writestr(name, data) + + def write_file(self, name: str, file_path: str) -> None: + """ + Copy a file into the archive. + name: The destination file inside the archive. + file_path: The source file on disk. + """ + assert Path(file_path).is_file(), f"{file_path} is not a valid file path" + assert self.archive_file is not None + self.archive_file.write(file_path, arcname=name) + + +class PT2ArchiveReader: + def __init__(self, archive_path: str) -> None: + self.archive_path: str = archive_path + self.archive_file: Optional[zipfile.ZipFile] = None + + def __enter__(self) -> "PT2ArchiveReader": + self.archive_file = zipfile.ZipFile( + self.archive_path, "r", compression=zipfile.ZIP_STORED + ) + return self + + def __exit__(self, *args) -> None: # type: ignore[no-untyped-def] + if self.archive_file is not None: + self.archive_file.close() + return None + + def read(self, name: str) -> bytes: + assert self.archive_file is not None + return self.archive_file.read(name) + + def extract_to_path(self, member: str, path: str) -> str: + assert self.archive_file is not None + return self.archive_file.extract(member, path) + + def extractall(self, path: str) -> None: + assert self.archive_file is not None + self.archive_file.extractall(path) + + def get_file_names(self) -> List[str]: + assert self.archive_file is not None + return self.archive_file.namelist() + + +def _run_command_and_check(cmd: str) -> None: + cmd = shlex.split(cmd) + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError as e: + raise exc.CppCompileError(cmd, e.output) from e + + +def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str: + def get_aoti_file_with_suffix(suffix: str) -> str: + for file in aoti_files: + if file.endswith(suffix): + return file + raise RuntimeError(f"Unable to find file with suffix {suffix}") + + # Compile all the files into a .so + cpp_file = os.path.join(aoti_dir, get_aoti_file_with_suffix(".cpp")) + consts_o = os.path.join(aoti_dir, get_aoti_file_with_suffix(".o")) + + file_name = os.path.splitext(cpp_file)[0] + + # Parse compile flags and build the .o file + with open(file_name + "_compile_flags.json") as f: + compile_flags = json.load(f) + + compile_options = BuildOptionsBase(**compile_flags) + object_builder = CppBuilder( + name=file_name, + sources=cpp_file, + BuildOption=compile_options, + ) + compile_cmd = object_builder.get_command_line() + output_o = object_builder.get_target_file_path() + + _run_command_and_check(compile_cmd) + + # Parse linker flags and build the .so file + with open(file_name + "_linker_flags.json") as f: + linker_flags = json.load(f) + + linker_options = BuildOptionsBase(**linker_flags) + so_builder = CppBuilder( + name=os.path.split(so_path)[-1], + sources=[output_o, consts_o], + BuildOption=linker_options, + output_dir=so_path, + ) + link_cmd = so_builder.get_command_line() + output_so = so_builder.get_target_file_path() + + _run_command_and_check(link_cmd) + + # mmapped weights + serialized_weights_filename = file_name + "_serialized_weights.bin" + if serialized_weights_filename in aoti_files: + with open(serialized_weights_filename, "rb") as f_weights: + serialized_weights = f_weights.read() + + with open(output_so, "a+b") as f_so: + so_size = f_so.tell() + # Page align the weights + f_so.write(b" " * (16384 - so_size % 16384)) + f_so.write(serialized_weights) + + return output_so + + +def package_aoti(aoti_output_dir: str) -> str: + """ + Saves the AOTInductor generated files to the PT2Archive format. + """ + + # Add a makefile and python script + build_package_filename = "build_package.py" + with open(os.path.join(aoti_output_dir, build_package_filename), "w") as f: + f.write(build_package_contents) + + with open(os.path.join(aoti_output_dir, "Makefile"), "w") as f: + f.write(f"all:\n\tpython3 {build_package_filename}\n") + + if config.aot_inductor.output_path.endswith(".so"): + raise RuntimeError( + "Unable to save package as a .so. It should be a .pt2 format or a directory." + ) + elif config.aot_inductor.output_path.endswith(".pt2"): + # Save using the PT2 packaging format + # (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a) + archive_path = config.aot_inductor.output_path + + with PT2ArchiveWriter(archive_path) as archive_writer: + package_files = glob.glob(f"{aoti_output_dir}/*") + + for path in package_files: + filename = os.path.basename(path) + archive_writer.write_file(f"{AOTINDUCTOR_DIR}{filename}", path) + + return archive_path + + else: + # Directly put the files into the directory, without any archiving + return aoti_output_dir + + +def load_package(path: str, device: str) -> Callable: # type: ignore[type-arg] + if path.endswith(".so"): + raise RuntimeError( + "Unable to load .so. It should be a .pt2 format or a directory." + ) + + elif path.endswith(".pt2"): + so_path = os.path.splitext(path)[0] + with PT2ArchiveReader(path) as archive_reader: + file_names = archive_reader.get_file_names() + + with tempfile.TemporaryDirectory() as tmp_dir: + archive_reader.extractall(tmp_dir) + file_names = archive_reader.get_file_names() + aoti_files = [ + file for file in file_names if file.startswith(AOTINDUCTOR_DIR) + ] + + so_path = compile_so(tmp_dir, aoti_files, so_path) + + else: + assert os.path.isdir(path), "Must specify a directory or a .pt2 file" + aoti_files = [ + os.path.join(root, file) + for root, dirs, files in os.walk(path) + for file in files + ] + so_path = compile_so(path, aoti_files, path) + + if device == "cpu": + runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg] + elif device == "cuda" or device.startswith("cuda:"): + runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg] + else: + raise RuntimeError("Unsupported device " + device) + + def optimized(*args, **kwargs): # type: ignore[no-untyped-def] + call_spec = runner.get_call_spec() # type: ignore[attr-defined] + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] + flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] + return pytree.tree_unflatten(flat_outputs, out_spec) + + return optimized diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/package/pt2_archive_constants.py b/.venv/lib/python3.11/site-packages/torch/_inductor/package/pt2_archive_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..7f0c91d2fcd4806c7b85b5e3b78ec53937a3b098 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/package/pt2_archive_constants.py @@ -0,0 +1,16 @@ +ARCHIVE_ROOT_NAME = "package" +ARCHIVE_FORMAT_PATH = "archive_format" +MODELS_DIR = "models/" +MODELS_FILENAME_FORMAT = "models/{}.json" +AOTINDUCTOR_DIR = "data/aotinductor/" +WEIGHTS_DIR = "data/weights/" +WEIGHT_FILENAME_PREFIX = "weight_" +CONSTANTS_DIR = "data/constants/" +TENSOR_CONSTANT_FILENAME_PREFIX = "tensor_" +CUSTOM_OBJ_FILENAME_PREFIX = "custom_obj_" +SAMPLE_INPUTS_DIR = "data/sample_inputs/" +SAMPLE_INPUTS_FILENAME_FORMAT = "data/sample_inputs/{}.pt" +EXTRA_DIR = "extra/" +MODULE_INFO_PATH = "extra/module_info.json" + +ARCHIVE_VERSION = 0 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py b/.venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..e43d37fd37b1a757949406ce5e906b8b08cf719b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py @@ -0,0 +1,2005 @@ +# mypy: allow-untyped-decorators +""" +# Inductor Pattern Matcher + +The pattern matcher enables search/replace within an FX graph. + +The main entrypoint to the pattern matcher is register_replacement(). Given a +search function and a replacement function this will register a replacement with +a pass (such as torch._inductor.fx_passes.joint_graph.patterns). + +Internally the pattern matcher represents patterns as a graph (a DAG). Creating +new patterns manually as a graph is cumbersome and error-prone so the standard +way to create patterns (using register_replacement()) is to provide a search +function and a replacement function which is traced and converted into a graph. + +Because the search functions are built somewhat generic (they tend to ignore +tensor sizes, for example) register_replacement() allows you to specify an +`extra_check` function which performs additional checks to verify that the +matched pattern fully matches before returning it. + +## Precompiled Patterns + +New patterns are added using register_replacement(). Patterns added in this way +can have a compile-time overhead because they need to be traced before +use. Patterns can be precompiled and added using gen_register_replacement() +instead. To do this you call gen_register_replacement() instead of +register_replacement(). The arguments are the same except for an additional +unique name which is used as a lookup key. + +## Internals + +The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr +implements a `_match` method which returns either a `Match` object for a +successful match or a `FailedMatch` object for a failure to match. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import importlib +import inspect +import itertools +import logging +import operator +import os +import re +import textwrap +import typing +from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Generator, + Iterable, + List, + Mapping, + NoReturn, + Optional, + Protocol, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, +) +from typing_extensions import Self, TypeGuard + +import torch +import torch._guards +import torch.fx +import torch.utils._pytree as pytree +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import counters +from torch._inductor.config import trace as trace_config +from torch._prims_common import is_integer_dtype +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import guard_size_oblivious +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.graph_transform_observer import GraphTransformObserver + +from .._functorch import config as functorch_config +from .._functorch.aot_autograd import aot_function, make_boxed_func +from .._functorch.partitioners import default_partition +from .._subclasses import FakeTensor, FakeTensorMode +from ..fx import Transformer +from . import config +from .decomposition import select_decomp_table +from .lowering import fallback_node_due_to_unsupported_type + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +Constant = Any +NodeOrConstant = Union[Constant, torch.fx.Node] + + +class SearchFn(Protocol): + __name__: str + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + ... + + +class ReplaceFn(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> Any: + ... + + +class TraceFn(Protocol): + def __call__( + self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any + ) -> torch.fx.GraphModule: + ... + + +T = TypeVar("T") + +# What's a better name for this? +FnsType = Union[torch.fx.node.Target, str] + + +class Multiple: + def __init__(self) -> None: + # Ensure we're really a singleton. + assert "MULTIPLE" not in globals() or self is MULTIPLE + + +# Sentinel indicating multiple quantities can be matched +MULTIPLE = Multiple() + + +class Match: + """ + Represents a successfully matched pattern. + + The `Match` object is returned to represent a successfully matched + pattern. Included in the Match are the pattern that was matched, the graph + nodes matched, and any args that were used during the matching. + + The args and kwargs are specific to the type of pattern that was matched and + provide hints about what was matched. + """ + + pattern: PatternExpr + args: List[Any] + kwargs: Dict[str, Any] + nodes: List[torch.fx.Node] + targets: Dict[_TargetExpr, torch.fx.node.Target] + ctx: MatchContext + replacement_graph: Optional[torch.fx.Graph] + + def __init__( + self, + ctx: MatchContext, + pattern: PatternExpr, + args: Optional[Sequence[Any]] = None, + kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self.pattern = pattern + # The input nodes that must be passed in to the result + self.args = list(args or []) + self.kwargs = kwargs or {} + # The nodes matched in this expression + self.nodes = [] + # Mapping CallFunction to the node.target + self.targets = {} + self.ctx = ctx + self.replacement_graph = None + + @property + def graph(self) -> torch.fx.Graph: + return self.ctx.graph + + def extend(self, other: Match) -> None: + if self.kwargs: + for key in set(self.kwargs.keys()) & set(other.kwargs.keys()): + if self.kwargs[key] != other.kwargs[key]: + raise FailedMatch("kwarg mismatch: {}", key) + self.args.extend(other.args) + self.nodes.extend(other.nodes) + self.kwargs.update(other.kwargs) + self.targets.update(other.targets) + + def bundle(self) -> Match: + # Wrap args in an extra list + self.args = [tuple(self.args)] if self.args else [] + return self + + def __repr__(self) -> str: + return f"Match(..., {self.args}, {self.kwargs})" + + def erase_nodes(self) -> None: + graph = self.graph + for n in reversed(self.nodes): + if not n._erased and not n.users: + graph.erase_node(n) + + def output_nodes(self) -> List[Optional[torch.fx.Node]]: + return [ + (self.ctx.pattern_to_node[p] if p is not None else None) + for p in self.ctx.outputs + ] + + def output_node(self) -> torch.fx.Node: + return next(p for p in self.output_nodes() if p) + + def replace_with_graph( + self, replacement_graph: torch.fx.Graph, args: Sequence[Any] + ) -> None: + ReplacementPatternEntry.replace_with_graph( + self, self.ctx.graph, replacement_graph, args + ) + + def replace_by_example( + self, + replacement_fn: ReplaceFn, + args: Sequence[Any], + trace_fn: Optional[TraceFn] = None, + run_functional_passes: bool = True, + ) -> None: + """Replace with a graph generated by tracing the replacement_fn. + + Args: + run_functional_passes (bool). If we should run passes that + assume functional IR (like DCE, remove_noop_ops), on the + replacement graph. + + """ + from torch._inductor.virtualized import NullHandler, V + + context = ( + V.fake_mode + if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None)) + else contextlib.nullcontext() + ) + + with context: + if trace_fn is None: + trace_fn = functools.partial( + fwd_only, run_functional_passes=run_functional_passes + ) + replacement = trace_fn( + replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type] + ) + ReplacementPatternEntry.replace_with_graph( + self, + self.ctx.graph, + replacement, + args, + ) + + +class FailedMatch(RuntimeError): + """ + Represents a unsuccessful match. + + The `FailedMatch` object is returned to represent a failure to match a + pattern. + """ + + format_string: str + + def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None: + self.format_string = format_string + # We want to construct error messages lazily instead of eagerly, as + # constructing them eagerly can significantly worsen compile times. + if len(format_string) > 200: + raise RuntimeError( + f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}" + ) + self.args = args + self.kwargs = kwargs + + def __str__(self) -> str: + return self.format_string.format(*self.args, **self.kwargs) + + def __bool__(self) -> bool: + return False + + +MatchResult = Union[Match, FailedMatch] + + +def is_match(m: MatchResult) -> TypeGuard[Match]: + """ + TypeGuards cannot act on `self`. Thus this function exists to let mypy + recognize FailedMatch.__bool__ as a TypeGuard. + """ + return bool(m) + + +class MatchContext: + """ + Internal state needed while running PatternExpr._match(). + """ + + outputs: List[Optional[PatternExpr]] + pattern_to_node: Dict[PatternExpr, Optional[torch.fx.Node]] + graph: torch.fx.Graph + exclusive_node_set: List[NodeOrConstant] + + def __init__( + self, + outputs: List[Optional[PatternExpr]], + pattern_to_node: Optional[Dict[PatternExpr, torch.fx.Node]] = None, + *, + graph: torch.fx.Graph, + ) -> None: + self.outputs = outputs + self.pattern_to_node = {} if pattern_to_node is None else dict(pattern_to_node) + self.graph = graph + self.exclusive_node_set = [] + + def match(self, pattern: PatternExpr, node: NodeOrConstant) -> MatchResult: + """wrapper to check reused nodes in patterns""" + if pattern in self.pattern_to_node: + if self.pattern_to_node[pattern] == node: + return Match(self, pattern) # already checked this node + else: + return FailedMatch("repeated pattern differs") + m = pattern._match(node, self) + assert pattern not in self.pattern_to_node + self.pattern_to_node[pattern] = node if m else None + return m + + def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]: + return { + pattern: node + for pattern, node in self.pattern_to_node.items() + if pattern.has_multiple_users() and node is not None + } + + +class PatternExpr(ABC): + """ + Base class for types of patterns. + """ + + @abstractmethod + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + ... + + def match(self, node: torch.fx.Node) -> MatchResult: + try: + return MatchContext([self], graph=node.graph).match(self, node) + except FailedMatch as e: + return e + + def has_multiple_users(self) -> bool: + return False + + def __repr__(self) -> str: + return self.__class__.__name__ + "()" + + def find_anchor_nodes( + self, ctx: MatchContext, searched: Set[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: + if self in ctx.pattern_to_node: + yield ctx.pattern_to_node[self] + + def pattern_eq(self, other: Any) -> bool: + """ + Compare two `PatternExpr`s and return true if they are the + same. Note this is NOT matching a pattern - it is comparing the pattern + structures (for debugging). + """ + return isinstance(other, self.__class__) + + +class Arg(PatternExpr): + """ + Capture an arg which will become an input to the handler. Args are + passed in depth first order. + """ + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self, args=[node]) # matches anything + + +class Ignored(PatternExpr): + """ + Match an arg, but don't pass it to handler + """ + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self) # matches anything + + def __repr__(self) -> str: + return "*" + + def pretty_print(self, pp: PatternPrettyPrinter) -> str: + return "Ignored()" + + +class KeywordArg(PatternExpr): + """ + Capture a kwarg which will become an input to the handler. + """ + + def __init__(self, name: str) -> None: + super().__init__() + self.name = name + + def __repr__(self) -> str: + return f"KeywordArg({self.name!r})" + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self, kwargs={self.name: node}) # matches anything + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return super().pattern_eq(other) and self.name == other.name + + +class ExclusiveKeywordArg(PatternExpr): + """ + Capture a kwarg which will become an input to the handler. + """ + + name: str + + def __init__(self, name: str) -> None: + super().__init__() + self.name = name + + def __repr__(self) -> str: + return f"ExclusiveKeywordArg({self.name!r})" + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + if node in ctx.exclusive_node_set: + return FailedMatch("exclusive arg appears twice") + + ctx.exclusive_node_set.append(node) + return Match(ctx, self, kwargs={self.name: node}) # matches anything + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return super().pattern_eq(other) and self.name == other.name + + +class _TargetExpr(PatternExpr): + """ + Base class for filtering match by node.target + """ + + fns: List[FnsType] + fns_set: Set[FnsType] + + def __init__( + self, fns: Union[FnsType, Sequence[FnsType]], users: Union[Multiple, int] = 1 + ) -> None: + super().__init__() + fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns) + for fn in fns: + if isinstance(fn, torch._ops.OpOverloadPacket): + fns.extend(getattr(fn, overload) for overload in fn.overloads()) + + self.fns = fns + self.fns_set = set(fns) + self.users = users + + @property + @abstractmethod + def op(self) -> str: + ... + + def fns_repr(self) -> str: + first_repr = self.fns[0] + if not isinstance(first_repr, str): + first_repr = first_repr.__name__ + + if len(self.fns) > 1: + return f"[{first_repr}, ...]" + elif self.fns[0] is getattr(torch, first_repr, None): + return f"torch.{first_repr}" + elif isinstance(self.fns[0], torch._ops.OpOverload): + return str(self.fns[0]) + else: + return first_repr + + def __repr__(self) -> str: + if self.users is MULTIPLE: + comma_users = ", MULTIPLE" + elif self.users != 1: + comma_users = f", {self.users})" + else: + comma_users = "" + return f"{self.__class__.__name__}({self.fns_repr()}{comma_users})" + + def has_multiple_users(self) -> bool: + return isinstance(self.users, Multiple) or self.users > 1 + + def find_anchor_nodes( + self, ctx: MatchContext, searched: Set[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: + raise NotImplementedError + + def _match_fns(self, node: torch.fx.Node) -> bool: + return ( + isinstance(node, torch.fx.Node) + and node.op == self.op + and extract_target(node) in self.fns_set + ) + + def _match_users(self, node: torch.fx.Node, ctx: MatchContext) -> bool: + return ( + self in ctx.outputs + or self.users is MULTIPLE + or len(node.users) == self.users + ) + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and self.op == other.op + and self.fns == other.fns + and self.users == other.users + ) + + +_SimpleSpec = Tuple[Any, ...] + + +class _TargetArgsExpr(_TargetExpr): + """ + Base class for filtering match by node.{target,args,kwargs} + """ + + def __init__( + self, + fns: Union[torch.fx.node.Target, str, Sequence[Any]], + *args: Any, + _users: Union[int, Multiple] = 1, + **kwargs: Any, + ) -> None: + super().__init__(fns, _users) + self.args = tuple(args) + self.kwargs = dict(kwargs) + if any( + isinstance(x, (dict, list, tuple)) + for x in itertools.chain(args, kwargs.values()) + ): + self.flatten = self.pytree_flatten + else: + self.flatten = self.simple_flatten + self.flat_args_kwargs = self.flatten(self.args, self.kwargs) + + @staticmethod + def simple_flatten( + args: Sequence[Any], kwargs: Mapping[Any, Any] + ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: + values = (*args, *kwargs.values()) + spec = (len(args), *kwargs.keys()) + return values, spec + + @staticmethod + def pytree_flatten( + args: Sequence[Any], kwargs: Mapping[Any, Any] + ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: + def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec: + if s.type is None: + return s + mapping = {immutable_list: list, tuple: list, immutable_dict: dict} + return pytree.TreeSpec( + mapping.get(s.type, s.type), + s.context, + list(map(norm_spec, s.children_specs)), + ) + + flat, spec = pytree.tree_flatten([args, kwargs]) + spec = norm_spec(spec) + return flat, spec + + def __repr__(self) -> str: + args = [ + self.fns_repr(), + *map(repr, self.args), + *[f"{k}={v}" for k, v in self.kwargs.items()], + ] + if self.users is MULTIPLE: + args.append("_users=MULTIPLE") + elif self.users != 1: + args.append(f"_users={self.users}") + return f"{self.__class__.__name__}({', '.join(args)})" + + def pretty_print(self, pp: PatternPrettyPrinter) -> str: + args = [ + self.fns_repr(), + *(pp.pretty_print(x) for x in self.args), + *[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()], + ] + if self.users is MULTIPLE: + args.append("_users=MULTIPLE") + elif self.users != 1: + args.append(f"_users={self.users}") + + joiner_str = ", " + return f"{self.__class__.__name__}({joiner_str.join(args)})" + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + if not self._match_fns(node) or len(node.args) != len(self.args): + return FailedMatch("function_mismatch: node={}, pattern={}", node, self) + + if not self._match_users(node, ctx): + return FailedMatch("multiple_users {}", self) + + _args = node.args + _kwargs = node.kwargs + if len(_kwargs) < len(self.kwargs): + from torch.fx.operator_schemas import normalize_function + + normalized_args_and_kwargs = normalize_function( + node.target, node.args, node.kwargs # type: ignore[arg-type] + ) + + if normalized_args_and_kwargs is None: + return FailedMatch("function_mismatch: node={}, pattern={}", node, self) + else: + _args, _kwargs = normalized_args_and_kwargs + if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs): + _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} + else: + return FailedMatch( + "function_mismatch: node={}, pattern={}", node, self + ) + else: + _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} + + node_items, node_spec = self.flatten(_args, _kwargs) + self_items, self_spec = self.flat_args_kwargs + if node_spec != self_spec: + return FailedMatch("args_structure {} {}", node_spec, self_spec) + assert len(node_items) == len(self_items) + + m = Match(ctx, self) + for i, pattern, child_node in zip(itertools.count(), self_items, node_items): + if isinstance(pattern, PatternExpr): + child_match = ctx.match(pattern, child_node) + if not is_match(child_match): + return child_match + m.extend(child_match) + elif isinstance(child_node, torch.fx.Node) or child_node != pattern: + return FailedMatch( + "constant_args: {} {!r}!={pattern!r}", node, child_node + ) + m.nodes.append(node) + m.targets[self] = node.target + return m + + def find_anchor_nodes( + self, ctx: MatchContext, searched: Set[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: + """ + This is used when we are matching a pattern with multiple outputs. + There is a partial match (stored in ctx) and we want to walk + this pattern to find a connection to an already-matched node. + + Yields candidate nodes that `self._match` might like. + """ + if self in ctx.pattern_to_node: + yield ctx.pattern_to_node[self] + return + + for pattern in self.flat_args_kwargs[0]: + if isinstance(pattern, PatternExpr): + for other_node in pattern.find_anchor_nodes(ctx, searched): + if not isinstance(other_node, torch.fx.Node): + continue + for node in other_node.users: + if node not in searched: + if self._match_fns(node): + yield node + searched.add(node) + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and self.flat_args_kwargs[1] == other.flat_args_kwargs[1] + and all( + a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b + for a, b in zip(self.flat_args_kwargs[0], other.flat_args_kwargs[0]) + ) + ) + + +class CallFunction(_TargetArgsExpr): + """ + Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)` + """ + + op = "call_function" + + +class CallMethod(_TargetArgsExpr): + """ + Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)` + """ + + op = "call_method" + + +class CallModule(_TargetArgsExpr): + """ + Matches a call_module node in the FX graphs: `module(*args, **kwargs)` + """ + + op = "call_module" + + +class _TargetExprVarArgs(_TargetExpr): + """ + Matches a call_function node with any arguments which are passed into the pattern + """ + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + if not self._match_fns(node): + return FailedMatch("function_mismatch") + + if not self._match_users(node, ctx): + return FailedMatch("multiple_users") + + m = Match(ctx, self) + m.nodes.append(node) + m.targets[self] = node.target + m.args.extend(node.args) + m.kwargs.update(node.kwargs) + return m + + +class CallFunctionVarArgs(_TargetExprVarArgs): + op = "call_function" + + +class CallMethodVarArgs(_TargetExprVarArgs): + op = "call_method" + + +class CallModuleVarArgs(_TargetExprVarArgs): + op = "call_module" + + +class ListOf(PatternExpr): + """ + Matches a repeated pattern + """ + + def __init__(self, pattern: PatternExpr, partial: bool = False) -> None: + super().__init__() + assert isinstance(pattern, PatternExpr) + self.pattern = pattern + self.partial = partial + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.pattern})" + + def _match(self, node: List[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override] + if not isinstance(node, (list, tuple)) or len(node) == 0: + return FailedMatch("non_list") + m = Match(ctx, self) + # Propagating patterns with multiple users will ensure we don't revisit + # the same nodes + pattern_to_node = ctx.filter_multi_user_patterns() + matched = False + for i, child_node in enumerate(node): + child_ctx = MatchContext( + ctx.outputs, pattern_to_node, graph=child_node.graph + ) + child_match = child_ctx.match(self.pattern, child_node) + pattern_to_node = child_ctx.filter_multi_user_patterns() + if not is_match(child_match): + if not self.partial: + return FailedMatch("list[{}]: {}", i, child_match) + continue + matched = True + m.extend(child_match.bundle()) + if not matched: + return FailedMatch("list: no_match") + return m.bundle() + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and self.pattern.pattern_eq(other.pattern) + and self.partial == other.partial + ) + + +class MultiOutputPattern(PatternExpr): + outputs: List[Optional[PatternExpr]] + + def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None: + super().__init__() + assert isinstance(outputs[0], _TargetExpr) + assert all(x is None or isinstance(x, PatternExpr) for x in outputs), outputs + self.outputs = list(outputs) + self.op = outputs[0].op + + @property + def fns(self) -> Union[Callable[..., Any], str, Sequence[Any]]: + # This cast is checked above in __init__() + output = typing.cast(_TargetExpr, self.outputs[0]) + return output.fns + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.outputs})" + + def pretty_print(self, pp: PatternPrettyPrinter) -> str: + args = [pp.pretty_print(x) for x in self.outputs] + joiner_str = f",\n{' '}" + str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}" + str_out = f"{str_out}\n])" + return str_out + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + output = typing.cast(_TargetExpr, self.outputs[0]) + m = ctx.match(output, node) + if not is_match(m): + return m + + for pattern in self.outputs[1:]: + if pattern is None: + continue + child_match = self._match_from_anchors(pattern, ctx) + if not is_match(child_match): + return child_match + m.extend(child_match) + + return m + + def _match_from_anchors( + self, pattern: PatternExpr, ctx: MatchContext + ) -> MatchResult: + prior = dict(ctx.pattern_to_node) + m: MatchResult = FailedMatch("no anchor found") + for node in pattern.find_anchor_nodes(ctx, set()): + m = ctx.match(pattern, node) + if is_match(m): + return m + # revert any partial matches + ctx.pattern_to_node = dict(prior) + return m + + def match(self, node: torch.fx.Node) -> MatchResult: + try: + return MatchContext(self.outputs, graph=node.graph).match(self, node) + except FailedMatch as e: + return e + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and len(self.outputs) == len(other.outputs) + and all( + a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b + for a, b in zip(self.outputs, other.outputs) + ) + ) + + +class RepeatedExpr(PatternExpr): + """ + Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind` + """ + + def __init__(self, inner_pattern: _TargetExpr) -> None: + super().__init__() + self.inner_pattern = inner_pattern + self.op = inner_pattern.op + + @property + def fns(self) -> Sequence[FnsType]: + return self.inner_pattern.fns + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + m = ctx.match(self.inner_pattern, node) + if not is_match(m): + return m + ctx.pattern_to_node.pop( + self.inner_pattern, + ) + # Check all anchor nodes match the pattern + for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, set()): + anchor_m = MatchContext([self], graph=node.graph).match( + self.inner_pattern, anchor_node + ) + if not is_match(anchor_m): + return anchor_m + m.extend(anchor_m) + return m + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return super().pattern_eq(other) and self.inner_pattern.pattern_eq( + other.inner_pattern + ) + + +class PatternPrettyPrinter: + """ + Serializes Patterns to executable python. + XXX: currently only used and tested for fuse attention patterns. May not cover + all patterns. + """ + + def __init__(self) -> None: + self.namespace = torch.fx.graph._Namespace() + self.memoized_objs_names: Dict[PatternExpr, str] = {} + self.memoized_objs_pp: Dict[PatternExpr, str] = {} + + @staticmethod + @functools.lru_cache(None) + def run(obj: PatternExpr, output_name: str = "output") -> str: + """ + Serializes obj to python code with obj written out to `output_name` + """ + + pp = PatternPrettyPrinter() + assert hasattr(obj, "pretty_print") + out_str = obj.pretty_print(pp=pp) + + output = [] + for key in pp.memoized_objs_names: + output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}") + + output.append(f"{output_name} = {out_str}") + + return "\n".join(output) + + def pretty_print(self, obj: Any) -> str: + if isinstance(obj, _TargetArgsExpr): + if memoized_name := self.memoized_objs_names.get(obj): + return memoized_name + else: + return self.memoize(obj) + if hasattr(obj, "pretty_print"): + return obj.pretty_print(self) + + return repr(obj) + + def memoize(self, obj: _TargetArgsExpr) -> str: + obj_str = obj.pretty_print(self) + obj_name = obj.fns_repr() + for prefix in ("aten.", "torch.", "prims."): + obj_name = obj_name.replace(prefix, "") + + tmp_name = self.namespace.create_name(obj_name, None) + self.memoized_objs_names[obj] = tmp_name + self.memoized_objs_pp[obj] = obj_str + return tmp_name + + +class _PassDictsType(Protocol): + def __getitem__(self, k: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: + ... + + +@dataclasses.dataclass +class PatternEntry: + pattern: PatternExpr + extra_check: Callable[[Match], bool] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + raise NotImplementedError + + def register( + self, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + target: Union[torch.fx.node.Target, None] = None, + prepend: bool = False, + ) -> None: + if target is None: + assert hasattr(self.pattern, "fns") + for fn in self.pattern.fns: + self.register(pass_dicts, fn, prepend=prepend) + elif isinstance(pass_dicts, (dict, PatternMatcherPass)): + assert hasattr(self.pattern, "op") + if prepend: + pass_dicts[(self.pattern.op, target)].insert(0, self) + else: + pass_dicts[(self.pattern.op, target)].append(self) + else: + pass_dicts = typing.cast(Sequence[_PassDictsType], pass_dicts) + for x in pass_dicts: + self.register(x, target, prepend=prepend) + + +@dataclasses.dataclass +class LoweringPatternEntry(PatternEntry): + handler: Callable[..., Any] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + handler = functools.wraps(self.handler)(functools.partial(self.handler, match)) + with graph.inserting_before(node): + replacement = graph.call_function(handler, tuple(match.args), match.kwargs) + replacement.meta.update(node.meta) + node.replace_all_uses_with(replacement) + assert match.nodes[-1] is node + match.erase_nodes() + + +@dataclasses.dataclass +class GraphPatternEntry(PatternEntry): + """ + A pattern that runs a function on the FX graph + """ + + handler: Callable[..., Any] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + with graph.inserting_before(node): + self.handler(match, *match.args, **match.kwargs) + + +@dataclasses.dataclass +class ReplacementPatternEntry(PatternEntry): + normalize_args: Callable[..., List[Any]] + + @staticmethod + def replace_with_graph( + match: Match, + graph: torch.fx.Graph, + replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule], + args: Sequence[torch.fx.Node], + ) -> None: + class Replacer(torch.fx.Interpreter): + call_method = None # type: ignore[assignment] + call_module = None # type: ignore[assignment] + get_attr = None # type: ignore[assignment] + + def run_node(self, node: torch.fx.Node) -> Any: + if node.op in ("placeholder", "output"): + return super().run_node(node) + if node.op == "call_function": + target = node.target + args, kwargs = self.fetch_args_kwargs_from_env(node) + result = graph.call_function(target, args, kwargs) # type: ignore[arg-type] + if "val" in node.meta and "val" not in result.meta: + result.meta["val"] = node.meta["val"] + if isinstance(node.meta["val"], torch.Tensor): + assert "tensor_meta" in node.meta + result.meta["tensor_meta"] = node.meta["tensor_meta"] + return result + raise NotImplementedError(f"unhandled {node}") + + output_nodes = match.output_nodes() + + if len(output_nodes) == 1: + last_node = output_nodes[0] + else: + assert output_nodes[0] + nodes = list(output_nodes[0].graph.nodes) + indices = [ + (nodes.index(n), n) + for n in output_nodes + if isinstance(n, torch.fx.Node) + ] + last_node = min(indices, key=operator.itemgetter(0))[1] + + def percolate_tags( + node: torch.fx.Node, + tag_name: str, + tag_value: str, + input_stops: Set[torch.fx.Node], + ) -> None: + queue = [node] + visited = set() + + while queue: + arg = queue.pop() + if ( + arg not in visited + and arg not in input_stops + and hasattr(arg, "meta") + ): + visited.add(arg) + arg.meta[tag_name] = tag_value + queue.extend(arg.all_input_nodes) + + with graph.inserting_before(last_node): + replacement = Replacer(replacement_graph).run(*args) # type: ignore[arg-type] + if isinstance(replacement, torch.fx.Node): + replacement = [replacement] + + def maybe_getitem(node: torch.fx.Node) -> Any: + if node.op != "call_function": + return None + if node.target != operator.getitem: + return None + assert len(node.args) == 2 + return node.args[1] + + def replace( + old: Union[torch.fx.Node, None], + new: Union[torch.fx.Node, Sequence[torch.fx.Node], None], + ) -> None: + if old is None: + assert new is None + return + assert isinstance(old, torch.fx.Node) + if new is None: + old.replace_all_uses_with(None) # type: ignore[arg-type] + graph.erase_node(old) + return + if isinstance(new, torch.fx.Node): + if "val" not in new.meta: + new.meta.update(old.meta) + + # Preserve the recompute tags in the replacement graph. We + # look at the recompute tags of the original output node to + # propagate the tag from the output all the way to the input + # args (named as args in the replace_with_graph). + # Note that this is best effort. Since patterns are from + # many to many, there is no easy way to correctly map the + # recomputable tags. It is possible in some scenarios that we + # incorrectly tag some nodes as recomputables. + for tag_name in ["recompute", "ac_graph_id"]: + if tag_name in old.meta: + percolate_tags(new, tag_name, old.meta[tag_name], set(args)) + + old.replace_all_uses_with(new) + graph.erase_node(old) + return + + # `new` is not a node: it's a list of nodes. + # + # This happens when we want to replace a node that has a single + # packed return with multiple unpacked returns. We need to do + # some graph surgery here. + # + # Example: + # def original_graph(x): + # a = op(x) + # b = a[0] + # c = a[1] + # ... + # + # Assume that we want to replace op(x) with the graph + # def new_op(x): + # w = x + 1 + # z = x + 2 + # return (w, z) + # + # We need to replace `op` with the contents of `new_op`, + # and then rewrite a[0] to be w and a[1] to be z, as so: + # def new_graph(x): + # w = x + 1 + # z = x + 2 + # b = w + # c = z + # ... + old_uses = list(old.users.keys()) + for user in old_uses: + idx = maybe_getitem(user) + if idx is None: + raise AssertionError("can't handle") + replace(user, new[idx]) # type: ignore[index] + graph.erase_node(old) + + if len(output_nodes) == len(replacement): + for old, new in zip(output_nodes, replacement): + replace(old, new) + else: + assert len(output_nodes) == 1 + replace(output_nodes[0], replacement) + + match.erase_nodes() + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + assert match.replacement_graph is not None + self.replace_with_graph( + match, + graph, + match.replacement_graph, + self.normalize_args(*match.args, **match.kwargs), + ) + + +def _return_true(match: Match) -> bool: + return True + + +def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None: + log.info( + "Replacement pattern %s failed to apply due to shape mismatch: %s", + search_fn.__name__, + e, + ) + + +def register_replacement( + search_fn: SearchFn, + replace_fn: ReplaceFn, + example_inputs: Iterable[Any], + trace_fn: TraceFn, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + extra_check: Callable[[Match], bool] = _return_true, + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), + search_fn_pattern: Union[PatternExpr, None] = None, +) -> bool: + """ + Create a replacement rule based on example functions that get traced + to create patterns. This supports both training and inference when + run on a joint forward+backward graph. + + Args: + search_fn: traced to give original pattern + replace_fn: traced to give replacement graph + example_inputs: example inputs for initial trace + trace_fn: fwd_only or joint_fwd_bwd + pass_dict: dict of passes to register to + extra_check: additional check to run on match(using real shapes) + """ + argnames_static = [*inspect.signature(search_fn).parameters.keys()] + + def check_fn(match: Match) -> bool: + """ + Often shapes get burned into the pattern, so our initial match ran with + `ignore_types=(int, ...)`. + + Recheck the match with the correct shapes. + """ + argnames = list(argnames_static) + for name in argnames: + if name not in match.kwargs: + raise RuntimeError( + f"Not all inputs to pattern found in match.kwargs. Perhaps one " + f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}" + ) + + args = list( + torch.fx.map_arg( # type: ignore[arg-type] + [match.kwargs[name] for name in argnames], lambda n: n.meta["val"] + ) + ) + sym_args: List[torch.SymInt] = [] + with torch._dynamo.utils.detect_fake_mode(args): + for i, grad in enumerate(requires_grad): + if isinstance(args[i], torch.Tensor): + if grad and is_integer_dtype(args[i].dtype): + return False + + args[i] = torch.empty_strided( + args[i].size(), + args[i].stride(), + dtype=args[i].dtype, + device=args[i].device, + requires_grad=grad, + ) + for v in itertools.chain(args[i].shape, args[i].stride()): + if isinstance(v, torch.SymInt) and all( + guard_size_oblivious(v != a) for a in sym_args + ): + sym_args.append(v) + + # If we were given a pre-traced pattern then use that instead of + # retracing. Note that this means the pattern has to be independent + # of its args. + specific_pattern = search_fn_pattern + + if not specific_pattern: + if sym_args: + # AOT Autograd and make fx will dedupe symbolic shape size + # accesses of sym ints that appear as inputs + # We don't want the sym_size uses to interfere with pattern matching + # so we provide them as inputs. + # Later, when we actually do the replacement, the symbolic shape + # sizes will get re-traced and added to the graph. + + def search_fn_new(*args_new: Any) -> Any: + return search_fn(*args_new[len(args_new) - len(args) :]) + + try: + specific_graph = trace_fn(search_fn_new, sym_args + args) + except RuntimeError as e: + log_trace_failure(search_fn, e) + return False + + # correct argnames in the graph + sym_arg_names = [] + for i, placeholder in zip( + range(len(sym_args) + len(args)), + specific_graph.graph.nodes, + ): + if i < len(sym_args): + sym_arg_names.append(placeholder.target) + continue + + with specific_graph.graph.inserting_after(placeholder): + new_node = specific_graph.graph.placeholder( + argnames[i - len(sym_args)] + ) + new_node.target = new_node.name + placeholder.replace_all_uses_with(new_node) + specific_graph.graph.erase_node(placeholder) + + argnames = sym_arg_names + argnames + else: + try: + specific_graph = trace_fn(search_fn, args) + except RuntimeError as e: + log_trace_failure(search_fn, e) + return False + + specific_pattern = fx_to_pattern( + specific_graph, + argnames=argnames, + exclusive_arg_names=exclusive_arg_names, + scalar_workaround=scalar_workaround, + ) + + node = match.output_nodes()[0] + assert node is not None + specific_pattern_match = specific_pattern.match(node) + + if is_match(specific_pattern_match) and extra_check(specific_pattern_match): + # trace the pattern using the shapes from the user program + match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment] + return True + return False + + def normalize_args(**kwargs: Any) -> List[Any]: + args = [] + for name in argnames_static: + args.append(kwargs.pop(name)) + for i in range(1, len(kwargs) + 1): + if f"tangents_{i}" not in kwargs: + break + args.append(kwargs.pop(f"tangents_{i}")) + assert not kwargs, f"leftover kwargs: {kwargs!r}" + return args + + if trace_fn is joint_fwd_bwd: + # If inference mode is enabled during compilation, assume that we don't + # want to match on any training graph patterns + if torch.is_inference_mode_enabled(): + return False + + # TODO: Revisit the functionalize_rng_ops for lowmem dropout + with functorch_config.patch(functionalize_rng_ops=False): + requires_grad: List[bool] = [ + isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs + ] + if search_fn_pattern is None: + pattern = gen_pattern( + search_fn, + example_inputs, + trace_fn, + scalar_workaround, + exclusive_arg_names, + ) + else: + pattern = search_fn_pattern + + pattern_repr = PatternPrettyPrinter.run(pattern) + assert pattern_repr not in _seen_patterns + _seen_patterns.add(pattern_repr) + pattern = ReplacementPatternEntry( + pattern=pattern, + extra_check=check_fn, + normalize_args=normalize_args, + ) + pattern.register(pass_dicts) + return pattern.pattern + + +_serialized_patterns: Set[str] = set() + + +def _serialize_pattern( + unique_name: str, + search_fn: SearchFn, + example_inputs: Iterable[Any], + trace_fn: TraceFn, + scalar_workaround: Union[Dict[str, Union[float, int]], None], +) -> PatternExpr: + def get_file_template() -> str: + auto_generated_msg = textwrap.dedent( + """\ + # This is an auto-generated file. Please do not modify it by hand. + # To re-generate, run: + # cd ~/pytorch && python torchgen/fuse/gen_patterns.py + """ + ) + + file_template = textwrap.dedent( + """\ + # mypy: ignore-errors + + # noqa: F401, E501 + {msg} + import torch + import torch._inductor + + aten = torch.ops.aten + prims = torch.ops.prims + + """ + ).format(msg=auto_generated_msg) + + pattern_matcher_imports = [] + for name in dir(torch._inductor.pattern_matcher): + attr = getattr(torch._inductor.pattern_matcher, name) + if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)): + pattern_matcher_imports.append(name) + + formatted_imports = ",\n ".join(pattern_matcher_imports) + formatted_imports = f"from torch._inductor.pattern_matcher import (\n {formatted_imports},\n)\n" + return f"{file_template}{formatted_imports}" + + if not SERIALIZED_PATTERN_PATH.is_dir(): + raise RuntimeError( + f"Could not find serialized patterns directory at {SERIALIZED_PATTERN_PATH}" + ) + + pattern_name = search_fn.__name__ + + from torch._functorch import config as functorch_config + + with functorch_config.patch(functionalize_rng_ops=False): + pattern = gen_pattern(search_fn, example_inputs, trace_fn, scalar_workaround) + + serialized_pattern = PatternPrettyPrinter.run(pattern, output_name=unique_name) + if pattern_name not in _serialized_patterns: + write_mode = "w" + _serialized_patterns.add(pattern_name) + else: + write_mode = "a" + + file_template = get_file_template() + + with open(SERIALIZED_PATTERN_PATH / f"{pattern_name}.py", write_mode) as f: + if write_mode == "w": + f.write(file_template) + else: + f.write("\n\n") + f.write(serialized_pattern) + f.write("\n") + + return pattern + + +SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patterns" + +# This is the set of serialized patterns that we've registered. Used by +# test_serialized_patterns_up_to_date() to ensure the patterns are up +# to date. +_known_precompiled_patterns: List[ + Tuple[ + Any, + Iterable[Any], + Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], + Any, + PatternExpr, + ] +] = [] + + +def gen_register_replacement( + unique_name: str, + search_fn: SearchFn, + replace_fn: ReplaceFn, + example_inputs: Iterable[Any], + trace_fn: TraceFn, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + extra_check: Callable[[Match], bool] = _return_true, + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), + skip_duplicates: bool = False, +) -> None: + # Make sure the example_inputs is materialized. + example_inputs = tuple(example_inputs) + + if "PYTORCH_GEN_PATTERNS" in os.environ: + pat = _serialize_pattern( + unique_name, search_fn, example_inputs, trace_fn, scalar_workaround + ) + else: + pattern_name = search_fn.__name__ + m = importlib.import_module( + f"torch._inductor.fx_passes.serialized_patterns.{pattern_name}" + ) + if not m or not hasattr(m, unique_name): + log.warning( + "Precompiled pattern %r not found. Run torchgen/fuse/gen_patterns.py.", + unique_name, + ) + pat = getattr(m, unique_name) + + for arg in pytree.tree_iter(example_inputs): + if isinstance(arg, FakeTensor) and arg.constant is not None: + # This can be a problem - small fake tensors (e.g. `tensor(2)`) will + # hold onto their original constant value - and by stashing it here + # will cause a memory leak if the constant value is on GPU. + # Since this is just an optimization we can clear it out. + arg.constant = None + + if PatternPrettyPrinter.run(pat) in _seen_patterns and skip_duplicates: + return + _known_precompiled_patterns.append( + (search_fn, example_inputs, trace_fn, scalar_workaround, pat) + ) + register_replacement( + search_fn, + replace_fn, + example_inputs, + trace_fn, + pass_dicts, + extra_check, + scalar_workaround, + exclusive_arg_names, + search_fn_pattern=pat, + ) + + +@functorch_config.patch(functionalize_rng_ops=False) +def gen_pattern( + search_fn: SearchFn, + example_inputs: Sequence[Any], + trace_fn: TraceFn, + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), +) -> PatternExpr: + argnames = [*inspect.signature(search_fn).parameters.keys()] + + if scalar_workaround is None: + scalar_workaround = {} + flat_inputs = [] + input_idx = 0 # Positional arguments index + + for argname in argnames: + if argname in scalar_workaround: + flat_inputs.append(scalar_workaround[argname]) + else: + flat_inputs.append(example_inputs[input_idx]) + input_idx += 1 + + search_gm = trace_fn(search_fn, flat_inputs) + return fx_to_pattern( + search_gm, + ignore_types=(int, float, list, torch.device, torch.dtype), + argnames=argnames, + scalar_workaround=scalar_workaround, + exclusive_arg_names=exclusive_arg_names, + ) + + +def register_lowering_pattern( + pattern: PatternExpr, + extra_check: Callable[[Match], bool] = _return_true, + *, + pass_dict: _PassDictsType, + prepend: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Register an aten to inductor IR replacement pattern. The decorated + function is saved and then called a lowering time allowing direct + pattern to inductor IR conversion. + """ + + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + assert callable(handler) + LoweringPatternEntry( + pattern=pattern, extra_check=extra_check, handler=handler + ).register(pass_dict, prepend=prepend) + handler._inductor_lowering_function = True # type: ignore[attr-defined] + return handler + + return decorator + + +def register_graph_pattern( + pattern: PatternExpr, + extra_check: Callable[[Match], bool] = _return_true, + *, + pass_dict: _PassDictsType, + prepend: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Register a pattern that runs a function on the FX graph, allowing + custom transformation code. + """ + + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + assert callable(handler) + GraphPatternEntry( + pattern=pattern, extra_check=extra_check, handler=handler + ).register(pass_dict, prepend=prepend) + return handler + + return decorator + + +def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool: + # first node in the graph + return node is next(iter(graph.nodes)) + + +# match: copy_, relu_, _set_grad_enabled, manual_seed, _enter_autocast, etc +# doesn't match: __rshift__, etc +_mutation_op_re = re.compile(r"(? bool: + if node.op == "call_function": + if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr] + return True + elif node.op == "call_method": + if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type] + return True + return node.kwargs.get("out") is not None + + +def same_mutation_regions(a: torch.fx.Node, b: torch.fx.Node) -> bool: + assert "mutation_region_id" in a.meta + assert "mutation_region_id" in b.meta + return a.meta["mutation_region_id"] == b.meta["mutation_region_id"] + + +def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int: + n = node + while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n): + n = n.prev + mutation_region_id = n.meta.get("mutation_region_id", 0) + while n is not node: + n = n.next + if is_mutation_op(n): + mutation_region_id += 1 + n.meta["mutation_region_id"] = mutation_region_id + return mutation_region_id + + +def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool: + return "mutation_region_id" not in next(iter(graph.nodes)).meta + + +def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None: + mutation_region_id = 0 + for nd in graph.nodes: + if is_mutation_op(nd): + mutation_region_id += 1 + nd.meta["mutation_region_id"] = mutation_region_id + + +class PatternMatcherPass: + def __init__( + self, + pass_name: Optional[str] = None, + ) -> None: + super().__init__() + self.patterns: DefaultDict[ + Tuple[str, torch.fx.node.Target], List[PatternEntry] + ] = defaultdict(list) + self.pass_name = pass_name + + def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: + return self.patterns[item] + + def apply(self, gm: torch.fx.GraphModule) -> int: + if not self.patterns: + return 0 + if isinstance(gm, torch.fx.GraphModule): + graph = gm.graph + elif isinstance(gm, torch.fx.Graph): + graph = gm + gm = graph.owning_module + else: + raise RuntimeError( + f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}" + ) + if should_compute_mutation_region_ids(graph): # type: ignore[arg-type] + compute_mutation_region_ids(graph) # type: ignore[arg-type] + get_mutation_region_id_partial = functools.partial( + get_mutation_region_id, graph + ) + count = 0 + nodes = [] + has_call_module = False + for op, target in self.patterns: + if op == "call_module": + has_call_module = True + else: + nodes.append(graph.find_nodes(op=op, target=target, sort=False)) + if has_call_module: + nodes.append(graph.find_nodes(op="call_module", sort=False)) + pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher" + with GraphTransformObserver( + gm, pass_name, trace_config.log_url_for_graph_xform + ): + for node in sorted(itertools.chain.from_iterable(nodes), reverse=True): + target = extract_target(node) + if node.op == "call_module": + if (node.op, target) not in self.patterns: + continue + + # conservatively not applying pattern for cpu input, + # since some of the patterns induce codegen and split nodes. + # Note: we will only skip cpu compute if disable_cpp_codegen=True + if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False): + continue + + for entry in self.patterns[(node.op, target)]: + if node._erased: + break + m = entry.pattern.match(node) + # pattern match crosses mutation barrier - discard + if ( + is_match(m) + and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined] + ): + continue + if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: + log.warning("%s%s %s %s", node, node.args, m, entry.pattern) + if is_match(m) and entry.extra_check(m): + count += 1 + entry.apply(m, graph, node) # type: ignore[arg-type] + counters["inductor"]["pattern_matcher_count"] += 1 + counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) + return count + + def clear(self) -> None: + self.patterns.clear() + + +def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn: + raise NotImplementedError + + +def fx_to_pattern( + gm: Union[torch.fx.GraphModule, torch.fx.Graph], + ignore_types: Sequence[Type[Any]] = (), + argnames: Sequence[str] = (), + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), +) -> PatternExpr: + """ + Convert an FX graph into a PatternExpr. This is useful for simple + patterns that can only match single functions and fixed-length lists. + """ + # scalar_workaround is a hack to capture dropout_p + # see https://github.com/pytorch/pytorch/issues/97894 + scalar_workaround = scalar_workaround or {} + inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()} + assert len(inv_scalar_workaround) == len(scalar_workaround) + + def process_arg(x: T) -> Union[T, KeywordArg, Ignored]: + if isinstance(x, (float, int)) and x in inv_scalar_workaround: + return KeywordArg(inv_scalar_workaround[x]) + if type(x) in ignore_types: + return Ignored() + if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x: + return Ignored() + return x + + argnum = itertools.count() + + class Converter(torch.fx.Interpreter): + call_method = _not_implemented + call_module = _not_implemented + get_attr = _not_implemented + + def placeholder( + self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override] + ) -> Union[ExclusiveKeywordArg, KeywordArg]: + n = next(argnum) + if n < len(argnames): + name = argnames[n] + elif argnames: + assert target.startswith("tangent") + name = target + else: + target = re.sub(r"_\d+$", "", target) # de-mangle arg name + name = target + if name in exclusive_arg_names: + return ExclusiveKeywordArg(name) + else: + return KeywordArg(name) + + def call_function( + self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override] + ) -> PatternExpr: + args, kwargs = pytree.tree_map(process_arg, (args, kwargs)) + if list in ignore_types: + # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...] + args = [process_arg(a) for a in args] + kwargs = {k: process_arg(a) for k, a in kwargs.items()} + return CallFunction(target, *args, **kwargs) + + def run_node(self, n: torch.fx.Node) -> Any: + rv = super().run_node(n) + if n.op == "output" and isinstance(rv, tuple): + assert len(rv) == len(n.args[0]) # type: ignore[arg-type] + for r, arg in zip(rv, n.args[0]): # type: ignore[arg-type] + r.users = len(arg.users) + else: + rv.users = len(n.users) + return rv + + pattern = Converter(gm).run() # type: ignore[arg-type] + if not isinstance(pattern, PatternExpr): + return MultiOutputPattern(pytree.tree_leaves(pattern)) + return pattern + + +@torch.no_grad() +def fwd_only( + fn: Callable[..., Any], + args: Sequence[Any], + *, + run_functional_passes: bool = True, + get_decomp_fn: Optional[Callable[..., Any]] = None, +) -> torch.fx.GraphModule: + """Build a normalized inference graph, for use with fx_to_pattern""" + # TODO - look into using aot autograd, asserting no mutating ops here + with enable_python_dispatcher(): + decompositions = ( + get_decomp_fn() if get_decomp_fn is not None else select_decomp_table() + ) + gm = make_fx(fn, decompositions, tracing_mode="real")(*args) + + from .fx_passes.post_grad import remove_noop_ops + + if run_functional_passes: + remove_noop_ops(gm.graph) + gm.graph.eliminate_dead_code() + + gm.recompile() + return gm + + +@torch.enable_grad() +def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.GraphModule: + """Build a normalized training graph, for use with fx_to_pattern""" + gm: Optional[torch.fx.GraphModule] = None + + def record_joint_graph( + joint_graph: torch.fx.GraphModule, inputs: Sequence[Any], **kwargs: Any + ) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + nonlocal gm + assert not gm + gm = clone_graph(joint_graph) + return default_partition(joint_graph, inputs, **kwargs) + + with torch._guards.tracing(None): + aot_function( + fn, + lambda g, i: make_boxed_func(g), + partition_fn=record_joint_graph, + decompositions=select_decomp_table(), + keep_inference_input_mutations=True, + enable_log=False, + )(*args) + assert gm + + from .fx_passes.post_grad import remove_noop_ops + + remove_noop_ops(gm.graph) + + from .fx_passes.joint_graph import pointless_view + + matcher_pass = PatternMatcherPass() + + pattern = CallFunction( + torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size") + ) + GraphPatternEntry( + pattern=pattern, handler=pointless_view, extra_check=_return_true + ).register(matcher_pass.patterns) + matcher_pass.apply(gm.graph) # type: ignore[arg-type] + + # remove in/out specs + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.graph.eliminate_dead_code() + gm.recompile() + return gm + + +def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: + args: List[torch.fx.node.Argument] = [] + torch.fx.map_arg((n.args, n.kwargs), args.append) + return args + + +def stable_topological_sort(graph: torch.fx.Graph) -> None: + # Nodes are in exactly one of these three collections: + + # - Nodes in `pending` are waiting to be processed (in reverse order): + pending = list(reversed(graph.nodes)) + + # - Nodes in `ready` have been processed and are already in the correct + # order. + ready = set() + + # - `waiting` is a mapping from a dependency to nodes which depend on that + # dependency. + waiting = defaultdict(list) + + # The cursor indicates the last processed node so we can add new nodes + # after it. + cursor = None + while pending: + node = pending.pop() + waiting_for = [x for x in _args(node) if x not in ready] + if waiting_for: + # We have unprocessed input nodes. Might as well wait for the last + # arg so an already sorted list will only recheck this node once. + waiting[waiting_for[-1]].append(node) + else: + ready.add(node) + if cursor and cursor.next is not node: + cursor.append(node) + cursor = node + # Mark the nodes that have been waiting for this node to finish as + # ready to check again. + pending.extend(reversed(waiting.pop(node, ()))) + + assert not waiting and len(ready) == len(graph.nodes) + + +def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]: + """Wrapper around lazy init functions in fx_passes/""" + + @functools.lru_cache(None) + @functools.wraps(fn) + def lazy_init() -> Any: + counters_ref = counters["inductor"].copy() + + with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): + result = fn() + + # clear view matches encountered during tracing + counters["inductor"] = counters_ref + + return result + + return lazy_init + + +def config_flag(name: str) -> Callable[[Match], Any]: + """Function for extra_check to put pass behind a flag""" + + def flag_check(match: Match) -> Any: + return getattr(config, name) + + return flag_check + + +def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule: + class CopyGraph(Transformer): + def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node: + new_node = super().run_node(old_node) + if isinstance(new_node, torch.fx.Proxy): + new_node.node.meta.update(old_node.meta) + new_node.node.name = self.new_graph._graph_namespace.create_name( + old_node.name, None + ) + return new_node + + return CopyGraph(input_graph).transform() + + +_seen_patterns: Set[str] = set() + + +def get_arg_value( + node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None +) -> Any: + return ( + node.args[arg_number] + if len(node.args) > arg_number + else node.kwargs.get(kwarg_name) # type: ignore[arg-type] + ) + + +def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> List[torch.fx.Node]: + fns = [fn] + if isinstance(fn, torch._ops.OpOverloadPacket): + fns.extend([getattr(fn, overload) for overload in fn.overloads()]) + + return [node for node in nodes if node.target in fns] + + +def extract_target(node: torch.fx.Node) -> torch.fx.node.Target: + """For call_function and call_method, we directly use the target function; + For call_module, the target is string, and we treat the module class + as a function. + """ + if node.op == "call_module": + return getattr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type] + return node.target diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/quantized_lowerings.py b/.venv/lib/python3.11/site-packages/torch/_inductor/quantized_lowerings.py new file mode 100644 index 0000000000000000000000000000000000000000..80910e67d3a61b7ba1ae115371c6b59786eff205 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/quantized_lowerings.py @@ -0,0 +1,92 @@ +# mypy: allow-untyped-defs +import logging + +import torch +from torch._inductor.kernel.mm_common import mm_args + +from . import config as inductor_config, lowering +from .codegen.cpp_gemm_template import CppPackedGemmTemplate +from .codegen.cpp_utils import create_epilogue_with_attr +from .lowering import expand, register_lowering +from .select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + realize_inputs, +) +from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template + + +log = logging.getLogger(__name__) + +aten__weight_int8pack_mm = ExternKernelChoice( + torch._weight_int8pack_mm, "at::_weight_int8pack_mm", has_out_variant=False +) + + +quantized = torch.ops.quantized +_quantized = torch.ops._quantized +aten = torch.ops.aten + + +def register_quantized_ops(): + lowering.add_needs_realized_inputs( + [ + quantized.max_pool2d, + _quantized.wrapped_fbgemm_pack_gemm_matrix_fp16, + _quantized.wrapped_fbgemm_linear_fp16_weight, + ] + ) + + lowering.make_fallback(quantized.max_pool2d) + lowering.make_fallback(_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16) + lowering.make_fallback(_quantized.wrapped_fbgemm_linear_fp16_weight) + + +def register_woq_mm_ops(): + @register_lowering(aten._weight_int8pack_mm, type_promotion_kind=None) + def int8pack_mm(input, weight, scale, *, layout=None): + _, _, _, layout, mat1, mat2 = mm_args( + input, weight, layout=layout, mat2_transposed=True + ) + assert ( + mat1.get_dtype() in [torch.bfloat16, torch.float16, torch.float] + and mat2.get_dtype() == torch.int8 + ) + aten_layout = layout + + # options to tune from + choices = ( + [aten__weight_int8pack_mm.bind((mat1, mat2, scale), aten_layout)] + if use_aten_gemm_kernels() + else [] + ) + + # scale is applied as an epilogue, and the scale tensor is expanded (with a view op) + # for broadcasting, as it's 1D. + def _mul_epilogue(buf): + return create_epilogue_with_attr( + buf, "mul", other=realize_inputs(expand(scale, layout.size)) + ) + + if use_cpp_packed_gemm_template(aten_layout, mat1, mat2, mat2_transposed=True): + CppPackedGemmTemplate.add_choices( + choices, + aten_layout, + [mat1, mat2, scale], + trans_w=True, + epilogue_creator=_mul_epilogue, + ) + + if ( + len(choices) == 0 + and inductor_config.autotune_fallback_to_aten + and not use_aten_gemm_kernels() + ): + log.warning("No choices for GEMM, using ATen backend as fallback") + return aten__weight_int8pack_mm.bind( + (mat1, mat2, scale), aten_layout + ).output_node() + + return autotune_select_algorithm( + "_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout + ) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/remote_cache.py b/.venv/lib/python3.11/site-packages/torch/_inductor/remote_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..14963a1bf5d4358b012f69bb97411f3d610f6bbe --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/remote_cache.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import json +import os +import typing +from abc import abstractmethod +from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union +from typing_extensions import override, TypeAlias + +from torch._inductor import config + + +try: + import redis +except ImportError: + redis = None # type: ignore[assignment] + + +if config.is_fbcode(): + from rfe.scubadata.scubadata_py3 import ( # type: ignore[import-not-found] + Sample as Sample_, + ) + + Sample: TypeAlias = Sample_ +else: + Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef] + + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +class RemoteCacheBackend(Generic[_T]): + """ + A backend implementation for accessing a remote/distributed cache. Only + works with bytes in/out. For structured data use a RemoteCache. + """ + + @abstractmethod + def get(self, key: str) -> Optional[_T]: + pass + + @abstractmethod + def put(self, key: str, data: _T) -> None: + pass + + +# Serde that encodes from _T to _U and decodes from _U to _T. +class RemoteCacheSerde(Generic[_T, _U]): + @abstractmethod + def encode(self, data: _T) -> _U: + pass + + @abstractmethod + def decode(self, data: _U) -> _T: + pass + + +JsonDataTy = Optional[ + Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]] +] + + +class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]): + def encode(self, data: JsonDataTy) -> bytes: + return bytes(json.dumps(data), "ascii") + + def decode(self, data: bytes) -> JsonDataTy: + return json.loads(data) + + +class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]): + def encode(self, data: _T) -> _T: + return data + + def decode(self, data: _T) -> _T: + return data + + +class RemoteCache(Generic[_T]): + backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None + + def __init__( + self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U] + ) -> None: + # Support for testing. + if (override_cls := self.__class__.backend_override_cls) is not None: + self.backend = override_cls() + else: + self.backend = backend + self.serde = serde + + def get(self, key: str) -> Optional[_T]: + sample = self._create_sample() + result = self._get(key, sample) + self._log_sample(sample) + return result + + def put(self, key: str, value: _T) -> None: + sample = self._create_sample() + self._put(key, value, sample) + self._log_sample(sample) + + def _decode(self, data: _U, sample: Optional[Sample]) -> _T: + return self.serde.decode(data) + + def _encode(self, value: _T, sample: Optional[Sample]) -> Any: # returns _U + return self.serde.encode(value) + + def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]: + if data := self.backend.get(key): + return self._decode(data, sample) + return None + + def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None: + data = self._encode(value, sample) + self.backend.put(key, data) + + def _create_sample(self) -> Optional[Sample]: + return None + + def _log_sample(self, sample: Optional[Sample]) -> None: + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]): + """ + A Redis implementation of a remote/distributed cache. + """ + + _key_fmt: str + _redis: Optional[redis.Redis] = None + + def __init__(self, cache_id: str) -> None: + if not redis: + # We had trouble importing redis - just skip init. + return + + self._key_fmt = f"pt2:{cache_id}:{{key}}" + self._redis = redis.Redis( + host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"), + port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)), + ) + + def __get_key(self, key: str) -> str: + return self._key_fmt.format(key=key) + + @override + def get(self, key: str) -> Optional[bytes]: + if not self._redis: + # Either redis wasn't found or we already had some trouble... + return None + + try: + value = self._redis.get(self.__get_key(key)) + except redis.exceptions.ConnectionError: + # Redis is lazy and doesn't actually attempt to connect until the + # first use. Mark is as unavailable now. + self._redis = None + return None + + # In theory redis.get() can return an Awaitable as well... + assert value is None or isinstance(value, bytes) + return value + + @override + def put(self, key: str, data: bytes) -> None: + if not self._redis: + # Either redis wasn't found or we already had some trouble... + return + + try: + self._redis.set(self.__get_key(key), data) + except redis.exceptions.ConnectionError: + # Redis is lazy and doesn't actually attempt to connect until the + # first use. Mark is as unavailable now. + self._redis = None + + +class RedisRemoteCache(RemoteCache[JsonDataTy]): + def __init__(self, key: str) -> None: + # Special test handling: If we're just going to override the backend + # anyway don't require redis + if self.__class__.backend_override_cls: + # This is totally bogus but it works for now... + backend = typing.cast(RemoteCacheBackend[bytes], None) + else: + backend = RedisRemoteCacheBackend(key) + serde = RemoteCacheJsonSerde() + super().__init__(backend, serde) + + +class RemoteAutotuneCache(RedisRemoteCache): + pass + + +class RemoteFxGraphCache(RedisRemoteCache): + pass diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py b/.venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..67d9651a1e39fae326a190102fe8a23aee29fcd0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py @@ -0,0 +1,1743 @@ +# mypy: allow-untyped-defs +import builtins +import contextlib +import functools +import inspect +import itertools +import json +import logging +import math +import operator +import os +import sys +import textwrap +import time +from collections import namedtuple +from concurrent.futures import as_completed, ThreadPoolExecutor +from io import StringIO +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest.mock import patch + +import sympy +from filelock import FileLock + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._dynamo.testing import rand_strided +from torch._dynamo.utils import counters, identity, preserve_rng_state + +from . import config, ir +from .autotune_process import TensorMeta, TritonBenchmarkRequest +from .codecache import code_hash, PersistentCache, PyCodeCache +from .codegen.common import IndentedBuffer, KernelTemplate +from .codegen.triton import ( + gen_common_triton_imports, + texpr, + TritonKernel, + TritonPrinter, + TritonScheduling, +) +from .codegen.triton_utils import config_of, signature_to_meta +from .exc import CUDACompileError +from .ir import ChoiceCaller, PrimitiveInfoType +from .runtime.benchmarking import benchmarker +from .runtime.hints import DeviceProperties +from .utils import ( + FakeIndentedBuffer, + get_dtype_size, + Placeholder, + restore_stdout_stderr, + sympy_dot, + sympy_index_symbol, + sympy_product, + unique, +) +from .virtualized import V + + +log = logging.getLogger(__name__) + +# correctness checks struggle with fp16/tf32 +VERIFY: Dict[str, Any] = {} +PRINT_AUTOTUNE = True +DEBUG = False + + +class KernelNamespace: + pass + + +# these objects are imported from the generated wrapper code +extern_kernels = KernelNamespace() + + +class PartialRender: + """ + Some parts of a template need to be generated at the end, but + inserted into the template at the start. This allows doing a bunch + of replacements after the initial render. + """ + + def __init__(self, code, replacement_hooks) -> None: + super().__init__() + self.code = code + self.replacement_hooks = replacement_hooks + + def finalize_hook(self, hook_key: str, strict=True) -> None: + if hook_key not in self.replacement_hooks: + if strict: + raise RuntimeError( + f"{hook_key} not registered in self.replacement_hooks" + ) + else: + return + assert ( + self.replacement_hooks[hook_key] is not None + ), "hook_key can only be called once" + self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]()) + self.replacement_hooks[hook_key] = None + + def finalize_all(self) -> str: + for key, fn in self.replacement_hooks.items(): + self.code = self.code.replace(key, fn()) + return self.code + + +# This is used to store info needed for lowering each subgraph in triton +# templates +SubgraphInfo = namedtuple( + "SubgraphInfo", + [ + "body", + "template_mask", + "template_out", + ], +) + + +class TritonTemplateKernel(TritonKernel): + def __init__( + self, + kernel_name, + input_nodes, + output_node, + defines, + num_stages, + num_warps, + grid_fn, + meta, + call_sizes, + use_jit=False, + prefix_args=0, + suffix_args=0, + epilogue_fn=identity, + subgraphs: Optional[List[ir.ComputedBuffer]] = None, + *, + index_dtype, + ) -> None: + super().__init__( + sympy_product(output_node.get_size()), + sympy.Integer(1), + index_dtype=index_dtype, + ) + self.input_nodes = input_nodes + self.output_node = output_node + self.named_input_nodes = {} # type: ignore[var-annotated] + self.defines = defines + self.kernel_name = kernel_name + self.use_jit = use_jit + self.num_stages = num_stages + self.num_warps = num_warps + self.grid_fn = grid_fn + self.meta = meta + self.call_sizes = call_sizes + # for templates with fixed epilogues + self.prefix_args = prefix_args + self.suffix_args = suffix_args + self.epilogue_fn = epilogue_fn + self.render_hooks = {} # type: ignore[var-annotated] + self.triton_meta: Optional[Dict[str, object]] = None + # For Templated Attention this can be a list of ir.Subgraph + self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs + + # The following attributes (body, template_mask, output_val) are all + # used for triton kernel codegen. + # They are swapped onto the TritonTemplateKernel object by + # `set_subgraph_body` + self.subgraph_bodies: Dict[str, SubgraphInfo] = {} + + self.body: IndentedBuffer = FakeIndentedBuffer() + self.template_mask: Optional[str] = None + self.template_out: Optional[str] = None + + @contextlib.contextmanager + def set_subgraph_body(self, body_name: str): + old_body, old_mask, old_out = self.body, self.template_mask, self.template_out + assert body_name in self.subgraph_bodies, body_name + self.body, self.template_mask, self.template_out = self.subgraph_bodies[ + body_name + ] + yield + self.subgraph_bodies[body_name] = SubgraphInfo( + self.body, self.template_mask, self.template_out + ) + self.body, self.template_mask, self.template_out = old_body, old_mask, old_out + + @contextlib.contextmanager + def create_subgraph_body(self, body_name: str): + assert body_name not in self.subgraph_bodies + self.subgraph_bodies[body_name] = SubgraphInfo(IndentedBuffer(), None, None) + with self.set_subgraph_body(body_name): + yield + + def need_numel_args(self): + return False + + def estimate_kernel_num_bytes(self): + """ + Estimate the total number of bytes this kernel takes. + For in/out nodes, sizes are counted twice: once for reading and + once for writing. + """ + ninplace_args = len(unique(self.args.inplace_buffers.values())) + num_bytes = [] + for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))): + size = V.graph.sizevars.size_hints(inp.get_size()) + numel = functools.reduce(operator.mul, size, 1) + dtype_size = get_dtype_size(inp.get_dtype()) + num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(num_bytes) + + def jit_lines(self): + if self.use_jit: + return "@triton.jit" + + argdefs, _, signature, _ = self.args.python_argdefs() + triton_meta = { + "signature": signature_to_meta(signature, size_dtype=self.index_dtype), + "device": DeviceProperties.create(self.output_node.get_device()), + "constants": {}, + } + triton_meta["configs"] = [config_of(signature)] + for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] + triton_meta["constants"][arg_num] = 1 # type: ignore[index] + matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0) + if matrix_instr_nonkdim != 0: + triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim + + self.triton_meta = triton_meta + + inductor_meta = { + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + **TritonKernel.inductor_meta_common(), + } + if config.profile_bandwidth or config.benchmark_kernel: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + return f""" + @triton_heuristics.template( + num_stages={self.num_stages}, + num_warps={self.num_warps}, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + ) + @triton.jit + """ + + def gen_argdefs(self): + def hook(): + # python_argdefs() cannot be run until after the rest of the template lazily adds more args + arg_defs, *_ = self.args.python_argdefs() + return f"{', '.join(arg_defs)}" + + self.render_hooks[""] = hook + return "" + + def gen_defines(self): + return self.defines + + def def_kernel(self, *argnames): + """ + Hook called from template code to generate function def and + needed args. + """ + assert all(isinstance(x, str) for x in argnames) + renames = IndentedBuffer(initial_indent=1) + + named_args = self.input_nodes[ + self.prefix_args : len(self.input_nodes) - self.suffix_args + ] + + assert len(argnames) == len(named_args), ( + len(argnames), + len(named_args), + self.prefix_args, + len(self.input_nodes), + ) + + for input_node in self.input_nodes[: self.prefix_args]: + # get args in correct order + self.args.input(input_node.get_name()) + + for name, input_node in zip(argnames, named_args): + arg_name = f"arg_{name}" + self.named_input_nodes[name] = input_node + self.args.input_buffers[input_node.get_name()] = arg_name + + # The args may be duplicated, so renaming must be after args are de-duplicated. + for name in argnames: + input_node = self.named_input_nodes[name] + arg_name = self.args.input_buffers[input_node.get_name()] + if input_node.get_layout().offset == 0: + renames.writeline(f"{name} = {arg_name}") + else: + offset = texpr(self.rename_indexing(input_node.get_layout().offset)) + renames.writeline(f"{name} = {arg_name} + {offset}") + + for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]: + # get args in correct order + self.args.input(input_node.get_name()) + + def hook(): + # python_argdefs() cannot be run until after the rest of the template lazily adds more args + arg_defs, *_ = self.args.python_argdefs() + code = IndentedBuffer() + code.splice(gen_common_triton_imports()) + code.splice(self.jit_lines()) + code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):") + with code.indent(): + code.splice(self.defines) + code.splice(renames.getvalue()) + return code.getvalue() + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + + def size(self, name: str, index: int): + """ + Hook called from template code to get the size of an arg. + Will add needed args to pass it in if it is dynamic. + """ + assert isinstance(index, int) + if name is None: + val = self.output_node.get_size()[index] + else: + assert isinstance(name, str) + val = self.named_input_nodes[name].get_size()[index] + return texpr(self.rename_indexing(val)) + + def stride(self, name, index=None): + """ + Hook called from template code to get the stride of an arg. + Will add needed args to pass it in if it is dynamic. + """ + if name is None: + val = self.output_node.get_stride() + else: + assert isinstance(name, str) + val = self.named_input_nodes[name].get_stride() + + if isinstance(index, int): + return texpr(self.rename_indexing(val[index])) + else: + return ", ".join([texpr(self.rename_indexing(i)) for i in val]) + + def modification( + self, subgraph_number: int, output_name: str, **fixed_inputs + ) -> str: + """This creates a modification function for a subgraph. + To use this inside a template, the first argument should specify which subgraph to codegen for + + Args: + subgraph_number (int): The index of the subgraph in self.subgraphs + """ + num = 0 + while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: + num += 1 + with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"): + assert isinstance(subgraph_number, int) + assert isinstance(self.subgraphs, list) + assert ( + self.body.getvalue() == "" + ), "Body should be clear before adding a modification" + assert subgraph_number < len( + self.subgraphs + ), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" + + subgraph = self.subgraphs[subgraph_number] + + def add_input(name): + return self.args.input(name) + + name = f"PlaceholderSubstitution_{subgraph_number}" + + class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined] + self.name = name + + def load(self, name: str, index: sympy.Expr): + if name not in fixed_inputs: + # If it's not a fixed input, it's a load from a captured + # tensor + var = add_input(name) + return f"tl.load({var} + {index})" + + return f"({fixed_inputs[name]})" + + def indirect_indexing(self, index_var, size, check, wrap_neg=True): + return sympy_index_symbol(str(index_var)) + + with V.set_ops_handler(PlaceholderSubstitution(V.ops)): + assert isinstance( + subgraph, ir.ComputedBuffer + ), f"Expected the subgraph to be a ComputedBuffer, got {type(subgraph)}" + if isinstance(subgraph.data, ir.InputBuffer): + out = subgraph.data.make_loader()(()) + else: + out = subgraph.data.inner_fn(()) + + self.codegen_body() + self.body.writeline(f"{output_name} = {out.value}") + + body_val = self.body.getvalue() + self.cse.invalidate(set()) # type: ignore[arg-type] + return body_val + + def store_output( + self, + indices: Union[List[Any], Tuple[Any]], + val: str, + mask: Optional[str] = None, + indent_width: int = 4, + ): + """Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away. + + Args: + indices (Union[List, Tuple]): The index for each dimension of the output. The dot product of + these indices and output strides must match `val`. + val (str): The value to store. + mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask + will be applied to the store. + indent_width (int): The number of spaces to use for indentation. This is used when the call to + store_output is indented in the kernel definition. + """ + with self.create_subgraph_body(""): + assert isinstance(indices, (list, tuple)) + assert isinstance(val, str) + assert isinstance(mask, (str, type(None))) + assert self.template_mask is None + indices = list(map(TritonPrinter.paren, indices)) + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] + lengths = [ + V.graph.sizevars.simplify(s) for s in self.output_node.get_size() + ] + assert len(indices) == len(lengths) + + # glue to make generated code use same indexing from template + for name, range_tree_entry in zip( + indices, self.range_trees[0].construct_entries(lengths) + ): + range_tree_entry.set_name(name) + contiguous_index = sympy_dot( + ir.FlexibleLayout.contiguous_strides(lengths), index_symbols + ) + contiguous_index = self.rename_indexing(contiguous_index) + self.body.writeline("xindex = " + texpr(contiguous_index)) + self.range_trees[0].lookup( + sympy.Integer(1), sympy_product(lengths) + ).set_name("xindex") + self.template_mask = mask + self.template_out = val + self.template_indices = indices + output_index = self.output_node.get_layout().make_indexer()(index_symbols) + output_index = self.rename_indexing(output_index) + if output_index == contiguous_index: + output_index = sympy.Symbol("xindex", integer=True) + + epilogue_args = [val] + for input_node in itertools.chain( + self.input_nodes[: self.prefix_args], + self.input_nodes[len(self.input_nodes) - self.suffix_args :], + ): + input_node.freeze_layout() + epilogue_args.append(input_node.make_loader()(index_symbols)) + + V.ops.store( + self.output_node.get_name(), + output_index, + self.epilogue_fn(*epilogue_args), + ) + self.codegen_body() + + def hook(): + # more stuff might have been added since the codegen_body above + self.codegen_body() + + return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + + def render(self, template, kwargs): + return PartialRender( + template.render(**self.template_env(), **kwargs), + self.render_hooks, + ) + + def make_load(self, name, indices, mask): + """ + Optional helper called from template code to generate the code + needed to load from an tensor. + """ + assert isinstance(indices, (list, tuple)) + assert isinstance(name, str) + assert isinstance(mask, str) + stride = self.named_input_nodes[name].get_stride() + indices = list(map(TritonPrinter.paren, indices)) + assert len(indices) == len(stride) + index = " + ".join( + f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices) + ) + return f"tl.load({name} + ({index}), {mask}, other=0.0)" + + def template_env(self): + """ + Generate the namespace visible in the template. + """ + return { + fn.__name__: fn + for fn in [ + self.def_kernel, + self.size, + self.stride, + self.store_output, + self.make_load, + self.modification, + self.gen_argdefs, + self.gen_defines, + ] + } + + def indexing( + self, + index: sympy.Expr, + *, + dense_indexing=False, + copy_shape=None, + override_mask=None, + block_ptr=False, + ): + """ + Override the default indexing to use our custom mask and force + dense indexing. + """ + return super().indexing( + index, + dense_indexing=False, + # We pass template_out as the shape to broadcast the indexing to as + # the mask might be broadcast to the output shape + copy_shape=self.template_out, + override_mask=self.template_mask, + block_ptr=block_ptr, + ) + + def codegen_range_tree(self): + pass # ignore default codegen + + def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): + wrapper = V.graph.wrapper_code + _, call_args, _, arg_types = self.args.python_argdefs() + if V.graph.cpp_wrapper: + # In the cpp_wrapper case, we have to compute CUDA launch grid at runtime + # if any dynamic dimension is involved. We rely on the Python version + # of the grid function to generate those grid configs, which may contain + # symbolic values. The wrapper will use cexpr to print out C++ code + # appropriately for the grid configs. + grid = self.call_sizes + [self.meta] + wrapper.generate_kernel_call( + name, + call_args, + grid=self.grid_fn(*grid), + arg_types=arg_types, + triton_meta=self.triton_meta, + ) + else: + wrapper.add_import_once(f"import {self.grid_fn.__module__}") + meta = wrapper.add_meta_once(self.meta) + grid = self.call_sizes + [meta] + wrapper.generate_kernel_call( + name, + call_args, + grid=grid, + grid_fn=f"{self.grid_fn.__module__}.{self.grid_fn.__name__}", + arg_types=arg_types, + triton_meta=self.triton_meta, + ) + + +@functools.lru_cache(None) +def _jinja2_env(): + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +class TritonTemplate(KernelTemplate): + index_counter = itertools.count() + all_templates: Dict[str, "TritonTemplate"] = {} + + def __init__(self, name: str, grid: Any, source: str, debug=False) -> None: + super().__init__(name) + self.grid = grid + self.template = self._template_from_string(source) + assert name not in self.all_templates, "duplicate template name" + self.all_templates[name] = self + self.debug = debug + + def generate( # type: ignore[override] + self, + input_nodes, + layout, + num_stages, + num_warps, + prefix_args=0, + suffix_args=0, + epilogue_fn=identity, + subgraphs=None, + mutated_inputs=None, + call_sizes=None, + **kwargs, + ): + """This function generates a TritonTemplateCaller + + Args: + input_nodes: List of input nodes + layout: Output layout + num_stages: Number of stages for triton launch + num_warps: Number of warps for triton launch + prefix_args: Number of input nodes to be passed as arguments + suffix_args: Number of input nodes to be passed as arguments + epilogue_fn: Optional epilogue function to be called on the output + subgraphs: Optional subgraphs to be passed as arguments, these will be inlined + into the triton template string + mutated_inputs: Optional list of input nodes that are mutated by the kernel, this is helpful + if you need to return multiple outputs. You can pass them as inputs and mark them as + being mutated by the kernel. + """ + assert self.template, "requires jinja2" + defines = StringIO() + for name, val in kwargs.items(): + defines.write(f"{name} : tl.constexpr = {val}\n") + defines = defines.getvalue() + + fake_out = ir.Buffer("buf_out", layout) + kernel_name = f"triton_{self.name}" + + numel = sympy_product(layout.size) + buffers = itertools.chain(input_nodes, (fake_out,)) + if not TritonScheduling.can_use_32bit_indexing(numel, buffers): + raise NotImplementedError( + "64-bit indexing is not yet implemented for triton templates" + ) + + if call_sizes is None: + call_sizes = layout.size + + kernel_options = dict( + input_nodes=input_nodes, + defines=defines, + num_stages=num_stages, + num_warps=num_warps, + grid_fn=self.grid, + meta=kwargs, + call_sizes=call_sizes, + prefix_args=prefix_args, + suffix_args=suffix_args, + epilogue_fn=epilogue_fn, + index_dtype="tl.int32", + subgraphs=subgraphs, + ) + + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(fake_out) + ), TritonTemplateKernel( + kernel_name=kernel_name, + output_node=fake_out, + use_jit=False, + **kernel_options, + ) as kernel: + try: + template = kernel.render(self.template, kwargs) + with kernel.set_subgraph_body(""): + code = template.finalize_all() + except ZeroDivisionError: + # TODO(nmacchioni): fix sympy division by zero + return None + if self.debug: + print("Generated Code:\n", code) + extra = ( + "-".join( + [ + *[ + f"{kwarg}={repr(kwargs[kwarg])}" + for kwarg in sorted(kwargs.keys()) + ], + f"num_stages={num_stages}", + f"num_warps={num_warps}", + ] + ) + + "-" + ) + mod = PyCodeCache.load(code, extra) + + input_call_args = tuple(kernel.args.input_buffers.keys()) + output_call_args = tuple(kernel.args.output_buffers.keys()) + + # We expect the input_buffer order to be [*input_nodes, *captured_buffers] + expected_input_args = tuple(unique(x.get_name() for x in input_nodes)) + expected_output_args = (fake_out.get_name(),) + assert input_call_args[: len(expected_input_args)] == expected_input_args, ( + input_call_args, + expected_input_args, + ) + assert output_call_args == expected_output_args, ( + output_call_args, + expected_output_args, + ) + + full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args]) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, tuple(kernel.args.sizevars.keys())), + fallback=config.unbacked_symint_fallback, + ) + + kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}" + + def make_kernel_render(out_node): + kernel = TritonTemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), + output_node=out_node, + use_jit=False, + **kernel_options, + ) + render = functools.partial( + kernel.render, + self.template, + kwargs, + ) + return kernel, render + + # create the BenchmarkRequest + assert mod.__file__ is not None + grid = self.grid( + *V.graph.sizevars.size_hints( + call_sizes, + fallback=config.unbacked_symint_fallback, + ), + kwargs, + ) + bmreq = TritonBenchmarkRequest( + module_path=mod.__file__, + module_cache_key=mod.key, + kernel_name=kernel_name, + grid=grid, + extra_args=extra_args, + num_stages=num_stages, + num_warps=num_warps, + matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0), + input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), # type: ignore[arg-type] + output_tensor_meta=TensorMeta.from_irnodes(layout), + ) + + return TritonTemplateCaller( + kernel_hash_name, + full_input_nodes, + layout, + make_kernel_render, + extra.strip("-").replace("-", ", "), + bmreq, + log_info={ + "tile_shape": str( + ( + kwargs.get("BLOCK_M", -1), + kwargs.get("BLOCK_K", -1), + kwargs.get("BLOCK_N", -1), + ) + ), + "num_stages": num_stages, + "num_warps": num_warps, + "allow_tf32": str(kwargs.get("ALLOW_TF32", None)), + "acc_type": str(kwargs.get("ACC_TYPE", None)), + }, + mutated_inputs=mutated_inputs, + ) + + +class ExternKernelChoice: + def __init__( + self, + kernel, + cpp_kernel=None, + *, + name=None, + has_out_variant=True, + op_overload=None, + use_fallback_kernel=False, + kernel_creator=None, + ) -> None: + super().__init__() + name = name or kernel.__name__ + assert callable(kernel) + assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}" + self.name = name + self.cpp_kernel_name = cpp_kernel + self.has_out_variant = has_out_variant + setattr(extern_kernels, name, kernel) + self.op_overload = op_overload + self.use_fallback_kernel = use_fallback_kernel + self.kernel_creator = kernel_creator + + def to_callable(self): + return getattr(extern_kernels, self.name) + + def call_name(self): + return f"extern_kernels.{self.name}" + + @functools.lru_cache(None) # noqa: B019 + def hash_key(self): + fn = self.to_callable() + parts = [ + self.name, + getattr(fn, "__name__", ""), + getattr(fn, "__module__", ""), + ] + try: + parts.append(inspect.getsource(fn)) + except Exception: + pass + return code_hash("-".join(parts)) + + def bind( + self, + input_nodes, + layout, + ordered_kwargs_for_cpp_kernel=(), + **kwargs, + ): + self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel + return ExternKernelCaller( + self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant + ) + + +class TritonTemplateCaller(ir.TritonTemplateCallerBase): + def __init__( + self, + name, + input_nodes, + layout, + make_kernel_render, + debug_extra, + bmreq, + log_info: Optional[ + Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]] + ] = None, + mutated_inputs=None, + ) -> None: + super().__init__(name, input_nodes, layout) + self.make_kernel_render = make_kernel_render + self.debug_extra = debug_extra + self.bmreq: TritonBenchmarkRequest = bmreq + if log_info is None: + log_info = {} + self.log_info: Dict[str, Any] = log_info + self.log_info.update( + { + "backend": "Triton", + "grid": str(self.bmreq.grid), + "num_stages": self.bmreq.num_stages, + "num_warps": self.bmreq.num_warps, + } + ) + self.mutated_inputs = mutated_inputs + + def benchmark(self, *args, out): + assert self.bmreq is not None + return self.bmreq.benchmark(*args, output_tensor=out) + + def precompile(self): + assert self.bmreq is not None + self.bmreq.precompile() + + def __str__(self) -> str: + return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})" + + def call_name(self): + return f"template_kernels.{self.name}" + + def hash_key(self): + return "-".join( + [ + self.name.rsplit("_", 1)[0], + self.bmreq.module_cache_key, + ] + ) + + def output_node(self): + return ir.TensorBox.create( + ir.TritonTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + debug_extra=self.debug_extra, + mutated_inputs=self.mutated_inputs, + ) + ) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return self.log_info + + def get_make_kernel_render(self): + return self.make_kernel_render + + def autoheuristic_id(self): + type_name = "triton" + info = self.info_dict() + # TODO(AlnisM): Does tile_shape always exist? + tile = info["tile_shape"] + tile_vals = eval(tile) # type: ignore[arg-type] + BLOCK_M = tile_vals[0] + BLOCK_K = tile_vals[1] + BLOCK_N = tile_vals[2] + num_stages = info["num_stages"] + num_warps = info["num_warps"] + return f"type={type_name}_BLOCK-M={BLOCK_M}_BLOCK-K={BLOCK_K}_BLOCK-N={BLOCK_N}_numstages={num_stages}_numwarps={num_warps}" + + +class ExternKernelCaller(ChoiceCaller): + def __init__( + self, + choice: ExternKernelChoice, + input_nodes, + layout, + kwargs=None, + *, + has_out_variant=True, + ) -> None: + super().__init__(choice.name, input_nodes, layout) + self.choice = choice + self.kwargs = kwargs or {} + self.has_out_variant = has_out_variant + + def __str__(self) -> str: + return f"ExternKernelCaller({self.choice.call_name()})" + + def benchmark(self, *args, out): + if out.numel() == 0: + # no need to run the kerrnel of do benchmarking + return 0.0 + if self.has_out_variant: + return super().benchmark(*args, out=out) + else: + algo = self.to_callable() + out_new = algo(*args) + torch._C._dynamo.guards.assert_size_stride( + out_new, tuple(out.size()), tuple(out.stride()) + ) + out.copy_(out_new) # for correctness checking + return benchmarker.benchmark(algo, args, {}) + + def to_callable(self): + fn = self.choice.to_callable() + if self.kwargs: + return functools.partial(fn, **self.kwargs) + else: + return fn + + def hash_key(self): + return "-".join( + [ + self.choice.name, + *[ + f"{kwarg}={repr(self.kwargs[kwarg])}" + for kwarg in sorted(self.kwargs.keys()) + ], + self.choice.hash_key(), + ] + ) + + def output_node(self): + if config.abi_compatible and self.choice.use_fallback_kernel: + assert ( + self.choice.op_overload is not None + ), "Please provide an op_overload to use ir.FallbackKernel" + inner = ir.FallbackKernel.create( + self.choice.op_overload, *self.input_nodes, **self.kwargs + ) + elif self.choice.kernel_creator is not None: + inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs) + else: + cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc + inner = cls( + layout=self.layout, + inputs=self.input_nodes, + python_kernel_name=self.choice.call_name(), + cpp_kernel_name=self.choice.cpp_kernel_name, + ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel, + op_overload=self.choice.op_overload, + kwargs=self.kwargs, + ) + + return ir.TensorBox.create(inner) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "extern", + "kernel_call_name": self.choice.call_name(), + } + + def autoheuristic_id(self): + return f"extern_{self.choice.name}" + + +@functools.lru_cache(None) +def get_mm_log_filename() -> Optional[str]: + mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None) + if not mm_file_name: + return None + + if "json" not in mm_file_name: + mm_file_name = f"{mm_file_name}.json" + + return mm_file_name + + +def append_to_log(filename, data): + lock_file = filename.replace(".json", ".lock") + lock = FileLock(lock_file) + with lock: + try: + with open(filename) as f: + log_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + log_data = [] + + log_data.append(data) + + with open(filename, "w") as f: + json.dump(log_data, f, indent=4) + + +class DataProcessorChoiceCallerWrapper: + def __init__(self, wrapped, preprocessor, postprocessor) -> None: + self._wrapped = wrapped + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x: x + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def benchmark(self, *args, out) -> float: + new_args, new_out = self._preprocessor(args, out) + result = self._wrapped.benchmark(*new_args, out=new_out) + new_out = self._postprocessor(new_out) + if out is not new_out: + out.copy_(new_out) + return result + + def output_node(self) -> ir.TensorBox: + result = self._wrapped.output_node() + return self._postprocessor(result) + + def __repr__(self) -> str: + return f"DataProcessorChoiceCallerWrapper({self._wrapped})" + + +class DataProcessorTemplateWrapper: + """ + A wrapper class for a kernel template. + + This class together with `DataProcessorChoiceCallerWrapper` provides a convenient way to + preprocess and postprocess data before and after using the wrapped template. A typical + usage is to reorder or filter the input nodes in order to match the expected input of other + kernel choices like a ATen kernel. A more complicated usage is to prepack the weights. + See the example from :mod:`cpp_gemm_template` for more details. + """ + + def __init__( + self, + wrapped_template_cls, + preprocessor, + postprocessor, + **kwargs, + ) -> None: + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x: x + assert "input_nodes" in kwargs + assert "layout" in kwargs + kwargs["input_nodes"], kwargs["layout"] = preprocessor( + kwargs["input_nodes"], kwargs["layout"] + ) + self._wrapped = wrapped_template_cls(**kwargs) + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def maybe_append_choice(self, choices, **kwargs): + return type(self._wrapped).maybe_append_choice(self, choices, **kwargs) + + def generate(self, **kwargs): + choice_caller = self._wrapped.generate(**kwargs) + return DataProcessorChoiceCallerWrapper( + choice_caller, self._preprocessor, self._postprocessor + ) + + def __repr__(self) -> str: + return f"DataProcessorTemplateWrapper({self._wrapped})" + + +class ErrorFromChoice(RuntimeError): + def __init__(self, msg, choice: ChoiceCaller, inputs_str) -> None: + msg += f"\nFrom choice {choice}\n{inputs_str}" + super().__init__(msg) + self.choice = choice + + +class NoValidChoicesError(RuntimeError): + pass + + +@functools.lru_cache(None) +def get_env_num_workers() -> Optional[int]: + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + return None + + +def create_inputs_key(input_nodes) -> str: + return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes]) + + +def create_precompile_key( + name: str, inputs_key: str, choices: List[ChoiceCaller] +) -> str: + return ":".join( + [ + name, + inputs_key, + torch.get_float32_matmul_precision(), + ] + + [choice.hash_key() for choice in choices] + ) + + +class AlgorithmSelectorCache(PersistentCache): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # the autotuning will get occur in the scheduler, so there is + # no guarantee that the first lowering for a given key will also be the + # first to benchmark it. share a single precompilation function for all lowerings + # of a particular key + self.precompile_cache: Dict[str, Callable[[], None]] = {} + # list of callbacks that are called after benchmarking + self.feedback_saver_fns: List[ + Callable[ + [Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None + ] + ] = [] + + def __call__( + self, + name, + choices: List[ChoiceCaller], + input_nodes, + layout, + # optional dict mapping arg indices to the functions + # generating a torch.Tensor for that input from the + # corresponding ir.Buffer. if passed for a given + # arg, the function will be called instead of + # generating a random torch.Tensor for benchmarking. + input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, + precompilation_timeout_seconds: int = 60 * 60, + return_multi_template=False, + ): + from .codegen.cuda.cuda_kernel import CUDATemplateCaller + + # Templates selected with input_gen_fns require specific input data to avoid IMA + # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection + # TODO(jgong5): support multi-template on CPU + if input_gen_fns is not None or layout.device.type == "cpu": + return_multi_template = False + + # TODO - assert that we have not mutating kernels here + + # TODO(nmacchioni): remove once CI tests are fixed + choices = [choice for choice in choices if choice is not None] + + if mm_file_name := get_mm_log_filename(): + M, K = input_nodes[-2].get_size()[:2] + N = input_nodes[-1].get_size()[-1] + append_to_log(mm_file_name, {"invoke": str((M, K, N))}) + + if len(choices) == 0: + backend_config = ( + "max_autotune_gemm_backends" + if name != "convolution" + else "max_autotune_conv_backends" + ) + raise NoValidChoicesError( + f"No choices to select, please consider adding ATEN into {backend_config} " + "config (defined in torch/_inductor/config.py) to allow at least one choice. " + ) + log.debug("Max autotune selects from %s choices.", str(len(choices))) + + if len(choices) == 1: + if not isinstance(choices[0], CUDATemplateCaller): + # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size. + return choices[0].output_node() + + @functools.lru_cache(None) + def make_benchmark_fn(): + return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns) + + inputs_key = create_inputs_key(input_nodes) + + def precompile(choices) -> Callable[[], None]: + def no_op(*args, **kwargs): + return + + if ( + precompilation_timeout_seconds is None + or precompilation_timeout_seconds <= 0 + ): + return no_op + + env_workers = get_env_num_workers() + num_workers = env_workers if env_workers is not None else (len(choices)) + + if num_workers <= 0: + return no_op + + # https://github.com/python/cpython/issues/106905 + if ( + sys.version_info.major == 3 + and sys.version_info.minor == 11 + and sys.version_info.micro <= 8 + ): + return no_op + + # check local and global cache before precompiling + timings = self.lookup( + choices, + name, + inputs_key, + benchmark=None, + ) + + if timings: + return no_op + + precompile_key = create_precompile_key(name, inputs_key, choices) + if precompile_func := self.precompile_cache.get(precompile_key): + return precompile_func + + log.info( + "Multithreaded precompilation for %d choices using %d worker threads", + len(choices), + num_workers, + ) + + # In rare circumstances, because python threads inherit global state, + # thread pool executor can race and leave stdout/stderr in a state + # different than the original values. we explicitly restore the state + # here to avoid this issue. + + initial_stdout = sys.stdout + initial_stderr = sys.stderr + + def precompile_with_captured_stdout(choice): + with restore_stdout_stderr(initial_stdout, initial_stderr): + return choice.precompile() + + executor = ThreadPoolExecutor(max_workers=num_workers) + + futures = {} + for c in choices: + if hasattr(c, "precompile"): + future = executor.submit(precompile_with_captured_stdout, c) + futures[future] = c + + @functools.lru_cache(None) + @restore_stdout_stderr(initial_stdout, initial_stderr) + def wait_on_futures(): + counters["inductor"]["select_algorithm_precompile"] += 1 + for future in as_completed( + futures, + timeout=precompilation_timeout_seconds, + ): + if e := future.exception(): + log.error( + "Exception %s for benchmark choice %s", e, futures[future] + ) + + executor.shutdown(wait=True) + + self.precompile_cache[precompile_key] = wait_on_futures + + return wait_on_futures + + def autotune(choices): + return make_benchmark_fn()(choices) + + if config.autotune_in_subproc: + from .autotune_process import tuning_pool + + # do the optional warmup + tuning_pool.initialize() + + def do_autotuning(precompile_fn): + precompile_start_ts = time.time() + precompile_fn() + precompile_elapse = time.time() - precompile_start_ts + + autotune_start_ts = time.time() + timings = self.lookup( + choices, + name, + inputs_key, + autotune, + ) + autotune_elapse = time.time() - autotune_start_ts + + if timings and all( + not math.isfinite(timing) for timing in timings.values() + ): + raise NoValidChoicesError + + if make_benchmark_fn.cache_info().currsize: + counters["inductor"]["select_algorithm_autotune"] += 1 + + if ( + make_benchmark_fn.cache_info().currsize + or log.getEffectiveLevel() == logging.DEBUG + or config.trace.log_autotuning_results + ): + self.log_results( + name, input_nodes, timings, autotune_elapse, precompile_elapse + ) + + for feedback_fn in self.feedback_saver_fns: + feedback_fn(timings, name, input_nodes, choices) + + return timings + + precompile_fn = precompile(choices) + + if return_multi_template and (config.max_autotune or config.max_autotune_gemm): + + def get_timings(): + timings = do_autotuning(precompile_fn) + min_extern_choice = float("inf") + for choice, timing in timings.items(): + if isinstance(choice, ExternKernelCaller): + min_extern_choice = min(min_extern_choice, timing) + + timings = { + choice: time + for choice, time in timings.items() + if ( + time <= min_extern_choice + or not isinstance(choice, ExternKernelCaller) + ) + } + + return timings + + return torch._inductor.ir.TensorBox.create( + torch._inductor.ir.MultiTemplateBuffer( + layout, + input_nodes, + get_timings, + ) + ) + + # TODO - dont want to precompile if we have a cache hit + timings = do_autotuning(precompile_fn) + if timings == {} or choices[0] not in timings: + return choices[0].output_node() + + selected_key = builtins.min(timings, key=timings.__getitem__) + selected_time = timings[selected_key] + selected_choice = selected_key.output_node() + log.debug("selected choice: %s", str(selected_choice)) + return selected_choice + + @classmethod + def make_benchmark_fn( + cls, + choices, + input_nodes, + layout, + input_gen_fns=None, + ): + if input_gen_fns is None: + input_gen_fns = {} + + def get_inputs(): + # de-duplicate args + unique_example_inputs = { + x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x) + for i, x in enumerate(input_nodes) + } + example_inputs = list(unique_example_inputs.values()) + example_inputs_extern = [ + unique_example_inputs[input_node.get_name()] + if unique_example_inputs[input_node.get_name()].is_mkldnn + else torch.as_strided( + unique_example_inputs[input_node.get_name()], + V.graph.sizevars.size_hints( + input_node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + input_node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hint( + input_node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + for input_node in input_nodes + ] + + out = cls.benchmark_example_value(layout) + out_extern = torch.as_strided( + out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset) + ) + expected = None + if VERIFY: + choices[0].benchmark(*example_inputs_extern, out=out_extern) + expected = out_extern.clone() + + return example_inputs, example_inputs_extern, out, out_extern, expected + + if DEBUG: + print(f"{len(choices)} tuning requests:") + + def debug_str(example_inputs, out): + def tensor_repr(x): + return ( + f"torch.empty_strided({tuple(x.size())!r}, {tuple(x.stride())!r}, " + f"dtype={x.dtype!r}, device={x.device.type!r})" + ) + + lines = [ + "inputs = [", + ] + for x in example_inputs: + lines.append(f" {tensor_repr(x)},") + lines += ["]", f"out = {tensor_repr(out)}", ""] + return "\n".join(lines) + + def benchmark_choice_in_current_process( + choice, example_inputs, example_inputs_extern, out, out_extern, expected + ): + out.zero_() + if isinstance(choice, ExternKernelCaller): + # aten kernels want the offset baked in for sliced tensors + result = choice.benchmark(*example_inputs_extern, out=out_extern) + else: + # triton templates want the base pointer for sliced tensors + result = choice.benchmark(*example_inputs, out=out) + if VERIFY and expected is not None: + torch.testing.assert_close(out_extern, expected, **VERIFY) + if torch.cuda.is_available(): + torch.cuda.synchronize() # shake out any CUDA errors + return result + + def benchmark_in_current_process(choices): + inputs = get_inputs() + example_inputs, _, out, _, _ = inputs + timings = {} + for choice in choices: + try: + timing = benchmark_choice_in_current_process(choice, *inputs) + except CUDACompileError as e: + log.error( + "CUDA compilation error during autotuning: \n%s. \nIgnoring this choice.", + str(e), + ) + timing = float("inf") + except NotImplementedError as e: + log.warning("Not yet implemented: %s", e) + timing = float("inf") + except RuntimeError as e: + msg = str(e) + if "invalid argument" in msg: + msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n" + else: + if "illegal memory access" in msg: + msg += "\n\nEither error in template or triton bug.\n" + log.error( + "Runtime error during autotuning: \n%s. \nIgnoring this choice.", + msg, + ) + timing = float("inf") + except AssertionError as e: + raise AssertionError( # noqa: B904 + f"Incorrect result from choice {choice}\n\n{e}" + ) + except Exception as e: + try: + from triton.runtime.autotuner import OutOfResources + + if isinstance(e, OutOfResources): + log.warning(e) + timing = float("inf") + else: + raise e + except ImportError: + raise e from None + + timings[choice] = timing + + return timings + + def benchmark_in_sub_process(choices): + from . import autotune_process + + # only benchmark triton kernel in sub process for now. + # ATen/Extern kernel are still benchmarked in the current process. + extern = [c for c in choices if isinstance(c, ExternKernelCaller)] + triton = [c for c in choices if not isinstance(c, ExternKernelCaller)] + + timings = benchmark_in_current_process(extern) + timings.update(autotune_process.benchmark_in_sub_process(triton)) + return timings + + benchmark = ( + benchmark_in_sub_process + if config.autotune_in_subproc + else benchmark_in_current_process + ) + + return benchmark + + @staticmethod + def log_results( + name: str, + input_nodes: List[ir.IRNode], + timings: Dict[ChoiceCaller, float], + elapse: float, + precompile_elapse: float, + ): + V.debug.log_autotuning_results( + name, input_nodes, timings, elapse, precompile_elapse + ) + if not (config.max_autotune or config.max_autotune_gemm) or not PRINT_AUTOTUNE: + return + sizes = ", ".join( + [ + "x".join( + map( + str, + V.graph.sizevars.size_hints( + n.get_size(), fallback=config.unbacked_symint_fallback + ), + ) + ) + for n in input_nodes + ] + ) + + n = None if log.getEffectiveLevel() == logging.DEBUG else 10 + top_k = sorted(timings, key=timings.__getitem__)[:n] + best = top_k[0] + + def get_choice_info(choice): + if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller): + return {"type": "cublas", "time": timings[choice]} + + assert isinstance( + choice, torch._inductor.select_algorithm.TritonTemplateCaller + ) + + info = choice.info_dict() + tile = info["tile_shape"] + + tile_vals = eval(tile) # type: ignore[arg-type] + BLOCK_M = tile_vals[0] + BLOCK_K = tile_vals[1] + BLOCK_N = tile_vals[2] + + return { + "type": "triton", + "time": timings[choice], + "BLOCK_M": BLOCK_M, + "BLOCK_K": BLOCK_K, + "BLOCK_N": BLOCK_N, + "num_stages": info["num_stages"], + "num_warps": info["num_warps"], + } + + mm_filename = get_mm_log_filename() + if mm_filename and "mm" in name: + M, K = input_nodes[-2].get_size()[:2] + N = input_nodes[-1].get_size()[-1] + + out_dict = { + str((M, K, N)): [get_choice_info(choice) for choice in timings.keys()] + } + + append_to_log(mm_filename, out_dict) + + best_time = timings[best] + sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") + for choice in top_k: + result = timings[choice] + if result: + kernel_info = ( + choice.debug_extra if hasattr(choice, "debug_extra") else "" + ) + sys.stderr.write( + f" {choice.name} {result:.4f} ms {best_time / result:.1%} {kernel_info}\n" + ) + else: + sys.stderr.write( + f" {choice.name} {result:.4f} ms \n" + ) + + autotune_type_str = ( + "SubProcess" if config.autotune_in_subproc else "SingleProcess" + ) + sys.stderr.write( + f"{autotune_type_str} AUTOTUNE benchmarking takes {elapse:.4f} seconds and {precompile_elapse:.4f}" + " seconds precompiling\n" + ) + + @staticmethod + def benchmark_example_value(node): + """ + Convert an ir.Buffer into a concrete torch.Tensor we can use for + benchmarking. + """ + if isinstance(node, ir.Layout): + node = ir.Buffer("fake", node) + # triton templates want the base tensor. + if isinstance(node, ir.BaseView): + node = node.unwrap_view() + return AlgorithmSelectorCache.generate_example_value( + V.graph.sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + node.get_device(), + node.get_dtype(), + node.layout.offset, + ) + + @staticmethod + def generate_example_value(size, stride, device, dtype, extra_size): + # preserve rng states to avoid the rand_strided call below changes + # the rng states for the real model code. + with preserve_rng_state(): + return rand_strided( + size, + stride, + device=device, + dtype=dtype, + extra_size=extra_size, + ) + + @staticmethod + def key_of(node): + """ + Extract the pieces of an ir.Buffer that we should invalidate cached + autotuning results on. + """ + sizevars = V.graph.sizevars + return ( + node.get_device().type, + str(node.get_dtype()), + *sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + *sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + sizevars.size_hint( + node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + + def add_feedback_saver( + self, + fn: Callable[ + [Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None + ], + ): + self.feedback_saver_fns.append(fn) + + +_ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None + + +def autotune_select_algorithm(*args, **kwargs): + global _ALGORITHM_SELECTOR_CACHE + if _ALGORITHM_SELECTOR_CACHE is None: + _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() + + if "return_multi_template" not in kwargs: + kwargs[ + "return_multi_template" + ] = torch._inductor.config.benchmark_epilogue_fusion + + return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs) + + +def add_feedback_saver( + fn: Callable[[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None] +): + global _ALGORITHM_SELECTOR_CACHE + if _ALGORITHM_SELECTOR_CACHE is None: + _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() + _ALGORITHM_SELECTOR_CACHE.add_feedback_saver(fn) + + +def realize_inputs(*args): + if len(args) == 1: + return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0])) + return [realize_inputs(x) for x in args] + + +# ensure lowering is imported so that `extern_kernels.*` is populated +from . import lowering # noqa: F401 diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/subgraph_lowering.py b/.venv/lib/python3.11/site-packages/torch/_inductor/subgraph_lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..21145375ede8fa1eb803955c435564238148fb1c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/subgraph_lowering.py @@ -0,0 +1,155 @@ +"""Utilities for lowering subgraphs used by higher order operators + +""" + +import functools +import operator +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union +from typing_extensions import ParamSpec + +import torch + +from . import ir +from .exc import SubgraphLoweringException +from .ops_handler import SimpleCSEHandler +from .sizevars import SizeVarAllocator +from .virtualized import ops, V, WrapperHandler + + +T = TypeVar("T") +_P = ParamSpec("_P") + + +class PointwiseSubgraphLowering(torch.fx.Interpreter): + graph_outputs: Optional[List[ir.IRNode]] + + def __init__( + self, + gm: torch.fx.GraphModule, + root_graph_lowering: "torch._inductor.graph.GraphLowering", + ) -> None: + super().__init__(gm) + self.graph_outputs = None + self.root_graph = root_graph_lowering + + @property + def sizevars(self) -> SizeVarAllocator: + return self.root_graph.sizevars + + def mark_buffer_mutated(self, name: str) -> None: + raise SubgraphLoweringException("Mutations are not supported in this context") + + def register_buffer(self, buffer: ir.Buffer) -> str: + raise SubgraphLoweringException( + "Buffer creation is not supported in this context" + ) + + def call_function( + self, + target: Callable[[Any], Any], # type: ignore[override] + args: Any, + kwargs: Dict[str, Any], + ) -> Any: + from .lowering import lowerings + + if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): + return super().call_function(target, args, kwargs) + + assert isinstance(target, torch._ops.OpOverload) + + if target not in lowerings: + raise SubgraphLoweringException( + f"{target} not supported in subgraph, (missing lowering)" + ) + + if torch.Tag.pointwise not in target.tags: + raise SubgraphLoweringException( + f"Only pointwise operators are supported in this context, but got {target}" + ) + + return lowerings[target](*args, **kwargs) + + def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None: # type: ignore[override] + assert len(args) == 1 + self.graph_outputs = args[0] + + +@dataclass +class InputDescriptor: + dtype: torch.dtype + device: torch.device + + +class TracingOpsHandler(WrapperHandler[T]): + def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None: + parent = tracer.create_proxy("placeholder", "ops", (), {}) + super().__init__(parent) + self.tracer = tracer + + self.placeholders = [ + self.tracer.create_proxy("placeholder", f"input{i}", (), {}) + for i in range(num_inputs) + ] + + def placeholder(self, idx: int) -> torch.fx.Proxy: + return self.placeholders[idx] + + def output(self, *args: Tuple[object]) -> torch.fx.Node: + return self.tracer.create_node( + "output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {} + ) + + +def lower_pointwise_subgraph( + subgraph: ir.Subgraph, inputs: List[InputDescriptor] +) -> Callable[_P, Any]: + # Lower subgraph to ir.Pointwise nodes + def fake_inner_fn( + loop_idx: int, input_idx: int + ) -> Union[ir.Expr, ir.TensorBox, None]: + return ops.placeholder(input_idx) + + graph_inputs = [ + ir.Pointwise.create( + device=desc.device, + dtype=desc.dtype, + inner_fn=functools.partial(fake_inner_fn, input_idx=i), + ranges=[], + ) + for i, desc in enumerate(inputs) + ] + gm = subgraph.graph_module + pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph) + with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] + pw_subgraph.run(*graph_inputs) + + # Combine multiple pointwise computations into a single graph module + # Do this by tracing through each individually and doing CSE + tracer = torch.fx.Tracer() + tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) + trace_ops = SimpleCSEHandler(TracingOpsHandler(tracer, len(inputs))) + assert pw_subgraph.graph_outputs is not None + + with V.set_ops_handler(trace_ops): + output_irs = [] + + for out_var in pw_subgraph.graph_outputs: + assert isinstance(out_var, ir.TensorBox), type(out_var) + assert out_var.get_size() == [] + assert isinstance(out_var.data, ir.StorageBox) + assert isinstance(out_var.data.data, ir.Pointwise) + + idx = () + ir_out = out_var.data.data.inner_fn(idx) + + output_irs.append(ir_out) + + ops.output(*output_irs) + + lowered_gm = torch.fx.GraphModule({}, tracer.graph) + + def inner_fn(*args: _P.args, **kwargs: _P.kwargs) -> Any: + return lowered_gm(V.get_ops_handler(), *args, **kwargs) + + return inner_fn diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/test_case.py b/.venv/lib/python3.11/site-packages/torch/_inductor/test_case.py new file mode 100644 index 0000000000000000000000000000000000000000..53a791685c67ed005f8d88f8d64661187e63e0bb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/test_case.py @@ -0,0 +1,35 @@ +# mypy: allow-untyped-defs +import contextlib +import os + +from torch._dynamo.test_case import ( + run_tests as dynamo_run_tests, + TestCase as DynamoTestCase, +) +from torch._inductor import config +from torch._inductor.utils import fresh_inductor_cache + + +def run_tests(needs=()): + dynamo_run_tests(needs) + + +class TestCase(DynamoTestCase): + """ + A base TestCase for inductor tests. Enables FX graph caching and isolates + the cache directory for each test. + """ + + def setUp(self): + super().setUp() + self._inductor_test_stack = contextlib.ExitStack() + self._inductor_test_stack.enter_context(config.patch({"fx_graph_cache": True})) + if ( + os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1" + and os.environ.get("TORCH_COMPILE_DEBUG") != "1" + ): + self._inductor_test_stack.enter_context(fresh_inductor_cache()) + + def tearDown(self): + super().tearDown() + self._inductor_test_stack.close() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/test_operators.py b/.venv/lib/python3.11/site-packages/torch/_inductor/test_operators.py new file mode 100644 index 0000000000000000000000000000000000000000..a5c1d401f2d0207503273dc709b01e50a288fe1c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/test_operators.py @@ -0,0 +1,27 @@ +# mypy: allow-untyped-defs +import torch.library +from torch import Tensor +from torch.autograd import Function + + +if not torch._running_with_deploy(): + _test_lib_def = torch.library.Library("_inductor_test", "DEF") + _test_lib_def.define( + "realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag + ) + + _test_lib_impl = torch.library.Library("_inductor_test", "IMPL") + for dispatch_key in ("CPU", "CUDA", "Meta"): + _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) + + class Realize(Function): + @staticmethod + def forward(ctx, x): + return torch.ops._inductor_test.realize(x) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + def realize(x: Tensor) -> Tensor: + return Realize.apply(x) diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/virtualized.py b/.venv/lib/python3.11/site-packages/torch/_inductor/virtualized.py new file mode 100644 index 0000000000000000000000000000000000000000..b00a94c5f2cee32534f94730b34dc7bffa7a9c56 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/virtualized.py @@ -0,0 +1,361 @@ +# mypy: allow-untyped-defs +""" +This file provides a number of "global" variables/handlers that are actually +thread local and dynamically scoped, with Inductor patching them to various +implementations depending on the situation. + +These handlers are interacted with in a fairly stylized way. Typically, +we will import V from this module:: + + from .virtualized import V + +Various handlers are accessible as attributes on this module; for example, +you might access ``V.graph.sizevars.size_hint`` to resolve a size hint associated with +a number. + +There are a few distinct usage patterns for virtualized global variables: + +1. Implicit argument passing. Examples: ``V.current_node``, ``V.aot_compilation``. + Use ``V.set_current_node`` to change what the current node is while we're + executing some region of code, so code inside that region can query ``V.current_node`` + to find out what it is. This is often more convenient than manually threading + the current node as an argument through all call stacks. + +2. Per-compilation global state. Examples: ``V.fake_mode``, ``V.graph``. For a + given ``compile_fx`` invocation, these typically don't change, but they are + associated with some internal state so they cannot just be global functions. + We install these objects at the beginning of compilation and then you can + conveniently access them without having to pass them around. + +3. Alternate define-by-run interpretations. Examples: ``V.ops``, ``V.kernel``. + A commonly used IR in Inductor is define-by-run: instead of maintaining + explicit syntax data structures, we instead represent loop bodies as + callable functions, which internally invoke operations defined on + ``V.ops``. To perform semantic analysis, print or code generate these + operations, we dynamically patch ``V.ops`` with an alternate handler with + the intended semantics and then run the callable function. For example, to + extract out a traditional (FX) graph representation of the define-by-run + IR, simply install a handler that records each ``ops`` call to a graph. + + TODO: Define a parent class / protocol that defines all of the operations + V.ops is expected to support. + +It is typically an error to access a virtualized global without having installed +an appropriate handler (you will get a NullHandler), although in some cases we +provide a default implementation. + +One last thing: although most virtualized globals are accessed via ``V``, ``ops`` is +ubiquitous enough to have its own top level variable, so you will typically see +``ops.constant(...)`` rather than ``V.ops.constant(...)``. In fact, these are not +equivalent; the former interface supports arithmetic overloads like ``x + y`` +instead of forcing ``ops.add(x, y)``, so it should be preferred. + +Some operators are seemingly unused, but they are implicitly used by ops_wrapper. +In particular, we typically have an operator for every basic pointwise PyTorch operation +supported. +""" + +from __future__ import annotations + +from contextlib import AbstractContextManager, contextmanager +from threading import local +from typing import Any, Callable, Generic, List, Type, TYPE_CHECKING, TypeVar, Union + +from .ops_handler import ( # noqa: F401 + KernelFormatterHandler, + MockHandler, + OpsHandler, + ReductionType, + StoreMode, + WrapperHandler, +) + + +if TYPE_CHECKING: + import torch + from torch._inductor.codegen.cpp_utils import LocalBufferContext + from torch._inductor.debug import DebugContext + from torch._inductor.graph import GraphLowering + from torch._inductor.loop_body import InterpreterShim + from torch._subclasses import FakeTensorMode + +threadlocal = local() + +T = TypeVar("T") + + +class NullHandler: + """ + Sentinel indicating that a global variable is unset ala None. Typically, + attempting to access the global variable before it's set is an error, but with + NullHandler it won't fail until you try to access an attribute on it. + """ + + +class Virtualized(Generic[T]): + """ + Implements a global variable that redirects via thread local variable + (NB: construct this class to create the global variable; this is not + a singleton class!) + + This allows us to swap in different op implementations in codegen. + + NB: Despite the fact that we typically call these "handlers" (e.g., NullHandler is + the default value of the variable), we sometimes use these variables to + store other things, like booleans. + """ + + def __init__(self, vname: str, default: Union[Callable[[], T], Type[NullHandler]]): + self._key: str = f"__torchinductor_{vname}" + self._default = default + + def _set_handler(self, value: T) -> AbstractContextManager[None]: + prior = self._get_handler() + setattr(threadlocal, self._key, value) + + @contextmanager + def ctx(): + try: + yield + finally: + self._set_handler(prior) + + return ctx() + + def _get_handler(self) -> T: + try: + return getattr(threadlocal, self._key) + except AttributeError: + # TODO: To be honest, I feel we probably should just error in this + # case, instead of making a null handler that will probably error + # when you getattr on it + return self._default() # type: ignore[return-value] + + def __getattr__(self, name: str) -> Any: + return getattr(self._get_handler(), name) + + +class NullKernelHandler(NullHandler): + """ + We need access `V.kernel.removed_buffers` in DeferredLine class when there + is no kernel in the context. This happens when codegening the wrapper. + Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't + need call 'getattr' with default value which is error prone to typo in + attribute name. + """ + + def __init__(self): + super().__init__() + self.removed_buffers = set() + self.inplaced_to_remove = set() + self.index_dtype = "tl.int64" + + +_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler) +_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler) +_real_inputs: Virtualized[List[torch.Tensor]] = Virtualized("real_inputs", NullHandler) +_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler) +_kernel: Virtualized[NullKernelHandler] = Virtualized( + "kernel", NullKernelHandler +) # TODO: improve type +_debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler) +_interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler) +_aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler) +_current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler) +_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized( + "local_buffer_context", NullHandler +) + + +class OpsValue: + """The return type of most ops calls. + + This exists so we can overload magic methods, and write mathematical + expressions much more fluently. So instead of + + ops.add(ops.mul(ops.mul(ops.sub(ops.mul(_Ap2, x), _Ap3), x), x), _1) + + we can write + + (_Ap2 * x - _Ap3) * x * x + _1 + + """ + + value: Any + + def __init__(self, value): + self.value = value + + def __str__(self): + return str(self.value) + + def __repr__(self): + return f"OpsValue({self.value!r})" + + def __add__(self, other): + return ops.add(self, other) + + def __mul__(self, other): + return ops.mul(self, other) + + def __sub__(self, other): + return ops.sub(self, other) + + def __neg__(self): + return ops.neg(self) + + def __truediv__(self, other): + return ops.truediv(self, other) + + def __floordiv__(self, other): + return ops.floordiv(self, other) + + def __mod__(self, other): + return ops.mod(self, other) + + def __pow__(self, other): + return ops.pow(self, other) + + def __lt__(self, other): + return ops.lt(self, other) + + def __le__(self, other): + return ops.le(self, other) + + def __eq__(self, other): + return ops.eq(self, other) + + def __ne__(self, other): + return ops.ne(self, other) + + def __gt__(self, other): + return ops.gt(self, other) + + def __ge__(self, other): + return ops.ge(self, other) + + def __and__(self, other): + return ops.bitwise_and(self, other) + + def __or__(self, other): + return ops.bitwise_or(self, other) + + def __xor__(self, other): + return ops.bitwise_xor(self, other) + + def __invert__(self): + return ops.bitwise_not(self) + + def __rshfit__(self, n): + return ops.bitwise_right_shift(self, n) + + def __lshift__(self, n): + return ops.bitwise_left_shift(self, n) + + +class OpsWrapper: + """This wraps any returned IR values into an `OpsValue` instance, so that we + can overload the magic methods for writing mathematical expressions fluently. + """ + + def __getattr__(self, name): + def inner(*args, **kwargs): + new_args = [OpsWrapper._unwrap(a) for a in args] + new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()} + return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs)) + + return inner + + @staticmethod + def _unwrap(x): + if isinstance(x, (list, tuple)): + return tuple(OpsWrapper._unwrap(v) for v in x) + if isinstance(x, OpsValue): + return x.value + return x + + @staticmethod + def _wrap(x): + if isinstance(x, (list, tuple)): + return tuple(OpsValue(v) for v in x) + return OpsValue(x) + + @staticmethod + def indirect_indexing(index, size, check=True, wrap_neg=True): + # Returns a sympy value, not IR value + index = OpsWrapper._unwrap(index) + return _ops.indirect_indexing(index, size, check, wrap_neg) + + +ops = OpsWrapper() + + +class _V: + MockHandler = MockHandler + KernelFormatterHandler = KernelFormatterHandler + WrapperHandler = WrapperHandler + + set_ops_handler: Callable[[Any], Any] = _ops._set_handler + get_ops_handler: Callable[[], Any] = _ops._get_handler + set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler + set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler + get_real_inputs: Callable[[], Any] = _real_inputs._get_handler + set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler + get_fake_mode: Callable[[], Any] = _fake_mode._get_handler + set_kernel_handler: Callable[[Any], Any] = _kernel._set_handler + set_debug_handler: Callable[[Any], Any] = _debug._set_handler + set_interpreter_handler: Callable[[Any], Any] = _interpreter._set_handler + set_aot_compilation: Callable[[bool], Any] = _aot_compilation._set_handler + get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler + set_current_node: Callable[[Any], Any] = _current_node._set_handler + get_current_node: Callable[[], Any] = _current_node._get_handler + set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler + get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler + + @property + def ops(self) -> OpsHandler[Any]: + """The operator handler specific to the current codegen task""" + return _ops._get_handler() + + @property + def graph(self) -> GraphLowering: + """The graph currently being generated""" + return _graph._get_handler() + + @property + def real_inputs(self): + """non-fake example inputs""" + return _real_inputs._get_handler() + + @property + def fake_mode(self): + """The graph currently being generated""" + return _fake_mode._get_handler() + + @property + def kernel(self): + """The kernel currently being generated""" + return _kernel._get_handler() + + @property + def debug(self): + return _debug._get_handler() + + @property + def interpreter(self): + return _interpreter._get_handler() + + @property + def aot_compilation(self): + return _aot_compilation._get_handler() + + @property + def current_node(self): + return _current_node._get_handler() + + @property + def local_buffer_context(self): + return _local_buffer_context._get_handler() + + +V = _V() diff --git a/.venv/lib/python3.11/site-packages/torch/_inductor/wrapper_benchmark.py b/.venv/lib/python3.11/site-packages/torch/_inductor/wrapper_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd1f0fc95b7f905eab8457a733d4b036849c443 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_inductor/wrapper_benchmark.py @@ -0,0 +1,315 @@ +# mypy: allow-untyped-defs +import dataclasses +import tempfile +from collections import defaultdict + +import torch +from torch.autograd import DeviceType + +from .runtime.benchmarking import benchmarker +from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes + + +_kernel_category_choices = [ + "foreach", + "persistent_reduction", + "pointwise", + "reduction", + "split_scan", + "template", +] + + +def get_kernel_category_by_source_code(src_code): + """ + Similar to get_kernel_category but use the source code. Call this API + if we have not compile the src_code to module yet. + """ + choices = [ + ch for ch in _kernel_category_choices if f"@triton_heuristics.{ch}" in src_code + ] + if len(choices) == 1: + return choices[0] + else: + return "unknown" + + +def get_kernel_category(kernel_mod): + """ + Given the module defining a triton kernel, return the category of the kernel. + Category can be one of: + - pointwise + - reduction + - persistent_reduction + + Currently we simply decide the category depending on what decorator is imported + by the kernel. + """ + choices = [ch for ch in _kernel_category_choices if ch in kernel_mod.__dict__] + if len(choices) == 1: + return choices[0] + else: + return "unknown" + + +def get_triton_kernel(mod): + from torch._inductor.runtime.triton_heuristics import CachingAutotuner + + cand_list = [ + v + for k, v in mod.__dict__.items() + if k.startswith("triton_") and isinstance(v, CachingAutotuner) + ] + assert len(cand_list) == 1 + return cand_list[0] + + +def benchmark_all_kernels(benchmark_name, benchmark_all_configs): + """ + An experimental API used only when config.benchmark_kernel is true. + + Run the kernel benchmarks for all the kernels cached in PyCodeCache. + Used in the compiled modules. + + Put this method here rather than codegen it for convenience since its implementation + does not change based on different graph modules being compiled. + """ + from torch._inductor.codecache import PyCodeCache + + nfound = 0 + for kernel_key, kernel_mod in PyCodeCache.cache.items(): + if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"): + continue + + triton_kernel = get_triton_kernel(kernel_mod) + kernel_category = get_kernel_category(kernel_mod) + args = kernel_mod.get_args() + num_in_out_ptrs = len( + [ + arg_name + for arg_name in triton_kernel.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + + def get_info_str(ms, n_regs, n_spills, shared, prefix=""): + if not any(x is None for x in [n_regs, n_spills, shared]): + kernel_detail_str = ( + f" {n_regs:3} regs {n_spills:3} spills {shared:8} shared mem" + ) + else: + kernel_detail_str = "" + + gb_per_s = num_gb / (ms / 1e3) + return create_bandwidth_info_str( + ms, num_gb, gb_per_s, prefix=prefix, suffix=kernel_detail_str + ) + + kernel_desc = ( + f"{benchmark_name:20} {kernel_category[:3].upper()} {kernel_key[:10]}" + ) + if benchmark_all_configs: + assert hasattr(kernel_mod, "benchmark_all_configs") + bench_result = kernel_mod.benchmark_all_configs(args) + print(kernel_desc) + for launcher, ms in bench_result.items(): + print( + f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}" + ) + else: + ms = benchmarker.benchmark_gpu( + lambda: kernel_mod.call(args), rep=40, fast_flush=True + ) + assert ( + len(triton_kernel.launchers) == 1 + ), "Autotuner should have selected the best config" + launcher = triton_kernel.launchers[0] + print( + get_info_str( + ms, + launcher.n_regs, + launcher.n_spills, + launcher.shared, + prefix=f"{kernel_desc} ", + ) + ) + + nfound += 1 + if nfound == 0: + print( + "No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True" + ) + + +@dataclasses.dataclass +class ProfileEvent: + category: str + key: str + self_device_time_ms: float + # the benchmark is run multiple times and we average the count across all the + # runs. It should be an integer but define a float just in case. + count: float + + +def parse_profile_event_list( + benchmark_name, event_list, wall_time_ms, nruns, device_name +): + def get_self_device_time(ev): + """ + ev.self_device_time_total is in microsecond. Convert to millisecond. + """ + return ev.self_device_time_total / 1000 / nruns + + all_events = defaultdict(list) + + def add_event(ev, category): + profile_ev = ProfileEvent( + category=category, + key=ev.key, + self_device_time_ms=get_self_device_time(ev), + count=ev.count / nruns, # average across all runs + ) + all_events[category].append(profile_ev) + + for ev in event_list: + assert not ev.is_legacy, "Don't support the legacy profiler" + if ev.device_type == DeviceType.CPU: + # ignore the event on CPU side + continue + + category = "unknown" + if ev.key.startswith("triton_"): + if ev.key.startswith("triton_poi"): + category = "triton_pointwise" + elif ev.key.startswith("triton_red"): + category = "triton_reduction" + elif ev.key.startswith("triton_per"): + category = "triton_persistent_reduction" + else: + category = "triton_unknown" + + add_event(ev, category) + + def report_category(category, profile_events): + from tabulate import tabulate + + profile_events.sort(key=lambda ev: ev.self_device_time_ms, reverse=True) + + rows = [] + total_time = 0.0 + print(f"\n == {category} category kernels == ") + for ev in profile_events: + total_time += ev.self_device_time_ms + percent = f"{ev.self_device_time_ms / wall_time_ms * 100:.2f}%" + rows.append([ev.key[:120], ev.self_device_time_ms, ev.count, percent]) + rows.append( + ["Total", total_time, "", f"{total_time / wall_time_ms * 100:.2f}%"] + ) + print( + tabulate( + rows, + headers=[ + "Kernel", + f"Self {device_name.upper()} TIME (ms)", + "Count", + "Percent", + ], + ) + ) + return total_time + + def report(): + category_list = [ + "triton_pointwise", + "triton_reduction", + "triton_persistent_reduction", + "triton_unknown", + "unknown", + ] + assert set(all_events.keys()).issubset( + set(category_list) + ), f"{list(all_events.keys())}" + + per_category_wall_time = {} + total_device_ms = 0.0 + for category in category_list: + if category in all_events: + _time = report_category(category, all_events[category]) + per_category_wall_time[category] = _time + total_device_ms += _time + + device_busy_percent = f"{total_device_ms / wall_time_ms * 100:.2f}%" + print( + f"\nPercent of time when {device_name.upper()} is busy: {device_busy_percent}" + ) + print(f"Total wall time {wall_time_ms:.3f} ms") + + # output such a line so we can gather such line from all compiled modules from all + # benchmarks and tabulate it! + # Columns: benchmark_name, pointwise_percent, reduction_percent, persistent_reduction_percent, + # unknown_category_percent, device_busy_percent, wall_time_ms + tabulate_line = f"Output for tabulate: {benchmark_name}" + for category in category_list: + percent = ( + f"{per_category_wall_time.get(category, 0.0) / wall_time_ms * 100:.2f}%" + ) + tabulate_line += f", {percent}" + tabulate_line += f", {device_busy_percent}, {wall_time_ms:.3f}ms" + + print(tabulate_line) + + report() + + +def compiled_module_main(benchmark_name, benchmark_compiled_module_fn): + """ + This is the function called in __main__ block of a compiled module. + """ + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--benchmark-kernels", + "-k", + action="store_true", + help="Whether to benchmark each individual kernels", + ) + parser.add_argument( + "--benchmark-all-configs", + "-c", + action="store_true", + help="Whether to benchmark each individual config for a kernel", + ) + parser.add_argument( + "--profile", + "-p", + action="store_true", + help="Whether to profile the compiled module", + ) + args = parser.parse_args() + + if args.benchmark_kernels: + benchmark_all_kernels(benchmark_name, args.benchmark_all_configs) + else: + times = 10 + repeat = 10 + wall_time_ms = benchmark_compiled_module_fn(times=times, repeat=repeat) * 1000 + + if not args.profile: + return + + with torch.profiler.profile(record_shapes=True) as p: + benchmark_compiled_module_fn(times=times, repeat=repeat) + + path = f"{tempfile.gettempdir()}/compiled_module_profile.json" + p.export_chrome_trace(path) + print(f"Profiling result for a compiled module of benchmark {benchmark_name}:") + print(f"Chrome trace for the profile is written to {path}") + event_list = p.key_averages(group_by_input_shape=True) + print(event_list.table(sort_by="self_device_time_total", row_limit=10)) + parse_profile_event_list( + benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device + ) diff --git a/.venv/lib/python3.11/site-packages/torch/mtia/__init__.py b/.venv/lib/python3.11/site-packages/torch/mtia/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa107e973f6d9ce9cbf153d4d544b1dca02e07e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/mtia/__init__.py @@ -0,0 +1,332 @@ +# mypy: allow-untyped-defs +r""" +This package enables an interface for accessing MTIA backend in python +""" + +import threading +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import device as _device, Tensor +from torch._utils import _dummy_type, _LazySeedTracker, classproperty +from torch.types import Device + +from ._utils import _get_device_index + + +_device_t = Union[_device, str, int, None] + +# torch.mtia.Event/Stream is alias of torch.Event/Stream +Event = torch.Event +Stream = torch.Stream + +_initialized = False +_queued_calls: List[ + Tuple[Callable[[], None], List[str]] +] = [] # don't invoke these until initialization occurs +_tls = threading.local() +_initialization_lock = threading.Lock() +_lazy_seed_tracker = _LazySeedTracker() + + +def init(): + _lazy_init() + + +def is_initialized(): + r"""Return whether PyTorch's MTIA state has been initialized.""" + return _initialized and not _is_in_bad_fork() + + +def _is_in_bad_fork() -> bool: + return torch._C._mtia_isInBadFork() + + +def _lazy_init() -> None: + global _initialized, _queued_calls + if is_initialized() or hasattr(_tls, "is_initializing"): + return + with _initialization_lock: + # We be double-checking locking, boys! This is OK because + # the above test was GIL protected anyway. The inner test + # is for when a thread blocked on some other thread which was + # doing the initialization; when they get the lock, they will + # find there is nothing left to do. + if is_initialized(): + return + # It is important to prevent other threads from entering _lazy_init + # immediately, while we are still guaranteed to have the GIL, because some + # of the C calls we make below will release the GIL + if _is_in_bad_fork(): + raise RuntimeError( + "Cannot re-initialize MTIA in forked subprocess. To use MTIA with " + "multiprocessing, you must use the 'spawn' start method" + ) + if not _is_compiled(): + raise AssertionError( + "Torch not compiled with MTIA enabled. " + "Ensure you have `import mtia.host_runtime.torch_mtia` in your python " + "src file and include `//mtia/host_runtime/torch_mtia:torch_mtia` as " + "your target dependency!" + ) + + torch._C._mtia_init() + # Some of the queued calls may reentrantly call _lazy_init(); + # we need to just return without initializing in that case. + # However, we must not let any *other* threads in! + _tls.is_initializing = True + + for calls in _lazy_seed_tracker.get_calls(): + if calls: + _queued_calls.append(calls) + + try: + for queued_call, orig_traceback in _queued_calls: + try: + queued_call() + except Exception as e: + msg = ( + f"MTIA call failed lazily at initialization with error: {str(e)}\n\n" + f"MTIA call was originally invoked at:\n\n{''.join(orig_traceback)}" + ) + raise DeferredMtiaCallError(msg) from e + finally: + delattr(_tls, "is_initializing") + _initialized = True + + +class DeferredMtiaCallError(Exception): + pass + + +def _is_compiled() -> bool: + r"""Return true if compiled with MTIA support.""" + return torch._C._mtia_isBuilt() + + +def is_available() -> bool: + r"""Return true if MTIA device is available""" + if not _is_compiled(): + return False + # MTIA has to init devices first to know if there is any devices available. + return device_count() > 0 + + +def synchronize(device: Optional[_device_t] = None) -> None: + r"""Waits for all jobs in all streams on a MTIA device to complete.""" + with torch.mtia.device(device): + return torch._C._mtia_deviceSynchronize() + + +def device_count() -> int: + r"""Return the number of MTIA devices available.""" + return torch._C._accelerator_hooks_device_count() + + +def current_device() -> int: + r"""Return the index of a currently selected device.""" + return torch._C._accelerator_hooks_get_current_device() + + +def current_stream(device: Optional[_device_t] = None) -> Stream: + r"""Return the currently selected :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the currently selected :class:`Stream` for the current device, given + by :func:`~torch.mtia.current_device`, if :attr:`device` is ``None`` + (default). + """ + return torch._C._mtia_getCurrentStream(_get_device_index(device, optional=True)) + + +def default_stream(device: Optional[_device_t] = None) -> Stream: + r"""Return the default :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the default :class:`Stream` for the current device, given by + :func:`~torch.mtia.current_device`, if :attr:`device` is ``None`` + (default). + """ + return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True)) + + +def memory_stats(device: Optional[_device_t] = None) -> Dict[str, Any]: + r"""Return a dictionary of MTIA memory allocator statistics for a given device. + + Args: + device (torch.device or int, optional) selected device. Returns + statistics for the current device, given by current_device(), + if device is None (default). + """ + if not is_initialized(): + return {} + return torch._C._mtia_memoryStats(_get_device_index(device, optional=True)) + + +def set_stream(stream: Stream): + r"""Set the current stream.This is a wrapper API to set the stream. + Usage of this function is discouraged in favor of the ``stream`` + context manager. + + Args: + stream (Stream): selected stream. This function is a no-op + if this argument is ``None``. + """ + if stream is None: + return + torch._C._mtia_setCurrentStream(stream) + + +def set_device(device: _device_t) -> None: + r"""Set the current device. + + Args: + device (torch.device or int): selected device. This function is a no-op + if this argument is negative. + """ + device = _get_device_index(device) + if device >= 0: + torch._C._accelerator_hooks_set_current_device(device) + + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device: Any): + self.idx = _get_device_index(device, optional=True) + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch._C._accelerator_hooks_maybe_exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + self.idx = torch._C._accelerator_hooks_maybe_exchange_device(self.prev_idx) + return False + + +class StreamContext: + r"""Context-manager that selects a given stream. + + All MTIA kernels queued within its context will be enqueued on a selected + stream. + + Args: + Stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + .. note:: Streams are per-device. + """ + + cur_stream: Optional["torch.mtia.Stream"] + + def __init__(self, stream: Optional["torch.mtia.Stream"]): + self.cur_stream = None + self.stream = stream + self.idx = _get_device_index(None, True) + if not torch.jit.is_scripting(): + if self.idx is None: + self.idx = -1 + + self.src_prev_stream = ( + None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) + ) + self.dst_prev_stream = ( + None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) + ) + + def __enter__(self): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # Return if stream is None or MTIA device not available + if cur_stream is None or self.idx == -1: + return + self.src_prev_stream = torch.mtia.current_stream(None) + + # If the stream is not on the current device, then + # set the current stream on the device + if self.src_prev_stream.device != cur_stream.device: + with device(cur_stream.device): + self.dst_prev_stream = torch.mtia.current_stream(cur_stream.device) + torch.mtia.set_stream(cur_stream) + + def __exit__(self, type: Any, value: Any, traceback: Any): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # If stream is None or no MTIA device available, return + if cur_stream is None or self.idx == -1: + return + + # Reset the stream on the original device + # and destination device + if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr] + torch.mtia.set_stream(self.dst_prev_stream) # type: ignore[arg-type] + torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type] + + +def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext: + r"""Wrap around the Context-manager StreamContext that selects a given stream. + + Arguments: + stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + ..Note:: In eager mode stream is of type Stream class while in JIT it doesn't support torch.mtia.stream + """ + return StreamContext(stream) + + +def get_rng_state(device: Union[int, str, torch.device] = "mtia") -> Tensor: + r"""Returns the random number generator state as a ByteTensor. + + Args: + device (torch.device or int, optional): The device to return the RNG state of. + Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device). + """ + warnings.warn( + "get_rng_state is not implemented in torch.mtia", + UserWarning, + stacklevel=2, + ) + return torch.zeros([1], dtype=torch.uint8, device=device) + + +def set_rng_state( + new_state: Tensor, device: Union[int, str, torch.device] = "mtia" +) -> None: + r"""Sets the random number generator state. + + Args: + new_state (torch.ByteTensor): The desired state + device (torch.device or int, optional): The device to set the RNG state. + Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device). + """ + warnings.warn( + "set_rng_state is not implemented in torch.mtia", + UserWarning, + stacklevel=2, + ) + + +__all__ = [ + "init", + "is_available", + "is_initialized", + "synchronize", + "device_count", + "current_device", + "current_stream", + "default_stream", + "memory_stats", + "set_device", + "set_stream", + "stream", + "device", + "set_rng_state", + "get_rng_state", +] diff --git a/.venv/lib/python3.11/site-packages/torch/mtia/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/mtia/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3af364280bdb17e4b0e7626f3c11bbe298489414 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/mtia/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/mtia/__pycache__/_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/mtia/__pycache__/_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eed8b3ac0a109d269bef793246fa4e4c1b892ef9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/mtia/__pycache__/_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/mtia/_utils.py b/.venv/lib/python3.11/site-packages/torch/mtia/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..090e26f321232f9687c2b348ac602dbb6699b03f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/mtia/_utils.py @@ -0,0 +1,38 @@ +from typing import Any + +import torch + +# The _get_device_index has been moved to torch.utils._get_device_index +from torch._utils import _get_device_index as _torch_get_device_index + + +def _get_device_index( + device: Any, optional: bool = False, allow_cpu: bool = False +) -> int: + r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. + + If :attr:`device` is a torch.device object, returns the device index if it + is a MTIA device. Note that for a MTIA device without a specified index, + i.e., ``torch.device('mtia')``, this will return the current default MTIA + device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, + CPU devices will be accepted and ``-1`` will be returned in this case. + + If :attr:`device` is a Python integer, it is returned as is. + + If :attr:`device` is ``None``, this will return the current default MTIA + device if :attr:`optional` is ``True``. + """ + if isinstance(device, int): + return device + if isinstance(device, str): + device = torch.device(device) + if isinstance(device, torch.device): + if allow_cpu: + if device.type not in ["mtia", "cpu"]: + raise ValueError(f"Expected a mtia or cpu device, but got: {device}") + elif device.type != "mtia": + raise ValueError(f"Expected a mtia device, but got: {device}") + if not torch.jit.is_scripting(): + if isinstance(device, torch.mtia.device): + return device.idx + return _torch_get_device_index(device, optional, allow_cpu) diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..561e6e61172c4a17325ba28338bc978a8c450a47 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/_numeric_suite.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/_numeric_suite.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36ef983ac07fd33d762eff5866dce81e5da900e8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/_numeric_suite.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/_numeric_suite_fx.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/_numeric_suite_fx.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..095c89f1cdbde7322ff079b054cc12eef6768477 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/_numeric_suite_fx.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/_quantized_conversions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/_quantized_conversions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a85e4da244e79f3fdce2597c6e907847dfa52d8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/_quantized_conversions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0b501ab95d639b12c435b0dba289644a152d9bc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b88b0e211acead2b8b2c18f412c50597483bd475 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/fuser_method_mappings.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/fuser_method_mappings.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a7dd8fed5290482a97abeb91cd7b4f5885d228b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/fuser_method_mappings.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/observer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/observer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1316602f30ee97ee0c12c0dca6238fa38025cf3f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/observer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/qconfig.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/qconfig.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79937327e822d526dbb828230008a73a49aeab4c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/qconfig.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quant_type.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quant_type.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..106af68d5d6bf9d242585751aca13a8a6b4035d2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quant_type.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cfc08189f179408d9f6ea3ce367999819007dbe Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantize.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d3f869b99dd784d9517a2155215797bd6ca6d6d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantize.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantize_fx.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantize_fx.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6006a0707c1de681bc2a18b3c1dde59a3766853 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantize_fx.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a025ea7a12fea2facf8942d2439a2491816e7ed Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/stubs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/stubs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd9d37c429fa69f48a9e3d5f181c3927168779b9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/stubs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ceb04b3fbd186089a723213ef4d37743e3d12ce Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/utils.cpython-311.pyc differ