Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torch/_inductor/async_compile.py +297 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py +296 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py +321 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py +150 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py +149 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py +109 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic.py +315 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py +339 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py +119 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py +92 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/autotune_process.py +876 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/codecache.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py +264 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/comms.py +640 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py +1629 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/config.py +1241 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py +348 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/cpu_vec_isa.py +373 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_utils.py +330 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/debug.py +693 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/decomposition.py +980 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/dependencies.py +745 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/exc.py +104 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/extern_node_serializer.py +25 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/freezing.py +269 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_utils.py +251 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/graph.py +1930 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/hooks.py +30 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/index_propagation.py +373 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py +179 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/ir.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/jagged_lowerings.py +264 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/lowering.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/metrics.py +436 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_ir.py +1881 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_lowerings.py +1087 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/package/__init__.py +1 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/package.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/package/build_package.py +15 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/package/package.py +237 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/package/pt2_archive_constants.py +16 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py +2005 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/quantized_lowerings.py +92 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/remote_cache.py +198 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py +1743 -0
.venv/lib/python3.11/site-packages/torch/_inductor/async_compile.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import functools
|
| 5 |
+
import logging
|
| 6 |
+
import multiprocessing
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
| 10 |
+
from concurrent.futures.process import BrokenProcessPool
|
| 11 |
+
from functools import partial
|
| 12 |
+
from time import time
|
| 13 |
+
from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch._dynamo.device_interface import get_registered_device_interfaces
|
| 17 |
+
from torch._inductor import config
|
| 18 |
+
from torch._inductor.codecache import (
|
| 19 |
+
CodeCacheFuture,
|
| 20 |
+
CppCodeCache,
|
| 21 |
+
CppPythonBindingsCodeCache,
|
| 22 |
+
CUDACodeCache,
|
| 23 |
+
HalideCodeCache,
|
| 24 |
+
LambdaFuture,
|
| 25 |
+
ROCmCodeCache,
|
| 26 |
+
TritonCodeCache,
|
| 27 |
+
TritonFuture,
|
| 28 |
+
)
|
| 29 |
+
from torch._inductor.compile_worker.subproc_pool import (
|
| 30 |
+
_warm_process_pool,
|
| 31 |
+
AnyPool,
|
| 32 |
+
SubprocPool,
|
| 33 |
+
)
|
| 34 |
+
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
|
| 35 |
+
from torch._inductor.runtime.compile_tasks import (
|
| 36 |
+
_set_triton_ptxas_path,
|
| 37 |
+
_worker_compile_triton,
|
| 38 |
+
)
|
| 39 |
+
from torch.hub import _Faketqdm, tqdm
|
| 40 |
+
from torch.utils._triton import has_triton_package
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if TYPE_CHECKING:
|
| 44 |
+
from torch._inductor.runtime.hints import HalideMeta
|
| 45 |
+
|
| 46 |
+
# timing metrics for time spent in the compilation
|
| 47 |
+
_cumulative_compile_time = 0.0
|
| 48 |
+
_t0: Optional[float] = None
|
| 49 |
+
|
| 50 |
+
kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def pre_fork_setup():
|
| 54 |
+
"""
|
| 55 |
+
Setup that must be done prior to forking with a process pool.
|
| 56 |
+
"""
|
| 57 |
+
# ensure properties have been calculated before processes
|
| 58 |
+
# are forked
|
| 59 |
+
caching_device_properties()
|
| 60 |
+
|
| 61 |
+
# Computing the triton key can be slow. If we call it before fork,
|
| 62 |
+
# it will be cached for the forked subprocesses.
|
| 63 |
+
try:
|
| 64 |
+
from triton.compiler.compiler import triton_key
|
| 65 |
+
|
| 66 |
+
triton_key()
|
| 67 |
+
except ImportError:
|
| 68 |
+
# Triton might not be installed or might be an old version.
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def caching_device_properties():
|
| 73 |
+
for _, device_interface in get_registered_device_interfaces():
|
| 74 |
+
if device_interface.is_available():
|
| 75 |
+
device_interface.Worker.get_device_properties()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _compile_start() -> None:
|
| 79 |
+
global _t0
|
| 80 |
+
if _t0 is None:
|
| 81 |
+
_t0 = time()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _compile_end() -> None:
|
| 85 |
+
global _cumulative_compile_time, _t0
|
| 86 |
+
if _t0 is not None:
|
| 87 |
+
t1 = time()
|
| 88 |
+
_cumulative_compile_time += t1 - _t0
|
| 89 |
+
_t0 = None
|
| 90 |
+
# print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
_IS_WINDOWS = sys.platform == "win32"
|
| 94 |
+
|
| 95 |
+
log = logging.getLogger(__name__)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Used to keep track of all process pools invoked so far.
|
| 99 |
+
_pool_set: Set[AnyPool] = set()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def shutdown_compile_workers() -> None:
|
| 103 |
+
"""Shut down all outstanding compile-worker pools."""
|
| 104 |
+
for pool in _pool_set:
|
| 105 |
+
pool.shutdown()
|
| 106 |
+
after_fork()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def after_fork():
|
| 110 |
+
"""Reset pools to initial state without shutting them down"""
|
| 111 |
+
_pool_set.clear()
|
| 112 |
+
AsyncCompile.process_pool.cache_clear()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
os.register_at_fork(after_in_child=after_fork)
|
| 117 |
+
except AttributeError:
|
| 118 |
+
pass # register_at_fork does not exists on windows
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class AsyncCompile:
|
| 122 |
+
def __init__(self) -> None:
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
@functools.lru_cache(1)
|
| 127 |
+
def pool() -> ThreadPoolExecutor:
|
| 128 |
+
assert config.compile_threads > 1
|
| 129 |
+
return ThreadPoolExecutor(config.compile_threads)
|
| 130 |
+
|
| 131 |
+
@staticmethod
|
| 132 |
+
def _get_ready():
|
| 133 |
+
"""No-op function to help mark when the subprocess pool is ready."""
|
| 134 |
+
return "ready"
|
| 135 |
+
|
| 136 |
+
@staticmethod
|
| 137 |
+
@functools.lru_cache(1)
|
| 138 |
+
def process_pool() -> AnyPool:
|
| 139 |
+
assert config.compile_threads > 1
|
| 140 |
+
pool: AnyPool
|
| 141 |
+
if config.worker_start_method == "subprocess":
|
| 142 |
+
# Wrapper around ProcessPoolExecutor forks in a new process we control
|
| 143 |
+
pool = SubprocPool(config.compile_threads)
|
| 144 |
+
else:
|
| 145 |
+
pre_fork_setup()
|
| 146 |
+
ctx = multiprocessing.get_context(config.worker_start_method)
|
| 147 |
+
pool = ProcessPoolExecutor(
|
| 148 |
+
config.compile_threads,
|
| 149 |
+
mp_context=ctx,
|
| 150 |
+
initializer=partial(_async_compile_initializer, os.getpid()),
|
| 151 |
+
)
|
| 152 |
+
# when this pool is created in a subprocess object, the normal exit handler
|
| 153 |
+
# doesn't run, and we need to register our own handler.
|
| 154 |
+
# exitpriority has to be high, because another one of the finalizers will
|
| 155 |
+
# kill the worker thread that sends the shutdown message to the workers...
|
| 156 |
+
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
|
| 157 |
+
|
| 158 |
+
# Set an attribute we can check to see if the pool is ready.
|
| 159 |
+
pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[union-attr]
|
| 160 |
+
_pool_set.add(pool)
|
| 161 |
+
return pool
|
| 162 |
+
|
| 163 |
+
@classmethod
|
| 164 |
+
def warm_pool(cls) -> None:
|
| 165 |
+
if config.compile_threads <= 1:
|
| 166 |
+
return
|
| 167 |
+
_compile_start()
|
| 168 |
+
_warm_process_pool(cls.process_pool(), config.compile_threads)
|
| 169 |
+
_compile_end()
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def submit(cls, task: Callable[..., Any]) -> Any:
|
| 173 |
+
if config.compile_threads <= 1:
|
| 174 |
+
return task()
|
| 175 |
+
return cls.pool().submit(task)
|
| 176 |
+
|
| 177 |
+
def _use_process_pool(self):
|
| 178 |
+
return (
|
| 179 |
+
config.compile_threads > 1
|
| 180 |
+
and self.process_pool().ready_future.done() # type: ignore[union-attr]
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
|
| 184 |
+
kernel_code_log.info("Triton Kernel:\n%s", source_code)
|
| 185 |
+
_compile_start()
|
| 186 |
+
_set_triton_ptxas_path()
|
| 187 |
+
|
| 188 |
+
kernel = TritonCodeCache.load(kernel_name, source_code)
|
| 189 |
+
if self._use_process_pool():
|
| 190 |
+
# We want to support changing these env vars after (and while) the
|
| 191 |
+
# process pool is running, so pass them to the subprocess to reset.
|
| 192 |
+
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
|
| 193 |
+
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
|
| 194 |
+
return TritonFuture(
|
| 195 |
+
kernel,
|
| 196 |
+
self.process_pool().submit(
|
| 197 |
+
_worker_compile_triton,
|
| 198 |
+
kernel._reload_in_subproc,
|
| 199 |
+
extra_env,
|
| 200 |
+
),
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
kernel.precompile()
|
| 204 |
+
return kernel
|
| 205 |
+
|
| 206 |
+
def multi_kernel(self, *args, **kwargs) -> Any:
|
| 207 |
+
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
| 208 |
+
|
| 209 |
+
# no need to call this in parallel since the sub-kernels are already parallel tasks
|
| 210 |
+
return MultiKernelCall(*args, **kwargs)
|
| 211 |
+
|
| 212 |
+
def cpp(self, source_code: str):
|
| 213 |
+
kernel_code_log.info("CPP Kernel:\n%s", source_code)
|
| 214 |
+
if config.compile_threads <= 1:
|
| 215 |
+
return CppCodeCache.load(source_code).kernel
|
| 216 |
+
else:
|
| 217 |
+
get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit)
|
| 218 |
+
return LambdaFuture(lambda: get_result().kernel)
|
| 219 |
+
|
| 220 |
+
def cpp_pybinding(self, argtypes: List[str], source_code: str):
|
| 221 |
+
kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
|
| 222 |
+
if config.compile_threads <= 1:
|
| 223 |
+
return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
|
| 224 |
+
else:
|
| 225 |
+
get_result = CppPythonBindingsCodeCache.load_pybinding_async(
|
| 226 |
+
argtypes, source_code, submit_fn=self.submit
|
| 227 |
+
)
|
| 228 |
+
return LambdaFuture(get_result)
|
| 229 |
+
|
| 230 |
+
def cuda(self, source_code, dst_file_ext):
|
| 231 |
+
kernel_code_log.info("CUDA Kernel:\n%s", source_code)
|
| 232 |
+
|
| 233 |
+
def task():
|
| 234 |
+
return CUDACodeCache.load(source_code, dst_file_ext)[0]
|
| 235 |
+
|
| 236 |
+
return self.submit(task)
|
| 237 |
+
|
| 238 |
+
def rocm(self, source_code, dst_file_ext):
|
| 239 |
+
kernel_code_log.info("ROCm Kernel:\n%s", source_code)
|
| 240 |
+
|
| 241 |
+
def task():
|
| 242 |
+
return ROCmCodeCache.load(source_code, dst_file_ext)[0]
|
| 243 |
+
|
| 244 |
+
return self.submit(task)
|
| 245 |
+
|
| 246 |
+
def halide(self, meta: HalideMeta, source_code: str):
|
| 247 |
+
kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code)
|
| 248 |
+
if config.compile_threads <= 1:
|
| 249 |
+
return HalideCodeCache.generate_halide(meta, source_code)
|
| 250 |
+
else:
|
| 251 |
+
get_result = HalideCodeCache.generate_halide_async(
|
| 252 |
+
meta, source_code, submit_fn=self.submit
|
| 253 |
+
)
|
| 254 |
+
return LambdaFuture(get_result)
|
| 255 |
+
|
| 256 |
+
def wait(self, scope: Dict[str, Any]) -> None:
|
| 257 |
+
num_kernels = len(
|
| 258 |
+
[
|
| 259 |
+
value
|
| 260 |
+
for key, value in scope.items()
|
| 261 |
+
if isinstance(value, (Future, CodeCacheFuture))
|
| 262 |
+
]
|
| 263 |
+
)
|
| 264 |
+
pbar = tqdm(
|
| 265 |
+
total=num_kernels,
|
| 266 |
+
desc="Inductor Compilation",
|
| 267 |
+
disable=config.disable_progress,
|
| 268 |
+
delay=0,
|
| 269 |
+
)
|
| 270 |
+
if config.compile_threads > 1:
|
| 271 |
+
for key, result in scope.items():
|
| 272 |
+
if config.verbose_progress and not isinstance(pbar, _Faketqdm):
|
| 273 |
+
pbar.set_postfix_str(key)
|
| 274 |
+
if isinstance(result, (Future, CodeCacheFuture)):
|
| 275 |
+
try:
|
| 276 |
+
scope[key] = result.result()
|
| 277 |
+
except BrokenProcessPool as e:
|
| 278 |
+
raise RuntimeError(
|
| 279 |
+
"A compilation subprocess exited unexpectedly. This "
|
| 280 |
+
"is likely due to a crash. To facilitate debugging, "
|
| 281 |
+
"you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 "
|
| 282 |
+
"to cause compilation to occur in the main process."
|
| 283 |
+
) from e
|
| 284 |
+
pbar.update(1)
|
| 285 |
+
|
| 286 |
+
_compile_end()
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if (
|
| 290 |
+
os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
|
| 291 |
+
or os.environ.get("TORCH_WARM_POOL", "1") != "1"
|
| 292 |
+
# The subprocess pool is only used for the Triton backend
|
| 293 |
+
or not has_triton_package()
|
| 294 |
+
):
|
| 295 |
+
pass
|
| 296 |
+
else:
|
| 297 |
+
AsyncCompile.warm_pool()
|
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (202 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-311.pyc
ADDED
|
Binary file (19 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
Choice,
|
| 11 |
+
)
|
| 12 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 13 |
+
LearnedHeuristicDecision,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MMRankingA100(LearnedHeuristicDecision):
|
| 18 |
+
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.choices: List[Choice] = []
|
| 21 |
+
self.fill_choices()
|
| 22 |
+
|
| 23 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 24 |
+
return (
|
| 25 |
+
metadata.name == self.get_name()
|
| 26 |
+
and metadata.shared_memory == 166912
|
| 27 |
+
and str(metadata.device_capa) == "(8, 0)"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def get_confidence_threshold(self) -> float:
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
def get_choice(self, idx: int) -> Optional[str]:
|
| 34 |
+
if idx < len(self.choices):
|
| 35 |
+
return self.choices[idx]
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def fill_choices(self) -> None:
|
| 39 |
+
self.choices.append('extern_mm')
|
| 40 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
|
| 41 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
|
| 42 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
|
| 43 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
|
| 44 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 45 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
|
| 46 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 47 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
|
| 48 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
|
| 49 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
|
| 50 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
|
| 51 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
|
| 52 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
|
| 53 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
|
| 54 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
|
| 55 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
|
| 56 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
|
| 57 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
|
| 58 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
|
| 59 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
|
| 60 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
|
| 61 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
|
| 62 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
|
| 63 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 64 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
|
| 65 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
|
| 66 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
|
| 67 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 68 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
|
| 69 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 70 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
|
| 71 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 72 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
|
| 73 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
|
| 74 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 75 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
|
| 76 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 77 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
|
| 78 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
|
| 79 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
|
| 80 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
|
| 81 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
|
| 82 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
|
| 83 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
|
| 84 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
|
| 85 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
|
| 86 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
|
| 87 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
|
| 88 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
|
| 89 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
|
| 90 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
|
| 91 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
|
| 92 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 93 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
|
| 94 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
|
| 95 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
|
| 96 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 97 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
|
| 98 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 99 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
|
| 100 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 101 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 102 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 103 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
|
| 104 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 105 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 106 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
|
| 107 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
|
| 108 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
|
| 109 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
|
| 110 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
|
| 111 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
|
| 112 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
|
| 113 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
|
| 114 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
|
| 115 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
|
| 116 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 117 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 118 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
|
| 119 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 120 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
|
| 121 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 122 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
|
| 123 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 124 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
|
| 125 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 126 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 127 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
|
| 128 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 129 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
|
| 130 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
|
| 131 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
|
| 132 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
|
| 133 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
|
| 134 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
|
| 135 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
|
| 136 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
|
| 137 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
|
| 138 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 139 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 140 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 141 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 142 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
|
| 143 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
|
| 144 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 145 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 146 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 147 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 148 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 149 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=2')
|
| 150 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 151 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
|
| 152 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 153 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 154 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
|
| 155 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
|
| 156 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 157 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 158 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
|
| 159 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 160 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
|
| 161 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
|
| 162 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
|
| 163 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
|
| 164 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
|
| 165 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 166 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
|
| 167 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
|
| 168 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
|
| 169 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 170 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 171 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 172 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
|
| 173 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 174 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
|
| 175 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
|
| 176 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
|
| 177 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
|
| 178 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
|
| 179 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 180 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 181 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
|
| 182 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
|
| 183 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 184 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 185 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 186 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
|
| 187 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
|
| 188 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
|
| 189 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
|
| 190 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
|
| 191 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
|
| 192 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
|
| 193 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
|
| 194 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
|
| 195 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
|
| 196 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 197 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
|
| 198 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
|
| 199 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 200 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
|
| 201 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 202 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
|
| 203 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 204 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 205 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 206 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
|
| 207 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
|
| 208 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
|
| 209 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
|
| 210 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
|
| 211 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
|
| 212 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
|
| 213 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
|
| 214 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
|
| 215 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
|
| 216 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 217 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
|
| 218 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
|
| 219 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 220 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
|
| 221 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 222 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
|
| 223 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 224 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 225 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
|
| 226 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
|
| 227 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
|
| 228 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
|
| 229 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
|
| 230 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
|
| 231 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
|
| 232 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
|
| 233 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 234 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 235 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
|
| 236 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 237 |
+
|
| 238 |
+
def get_name(self) -> str:
|
| 239 |
+
return 'mm'
|
| 240 |
+
|
| 241 |
+
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
|
| 242 |
+
if context.get_value('arith_intensity') <= 52.6245059967041:
|
| 243 |
+
if context.get_value('n') <= 34.0:
|
| 244 |
+
if context.get_value('n') <= 18.0:
|
| 245 |
+
if context.get_value('k*n') <= 312.0:
|
| 246 |
+
return [(0.093, 12), (0.081, 16), (0.081, 148), (0.070, 10), (0.070, 17), (0.070, 149), (0.070, 151), (0.070, 150), (0.070, 14), (0.058, 11), (0.058, 15), (0.058, 13), (0.058, 122), (0.047, 121), (0.035, 123), (0.012, 92)]
|
| 247 |
+
else:
|
| 248 |
+
if context.get_value('k') <= 40.0:
|
| 249 |
+
return [(0.083, 42), (0.083, 46), (0.083, 44), (0.083, 40), (0.083, 128), (0.067, 45), (0.067, 43), (0.067, 41), (0.067, 169), (0.067, 171), (0.067, 168), (0.067, 129), (0.067, 170), (0.033, 103), (0.017, 121)]
|
| 250 |
+
else:
|
| 251 |
+
return [(0.112, 137), (0.104, 136), (0.101, 0), (0.081, 1), (0.073, 135), (0.069, 67), (0.066, 187), (0.058, 41), (0.050, 71), (0.046, 68), (0.046, 70), (0.031, 44), (0.027, 43), (0.027, 170), (0.019, 189), (0.019, 188), (0.015, 169), (0.015, 171), (0.012, 115), (0.012, 168), (0.012, 69), (0.004, 103)]
|
| 252 |
+
else:
|
| 253 |
+
if context.get_value('mat1_stride_0') <= 20.0:
|
| 254 |
+
return [(0.069, 0), (0.059, 157), (0.059, 22), (0.059, 153), (0.059, 155), (0.059, 25), (0.059, 23), (0.059, 19), (0.044, 21), (0.044, 18), (0.044, 152), (0.044, 158), (0.044, 154), (0.044, 156), (0.044, 20), (0.044, 124), (0.044, 24), (0.030, 125), (0.029, 126), (0.015, 97), (0.015, 95), (0.015, 96), (0.010, 2), (0.010, 75)]
|
| 255 |
+
else:
|
| 256 |
+
if context.get_value('k') <= 68.0:
|
| 257 |
+
return [(0.087, 72), (0.087, 74), (0.087, 73), (0.086, 76), (0.077, 75), (0.067, 192), (0.058, 190), (0.048, 47), (0.048, 193), (0.048, 49), (0.048, 51), (0.048, 191), (0.038, 53), (0.019, 133), (0.019, 50), (0.019, 175), (0.019, 172), (0.019, 48), (0.019, 174), (0.010, 173), (0.010, 177), (0.010, 52), (0.010, 54), (0.010, 178), (0.010, 176)]
|
| 258 |
+
else:
|
| 259 |
+
return [(0.154, 52), (0.154, 72), (0.102, 75), (0.087, 49), (0.087, 73), (0.086, 51), (0.057, 176), (0.045, 2), (0.038, 191), (0.038, 178), (0.038, 190), (0.029, 173), (0.029, 76), (0.026, 138), (0.013, 139), (0.013, 140), (0.003, 0)]
|
| 260 |
+
else:
|
| 261 |
+
if context.get_value('k') <= 35.0:
|
| 262 |
+
if context.get_value('k') <= 18.0:
|
| 263 |
+
if context.get_value('m*n') <= 19505152.0:
|
| 264 |
+
return [(0.151, 159), (0.140, 160), (0.129, 164), (0.055, 127), (0.051, 29), (0.044, 161), (0.044, 147), (0.040, 146), (0.040, 31), (0.037, 145), (0.026, 28), (0.022, 90), (0.022, 93), (0.022, 94), (0.022, 100), (0.022, 125), (0.022, 158), (0.022, 157), (0.011, 87), (0.011, 88), (0.011, 89), (0.011, 91), (0.011, 95), (0.011, 96), (0.011, 98), (0.011, 99)]
|
| 265 |
+
else:
|
| 266 |
+
return [(0.069, 7), (0.069, 5), (0.067, 147), (0.066, 8), (0.061, 145), (0.058, 146), (0.052, 124), (0.049, 29), (0.049, 159), (0.046, 31), (0.043, 157), (0.041, 9), (0.041, 4), (0.040, 6), (0.035, 164), (0.035, 160), (0.026, 158), (0.017, 125), (0.017, 28), (0.017, 32), (0.017, 162), (0.017, 27), (0.017, 30), (0.017, 161), (0.009, 33), (0.009, 26), (0.009, 163), (0.006, 0)]
|
| 267 |
+
else:
|
| 268 |
+
if context.get_value('n') <= 68.0:
|
| 269 |
+
return [(0.101, 182), (0.101, 59), (0.088, 57), (0.076, 184), (0.076, 61), (0.076, 179), (0.076, 62), (0.076, 58), (0.063, 180), (0.063, 60), (0.051, 56), (0.050, 181), (0.025, 130), (0.025, 177), (0.025, 183), (0.013, 178), (0.013, 55)]
|
| 270 |
+
else:
|
| 271 |
+
return [(0.089, 180), (0.079, 60), (0.066, 35), (0.066, 181), (0.066, 38), (0.066, 58), (0.066, 179), (0.066, 57), (0.062, 184), (0.053, 37), (0.044, 166), (0.040, 55), (0.040, 39), (0.040, 36), (0.040, 165), (0.040, 167), (0.027, 177), (0.027, 34), (0.022, 159)]
|
| 272 |
+
else:
|
| 273 |
+
if context.get_value('m*n') <= 309760.0:
|
| 274 |
+
return [(0.298, 0), (0.097, 140), (0.080, 83), (0.072, 86), (0.044, 84), (0.036, 178), (0.036, 117), (0.036, 82), (0.032, 120), (0.032, 85), (0.028, 119), (0.024, 130), (0.024, 109), (0.020, 108), (0.020, 118), (0.012, 104), (0.012, 116), (0.012, 141), (0.012, 144), (0.008, 105), (0.008, 106), (0.008, 111), (0.008, 114), (0.008, 107), (0.008, 132), (0.004, 101), (0.004, 102), (0.004, 110), (0.004, 112), (0.004, 113), (0.004, 131)]
|
| 275 |
+
else:
|
| 276 |
+
if context.get_value('n') <= 72.0:
|
| 277 |
+
return [(0.227, 77), (0.118, 78), (0.102, 194), (0.086, 80), (0.059, 57), (0.054, 81), (0.049, 196), (0.048, 197), (0.048, 59), (0.043, 79), (0.032, 195), (0.027, 180), (0.022, 3), (0.021, 141), (0.016, 60), (0.016, 142), (0.011, 183), (0.011, 0), (0.011, 144)]
|
| 278 |
+
else:
|
| 279 |
+
return [(0.140, 186), (0.132, 185), (0.109, 63), (0.085, 65), (0.078, 37), (0.077, 35), (0.062, 197), (0.047, 194), (0.046, 165), (0.046, 57), (0.039, 78), (0.039, 79), (0.039, 66), (0.039, 64), (0.016, 195), (0.008, 159)]
|
| 280 |
+
else:
|
| 281 |
+
if str(context.get_value('using_tf32')) != 'False':
|
| 282 |
+
if context.get_value('m*n') <= 815360.0:
|
| 283 |
+
if context.get_value('k') <= 1184.0:
|
| 284 |
+
return [(0.218, 140), (0.205, 0), (0.154, 144), (0.115, 141), (0.051, 185), (0.051, 104), (0.039, 78), (0.038, 116), (0.026, 165), (0.026, 130), (0.026, 178), (0.013, 57), (0.013, 195), (0.013, 167), (0.013, 186)]
|
| 285 |
+
else:
|
| 286 |
+
return [(0.901, 0), (0.030, 144), (0.030, 134), (0.016, 3), (0.006, 78), (0.006, 77), (0.002, 57), (0.002, 194), (0.002, 59), (0.002, 60), (0.002, 143)]
|
| 287 |
+
else:
|
| 288 |
+
if context.get_value('arith_intensity') <= 187.23922729492188:
|
| 289 |
+
if context.get_value('mat1_stride_0') <= 198.0:
|
| 290 |
+
return [(0.273, 63), (0.158, 37), (0.152, 35), (0.127, 57), (0.097, 165), (0.053, 185), (0.031, 0), (0.028, 64), (0.014, 60), (0.014, 78), (0.009, 55), (0.008, 134), (0.005, 34), (0.005, 167), (0.005, 179), (0.005, 65), (0.005, 66), (0.005, 186), (0.005, 194), (0.002, 166)]
|
| 291 |
+
else:
|
| 292 |
+
return [(0.296, 63), (0.235, 0), (0.132, 64), (0.074, 37), (0.069, 78), (0.051, 185), (0.051, 35), (0.030, 57), (0.020, 77), (0.016, 194), (0.008, 66), (0.007, 65), (0.003, 3), (0.003, 165), (0.003, 141), (0.001, 134), (0.001, 166)]
|
| 293 |
+
else:
|
| 294 |
+
return [(0.405, 0), (0.246, 37), (0.177, 63), (0.145, 35), (0.005, 185), (0.005, 65), (0.005, 64), (0.004, 57), (0.003, 66), (0.002, 165), (0.001, 78), (0.001, 55)]
|
| 295 |
+
else:
|
| 296 |
+
return [(0.357, 0), (0.112, 165), (0.101, 57), (0.094, 179), (0.086, 64), (0.074, 167), (0.067, 60), (0.064, 159), (0.033, 35), (0.007, 195), (0.002, 180), (0.001, 34), (0.001, 166), (0.001, 78)]
|
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
Choice,
|
| 11 |
+
)
|
| 12 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 13 |
+
LearnedHeuristicDecision,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MMRankingH100(LearnedHeuristicDecision):
|
| 18 |
+
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.choices: List[Choice] = []
|
| 21 |
+
self.fill_choices()
|
| 22 |
+
|
| 23 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 24 |
+
return (
|
| 25 |
+
metadata.name == self.get_name()
|
| 26 |
+
and metadata.shared_memory == 232448
|
| 27 |
+
and str(metadata.device_capa) == "(9, 0)"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def get_confidence_threshold(self) -> float:
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
def get_choice(self, idx: int) -> Optional[str]:
|
| 34 |
+
if idx < len(self.choices):
|
| 35 |
+
return self.choices[idx]
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def fill_choices(self) -> None:
|
| 39 |
+
self.choices.append('extern_mm')
|
| 40 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
|
| 41 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
|
| 42 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
|
| 43 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
|
| 44 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 45 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
|
| 46 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 47 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
|
| 48 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
|
| 49 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
|
| 50 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
|
| 51 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
|
| 52 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
|
| 53 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
|
| 54 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
|
| 55 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
|
| 56 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
|
| 57 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
|
| 58 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
|
| 59 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
|
| 60 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
|
| 61 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
|
| 62 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 63 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
|
| 64 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
|
| 65 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 66 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
|
| 67 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 68 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
|
| 69 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 70 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
|
| 71 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
|
| 72 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 73 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
|
| 74 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 75 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
|
| 76 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
|
| 77 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
|
| 78 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
|
| 79 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
|
| 80 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
|
| 81 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
|
| 82 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
|
| 83 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
|
| 84 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
|
| 85 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
|
| 86 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
|
| 87 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
|
| 88 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
|
| 89 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
|
| 90 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
|
| 91 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 92 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
|
| 93 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
|
| 94 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
|
| 95 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 96 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
|
| 97 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 98 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
|
| 99 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 100 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 101 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 102 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
|
| 103 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 104 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 105 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
|
| 106 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
|
| 107 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
|
| 108 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
|
| 109 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
|
| 110 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
|
| 111 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
|
| 112 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
|
| 113 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
|
| 114 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
|
| 115 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 116 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 117 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
|
| 118 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 119 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
|
| 120 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 121 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
|
| 122 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
|
| 123 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 124 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
|
| 125 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 126 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
|
| 127 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 128 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
|
| 129 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 130 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
|
| 131 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
|
| 132 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
|
| 133 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=1')
|
| 134 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=1')
|
| 135 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
|
| 136 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
|
| 137 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
|
| 138 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
|
| 139 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
|
| 140 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
|
| 141 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
|
| 142 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
|
| 143 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 144 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 145 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 146 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
|
| 147 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 148 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 149 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
|
| 150 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=1')
|
| 151 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
|
| 152 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=2')
|
| 153 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
|
| 154 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
|
| 155 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 156 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 157 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 158 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
|
| 159 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 160 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 161 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 162 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
|
| 163 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
|
| 164 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 165 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 166 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
|
| 167 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
|
| 168 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
|
| 169 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
|
| 170 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
|
| 171 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 172 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
|
| 173 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
|
| 174 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
|
| 175 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
|
| 176 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 177 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 178 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=16_numstages=2_numwarps=2')
|
| 179 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
|
| 180 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 181 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
|
| 182 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
|
| 183 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
|
| 184 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
|
| 185 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
|
| 186 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 187 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 188 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
|
| 189 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 190 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 191 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 192 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
|
| 193 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
|
| 194 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
|
| 195 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
|
| 196 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
|
| 197 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
|
| 198 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
|
| 199 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
|
| 200 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 201 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
|
| 202 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
|
| 203 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 204 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
|
| 205 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 206 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
|
| 207 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 208 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 209 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 210 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
|
| 211 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
|
| 212 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
|
| 213 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
|
| 214 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
|
| 215 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
|
| 216 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
|
| 217 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
|
| 218 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
|
| 219 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
|
| 220 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 221 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
|
| 222 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
|
| 223 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 224 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
|
| 225 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 226 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
|
| 227 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 228 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 229 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
|
| 230 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
|
| 231 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
|
| 232 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
|
| 233 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
|
| 234 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
|
| 235 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
|
| 236 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
|
| 237 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 238 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 239 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
|
| 240 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 241 |
+
|
| 242 |
+
def get_name(self) -> str:
|
| 243 |
+
return 'mm'
|
| 244 |
+
|
| 245 |
+
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
|
| 246 |
+
if context.get_value('arith_intensity') <= 29.89772129058838:
|
| 247 |
+
if context.get_value('n') <= 34.0:
|
| 248 |
+
if context.get_value('n') <= 18.0:
|
| 249 |
+
if context.get_value('k*n') <= 432.0:
|
| 250 |
+
if context.get_value('arith_intensity') <= 7.8700292110443115:
|
| 251 |
+
return [(0.098, 128), (0.098, 129), (0.098, 127), (0.073, 14), (0.073, 16), (0.073, 12), (0.073, 154), (0.073, 156), (0.073, 157), (0.073, 155), (0.049, 10), (0.049, 94), (0.049, 95), (0.048, 96)]
|
| 252 |
+
else:
|
| 253 |
+
return [(0.091, 154), (0.073, 10), (0.073, 15), (0.073, 13), (0.073, 11), (0.073, 17), (0.073, 16), (0.073, 14), (0.073, 12), (0.055, 127), (0.054, 157), (0.054, 156), (0.054, 155), (0.036, 129), (0.036, 128), (0.018, 41), (0.018, 43)]
|
| 254 |
+
else:
|
| 255 |
+
if context.get_value('k') <= 40.0:
|
| 256 |
+
return [(0.070, 39), (0.069, 45), (0.069, 41), (0.069, 43), (0.069, 111), (0.069, 112), (0.056, 38), (0.056, 40), (0.056, 42), (0.056, 44), (0.056, 174), (0.056, 173), (0.056, 175), (0.056, 134), (0.056, 172), (0.056, 135), (0.014, 154), (0.014, 127)]
|
| 257 |
+
else:
|
| 258 |
+
return [(0.147, 144), (0.119, 143), (0.087, 142), (0.083, 0), (0.073, 191), (0.059, 69), (0.050, 67), (0.046, 70), (0.041, 1), (0.036, 174), (0.032, 43), (0.032, 123), (0.028, 40), (0.027, 42), (0.027, 173), (0.023, 175), (0.018, 66), (0.014, 192), (0.014, 193), (0.014, 139), (0.014, 68), (0.014, 127)]
|
| 259 |
+
else:
|
| 260 |
+
if context.get_value('mat1_stride_0') <= 40.0:
|
| 261 |
+
if context.get_value('mat1_stride_0') <= 20.0:
|
| 262 |
+
return [(0.109, 23), (0.109, 21), (0.109, 20), (0.088, 0), (0.087, 131), (0.066, 18), (0.065, 130), (0.065, 132), (0.065, 159), (0.065, 160), (0.065, 161), (0.065, 158), (0.022, 22), (0.022, 19)]
|
| 263 |
+
else:
|
| 264 |
+
return [(0.065, 46), (0.064, 52), (0.064, 50), (0.064, 48), (0.064, 51), (0.064, 49), (0.064, 47), (0.064, 53), (0.064, 181), (0.064, 177), (0.064, 179), (0.064, 176), (0.038, 130), (0.038, 136), (0.026, 182), (0.026, 178), (0.026, 180), (0.026, 137), (0.025, 158), (0.013, 114), (0.013, 113)]
|
| 265 |
+
else:
|
| 266 |
+
if context.get_value('mat1_stride_0') <= 68.0:
|
| 267 |
+
return [(0.138, 140), (0.125, 195), (0.100, 71), (0.100, 74), (0.100, 196), (0.100, 194), (0.100, 197), (0.075, 75), (0.062, 72), (0.062, 73), (0.012, 180), (0.012, 51), (0.012, 182)]
|
| 268 |
+
else:
|
| 269 |
+
return [(0.124, 180), (0.124, 182), (0.114, 75), (0.103, 74), (0.093, 51), (0.093, 71), (0.072, 72), (0.062, 194), (0.052, 145), (0.052, 195), (0.021, 48), (0.021, 50), (0.021, 47), (0.020, 124), (0.010, 147), (0.010, 146), (0.010, 46)]
|
| 270 |
+
else:
|
| 271 |
+
if context.get_value('k') <= 18.0:
|
| 272 |
+
if context.get_value('m*k') <= 528.0:
|
| 273 |
+
return [(0.097, 88), (0.087, 92), (0.077, 90), (0.058, 105), (0.058, 103), (0.058, 104), (0.058, 99), (0.058, 100), (0.058, 106), (0.058, 93), (0.057, 91), (0.057, 97), (0.057, 98), (0.057, 101), (0.048, 102), (0.029, 87), (0.029, 89)]
|
| 274 |
+
else:
|
| 275 |
+
if context.get_value('n') <= 80.0:
|
| 276 |
+
return [(0.057, 161), (0.057, 130), (0.057, 24), (0.056, 164), (0.056, 163), (0.056, 166), (0.056, 168), (0.056, 30), (0.056, 28), (0.056, 26), (0.056, 25), (0.056, 27), (0.056, 29), (0.056, 31), (0.042, 131), (0.028, 99), (0.028, 101), (0.028, 100), (0.028, 167), (0.028, 165), (0.028, 133)]
|
| 277 |
+
else:
|
| 278 |
+
return [(0.110, 164), (0.108, 163), (0.106, 168), (0.069, 161), (0.066, 151), (0.060, 152), (0.055, 165), (0.050, 27), (0.050, 29), (0.048, 131), (0.043, 153), (0.037, 133), (0.037, 130), (0.028, 8), (0.028, 5), (0.027, 7), (0.026, 26), (0.016, 162), (0.012, 9), (0.007, 4), (0.005, 100), (0.005, 6), (0.005, 24)]
|
| 279 |
+
else:
|
| 280 |
+
if context.get_value('k') <= 36.0:
|
| 281 |
+
if context.get_value('n') <= 68.0:
|
| 282 |
+
return [(0.097, 184), (0.097, 56), (0.086, 186), (0.086, 183), (0.086, 188), (0.086, 58), (0.086, 60), (0.065, 54), (0.043, 187), (0.043, 185), (0.043, 57), (0.043, 61), (0.032, 55), (0.032, 130), (0.032, 59), (0.011, 181), (0.011, 163), (0.011, 136), (0.011, 138)]
|
| 283 |
+
else:
|
| 284 |
+
return [(0.117, 184), (0.117, 170), (0.117, 169), (0.107, 183), (0.106, 188), (0.075, 181), (0.064, 130), (0.064, 56), (0.053, 171), (0.032, 57), (0.032, 59), (0.032, 185), (0.011, 163), (0.011, 32), (0.011, 37), (0.011, 34), (0.011, 33), (0.011, 35), (0.011, 36), (0.011, 54)]
|
| 285 |
+
else:
|
| 286 |
+
if context.get_value('mat2_stride_0') <= 384.0:
|
| 287 |
+
return [(0.244, 0), (0.061, 76), (0.061, 79), (0.030, 3), (0.030, 183), (0.030, 189), (0.030, 187), (0.030, 64), (0.030, 190), (0.030, 62), (0.030, 198), (0.030, 201), (0.030, 77), (0.030, 200), (0.030, 80), (0.030, 199), (0.030, 78), (0.030, 184), (0.020, 86), (0.020, 84), (0.020, 120), (0.020, 81), (0.020, 121), (0.020, 85), (0.020, 122), (0.010, 83), (0.010, 118), (0.010, 119), (0.010, 82)]
|
| 288 |
+
else:
|
| 289 |
+
return [(0.274, 83), (0.171, 86), (0.152, 0), (0.071, 85), (0.061, 125), (0.050, 84), (0.020, 109), (0.020, 117), (0.020, 81), (0.020, 118), (0.020, 121), (0.020, 108), (0.020, 115), (0.020, 116), (0.010, 110), (0.010, 120), (0.010, 103), (0.010, 107), (0.010, 119), (0.010, 122)]
|
| 290 |
+
else:
|
| 291 |
+
if context.get_value('arith_intensity') <= 56.995582580566406:
|
| 292 |
+
if context.get_value('n') <= 68.0:
|
| 293 |
+
if context.get_value('k*n') <= 4448.0:
|
| 294 |
+
if context.get_value('m*n') <= 29626368.0:
|
| 295 |
+
return [(0.107, 198), (0.107, 200), (0.107, 201), (0.107, 199), (0.106, 76), (0.106, 79), (0.064, 197), (0.063, 56), (0.043, 184), (0.043, 187), (0.042, 80), (0.042, 77), (0.042, 183), (0.021, 78)]
|
| 296 |
+
else:
|
| 297 |
+
return [(0.073, 201), (0.073, 198), (0.073, 200), (0.073, 199), (0.073, 197), (0.073, 56), (0.073, 58), (0.073, 79), (0.073, 76), (0.072, 59), (0.072, 78), (0.072, 77), (0.072, 80), (0.018, 184), (0.018, 55), (0.018, 54)]
|
| 298 |
+
else:
|
| 299 |
+
if context.get_value('k') <= 348.0:
|
| 300 |
+
return [(0.206, 76), (0.183, 77), (0.169, 198), (0.160, 199), (0.053, 59), (0.046, 56), (0.038, 3), (0.030, 148), (0.030, 58), (0.030, 187), (0.023, 184), (0.015, 0), (0.008, 55), (0.008, 54)]
|
| 301 |
+
else:
|
| 302 |
+
return [(0.146, 198), (0.145, 199), (0.145, 148), (0.126, 0), (0.084, 76), (0.084, 77), (0.042, 80), (0.042, 79), (0.021, 149), (0.021, 150), (0.021, 3), (0.014, 46), (0.014, 74), (0.014, 75), (0.014, 124), (0.014, 194), (0.014, 195), (0.007, 145), (0.007, 146), (0.007, 2), (0.007, 72), (0.007, 147), (0.007, 71)]
|
| 303 |
+
else:
|
| 304 |
+
if context.get_value('m') <= 3264.0:
|
| 305 |
+
return [(0.247, 147), (0.115, 197), (0.066, 199), (0.066, 201), (0.066, 198), (0.049, 0), (0.049, 169), (0.049, 171), (0.033, 140), (0.033, 125), (0.033, 114), (0.016, 126), (0.016, 183), (0.016, 184), (0.016, 185), (0.016, 182), (0.016, 188), (0.016, 78), (0.016, 148), (0.016, 138), (0.016, 77), (0.016, 56), (0.016, 59)]
|
| 306 |
+
else:
|
| 307 |
+
if context.get_value('k') <= 62.5:
|
| 308 |
+
return [(0.226, 190), (0.226, 189), (0.122, 62), (0.122, 64), (0.055, 77), (0.055, 78), (0.037, 198), (0.036, 201), (0.036, 33), (0.024, 163), (0.018, 56), (0.018, 35), (0.018, 169), (0.006, 171)]
|
| 309 |
+
else:
|
| 310 |
+
return [(0.162, 35), (0.118, 33), (0.096, 189), (0.096, 190), (0.088, 169), (0.074, 62), (0.073, 56), (0.066, 171), (0.051, 198), (0.051, 201), (0.044, 59), (0.037, 64), (0.029, 63), (0.007, 0), (0.007, 77)]
|
| 311 |
+
else:
|
| 312 |
+
if context.get_value('m*n') <= 1097728.0:
|
| 313 |
+
return [(0.403, 0), (0.179, 141), (0.134, 150), (0.086, 147), (0.051, 148), (0.048, 3), (0.024, 189), (0.020, 199), (0.017, 64), (0.010, 65), (0.010, 77), (0.007, 114), (0.003, 138), (0.003, 59), (0.003, 182)]
|
| 314 |
+
else:
|
| 315 |
+
if context.get_value('m*n') <= 3244032.0:
|
| 316 |
+
return [(0.295, 189), (0.176, 64), (0.157, 65), (0.090, 0), (0.069, 62), (0.059, 63), (0.046, 77), (0.039, 169), (0.023, 199), (0.020, 35), (0.013, 33), (0.010, 171), (0.003, 141)]
|
| 317 |
+
else:
|
| 318 |
+
if context.get_value('n') <= 136.0:
|
| 319 |
+
return [(0.197, 189), (0.197, 63), (0.161, 77), (0.157, 62), (0.061, 33), (0.044, 65), (0.039, 35), (0.039, 64), (0.030, 169), (0.026, 0), (0.017, 199), (0.017, 148), (0.009, 56), (0.004, 3)]
|
| 320 |
+
else:
|
| 321 |
+
return [(0.460, 0), (0.145, 62), (0.138, 63), (0.081, 35), (0.047, 33), (0.043, 189), (0.023, 64), (0.018, 77), (0.013, 169), (0.009, 65), (0.009, 56), (0.005, 32), (0.005, 59), (0.002, 183), (0.002, 163)]
|
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
Choice,
|
| 11 |
+
)
|
| 12 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 13 |
+
LearnedHeuristicDecision,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MixedMMA100(LearnedHeuristicDecision):
|
| 18 |
+
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.choices: List[Choice] = []
|
| 21 |
+
self.fill_choices()
|
| 22 |
+
|
| 23 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 24 |
+
return (
|
| 25 |
+
metadata.name == self.get_name()
|
| 26 |
+
and metadata.shared_memory == 166912
|
| 27 |
+
and str(metadata.device_capa) == "(8, 0)"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def get_confidence_threshold(self) -> float:
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
def get_choice(self, idx: int) -> Optional[str]:
|
| 34 |
+
if idx < len(self.choices):
|
| 35 |
+
return self.choices[idx]
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def fill_choices(self) -> None:
|
| 39 |
+
self.choices.append('extern_fallback_mixed_mm')
|
| 40 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 41 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 42 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 43 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
|
| 44 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
|
| 45 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 46 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
|
| 47 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
|
| 48 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 49 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 50 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 51 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
|
| 52 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 53 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 54 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 55 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 56 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 57 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
|
| 58 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 59 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 60 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 61 |
+
|
| 62 |
+
def get_name(self) -> str:
|
| 63 |
+
return 'mixed_mm'
|
| 64 |
+
|
| 65 |
+
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
|
| 66 |
+
if str(context.get_value('1LEQmLEQ16')) != 'True':
|
| 67 |
+
if context.get_value('m') <= 32.5:
|
| 68 |
+
if context.get_value('n') <= 6976.0:
|
| 69 |
+
if context.get_value('n') <= 3520.0:
|
| 70 |
+
if context.get_value('m*n') <= 37632.0:
|
| 71 |
+
return None
|
| 72 |
+
else:
|
| 73 |
+
return [(1.000, 13)]
|
| 74 |
+
else:
|
| 75 |
+
if context.get_value('m*k') <= 452352.0:
|
| 76 |
+
return [(0.590, 13), (0.256, 8), (0.103, 7), (0.051, 11)]
|
| 77 |
+
else:
|
| 78 |
+
return [(0.778, 8), (0.222, 13)]
|
| 79 |
+
else:
|
| 80 |
+
if context.get_value('k*n') <= 102776832.0:
|
| 81 |
+
if context.get_value('n') <= 14656.0:
|
| 82 |
+
return [(1.000, 11)]
|
| 83 |
+
else:
|
| 84 |
+
return [(0.889, 11), (0.111, 13)]
|
| 85 |
+
else:
|
| 86 |
+
return [(1.000, 11)]
|
| 87 |
+
else:
|
| 88 |
+
if context.get_value('m*n') <= 446464.0:
|
| 89 |
+
if context.get_value('m*n') <= 223424.0:
|
| 90 |
+
if context.get_value('mat1_stride_0') <= 3968.0:
|
| 91 |
+
return None
|
| 92 |
+
else:
|
| 93 |
+
return None
|
| 94 |
+
else:
|
| 95 |
+
if context.get_value('m*n') <= 346112.0:
|
| 96 |
+
return [(0.960, 16), (0.040, 7)]
|
| 97 |
+
else:
|
| 98 |
+
return [(0.750, 16), (0.136, 14), (0.114, 7)]
|
| 99 |
+
else:
|
| 100 |
+
if str(context.get_value('33LEQmLEQ64')) != 'True':
|
| 101 |
+
if context.get_value('n') <= 6976.0:
|
| 102 |
+
return [(1.000, 14)]
|
| 103 |
+
else:
|
| 104 |
+
return [(0.753, 2), (0.222, 1), (0.015, 7), (0.007, 16), (0.004, 12)]
|
| 105 |
+
else:
|
| 106 |
+
if context.get_value('n') <= 13888.0:
|
| 107 |
+
return [(0.710, 14), (0.275, 21), (0.014, 12)]
|
| 108 |
+
else:
|
| 109 |
+
return [(0.374, 19), (0.339, 20), (0.106, 21), (0.101, 16), (0.066, 17), (0.009, 14), (0.004, 18)]
|
| 110 |
+
else:
|
| 111 |
+
if context.get_value('n') <= 3520.0:
|
| 112 |
+
if context.get_value('arith_intensity') <= 3.994754433631897:
|
| 113 |
+
if str(context.get_value('mat2_dtype')) != 'torch.uint8':
|
| 114 |
+
if context.get_value('m*k') <= 18944.0:
|
| 115 |
+
return [(0.577, 5), (0.423, 6)]
|
| 116 |
+
else:
|
| 117 |
+
return [(0.988, 5), (0.012, 6)]
|
| 118 |
+
else:
|
| 119 |
+
if context.get_value('arith_intensity') <= 2.9899919033050537:
|
| 120 |
+
return None
|
| 121 |
+
else:
|
| 122 |
+
return None
|
| 123 |
+
else:
|
| 124 |
+
if context.get_value('arith_intensity') <= 7.956453561782837:
|
| 125 |
+
if context.get_value('k*n') <= 9244032.0:
|
| 126 |
+
return [(0.822, 5), (0.178, 6)]
|
| 127 |
+
else:
|
| 128 |
+
return [(0.977, 5), (0.023, 0)]
|
| 129 |
+
else:
|
| 130 |
+
if context.get_value('m*k') <= 978944.0:
|
| 131 |
+
return [(1.000, 5)]
|
| 132 |
+
else:
|
| 133 |
+
return [(0.971, 5), (0.029, 0)]
|
| 134 |
+
else:
|
| 135 |
+
if context.get_value('n') <= 13632.0:
|
| 136 |
+
if context.get_value('n') <= 6976.0:
|
| 137 |
+
return [(1.000, 6)]
|
| 138 |
+
else:
|
| 139 |
+
if context.get_value('k') <= 3968.0:
|
| 140 |
+
return [(0.617, 3), (0.111, 5), (0.099, 7), (0.086, 9), (0.062, 6), (0.025, 8)]
|
| 141 |
+
else:
|
| 142 |
+
return [(0.779, 8), (0.119, 5), (0.053, 7), (0.035, 6), (0.013, 3)]
|
| 143 |
+
else:
|
| 144 |
+
if context.get_value('k*n') <= 39518208.0:
|
| 145 |
+
return [(0.385, 4), (0.327, 3), (0.192, 6), (0.038, 7), (0.038, 10), (0.019, 5)]
|
| 146 |
+
else:
|
| 147 |
+
if context.get_value('n') <= 20800.0:
|
| 148 |
+
return [(0.821, 6), (0.121, 7), (0.029, 4), (0.014, 5), (0.007, 3), (0.007, 8)]
|
| 149 |
+
else:
|
| 150 |
+
return [(0.530, 7), (0.386, 6), (0.046, 8), (0.021, 3), (0.015, 4), (0.002, 5)]
|
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
Choice,
|
| 11 |
+
)
|
| 12 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 13 |
+
LearnedHeuristicDecision,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MixedMMH100(LearnedHeuristicDecision):
|
| 18 |
+
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.choices: List[Choice] = []
|
| 21 |
+
self.fill_choices()
|
| 22 |
+
|
| 23 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 24 |
+
return (
|
| 25 |
+
metadata.name == self.get_name()
|
| 26 |
+
and metadata.shared_memory == 232448
|
| 27 |
+
and str(metadata.device_capa) == "(9, 0)"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def get_confidence_threshold(self) -> float:
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
def get_choice(self, idx: int) -> Optional[str]:
|
| 34 |
+
if idx < len(self.choices):
|
| 35 |
+
return self.choices[idx]
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def fill_choices(self) -> None:
|
| 39 |
+
self.choices.append('extern_fallback_mixed_mm')
|
| 40 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 41 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 42 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 43 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 44 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
|
| 45 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
|
| 46 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 47 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
|
| 48 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
|
| 49 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 50 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 51 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 52 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
|
| 53 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 54 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 55 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 56 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 57 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 58 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 59 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 60 |
+
|
| 61 |
+
def get_name(self) -> str:
|
| 62 |
+
return 'mixed_mm'
|
| 63 |
+
|
| 64 |
+
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
|
| 65 |
+
if context.get_value('arith_intensity') <= 15.988086223602295:
|
| 66 |
+
if context.get_value('n') <= 25280.0:
|
| 67 |
+
if context.get_value('n') <= 1344.0:
|
| 68 |
+
if context.get_value('mat1_stride_0') <= 7808.0:
|
| 69 |
+
return [(0.581, 7), (0.419, 6)]
|
| 70 |
+
else:
|
| 71 |
+
if context.get_value('m*n') <= 7680.0:
|
| 72 |
+
return [(0.875, 0), (0.125, 6)]
|
| 73 |
+
else:
|
| 74 |
+
return [(0.833, 0), (0.167, 7)]
|
| 75 |
+
else:
|
| 76 |
+
if context.get_value('n') <= 8512.0:
|
| 77 |
+
if str(context.get_value('mat2_dtype')) != 'torch.int8':
|
| 78 |
+
return [(0.763, 6), (0.237, 7)]
|
| 79 |
+
else:
|
| 80 |
+
return [(0.725, 7), (0.275, 6)]
|
| 81 |
+
else:
|
| 82 |
+
if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
|
| 83 |
+
return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)]
|
| 84 |
+
else:
|
| 85 |
+
return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)]
|
| 86 |
+
else:
|
| 87 |
+
if context.get_value('n') <= 42254.0:
|
| 88 |
+
if context.get_value('n') <= 33856.0:
|
| 89 |
+
if context.get_value('k*n') <= 68157440.0:
|
| 90 |
+
return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)]
|
| 91 |
+
else:
|
| 92 |
+
return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)]
|
| 93 |
+
else:
|
| 94 |
+
return [(0.659, 5), (0.341, 6)]
|
| 95 |
+
else:
|
| 96 |
+
if context.get_value('k*n') <= 326052992.0:
|
| 97 |
+
if context.get_value('n') <= 55232.0:
|
| 98 |
+
return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)]
|
| 99 |
+
else:
|
| 100 |
+
return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)]
|
| 101 |
+
else:
|
| 102 |
+
if context.get_value('n') <= 57024.0:
|
| 103 |
+
return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)]
|
| 104 |
+
else:
|
| 105 |
+
return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)]
|
| 106 |
+
else:
|
| 107 |
+
if context.get_value('m*n') <= 543936.0:
|
| 108 |
+
if str(context.get_value('17LEQmLEQ32')) != 'True':
|
| 109 |
+
if context.get_value('m*n') <= 262272.0:
|
| 110 |
+
if context.get_value('n') <= 1592.5:
|
| 111 |
+
return [(0.860, 0), (0.140, 9)]
|
| 112 |
+
else:
|
| 113 |
+
return None
|
| 114 |
+
else:
|
| 115 |
+
if context.get_value('m*k') <= 1294336.0:
|
| 116 |
+
return [(0.833, 17), (0.150, 18), (0.017, 15)]
|
| 117 |
+
else:
|
| 118 |
+
return [(0.917, 17), (0.083, 8)]
|
| 119 |
+
else:
|
| 120 |
+
if context.get_value('n') <= 12416.0:
|
| 121 |
+
if context.get_value('m*n') <= 43008.0:
|
| 122 |
+
return None
|
| 123 |
+
else:
|
| 124 |
+
return [(0.853, 14), (0.147, 9)]
|
| 125 |
+
else:
|
| 126 |
+
return [(0.625, 12), (0.375, 14)]
|
| 127 |
+
else:
|
| 128 |
+
if context.get_value('m') <= 32.5:
|
| 129 |
+
if context.get_value('mat2_stride_1') <= 6656.0:
|
| 130 |
+
if context.get_value('n') <= 69184.0:
|
| 131 |
+
return [(0.611, 12), (0.361, 14), (0.028, 13)]
|
| 132 |
+
else:
|
| 133 |
+
return [(1.000, 12)]
|
| 134 |
+
else:
|
| 135 |
+
if context.get_value('mat2_stride_1') <= 20864.0:
|
| 136 |
+
return [(1.000, 12)]
|
| 137 |
+
else:
|
| 138 |
+
return [(0.958, 12), (0.042, 9)]
|
| 139 |
+
else:
|
| 140 |
+
if context.get_value('m*n') <= 1085440.0:
|
| 141 |
+
if context.get_value('n') <= 9152.0:
|
| 142 |
+
return [(1.000, 18)]
|
| 143 |
+
else:
|
| 144 |
+
return [(0.780, 18), (0.160, 16), (0.060, 20)]
|
| 145 |
+
else:
|
| 146 |
+
if context.get_value('m') <= 67.0:
|
| 147 |
+
return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)]
|
| 148 |
+
else:
|
| 149 |
+
return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)]
|
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/
|
| 5 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL
|
| 6 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 7 |
+
LearnedHeuristicRegression,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PadMMA100(LearnedHeuristicRegression):
|
| 12 |
+
|
| 13 |
+
def __init__(self) -> None:
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 17 |
+
return (
|
| 18 |
+
metadata.name == self.get_name()
|
| 19 |
+
and metadata.shared_memory == 166912
|
| 20 |
+
and str(metadata.device_capa) == "(8, 0)"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def get_feedback(self, context: AHContext, choice: Choice) -> float:
|
| 24 |
+
context.context_dict[CHOICE_COL] = choice
|
| 25 |
+
return self.predict(context)
|
| 26 |
+
|
| 27 |
+
def get_confidence_threshold(self) -> float:
|
| 28 |
+
return 1.7025303314066
|
| 29 |
+
|
| 30 |
+
def get_name(self) -> str:
|
| 31 |
+
return 'pad_mm'
|
| 32 |
+
|
| 33 |
+
def predict(self, context: AHContext) -> float:
|
| 34 |
+
if str(context.get_value('choice')) != 'pad':
|
| 35 |
+
if str(context.get_value('using_tf32')) != 'False':
|
| 36 |
+
if context.get_value('m*n') <= 4171264.0:
|
| 37 |
+
if context.get_value('m*k') <= 3999308.0:
|
| 38 |
+
return 1.8751469764071178
|
| 39 |
+
else:
|
| 40 |
+
if str(context.get_value('n_multiple_32')) != 'True':
|
| 41 |
+
return 0.9117231355626345
|
| 42 |
+
else:
|
| 43 |
+
return 1.1607689608873861
|
| 44 |
+
else:
|
| 45 |
+
if str(context.get_value('n_multiple_2')) != 'True':
|
| 46 |
+
if str(context.get_value('using_tf32')) != 'True':
|
| 47 |
+
return 0.7430382200435992
|
| 48 |
+
else:
|
| 49 |
+
return 0.8531269794448678
|
| 50 |
+
else:
|
| 51 |
+
if str(context.get_value('k_multiple_2')) != 'True':
|
| 52 |
+
return 0.7577181972719917
|
| 53 |
+
else:
|
| 54 |
+
return 0.8977349440424219
|
| 55 |
+
else:
|
| 56 |
+
if context.get_value('m*n') <= 1299712.0:
|
| 57 |
+
return 1.1669723418995592
|
| 58 |
+
else:
|
| 59 |
+
if context.get_value('mat2_stride_1') <= 45217.5:
|
| 60 |
+
if context.get_value('m*n') <= 55884158.0:
|
| 61 |
+
return 1.0262769936909601
|
| 62 |
+
else:
|
| 63 |
+
return 1.0022677428470845
|
| 64 |
+
else:
|
| 65 |
+
if context.get_value('m') <= 18478.0:
|
| 66 |
+
return 1.1127066261894312
|
| 67 |
+
else:
|
| 68 |
+
return 1.0337740659894263
|
| 69 |
+
else:
|
| 70 |
+
if str(context.get_value('mat1_dtype')) != 'torch.float32':
|
| 71 |
+
if str(context.get_value('n_multiple_2')) != 'False':
|
| 72 |
+
if str(context.get_value('k_multiple_2')) != 'True':
|
| 73 |
+
if context.get_value('mat1_stride_0') <= 561.0:
|
| 74 |
+
return 1.2900382135142956
|
| 75 |
+
else:
|
| 76 |
+
return 1.5761737616057887
|
| 77 |
+
else:
|
| 78 |
+
if context.get_value('num_dims_needs_padding') <= 1.5:
|
| 79 |
+
return 1.0472263310239422
|
| 80 |
+
else:
|
| 81 |
+
return 1.1727673465762514
|
| 82 |
+
else:
|
| 83 |
+
if context.get_value('k') <= 28238.5:
|
| 84 |
+
if context.get_value('k/(m*n)') <= 0.00026227018679492176:
|
| 85 |
+
return 1.6770542505397175
|
| 86 |
+
else:
|
| 87 |
+
return 1.3974785435105923
|
| 88 |
+
else:
|
| 89 |
+
if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
|
| 90 |
+
return 1.3952699800111992
|
| 91 |
+
else:
|
| 92 |
+
return 1.5759286511628336
|
| 93 |
+
else:
|
| 94 |
+
if str(context.get_value('using_tf32')) != 'False':
|
| 95 |
+
if context.get_value('m*n') <= 14119424.0:
|
| 96 |
+
return 0.8875772670422478
|
| 97 |
+
else:
|
| 98 |
+
if str(context.get_value('mat2_innermost_needs_padding')) != 'True':
|
| 99 |
+
return 1.1467728924377265
|
| 100 |
+
else:
|
| 101 |
+
return 1.215842963532998
|
| 102 |
+
else:
|
| 103 |
+
if context.get_value('arith_intensity') <= 396.8774871826172:
|
| 104 |
+
return 0.89940161869551
|
| 105 |
+
else:
|
| 106 |
+
if context.get_value('mat2_stride_1') <= 45217.5:
|
| 107 |
+
return 0.9964328169353532
|
| 108 |
+
else:
|
| 109 |
+
return 0.9493479238294826
|
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
AHOperation,
|
| 11 |
+
Choice,
|
| 12 |
+
CHOICE_COL,
|
| 13 |
+
Feedback,
|
| 14 |
+
FEEDBACK_COL,
|
| 15 |
+
get_metadata_str_from_log,
|
| 16 |
+
)
|
| 17 |
+
from torch._inductor.autoheuristic.learned_heuristic_controller import (
|
| 18 |
+
LearnedHeuristicController,
|
| 19 |
+
)
|
| 20 |
+
from torch._inductor.ir import ChoiceCaller
|
| 21 |
+
from torch._inductor.runtime.runtime_utils import cache_dir
|
| 22 |
+
from torch._inductor.utils import get_gpu_shared_memory
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class LocalFeedback:
|
| 26 |
+
"""
|
| 27 |
+
To be able to collect data for a choice, a function providing feedback given a choice has to be provided.
|
| 28 |
+
LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice
|
| 29 |
+
(see pad_mm.py, where the autotuning happens locally, for an example).
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None:
|
| 33 |
+
self.feedback_fn = feedback_fn
|
| 34 |
+
|
| 35 |
+
def __call__(self, choice: Choice) -> Feedback:
|
| 36 |
+
return self.feedback_fn(choice)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class InconsistentMetadata(Exception):
|
| 40 |
+
"""
|
| 41 |
+
Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does
|
| 42 |
+
not match the metadata it would store if the file didn't exist.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class AutoHeuristic:
|
| 47 |
+
"""
|
| 48 |
+
AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and
|
| 49 |
+
generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train
|
| 50 |
+
a heuristic (see torchgen/autoheuristic/).
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
collected_feedback: Dict[Choice, Feedback]
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
fallback: Callable[[], Choice],
|
| 58 |
+
choices: List[Choice],
|
| 59 |
+
feedback: Optional[LocalFeedback],
|
| 60 |
+
context: AHContext,
|
| 61 |
+
name: str,
|
| 62 |
+
augment_context: Optional[List[AHOperation]] = None,
|
| 63 |
+
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
|
| 64 |
+
) -> None:
|
| 65 |
+
"""
|
| 66 |
+
Initializes an instance of the AutoHeuristic class.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or
|
| 70 |
+
AutoHeuristic is in data collection mode.
|
| 71 |
+
choices: A list of possible choices the heuristic can make.
|
| 72 |
+
feedback: An instance of LocalFeedback that provides feedback for a given choice.
|
| 73 |
+
context: Context to store with each choice and feedback.
|
| 74 |
+
name: A string that identifies the heuristic.
|
| 75 |
+
augment_context: An optional list of AHOperation instances that augment the context.
|
| 76 |
+
precondition: A callable that returns a boolean indicating whether AutoHeuristic should run.
|
| 77 |
+
"""
|
| 78 |
+
self.fallback = fallback
|
| 79 |
+
self.choices = choices
|
| 80 |
+
self.feedback = feedback
|
| 81 |
+
self.context = context
|
| 82 |
+
self.name = name
|
| 83 |
+
self.collected_feedback = {}
|
| 84 |
+
self.augment_context = augment_context
|
| 85 |
+
self.metadata = AHMetadata(
|
| 86 |
+
get_gpu_shared_memory(),
|
| 87 |
+
torch.cuda.get_device_capability(),
|
| 88 |
+
self.choices,
|
| 89 |
+
self.name,
|
| 90 |
+
)
|
| 91 |
+
self.precondition = precondition
|
| 92 |
+
|
| 93 |
+
if not self.satisfies_precondition():
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
if torch._inductor.config.autoheuristic_log_path == "DEFAULT":
|
| 97 |
+
self.log_path = self.get_default_log_path()
|
| 98 |
+
else:
|
| 99 |
+
self.log_path = torch._inductor.config.autoheuristic_log_path
|
| 100 |
+
|
| 101 |
+
if torch._inductor.config.collect_autoheuristic(self.name):
|
| 102 |
+
if self.feedback is not None:
|
| 103 |
+
for choice in self.choices:
|
| 104 |
+
feedback_val = self.feedback(choice)
|
| 105 |
+
self.save_data(choice, feedback_val)
|
| 106 |
+
|
| 107 |
+
def satisfies_precondition(self) -> bool:
|
| 108 |
+
return self.precondition is None or self.precondition(
|
| 109 |
+
self.metadata, self.context
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def get_choice(self) -> Choice:
|
| 113 |
+
"""
|
| 114 |
+
Returns the chosen option based on the value of autoheuristic_use.
|
| 115 |
+
If self.name is one of the comma separated strings in autoheuristic_use,
|
| 116 |
+
it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
if not self.satisfies_precondition():
|
| 120 |
+
return self.fallback()
|
| 121 |
+
|
| 122 |
+
if torch._inductor.config.use_autoheuristic(self.name):
|
| 123 |
+
if self.augment_context is not None:
|
| 124 |
+
self.context.apply_operations(self.augment_context)
|
| 125 |
+
controller = LearnedHeuristicController(
|
| 126 |
+
self.metadata,
|
| 127 |
+
self.context,
|
| 128 |
+
)
|
| 129 |
+
decision = controller.get_decision()
|
| 130 |
+
if decision not in self.choices:
|
| 131 |
+
# TODO(AlnisM): We might want to allow this in the future
|
| 132 |
+
return self.fallback()
|
| 133 |
+
if decision is not None:
|
| 134 |
+
return decision
|
| 135 |
+
return self.fallback()
|
| 136 |
+
|
| 137 |
+
def get_top_k_choices(
|
| 138 |
+
self, top_k: int, always_included: Optional[List[str]] = None
|
| 139 |
+
) -> Optional[List[Choice]]:
|
| 140 |
+
if not self.satisfies_precondition():
|
| 141 |
+
return None
|
| 142 |
+
if torch._inductor.config.use_autoheuristic(self.name):
|
| 143 |
+
if self.augment_context is not None:
|
| 144 |
+
self.context.apply_operations(self.augment_context)
|
| 145 |
+
controller = LearnedHeuristicController(
|
| 146 |
+
self.metadata,
|
| 147 |
+
self.context,
|
| 148 |
+
)
|
| 149 |
+
choices = controller.get_decisions_ranked(top_k)
|
| 150 |
+
if choices is None:
|
| 151 |
+
return None
|
| 152 |
+
if always_included is not None:
|
| 153 |
+
for choice in always_included:
|
| 154 |
+
if choice not in choices:
|
| 155 |
+
choices.append(choice)
|
| 156 |
+
return choices
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
def get_collected_feedback(self, choice: Choice) -> Any:
|
| 160 |
+
return self.collected_feedback.get(choice, None)
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def get_device_identifier() -> str:
|
| 164 |
+
# a heuristic might work well for one GPU, but not for another
|
| 165 |
+
# we store the collected data per GPU model and learn a heuristic per GPU model
|
| 166 |
+
|
| 167 |
+
# TODO(AlnisM): just using the device name for now, but the same GPU model can have different names
|
| 168 |
+
device_name = torch.cuda.get_device_name().replace(" ", "_")
|
| 169 |
+
return device_name
|
| 170 |
+
|
| 171 |
+
def get_default_log_path(self) -> str:
|
| 172 |
+
device_name = self.get_device_identifier()
|
| 173 |
+
path = f"{cache_dir()}/autoheuristic/{device_name}/"
|
| 174 |
+
os.makedirs(path, exist_ok=True)
|
| 175 |
+
path += f"{self.name}.txt"
|
| 176 |
+
return path
|
| 177 |
+
|
| 178 |
+
def serialize_metadata(self) -> str:
|
| 179 |
+
metadata_dict = self.metadata.to_dict()
|
| 180 |
+
(
|
| 181 |
+
num_features,
|
| 182 |
+
cat_features,
|
| 183 |
+
) = self.context.get_numerical_and_categorical_features()
|
| 184 |
+
metadata_dict["numerical_features"] = num_features
|
| 185 |
+
metadata_dict["categorical_features"] = cat_features
|
| 186 |
+
return json.dumps(metadata_dict)
|
| 187 |
+
|
| 188 |
+
def save_data(self, choice: Choice, feedback_val: Feedback) -> None:
|
| 189 |
+
self.collected_feedback[choice] = feedback_val
|
| 190 |
+
log_path = self.log_path
|
| 191 |
+
|
| 192 |
+
lines = []
|
| 193 |
+
log_exists = os.path.exists(log_path)
|
| 194 |
+
if log_exists:
|
| 195 |
+
# if log already exists, make sure it is consistent
|
| 196 |
+
metadata = self.serialize_metadata()
|
| 197 |
+
existing_metadata = get_metadata_str_from_log(self.log_path)
|
| 198 |
+
if existing_metadata != metadata:
|
| 199 |
+
raise InconsistentMetadata(
|
| 200 |
+
"Given metadata does not match existing metadata"
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
lines.append(self.serialize_metadata())
|
| 204 |
+
feature_header = self.context.get_feature_names_csv()
|
| 205 |
+
header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL
|
| 206 |
+
lines.append(header)
|
| 207 |
+
|
| 208 |
+
line = ""
|
| 209 |
+
feature_values = self.context.get_feature_values_csv()
|
| 210 |
+
line += feature_values + "," + choice + "," + str(feedback_val)
|
| 211 |
+
lines.append(line)
|
| 212 |
+
|
| 213 |
+
with open(log_path, "a") as f:
|
| 214 |
+
f.write("\n".join(lines) + "\n")
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class AutoHeuristicSelectAlgorithm(AutoHeuristic):
|
| 218 |
+
"""
|
| 219 |
+
AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic
|
| 220 |
+
when one wants to use AutoHeuristic for kernel choice selection.
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
fallback: Callable[[], Optional[ChoiceCaller]],
|
| 226 |
+
choices: List[ChoiceCaller],
|
| 227 |
+
input_nodes: List[Any],
|
| 228 |
+
context: AHContext,
|
| 229 |
+
name: str,
|
| 230 |
+
augment_context: Optional[List[AHOperation]] = None,
|
| 231 |
+
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
|
| 232 |
+
) -> None:
|
| 233 |
+
"""
|
| 234 |
+
The arguments choices, input_nodes and name have to match the ones used in the call to
|
| 235 |
+
autotune_select_algorithm(), e.g. if the following call is made
|
| 236 |
+
autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes
|
| 237 |
+
have to be used here.
|
| 238 |
+
"""
|
| 239 |
+
self.input_nodes = input_nodes
|
| 240 |
+
self.choicestr2choice: Dict[str, ChoiceCaller] = {}
|
| 241 |
+
for choice in choices:
|
| 242 |
+
self.choicestr2choice[choice.autoheuristic_id()] = choice
|
| 243 |
+
choices_str = list(self.choicestr2choice.keys())
|
| 244 |
+
|
| 245 |
+
def fallback_str() -> str:
|
| 246 |
+
fallback_choice = fallback()
|
| 247 |
+
if fallback_choice is None:
|
| 248 |
+
# TODO: Find a nicer way to handle this
|
| 249 |
+
return "unsure"
|
| 250 |
+
return fallback_choice.autoheuristic_id()
|
| 251 |
+
|
| 252 |
+
super().__init__(
|
| 253 |
+
fallback_str,
|
| 254 |
+
choices_str,
|
| 255 |
+
None,
|
| 256 |
+
context,
|
| 257 |
+
name,
|
| 258 |
+
augment_context,
|
| 259 |
+
precondition,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
if (
|
| 263 |
+
torch._inductor.config.collect_autoheuristic(self.name)
|
| 264 |
+
and self.satisfies_precondition()
|
| 265 |
+
):
|
| 266 |
+
self.register_global_feedback(input_nodes, choices)
|
| 267 |
+
|
| 268 |
+
def register_global_feedback(
|
| 269 |
+
self, input_nodes: List[Any], choices: List[ChoiceCaller]
|
| 270 |
+
) -> None:
|
| 271 |
+
"""
|
| 272 |
+
Registers a callback in select_algorithm, which is called with the timing of each choice.
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
from torch._inductor.select_algorithm import (
|
| 276 |
+
add_feedback_saver,
|
| 277 |
+
create_inputs_key,
|
| 278 |
+
create_precompile_key,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def store_global_feedback(
|
| 282 |
+
ah_inputs_key: str,
|
| 283 |
+
ah_precompile_key: str,
|
| 284 |
+
timings: Dict[ChoiceCaller, float],
|
| 285 |
+
name: str,
|
| 286 |
+
input_nodes: List[Any],
|
| 287 |
+
choices: List[ChoiceCaller],
|
| 288 |
+
) -> None:
|
| 289 |
+
current_inputs_key = create_inputs_key(input_nodes)
|
| 290 |
+
if current_inputs_key != ah_inputs_key:
|
| 291 |
+
return
|
| 292 |
+
current_precompile_key = create_precompile_key(
|
| 293 |
+
name, current_inputs_key, choices
|
| 294 |
+
)
|
| 295 |
+
if current_precompile_key != ah_precompile_key:
|
| 296 |
+
return
|
| 297 |
+
for choice, time in timings.items():
|
| 298 |
+
self.save_data(choice.autoheuristic_id(), time)
|
| 299 |
+
|
| 300 |
+
inputs_key = create_inputs_key(input_nodes)
|
| 301 |
+
precompile_key = create_precompile_key(self.name, inputs_key, choices)
|
| 302 |
+
feedback_saver = partial(store_global_feedback, inputs_key, precompile_key)
|
| 303 |
+
add_feedback_saver(feedback_saver)
|
| 304 |
+
|
| 305 |
+
def get_choice_caller(self) -> Optional[ChoiceCaller]:
|
| 306 |
+
choice = self.get_choice()
|
| 307 |
+
return self.choicestr2choice.get(choice, None)
|
| 308 |
+
|
| 309 |
+
def get_top_k_choices_caller(
|
| 310 |
+
self, top_k: int, always_included: Optional[List[str]] = None
|
| 311 |
+
) -> Optional[List[ChoiceCaller]]:
|
| 312 |
+
choices = self.get_top_k_choices(top_k, always_included)
|
| 313 |
+
if choices is None:
|
| 314 |
+
return None
|
| 315 |
+
return [self.choicestr2choice[choice] for choice in choices]
|
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
from typing import Any, Callable, Dict, List, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
Feedback = float
|
| 8 |
+
Choice = str
|
| 9 |
+
Value = Any
|
| 10 |
+
|
| 11 |
+
CHOICE_COL = "choice"
|
| 12 |
+
FEEDBACK_COL = "feedback"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AHFeature:
|
| 16 |
+
"""
|
| 17 |
+
The context, that AutoHeuristic stores, is a list of features. AutoHeuristic needs to know whether a feature is
|
| 18 |
+
categorical (i.e., not a continuous variable) to learn a machine learning model.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, name: str, value: Value, is_categorical: bool = False) -> None:
|
| 22 |
+
self.name = name
|
| 23 |
+
self.value = value
|
| 24 |
+
self.is_categorical = is_categorical
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class AHOperation:
|
| 28 |
+
"""
|
| 29 |
+
AHOperation can be used to augment the data collected by AutoHeuristic.
|
| 30 |
+
One might for example store features like m, k, n, but also want to use
|
| 31 |
+
features like m*n, or k*n, to learn a heuristic. Instead of storing features
|
| 32 |
+
that can be created from the collected data, one can use AHOperation to
|
| 33 |
+
create new features from the collected data.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self, name: str, func: Callable[[Any], Value], is_categorical: bool = False
|
| 38 |
+
) -> None:
|
| 39 |
+
self.name = name
|
| 40 |
+
self.func = func
|
| 41 |
+
self.is_categorical = is_categorical
|
| 42 |
+
|
| 43 |
+
def apply_operation(self, data: Any) -> None:
|
| 44 |
+
data[self.name] = self.func(data)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class AHContext:
|
| 48 |
+
"""
|
| 49 |
+
This class is used to specify which information AutoHeuristic should store. For each choice, AutoHeursitic will
|
| 50 |
+
store the context and the collected feedback. The context could be something like the shape of a tensor, i.e.,
|
| 51 |
+
information that will help to learn a heuristic.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
features: List[AHFeature]
|
| 55 |
+
context_dict: Dict[str, Value]
|
| 56 |
+
|
| 57 |
+
def __init__(self) -> None:
|
| 58 |
+
self.features = []
|
| 59 |
+
self.context_dict = {}
|
| 60 |
+
|
| 61 |
+
def add_feature(
|
| 62 |
+
self, name: str, value: Value, is_categorical: bool = False
|
| 63 |
+
) -> None:
|
| 64 |
+
self.features.append(AHFeature(name, value, is_categorical=is_categorical))
|
| 65 |
+
self.context_dict[name] = value
|
| 66 |
+
|
| 67 |
+
def get_numerical_and_categorical_features(self) -> Tuple[List[str], List[str]]:
|
| 68 |
+
numerical_features = []
|
| 69 |
+
categorical_features = []
|
| 70 |
+
for feature in self.features:
|
| 71 |
+
if feature.is_categorical:
|
| 72 |
+
categorical_features.append(feature.name)
|
| 73 |
+
else:
|
| 74 |
+
numerical_features.append(feature.name)
|
| 75 |
+
|
| 76 |
+
return numerical_features, categorical_features
|
| 77 |
+
|
| 78 |
+
def get_feature_names_csv(self) -> str:
|
| 79 |
+
return ",".join(feature.name for feature in self.features)
|
| 80 |
+
|
| 81 |
+
def get_feature_values_csv(self) -> str:
|
| 82 |
+
return ",".join(str(feature.value) for feature in self.features)
|
| 83 |
+
|
| 84 |
+
def get_value(self, name: str) -> Value:
|
| 85 |
+
return self.context_dict[name]
|
| 86 |
+
|
| 87 |
+
def apply_operations(self, operations: List[AHOperation]) -> None:
|
| 88 |
+
for op in operations:
|
| 89 |
+
op.apply_operation(self.context_dict)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class AHMetadata:
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
shared_memory: Any,
|
| 96 |
+
device_capa: Tuple[int, int],
|
| 97 |
+
choices: List[Choice],
|
| 98 |
+
name: str,
|
| 99 |
+
) -> None:
|
| 100 |
+
# use amount of shared_memory and device_capability to identify GPU
|
| 101 |
+
# TODO(AlnisM): there might be a better way to do this
|
| 102 |
+
self.shared_memory = shared_memory
|
| 103 |
+
self.device_capa = device_capa
|
| 104 |
+
self.choices = choices
|
| 105 |
+
self.name = name
|
| 106 |
+
|
| 107 |
+
def to_dict(self) -> Dict[str, Value]:
|
| 108 |
+
return {
|
| 109 |
+
"shared_memory": self.shared_memory,
|
| 110 |
+
"device_capa": self.device_capa,
|
| 111 |
+
"name": self.name,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_metadata_str_from_log(log_path: str) -> str:
|
| 116 |
+
with open(log_path, newline="") as file:
|
| 117 |
+
json_string = file.readline().strip()
|
| 118 |
+
return json_string
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def check_minsize(context: AHContext, minsize: int) -> bool:
|
| 122 |
+
return (
|
| 123 |
+
context.get_value("m") >= minsize
|
| 124 |
+
and context.get_value("k") >= minsize
|
| 125 |
+
and context.get_value("n") >= minsize
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def pad_mm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
|
| 130 |
+
if metadata.shared_memory == 166912 and metadata.device_capa == (8, 0):
|
| 131 |
+
# A100 precondition
|
| 132 |
+
return check_minsize(context, 512)
|
| 133 |
+
elif metadata.shared_memory == 232448 and metadata.device_capa == (9, 0):
|
| 134 |
+
# H100 precondition
|
| 135 |
+
return check_minsize(context, 768)
|
| 136 |
+
return True
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
|
| 140 |
+
m = context.get_value("m")
|
| 141 |
+
k = context.get_value("k")
|
| 142 |
+
n = context.get_value("n")
|
| 143 |
+
if m > 128 or k < 1024 or n < 1024:
|
| 144 |
+
return False
|
| 145 |
+
mat1_iscontig = context.get_value("mat1_iscontig")
|
| 146 |
+
mat2_iscontig = context.get_value("mat2_iscontig")
|
| 147 |
+
return mat1_iscontig and not mat2_iscontig
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def get_mult_dims_ops() -> List[AHOperation]:
|
| 151 |
+
m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"])
|
| 152 |
+
m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"])
|
| 153 |
+
k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"])
|
| 154 |
+
return [m_times_k_op, m_times_n_op, k_times_n_op]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_arith_intensity(data: Any) -> float:
|
| 158 |
+
m = data["m"]
|
| 159 |
+
k = data["k"]
|
| 160 |
+
n = data["n"]
|
| 161 |
+
if m == 0 or k == 0 or n == 0:
|
| 162 |
+
return 0.0
|
| 163 |
+
return m * k * n / (m * k + k * n + m * n)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def pad_mm_operations() -> List[AHOperation]:
|
| 167 |
+
mult_dims_ops = get_mult_dims_ops()
|
| 168 |
+
k_div_m_times_n_op = AHOperation(
|
| 169 |
+
"k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"])
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def bfloat_perf_hit(data: Any) -> bool:
|
| 173 |
+
m = data["m"]
|
| 174 |
+
k = data["k"]
|
| 175 |
+
n = data["n"]
|
| 176 |
+
is_bfloat = str(data["mat1_dtype"]) == "torch.bfloat16"
|
| 177 |
+
return k > (m * 1024) and k > (n * 1024) and is_bfloat
|
| 178 |
+
|
| 179 |
+
bfloat_perf_hit_op = AHOperation(
|
| 180 |
+
"bfloat_perf_hit", bfloat_perf_hit, is_categorical=True
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
|
| 184 |
+
dims_need_padding_ops = get_dims_need_padding_ops()
|
| 185 |
+
dims_multiple_ops = get_dims_multiple_ops()
|
| 186 |
+
is_contig_ops = get_is_contig_ops()
|
| 187 |
+
|
| 188 |
+
ah_operations = mult_dims_ops + [
|
| 189 |
+
k_div_m_times_n_op,
|
| 190 |
+
bfloat_perf_hit_op,
|
| 191 |
+
arith_intensity_op,
|
| 192 |
+
]
|
| 193 |
+
ah_operations.extend(dims_need_padding_ops)
|
| 194 |
+
ah_operations.extend(dims_multiple_ops)
|
| 195 |
+
ah_operations.extend(is_contig_ops)
|
| 196 |
+
return ah_operations
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def between_op(data: Any, dim: str, lower: int, upper: int) -> bool:
|
| 200 |
+
return data[dim] >= lower and data[dim] <= upper
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def between_ops() -> List[AHOperation]:
|
| 204 |
+
dims = ["m", "k", "n"]
|
| 205 |
+
limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)]
|
| 206 |
+
ah_operations = []
|
| 207 |
+
for dim in dims:
|
| 208 |
+
for lower, upper in limits:
|
| 209 |
+
between_op_fn = functools.partial(
|
| 210 |
+
between_op, dim=dim, lower=lower, upper=upper
|
| 211 |
+
)
|
| 212 |
+
# using 'LEQ' instead of '<=' because '<=' cannot be exported to dot
|
| 213 |
+
between_op_name = f"{lower}LEQ{dim}LEQ{upper}"
|
| 214 |
+
ah_operations.append(
|
| 215 |
+
AHOperation(between_op_name, between_op_fn, is_categorical=True)
|
| 216 |
+
)
|
| 217 |
+
return ah_operations
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def pow2_op(data: Any, dim: str, exponent: int) -> bool:
|
| 221 |
+
return data[dim] == 2**exponent
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def mm_operations() -> List[AHOperation]:
|
| 225 |
+
mult_dims_ops = get_mult_dims_ops()
|
| 226 |
+
arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
|
| 227 |
+
return mult_dims_ops + [arith_intensity_op]
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def mixed_mm_operations() -> List[AHOperation]:
|
| 231 |
+
return mm_operations() + between_ops()
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def is_multiple(data: Any, dim: str, mult: int) -> bool:
|
| 235 |
+
return data[dim] % mult == 0
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def get_dims_multiple_ops() -> List[AHOperation]:
|
| 239 |
+
multiples = [2, 4, 8, 16, 32]
|
| 240 |
+
dims = ["m", "k", "n"]
|
| 241 |
+
dims_multiple_ops = []
|
| 242 |
+
for dim in dims:
|
| 243 |
+
for mult in multiples:
|
| 244 |
+
is_multiple_fn = functools.partial(is_multiple, dim=dim, mult=mult)
|
| 245 |
+
dims_multiple_op = AHOperation(
|
| 246 |
+
f"{dim}_multiple_{mult}", is_multiple_fn, is_categorical=True
|
| 247 |
+
)
|
| 248 |
+
dims_multiple_ops.append(dims_multiple_op)
|
| 249 |
+
return dims_multiple_ops
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def get_dims_need_padding_ops() -> List[AHOperation]:
|
| 253 |
+
def mat1_innermost_needs_padding_fn(data: Any) -> bool:
|
| 254 |
+
mat1_stride_0 = data["mat1_stride_0"]
|
| 255 |
+
mat1_stride_1 = data["mat1_stride_1"]
|
| 256 |
+
m_padded_length = data["m_padded_length"]
|
| 257 |
+
k_padded_length = data["k_padded_length"]
|
| 258 |
+
mat1_innermost_needs_padding = False
|
| 259 |
+
if mat1_stride_0 == 1 and m_padded_length != 0:
|
| 260 |
+
mat1_innermost_needs_padding = True
|
| 261 |
+
if mat1_stride_1 == 1 and k_padded_length != 0:
|
| 262 |
+
mat1_innermost_needs_padding = True
|
| 263 |
+
return mat1_innermost_needs_padding
|
| 264 |
+
|
| 265 |
+
mat1_innermost_op = AHOperation(
|
| 266 |
+
"mat1_innermost_needs_padding",
|
| 267 |
+
mat1_innermost_needs_padding_fn,
|
| 268 |
+
is_categorical=True,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
def mat2_innermost_needs_padding_fn(data: Any) -> bool:
|
| 272 |
+
mat2_stride_0 = data["mat2_stride_0"]
|
| 273 |
+
mat2_stride_1 = data["mat2_stride_1"]
|
| 274 |
+
k_padded_length = data["k_padded_length"]
|
| 275 |
+
n_padded_length = data["n_padded_length"]
|
| 276 |
+
mat2_innermost_needs_padding = False
|
| 277 |
+
if mat2_stride_0 == 1 and k_padded_length != 0:
|
| 278 |
+
mat2_innermost_needs_padding = True
|
| 279 |
+
if mat2_stride_1 == 1 and n_padded_length != 0:
|
| 280 |
+
mat2_innermost_needs_padding = True
|
| 281 |
+
return mat2_innermost_needs_padding
|
| 282 |
+
|
| 283 |
+
mat2_innermost_op = AHOperation(
|
| 284 |
+
"mat2_innermost_needs_padding",
|
| 285 |
+
mat2_innermost_needs_padding_fn,
|
| 286 |
+
is_categorical=True,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
def num_dims_needs_padding_fn(data: Any) -> int:
|
| 290 |
+
m_padded_length = data["m_padded_length"]
|
| 291 |
+
k_padded_length = data["k_padded_length"]
|
| 292 |
+
n_padded_length = data["n_padded_length"]
|
| 293 |
+
num_dims_needs_padding = 0
|
| 294 |
+
if m_padded_length != 0:
|
| 295 |
+
num_dims_needs_padding += 1
|
| 296 |
+
if k_padded_length != 0:
|
| 297 |
+
num_dims_needs_padding += 1
|
| 298 |
+
if n_padded_length != 0:
|
| 299 |
+
num_dims_needs_padding += 1
|
| 300 |
+
return num_dims_needs_padding
|
| 301 |
+
|
| 302 |
+
num_dims_op = AHOperation("num_dims_needs_padding", num_dims_needs_padding_fn)
|
| 303 |
+
return [mat1_innermost_op, mat2_innermost_op, num_dims_op]
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def get_is_contig_ops() -> List[AHOperation]:
|
| 307 |
+
def mat1_is_contig_fn(data: Any) -> bool:
|
| 308 |
+
stride_0 = data["mat1_stride_0"]
|
| 309 |
+
stride_1 = data["mat1_stride_1"]
|
| 310 |
+
k = data["k"]
|
| 311 |
+
return stride_0 == k and stride_1 == 1
|
| 312 |
+
|
| 313 |
+
mat1_is_contig_op = AHOperation(
|
| 314 |
+
"mat1_iscontig", mat1_is_contig_fn, is_categorical=True
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
def mat2_is_contig_fn(data: Any) -> bool:
|
| 318 |
+
stride_0 = data["mat2_stride_0"]
|
| 319 |
+
stride_1 = data["mat2_stride_1"]
|
| 320 |
+
n = data["n"]
|
| 321 |
+
return stride_0 == n and stride_1 == 1
|
| 322 |
+
|
| 323 |
+
mat2_is_contig_op = AHOperation(
|
| 324 |
+
"mat2_iscontig", mat2_is_contig_fn, is_categorical=True
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
return [mat1_is_contig_op, mat2_is_contig_op]
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def context_add_strides(context: AHContext, name: str, stride: Tuple[int, ...]) -> None:
|
| 331 |
+
for i, s in enumerate(stride):
|
| 332 |
+
context.add_feature(f"{name}_stride_{i}", s)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def context_add_using_tf32(context: AHContext, dtype: torch.dtype) -> None:
|
| 336 |
+
using_tf32 = "not_float_32"
|
| 337 |
+
if dtype == torch.float32:
|
| 338 |
+
using_tf32 = torch.backends.cuda.matmul.allow_tf32
|
| 339 |
+
context.add_feature("using_tf32", using_tf32, is_categorical=True)
|
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import inspect
|
| 3 |
+
import pkgutil
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
Choice,
|
| 11 |
+
)
|
| 12 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def find_and_instantiate_subclasses(
|
| 16 |
+
package_name: str, base_class: Any
|
| 17 |
+
) -> List[LearnedHeuristic]:
|
| 18 |
+
instances = []
|
| 19 |
+
|
| 20 |
+
package = importlib.import_module(package_name)
|
| 21 |
+
for _, module_name, _ in pkgutil.walk_packages(
|
| 22 |
+
package.__path__, package.__name__ + "."
|
| 23 |
+
):
|
| 24 |
+
try:
|
| 25 |
+
module_basename = module_name.split(".")[-1]
|
| 26 |
+
if not module_basename.startswith("_"):
|
| 27 |
+
# learned heuristics start with an underscore
|
| 28 |
+
continue
|
| 29 |
+
module = importlib.import_module(module_name)
|
| 30 |
+
|
| 31 |
+
# look for classes that are subclasses of base_class
|
| 32 |
+
for name, obj in inspect.getmembers(module):
|
| 33 |
+
if (
|
| 34 |
+
inspect.isclass(obj)
|
| 35 |
+
and issubclass(obj, base_class)
|
| 36 |
+
and obj != base_class
|
| 37 |
+
):
|
| 38 |
+
instance = obj()
|
| 39 |
+
instances.append(instance)
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"Error processing module {module_name}: {e}")
|
| 42 |
+
|
| 43 |
+
return instances
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class LearnedHeuristicController:
|
| 47 |
+
"""
|
| 48 |
+
Class that finds and instantiates all learned heuristics. It also provides
|
| 49 |
+
a way to get the decision of a learned heuristic.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
existing_heuristics: Dict[str, List[LearnedHeuristic]] = defaultdict(list)
|
| 53 |
+
"""
|
| 54 |
+
A dictionary that stores all the learned heuristics for each optimization.
|
| 55 |
+
The key is the optimization name, and the value is a list of LearnedHeuristic objects.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
heuristics_initialized: bool = False
|
| 59 |
+
"""
|
| 60 |
+
A flag that indicates whether the learned heuristics have been initialized.
|
| 61 |
+
Set to true when the get_decision() function is called for the first time.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
metadata: AHMetadata,
|
| 67 |
+
context: AHContext,
|
| 68 |
+
) -> None:
|
| 69 |
+
self.metadata = metadata
|
| 70 |
+
self.context = context
|
| 71 |
+
|
| 72 |
+
def get_heuristics(self, name: str) -> List[LearnedHeuristic]:
|
| 73 |
+
"""
|
| 74 |
+
Returns a list of learned heuristics for the given optimization name.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
if not LearnedHeuristicController.heuristics_initialized:
|
| 78 |
+
# learned heuristics are generated into the following package
|
| 79 |
+
learned_heuristics_package = "torch._inductor.autoheuristic.artifacts"
|
| 80 |
+
|
| 81 |
+
# learned heuristics have to be of type LearnedHeuristic
|
| 82 |
+
base_class = LearnedHeuristic
|
| 83 |
+
found_heuristics = find_and_instantiate_subclasses(
|
| 84 |
+
learned_heuristics_package, base_class
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
for learned_heuristic in found_heuristics:
|
| 88 |
+
opt_name = learned_heuristic.get_name()
|
| 89 |
+
LearnedHeuristicController.existing_heuristics[opt_name].append(
|
| 90 |
+
learned_heuristic
|
| 91 |
+
)
|
| 92 |
+
LearnedHeuristicController.heuristics_initialized = True
|
| 93 |
+
|
| 94 |
+
return LearnedHeuristicController.existing_heuristics[name]
|
| 95 |
+
|
| 96 |
+
def get_decision(self) -> Optional[Choice]:
|
| 97 |
+
"""
|
| 98 |
+
Returns the decision made by the learned heuristic or None if no heuristic was found or the heuristic is unsure
|
| 99 |
+
which choice to make.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
heuristics = self.get_heuristics(self.metadata.name)
|
| 103 |
+
for heuristic in heuristics:
|
| 104 |
+
if heuristic.check_precondition(self.metadata, self.context):
|
| 105 |
+
return heuristic.get_decision(self.context, self.metadata.choices)
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
def get_decisions_ranked(self, top_k: int) -> Optional[List[Choice]]:
|
| 109 |
+
heuristics = self.get_heuristics(self.metadata.name)
|
| 110 |
+
for heuristic in heuristics:
|
| 111 |
+
if heuristic.check_precondition(self.metadata, self.context):
|
| 112 |
+
choices = heuristic.get_decisions_ranked(self.context)
|
| 113 |
+
if choices is None:
|
| 114 |
+
return None
|
| 115 |
+
avail_choices = [
|
| 116 |
+
choice for choice in choices if choice in self.metadata.choices
|
| 117 |
+
]
|
| 118 |
+
return avail_choices[:top_k]
|
| 119 |
+
return None
|
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 4 |
+
AHContext,
|
| 5 |
+
AHMetadata,
|
| 6 |
+
Choice,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LearnedHeuristic:
|
| 11 |
+
"""
|
| 12 |
+
LearnedHeuristic is a base class for all learned heuristics.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self) -> None:
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
def check_precondition(
|
| 19 |
+
self,
|
| 20 |
+
metadata: AHMetadata,
|
| 21 |
+
context: AHContext,
|
| 22 |
+
) -> bool:
|
| 23 |
+
return True
|
| 24 |
+
|
| 25 |
+
def get_decision(
|
| 26 |
+
self, context: AHContext, choices: List[Choice]
|
| 27 |
+
) -> Optional[Choice]:
|
| 28 |
+
return None
|
| 29 |
+
|
| 30 |
+
def get_confidence_threshold(self) -> float:
|
| 31 |
+
return 1.0
|
| 32 |
+
|
| 33 |
+
def get_name(self) -> str:
|
| 34 |
+
return ""
|
| 35 |
+
|
| 36 |
+
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class LearnedHeuristicRegression(LearnedHeuristic):
|
| 41 |
+
def __init__(self) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
def get_feedback(self, context: AHContext, choice: Choice) -> float:
|
| 45 |
+
return 1.0
|
| 46 |
+
|
| 47 |
+
def get_decision(
|
| 48 |
+
self, context: AHContext, choices: List[Choice]
|
| 49 |
+
) -> Optional[Choice]:
|
| 50 |
+
choice2feedback = {}
|
| 51 |
+
for choice in choices:
|
| 52 |
+
predicted_feedback = self.get_feedback(context, choice)
|
| 53 |
+
choice2feedback[choice] = predicted_feedback
|
| 54 |
+
sorted_choices_feedback = sorted(choice2feedback.items(), key=lambda t: t[1])
|
| 55 |
+
highest_feedback = sorted_choices_feedback[-1][1]
|
| 56 |
+
second_highest_feedback = sorted_choices_feedback[-2][1]
|
| 57 |
+
if highest_feedback / second_highest_feedback > self.get_confidence_threshold():
|
| 58 |
+
return sorted_choices_feedback[-1][0]
|
| 59 |
+
# We are not sure which choice is the best one
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class LearnedHeuristicDecision(LearnedHeuristic):
|
| 64 |
+
def __init__(self) -> None:
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
def get_choice(self, idx: int) -> Optional[str]:
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
def get_decision(
|
| 71 |
+
self, context: AHContext, choices: List[Choice]
|
| 72 |
+
) -> Optional[Choice]:
|
| 73 |
+
best_choices = self.get_best_choices(context)
|
| 74 |
+
if not best_choices:
|
| 75 |
+
return None
|
| 76 |
+
(best_choice_proba, best_choice_idx) = best_choices[0]
|
| 77 |
+
if best_choice_proba <= self.get_confidence_threshold():
|
| 78 |
+
return None
|
| 79 |
+
return self.get_choice(best_choice_idx)
|
| 80 |
+
|
| 81 |
+
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
|
| 82 |
+
feedback_idx_list = self.get_best_choices(context)
|
| 83 |
+
if feedback_idx_list is None:
|
| 84 |
+
return None
|
| 85 |
+
choices = [
|
| 86 |
+
self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list
|
| 87 |
+
]
|
| 88 |
+
choices = [choice for choice in choices if choice is not None]
|
| 89 |
+
return choices
|
| 90 |
+
|
| 91 |
+
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
|
| 92 |
+
return []
|
.venv/lib/python3.11/site-packages/torch/_inductor/autotune_process.py
ADDED
|
@@ -0,0 +1,876 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import contextlib
|
| 5 |
+
import ctypes
|
| 6 |
+
import dataclasses
|
| 7 |
+
import functools
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import queue
|
| 11 |
+
import time
|
| 12 |
+
import warnings
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 14 |
+
from ctypes import byref, c_size_t, c_void_p, CDLL
|
| 15 |
+
from typing import (
|
| 16 |
+
Any,
|
| 17 |
+
Callable,
|
| 18 |
+
Dict,
|
| 19 |
+
Iterable,
|
| 20 |
+
List,
|
| 21 |
+
Optional,
|
| 22 |
+
Sequence,
|
| 23 |
+
TYPE_CHECKING,
|
| 24 |
+
Union,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
| 29 |
+
from torch import multiprocessing
|
| 30 |
+
from torch._dynamo.testing import rand_strided
|
| 31 |
+
from torch._inductor import ir
|
| 32 |
+
from torch._inductor.codecache import (
|
| 33 |
+
CppCodeCache,
|
| 34 |
+
CUDACodeCache,
|
| 35 |
+
DLLWrapper,
|
| 36 |
+
get_hash,
|
| 37 |
+
PyCodeCache,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if TYPE_CHECKING:
|
| 42 |
+
from multiprocessing.process import BaseProcess
|
| 43 |
+
from multiprocessing.queues import Queue
|
| 44 |
+
from types import ModuleType
|
| 45 |
+
|
| 46 |
+
from torch._inductor.select_algorithm import TritonTemplateCaller
|
| 47 |
+
|
| 48 |
+
from . import config
|
| 49 |
+
from .runtime.benchmarking import benchmarker
|
| 50 |
+
from .virtualized import V
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
|
| 54 |
+
EXIT_HANDLER_REGISTERED = False
|
| 55 |
+
|
| 56 |
+
log = logging.getLogger(__name__)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Used to synchronize between parent and child processes
|
| 60 |
+
class Ping:
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Pong:
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class NonzeroWorkspaceNotSupportedError(Exception):
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@contextlib.contextmanager
|
| 73 |
+
def set_cuda_visible_device(device: Optional[int]):
|
| 74 |
+
"""
|
| 75 |
+
Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the
|
| 76 |
+
specified single device. If device is None, don't manipulate the environment.
|
| 77 |
+
"""
|
| 78 |
+
if device is None:
|
| 79 |
+
yield
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
current = os.environ.get(CUDA_VISIBLE_DEVICES)
|
| 83 |
+
os.environ[CUDA_VISIBLE_DEVICES] = str(device)
|
| 84 |
+
try:
|
| 85 |
+
yield
|
| 86 |
+
finally:
|
| 87 |
+
if current is None:
|
| 88 |
+
del os.environ[CUDA_VISIBLE_DEVICES]
|
| 89 |
+
else:
|
| 90 |
+
os.environ[CUDA_VISIBLE_DEVICES] = current
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@dataclasses.dataclass
|
| 94 |
+
class TuningProcess:
|
| 95 |
+
"""
|
| 96 |
+
Abstraction for launching a helper process to benchmark kernels. Spawns
|
| 97 |
+
the parent process and uses multiprocessing queues to send benchmark
|
| 98 |
+
requests and return results.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
device: Optional[int] = None
|
| 102 |
+
process: Optional[BaseProcess] = None
|
| 103 |
+
request_queue: Optional[Queue[Any]] = None
|
| 104 |
+
response_queue: Optional[Queue[Any]] = None
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def process_main(
|
| 108 |
+
request_queue: Queue[Any],
|
| 109 |
+
response_queue: Queue[Any],
|
| 110 |
+
) -> None:
|
| 111 |
+
"""
|
| 112 |
+
Entry point for the child process.
|
| 113 |
+
"""
|
| 114 |
+
log.debug(
|
| 115 |
+
"Entering TuningProcess child. Visible devices = %s",
|
| 116 |
+
os.environ.get(CUDA_VISIBLE_DEVICES),
|
| 117 |
+
)
|
| 118 |
+
try:
|
| 119 |
+
TuningProcess.workloop(request_queue, response_queue)
|
| 120 |
+
except Exception as ex:
|
| 121 |
+
log.exception("Exception in TuningProcess")
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
|
| 125 |
+
"""
|
| 126 |
+
Work loop for the benchmarking subprocess.
|
| 127 |
+
"""
|
| 128 |
+
while True:
|
| 129 |
+
obj = request_queue.get()
|
| 130 |
+
|
| 131 |
+
if obj is None:
|
| 132 |
+
break # None is a sentinel for the child to terminate
|
| 133 |
+
elif isinstance(obj, Ping):
|
| 134 |
+
response_queue.put(Pong())
|
| 135 |
+
elif isinstance(obj, BenchmarkRequest):
|
| 136 |
+
response_queue.put(obj.benchmark())
|
| 137 |
+
else:
|
| 138 |
+
raise RuntimeError(f"Invalid request type {type(obj)}")
|
| 139 |
+
|
| 140 |
+
def valid(self) -> bool:
|
| 141 |
+
"""
|
| 142 |
+
True if the sub-process has been initialized.
|
| 143 |
+
"""
|
| 144 |
+
return (
|
| 145 |
+
self.process is not None
|
| 146 |
+
and self.request_queue is not None
|
| 147 |
+
and self.response_queue is not None
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def clear(self) -> None:
|
| 151 |
+
"""
|
| 152 |
+
Reset to an uninitialized state.
|
| 153 |
+
"""
|
| 154 |
+
self.process = self.request_queue = self.response_queue = None
|
| 155 |
+
|
| 156 |
+
def initialize(self) -> None:
|
| 157 |
+
"""
|
| 158 |
+
Create child process, request/response queues, and do the warm up.
|
| 159 |
+
Set the environment to make only the provided GPU device visible
|
| 160 |
+
to the process.
|
| 161 |
+
"""
|
| 162 |
+
if self.valid():
|
| 163 |
+
return
|
| 164 |
+
|
| 165 |
+
# cuda runtime does not work with "fork", use "spawn" to start processes.
|
| 166 |
+
ctx = multiprocessing.get_context("spawn")
|
| 167 |
+
self.request_queue = ctx.Queue()
|
| 168 |
+
self.response_queue = ctx.Queue()
|
| 169 |
+
|
| 170 |
+
self.process = ctx.Process(
|
| 171 |
+
target=self.process_main,
|
| 172 |
+
args=(
|
| 173 |
+
self.request_queue,
|
| 174 |
+
self.response_queue,
|
| 175 |
+
),
|
| 176 |
+
)
|
| 177 |
+
assert self.process is not None
|
| 178 |
+
with set_cuda_visible_device(self.device):
|
| 179 |
+
self.process.start()
|
| 180 |
+
|
| 181 |
+
def put(self, obj: Any) -> None:
|
| 182 |
+
"""
|
| 183 |
+
Push a work item to the child process.
|
| 184 |
+
"""
|
| 185 |
+
# In case of a prior crash, ensure the subprocess is running
|
| 186 |
+
self.initialize()
|
| 187 |
+
assert self.request_queue is not None
|
| 188 |
+
self.request_queue.put(obj)
|
| 189 |
+
|
| 190 |
+
def get(
|
| 191 |
+
self, result_timeout=120.0, graceful_timeout=3.0, terminate_timeout=1.0
|
| 192 |
+
) -> Any:
|
| 193 |
+
"""
|
| 194 |
+
Get a response from the child process. Raises queue.Empty on timeout
|
| 195 |
+
or if the process dies.
|
| 196 |
+
|
| 197 |
+
This method is (so far) only used by TuningProcessPool, where torch._inductor.config entries are being used
|
| 198 |
+
to populate the timeouts:
|
| 199 |
+
|
| 200 |
+
Arguments:
|
| 201 |
+
|
| 202 |
+
@param result_timeout: Timeout in seconds, defaults to 120.0 or to
|
| 203 |
+
config.max_autotune_subproc_result_timeout_seconds when called by TuningProcessPool
|
| 204 |
+
@param graceful_timeout: Timeout in seconds to allow graceful shutdown (SIGTERM is sent after this time).
|
| 205 |
+
Defaults to 3.0 or to config.max_autotune_subproc_graceful_timeout_seconds
|
| 206 |
+
@param terminate_timeout: Timeout in seconds after SIGTERM, until we send SIGKILL if the process
|
| 207 |
+
remains alive. Defaults to 1.0 or to
|
| 208 |
+
config.max_autotune_subproc_terminate_timeout_seconds.
|
| 209 |
+
Returns:
|
| 210 |
+
A response from the child process (Any type)
|
| 211 |
+
"""
|
| 212 |
+
assert self.process is not None
|
| 213 |
+
assert self.response_queue is not None
|
| 214 |
+
while True:
|
| 215 |
+
try:
|
| 216 |
+
remaining_timeout = result_timeout
|
| 217 |
+
res = None
|
| 218 |
+
while remaining_timeout is not None and remaining_timeout >= 1.0:
|
| 219 |
+
remaining_timeout -= 0.5
|
| 220 |
+
try:
|
| 221 |
+
res = self.response_queue.get(timeout=0.5)
|
| 222 |
+
break
|
| 223 |
+
except queue.Empty:
|
| 224 |
+
if not self.process.is_alive():
|
| 225 |
+
raise # is being caught a few lines below
|
| 226 |
+
if res is None:
|
| 227 |
+
res = self.response_queue.get(timeout=remaining_timeout)
|
| 228 |
+
return res
|
| 229 |
+
except queue.Empty:
|
| 230 |
+
status = self.process.exitcode
|
| 231 |
+
if status is None:
|
| 232 |
+
self.kill(
|
| 233 |
+
graceful_timeout=graceful_timeout,
|
| 234 |
+
terminate_timeout=terminate_timeout,
|
| 235 |
+
)
|
| 236 |
+
else:
|
| 237 |
+
# child process crashed
|
| 238 |
+
self.clear()
|
| 239 |
+
raise
|
| 240 |
+
|
| 241 |
+
def terminate(self) -> None:
|
| 242 |
+
"""
|
| 243 |
+
Signal the child process to terminate.
|
| 244 |
+
"""
|
| 245 |
+
if self.valid():
|
| 246 |
+
assert self.process is not None
|
| 247 |
+
assert self.request_queue is not None
|
| 248 |
+
self.request_queue.put(None)
|
| 249 |
+
|
| 250 |
+
def wait(self) -> None:
|
| 251 |
+
"""
|
| 252 |
+
Wait for the child process to exit.
|
| 253 |
+
"""
|
| 254 |
+
if self.process is not None:
|
| 255 |
+
self.process.join()
|
| 256 |
+
self.clear()
|
| 257 |
+
|
| 258 |
+
def kill(self, graceful_timeout=5.0, terminate_timeout=1.0) -> None:
|
| 259 |
+
# Tries to kill the process, using a graceful_timeout in which the process
|
| 260 |
+
# is allowed to exit gracefully. If the process is still alive,
|
| 261 |
+
# it will be terminated. If that is not sufficient to end it
|
| 262 |
+
# within terminate_timeout seconds, it will be killed.
|
| 263 |
+
if self.process is not None:
|
| 264 |
+
self.terminate()
|
| 265 |
+
self.process.join(timeout=graceful_timeout)
|
| 266 |
+
if self.process.is_alive():
|
| 267 |
+
log.warning(
|
| 268 |
+
"Sending SIGTERM to process with PID %d",
|
| 269 |
+
self.process.pid,
|
| 270 |
+
)
|
| 271 |
+
self.process.terminate()
|
| 272 |
+
self.process.join(timeout=terminate_timeout)
|
| 273 |
+
if self.process.is_alive():
|
| 274 |
+
log.error(
|
| 275 |
+
"Sending SIGKILL to process with PID %d",
|
| 276 |
+
self.process.pid,
|
| 277 |
+
)
|
| 278 |
+
self.process.kill() # This should definitely end the process
|
| 279 |
+
self.clear()
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
@dataclasses.dataclass
|
| 283 |
+
class TuningProcessPool:
|
| 284 |
+
"""
|
| 285 |
+
Maintains a pool of TuningProcesses to benchmark kernels in parallel
|
| 286 |
+
across devices. By default, we create one TuningProcess per device and
|
| 287 |
+
set the sub-process environment to make only that device visible.
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
processes: Optional[queue.Queue[TuningProcess]] = None
|
| 291 |
+
executor: Optional[ThreadPoolExecutor] = None
|
| 292 |
+
|
| 293 |
+
def initialize(self) -> None:
|
| 294 |
+
"""
|
| 295 |
+
Start the child processes.
|
| 296 |
+
"""
|
| 297 |
+
assert (self.processes is None) == (self.executor is None)
|
| 298 |
+
if self.processes is not None:
|
| 299 |
+
return
|
| 300 |
+
|
| 301 |
+
devices = self.get_device_list()
|
| 302 |
+
log.debug("Sub-process autotune device list: %s", devices)
|
| 303 |
+
|
| 304 |
+
# Launch the child processes and push a msg to "warm up"
|
| 305 |
+
self.processes = queue.Queue()
|
| 306 |
+
for device in devices:
|
| 307 |
+
p = TuningProcess(device=device)
|
| 308 |
+
p.initialize()
|
| 309 |
+
p.put(Ping())
|
| 310 |
+
self.processes.put(p)
|
| 311 |
+
|
| 312 |
+
# Wait for the initialization to finish
|
| 313 |
+
for p in self.processes.queue:
|
| 314 |
+
assert isinstance(p.get(result_timeout=None), Pong)
|
| 315 |
+
|
| 316 |
+
# Use a thread pool to manage distributing work to the subprocesses.
|
| 317 |
+
# Threads block on an available process, so it makes sense to match
|
| 318 |
+
# the number of threads with the number of devices.
|
| 319 |
+
self.executor = ThreadPoolExecutor(max_workers=len(devices))
|
| 320 |
+
|
| 321 |
+
# Register the exit handler for the parent process so it will terminate
|
| 322 |
+
# the child processes.
|
| 323 |
+
global EXIT_HANDLER_REGISTERED
|
| 324 |
+
if not EXIT_HANDLER_REGISTERED:
|
| 325 |
+
EXIT_HANDLER_REGISTERED = True
|
| 326 |
+
import atexit
|
| 327 |
+
|
| 328 |
+
atexit.register(self.terminate)
|
| 329 |
+
|
| 330 |
+
def get_device_list(self) -> Sequence[Optional[int]]:
|
| 331 |
+
"""
|
| 332 |
+
Gather the list of devices to be used in the pool.
|
| 333 |
+
"""
|
| 334 |
+
if not config.autotune_multi_device:
|
| 335 |
+
# Don't use multiple devices
|
| 336 |
+
return [None]
|
| 337 |
+
|
| 338 |
+
count = torch.cuda.device_count()
|
| 339 |
+
|
| 340 |
+
# If the user specified the visible devices in the env, use those.
|
| 341 |
+
if CUDA_VISIBLE_DEVICES in os.environ:
|
| 342 |
+
devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")]
|
| 343 |
+
assert len(devices) <= count
|
| 344 |
+
return devices
|
| 345 |
+
|
| 346 |
+
return list(range(count))
|
| 347 |
+
|
| 348 |
+
def terminate(self) -> None:
|
| 349 |
+
"""
|
| 350 |
+
Signal all child processes to terminate.
|
| 351 |
+
"""
|
| 352 |
+
if self.executor is not None:
|
| 353 |
+
self.executor.shutdown()
|
| 354 |
+
self.executor = None
|
| 355 |
+
|
| 356 |
+
if self.processes is not None:
|
| 357 |
+
for p in self.processes.queue:
|
| 358 |
+
p.terminate()
|
| 359 |
+
for p in self.processes.queue:
|
| 360 |
+
p.wait()
|
| 361 |
+
self.processes = None
|
| 362 |
+
|
| 363 |
+
def target(self, choice: TritonTemplateCaller) -> float:
|
| 364 |
+
"""
|
| 365 |
+
Entry point for the thread-pool helper threads: Wait for an open TuningProcess,
|
| 366 |
+
remove it from the queue, execute the benchmark in that subprocess, and return
|
| 367 |
+
the TuningProcess to the queue.
|
| 368 |
+
"""
|
| 369 |
+
assert choice.bmreq is not None
|
| 370 |
+
assert self.processes is not None
|
| 371 |
+
|
| 372 |
+
process = self.processes.get()
|
| 373 |
+
process.put(choice.bmreq)
|
| 374 |
+
try:
|
| 375 |
+
return process.get(
|
| 376 |
+
config.max_autotune_subproc_result_timeout_seconds,
|
| 377 |
+
config.max_autotune_subproc_graceful_timeout_seconds,
|
| 378 |
+
config.max_autotune_subproc_terminate_timeout_seconds,
|
| 379 |
+
)
|
| 380 |
+
except queue.Empty:
|
| 381 |
+
warnings.warn(
|
| 382 |
+
f"Failed to benchmark choice '{choice}'. It will be ignored. "
|
| 383 |
+
"Please debug the root cause in case the choice can bring perf gains."
|
| 384 |
+
)
|
| 385 |
+
# set to INF so this choice will be ignored
|
| 386 |
+
return float("inf")
|
| 387 |
+
finally:
|
| 388 |
+
self.processes.put(process)
|
| 389 |
+
|
| 390 |
+
def benchmark(
|
| 391 |
+
self,
|
| 392 |
+
choices: List[TritonTemplateCaller],
|
| 393 |
+
) -> Dict[TritonTemplateCaller, float]:
|
| 394 |
+
"""
|
| 395 |
+
Benchmark each choice in a separate process.
|
| 396 |
+
"""
|
| 397 |
+
assert self.processes is not None, "Tuning process pool is not initialized"
|
| 398 |
+
assert self.executor is not None
|
| 399 |
+
|
| 400 |
+
results = {}
|
| 401 |
+
|
| 402 |
+
# Use a ThreadExecutorPool to spread the work across the subprocesses and
|
| 403 |
+
# to grab subprocesses as soon as they're free.
|
| 404 |
+
for choice, result in zip(choices, self.executor.map(self.target, choices)):
|
| 405 |
+
results[choice] = result
|
| 406 |
+
|
| 407 |
+
return results
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
tuning_pool = TuningProcessPool()
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
LayoutOrBuffer = Union[ir.Layout, ir.Buffer]
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
@dataclasses.dataclass
|
| 417 |
+
class TensorMeta:
|
| 418 |
+
device: torch.device
|
| 419 |
+
dtype: torch.dtype
|
| 420 |
+
sizes: torch._prims_common.ShapeType
|
| 421 |
+
strides: torch._prims_common.StrideType
|
| 422 |
+
offset: int
|
| 423 |
+
name: Optional[str] = None
|
| 424 |
+
|
| 425 |
+
@classmethod
|
| 426 |
+
def from_irnodes(
|
| 427 |
+
cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
|
| 428 |
+
) -> Union[TensorMeta, List[TensorMeta]]:
|
| 429 |
+
if isinstance(irnodes, Sequence):
|
| 430 |
+
result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
|
| 431 |
+
assert all(isinstance(x, TensorMeta) for x in result)
|
| 432 |
+
return result
|
| 433 |
+
|
| 434 |
+
node = irnodes
|
| 435 |
+
if isinstance(node, ir.Layout):
|
| 436 |
+
node = ir.Buffer("fake", node)
|
| 437 |
+
|
| 438 |
+
dtype = node.get_dtype()
|
| 439 |
+
assert dtype is not None
|
| 440 |
+
|
| 441 |
+
return TensorMeta(
|
| 442 |
+
device=node.get_device(),
|
| 443 |
+
dtype=dtype,
|
| 444 |
+
sizes=V.graph.sizevars.size_hints(
|
| 445 |
+
node.get_size(),
|
| 446 |
+
fallback=config.unbacked_symint_fallback,
|
| 447 |
+
),
|
| 448 |
+
strides=V.graph.sizevars.size_hints(
|
| 449 |
+
node.get_stride(),
|
| 450 |
+
fallback=config.unbacked_symint_fallback,
|
| 451 |
+
),
|
| 452 |
+
offset=V.graph.sizevars.size_hint(
|
| 453 |
+
node.get_layout().offset,
|
| 454 |
+
fallback=config.unbacked_symint_fallback,
|
| 455 |
+
),
|
| 456 |
+
name=node.get_name(),
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
def to_tensor(self) -> torch.Tensor:
|
| 460 |
+
return rand_strided(
|
| 461 |
+
self.sizes,
|
| 462 |
+
self.strides,
|
| 463 |
+
device=self.device,
|
| 464 |
+
dtype=self.dtype,
|
| 465 |
+
extra_size=self.offset,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
@dataclasses.dataclass
|
| 470 |
+
class BenchmarkRequest:
|
| 471 |
+
"""
|
| 472 |
+
Only handle triton template benchmark for now. The extern kernel benchmark
|
| 473 |
+
can be done inside the same process since they usually don't cause crash.
|
| 474 |
+
|
| 475 |
+
Important: Instances of this class and subclasses have to be serializable
|
| 476 |
+
across process boundaries. Do not put CUDA Tensors in here!
|
| 477 |
+
"""
|
| 478 |
+
|
| 479 |
+
def __init__(
|
| 480 |
+
self,
|
| 481 |
+
kernel_name: str,
|
| 482 |
+
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 483 |
+
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 484 |
+
extra_args: Iterable[Any],
|
| 485 |
+
) -> None:
|
| 486 |
+
# the kernel name defined in the module
|
| 487 |
+
self.kernel_name = kernel_name
|
| 488 |
+
|
| 489 |
+
if isinstance(input_tensor_meta, TensorMeta):
|
| 490 |
+
input_tensor_meta = [input_tensor_meta]
|
| 491 |
+
self.input_tensor_meta = input_tensor_meta
|
| 492 |
+
|
| 493 |
+
if isinstance(output_tensor_meta, (tuple, list)):
|
| 494 |
+
assert len(output_tensor_meta) == 1
|
| 495 |
+
output_tensor_meta = output_tensor_meta[0]
|
| 496 |
+
self.output_tensor_meta = output_tensor_meta
|
| 497 |
+
|
| 498 |
+
self.extra_args = extra_args
|
| 499 |
+
|
| 500 |
+
def make_run_fn(
|
| 501 |
+
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
|
| 502 |
+
) -> Callable[[], None]:
|
| 503 |
+
raise NotImplementedError
|
| 504 |
+
|
| 505 |
+
def cleanup_run_fn(self) -> None:
|
| 506 |
+
pass
|
| 507 |
+
|
| 508 |
+
def do_bench(
|
| 509 |
+
self,
|
| 510 |
+
fn,
|
| 511 |
+
*input_tensors: torch.Tensor,
|
| 512 |
+
output_tensor: Optional[torch.Tensor] = None,
|
| 513 |
+
) -> float:
|
| 514 |
+
raise NotImplementedError
|
| 515 |
+
|
| 516 |
+
def benchmark(
|
| 517 |
+
self,
|
| 518 |
+
*input_tensors: torch.Tensor,
|
| 519 |
+
output_tensor: Optional[torch.Tensor] = None,
|
| 520 |
+
) -> float:
|
| 521 |
+
debug = log.isEnabledFor(logging.DEBUG)
|
| 522 |
+
if debug:
|
| 523 |
+
start_ts = time.time()
|
| 524 |
+
|
| 525 |
+
# create args and out tensor
|
| 526 |
+
if output_tensor is None:
|
| 527 |
+
assert len(input_tensors) == 0
|
| 528 |
+
input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta)
|
| 529 |
+
output_tensor = self.output_tensor_meta.to_tensor()
|
| 530 |
+
|
| 531 |
+
if debug:
|
| 532 |
+
create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
|
| 533 |
+
start_ts = time.time()
|
| 534 |
+
try:
|
| 535 |
+
fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
|
| 536 |
+
except NonzeroWorkspaceNotSupportedError:
|
| 537 |
+
# Skipping all ops with nonzero workspace requirements
|
| 538 |
+
log.info("Skipping op due to nonzero workspace requirement")
|
| 539 |
+
return float("inf")
|
| 540 |
+
|
| 541 |
+
if debug:
|
| 542 |
+
load_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
|
| 543 |
+
start_ts = time.time()
|
| 544 |
+
|
| 545 |
+
out = self.do_bench(fn, *input_tensors, output_tensor)
|
| 546 |
+
|
| 547 |
+
if debug:
|
| 548 |
+
bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
|
| 549 |
+
log.debug(
|
| 550 |
+
"InChildProcess %s: load %f, create tensor %f, bench %f",
|
| 551 |
+
str(self),
|
| 552 |
+
load_elapse, # type: ignore[possibly-undefined]
|
| 553 |
+
create_tensor_elapse, # type: ignore[possibly-undefined]
|
| 554 |
+
bench_elapse,
|
| 555 |
+
)
|
| 556 |
+
self.cleanup_run_fn()
|
| 557 |
+
return out
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class TestBenchmarkRequest(BenchmarkRequest):
|
| 561 |
+
"""
|
| 562 |
+
Supports unit testing. Defined in this file so that the TuningProcess
|
| 563 |
+
sub-process knows how to unpickle these objects.
|
| 564 |
+
"""
|
| 565 |
+
|
| 566 |
+
def __init__(self, value: Optional[float] = None) -> None:
|
| 567 |
+
self.value = value
|
| 568 |
+
|
| 569 |
+
def benchmark(
|
| 570 |
+
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
|
| 571 |
+
) -> float:
|
| 572 |
+
if self.value is None:
|
| 573 |
+
raise Exception("Failed to run") # noqa: TRY002
|
| 574 |
+
return self.value
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
class GPUDeviceBenchmarkRequest(BenchmarkRequest):
|
| 578 |
+
def do_bench(
|
| 579 |
+
self,
|
| 580 |
+
fn,
|
| 581 |
+
*input_tensors: torch.Tensor,
|
| 582 |
+
output_tensor: Optional[torch.Tensor] = None,
|
| 583 |
+
) -> float:
|
| 584 |
+
device_idx_set = {
|
| 585 |
+
tensor.device.index
|
| 586 |
+
for tensor in [*input_tensors, output_tensor]
|
| 587 |
+
if isinstance(tensor, torch.Tensor)
|
| 588 |
+
and tensor.is_cuda
|
| 589 |
+
and tensor.device.index is not None
|
| 590 |
+
}
|
| 591 |
+
assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}"
|
| 592 |
+
if len(device_idx_set) == 1:
|
| 593 |
+
device_idx = next(iter(device_idx_set))
|
| 594 |
+
else:
|
| 595 |
+
device_idx = torch.cuda.current_device()
|
| 596 |
+
|
| 597 |
+
with torch.cuda.device(device_idx):
|
| 598 |
+
out = benchmarker.benchmark_gpu(fn)
|
| 599 |
+
torch.cuda.synchronize() # shake out any CUDA errors
|
| 600 |
+
|
| 601 |
+
return out
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest):
|
| 605 |
+
# Important: Instances of this class have to be serializable
|
| 606 |
+
# across process boundaries. Do not put CUDA Tensors in here!
|
| 607 |
+
def __init__(
|
| 608 |
+
self,
|
| 609 |
+
kernel_name: str,
|
| 610 |
+
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 611 |
+
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 612 |
+
extra_args: Iterable[Any],
|
| 613 |
+
module_path: str, # the path of the module defining the triton kernel
|
| 614 |
+
module_cache_key: str,
|
| 615 |
+
grid: List[int],
|
| 616 |
+
num_stages: int,
|
| 617 |
+
num_warps: int,
|
| 618 |
+
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
|
| 619 |
+
) -> None:
|
| 620 |
+
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
|
| 621 |
+
self.module_path = module_path
|
| 622 |
+
self.module_cache_key = module_cache_key
|
| 623 |
+
self.grid = grid
|
| 624 |
+
self.num_stages = num_stages
|
| 625 |
+
self.num_warps = num_warps
|
| 626 |
+
self.matrix_instr_nonkdim = matrix_instr_nonkdim
|
| 627 |
+
|
| 628 |
+
def make_run_fn(
|
| 629 |
+
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
|
| 630 |
+
) -> Callable[[], None]:
|
| 631 |
+
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
|
| 632 |
+
log.debug(
|
| 633 |
+
"benchmark module key: %s, path: %s",
|
| 634 |
+
self.module_cache_key,
|
| 635 |
+
self.module_path,
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
run_method = getattr(mod, self.kernel_name).run
|
| 639 |
+
extra_args = list(self.extra_args)
|
| 640 |
+
|
| 641 |
+
# Newer version of triton add warmup argument to JITFunction.run.
|
| 642 |
+
# This code handles backward-compatibility.
|
| 643 |
+
warmup_arg = {}
|
| 644 |
+
import inspect
|
| 645 |
+
|
| 646 |
+
if "warmup" in inspect.signature(run_method).parameters:
|
| 647 |
+
warmup_arg["warmup"] = False
|
| 648 |
+
|
| 649 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 650 |
+
|
| 651 |
+
if torch.version.hip and self.matrix_instr_nonkdim != 0:
|
| 652 |
+
return functools.partial(
|
| 653 |
+
run_method,
|
| 654 |
+
*input_tensors,
|
| 655 |
+
output_tensor,
|
| 656 |
+
*self.extra_args,
|
| 657 |
+
grid=self.grid,
|
| 658 |
+
**warmup_arg,
|
| 659 |
+
stream=get_raw_stream(self.output_tensor_meta.device.index),
|
| 660 |
+
)
|
| 661 |
+
else:
|
| 662 |
+
return functools.partial(
|
| 663 |
+
run_method,
|
| 664 |
+
*input_tensors,
|
| 665 |
+
output_tensor,
|
| 666 |
+
*self.extra_args,
|
| 667 |
+
grid=self.grid,
|
| 668 |
+
**warmup_arg,
|
| 669 |
+
stream=get_raw_stream(self.output_tensor_meta.device.index),
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
def precompile(self):
|
| 673 |
+
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
|
| 674 |
+
getattr(mod, self.kernel_name).precompile()
|
| 675 |
+
|
| 676 |
+
def __str__(self) -> str:
|
| 677 |
+
return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest):
|
| 681 |
+
# Important: Instances of this class have to be serializable
|
| 682 |
+
# across process boundaries. Do not put CUDA Tensors in here!
|
| 683 |
+
|
| 684 |
+
def __init__(
|
| 685 |
+
self,
|
| 686 |
+
kernel_name: str,
|
| 687 |
+
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 688 |
+
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 689 |
+
extra_args: Iterable[Any],
|
| 690 |
+
source_code: str,
|
| 691 |
+
) -> None:
|
| 692 |
+
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
|
| 693 |
+
self.source_code = source_code
|
| 694 |
+
self.workspace_size: int = 0
|
| 695 |
+
self.workspace: Optional[torch.Tensor] = None
|
| 696 |
+
self.DLL: Optional[DLLWrapper] = None
|
| 697 |
+
self._workspace_size_updated = False
|
| 698 |
+
self.hash_key: str = ""
|
| 699 |
+
self.source_file: str = ""
|
| 700 |
+
self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
|
| 701 |
+
|
| 702 |
+
def precompile(self):
|
| 703 |
+
# Prepopulate CUDACodeCache
|
| 704 |
+
# may happen in separate Threadpool
|
| 705 |
+
log.debug("Precompiling %s", self)
|
| 706 |
+
CUDACodeCache.compile(self.source_code, "so")
|
| 707 |
+
log.debug("Done precompiling %s", self)
|
| 708 |
+
|
| 709 |
+
def make_run_fn(
|
| 710 |
+
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
|
| 711 |
+
) -> Callable[[], None]:
|
| 712 |
+
self.ensure_dll_loaded()
|
| 713 |
+
self.update_workspace_size()
|
| 714 |
+
args = [
|
| 715 |
+
c_void_p(tensor.data_ptr())
|
| 716 |
+
for tensor in list(input_tensors) + [output_tensor]
|
| 717 |
+
]
|
| 718 |
+
log.debug(
|
| 719 |
+
"make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
|
| 720 |
+
self.kernel_name,
|
| 721 |
+
self.source_file,
|
| 722 |
+
self.hash_key,
|
| 723 |
+
self.DLL,
|
| 724 |
+
args,
|
| 725 |
+
self.extra_args,
|
| 726 |
+
)
|
| 727 |
+
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
|
| 728 |
+
run_method = getattr(self.DLL, self.kernel_name)
|
| 729 |
+
workspace_ptr = c_void_p(0)
|
| 730 |
+
if self.workspace_size > 0:
|
| 731 |
+
self.workspace = torch.zeros(
|
| 732 |
+
(self.workspace_size + 7) // 8,
|
| 733 |
+
dtype=torch.float64,
|
| 734 |
+
device=output_tensor.device,
|
| 735 |
+
)
|
| 736 |
+
workspace_ptr = c_void_p(self.workspace.data_ptr())
|
| 737 |
+
|
| 738 |
+
# Generate partial function.
|
| 739 |
+
return functools.partial(
|
| 740 |
+
run_method,
|
| 741 |
+
*args,
|
| 742 |
+
*self.extra_args,
|
| 743 |
+
None, # null workspace size ptr
|
| 744 |
+
workspace_ptr, # set workspace ptr,
|
| 745 |
+
stream_ptr,
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
def update_workspace_size(self) -> None:
|
| 749 |
+
if self._workspace_size_updated:
|
| 750 |
+
return
|
| 751 |
+
self.ensure_dll_loaded()
|
| 752 |
+
unique_input_count = len({meta.name for meta in self.input_tensor_meta})
|
| 753 |
+
args = [c_void_p(None) for _ in range(unique_input_count + 1)]
|
| 754 |
+
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
|
| 755 |
+
|
| 756 |
+
run_method = getattr(self.DLL, self.kernel_name)
|
| 757 |
+
# Retrieve workspace_size and initialize workspace.
|
| 758 |
+
c_workspace_size = c_size_t()
|
| 759 |
+
run_method(
|
| 760 |
+
*args, # input ptrs and output ptrs
|
| 761 |
+
*self.extra_args,
|
| 762 |
+
byref(
|
| 763 |
+
c_workspace_size
|
| 764 |
+
), # set workspace size ptr to retrieve workspace size
|
| 765 |
+
None, # null workspace ptr
|
| 766 |
+
stream_ptr,
|
| 767 |
+
)
|
| 768 |
+
torch.cuda.synchronize() # shake out any CUDA errors
|
| 769 |
+
self.workspace_size = c_workspace_size.value
|
| 770 |
+
log.debug(
|
| 771 |
+
"update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950
|
| 772 |
+
self.workspace_size,
|
| 773 |
+
self.kernel_name,
|
| 774 |
+
self.source_file,
|
| 775 |
+
self.hash_key,
|
| 776 |
+
self.DLL,
|
| 777 |
+
args,
|
| 778 |
+
self.extra_args,
|
| 779 |
+
)
|
| 780 |
+
self._workspace_size_updated = True
|
| 781 |
+
|
| 782 |
+
def ensure_dll_loaded(self):
|
| 783 |
+
if self.DLL is None:
|
| 784 |
+
self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
|
| 785 |
+
self.source_code, "so"
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
def cleanup_run_fn(self) -> None:
|
| 789 |
+
if self.DLL is not None:
|
| 790 |
+
self.DLL.close()
|
| 791 |
+
self.workspace = None
|
| 792 |
+
|
| 793 |
+
def __str__(self) -> str:
|
| 794 |
+
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
class CPUDeviceBenchmarkRequest(BenchmarkRequest):
|
| 798 |
+
def do_bench(
|
| 799 |
+
self,
|
| 800 |
+
fn,
|
| 801 |
+
*input_tensors: torch.Tensor,
|
| 802 |
+
output_tensor: Optional[torch.Tensor] = None,
|
| 803 |
+
) -> float:
|
| 804 |
+
return benchmarker.benchmark_cpu(fn)
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
class CppBenchmarkRequest(CPUDeviceBenchmarkRequest):
|
| 808 |
+
# Important: Instances of this class have to be serializable
|
| 809 |
+
# across process boundaries. Do not put Tensors in here!
|
| 810 |
+
|
| 811 |
+
def __init__(
|
| 812 |
+
self,
|
| 813 |
+
kernel_name: str,
|
| 814 |
+
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 815 |
+
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 816 |
+
extra_args: Iterable[Any],
|
| 817 |
+
source_code: str,
|
| 818 |
+
) -> None:
|
| 819 |
+
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
|
| 820 |
+
self.source_code = source_code
|
| 821 |
+
self.hash_key = get_hash(source_code)
|
| 822 |
+
self.DLL: Optional[Union[CDLL, ModuleType]] = None
|
| 823 |
+
|
| 824 |
+
def precompile(self):
|
| 825 |
+
# Prepopulate CppCodeCache
|
| 826 |
+
# may happen in separate Threadpool
|
| 827 |
+
log.debug("Precompiling %s", self)
|
| 828 |
+
CppCodeCache.load(self.source_code, cuda=False)
|
| 829 |
+
log.debug("Done precompiling %s", self)
|
| 830 |
+
|
| 831 |
+
def make_run_fn(
|
| 832 |
+
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
|
| 833 |
+
) -> Callable[[], None]:
|
| 834 |
+
# TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf
|
| 835 |
+
self.DLL = CppCodeCache.load(self.source_code, cuda=False)
|
| 836 |
+
args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]]
|
| 837 |
+
log.debug(
|
| 838 |
+
"make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s",
|
| 839 |
+
self.kernel_name,
|
| 840 |
+
self.DLL,
|
| 841 |
+
args,
|
| 842 |
+
self.extra_args,
|
| 843 |
+
)
|
| 844 |
+
run_method = getattr(self.DLL, self.kernel_name)
|
| 845 |
+
# Assume only size with type ctypes.c_ulonglong in extra_args
|
| 846 |
+
assert all(isinstance(arg, ctypes.c_ulonglong) for arg in self.extra_args)
|
| 847 |
+
run_method.argtypes = [ctypes.c_ulonglong] * (
|
| 848 |
+
len(args) + len(list(self.extra_args))
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# Generate partial function.
|
| 852 |
+
return functools.partial(
|
| 853 |
+
run_method,
|
| 854 |
+
*args,
|
| 855 |
+
*self.extra_args,
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
def cleanup_run_fn(self) -> None:
|
| 859 |
+
if self.DLL is not None:
|
| 860 |
+
"""
|
| 861 |
+
Check close attr due to it crash on Windows.
|
| 862 |
+
"""
|
| 863 |
+
if hasattr(self.DLL, "close"):
|
| 864 |
+
self.DLL.close()
|
| 865 |
+
|
| 866 |
+
def __str__(self) -> str:
|
| 867 |
+
return f"{self.kernel_name=}"
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
def benchmark_in_sub_process(
|
| 871 |
+
choices: List[TritonTemplateCaller],
|
| 872 |
+
) -> Dict[TritonTemplateCaller, float]:
|
| 873 |
+
"""
|
| 874 |
+
Do benchmarking in a subprocess and return the perf number (latency).
|
| 875 |
+
"""
|
| 876 |
+
return tuning_pool.benchmark(choices)
|
.venv/lib/python3.11/site-packages/torch/_inductor/codecache.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import math
|
| 3 |
+
from enum import IntEnum
|
| 4 |
+
|
| 5 |
+
import sympy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from . import ir
|
| 10 |
+
from .utils import get_dtype_size, sympy_product
|
| 11 |
+
from .virtualized import V
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class NCCL_COLL(IntEnum):
|
| 15 |
+
ALL_REDUCE = 0
|
| 16 |
+
ALL_GATHER = 1
|
| 17 |
+
REDUCE_SCATTER = 2
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class NVIDIA_GPU_TYPE(IntEnum):
|
| 21 |
+
VOLTA = 0
|
| 22 |
+
AMPERE = 1
|
| 23 |
+
HOPPER = 2
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@functools.lru_cache
|
| 27 |
+
def get_gpu_type() -> NVIDIA_GPU_TYPE:
|
| 28 |
+
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or ""
|
| 29 |
+
if "V100" in gpu_info:
|
| 30 |
+
return NVIDIA_GPU_TYPE.VOLTA
|
| 31 |
+
elif "A100" in gpu_info:
|
| 32 |
+
return NVIDIA_GPU_TYPE.AMPERE
|
| 33 |
+
elif "H100" in gpu_info:
|
| 34 |
+
return NVIDIA_GPU_TYPE.HOPPER
|
| 35 |
+
else:
|
| 36 |
+
# for other gpu types, assume Ampere
|
| 37 |
+
return NVIDIA_GPU_TYPE.AMPERE
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
|
| 41 |
+
if not isinstance(node, ir._CollectiveKernel):
|
| 42 |
+
raise ValueError(f"node is not a collective kernel: {node}")
|
| 43 |
+
|
| 44 |
+
kernel_name = node.python_kernel_name
|
| 45 |
+
assert kernel_name is not None
|
| 46 |
+
if "all_reduce" in kernel_name:
|
| 47 |
+
return NCCL_COLL.ALL_REDUCE
|
| 48 |
+
elif "all_gather" in kernel_name:
|
| 49 |
+
return NCCL_COLL.ALL_GATHER
|
| 50 |
+
elif "reduce_scatter" in kernel_name:
|
| 51 |
+
return NCCL_COLL.REDUCE_SCATTER
|
| 52 |
+
else:
|
| 53 |
+
raise ValueError(f"Unsupported collective kernel: {kernel_name}")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_collective_input_size_bytes(node: ir.IRNode) -> int:
|
| 57 |
+
sz_bytes = 0
|
| 58 |
+
for inp in node.inputs: # type: ignore[attr-defined]
|
| 59 |
+
numel = sympy_product(inp.layout.size)
|
| 60 |
+
if isinstance(numel, sympy.Integer):
|
| 61 |
+
# For ease of testing
|
| 62 |
+
numel = int(numel)
|
| 63 |
+
else:
|
| 64 |
+
numel = V.graph.sizevars.size_hint(numel, fallback=0)
|
| 65 |
+
sz_bytes += numel * get_dtype_size(inp.layout.dtype)
|
| 66 |
+
return sz_bytes
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_collective_group_size(node: ir.IRNode) -> int:
|
| 70 |
+
if type(node) == ir._CollectiveKernel:
|
| 71 |
+
from torch.distributed.distributed_c10d import _get_group_size_by_name
|
| 72 |
+
|
| 73 |
+
return _get_group_size_by_name(node.constant_args[-1])
|
| 74 |
+
else:
|
| 75 |
+
raise TypeError(f"Unsupported collective type: {node}")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
####################################################################################################################
|
| 79 |
+
# The following code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
|
| 80 |
+
####################################################################################################################
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class NCCL_HW(IntEnum):
|
| 84 |
+
NVLINK = 0
|
| 85 |
+
PCI = 1
|
| 86 |
+
NET = 2
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class NCCL_ALGO(IntEnum):
|
| 90 |
+
TREE = 0
|
| 91 |
+
RING = 1
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class NCCL_PROTO(IntEnum):
|
| 95 |
+
# The ordering and enum values here matches original in
|
| 96 |
+
# https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28
|
| 97 |
+
# For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990
|
| 98 |
+
LL = 0 # Low-latency
|
| 99 |
+
# LL128 = 1 # Low-latency 128-byte
|
| 100 |
+
# SIMPLE = 2
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# Latencies in us
|
| 104 |
+
# len(NCCL_ALGO) x len(NCCL_PROTO)
|
| 105 |
+
# NOTE: use array instead of tensor to prevent incompatibility with fake mode
|
| 106 |
+
baseLat = [
|
| 107 |
+
# Tree
|
| 108 |
+
[
|
| 109 |
+
6.8, # LL
|
| 110 |
+
],
|
| 111 |
+
# Ring
|
| 112 |
+
[
|
| 113 |
+
6.6, # LL
|
| 114 |
+
],
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
# Latencies in us
|
| 118 |
+
# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO)
|
| 119 |
+
hwLat = [
|
| 120 |
+
# NVLINK
|
| 121 |
+
[
|
| 122 |
+
[0.6], # Tree (LL)
|
| 123 |
+
[0.6], # Ring (LL)
|
| 124 |
+
],
|
| 125 |
+
# PCI
|
| 126 |
+
[
|
| 127 |
+
[1.0], # Tree (LL)
|
| 128 |
+
[1.0], # Ring (LL)
|
| 129 |
+
],
|
| 130 |
+
# NET
|
| 131 |
+
[
|
| 132 |
+
[5.0], # Tree (LL)
|
| 133 |
+
[2.7], # Ring (LL)
|
| 134 |
+
],
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# LL128 max BW per channel
|
| 139 |
+
llMaxBws = [
|
| 140 |
+
# Volta-N1/Intel-N2/Intel-N4
|
| 141 |
+
[
|
| 142 |
+
39.0,
|
| 143 |
+
39.0,
|
| 144 |
+
20.4,
|
| 145 |
+
],
|
| 146 |
+
# Ampere-N1/AMD-N2/AMD-N4
|
| 147 |
+
[
|
| 148 |
+
87.7,
|
| 149 |
+
22.5, # avg of ring & tree
|
| 150 |
+
19.0,
|
| 151 |
+
],
|
| 152 |
+
# Hopper-N1/AMD-N2/AMD-N4
|
| 153 |
+
[
|
| 154 |
+
87.7,
|
| 155 |
+
22.5, # avg of ring & tree
|
| 156 |
+
19.0,
|
| 157 |
+
],
|
| 158 |
+
]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
| 162 |
+
"""
|
| 163 |
+
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
| 164 |
+
|
| 165 |
+
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
|
| 166 |
+
We aim to estimate the runtime as accurately as possible.
|
| 167 |
+
|
| 168 |
+
Assumptions:
|
| 169 |
+
- only ring algorithm (NCCL_ALGO_RING) is used
|
| 170 |
+
- only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used
|
| 171 |
+
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
| 172 |
+
- collective is one of: allreduce, reducescatter, allgather
|
| 173 |
+
"""
|
| 174 |
+
tensor_storage_size_bytes = get_collective_input_size_bytes(node)
|
| 175 |
+
# Convert bytes to GB
|
| 176 |
+
tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
|
| 177 |
+
|
| 178 |
+
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
|
| 179 |
+
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
| 180 |
+
num_gpus_per_node = 8
|
| 181 |
+
group_size = get_collective_group_size(node)
|
| 182 |
+
nNodes = math.ceil(group_size / num_gpus_per_node)
|
| 183 |
+
nRanks = group_size # this is total # of gpus globally that participate in this collective op
|
| 184 |
+
|
| 185 |
+
if nRanks <= 1:
|
| 186 |
+
return 0
|
| 187 |
+
|
| 188 |
+
# Assumes ring algorithm
|
| 189 |
+
nccl_algo = NCCL_ALGO.RING
|
| 190 |
+
nccl_proto = NCCL_PROTO.LL
|
| 191 |
+
coll = get_collective_type(node)
|
| 192 |
+
|
| 193 |
+
# =============== bandwidth computation ===============
|
| 194 |
+
# First compute bandwidth in GB/s; then at the end, convert it to GB/ns
|
| 195 |
+
|
| 196 |
+
bwIntra = torch._inductor.config.intra_node_bw
|
| 197 |
+
bwInter = torch._inductor.config.inter_node_bw
|
| 198 |
+
|
| 199 |
+
compCapIndex = get_gpu_type()
|
| 200 |
+
index2 = nNodes - 1 if nNodes <= 2 else 2
|
| 201 |
+
# LL: for single node, we look at GPU type; for multi-node, we look at CPU type
|
| 202 |
+
index1 = compCapIndex if nNodes == 1 else 0
|
| 203 |
+
llMaxBw = llMaxBws[index1][index2]
|
| 204 |
+
|
| 205 |
+
# NOTE: each step of ring algorithm is synchronized,
|
| 206 |
+
# and is bottlenecked by the slowest link which is the inter-node interconnect.
|
| 207 |
+
# hence when nNodes >= 2, bw is inter-node bandwidth.
|
| 208 |
+
# NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc
|
| 209 |
+
# have this as `if nNodes <= 2` which seems wrong. Corrected it here.
|
| 210 |
+
bw = bwIntra if nNodes == 1 else bwInter
|
| 211 |
+
nChannels = 2 # Assume # channels is 2
|
| 212 |
+
busBw = nChannels * bw
|
| 213 |
+
|
| 214 |
+
# Various model refinements
|
| 215 |
+
busBw = min(
|
| 216 |
+
llMaxBw,
|
| 217 |
+
busBw
|
| 218 |
+
* (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0),
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
if coll == NCCL_COLL.ALL_REDUCE:
|
| 222 |
+
nsteps = 2 * (nRanks - 1)
|
| 223 |
+
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
|
| 224 |
+
nsteps = nRanks - 1
|
| 225 |
+
|
| 226 |
+
# Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time)
|
| 227 |
+
ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined]
|
| 228 |
+
bandwidth = busBw * ratio
|
| 229 |
+
# Convert GB/s to GB/ns
|
| 230 |
+
bandwidth_GB_per_ns = bandwidth / 1e9
|
| 231 |
+
|
| 232 |
+
# =============== latency computation ===============
|
| 233 |
+
intraHw = NCCL_HW.NVLINK
|
| 234 |
+
|
| 235 |
+
if coll == NCCL_COLL.ALL_REDUCE:
|
| 236 |
+
if nNodes > 1:
|
| 237 |
+
nInterSteps = 2 * nNodes
|
| 238 |
+
else:
|
| 239 |
+
nInterSteps = 0
|
| 240 |
+
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
|
| 241 |
+
nInterSteps = nNodes - 1
|
| 242 |
+
|
| 243 |
+
# First compute latency in us; then at the end, convert it to ns
|
| 244 |
+
latency = baseLat[nccl_algo][nccl_proto]
|
| 245 |
+
intraLat = hwLat[intraHw][nccl_algo][nccl_proto]
|
| 246 |
+
interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto]
|
| 247 |
+
|
| 248 |
+
# Inter-node rings still have to launch nsteps * net overhead.
|
| 249 |
+
netOverhead = 0.0
|
| 250 |
+
if nNodes > 1:
|
| 251 |
+
netOverhead = 1.0 # getNetOverhead(comm);
|
| 252 |
+
intraLat = max(intraLat, netOverhead)
|
| 253 |
+
latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined]
|
| 254 |
+
# Convert us to ns
|
| 255 |
+
latency_ns = latency * 1e3
|
| 256 |
+
|
| 257 |
+
# =============== final result ===============
|
| 258 |
+
transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns
|
| 259 |
+
return transport_ns + latency_ns
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
################################################################################################################
|
| 263 |
+
# The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
|
| 264 |
+
################################################################################################################
|
.venv/lib/python3.11/site-packages/torch/_inductor/comms.py
ADDED
|
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# pyre-strict
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import heapq
|
| 6 |
+
import operator
|
| 7 |
+
import sys
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from typing import Dict, List, Set, TYPE_CHECKING
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from . import config, ir
|
| 14 |
+
from .dependencies import WeakDep
|
| 15 |
+
from .utils import (
|
| 16 |
+
contains_collective,
|
| 17 |
+
contains_wait,
|
| 18 |
+
find_recursive_deps_of_node,
|
| 19 |
+
find_recursive_users_of_node,
|
| 20 |
+
is_collective,
|
| 21 |
+
is_fallback_op,
|
| 22 |
+
is_wait,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
from .scheduler import BaseSchedulerNode
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
| 33 |
+
"""
|
| 34 |
+
Greedily schedules waits as late as possible.
|
| 35 |
+
"""
|
| 36 |
+
return _schedule_for_comm(
|
| 37 |
+
snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
| 42 |
+
"""
|
| 43 |
+
Greedily schedules comms as early as possible.
|
| 44 |
+
"""
|
| 45 |
+
return _schedule_for_comm(
|
| 46 |
+
snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def reorder_compute_for_overlap(
|
| 51 |
+
snodes: List[BaseSchedulerNode],
|
| 52 |
+
) -> List[BaseSchedulerNode]:
|
| 53 |
+
"""
|
| 54 |
+
This achieves the following overall scheduling procedure:
|
| 55 |
+
Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
|
| 56 |
+
that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N.
|
| 57 |
+
Step 2: If all those compute nodes are sufficient to overlap comm N, we're done.
|
| 58 |
+
Otherwise, we now need to look elsewhere to find compute that overlaps with comm N.
|
| 59 |
+
We prioritize compute nodes that are needed sooner.
|
| 60 |
+
Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1.
|
| 61 |
+
Step 4: We schedule comm N + 1.
|
| 62 |
+
Repeat this for subsequent comm nodes.
|
| 63 |
+
"""
|
| 64 |
+
return _schedule_for_comm(
|
| 65 |
+
snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _schedule_for_comm(
|
| 70 |
+
snodes: List[BaseSchedulerNode],
|
| 71 |
+
raise_comms: bool,
|
| 72 |
+
sink_waits: bool,
|
| 73 |
+
reorder_for_overlap: bool,
|
| 74 |
+
) -> List[BaseSchedulerNode]:
|
| 75 |
+
"""
|
| 76 |
+
Schedule `snodes` for various comm optimization objectives.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
snodes: the nodes to be scheduled.
|
| 80 |
+
raise_comms: whether to greedily schedule collectives as early as possible
|
| 81 |
+
sink_wait: whether to greedily schedule waits as late as possible
|
| 82 |
+
reorder_compute_for_overlap: whether to reorder compute nodes to
|
| 83 |
+
optimize for compute/communication overlapping.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
The new schedule order.
|
| 87 |
+
|
| 88 |
+
Some notes on the synergy between different options:
|
| 89 |
+
- `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`.
|
| 90 |
+
- When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized.
|
| 91 |
+
"""
|
| 92 |
+
# We assign each node a tuple of scores (score_0, score_1, score_2),
|
| 93 |
+
# decreasing in importance, with a lower value indicating a higher ranking:
|
| 94 |
+
#
|
| 95 |
+
# - score_0: the lowest comm_idx among the comm nodes that the node blocks.
|
| 96 |
+
# If a node doesn't block any comm nodes, its score_0 is set to
|
| 97 |
+
# sys.maxsize. This score ensures that comm nodes get scheduled as early as
|
| 98 |
+
# possible.
|
| 99 |
+
# - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures
|
| 100 |
+
# that wait nodes are deferred as late as possible.
|
| 101 |
+
# - score_2: the index of the node in the original topological order. This
|
| 102 |
+
# score provides stability in case of ties.
|
| 103 |
+
#
|
| 104 |
+
# When only raise_comms is True, only score_0 and score_2 are considered.
|
| 105 |
+
# When only sink_waits is True, only score_1 and score_2 are considered.
|
| 106 |
+
# When neither is True, the original order is yielded.
|
| 107 |
+
buf_name_to_snode = {}
|
| 108 |
+
name_to_fused_node = {}
|
| 109 |
+
scores_0, scores_1, scores_2 = {}, {}, {}
|
| 110 |
+
for idx, snode in enumerate(snodes):
|
| 111 |
+
for buf_name in snode.get_buffer_names():
|
| 112 |
+
buf_name_to_snode[buf_name] = snode
|
| 113 |
+
|
| 114 |
+
for op_name in snode.get_operation_names():
|
| 115 |
+
name_to_fused_node[op_name] = snode
|
| 116 |
+
name_to_fused_node[snode.get_name()] = snode
|
| 117 |
+
|
| 118 |
+
node_name = snode.get_name()
|
| 119 |
+
scores_0[node_name] = sys.maxsize
|
| 120 |
+
scores_1[node_name] = 0
|
| 121 |
+
scores_2[node_name] = idx
|
| 122 |
+
|
| 123 |
+
comm_idx = 0
|
| 124 |
+
for snode in snodes:
|
| 125 |
+
if raise_comms and contains_collective(snode):
|
| 126 |
+
scores_0[snode.get_name()] = comm_idx
|
| 127 |
+
for anc in snode.ancestors:
|
| 128 |
+
anc_fused_name = name_to_fused_node[anc].get_name()
|
| 129 |
+
scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx)
|
| 130 |
+
comm_idx += 1
|
| 131 |
+
elif sink_waits and contains_wait(snode):
|
| 132 |
+
scores_1[snode.get_name()] = 1
|
| 133 |
+
|
| 134 |
+
class Runnable:
|
| 135 |
+
def __init__(self, snode) -> None:
|
| 136 |
+
self.snode = snode
|
| 137 |
+
name = next(iter(snode.get_operation_names()))
|
| 138 |
+
fused_name = name_to_fused_node[name].get_name()
|
| 139 |
+
self.score = (
|
| 140 |
+
scores_0[fused_name],
|
| 141 |
+
scores_1[fused_name],
|
| 142 |
+
scores_2[fused_name],
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def __lt__(self, other):
|
| 146 |
+
return self.score < other.score
|
| 147 |
+
|
| 148 |
+
unmet_deps: Dict[BaseSchedulerNode, Set[str]] = {
|
| 149 |
+
snode: {dep.name for dep in snode.unmet_dependencies} for snode in snodes
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
ready: List[Runnable] = []
|
| 153 |
+
buffer_users: Dict[str, Set[BaseSchedulerNode]] = defaultdict(set)
|
| 154 |
+
snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes}
|
| 155 |
+
|
| 156 |
+
for snode, deps in unmet_deps.items():
|
| 157 |
+
if len(deps) == 0:
|
| 158 |
+
heapq.heappush(ready, Runnable(snode))
|
| 159 |
+
for dep in deps:
|
| 160 |
+
buffer_users[dep].add(snode)
|
| 161 |
+
|
| 162 |
+
scheduled = []
|
| 163 |
+
|
| 164 |
+
def schedule(snode):
|
| 165 |
+
"""
|
| 166 |
+
Schedules `snode` and put all unblocked nodes onto the ready queue.
|
| 167 |
+
"""
|
| 168 |
+
scheduled.append(snode)
|
| 169 |
+
for buf_name in snode.get_buffer_names():
|
| 170 |
+
for snode in buffer_users[buf_name]:
|
| 171 |
+
unmet_deps[snode].remove(buf_name)
|
| 172 |
+
if len(unmet_deps[snode]) == 0:
|
| 173 |
+
heapq.heappush(ready, Runnable(snode))
|
| 174 |
+
|
| 175 |
+
def get_overlapping_candidate():
|
| 176 |
+
"""
|
| 177 |
+
Return the next node in the ready queue that's neither a collective or
|
| 178 |
+
a wait.
|
| 179 |
+
"""
|
| 180 |
+
candidates = [
|
| 181 |
+
x
|
| 182 |
+
for x in ready
|
| 183 |
+
if not contains_collective(x.snode) and not contains_wait(x.snode)
|
| 184 |
+
]
|
| 185 |
+
if len(candidates) == 0:
|
| 186 |
+
return None
|
| 187 |
+
return min(candidates, key=lambda x: x.score)
|
| 188 |
+
|
| 189 |
+
def schedule_collective_for_overlap(snode):
|
| 190 |
+
"""
|
| 191 |
+
Schedules collective node `snode`, along with one or more compute nodes
|
| 192 |
+
to overlap with it. The strategy is described in the comment of
|
| 193 |
+
`reorder_compute_for_overlap`.
|
| 194 |
+
"""
|
| 195 |
+
assert contains_collective(snode)
|
| 196 |
+
schedule(snode)
|
| 197 |
+
|
| 198 |
+
collective_cost = snode_to_cost[snode]
|
| 199 |
+
while (
|
| 200 |
+
collective_cost > 0
|
| 201 |
+
and (candidate := get_overlapping_candidate()) is not None
|
| 202 |
+
):
|
| 203 |
+
ready.remove(candidate)
|
| 204 |
+
schedule(candidate.snode)
|
| 205 |
+
collective_cost -= snode_to_cost[candidate.snode]
|
| 206 |
+
heapq.heapify(ready)
|
| 207 |
+
|
| 208 |
+
while len(ready):
|
| 209 |
+
snode = heapq.heappop(ready).snode
|
| 210 |
+
if reorder_for_overlap and contains_collective(snode):
|
| 211 |
+
schedule_collective_for_overlap(snode)
|
| 212 |
+
else:
|
| 213 |
+
schedule(snode)
|
| 214 |
+
|
| 215 |
+
for snode, deps in unmet_deps.items():
|
| 216 |
+
assert len(deps) == 0, (
|
| 217 |
+
"Detected unscheduled nodes. "
|
| 218 |
+
f"Nodes with unmet dependencies: {unmet_deps}"
|
| 219 |
+
)
|
| 220 |
+
return scheduled
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def decide_global_ordering_of_comms(
|
| 224 |
+
nodes: List[BaseSchedulerNode], name_to_buf, name_to_fused_node
|
| 225 |
+
) -> List[BaseSchedulerNode]:
|
| 226 |
+
"""
|
| 227 |
+
Decide global ordering of comms, by just enforcing the ordering that's in the input graph
|
| 228 |
+
(might not be the same ordering as the eager mode program).
|
| 229 |
+
TODO: Come up with a better approach
|
| 230 |
+
"""
|
| 231 |
+
# If FSDP2 is used, we apply FSDP-specific passes.
|
| 232 |
+
if any(
|
| 233 |
+
is_fallback_op(
|
| 234 |
+
x.node,
|
| 235 |
+
{
|
| 236 |
+
torch.ops.fsdp.all_gather_copy_in.default,
|
| 237 |
+
torch.ops.fsdp.chunk_cat.default,
|
| 238 |
+
},
|
| 239 |
+
)
|
| 240 |
+
for x in nodes
|
| 241 |
+
):
|
| 242 |
+
nodes = enforce_comm_ordering_for_fsdp(nodes, name_to_buf, name_to_fused_node)
|
| 243 |
+
|
| 244 |
+
comm_nodes = [n for n in nodes if contains_collective(n)]
|
| 245 |
+
|
| 246 |
+
for i in range(1, len(comm_nodes)):
|
| 247 |
+
# Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
|
| 248 |
+
mutating_buf = next(iter(comm_nodes[i].get_buffer_names()))
|
| 249 |
+
for buf in comm_nodes[i - 1].get_buffer_names():
|
| 250 |
+
comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf))
|
| 251 |
+
|
| 252 |
+
return nodes
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def estimate_op_runtime(snode: BaseSchedulerNode) -> float:
|
| 256 |
+
"""
|
| 257 |
+
Returns estimated op runtime in nanoseconds (ns)
|
| 258 |
+
"""
|
| 259 |
+
if config.estimate_op_runtime == "default":
|
| 260 |
+
runtime = snode.get_estimated_runtime()
|
| 261 |
+
else:
|
| 262 |
+
assert callable(config.estimate_op_runtime)
|
| 263 |
+
runtime = config.estimate_op_runtime(snode)
|
| 264 |
+
return runtime
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def node_summary(snode):
|
| 268 |
+
detail = ""
|
| 269 |
+
if isinstance(snode.node, ir.ExternKernelOut):
|
| 270 |
+
detail = f" ({snode.node.python_kernel_name})"
|
| 271 |
+
out_tensor_info = ""
|
| 272 |
+
if (
|
| 273 |
+
hasattr(snode.node, "layout")
|
| 274 |
+
and hasattr(snode.node.layout, "size")
|
| 275 |
+
and hasattr(snode.node.layout, "stride")
|
| 276 |
+
):
|
| 277 |
+
out_tensor_info = (
|
| 278 |
+
f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})"
|
| 279 |
+
)
|
| 280 |
+
node_name = ""
|
| 281 |
+
if hasattr(snode.node, "name"):
|
| 282 |
+
node_name = snode.node.name
|
| 283 |
+
return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})"
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def visualize_overlap(order):
|
| 287 |
+
total_est_runtime: float = 0.0
|
| 288 |
+
cur_comm_node = None
|
| 289 |
+
for snode in order:
|
| 290 |
+
if cur_comm_node is None:
|
| 291 |
+
if contains_collective(snode):
|
| 292 |
+
total_est_runtime += estimate_op_runtime(snode)
|
| 293 |
+
cur_comm_node = snode.node
|
| 294 |
+
elif is_wait(snode.node):
|
| 295 |
+
raise AssertionError(
|
| 296 |
+
"Wait is not expected when there is no collective running"
|
| 297 |
+
)
|
| 298 |
+
else: # exposed compute op
|
| 299 |
+
total_est_runtime += estimate_op_runtime(snode)
|
| 300 |
+
overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
|
| 301 |
+
else: # cur_comm_node is not None
|
| 302 |
+
if contains_collective(snode):
|
| 303 |
+
raise AssertionError(
|
| 304 |
+
"Found two collectives running at the same time. "
|
| 305 |
+
"`visualize_overlap` needs to be updated to handle this case"
|
| 306 |
+
)
|
| 307 |
+
elif is_wait(snode.node): # end of this comm op
|
| 308 |
+
overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
|
| 309 |
+
cur_comm_node = None
|
| 310 |
+
else: # overlapped compute op
|
| 311 |
+
overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004
|
| 312 |
+
overlap_log.debug(
|
| 313 |
+
f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def reorder_compute_and_comm_for_overlap(
|
| 318 |
+
snodes: List[BaseSchedulerNode],
|
| 319 |
+
) -> List[BaseSchedulerNode]:
|
| 320 |
+
order = snodes
|
| 321 |
+
|
| 322 |
+
for p in config.reorder_for_compute_comm_overlap_passes:
|
| 323 |
+
if isinstance(p, str) and p in globals():
|
| 324 |
+
p = globals()[p] # it is a builtin pass
|
| 325 |
+
if torch.distributed.get_rank() == 0:
|
| 326 |
+
overlap_log.debug(
|
| 327 |
+
f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004
|
| 328 |
+
)
|
| 329 |
+
try:
|
| 330 |
+
visualize_overlap(order)
|
| 331 |
+
except Exception as e:
|
| 332 |
+
overlap_log.debug(str(e))
|
| 333 |
+
order = p(order) # type: ignore[operator]
|
| 334 |
+
if torch.distributed.get_rank() == 0:
|
| 335 |
+
overlap_log.debug(
|
| 336 |
+
f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004
|
| 337 |
+
)
|
| 338 |
+
try:
|
| 339 |
+
visualize_overlap(order)
|
| 340 |
+
except Exception as e:
|
| 341 |
+
overlap_log.debug(str(e))
|
| 342 |
+
return order
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
|
| 346 |
+
try:
|
| 347 |
+
import torch.distributed._composable.fsdp._fsdp_collectives
|
| 348 |
+
|
| 349 |
+
assert torch.distributed.is_available()
|
| 350 |
+
# Assert existence of these ops
|
| 351 |
+
assert (
|
| 352 |
+
torch.ops._c10d_functional.all_gather_into_tensor
|
| 353 |
+
and torch.ops._c10d_functional.all_gather_into_tensor_out
|
| 354 |
+
)
|
| 355 |
+
except (ImportError, AttributeError, AssertionError):
|
| 356 |
+
return
|
| 357 |
+
|
| 358 |
+
from .pattern_matcher import (
|
| 359 |
+
CallFunction,
|
| 360 |
+
KeywordArg,
|
| 361 |
+
Match,
|
| 362 |
+
PatternMatcherPass,
|
| 363 |
+
register_graph_pattern,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
"""
|
| 367 |
+
all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
|
| 368 |
+
getitem = all_gather_copy_in[0];
|
| 369 |
+
(getitem_1 = all_gather_copy_in[1];) # optional
|
| 370 |
+
|
| 371 |
+
all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...);
|
| 372 |
+
|
| 373 |
+
->
|
| 374 |
+
|
| 375 |
+
all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
|
| 376 |
+
getitem = all_gather_copy_in[0];
|
| 377 |
+
getitem_1 = all_gather_copy_in[1];
|
| 378 |
+
|
| 379 |
+
all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1);
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
def remove_unused_getitem(g):
|
| 383 |
+
# Remove `getitem_X = all_gather_copy_in[1]` which is never used.
|
| 384 |
+
node_list = list(g.nodes)
|
| 385 |
+
for n in node_list:
|
| 386 |
+
if (
|
| 387 |
+
n.target == operator.getitem
|
| 388 |
+
and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default
|
| 389 |
+
and n.args[1] == 1
|
| 390 |
+
):
|
| 391 |
+
g.erase_node(n)
|
| 392 |
+
|
| 393 |
+
graph_pass = PatternMatcherPass()
|
| 394 |
+
|
| 395 |
+
@register_graph_pattern(
|
| 396 |
+
CallFunction(
|
| 397 |
+
torch.ops._c10d_functional.all_gather_into_tensor.default,
|
| 398 |
+
CallFunction(
|
| 399 |
+
operator.getitem,
|
| 400 |
+
CallFunction(
|
| 401 |
+
torch.ops.fsdp.all_gather_copy_in.default,
|
| 402 |
+
KeywordArg("all_gather_inputs"),
|
| 403 |
+
KeywordArg("inp_split_sizes"),
|
| 404 |
+
KeywordArg("all_gather_input_numel"),
|
| 405 |
+
KeywordArg("world_size"),
|
| 406 |
+
KeywordArg("rank"),
|
| 407 |
+
KeywordArg("dtype"),
|
| 408 |
+
KeywordArg("device"),
|
| 409 |
+
),
|
| 410 |
+
KeywordArg("item_idx"),
|
| 411 |
+
),
|
| 412 |
+
KeywordArg("group_size"),
|
| 413 |
+
KeywordArg("group_name"),
|
| 414 |
+
),
|
| 415 |
+
pass_dict=graph_pass,
|
| 416 |
+
extra_check=lambda match: match.kwargs["item_idx"] == 0,
|
| 417 |
+
)
|
| 418 |
+
def reinplace_all_gather(match: Match, *args, **kwargs):
|
| 419 |
+
def repl(
|
| 420 |
+
*args,
|
| 421 |
+
):
|
| 422 |
+
copy_in_args = args[:-2]
|
| 423 |
+
group_size = args[-2]
|
| 424 |
+
group_name = args[-1]
|
| 425 |
+
all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(
|
| 426 |
+
*copy_in_args
|
| 427 |
+
)
|
| 428 |
+
getitem = all_gather_copy_in[0]
|
| 429 |
+
getitem_1 = all_gather_copy_in[1]
|
| 430 |
+
all_gather_into_tensor = (
|
| 431 |
+
torch.ops._c10d_functional.all_gather_into_tensor_out.default(
|
| 432 |
+
getitem, group_size, group_name, out=getitem_1
|
| 433 |
+
)
|
| 434 |
+
)
|
| 435 |
+
return all_gather_into_tensor
|
| 436 |
+
|
| 437 |
+
match.replace_by_example(
|
| 438 |
+
repl,
|
| 439 |
+
[
|
| 440 |
+
kwargs["all_gather_inputs"],
|
| 441 |
+
kwargs["inp_split_sizes"],
|
| 442 |
+
kwargs["all_gather_input_numel"],
|
| 443 |
+
kwargs["world_size"],
|
| 444 |
+
kwargs["rank"],
|
| 445 |
+
kwargs["dtype"],
|
| 446 |
+
kwargs["device"],
|
| 447 |
+
kwargs["group_size"],
|
| 448 |
+
kwargs["group_name"],
|
| 449 |
+
],
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
remove_unused_getitem(graph)
|
| 453 |
+
graph_pass.apply(graph) # type: ignore[arg-type]
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def get_op_idx(snode):
|
| 457 |
+
assert not isinstance(
|
| 458 |
+
snode,
|
| 459 |
+
(
|
| 460 |
+
torch._inductor.scheduler.FusedSchedulerNode,
|
| 461 |
+
torch._inductor.scheduler.GroupedSchedulerNode,
|
| 462 |
+
),
|
| 463 |
+
)
|
| 464 |
+
return int(snode.get_name()[2:])
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def enforce_comm_ordering_for_fsdp(
|
| 468 |
+
snodes: List[torch._inductor.scheduler.BaseSchedulerNode],
|
| 469 |
+
name_to_buf: Dict[str, torch._inductor.scheduler.SchedulerBuffer],
|
| 470 |
+
name_to_fused_node: Dict[str, BaseSchedulerNode],
|
| 471 |
+
) -> List[torch._inductor.scheduler.BaseSchedulerNode]:
|
| 472 |
+
from . import scheduler
|
| 473 |
+
|
| 474 |
+
new_order: list[BaseSchedulerNode] = []
|
| 475 |
+
scheduled = set()
|
| 476 |
+
ag_exists = False
|
| 477 |
+
rs_exists = False
|
| 478 |
+
ag_grouped_node_to_wait_grouped_node = {}
|
| 479 |
+
rs_grouped_node_to_wait_grouped_node = {}
|
| 480 |
+
snode_name_to_final_snode = {}
|
| 481 |
+
|
| 482 |
+
def _create_group_node(snodes_to_group):
|
| 483 |
+
group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group)
|
| 484 |
+
for snode in snodes_to_group:
|
| 485 |
+
snode_name_to_final_snode[snode.get_name()] = group_node
|
| 486 |
+
snode_name_to_final_snode[group_node.get_name()] = group_node
|
| 487 |
+
return group_node
|
| 488 |
+
|
| 489 |
+
# Create grouped nodes for specific sets of ops
|
| 490 |
+
for snode in snodes:
|
| 491 |
+
# Case 1: Handle AllGather
|
| 492 |
+
if is_collective(
|
| 493 |
+
snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default
|
| 494 |
+
) and any(
|
| 495 |
+
is_fallback_op(
|
| 496 |
+
name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default
|
| 497 |
+
)
|
| 498 |
+
for x in snode.ancestors
|
| 499 |
+
):
|
| 500 |
+
ag_exists = True
|
| 501 |
+
ag_snode = snode
|
| 502 |
+
ag_related_snode_set: set[scheduler.BaseSchedulerNode] = set()
|
| 503 |
+
|
| 504 |
+
# Find the "cast + copy_in + getitem + all_gather" code block
|
| 505 |
+
find_recursive_deps_of_node(
|
| 506 |
+
ag_snode,
|
| 507 |
+
ag_related_snode_set,
|
| 508 |
+
name_to_buf,
|
| 509 |
+
name_to_fused_node,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# Find the "all_gather + all_gather_wait_tensor + copy_out + set_" code block
|
| 513 |
+
allowed_ops = {
|
| 514 |
+
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
|
| 515 |
+
torch.ops._c10d_functional.wait_tensor.default,
|
| 516 |
+
torch.ops.fsdp.split_with_sizes_copy.default,
|
| 517 |
+
torch.ops.aten.set_.source_Tensor,
|
| 518 |
+
}
|
| 519 |
+
find_recursive_users_of_node(
|
| 520 |
+
ag_snode,
|
| 521 |
+
ag_related_snode_set,
|
| 522 |
+
name_to_buf,
|
| 523 |
+
name_to_fused_node,
|
| 524 |
+
criteria_cb=lambda x: not (
|
| 525 |
+
isinstance(x, scheduler.NopKernelSchedulerNode)
|
| 526 |
+
or (
|
| 527 |
+
isinstance(x, scheduler.ExternKernelSchedulerNode)
|
| 528 |
+
and x.node.op_overload in allowed_ops # type: ignore[union-attr]
|
| 529 |
+
)
|
| 530 |
+
),
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
# sort nodes by original operation order
|
| 534 |
+
ag_related_snodes = sorted(
|
| 535 |
+
ag_related_snode_set, key=lambda x: get_op_idx(x)
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# In the "reuse layer" case, some ops in the 2nd all-gather code block could also
|
| 539 |
+
# depend on ops in the 1st all-gather code block, and we don't want to group them together.
|
| 540 |
+
end_idx_of_current_ag_block = len(ag_related_snodes)
|
| 541 |
+
copy_out_count = 0
|
| 542 |
+
for i in range(len(ag_related_snodes)):
|
| 543 |
+
cur_snode = ag_related_snodes[i]
|
| 544 |
+
if is_fallback_op(
|
| 545 |
+
cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default
|
| 546 |
+
):
|
| 547 |
+
copy_out_count += 1
|
| 548 |
+
if copy_out_count > 1:
|
| 549 |
+
end_idx_of_current_ag_block = i
|
| 550 |
+
break
|
| 551 |
+
|
| 552 |
+
ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block]
|
| 553 |
+
|
| 554 |
+
# Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode
|
| 555 |
+
wait_node_idx = None
|
| 556 |
+
for i in range(len(ag_related_snodes) - 1):
|
| 557 |
+
if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel):
|
| 558 |
+
wait_node_idx = i + 1
|
| 559 |
+
break
|
| 560 |
+
assert wait_node_idx is not None
|
| 561 |
+
ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx])
|
| 562 |
+
|
| 563 |
+
# Group "all_gather_wait_tensor + copy_out + set_" into one GroupedSchedulerNode
|
| 564 |
+
ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:])
|
| 565 |
+
|
| 566 |
+
ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node
|
| 567 |
+
|
| 568 |
+
# Case 2: Handle ReduceScatter
|
| 569 |
+
elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default):
|
| 570 |
+
rs_exists = True
|
| 571 |
+
rs_snode = snode
|
| 572 |
+
|
| 573 |
+
# Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block
|
| 574 |
+
rs_related_snode_set: set[scheduler.BaseSchedulerNode] = set()
|
| 575 |
+
find_recursive_users_of_node(
|
| 576 |
+
rs_snode,
|
| 577 |
+
rs_related_snode_set,
|
| 578 |
+
name_to_buf,
|
| 579 |
+
name_to_fused_node,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
# sort nodes by original operation order
|
| 583 |
+
rs_related_snodes = sorted(
|
| 584 |
+
rs_related_snode_set, key=lambda x: get_op_idx(x)
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
# Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode
|
| 588 |
+
wait_node_idx = None
|
| 589 |
+
for i in range(len(rs_related_snodes) - 1):
|
| 590 |
+
if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel):
|
| 591 |
+
wait_node_idx = i + 1
|
| 592 |
+
break
|
| 593 |
+
assert wait_node_idx is not None
|
| 594 |
+
rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx])
|
| 595 |
+
|
| 596 |
+
# Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode
|
| 597 |
+
rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:])
|
| 598 |
+
|
| 599 |
+
rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node
|
| 600 |
+
|
| 601 |
+
assert len(snode_name_to_final_snode) > 0
|
| 602 |
+
if ag_exists:
|
| 603 |
+
assert len(ag_grouped_node_to_wait_grouped_node) > 0
|
| 604 |
+
if rs_exists:
|
| 605 |
+
assert len(rs_grouped_node_to_wait_grouped_node) > 0
|
| 606 |
+
|
| 607 |
+
# Build the new node schedule, taking GroupedSchedulerNode into account
|
| 608 |
+
for snode in snodes:
|
| 609 |
+
if snode.get_name() in snode_name_to_final_snode:
|
| 610 |
+
snode = snode_name_to_final_snode[snode.get_name()]
|
| 611 |
+
if snode in scheduled:
|
| 612 |
+
continue
|
| 613 |
+
new_order.append(snode)
|
| 614 |
+
scheduled.add(snode)
|
| 615 |
+
|
| 616 |
+
# Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run
|
| 617 |
+
# before next AllGather's "copy_in then AG" group node
|
| 618 |
+
prev_ag_wait = None
|
| 619 |
+
for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items():
|
| 620 |
+
if prev_ag_wait is not None:
|
| 621 |
+
mutating_buf = next(iter(ag_group_node.get_buffer_names()))
|
| 622 |
+
for o in prev_ag_wait.get_outputs():
|
| 623 |
+
ag_group_node.add_fake_dep(
|
| 624 |
+
WeakDep(o.get_name(), mutating_buf=mutating_buf)
|
| 625 |
+
)
|
| 626 |
+
prev_ag_wait = wait_group_node
|
| 627 |
+
|
| 628 |
+
# Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run
|
| 629 |
+
# before next ReduceScatter's "copy_in then RS" group node
|
| 630 |
+
prev_rs_wait = None
|
| 631 |
+
for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items():
|
| 632 |
+
if prev_rs_wait is not None:
|
| 633 |
+
mutating_buf = next(iter(rs_group_node.get_buffer_names()))
|
| 634 |
+
for o in prev_rs_wait.get_outputs():
|
| 635 |
+
rs_group_node.add_fake_dep(
|
| 636 |
+
WeakDep(o.get_name(), mutating_buf=mutating_buf)
|
| 637 |
+
)
|
| 638 |
+
prev_rs_wait = wait_group_node
|
| 639 |
+
|
| 640 |
+
return new_order # type: ignore[return-value]
|
.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py
ADDED
|
@@ -0,0 +1,1629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
import contextlib
|
| 4 |
+
import functools
|
| 5 |
+
import io
|
| 6 |
+
import itertools
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
import warnings
|
| 12 |
+
from itertools import count
|
| 13 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
| 14 |
+
from unittest import mock
|
| 15 |
+
|
| 16 |
+
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
| 17 |
+
import torch.fx
|
| 18 |
+
import torch.utils._pytree as pytree
|
| 19 |
+
from functorch.compile import min_cut_rematerialization_partition
|
| 20 |
+
from torch._dynamo import (
|
| 21 |
+
compiled_autograd,
|
| 22 |
+
config as dynamo_config,
|
| 23 |
+
logging as dynamo_logging,
|
| 24 |
+
utils as dynamo_utils,
|
| 25 |
+
)
|
| 26 |
+
from torch._dynamo.device_interface import get_interface_for_device
|
| 27 |
+
from torch._dynamo.repro.after_aot import wrap_compiler_debug
|
| 28 |
+
from torch._dynamo.utils import (
|
| 29 |
+
counters,
|
| 30 |
+
detect_fake_mode,
|
| 31 |
+
flatten_graph_inputs,
|
| 32 |
+
lazy_format_graph_code,
|
| 33 |
+
)
|
| 34 |
+
from torch._functorch import config as functorch_config
|
| 35 |
+
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
|
| 36 |
+
from torch._inductor.codecache import (
|
| 37 |
+
_StrideExprStr,
|
| 38 |
+
code_hash,
|
| 39 |
+
CompiledFxGraph,
|
| 40 |
+
FxGraphCache,
|
| 41 |
+
)
|
| 42 |
+
from torch._inductor.cudagraph_utils import (
|
| 43 |
+
BoxedDeviceIndex,
|
| 44 |
+
CudagraphCachedInfo,
|
| 45 |
+
get_placeholder_info,
|
| 46 |
+
log_cudagraph_skip_and_bump_counter,
|
| 47 |
+
PlaceholderInfo,
|
| 48 |
+
)
|
| 49 |
+
from torch._inductor.debug import save_args_for_compile_fx_inner
|
| 50 |
+
from torch._inductor.runtime.runtime_utils import cache_dir
|
| 51 |
+
from torch._inductor.utils import (
|
| 52 |
+
BoxedBool,
|
| 53 |
+
count_tangents,
|
| 54 |
+
fresh_inductor_cache,
|
| 55 |
+
InputType,
|
| 56 |
+
is_gpu,
|
| 57 |
+
should_assume_input_aligned,
|
| 58 |
+
tensor_is_aligned,
|
| 59 |
+
)
|
| 60 |
+
from torch._logging import trace_structured
|
| 61 |
+
from torch._ops import OpOverload
|
| 62 |
+
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter
|
| 63 |
+
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
| 64 |
+
from torch.monitor import _WaitCounter
|
| 65 |
+
|
| 66 |
+
from .._dynamo.backends.common import aot_autograd
|
| 67 |
+
from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined]
|
| 68 |
+
from ..fx.graph import _PyTreeCodeGen
|
| 69 |
+
from . import config, metrics
|
| 70 |
+
from .debug import DebugContext
|
| 71 |
+
from .decomposition import select_decomp_table
|
| 72 |
+
from .fx_passes.joint_graph import joint_graph_passes
|
| 73 |
+
from .fx_passes.post_grad import post_grad_passes, view_to_reshape
|
| 74 |
+
from .fx_passes.pre_grad import pre_grad_passes
|
| 75 |
+
from .graph import GraphLowering
|
| 76 |
+
from .ir import ExternKernelNode
|
| 77 |
+
from .utils import (
|
| 78 |
+
align_inputs_from_check_idxs,
|
| 79 |
+
clone_preserve_strides,
|
| 80 |
+
copy_misaligned_inputs,
|
| 81 |
+
get_cloned_parameter_buffer_name,
|
| 82 |
+
has_incompatible_cudagraph_ops,
|
| 83 |
+
maybe_get_suppress_shape_guards_ctx,
|
| 84 |
+
output_node,
|
| 85 |
+
remove_unaligned_input_idxs,
|
| 86 |
+
shape_env_from_inputs,
|
| 87 |
+
)
|
| 88 |
+
from .virtualized import V
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if config.is_fbcode():
|
| 92 |
+
from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log
|
| 93 |
+
else:
|
| 94 |
+
# no-op decorator
|
| 95 |
+
def time_and_log(attr: str):
|
| 96 |
+
return dynamo_utils.identity
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
log = logging.getLogger(__name__)
|
| 100 |
+
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
| 101 |
+
post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs")
|
| 102 |
+
static_inputs_log = torch._logging.getArtifactLogger(
|
| 103 |
+
__name__, "cudagraph_static_inputs"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# copy_ fails when trying to write to tensors with memory overlap,
|
| 108 |
+
# for expanded dimensions (a dimension which used to have size 1 -> ?)
|
| 109 |
+
# we can select one element from that dimension and write to it
|
| 110 |
+
# to achieve writing to all values of that dimension of the input tensor
|
| 111 |
+
def get_expanded_dims(t):
|
| 112 |
+
if not isinstance(t, torch.Tensor):
|
| 113 |
+
return None
|
| 114 |
+
return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor:
|
| 118 |
+
for expanded_dim in expanded_dims:
|
| 119 |
+
t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
|
| 120 |
+
return t
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def complex_memory_overlap(t: torch.Tensor) -> bool:
|
| 124 |
+
# if torch._debug_has_internal_overlap thinks this tensor potentially has
|
| 125 |
+
# memory overlap internally, let's dig deeper to find out whether it's true.
|
| 126 |
+
#
|
| 127 |
+
# Call squeeze() so that dimension with size 1 does not cause false positive.
|
| 128 |
+
t = index_expanded_dims(t, get_expanded_dims(t)).squeeze()
|
| 129 |
+
if torch._debug_has_internal_overlap(t) != 0:
|
| 130 |
+
strides = t.stride()
|
| 131 |
+
sizes = t.shape
|
| 132 |
+
indices = list(range(len(strides)))
|
| 133 |
+
indices = [x for _, x in sorted(zip(strides, indices))]
|
| 134 |
+
for i in range(len(strides)):
|
| 135 |
+
prev_stride = 1 if i == 0 else strides[indices[i - 1]]
|
| 136 |
+
prev_size = 1 if i == 0 else sizes[indices[i - 1]]
|
| 137 |
+
if strides[indices[i]] < prev_stride * prev_size:
|
| 138 |
+
return True
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_static_input_idxs(num_fixed):
|
| 143 |
+
# If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes
|
| 144 |
+
# of cudagraphs. Rather than copying these into cudagraph-owned memory
|
| 145 |
+
# like we do for normal inputs on each run, we will re-record a cudagraph if these
|
| 146 |
+
# parameter locations change.
|
| 147 |
+
context = torch._guards.TracingContext.try_get()
|
| 148 |
+
fixed = list(range(num_fixed))
|
| 149 |
+
if not context or not context.fw_metadata:
|
| 150 |
+
return fixed
|
| 151 |
+
|
| 152 |
+
return fixed + context.fw_metadata.static_input_indices
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@functools.lru_cache(None)
|
| 156 |
+
def _step_logger():
|
| 157 |
+
return dynamo_logging.get_step_logger(log)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@functools.lru_cache(None)
|
| 161 |
+
def _warn_tf32_disabled():
|
| 162 |
+
if (
|
| 163 |
+
torch.cuda.is_available()
|
| 164 |
+
and not torch.backends.cuda.matmul.allow_tf32
|
| 165 |
+
and torch.cuda.get_device_capability() >= (8, 0)
|
| 166 |
+
):
|
| 167 |
+
warnings.warn(
|
| 168 |
+
"TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. "
|
| 169 |
+
"Consider setting `torch.set_float32_matmul_precision('high')` for better performance."
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _unlift_graph(mod, gm, graph_signature):
|
| 174 |
+
from torch.export.unflatten import _assign_attr, _AttrKind
|
| 175 |
+
|
| 176 |
+
state_dict = {}
|
| 177 |
+
for name, param in mod.named_parameters(remove_duplicate=False):
|
| 178 |
+
state_dict[name] = param
|
| 179 |
+
_assign_attr(
|
| 180 |
+
param,
|
| 181 |
+
gm,
|
| 182 |
+
name,
|
| 183 |
+
attr_kind=_AttrKind.PARAMETER,
|
| 184 |
+
)
|
| 185 |
+
for name, buffer in mod.named_buffers(remove_duplicate=False):
|
| 186 |
+
state_dict[name] = buffer
|
| 187 |
+
_assign_attr(
|
| 188 |
+
buffer,
|
| 189 |
+
gm,
|
| 190 |
+
name,
|
| 191 |
+
attr_kind=_AttrKind.BUFFER,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
placeholder_nodes = gm.graph.find_nodes(op="placeholder")
|
| 195 |
+
lifted_inputs = []
|
| 196 |
+
|
| 197 |
+
# In AOTI, module parameters and buffers are not lifted as graph inputs.
|
| 198 |
+
# As a result, mutation to buffers has side effect which makes their initial
|
| 199 |
+
# values different from Eager. So we clone them here as a copy.
|
| 200 |
+
# We are not cloning for parameters, although it will be needed if we want to
|
| 201 |
+
# support training.
|
| 202 |
+
for node in placeholder_nodes:
|
| 203 |
+
node_name = node.name
|
| 204 |
+
if node_name in graph_signature.inputs_to_parameters:
|
| 205 |
+
parameter_name = graph_signature.inputs_to_parameters[node_name]
|
| 206 |
+
lifted_inputs.append(parameter_name)
|
| 207 |
+
elif node_name in graph_signature.inputs_to_buffers:
|
| 208 |
+
buffer_name = graph_signature.inputs_to_buffers[node_name]
|
| 209 |
+
lifted_inputs.append(buffer_name)
|
| 210 |
+
gm.meta[
|
| 211 |
+
get_cloned_parameter_buffer_name(buffer_name)
|
| 212 |
+
] = clone_preserve_strides(state_dict[buffer_name])
|
| 213 |
+
else:
|
| 214 |
+
assert node_name in graph_signature.user_inputs
|
| 215 |
+
lifted_inputs.append(None)
|
| 216 |
+
|
| 217 |
+
from torch.export._unlift import _unlift
|
| 218 |
+
|
| 219 |
+
outputs = list(gm.graph.nodes)[-1].args[0]
|
| 220 |
+
mutated_outputs = []
|
| 221 |
+
buffer_mutations = graph_signature.buffers_to_mutate
|
| 222 |
+
user_input_mutations = graph_signature.user_inputs_to_mutate
|
| 223 |
+
output_tokens = graph_signature.output_tokens
|
| 224 |
+
for idx, out in enumerate(outputs):
|
| 225 |
+
value = None
|
| 226 |
+
|
| 227 |
+
if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
|
| 228 |
+
if out.name in buffer_mutations:
|
| 229 |
+
value = buffer_mutations[out.name]
|
| 230 |
+
elif out.name in user_input_mutations:
|
| 231 |
+
value = user_input_mutations[out.name]
|
| 232 |
+
|
| 233 |
+
mutated_outputs.append(value)
|
| 234 |
+
|
| 235 |
+
unlifted_gm = _unlift(
|
| 236 |
+
gm,
|
| 237 |
+
lifted_inputs,
|
| 238 |
+
mutated_outputs,
|
| 239 |
+
pytree.LeafSpec(),
|
| 240 |
+
None,
|
| 241 |
+
state_dict,
|
| 242 |
+
{},
|
| 243 |
+
)
|
| 244 |
+
return unlifted_gm
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def _get_subgraph_names(gm):
|
| 248 |
+
for node in sorted(
|
| 249 |
+
itertools.chain(
|
| 250 |
+
gm.graph.find_nodes(op="call_function", target=torch.ops.higher_order.cond),
|
| 251 |
+
gm.graph.find_nodes(
|
| 252 |
+
op="call_function", target=torch.ops.higher_order.while_loop
|
| 253 |
+
),
|
| 254 |
+
)
|
| 255 |
+
):
|
| 256 |
+
if node.target == torch.ops.higher_order.cond:
|
| 257 |
+
true_subgraph_name = node.args[1].name
|
| 258 |
+
false_subgraph_name = node.args[2].name
|
| 259 |
+
yield true_subgraph_name
|
| 260 |
+
yield false_subgraph_name
|
| 261 |
+
elif node.target == torch.ops.higher_order.while_loop:
|
| 262 |
+
cond_subgraph_name = node.args[0].name
|
| 263 |
+
body_subgraph_name = node.args[1].name
|
| 264 |
+
yield cond_subgraph_name
|
| 265 |
+
yield body_subgraph_name
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _recursive_pre_grad_passes(gm, example_inputs):
|
| 269 |
+
for subgraph_name in _get_subgraph_names(gm):
|
| 270 |
+
subgraph = getattr(gm, subgraph_name)
|
| 271 |
+
# as we don't have recursive example inputs, passing None here
|
| 272 |
+
new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None)
|
| 273 |
+
setattr(gm, subgraph_name, new_subgraph)
|
| 274 |
+
return pre_grad_passes(gm, example_inputs)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _recursive_joint_graph_passes(gm):
|
| 278 |
+
for subgraph_name in _get_subgraph_names(gm):
|
| 279 |
+
subgraph = getattr(gm, subgraph_name)
|
| 280 |
+
_recursive_joint_graph_passes(subgraph)
|
| 281 |
+
joint_graph_passes(gm)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _recursive_post_grad_passes(gm, is_inference: bool = False):
|
| 285 |
+
for subgraph_name in _get_subgraph_names(gm):
|
| 286 |
+
subgraph = getattr(gm, subgraph_name)
|
| 287 |
+
_recursive_post_grad_passes(subgraph, is_inference)
|
| 288 |
+
post_grad_passes(gm, is_inference)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def split_const_gm(
|
| 292 |
+
gm: torch.fx.GraphModule,
|
| 293 |
+
lifted_constants: Optional[Dict[str, Any]] = None,
|
| 294 |
+
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
| 295 |
+
) -> Tuple[torch.fx.GraphModule, Dict[str, int]]:
|
| 296 |
+
"""
|
| 297 |
+
This function takes an GraphModule input "gm".
|
| 298 |
+
The gm will be split into 2 components,
|
| 299 |
+
1) const_gm, which consists the subgraph of gm that can be constant folded.
|
| 300 |
+
2) gm (being inplace modified,) which returns the graph after constant folding.
|
| 301 |
+
|
| 302 |
+
If an additional "lifted_constants" argument is passed in, we will assume the gm has
|
| 303 |
+
been lifted and run the transformation accordingly.
|
| 304 |
+
|
| 305 |
+
When a "skip_folding_node_fn" callback is passed, we will skip constant folding on
|
| 306 |
+
the nodes for which the callback returns True.
|
| 307 |
+
|
| 308 |
+
const_output_index is a mapping of corresponding node name from gm to the
|
| 309 |
+
output index of const_gm.
|
| 310 |
+
Returns (const_gm, const_output_index)
|
| 311 |
+
"""
|
| 312 |
+
from torch._inductor.constant_folding import (
|
| 313 |
+
CONST_MODULE_TAG,
|
| 314 |
+
META_TAG,
|
| 315 |
+
MODULE_TAG,
|
| 316 |
+
replace_node_with_constant,
|
| 317 |
+
run_and_get_constant_graph,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
const_gm, const_result = run_and_get_constant_graph(
|
| 321 |
+
gm, lifted_constants, skip_folding_node_fn
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
const_outputs = {
|
| 325 |
+
x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0])
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
to_erase_node = []
|
| 329 |
+
to_replace_node = []
|
| 330 |
+
const_output_index = {}
|
| 331 |
+
for node in gm.graph.nodes:
|
| 332 |
+
if node.name in const_outputs:
|
| 333 |
+
to_replace_node.append(node)
|
| 334 |
+
elif node.meta[META_TAG] == CONST_MODULE_TAG and node.op != "placeholder":
|
| 335 |
+
to_erase_node.append(node)
|
| 336 |
+
|
| 337 |
+
for node in to_replace_node:
|
| 338 |
+
new_const_name = "_FOLDED_CONST_" + node.name
|
| 339 |
+
replace_node_with_constant(
|
| 340 |
+
gm,
|
| 341 |
+
node,
|
| 342 |
+
const_result[const_outputs[node.name]],
|
| 343 |
+
new_const_name,
|
| 344 |
+
)
|
| 345 |
+
const_output_index[new_const_name] = const_outputs[node.name]
|
| 346 |
+
for node in to_erase_node[::-1]:
|
| 347 |
+
if node.users:
|
| 348 |
+
for n in node.users:
|
| 349 |
+
assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty."
|
| 350 |
+
else:
|
| 351 |
+
gm.graph.erase_node(node)
|
| 352 |
+
gm.recompile()
|
| 353 |
+
|
| 354 |
+
return const_gm, const_output_index
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def is_tf32_warning_applicable(gm: torch.fx.GraphModule):
|
| 358 |
+
aten = torch.ops.aten
|
| 359 |
+
tf32_ops = {
|
| 360 |
+
aten.mm.default,
|
| 361 |
+
aten.addmm.default,
|
| 362 |
+
aten.bmm.default,
|
| 363 |
+
aten.baddbmm.default,
|
| 364 |
+
}
|
| 365 |
+
for target in tf32_ops:
|
| 366 |
+
for node in gm.graph.find_nodes(op="call_function", target=target):
|
| 367 |
+
if (
|
| 368 |
+
isinstance(node.meta.get("val", None), torch.Tensor)
|
| 369 |
+
and node.meta["val"].dtype == torch.float32
|
| 370 |
+
and node.meta["val"].device.type == "cuda"
|
| 371 |
+
):
|
| 372 |
+
return True
|
| 373 |
+
return False
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]):
|
| 377 |
+
"""
|
| 378 |
+
For CPU backend, enable comprehensive padding causes some unit tests
|
| 379 |
+
fail due to changing number of generated kernels. Skip for now.
|
| 380 |
+
"""
|
| 381 |
+
has_gpu = any(
|
| 382 |
+
is_gpu(t.device.type) for t in example_inputs if isinstance(t, torch.Tensor)
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
if config.disable_padding_cpu and config.comprehensive_padding and not has_gpu:
|
| 386 |
+
perf_hint_log.info("Skip comprehensive padding on CPU")
|
| 387 |
+
return config.patch(comprehensive_padding=False)
|
| 388 |
+
else:
|
| 389 |
+
return contextlib.nullcontext()
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def fake_tensor_prop(
|
| 393 |
+
gm: torch.fx.GraphModule,
|
| 394 |
+
example_inputs: List[torch.Tensor],
|
| 395 |
+
force_allow_non_fake_inputs: bool = False,
|
| 396 |
+
):
|
| 397 |
+
"""
|
| 398 |
+
If we can not detect fake mode from the context of inputs, create one.
|
| 399 |
+
|
| 400 |
+
The created fake mode will be returned.
|
| 401 |
+
"""
|
| 402 |
+
fake_mode = detect_fake_mode(example_inputs)
|
| 403 |
+
if not fake_mode:
|
| 404 |
+
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
|
| 405 |
+
FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
|
| 406 |
+
else:
|
| 407 |
+
ctx = (
|
| 408 |
+
contextlib.nullcontext()
|
| 409 |
+
if not force_allow_non_fake_inputs
|
| 410 |
+
else mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
|
| 411 |
+
)
|
| 412 |
+
with ctx: # type: ignore[attr-defined]
|
| 413 |
+
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
|
| 414 |
+
*example_inputs
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
return fake_mode
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def should_use_remote_fx_graph_cache():
|
| 421 |
+
if config.fx_graph_remote_cache is not None:
|
| 422 |
+
return config.fx_graph_remote_cache
|
| 423 |
+
if not config.is_fbcode():
|
| 424 |
+
return False
|
| 425 |
+
|
| 426 |
+
if torch._utils_internal.is_fb_unit_test():
|
| 427 |
+
return False
|
| 428 |
+
|
| 429 |
+
try:
|
| 430 |
+
from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
|
| 431 |
+
except ModuleNotFoundError:
|
| 432 |
+
return False
|
| 433 |
+
|
| 434 |
+
jk_name = "pytorch/remote_cache:fx_graph_memcache_version"
|
| 435 |
+
if torch.version.hip is not None:
|
| 436 |
+
jk_name = "pytorch/remote_cache:fx_graph_memcache_version_amd"
|
| 437 |
+
|
| 438 |
+
return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(jk_name)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
# pass config dict back to user
|
| 442 |
+
def get_patched_config_dict(config_patches=None) -> Dict[str, Any]:
|
| 443 |
+
with config.patch(config_patches):
|
| 444 |
+
return config.get_config_copy()
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
@contextlib.contextmanager
|
| 448 |
+
def with_fresh_cache_if_config():
|
| 449 |
+
if config.force_disable_caches:
|
| 450 |
+
# Don't delete the cache dir because it has to survive beyond the
|
| 451 |
+
# compile_fx call. Let's put the temp dirs under the default cache
|
| 452 |
+
# dir so they're easier to locate.
|
| 453 |
+
with fresh_inductor_cache(dir=cache_dir(), delete=False):
|
| 454 |
+
yield
|
| 455 |
+
else:
|
| 456 |
+
yield
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def compile_fx_inner(*args, **kwargs):
|
| 460 |
+
# Need with_fresh_cache_if_config for compile_fx_inner even if we already have one for
|
| 461 |
+
# compile_fx. The reason is the compilation for backward graph may happen after
|
| 462 |
+
# compile_fx return and we may want to use the _LazyGraphModule for compiling
|
| 463 |
+
# the backward graph as well.
|
| 464 |
+
with contextlib.ExitStack() as stack:
|
| 465 |
+
stack.enter_context(torch.utils._python_dispatch._disable_current_modes())
|
| 466 |
+
stack.enter_context(_use_lazy_graph_module(dynamo_config.use_lazy_graph_module))
|
| 467 |
+
stack.enter_context(
|
| 468 |
+
dynamo_utils.dynamo_timed(
|
| 469 |
+
"compile_fx_inner", phase_name="inductor_compile", fwd_only=False
|
| 470 |
+
)
|
| 471 |
+
)
|
| 472 |
+
stack.enter_context(with_fresh_cache_if_config())
|
| 473 |
+
stack.enter_context(DebugContext())
|
| 474 |
+
|
| 475 |
+
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
|
| 476 |
+
*args, **kwargs
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
@time_and_log(attr="compilation time (in seconds)")
|
| 481 |
+
def _compile_fx_inner(
|
| 482 |
+
gm: torch.fx.GraphModule,
|
| 483 |
+
example_inputs: List[torch.Tensor],
|
| 484 |
+
cudagraphs: Optional[BoxedBool] = None,
|
| 485 |
+
static_input_idxs: Optional[List[int]] = None,
|
| 486 |
+
is_backward: bool = False,
|
| 487 |
+
graph_id: Optional[int] = None,
|
| 488 |
+
cpp_wrapper: bool = False,
|
| 489 |
+
aot_mode: bool = False,
|
| 490 |
+
is_inference: bool = False,
|
| 491 |
+
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
|
| 492 |
+
user_visible_outputs: Optional[Dict[str, None]] = None,
|
| 493 |
+
layout_opt: Optional[bool] = None,
|
| 494 |
+
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
|
| 495 |
+
) -> Union[CompiledFxGraph, str]:
|
| 496 |
+
"""
|
| 497 |
+
Inductor API that compiles a single graph.
|
| 498 |
+
|
| 499 |
+
If you change the argument list for this function, make sure you
|
| 500 |
+
also update the call to save_args_for_compile_fx_inner below accordingly.
|
| 501 |
+
"""
|
| 502 |
+
if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode:
|
| 503 |
+
# trigger the real recompilation for _LazyGraphModule before returning
|
| 504 |
+
# the forward method.
|
| 505 |
+
from torch.fx._lazy_graph_module import _LazyGraphModule
|
| 506 |
+
|
| 507 |
+
_LazyGraphModule.force_recompile(gm)
|
| 508 |
+
return make_boxed_func(gm.forward)
|
| 509 |
+
|
| 510 |
+
if static_input_idxs is None:
|
| 511 |
+
static_input_idxs = []
|
| 512 |
+
|
| 513 |
+
static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs)
|
| 514 |
+
|
| 515 |
+
assert isinstance(
|
| 516 |
+
next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)
|
| 517 |
+
), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
|
| 518 |
+
|
| 519 |
+
if config.save_args:
|
| 520 |
+
save_args_for_compile_fx_inner(
|
| 521 |
+
gm,
|
| 522 |
+
example_inputs,
|
| 523 |
+
cudagraphs=cudagraphs,
|
| 524 |
+
static_input_idxs=static_input_idxs,
|
| 525 |
+
is_backward=is_backward,
|
| 526 |
+
graph_id=graph_id,
|
| 527 |
+
cpp_wrapper=cpp_wrapper,
|
| 528 |
+
aot_mode=aot_mode,
|
| 529 |
+
is_inference=is_inference,
|
| 530 |
+
boxed_forward_device_index=boxed_forward_device_index,
|
| 531 |
+
user_visible_outputs=user_visible_outputs,
|
| 532 |
+
layout_opt=layout_opt,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
if cudagraphs is None:
|
| 536 |
+
cudagraphs = BoxedBool(config.triton.cudagraphs)
|
| 537 |
+
|
| 538 |
+
# Inputs to fx_codegen_and_compile
|
| 539 |
+
# Anything that affects codegen should go here, so if the signature
|
| 540 |
+
# of fx_codegen_and_compile changes, the dict should be updated accordingly
|
| 541 |
+
graph_kwargs = {
|
| 542 |
+
"cudagraphs": cudagraphs,
|
| 543 |
+
"static_input_idxs": static_input_idxs,
|
| 544 |
+
"is_backward": is_backward,
|
| 545 |
+
"graph_id": graph_id,
|
| 546 |
+
"cpp_wrapper": cpp_wrapper,
|
| 547 |
+
"aot_mode": aot_mode,
|
| 548 |
+
"is_inference": is_inference,
|
| 549 |
+
"user_visible_outputs": user_visible_outputs,
|
| 550 |
+
"layout_opt": layout_opt,
|
| 551 |
+
"extern_node_serializer": extern_node_serializer,
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
start = time.time()
|
| 555 |
+
|
| 556 |
+
fx_graph_remote_cache = should_use_remote_fx_graph_cache()
|
| 557 |
+
|
| 558 |
+
inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) # type: ignore[arg-type]
|
| 559 |
+
|
| 560 |
+
def codegen_and_compile(
|
| 561 |
+
gm,
|
| 562 |
+
example_inputs,
|
| 563 |
+
inputs_to_check,
|
| 564 |
+
fx_kwargs,
|
| 565 |
+
):
|
| 566 |
+
"""
|
| 567 |
+
This function calls fx_codegen_and_compile and also adds some extra metadata to the resulting
|
| 568 |
+
compiled fx graph. The metadata is saved to FXGraphCache.
|
| 569 |
+
"""
|
| 570 |
+
compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
|
| 571 |
+
if isinstance(compiled_graph, str):
|
| 572 |
+
# We only return a string in aot mode, in which case we don't
|
| 573 |
+
# need to do any post-compilation steps: we just return the string,
|
| 574 |
+
# which is the filename of the compiled code.
|
| 575 |
+
return compiled_graph
|
| 576 |
+
cudagraph_info = None
|
| 577 |
+
if cudagraphs:
|
| 578 |
+
# check cudagraph disabling reasons from inductor lowering
|
| 579 |
+
if compiled_graph.disabled_cudagraphs_reason:
|
| 580 |
+
if "cuda" in compiled_graph.device_types:
|
| 581 |
+
log_cudagraph_skip_and_bump_counter(
|
| 582 |
+
f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}"
|
| 583 |
+
)
|
| 584 |
+
else:
|
| 585 |
+
counters["inductor"]["cudagraph_skips"] += 1
|
| 586 |
+
BoxedBool.disable(cudagraphs)
|
| 587 |
+
else:
|
| 588 |
+
complex_memory_overlap_inputs = any(
|
| 589 |
+
complex_memory_overlap(t)
|
| 590 |
+
for t in example_inputs
|
| 591 |
+
if isinstance(t, torch.Tensor)
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
if not config.triton.cudagraph_support_input_mutation:
|
| 595 |
+
# Skip supports for cudagraph-managed tensors
|
| 596 |
+
from torch._inductor.cudagraph_utils import (
|
| 597 |
+
check_for_mutation_ignore_cuda_graph_managed_tensor,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
has_mutation_str = (
|
| 601 |
+
check_for_mutation_ignore_cuda_graph_managed_tensor(
|
| 602 |
+
gm,
|
| 603 |
+
compiled_graph,
|
| 604 |
+
static_input_idxs, # type:ignore[arg-type]
|
| 605 |
+
)
|
| 606 |
+
)
|
| 607 |
+
has_mutation = has_mutation_str is not None
|
| 608 |
+
|
| 609 |
+
if has_mutation:
|
| 610 |
+
compiled_graph.disabled_cudagraphs_reason = has_mutation_str
|
| 611 |
+
else:
|
| 612 |
+
# Check mutation later to support cudagraph-managed tensors
|
| 613 |
+
has_mutation = None
|
| 614 |
+
|
| 615 |
+
cudagraph_tests = [
|
| 616 |
+
(not has_mutation, "mutated inputs"),
|
| 617 |
+
(not has_incompatible_cudagraph_ops(gm), "incompatible ops"),
|
| 618 |
+
(not complex_memory_overlap_inputs, "complex memory overlap"),
|
| 619 |
+
(
|
| 620 |
+
all(
|
| 621 |
+
isinstance(t, (torch.Tensor, torch.SymInt))
|
| 622 |
+
for t in example_inputs
|
| 623 |
+
),
|
| 624 |
+
"non-Tensor inputs",
|
| 625 |
+
),
|
| 626 |
+
]
|
| 627 |
+
output = output_node(gm)
|
| 628 |
+
# output args are tuple of first argument
|
| 629 |
+
assert len(output.args) == 1
|
| 630 |
+
stack_traces = [
|
| 631 |
+
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
|
| 632 |
+
for arg in output.args[0]
|
| 633 |
+
]
|
| 634 |
+
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
|
| 635 |
+
placeholders = tuple(get_placeholder_info(gm.graph))
|
| 636 |
+
cudagraph_info = CudagraphCachedInfo(
|
| 637 |
+
placeholders, stack_traces, cudagraph_fail_reasons
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
compiled_graph.cudagraph_info = cudagraph_info
|
| 641 |
+
compiled_graph.inputs_to_check = inputs_to_check
|
| 642 |
+
compiled_graph.fx_kwargs = fx_kwargs
|
| 643 |
+
# TODO: should this be part of fx_kwargs
|
| 644 |
+
compiled_graph.boxed_forward_device_index = boxed_forward_device_index
|
| 645 |
+
return compiled_graph
|
| 646 |
+
|
| 647 |
+
with _WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _:
|
| 648 |
+
if (
|
| 649 |
+
not config.force_disable_caches
|
| 650 |
+
and (config.fx_graph_cache or fx_graph_remote_cache)
|
| 651 |
+
and not aot_mode
|
| 652 |
+
):
|
| 653 |
+
for i, input in enumerate(example_inputs):
|
| 654 |
+
if (
|
| 655 |
+
isinstance(input, torch.Tensor)
|
| 656 |
+
and input.device.type == "cuda"
|
| 657 |
+
and i in static_input_idxs
|
| 658 |
+
):
|
| 659 |
+
input._is_inductor_static = True # type: ignore[attr-defined]
|
| 660 |
+
|
| 661 |
+
compiled_graph = FxGraphCache.load(
|
| 662 |
+
codegen_and_compile,
|
| 663 |
+
gm,
|
| 664 |
+
example_inputs,
|
| 665 |
+
graph_kwargs,
|
| 666 |
+
inputs_to_check,
|
| 667 |
+
local=config.fx_graph_cache,
|
| 668 |
+
remote=fx_graph_remote_cache,
|
| 669 |
+
)
|
| 670 |
+
else:
|
| 671 |
+
compiled_graph = codegen_and_compile(
|
| 672 |
+
gm, example_inputs, inputs_to_check, graph_kwargs # type: ignore[arg-type]
|
| 673 |
+
)
|
| 674 |
+
if aot_mode:
|
| 675 |
+
# AOT mode is special because codegen_and_compile returns a string.
|
| 676 |
+
# In that case, we don't need to run all post compilation steps, we just need
|
| 677 |
+
# to return the string directly.
|
| 678 |
+
return compiled_graph
|
| 679 |
+
compiled_graph = FxGraphCache.post_compile(
|
| 680 |
+
compiled_graph, example_inputs, cudagraphs
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
|
| 684 |
+
|
| 685 |
+
_step_logger()(
|
| 686 |
+
logging.INFO,
|
| 687 |
+
"torchinductor done compiling "
|
| 688 |
+
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
|
| 689 |
+
f"graph {graph_id}",
|
| 690 |
+
)
|
| 691 |
+
# aot autograd needs to know to pass in inputs as a list
|
| 692 |
+
compiled_graph._boxed_call = True
|
| 693 |
+
return compiled_graph
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def fx_codegen_and_compile(
|
| 697 |
+
gm: torch.fx.GraphModule,
|
| 698 |
+
example_inputs: List[torch.Tensor],
|
| 699 |
+
cudagraphs: Optional[BoxedBool] = None,
|
| 700 |
+
static_input_idxs: Optional[List[int]] = None,
|
| 701 |
+
is_backward: bool = False,
|
| 702 |
+
graph_id: Optional[int] = None,
|
| 703 |
+
cpp_wrapper: bool = False,
|
| 704 |
+
aot_mode: bool = False,
|
| 705 |
+
is_inference: bool = False,
|
| 706 |
+
# Use a dict with None value rather than a set for deterministic
|
| 707 |
+
# iteration order just in case.
|
| 708 |
+
user_visible_outputs: Optional[Dict[str, None]] = None,
|
| 709 |
+
layout_opt: Optional[bool] = None,
|
| 710 |
+
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
|
| 711 |
+
) -> Union[CompiledFxGraph, str]:
|
| 712 |
+
if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None:
|
| 713 |
+
import time
|
| 714 |
+
|
| 715 |
+
log.warning("Sleeping for %s since sleep_sec_TESTING_ONLY is set", sleep_sec)
|
| 716 |
+
time.sleep(sleep_sec)
|
| 717 |
+
|
| 718 |
+
with dynamo_utils.preserve_rng_state():
|
| 719 |
+
if is_tf32_warning_applicable(gm):
|
| 720 |
+
_warn_tf32_disabled()
|
| 721 |
+
|
| 722 |
+
inductor_counters = counters["inductor"].copy()
|
| 723 |
+
|
| 724 |
+
# lift the maximum depth of the Python interpreter stack
|
| 725 |
+
# to adapt large/deep models
|
| 726 |
+
sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
|
| 727 |
+
|
| 728 |
+
_step_logger()(
|
| 729 |
+
logging.INFO,
|
| 730 |
+
"torchinductor compiling "
|
| 731 |
+
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
|
| 732 |
+
f"graph {graph_id}",
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
def log_graph_runnable():
|
| 736 |
+
fd = io.StringIO()
|
| 737 |
+
torch._dynamo.repro.after_aot.save_graph_repro(
|
| 738 |
+
fd, gm, example_inputs, "inductor", save_dir=None
|
| 739 |
+
)
|
| 740 |
+
return fd.getvalue()
|
| 741 |
+
|
| 742 |
+
torch._logging.trace_structured(
|
| 743 |
+
"artifact",
|
| 744 |
+
metadata_fn=lambda: {
|
| 745 |
+
"name": "fx_graph_runnable",
|
| 746 |
+
"encoding": "string",
|
| 747 |
+
},
|
| 748 |
+
payload_fn=lambda: log_graph_runnable(),
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
V.debug.fx_graph(gm, example_inputs)
|
| 752 |
+
# TODO: Should we actually dump this? It should be redundant with the aot
|
| 753 |
+
# structured logs...
|
| 754 |
+
# trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False))
|
| 755 |
+
|
| 756 |
+
shape_env = shape_env_from_inputs(example_inputs)
|
| 757 |
+
|
| 758 |
+
# Convert view to reshape in the graph. This is necessary primarily for
|
| 759 |
+
# layout optimization. Do it unconditionally for uniformity.
|
| 760 |
+
#
|
| 761 |
+
# It's needed because when we do layout optimization, an contiguous tensor
|
| 762 |
+
# in eager mode may becomes a channels last tensor. A view op previously
|
| 763 |
+
# can be applied to the contiguous tensor may not be able to be applied
|
| 764 |
+
# on the channels tensor any more. An error like
|
| 765 |
+
# RuntimeError: view size is not compatible with input tensor's size and stride
|
| 766 |
+
# (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
|
| 767 |
+
# will be printed.
|
| 768 |
+
#
|
| 769 |
+
# Replace view op to reshape op in this case.
|
| 770 |
+
# As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this.
|
| 771 |
+
#
|
| 772 |
+
# Also this has to be done before FakeTensorProp below to avoid the failed
|
| 773 |
+
# .view() call.
|
| 774 |
+
view_to_reshape(gm)
|
| 775 |
+
|
| 776 |
+
# It is safe to run FakeTensorProp under no_grad because by the time
|
| 777 |
+
# we're in inductor, we assume that AOTAutograd has already "taken care"
|
| 778 |
+
# of autograd, so there should be no more autograd-related API's in the
|
| 779 |
+
# graph.
|
| 780 |
+
with torch.no_grad():
|
| 781 |
+
fake_mode = fake_tensor_prop(gm, example_inputs)
|
| 782 |
+
|
| 783 |
+
# pattern matcher passes might not preserve striding information
|
| 784 |
+
# on node.meta["val"]. if in the future we rely on these being
|
| 785 |
+
# correct we will need to fix.
|
| 786 |
+
|
| 787 |
+
with V.set_fake_mode(fake_mode):
|
| 788 |
+
# has some issues with memory in training
|
| 789 |
+
_recursive_post_grad_passes(gm, is_inference=is_inference)
|
| 790 |
+
V.debug.fx_graph_transformed(gm, example_inputs)
|
| 791 |
+
post_grad_graphs_log.debug(
|
| 792 |
+
"%s",
|
| 793 |
+
lazy_format_graph_code(
|
| 794 |
+
"AFTER POST GRAD",
|
| 795 |
+
gm,
|
| 796 |
+
include_stride=True,
|
| 797 |
+
include_device=True,
|
| 798 |
+
colored=True,
|
| 799 |
+
),
|
| 800 |
+
)
|
| 801 |
+
trace_structured(
|
| 802 |
+
"inductor_post_grad_graph",
|
| 803 |
+
payload_fn=lambda: gm.print_readable(
|
| 804 |
+
print_output=False, include_stride=True, include_device=True
|
| 805 |
+
),
|
| 806 |
+
)
|
| 807 |
+
if config.is_fbcode():
|
| 808 |
+
log_optimus_to_scuba(
|
| 809 |
+
extra_logging={"pt2_configs": str(get_patched_config_dict())}
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding(
|
| 813 |
+
example_inputs
|
| 814 |
+
):
|
| 815 |
+
const_output_index = None
|
| 816 |
+
const_graph = None
|
| 817 |
+
const_code = None
|
| 818 |
+
|
| 819 |
+
if aot_mode and config.aot_inductor.use_runtime_constant_folding:
|
| 820 |
+
const_gm, const_output_index = split_const_gm(gm)
|
| 821 |
+
|
| 822 |
+
const_graph = GraphLowering(
|
| 823 |
+
const_gm,
|
| 824 |
+
example_inputs=[],
|
| 825 |
+
shape_env=shape_env,
|
| 826 |
+
graph_id=graph_id,
|
| 827 |
+
cpp_wrapper=cpp_wrapper,
|
| 828 |
+
aot_mode=aot_mode,
|
| 829 |
+
user_visible_outputs=user_visible_outputs,
|
| 830 |
+
extern_node_serializer=extern_node_serializer,
|
| 831 |
+
is_inference=is_inference,
|
| 832 |
+
is_const_graph=True,
|
| 833 |
+
)
|
| 834 |
+
with V.set_graph_handler(const_graph):
|
| 835 |
+
assert cpp_wrapper, "AOT mode only supports C++ wrapper"
|
| 836 |
+
const_graph.run()
|
| 837 |
+
|
| 838 |
+
const_code, _ = const_graph.codegen_with_cpp_wrapper()
|
| 839 |
+
|
| 840 |
+
graph = GraphLowering(
|
| 841 |
+
gm,
|
| 842 |
+
# example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
|
| 843 |
+
# For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
|
| 844 |
+
# we currently use fake tensors and defake them later.
|
| 845 |
+
example_inputs=example_inputs,
|
| 846 |
+
shape_env=shape_env,
|
| 847 |
+
graph_id=graph_id,
|
| 848 |
+
cpp_wrapper=cpp_wrapper,
|
| 849 |
+
aot_mode=aot_mode,
|
| 850 |
+
user_visible_outputs=user_visible_outputs,
|
| 851 |
+
extern_node_serializer=extern_node_serializer,
|
| 852 |
+
is_inference=is_inference,
|
| 853 |
+
const_output_index=const_output_index,
|
| 854 |
+
const_code=const_code,
|
| 855 |
+
const_module=const_graph,
|
| 856 |
+
)
|
| 857 |
+
metrics_helper = metrics.CachedMetricsHelper()
|
| 858 |
+
with V.set_graph_handler(graph):
|
| 859 |
+
graph.run(*example_inputs)
|
| 860 |
+
output_strides: List[Optional[Tuple[_StrideExprStr, ...]]] = []
|
| 861 |
+
if graph.graph_outputs is not None:
|
| 862 |
+
# We'll put the output strides in the compiled graph so we
|
| 863 |
+
# can later return them to the caller via TracingContext
|
| 864 |
+
p = SymExprPrinter()
|
| 865 |
+
for out in graph.graph_outputs:
|
| 866 |
+
if (
|
| 867 |
+
hasattr(out, "layout")
|
| 868 |
+
and len(free_unbacked_symbols(out.layout.stride)) == 0
|
| 869 |
+
):
|
| 870 |
+
# Convert to string for eval on the load path
|
| 871 |
+
output_strides.append(
|
| 872 |
+
tuple(p.doprint(s) for s in out.layout.stride)
|
| 873 |
+
)
|
| 874 |
+
else:
|
| 875 |
+
output_strides.append(None)
|
| 876 |
+
|
| 877 |
+
_check_triton_bf16_support(graph)
|
| 878 |
+
compiled_fn = graph.compile_to_fn()
|
| 879 |
+
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
|
| 880 |
+
metrics.num_bytes_accessed += num_bytes
|
| 881 |
+
metrics.node_runtimes += node_runtimes
|
| 882 |
+
metrics.nodes_num_elem += nodes_num_elem
|
| 883 |
+
|
| 884 |
+
if (
|
| 885 |
+
cudagraphs
|
| 886 |
+
and config.triton.cudagraph_skip_dynamic_graphs
|
| 887 |
+
and not V.graph.disable_cudagraphs_reason
|
| 888 |
+
and torch._inductor.utils.any_is_symbolic(*example_inputs)
|
| 889 |
+
):
|
| 890 |
+
stack_trace = None
|
| 891 |
+
for node in gm.graph.nodes:
|
| 892 |
+
meta_val = node.meta.get("val", None)
|
| 893 |
+
if (
|
| 894 |
+
node.op == "placeholder"
|
| 895 |
+
or not isinstance(meta_val, torch.Tensor)
|
| 896 |
+
or not torch._inductor.utils.any_is_symbolic(meta_val)
|
| 897 |
+
):
|
| 898 |
+
continue
|
| 899 |
+
|
| 900 |
+
if stack_trace := node.meta.get("stack_trace", None):
|
| 901 |
+
break
|
| 902 |
+
disable = "graph with symbolic shapes inputs and config.triton.cudagraph_skip_dynamic_graphs=True."
|
| 903 |
+
if stack_trace:
|
| 904 |
+
disable = f"{disable} Found from {stack_trace}\n"
|
| 905 |
+
else:
|
| 906 |
+
disable = f"{disable}\n"
|
| 907 |
+
V.graph.disable_cudagraphs_reason = disable
|
| 908 |
+
|
| 909 |
+
if V.aot_compilation is True:
|
| 910 |
+
return compiled_fn
|
| 911 |
+
|
| 912 |
+
if cudagraphs and not V.graph.disable_cudagraphs_reason:
|
| 913 |
+
from torch._inductor.cudagraph_utils import (
|
| 914 |
+
check_lowering_disable_cudagraph,
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
V.graph.disable_cudagraphs_reason = (
|
| 918 |
+
check_lowering_disable_cudagraph(V.graph.device_node_mapping)
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
compiled_graph = CompiledFxGraph(
|
| 922 |
+
compiled_fn,
|
| 923 |
+
graph,
|
| 924 |
+
output_strides,
|
| 925 |
+
V.graph.disable_cudagraphs_reason,
|
| 926 |
+
metrics_helper.get_deltas(),
|
| 927 |
+
counters["inductor"] - inductor_counters,
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
return compiled_graph
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
def get_input_idxs_to_check(
|
| 934 |
+
inputs: List[InputType],
|
| 935 |
+
static_input_idxs: Sequence[int],
|
| 936 |
+
) -> Sequence[int]:
|
| 937 |
+
"""
|
| 938 |
+
This function runs at compile time, and generates a list of indices for which we
|
| 939 |
+
might need to do a copy to preserve alignment requirements.
|
| 940 |
+
"""
|
| 941 |
+
ids_to_check = []
|
| 942 |
+
|
| 943 |
+
for i, input in enumerate(inputs):
|
| 944 |
+
if not isinstance(input, torch.Tensor):
|
| 945 |
+
# non-tensors don't need alignment
|
| 946 |
+
continue
|
| 947 |
+
if not is_gpu(input.device.type):
|
| 948 |
+
# right now we only care for gpu tensors
|
| 949 |
+
continue
|
| 950 |
+
with maybe_get_suppress_shape_guards_ctx():
|
| 951 |
+
# suppress guards so that tensor_is_aligned and should_assume_input_aligned
|
| 952 |
+
# do not add guards on input's storage offset
|
| 953 |
+
if i in static_input_idxs and tensor_is_aligned(input):
|
| 954 |
+
continue
|
| 955 |
+
if not should_assume_input_aligned(input):
|
| 956 |
+
continue
|
| 957 |
+
|
| 958 |
+
# if we get here, then
|
| 959 |
+
# (a) our triton code assumes that the input is aligned
|
| 960 |
+
# (b) we can't be sure ahead of time that the input will actually be aligned.
|
| 961 |
+
# therefore, at runtime, we'll need to check that the input is aligned
|
| 962 |
+
# (and if not, clone it to make it aligned.)
|
| 963 |
+
ids_to_check.append(i)
|
| 964 |
+
|
| 965 |
+
return ids_to_check
|
| 966 |
+
|
| 967 |
+
|
| 968 |
+
def cudagraphify(
|
| 969 |
+
model: Callable[..., Any],
|
| 970 |
+
static_input_idxs: Sequence[int] = (),
|
| 971 |
+
*,
|
| 972 |
+
device_index: int,
|
| 973 |
+
stack_traces: List[Optional[str]],
|
| 974 |
+
is_backward: bool,
|
| 975 |
+
is_inference: bool,
|
| 976 |
+
constants: Tuple[torch.Tensor, ...] = (),
|
| 977 |
+
placeholders: Sequence[PlaceholderInfo] = (),
|
| 978 |
+
mutated_input_idxs: Tuple[int, ...] = (),
|
| 979 |
+
) -> Callable[..., Any]:
|
| 980 |
+
from torch._inductor.cudagraph_trees import (
|
| 981 |
+
cudagraphify_impl as new_cudagraphify_impl,
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
cudagraphify_fn: Callable[..., Any]
|
| 985 |
+
if config.triton.cudagraph_trees:
|
| 986 |
+
cudagraphify_fn = functools.partial(
|
| 987 |
+
new_cudagraphify_impl,
|
| 988 |
+
device_index=device_index,
|
| 989 |
+
stack_traces=stack_traces,
|
| 990 |
+
is_backward=is_backward,
|
| 991 |
+
is_inference=is_inference,
|
| 992 |
+
constants=constants,
|
| 993 |
+
placeholders=placeholders,
|
| 994 |
+
mutated_input_idxs=mutated_input_idxs,
|
| 995 |
+
)
|
| 996 |
+
else:
|
| 997 |
+
cudagraphify_fn = cudagraphify_impl
|
| 998 |
+
|
| 999 |
+
compiled_fn = None
|
| 1000 |
+
|
| 1001 |
+
def run(new_inputs):
|
| 1002 |
+
nonlocal compiled_fn
|
| 1003 |
+
if compiled_fn is None:
|
| 1004 |
+
with dynamo_utils.dynamo_timed(
|
| 1005 |
+
"cudagraphify"
|
| 1006 |
+
), dynamo_utils.preserve_rng_state():
|
| 1007 |
+
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
|
| 1008 |
+
return compiled_fn(new_inputs)
|
| 1009 |
+
|
| 1010 |
+
return run
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
def static_input(x: torch.Tensor) -> torch.Tensor:
|
| 1014 |
+
"""
|
| 1015 |
+
Copy and input while preserving strides
|
| 1016 |
+
"""
|
| 1017 |
+
return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
def index_expanded_dims_and_copy_(
|
| 1021 |
+
dst: torch.Tensor,
|
| 1022 |
+
src: torch.Tensor,
|
| 1023 |
+
expanded_dims: List[int],
|
| 1024 |
+
):
|
| 1025 |
+
"Index into expanded dimensions of both dst and src then copy_"
|
| 1026 |
+
dst = index_expanded_dims(dst, expanded_dims)
|
| 1027 |
+
src = index_expanded_dims(src, expanded_dims)
|
| 1028 |
+
dst.copy_(src)
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
def cudagraphify_impl(
|
| 1032 |
+
model: Callable[..., Any],
|
| 1033 |
+
inputs: List[torch.Tensor],
|
| 1034 |
+
static_input_idxs: Sequence[int] = (),
|
| 1035 |
+
):
|
| 1036 |
+
"""
|
| 1037 |
+
Assumes inputs[static_input_idxs[i]] are always the same memory address
|
| 1038 |
+
"""
|
| 1039 |
+
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type]
|
| 1040 |
+
static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type]
|
| 1041 |
+
copy_misaligned_inputs(inputs, check_input_idxs) # type: ignore[arg-type]
|
| 1042 |
+
|
| 1043 |
+
assert isinstance(inputs, list)
|
| 1044 |
+
|
| 1045 |
+
inps_expanded_dims = [
|
| 1046 |
+
get_expanded_dims(x) if idx not in static_input_idxs else []
|
| 1047 |
+
for idx, x in enumerate(inputs)
|
| 1048 |
+
]
|
| 1049 |
+
|
| 1050 |
+
# allocate static tensor inputs
|
| 1051 |
+
static_inputs = [
|
| 1052 |
+
x
|
| 1053 |
+
if not isinstance(x, torch.Tensor)
|
| 1054 |
+
else static_input(x)
|
| 1055 |
+
if idx not in static_input_idxs
|
| 1056 |
+
else x.detach()
|
| 1057 |
+
for idx, x in enumerate(inputs)
|
| 1058 |
+
]
|
| 1059 |
+
|
| 1060 |
+
# copy over input values for fresh allocations
|
| 1061 |
+
for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)):
|
| 1062 |
+
if isinstance(x, torch.Tensor) and idx not in static_input_idxs:
|
| 1063 |
+
index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims)
|
| 1064 |
+
|
| 1065 |
+
# warmup
|
| 1066 |
+
torch.cuda.synchronize()
|
| 1067 |
+
stream = torch.cuda.Stream()
|
| 1068 |
+
stream.wait_stream(torch.cuda.current_stream())
|
| 1069 |
+
# copy static_inputs because it will be cleared in model
|
| 1070 |
+
with torch.cuda.stream(stream):
|
| 1071 |
+
model(list(static_inputs))
|
| 1072 |
+
stream.synchronize()
|
| 1073 |
+
torch.cuda.current_stream().wait_stream(stream)
|
| 1074 |
+
torch.cuda.synchronize()
|
| 1075 |
+
|
| 1076 |
+
# record
|
| 1077 |
+
graph = torch.cuda.CUDAGraph()
|
| 1078 |
+
with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):
|
| 1079 |
+
static_outputs = model(list(static_inputs))
|
| 1080 |
+
if not isinstance(static_outputs, (list, tuple)):
|
| 1081 |
+
static_outputs = (static_outputs,)
|
| 1082 |
+
|
| 1083 |
+
if config.size_asserts:
|
| 1084 |
+
|
| 1085 |
+
def run(new_inputs):
|
| 1086 |
+
assert len(static_inputs) == len(new_inputs)
|
| 1087 |
+
for idx, (dst, src, expanded_dims) in enumerate(
|
| 1088 |
+
zip(static_inputs, new_inputs, inps_expanded_dims)
|
| 1089 |
+
):
|
| 1090 |
+
if not isinstance(dst, torch.Tensor):
|
| 1091 |
+
pass
|
| 1092 |
+
elif idx in static_input_idxs:
|
| 1093 |
+
assert dst.data_ptr() == src.data_ptr()
|
| 1094 |
+
else:
|
| 1095 |
+
# TODO - could make one single op of multiple slices
|
| 1096 |
+
# and avoid dispatch.
|
| 1097 |
+
# Could also pre-index the `dst` tensors
|
| 1098 |
+
index_expanded_dims_and_copy_(dst, src, expanded_dims)
|
| 1099 |
+
new_inputs.clear()
|
| 1100 |
+
graph.replay()
|
| 1101 |
+
return static_outputs
|
| 1102 |
+
|
| 1103 |
+
else:
|
| 1104 |
+
copy_indices = [
|
| 1105 |
+
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
|
| 1106 |
+
]
|
| 1107 |
+
|
| 1108 |
+
def run(new_inputs):
|
| 1109 |
+
for idx in copy_indices:
|
| 1110 |
+
expanded_dims = inps_expanded_dims[idx]
|
| 1111 |
+
index_expanded_dims_and_copy_(
|
| 1112 |
+
static_inputs[idx], new_inputs[idx], expanded_dims
|
| 1113 |
+
)
|
| 1114 |
+
new_inputs.clear()
|
| 1115 |
+
graph.replay()
|
| 1116 |
+
return static_outputs
|
| 1117 |
+
|
| 1118 |
+
return align_inputs_from_check_idxs(run, check_input_idxs)
|
| 1119 |
+
|
| 1120 |
+
|
| 1121 |
+
def compile_fx_aot(
|
| 1122 |
+
model_: torch.fx.GraphModule,
|
| 1123 |
+
example_inputs_: List[torch.Tensor],
|
| 1124 |
+
inner_compile: Callable[..., Any] = compile_fx_inner,
|
| 1125 |
+
config_patches: Optional[Dict[str, Any]] = None,
|
| 1126 |
+
):
|
| 1127 |
+
config_patches: Dict[str, Any] = (
|
| 1128 |
+
{"cpp_wrapper": True}
|
| 1129 |
+
if config_patches is None
|
| 1130 |
+
else {**config_patches, "cpp_wrapper": True}
|
| 1131 |
+
)
|
| 1132 |
+
if (
|
| 1133 |
+
"aot_inductor.output_path" not in config_patches
|
| 1134 |
+
and not config.aot_inductor.output_path
|
| 1135 |
+
):
|
| 1136 |
+
config_patches = {
|
| 1137 |
+
**config_patches,
|
| 1138 |
+
"aot_inductor.output_path": code_hash(model_.code),
|
| 1139 |
+
}
|
| 1140 |
+
|
| 1141 |
+
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
|
| 1142 |
+
with V.set_aot_compilation(True):
|
| 1143 |
+
compiled_lib_path = compile_fx(
|
| 1144 |
+
model_,
|
| 1145 |
+
example_inputs_,
|
| 1146 |
+
inner_compile=functools.partial(
|
| 1147 |
+
inner_compile,
|
| 1148 |
+
aot_mode=True,
|
| 1149 |
+
extern_node_serializer=extern_node_serializer,
|
| 1150 |
+
),
|
| 1151 |
+
config_patches=config_patches,
|
| 1152 |
+
)
|
| 1153 |
+
assert os.path.exists(
|
| 1154 |
+
compiled_lib_path
|
| 1155 |
+
), f"AOTInductor compiled library does not exist at {compiled_lib_path}"
|
| 1156 |
+
return compiled_lib_path
|
| 1157 |
+
|
| 1158 |
+
|
| 1159 |
+
_graph_counter = count(0)
|
| 1160 |
+
|
| 1161 |
+
|
| 1162 |
+
def fw_compiler_freezing(
|
| 1163 |
+
aot_autograd_model: torch.fx.GraphModule,
|
| 1164 |
+
aot_example_inputs: List[torch.Tensor],
|
| 1165 |
+
dynamo_model: torch.fx.GraphModule,
|
| 1166 |
+
num_example_inputs: int,
|
| 1167 |
+
inner_compile: Callable[..., Any],
|
| 1168 |
+
cudagraphs: BoxedBool,
|
| 1169 |
+
graph_id: int,
|
| 1170 |
+
forward_device: BoxedDeviceIndex,
|
| 1171 |
+
):
|
| 1172 |
+
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
|
| 1173 |
+
|
| 1174 |
+
# partition_fn won't be called
|
| 1175 |
+
_recursive_joint_graph_passes(aot_autograd_model)
|
| 1176 |
+
|
| 1177 |
+
layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True)
|
| 1178 |
+
if layout_opt:
|
| 1179 |
+
# make sure meta['val'] is properly setup
|
| 1180 |
+
fake_tensor_prop(aot_autograd_model, aot_example_inputs, True)
|
| 1181 |
+
convert_conv_weights_to_channels_last(aot_autograd_model)
|
| 1182 |
+
|
| 1183 |
+
opt_model, preserved_arg_indices = freeze(
|
| 1184 |
+
dynamo_model,
|
| 1185 |
+
aot_autograd_model,
|
| 1186 |
+
aot_example_inputs, # type: ignore[arg-type]
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
|
| 1190 |
+
num_fixed = len(preserved_arg_indices) - num_example_inputs
|
| 1191 |
+
|
| 1192 |
+
fake_mode = detect_fake_mode(aot_example_inputs)
|
| 1193 |
+
|
| 1194 |
+
# for freezing, all graph outputs should be user visible
|
| 1195 |
+
*_, model_outputs_node = opt_model.graph.nodes
|
| 1196 |
+
model_outputs = model_outputs_node.args[0]
|
| 1197 |
+
user_visible_outputs = dict.fromkeys(
|
| 1198 |
+
n.name for n in model_outputs if isinstance(n, torch.fx.Node)
|
| 1199 |
+
)
|
| 1200 |
+
|
| 1201 |
+
static_input_idxs = list(range(num_fixed))
|
| 1202 |
+
# constant params will be real tensors, not fake
|
| 1203 |
+
tracing_context = torch._guards.TracingContext.try_get()
|
| 1204 |
+
if tracing_context is not None:
|
| 1205 |
+
params_flat = tracing_context.params_flat
|
| 1206 |
+
assert params_flat is not None
|
| 1207 |
+
for i in range(len(params_flat)):
|
| 1208 |
+
if i not in preserved_arg_indices:
|
| 1209 |
+
params_flat[i] = None
|
| 1210 |
+
|
| 1211 |
+
if tracing_context.fw_metadata:
|
| 1212 |
+
static_input_idxs += tracing_context.fw_metadata.static_input_indices
|
| 1213 |
+
|
| 1214 |
+
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
|
| 1215 |
+
optimized_function = inner_compile(
|
| 1216 |
+
opt_model,
|
| 1217 |
+
aot_example_inputs,
|
| 1218 |
+
static_input_idxs=static_input_idxs,
|
| 1219 |
+
cudagraphs=cudagraphs,
|
| 1220 |
+
graph_id=graph_id,
|
| 1221 |
+
is_inference=True,
|
| 1222 |
+
boxed_forward_device_index=forward_device,
|
| 1223 |
+
layout_opt=layout_opt,
|
| 1224 |
+
user_visible_outputs=user_visible_outputs,
|
| 1225 |
+
)
|
| 1226 |
+
|
| 1227 |
+
# aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper
|
| 1228 |
+
# that drops constant-ified params
|
| 1229 |
+
if V.aot_compilation is True:
|
| 1230 |
+
return optimized_function
|
| 1231 |
+
|
| 1232 |
+
def wrapper(args):
|
| 1233 |
+
args_new = [args[i] for i in preserved_arg_indices]
|
| 1234 |
+
args.clear()
|
| 1235 |
+
return optimized_function(args_new)
|
| 1236 |
+
|
| 1237 |
+
wrapper._boxed_call = True # type: ignore[attr-defined]
|
| 1238 |
+
|
| 1239 |
+
return wrapper
|
| 1240 |
+
|
| 1241 |
+
|
| 1242 |
+
def compile_fx(
|
| 1243 |
+
model_: torch.fx.GraphModule,
|
| 1244 |
+
example_inputs_: List[torch.Tensor],
|
| 1245 |
+
inner_compile: Callable[..., Any] = compile_fx_inner,
|
| 1246 |
+
config_patches: Optional[Dict[str, Any]] = None,
|
| 1247 |
+
decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
|
| 1248 |
+
):
|
| 1249 |
+
with _use_lazy_graph_module(dynamo_config.use_lazy_graph_module):
|
| 1250 |
+
"""Main entrypoint to a compile given FX graph"""
|
| 1251 |
+
if config_patches:
|
| 1252 |
+
with config.patch(config_patches):
|
| 1253 |
+
return compile_fx(
|
| 1254 |
+
model_,
|
| 1255 |
+
example_inputs_,
|
| 1256 |
+
# need extra layer of patching as backwards is compiled out of scope
|
| 1257 |
+
inner_compile=config.patch(config_patches)(inner_compile),
|
| 1258 |
+
decompositions=decompositions,
|
| 1259 |
+
)
|
| 1260 |
+
|
| 1261 |
+
if config.cpp_wrapper:
|
| 1262 |
+
with config.patch(
|
| 1263 |
+
{
|
| 1264 |
+
"cpp_wrapper": False,
|
| 1265 |
+
# For triton.autotune_at_compile_time, disable by default for
|
| 1266 |
+
# FBCode, but enabled by default for OSS.
|
| 1267 |
+
"triton.autotune_at_compile_time": config.triton.autotune_at_compile_time
|
| 1268 |
+
if config.is_fbcode()
|
| 1269 |
+
else os.environ.get(
|
| 1270 |
+
"TORCHINDUCTOR_TRITON_AUTOTUNE_AT_COMPILE_TIME", "1"
|
| 1271 |
+
)
|
| 1272 |
+
== "1",
|
| 1273 |
+
"triton.autotune_cublasLt": False,
|
| 1274 |
+
"triton.cudagraphs": False,
|
| 1275 |
+
"triton.store_cubin": True,
|
| 1276 |
+
}
|
| 1277 |
+
), V.set_real_inputs(example_inputs_):
|
| 1278 |
+
inputs_ = example_inputs_
|
| 1279 |
+
if isinstance(model_, torch.fx.GraphModule):
|
| 1280 |
+
fake_inputs = [
|
| 1281 |
+
node.meta.get("val")
|
| 1282 |
+
for node in model_.graph.nodes
|
| 1283 |
+
if node.op == "placeholder"
|
| 1284 |
+
]
|
| 1285 |
+
if all(v is not None for v in fake_inputs):
|
| 1286 |
+
# Validate devices before switching to fake tensors.
|
| 1287 |
+
for idx, fi, i in zip(count(), fake_inputs, inputs_):
|
| 1288 |
+
if fi.device != i.device:
|
| 1289 |
+
raise ValueError(
|
| 1290 |
+
f"Device mismatch between fake input and example input at position #{idx}: "
|
| 1291 |
+
f"{fi.device} vs {i.device}. If the model was exported via torch.export(), "
|
| 1292 |
+
"make sure torch.export() and torch.aot_compile() run on the same device."
|
| 1293 |
+
)
|
| 1294 |
+
inputs_ = fake_inputs
|
| 1295 |
+
return compile_fx(
|
| 1296 |
+
model_,
|
| 1297 |
+
inputs_,
|
| 1298 |
+
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
|
| 1299 |
+
decompositions=decompositions,
|
| 1300 |
+
)
|
| 1301 |
+
|
| 1302 |
+
recursive_compile_fx = functools.partial(
|
| 1303 |
+
compile_fx,
|
| 1304 |
+
inner_compile=inner_compile,
|
| 1305 |
+
decompositions=decompositions,
|
| 1306 |
+
)
|
| 1307 |
+
|
| 1308 |
+
if not graph_returns_tuple(model_):
|
| 1309 |
+
return make_graph_return_tuple(
|
| 1310 |
+
model_,
|
| 1311 |
+
example_inputs_,
|
| 1312 |
+
recursive_compile_fx,
|
| 1313 |
+
)
|
| 1314 |
+
|
| 1315 |
+
if isinstance(model_, torch.fx.GraphModule):
|
| 1316 |
+
if isinstance(model_.graph._codegen, _PyTreeCodeGen):
|
| 1317 |
+
# this graph is the result of dynamo.export()
|
| 1318 |
+
return handle_dynamo_export_graph(
|
| 1319 |
+
model_,
|
| 1320 |
+
example_inputs_,
|
| 1321 |
+
recursive_compile_fx,
|
| 1322 |
+
)
|
| 1323 |
+
|
| 1324 |
+
model_ = _recursive_pre_grad_passes(model_, example_inputs_)
|
| 1325 |
+
|
| 1326 |
+
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
|
| 1327 |
+
return flatten_graph_inputs(
|
| 1328 |
+
model_,
|
| 1329 |
+
example_inputs_,
|
| 1330 |
+
recursive_compile_fx,
|
| 1331 |
+
)
|
| 1332 |
+
|
| 1333 |
+
assert not config._raise_error_for_testing
|
| 1334 |
+
num_example_inputs = len(example_inputs_)
|
| 1335 |
+
cudagraphs = BoxedBool(config.triton.cudagraphs)
|
| 1336 |
+
forward_device = BoxedDeviceIndex(None)
|
| 1337 |
+
|
| 1338 |
+
graph_id = next(_graph_counter)
|
| 1339 |
+
|
| 1340 |
+
decompositions = (
|
| 1341 |
+
decompositions if decompositions is not None else select_decomp_table()
|
| 1342 |
+
)
|
| 1343 |
+
|
| 1344 |
+
def fw_compiler_base(
|
| 1345 |
+
model: torch.fx.GraphModule,
|
| 1346 |
+
example_inputs: List[torch.Tensor],
|
| 1347 |
+
is_inference: bool,
|
| 1348 |
+
):
|
| 1349 |
+
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
|
| 1350 |
+
return _fw_compiler_base(model, example_inputs, is_inference)
|
| 1351 |
+
|
| 1352 |
+
def _fw_compiler_base(
|
| 1353 |
+
model: torch.fx.GraphModule,
|
| 1354 |
+
example_inputs: List[torch.Tensor],
|
| 1355 |
+
is_inference: bool,
|
| 1356 |
+
):
|
| 1357 |
+
if is_inference:
|
| 1358 |
+
# partition_fn won't be called
|
| 1359 |
+
_recursive_joint_graph_passes(model)
|
| 1360 |
+
|
| 1361 |
+
fixed = torch._inductor.utils.num_fw_fixed_arguments(
|
| 1362 |
+
num_example_inputs, len(example_inputs)
|
| 1363 |
+
)
|
| 1364 |
+
|
| 1365 |
+
user_visible_outputs = {}
|
| 1366 |
+
|
| 1367 |
+
if config.keep_output_stride:
|
| 1368 |
+
model_outputs_node = output_node(model)
|
| 1369 |
+
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
| 1370 |
+
num_model_outputs = len(model_outputs)
|
| 1371 |
+
|
| 1372 |
+
context = torch._guards.TracingContext.try_get()
|
| 1373 |
+
# See Note [User Outputs in the inductor graph]
|
| 1374 |
+
if context is not None and context.fw_metadata and not is_inference:
|
| 1375 |
+
original_output_start_index = (
|
| 1376 |
+
context.fw_metadata.num_mutated_inp_runtime_indices
|
| 1377 |
+
)
|
| 1378 |
+
else:
|
| 1379 |
+
original_output_start_index = 0
|
| 1380 |
+
|
| 1381 |
+
if isinstance(model_, torch.fx.GraphModule):
|
| 1382 |
+
*_, orig_model_outputs_node = model_.graph.nodes
|
| 1383 |
+
assert orig_model_outputs_node.op == "output"
|
| 1384 |
+
orig_model_outputs, _ = pytree.tree_flatten(
|
| 1385 |
+
orig_model_outputs_node.args
|
| 1386 |
+
)
|
| 1387 |
+
num_orig_model_outputs = len(orig_model_outputs)
|
| 1388 |
+
else:
|
| 1389 |
+
num_orig_model_outputs = num_model_outputs
|
| 1390 |
+
|
| 1391 |
+
assert num_orig_model_outputs <= num_model_outputs
|
| 1392 |
+
|
| 1393 |
+
# Note [User Outputs in the inductor graph]
|
| 1394 |
+
# We makes the following assumption
|
| 1395 |
+
# For inference
|
| 1396 |
+
# len(orig_model_outputs) == len(model_outputs)
|
| 1397 |
+
# For training
|
| 1398 |
+
# len(orig_model_outputs) <= len(model_outputs)
|
| 1399 |
+
# During training, most of the time the model_outputs starts with
|
| 1400 |
+
# original module's outputs followed by saved activations.
|
| 1401 |
+
# But this can be not true if the model have inplace updated tensors.
|
| 1402 |
+
# AOTAutograd will make those tensors being returned before the original
|
| 1403 |
+
# module's output.
|
| 1404 |
+
# To make things safe, we'll use original_output_start_index field
|
| 1405 |
+
# set by AOTAutograd to decide where the original module outputs start.
|
| 1406 |
+
orig_output_end_idx = (
|
| 1407 |
+
original_output_start_index + num_orig_model_outputs
|
| 1408 |
+
)
|
| 1409 |
+
# Sanity chec: we are about to splice out the "user" outputs from the full set
|
| 1410 |
+
# of "graph" outputs. Make sure we're within bounds.
|
| 1411 |
+
assert orig_output_end_idx <= num_model_outputs
|
| 1412 |
+
|
| 1413 |
+
user_visible_outputs = dict.fromkeys(
|
| 1414 |
+
n.name
|
| 1415 |
+
for n in model_outputs[
|
| 1416 |
+
original_output_start_index:orig_output_end_idx
|
| 1417 |
+
]
|
| 1418 |
+
if isinstance(n, torch.fx.Node)
|
| 1419 |
+
)
|
| 1420 |
+
|
| 1421 |
+
return inner_compile(
|
| 1422 |
+
model,
|
| 1423 |
+
example_inputs,
|
| 1424 |
+
static_input_idxs=get_static_input_idxs(fixed),
|
| 1425 |
+
cudagraphs=cudagraphs,
|
| 1426 |
+
graph_id=graph_id,
|
| 1427 |
+
is_inference=is_inference,
|
| 1428 |
+
boxed_forward_device_index=forward_device,
|
| 1429 |
+
user_visible_outputs=user_visible_outputs,
|
| 1430 |
+
)
|
| 1431 |
+
|
| 1432 |
+
fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
|
| 1433 |
+
|
| 1434 |
+
if config.freezing and not torch.is_grad_enabled():
|
| 1435 |
+
inference_compiler = functools.partial(
|
| 1436 |
+
fw_compiler_freezing,
|
| 1437 |
+
dynamo_model=model_,
|
| 1438 |
+
num_example_inputs=num_example_inputs,
|
| 1439 |
+
inner_compile=inner_compile,
|
| 1440 |
+
cudagraphs=cudagraphs,
|
| 1441 |
+
graph_id=graph_id,
|
| 1442 |
+
forward_device=forward_device,
|
| 1443 |
+
)
|
| 1444 |
+
else:
|
| 1445 |
+
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
|
| 1446 |
+
|
| 1447 |
+
def partition_fn(graph, joint_inputs, **kwargs):
|
| 1448 |
+
_recursive_joint_graph_passes(graph)
|
| 1449 |
+
return min_cut_rematerialization_partition(
|
| 1450 |
+
graph, joint_inputs, **kwargs, compiler="inductor"
|
| 1451 |
+
)
|
| 1452 |
+
|
| 1453 |
+
def bw_compiler(
|
| 1454 |
+
model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
| 1455 |
+
):
|
| 1456 |
+
with dynamo_utils.dynamo_timed("compile_fx.<locals>.bw_compiler"):
|
| 1457 |
+
user_visible_outputs = {}
|
| 1458 |
+
|
| 1459 |
+
if config.bw_outputs_user_visible:
|
| 1460 |
+
model_outputs_node = output_node(model)
|
| 1461 |
+
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
| 1462 |
+
user_visible_outputs = dict.fromkeys(
|
| 1463 |
+
n.name for n in model_outputs if isinstance(n, torch.fx.Node)
|
| 1464 |
+
)
|
| 1465 |
+
fixed = count_tangents(model)
|
| 1466 |
+
return inner_compile(
|
| 1467 |
+
model,
|
| 1468 |
+
example_inputs,
|
| 1469 |
+
static_input_idxs=list(range(fixed)),
|
| 1470 |
+
cudagraphs=cudagraphs,
|
| 1471 |
+
is_backward=True,
|
| 1472 |
+
graph_id=graph_id,
|
| 1473 |
+
boxed_forward_device_index=forward_device,
|
| 1474 |
+
user_visible_outputs=user_visible_outputs,
|
| 1475 |
+
)
|
| 1476 |
+
|
| 1477 |
+
# TODO: can add logging before/after the call to create_aot_dispatcher_function
|
| 1478 |
+
# in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
|
| 1479 |
+
# once torchdynamo is merged into pytorch
|
| 1480 |
+
|
| 1481 |
+
fake_mode = detect_fake_mode(
|
| 1482 |
+
example_inputs_
|
| 1483 |
+
) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
|
| 1484 |
+
tracing_context = (
|
| 1485 |
+
torch._guards.TracingContext.try_get()
|
| 1486 |
+
or torch._guards.TracingContext(fake_mode)
|
| 1487 |
+
)
|
| 1488 |
+
|
| 1489 |
+
if V.aot_compilation is True:
|
| 1490 |
+
with functorch_config.patch(unlift_effect_tokens=True):
|
| 1491 |
+
gm, graph_signature = aot_export_module(
|
| 1492 |
+
model_,
|
| 1493 |
+
example_inputs_,
|
| 1494 |
+
trace_joint=False,
|
| 1495 |
+
decompositions=decompositions,
|
| 1496 |
+
)
|
| 1497 |
+
unlifted_gm = _unlift_graph(model_, gm, graph_signature)
|
| 1498 |
+
if "dynamo_flat_name_to_original_fqn" in model_.meta:
|
| 1499 |
+
unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[
|
| 1500 |
+
"dynamo_flat_name_to_original_fqn"
|
| 1501 |
+
]
|
| 1502 |
+
|
| 1503 |
+
# Disable amp as in aot_dispatch_autograd (https://github.com/pytorch/pytorch/pull/86515)
|
| 1504 |
+
# In inference_compiler (fw_compiler_base), _recursive_joint_graph_passes will call into
|
| 1505 |
+
# _sfdp_init() to register patterns.
|
| 1506 |
+
# When fallback_random is set to True, the sdpa patterns will be traced during runtime.
|
| 1507 |
+
# If amp is turned on, the traced FP32 patterns will have prims.convert_element_type which
|
| 1508 |
+
# will be the same as the generated FP16 patterns.
|
| 1509 |
+
disable_amp = torch._C._is_any_autocast_enabled()
|
| 1510 |
+
context = (
|
| 1511 |
+
torch._C._DisableAutocast if disable_amp else contextlib.nullcontext
|
| 1512 |
+
)
|
| 1513 |
+
with V.set_fake_mode(fake_mode), compiled_autograd.disable(), context():
|
| 1514 |
+
return inference_compiler(unlifted_gm, example_inputs_)
|
| 1515 |
+
|
| 1516 |
+
with V.set_fake_mode(fake_mode), torch._guards.tracing(
|
| 1517 |
+
tracing_context
|
| 1518 |
+
), compiled_autograd.disable(), functorch_config.patch(
|
| 1519 |
+
unlift_effect_tokens=True
|
| 1520 |
+
):
|
| 1521 |
+
return aot_autograd(
|
| 1522 |
+
fw_compiler=fw_compiler,
|
| 1523 |
+
bw_compiler=bw_compiler,
|
| 1524 |
+
inference_compiler=inference_compiler,
|
| 1525 |
+
decompositions=decompositions,
|
| 1526 |
+
partition_fn=partition_fn,
|
| 1527 |
+
keep_inference_input_mutations=True,
|
| 1528 |
+
cudagraphs=cudagraphs,
|
| 1529 |
+
)(model_, example_inputs_)
|
| 1530 |
+
|
| 1531 |
+
|
| 1532 |
+
def graph_returns_tuple(gm: torch.fx.GraphModule):
|
| 1533 |
+
"""True if a FX graph returns a tuple"""
|
| 1534 |
+
if not isinstance(gm, torch.fx.GraphModule):
|
| 1535 |
+
return True # can't check this, assume true
|
| 1536 |
+
(rv,) = output_node(gm).args
|
| 1537 |
+
if isinstance(rv, (list, tuple)):
|
| 1538 |
+
return True
|
| 1539 |
+
if (
|
| 1540 |
+
isinstance(rv, torch.fx.node.Node)
|
| 1541 |
+
and hasattr(rv.target, "_schema")
|
| 1542 |
+
and len(rv.target._schema.returns) > 1
|
| 1543 |
+
and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns)
|
| 1544 |
+
):
|
| 1545 |
+
# for graphs whose result is one node with multiple outputs
|
| 1546 |
+
return True
|
| 1547 |
+
return False
|
| 1548 |
+
|
| 1549 |
+
|
| 1550 |
+
def make_graph_return_tuple(
|
| 1551 |
+
gm: torch.fx.GraphModule,
|
| 1552 |
+
inputs: List[torch.Tensor],
|
| 1553 |
+
compile_gm: Callable[..., Any],
|
| 1554 |
+
):
|
| 1555 |
+
"""
|
| 1556 |
+
Mutate gm so it returns a tuple. This is only needed for graphs
|
| 1557 |
+
not created by torchdynamo that return non-tuples.
|
| 1558 |
+
"""
|
| 1559 |
+
node = output_node(gm)
|
| 1560 |
+
(rv,) = node.args
|
| 1561 |
+
rv, spec = pytree.tree_flatten(rv)
|
| 1562 |
+
with gm.graph.inserting_before(node):
|
| 1563 |
+
gm.graph.output(rv)
|
| 1564 |
+
gm.graph.erase_node(node)
|
| 1565 |
+
assert graph_returns_tuple(gm)
|
| 1566 |
+
|
| 1567 |
+
compiled_fn = compile_gm(gm, inputs)
|
| 1568 |
+
|
| 1569 |
+
@functools.wraps(compiled_fn)
|
| 1570 |
+
def wrapper(*args, **kwargs):
|
| 1571 |
+
return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec)
|
| 1572 |
+
|
| 1573 |
+
return wrapper
|
| 1574 |
+
|
| 1575 |
+
|
| 1576 |
+
def handle_dynamo_export_graph(
|
| 1577 |
+
gm: torch.fx.GraphModule,
|
| 1578 |
+
inputs: List[torch.Tensor],
|
| 1579 |
+
compile_gm: Callable[..., Any],
|
| 1580 |
+
):
|
| 1581 |
+
"""
|
| 1582 |
+
`torch._dynamo.export` embeds pytrees in the FX graph codegen object,
|
| 1583 |
+
convert that to a normal FX graph so inductor can compile it.
|
| 1584 |
+
"""
|
| 1585 |
+
codegen = gm.graph._codegen
|
| 1586 |
+
gm.graph._codegen = torch.fx.graph.CodeGen()
|
| 1587 |
+
gm.recompile()
|
| 1588 |
+
|
| 1589 |
+
compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs))
|
| 1590 |
+
|
| 1591 |
+
@functools.wraps(compiled_fn)
|
| 1592 |
+
def wrapper(*args):
|
| 1593 |
+
return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args)))
|
| 1594 |
+
|
| 1595 |
+
return wrapper
|
| 1596 |
+
|
| 1597 |
+
|
| 1598 |
+
def _check_triton_bf16_support(graph: GraphLowering) -> None:
|
| 1599 |
+
def warn_and_skip(device) -> None:
|
| 1600 |
+
from torch._dynamo.exc import SkipFrame
|
| 1601 |
+
|
| 1602 |
+
device_interface = get_interface_for_device(device.type)
|
| 1603 |
+
device_props = device_interface.get_device_properties(device)
|
| 1604 |
+
warnings.warn(
|
| 1605 |
+
f"{device_props.name} does not support bfloat16 compilation natively, skipping"
|
| 1606 |
+
)
|
| 1607 |
+
raise SkipFrame("BF16 is not supported")
|
| 1608 |
+
|
| 1609 |
+
for inp in graph.graph_inputs.values():
|
| 1610 |
+
device = getattr(inp, "get_device", lambda: torch.device("meta"))()
|
| 1611 |
+
if (not is_gpu(device.type)) or inp.get_dtype() != torch.bfloat16:
|
| 1612 |
+
continue
|
| 1613 |
+
# Print warning and skip frame if attempting to compile for bfloat16
|
| 1614 |
+
# on device without hardware support for dtype
|
| 1615 |
+
device_interface = get_interface_for_device(device.type)
|
| 1616 |
+
if device_interface.is_bf16_supported(including_emulation=False):
|
| 1617 |
+
return
|
| 1618 |
+
warn_and_skip(device)
|
| 1619 |
+
|
| 1620 |
+
for out in graph.graph_outputs:
|
| 1621 |
+
device = getattr(out, "get_device", lambda: torch.device("meta"))()
|
| 1622 |
+
if (not is_gpu(device.type)) or out.get_dtype() != torch.bfloat16:
|
| 1623 |
+
continue
|
| 1624 |
+
# Print warning and skip frame if attempting to compile for bfloat16
|
| 1625 |
+
# on device without hardware support for dtype
|
| 1626 |
+
device_interface = get_interface_for_device(device.type)
|
| 1627 |
+
if device_interface.is_bf16_supported(including_emulation=False):
|
| 1628 |
+
return
|
| 1629 |
+
warn_and_skip(device)
|
.venv/lib/python3.11/site-packages/torch/_inductor/config.py
ADDED
|
@@ -0,0 +1,1241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os # noqa: C101
|
| 2 |
+
import sys
|
| 3 |
+
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def is_fbcode() -> bool:
|
| 9 |
+
return not hasattr(torch.version, "git_version")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def fx_graph_remote_cache_default() -> Optional[bool]:
|
| 13 |
+
if os.environ.get("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") == "1":
|
| 14 |
+
return True
|
| 15 |
+
if os.environ.get("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") == "0":
|
| 16 |
+
return False
|
| 17 |
+
return None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def autotune_remote_cache_default() -> Optional[bool]:
|
| 21 |
+
if os.environ.get("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") == "1":
|
| 22 |
+
return True
|
| 23 |
+
if os.environ.get("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") == "0":
|
| 24 |
+
return False
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Enable auto_functionalized_v2 (enabled by default)
|
| 29 |
+
enable_auto_functionalized_v2 = (
|
| 30 |
+
os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "0") == "1"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# add some debug printouts
|
| 34 |
+
debug = False
|
| 35 |
+
|
| 36 |
+
# Whether to disable a progress bar for autotuning
|
| 37 |
+
disable_progress = True
|
| 38 |
+
|
| 39 |
+
# Whether to enable printing the source code for each future
|
| 40 |
+
verbose_progress = False
|
| 41 |
+
|
| 42 |
+
# use fx aot graph codegen cache
|
| 43 |
+
fx_graph_cache = (
|
| 44 |
+
os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE", "0" if is_fbcode() else "1") == "1"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# use remote fx aot graph codegen cache
|
| 48 |
+
# False: Disables the cache
|
| 49 |
+
# True: Enables the cache
|
| 50 |
+
# None: Not set -- Off for OSS, JustKnobs based for internal
|
| 51 |
+
fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default()
|
| 52 |
+
|
| 53 |
+
# enable autotune local cache
|
| 54 |
+
autotune_local_cache = True
|
| 55 |
+
|
| 56 |
+
# enable autotune remote cache
|
| 57 |
+
# False: Disables the cache
|
| 58 |
+
# True: Enables the cache
|
| 59 |
+
# None: Not set -- Off for OSS, JustKnobs based for internal
|
| 60 |
+
autotune_remote_cache: Optional[bool] = autotune_remote_cache_default()
|
| 61 |
+
|
| 62 |
+
# Force disabled all inductor level caching -- This will override any other caching flag
|
| 63 |
+
force_disable_caches = os.environ.get("TORCHINDUCTOR_FORCE_DISABLE_CACHES") == "1"
|
| 64 |
+
|
| 65 |
+
# sleep in inductor for testing
|
| 66 |
+
sleep_sec_TESTING_ONLY: Optional[int] = None
|
| 67 |
+
|
| 68 |
+
# The default layout constraint for custom operators.
|
| 69 |
+
# This must be the name of one of the layout constraint tags
|
| 70 |
+
# (that is, one of {"needs_fixed_stride_order", "flexible_layout"}),
|
| 71 |
+
# If the custom op does not have a layout constraint tag already
|
| 72 |
+
# then we assume the following applies.
|
| 73 |
+
custom_op_default_layout_constraint = "flexible_layout"
|
| 74 |
+
|
| 75 |
+
# use cpp wrapper instead of python wrapper
|
| 76 |
+
cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1"
|
| 77 |
+
|
| 78 |
+
# codegen cpp wrapper code in an ABI compatible mode
|
| 79 |
+
abi_compatible = (
|
| 80 |
+
os.environ.get("TORCHINDUCTOR_ABI_COMPATIBLE", "1" if is_fbcode() else "0") == "1"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
c_shim_version = os.environ.get("TORCHINDUCTOR_C_SHIM_VERSION", "2")
|
| 84 |
+
|
| 85 |
+
# dead code elimination
|
| 86 |
+
dce = False
|
| 87 |
+
|
| 88 |
+
# assume weight tensors are fixed size
|
| 89 |
+
static_weight_shapes = True
|
| 90 |
+
|
| 91 |
+
# put correctness assertions in generated code
|
| 92 |
+
size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
|
| 93 |
+
nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
|
| 94 |
+
|
| 95 |
+
# enable loop reordering based on input orders
|
| 96 |
+
pick_loop_orders = True
|
| 97 |
+
|
| 98 |
+
# reuse a kernel input as the output
|
| 99 |
+
inplace_buffers = True
|
| 100 |
+
|
| 101 |
+
# reuse a buffer for an unrelated purpose
|
| 102 |
+
allow_buffer_reuse = True
|
| 103 |
+
|
| 104 |
+
# Enable pooled allocations for non-output tensors
|
| 105 |
+
memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1"
|
| 106 |
+
|
| 107 |
+
# How to organize memory under memory_planning=True:
|
| 108 |
+
# - "none": do not try to pool storage, just reuse
|
| 109 |
+
# - "intermediates": all non-outputs share storage, outputs each get unique storage
|
| 110 |
+
# - "outputs": two pools, one for intermediates (freed on return) and one for outputs
|
| 111 |
+
# - "combined": a single pool for both intermediates and outputs
|
| 112 |
+
memory_pool = os.environ.get("TORCHINDUCTOR_MEMORY_POOL", "intermediates")
|
| 113 |
+
|
| 114 |
+
# codegen benchmark harness
|
| 115 |
+
benchmark_harness = True
|
| 116 |
+
|
| 117 |
+
# fuse pointwise into templates
|
| 118 |
+
epilogue_fusion = True
|
| 119 |
+
|
| 120 |
+
# do epilogue fusions before other fusions
|
| 121 |
+
epilogue_fusion_first = False
|
| 122 |
+
|
| 123 |
+
# enable pattern match+replace optimizations
|
| 124 |
+
pattern_matcher = True
|
| 125 |
+
|
| 126 |
+
# set to True to enable the back-to-back GEMM pass
|
| 127 |
+
b2b_gemm_pass = False
|
| 128 |
+
|
| 129 |
+
# register custom graph optimization pass hook. so far, pre/post passes are
|
| 130 |
+
# only applied before/after pattern_matcher in post_grad_passes.
|
| 131 |
+
#
|
| 132 |
+
# def my_custom_pre_pass(graph: torch.fx.graph.Graph):
|
| 133 |
+
# # my custom graph optimization pass
|
| 134 |
+
# ...
|
| 135 |
+
#
|
| 136 |
+
# def my_custom_post_pass(graph: torch.fx.graph.Graph):
|
| 137 |
+
# # my custom graph optimization pass
|
| 138 |
+
# ...
|
| 139 |
+
#
|
| 140 |
+
# torch._inductor.config.post_grad_custom_pre_pass = my_custom_pre_pass
|
| 141 |
+
# torch._inductor.config.post_grad_custom_post_pass = my_custom_post_pass
|
| 142 |
+
post_grad_custom_pre_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
|
| 143 |
+
post_grad_custom_post_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
|
| 144 |
+
|
| 145 |
+
# Registers a custom joint graph pass.
|
| 146 |
+
joint_custom_pre_pass: Optional[Callable[[torch.fx.Graph], None]] = None
|
| 147 |
+
joint_custom_post_pass: Optional[Callable[[torch.fx.Graph], None]] = None
|
| 148 |
+
|
| 149 |
+
# Registers a custom pregrad pass. Note that the pre-grad IR is 1.
|
| 150 |
+
# non-functional, 2. non-normalized, and 3. prone to change. Ideally we should
|
| 151 |
+
# use post-grad passes.
|
| 152 |
+
pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
|
| 153 |
+
|
| 154 |
+
# Registers a custom pass to be run right before fusion in Inductor scheduler.
|
| 155 |
+
# WARNING: Inductor scheduler IR is at prototype stage and subject to change,
|
| 156 |
+
# hence custom IR passes built on top of it might break in the future.
|
| 157 |
+
_pre_fusion_custom_pass: Optional[
|
| 158 |
+
Callable[
|
| 159 |
+
[List["torch._inductor.scheduler.BaseSchedulerNode"]],
|
| 160 |
+
List["torch._inductor.scheduler.BaseSchedulerNode"],
|
| 161 |
+
]
|
| 162 |
+
] = None
|
| 163 |
+
|
| 164 |
+
# Deprecated
|
| 165 |
+
split_cat_fx_passes = True
|
| 166 |
+
|
| 167 |
+
# Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability.
|
| 168 |
+
efficient_conv_bn_eval_fx_passes = False
|
| 169 |
+
|
| 170 |
+
# Enable predispatch aten IR for export
|
| 171 |
+
is_predispatch = False
|
| 172 |
+
|
| 173 |
+
# Deprecated
|
| 174 |
+
group_fusion = False
|
| 175 |
+
|
| 176 |
+
# Deprecated
|
| 177 |
+
batch_fusion = True
|
| 178 |
+
|
| 179 |
+
# Pre grad fusion and options in order, set to empty dict to disable fusion.
|
| 180 |
+
# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions.
|
| 181 |
+
# batch fusion options:
|
| 182 |
+
# batch_linear
|
| 183 |
+
# batch_linear_lhs
|
| 184 |
+
# batch_layernorm
|
| 185 |
+
# batch_tanh
|
| 186 |
+
# batch_relu
|
| 187 |
+
# batch_sigmoid
|
| 188 |
+
|
| 189 |
+
# split cat fusion options:
|
| 190 |
+
# normalization_pass
|
| 191 |
+
# remove_split_with_size_one_pass
|
| 192 |
+
# merge_getitem_cat_pass
|
| 193 |
+
# merge_stack_tahn_unbind
|
| 194 |
+
# merge_splits_pass
|
| 195 |
+
# mutate_cat_pass
|
| 196 |
+
# split_cat_pass
|
| 197 |
+
pre_grad_fusion_options: Dict[str, Dict[str, Any]] = {
|
| 198 |
+
"batch_linear": {},
|
| 199 |
+
"batch_linear_lhs": {},
|
| 200 |
+
"batch_layernorm": {},
|
| 201 |
+
"batch_tanh": {},
|
| 202 |
+
"batch_relu": {},
|
| 203 |
+
"batch_sigmoid": {},
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
# Post grad fusion and options, set to empty dict to disable fusion.
|
| 207 |
+
# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
|
| 208 |
+
post_grad_fusion_options: Dict[str, Dict[str, Any]] = {}
|
| 209 |
+
|
| 210 |
+
# enable reordering pass for improving memory locality
|
| 211 |
+
reorder_for_locality = True
|
| 212 |
+
|
| 213 |
+
# Scale down RBLOCK for better occupancy
|
| 214 |
+
dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1"
|
| 215 |
+
|
| 216 |
+
# this forces fusion for int_mm with mul. Needed when you want to avoid realizing the int32
|
| 217 |
+
# but the mul gets fused with other pointwise ops instead.
|
| 218 |
+
force_fuse_int_mm_with_mul = False
|
| 219 |
+
|
| 220 |
+
# for pattern torch.mm(a, b.to(dtype)) with cuda tensors,
|
| 221 |
+
# enable torch._inductor.kernel.mm.tuned_mixed_mm fused kernel.
|
| 222 |
+
# Autotune will compare perf with normal cast->then->mm option
|
| 223 |
+
use_mixed_mm = True
|
| 224 |
+
|
| 225 |
+
# enable runtime numeric check for pre/post grad fx passes
|
| 226 |
+
# floating point provides limited accuracy (about 7 decimal digits for single precision
|
| 227 |
+
# floating point numbers,about 16 decimal digits for double precision floating point numbers)
|
| 228 |
+
# according to PyTorch documentation.
|
| 229 |
+
# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations
|
| 230 |
+
fx_passes_numeric_check: Dict[str, Any] = {
|
| 231 |
+
"pre_grad": False,
|
| 232 |
+
"precision": 1e-4,
|
| 233 |
+
"num_iterations": 1,
|
| 234 |
+
"requires_optimizer": True,
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
# mixed_mm_choice can be used to control the behaviour for pattern torch.mm(a, b.to(dtype)) with cuda tensors.
|
| 238 |
+
# The fallback aten implementation is normal cast->then->mm option.
|
| 239 |
+
# If mixed_mm_choice is "default": this flag will be ignored.
|
| 240 |
+
# If mixed_mm_choice is "triton":
|
| 241 |
+
# - Always use torch._inductor.kernel.mm.tuned_mixed_mm's fused kernel.
|
| 242 |
+
# - Autotune will not compare with fallback.
|
| 243 |
+
# If mixed_mm_choice is "aten": always use the fallback aten implementation.
|
| 244 |
+
# If mixed_mm_choice is "heuristic":
|
| 245 |
+
# - Enables the heuristic.
|
| 246 |
+
# - If the heuristic decides to add a config, it will add the config as the first choice.
|
| 247 |
+
# - If autotune is disabled, this config will always be chosen.
|
| 248 |
+
# - If autotune is enabled, it will also compare with fallback aten implementation and fused kernel.
|
| 249 |
+
# The use_mixed_mm flag will be ignored if mixed_mm_choice != "default".
|
| 250 |
+
mixed_mm_choice = "heuristic"
|
| 251 |
+
|
| 252 |
+
# enable reordering pass for increasing overlap between compute and communication
|
| 253 |
+
reorder_for_compute_comm_overlap = False
|
| 254 |
+
|
| 255 |
+
# passes (in execution order) for increasing overlap between compute and communication
|
| 256 |
+
# for built-in passes, use string name; for user-defined passes, pass in the function handle
|
| 257 |
+
# WARNING: Inductor scheduler IR is at prototype stage and subject to change,
|
| 258 |
+
# hence custom IR passes built on top of it might break in the future.
|
| 259 |
+
reorder_for_compute_comm_overlap_passes = [
|
| 260 |
+
"reorder_compute_for_overlap",
|
| 261 |
+
"sink_waits",
|
| 262 |
+
"raise_comms",
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
# runtime estimation function for ops
|
| 266 |
+
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
|
| 267 |
+
estimate_op_runtime = "default"
|
| 268 |
+
|
| 269 |
+
# unit: GB/s, uni-directional P2P bandwidth per card
|
| 270 |
+
# default value is NVLink
|
| 271 |
+
intra_node_bw = 300
|
| 272 |
+
|
| 273 |
+
# unit: GB/s, uni-directional P2P bandwidth per node
|
| 274 |
+
# default value is InfiniBand
|
| 275 |
+
inter_node_bw = 25
|
| 276 |
+
|
| 277 |
+
# enable slow autotuning passes to select algorithms
|
| 278 |
+
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
|
| 279 |
+
|
| 280 |
+
# enable slow autotuning passes to select pointwise/reductions algorithms
|
| 281 |
+
max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1"
|
| 282 |
+
|
| 283 |
+
# enable slow autotuning passes to select gemm algorithms
|
| 284 |
+
max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1"
|
| 285 |
+
|
| 286 |
+
# force cublas and triton to use the same precision; cublas supports TF32 for matmul operations
|
| 287 |
+
# when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations
|
| 288 |
+
# for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure
|
| 289 |
+
# that triton does not use TF32 wherever cublas would not use TF32
|
| 290 |
+
force_same_precision = (
|
| 291 |
+
True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Specify candidate backends for gemm autotune.
|
| 295 |
+
# Possible choices are combinations of: ATen, Triton, CUTLASS, CK, CPP.
|
| 296 |
+
# ATen: default Pytorch ATen kernels.
|
| 297 |
+
# Triton: Triton templates defined in torch inductor (AMD and NVidia GPUs).
|
| 298 |
+
# CUTLASS: Cutlass templates and kernels (NVidia GPUs only).
|
| 299 |
+
# CK: Composable Kernel templates and kernels (AMD Instinct GPUs only).
|
| 300 |
+
# CPP: CPP templates and kernels for CPU.
|
| 301 |
+
max_autotune_gemm_backends = os.environ.get(
|
| 302 |
+
"TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP"
|
| 303 |
+
).upper()
|
| 304 |
+
|
| 305 |
+
# As above, specify candidate backends for conv autotune.
|
| 306 |
+
# NB: in some cases for 1x1 convs we emit as matmul,
|
| 307 |
+
# which will use the backends of `max_autotune_gemm_backends`
|
| 308 |
+
max_autotune_conv_backends = os.environ.get(
|
| 309 |
+
"TORCHINDUCTOR_MAX_AUTOTUNE_CONV_BACKENDS", "ATEN,TRITON"
|
| 310 |
+
).upper()
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# Specify the size of the search space for GEMM autotuning.
|
| 314 |
+
# DEFAULT - balance between compile time overhead and performance
|
| 315 |
+
# EXHAUSTIVE - maximize performance
|
| 316 |
+
max_autotune_gemm_search_space = os.environ.get(
|
| 317 |
+
"TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT"
|
| 318 |
+
).upper()
|
| 319 |
+
|
| 320 |
+
# Whether we fall back to ATen or hard error when no matches are found during autotuning
|
| 321 |
+
autotune_fallback_to_aten = (
|
| 322 |
+
os.environ.get("TORCHINDUCTOR_AUTOTUNE_FALLBACK_TO_ATEN", "1") == "1"
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# the value used as a fallback for the unbacked SymInts
|
| 326 |
+
# that can appear in the input shapes (e.g., in autotuning)
|
| 327 |
+
unbacked_symint_fallback = 8192
|
| 328 |
+
|
| 329 |
+
# DEPRECATED, DO NOT USE
|
| 330 |
+
search_autotune_cache = False
|
| 331 |
+
|
| 332 |
+
save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1"
|
| 333 |
+
|
| 334 |
+
# We will disable creating subprocess for autotuning if this is False
|
| 335 |
+
autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
|
| 336 |
+
|
| 337 |
+
# The following three timeouts are applicable if autotune_in_subproc is True:
|
| 338 |
+
|
| 339 |
+
# Max time that a a valid benchmark result may take during autotuning
|
| 340 |
+
max_autotune_subproc_result_timeout_seconds = 60.0
|
| 341 |
+
# Additional time we allow subprocesses to terminate gracefully after the timeout until we send a SIGTERM
|
| 342 |
+
max_autotune_subproc_graceful_timeout_seconds = 1.0
|
| 343 |
+
# Additional time that we grant after a SIGTERM until we do a hard SIGKILL of subprocesses
|
| 344 |
+
max_autotune_subproc_terminate_timeout_seconds = 2.0
|
| 345 |
+
|
| 346 |
+
# If autotuning in subprocess, whether to use multiple devices
|
| 347 |
+
autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1"
|
| 348 |
+
|
| 349 |
+
coordinate_descent_tuning = (
|
| 350 |
+
os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1"
|
| 351 |
+
)
|
| 352 |
+
coordinate_descent_check_all_directions = (
|
| 353 |
+
os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1"
|
| 354 |
+
)
|
| 355 |
+
coordinate_descent_search_radius = int(
|
| 356 |
+
os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1")
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# AutoHeuristic is a framework that allows one to collect data from autotuning, use the data to learn a heuristic, and
|
| 360 |
+
# generate the learned heursitic to code which is shipped with the compiler
|
| 361 |
+
# Specify a list of comma separated optimizations to collect data for
|
| 362 |
+
autoheuristic_collect = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "")
|
| 363 |
+
# Specify a list of comma separated optimizations to use learned heuristics for
|
| 364 |
+
autoheuristic_use = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "mixed_mm")
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def run_autoheuristic(name: str) -> bool:
|
| 368 |
+
return collect_autoheuristic(name) or use_autoheuristic(name)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def collect_autoheuristic(name: str) -> bool:
|
| 372 |
+
return name in torch._inductor.config.autoheuristic_collect.split(",")
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def use_autoheuristic(name: str) -> bool:
|
| 376 |
+
return name in torch._inductor.config.autoheuristic_use.split(",")
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# If set to "DEFAULT", this will use the default log path specified in autoheuristic.py.
|
| 380 |
+
# If set to another path, autoheuristic will instead log results to the given path.
|
| 381 |
+
autoheuristic_log_path = os.environ.get(
|
| 382 |
+
"TORCHINDUCTOR_AUTOHEURISTIC_LOG_PATH", "DEFAULT"
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Disabled by default on ROCm, opt-in if model utilises NHWC convolutions
|
| 386 |
+
layout_opt_default = "1" if not torch.version.hip else "0"
|
| 387 |
+
layout_optimization = (
|
| 388 |
+
os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", layout_opt_default) == "1"
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1"
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# Whether to keep the output strides the same as eager after layout optimization.
|
| 395 |
+
keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1"
|
| 396 |
+
|
| 397 |
+
# Enabling this will let compiler print warning messages if a generated triton
|
| 398 |
+
# kernel has inputs with mixed layouts. This is helpful for perf debugging
|
| 399 |
+
# since kernel with mixed layout inputs may run much slower then one whose inputs
|
| 400 |
+
# have uniform layouts.
|
| 401 |
+
warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1"
|
| 402 |
+
|
| 403 |
+
# control store vs recompute heuristic
|
| 404 |
+
# For fanouts, rematerialization can lead to exponential blowup. So, have
|
| 405 |
+
# smaller threshold
|
| 406 |
+
realize_reads_threshold = 4
|
| 407 |
+
realize_opcount_threshold = 30
|
| 408 |
+
|
| 409 |
+
# Threshold to prevent excessive accumulation of ops in one buffer during lowering
|
| 410 |
+
realize_acc_reads_threshold = 8
|
| 411 |
+
|
| 412 |
+
# fallback to eager for random/dropout, this is slow but useful for debugging
|
| 413 |
+
fallback_random = False
|
| 414 |
+
|
| 415 |
+
# automatically create fallbacks when encountering an unhandled op
|
| 416 |
+
implicit_fallbacks = True
|
| 417 |
+
|
| 418 |
+
# fuse even in cases without common reads
|
| 419 |
+
aggressive_fusion = False
|
| 420 |
+
|
| 421 |
+
# For each fused kernel in the wrapper, comment with the nodes that get fused.
|
| 422 |
+
# Useful for debugging fusion.
|
| 423 |
+
debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1"
|
| 424 |
+
benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
|
| 425 |
+
enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "")
|
| 426 |
+
loop_ordering_after_fusion = (
|
| 427 |
+
os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1"
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel
|
| 431 |
+
benchmark_epilogue_fusion = (
|
| 432 |
+
os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1"
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# Take how many of the top triton kernels to benchmark epilogue
|
| 436 |
+
max_epilogue_benchmarked_choices = 1
|
| 437 |
+
|
| 438 |
+
# how many nodes to allow into a single fusion
|
| 439 |
+
max_fusion_size = 64
|
| 440 |
+
|
| 441 |
+
# max number of inputs to generate cat as a pointwise op with masked laods
|
| 442 |
+
max_pointwise_cat_inputs = 8
|
| 443 |
+
|
| 444 |
+
# replace small reductions with pointwise, disable with `= 1`
|
| 445 |
+
unroll_reductions_threshold = 8
|
| 446 |
+
|
| 447 |
+
# Add extra comments to output code (causes compile cache misses)
|
| 448 |
+
comment_origin = False
|
| 449 |
+
|
| 450 |
+
# Convert 1x1 convs into matmuls
|
| 451 |
+
conv_1x1_as_mm = False
|
| 452 |
+
|
| 453 |
+
# Enable split reductions for better utilization when the dimension
|
| 454 |
+
# being reduced over is large (by splitting it)
|
| 455 |
+
split_reductions = True
|
| 456 |
+
|
| 457 |
+
benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1"
|
| 458 |
+
|
| 459 |
+
# Enable constant and index_expr folding
|
| 460 |
+
constant_and_index_propagation = True
|
| 461 |
+
|
| 462 |
+
# we always add constants into graph.constants without
|
| 463 |
+
# performing any constant-inlining optimization
|
| 464 |
+
always_keep_tensor_constants = False
|
| 465 |
+
|
| 466 |
+
# assert that indirect indexing does not read / write out of bounds
|
| 467 |
+
assert_indirect_indexing = True
|
| 468 |
+
|
| 469 |
+
# compute CSE bounds on variables that do not appear in the FX graph
|
| 470 |
+
compute_all_bounds = False
|
| 471 |
+
|
| 472 |
+
# enable the combo kernel that combines data-independent kernels (additional
|
| 473 |
+
# to foreach kernels) into a single one (Experimental)
|
| 474 |
+
combo_kernels = False
|
| 475 |
+
# benchmark combo kernels and only allow ones with perf gains
|
| 476 |
+
benchmark_combo_kernel = False
|
| 477 |
+
# combo_kernel autotuning options: 0 - disable, 1 - enable except for foreach,
|
| 478 |
+
# 2 - enable for all
|
| 479 |
+
combo_kernels_autotune = 1
|
| 480 |
+
# Enable masking for combining kernels of mixed sizes: 0 - disable, 1 - enable
|
| 481 |
+
# for all except for foreach, 2 - enable for all
|
| 482 |
+
combo_kernel_allow_mixed_sizes = 1
|
| 483 |
+
# Enable dynamic shapes for foreach kernels
|
| 484 |
+
combo_kernel_foreach_dynamic_shapes = False
|
| 485 |
+
|
| 486 |
+
# constant folding on the joint graph
|
| 487 |
+
joint_graph_constant_folding = True
|
| 488 |
+
|
| 489 |
+
# Enable indirect_indexing asserts for decompositions and lowerings
|
| 490 |
+
debug_index_asserts = False
|
| 491 |
+
|
| 492 |
+
# Mode to emulate pytorch eager numerics for lower precision (fp16, bf16)
|
| 493 |
+
# Pytorch eager computes bf16/fp16 by upcasting inputs to fp32 and downcasting after
|
| 494 |
+
# For multiple, fused pointwise nodes, inductor will elide the intermediary upcasts and downcasts
|
| 495 |
+
# Typically this should be closer to fp64 ref numerics. However, it can be useful for debugging
|
| 496 |
+
# to emulate the eager numerics.
|
| 497 |
+
emulate_precision_casts = False
|
| 498 |
+
|
| 499 |
+
# warnings intended for PyTorch developers, disable for point releases
|
| 500 |
+
is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__
|
| 501 |
+
developer_warnings = is_fbcode() or is_nightly_or_source
|
| 502 |
+
|
| 503 |
+
# This pattern matches a special usage of scatter
|
| 504 |
+
# 1. It's applied to a constant tensor
|
| 505 |
+
# 2. The index tensor has size 1 in the scatter dimension
|
| 506 |
+
# Such pattern generates a sparse matrix when the const tensor is all-zero.
|
| 507 |
+
# We can lower this pattern to a pointwise kernel for more fusion opportunities
|
| 508 |
+
# and saving memory footprint.
|
| 509 |
+
optimize_scatter_upon_const_tensor = (
|
| 510 |
+
os.environ.get("TORCHINDUCTOR_OPTIMIZE_SCATTER_UPON_CONST_TENSOR", "1") == "1"
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
# The multiprocessing start method to use for inductor workers in the codecache.
|
| 515 |
+
# Can be "subprocess" or "fork".
|
| 516 |
+
def decide_worker_start_method() -> str:
|
| 517 |
+
start_method = os.environ.get(
|
| 518 |
+
"TORCHINDUCTOR_WORKER_START", "fork" if is_fbcode() else "subprocess"
|
| 519 |
+
)
|
| 520 |
+
assert start_method in (
|
| 521 |
+
"subprocess",
|
| 522 |
+
"fork",
|
| 523 |
+
), f"Invalid start method: {start_method}"
|
| 524 |
+
return start_method
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
worker_start_method = decide_worker_start_method()
|
| 528 |
+
|
| 529 |
+
# Flags to turn on all_reduce fusion. These 2 flags should be automaticaly turned
|
| 530 |
+
# on by DDP and should not be set by the users.
|
| 531 |
+
_fuse_ddp_communication = False
|
| 532 |
+
_fuse_ddp_bucket_size = 25
|
| 533 |
+
|
| 534 |
+
# Flag to control which fusion passes to apply. Functions in the list will
|
| 535 |
+
# be applied in order. There are two different different fusion passes
|
| 536 |
+
# --"fuse_ddp_with_concat_op" and "fuse_ddp_with_coalesced_op". The default
|
| 537 |
+
# one is "fuse_ddp_with_concat_op". Users can also change this to a customized
|
| 538 |
+
# fusion function.
|
| 539 |
+
#
|
| 540 |
+
# The fusion currently does not support multiple DDP with different PG or
|
| 541 |
+
# data type. This feature will be added in the future PRs.
|
| 542 |
+
#
|
| 543 |
+
# "schedule_comm_wait" is used to delay the wait ops to maximize comm/comp
|
| 544 |
+
# overlapping. At this moment, this pass performs better than
|
| 545 |
+
# reorder_for_compute_comm_overlap_passes but we will add the logic of
|
| 546 |
+
# "schedule_comm_wait" in the future and remove the one here.
|
| 547 |
+
_fuse_ddp_communication_passes: List[Union[Callable[..., None], str]] = [
|
| 548 |
+
"fuse_ddp_with_concat_op",
|
| 549 |
+
"schedule_comm_wait",
|
| 550 |
+
]
|
| 551 |
+
|
| 552 |
+
_micro_pipeline_tp: bool = False
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def decide_compile_threads() -> int:
|
| 556 |
+
"""
|
| 557 |
+
Here are the precedence to decide compile_threads
|
| 558 |
+
1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by
|
| 559 |
+
setting this to 1 to make pdb happy.
|
| 560 |
+
2. Set to 1 if it's win32 platform
|
| 561 |
+
3. decide by the number of CPU cores
|
| 562 |
+
"""
|
| 563 |
+
if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
|
| 564 |
+
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
|
| 565 |
+
elif sys.platform == "win32":
|
| 566 |
+
return 1
|
| 567 |
+
elif is_fbcode():
|
| 568 |
+
return 1
|
| 569 |
+
else:
|
| 570 |
+
cpu_count = (
|
| 571 |
+
len(os.sched_getaffinity(0))
|
| 572 |
+
if hasattr(os, "sched_getaffinity")
|
| 573 |
+
else os.cpu_count()
|
| 574 |
+
)
|
| 575 |
+
assert cpu_count
|
| 576 |
+
return min(32, cpu_count)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
compile_threads = decide_compile_threads()
|
| 580 |
+
|
| 581 |
+
# gemm autotuning global cache dir
|
| 582 |
+
if is_fbcode():
|
| 583 |
+
try:
|
| 584 |
+
from libfb.py import parutil
|
| 585 |
+
|
| 586 |
+
if __package__:
|
| 587 |
+
global_cache_dir = parutil.get_dir_path(
|
| 588 |
+
os.path.join(__package__.replace(".", os.sep), "fb/cache")
|
| 589 |
+
)
|
| 590 |
+
else:
|
| 591 |
+
global_cache_dir = parutil.get_dir_path("fb/cache")
|
| 592 |
+
except (ValueError, ModuleNotFoundError):
|
| 593 |
+
global_cache_dir = None
|
| 594 |
+
|
| 595 |
+
else:
|
| 596 |
+
global_cache_dir = None
|
| 597 |
+
|
| 598 |
+
# If kernel is fused, the name is generated from the origin node op names
|
| 599 |
+
# for larger kernels limit this
|
| 600 |
+
kernel_name_max_ops = 10
|
| 601 |
+
|
| 602 |
+
# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs
|
| 603 |
+
shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1"
|
| 604 |
+
|
| 605 |
+
# Control if we will do padding for pointwise/reductions
|
| 606 |
+
comprehensive_padding = (
|
| 607 |
+
os.environ.get("TORCHINDUCTOR_COMPREHENSIVE_PADDING", "1") == "1"
|
| 608 |
+
)
|
| 609 |
+
pad_channels_last = False
|
| 610 |
+
|
| 611 |
+
# Disable comprehensive padding on the CPU
|
| 612 |
+
disable_padding_cpu = True
|
| 613 |
+
|
| 614 |
+
# The width of comprehensive padding, in bytes.
|
| 615 |
+
# CUDA max memory transaction size is 128 bytes for a warp.
|
| 616 |
+
padding_alignment_bytes = 128
|
| 617 |
+
|
| 618 |
+
# Threshold on the minimum stride that will be padded.
|
| 619 |
+
#
|
| 620 |
+
# Don't align a too small stride since that causes too much memory increase.
|
| 621 |
+
# Pad too small stride may also cause perf loss. We may result in many tiny data blocks
|
| 622 |
+
# with gaps in between. That causes less coalesced GPU memory access!
|
| 623 |
+
#
|
| 624 |
+
# Initially we pick 320 as the threshold since for alignement=16,
|
| 625 |
+
# that results in at most 5% memory cost.
|
| 626 |
+
#
|
| 627 |
+
# But later on we raise the threshold to 1024 to avoid interfere with persistent reduction.
|
| 628 |
+
# Let's say an inner reduction has a row size 513. Inductor will generate
|
| 629 |
+
# persistent reduction code.
|
| 630 |
+
# If we do padding, the strides are not contiguous any more. Inductor
|
| 631 |
+
# uses a much smaller threshold for persistent reduction in this case and
|
| 632 |
+
# generates potentially worse non-persistent reduction code.
|
| 633 |
+
#
|
| 634 |
+
# This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x.
|
| 635 |
+
# (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms)
|
| 636 |
+
padding_stride_threshold = 1024
|
| 637 |
+
|
| 638 |
+
# Enable padding outputs, even if they would not be padded in eager mode.
|
| 639 |
+
# By default, we use the same strides as eager mode.
|
| 640 |
+
pad_outputs = False
|
| 641 |
+
|
| 642 |
+
# Whether to treat output of the backward graph as user visible.
|
| 643 |
+
# For user visible outputs, inductor will make sure the stride matches with eager.
|
| 644 |
+
bw_outputs_user_visible = True
|
| 645 |
+
|
| 646 |
+
# Whether to always use shape padding if it is enabled and possible
|
| 647 |
+
force_shape_pad: bool = False
|
| 648 |
+
|
| 649 |
+
# Fx-based linear/matmul/bmm + permute/transpose vertical fusion
|
| 650 |
+
permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"
|
| 651 |
+
|
| 652 |
+
# Mark the wrapper call in PyTorch profiler
|
| 653 |
+
profiler_mark_wrapper_call = False
|
| 654 |
+
|
| 655 |
+
# Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for
|
| 656 |
+
# every intermediate for which we can correlate it with an intermediate
|
| 657 |
+
# from the original FX graph
|
| 658 |
+
generate_intermediate_hooks = False
|
| 659 |
+
|
| 660 |
+
# Populate traceback field on IRNode; good for debugging why origin_node is
|
| 661 |
+
# not populated, or finding out where an IRNode was constructed
|
| 662 |
+
debug_ir_traceback = False
|
| 663 |
+
|
| 664 |
+
# used for debugging to make sure config is properly set
|
| 665 |
+
_raise_error_for_testing = False
|
| 666 |
+
|
| 667 |
+
_profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "")
|
| 668 |
+
profile_bandwidth = _profile_var != ""
|
| 669 |
+
profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var
|
| 670 |
+
# Specify a file where we print out the profiling results.
|
| 671 |
+
# None means we do not dump results to a file.
|
| 672 |
+
profile_bandwidth_output = os.environ.get("TORCHINDUCTOR_PROFILE_OUTPUT", None)
|
| 673 |
+
# Switch to do_bench_using_profiling to exclude the CPU overheads
|
| 674 |
+
profile_bandwidth_with_do_bench_using_profiling = (
|
| 675 |
+
os.environ.get("TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING") == "1"
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
# TODO: remove later
|
| 680 |
+
disable_cpp_codegen = False
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
# Freezing will attempt to inline weights as constants in optimization
|
| 684 |
+
# and run constant folding and other optimizations on them. After freezing, weights
|
| 685 |
+
# can no longer be updated.
|
| 686 |
+
freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1"
|
| 687 |
+
|
| 688 |
+
# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead
|
| 689 |
+
# of potentially keeping multiple copies of weights.
|
| 690 |
+
freezing_discard_parameters: bool = False
|
| 691 |
+
|
| 692 |
+
# Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests
|
| 693 |
+
# should be run with this flag both on and off to make sure we have coverage.
|
| 694 |
+
allow_stack_allocation: bool = (
|
| 695 |
+
os.environ.get("TORCHINDUCTOR_STACK_ALLOCATION", "1" if is_fbcode() else "0") == "1"
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
# Enables an alternate DSO interface (the "minimal ArrayRef interface") intended
|
| 699 |
+
# to maximize performance for use cases that it can accommodate at the expense of
|
| 700 |
+
# generality. In brief:
|
| 701 |
+
# - inputs and outputs are ArrayRefTensor<T> (note that strides are required, but the
|
| 702 |
+
# tensor must be contiguous)
|
| 703 |
+
# - constant handling is unchanged because it is not a per-inference-iteration bottleneck
|
| 704 |
+
#
|
| 705 |
+
# When the DSO is generated in this mode, the usual interface will also be supported,
|
| 706 |
+
# but performance for that interface may be degraded.
|
| 707 |
+
use_minimal_arrayref_interface: bool = False
|
| 708 |
+
|
| 709 |
+
# decompose some memory bound matmul/bmm to mul
|
| 710 |
+
decompose_mem_bound_mm: bool = False
|
| 711 |
+
|
| 712 |
+
# assume_aligned_inputs means that we assume that inputs will be aligned; we generate
|
| 713 |
+
# code using this assumption, and clone tensors before use if they aren't aligned.
|
| 714 |
+
# In the common case, most inputs will be aligned.
|
| 715 |
+
assume_aligned_inputs: bool = False
|
| 716 |
+
|
| 717 |
+
# For the user-written Triton kernels compiled with the model, ignore the unsupported
|
| 718 |
+
# arguments passed to the @triton.autotune in the user's code; this is unsafe, as
|
| 719 |
+
# ignoring the unsupported args may lead to unexpected autotuning behavior: don't
|
| 720 |
+
# set unless you know what you're doing.
|
| 721 |
+
unsafe_ignore_unsupported_triton_autotune_args: bool = False
|
| 722 |
+
|
| 723 |
+
# When True, we will check in scheduler.py _codegen that there are no "loops"
|
| 724 |
+
# in the call stack; that is to say, the same frame multiple times. This
|
| 725 |
+
# ensures that a cProfile trace to this frame will be a straight line without
|
| 726 |
+
# any cycles.
|
| 727 |
+
check_stack_no_cycles_TESTING_ONLY: bool = False
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
# config specific to codegen/cpp.py
|
| 731 |
+
class cpp:
|
| 732 |
+
# set to torch.get_num_threads()
|
| 733 |
+
threads = -1
|
| 734 |
+
|
| 735 |
+
# Do not generate loops when the condition doesn't hold, like:
|
| 736 |
+
# for(long i0=4096; i0<4096; i0+=1)
|
| 737 |
+
no_redundant_loops = (
|
| 738 |
+
os.environ.get("TORCHINDUCTOR_CPP_NO_REDUNDANT_LOOPS", "1") == "1"
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
# Assume number of threads is dynamic, don't specialize thread number.
|
| 742 |
+
# Kernels don't recompile on thread number changes with this flag on.
|
| 743 |
+
# For single-threaded workload, turning it on would incur a slight
|
| 744 |
+
# performance degradation.
|
| 745 |
+
dynamic_threads = os.environ.get("TORCHINDUCTOR_CPP_DYNAMIC_THREADS", "0") == "1"
|
| 746 |
+
|
| 747 |
+
simdlen: Optional[int] = None
|
| 748 |
+
min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096"))
|
| 749 |
+
cxx = (
|
| 750 |
+
None, # download gcc12 from conda-forge if conda is installed
|
| 751 |
+
# "g++-12",
|
| 752 |
+
# "g++-11",
|
| 753 |
+
# "g++-10",
|
| 754 |
+
# "clang++",
|
| 755 |
+
os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"),
|
| 756 |
+
# "g++.par",
|
| 757 |
+
)
|
| 758 |
+
# Allow kernel performance profiling via PyTorch profiler
|
| 759 |
+
enable_kernel_profile = (
|
| 760 |
+
os.environ.get("TORCHINDUCTOR_CPP_ENABLE_KERNEL_PROFILE", "0") == "1"
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
# enable weight prepacking to get a better performance; may lead to large memory footprint
|
| 764 |
+
weight_prepack = os.environ.get("TORCHINDUCTOR_CPP_WEIGHT_PREPACK", "1") == "1"
|
| 765 |
+
|
| 766 |
+
# Inject a bug into our relu implementation; useful for testing our repro
|
| 767 |
+
# extraction and minification functionality.
|
| 768 |
+
# Valid values: "compile_error", "runtime_error", "accuracy"
|
| 769 |
+
inject_relu_bug_TESTING_ONLY: Optional[str] = None
|
| 770 |
+
inject_log1p_bug_TESTING_ONLY: Optional[str] = None
|
| 771 |
+
|
| 772 |
+
# If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise,
|
| 773 |
+
# force usage as specified, without testing.
|
| 774 |
+
vec_isa_ok: Optional[bool] = None
|
| 775 |
+
|
| 776 |
+
# similar to config.triton.descriptive_names
|
| 777 |
+
descriptive_names = "original_aten"
|
| 778 |
+
|
| 779 |
+
# how many nodes to allow into a single horizontal fusion
|
| 780 |
+
max_horizontal_fusion_size = int(
|
| 781 |
+
os.environ.get("TORCHINDUCTOR_CPP_MAX_HORIZONTAL_FUSION_SIZE", "16")
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# Make scatter_reduce fallback when reduce is sum to avoid performance regression
|
| 785 |
+
# using atomic_add.
|
| 786 |
+
fallback_scatter_reduce_sum = (
|
| 787 |
+
os.environ.get("TORCHINDUCTOR_CPP_FALLBACK_SCATTER_REDUCE_SUM", "1") == "1"
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
# Use funsafe-math-optimizations when compiling
|
| 791 |
+
enable_unsafe_math_opt_flag = (
|
| 792 |
+
os.environ.get("TORCHINDUCTOR_CPP_ENABLE_UNSAFE_MATH_OPT_FLAG", "0") == "1"
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
# Use ffp-contract when compiling
|
| 796 |
+
enable_floating_point_contract_flag = (
|
| 797 |
+
os.environ.get("TORCHINDUCTOR_CPP_ENABLE_FLOATING_POINT_CONTRACT_FLAG", "0")
|
| 798 |
+
== "1"
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
# Disable the tiling select heuristic
|
| 802 |
+
enable_tiling_heuristics = (
|
| 803 |
+
os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1"
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
# Maximal allowed number of slices on K-dim for a GEMM kernel. This controls
|
| 807 |
+
# the maximal parallelism of K-slicing. Since K-slicing requires extra thread
|
| 808 |
+
# synchronization and buffers, the maximal number of slices is limited to
|
| 809 |
+
# mitigate the sync overhead and memory usage.
|
| 810 |
+
# When set to 0, the number of slices is unlimited.
|
| 811 |
+
gemm_max_k_slices = int(os.environ.get("TORCHINDUCTOR_CPP_GEMM_MAX_K_SLICES", "1"))
|
| 812 |
+
|
| 813 |
+
# For perf tuning and debugging purpose, configure the pre-defined cache blocking for
|
| 814 |
+
# MxNxK dims respectively. The blockings are separated by comma and the unit is
|
| 815 |
+
# the number of register blocks.
|
| 816 |
+
# For example, "4,1,10" means 4 register blocks on M, 1 on N and 10 on K respectively.
|
| 817 |
+
gemm_cache_blocking = os.environ.get("TORCHINDUCTOR_CPP_GEMM_CACHE_BLOCKING", None)
|
| 818 |
+
|
| 819 |
+
# For perf tuning and debugging purpose, configure the pre-defined thread blocking factors for
|
| 820 |
+
# MxNxK dims respectively. The factors are separated by comma and their product
|
| 821 |
+
# should be the same as the total number of threads.
|
| 822 |
+
# For example, if the total number of threads is 56, "7,4,2" means the work is
|
| 823 |
+
# decomposed into 7x4x2 thread blocks along MxNxK of a GEMM.
|
| 824 |
+
gemm_thread_factors = os.environ.get("TORCHINDUCTOR_CPP_GEMM_THREAD_FACTORS", None)
|
| 825 |
+
|
| 826 |
+
# Whether to enable masked vectorization for the tail_loop.
|
| 827 |
+
enable_loop_tail_vec = True
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
# config specific to codegen/triton.py
|
| 831 |
+
class triton:
|
| 832 |
+
# Use cudagraphs on output code
|
| 833 |
+
cudagraphs = os.environ.get("TORCHINDUCTOR_CUDAGRAPHS") == "1"
|
| 834 |
+
|
| 835 |
+
# Use cudagraph trees for memory pooling if `cudagraphs` is True
|
| 836 |
+
cudagraph_trees = True
|
| 837 |
+
|
| 838 |
+
# Should we skip cudagraphing graphs with dynamic shape inputs
|
| 839 |
+
# If False, we will re-record a graph for each unique set of shape inputs
|
| 840 |
+
cudagraph_skip_dynamic_graphs = False
|
| 841 |
+
|
| 842 |
+
# assertions not on the fast path, steady state
|
| 843 |
+
slow_path_cudagraph_asserts = True
|
| 844 |
+
|
| 845 |
+
# TODO - need to debug why this prevents cleanup
|
| 846 |
+
cudagraph_trees_history_recording = False
|
| 847 |
+
|
| 848 |
+
# Enable cudagraph support for mutated inputs from prior cudagraph pool
|
| 849 |
+
cudagraph_support_input_mutation = False if is_fbcode() else True
|
| 850 |
+
|
| 851 |
+
# Maximal number of allowed cudagraph re-record for a function and
|
| 852 |
+
# a cudagraph node due to static input tensor address changes or
|
| 853 |
+
# cudagraph managed tensor data pointer changed.
|
| 854 |
+
# i.e., allow num_recording <= cudagraph_unexpected_rerecord_limit
|
| 855 |
+
# note: we are conservative here and choose a large limit.
|
| 856 |
+
cudagraph_unexpected_rerecord_limit = 128
|
| 857 |
+
|
| 858 |
+
# Warn loudly when the number of cudagraphs due to dynamic shape
|
| 859 |
+
# exceeds this limit
|
| 860 |
+
cudagraph_dynamic_shape_warn_limit: Optional[int] = 50
|
| 861 |
+
|
| 862 |
+
# synchronize after cudagraph invocation
|
| 863 |
+
force_cudagraph_sync = False
|
| 864 |
+
|
| 865 |
+
# always run cudagraphs in the eager warmup stage
|
| 866 |
+
# instead of recording and executing cudagraphs
|
| 867 |
+
force_cudagraphs_warmup = False
|
| 868 |
+
|
| 869 |
+
# assertions on the fast path
|
| 870 |
+
fast_path_cudagraph_asserts = False
|
| 871 |
+
|
| 872 |
+
# skip warmup for cudagraph trees
|
| 873 |
+
skip_cudagraph_warmup = False
|
| 874 |
+
|
| 875 |
+
# Synchronize before and after every compiled graph.
|
| 876 |
+
debug_sync_graph = False
|
| 877 |
+
|
| 878 |
+
# Synchronize after every kernel launch, to help pinpoint bugs
|
| 879 |
+
debug_sync_kernel = False
|
| 880 |
+
|
| 881 |
+
# Always load full blocks (rather than broadcasting inside the block)
|
| 882 |
+
dense_indexing = False
|
| 883 |
+
|
| 884 |
+
# limit tiling dimensions
|
| 885 |
+
max_tiles = 2
|
| 886 |
+
|
| 887 |
+
# Prefer higher dimensional tilings. This simplifies indexing expressions, making
|
| 888 |
+
# it easier to identify block pointers.
|
| 889 |
+
prefer_nd_tiling: bool = False
|
| 890 |
+
|
| 891 |
+
# use triton.autotune for pointwise ops with complex layouts
|
| 892 |
+
# this should only be disabled for debugging/testing
|
| 893 |
+
autotune_pointwise = True
|
| 894 |
+
|
| 895 |
+
# max autotune gemm with cublasLt
|
| 896 |
+
autotune_cublasLt = True
|
| 897 |
+
|
| 898 |
+
# Tune the generated Triton kernels at compile time instead of first time they run
|
| 899 |
+
autotune_at_compile_time = False
|
| 900 |
+
|
| 901 |
+
# should we stop a fusion to allow better tiling?
|
| 902 |
+
tiling_prevents_pointwise_fusion = True
|
| 903 |
+
tiling_prevents_reduction_fusion = True
|
| 904 |
+
|
| 905 |
+
# should we give different names to kernels
|
| 906 |
+
# Note: This is orthogonal to descriptive_names - this is deciding whether
|
| 907 |
+
# our triton kernel names should all be `triton_` (to maximize caching) or
|
| 908 |
+
# whether they should be unique.
|
| 909 |
+
unique_kernel_names = os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1"
|
| 910 |
+
|
| 911 |
+
# should we put op names in kernel names
|
| 912 |
+
# False: No special names (just triton__1, triton__2, etc.)
|
| 913 |
+
# "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.)
|
| 914 |
+
# "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions)
|
| 915 |
+
# "inductor_node": Maps to the node name in the FX graph passed to Inductor
|
| 916 |
+
descriptive_names = "original_aten"
|
| 917 |
+
|
| 918 |
+
# use alternate codegen for smaller reductions
|
| 919 |
+
persistent_reductions = (
|
| 920 |
+
os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1"
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
# 0/False: disable
|
| 924 |
+
# 1/True: enable, use tuning to pick between different subkernels
|
| 925 |
+
# 2: enable, force using persistent reduction (for debugging)
|
| 926 |
+
# 3: enable, force using non-persistent reduction (for debugging)
|
| 927 |
+
multi_kernel = int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0"))
|
| 928 |
+
|
| 929 |
+
# hint to Triton when arguments are divisible by 16
|
| 930 |
+
divisible_by_16 = True
|
| 931 |
+
|
| 932 |
+
# Minimum RBLOCK to be used for a TritonSplitScanKernel
|
| 933 |
+
# NOTE: This also indirectly controls the size of workspace buffer required
|
| 934 |
+
min_split_scan_rblock = 256
|
| 935 |
+
|
| 936 |
+
# Store the generated cubin files for cpp wrapper code to load
|
| 937 |
+
store_cubin = False
|
| 938 |
+
|
| 939 |
+
# the max number of spills we allow for the configs we benchmark.
|
| 940 |
+
# Setting this to 0 means we skip a config if it spills even a single
|
| 941 |
+
# register.
|
| 942 |
+
# Setting it to a larger value allows a config spilling a small amount
|
| 943 |
+
# of registers being benchmarked.
|
| 944 |
+
#
|
| 945 |
+
# NOTE: triton will always report >0 register spills for kernels using sin/cos.
|
| 946 |
+
# (check this issue https://github.com/openai/triton/issues/1756 )
|
| 947 |
+
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
|
| 948 |
+
# Raise the threshold to 16 to be safe.
|
| 949 |
+
# We should revisit this once we understand more of the source of register spills.
|
| 950 |
+
spill_threshold: int = 16
|
| 951 |
+
|
| 952 |
+
# Generate code containing the newer tl.make_block_ptr() API for loads/store
|
| 953 |
+
use_block_ptr = False
|
| 954 |
+
|
| 955 |
+
# Inject a bug into our relu implementation; useful for testing our repro
|
| 956 |
+
# extraction and minification functionality.
|
| 957 |
+
# Valid values: "compile_error", "runtime_error", "accuracy"
|
| 958 |
+
inject_relu_bug_TESTING_ONLY: Optional[str] = None
|
| 959 |
+
|
| 960 |
+
# Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental)
|
| 961 |
+
codegen_upcast_to_fp32 = True
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
class aot_inductor:
|
| 965 |
+
# AOTInductor output path
|
| 966 |
+
# If an absolute path is specified, the generated lib files will be stored under the directory;
|
| 967 |
+
# If a relative path is specified, it will be used as a subdirectory under the default caching path;
|
| 968 |
+
# If not specified, a temp directory will be created under the default caching path.
|
| 969 |
+
# If the specified path contains something like "model.so", the sub-string will be used
|
| 970 |
+
# to name the generated library.
|
| 971 |
+
output_path = ""
|
| 972 |
+
|
| 973 |
+
debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1"
|
| 974 |
+
|
| 975 |
+
debug_dump_consts_bin: bool = (
|
| 976 |
+
os.environ.get("AOT_INDUCTOR_DEBUG_DUMP_CONSTS_BIN", "0") == "1"
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
# option for debug printing/saving for intermediate tensor values for aot inductor
|
| 980 |
+
# 0: disable debug dumping
|
| 981 |
+
# 1: enable saving intermediate tensor values
|
| 982 |
+
# 2: enable printing intermediate tensor values
|
| 983 |
+
debug_intermediate_value_printer = os.environ.get(
|
| 984 |
+
"AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0"
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
# filtered nodes to be printed for debug values. Specify this option when debug_intermediate_value_printer is set to 2
|
| 988 |
+
filtered_kernel_names = os.environ.get(
|
| 989 |
+
"AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", None
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
# Serialized tree spec for flattening inputs
|
| 993 |
+
serialized_in_spec = ""
|
| 994 |
+
|
| 995 |
+
# Serialized tree spec for flattening outputs
|
| 996 |
+
serialized_out_spec = ""
|
| 997 |
+
|
| 998 |
+
# flag to decide whether to create a submodule for constant graph.
|
| 999 |
+
use_runtime_constant_folding: bool = False
|
| 1000 |
+
|
| 1001 |
+
# flag to force weight to be appened to the shared library and mmaped by the runtime
|
| 1002 |
+
# rather than embedded into the data section. Needed to support 1B+ parameter models
|
| 1003 |
+
force_mmap_weights: bool = False
|
| 1004 |
+
|
| 1005 |
+
package: bool = False
|
| 1006 |
+
|
| 1007 |
+
|
| 1008 |
+
class cuda:
|
| 1009 |
+
# CUDA arch to use for CUDA template kernel compilation.
|
| 1010 |
+
# e.g. "70", "75", "80", "90", etc.
|
| 1011 |
+
# When arch is None, Inductor uses torch.cuda.get_device_capability(0).
|
| 1012 |
+
arch: Optional[str] = None
|
| 1013 |
+
|
| 1014 |
+
# CUDA version to use for CUDA template kernel compilation.
|
| 1015 |
+
# e.g. "11.4", "12.1", etc.
|
| 1016 |
+
# When version is None, Inductor uses torch.version.cuda.
|
| 1017 |
+
version: Optional[str] = None
|
| 1018 |
+
|
| 1019 |
+
# Optimization level for the host compiler.
|
| 1020 |
+
compile_opt_level = "-O1"
|
| 1021 |
+
|
| 1022 |
+
# Whether to enable device LTO (link-time-optimization).
|
| 1023 |
+
enable_cuda_lto = False
|
| 1024 |
+
|
| 1025 |
+
# Whether to keep intermediate files dring compilation.
|
| 1026 |
+
enable_ptxas_info = False
|
| 1027 |
+
|
| 1028 |
+
# Whether to enable debug info, e.g. line number, cutlass debug info.
|
| 1029 |
+
enable_debug_info = False
|
| 1030 |
+
|
| 1031 |
+
# Whether to use fast math.
|
| 1032 |
+
use_fast_math = False
|
| 1033 |
+
|
| 1034 |
+
# Path to the CUTLASS repo root directory.
|
| 1035 |
+
# The default path only works under PyTorch local development environment.
|
| 1036 |
+
cutlass_dir = os.environ.get(
|
| 1037 |
+
"TORCHINDUCTOR_CUTLASS_DIR",
|
| 1038 |
+
os.path.abspath(
|
| 1039 |
+
os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/")
|
| 1040 |
+
),
|
| 1041 |
+
)
|
| 1042 |
+
|
| 1043 |
+
# Configures the maximum number of CUTLASS configs to profile in max_autotune.
|
| 1044 |
+
# By default it's None, so that all CUTLASS configs are tuned.
|
| 1045 |
+
# This is mainly used to reduce test time in CI.
|
| 1046 |
+
cutlass_max_profiling_configs: Optional[int] = None
|
| 1047 |
+
|
| 1048 |
+
# Path to CUDA NVCC.
|
| 1049 |
+
# NVCC search order:
|
| 1050 |
+
# 1) cuda_cxx set in this config
|
| 1051 |
+
# 2) CUDACXX environment variable
|
| 1052 |
+
# 3) CUDA_HOME environment variable
|
| 1053 |
+
# 4) default system search PATH.
|
| 1054 |
+
cuda_cxx: Optional[str] = None
|
| 1055 |
+
|
| 1056 |
+
# Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops.
|
| 1057 |
+
cutlass_backend_min_gemm_size: int = 1
|
| 1058 |
+
|
| 1059 |
+
# enable generation of inline standalone runner in CUDA CPP generated code
|
| 1060 |
+
# which allows to compile the generated code into a standalone executable.
|
| 1061 |
+
generate_test_runner: bool = (
|
| 1062 |
+
os.environ.get("INDUCTOR_CUDA_BACKEND_GENERATE_TEST_RUNNER_CODE", "1") == "1"
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
# Keep only Cutlass op configs which contain this regular expression pattern
|
| 1066 |
+
# Set this to "warpspecialized_cooperative_epi_tma" to enable only SM90 TMA Cutlass Kernels for large GEMMs
|
| 1067 |
+
cutlass_op_allowlist_regex: Optional[str] = None
|
| 1068 |
+
|
| 1069 |
+
# Note: Names of Cutlass ops names can be obtained by calling
|
| 1070 |
+
# op.configuration_name() on a Cutlass op instance, for example those
|
| 1071 |
+
# returned from cutlass_utils.gen_ops() or the op argument passed to
|
| 1072 |
+
# CUTLASSGemmTemplate.render(...)
|
| 1073 |
+
|
| 1074 |
+
# Filter Cutlass configs which contain this regular expression pattern
|
| 1075 |
+
# Set this to "pingpong" to avoid numerical issues
|
| 1076 |
+
# caused by the op ordering of the "pingpong" memory access
|
| 1077 |
+
# pattern used by some Cutlass Kernels.
|
| 1078 |
+
cutlass_op_denylist_regex: Optional[str] = "pingpong"
|
| 1079 |
+
|
| 1080 |
+
|
| 1081 |
+
class rocm:
|
| 1082 |
+
# Offload arch list for device code compilation, e.g. ["gfx941", "gfx942"].
|
| 1083 |
+
# If empty, the `native` arch is used
|
| 1084 |
+
arch: List[str] = []
|
| 1085 |
+
|
| 1086 |
+
# Enable the CK backend for CDNA2 and CDNA3 only (for now)
|
| 1087 |
+
# Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors
|
| 1088 |
+
ck_supported_arch: List[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"]
|
| 1089 |
+
|
| 1090 |
+
# Optimization level, use to balance compilation speed and runtime performance
|
| 1091 |
+
compile_opt_level = "-O2"
|
| 1092 |
+
|
| 1093 |
+
# Flag to keep debug information in compiled objects
|
| 1094 |
+
is_debug = False
|
| 1095 |
+
|
| 1096 |
+
# Flag to keep intermediate files (assembly listings, preprocessed sources, etc.)
|
| 1097 |
+
save_temps = False
|
| 1098 |
+
|
| 1099 |
+
# Flag to add `-ffast-math`` to compile flags
|
| 1100 |
+
use_fast_math = True
|
| 1101 |
+
|
| 1102 |
+
# Flag to add `-fgpu-flush-denormals-to-zero` to compile flags
|
| 1103 |
+
flush_denormals = True
|
| 1104 |
+
|
| 1105 |
+
# Flag to print register and LDS usage during compilation
|
| 1106 |
+
print_kernel_resource_usage = False
|
| 1107 |
+
|
| 1108 |
+
# Path to ROCm installation, if None, use env variable ROCM_HOME
|
| 1109 |
+
rocm_home: Optional[str] = None
|
| 1110 |
+
|
| 1111 |
+
# Path to Composable Kernel library.
|
| 1112 |
+
# Install with `pip install git+https://github.com/rocm/composable_kernel@develop`.
|
| 1113 |
+
ck_dir = os.environ.get("TORCHINDUCTOR_CK_DIR")
|
| 1114 |
+
|
| 1115 |
+
# Number of op instance choices to trade off between runtime perf and compilation time
|
| 1116 |
+
n_max_profiling_configs: Optional[int] = None
|
| 1117 |
+
|
| 1118 |
+
# Flag to use a short list of CK instances which perform well across a variety of shapes.
|
| 1119 |
+
# Currently RCR and F16 only
|
| 1120 |
+
use_preselected_instances: bool = False
|
| 1121 |
+
|
| 1122 |
+
|
| 1123 |
+
# Backend to use for CPU codegen either "cpp" or "halide" (experimental)
|
| 1124 |
+
cpu_backend = "cpp"
|
| 1125 |
+
|
| 1126 |
+
# Backend to use for CUDA codegen either "triton" or "halide" (experimental)
|
| 1127 |
+
cuda_backend = "triton"
|
| 1128 |
+
|
| 1129 |
+
|
| 1130 |
+
class halide:
|
| 1131 |
+
# Base halide target to use for CPU devices
|
| 1132 |
+
cpu_target = "host"
|
| 1133 |
+
|
| 1134 |
+
# Base halide target to use for CUDA devices
|
| 1135 |
+
gpu_target = "host-cuda"
|
| 1136 |
+
|
| 1137 |
+
# Halide autoscheduler to use, choices are:
|
| 1138 |
+
# "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only)
|
| 1139 |
+
scheduler_cuda = "Anderson2021"
|
| 1140 |
+
scheduler_cpu = "Adams2019"
|
| 1141 |
+
|
| 1142 |
+
# Controls `no_asserts` flag passed to Halide target (warning: can false positive)
|
| 1143 |
+
asserts = False
|
| 1144 |
+
|
| 1145 |
+
# Controls `debug` flag passed to Halide target
|
| 1146 |
+
debug = False
|
| 1147 |
+
|
| 1148 |
+
# Enable (or fallback on) scan kernels such as cumsum
|
| 1149 |
+
# Halide autoschedulers struggle with these kernels
|
| 1150 |
+
scan_kernels = False
|
| 1151 |
+
|
| 1152 |
+
|
| 1153 |
+
# create a directory containing lots of debug information
|
| 1154 |
+
class trace:
|
| 1155 |
+
# master switch for all debugging flags below
|
| 1156 |
+
enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
|
| 1157 |
+
|
| 1158 |
+
# Save debug information to a temporary directory
|
| 1159 |
+
# If not specified, a temp directory will be created by system
|
| 1160 |
+
debug_dir: Optional[str] = None
|
| 1161 |
+
|
| 1162 |
+
# Save python logger call >=logging.DEBUG
|
| 1163 |
+
debug_log = False
|
| 1164 |
+
|
| 1165 |
+
# Save python logger call >=logging.INFO
|
| 1166 |
+
info_log = False
|
| 1167 |
+
|
| 1168 |
+
# Save input FX graph (post decomps, pre optimization)
|
| 1169 |
+
fx_graph = True
|
| 1170 |
+
|
| 1171 |
+
# Save FX graph after transformations
|
| 1172 |
+
fx_graph_transformed = True
|
| 1173 |
+
|
| 1174 |
+
# Save TorchInductor IR before fusion pass
|
| 1175 |
+
ir_pre_fusion = True
|
| 1176 |
+
|
| 1177 |
+
# Save TorchInductor IR after fusion pass
|
| 1178 |
+
ir_post_fusion = True
|
| 1179 |
+
|
| 1180 |
+
# Copy generated code to trace dir
|
| 1181 |
+
output_code = True
|
| 1182 |
+
|
| 1183 |
+
# SVG figure showing post-fusion graph
|
| 1184 |
+
graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1"
|
| 1185 |
+
|
| 1186 |
+
# SVG figure showing fx with fusion
|
| 1187 |
+
draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1"
|
| 1188 |
+
|
| 1189 |
+
# We draw our fx graphs with the "record" shape attribute by default.
|
| 1190 |
+
# Sometimes, when the graph is very complex, we may hit dot errors like below:
|
| 1191 |
+
# "flat edge between adjacent nodes one of which has a record shape -
|
| 1192 |
+
# replace records with HTML-like labels"
|
| 1193 |
+
# and thus fail to generate a graph. So, let's give the user an option
|
| 1194 |
+
# to specify the shape attribute for the dot graph. For example, passing
|
| 1195 |
+
# INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like lables
|
| 1196 |
+
# to workaround the above failure.
|
| 1197 |
+
dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None)
|
| 1198 |
+
|
| 1199 |
+
# If not None, this is the URL that saves the SVG files of the input/output
|
| 1200 |
+
# graph of each pass that changed the graph
|
| 1201 |
+
# The nodes that are being transformed in each pass will be colored in yellow
|
| 1202 |
+
# URL only supports local directory for now
|
| 1203 |
+
log_url_for_graph_xform = os.environ.get("INDUCTOR_LOG_URL_FOR_GRAPH_XFORM", None)
|
| 1204 |
+
|
| 1205 |
+
# Store cProfile (see snakeviz to view)
|
| 1206 |
+
compile_profile = False
|
| 1207 |
+
|
| 1208 |
+
# Upload the .tar.gz file
|
| 1209 |
+
# Needs to be overriden based on specific environment needs
|
| 1210 |
+
upload_tar: Optional[Callable[[str], None]] = None
|
| 1211 |
+
|
| 1212 |
+
log_autotuning_results: bool = False
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
_save_config_ignore = [
|
| 1216 |
+
# workaround: "Can't pickle <function ...>"
|
| 1217 |
+
"trace.upload_tar",
|
| 1218 |
+
"post_grad_custom_post_pass",
|
| 1219 |
+
"post_grad_custom_pre_pass",
|
| 1220 |
+
"joint_custom_pre_pass",
|
| 1221 |
+
"joint_custom_post_pass",
|
| 1222 |
+
"pre_grad_custom_pass",
|
| 1223 |
+
]
|
| 1224 |
+
|
| 1225 |
+
_cache_config_ignore_prefix = [
|
| 1226 |
+
# trace functions are not relevant to config caching
|
| 1227 |
+
"trace",
|
| 1228 |
+
# uses absolute path
|
| 1229 |
+
"cuda.cutlass_dir",
|
| 1230 |
+
# not relevant
|
| 1231 |
+
"compile_threads",
|
| 1232 |
+
]
|
| 1233 |
+
|
| 1234 |
+
if TYPE_CHECKING:
|
| 1235 |
+
from torch.utils._config_typing import * # noqa: F401, F403
|
| 1236 |
+
|
| 1237 |
+
from torch.utils._config_module import install_config_module
|
| 1238 |
+
|
| 1239 |
+
|
| 1240 |
+
# adds patch, save_config, etc
|
| 1241 |
+
install_config_module(sys.modules[__name__])
|
.venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils._pytree as pytree
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
aten = torch.ops.aten
|
| 9 |
+
|
| 10 |
+
# We would like to split modules into two subgraphs for runtime weight updates to work correctly.
|
| 11 |
+
# The use case and more information could be found at:
|
| 12 |
+
# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing
|
| 13 |
+
META_TAG = "MODULE_TYPE"
|
| 14 |
+
MODULE_TAG = "_MAIN_MODULE"
|
| 15 |
+
CONST_MODULE_TAG = "_CONST_MODULE"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def replace_node_with_constant(
|
| 19 |
+
gm: torch.fx.GraphModule,
|
| 20 |
+
node: torch.fx.Node,
|
| 21 |
+
constant: torch.Tensor,
|
| 22 |
+
name: Optional[str] = None,
|
| 23 |
+
) -> None:
|
| 24 |
+
g = gm.graph
|
| 25 |
+
|
| 26 |
+
if name:
|
| 27 |
+
qualname = name
|
| 28 |
+
else:
|
| 29 |
+
if not hasattr(gm, "_frozen_param_count"):
|
| 30 |
+
gm._frozen_param_count = 0 # type: ignore[assignment]
|
| 31 |
+
i = gm._frozen_param_count
|
| 32 |
+
|
| 33 |
+
while True:
|
| 34 |
+
qualname = f"_frozen_param{i}"
|
| 35 |
+
if not hasattr(gm, qualname):
|
| 36 |
+
break
|
| 37 |
+
i += 1
|
| 38 |
+
|
| 39 |
+
gm._frozen_param_count = i + 1
|
| 40 |
+
|
| 41 |
+
with g.inserting_before(node):
|
| 42 |
+
new_input_node = g.create_node("get_attr", qualname, (), {})
|
| 43 |
+
node.replace_all_uses_with(new_input_node)
|
| 44 |
+
new_input_node.meta.update(node.meta)
|
| 45 |
+
g.erase_node(node)
|
| 46 |
+
|
| 47 |
+
# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
|
| 48 |
+
gm.register_buffer(qualname, constant)
|
| 49 |
+
setattr(gm, qualname, constant)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def is_const_source(
|
| 53 |
+
node: torch.fx.Node, lifted_constants: Optional[Dict[str, Any]]
|
| 54 |
+
) -> bool:
|
| 55 |
+
return node.op == "get_attr" or (
|
| 56 |
+
node.op == "placeholder"
|
| 57 |
+
and lifted_constants is not None
|
| 58 |
+
and node.name in lifted_constants
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ConstantFolder(torch.fx.Interpreter):
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
gm: torch.fx.GraphModule,
|
| 66 |
+
skip_constructors: bool = False,
|
| 67 |
+
lifted_constants: Optional[Dict[str, torch.Tensor]] = None,
|
| 68 |
+
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
| 69 |
+
) -> None:
|
| 70 |
+
super().__init__(gm)
|
| 71 |
+
self.node_replacements: Dict[torch.fx.Node, Any] = {}
|
| 72 |
+
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
|
| 73 |
+
self.unknown_value = object()
|
| 74 |
+
self.skip_constructors: bool = skip_constructors
|
| 75 |
+
|
| 76 |
+
# overwrite this to deallocate env values if their only remaining use
|
| 77 |
+
# is the output
|
| 78 |
+
self.user_to_last_uses = self.node_to_last_non_output_use()
|
| 79 |
+
self.lifted_constants = lifted_constants
|
| 80 |
+
|
| 81 |
+
def _support_dynamic_shape(self) -> bool:
|
| 82 |
+
# ConstantFolder not support dynamic shape now
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
def _deduce_value(self, node: torch.fx.Node) -> Any:
|
| 86 |
+
return super().run_node(node)
|
| 87 |
+
|
| 88 |
+
def is_impure(self, node: torch.fx.node.Node) -> bool:
|
| 89 |
+
if (
|
| 90 |
+
node.target == torch.ops.prims.convert_element_type.default
|
| 91 |
+
and is_const_source(node.args[0], self.lifted_constants) # type: ignore[arg-type]
|
| 92 |
+
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
|
| 93 |
+
and node.args[1] == torch.bfloat16
|
| 94 |
+
):
|
| 95 |
+
# For int8_weight -> dq -> bf16_weight
|
| 96 |
+
return True
|
| 97 |
+
if node.target in [
|
| 98 |
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
| 99 |
+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
| 100 |
+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
| 101 |
+
]:
|
| 102 |
+
# For the pattern fp32_weight -> q -> dq
|
| 103 |
+
# We only folding fp32_weight -> q
|
| 104 |
+
# int8_weight and leave dq in graph to be fused
|
| 105 |
+
return True
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
def node_to_last_non_output_use(self) -> Dict[torch.fx.Node, List[torch.fx.Node]]:
|
| 109 |
+
last_non_output_use = collections.defaultdict(list)
|
| 110 |
+
seen_uses = set()
|
| 111 |
+
output_node = next(iter(reversed(self.module.graph.nodes)))
|
| 112 |
+
|
| 113 |
+
for node in reversed(self.module.graph.nodes):
|
| 114 |
+
if node.target == "output":
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
def add_use(inp: torch.fx.Node) -> None:
|
| 118 |
+
if inp in seen_uses:
|
| 119 |
+
return
|
| 120 |
+
|
| 121 |
+
seen_uses.add(inp)
|
| 122 |
+
last_non_output_use[node].append(inp)
|
| 123 |
+
|
| 124 |
+
# In-place is fine since we don't mutate
|
| 125 |
+
pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))
|
| 126 |
+
|
| 127 |
+
# if this node is only used in output, we want to gc it right away
|
| 128 |
+
if len(node.users) == 1 and output_node in node.users:
|
| 129 |
+
last_non_output_use[node].append(node)
|
| 130 |
+
|
| 131 |
+
return last_non_output_use
|
| 132 |
+
|
| 133 |
+
def run_node(self, node: torch.fx.Node) -> Any:
|
| 134 |
+
if node.target == "output":
|
| 135 |
+
# because we remove nodes from env on last non output use,
|
| 136 |
+
# re-define them now or we'll get error in interpreter
|
| 137 |
+
def set_env(arg: torch.fx.Node) -> None:
|
| 138 |
+
self.env[arg] = self.unknown_value
|
| 139 |
+
|
| 140 |
+
# In-place is fine since we don't mutate
|
| 141 |
+
pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
|
| 142 |
+
return super().run_node(node)
|
| 143 |
+
|
| 144 |
+
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
| 145 |
+
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
|
| 146 |
+
|
| 147 |
+
# We need to do this weird thing because in cases where flattened_inputs
|
| 148 |
+
# contains a ScriptObject, equality checking results in a type error if
|
| 149 |
+
# the types are different.
|
| 150 |
+
if any(
|
| 151 |
+
type(self.unknown_value) == type(input_) and self.unknown_value == input_
|
| 152 |
+
for input_ in flattened_inputs
|
| 153 |
+
):
|
| 154 |
+
return self.unknown_value
|
| 155 |
+
|
| 156 |
+
# TODO - fix errors with this
|
| 157 |
+
if (
|
| 158 |
+
node.op == "call_function"
|
| 159 |
+
and node.target == aten._efficientzerotensor.default
|
| 160 |
+
):
|
| 161 |
+
return self.unknown_value
|
| 162 |
+
|
| 163 |
+
# TODO - constant folding triton kernel returns the inputs -- fix this
|
| 164 |
+
if (
|
| 165 |
+
node.op == "call_function"
|
| 166 |
+
and node.name == "triton_kernel_wrapper_functional_proxy"
|
| 167 |
+
):
|
| 168 |
+
return self.unknown_value
|
| 169 |
+
|
| 170 |
+
# skip constructors, since inductor generates optimal code for them already
|
| 171 |
+
# and turning into tensor would result in an additional global memory read
|
| 172 |
+
# TODO - more complicated strategy
|
| 173 |
+
if (
|
| 174 |
+
self.skip_constructors
|
| 175 |
+
and not is_const_source(node, self.lifted_constants)
|
| 176 |
+
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
|
| 177 |
+
):
|
| 178 |
+
return self.unknown_value
|
| 179 |
+
|
| 180 |
+
# All mutations should either be removed or on inputs which we did not make constant
|
| 181 |
+
if (
|
| 182 |
+
isinstance(node.target, torch._ops.OpOverload)
|
| 183 |
+
and torch.Tag.nondeterministic_seeded in node.target.tags
|
| 184 |
+
):
|
| 185 |
+
return self.unknown_value
|
| 186 |
+
|
| 187 |
+
out = self._deduce_value(node)
|
| 188 |
+
if out == self.unknown_value:
|
| 189 |
+
return self.unknown_value
|
| 190 |
+
|
| 191 |
+
if not is_const_source(node, self.lifted_constants) and isinstance(
|
| 192 |
+
out, torch.Tensor
|
| 193 |
+
):
|
| 194 |
+
if out.device.type == "meta":
|
| 195 |
+
return out
|
| 196 |
+
|
| 197 |
+
if not self.insertable_tensor_check(out):
|
| 198 |
+
return out
|
| 199 |
+
|
| 200 |
+
if self.is_impure(node):
|
| 201 |
+
return self.unknown_value
|
| 202 |
+
|
| 203 |
+
self.add_node_replacement(node, out)
|
| 204 |
+
|
| 205 |
+
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
|
| 206 |
+
|
| 207 |
+
for n in flattened_node_inps:
|
| 208 |
+
if not isinstance(n, torch.fx.Node):
|
| 209 |
+
continue
|
| 210 |
+
|
| 211 |
+
self.replaced_uses[n] += 1
|
| 212 |
+
|
| 213 |
+
for to_delete in self.user_to_last_uses.get(node, []):
|
| 214 |
+
if self.replaced_uses[to_delete] == len(to_delete.users):
|
| 215 |
+
self.node_replacements.pop(to_delete, None)
|
| 216 |
+
|
| 217 |
+
return out
|
| 218 |
+
|
| 219 |
+
def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
|
| 220 |
+
return True
|
| 221 |
+
|
| 222 |
+
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
|
| 223 |
+
self.node_replacements[node] = tensor
|
| 224 |
+
|
| 225 |
+
def run(self) -> Any: # type: ignore[override]
|
| 226 |
+
env: Dict[torch.fx.Node, Any] = {}
|
| 227 |
+
self.insert_placerholder_values(env)
|
| 228 |
+
return super().run(initial_env=env)
|
| 229 |
+
|
| 230 |
+
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
|
| 231 |
+
for n in self.module.graph.find_nodes(op="placeholder"):
|
| 232 |
+
if self.lifted_constants is not None and n.name in self.lifted_constants:
|
| 233 |
+
env[n] = self.lifted_constants[n.name]
|
| 234 |
+
else:
|
| 235 |
+
env[n] = self.unknown_value # type: ignore[assignment]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def constant_fold(
|
| 239 |
+
gm: torch.fx.GraphModule,
|
| 240 |
+
constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
| 241 |
+
) -> None:
|
| 242 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 243 |
+
cf = ConstantFolder(gm, skip_constructors=True)
|
| 244 |
+
cf.run()
|
| 245 |
+
|
| 246 |
+
for node, constant in cf.node_replacements.items():
|
| 247 |
+
if constraint_fn is not None and not constraint_fn(node):
|
| 248 |
+
continue
|
| 249 |
+
replace_node_with_constant(gm, node, constant)
|
| 250 |
+
|
| 251 |
+
erased_params = []
|
| 252 |
+
for node in gm.graph.find_nodes(op="get_attr"):
|
| 253 |
+
if len(node.users) == 0:
|
| 254 |
+
if hasattr(gm, node.target):
|
| 255 |
+
delattr(gm, node.target)
|
| 256 |
+
erased_params.append(node)
|
| 257 |
+
|
| 258 |
+
for node in erased_params:
|
| 259 |
+
gm.graph.erase_node(node)
|
| 260 |
+
|
| 261 |
+
gm.graph.eliminate_dead_code()
|
| 262 |
+
gm.graph.lint()
|
| 263 |
+
gm.recompile()
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def constant_graph_tag(
|
| 267 |
+
gm: torch.fx.GraphModule,
|
| 268 |
+
lifted_constants: Optional[Dict[str, Any]],
|
| 269 |
+
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]],
|
| 270 |
+
) -> None:
|
| 271 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 272 |
+
cf = ConstantFolder(
|
| 273 |
+
gm, skip_constructors=True, lifted_constants=lifted_constants
|
| 274 |
+
)
|
| 275 |
+
cf.run()
|
| 276 |
+
|
| 277 |
+
for node in gm.graph.nodes:
|
| 278 |
+
if skip_folding_node_fn is not None and skip_folding_node_fn(node):
|
| 279 |
+
node.meta[META_TAG] = MODULE_TAG
|
| 280 |
+
continue
|
| 281 |
+
if (
|
| 282 |
+
is_const_source(node, lifted_constants)
|
| 283 |
+
or node in cf.node_replacements
|
| 284 |
+
or node in cf.replaced_uses
|
| 285 |
+
):
|
| 286 |
+
node.meta[META_TAG] = CONST_MODULE_TAG
|
| 287 |
+
else:
|
| 288 |
+
node.meta[META_TAG] = MODULE_TAG
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def run_and_get_constant_graph(
|
| 292 |
+
gm: torch.fx.GraphModule,
|
| 293 |
+
lifted_constants: Optional[Dict[str, Any]],
|
| 294 |
+
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]],
|
| 295 |
+
) -> Tuple[torch.fx.GraphModule, Tuple[torch.Tensor, ...]]:
|
| 296 |
+
"""
|
| 297 |
+
Construct a GraphModule which corresponds to the part which could be
|
| 298 |
+
constant folded in provided gm.
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
constant_graph_tag(gm, lifted_constants, skip_folding_node_fn)
|
| 302 |
+
|
| 303 |
+
def untag(node: torch.fx.Node) -> bool:
|
| 304 |
+
used_to_fold = False
|
| 305 |
+
for u in node.users:
|
| 306 |
+
if u.meta[META_TAG] == CONST_MODULE_TAG:
|
| 307 |
+
used_to_fold = True
|
| 308 |
+
break
|
| 309 |
+
if not used_to_fold:
|
| 310 |
+
node.meta[META_TAG] = MODULE_TAG
|
| 311 |
+
return used_to_fold
|
| 312 |
+
|
| 313 |
+
const_args = []
|
| 314 |
+
if lifted_constants is not None:
|
| 315 |
+
placeholders = list(gm.graph.find_nodes(op="placeholder"))
|
| 316 |
+
for node in placeholders:
|
| 317 |
+
if node.meta[META_TAG] == MODULE_TAG:
|
| 318 |
+
continue
|
| 319 |
+
if untag(node):
|
| 320 |
+
const_args.append(lifted_constants[node.name])
|
| 321 |
+
|
| 322 |
+
# We rewrite the tags, if it's a constant being directly consumed, without
|
| 323 |
+
# any folding opportunity, we keep it in main gm.
|
| 324 |
+
for node in gm.graph.find_nodes(op="get_attr"):
|
| 325 |
+
untag(node)
|
| 326 |
+
|
| 327 |
+
new_graph = torch.fx.Graph()
|
| 328 |
+
|
| 329 |
+
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
| 330 |
+
output_nodes = []
|
| 331 |
+
for node in gm.graph.nodes:
|
| 332 |
+
if node.meta[META_TAG] == MODULE_TAG:
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
|
| 336 |
+
node_remapping[node] = new_node
|
| 337 |
+
|
| 338 |
+
for user in node.users:
|
| 339 |
+
if user.meta[META_TAG] == MODULE_TAG:
|
| 340 |
+
output_nodes.append(new_node)
|
| 341 |
+
break
|
| 342 |
+
|
| 343 |
+
new_graph.output(tuple(output_nodes))
|
| 344 |
+
new_graph.lint()
|
| 345 |
+
new_gm = torch.fx.GraphModule(gm, new_graph)
|
| 346 |
+
|
| 347 |
+
const_result = new_gm(*const_args)
|
| 348 |
+
return new_gm, const_result
|
.venv/lib/python3.11/site-packages/torch/_inductor/cpu_vec_isa.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import dataclasses
|
| 3 |
+
import functools
|
| 4 |
+
import os
|
| 5 |
+
import platform
|
| 6 |
+
import re
|
| 7 |
+
import subprocess
|
| 8 |
+
import sys
|
| 9 |
+
from typing import Any, Callable, Dict, List
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch._inductor import config
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_IS_WINDOWS = sys.platform == "win32"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str:
|
| 19 |
+
# ISA dry compile will cost about 1 sec time each startup time.
|
| 20 |
+
# Please check the issue: https://github.com/pytorch/pytorch/issues/100378
|
| 21 |
+
# Actually, dry compile is checking compile capability for ISA.
|
| 22 |
+
# We just record the compiler version, isa options and pytorch version info,
|
| 23 |
+
# and generated them to output binary hash path.
|
| 24 |
+
# It would optimize and skip compile existing binary.
|
| 25 |
+
from torch._inductor.cpp_builder import get_compiler_version_info, get_cpp_compiler
|
| 26 |
+
|
| 27 |
+
compiler_info = get_compiler_version_info(get_cpp_compiler())
|
| 28 |
+
torch_version = torch.__version__
|
| 29 |
+
fingerprint = f"{compiler_info}={isa_flags}={torch_version}"
|
| 30 |
+
return fingerprint
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class VecISA:
|
| 34 |
+
_bit_width: int
|
| 35 |
+
_macro: List[str]
|
| 36 |
+
_arch_flags: str
|
| 37 |
+
_dtype_nelements: Dict[torch.dtype, int]
|
| 38 |
+
|
| 39 |
+
# Note [Checking for Vectorized Support in Inductor]
|
| 40 |
+
# TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
|
| 41 |
+
# Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions
|
| 42 |
+
# like exp, pow, sin, cos and etc.
|
| 43 |
+
# But PyTorch and TorchInductor might use different compilers to build code. If
|
| 44 |
+
# PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so
|
| 45 |
+
# will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass
|
| 46 |
+
# avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest
|
| 47 |
+
# gcc/g++ compiler by default while it could support the AVX512 compilation.
|
| 48 |
+
# Therefore, there would be a conflict sleef version between PyTorch and
|
| 49 |
+
# TorchInductor. Hence, we dry-compile the following code to check whether current
|
| 50 |
+
# HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM
|
| 51 |
+
# also needs the logic
|
| 52 |
+
# In fbcode however, we are using the same compiler for pytorch and for inductor codegen,
|
| 53 |
+
# making the runtime check unnecessary.
|
| 54 |
+
_avx_code = """
|
| 55 |
+
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX)
|
| 56 |
+
#include <ATen/cpu/vec/functional.h>
|
| 57 |
+
#include <ATen/cpu/vec/vec.h>
|
| 58 |
+
#endif
|
| 59 |
+
|
| 60 |
+
alignas(64) float in_out_ptr0[16] = {0.0};
|
| 61 |
+
|
| 62 |
+
extern "C" void __avx_chk_kernel() {
|
| 63 |
+
auto tmp0 = at::vec::Vectorized<float>(1);
|
| 64 |
+
auto tmp1 = tmp0.exp();
|
| 65 |
+
tmp1.store(in_out_ptr0);
|
| 66 |
+
}
|
| 67 |
+
""" # noqa: B950
|
| 68 |
+
|
| 69 |
+
_avx_py_load = """
|
| 70 |
+
import torch
|
| 71 |
+
from ctypes import cdll
|
| 72 |
+
cdll.LoadLibrary("__lib_path__")
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def bit_width(self) -> int:
|
| 76 |
+
return self._bit_width
|
| 77 |
+
|
| 78 |
+
def nelements(self, dtype: torch.dtype = torch.float) -> int:
|
| 79 |
+
return self._dtype_nelements[dtype]
|
| 80 |
+
|
| 81 |
+
def build_macro(self) -> List[str]:
|
| 82 |
+
return self._macro
|
| 83 |
+
|
| 84 |
+
def build_arch_flags(self) -> str:
|
| 85 |
+
return self._arch_flags
|
| 86 |
+
|
| 87 |
+
def __hash__(self) -> int:
|
| 88 |
+
return hash(str(self))
|
| 89 |
+
|
| 90 |
+
def check_build(self, code: str) -> bool:
|
| 91 |
+
from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT, write
|
| 92 |
+
from torch._inductor.cpp_builder import (
|
| 93 |
+
CppBuilder,
|
| 94 |
+
CppTorchOptions,
|
| 95 |
+
normalize_path_separator,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
key, input_path = write(
|
| 99 |
+
code,
|
| 100 |
+
"cpp",
|
| 101 |
+
extra=_get_isa_dry_compile_fingerprint(self._arch_flags),
|
| 102 |
+
)
|
| 103 |
+
from filelock import FileLock
|
| 104 |
+
|
| 105 |
+
lock_dir = get_lock_dir()
|
| 106 |
+
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
| 107 |
+
with lock:
|
| 108 |
+
output_dir = os.path.dirname(input_path)
|
| 109 |
+
buid_options = CppTorchOptions(vec_isa=self, warning_all=False)
|
| 110 |
+
x86_isa_help_builder = CppBuilder(
|
| 111 |
+
key,
|
| 112 |
+
[input_path],
|
| 113 |
+
buid_options,
|
| 114 |
+
output_dir,
|
| 115 |
+
)
|
| 116 |
+
try:
|
| 117 |
+
# Check if the output file exist, and compile when not.
|
| 118 |
+
output_path = normalize_path_separator(
|
| 119 |
+
x86_isa_help_builder.get_target_file_path()
|
| 120 |
+
)
|
| 121 |
+
if not os.path.isfile(output_path):
|
| 122 |
+
status, target_file = x86_isa_help_builder.build()
|
| 123 |
+
|
| 124 |
+
# Check build result
|
| 125 |
+
subprocess.check_call(
|
| 126 |
+
[
|
| 127 |
+
sys.executable,
|
| 128 |
+
"-c",
|
| 129 |
+
VecISA._avx_py_load.replace("__lib_path__", output_path),
|
| 130 |
+
],
|
| 131 |
+
cwd=output_dir,
|
| 132 |
+
stderr=subprocess.DEVNULL,
|
| 133 |
+
env={**os.environ, "PYTHONPATH": ":".join(sys.path)},
|
| 134 |
+
)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
return True
|
| 139 |
+
|
| 140 |
+
@functools.lru_cache(None) # noqa: B019
|
| 141 |
+
def __bool__(self) -> bool:
|
| 142 |
+
if config.cpp.vec_isa_ok is not None:
|
| 143 |
+
return config.cpp.vec_isa_ok
|
| 144 |
+
|
| 145 |
+
if config.is_fbcode():
|
| 146 |
+
return True
|
| 147 |
+
|
| 148 |
+
return self.check_build(VecISA._avx_code)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@dataclasses.dataclass
|
| 152 |
+
class VecNEON(VecISA):
|
| 153 |
+
_bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
|
| 154 |
+
_macro = ["CPU_CAPABILITY_NEON"]
|
| 155 |
+
if sys.platform == "darwin" and platform.processor() == "arm":
|
| 156 |
+
_macro.append("AT_BUILD_ARM_VEC256_WITH_SLEEF")
|
| 157 |
+
_arch_flags = "" # Unused
|
| 158 |
+
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
| 159 |
+
|
| 160 |
+
def __str__(self) -> str:
|
| 161 |
+
return "asimd" # detects the presence of advanced SIMD on armv8-a kernels
|
| 162 |
+
|
| 163 |
+
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@dataclasses.dataclass
|
| 167 |
+
class VecAVX512(VecISA):
|
| 168 |
+
_bit_width = 512
|
| 169 |
+
_macro = ["CPU_CAPABILITY_AVX512"]
|
| 170 |
+
_arch_flags = (
|
| 171 |
+
"-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
|
| 172 |
+
if not _IS_WINDOWS
|
| 173 |
+
else "/arch:AVX512"
|
| 174 |
+
) # TODO: use cflags
|
| 175 |
+
_dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32}
|
| 176 |
+
|
| 177 |
+
def __str__(self) -> str:
|
| 178 |
+
return "avx512"
|
| 179 |
+
|
| 180 |
+
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@dataclasses.dataclass
|
| 184 |
+
class VecAMX(VecAVX512):
|
| 185 |
+
_arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8"
|
| 186 |
+
|
| 187 |
+
def __str__(self) -> str:
|
| 188 |
+
return super().__str__() + " amx_tile"
|
| 189 |
+
|
| 190 |
+
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
| 191 |
+
|
| 192 |
+
_amx_code = """
|
| 193 |
+
#include <cstdint>
|
| 194 |
+
#include <immintrin.h>
|
| 195 |
+
|
| 196 |
+
struct amx_tilecfg {
|
| 197 |
+
uint8_t palette_id;
|
| 198 |
+
uint8_t start_row;
|
| 199 |
+
uint8_t reserved_0[14];
|
| 200 |
+
uint16_t colsb[16];
|
| 201 |
+
uint8_t rows[16];
|
| 202 |
+
};
|
| 203 |
+
|
| 204 |
+
extern "C" void __amx_chk_kernel() {
|
| 205 |
+
amx_tilecfg cfg = {0};
|
| 206 |
+
_tile_loadconfig(&cfg);
|
| 207 |
+
_tile_zero(0);
|
| 208 |
+
_tile_dpbf16ps(0, 1, 2);
|
| 209 |
+
_tile_dpbusd(0, 1, 2);
|
| 210 |
+
}
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
@functools.lru_cache(None) # noqa: B019
|
| 214 |
+
def __bool__(self) -> bool:
|
| 215 |
+
if super().__bool__():
|
| 216 |
+
if config.is_fbcode():
|
| 217 |
+
return False
|
| 218 |
+
if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx():
|
| 219 |
+
return True
|
| 220 |
+
return False
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@dataclasses.dataclass
|
| 224 |
+
class VecAVX2(VecISA):
|
| 225 |
+
_bit_width = 256
|
| 226 |
+
_macro = ["CPU_CAPABILITY_AVX2"]
|
| 227 |
+
_arch_flags = (
|
| 228 |
+
"-mavx2 -mfma -mf16c" if not _IS_WINDOWS else "/arch:AVX2"
|
| 229 |
+
) # TODO: use cflags
|
| 230 |
+
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
| 231 |
+
|
| 232 |
+
def __str__(self) -> str:
|
| 233 |
+
return "avx2"
|
| 234 |
+
|
| 235 |
+
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@dataclasses.dataclass
|
| 239 |
+
class VecZVECTOR(VecISA):
|
| 240 |
+
_bit_width = 256
|
| 241 |
+
_macro = [
|
| 242 |
+
"CPU_CAPABILITY_ZVECTOR",
|
| 243 |
+
"CPU_CAPABILITY=ZVECTOR",
|
| 244 |
+
"HAVE_ZVECTOR_CPU_DEFINITION",
|
| 245 |
+
]
|
| 246 |
+
_arch_flags = "-mvx -mzvector"
|
| 247 |
+
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
| 248 |
+
|
| 249 |
+
def __str__(self) -> str:
|
| 250 |
+
return "zvector"
|
| 251 |
+
|
| 252 |
+
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@dataclasses.dataclass
|
| 256 |
+
class VecVSX(VecISA):
|
| 257 |
+
_bit_width = 256 # VSX simd supports 128 bit_width, but aten is emulating it as 256
|
| 258 |
+
_macro = ["CPU_CAPABILITY_VSX"]
|
| 259 |
+
_arch_flags = "-mvsx"
|
| 260 |
+
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
| 261 |
+
|
| 262 |
+
def __str__(self) -> str:
|
| 263 |
+
return "vsx"
|
| 264 |
+
|
| 265 |
+
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class InvalidVecISA(VecISA):
|
| 269 |
+
_bit_width = 0
|
| 270 |
+
_macro = [""]
|
| 271 |
+
_arch_flags = ""
|
| 272 |
+
_dtype_nelements = {}
|
| 273 |
+
|
| 274 |
+
def __str__(self) -> str:
|
| 275 |
+
return "INVALID_VEC_ISA"
|
| 276 |
+
|
| 277 |
+
def __bool__(self) -> bool: # type: ignore[override]
|
| 278 |
+
return False
|
| 279 |
+
|
| 280 |
+
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def x86_isa_checker() -> List[str]:
|
| 284 |
+
supported_isa: List[str] = []
|
| 285 |
+
|
| 286 |
+
def _check_and_append_supported_isa(
|
| 287 |
+
dest: List[str], isa_supported: bool, isa_name: str
|
| 288 |
+
) -> None:
|
| 289 |
+
if isa_supported:
|
| 290 |
+
dest.append(isa_name)
|
| 291 |
+
|
| 292 |
+
Arch = platform.machine()
|
| 293 |
+
"""
|
| 294 |
+
Arch value is x86_64 on Linux, and the value is AMD64 on Windows.
|
| 295 |
+
"""
|
| 296 |
+
if Arch != "x86_64" and Arch != "AMD64":
|
| 297 |
+
return supported_isa
|
| 298 |
+
|
| 299 |
+
avx2 = torch.cpu._is_avx2_supported()
|
| 300 |
+
avx512 = torch.cpu._is_avx512_supported()
|
| 301 |
+
amx_tile = torch.cpu._is_amx_tile_supported()
|
| 302 |
+
|
| 303 |
+
_check_and_append_supported_isa(supported_isa, avx2, "avx2")
|
| 304 |
+
_check_and_append_supported_isa(supported_isa, avx512, "avx512")
|
| 305 |
+
_check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile")
|
| 306 |
+
|
| 307 |
+
return supported_isa
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
invalid_vec_isa = InvalidVecISA()
|
| 311 |
+
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()]
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
|
| 315 |
+
# might have too much redundant content that is useless for ISA check. Hence,
|
| 316 |
+
# we only cache some key isa information.
|
| 317 |
+
@functools.lru_cache(None)
|
| 318 |
+
def valid_vec_isa_list() -> List[VecISA]:
|
| 319 |
+
isa_list: List[VecISA] = []
|
| 320 |
+
if sys.platform == "darwin" and platform.processor() == "arm":
|
| 321 |
+
isa_list.append(VecNEON())
|
| 322 |
+
|
| 323 |
+
if sys.platform not in ["linux", "win32"]:
|
| 324 |
+
return isa_list
|
| 325 |
+
|
| 326 |
+
arch = platform.machine()
|
| 327 |
+
if arch == "s390x":
|
| 328 |
+
with open("/proc/cpuinfo") as _cpu_info:
|
| 329 |
+
while True:
|
| 330 |
+
line = _cpu_info.readline()
|
| 331 |
+
if not line:
|
| 332 |
+
break
|
| 333 |
+
# process line
|
| 334 |
+
featuresmatch = re.match(r"^features\s*:\s*(.*)$", line)
|
| 335 |
+
if featuresmatch:
|
| 336 |
+
for group in featuresmatch.groups():
|
| 337 |
+
if re.search(r"[\^ ]+vxe[\$ ]+", group):
|
| 338 |
+
isa_list.append(VecZVECTOR())
|
| 339 |
+
break
|
| 340 |
+
elif arch == "ppc64le":
|
| 341 |
+
isa_list.append(VecVSX())
|
| 342 |
+
elif arch == "aarch64":
|
| 343 |
+
isa_list.append(VecNEON())
|
| 344 |
+
elif arch in ["x86_64", "AMD64"]:
|
| 345 |
+
"""
|
| 346 |
+
arch value is x86_64 on Linux, and the value is AMD64 on Windows.
|
| 347 |
+
"""
|
| 348 |
+
_cpu_supported_x86_isa = x86_isa_checker()
|
| 349 |
+
for isa in supported_vec_isa_list:
|
| 350 |
+
if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa:
|
| 351 |
+
isa_list.append(isa)
|
| 352 |
+
|
| 353 |
+
return isa_list
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def pick_vec_isa() -> VecISA:
|
| 357 |
+
if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]):
|
| 358 |
+
return VecAVX2()
|
| 359 |
+
|
| 360 |
+
_valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
|
| 361 |
+
if not _valid_vec_isa_list:
|
| 362 |
+
return invalid_vec_isa
|
| 363 |
+
|
| 364 |
+
# If the simdlen is None, it indicates determine the vectorization length automatically
|
| 365 |
+
if config.cpp.simdlen is None:
|
| 366 |
+
assert _valid_vec_isa_list
|
| 367 |
+
return _valid_vec_isa_list[0]
|
| 368 |
+
|
| 369 |
+
for isa in _valid_vec_isa_list:
|
| 370 |
+
if config.cpp.simdlen == isa.bit_width():
|
| 371 |
+
return isa
|
| 372 |
+
|
| 373 |
+
return invalid_vec_isa
|
.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_utils.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import dataclasses
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch._dynamo.utils import counters
|
| 10 |
+
from torch._inductor.utils import InputType
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
| 14 |
+
static_inputs_log = torch._logging.getArtifactLogger(
|
| 15 |
+
__name__, "cudagraph_static_inputs"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
OutputType = List[Optional[Union[int, torch.Tensor]]]
|
| 20 |
+
ModelType = Callable[[List[InputType]], OutputType]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclasses.dataclass(frozen=True)
|
| 24 |
+
class FunctionID:
|
| 25 |
+
"Unique counter of a function wrapped in cudagraphify_impl"
|
| 26 |
+
id: int
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclasses.dataclass(frozen=True)
|
| 30 |
+
class PlaceholderInfo:
|
| 31 |
+
"""
|
| 32 |
+
A serializable version of torch.fx.Node that contains information
|
| 33 |
+
pertinent to placeholder stack traces. We use these in logging and error messages
|
| 34 |
+
related to cudagraphs, and will cache these results.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
name: str
|
| 38 |
+
stack_trace: Optional[str]
|
| 39 |
+
# This field is recursive, but never cyclic (since a node never uses itself)
|
| 40 |
+
users: List[PlaceholderInfo]
|
| 41 |
+
mutating_use_stack_trace: Optional[str]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclasses.dataclass(frozen=True)
|
| 45 |
+
class WrappedFunction:
|
| 46 |
+
"""
|
| 47 |
+
Represents a function that you want to record for CUDA graph replay,
|
| 48 |
+
with a little more metadata so we can identify if we have an applicable
|
| 49 |
+
CUDA graph in our CUDA graph tree for it.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
model: Callable[..., Any]
|
| 53 |
+
static_input_idxs: Sequence[int]
|
| 54 |
+
id: FunctionID
|
| 55 |
+
constants: Tuple[torch.Tensor, ...]
|
| 56 |
+
placeholders: Sequence[PlaceholderInfo]
|
| 57 |
+
mutated_input_idxs: Sequence[int]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_mutating_use_stack_trace_from_node(
|
| 61 |
+
placeholder_node: torch.fx.Node,
|
| 62 |
+
) -> Optional[str]:
|
| 63 |
+
# reinplaced uses might have a single, non-copy_ use
|
| 64 |
+
if len(placeholder_node.users) == 1:
|
| 65 |
+
return next(iter(placeholder_node.users)).meta.get("stack_trace", None)
|
| 66 |
+
|
| 67 |
+
for use in placeholder_node.users:
|
| 68 |
+
if use.target == torch.ops.aten.copy_.default:
|
| 69 |
+
if stack_trace := use.meta.get("stack_trace", None):
|
| 70 |
+
return stack_trace
|
| 71 |
+
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_mutating_use_stack_trace(placeholder_info: PlaceholderInfo) -> Optional[str]:
|
| 76 |
+
return placeholder_info.mutating_use_stack_trace
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo:
|
| 80 |
+
name = placeholder_node.name
|
| 81 |
+
stack_trace = placeholder_node.meta.get("stack_trace", None)
|
| 82 |
+
users = []
|
| 83 |
+
mutating_use_stack_trace = None
|
| 84 |
+
# Only recurse to users once, since we only care about user's stack traces
|
| 85 |
+
if placeholder_node.op == "placeholder":
|
| 86 |
+
users = [to_placeholder_info(i) for i in placeholder_node.users]
|
| 87 |
+
mutating_use_stack_trace = get_mutating_use_stack_trace_from_node(
|
| 88 |
+
placeholder_node
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_placeholder_info(graph: torch.fx.Graph) -> List[PlaceholderInfo]:
|
| 95 |
+
return [
|
| 96 |
+
to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder"
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def format_default_skip_message(reason: str) -> str:
|
| 101 |
+
return f"skipping cudagraphs due to {reason}"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def get_mutation_stack_trace(
|
| 105 |
+
placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int]
|
| 106 |
+
) -> str:
|
| 107 |
+
stack_trace: Optional[str] = ""
|
| 108 |
+
|
| 109 |
+
for idx in mutation_indices:
|
| 110 |
+
placeholder = placeholders[idx]
|
| 111 |
+
if stack_trace := get_mutating_use_stack_trace(placeholder):
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
msg = format_default_skip_message(
|
| 115 |
+
f"mutated inputs ({len(mutation_indices)} instances)"
|
| 116 |
+
)
|
| 117 |
+
if stack_trace:
|
| 118 |
+
return f"{msg}. Found from : \n {stack_trace}"
|
| 119 |
+
|
| 120 |
+
return msg
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def check_for_mutation(
|
| 124 |
+
func: WrappedFunction,
|
| 125 |
+
inputs: List[InputType],
|
| 126 |
+
is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool],
|
| 127 |
+
) -> Optional[str]:
|
| 128 |
+
# doesnt work for non-trees because the warmup run would apply mutation twice
|
| 129 |
+
if torch._inductor.config.triton.cudagraph_trees:
|
| 130 |
+
# checking if mutation is only on parameters/static inputs
|
| 131 |
+
mutation_indices: Sequence[int] = [
|
| 132 |
+
idx
|
| 133 |
+
for idx in func.mutated_input_idxs
|
| 134 |
+
if not (
|
| 135 |
+
idx in func.static_input_idxs
|
| 136 |
+
or is_cuda_graph_recorded_tensor(inputs[idx]) # type: ignore[arg-type]
|
| 137 |
+
)
|
| 138 |
+
]
|
| 139 |
+
else:
|
| 140 |
+
mutation_indices = func.mutated_input_idxs
|
| 141 |
+
|
| 142 |
+
static_inputs_log.debug(
|
| 143 |
+
"check mutation static input indices: %s", func.static_input_idxs
|
| 144 |
+
)
|
| 145 |
+
static_inputs_log.debug("check mutation mutation indices: %s", mutation_indices)
|
| 146 |
+
|
| 147 |
+
return (
|
| 148 |
+
get_mutation_stack_trace(func.placeholders, mutation_indices)
|
| 149 |
+
if mutation_indices
|
| 150 |
+
else None
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _get_use_stack_trace(node) -> Optional[str]:
|
| 155 |
+
for use in node.users:
|
| 156 |
+
if stack_trace := use.meta.get("stack_trace", None):
|
| 157 |
+
return stack_trace
|
| 158 |
+
return None
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def check_multiple_devices_or_any_cpu_nodes(
|
| 162 |
+
device_node_mapping: Dict[torch.device, torch.fx.Node]
|
| 163 |
+
) -> Optional[str]:
|
| 164 |
+
if cpu_node := device_node_mapping.get(torch.device("cpu")):
|
| 165 |
+
msg = f"cpu device ({cpu_node.name})"
|
| 166 |
+
if stack_trace := _get_use_stack_trace(cpu_node):
|
| 167 |
+
return format_default_skip_message(f"{msg}. Found from : \n {stack_trace}")
|
| 168 |
+
|
| 169 |
+
return format_default_skip_message(msg)
|
| 170 |
+
|
| 171 |
+
if (
|
| 172 |
+
len(device_node_mapping) == 1
|
| 173 |
+
and next(iter(device_node_mapping.keys())).type == "cuda"
|
| 174 |
+
):
|
| 175 |
+
return None
|
| 176 |
+
|
| 177 |
+
keys_repr = (repr(key) for key in device_node_mapping.keys())
|
| 178 |
+
return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def check_lowering_disable_cudagraph(
|
| 182 |
+
device_node_mapping: Dict[torch.device, torch.fx.Node]
|
| 183 |
+
):
|
| 184 |
+
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def log_cudagraph_skip_and_bump_counter(msg):
|
| 188 |
+
perf_hint_log.warning(msg)
|
| 189 |
+
counters["inductor"]["cudagraph_skips"] += 1
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@dataclasses.dataclass
|
| 193 |
+
class BoxedDeviceIndex:
|
| 194 |
+
value: Optional[int]
|
| 195 |
+
|
| 196 |
+
def set(self, device_idx: Optional[int]):
|
| 197 |
+
assert device_idx is None or isinstance(device_idx, int)
|
| 198 |
+
self.value = device_idx
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def check_for_mutation_ignore_cuda_graph_managed_tensor(
|
| 202 |
+
gm: torch.fx.GraphModule, compiled_graph, static_input_idxs: Sequence[int]
|
| 203 |
+
) -> Optional[str]:
|
| 204 |
+
default_msg = format_default_skip_message("mutated inputs")
|
| 205 |
+
|
| 206 |
+
# doesnt work for non-trees because the warmup run would apply mutation twice
|
| 207 |
+
if torch._inductor.config.triton.cudagraph_trees:
|
| 208 |
+
unique_idxs = set(static_input_idxs)
|
| 209 |
+
# checking if mutation is only on parameters/static inputs
|
| 210 |
+
mutation_indices = [
|
| 211 |
+
idx for idx in compiled_graph.mutated_input_idxs if idx not in unique_idxs
|
| 212 |
+
]
|
| 213 |
+
has_mutation = len(mutation_indices) != 0
|
| 214 |
+
if not has_mutation:
|
| 215 |
+
return None
|
| 216 |
+
placeholders = get_placeholder_info(gm.graph)
|
| 217 |
+
return get_mutation_stack_trace(placeholders, mutation_indices)
|
| 218 |
+
|
| 219 |
+
else:
|
| 220 |
+
has_mutation = len(compiled_graph.mutated_inputs) != 0
|
| 221 |
+
return None if not has_mutation else default_msg
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_placeholder_stack_trace(placeholder: PlaceholderInfo) -> Optional[str]:
|
| 225 |
+
"""
|
| 226 |
+
Gets the first non-empty stack trace of a placeholder or its users.
|
| 227 |
+
"""
|
| 228 |
+
if placeholder.stack_trace:
|
| 229 |
+
return placeholder.stack_trace
|
| 230 |
+
|
| 231 |
+
for user in placeholder.users:
|
| 232 |
+
if user.stack_trace:
|
| 233 |
+
return user.stack_trace
|
| 234 |
+
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class CheckInvariantStatus(Enum):
|
| 239 |
+
# Check invariant succeeded
|
| 240 |
+
SUCCESS = 1
|
| 241 |
+
|
| 242 |
+
# Previously managed data pointers are not stable
|
| 243 |
+
CudagraphManagedIdxMismatch = 2
|
| 244 |
+
|
| 245 |
+
# Static tensor input addresses are not stable
|
| 246 |
+
StaticInputIdxMismatch = 3
|
| 247 |
+
|
| 248 |
+
# Expected dead indices before graph are live
|
| 249 |
+
ExpectedDeadIndicesBeforeGraphMismatch = 4
|
| 250 |
+
|
| 251 |
+
def __str__(self) -> str:
|
| 252 |
+
if self.name == "CudagraphManagedIdxMismatch":
|
| 253 |
+
return "cudagraph managed tensor data pointer changed"
|
| 254 |
+
elif self.name == "StaticInputIdxMismatch":
|
| 255 |
+
return "static input data pointer changed"
|
| 256 |
+
elif self.name == "ExpectedDeadIndicesBeforeGraphMismatch":
|
| 257 |
+
return "expected dead indices before graph are live"
|
| 258 |
+
else:
|
| 259 |
+
return f"{self.name}: {self.value}"
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def log_data_ptr_mismatch(
|
| 263 |
+
placeholders: Sequence[PlaceholderInfo],
|
| 264 |
+
inputs: List[InputType],
|
| 265 |
+
recorded_data_ptr: Sequence[Optional[int]],
|
| 266 |
+
target_idxs: Sequence[int],
|
| 267 |
+
mismatch: CheckInvariantStatus,
|
| 268 |
+
) -> str:
|
| 269 |
+
"""
|
| 270 |
+
Logs the mismatch between input data pointers and recorded data pointers.
|
| 271 |
+
This checks only idxs in target_idxs.
|
| 272 |
+
"""
|
| 273 |
+
assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(
|
| 274 |
+
placeholders
|
| 275 |
+
), "length mismatch between inputs, recorded_data_ptr, and placeholders"
|
| 276 |
+
|
| 277 |
+
t_tensors = [inputs[i] for i in target_idxs]
|
| 278 |
+
t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs]
|
| 279 |
+
error_msg = f"{mismatch}.\n"
|
| 280 |
+
for i, (tensor, data_ptr) in enumerate(zip(t_tensors, t_data_ptrs)):
|
| 281 |
+
assert isinstance(tensor, torch.Tensor)
|
| 282 |
+
index = target_idxs[i]
|
| 283 |
+
if tensor.data_ptr() != data_ptr:
|
| 284 |
+
placeholder = placeholders[index]
|
| 285 |
+
error_msg = (
|
| 286 |
+
f"{error_msg}input name: {placeholder.name}. "
|
| 287 |
+
f"data pointer changed from {data_ptr} to {tensor.data_ptr()}. "
|
| 288 |
+
f"input stack trace: {get_placeholder_stack_trace(placeholder)}\n"
|
| 289 |
+
)
|
| 290 |
+
return error_msg
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def maybe_warning_due_to_dynamic_shape(
|
| 294 |
+
fn_cache: Dict[Tuple[int, ...], Callable[..., Any]],
|
| 295 |
+
new_int_key: Any,
|
| 296 |
+
) -> bool:
|
| 297 |
+
num_cudagraphs = len(fn_cache.keys()) + 1
|
| 298 |
+
|
| 299 |
+
def warn_msg():
|
| 300 |
+
return (
|
| 301 |
+
"CUDAGraph supports dynamic shapes by recording a new graph for each "
|
| 302 |
+
"distinct input size. Recording too many CUDAGraphs may lead to "
|
| 303 |
+
f"extra overhead. We have observed {num_cudagraphs} distinct sizes. "
|
| 304 |
+
"Please consider the following options for better performance: "
|
| 305 |
+
"a) padding inputs to a few fixed number of shapes; or b) set "
|
| 306 |
+
"torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. "
|
| 307 |
+
"Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None "
|
| 308 |
+
"to silence this warning."
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if (
|
| 312 |
+
torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit
|
| 313 |
+
and num_cudagraphs
|
| 314 |
+
> torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit
|
| 315 |
+
):
|
| 316 |
+
perf_hint_log.warning(warn_msg())
|
| 317 |
+
return True
|
| 318 |
+
|
| 319 |
+
return False
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@dataclasses.dataclass(frozen=True)
|
| 323 |
+
class CudagraphCachedInfo:
|
| 324 |
+
"""
|
| 325 |
+
Info needed to realign inputs
|
| 326 |
+
"""
|
| 327 |
+
|
| 328 |
+
placeholders: Sequence[PlaceholderInfo]
|
| 329 |
+
stack_traces: List[Optional[str]]
|
| 330 |
+
cudagraph_fail_reasons: List[str]
|
.venv/lib/python3.11/site-packages/torch/_inductor/debug.py
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import contextlib
|
| 3 |
+
import dataclasses
|
| 4 |
+
import functools
|
| 5 |
+
import itertools
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import os.path
|
| 9 |
+
import pickle
|
| 10 |
+
import pstats
|
| 11 |
+
import shutil
|
| 12 |
+
import subprocess
|
| 13 |
+
from typing import Any, Callable, Dict, IO, Iterator, List, Optional, Type, Union
|
| 14 |
+
from unittest.mock import patch
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
|
| 18 |
+
from torch import fx as fx
|
| 19 |
+
from torch._dynamo.repro.after_aot import save_graph_repro
|
| 20 |
+
from torch._dynamo.utils import get_debug_dir
|
| 21 |
+
from torch.fx.graph_module import GraphModule
|
| 22 |
+
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
| 23 |
+
from torch.fx.passes.tools_common import legalize_graph
|
| 24 |
+
from torch.utils._pytree import tree_map
|
| 25 |
+
|
| 26 |
+
from . import config, ir # noqa: F811, this is needed
|
| 27 |
+
from .scheduler import (
|
| 28 |
+
BaseSchedulerNode,
|
| 29 |
+
FusedSchedulerNode,
|
| 30 |
+
NopKernelSchedulerNode,
|
| 31 |
+
OutputNode,
|
| 32 |
+
SchedulerNode,
|
| 33 |
+
)
|
| 34 |
+
from .virtualized import V
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
log = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
SchedulerNodeList = List[Any]
|
| 40 |
+
BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
|
| 41 |
+
GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@functools.lru_cache(None)
|
| 45 |
+
def has_dot() -> bool:
|
| 46 |
+
try:
|
| 47 |
+
subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
|
| 48 |
+
return True
|
| 49 |
+
except subprocess.SubprocessError:
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def draw_buffers(
|
| 54 |
+
nodes: List[BaseSchedulerNode],
|
| 55 |
+
print_graph: bool = False,
|
| 56 |
+
fname: Optional[str] = None,
|
| 57 |
+
) -> None:
|
| 58 |
+
"""
|
| 59 |
+
Draw a graph in fname.svg.
|
| 60 |
+
"""
|
| 61 |
+
if not has_dot():
|
| 62 |
+
log.warning("draw_buffers() requires `graphviz` package")
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
if fname is None:
|
| 66 |
+
fname = get_graph_being_compiled()
|
| 67 |
+
|
| 68 |
+
graph = create_fx_from_snodes(nodes)
|
| 69 |
+
|
| 70 |
+
for node in graph.nodes:
|
| 71 |
+
if "fusion_meta" not in node.meta:
|
| 72 |
+
continue
|
| 73 |
+
group = node.meta["fusion_meta"].group
|
| 74 |
+
if isinstance(group, tuple):
|
| 75 |
+
if isinstance(group[1], int):
|
| 76 |
+
group = (group[1],)
|
| 77 |
+
else:
|
| 78 |
+
group = group[1]
|
| 79 |
+
|
| 80 |
+
# gather meta data
|
| 81 |
+
dtype = None
|
| 82 |
+
if isinstance(node, ir.ComputedBuffer):
|
| 83 |
+
dtype = node.data.dtype
|
| 84 |
+
|
| 85 |
+
metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type]
|
| 86 |
+
node.meta["tensor_meta"] = metadata
|
| 87 |
+
|
| 88 |
+
if print_graph:
|
| 89 |
+
print(graph)
|
| 90 |
+
|
| 91 |
+
gm = GraphModule({}, graph)
|
| 92 |
+
legalize_graph(gm)
|
| 93 |
+
gm.graph.lint()
|
| 94 |
+
draw_graph(
|
| 95 |
+
gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
| 100 |
+
"""
|
| 101 |
+
Creates a FX Graph from a list of SchedulerNode objects.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def get_fake_func(name: str) -> Callable[..., int]:
|
| 105 |
+
def func1(*args: Any) -> int:
|
| 106 |
+
return 0
|
| 107 |
+
|
| 108 |
+
func1.__name__ = name
|
| 109 |
+
return func1
|
| 110 |
+
|
| 111 |
+
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])
|
| 112 |
+
|
| 113 |
+
buf_to_fx_node = {}
|
| 114 |
+
node_to_fx_node = {}
|
| 115 |
+
graph = torch.fx.Graph()
|
| 116 |
+
first_node = None
|
| 117 |
+
|
| 118 |
+
outputs = []
|
| 119 |
+
group: Any = None
|
| 120 |
+
# create call_function node for each Buffer and Kernel
|
| 121 |
+
for snode in snodes:
|
| 122 |
+
if snode.is_extern():
|
| 123 |
+
node_type = "extern"
|
| 124 |
+
group = node_type
|
| 125 |
+
elif snode.is_template():
|
| 126 |
+
node_type = "template"
|
| 127 |
+
group = node_type
|
| 128 |
+
elif isinstance(snode, NopKernelSchedulerNode):
|
| 129 |
+
node_type = "nop"
|
| 130 |
+
group = node_type
|
| 131 |
+
elif isinstance(snode, SchedulerNode):
|
| 132 |
+
node_type = "compute"
|
| 133 |
+
group = snode.group
|
| 134 |
+
elif isinstance(snode, FusedSchedulerNode):
|
| 135 |
+
node_type = "fused"
|
| 136 |
+
group = snode.group
|
| 137 |
+
else:
|
| 138 |
+
raise RuntimeError("Unknown node type")
|
| 139 |
+
|
| 140 |
+
fused_name = torch._inductor.utils.get_fused_kernel_name(
|
| 141 |
+
snode.get_nodes(), "original_aten"
|
| 142 |
+
)
|
| 143 |
+
func_name = f"{node_type}: {fused_name}"
|
| 144 |
+
node_func = get_fake_func(func_name)
|
| 145 |
+
kwargs = {}
|
| 146 |
+
if hasattr(snode, "get_device"):
|
| 147 |
+
kwargs = {"device": snode.get_device()}
|
| 148 |
+
fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type]
|
| 149 |
+
|
| 150 |
+
def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool:
|
| 151 |
+
if isinstance(snode, FusedSchedulerNode):
|
| 152 |
+
return any(in_output(x) for x in snode.snodes)
|
| 153 |
+
return any(
|
| 154 |
+
isinstance(user.node, OutputNode)
|
| 155 |
+
for buf in snode.get_outputs()
|
| 156 |
+
for user in buf.users
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
if in_output(snode):
|
| 160 |
+
outputs.append(fx_node)
|
| 161 |
+
name = snode.get_name()
|
| 162 |
+
fx_node.name = name
|
| 163 |
+
|
| 164 |
+
fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type)
|
| 165 |
+
|
| 166 |
+
node_to_fx_node[name] = fx_node
|
| 167 |
+
for buf in snode.get_outputs():
|
| 168 |
+
buf_to_fx_node[buf.get_name()] = fx_node
|
| 169 |
+
|
| 170 |
+
if first_node is None:
|
| 171 |
+
first_node = fx_node
|
| 172 |
+
|
| 173 |
+
# create edges between nodes
|
| 174 |
+
for snode in snodes:
|
| 175 |
+
name = snode.get_name()
|
| 176 |
+
deps = snode.read_writes.reads
|
| 177 |
+
|
| 178 |
+
fx_node = node_to_fx_node[name]
|
| 179 |
+
new_args = []
|
| 180 |
+
for dep in deps:
|
| 181 |
+
if dep.name in buf_to_fx_node:
|
| 182 |
+
dep_node = buf_to_fx_node[dep.name]
|
| 183 |
+
else:
|
| 184 |
+
with graph.inserting_before(first_node):
|
| 185 |
+
dep_node = graph.placeholder(dep.name)
|
| 186 |
+
buf_to_fx_node[dep.name] = dep_node
|
| 187 |
+
if dep_node == fx_node: # to avoid cycles
|
| 188 |
+
continue
|
| 189 |
+
new_args.append(dep_node)
|
| 190 |
+
|
| 191 |
+
fx_node.args = tuple(new_args)
|
| 192 |
+
|
| 193 |
+
graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
|
| 194 |
+
return graph
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def update_orig_fx_node_name_to_buf_name(
|
| 198 |
+
nodes: Optional[SchedulerNodeList],
|
| 199 |
+
node_name_to_buf_name: Dict[str, str],
|
| 200 |
+
parent_buf_name: Optional[str] = None,
|
| 201 |
+
n_origins: int = 0,
|
| 202 |
+
) -> None:
|
| 203 |
+
if nodes is None:
|
| 204 |
+
return
|
| 205 |
+
for node in nodes:
|
| 206 |
+
# for FusedSchedulerNode, traverse recursively into get_nodes()
|
| 207 |
+
buf_name = node.get_name()
|
| 208 |
+
children_nodes = node.get_nodes()
|
| 209 |
+
if children_nodes is not None and len(children_nodes) > 1:
|
| 210 |
+
update_orig_fx_node_name_to_buf_name(
|
| 211 |
+
children_nodes,
|
| 212 |
+
node_name_to_buf_name,
|
| 213 |
+
buf_name if parent_buf_name is None else parent_buf_name,
|
| 214 |
+
)
|
| 215 |
+
continue
|
| 216 |
+
else:
|
| 217 |
+
assert len(children_nodes) == 1 and children_nodes[0] == node
|
| 218 |
+
|
| 219 |
+
ir_node = node.node
|
| 220 |
+
if ir_node is None or ir_node.origins is None:
|
| 221 |
+
continue
|
| 222 |
+
for origin in ir_node.origins:
|
| 223 |
+
node_name = origin.name
|
| 224 |
+
# when buf1 and buf2 both have origin=node1
|
| 225 |
+
# we draw node1 according to buf1
|
| 226 |
+
if node_name not in node_name_to_buf_name:
|
| 227 |
+
node_name_to_buf_name[node_name] = (
|
| 228 |
+
buf_name if parent_buf_name is None else parent_buf_name
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def get_node_name_to_buf_meta(
|
| 233 |
+
node_name_to_buf_name: Dict[str, str]
|
| 234 |
+
) -> Dict[str, BufMeta]:
|
| 235 |
+
buf_name_to_n_node = {}
|
| 236 |
+
for node_name, buf_name in node_name_to_buf_name.items():
|
| 237 |
+
if buf_name not in buf_name_to_n_node:
|
| 238 |
+
buf_name_to_n_node[buf_name] = {node_name}
|
| 239 |
+
else:
|
| 240 |
+
buf_name_to_n_node[buf_name].add(node_name)
|
| 241 |
+
|
| 242 |
+
node_name_to_buf_meta = {}
|
| 243 |
+
for node_name, buf_name in node_name_to_buf_name.items():
|
| 244 |
+
n_node = len(buf_name_to_n_node[buf_name])
|
| 245 |
+
node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node)
|
| 246 |
+
return node_name_to_buf_meta
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def annotate_orig_fx_with_snodes(
|
| 250 |
+
gm: torch.fx.GraphModule,
|
| 251 |
+
snodes: SchedulerNodeList,
|
| 252 |
+
) -> None:
|
| 253 |
+
"""
|
| 254 |
+
Creates a FX Graph from a list of SchedulerNode objects.
|
| 255 |
+
"""
|
| 256 |
+
node_name_to_buf_name: Dict[str, str] = {}
|
| 257 |
+
update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
|
| 258 |
+
if node_name_to_buf_name is None:
|
| 259 |
+
return
|
| 260 |
+
node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name)
|
| 261 |
+
for node in gm.graph.nodes:
|
| 262 |
+
if node.name in node_name_to_buf_meta:
|
| 263 |
+
node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
@contextlib.contextmanager
|
| 267 |
+
def enable_aot_logging() -> Iterator[None]:
|
| 268 |
+
compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
|
| 269 |
+
|
| 270 |
+
import torch._functorch.aot_autograd
|
| 271 |
+
|
| 272 |
+
log = logging.getLogger(torch._functorch.aot_autograd.__name__)
|
| 273 |
+
|
| 274 |
+
stack = contextlib.ExitStack()
|
| 275 |
+
if not compile_debug:
|
| 276 |
+
try:
|
| 277 |
+
yield
|
| 278 |
+
finally:
|
| 279 |
+
stack.close()
|
| 280 |
+
return
|
| 281 |
+
|
| 282 |
+
# Enable all graphs to be logged to a file by setting the flags to True
|
| 283 |
+
# and the log level of the file logger to DEBUG
|
| 284 |
+
stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
|
| 285 |
+
|
| 286 |
+
path = os.path.join(get_debug_dir(), "torchinductor")
|
| 287 |
+
os.makedirs(path, exist_ok=True)
|
| 288 |
+
|
| 289 |
+
fh = logging.FileHandler(
|
| 290 |
+
os.path.join(
|
| 291 |
+
path,
|
| 292 |
+
f"aot_{get_aot_graph_name()}_debug.log",
|
| 293 |
+
)
|
| 294 |
+
)
|
| 295 |
+
fh.setLevel(logging.DEBUG)
|
| 296 |
+
fh.setFormatter(
|
| 297 |
+
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
|
| 298 |
+
)
|
| 299 |
+
log.addHandler(fh)
|
| 300 |
+
try:
|
| 301 |
+
yield
|
| 302 |
+
finally:
|
| 303 |
+
log.removeHandler(fh)
|
| 304 |
+
stack.close()
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class DebugContext:
|
| 308 |
+
_counter = itertools.count()
|
| 309 |
+
|
| 310 |
+
@staticmethod
|
| 311 |
+
def create_debug_dir(folder_name: str) -> Optional[str]:
|
| 312 |
+
debug_dir = config.trace.debug_dir or get_debug_dir()
|
| 313 |
+
for n in DebugContext._counter:
|
| 314 |
+
dirname = os.path.join(
|
| 315 |
+
debug_dir,
|
| 316 |
+
"torchinductor",
|
| 317 |
+
f"{folder_name}.{n}",
|
| 318 |
+
)
|
| 319 |
+
if not os.path.exists(dirname):
|
| 320 |
+
os.makedirs(dirname)
|
| 321 |
+
return dirname
|
| 322 |
+
return None
|
| 323 |
+
|
| 324 |
+
def __init__(self) -> None:
|
| 325 |
+
self._prof = None
|
| 326 |
+
self._path = None
|
| 327 |
+
self._stack = contextlib.ExitStack()
|
| 328 |
+
|
| 329 |
+
def copy(self, new_path: str) -> None:
|
| 330 |
+
if not self._path:
|
| 331 |
+
return
|
| 332 |
+
assert new_path.endswith(".debug"), new_path
|
| 333 |
+
from filelock import FileLock
|
| 334 |
+
|
| 335 |
+
try:
|
| 336 |
+
with FileLock(f"{new_path}.lock"):
|
| 337 |
+
if os.path.exists(new_path):
|
| 338 |
+
shutil.rmtree(new_path)
|
| 339 |
+
shutil.copytree(self._path, new_path)
|
| 340 |
+
except OSError:
|
| 341 |
+
log.warning(
|
| 342 |
+
"Failed to copy debug files from %s to %s", self._path, new_path
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def fopen(
|
| 346 |
+
self,
|
| 347 |
+
filename: str,
|
| 348 |
+
write_mode: str = "w",
|
| 349 |
+
*args: Any,
|
| 350 |
+
**kwargs: Any,
|
| 351 |
+
) -> IO[Any]:
|
| 352 |
+
assert self._path
|
| 353 |
+
return open(os.path.join(self._path, filename), write_mode, *args, **kwargs)
|
| 354 |
+
|
| 355 |
+
@contextlib.contextmanager
|
| 356 |
+
def fopen_context(
|
| 357 |
+
self,
|
| 358 |
+
filename: str,
|
| 359 |
+
write_mode: str = "w",
|
| 360 |
+
*args: Any,
|
| 361 |
+
**kwargs: Any,
|
| 362 |
+
) -> Iterator[IO[Any]]:
|
| 363 |
+
assert self._path
|
| 364 |
+
with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f:
|
| 365 |
+
yield f
|
| 366 |
+
|
| 367 |
+
def filename(self, suffix: str) -> str:
|
| 368 |
+
assert self._path
|
| 369 |
+
return os.path.join(self._path, suffix)
|
| 370 |
+
|
| 371 |
+
def upload_tar(self) -> None:
|
| 372 |
+
if config.trace.upload_tar is not None:
|
| 373 |
+
import tarfile
|
| 374 |
+
|
| 375 |
+
assert self._path
|
| 376 |
+
tar_file = os.path.join(
|
| 377 |
+
self._path, f"{os.path.basename(self._path)}.tar.gz"
|
| 378 |
+
)
|
| 379 |
+
with tarfile.open(tar_file, "w:gz") as tar:
|
| 380 |
+
tar.add(self._path, arcname=os.path.basename(self._path))
|
| 381 |
+
config.trace.upload_tar(tar_file)
|
| 382 |
+
|
| 383 |
+
def __enter__(self) -> None:
|
| 384 |
+
if config.debug:
|
| 385 |
+
log = logging.getLogger("torch._dynamo")
|
| 386 |
+
prev_level = log.level
|
| 387 |
+
log.setLevel(logging.DEBUG)
|
| 388 |
+
|
| 389 |
+
def reset_log_level(level: Any) -> None:
|
| 390 |
+
log.setLevel(level)
|
| 391 |
+
|
| 392 |
+
self._stack.callback(reset_log_level, prev_level)
|
| 393 |
+
|
| 394 |
+
self._stack.enter_context(V.set_debug_handler(self))
|
| 395 |
+
|
| 396 |
+
if not config.trace.enabled:
|
| 397 |
+
return
|
| 398 |
+
|
| 399 |
+
self._path = self.create_debug_dir(get_aot_graph_name()) # type: ignore[assignment]
|
| 400 |
+
|
| 401 |
+
if config.trace.debug_log:
|
| 402 |
+
self._setup_log_capture("debug.log", logging.DEBUG)
|
| 403 |
+
if config.trace.info_log:
|
| 404 |
+
self._setup_log_capture("info.log", logging.INFO)
|
| 405 |
+
|
| 406 |
+
def _setup_log_capture(
|
| 407 |
+
self,
|
| 408 |
+
filename: str,
|
| 409 |
+
level: int,
|
| 410 |
+
) -> None:
|
| 411 |
+
log = logging.getLogger("torch._inductor")
|
| 412 |
+
fd = self._stack.enter_context(self.fopen(filename))
|
| 413 |
+
ch = logging.StreamHandler(fd)
|
| 414 |
+
ch.setLevel(level)
|
| 415 |
+
ch.setFormatter(
|
| 416 |
+
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
|
| 417 |
+
)
|
| 418 |
+
log.addHandler(ch)
|
| 419 |
+
log.setLevel(min(log.level, level))
|
| 420 |
+
self._stack.callback(log.removeHandler, ch)
|
| 421 |
+
|
| 422 |
+
def __exit__(
|
| 423 |
+
self,
|
| 424 |
+
exc_type: Optional[Type[BaseException]],
|
| 425 |
+
exc_val: Optional[BaseException],
|
| 426 |
+
exc_tb: Optional[Any],
|
| 427 |
+
) -> None:
|
| 428 |
+
if self._prof:
|
| 429 |
+
self._prof.disable()
|
| 430 |
+
self._save_profile_data()
|
| 431 |
+
|
| 432 |
+
if self._path:
|
| 433 |
+
self.upload_tar()
|
| 434 |
+
log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
|
| 435 |
+
self._stack.close()
|
| 436 |
+
|
| 437 |
+
def _save_profile_data(self) -> None:
|
| 438 |
+
assert self._prof
|
| 439 |
+
self._prof.dump_stats(self.filename("compile.prof"))
|
| 440 |
+
with self.fopen("compile.stats") as fd:
|
| 441 |
+
stats = pstats.Stats(self._prof, stream=fd)
|
| 442 |
+
stats.strip_dirs()
|
| 443 |
+
stats.sort_stats("cumtime")
|
| 444 |
+
stats.print_stats(100)
|
| 445 |
+
stats.sort_stats("tottime")
|
| 446 |
+
stats.print_stats(100)
|
| 447 |
+
|
| 448 |
+
def __getattr__(self, name: str) -> Optional[Callable[..., None]]:
|
| 449 |
+
if config.trace.enabled and getattr(config.trace, name):
|
| 450 |
+
try:
|
| 451 |
+
return getattr(DebugFormatter(self), name)
|
| 452 |
+
except Exception:
|
| 453 |
+
log.warning("Ignoring exception in debug code", exc_info=True)
|
| 454 |
+
return None
|
| 455 |
+
else:
|
| 456 |
+
|
| 457 |
+
def ignored(*args: Any, **kwargs: Any) -> None:
|
| 458 |
+
pass
|
| 459 |
+
|
| 460 |
+
return ignored
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class DebugFormatter:
|
| 464 |
+
def __init__(self, handler: DebugContext) -> None:
|
| 465 |
+
self.fopen = handler.fopen
|
| 466 |
+
self.fopen_context = handler.fopen_context
|
| 467 |
+
self.filename = handler.filename
|
| 468 |
+
self.handler = handler
|
| 469 |
+
|
| 470 |
+
def fx_graph(
|
| 471 |
+
self,
|
| 472 |
+
gm: torch.fx.GraphModule,
|
| 473 |
+
inputs: List[torch.Tensor],
|
| 474 |
+
) -> None:
|
| 475 |
+
with self.fopen("fx_graph_runnable.py") as fd:
|
| 476 |
+
save_graph_repro(fd, gm, inputs, "inductor")
|
| 477 |
+
|
| 478 |
+
with self.fopen("fx_graph_readable.py") as fd:
|
| 479 |
+
fd.write(gm.print_readable(print_output=False))
|
| 480 |
+
|
| 481 |
+
def fx_graph_transformed(
|
| 482 |
+
self,
|
| 483 |
+
gm: torch.fx.GraphModule,
|
| 484 |
+
inputs: List[torch.Tensor],
|
| 485 |
+
) -> None:
|
| 486 |
+
with self.fopen("fx_graph_transformed.py") as fd:
|
| 487 |
+
fd.write(gm.print_readable(print_output=False))
|
| 488 |
+
|
| 489 |
+
def ir_pre_fusion(self, nodes: SchedulerNodeList) -> None:
|
| 490 |
+
self._write_ir("ir_pre_fusion.txt", nodes)
|
| 491 |
+
|
| 492 |
+
def ir_post_fusion(self, nodes: SchedulerNodeList) -> None:
|
| 493 |
+
self._write_ir("ir_post_fusion.txt", nodes)
|
| 494 |
+
|
| 495 |
+
def _write_ir(
|
| 496 |
+
self,
|
| 497 |
+
filename: str,
|
| 498 |
+
nodes: SchedulerNodeList,
|
| 499 |
+
) -> None:
|
| 500 |
+
with self.fopen(filename) as fd:
|
| 501 |
+
log.info("Writing debug ir to %s", fd.name)
|
| 502 |
+
for node in nodes:
|
| 503 |
+
fd.write(node.debug_str())
|
| 504 |
+
fd.write("\n\n\n")
|
| 505 |
+
|
| 506 |
+
def graph_diagram(self, nodes: SchedulerNodeList) -> None:
|
| 507 |
+
draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
|
| 508 |
+
|
| 509 |
+
def draw_orig_fx_graph(
|
| 510 |
+
self,
|
| 511 |
+
gm: torch.fx.GraphModule,
|
| 512 |
+
nodes: SchedulerNodeList,
|
| 513 |
+
) -> None:
|
| 514 |
+
annotate_orig_fx_with_snodes(gm, nodes)
|
| 515 |
+
draw_graph(
|
| 516 |
+
gm,
|
| 517 |
+
fname=self.filename("orig_fx_graph_diagram.svg"),
|
| 518 |
+
clear_meta=False,
|
| 519 |
+
prog=GRAPHVIZ_COMMAND_SCALABLE,
|
| 520 |
+
parse_stack_trace=True,
|
| 521 |
+
dot_graph_shape=config.trace.dot_graph_shape,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
def output_code(self, filename: str) -> None:
|
| 525 |
+
shutil.copy(filename, self.filename("output_code.py"))
|
| 526 |
+
|
| 527 |
+
def log_autotuning_results(
|
| 528 |
+
self,
|
| 529 |
+
name: str,
|
| 530 |
+
input_nodes: List[ir.IRNode],
|
| 531 |
+
timings: Dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
|
| 532 |
+
elapse: float,
|
| 533 |
+
precompile_elapse: float,
|
| 534 |
+
) -> None:
|
| 535 |
+
import json
|
| 536 |
+
|
| 537 |
+
from .ir import FixedLayout
|
| 538 |
+
|
| 539 |
+
def build_node_info(node: ir.IRNode) -> Dict[str, str]:
|
| 540 |
+
if hasattr(node, "name"):
|
| 541 |
+
node_name = node.name
|
| 542 |
+
else:
|
| 543 |
+
node_name = ""
|
| 544 |
+
node_info = {
|
| 545 |
+
"name": node_name,
|
| 546 |
+
"type": type(node).__name__,
|
| 547 |
+
}
|
| 548 |
+
try:
|
| 549 |
+
layout = node.get_layout()
|
| 550 |
+
if isinstance(layout, FixedLayout):
|
| 551 |
+
offset = 0
|
| 552 |
+
try:
|
| 553 |
+
offset = int(layout.offset)
|
| 554 |
+
except Exception:
|
| 555 |
+
try:
|
| 556 |
+
offset = V.graph.sizevars.size_hint(
|
| 557 |
+
layout.offset, fallback=0
|
| 558 |
+
)
|
| 559 |
+
except Exception:
|
| 560 |
+
pass
|
| 561 |
+
static_layout = FixedLayout(
|
| 562 |
+
layout.device,
|
| 563 |
+
dtype=layout.dtype,
|
| 564 |
+
size=list(V.graph.sizevars.size_hints(layout.size)),
|
| 565 |
+
stride=list(V.graph.sizevars.size_hints(layout.stride)),
|
| 566 |
+
offset=offset,
|
| 567 |
+
)
|
| 568 |
+
node_info["layout"] = str(static_layout)
|
| 569 |
+
else:
|
| 570 |
+
node_info["layout"] = str(node.get_layout())
|
| 571 |
+
except Exception as e:
|
| 572 |
+
pass
|
| 573 |
+
try:
|
| 574 |
+
node_info["dtype"] = str(node.get_dtype())
|
| 575 |
+
except Exception as e:
|
| 576 |
+
pass
|
| 577 |
+
try:
|
| 578 |
+
node_info["device"] = str(node.get_device())
|
| 579 |
+
except Exception as e:
|
| 580 |
+
pass
|
| 581 |
+
try:
|
| 582 |
+
node_info["stride"] = str(
|
| 583 |
+
V.graph.sizevars.size_hints(node.get_stride())
|
| 584 |
+
)
|
| 585 |
+
except Exception as e:
|
| 586 |
+
pass
|
| 587 |
+
try:
|
| 588 |
+
node_info["size"] = str(V.graph.sizevars.size_hints(node.get_size()))
|
| 589 |
+
except Exception as e:
|
| 590 |
+
pass
|
| 591 |
+
try:
|
| 592 |
+
node_info["numel"] = str(V.graph.sizevars.size_hint(node.get_numel()))
|
| 593 |
+
except Exception as e:
|
| 594 |
+
pass
|
| 595 |
+
if hasattr(node, "data") and isinstance(node.data, ir.IRNode):
|
| 596 |
+
node_info["data"] = build_node_info(node.data)
|
| 597 |
+
return node_info
|
| 598 |
+
|
| 599 |
+
general_properties = {
|
| 600 |
+
"op_name": name,
|
| 601 |
+
"cuda_device_name": torch.cuda.get_device_name(),
|
| 602 |
+
"cuda_device_count": torch.cuda.device_count(),
|
| 603 |
+
"input_nodes": [build_node_info(node) for node in input_nodes],
|
| 604 |
+
"autotuning_time": elapse,
|
| 605 |
+
"precompile_time": precompile_elapse,
|
| 606 |
+
}
|
| 607 |
+
with self.fopen_context(
|
| 608 |
+
"autotuning_result_json_list.txt", "at", encoding="utf-8"
|
| 609 |
+
) as fd:
|
| 610 |
+
for caller, time in timings.items():
|
| 611 |
+
info_dict = dict(caller.info_dict())
|
| 612 |
+
info_dict.update(general_properties)
|
| 613 |
+
info_dict["benchmark_result"] = time
|
| 614 |
+
json.dump(info_dict, fd)
|
| 615 |
+
fd.write("\n")
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
@dataclasses.dataclass
|
| 619 |
+
class TensorMetadataHolder:
|
| 620 |
+
tensor_metadata: TensorMetadata
|
| 621 |
+
device: torch.device
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
save_args_cnt = itertools.count()
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None:
|
| 628 |
+
"""
|
| 629 |
+
This function is used to save arguments for a compile_fx_inner function call
|
| 630 |
+
to the file system. Later on one can replay the compile_fx_inner call
|
| 631 |
+
with the saved arguments using load_args_and_run_compile_fx_inner.
|
| 632 |
+
"""
|
| 633 |
+
|
| 634 |
+
folder = "/tmp/inductor_saved_args"
|
| 635 |
+
if not os.path.exists(folder):
|
| 636 |
+
os.mkdir(folder)
|
| 637 |
+
|
| 638 |
+
def handle_tensor(x: Any) -> Any:
|
| 639 |
+
"""
|
| 640 |
+
Pickle FakeTensor will result in error:
|
| 641 |
+
AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'
|
| 642 |
+
|
| 643 |
+
Convert all Tensor to metadata. This may also makes pickle faster.
|
| 644 |
+
"""
|
| 645 |
+
if isinstance(x, torch.Tensor):
|
| 646 |
+
return TensorMetadataHolder(_extract_tensor_metadata(x), x.device)
|
| 647 |
+
else:
|
| 648 |
+
return x
|
| 649 |
+
|
| 650 |
+
args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs))
|
| 651 |
+
|
| 652 |
+
fn_name = "compile_fx_inner"
|
| 653 |
+
path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl"
|
| 654 |
+
with open(path, "wb") as f:
|
| 655 |
+
pickle.dump((args_to_save, kwargs_to_save), f)
|
| 656 |
+
|
| 657 |
+
if log.isEnabledFor(logging.DEBUG):
|
| 658 |
+
message = f"""
|
| 659 |
+
Arguments for a compile_fx_inner call is saved to {path}. To replay the call,
|
| 660 |
+
run the following:
|
| 661 |
+
|
| 662 |
+
from torch._inductor.debug import load_args_and_run_compile_fx_inner
|
| 663 |
+
load_args_and_run_compile_fx_inner({path!r})
|
| 664 |
+
"""
|
| 665 |
+
# call print rather than log.debug. log.debug will print message
|
| 666 |
+
# prefix for each line which makes the code snippet harder to be
|
| 667 |
+
# copied.
|
| 668 |
+
# Not a big deal since the code is already been guarded by checking
|
| 669 |
+
# the log level.
|
| 670 |
+
print(message)
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def load_args_and_run_compile_fx_inner(path: str) -> Any:
|
| 674 |
+
from torch._inductor.compile_fx import compile_fx_inner
|
| 675 |
+
|
| 676 |
+
with open(path, "rb") as f:
|
| 677 |
+
args, kwargs = pickle.load(f)
|
| 678 |
+
|
| 679 |
+
def handle_tensor(x: Any) -> Any:
|
| 680 |
+
if isinstance(x, TensorMetadataHolder):
|
| 681 |
+
return torch._dynamo.testing.rand_strided(
|
| 682 |
+
x.tensor_metadata.shape,
|
| 683 |
+
x.tensor_metadata.stride,
|
| 684 |
+
x.tensor_metadata.dtype,
|
| 685 |
+
x.device,
|
| 686 |
+
)
|
| 687 |
+
else:
|
| 688 |
+
return x
|
| 689 |
+
|
| 690 |
+
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
|
| 691 |
+
with fake_mode, config.patch("save_args", False):
|
| 692 |
+
args, kwargs = tree_map(handle_tensor, (args, kwargs))
|
| 693 |
+
return compile_fx_inner(*args, **kwargs)
|
.venv/lib/python3.11/site-packages/torch/_inductor/decomposition.py
ADDED
|
@@ -0,0 +1,980 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
import functools
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import sys
|
| 6 |
+
import typing
|
| 7 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._decomp as decomp
|
| 11 |
+
import torch._prims_common as utils
|
| 12 |
+
import torch.ao.quantization.fx._decomposed
|
| 13 |
+
from torch._decomp import (
|
| 14 |
+
core_aten_decompositions,
|
| 15 |
+
get_decompositions,
|
| 16 |
+
remove_decompositions,
|
| 17 |
+
)
|
| 18 |
+
from torch._decomp.decompositions import (
|
| 19 |
+
_grid_sampler_2d as decomp_grid_sampler_2d,
|
| 20 |
+
pw_cast_for_opmath,
|
| 21 |
+
)
|
| 22 |
+
from torch._decomp.decompositions_for_rng import extra_random_decomps
|
| 23 |
+
from torch._dynamo.utils import counters
|
| 24 |
+
from torch._higher_order_ops.out_dtype import out_dtype
|
| 25 |
+
from torch._inductor.utils import pad_listlike
|
| 26 |
+
from torch._prims_common import (
|
| 27 |
+
elementwise_dtypes,
|
| 28 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
| 29 |
+
type_to_dtype,
|
| 30 |
+
)
|
| 31 |
+
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_size_oblivious
|
| 32 |
+
|
| 33 |
+
from . import config, inductor_prims
|
| 34 |
+
from .utils import (
|
| 35 |
+
is_gpu,
|
| 36 |
+
needs_fallback_due_to_atomic_add_limitations,
|
| 37 |
+
use_scatter_fallback,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
log = logging.getLogger(__name__)
|
| 42 |
+
aten = torch.ops.aten
|
| 43 |
+
prims = torch.ops.prims
|
| 44 |
+
quantized = torch.ops.quantized
|
| 45 |
+
_quantized = torch.ops._quantized
|
| 46 |
+
quantized_decomposed = torch.ops.quantized_decomposed
|
| 47 |
+
|
| 48 |
+
inductor_decompositions = get_decompositions(
|
| 49 |
+
[
|
| 50 |
+
aten._adaptive_avg_pool2d_backward,
|
| 51 |
+
aten.addmv,
|
| 52 |
+
aten.arange,
|
| 53 |
+
aten.bitwise_and_,
|
| 54 |
+
aten.bitwise_or_,
|
| 55 |
+
aten.clamp_min_,
|
| 56 |
+
aten.dist,
|
| 57 |
+
aten.empty_like,
|
| 58 |
+
aten.flip,
|
| 59 |
+
aten.gelu,
|
| 60 |
+
aten.hardtanh,
|
| 61 |
+
aten.index_select,
|
| 62 |
+
aten.lcm,
|
| 63 |
+
aten.leaky_relu,
|
| 64 |
+
aten.linalg_vector_norm,
|
| 65 |
+
aten._log_softmax,
|
| 66 |
+
aten.max_pool2d_with_indices_backward,
|
| 67 |
+
aten._native_batch_norm_legit,
|
| 68 |
+
aten._native_batch_norm_legit_functional,
|
| 69 |
+
aten._native_batch_norm_legit_no_training,
|
| 70 |
+
aten._batch_norm_with_update,
|
| 71 |
+
aten._batch_norm_with_update_functional,
|
| 72 |
+
aten._batch_norm_no_update,
|
| 73 |
+
aten.batch_norm_backward,
|
| 74 |
+
aten.native_batch_norm,
|
| 75 |
+
aten.native_group_norm,
|
| 76 |
+
aten.native_layer_norm,
|
| 77 |
+
aten.nll_loss2d_backward,
|
| 78 |
+
aten._softmax,
|
| 79 |
+
aten.sin_,
|
| 80 |
+
aten.sqrt_,
|
| 81 |
+
out_dtype,
|
| 82 |
+
aten._to_copy,
|
| 83 |
+
aten.tril_indices,
|
| 84 |
+
aten.triu_indices,
|
| 85 |
+
aten.upsample_bilinear2d.vec,
|
| 86 |
+
quantized.linear_dynamic_fp16_unpacked_weight,
|
| 87 |
+
_quantized.wrapped_quantized_linear,
|
| 88 |
+
]
|
| 89 |
+
)
|
| 90 |
+
decompositions = {**core_aten_decompositions(), **inductor_decompositions}
|
| 91 |
+
|
| 92 |
+
# Remove unwanted decompositions included via the core ATen decompositions from
|
| 93 |
+
# the Inductor decomp table.
|
| 94 |
+
decomps_to_exclude = [
|
| 95 |
+
aten._unsafe_index,
|
| 96 |
+
aten._unsafe_masked_index,
|
| 97 |
+
aten._unsafe_masked_index_put_accumulate,
|
| 98 |
+
aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py
|
| 99 |
+
aten._softmax_backward_data,
|
| 100 |
+
aten.clamp_max,
|
| 101 |
+
aten.clamp_min,
|
| 102 |
+
aten.glu, # inductor lowers this directly
|
| 103 |
+
aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
|
| 104 |
+
aten.slice_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
|
| 105 |
+
aten.split.Tensor, # inductor lowers this directly
|
| 106 |
+
aten.squeeze, # inductor lowers this directly
|
| 107 |
+
aten.sum, # inductor lowers this directly
|
| 108 |
+
aten.unbind, # inductor lowers this directly
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
remove_decompositions(decompositions, decomps_to_exclude)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def register_decomposition(
|
| 115 |
+
ops: List[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
|
| 116 |
+
) -> Callable[..., Any]:
|
| 117 |
+
for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
|
| 118 |
+
if op in decompositions:
|
| 119 |
+
log.warning("duplicate decomp: %s", ops)
|
| 120 |
+
return decomp.register_decomposition(ops, decompositions)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# TODO: for now, inductor doesn't handle asserts
|
| 124 |
+
# because the condition is symbol -> tensor in the graph.
|
| 125 |
+
@register_decomposition([aten._assert_async.msg])
|
| 126 |
+
def assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
|
| 127 |
+
return
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Following `assert_async_msg_decomp` and implement as non-op.
|
| 131 |
+
@register_decomposition([aten._functional_assert_async.msg])
|
| 132 |
+
def functional_assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@register_decomposition([aten.sym_constrain_range_for_size.default])
|
| 137 |
+
def sym_constrain_range_for_size(
|
| 138 |
+
symbol: torch.SymInt,
|
| 139 |
+
*,
|
| 140 |
+
min: Optional[torch.types.Number] = None,
|
| 141 |
+
max: Optional[torch.types.Number] = None,
|
| 142 |
+
) -> None:
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@register_decomposition([aten.clamp])
|
| 147 |
+
@pw_cast_for_opmath
|
| 148 |
+
def clamp(
|
| 149 |
+
x: torch.Tensor,
|
| 150 |
+
min: Optional[torch.types.Number] = None,
|
| 151 |
+
max: Optional[torch.types.Number] = None,
|
| 152 |
+
) -> torch.Tensor:
|
| 153 |
+
if min is not None:
|
| 154 |
+
x = x.clamp_min(min)
|
| 155 |
+
if max is not None:
|
| 156 |
+
x = x.clamp_max(max)
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@register_decomposition([aten.full])
|
| 161 |
+
def full(
|
| 162 |
+
size: List[Union[int, torch.SymInt]],
|
| 163 |
+
fill_value: torch.types.Number,
|
| 164 |
+
**kwargs: Any,
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
dtype = kwargs.get("dtype")
|
| 167 |
+
if dtype is None:
|
| 168 |
+
kwargs["dtype"] = type_to_dtype(type(fill_value))
|
| 169 |
+
return torch.full(size, fill_value, **kwargs)
|
| 170 |
+
return NotImplemented
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# Not really sure how to put this into the main library. PrimTorch wants
|
| 174 |
+
# empty_permuted to go to the prim, and typically users don't really want
|
| 175 |
+
# to decompose to empty_strided (but inductor is OK with it, because we are
|
| 176 |
+
# cool with strides and everything goes to empty_strided)
|
| 177 |
+
@register_decomposition([aten.empty_permuted.default])
|
| 178 |
+
def empty_permuted(
|
| 179 |
+
size: List[Union[int, torch.SymInt]],
|
| 180 |
+
physical_layout: List[int],
|
| 181 |
+
**kwargs: Any,
|
| 182 |
+
) -> torch.Tensor:
|
| 183 |
+
perm = [0] * len(size)
|
| 184 |
+
for p, l in enumerate(physical_layout):
|
| 185 |
+
perm[l] = p
|
| 186 |
+
return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@register_decomposition([aten.convolution_backward])
|
| 190 |
+
def convolution_backward(
|
| 191 |
+
grad_output: torch.Tensor,
|
| 192 |
+
input: torch.Tensor,
|
| 193 |
+
weight: torch.Tensor,
|
| 194 |
+
bias_sizes: List[int],
|
| 195 |
+
stride: Union[int, List[int]],
|
| 196 |
+
padding: Union[int, List[int]],
|
| 197 |
+
dilation: Union[int, List[int]],
|
| 198 |
+
transposed: bool,
|
| 199 |
+
output_padding: List[int],
|
| 200 |
+
groups: int,
|
| 201 |
+
output_mask: List[bool],
|
| 202 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 203 |
+
if not output_mask[2] or not is_gpu(grad_output.device.type):
|
| 204 |
+
return NotImplemented
|
| 205 |
+
grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
|
| 206 |
+
grad_inp, grad_weight, _ = aten.convolution_backward(
|
| 207 |
+
grad_output,
|
| 208 |
+
input,
|
| 209 |
+
weight,
|
| 210 |
+
bias_sizes,
|
| 211 |
+
stride,
|
| 212 |
+
padding,
|
| 213 |
+
dilation,
|
| 214 |
+
transposed,
|
| 215 |
+
output_padding,
|
| 216 |
+
groups,
|
| 217 |
+
[output_mask[0], output_mask[1], False],
|
| 218 |
+
)
|
| 219 |
+
return (grad_inp, grad_weight, grad_bias)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@register_decomposition([aten.round.decimals])
|
| 223 |
+
def round_dec(x: torch.Tensor, decimals: int = 0) -> torch.Tensor:
|
| 224 |
+
ten_pow_decimals = 10.0**decimals
|
| 225 |
+
return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@register_decomposition([aten.bmm])
|
| 229 |
+
@pw_cast_for_opmath
|
| 230 |
+
def bmm(
|
| 231 |
+
self: torch.Tensor,
|
| 232 |
+
batch2: torch.Tensor,
|
| 233 |
+
) -> torch.Tensor:
|
| 234 |
+
if config.coordinate_descent_tuning:
|
| 235 |
+
if guard_size_oblivious(self.shape[1] == 1) or guard_size_oblivious(
|
| 236 |
+
batch2.shape[2] == 1
|
| 237 |
+
):
|
| 238 |
+
out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2)
|
| 239 |
+
return out
|
| 240 |
+
if self.device.type == "cpu":
|
| 241 |
+
if guard_size_oblivious(self.size(1) == 1) and guard_size_oblivious(
|
| 242 |
+
batch2.size(-1) == 1
|
| 243 |
+
):
|
| 244 |
+
counters["inductor"]["decompose_bmm"] += 1
|
| 245 |
+
return torch.sum(
|
| 246 |
+
self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True
|
| 247 |
+
).unsqueeze(1)
|
| 248 |
+
return NotImplemented
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@register_decomposition([aten.addmm])
|
| 252 |
+
@pw_cast_for_opmath
|
| 253 |
+
def addmm(
|
| 254 |
+
self: torch.Tensor,
|
| 255 |
+
mat1: torch.Tensor,
|
| 256 |
+
mat2: torch.Tensor,
|
| 257 |
+
beta: torch.types.Number = 1,
|
| 258 |
+
alpha: torch.types.Number = 1,
|
| 259 |
+
) -> torch.Tensor:
|
| 260 |
+
if self.device.type == "cpu":
|
| 261 |
+
if guard_size_oblivious(mat1.size(0) == 1) and guard_size_oblivious(
|
| 262 |
+
mat2.size(-1) == 1
|
| 263 |
+
):
|
| 264 |
+
counters["inductor"]["decompose_addmm"] += 1
|
| 265 |
+
out = torch.sum(
|
| 266 |
+
mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True
|
| 267 |
+
).unsqueeze(0)
|
| 268 |
+
return alpha * out + beta * self
|
| 269 |
+
if (
|
| 270 |
+
guard_size_oblivious(mat1.size(0) == 1)
|
| 271 |
+
and definitely_true(mat2.size(0) <= 16)
|
| 272 |
+
and definitely_true(mat2.size(1) <= 16)
|
| 273 |
+
):
|
| 274 |
+
counters["inductor"]["decompose_addmm"] += 1
|
| 275 |
+
out = (mat1.T * mat2).sum(dim=0, keepdim=True)
|
| 276 |
+
return alpha * out + beta * self
|
| 277 |
+
return NotImplemented
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@register_decomposition([aten.mm])
|
| 281 |
+
@pw_cast_for_opmath
|
| 282 |
+
def mm(
|
| 283 |
+
self: torch.Tensor,
|
| 284 |
+
input2: torch.Tensor,
|
| 285 |
+
) -> torch.Tensor:
|
| 286 |
+
# Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
|
| 287 |
+
# todo: Look into why and fix it (hopefully)
|
| 288 |
+
if config.coordinate_descent_tuning:
|
| 289 |
+
if guard_size_oblivious(self.shape[0] == 1) or guard_size_oblivious(
|
| 290 |
+
input2.shape[1] == 1
|
| 291 |
+
):
|
| 292 |
+
return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1)
|
| 293 |
+
if self.device.type == "cpu":
|
| 294 |
+
if (
|
| 295 |
+
guard_size_oblivious(self.size(-1) == 1)
|
| 296 |
+
and guard_size_oblivious(self.size(0) > 0)
|
| 297 |
+
and guard_size_oblivious(input2.size(0) == 1)
|
| 298 |
+
and (self.dtype == input2.dtype)
|
| 299 |
+
and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32)
|
| 300 |
+
):
|
| 301 |
+
counters["inductor"]["decompose_mm"] += 1
|
| 302 |
+
return torch.cat([self[i, :] * input2 for i in range(self.size(0))])
|
| 303 |
+
if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious(
|
| 304 |
+
input2.size(-1) == 1
|
| 305 |
+
):
|
| 306 |
+
counters["inductor"]["decompose_mm"] += 1
|
| 307 |
+
return torch.sum(
|
| 308 |
+
self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True
|
| 309 |
+
).unsqueeze(0)
|
| 310 |
+
return NotImplemented
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# This pass does two things:
|
| 314 |
+
# - Eliminate cat when there is only one tensor input
|
| 315 |
+
# - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we
|
| 316 |
+
# don't remove ALL empty tensors, only the naughty ones)
|
| 317 |
+
@register_decomposition([aten.cat.default])
|
| 318 |
+
def cat(
|
| 319 |
+
tensors: List[torch.Tensor],
|
| 320 |
+
dim: int = 0,
|
| 321 |
+
) -> torch.Tensor:
|
| 322 |
+
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
| 323 |
+
|
| 324 |
+
def non_empty_tensor(x: torch.Tensor) -> bool:
|
| 325 |
+
# For better or worse, this is a valid cat:
|
| 326 |
+
#
|
| 327 |
+
# torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)])
|
| 328 |
+
#
|
| 329 |
+
# We'd like to eliminate naughtiness like this for downstream passes
|
| 330 |
+
# like split_cat. The easiest way is to just drop such inputs
|
| 331 |
+
# (guarding that they are non-zero).
|
| 332 |
+
#
|
| 333 |
+
# Is it permissible for this filtering to be size-oblivious? A case
|
| 334 |
+
# where this could matter is cat([(2, 2), (u0,)], dim=0); if u0
|
| 335 |
+
# happened to be zero, we would have liked to have filtered it out.
|
| 336 |
+
# But actually, the ONLY way this could have passed is if u0 == 0,
|
| 337 |
+
# so by the time we get here we have already installed a deferred
|
| 338 |
+
# runtime assert forcing u0 to be zero. So if this hasn't happened,
|
| 339 |
+
# we know that the unbacked SymInt has appropriate size and there are
|
| 340 |
+
# no problems.
|
| 341 |
+
if len(x.shape) == 1 and guard_size_oblivious(x.shape[0] == 0):
|
| 342 |
+
return False
|
| 343 |
+
|
| 344 |
+
if dim < len(x.shape) and guard_size_oblivious(x.shape[dim] == 0):
|
| 345 |
+
return False
|
| 346 |
+
|
| 347 |
+
return True
|
| 348 |
+
|
| 349 |
+
filtered_tensors = list(filter(non_empty_tensor, tensors))
|
| 350 |
+
|
| 351 |
+
if len(filtered_tensors) == 1:
|
| 352 |
+
return filtered_tensors[0].clone()
|
| 353 |
+
elif 1 < len(filtered_tensors) < len(tensors):
|
| 354 |
+
# on the first call, when we remove empty tensors, we redispatch recursively
|
| 355 |
+
return aten.cat.default(filtered_tensors, dim)
|
| 356 |
+
|
| 357 |
+
# optimization, avoid concat for single, repeated input
|
| 358 |
+
if len(filtered_tensors) > 1 and all(
|
| 359 |
+
t is filtered_tensors[0] for t in filtered_tensors
|
| 360 |
+
):
|
| 361 |
+
inp = filtered_tensors[0]
|
| 362 |
+
shape = list(inp.shape)
|
| 363 |
+
dim = dim + len(inp.shape) if dim < 0 else dim
|
| 364 |
+
shape.insert(dim, len(filtered_tensors))
|
| 365 |
+
return inp.unsqueeze(dim).expand(*shape).flatten(dim, dim + 1).clone()
|
| 366 |
+
|
| 367 |
+
# when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed)
|
| 368 |
+
return NotImplemented
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@register_decomposition([aten.angle])
|
| 372 |
+
def angle(x: torch.Tensor) -> torch.Tensor:
|
| 373 |
+
if x.is_complex():
|
| 374 |
+
return torch.where(
|
| 375 |
+
torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real)
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# when x is real number
|
| 379 |
+
# if x >= 0, return 0
|
| 380 |
+
# if x < 0, return pi
|
| 381 |
+
# if x is nan, return nan
|
| 382 |
+
_, dtype = elementwise_dtypes(
|
| 383 |
+
x,
|
| 384 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 385 |
+
)
|
| 386 |
+
pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device)
|
| 387 |
+
ret = torch.where(x < 0, pi, 0.0)
|
| 388 |
+
return torch.where(torch.isnan(x), float("nan"), ret)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
@register_decomposition([aten.add])
|
| 392 |
+
def add(
|
| 393 |
+
x: torch.Tensor,
|
| 394 |
+
y: torch.Tensor,
|
| 395 |
+
*,
|
| 396 |
+
alpha: Optional[torch.types.Number] = None,
|
| 397 |
+
) -> torch.Tensor:
|
| 398 |
+
# Require both x and y to be complex tensors.
|
| 399 |
+
x_is_complex_tensor = torch.is_tensor(x) and x.is_complex()
|
| 400 |
+
y_is_complex_tensor = torch.is_tensor(y) and y.is_complex()
|
| 401 |
+
if not x_is_complex_tensor or not y_is_complex_tensor:
|
| 402 |
+
return NotImplemented
|
| 403 |
+
z = y
|
| 404 |
+
if alpha is not None:
|
| 405 |
+
z = alpha * y
|
| 406 |
+
complex_type = torch.promote_types(x.dtype, y.dtype)
|
| 407 |
+
|
| 408 |
+
# For complex typed `x`, `x.view(x.real.dtype)` doubles the last dimension and can cause problem
|
| 409 |
+
# when broadcasting the add.
|
| 410 |
+
def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor:
|
| 411 |
+
"""Reshape tensor from [*initial_dims, last_dim] to *initial_dims, last_dim/2, 2]"""
|
| 412 |
+
# Get the current shape of the tensor
|
| 413 |
+
*initial_dims, last_dim = tensor.shape
|
| 414 |
+
|
| 415 |
+
# Check if the last dimension is even. We should never reach here since `x.view(x.real.dtype)`
|
| 416 |
+
# doubles the last dimension for complex numbers.
|
| 417 |
+
if last_dim % 2 != 0:
|
| 418 |
+
raise AssertionError(
|
| 419 |
+
"The size of the last dimension must be even to reshape it to [..., last_dim/2, 2]"
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
# Reshape the tensor
|
| 423 |
+
new_shape = (*initial_dims, last_dim // 2, 2)
|
| 424 |
+
reshaped_tensor = tensor.view(new_shape)
|
| 425 |
+
return reshaped_tensor
|
| 426 |
+
|
| 427 |
+
x_reshaped = reshape_tensor_complex(x.view(x.real.dtype))
|
| 428 |
+
z_reshaped = reshape_tensor_complex(z.view(y.real.dtype))
|
| 429 |
+
result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type)
|
| 430 |
+
return result
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
@register_decomposition([aten.conj_physical])
|
| 434 |
+
def conj_physical(self: torch.Tensor) -> torch.Tensor:
|
| 435 |
+
assert not self.is_complex(), "TODO: implement this"
|
| 436 |
+
return self
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
@register_decomposition([aten.lift, aten.detach_])
|
| 440 |
+
def lift(self: torch.Tensor) -> torch.Tensor:
|
| 441 |
+
return self
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
@register_decomposition([aten.bernoulli.default])
|
| 445 |
+
def bernoulli(
|
| 446 |
+
self: torch.Tensor,
|
| 447 |
+
*,
|
| 448 |
+
generator: Optional[torch.Generator] = None,
|
| 449 |
+
) -> torch.Tensor:
|
| 450 |
+
assert generator is None
|
| 451 |
+
return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
@register_decomposition([aten.fmin, prims.fmin])
|
| 455 |
+
def fmin(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
|
| 456 |
+
return torch.where(torch.isnan(other) | (other > self), self, other)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
@register_decomposition([aten.fmax, prims.fmax])
|
| 460 |
+
def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
|
| 461 |
+
return torch.where(torch.isnan(other) | (other < self), self, other)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
@register_decomposition(aten.amax)
|
| 465 |
+
def amax(
|
| 466 |
+
self: torch.Tensor,
|
| 467 |
+
dim: Optional[int] = None,
|
| 468 |
+
keepdim: bool = False,
|
| 469 |
+
) -> torch.Tensor:
|
| 470 |
+
if self.dtype == torch.bool:
|
| 471 |
+
return torch.any(self, dim=dim, keepdim=keepdim)
|
| 472 |
+
return NotImplemented
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
@register_decomposition(aten.amin)
|
| 476 |
+
def amin(
|
| 477 |
+
self: torch.Tensor,
|
| 478 |
+
dim: Optional[int] = None,
|
| 479 |
+
keepdim: bool = False,
|
| 480 |
+
) -> torch.Tensor:
|
| 481 |
+
if self.dtype == torch.bool:
|
| 482 |
+
return torch.all(self, dim=dim, keepdim=keepdim)
|
| 483 |
+
return NotImplemented
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
@register_decomposition([aten.narrow_copy])
|
| 487 |
+
def narrow_copy(
|
| 488 |
+
self: torch.Tensor,
|
| 489 |
+
dim: int,
|
| 490 |
+
start: int,
|
| 491 |
+
length: int,
|
| 492 |
+
) -> torch.Tensor:
|
| 493 |
+
return torch.narrow(self, dim, start, length).clone()
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
@register_decomposition([aten.view_copy.default])
|
| 497 |
+
def view_copy_default(
|
| 498 |
+
self: torch.Tensor,
|
| 499 |
+
size: List[Union[int, torch.SymInt]],
|
| 500 |
+
) -> torch.Tensor:
|
| 501 |
+
return aten.view(self, size).clone()
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
@register_decomposition([aten.view_copy.dtype])
|
| 505 |
+
def view_copy_dtype(
|
| 506 |
+
self: torch.Tensor,
|
| 507 |
+
dtype: torch.dtype,
|
| 508 |
+
) -> torch.Tensor:
|
| 509 |
+
return self.to(dtype).clone()
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def get_like_layout(
|
| 513 |
+
tensor: torch.Tensor,
|
| 514 |
+
memory_format: Optional[torch.memory_format] = None,
|
| 515 |
+
) -> torch.memory_format:
|
| 516 |
+
# TODO: _to_copy tensor to stride permutation
|
| 517 |
+
if memory_format is torch.preserve_format or memory_format is None:
|
| 518 |
+
return utils.suggest_memory_format(tensor)
|
| 519 |
+
else:
|
| 520 |
+
return memory_format
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
@register_decomposition(aten.rand_like)
|
| 524 |
+
def rand_like(
|
| 525 |
+
self: torch.Tensor,
|
| 526 |
+
*,
|
| 527 |
+
dtype: Optional[torch.dtype] = None,
|
| 528 |
+
device: Optional[torch.device] = None,
|
| 529 |
+
memory_format: Optional[torch.memory_format] = None,
|
| 530 |
+
**kwargs: Any,
|
| 531 |
+
) -> torch.Tensor:
|
| 532 |
+
return torch.rand(
|
| 533 |
+
[*self.size()],
|
| 534 |
+
dtype=dtype or self.dtype,
|
| 535 |
+
device=device or self.device,
|
| 536 |
+
**kwargs,
|
| 537 |
+
).to(memory_format=get_like_layout(self, memory_format))
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@register_decomposition(aten.randn_like)
|
| 541 |
+
def randn_like(
|
| 542 |
+
self: torch.Tensor,
|
| 543 |
+
*,
|
| 544 |
+
dtype: Optional[torch.dtype] = None,
|
| 545 |
+
device: Optional[torch.device] = None,
|
| 546 |
+
memory_format: Optional[torch.memory_format] = None,
|
| 547 |
+
**kwargs: Any,
|
| 548 |
+
) -> torch.Tensor:
|
| 549 |
+
return torch.randn(
|
| 550 |
+
[*self.size()],
|
| 551 |
+
dtype=dtype or self.dtype,
|
| 552 |
+
device=device or self.device,
|
| 553 |
+
**kwargs,
|
| 554 |
+
).to(memory_format=get_like_layout(self, memory_format))
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
@register_decomposition(aten.full_like)
|
| 558 |
+
def full_like(
|
| 559 |
+
self: torch.Tensor,
|
| 560 |
+
fill_value: Union[int, float],
|
| 561 |
+
*,
|
| 562 |
+
dtype: Optional[torch.dtype] = None,
|
| 563 |
+
layout: Optional[torch.layout] = None,
|
| 564 |
+
device: Optional[torch.device] = None,
|
| 565 |
+
pin_memory: bool = False,
|
| 566 |
+
requires_grad: bool = False,
|
| 567 |
+
memory_format: torch.memory_format = torch.preserve_format,
|
| 568 |
+
) -> torch.Tensor:
|
| 569 |
+
return torch.full(
|
| 570 |
+
[*self.size()],
|
| 571 |
+
fill_value,
|
| 572 |
+
dtype=dtype or self.dtype,
|
| 573 |
+
layout=layout or self.layout,
|
| 574 |
+
device=device or self.device,
|
| 575 |
+
requires_grad=requires_grad,
|
| 576 |
+
).to(memory_format=get_like_layout(self, memory_format))
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
@register_decomposition(aten.randint_like.default)
|
| 580 |
+
def randint_like(
|
| 581 |
+
self: torch.Tensor,
|
| 582 |
+
high: int,
|
| 583 |
+
*,
|
| 584 |
+
dtype: Optional[torch.dtype] = None,
|
| 585 |
+
device: Optional[torch.device] = None,
|
| 586 |
+
memory_format: Optional[torch.memory_format] = None,
|
| 587 |
+
**kwargs: Any,
|
| 588 |
+
) -> torch.Tensor:
|
| 589 |
+
return aten.randint.low(
|
| 590 |
+
0,
|
| 591 |
+
high,
|
| 592 |
+
[*self.size()],
|
| 593 |
+
dtype=dtype or self.dtype,
|
| 594 |
+
device=device or self.device,
|
| 595 |
+
**kwargs,
|
| 596 |
+
).to(memory_format=get_like_layout(self, memory_format))
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
@register_decomposition(aten.randint_like.low_dtype)
|
| 600 |
+
def randint_like_low(
|
| 601 |
+
self: torch.Tensor,
|
| 602 |
+
low: int,
|
| 603 |
+
high: int,
|
| 604 |
+
*,
|
| 605 |
+
dtype: Optional[torch.dtype] = None,
|
| 606 |
+
device: Optional[torch.device] = None,
|
| 607 |
+
memory_format: Optional[torch.memory_format] = None,
|
| 608 |
+
**kwargs: Any,
|
| 609 |
+
) -> torch.Tensor:
|
| 610 |
+
return aten.randint.low(
|
| 611 |
+
low,
|
| 612 |
+
high,
|
| 613 |
+
[*self.size()],
|
| 614 |
+
dtype=dtype or self.dtype,
|
| 615 |
+
device=device or self.device,
|
| 616 |
+
**kwargs,
|
| 617 |
+
).to(memory_format=get_like_layout(self, memory_format))
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
@register_decomposition(aten.randint.default)
|
| 621 |
+
def randint(
|
| 622 |
+
high: int,
|
| 623 |
+
size: List[Union[int, torch.SymInt]],
|
| 624 |
+
**kwargs: Any,
|
| 625 |
+
) -> torch.Tensor:
|
| 626 |
+
return aten.randint.low(0, high, size, **kwargs)
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
@register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default)
|
| 630 |
+
def linear_dynamic_fp16_unpacked_weight(
|
| 631 |
+
input: torch.Tensor,
|
| 632 |
+
weight: torch.Tensor,
|
| 633 |
+
bias: torch.Tensor,
|
| 634 |
+
) -> torch.Tensor:
|
| 635 |
+
packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight)
|
| 636 |
+
return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(
|
| 637 |
+
input, packed_weight, bias, weight.size()[0]
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
@register_decomposition(_quantized.wrapped_quantized_linear.default)
|
| 642 |
+
def wrapped_quantized_linear(
|
| 643 |
+
input: torch.Tensor,
|
| 644 |
+
input_scale: torch.Tensor,
|
| 645 |
+
input_zero_point: torch.Tensor,
|
| 646 |
+
weight: torch.Tensor,
|
| 647 |
+
weight_scale: torch.Tensor,
|
| 648 |
+
weight_zero_point: torch.Tensor,
|
| 649 |
+
bias: torch.Tensor,
|
| 650 |
+
out_scale: torch.Tensor,
|
| 651 |
+
out_zero_point: torch.Tensor,
|
| 652 |
+
out_channel: int,
|
| 653 |
+
) -> torch.Tensor:
|
| 654 |
+
packed_weight = torch.ops._quantized._wrapped_linear_prepack(
|
| 655 |
+
weight, weight_scale, weight_zero_point, bias
|
| 656 |
+
)
|
| 657 |
+
return torch.ops._quantized._wrapped_quantized_linear_prepacked(
|
| 658 |
+
input,
|
| 659 |
+
input_scale,
|
| 660 |
+
input_zero_point,
|
| 661 |
+
packed_weight,
|
| 662 |
+
out_scale,
|
| 663 |
+
out_zero_point,
|
| 664 |
+
out_channel,
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
|
| 669 |
+
def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor:
|
| 670 |
+
def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor:
|
| 671 |
+
x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
|
| 672 |
+
if sys.byteorder == "little":
|
| 673 |
+
return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None]
|
| 674 |
+
else:
|
| 675 |
+
return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None]
|
| 676 |
+
|
| 677 |
+
scales = bitcast_u8_to_f32(packed[..., -8:-4])
|
| 678 |
+
offsets = bitcast_u8_to_f32(packed[..., -4:])
|
| 679 |
+
return packed[..., :-8].to(torch.float32) * scales + offsets
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
@register_decomposition([aten.grid_sampler_2d])
|
| 683 |
+
@pw_cast_for_opmath
|
| 684 |
+
def grid_sampler_2d(
|
| 685 |
+
a: torch.Tensor,
|
| 686 |
+
grid: torch.Tensor,
|
| 687 |
+
interpolation_mode: int = 0,
|
| 688 |
+
padding_mode: int = 0,
|
| 689 |
+
align_corners: bool = False,
|
| 690 |
+
) -> torch.Tensor:
|
| 691 |
+
# We do not expand the grid (_expand_grid=False) on cpu for performance reasons
|
| 692 |
+
# Experimenting locally it was found that compiled CUDA code is accelerated by ~5x
|
| 693 |
+
# and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2)
|
| 694 |
+
# However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first.
|
| 695 |
+
# Thus we apply this hack to not expand the grid for this case.
|
| 696 |
+
_expand_grid = not (
|
| 697 |
+
a.device == torch.device("cpu")
|
| 698 |
+
and interpolation_mode == 0
|
| 699 |
+
and a.is_contiguous(memory_format=torch.contiguous_format)
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
output = decomp_grid_sampler_2d(
|
| 703 |
+
a,
|
| 704 |
+
grid=grid,
|
| 705 |
+
interpolation_mode=interpolation_mode,
|
| 706 |
+
padding_mode=padding_mode,
|
| 707 |
+
align_corners=align_corners,
|
| 708 |
+
_expand_grid=_expand_grid,
|
| 709 |
+
)
|
| 710 |
+
return output
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
@register_decomposition(aten._foreach_addcmul.Scalar)
|
| 714 |
+
def _foreach_addcmul_scalar(
|
| 715 |
+
self: List[torch.Tensor],
|
| 716 |
+
left_tensors: List[torch.Tensor],
|
| 717 |
+
right_tensors: List[torch.Tensor],
|
| 718 |
+
scalar: float = 1,
|
| 719 |
+
) -> List[torch.Tensor]:
|
| 720 |
+
return aten._foreach_add.List(
|
| 721 |
+
self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
@register_decomposition(aten._foreach_addcdiv.Scalar)
|
| 726 |
+
def _foreach_addcdiv_scalar(
|
| 727 |
+
self: List[torch.Tensor],
|
| 728 |
+
left_tensors: List[torch.Tensor],
|
| 729 |
+
right_tensors: List[torch.Tensor],
|
| 730 |
+
scalar: float = 1,
|
| 731 |
+
) -> List[torch.Tensor]:
|
| 732 |
+
return aten._foreach_add.List(
|
| 733 |
+
self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
@register_decomposition(aten._foreach_lerp.Scalar)
|
| 738 |
+
def _foreach_lerp_scalar(
|
| 739 |
+
start_tensors: List[torch.Tensor],
|
| 740 |
+
end_tensors: List[torch.Tensor],
|
| 741 |
+
weight: torch.types.Number,
|
| 742 |
+
) -> List[torch.Tensor]:
|
| 743 |
+
return aten._foreach_add.List(
|
| 744 |
+
start_tensors,
|
| 745 |
+
aten._foreach_mul.Scalar(
|
| 746 |
+
aten._foreach_sub.List(end_tensors, start_tensors), weight
|
| 747 |
+
),
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
|
| 752 |
+
@register_decomposition(aten.miopen_batch_norm)
|
| 753 |
+
def miopen_batch_norm(
|
| 754 |
+
input: torch.Tensor,
|
| 755 |
+
weight: torch.Tensor,
|
| 756 |
+
bias: typing.Optional[torch.Tensor],
|
| 757 |
+
running_mean: typing.Optional[torch.Tensor],
|
| 758 |
+
running_var: typing.Optional[torch.Tensor],
|
| 759 |
+
training: bool,
|
| 760 |
+
exponential_average_factor: float,
|
| 761 |
+
epsilon: float,
|
| 762 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 763 |
+
a, b, c = aten.native_batch_norm(
|
| 764 |
+
input,
|
| 765 |
+
weight,
|
| 766 |
+
bias,
|
| 767 |
+
running_mean,
|
| 768 |
+
running_var,
|
| 769 |
+
training,
|
| 770 |
+
exponential_average_factor,
|
| 771 |
+
epsilon,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
if training:
|
| 775 |
+
return (a, b, c)
|
| 776 |
+
return (
|
| 777 |
+
a,
|
| 778 |
+
weight.new_zeros((0,)),
|
| 779 |
+
weight.new_zeros((0,)),
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
@functools.lru_cache(None)
|
| 784 |
+
def fast_random_decomps() -> Dict[Any, Callable[..., Any]]:
|
| 785 |
+
return {**decompositions, **extra_random_decomps}
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
# TODO(aakhundov): replace this (and the above) Any by more
|
| 789 |
+
# specific type and fix all the cascading mypy errors
|
| 790 |
+
def select_decomp_table() -> Dict[Any, Callable[..., Any]]:
|
| 791 |
+
"""decomps can change based on config"""
|
| 792 |
+
if config.fallback_random:
|
| 793 |
+
return decompositions
|
| 794 |
+
return fast_random_decomps()
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
@register_decomposition(aten.masked_scatter)
|
| 798 |
+
def masked_scatter(
|
| 799 |
+
self: torch.Tensor,
|
| 800 |
+
mask: torch.Tensor,
|
| 801 |
+
source: torch.Tensor,
|
| 802 |
+
) -> torch.Tensor:
|
| 803 |
+
from .codegen.common import BackendFeature, has_backend_feature
|
| 804 |
+
|
| 805 |
+
if has_backend_feature(self.device, BackendFeature.MASKED_SCATTER_WITH_INDEX):
|
| 806 |
+
# This two-step algorithm is the same as eager CUDA, for eager CPU we
|
| 807 |
+
# use a 1-shot serial iteration.
|
| 808 |
+
self, mask = aten.broadcast_tensors([self, mask])
|
| 809 |
+
source_idx = mask.reshape(-1).cumsum(0) - 1
|
| 810 |
+
self_flat, mask_flat, source_flat = (x.flatten() for x in (self, mask, source))
|
| 811 |
+
result = aten._unsafe_masked_index(source_flat, mask_flat, [source_idx], 0)
|
| 812 |
+
return torch.where(mask_flat, result, self_flat).view(self.shape)
|
| 813 |
+
return NotImplemented
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
@register_decomposition(quantized_decomposed.choose_qparams.tensor)
|
| 817 |
+
def choose_qparams_tensor(
|
| 818 |
+
input: torch.Tensor,
|
| 819 |
+
quant_min: int,
|
| 820 |
+
quant_max: int,
|
| 821 |
+
eps: float,
|
| 822 |
+
dtype: torch.dtype,
|
| 823 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 824 |
+
min_val, max_val = torch.aminmax(input)
|
| 825 |
+
scale = (max_val - min_val) / float(quant_max - quant_min)
|
| 826 |
+
scale = torch.max(scale, torch.Tensor([eps]))
|
| 827 |
+
zero_point = quant_min - torch.round(min_val / scale).to(torch.int)
|
| 828 |
+
zero_point = torch.clamp(zero_point, quant_min, quant_max)
|
| 829 |
+
return scale.to(torch.float64), zero_point.to(torch.int64)
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
@register_decomposition(aten.put)
|
| 833 |
+
def put(
|
| 834 |
+
self: torch.Tensor,
|
| 835 |
+
index: torch.Tensor,
|
| 836 |
+
source: torch.Tensor,
|
| 837 |
+
accumulate: bool = False,
|
| 838 |
+
) -> torch.Tensor:
|
| 839 |
+
flattened = self.flatten()
|
| 840 |
+
flattened = torch.index_put(
|
| 841 |
+
flattened, [index], source.reshape(index.shape), accumulate
|
| 842 |
+
)
|
| 843 |
+
return flattened.reshape(self.shape)
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
@register_decomposition(aten.put_)
|
| 847 |
+
def put_(
|
| 848 |
+
self: torch.Tensor,
|
| 849 |
+
index: torch.Tensor,
|
| 850 |
+
source: torch.Tensor,
|
| 851 |
+
accumulate: bool = False,
|
| 852 |
+
) -> torch.Tensor:
|
| 853 |
+
out = aten.put(self, index, source, accumulate=accumulate)
|
| 854 |
+
return self.copy_(out)
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
@register_decomposition(aten._softmax_backward_data.default)
|
| 858 |
+
@pw_cast_for_opmath
|
| 859 |
+
def _softmax_backward_data(
|
| 860 |
+
grad_output: torch.Tensor,
|
| 861 |
+
output: torch.Tensor,
|
| 862 |
+
dim: int,
|
| 863 |
+
input_dtype: torch.dtype,
|
| 864 |
+
) -> torch.Tensor:
|
| 865 |
+
new_grad_output = grad_output * output
|
| 866 |
+
sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True)
|
| 867 |
+
# grad_input = new_grad_output - output * sum_new_grad
|
| 868 |
+
grad_input = inductor_prims.fma(-output, sum_new_grad, new_grad_output)
|
| 869 |
+
|
| 870 |
+
# CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor
|
| 871 |
+
# if grad_output.device == torch.device("cpu"):
|
| 872 |
+
# return grad_input.contiguous()
|
| 873 |
+
|
| 874 |
+
if grad_output.dtype != input_dtype:
|
| 875 |
+
grad_input = grad_input.to(input_dtype)
|
| 876 |
+
return grad_input.contiguous()
|
| 877 |
+
|
| 878 |
+
|
| 879 |
+
@register_decomposition(aten.index_reduce)
|
| 880 |
+
def index_reduce(
|
| 881 |
+
self: torch.Tensor,
|
| 882 |
+
dim: int,
|
| 883 |
+
index: torch.Tensor,
|
| 884 |
+
src: torch.Tensor,
|
| 885 |
+
reduction_type: str,
|
| 886 |
+
*,
|
| 887 |
+
include_self: bool = True,
|
| 888 |
+
) -> torch.Tensor:
|
| 889 |
+
if reduction_type == "mean" and not needs_fallback_due_to_atomic_add_limitations(
|
| 890 |
+
self.dtype
|
| 891 |
+
):
|
| 892 |
+
true_division = self.dtype.is_floating_point or self.dtype.is_complex
|
| 893 |
+
ones = torch.ones_like(src)
|
| 894 |
+
if include_self:
|
| 895 |
+
out = self
|
| 896 |
+
counts = torch.ones_like(self).index_add(dim, index, ones)
|
| 897 |
+
else:
|
| 898 |
+
out = self.index_fill(dim, index, 0)
|
| 899 |
+
counts = torch.zeros_like(self).index_add(dim, index, ones)
|
| 900 |
+
counts = counts.masked_fill(counts < 1, 1)
|
| 901 |
+
out = out.index_add(dim, index, src)
|
| 902 |
+
return out / counts if true_division else out // counts
|
| 903 |
+
|
| 904 |
+
if use_scatter_fallback(
|
| 905 |
+
aten.scatter_reduce_.two,
|
| 906 |
+
reduction_type,
|
| 907 |
+
self.dtype,
|
| 908 |
+
src.dtype,
|
| 909 |
+
src.device.type,
|
| 910 |
+
True,
|
| 911 |
+
):
|
| 912 |
+
return NotImplemented
|
| 913 |
+
|
| 914 |
+
repeats = self.shape[dim + 1 :].numel() * self.shape[:dim].numel()
|
| 915 |
+
index_shape = (index.numel(), *self.shape[dim + 1 :], *self.shape[:dim])
|
| 916 |
+
perm = (*range(self.ndim - dim, self.ndim), 0, *range(1, self.ndim - dim))
|
| 917 |
+
scatter_index = (
|
| 918 |
+
index.to(torch.int64)
|
| 919 |
+
.repeat_interleave(repeats)
|
| 920 |
+
.reshape(index_shape)
|
| 921 |
+
.permute(perm)
|
| 922 |
+
)
|
| 923 |
+
return self.scatter_reduce(
|
| 924 |
+
dim,
|
| 925 |
+
scatter_index,
|
| 926 |
+
src,
|
| 927 |
+
reduction_type,
|
| 928 |
+
include_self=include_self,
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
@register_decomposition(aten.max_pool2d_with_indices)
|
| 933 |
+
def max_pool2d_with_indices(
|
| 934 |
+
x: torch.Tensor,
|
| 935 |
+
kernel_size: List[int],
|
| 936 |
+
stride: Optional[Union[int, List[int]]] = None,
|
| 937 |
+
padding: Union[int, List[int]] = 0,
|
| 938 |
+
dilation: Union[int, List[int]] = 1,
|
| 939 |
+
ceil_mode: bool = False,
|
| 940 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 941 |
+
if dilation == 1:
|
| 942 |
+
dilation = [1, 1]
|
| 943 |
+
|
| 944 |
+
if padding == 0:
|
| 945 |
+
padding = [0, 0]
|
| 946 |
+
|
| 947 |
+
if not stride:
|
| 948 |
+
stride = kernel_size
|
| 949 |
+
|
| 950 |
+
kernel_size = pad_listlike(kernel_size, 2)
|
| 951 |
+
dilation = pad_listlike(dilation, 2)
|
| 952 |
+
padding = pad_listlike(padding, 2)
|
| 953 |
+
stride = pad_listlike(stride, 2)
|
| 954 |
+
|
| 955 |
+
window_size = kernel_size[0] * kernel_size[1]
|
| 956 |
+
# We fallback when using non-default dilation or when the window size is too large
|
| 957 |
+
if (
|
| 958 |
+
torch._inductor.lowering.should_fallback_max_pool2d_with_indices(
|
| 959 |
+
kernel_size, dilation
|
| 960 |
+
)
|
| 961 |
+
or window_size > torch.iinfo(torch.int8).max
|
| 962 |
+
):
|
| 963 |
+
return NotImplemented
|
| 964 |
+
|
| 965 |
+
vals, offsets = prims._low_memory_max_pool2d_with_offsets(
|
| 966 |
+
x,
|
| 967 |
+
kernel_size,
|
| 968 |
+
stride,
|
| 969 |
+
padding,
|
| 970 |
+
dilation,
|
| 971 |
+
ceil_mode,
|
| 972 |
+
)
|
| 973 |
+
indices = prims._low_memory_max_pool2d_offsets_to_indices(
|
| 974 |
+
offsets,
|
| 975 |
+
kernel_size[1],
|
| 976 |
+
x.size(-1),
|
| 977 |
+
stride,
|
| 978 |
+
padding,
|
| 979 |
+
)
|
| 980 |
+
return vals, indices
|
.venv/lib/python3.11/site-packages/torch/_inductor/dependencies.py
ADDED
|
@@ -0,0 +1,745 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import abc
|
| 3 |
+
import dataclasses
|
| 4 |
+
import itertools
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
import typing
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
| 9 |
+
from unittest.mock import patch
|
| 10 |
+
|
| 11 |
+
import sympy
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
| 15 |
+
from torch.utils._ordered_set import OrderedSet
|
| 16 |
+
|
| 17 |
+
from .codegen.common import index_prevent_reordering
|
| 18 |
+
from .utils import (
|
| 19 |
+
get_dtype_size,
|
| 20 |
+
reduction_num_outputs,
|
| 21 |
+
sympy_index_symbol,
|
| 22 |
+
sympy_str,
|
| 23 |
+
sympy_subs,
|
| 24 |
+
VarRanges,
|
| 25 |
+
)
|
| 26 |
+
from .virtualized import OpsHandler, ReductionType, V
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
log = logging.getLogger(__name__)
|
| 30 |
+
is_indirect = re.compile(r"indirect|tmp").search
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Dep(abc.ABC):
|
| 34 |
+
name: str
|
| 35 |
+
index: sympy.Expr
|
| 36 |
+
|
| 37 |
+
@abc.abstractmethod
|
| 38 |
+
def rename(self, renames: Dict[str, str]) -> "Dep":
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
@abc.abstractmethod
|
| 42 |
+
def get_numel(self) -> sympy.Expr:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
@abc.abstractmethod
|
| 46 |
+
def numbytes_hint(self):
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
@abc.abstractmethod
|
| 50 |
+
def has_unbacked_symbols(self) -> bool:
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
@abc.abstractmethod
|
| 54 |
+
def is_contiguous(self) -> bool:
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
def normalize_with_stride_order(self, prefix="t"):
|
| 58 |
+
return self
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclasses.dataclass(frozen=True)
|
| 62 |
+
class MemoryDep(Dep):
|
| 63 |
+
name: str
|
| 64 |
+
index: sympy.Expr
|
| 65 |
+
var_names: Tuple[sympy.Symbol, ...]
|
| 66 |
+
size: Tuple[sympy.Expr, ...]
|
| 67 |
+
mode: Optional[str] = None
|
| 68 |
+
|
| 69 |
+
def __repr__(self) -> str:
|
| 70 |
+
return f"MemoryDep({self.name!r}, {self.index}, {self.ranges}, {self.mode})"
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def num_vars(self):
|
| 74 |
+
return len(self.var_names)
|
| 75 |
+
|
| 76 |
+
def decide_loop_order_to_match(self, other):
|
| 77 |
+
"""
|
| 78 |
+
Can return None if not able to decide loop orders.
|
| 79 |
+
"""
|
| 80 |
+
assert self.num_vars == other.num_vars
|
| 81 |
+
|
| 82 |
+
# ignore broadcast for now since broadcast causes extra 0 strides
|
| 83 |
+
# which makes it hard to decide the correct loop orders.
|
| 84 |
+
if self.num_vars != len(self.index.free_symbols):
|
| 85 |
+
return None
|
| 86 |
+
if other.num_vars != len(other.index.free_symbols):
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
# bail out if any size is 0 or 1
|
| 90 |
+
# For size == 0, it's an empty tensor, any strides for that dimension
|
| 91 |
+
# are equivalent. Skip for simplicity and it may not matter that much.
|
| 92 |
+
#
|
| 93 |
+
# For size == 1, it cause cause tie for strides of different dimensions.
|
| 94 |
+
# Also when we first time create LoopBody in ComputedBuffer.simplify_and_reorder
|
| 95 |
+
# we can dependencies.index_vars_squeeze which should already sqeeuze
|
| 96 |
+
# the size == 1 dimensions.
|
| 97 |
+
if any(s == 0 or s == 1 for s in itertools.chain(self.size, other.size)):
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
# Extract strides for both expression
|
| 101 |
+
self_strides = V.graph.sizevars.stride_hints(self.index, self.var_names)
|
| 102 |
+
other_strides = V.graph.sizevars.stride_hints(other.index, other.var_names)
|
| 103 |
+
|
| 104 |
+
# Even if the shape contains no 0/1, some complex index expression may
|
| 105 |
+
# still have duplicate stride values. Here is an example:
|
| 106 |
+
# https://gist.github.com/shunting314/511a7e1ec88aa2e1a8ec85d8445ab129
|
| 107 |
+
# We don't reorder the loop for these cases for now, but in theory
|
| 108 |
+
# we could improve the algorithm to detect the correct loop orders.
|
| 109 |
+
if len(set(self_strides)) != len(self_strides) or len(
|
| 110 |
+
set(other_strides)
|
| 111 |
+
) != len(other_strides):
|
| 112 |
+
log.debug(
|
| 113 |
+
"unable to decide loop order. self_dep=%s v.s. other_dep=%s, self_strides=%s v.s. other_strides=%s",
|
| 114 |
+
self,
|
| 115 |
+
other,
|
| 116 |
+
self_strides,
|
| 117 |
+
other_strides,
|
| 118 |
+
)
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
# May hanppen if self and other are as follows
|
| 122 |
+
# MemoryDep('addmm_6', 393216*d0 + 768*d1 + d2, {d0: 16, d1: 512, d2: 768}, None)
|
| 123 |
+
# MemoryDep('addmm_6', 98304*d0 + d1 + 768*d2, {d0: 64, d1: 768, d2: 128}, None)
|
| 124 |
+
if set(self_strides) != set(other_strides):
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
stride_to_index = {s: i for i, s in enumerate(self_strides)}
|
| 128 |
+
order = []
|
| 129 |
+
for s in other_strides:
|
| 130 |
+
order.append(stride_to_index[s])
|
| 131 |
+
|
| 132 |
+
assert set(order) == set(range(0, self.num_vars))
|
| 133 |
+
return order
|
| 134 |
+
|
| 135 |
+
def get_offset(self):
|
| 136 |
+
"""
|
| 137 |
+
Return the offset by setting every variable to be 0.
|
| 138 |
+
"""
|
| 139 |
+
return sympy_subs(self.index, dict.fromkeys(self.var_names, 0))
|
| 140 |
+
|
| 141 |
+
def normalize(self) -> "MemoryDep":
|
| 142 |
+
"""
|
| 143 |
+
Normalize by merging loops. The different to normalize_with_stride_order is,
|
| 144 |
+
this method does not reorder loops while normalize_with_stride_order reorder
|
| 145 |
+
loops based on stride order.
|
| 146 |
+
"""
|
| 147 |
+
return MemoryDep(
|
| 148 |
+
self.name,
|
| 149 |
+
*_RecordLoadStoreInner._normalize(self.index, self.ranges), # type: ignore[arg-type]
|
| 150 |
+
self.mode,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def normalize_with_stride_order(self, prefix="t"):
|
| 154 |
+
r"""
|
| 155 |
+
Used to decide if two MemoryDep does not equal due to different loop orders.
|
| 156 |
+
More specifically, when dep1 and dep2 are not equal, we can normalize
|
| 157 |
+
both and check if they are equal after that. If yes, then the mismatch is
|
| 158 |
+
caused by different loop orders.
|
| 159 |
+
"""
|
| 160 |
+
# import here to avoid circular import
|
| 161 |
+
from torch._inductor import ir
|
| 162 |
+
|
| 163 |
+
strides = V.graph.sizevars.stride_hints(self.index, self.var_names)
|
| 164 |
+
|
| 165 |
+
# pick a loop order with stride ordered decreasingly
|
| 166 |
+
order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
|
| 167 |
+
stride_reorder = ir.same_reorder(order)
|
| 168 |
+
sizes = self.size
|
| 169 |
+
var_names = self.var_names
|
| 170 |
+
|
| 171 |
+
new_reordered_sizes = stride_reorder(sizes)
|
| 172 |
+
new_reordered_var_names = stride_reorder(var_names)
|
| 173 |
+
|
| 174 |
+
new_simplified_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
|
| 175 |
+
new_reordered_var_names,
|
| 176 |
+
new_reordered_sizes,
|
| 177 |
+
index_prevent_reordering(
|
| 178 |
+
[self.index], new_reordered_var_names, new_reordered_sizes
|
| 179 |
+
),
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# now let's create new symbols with the passed in prefix
|
| 183 |
+
var_ranges, add_var = var_builder(prefix)
|
| 184 |
+
replacement = dict(
|
| 185 |
+
zip(
|
| 186 |
+
new_reordered_var_names,
|
| 187 |
+
reindex([add_var(x) for x in new_simplified_sizes]),
|
| 188 |
+
)
|
| 189 |
+
)
|
| 190 |
+
new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR
|
| 191 |
+
|
| 192 |
+
out = MemoryDep(self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())) # type: ignore[arg-type]
|
| 193 |
+
return out
|
| 194 |
+
|
| 195 |
+
@property
|
| 196 |
+
def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]:
|
| 197 |
+
"""{c0: 128, c1: 512, ...}"""
|
| 198 |
+
return dict(zip(self.var_names, self.size))
|
| 199 |
+
|
| 200 |
+
def get_numel(self) -> sympy.Expr:
|
| 201 |
+
if self.is_indirect():
|
| 202 |
+
numel = V.graph.get_numel(self.name)
|
| 203 |
+
else:
|
| 204 |
+
vars: OrderedSet[sympy.Basic] = OrderedSet(self.index.free_symbols)
|
| 205 |
+
numel = sympy.Integer(1)
|
| 206 |
+
for var, size in zip(self.var_names, self.size):
|
| 207 |
+
if var in vars:
|
| 208 |
+
numel = numel * size
|
| 209 |
+
return numel # type: ignore[return-value]
|
| 210 |
+
|
| 211 |
+
def rename(self, renames: Dict[str, str]) -> "MemoryDep":
|
| 212 |
+
if self.name in renames:
|
| 213 |
+
return MemoryDep(
|
| 214 |
+
renames[self.name],
|
| 215 |
+
self.index,
|
| 216 |
+
var_names=self.var_names,
|
| 217 |
+
size=self.size,
|
| 218 |
+
mode=self.mode,
|
| 219 |
+
)
|
| 220 |
+
return self
|
| 221 |
+
|
| 222 |
+
def numbytes_hint(self):
|
| 223 |
+
return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
|
| 224 |
+
V.graph.get_dtype(self.name)
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
def has_unbacked_symbols(self):
|
| 228 |
+
return len(free_unbacked_symbols(self.get_numel())) > 0
|
| 229 |
+
|
| 230 |
+
def is_contiguous(self) -> bool:
|
| 231 |
+
return isinstance(self.index, sympy.Symbol) and self.index in self.var_names
|
| 232 |
+
|
| 233 |
+
def stride1_for_last_dim(self, result_for_complex_expression=True) -> bool:
|
| 234 |
+
"""
|
| 235 |
+
Whether the stride for the last dimension is 1.
|
| 236 |
+
"""
|
| 237 |
+
# python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_masked_scatter_cuda_float16
|
| 238 |
+
# will exercise thru this corner case.
|
| 239 |
+
if len(self.var_names) == 0:
|
| 240 |
+
return True
|
| 241 |
+
|
| 242 |
+
terms = self.index.args if isinstance(self.index, sympy.Add) else [self.index]
|
| 243 |
+
|
| 244 |
+
last_sym = self.var_names[-1]
|
| 245 |
+
for term in terms:
|
| 246 |
+
if term is last_sym:
|
| 247 |
+
return True
|
| 248 |
+
|
| 249 |
+
# Having a >1 stride for the last dimension is bad for perf
|
| 250 |
+
# return False.
|
| 251 |
+
if (
|
| 252 |
+
isinstance(term, sympy.Mul)
|
| 253 |
+
and len(term.args) == 2
|
| 254 |
+
and term.args[1] is last_sym
|
| 255 |
+
and isinstance(term.args[0], (int, sympy.Integer))
|
| 256 |
+
and term.args[0] > 1
|
| 257 |
+
):
|
| 258 |
+
return False
|
| 259 |
+
|
| 260 |
+
return result_for_complex_expression
|
| 261 |
+
|
| 262 |
+
def is_scalar(self) -> bool:
|
| 263 |
+
if isinstance(self.index, sympy.Symbol):
|
| 264 |
+
return self.index not in self.var_names and not self.is_indirect()
|
| 265 |
+
return isinstance(self.index, (int, sympy.Integer))
|
| 266 |
+
|
| 267 |
+
def is_indirect(self) -> bool:
|
| 268 |
+
return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined]
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
@dataclasses.dataclass(frozen=True)
|
| 272 |
+
class StarDep(Dep):
|
| 273 |
+
name: str
|
| 274 |
+
mode: Optional[str] = None
|
| 275 |
+
|
| 276 |
+
# depends on the entire buffer
|
| 277 |
+
@property
|
| 278 |
+
def index(self):
|
| 279 |
+
raise NotImplementedError("StarDep does not have an index")
|
| 280 |
+
|
| 281 |
+
def get_numel(self) -> sympy.Expr:
|
| 282 |
+
return V.graph.get_numel(self.name) # type: ignore[return-value]
|
| 283 |
+
|
| 284 |
+
def rename(self, renames: Dict[str, str]) -> "StarDep":
|
| 285 |
+
if self.name in renames:
|
| 286 |
+
return StarDep(renames[self.name], self.mode)
|
| 287 |
+
return self
|
| 288 |
+
|
| 289 |
+
def numbytes_hint(self):
|
| 290 |
+
return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
|
| 291 |
+
V.graph.get_dtype(self.name)
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
def has_unbacked_symbols(self):
|
| 295 |
+
return len(free_unbacked_symbols(self.get_numel())) > 0
|
| 296 |
+
|
| 297 |
+
def is_contiguous(self) -> bool:
|
| 298 |
+
return False
|
| 299 |
+
|
| 300 |
+
def is_scalar(self) -> bool:
|
| 301 |
+
return False
|
| 302 |
+
|
| 303 |
+
def is_indirect(self) -> bool:
|
| 304 |
+
return False
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# Used for tracking mutation ordering
|
| 308 |
+
# if A reads a buffer and B mutates it
|
| 309 |
+
# B must be ordered after A
|
| 310 |
+
#
|
| 311 |
+
# This is useful for a variety of reasons.
|
| 312 |
+
# For example, if A's read is never actually used, we can eliminate it.
|
| 313 |
+
# Another case is if A's buffer ends up being fused away, we never need to
|
| 314 |
+
# materialize that buffer
|
| 315 |
+
@dataclasses.dataclass(frozen=True)
|
| 316 |
+
class WeakDep(Dep):
|
| 317 |
+
# Fake dependency on unused buffer
|
| 318 |
+
name: str
|
| 319 |
+
# Buffer that is doing the mutation
|
| 320 |
+
mutating_buf: str
|
| 321 |
+
|
| 322 |
+
@property
|
| 323 |
+
def index(self):
|
| 324 |
+
raise NotImplementedError("WeakDep does not have an index")
|
| 325 |
+
|
| 326 |
+
def get_numel(self) -> sympy.Expr:
|
| 327 |
+
return sympy.Integer(1)
|
| 328 |
+
|
| 329 |
+
def rename(self, renames: Dict[str, str]) -> "WeakDep":
|
| 330 |
+
if self.name in renames:
|
| 331 |
+
return WeakDep(renames[self.name], self.mutating_buf)
|
| 332 |
+
return self
|
| 333 |
+
|
| 334 |
+
def numbytes_hint(self):
|
| 335 |
+
return 1 # Purely inserted for ordering, not an actual dep
|
| 336 |
+
|
| 337 |
+
def has_unbacked_symbols(self):
|
| 338 |
+
return False
|
| 339 |
+
|
| 340 |
+
def is_contiguous(self) -> bool:
|
| 341 |
+
return False
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@dataclasses.dataclass(frozen=True)
|
| 345 |
+
class IndexExprDep:
|
| 346 |
+
index: sympy.Expr # type: ignore[assignment]
|
| 347 |
+
var_names: Tuple[sympy.Symbol, ...]
|
| 348 |
+
size: Tuple[sympy.Expr, ...]
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
@dataclasses.dataclass
|
| 352 |
+
class ReadWrites:
|
| 353 |
+
reads: OrderedSet[Dep]
|
| 354 |
+
writes: OrderedSet[Dep]
|
| 355 |
+
index_exprs: OrderedSet[IndexExprDep]
|
| 356 |
+
range_vars: Optional[List[sympy.Expr]] = None
|
| 357 |
+
var_ranges: Optional[VarRanges] = None
|
| 358 |
+
|
| 359 |
+
def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
|
| 360 |
+
return ReadWrites(
|
| 361 |
+
OrderedSet(dep.rename(renames) for dep in self.reads),
|
| 362 |
+
OrderedSet(dep.rename(renames) for dep in self.writes),
|
| 363 |
+
self.index_exprs,
|
| 364 |
+
self.range_vars,
|
| 365 |
+
self.var_ranges,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
def with_read(self, dep: Union[Dep, Set[Dep]]) -> "ReadWrites":
|
| 369 |
+
assert isinstance(dep, (WeakDep, StarDep, set))
|
| 370 |
+
if not isinstance(dep, set):
|
| 371 |
+
dep = {dep}
|
| 372 |
+
return ReadWrites(
|
| 373 |
+
OrderedSet.union(self.reads, dep),
|
| 374 |
+
self.writes,
|
| 375 |
+
self.index_exprs,
|
| 376 |
+
self.range_vars,
|
| 377 |
+
self.var_ranges,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
def merge(self, other: "ReadWrites"):
|
| 381 |
+
reads = OrderedSet.union(self.reads, other.reads)
|
| 382 |
+
writes = OrderedSet.union(self.writes, other.writes)
|
| 383 |
+
index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs)
|
| 384 |
+
return ReadWrites(reads - writes, writes, index_exprs)
|
| 385 |
+
|
| 386 |
+
@staticmethod
|
| 387 |
+
def merge_list(read_writes: List["ReadWrites"]):
|
| 388 |
+
all_writes = OrderedSet.union(*[rw.writes for rw in read_writes])
|
| 389 |
+
all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes
|
| 390 |
+
all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes])
|
| 391 |
+
return ReadWrites(all_reads, all_writes, all_index_exprs)
|
| 392 |
+
|
| 393 |
+
def remove_reads(self, rem_reads):
|
| 394 |
+
return ReadWrites(
|
| 395 |
+
self.reads - rem_reads,
|
| 396 |
+
self.writes,
|
| 397 |
+
self.index_exprs,
|
| 398 |
+
self.range_vars,
|
| 399 |
+
self.var_ranges,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
def reads_and_writes(self):
|
| 403 |
+
return itertools.chain(self.reads, self.writes)
|
| 404 |
+
|
| 405 |
+
def buffer_names(self, ignore_integer_index=True):
|
| 406 |
+
"""
|
| 407 |
+
Integer index is used for load_seed.
|
| 408 |
+
"""
|
| 409 |
+
names: OrderedSet[str] = OrderedSet()
|
| 410 |
+
for dep in self.reads_and_writes():
|
| 411 |
+
if not isinstance(dep, MemoryDep):
|
| 412 |
+
continue
|
| 413 |
+
if not ignore_integer_index or not isinstance(
|
| 414 |
+
dep.index, (int, sympy.Integer)
|
| 415 |
+
):
|
| 416 |
+
names.add(dep.name)
|
| 417 |
+
return names
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
|
| 421 |
+
def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:
|
| 422 |
+
super().__init__()
|
| 423 |
+
self._reads: OrderedSet[Dep] = OrderedSet()
|
| 424 |
+
self._writes: OrderedSet[MemoryDep] = OrderedSet()
|
| 425 |
+
self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet()
|
| 426 |
+
self._var_ranges: VarRanges = var_ranges
|
| 427 |
+
self._should_normalize: bool = normalize
|
| 428 |
+
|
| 429 |
+
@staticmethod
|
| 430 |
+
def drop_unused_symbols(index, var_names, sizes):
|
| 431 |
+
"""
|
| 432 |
+
Reduction has last (reduced) dim in its sizes, but
|
| 433 |
+
downstream users won't. Normalize this away.
|
| 434 |
+
"""
|
| 435 |
+
if not isinstance(index, sympy.Expr):
|
| 436 |
+
# index can be an int
|
| 437 |
+
return
|
| 438 |
+
free_symbols = index.free_symbols
|
| 439 |
+
while var_names and var_names[-1] not in free_symbols:
|
| 440 |
+
var_names.pop()
|
| 441 |
+
sizes.pop()
|
| 442 |
+
|
| 443 |
+
@classmethod
|
| 444 |
+
def _normalize(
|
| 445 |
+
cls, index: sympy.Expr, var_ranges: VarRanges
|
| 446 |
+
) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]:
|
| 447 |
+
# Try to further simplify the indexes even if simplify_loops didn't
|
| 448 |
+
# convert it to the simplest form because of the interference from
|
| 449 |
+
# different indexing formulas.
|
| 450 |
+
index_vars = [*var_ranges.keys()]
|
| 451 |
+
sizes = tuple(var_ranges.values()) # type: ignore[assignment]
|
| 452 |
+
new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
|
| 453 |
+
index_vars,
|
| 454 |
+
sizes,
|
| 455 |
+
index_prevent_reordering([index], index_vars, sizes),
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# assign new variables each dimension to deal with numbering mismatches
|
| 459 |
+
# d0, d1, d2 could become d0, d2 -- which won't match d0, d1
|
| 460 |
+
new_vars, add_var = var_builder(canonicalization_prefix())
|
| 461 |
+
replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
|
| 462 |
+
index = sympy_subs(sympy.expand(index), replacement)
|
| 463 |
+
|
| 464 |
+
new_vars = [*new_vars.keys()]
|
| 465 |
+
new_sizes = [*new_sizes]
|
| 466 |
+
cls.drop_unused_symbols(index, new_vars, new_sizes)
|
| 467 |
+
return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type]
|
| 468 |
+
|
| 469 |
+
def canonicalize(
|
| 470 |
+
self, index: sympy.Expr
|
| 471 |
+
) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]:
|
| 472 |
+
if not self._should_normalize:
|
| 473 |
+
sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()]
|
| 474 |
+
var_names = [k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1]
|
| 475 |
+
sizes = [v for v in sizes if v != 1]
|
| 476 |
+
|
| 477 |
+
self.drop_unused_symbols(index, var_names, sizes)
|
| 478 |
+
|
| 479 |
+
return index, tuple(var_names), tuple(sizes) # type: ignore[return-value, arg-type]
|
| 480 |
+
var_ranges = {
|
| 481 |
+
k: V.graph.sizevars.simplify(v)
|
| 482 |
+
for k, v in self._var_ranges.items()
|
| 483 |
+
# TODO(jansel): explore this further normalization
|
| 484 |
+
# if k in free_symbols
|
| 485 |
+
}
|
| 486 |
+
return self._normalize(index, var_ranges)
|
| 487 |
+
|
| 488 |
+
def load(self, name: str, index: sympy.Expr) -> str:
|
| 489 |
+
self._reads.add(MemoryDep(name, *self.canonicalize(index)))
|
| 490 |
+
return f"load({name}, {sympy_str(index)})"
|
| 491 |
+
|
| 492 |
+
def load_seed(self, name: str, index: int):
|
| 493 |
+
assert isinstance(index, int)
|
| 494 |
+
return self.load(name, sympy.Integer(index))
|
| 495 |
+
|
| 496 |
+
def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str:
|
| 497 |
+
self._writes.add(MemoryDep(name, *self.canonicalize(index), mode=mode))
|
| 498 |
+
return f"store({name}, {sympy_str(index)}, {value}, {mode})"
|
| 499 |
+
|
| 500 |
+
def store_reduction(self, name: str, index, value) -> str:
|
| 501 |
+
return self.store(name, index, f"store_reduction({value})")
|
| 502 |
+
|
| 503 |
+
def index_expr(self, index: sympy.Expr, dtype) -> str:
|
| 504 |
+
self._index_exprs.add(IndexExprDep(*self.canonicalize(index)))
|
| 505 |
+
return f"index_expr({sympy_str(index)}, {dtype})"
|
| 506 |
+
|
| 507 |
+
def bucketize(
|
| 508 |
+
self,
|
| 509 |
+
values,
|
| 510 |
+
offsets_name: str,
|
| 511 |
+
offsets_size: sympy.Expr,
|
| 512 |
+
indexing_dtype: torch.dtype,
|
| 513 |
+
right: bool,
|
| 514 |
+
):
|
| 515 |
+
self._reads.add(StarDep(offsets_name))
|
| 516 |
+
return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})"
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined]
|
| 520 |
+
def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:
|
| 521 |
+
parent_handler = _RecordLoadStoreInner(
|
| 522 |
+
var_ranges=var_ranges, normalize=normalize
|
| 523 |
+
)
|
| 524 |
+
super().__init__(parent_handler=parent_handler)
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
# TODO: check call sites
|
| 528 |
+
def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]:
|
| 529 |
+
cnt = itertools.count()
|
| 530 |
+
var_ranges: VarRanges = {}
|
| 531 |
+
|
| 532 |
+
def add_var(length: sympy.Expr) -> sympy.Symbol:
|
| 533 |
+
v = sympy_index_symbol(f"{prefix}{next(cnt)}")
|
| 534 |
+
var_ranges[v] = length
|
| 535 |
+
return v
|
| 536 |
+
|
| 537 |
+
return var_ranges, add_var
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
|
| 541 |
+
var_ranges, add_var = var_builder(prefix)
|
| 542 |
+
args: List[List[sympy.Symbol]] = []
|
| 543 |
+
for size in argsizes:
|
| 544 |
+
args.append(list(map(add_var, size)))
|
| 545 |
+
return args, var_ranges
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"):
|
| 549 |
+
from .ir import SqueezeView
|
| 550 |
+
|
| 551 |
+
var_ranges, add_var = var_builder(prefix)
|
| 552 |
+
args: List[List[sympy.Expr]] = []
|
| 553 |
+
new_sizes: List[List[sympy.Expr]] = []
|
| 554 |
+
for size in argsizes:
|
| 555 |
+
new_size, reindex = SqueezeView.squeezer(size)
|
| 556 |
+
new_sizes.append(new_size)
|
| 557 |
+
args.append(reindex(list(map(add_var, new_size))))
|
| 558 |
+
return args, var_ranges
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def extract_read_writes(
|
| 562 |
+
fn: Callable[..., Any],
|
| 563 |
+
*argsizes: Tuple[sympy.Expr, ...],
|
| 564 |
+
normalize: bool = False,
|
| 565 |
+
prefix: str = "d",
|
| 566 |
+
hidden_args=(),
|
| 567 |
+
):
|
| 568 |
+
args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
|
| 569 |
+
|
| 570 |
+
from .loop_body import LoopBody, MemoryUsageType
|
| 571 |
+
|
| 572 |
+
if isinstance(fn, LoopBody):
|
| 573 |
+
# Fast path to avoid tracing when we already have a LoopBody
|
| 574 |
+
inner = _RecordLoadStoreInner(var_ranges=var_ranges, normalize=normalize)
|
| 575 |
+
name_to_index = fn.indexing_from_args([*args, *hidden_args])
|
| 576 |
+
if fn.indirect_vars:
|
| 577 |
+
# mimic the `tmpX` naming tracing gives us
|
| 578 |
+
repl = {v: sympy.Symbol(f"tmp{i}") for i, v in enumerate(fn.indirect_vars)}
|
| 579 |
+
name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()}
|
| 580 |
+
for entry in fn.memory_usage[MemoryUsageType.LOAD]:
|
| 581 |
+
inner.load(entry.buffer_name, name_to_index[entry.index_name])
|
| 582 |
+
for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]:
|
| 583 |
+
inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name]))
|
| 584 |
+
for entry in fn.memory_usage[MemoryUsageType.STORE]:
|
| 585 |
+
inner.store(
|
| 586 |
+
entry.buffer_name, name_to_index[entry.index_name], None, entry.mode
|
| 587 |
+
)
|
| 588 |
+
for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]:
|
| 589 |
+
inner.store_reduction(
|
| 590 |
+
entry.buffer_name, name_to_index[entry.index_name], None
|
| 591 |
+
)
|
| 592 |
+
for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
|
| 593 |
+
inner.index_expr(name_to_index[entry.index_name], None)
|
| 594 |
+
for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]:
|
| 595 |
+
inner.bucketize(
|
| 596 |
+
None, entry.buffer_name, name_to_index[entry.index_name], None, None
|
| 597 |
+
)
|
| 598 |
+
# fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
|
| 599 |
+
else:
|
| 600 |
+
# Slow path tracing the function
|
| 601 |
+
rw = RecordLoadStore(var_ranges, normalize=normalize)
|
| 602 |
+
with V.set_ops_handler(rw):
|
| 603 |
+
fn(*args, *hidden_args)
|
| 604 |
+
inner = rw.parent_handler
|
| 605 |
+
|
| 606 |
+
if normalize:
|
| 607 |
+
range_vars = [] # Number of vars could differ due to normalization
|
| 608 |
+
else:
|
| 609 |
+
range_vars = [*itertools.chain.from_iterable(args)]
|
| 610 |
+
|
| 611 |
+
return ReadWrites(
|
| 612 |
+
OrderedSet(inner._reads),
|
| 613 |
+
OrderedSet(inner._writes),
|
| 614 |
+
inner._index_exprs,
|
| 615 |
+
range_vars,
|
| 616 |
+
var_ranges,
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def extract_input_node_reduction_ranges(
|
| 621 |
+
input_node: "torch._inductor.ir.TensorBox",
|
| 622 |
+
) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
|
| 623 |
+
"""
|
| 624 |
+
Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
|
| 625 |
+
It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
|
| 626 |
+
In this case, reduction_sizes of the Reduction nodes need to be the same.
|
| 627 |
+
Otherwise returns (None, None).
|
| 628 |
+
"""
|
| 629 |
+
|
| 630 |
+
from .ir import ComputedBuffer, Loops
|
| 631 |
+
|
| 632 |
+
if isinstance(input_node.data, ComputedBuffer):
|
| 633 |
+
# Input node has already been realized. Return its size and reduction_size.
|
| 634 |
+
size = input_node.get_size()
|
| 635 |
+
reduction_size = input_node.get_reduction_size()
|
| 636 |
+
if len(reduction_size) > 0:
|
| 637 |
+
return (size, reduction_size)
|
| 638 |
+
else:
|
| 639 |
+
return (None, None)
|
| 640 |
+
|
| 641 |
+
if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined]
|
| 642 |
+
# Other IRNodes do not have reduction_ranges.
|
| 643 |
+
return (None, None)
|
| 644 |
+
|
| 645 |
+
# There is one issue: what if there are views / permutations between the input node and its dependent realized nodes?
|
| 646 |
+
# The current method still uses reduction ranges from the dependent realized node, which is not ideal.
|
| 647 |
+
# Is there a way to check whether there are permutations inbetween?
|
| 648 |
+
reads = input_node.get_reads()
|
| 649 |
+
reduction_size = None
|
| 650 |
+
size = None
|
| 651 |
+
while reduction_size is None and len(reads) > 0:
|
| 652 |
+
seen: OrderedSet[str] = OrderedSet()
|
| 653 |
+
new_reads = []
|
| 654 |
+
for read in reads:
|
| 655 |
+
if not isinstance(read, MemoryDep):
|
| 656 |
+
continue
|
| 657 |
+
if read.name in seen:
|
| 658 |
+
continue
|
| 659 |
+
seen.add(read.name)
|
| 660 |
+
buffer = V.graph.try_get_buffer(read.name)
|
| 661 |
+
if buffer is None:
|
| 662 |
+
continue
|
| 663 |
+
op = buffer.get_defining_op()
|
| 664 |
+
if op is None:
|
| 665 |
+
continue
|
| 666 |
+
|
| 667 |
+
if isinstance(op, ComputedBuffer) and len(op.get_reduction_size()) > 0:
|
| 668 |
+
if reduction_size is None:
|
| 669 |
+
reduction_size = op.get_reduction_size()
|
| 670 |
+
size = op.get_size()
|
| 671 |
+
elif reduction_size != op.get_reduction_size() or size != op.get_size():
|
| 672 |
+
return (None, None)
|
| 673 |
+
else:
|
| 674 |
+
new_reads.extend(op.get_reads())
|
| 675 |
+
if reads == new_reads:
|
| 676 |
+
return (size, reduction_size)
|
| 677 |
+
else:
|
| 678 |
+
reads = new_reads
|
| 679 |
+
return (size, reduction_size)
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def canonicalization_prefix():
|
| 683 |
+
return "c"
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
# ops handler which computes all the free unbacked symbols for an IR
|
| 687 |
+
class FreeUnbackedSymbolsOpsHandler:
|
| 688 |
+
symbols: OrderedSet[sympy.Symbol]
|
| 689 |
+
|
| 690 |
+
def __init__(self) -> None:
|
| 691 |
+
self.symbols = OrderedSet()
|
| 692 |
+
|
| 693 |
+
def __getattr__(self, name: str) -> Callable[..., Any]:
|
| 694 |
+
def inner(*args, **kwargs):
|
| 695 |
+
for a in itertools.chain(args, kwargs.values()):
|
| 696 |
+
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
|
| 697 |
+
self.symbols |= free_unbacked_symbols(a)
|
| 698 |
+
|
| 699 |
+
return inner
|
| 700 |
+
|
| 701 |
+
def indirect_indexing(
|
| 702 |
+
self, index_var, size, check=True, wrap_neg=True
|
| 703 |
+
) -> sympy.Symbol:
|
| 704 |
+
assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean))
|
| 705 |
+
self.symbols |= free_unbacked_symbols(size)
|
| 706 |
+
return sympy_index_symbol(f"({str(index_var)})")
|
| 707 |
+
|
| 708 |
+
def frexp(self, x):
|
| 709 |
+
return (None,) * 2
|
| 710 |
+
|
| 711 |
+
def scan(self, dtypes, combine_fn, values):
|
| 712 |
+
return (None,) * len(values)
|
| 713 |
+
|
| 714 |
+
def sort(self, dtypes, values, stable, descending):
|
| 715 |
+
return (None,) * len(values)
|
| 716 |
+
|
| 717 |
+
def reduction(
|
| 718 |
+
self,
|
| 719 |
+
dtype: torch.dtype,
|
| 720 |
+
src_dtype: torch.dtype,
|
| 721 |
+
reduction_type: ReductionType,
|
| 722 |
+
value: Union[None, Tuple[None, ...]],
|
| 723 |
+
) -> Union[None, Tuple[None, ...]]:
|
| 724 |
+
num_values = reduction_num_outputs(reduction_type)
|
| 725 |
+
return (None,) * num_values if num_values > 1 else None
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def _typecheck_FreeUnbackedSymbolsOpsHandler(
|
| 729 |
+
h: FreeUnbackedSymbolsOpsHandler,
|
| 730 |
+
) -> OpsHandler[None]:
|
| 731 |
+
return h
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
def extract_free_unbacked_symbols(fn: Callable[..., Any], index, rindex=None):
|
| 735 |
+
from .ir import FlexibleLayout
|
| 736 |
+
|
| 737 |
+
args = [index, rindex] if rindex is not None else [index]
|
| 738 |
+
handler = FreeUnbackedSymbolsOpsHandler()
|
| 739 |
+
# NB: I cargo culted the allow_indexing patch here, I don't understand why
|
| 740 |
+
# people do this all over
|
| 741 |
+
with V.set_ops_handler(handler), patch.object(
|
| 742 |
+
FlexibleLayout, "allow_indexing", True
|
| 743 |
+
):
|
| 744 |
+
fn(*args)
|
| 745 |
+
return handler.symbols
|
.venv/lib/python3.11/site-packages/torch/_inductor/exc.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
import textwrap
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1":
|
| 11 |
+
|
| 12 |
+
@lru_cache(None)
|
| 13 |
+
def _record_missing_op(target):
|
| 14 |
+
with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd:
|
| 15 |
+
fd.write(str(target) + "\n")
|
| 16 |
+
|
| 17 |
+
else:
|
| 18 |
+
|
| 19 |
+
def _record_missing_op(target): # type: ignore[misc]
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class OperatorIssue(RuntimeError):
|
| 24 |
+
@staticmethod
|
| 25 |
+
def operator_str(target, args, kwargs):
|
| 26 |
+
lines = [f"target: {target}"] + [
|
| 27 |
+
f"args[{i}]: {arg}" for i, arg in enumerate(args)
|
| 28 |
+
]
|
| 29 |
+
if kwargs:
|
| 30 |
+
lines.append(f"kwargs: {kwargs}")
|
| 31 |
+
return textwrap.indent("\n".join(lines), " ")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class MissingOperatorWithoutDecomp(OperatorIssue):
|
| 35 |
+
def __init__(self, target, args, kwargs) -> None:
|
| 36 |
+
_record_missing_op(target)
|
| 37 |
+
super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MissingOperatorWithDecomp(OperatorIssue):
|
| 41 |
+
def __init__(self, target, args, kwargs) -> None:
|
| 42 |
+
_record_missing_op(target)
|
| 43 |
+
super().__init__(
|
| 44 |
+
f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
|
| 45 |
+
+ textwrap.dedent(
|
| 46 |
+
f"""
|
| 47 |
+
|
| 48 |
+
There is a decomposition available for {target} in
|
| 49 |
+
torch._decomp.get_decompositions(). Please add this operator to the
|
| 50 |
+
`decompositions` list in torch._inductor.decomposition
|
| 51 |
+
"""
|
| 52 |
+
)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class LoweringException(OperatorIssue):
|
| 57 |
+
def __init__(self, exc: Exception, target, args, kwargs) -> None:
|
| 58 |
+
super().__init__(
|
| 59 |
+
f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class SubgraphLoweringException(RuntimeError):
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class InvalidCxxCompiler(RuntimeError):
|
| 68 |
+
def __init__(self) -> None:
|
| 69 |
+
from . import config
|
| 70 |
+
|
| 71 |
+
super().__init__(
|
| 72 |
+
f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class CppWrapperCodeGenError(RuntimeError):
|
| 77 |
+
def __init__(self, msg: str) -> None:
|
| 78 |
+
super().__init__(f"C++ wrapper codegen error: {msg}")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class CppCompileError(RuntimeError):
|
| 82 |
+
def __init__(self, cmd: list[str], output: str) -> None:
|
| 83 |
+
if isinstance(output, bytes):
|
| 84 |
+
output = output.decode("utf-8")
|
| 85 |
+
|
| 86 |
+
super().__init__(
|
| 87 |
+
textwrap.dedent(
|
| 88 |
+
"""
|
| 89 |
+
C++ compile error
|
| 90 |
+
|
| 91 |
+
Command:
|
| 92 |
+
{cmd}
|
| 93 |
+
|
| 94 |
+
Output:
|
| 95 |
+
{output}
|
| 96 |
+
"""
|
| 97 |
+
)
|
| 98 |
+
.strip()
|
| 99 |
+
.format(cmd=" ".join(cmd), output=output)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class CUDACompileError(CppCompileError):
|
| 104 |
+
pass
|
.venv/lib/python3.11/site-packages/torch/_inductor/extern_node_serializer.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node
|
| 5 |
+
from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder
|
| 6 |
+
from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def serialize_extern_kernel_node(
|
| 10 |
+
extern_kernel_node: inductor_ExternKernelNode,
|
| 11 |
+
) -> ExternKernelNode:
|
| 12 |
+
assert isinstance(extern_kernel_node.node, Node)
|
| 13 |
+
return ExternKernelNode(
|
| 14 |
+
name=extern_kernel_node.name,
|
| 15 |
+
node=extern_kernel_node.node,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def extern_node_json_serializer(
|
| 20 |
+
extern_kernel_nodes: List[inductor_ExternKernelNode],
|
| 21 |
+
) -> str:
|
| 22 |
+
serialized_nodes = ExternKernelNodes(
|
| 23 |
+
nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes]
|
| 24 |
+
)
|
| 25 |
+
return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder)
|
.venv/lib/python3.11/site-packages/torch/_inductor/freezing.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import itertools
|
| 5 |
+
import logging
|
| 6 |
+
import weakref
|
| 7 |
+
from typing import Any, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils._pytree as pytree
|
| 11 |
+
from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code
|
| 12 |
+
from torch._functorch.aot_autograd import MutationType
|
| 13 |
+
from torch._functorch.compile_utils import fx_graph_cse
|
| 14 |
+
from torch._inductor.constant_folding import constant_fold, replace_node_with_constant
|
| 15 |
+
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
|
| 16 |
+
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
| 17 |
+
|
| 18 |
+
from . import config
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
prims = torch.ops.prims
|
| 23 |
+
|
| 24 |
+
log = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def replace_params_with_constants(
|
| 28 |
+
gm: torch.fx.GraphModule,
|
| 29 |
+
flat_params: list[Any],
|
| 30 |
+
fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta,
|
| 31 |
+
) -> List[int]:
|
| 32 |
+
"""
|
| 33 |
+
Replaces the parameters of a PyTorch GraphModule with constants wherever possible.
|
| 34 |
+
Returns a list of indices representing the input parameters that were not converted to constants.
|
| 35 |
+
"""
|
| 36 |
+
params = gm.graph.find_nodes(op="placeholder")
|
| 37 |
+
fake_inp_nodes = params[: len(params)]
|
| 38 |
+
preserved_arg_indices = []
|
| 39 |
+
aliased_input_args = [
|
| 40 |
+
out_info.base_idx
|
| 41 |
+
for out_info in fw_metadata.output_info
|
| 42 |
+
if out_info.base_idx is not None
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
# TODO (tmanlaibaatar) figure out why this is different
|
| 46 |
+
# from mutated_inp_runtime_indices
|
| 47 |
+
mutated_inps = [
|
| 48 |
+
i
|
| 49 |
+
for i, m in enumerate(fw_metadata.input_info)
|
| 50 |
+
if m.mutation_type
|
| 51 |
+
in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH)
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)):
|
| 55 |
+
if i in mutated_inps or i in aliased_input_args:
|
| 56 |
+
preserved_arg_indices.append(i)
|
| 57 |
+
continue
|
| 58 |
+
replace_node_with_constant(gm, node, real_input)
|
| 59 |
+
# add on non param inputs
|
| 60 |
+
preserved_arg_indices.extend(range(len(flat_params), len(params)))
|
| 61 |
+
# is this necessary ?
|
| 62 |
+
gm.recompile()
|
| 63 |
+
return preserved_arg_indices
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def freeze(
|
| 67 |
+
dynamo_gm: torch.fx.GraphModule,
|
| 68 |
+
aot_autograd_gm: torch.fx.GraphModule,
|
| 69 |
+
example_inputs: List[torch._subclasses.FakeTensor],
|
| 70 |
+
) -> Tuple[torch.fx.GraphModule, List[int]]:
|
| 71 |
+
"""
|
| 72 |
+
Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation
|
| 73 |
+
and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency.
|
| 74 |
+
|
| 75 |
+
Assumes that this function is run in dynamo tracing post aot_autograd.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
dynamo_gm (torch.fx.GraphModule): The Dynamo constructed GraphModule.
|
| 79 |
+
aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen.
|
| 80 |
+
example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices
|
| 84 |
+
of the inputs that were preserved (not turned into constants).
|
| 85 |
+
"""
|
| 86 |
+
# We have convert conv's weight to channels last which may meet error for .view
|
| 87 |
+
# when doing fake_tensor_prop. So we need to convert view to reshape first.
|
| 88 |
+
# See the details in fx_codegen_and_compile of compile_fx.py.
|
| 89 |
+
view_to_reshape(aot_autograd_gm)
|
| 90 |
+
|
| 91 |
+
if tracing_context := torch._guards.TracingContext.try_get():
|
| 92 |
+
fw_metadata = tracing_context.fw_metadata
|
| 93 |
+
params_flat = tracing_context.params_flat
|
| 94 |
+
assert fw_metadata is not None and params_flat is not None
|
| 95 |
+
|
| 96 |
+
preserved_arg_indices = replace_params_with_constants(
|
| 97 |
+
aot_autograd_gm, params_flat, fw_metadata
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
inputs = aot_autograd_gm.graph.find_nodes(op="placeholder")
|
| 101 |
+
preserved_arg_indices = list(range(len(inputs)))
|
| 102 |
+
|
| 103 |
+
# TODO - further restrict cse ? right now needed to dedup aliasing ops
|
| 104 |
+
cse_graph = fx_graph_cse(aot_autograd_gm.graph)
|
| 105 |
+
aot_autograd_gm.graph = cse_graph
|
| 106 |
+
aot_autograd_gm.recompile()
|
| 107 |
+
|
| 108 |
+
aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices]
|
| 109 |
+
freezing_passes(aot_autograd_gm, aot_example_inputs)
|
| 110 |
+
|
| 111 |
+
constant_fold(aot_autograd_gm)
|
| 112 |
+
# invalidate nn Modules
|
| 113 |
+
if config.freezing_discard_parameters:
|
| 114 |
+
invalidate_eager_modules()
|
| 115 |
+
discard_traced_gm_params(dynamo_gm)
|
| 116 |
+
|
| 117 |
+
log.debug(
|
| 118 |
+
"%s", lazy_format_graph_code("FROZEN GRAPH", aot_autograd_gm, colored=True)
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return aot_autograd_gm, preserved_arg_indices
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class ErasedTensor(torch.Tensor):
|
| 125 |
+
@staticmethod
|
| 126 |
+
def __new__(cls, elem, name, owning_mod):
|
| 127 |
+
return super().__new__(cls, elem.to(device="meta"))
|
| 128 |
+
|
| 129 |
+
def __init__(self, elem, name: Optional[str], mod) -> None:
|
| 130 |
+
self.erased_name = name
|
| 131 |
+
self.owning_mod_ref = weakref.ref(mod)
|
| 132 |
+
|
| 133 |
+
@classmethod
|
| 134 |
+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
| 135 |
+
erased_tensors = [
|
| 136 |
+
e
|
| 137 |
+
for e in pytree.arg_tree_leaves(*args, **kwargs)
|
| 138 |
+
if isinstance(e, ErasedTensor)
|
| 139 |
+
]
|
| 140 |
+
assert len(erased_tensors) > 0
|
| 141 |
+
e = erased_tensors[0]
|
| 142 |
+
|
| 143 |
+
raise RuntimeError(
|
| 144 |
+
f"Trying to run Pytorch Eager Module after Dynamo Freezing. "
|
| 145 |
+
"The original parameters have been discarded for memory efficiency. "
|
| 146 |
+
f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def invalidate_eager_modules():
|
| 151 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 152 |
+
for (
|
| 153 |
+
mod
|
| 154 |
+
) in torch._guards.TracingContext.get().module_context.nn_modules.values():
|
| 155 |
+
if not isinstance(mod, torch.nn.Module):
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
for attr_name, tensor in list(
|
| 159 |
+
itertools.chain(
|
| 160 |
+
mod.named_parameters(recurse=False),
|
| 161 |
+
mod.named_buffers(recurse=False),
|
| 162 |
+
)
|
| 163 |
+
):
|
| 164 |
+
with torch._dispatch.python.no_python_dispatcher():
|
| 165 |
+
e_t = ErasedTensor(tensor, attr_name, mod)
|
| 166 |
+
if isinstance(tensor, torch.nn.Parameter):
|
| 167 |
+
e_t.requires_grad_(True)
|
| 168 |
+
e_t._is_param = True # type: ignore[attr-defined]
|
| 169 |
+
setattr(mod, attr_name, e_t)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def discard_traced_gm_params(mod: torch.fx.GraphModule):
|
| 173 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 174 |
+
for attr_name, tensor in list(
|
| 175 |
+
itertools.chain(
|
| 176 |
+
mod.named_parameters(recurse=False), mod.named_buffers(recurse=False)
|
| 177 |
+
)
|
| 178 |
+
):
|
| 179 |
+
with torch._dispatch.python.no_python_dispatcher():
|
| 180 |
+
e_t = ErasedTensor(tensor, attr_name, mod)
|
| 181 |
+
if isinstance(tensor, torch.nn.Parameter):
|
| 182 |
+
e_t.requires_grad_(True)
|
| 183 |
+
e_t._is_param = True # type: ignore[attr-defined]
|
| 184 |
+
setattr(mod, attr_name, e_t)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def enforce_output_layout(gm: torch.fx.GraphModule):
|
| 188 |
+
"""
|
| 189 |
+
Make sure the output node's layout does not change due to compiler optimizations
|
| 190 |
+
by adding aten.as_strided nodes with the expected strides.
|
| 191 |
+
|
| 192 |
+
Only used for inference so we can assume all graph outputs are model outputs.
|
| 193 |
+
"""
|
| 194 |
+
*_, output_node = gm.graph.nodes
|
| 195 |
+
out_list = output_node.args[0]
|
| 196 |
+
with gm.graph.inserting_before(output_node):
|
| 197 |
+
for n in out_list:
|
| 198 |
+
if not isinstance(
|
| 199 |
+
n.meta["val"], torch.Tensor
|
| 200 |
+
) or not torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]):
|
| 201 |
+
continue
|
| 202 |
+
|
| 203 |
+
# add a node to enforce eager layout
|
| 204 |
+
ft = n.meta["val"]
|
| 205 |
+
new_node = gm.graph.call_function(
|
| 206 |
+
prims.inductor_force_stride_order.default, (n, ft.stride())
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# can not call
|
| 210 |
+
# n.replace_all_uses_with(new_node)
|
| 211 |
+
# since it will replace the usage of n in new_node itself.
|
| 212 |
+
output_node.replace_input_with(n, new_node)
|
| 213 |
+
|
| 214 |
+
gm.graph.lint()
|
| 215 |
+
gm.recompile()
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def enforce_as_strided_input_layout(gm: torch.fx.GraphModule):
|
| 219 |
+
"""
|
| 220 |
+
Make sure the as_strided node's input's layout does not change due to compiler
|
| 221 |
+
optimizations, because the as_strided strides info depends on input tensor stride info.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
as_strided_ops = [
|
| 225 |
+
torch.ops.aten.as_strided.default,
|
| 226 |
+
torch.ops.aten.as_strided_.default,
|
| 227 |
+
torch.ops.aten.as_strided_scatter.default,
|
| 228 |
+
]
|
| 229 |
+
strided_nodes = [n for n in gm.graph.nodes if n.target in as_strided_ops]
|
| 230 |
+
for n in strided_nodes:
|
| 231 |
+
with gm.graph.inserting_before(n):
|
| 232 |
+
# add a node to enforce eager layout
|
| 233 |
+
ft = n.args[0].meta["val"]
|
| 234 |
+
new_node = gm.graph.call_function(
|
| 235 |
+
prims.inductor_force_stride_order.default, (n.args[0], ft.stride())
|
| 236 |
+
)
|
| 237 |
+
n.replace_input_with(n.args[0], new_node)
|
| 238 |
+
|
| 239 |
+
gm.graph.lint()
|
| 240 |
+
gm.recompile()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule):
|
| 244 |
+
"""
|
| 245 |
+
Convert 4d convolution weight tensor to channels last format.
|
| 246 |
+
|
| 247 |
+
This pass is performed before freezing so the added nodes can be constant
|
| 248 |
+
folded by freezing.
|
| 249 |
+
"""
|
| 250 |
+
with dynamo_timed("convert_conv_weights_to_channels_last"):
|
| 251 |
+
convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
|
| 252 |
+
for conv in convs:
|
| 253 |
+
weight_node = conv.args[1]
|
| 254 |
+
if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[
|
| 255 |
+
"val"
|
| 256 |
+
].is_contiguous(memory_format=torch.channels_last):
|
| 257 |
+
# not a 4d tensor or already channels last, skip
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
with gm.graph.inserting_before(conv):
|
| 261 |
+
new_node = gm.graph.call_function(
|
| 262 |
+
aten.clone.default,
|
| 263 |
+
(weight_node,),
|
| 264 |
+
{"memory_format": torch.channels_last},
|
| 265 |
+
)
|
| 266 |
+
conv.replace_input_with(weight_node, new_node)
|
| 267 |
+
|
| 268 |
+
enforce_as_strided_input_layout(gm)
|
| 269 |
+
enforce_output_layout(gm)
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_utils.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import operator
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
|
| 5 |
+
|
| 6 |
+
import sympy
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.fx
|
| 10 |
+
from torch.fx.experimental.symbolic_shapes import (
|
| 11 |
+
compute_unbacked_bindings,
|
| 12 |
+
rebind_unbacked,
|
| 13 |
+
statically_known_true,
|
| 14 |
+
sym_eq,
|
| 15 |
+
)
|
| 16 |
+
from torch.utils import _pytree as pytree
|
| 17 |
+
from torch.utils._pytree import tree_map
|
| 18 |
+
|
| 19 |
+
from .virtualized import V
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
|
| 23 |
+
# Works for length 2 patterns with 1 module and 1 function/method.
|
| 24 |
+
def matches_module_function_pattern(
|
| 25 |
+
pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
|
| 26 |
+
node: torch.fx.node.Node,
|
| 27 |
+
modules: Dict[str, torch.nn.modules.Module],
|
| 28 |
+
) -> bool:
|
| 29 |
+
if len(node.args) == 0:
|
| 30 |
+
return False
|
| 31 |
+
if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
|
| 32 |
+
node, torch.fx.Node
|
| 33 |
+
):
|
| 34 |
+
return False
|
| 35 |
+
# the first node is call_module
|
| 36 |
+
if node.args[0].op != "call_module":
|
| 37 |
+
return False
|
| 38 |
+
if not isinstance(node.args[0].target, str):
|
| 39 |
+
return False
|
| 40 |
+
if node.args[0].target not in modules:
|
| 41 |
+
return False
|
| 42 |
+
if type(modules[node.args[0].target]) is not pattern[0]:
|
| 43 |
+
return False
|
| 44 |
+
# the second node is call_function or call_method
|
| 45 |
+
if node.op != "call_function" and node.op != "call_method":
|
| 46 |
+
return False
|
| 47 |
+
if node.target != pattern[1]:
|
| 48 |
+
return False
|
| 49 |
+
# make sure node.args[0] output is only used by current node.
|
| 50 |
+
if len(node.args[0].users) > 1:
|
| 51 |
+
return False
|
| 52 |
+
return True
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class FakeTensorUpdater:
|
| 56 |
+
"""
|
| 57 |
+
The main idea here is that it's difficult to maintain accurate fake
|
| 58 |
+
tensors (our primary form of metadata) for each node in our graph as we
|
| 59 |
+
transform it.
|
| 60 |
+
|
| 61 |
+
The most reliable way to obtain this information is by rerunning
|
| 62 |
+
faketensor propagation. However, in general, faketensor propagation is
|
| 63 |
+
fairly expensive. So, instead we'd like to only rerun faketensor
|
| 64 |
+
propagation on nodes that have changed.
|
| 65 |
+
|
| 66 |
+
In order to detect which nodes have changed, we first hash its node,
|
| 67 |
+
target, and argument lists (which are immutable in FX).
|
| 68 |
+
|
| 69 |
+
Then, whenever we call incremental_update, we check which FX nodes have a
|
| 70 |
+
new hash, and recompute the faketensor metadata for that node. Then, we
|
| 71 |
+
continue to recursively compute the faketensors for all users until the
|
| 72 |
+
fake tensors stop changing.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, graph: torch.fx.Graph) -> None:
|
| 76 |
+
self.processed_hashes = set()
|
| 77 |
+
self.graph = graph
|
| 78 |
+
|
| 79 |
+
for node in self.graph.nodes:
|
| 80 |
+
self.processed_hashes.add(self.hash_node(node))
|
| 81 |
+
|
| 82 |
+
def hash_node(self, node: torch.fx.Node):
|
| 83 |
+
# todo(chilli): Not a great hash function
|
| 84 |
+
return (node, node.target, id(node.args), id(node.kwargs))
|
| 85 |
+
|
| 86 |
+
def incremental_update(self):
|
| 87 |
+
processed = set()
|
| 88 |
+
existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
|
| 89 |
+
for node in self.graph.nodes:
|
| 90 |
+
existing_storages[get_node_storage(node)] += 1
|
| 91 |
+
|
| 92 |
+
def is_intlist_same(new, old):
|
| 93 |
+
return statically_known_true(sym_eq(new, old))
|
| 94 |
+
|
| 95 |
+
def is_fake_tensor_same(new, old):
|
| 96 |
+
if type(new) != type(old):
|
| 97 |
+
return False
|
| 98 |
+
if isinstance(new, (list, tuple)):
|
| 99 |
+
if len(new) != len(old):
|
| 100 |
+
return False
|
| 101 |
+
return all(
|
| 102 |
+
is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
|
| 103 |
+
)
|
| 104 |
+
if new is None:
|
| 105 |
+
return old is None
|
| 106 |
+
if not isinstance(new, torch.Tensor):
|
| 107 |
+
assert isinstance(
|
| 108 |
+
new, (torch.SymInt, torch.SymBool, torch.SymFloat)
|
| 109 |
+
), f"Unknown type {type(new)} in {self.graph}"
|
| 110 |
+
return (
|
| 111 |
+
new.node.shape_env._maybe_evaluate_static(
|
| 112 |
+
sympy.Eq(new.node.expr, old.node.expr)
|
| 113 |
+
)
|
| 114 |
+
== sympy.true
|
| 115 |
+
)
|
| 116 |
+
if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
|
| 117 |
+
return False
|
| 118 |
+
if new.layout == torch.strided and (
|
| 119 |
+
not is_intlist_same(new.stride(), old.stride())
|
| 120 |
+
or not statically_known_true(
|
| 121 |
+
new.storage_offset() == old.storage_offset()
|
| 122 |
+
)
|
| 123 |
+
):
|
| 124 |
+
return False
|
| 125 |
+
|
| 126 |
+
if new.device != old.device:
|
| 127 |
+
return False
|
| 128 |
+
|
| 129 |
+
if get_storage(new) == get_storage(old):
|
| 130 |
+
return True
|
| 131 |
+
|
| 132 |
+
# This is the case where it returns a completely fresh storage that's used nowhere else.
|
| 133 |
+
if (
|
| 134 |
+
existing_storages[get_storage(old)] == 1
|
| 135 |
+
and get_storage(new) not in existing_storages
|
| 136 |
+
):
|
| 137 |
+
return True
|
| 138 |
+
return False
|
| 139 |
+
|
| 140 |
+
def should_process_node(node):
|
| 141 |
+
# node.target for nodes returning true from this function
|
| 142 |
+
# are called under fake mode and does not work for inductor
|
| 143 |
+
# lowerings. We check if the node.target is an aten operator
|
| 144 |
+
# or operator.getitem which is used when returning multiple
|
| 145 |
+
# tensors from an op.
|
| 146 |
+
return node.op == "call_function" and (
|
| 147 |
+
isinstance(node.target, torch._ops.OpOverload)
|
| 148 |
+
or node.target == operator.getitem
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
to_process = set()
|
| 152 |
+
for node in self.graph.nodes:
|
| 153 |
+
if (
|
| 154 |
+
self.hash_node(node) in self.processed_hashes
|
| 155 |
+
and id(node) not in to_process
|
| 156 |
+
):
|
| 157 |
+
continue
|
| 158 |
+
|
| 159 |
+
if not should_process_node(node):
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
is_valid, args, kwargs = get_fake_args_kwargs(node)
|
| 163 |
+
if not is_valid:
|
| 164 |
+
continue
|
| 165 |
+
with V.fake_mode:
|
| 166 |
+
new_fake_tensor = node.target(*args, **kwargs)
|
| 167 |
+
if "val" in node.meta and is_fake_tensor_same(
|
| 168 |
+
new_fake_tensor, node.meta["val"]
|
| 169 |
+
):
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor)
|
| 173 |
+
|
| 174 |
+
node.meta["val"] = new_fake_tensor
|
| 175 |
+
if (shape_env := V.fake_mode.shape_env) and (
|
| 176 |
+
symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
|
| 177 |
+
):
|
| 178 |
+
# Refresh the bindings to the new symbols
|
| 179 |
+
node.meta["unbacked_bindings"] = symbol_to_path
|
| 180 |
+
|
| 181 |
+
existing_storages[get_node_storage(node)] += 1
|
| 182 |
+
|
| 183 |
+
to_process.update([id(user) for user in node.users])
|
| 184 |
+
|
| 185 |
+
self.processed_hashes.add(self.hash_node(node))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_storage(t: torch.Tensor) -> int:
|
| 189 |
+
return t.untyped_storage()._cdata
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_node_storage(node: torch.fx.Node) -> Optional[int]:
|
| 193 |
+
if "val" not in node.meta:
|
| 194 |
+
return None
|
| 195 |
+
if not isinstance(node.meta["val"], torch.Tensor):
|
| 196 |
+
return None
|
| 197 |
+
if not torch._C._has_storage(node.meta["val"]):
|
| 198 |
+
return None
|
| 199 |
+
return get_storage(node.meta["val"])
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def get_fake(x):
|
| 203 |
+
if isinstance(x, torch.fx.Node):
|
| 204 |
+
if "val" not in x.meta:
|
| 205 |
+
return x
|
| 206 |
+
return x.meta["val"]
|
| 207 |
+
return x
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]:
|
| 211 |
+
"""
|
| 212 |
+
First value returns a boolean if any of the input nodes don't have a faketensor.
|
| 213 |
+
"""
|
| 214 |
+
args, kwargs = tree_map(get_fake, (x.args, x.kwargs))
|
| 215 |
+
if any(
|
| 216 |
+
isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs)
|
| 217 |
+
):
|
| 218 |
+
return False, args, kwargs
|
| 219 |
+
return True, args, kwargs
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def is_node_realized(node: torch.fx.Node) -> bool:
|
| 223 |
+
"""Returns true if a node is always realized when lowered to inductor IR.
|
| 224 |
+
|
| 225 |
+
NOTE: This may return some false negatives. e.g. it doesn't
|
| 226 |
+
handle buffers realized heuristically during lowering, or
|
| 227 |
+
buffers realized indirectly through view ops.
|
| 228 |
+
"""
|
| 229 |
+
from torch._inductor.lowering import fallbacks, needs_realized_inputs
|
| 230 |
+
|
| 231 |
+
def is_buffer(node: torch.fx.Node) -> bool:
|
| 232 |
+
if node.op == "call_function" and node.target is operator.getitem:
|
| 233 |
+
# For nodes with multiple outputs, we get the fx graph:
|
| 234 |
+
# foo = torch.ops.aten.foo(...)
|
| 235 |
+
# getitem = foo[0]
|
| 236 |
+
# getitem_1 = foo[1]
|
| 237 |
+
# where we need to check if foo is a fallback kernel
|
| 238 |
+
return is_buffer(node.args[0]) # type: ignore[arg-type]
|
| 239 |
+
return node.op in ("placeholder", "output") or node.target in fallbacks
|
| 240 |
+
|
| 241 |
+
if is_buffer(node):
|
| 242 |
+
return True
|
| 243 |
+
|
| 244 |
+
def realizes_inputs(node: torch.fx.Node) -> bool:
|
| 245 |
+
return node.op == "output" or node.target in needs_realized_inputs
|
| 246 |
+
|
| 247 |
+
if any(realizes_inputs(user) for user in node.users):
|
| 248 |
+
return True
|
| 249 |
+
|
| 250 |
+
# Otherwise, assume node isn't realized
|
| 251 |
+
return False
|
.venv/lib/python3.11/site-packages/torch/_inductor/graph.py
ADDED
|
@@ -0,0 +1,1930 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import itertools
|
| 3 |
+
import logging
|
| 4 |
+
import operator
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
from contextlib import contextmanager
|
| 11 |
+
from types import ModuleType
|
| 12 |
+
from typing import (
|
| 13 |
+
Any,
|
| 14 |
+
Callable,
|
| 15 |
+
DefaultDict,
|
| 16 |
+
Dict,
|
| 17 |
+
Iterable,
|
| 18 |
+
List,
|
| 19 |
+
NoReturn,
|
| 20 |
+
Optional,
|
| 21 |
+
Sequence,
|
| 22 |
+
Tuple,
|
| 23 |
+
TYPE_CHECKING,
|
| 24 |
+
Union,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
import sympy
|
| 28 |
+
from sympy import Expr
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch._logging
|
| 32 |
+
import torch.fx
|
| 33 |
+
from torch import device, Tensor
|
| 34 |
+
from torch._decomp import get_decompositions
|
| 35 |
+
from torch._dynamo.utils import defake, dynamo_timed
|
| 36 |
+
from torch._logging import LazyString, trace_structured
|
| 37 |
+
from torch._prims_common import make_channels_last_strides_for
|
| 38 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 39 |
+
from torch.fx import GraphModule
|
| 40 |
+
from torch.fx.experimental._backward_state import BackwardState
|
| 41 |
+
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
|
| 42 |
+
from torch.fx.experimental.symbolic_shapes import (
|
| 43 |
+
free_unbacked_symbols,
|
| 44 |
+
has_free_symbols,
|
| 45 |
+
resolve_unbacked_bindings,
|
| 46 |
+
RuntimeAssert,
|
| 47 |
+
ShapeEnv,
|
| 48 |
+
SymTypes,
|
| 49 |
+
)
|
| 50 |
+
from torch.fx.graph import Graph
|
| 51 |
+
from torch.fx.node import Node
|
| 52 |
+
from torch.utils._mode_utils import no_dispatch
|
| 53 |
+
from torch.utils._ordered_set import OrderedSet
|
| 54 |
+
from torch.utils._sympy.numbers import int_oo
|
| 55 |
+
|
| 56 |
+
from . import config, ir
|
| 57 |
+
from .codegen.common import (
|
| 58 |
+
BackendFeature,
|
| 59 |
+
DeviceOpOverrides,
|
| 60 |
+
get_backend_features,
|
| 61 |
+
get_device_op_overrides,
|
| 62 |
+
get_wrapper_codegen_for_device,
|
| 63 |
+
init_backend_registration,
|
| 64 |
+
)
|
| 65 |
+
from .exc import (
|
| 66 |
+
CppWrapperCodeGenError,
|
| 67 |
+
LoweringException,
|
| 68 |
+
MissingOperatorWithDecomp,
|
| 69 |
+
MissingOperatorWithoutDecomp,
|
| 70 |
+
)
|
| 71 |
+
from .ir import (
|
| 72 |
+
Constant,
|
| 73 |
+
FixedLayout,
|
| 74 |
+
get_device_type,
|
| 75 |
+
InputBuffer,
|
| 76 |
+
Pointwise,
|
| 77 |
+
Reduction,
|
| 78 |
+
StorageBox,
|
| 79 |
+
TensorBox,
|
| 80 |
+
TorchBindObject,
|
| 81 |
+
)
|
| 82 |
+
from .lowering import (
|
| 83 |
+
FALLBACK_ALLOW_LIST,
|
| 84 |
+
fallback_handler,
|
| 85 |
+
fallback_node_due_to_unsupported_type,
|
| 86 |
+
lowerings,
|
| 87 |
+
make_fallback,
|
| 88 |
+
maybe_layout_constraints,
|
| 89 |
+
needs_realized_inputs,
|
| 90 |
+
unsupported_output_tensor,
|
| 91 |
+
)
|
| 92 |
+
from .scheduler import BaseSchedulerNode
|
| 93 |
+
from .sizevars import SizeVarAllocator
|
| 94 |
+
from .utils import (
|
| 95 |
+
convert_shape_to_inductor,
|
| 96 |
+
gather_origins,
|
| 97 |
+
get_cloned_parameter_buffer_name,
|
| 98 |
+
get_sympy_Expr_dtype,
|
| 99 |
+
maybe_get_suppress_shape_guards_ctx,
|
| 100 |
+
should_assume_input_aligned,
|
| 101 |
+
)
|
| 102 |
+
from .virtualized import NullHandler, V
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if TYPE_CHECKING:
|
| 106 |
+
from torch._higher_order_ops.effects import _EffectType
|
| 107 |
+
from .codegen.wrapper import WrapperCodeGen
|
| 108 |
+
|
| 109 |
+
from torch._inductor.codecache import output_code_log
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
log = logging.getLogger(__name__)
|
| 113 |
+
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
| 114 |
+
|
| 115 |
+
aten = torch.ops.aten
|
| 116 |
+
|
| 117 |
+
_post_grad_graph_counter = itertools.count()
|
| 118 |
+
|
| 119 |
+
if config.is_fbcode():
|
| 120 |
+
from torch._inductor.fb.utils import log_module_code
|
| 121 |
+
else:
|
| 122 |
+
|
| 123 |
+
def log_module_code(*args: Any, **kwargs: Any) -> None:
|
| 124 |
+
pass
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def supported_dtype_of_cpp_wrapper(dtype: torch.device, cuda: bool) -> bool:
|
| 128 |
+
supported_dtype = {
|
| 129 |
+
torch.float32,
|
| 130 |
+
torch.float64,
|
| 131 |
+
torch.int64,
|
| 132 |
+
torch.int32,
|
| 133 |
+
torch.int16,
|
| 134 |
+
torch.int8,
|
| 135 |
+
torch.uint8,
|
| 136 |
+
torch.bool,
|
| 137 |
+
torch.bfloat16,
|
| 138 |
+
torch.complex32,
|
| 139 |
+
torch.complex64,
|
| 140 |
+
torch.complex128,
|
| 141 |
+
torch.float16,
|
| 142 |
+
}
|
| 143 |
+
if cuda:
|
| 144 |
+
supported_dtype.add(torch.float8_e4m3fn)
|
| 145 |
+
supported_dtype.add(torch.float8_e5m2)
|
| 146 |
+
supported_dtype.add(torch.float8_e4m3fnuz)
|
| 147 |
+
supported_dtype.add(torch.float8_e5m2fnuz)
|
| 148 |
+
|
| 149 |
+
return dtype in supported_dtype
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]:
|
| 153 |
+
assert isinstance(
|
| 154 |
+
constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
|
| 155 |
+
), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
|
| 156 |
+
if isinstance(constant_buffer, sympy.core.numbers.Integer):
|
| 157 |
+
return torch.int64
|
| 158 |
+
|
| 159 |
+
if isinstance(constant_buffer, sympy.Expr):
|
| 160 |
+
return get_sympy_Expr_dtype(constant_buffer)
|
| 161 |
+
|
| 162 |
+
if constant_buffer.is_integer:
|
| 163 |
+
return torch.int64
|
| 164 |
+
elif constant_buffer.is_float:
|
| 165 |
+
return torch.float32
|
| 166 |
+
else:
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def is_magic_method(op: Any) -> bool:
|
| 171 |
+
magic_ops = {method_to_operator(m) for m in magic_methods}
|
| 172 |
+
return op in magic_ops
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def getattr_recursive(
|
| 176 |
+
obj: GraphModule, target: str
|
| 177 |
+
) -> Union[Tensor, torch._C.ScriptObject, GraphModule]:
|
| 178 |
+
target_atoms = target.split(".")
|
| 179 |
+
attr_itr = obj
|
| 180 |
+
for i, atom in enumerate(target_atoms):
|
| 181 |
+
if not hasattr(attr_itr, atom):
|
| 182 |
+
raise RuntimeError(
|
| 183 |
+
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
|
| 184 |
+
)
|
| 185 |
+
attr_itr = getattr(attr_itr, atom)
|
| 186 |
+
return attr_itr
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def mark_nodes_dislike_padding(
|
| 190 |
+
g: Graph, user_visible_outputs: Optional[Dict[str, None]]
|
| 191 |
+
) -> None:
|
| 192 |
+
"""
|
| 193 |
+
Nodes like convolution/convolution_backward want its input to be dense.
|
| 194 |
+
If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction.
|
| 195 |
+
|
| 196 |
+
The pass finds nodes that dislike padding. These are nodes that can be reached
|
| 197 |
+
from a convolution/convolution_backward in the backward direction without
|
| 198 |
+
going thru a reduction.
|
| 199 |
+
"""
|
| 200 |
+
if not config.comprehensive_padding:
|
| 201 |
+
return
|
| 202 |
+
ops_dislike_padding = {
|
| 203 |
+
aten.convolution,
|
| 204 |
+
aten.convolution_backward,
|
| 205 |
+
}
|
| 206 |
+
# what's a better way to collect the reduction ops?
|
| 207 |
+
ops_like_padding = {
|
| 208 |
+
aten.var_mean,
|
| 209 |
+
aten.sum,
|
| 210 |
+
aten.mean,
|
| 211 |
+
aten.prod,
|
| 212 |
+
aten.any,
|
| 213 |
+
aten.amin,
|
| 214 |
+
aten.amax,
|
| 215 |
+
aten.min,
|
| 216 |
+
aten.max,
|
| 217 |
+
aten.argmin,
|
| 218 |
+
aten.argmax,
|
| 219 |
+
aten.scatter_reduce,
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
def _get_overload_packet(
|
| 223 |
+
node: torch.fx.Node,
|
| 224 |
+
) -> Optional[torch._ops.OpOverloadPacket]:
|
| 225 |
+
return (
|
| 226 |
+
node.target._overloadpacket
|
| 227 |
+
if node.op == "call_function"
|
| 228 |
+
# hasattr on OpOverloadPacket is slow, do isinstance first
|
| 229 |
+
and isinstance(node.target, torch._ops.OpOverload)
|
| 230 |
+
and hasattr(node.target, "_overloadpacket")
|
| 231 |
+
else None
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
for cur in reversed(g.nodes):
|
| 235 |
+
op = _get_overload_packet(cur)
|
| 236 |
+
if not op:
|
| 237 |
+
continue
|
| 238 |
+
if op in ops_dislike_padding:
|
| 239 |
+
cur.meta["dislike_padding"] = True
|
| 240 |
+
|
| 241 |
+
if cur.meta.get("dislike_padding", False):
|
| 242 |
+
# propagate
|
| 243 |
+
for prior in cur.all_input_nodes:
|
| 244 |
+
prior_op = _get_overload_packet(prior)
|
| 245 |
+
if not prior_op:
|
| 246 |
+
continue
|
| 247 |
+
if prior_op not in ops_like_padding:
|
| 248 |
+
prior.meta["dislike_padding"] = True
|
| 249 |
+
# We only want to mark output nodes. So, move it after the above prior nodes process.
|
| 250 |
+
if (
|
| 251 |
+
not config.pad_outputs
|
| 252 |
+
and user_visible_outputs
|
| 253 |
+
and cur.name in user_visible_outputs
|
| 254 |
+
):
|
| 255 |
+
cur.meta["dislike_padding"] = True
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class GraphLowering(torch.fx.Interpreter):
|
| 259 |
+
graph_outputs: List[ir.IRNode]
|
| 260 |
+
|
| 261 |
+
def symbolic_sizes_strides(
|
| 262 |
+
self, ex: torch.Tensor
|
| 263 |
+
) -> Tuple[Union[List[int], List[Expr]], Union[List[int], List[Expr]]]:
|
| 264 |
+
"""
|
| 265 |
+
Support dynamic shapes and dynamic strides by assigning variables
|
| 266 |
+
to each dimension. We duck-shape tensors, so if two tensors
|
| 267 |
+
have the same size they get assigned the same symbolic variable.
|
| 268 |
+
"""
|
| 269 |
+
if self.reuse_shape_env:
|
| 270 |
+
return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor(
|
| 271 |
+
ex.stride()
|
| 272 |
+
)
|
| 273 |
+
else:
|
| 274 |
+
from torch._dynamo.source import ConstantSource
|
| 275 |
+
|
| 276 |
+
# TODO: this should not be needed once #93059 lands
|
| 277 |
+
# https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816
|
| 278 |
+
# TODO: make a dedicated UnknownSource for this?
|
| 279 |
+
# NB: This is using the legacy default behavior from
|
| 280 |
+
# create_symbolic_sizes_strides_storage_offset but we hope we can
|
| 281 |
+
# just delete this entirely
|
| 282 |
+
source = ConstantSource(
|
| 283 |
+
f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}"
|
| 284 |
+
)
|
| 285 |
+
(
|
| 286 |
+
size,
|
| 287 |
+
stride,
|
| 288 |
+
_,
|
| 289 |
+
) = self._shape_env.create_symbolic_sizes_strides_storage_offset(
|
| 290 |
+
ex,
|
| 291 |
+
source,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size]
|
| 295 |
+
stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride]
|
| 296 |
+
return size, stride
|
| 297 |
+
|
| 298 |
+
def static_sizes_strides(
|
| 299 |
+
self, ex: torch.Tensor
|
| 300 |
+
) -> Tuple[List[sympy.Expr], List[sympy.Expr]]:
|
| 301 |
+
"""
|
| 302 |
+
Primarily used to weights
|
| 303 |
+
"""
|
| 304 |
+
size = [sympy.Integer(i) for i in ex.size()]
|
| 305 |
+
stride = [sympy.Integer(i) for i in ex.stride()]
|
| 306 |
+
return size, stride
|
| 307 |
+
|
| 308 |
+
def __init__(
|
| 309 |
+
self,
|
| 310 |
+
gm: torch.fx.GraphModule,
|
| 311 |
+
example_inputs: Optional[List[torch.Tensor]] = None,
|
| 312 |
+
shape_env: Optional[ShapeEnv] = None,
|
| 313 |
+
graph_id: Optional[int] = None,
|
| 314 |
+
cpp_wrapper: bool = False,
|
| 315 |
+
aot_mode: bool = False,
|
| 316 |
+
user_visible_outputs: Optional[Dict[str, None]] = None,
|
| 317 |
+
layout_opt: Optional[bool] = None,
|
| 318 |
+
extern_node_serializer: Optional[
|
| 319 |
+
Callable[[List[ir.ExternKernelNode]], Any]
|
| 320 |
+
] = None,
|
| 321 |
+
is_inference: bool = False,
|
| 322 |
+
is_const_graph: bool = False,
|
| 323 |
+
const_output_index: Optional[Dict[str, int]] = None,
|
| 324 |
+
const_code: Optional[str] = None,
|
| 325 |
+
const_module: Optional["GraphLowering"] = None,
|
| 326 |
+
name: Optional[str] = None,
|
| 327 |
+
) -> None:
|
| 328 |
+
super().__init__(gm)
|
| 329 |
+
self.example_inputs = example_inputs
|
| 330 |
+
self.layout_opt = (
|
| 331 |
+
layout_opt
|
| 332 |
+
if layout_opt is not None
|
| 333 |
+
else self.decide_layout_opt(gm, is_inference=is_inference)
|
| 334 |
+
)
|
| 335 |
+
self.num_channels_last_conv = 0
|
| 336 |
+
self.is_inference = is_inference
|
| 337 |
+
self.is_const_graph = is_const_graph
|
| 338 |
+
self.const_code = const_code
|
| 339 |
+
self.const_module = const_module
|
| 340 |
+
|
| 341 |
+
self.extra_traceback = False # we do our own error wrapping
|
| 342 |
+
if shape_env is None:
|
| 343 |
+
shape_env = ShapeEnv()
|
| 344 |
+
self.reuse_shape_env = False
|
| 345 |
+
else:
|
| 346 |
+
self._shape_env = shape_env
|
| 347 |
+
self.reuse_shape_env = True
|
| 348 |
+
self._shape_env = shape_env
|
| 349 |
+
# We are going to start code generating runtime asserts, so make sure
|
| 350 |
+
# you don't start adding new ones in the lowering process
|
| 351 |
+
shape_env.freeze_runtime_asserts()
|
| 352 |
+
# We're going to mutate ras_by_symbol as we finish generating them
|
| 353 |
+
self.ras_by_symbol: Dict[
|
| 354 |
+
sympy.Symbol, List[RuntimeAssert]
|
| 355 |
+
] = shape_env.deferred_runtime_asserts.copy()
|
| 356 |
+
self.bound_unbacked_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
|
| 357 |
+
self.sizevars = SizeVarAllocator(shape_env)
|
| 358 |
+
self.graph_input_names: List[str] = []
|
| 359 |
+
self.graph_inputs: Dict[str, TensorBox] = {}
|
| 360 |
+
self.graph_inputs_original: Dict[str, InputBuffer] = {}
|
| 361 |
+
self.zero_dim_cpu_tensor_list: OrderedSet[str] = OrderedSet()
|
| 362 |
+
self.device_types: OrderedSet[str] = (
|
| 363 |
+
const_module.device_types if const_module else OrderedSet()
|
| 364 |
+
)
|
| 365 |
+
self.device_idxs: OrderedSet[int] = (
|
| 366 |
+
const_module.device_idxs if const_module else OrderedSet()
|
| 367 |
+
)
|
| 368 |
+
self.cuda = False
|
| 369 |
+
self.buffers: List[ir.Buffer] = []
|
| 370 |
+
self.operations: List[ir.Operation] = []
|
| 371 |
+
self.const_output_index: Dict[str, int] = (
|
| 372 |
+
const_output_index if const_output_index else {}
|
| 373 |
+
)
|
| 374 |
+
self.folded_constants: OrderedSet[str] = (
|
| 375 |
+
OrderedSet(const_output_index.keys())
|
| 376 |
+
if const_output_index
|
| 377 |
+
else OrderedSet()
|
| 378 |
+
)
|
| 379 |
+
self.constants: Dict[str, torch.Tensor] = (
|
| 380 |
+
const_module.constants if const_module else {}
|
| 381 |
+
)
|
| 382 |
+
self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {}
|
| 383 |
+
self.constant_reprs: Dict[str, str] = {}
|
| 384 |
+
self.removed_operations: OrderedSet[str] = OrderedSet()
|
| 385 |
+
self.removed_buffers: OrderedSet[str] = OrderedSet()
|
| 386 |
+
self.removed_inplace_buffers: OrderedSet[str] = OrderedSet()
|
| 387 |
+
self.mutated_buffers: OrderedSet[str] = OrderedSet()
|
| 388 |
+
self.never_reuse_buffers: OrderedSet[str] = OrderedSet()
|
| 389 |
+
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
|
| 390 |
+
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
|
| 391 |
+
self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment]
|
| 392 |
+
# See `ProxyExecutor Design Note` in ir.py for more details
|
| 393 |
+
self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
|
| 394 |
+
|
| 395 |
+
from torch._inductor.extern_node_serializer import extern_node_json_serializer
|
| 396 |
+
|
| 397 |
+
self.extern_node_serializer: Callable[[List[ir.ExternKernelNode]], Any] = (
|
| 398 |
+
extern_node_serializer
|
| 399 |
+
if config.is_fbcode() and extern_node_serializer
|
| 400 |
+
else extern_node_json_serializer
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
self.current_node: torch.fx.Node = None # type: ignore[assignment]
|
| 404 |
+
self.lists: Dict[str, List[str]] = {}
|
| 405 |
+
self.mutated_inputs: OrderedSet[str] = OrderedSet()
|
| 406 |
+
self.mutated_input_idxs: List[int] = []
|
| 407 |
+
self.name_to_buffer: Dict[str, ir.Buffer] = {}
|
| 408 |
+
self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
|
| 409 |
+
self.name_to_op: Dict[str, ir.Operation] = {}
|
| 410 |
+
self.creation_time = time.time()
|
| 411 |
+
self.name = name # type: ignore[assignment]
|
| 412 |
+
self.cpp_wrapper = cpp_wrapper
|
| 413 |
+
|
| 414 |
+
# record multi_kernel choice for cpp_wrapper so the second pass knows
|
| 415 |
+
# which sub-kernel is picked. Copy cpp_wrapper to another variable
|
| 416 |
+
# since cpp_wrapper flag is OrderedSet to false for the first pass of codegen.
|
| 417 |
+
self.record_multi_kernel_choice = cpp_wrapper
|
| 418 |
+
self.multi_kernel_to_choice: Dict[str, int] = {}
|
| 419 |
+
|
| 420 |
+
self.aot_mode = aot_mode
|
| 421 |
+
self.graph_id = graph_id
|
| 422 |
+
self.post_grad_graph_id = next(_post_grad_graph_counter)
|
| 423 |
+
self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment]
|
| 424 |
+
self.nodes_prefer_channels_last = (
|
| 425 |
+
self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet()
|
| 426 |
+
)
|
| 427 |
+
self._warned_fallback = {"aten.convolution_backward"}
|
| 428 |
+
self.user_visible_outputs = (
|
| 429 |
+
user_visible_outputs if user_visible_outputs is not None else {}
|
| 430 |
+
)
|
| 431 |
+
mark_nodes_dislike_padding(gm.graph, user_visible_outputs)
|
| 432 |
+
self.cache_key: str = "" # This is the cache key for the compiled artifact
|
| 433 |
+
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
|
| 434 |
+
self.cache_linemap: List[
|
| 435 |
+
Tuple[int, str]
|
| 436 |
+
] = (
|
| 437 |
+
[]
|
| 438 |
+
) # This is the linemap used by the profiler to mark custom compiled kernels getting run
|
| 439 |
+
# Used if lowering encounters cases where cudagraphs are not supported
|
| 440 |
+
self.disable_cudagraphs_reason: Optional[str] = None
|
| 441 |
+
|
| 442 |
+
# only keeping one node per device for stack trace purposes
|
| 443 |
+
self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
|
| 444 |
+
self.orig_gm: torch.fx.GraphModule = gm.__copy__()
|
| 445 |
+
self.dynamo_flat_name_to_original_fqn = self.module.meta.get(
|
| 446 |
+
"dynamo_flat_name_to_original_fqn", {}
|
| 447 |
+
)
|
| 448 |
+
self.allocated_constant_name: Dict[str, str] = (
|
| 449 |
+
const_module.allocated_constant_name if const_module is not None else {}
|
| 450 |
+
)
|
| 451 |
+
init_backend_registration()
|
| 452 |
+
self.get_backend_features = functools.lru_cache(None)(get_backend_features)
|
| 453 |
+
|
| 454 |
+
self.effectful_ops: Dict[_EffectType, ir.Buffer] = {}
|
| 455 |
+
self.aligned_inputs: OrderedSet[str] = OrderedSet()
|
| 456 |
+
self.no_fuse_buffer_names: OrderedSet[str] = OrderedSet()
|
| 457 |
+
|
| 458 |
+
# Below field is related to printing debug intermediate tensor values info for debugging
|
| 459 |
+
self.all_codegen_kernel_names: OrderedSet[str] = OrderedSet()
|
| 460 |
+
|
| 461 |
+
def has_feature(
|
| 462 |
+
self, device: Union[torch._inductor.ir.IRNode, device], feature: BackendFeature
|
| 463 |
+
) -> bool:
|
| 464 |
+
assert isinstance(feature, BackendFeature), feature
|
| 465 |
+
return feature in self.get_backend_features(get_device_type(device))
|
| 466 |
+
|
| 467 |
+
@staticmethod
|
| 468 |
+
def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool:
|
| 469 |
+
"""
|
| 470 |
+
Decide if we should enable layout optimization for this graph based on
|
| 471 |
+
heuristics.
|
| 472 |
+
"""
|
| 473 |
+
if not config.layout_optimization:
|
| 474 |
+
return False
|
| 475 |
+
|
| 476 |
+
if config.force_layout_optimization:
|
| 477 |
+
return True
|
| 478 |
+
|
| 479 |
+
conv_nodes = [
|
| 480 |
+
n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default
|
| 481 |
+
]
|
| 482 |
+
nconv = len(conv_nodes)
|
| 483 |
+
|
| 484 |
+
if nconv == 0:
|
| 485 |
+
return False
|
| 486 |
+
|
| 487 |
+
# For cpu backend and mkldnn enabled, we always use channels_last for better performance.
|
| 488 |
+
if (
|
| 489 |
+
torch.backends.mkldnn.enabled
|
| 490 |
+
and torch.backends.mkldnn.is_available()
|
| 491 |
+
and all(
|
| 492 |
+
n.args[idx].meta["val"].device == torch.device("cpu")
|
| 493 |
+
for n in conv_nodes
|
| 494 |
+
for idx in [0, 1]
|
| 495 |
+
)
|
| 496 |
+
):
|
| 497 |
+
return True
|
| 498 |
+
|
| 499 |
+
# Following models are skipped due to this:
|
| 500 |
+
# jx_nest_base
|
| 501 |
+
# volo_d1_224
|
| 502 |
+
if len(list(gm.graph.nodes)) >= 300 * nconv:
|
| 503 |
+
log.debug("Skipped layout opt because only a few conv")
|
| 504 |
+
return False
|
| 505 |
+
|
| 506 |
+
if any(
|
| 507 |
+
has_free_symbols(n.args[idx].meta["val"])
|
| 508 |
+
for n in conv_nodes
|
| 509 |
+
for idx in [0, 1]
|
| 510 |
+
):
|
| 511 |
+
log.debug(
|
| 512 |
+
"See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670"
|
| 513 |
+
)
|
| 514 |
+
return False
|
| 515 |
+
|
| 516 |
+
def is_grouped(n: Any) -> bool:
|
| 517 |
+
meta_val = n.args[1].meta["val"] # type: ignore[union-attr, operator]
|
| 518 |
+
assert isinstance(meta_val, torch.Tensor)
|
| 519 |
+
return n.args[-1] > 1 and meta_val.size(1) > 1 # type: ignore[union-attr, operator]
|
| 520 |
+
|
| 521 |
+
def is_in_out_channel(n: torch.fx.Node) -> bool:
|
| 522 |
+
return (
|
| 523 |
+
n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) # type: ignore[union-attr, operator]
|
| 524 |
+
and n.args[1].meta["val"].size(2) > 1 # type: ignore[union-attr, operator]
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
def is_small_channel(n: torch.fx.Node) -> bool:
|
| 528 |
+
return (
|
| 529 |
+
n.args[1].meta["val"].size(0) <= 64 # type: ignore[union-attr, operator]
|
| 530 |
+
and n.args[1].meta["val"].size(1) <= 64 # type: ignore[union-attr, operator]
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
# only grouped convolutions benchmarked as slower in conv samples for inference only
|
| 534 |
+
if is_inference:
|
| 535 |
+
from torch.utils.flop_counter import FlopCounterMode
|
| 536 |
+
|
| 537 |
+
flop_counts: Dict[str, float] = defaultdict(float)
|
| 538 |
+
for node in conv_nodes:
|
| 539 |
+
success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(
|
| 540 |
+
node
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
if success:
|
| 544 |
+
with FlopCounterMode(display=False) as flop_counter_mode:
|
| 545 |
+
with V.fake_mode:
|
| 546 |
+
node.target(*args, **kwargs)
|
| 547 |
+
|
| 548 |
+
counted_flops = flop_counter_mode.get_total_flops()
|
| 549 |
+
if is_grouped(node):
|
| 550 |
+
node_type = "grouped"
|
| 551 |
+
elif is_small_channel(node):
|
| 552 |
+
node_type = "small"
|
| 553 |
+
elif is_in_out_channel(node):
|
| 554 |
+
node_type = "in_out"
|
| 555 |
+
else:
|
| 556 |
+
node_type = "default"
|
| 557 |
+
|
| 558 |
+
flop_counts[node_type] += counted_flops
|
| 559 |
+
else:
|
| 560 |
+
log.debug("Conv inputs meta not found")
|
| 561 |
+
|
| 562 |
+
# average benchmarked channels last speedup / slowdown, < 1 is speedup.
|
| 563 |
+
# taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/
|
| 564 |
+
# To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb
|
| 565 |
+
GROUPED_MULTIPLIER = 1.358
|
| 566 |
+
DEFAULT_MULTIPLIER = 0.823
|
| 567 |
+
IN_OUT_MULTIPLIER = 0.725
|
| 568 |
+
SMALL_MULTIPLIER = 0.783
|
| 569 |
+
|
| 570 |
+
total_flops = sum(flop_counts.values())
|
| 571 |
+
# TODO - get different values per hardware
|
| 572 |
+
weighted_flops = (
|
| 573 |
+
flop_counts["grouped"] * GROUPED_MULTIPLIER
|
| 574 |
+
+ flop_counts["small"] * SMALL_MULTIPLIER
|
| 575 |
+
+ flop_counts["in_out"] * IN_OUT_MULTIPLIER
|
| 576 |
+
+ flop_counts["default"] * DEFAULT_MULTIPLIER
|
| 577 |
+
)
|
| 578 |
+
do_layout_opt = weighted_flops <= total_flops
|
| 579 |
+
if not do_layout_opt:
|
| 580 |
+
log.debug(
|
| 581 |
+
"Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d",
|
| 582 |
+
total_flops,
|
| 583 |
+
weighted_flops,
|
| 584 |
+
)
|
| 585 |
+
return do_layout_opt
|
| 586 |
+
|
| 587 |
+
# Channels last layout can dramatically hurt grouped conv perf. E.g.
|
| 588 |
+
# Conv with arguments like
|
| 589 |
+
# {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3],
|
| 590 |
+
# "stride": [2, 2], "padding": [1, 1], "groups": 2}
|
| 591 |
+
# slows down 31x using channels last..
|
| 592 |
+
|
| 593 |
+
# But a lot of timm models use depthwise separable convolution which will
|
| 594 |
+
# result in grouped convolution with in-channel size == 1.
|
| 595 |
+
# For those grouped convolution, channels last still helps a lot.
|
| 596 |
+
# E.g.
|
| 597 |
+
# Conv with arguments
|
| 598 |
+
# {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3],
|
| 599 |
+
# "stride": [2, 2], "padding": [1, 1], "groups": 58}
|
| 600 |
+
# get 1.86x speedup with channels last layout.
|
| 601 |
+
#
|
| 602 |
+
# The following heuristics skip using channels-last if the model contains
|
| 603 |
+
# grouped convolution with in-channels > 1.
|
| 604 |
+
if any(map(is_grouped, conv_nodes)):
|
| 605 |
+
log.debug(
|
| 606 |
+
"Skip layout opt because found grouped convolution with >1 in_channels!"
|
| 607 |
+
)
|
| 608 |
+
return False
|
| 609 |
+
|
| 610 |
+
# For some models that contain convolution with larger in-channel than out-channel, applying
|
| 611 |
+
# channels last hurts performance.
|
| 612 |
+
# Following models are skipped due to this:
|
| 613 |
+
# - pytorch_unet
|
| 614 |
+
# - phlippe_densenet (slightly worse)
|
| 615 |
+
# - Background_Matting (1.22x -> 0.821x)
|
| 616 |
+
# - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x)
|
| 617 |
+
if any(map(is_in_out_channel, conv_nodes)):
|
| 618 |
+
log.debug(
|
| 619 |
+
"Skip layout opt because some convolutions have smaller out_channel"
|
| 620 |
+
)
|
| 621 |
+
return False
|
| 622 |
+
|
| 623 |
+
# Following models are skipped due to this:
|
| 624 |
+
# - functorch_maml_omniglot
|
| 625 |
+
if all(map(is_small_channel, conv_nodes)):
|
| 626 |
+
log.debug("Skip layout opt because all convolution channels are too small")
|
| 627 |
+
return False
|
| 628 |
+
|
| 629 |
+
return True
|
| 630 |
+
|
| 631 |
+
def qualify_name(self, name: str) -> str:
|
| 632 |
+
"""Prepend the given name with the graph name if any."""
|
| 633 |
+
if self.name is not None:
|
| 634 |
+
return f"{self.name}_{name}"
|
| 635 |
+
return name
|
| 636 |
+
|
| 637 |
+
def make_subgraph(
|
| 638 |
+
self,
|
| 639 |
+
gm: torch.fx.GraphModule,
|
| 640 |
+
example_inputs: List[torch.Tensor],
|
| 641 |
+
subgraph_name: str,
|
| 642 |
+
) -> "GraphLowering":
|
| 643 |
+
"""
|
| 644 |
+
Make a subgraph of the current graph with all inherited
|
| 645 |
+
parts, except the graph module (`gm`) and `example_inputs`.
|
| 646 |
+
The subgraphs are lowered separately, but intended to be
|
| 647 |
+
inlined in the parent graph's codegening. Hence the need
|
| 648 |
+
for maintaining the same `shape_env` and other properties.
|
| 649 |
+
The subgraph name is qualified by the parent graph's name.
|
| 650 |
+
"""
|
| 651 |
+
return GraphLowering(
|
| 652 |
+
gm=gm,
|
| 653 |
+
example_inputs=example_inputs,
|
| 654 |
+
shape_env=self._shape_env,
|
| 655 |
+
cpp_wrapper=self.cpp_wrapper,
|
| 656 |
+
aot_mode=self.aot_mode,
|
| 657 |
+
extern_node_serializer=self.extern_node_serializer,
|
| 658 |
+
is_inference=self.is_inference,
|
| 659 |
+
name=self.qualify_name(subgraph_name),
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]:
|
| 663 |
+
"""
|
| 664 |
+
The rule to decide if an node prefer channels last is simple.
|
| 665 |
+
1. if it's input/output of a convolution
|
| 666 |
+
2. if one of its user prefers channels last
|
| 667 |
+
|
| 668 |
+
We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs;
|
| 669 |
+
Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers
|
| 670 |
+
channels last.
|
| 671 |
+
|
| 672 |
+
Consider the scenario: conv -> batch-norm -> relu -> conv
|
| 673 |
+
Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies:
|
| 674 |
+
1. the output of batch-norm should be channels last initially since its input is a conv's output.
|
| 675 |
+
Forcing the batch-norm's output to be contiguous results in the first copy
|
| 676 |
+
2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output.
|
| 677 |
+
We need convert it to channels last layout which results in the second copy.
|
| 678 |
+
With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies
|
| 679 |
+
can be saved.
|
| 680 |
+
"""
|
| 681 |
+
output_set: OrderedSet[Node] = OrderedSet()
|
| 682 |
+
for n in reversed(self.module.graph.nodes):
|
| 683 |
+
if n.target == torch.ops.aten.convolution.default:
|
| 684 |
+
output_set.add(n)
|
| 685 |
+
continue
|
| 686 |
+
|
| 687 |
+
for user in n.users:
|
| 688 |
+
if user in output_set:
|
| 689 |
+
output_set.add(n)
|
| 690 |
+
break
|
| 691 |
+
|
| 692 |
+
# need a second pass to add downstream nodes of those channel last nodes to the sets.
|
| 693 |
+
# This pass is especially needed to avoid mix-layout kernel inputs in backward pass.
|
| 694 |
+
#
|
| 695 |
+
# Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned
|
| 696 |
+
# from the fwd graph. Without this second pass, we will force relu's output to be contiguous.
|
| 697 |
+
# Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last
|
| 698 |
+
# tensors and passed to a kernel.
|
| 699 |
+
#
|
| 700 |
+
# This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x.
|
| 701 |
+
# It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x .
|
| 702 |
+
# This also helps the following models:
|
| 703 |
+
# - res2net101_26w_4s
|
| 704 |
+
# - res2net50_14w_8s
|
| 705 |
+
# - sebotnet33ts_256
|
| 706 |
+
for n in self.module.graph.nodes:
|
| 707 |
+
if n in output_set:
|
| 708 |
+
output_set.update(n.users)
|
| 709 |
+
|
| 710 |
+
return output_set
|
| 711 |
+
|
| 712 |
+
def warn_fallback(self, name: str) -> None:
|
| 713 |
+
if name not in self._warned_fallback:
|
| 714 |
+
self._warned_fallback.add(name)
|
| 715 |
+
perf_hint_log.info("Using FallbackKernel: %s", name)
|
| 716 |
+
|
| 717 |
+
def add_device_info(self, device: torch.device) -> None:
|
| 718 |
+
self.device_types.add(device.type)
|
| 719 |
+
if device.index is not None:
|
| 720 |
+
self.device_idxs.add(device.index)
|
| 721 |
+
if V.graph.current_node and device not in self.device_node_mapping:
|
| 722 |
+
self.device_node_mapping[device] = V.graph.current_node
|
| 723 |
+
|
| 724 |
+
@property
|
| 725 |
+
def fake_mode(self) -> torch._subclasses.fake_tensor.FakeTensorMode:
|
| 726 |
+
return V.fake_mode
|
| 727 |
+
|
| 728 |
+
def try_get_buffer(
|
| 729 |
+
self, buffer_name: str
|
| 730 |
+
) -> Optional[Union[ir.TensorBox, ir.Buffer]]:
|
| 731 |
+
if buffer_name in self.name_to_buffer:
|
| 732 |
+
return self.name_to_buffer[buffer_name]
|
| 733 |
+
if buffer_name in self.graph_inputs:
|
| 734 |
+
return self.graph_inputs[buffer_name]
|
| 735 |
+
if buffer_name in self.constants:
|
| 736 |
+
data = V.graph.constants[buffer_name]
|
| 737 |
+
return ir.ConstantBuffer(
|
| 738 |
+
buffer_name,
|
| 739 |
+
ir.FixedLayout(
|
| 740 |
+
data.device, data.dtype, *V.graph.static_sizes_strides(data)
|
| 741 |
+
),
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
return None
|
| 745 |
+
|
| 746 |
+
def get_buffer(self, buffer_name: str) -> Union[ir.TensorBox, ir.Buffer]:
|
| 747 |
+
buf = self.try_get_buffer(buffer_name)
|
| 748 |
+
if buf is not None:
|
| 749 |
+
return buf
|
| 750 |
+
raise RuntimeError(f"Failed to find buffer matching name {buffer_name}")
|
| 751 |
+
|
| 752 |
+
def get_dtype(self, buffer_name: str) -> torch.dtype:
|
| 753 |
+
if buffer_name in self.constants:
|
| 754 |
+
return self.constants[buffer_name].dtype
|
| 755 |
+
if buffer_name in self.name_to_buffer:
|
| 756 |
+
return self.name_to_buffer[buffer_name].get_dtype()
|
| 757 |
+
if buffer_name in self.graph_inputs:
|
| 758 |
+
return self.graph_inputs[buffer_name].get_dtype()
|
| 759 |
+
m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name)
|
| 760 |
+
if m:
|
| 761 |
+
return self.get_dtype(m.group(1))
|
| 762 |
+
raise KeyError(f"could not find {buffer_name}")
|
| 763 |
+
|
| 764 |
+
def get_numel(self, buffer_name: str) -> Union[int, Expr]:
|
| 765 |
+
from .ir import MultiOutputLayout
|
| 766 |
+
|
| 767 |
+
if buffer_name in self.constants:
|
| 768 |
+
return self.constants[buffer_name].numel()
|
| 769 |
+
if buffer_name in self.name_to_buffer:
|
| 770 |
+
buf = self.name_to_buffer[buffer_name]
|
| 771 |
+
if isinstance(getattr(buf, "layout", None), MultiOutputLayout):
|
| 772 |
+
return 1
|
| 773 |
+
return buf.get_numel()
|
| 774 |
+
if buffer_name in self.graph_inputs:
|
| 775 |
+
return self.graph_inputs[buffer_name].get_numel()
|
| 776 |
+
raise KeyError(f"could not find {buffer_name}")
|
| 777 |
+
|
| 778 |
+
def run(self, *args: Any) -> Any: # type: ignore[override]
|
| 779 |
+
with dynamo_timed("GraphLowering.run"):
|
| 780 |
+
return super().run(*args)
|
| 781 |
+
|
| 782 |
+
def register_operation(self, op: ir.Operation) -> str:
|
| 783 |
+
assert op.operation_name is None, f"Operation registered twice: {op}"
|
| 784 |
+
assert isinstance(op, ir.Operation)
|
| 785 |
+
name = self.qualify_name(f"op{len(self.operations)}")
|
| 786 |
+
self.operations.append(op)
|
| 787 |
+
self.name_to_op[name] = op
|
| 788 |
+
op.operation_name = name
|
| 789 |
+
return name
|
| 790 |
+
|
| 791 |
+
def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str:
|
| 792 |
+
name = self.qualify_name(f"buf{len(self.buffers)}")
|
| 793 |
+
self.buffers.append(buffer)
|
| 794 |
+
self.name_to_buffer[name] = buffer
|
| 795 |
+
if (
|
| 796 |
+
# Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
|
| 797 |
+
not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements())
|
| 798 |
+
and buffer.get_device() is not None
|
| 799 |
+
):
|
| 800 |
+
self.add_device_info(buffer.get_device())
|
| 801 |
+
|
| 802 |
+
if set_name:
|
| 803 |
+
buffer.name = name
|
| 804 |
+
return name
|
| 805 |
+
|
| 806 |
+
def register_operation_list(self, operation_names: List[str]) -> str:
|
| 807 |
+
name = self.qualify_name("list_" + "_".join(operation_names))
|
| 808 |
+
self.lists[name] = operation_names
|
| 809 |
+
return name
|
| 810 |
+
|
| 811 |
+
def register_users_of(
|
| 812 |
+
self, node_output: Union[Iterable[ir.IRNode], ir.IRNode]
|
| 813 |
+
) -> None:
|
| 814 |
+
def register(value: Union[Iterable[ir.IRNode], ir.IRNode]) -> None:
|
| 815 |
+
if isinstance(value, (list, tuple)):
|
| 816 |
+
for x in value:
|
| 817 |
+
register(x)
|
| 818 |
+
if isinstance(value, ir.TensorBox):
|
| 819 |
+
for read_name in value.get_read_names():
|
| 820 |
+
self.name_to_users[read_name].append(value)
|
| 821 |
+
|
| 822 |
+
register(node_output)
|
| 823 |
+
|
| 824 |
+
def mark_buffer_mutated(self, name: str) -> None:
|
| 825 |
+
"""
|
| 826 |
+
When a buffer is mutated we need to make sure all the reads to
|
| 827 |
+
the old version are realized before the mutation happens.
|
| 828 |
+
"""
|
| 829 |
+
assert isinstance(name, str)
|
| 830 |
+
self.mutated_buffers.add(name)
|
| 831 |
+
|
| 832 |
+
if name not in self.name_to_users:
|
| 833 |
+
return
|
| 834 |
+
|
| 835 |
+
for user in self.name_to_users[name]:
|
| 836 |
+
user.realize()
|
| 837 |
+
|
| 838 |
+
def get_original_value_of_constant(self, name: str) -> torch.Tensor:
|
| 839 |
+
"""
|
| 840 |
+
In AOTI, module buffers may have been mutated during the tracing and compilation.
|
| 841 |
+
Thus we need to read from previously stored original buffers, to make sure the
|
| 842 |
+
generated model.so uses correct initial values.
|
| 843 |
+
"""
|
| 844 |
+
assert name in self.allocated_constant_name and name in self.constants, (
|
| 845 |
+
"Can not find the original value for " + name
|
| 846 |
+
)
|
| 847 |
+
orig_name = get_cloned_parameter_buffer_name(self.allocated_constant_name[name])
|
| 848 |
+
return (
|
| 849 |
+
self.module.meta[orig_name]
|
| 850 |
+
if orig_name in self.module.meta
|
| 851 |
+
else self.constants[name]
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
def allocate_non_dup_const_name(
|
| 855 |
+
self, name: Optional[str], data: Union[Tensor]
|
| 856 |
+
) -> str:
|
| 857 |
+
orig_name = name
|
| 858 |
+
if not config.aot_inductor.use_runtime_constant_folding:
|
| 859 |
+
for constant_name, value in self.constants.items():
|
| 860 |
+
if (
|
| 861 |
+
not data.is_mkldnn
|
| 862 |
+
and data.size() == value.size()
|
| 863 |
+
and data.stride() == value.stride()
|
| 864 |
+
and data.dtype == value.dtype
|
| 865 |
+
and data.device == value.device
|
| 866 |
+
and data.untyped_storage().data_ptr()
|
| 867 |
+
== value.untyped_storage().data_ptr()
|
| 868 |
+
and data.storage_offset() == value.storage_offset()
|
| 869 |
+
):
|
| 870 |
+
return constant_name
|
| 871 |
+
|
| 872 |
+
if name is None:
|
| 873 |
+
name = f"constant{len(self.constants)}"
|
| 874 |
+
assert name is not None
|
| 875 |
+
if name[0].isdigit():
|
| 876 |
+
name = f"constant_{name}"
|
| 877 |
+
name = self.qualify_name(name)
|
| 878 |
+
# We may generate a var name for each constant in the codegen.
|
| 879 |
+
# Let's only keep sane characters.
|
| 880 |
+
prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
| 881 |
+
name = prefix
|
| 882 |
+
cnt = 0
|
| 883 |
+
while name in self.constants:
|
| 884 |
+
name = f"{prefix}_{cnt}"
|
| 885 |
+
cnt += 1
|
| 886 |
+
self.constants[name] = data
|
| 887 |
+
self.constant_reprs[name] = (
|
| 888 |
+
f"{data.device!r} {data.dtype!r} "
|
| 889 |
+
f"{tuple(data.size())!r} {tuple(data.stride())!r} "
|
| 890 |
+
f"{hash(data):x}"
|
| 891 |
+
)
|
| 892 |
+
self.allocated_constant_name[name] = orig_name # type: ignore[assignment]
|
| 893 |
+
return name
|
| 894 |
+
|
| 895 |
+
def add_tensor_constant(
|
| 896 |
+
self, data: Tensor, name: Optional[str] = None
|
| 897 |
+
) -> TensorBox:
|
| 898 |
+
new_name = self.allocate_non_dup_const_name(name, data)
|
| 899 |
+
return TensorBox.create(
|
| 900 |
+
ir.ConstantBuffer(
|
| 901 |
+
new_name,
|
| 902 |
+
FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)),
|
| 903 |
+
)
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
def constant_name(self, name: str, device_override: Optional[torch.device]) -> str:
|
| 907 |
+
"""
|
| 908 |
+
We AOT copy constants to the devices they are needed on.
|
| 909 |
+
If device_override doesn't match the constant's device, then
|
| 910 |
+
copy it and return a different name.
|
| 911 |
+
"""
|
| 912 |
+
if self.constants[name].device == device_override or device_override is None:
|
| 913 |
+
return name
|
| 914 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 915 |
+
# caller might have OrderedSet fake tensor mode which will create a fake tensor
|
| 916 |
+
# when calling .to, so unset modes here
|
| 917 |
+
return self.allocate_non_dup_const_name(
|
| 918 |
+
f"{name}_{device_override.type}{device_override.index or 0}",
|
| 919 |
+
self.constants[name].to(device_override),
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
def placeholder(
|
| 923 |
+
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
|
| 924 |
+
) -> Union[Expr, TensorBox, None]:
|
| 925 |
+
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
| 926 |
+
self.graph_input_names.append(target)
|
| 927 |
+
if isinstance(example, SymTypes):
|
| 928 |
+
expr = example.node.expr
|
| 929 |
+
self.graph_inputs[target] = expr
|
| 930 |
+
return expr
|
| 931 |
+
elif isinstance(example, (int, bool, float)):
|
| 932 |
+
expr = sympy.sympify(example)
|
| 933 |
+
self.graph_inputs[target] = expr
|
| 934 |
+
return expr
|
| 935 |
+
elif example is None:
|
| 936 |
+
return None
|
| 937 |
+
if isinstance(example, BackwardState):
|
| 938 |
+
# Ignored arg, must be unused
|
| 939 |
+
# Alternately we could filter this out in AotAutograd
|
| 940 |
+
return None
|
| 941 |
+
assert isinstance(example, torch.Tensor), example
|
| 942 |
+
# todo(chilli): We can remove the last check once we turn buffers into
|
| 943 |
+
# static shape tensors. That's a hack to workaround Inductor believing
|
| 944 |
+
# the buffer should be static but us passing in a fake tensor with
|
| 945 |
+
# symbolic shapes.
|
| 946 |
+
if not example._has_symbolic_sizes_strides:
|
| 947 |
+
# the first N inputs are weights
|
| 948 |
+
sizes, strides = self.static_sizes_strides(example)
|
| 949 |
+
else:
|
| 950 |
+
sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment]
|
| 951 |
+
# TODO(jansel): handle input aliasing
|
| 952 |
+
target = self.qualify_name(target)
|
| 953 |
+
tensor = TensorBox.create(
|
| 954 |
+
InputBuffer(
|
| 955 |
+
target,
|
| 956 |
+
FixedLayout(example.device, example.dtype, sizes, strides),
|
| 957 |
+
)
|
| 958 |
+
)
|
| 959 |
+
self.graph_inputs[target] = tensor
|
| 960 |
+
self.graph_inputs_original[target] = tensor.data.data
|
| 961 |
+
if self.current_node.users: # cudagraphs should work with an unused CPU input
|
| 962 |
+
self.add_device_info(example.device)
|
| 963 |
+
|
| 964 |
+
# Note: [Input Alignment handling in Inductor]
|
| 965 |
+
# Alignment matters for generating efficient code. Some operations,
|
| 966 |
+
# e.g. vectorized loads, can only be performed on aligned inputs.
|
| 967 |
+
#
|
| 968 |
+
# But if we codegen assuming aligned inputs and then get unaligned
|
| 969 |
+
# inputs at runtime, then we are forced to clone - which is bad for
|
| 970 |
+
# both perf and memory usage.
|
| 971 |
+
#
|
| 972 |
+
# One option would be to guard on storage_offset%ALIGNMENT, and then
|
| 973 |
+
# codegen based on this. But storage_offset guards turned out to be
|
| 974 |
+
# expensive and cause recompiles; Instead, we're generating code
|
| 975 |
+
# based on the alignment of the example input without guarding.
|
| 976 |
+
with maybe_get_suppress_shape_guards_ctx():
|
| 977 |
+
if should_assume_input_aligned(example):
|
| 978 |
+
self.aligned_inputs.add(target)
|
| 979 |
+
return tensor
|
| 980 |
+
|
| 981 |
+
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg, override]
|
| 982 |
+
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
|
| 983 |
+
return super().call_function(target, args, kwargs)
|
| 984 |
+
|
| 985 |
+
# hasattr on OpOverloadPacket is slow, check isinstance first
|
| 986 |
+
if not isinstance(target, torch._ops.OpOverloadPacket) and hasattr(
|
| 987 |
+
target, "_inductor_lowering_function"
|
| 988 |
+
):
|
| 989 |
+
# passthrough lowerings from .pattern_matcher
|
| 990 |
+
return target(*args, **kwargs)
|
| 991 |
+
|
| 992 |
+
if target not in lowerings:
|
| 993 |
+
assert isinstance(
|
| 994 |
+
target, torch._ops.OpOverload
|
| 995 |
+
), f"{target} is not an OpOverload"
|
| 996 |
+
base_name = target.name().split(".")[0]
|
| 997 |
+
if base_name in FALLBACK_ALLOW_LIST:
|
| 998 |
+
make_fallback(target)
|
| 999 |
+
elif config.implicit_fallbacks:
|
| 1000 |
+
error = (
|
| 1001 |
+
MissingOperatorWithDecomp
|
| 1002 |
+
if get_decompositions([target])
|
| 1003 |
+
else MissingOperatorWithoutDecomp
|
| 1004 |
+
)
|
| 1005 |
+
log.info(
|
| 1006 |
+
"Creating implicit fallback for:\n%s",
|
| 1007 |
+
error.operator_str(target, args, kwargs),
|
| 1008 |
+
)
|
| 1009 |
+
make_fallback(target)
|
| 1010 |
+
|
| 1011 |
+
elif get_decompositions([target]):
|
| 1012 |
+
# There isn't a good way to dynamically patch this in
|
| 1013 |
+
# since AOT Autograd already ran. The error message tells
|
| 1014 |
+
# the user how to fix it.
|
| 1015 |
+
raise MissingOperatorWithDecomp(target, args, kwargs)
|
| 1016 |
+
else:
|
| 1017 |
+
raise MissingOperatorWithoutDecomp(target, args, kwargs)
|
| 1018 |
+
|
| 1019 |
+
try:
|
| 1020 |
+
log.debug(" via %s", lowerings[target]) # type: ignore[index]
|
| 1021 |
+
out = lowerings[target](*args, **kwargs) # type: ignore[index]
|
| 1022 |
+
return out
|
| 1023 |
+
except Exception as e:
|
| 1024 |
+
raise LoweringException(e, target, args, kwargs).with_traceback(
|
| 1025 |
+
e.__traceback__
|
| 1026 |
+
) from None
|
| 1027 |
+
|
| 1028 |
+
@staticmethod
|
| 1029 |
+
def can_inline_constant(t: torch.Tensor) -> bool:
|
| 1030 |
+
"""
|
| 1031 |
+
True if this is a small constant attr that will be inlined.
|
| 1032 |
+
"""
|
| 1033 |
+
return len(t.shape) == 1 and t.shape[0] <= 8
|
| 1034 |
+
|
| 1035 |
+
def get_attr(
|
| 1036 |
+
self, target: str, args: Tuple[()], kwargs: Dict[str, object] # type: ignore[override]
|
| 1037 |
+
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
|
| 1038 |
+
# this is a constant
|
| 1039 |
+
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
|
| 1040 |
+
|
| 1041 |
+
if isinstance(value, torch.fx.GraphModule):
|
| 1042 |
+
return ir.Subgraph(name=target, graph_module=value)
|
| 1043 |
+
|
| 1044 |
+
if isinstance(value, torch._C.ScriptObject):
|
| 1045 |
+
self.torchbind_constants[target] = value
|
| 1046 |
+
self.constant_reprs[target] = ""
|
| 1047 |
+
return TorchBindObject(target, value)
|
| 1048 |
+
|
| 1049 |
+
assert isinstance(value, torch.Tensor)
|
| 1050 |
+
if (
|
| 1051 |
+
config.aot_inductor.use_runtime_constant_folding
|
| 1052 |
+
or config.always_keep_tensor_constants
|
| 1053 |
+
or unsupported_output_tensor(value)
|
| 1054 |
+
):
|
| 1055 |
+
return self.add_tensor_constant(value, target)
|
| 1056 |
+
|
| 1057 |
+
with no_dispatch():
|
| 1058 |
+
if value.shape == ():
|
| 1059 |
+
return Constant(value.item(), value.dtype, value.device)
|
| 1060 |
+
if self.can_inline_constant(value):
|
| 1061 |
+
log.debug("Inlining constant: %s ", str(target))
|
| 1062 |
+
# tensor lowering has constant inlining logic
|
| 1063 |
+
from .lowering import tensor
|
| 1064 |
+
|
| 1065 |
+
return tensor(value.tolist(), dtype=value.dtype, device=value.device)
|
| 1066 |
+
|
| 1067 |
+
return self.add_tensor_constant(value, target)
|
| 1068 |
+
|
| 1069 |
+
def call_module(self, target: Any, args: Any, kwargs: Any) -> NoReturn:
|
| 1070 |
+
raise AssertionError
|
| 1071 |
+
|
| 1072 |
+
def call_method(self, target: Any, args: Any, kwargs: Any) -> NoReturn:
|
| 1073 |
+
raise AssertionError
|
| 1074 |
+
|
| 1075 |
+
def output(
|
| 1076 |
+
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
|
| 1077 |
+
) -> None:
|
| 1078 |
+
result = super().output(target, args, kwargs) # type: ignore[arg-type]
|
| 1079 |
+
if not isinstance(result, (tuple, list)):
|
| 1080 |
+
# nested subgraphs can have singleton outputs
|
| 1081 |
+
result = (result,)
|
| 1082 |
+
assert isinstance(result, (tuple, list)), type(result)
|
| 1083 |
+
assert all(
|
| 1084 |
+
isinstance(
|
| 1085 |
+
x,
|
| 1086 |
+
(
|
| 1087 |
+
TensorBox,
|
| 1088 |
+
ir.Constant,
|
| 1089 |
+
type(None),
|
| 1090 |
+
ir.ConstantBuffer,
|
| 1091 |
+
sympy.Expr,
|
| 1092 |
+
sympy.logic.boolalg.Boolean,
|
| 1093 |
+
int,
|
| 1094 |
+
ir.EffectfulKernel,
|
| 1095 |
+
),
|
| 1096 |
+
)
|
| 1097 |
+
for x in result
|
| 1098 |
+
), result
|
| 1099 |
+
|
| 1100 |
+
fx_node_args = V.graph.current_node.args[0] # type: ignore[arg-type]
|
| 1101 |
+
if not isinstance(fx_node_args, (tuple, list)):
|
| 1102 |
+
# nested subgraphs can have singleton outputs
|
| 1103 |
+
fx_node_args = (fx_node_args,)
|
| 1104 |
+
result = [ir.ExternKernel.realize_input(x) for x in result]
|
| 1105 |
+
result_correct_strides = []
|
| 1106 |
+
|
| 1107 |
+
assert len(fx_node_args) == len(result)
|
| 1108 |
+
for r, fx_node in zip(result, fx_node_args):
|
| 1109 |
+
if not isinstance(r, (ir.TensorBox, ir.BaseView)):
|
| 1110 |
+
result_correct_strides.append(r)
|
| 1111 |
+
else:
|
| 1112 |
+
# AOT Autograd tries to detect stride divergence of inductor from output metadata.
|
| 1113 |
+
# Here, we try to avoid spurious divergence by matching insignificant strides such as
|
| 1114 |
+
result_correct_strides.append(
|
| 1115 |
+
self.try_match_insignificant_strides(
|
| 1116 |
+
r, fx_node.meta["val"].stride()
|
| 1117 |
+
)
|
| 1118 |
+
)
|
| 1119 |
+
|
| 1120 |
+
self.graph_outputs = result_correct_strides
|
| 1121 |
+
value: ir.IRNode
|
| 1122 |
+
for name, value in self.graph_inputs.items():
|
| 1123 |
+
assert isinstance(
|
| 1124 |
+
value, (TensorBox, sympy.Expr)
|
| 1125 |
+
), f"Unsupported inductor graph input type: {type(value)}"
|
| 1126 |
+
if not isinstance(value, TensorBox):
|
| 1127 |
+
continue
|
| 1128 |
+
value.realize()
|
| 1129 |
+
assert isinstance(value, TensorBox)
|
| 1130 |
+
value = value.data
|
| 1131 |
+
assert isinstance(value, ir.StorageBox)
|
| 1132 |
+
value_storage_box = value
|
| 1133 |
+
value = value.data
|
| 1134 |
+
if not isinstance(value, InputBuffer) or value.get_name() != name:
|
| 1135 |
+
# one of our inputs was mutated, need to turn that into a copy
|
| 1136 |
+
ir.MutationLayoutSHOULDREMOVE.realize_into(
|
| 1137 |
+
value, self.graph_inputs_original[name]
|
| 1138 |
+
)
|
| 1139 |
+
# replace output with mutated input
|
| 1140 |
+
try:
|
| 1141 |
+
ind = self.graph_outputs.index(value_storage_box)
|
| 1142 |
+
self.graph_outputs[ind] = self.graph_inputs_original[name]
|
| 1143 |
+
except ValueError:
|
| 1144 |
+
pass
|
| 1145 |
+
|
| 1146 |
+
self.finalize()
|
| 1147 |
+
log.debug(
|
| 1148 |
+
"Force channels last inputs for %d conv for the current graph with id %d",
|
| 1149 |
+
self.num_channels_last_conv,
|
| 1150 |
+
self.graph_id if self.graph_id is not None else -1,
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
def finalize(self) -> None:
|
| 1154 |
+
for buf in self.buffers:
|
| 1155 |
+
buf.decide_layout()
|
| 1156 |
+
|
| 1157 |
+
@contextmanager
|
| 1158 |
+
def set_current_node(self, node: torch.fx.Node): # type: ignore[no-untyped-def]
|
| 1159 |
+
old = self.current_node
|
| 1160 |
+
try:
|
| 1161 |
+
self.current_node = node
|
| 1162 |
+
yield
|
| 1163 |
+
finally:
|
| 1164 |
+
self.current_node = old
|
| 1165 |
+
|
| 1166 |
+
def try_match_insignificant_strides(
|
| 1167 |
+
self,
|
| 1168 |
+
tensor: Union[ir.TensorBox, ir.BaseView],
|
| 1169 |
+
meta_strides_inp: Tuple[Union[int, torch.SymInt], ...],
|
| 1170 |
+
) -> Union[ir.TensorBox, ir.BaseView]:
|
| 1171 |
+
"""
|
| 1172 |
+
Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant
|
| 1173 |
+
dimensions - size 0 or 1 - will be updated.
|
| 1174 |
+
|
| 1175 |
+
If there are real stride differences (NHWC vs NCHW) then the input will be returned.
|
| 1176 |
+
"""
|
| 1177 |
+
|
| 1178 |
+
# should have already been realized
|
| 1179 |
+
assert torch._inductor.ir.is_storage_and_layout(tensor)
|
| 1180 |
+
|
| 1181 |
+
meta_strides = [
|
| 1182 |
+
s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_strides_inp
|
| 1183 |
+
]
|
| 1184 |
+
|
| 1185 |
+
if all(
|
| 1186 |
+
self.sizevars.statically_known_equals(s1, s2)
|
| 1187 |
+
for s1, s2 in zip(meta_strides, tensor.get_stride())
|
| 1188 |
+
):
|
| 1189 |
+
return tensor # type: ignore[arg-type]
|
| 1190 |
+
|
| 1191 |
+
def significant_strides_equal(
|
| 1192 |
+
shape: Sequence[Union[Expr, int]],
|
| 1193 |
+
meta_strides: Sequence[Union[Expr, int]],
|
| 1194 |
+
tensor_strides: Sequence[Union[Expr, int]],
|
| 1195 |
+
) -> bool:
|
| 1196 |
+
for dim, s1, s2 in zip(shape, meta_strides, tensor_strides):
|
| 1197 |
+
if self.sizevars.statically_known_leq(dim, 1): # type: ignore[arg-type]
|
| 1198 |
+
continue
|
| 1199 |
+
|
| 1200 |
+
if not self.sizevars.statically_known_equals(s1, s2):
|
| 1201 |
+
return False
|
| 1202 |
+
|
| 1203 |
+
return True
|
| 1204 |
+
|
| 1205 |
+
if not significant_strides_equal(
|
| 1206 |
+
tensor.get_size(), meta_strides, tensor.get_stride()
|
| 1207 |
+
):
|
| 1208 |
+
return tensor
|
| 1209 |
+
|
| 1210 |
+
storage, old_layout = torch._inductor.ir.as_storage_and_layout(tensor)
|
| 1211 |
+
new_stride = list(old_layout.stride)
|
| 1212 |
+
for i, s in enumerate(tensor.get_size()):
|
| 1213 |
+
if self.sizevars.statically_known_leq(s, 1): # type: ignore[arg-type]
|
| 1214 |
+
new_stride[i] = meta_strides[i]
|
| 1215 |
+
|
| 1216 |
+
new_layout = torch._inductor.ir.FixedLayout(
|
| 1217 |
+
old_layout.device,
|
| 1218 |
+
old_layout.dtype,
|
| 1219 |
+
old_layout.size,
|
| 1220 |
+
new_stride,
|
| 1221 |
+
old_layout.offset,
|
| 1222 |
+
)
|
| 1223 |
+
return ir.TensorBox(torch._inductor.ir.ReinterpretView(storage, new_layout))
|
| 1224 |
+
|
| 1225 |
+
def propagate_mutation(
|
| 1226 |
+
self,
|
| 1227 |
+
fx_node: torch.fx.Node,
|
| 1228 |
+
old_args: Tuple[Any],
|
| 1229 |
+
old_kwargs: Dict[str, Any],
|
| 1230 |
+
new_args: Tuple[Any],
|
| 1231 |
+
new_kwargs: Dict[str, Any],
|
| 1232 |
+
) -> None:
|
| 1233 |
+
"""Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs.
|
| 1234 |
+
|
| 1235 |
+
Assumes we may have cloned old_args/old_kwargs into new_args/new_kwargs
|
| 1236 |
+
and then called fx_node(*new_args, **new_kwargs).
|
| 1237 |
+
|
| 1238 |
+
If fx_node mutates any of new_args/new_kwargs, and they are different from
|
| 1239 |
+
old_args/old_kwargs, then we need to update the original tensor.
|
| 1240 |
+
"""
|
| 1241 |
+
assert isinstance(fx_node.target, torch._ops.OpOverload)
|
| 1242 |
+
assert len(old_args) == len(new_args)
|
| 1243 |
+
assert len(old_kwargs) == len(new_kwargs)
|
| 1244 |
+
|
| 1245 |
+
def maybe_propagate(
|
| 1246 |
+
schema_arg: torch._C.Argument, old_arg: ir.IRNode, new_arg: ir.IRNode
|
| 1247 |
+
) -> None:
|
| 1248 |
+
if old_arg is new_arg:
|
| 1249 |
+
return
|
| 1250 |
+
if schema_arg.alias_info is not None and schema_arg.alias_info.is_write:
|
| 1251 |
+
# The lowering for copy_ is smart enough to "replace" old_arg with
|
| 1252 |
+
# new_arg in all future uses so a copy_ kernel never gets emitted.
|
| 1253 |
+
self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {})
|
| 1254 |
+
|
| 1255 |
+
schema = fx_node.target._schema
|
| 1256 |
+
for idx, (old_arg, new_arg) in enumerate(zip(old_args, new_args)):
|
| 1257 |
+
schema_arg = schema.arguments[idx]
|
| 1258 |
+
maybe_propagate(schema_arg, old_arg, new_arg)
|
| 1259 |
+
|
| 1260 |
+
schema_kwargs = {arg.name: arg for arg in schema.arguments}
|
| 1261 |
+
|
| 1262 |
+
for key in old_kwargs.keys():
|
| 1263 |
+
old_arg = old_kwargs[key]
|
| 1264 |
+
new_arg = new_kwargs[key]
|
| 1265 |
+
schema_arg = schema_kwargs[key]
|
| 1266 |
+
maybe_propagate(schema_arg, old_arg, new_arg)
|
| 1267 |
+
|
| 1268 |
+
def run_node(self, n: torch.fx.Node) -> object:
|
| 1269 |
+
def debug(msg: str) -> None:
|
| 1270 |
+
log.debug("lowering %s %s", LazyString(n.format_node), msg)
|
| 1271 |
+
|
| 1272 |
+
buffer_watermark = len(self.buffers)
|
| 1273 |
+
operation_watermark = len(self.operations)
|
| 1274 |
+
|
| 1275 |
+
origins = {n}
|
| 1276 |
+
is_call_function = n.op == "call_function"
|
| 1277 |
+
if is_call_function:
|
| 1278 |
+
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
| 1279 |
+
origins |= gather_origins(args, kwargs)
|
| 1280 |
+
with ir.IRNode.current_origins(origins), self.set_current_node( # type: ignore[arg-type]
|
| 1281 |
+
n
|
| 1282 |
+
), V.set_current_node(
|
| 1283 |
+
n
|
| 1284 |
+
):
|
| 1285 |
+
if (
|
| 1286 |
+
n.op == "call_function"
|
| 1287 |
+
and n.target is not operator.getitem
|
| 1288 |
+
and fallback_node_due_to_unsupported_type(n)
|
| 1289 |
+
):
|
| 1290 |
+
debug("fallback_handler")
|
| 1291 |
+
result = fallback_handler(n.target, add_to_fallback_set=False)(
|
| 1292 |
+
*args, **kwargs # type: ignore[possibly-undefined]
|
| 1293 |
+
)
|
| 1294 |
+
elif n.op == "call_function" and (
|
| 1295 |
+
layout_constraints := maybe_layout_constraints(n.target) # type: ignore[arg-type]
|
| 1296 |
+
):
|
| 1297 |
+
debug("layout_constraints")
|
| 1298 |
+
old_args = args # type: ignore[possibly-undefined]
|
| 1299 |
+
old_kwargs = kwargs # type: ignore[possibly-undefined]
|
| 1300 |
+
args, kwargs = layout_constraints(n, *args, **kwargs) # type: ignore[index]
|
| 1301 |
+
result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type]
|
| 1302 |
+
# layout_constraints are allowed to make new copies of the inputs.
|
| 1303 |
+
# if they do, and if the target is mutable, then we need to
|
| 1304 |
+
# write the new values back into the original inputs.
|
| 1305 |
+
self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined]
|
| 1306 |
+
elif is_magic_method(n.target):
|
| 1307 |
+
# TODO: this is sus, it probably should be handled in the
|
| 1308 |
+
# lowerings themselves similarly to sym_size/sym-stride
|
| 1309 |
+
# https://github.com/pytorch/pytorch/issues/127789
|
| 1310 |
+
debug("is_magic_method")
|
| 1311 |
+
if isinstance(
|
| 1312 |
+
n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool)
|
| 1313 |
+
):
|
| 1314 |
+
result = n.meta["val"].node.expr
|
| 1315 |
+
else:
|
| 1316 |
+
result = super().run_node(n)
|
| 1317 |
+
else:
|
| 1318 |
+
debug("")
|
| 1319 |
+
result = super().run_node(n)
|
| 1320 |
+
|
| 1321 |
+
# require the same stride order for dense outputs,
|
| 1322 |
+
# 1. user-land view() will not throw because inductor
|
| 1323 |
+
# output different strides than eager
|
| 1324 |
+
# long term the solution is to make view() always succeed
|
| 1325 |
+
# with infallible strides.
|
| 1326 |
+
# 2: as_strided ops, we need make sure its input has same size/stride with
|
| 1327 |
+
# eager model to align with eager behavior.
|
| 1328 |
+
as_strided_ops = [
|
| 1329 |
+
torch.ops.aten.as_strided.default,
|
| 1330 |
+
torch.ops.aten.as_strided_.default,
|
| 1331 |
+
torch.ops.aten.as_strided_scatter.default,
|
| 1332 |
+
torch.ops.aten.resize.default,
|
| 1333 |
+
torch.ops.aten.resize_as.default,
|
| 1334 |
+
]
|
| 1335 |
+
is_output = any(user.op == "output" for user in n.users)
|
| 1336 |
+
is_input_for_as_strided = any(
|
| 1337 |
+
user.target in as_strided_ops for user in n.users
|
| 1338 |
+
)
|
| 1339 |
+
|
| 1340 |
+
if n.meta.get("inductor_realize_to_strides", False) and isinstance(
|
| 1341 |
+
result, TensorBox
|
| 1342 |
+
):
|
| 1343 |
+
result.realize()
|
| 1344 |
+
strides = n.meta["val"].stride()
|
| 1345 |
+
sym_strides = torch._inductor.utils.any_is_symbolic(*strides)
|
| 1346 |
+
if (
|
| 1347 |
+
not hasattr(result, "get_stride")
|
| 1348 |
+
or result.get_stride() != strides
|
| 1349 |
+
and not sym_strides
|
| 1350 |
+
):
|
| 1351 |
+
stride_order = ir.get_stride_order(strides)
|
| 1352 |
+
result = ir.ExternKernel.require_stride_order(result, stride_order)
|
| 1353 |
+
if (
|
| 1354 |
+
is_output
|
| 1355 |
+
and isinstance(result, TensorBox)
|
| 1356 |
+
and isinstance(result.data, ir.BaseView)
|
| 1357 |
+
):
|
| 1358 |
+
# Realize so that outputs are correctly aliased
|
| 1359 |
+
result.realize()
|
| 1360 |
+
|
| 1361 |
+
if (is_output or is_input_for_as_strided) and isinstance(
|
| 1362 |
+
n.meta["val"], torch.Tensor
|
| 1363 |
+
):
|
| 1364 |
+
strides = n.meta["val"].stride()
|
| 1365 |
+
if len(strides):
|
| 1366 |
+
allow_padding = (
|
| 1367 |
+
config.pad_outputs or n.name not in self.user_visible_outputs
|
| 1368 |
+
) and not is_input_for_as_strided
|
| 1369 |
+
dense = torch._prims_common.is_non_overlapping_and_dense(
|
| 1370 |
+
n.meta["val"]
|
| 1371 |
+
)
|
| 1372 |
+
unbacked_symbols_in_strides = (
|
| 1373 |
+
len(free_unbacked_symbols(strides)) > 0
|
| 1374 |
+
)
|
| 1375 |
+
if (
|
| 1376 |
+
not unbacked_symbols_in_strides
|
| 1377 |
+
and dense
|
| 1378 |
+
and len(result.get_size()) == 4
|
| 1379 |
+
and n in self.nodes_prefer_channels_last
|
| 1380 |
+
and n.name not in self.user_visible_outputs
|
| 1381 |
+
and not is_input_for_as_strided
|
| 1382 |
+
):
|
| 1383 |
+
strides = ir.FlexibleLayout.stride_ordered_for_memory_format(
|
| 1384 |
+
result.get_size(), torch.channels_last
|
| 1385 |
+
)
|
| 1386 |
+
if not unbacked_symbols_in_strides and len(strides):
|
| 1387 |
+
# To avoid converting possible view ops to a copy kernel, we use the previous
|
| 1388 |
+
# require_exact_strides to handle views. But ultimately it's better to require
|
| 1389 |
+
# the right strides at the tensor definition.
|
| 1390 |
+
if n.meta["val"]._is_view() or isinstance(
|
| 1391 |
+
result.data, ir.BaseView
|
| 1392 |
+
):
|
| 1393 |
+
result = ir.ExternKernel.require_stride_order(
|
| 1394 |
+
result,
|
| 1395 |
+
ir.get_stride_order(strides),
|
| 1396 |
+
allow_padding=allow_padding,
|
| 1397 |
+
)
|
| 1398 |
+
else:
|
| 1399 |
+
strides = [
|
| 1400 |
+
s.node.expr if isinstance(s, torch.SymInt) else s
|
| 1401 |
+
for s in strides
|
| 1402 |
+
]
|
| 1403 |
+
result = ir.ExternKernel.require_exact_strides(
|
| 1404 |
+
result, strides, allow_padding=allow_padding
|
| 1405 |
+
)
|
| 1406 |
+
|
| 1407 |
+
# Realize if (1) any user need inputs realized, or (2) there is
|
| 1408 |
+
# already too many reads and rematerializing can be bad.
|
| 1409 |
+
num_users = len(OrderedSet(n.users))
|
| 1410 |
+
if num_users > 1 and isinstance(result, TensorBox):
|
| 1411 |
+
for user in n.users:
|
| 1412 |
+
if user.target in needs_realized_inputs:
|
| 1413 |
+
result.realize_hint()
|
| 1414 |
+
# This inclusion is somewhat controversial (from
|
| 1415 |
+
# discussion between Horace, Natalia, and Elias).
|
| 1416 |
+
# Currently, it's not very clear why this is helpful.
|
| 1417 |
+
# The general idea here is that even though a node may
|
| 1418 |
+
# have FlexibleLayout, we still often *treat* it as if
|
| 1419 |
+
# it was contiguous. This appears to sometimes result in
|
| 1420 |
+
# suboptimal behavior.
|
| 1421 |
+
#
|
| 1422 |
+
# When we do a better job selecting layout, we should
|
| 1423 |
+
# revisit this.
|
| 1424 |
+
need_fixed_layout = [
|
| 1425 |
+
torch.ops.aten.convolution_backward.default,
|
| 1426 |
+
torch.ops.aten.mm.default,
|
| 1427 |
+
torch.ops.aten._int_mm.default,
|
| 1428 |
+
]
|
| 1429 |
+
need_fixed_channels_last_layout = []
|
| 1430 |
+
if not self.layout_opt:
|
| 1431 |
+
need_fixed_layout.append(torch.ops.aten.convolution.default)
|
| 1432 |
+
if torch._C._has_mkldnn:
|
| 1433 |
+
need_fixed_layout += [
|
| 1434 |
+
torch.ops.mkldnn._linear_pointwise.default,
|
| 1435 |
+
torch.ops.mkldnn._linear_pointwise.binary,
|
| 1436 |
+
torch.ops.aten.mkldnn_rnn_layer.default,
|
| 1437 |
+
torch.ops.onednn.qlinear_pointwise.default,
|
| 1438 |
+
torch.ops.onednn.qlinear_pointwise.tensor,
|
| 1439 |
+
torch.ops.onednn.qlinear_pointwise.binary,
|
| 1440 |
+
torch.ops.onednn.qlinear_pointwise.binary_tensor,
|
| 1441 |
+
]
|
| 1442 |
+
need_fixed_channels_last_layout += [
|
| 1443 |
+
torch.ops.mkldnn._convolution_pointwise.default,
|
| 1444 |
+
torch.ops.mkldnn._convolution_pointwise.binary,
|
| 1445 |
+
torch.ops.mkldnn._convolution_pointwise_.binary,
|
| 1446 |
+
torch.ops.mkldnn._convolution_transpose_pointwise.default,
|
| 1447 |
+
torch.ops.onednn.qconv2d_pointwise.default,
|
| 1448 |
+
torch.ops.onednn.qconv2d_pointwise.binary,
|
| 1449 |
+
]
|
| 1450 |
+
if torch._C.has_mkl:
|
| 1451 |
+
need_fixed_layout += [torch.ops.mkl._mkl_linear.default]
|
| 1452 |
+
if user.target in need_fixed_layout:
|
| 1453 |
+
result = ir.ExternKernel.require_stride_order(
|
| 1454 |
+
result,
|
| 1455 |
+
ir.get_stride_order(n.meta["val"].stride()),
|
| 1456 |
+
allow_padding=True,
|
| 1457 |
+
)
|
| 1458 |
+
if (
|
| 1459 |
+
user.target in need_fixed_channels_last_layout
|
| 1460 |
+
and n is user.args[0]
|
| 1461 |
+
):
|
| 1462 |
+
result = ir.ExternKernel.require_stride_order(
|
| 1463 |
+
result,
|
| 1464 |
+
ir.get_stride_order(
|
| 1465 |
+
make_channels_last_strides_for(n.meta["val"].shape)
|
| 1466 |
+
),
|
| 1467 |
+
)
|
| 1468 |
+
if user.op == "output":
|
| 1469 |
+
if isinstance(result.data.data, (Pointwise, Reduction)):
|
| 1470 |
+
result.realize()
|
| 1471 |
+
|
| 1472 |
+
# TODO(jansel): introduce a store vs inline choice
|
| 1473 |
+
result.mark_reuse(len(n.users))
|
| 1474 |
+
|
| 1475 |
+
# Realize if the IRNode already has accumulated lots of reads
|
| 1476 |
+
if isinstance(result, TensorBox) and result.has_exceeded_max_reads():
|
| 1477 |
+
# Prevent excessive accumulation in a computed buffer, when
|
| 1478 |
+
# there are multiple branches each with small number of memory
|
| 1479 |
+
# reads, but they converge to a user.
|
| 1480 |
+
result.realize_hint()
|
| 1481 |
+
|
| 1482 |
+
# Realize if a Pointwise has too much stuff to be inlined.
|
| 1483 |
+
# As this may cause RecursionError during Inductor's evaluation.
|
| 1484 |
+
if isinstance(result, TensorBox) and isinstance(result.data, StorageBox):
|
| 1485 |
+
curr = result.data.data
|
| 1486 |
+
if isinstance(curr, Pointwise):
|
| 1487 |
+
# Use inner fn as a rough proxy. Good enough.
|
| 1488 |
+
if curr.has_large_inner_fn():
|
| 1489 |
+
result.realize()
|
| 1490 |
+
|
| 1491 |
+
# This is not complete, but it doesn't have to be: origin_node
|
| 1492 |
+
# tracking is best effort. The logic here critically relies on direct
|
| 1493 |
+
# TensorBox -> StorageBox denoting a non-view; we don't bother trying
|
| 1494 |
+
# to get views to work. Feel free to add any extra cases as needed.
|
| 1495 |
+
#
|
| 1496 |
+
# Note: we can't YOLO tree_map over this result, because if there are
|
| 1497 |
+
# buffers or a view involved, we might not be able to validly assign
|
| 1498 |
+
# the origin_node here.
|
| 1499 |
+
if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox):
|
| 1500 |
+
if isinstance(result.data.data, ir.Loops):
|
| 1501 |
+
result.data.data.origin_node = n
|
| 1502 |
+
elif isinstance(result.data.data, ir.Buffer):
|
| 1503 |
+
result.data.data.origin_node = n
|
| 1504 |
+
if isinstance(result.data.data, ir.ComputedBuffer) and isinstance(
|
| 1505 |
+
result.data.data.data, ir.Loops
|
| 1506 |
+
):
|
| 1507 |
+
result.data.data.data.origin_node = n
|
| 1508 |
+
# Not really multi-output, can straightforwardly recurse in
|
| 1509 |
+
elif (
|
| 1510 |
+
isinstance(result.data.data, ir.MultiOutput)
|
| 1511 |
+
and not result.data.data.indices
|
| 1512 |
+
):
|
| 1513 |
+
if isinstance(result.data.data.inputs[0], ir.Buffer):
|
| 1514 |
+
result.data.data.inputs[0].origin_node = n
|
| 1515 |
+
|
| 1516 |
+
self.register_users_of(result)
|
| 1517 |
+
|
| 1518 |
+
new_unbacked_defs: OrderedSet[sympy.Symbol] = OrderedSet()
|
| 1519 |
+
for buf in self.buffers[buffer_watermark:]:
|
| 1520 |
+
new_unbacked_defs |= buf.get_unbacked_symbol_defs()
|
| 1521 |
+
for op in self.operations[operation_watermark:]:
|
| 1522 |
+
new_unbacked_defs |= op.get_unbacked_symbol_defs()
|
| 1523 |
+
|
| 1524 |
+
def format_new_defs() -> str:
|
| 1525 |
+
r = []
|
| 1526 |
+
for buf in self.buffers[buffer_watermark:]:
|
| 1527 |
+
r.append(
|
| 1528 |
+
f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n"
|
| 1529 |
+
)
|
| 1530 |
+
for op in self.operations[operation_watermark:]:
|
| 1531 |
+
r.append(
|
| 1532 |
+
f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n"
|
| 1533 |
+
)
|
| 1534 |
+
return "***\n".join(r)
|
| 1535 |
+
|
| 1536 |
+
if n.op != "placeholder":
|
| 1537 |
+
# Note [Backwards runtime asserts]
|
| 1538 |
+
# Backwards poses an interesting problem for deferred runtime
|
| 1539 |
+
# asserts. In the easy case, we may solely close over data
|
| 1540 |
+
# dependent sized tensors, and there are no binding sites for
|
| 1541 |
+
# unbacked SymInts. In this case, we can just drop all the
|
| 1542 |
+
# runtime asserts on the floor: no non-placeholder bindings, no
|
| 1543 |
+
# problem.
|
| 1544 |
+
#
|
| 1545 |
+
# However, it is *possible* for a fresh runtime assert to show up
|
| 1546 |
+
# between forwards and backwards. Right now, the freezing process
|
| 1547 |
+
# that happens when we lower forwards means that we will freeze
|
| 1548 |
+
# runtime asserts, and then the moment the backwards lowering
|
| 1549 |
+
# process attempts to add a new deferred runtime assert, we will
|
| 1550 |
+
# fail. Let's say you remove that assert. Now when we get here,
|
| 1551 |
+
# we need to make sure we actually emit these asserts (because we
|
| 1552 |
+
# can't emit them in forwards, we already compiled it). So we
|
| 1553 |
+
# have to do something here. But we don't want to reemit ALL
|
| 1554 |
+
# deferred runtime asserts, we only want to emit the NEW ones.
|
| 1555 |
+
# Therefore needing some sort of stratification in the ShapeEnv.
|
| 1556 |
+
# This is all doable, it just hasn't been done yet.
|
| 1557 |
+
shape_env = V.graph.sizevars.shape_env
|
| 1558 |
+
|
| 1559 |
+
def make_assert(expr: Expr, msg: str) -> None:
|
| 1560 |
+
assert_op = ir.AssertScalar(expr, msg)
|
| 1561 |
+
self.register_buffer(assert_op, set_name=True)
|
| 1562 |
+
self.register_operation(assert_op)
|
| 1563 |
+
|
| 1564 |
+
for i0 in new_unbacked_defs:
|
| 1565 |
+
ras = self.ras_by_symbol.pop(i0, [])
|
| 1566 |
+
# NB: size-like not needed, we won't retrace
|
| 1567 |
+
vr = shape_env.var_to_range[i0]
|
| 1568 |
+
if not shape_env._default_unspecified_value_range().issubset(vr):
|
| 1569 |
+
|
| 1570 |
+
def is_convertible(s: Expr) -> bool:
|
| 1571 |
+
if s in (int_oo, -int_oo):
|
| 1572 |
+
return False
|
| 1573 |
+
try:
|
| 1574 |
+
int(s)
|
| 1575 |
+
return True
|
| 1576 |
+
except TypeError:
|
| 1577 |
+
return False
|
| 1578 |
+
|
| 1579 |
+
if is_convertible(vr.lower):
|
| 1580 |
+
make_assert(i0 >= vr.lower, f"{i0} >= {vr.lower}")
|
| 1581 |
+
if is_convertible(vr.upper):
|
| 1582 |
+
make_assert(i0 <= vr.upper, f"{i0} <= {vr.upper}")
|
| 1583 |
+
|
| 1584 |
+
for ra in ras:
|
| 1585 |
+
fvs = free_unbacked_symbols(ra.expr)
|
| 1586 |
+
missing = fvs - self.bound_unbacked_symbols
|
| 1587 |
+
if missing:
|
| 1588 |
+
i1 = min(missing, key=str)
|
| 1589 |
+
self.ras_by_symbol.setdefault(i1, []).append(ra)
|
| 1590 |
+
else:
|
| 1591 |
+
make_assert(ra.expr, f"{ra.expr}")
|
| 1592 |
+
|
| 1593 |
+
self.bound_unbacked_symbols |= new_unbacked_defs
|
| 1594 |
+
|
| 1595 |
+
unbacked_bindings = resolve_unbacked_bindings(
|
| 1596 |
+
V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {})
|
| 1597 |
+
)
|
| 1598 |
+
# When we do lowering, it is possible we reallocate unbacked SymInts.
|
| 1599 |
+
# So we need to line up the unbacked SymInts when performing the test
|
| 1600 |
+
# here
|
| 1601 |
+
#
|
| 1602 |
+
# In principle, we could permit lowering to introduce MORE unbacked
|
| 1603 |
+
# SymInts: as long as all the old unbacked ones are accounted for,
|
| 1604 |
+
# it's fine for inductor to introduce extra calls to item()/unbacked()
|
| 1605 |
+
# whatever. This actually happens in practice when an unbacked SymInt
|
| 1606 |
+
# gets memoized away; naively, when Inductor reprocesses a kernel, it
|
| 1607 |
+
# doesn't know that the memo still applies, and ends up allocating a
|
| 1608 |
+
# new symbol. However, this is generally a bad thing: we may still
|
| 1609 |
+
# end up needing to test equalities on the symbols, and a fresh
|
| 1610 |
+
# symbol is likely to hit lots of GuardOnDataDependent errors that
|
| 1611 |
+
# we already know facts for.
|
| 1612 |
+
renamed_unbacked_bindings = OrderedSet(
|
| 1613 |
+
V.fake_mode.shape_env.unbacked_renamings.get(s, s)
|
| 1614 |
+
for s in unbacked_bindings.keys()
|
| 1615 |
+
)
|
| 1616 |
+
assert new_unbacked_defs >= renamed_unbacked_bindings, (
|
| 1617 |
+
f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n"
|
| 1618 |
+
f"fx node is: {n.format_node()}\n"
|
| 1619 |
+
f"new operations are:\n\n{format_new_defs()}"
|
| 1620 |
+
)
|
| 1621 |
+
|
| 1622 |
+
return result
|
| 1623 |
+
|
| 1624 |
+
def validate_can_generate_cpp_wrapper(self) -> None:
|
| 1625 |
+
if config.disable_cpp_codegen:
|
| 1626 |
+
raise CppWrapperCodeGenError("C++ codegen is disabled")
|
| 1627 |
+
|
| 1628 |
+
if sys.platform not in ["linux", "darwin", "win32"]:
|
| 1629 |
+
raise CppWrapperCodeGenError(f"Unsupported platform {sys.platform}")
|
| 1630 |
+
|
| 1631 |
+
for value in self.graph_inputs.values():
|
| 1632 |
+
dtype = None
|
| 1633 |
+
if isinstance(value, TensorBox):
|
| 1634 |
+
dtype = value.get_dtype()
|
| 1635 |
+
elif isinstance(
|
| 1636 |
+
value, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
|
| 1637 |
+
):
|
| 1638 |
+
dtype = may_get_constant_buffer_dtype(value)
|
| 1639 |
+
|
| 1640 |
+
if not supported_dtype_of_cpp_wrapper(dtype, self.cuda):
|
| 1641 |
+
raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}")
|
| 1642 |
+
|
| 1643 |
+
def init_wrapper_code(self) -> None:
|
| 1644 |
+
self.cuda = "cuda" in self.device_types
|
| 1645 |
+
if self.cpp_wrapper:
|
| 1646 |
+
self.validate_can_generate_cpp_wrapper()
|
| 1647 |
+
|
| 1648 |
+
device_types = self.device_types.copy()
|
| 1649 |
+
device_types.discard("cpu")
|
| 1650 |
+
device_types.discard("meta")
|
| 1651 |
+
# TODO(Eikan): Only support mixing cpu and other device now.
|
| 1652 |
+
assert len(device_types) <= 1, "Does not support mixing {}".format(
|
| 1653 |
+
"+".join(device_types)
|
| 1654 |
+
)
|
| 1655 |
+
only_cpu = len(device_types) == 0
|
| 1656 |
+
device_type = "cpu" if only_cpu else device_types.pop()
|
| 1657 |
+
|
| 1658 |
+
self.device_ops = get_device_op_overrides(device_type)
|
| 1659 |
+
wrapper_code_gen_cls = get_wrapper_codegen_for_device(
|
| 1660 |
+
device_type, self.cpp_wrapper
|
| 1661 |
+
)
|
| 1662 |
+
assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported"
|
| 1663 |
+
self.wrapper_code = wrapper_code_gen_cls()
|
| 1664 |
+
|
| 1665 |
+
if self.const_module:
|
| 1666 |
+
# If we have const module, we could reuse the kernels
|
| 1667 |
+
# This could avoid duplication and save time on doing recompilation (if Triton.)
|
| 1668 |
+
self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter
|
| 1669 |
+
self.wrapper_code.src_to_kernel = (
|
| 1670 |
+
self.const_module.wrapper_code.src_to_kernel
|
| 1671 |
+
)
|
| 1672 |
+
|
| 1673 |
+
def codegen_with_cpp_wrapper(self) -> Tuple[str, List[Tuple[int, Node]]]:
|
| 1674 |
+
"""
|
| 1675 |
+
For CPU, the cpp wrapper codegen is done in one pass.
|
| 1676 |
+
For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python
|
| 1677 |
+
wrapper code and run it to generate autotuned kernel binaries in the first pass; and then
|
| 1678 |
+
generate cpp wrapper code and compile it to a dynamic library in the second pass.
|
| 1679 |
+
"""
|
| 1680 |
+
if "cuda" in self.device_types:
|
| 1681 |
+
# first pass
|
| 1682 |
+
self.cpp_wrapper = False
|
| 1683 |
+
# Although triton.store_cubin was OrderedSet in compile_fx, the backward pass didn't pick
|
| 1684 |
+
# that up. In theory it should work by only setting triton.store_cubin to True here,
|
| 1685 |
+
# but that will cause a problem when use_runtime_constant_folding is OrderedSet.
|
| 1686 |
+
with config.patch({"triton.store_cubin": True}):
|
| 1687 |
+
compiled = self.compile_to_module().call
|
| 1688 |
+
|
| 1689 |
+
if not config.triton.autotune_at_compile_time:
|
| 1690 |
+
|
| 1691 |
+
def materialize(
|
| 1692 |
+
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor]
|
| 1693 |
+
) -> Union[int, float, torch.Tensor]:
|
| 1694 |
+
if x is None:
|
| 1695 |
+
return None
|
| 1696 |
+
elif isinstance(x, (torch.SymInt, torch.SymFloat)):
|
| 1697 |
+
# Need concrete value to run dynamic shapes and tune the result
|
| 1698 |
+
return x.node.hint
|
| 1699 |
+
elif isinstance(x, FakeTensor):
|
| 1700 |
+
return defake(x)
|
| 1701 |
+
else:
|
| 1702 |
+
assert isinstance(
|
| 1703 |
+
x, torch.Tensor
|
| 1704 |
+
), "Unknown type when creating real inputs" + str(type(x))
|
| 1705 |
+
return x
|
| 1706 |
+
|
| 1707 |
+
tracing_context = torch._guards.TracingContext.try_get()
|
| 1708 |
+
if tracing_context is not None and not isinstance(
|
| 1709 |
+
V.real_inputs, NullHandler
|
| 1710 |
+
):
|
| 1711 |
+
if tracing_context.output_strides:
|
| 1712 |
+
tracing_context.output_strides.clear()
|
| 1713 |
+
|
| 1714 |
+
params_flat = [
|
| 1715 |
+
param
|
| 1716 |
+
for param in tracing_context.params_flat # type: ignore[union-attr]
|
| 1717 |
+
if param is not None
|
| 1718 |
+
]
|
| 1719 |
+
real_inputs = [
|
| 1720 |
+
materialize(x)
|
| 1721 |
+
for x in itertools.chain(params_flat, V.real_inputs)
|
| 1722 |
+
]
|
| 1723 |
+
else:
|
| 1724 |
+
# In the backward pass, V.real_inputs is not OrderedSet.
|
| 1725 |
+
# Generating random inputs based on self.example_inputs sometimes can be problematic,
|
| 1726 |
+
# e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
|
| 1727 |
+
real_inputs = [
|
| 1728 |
+
materialize(x)
|
| 1729 |
+
for x in (
|
| 1730 |
+
self.example_inputs
|
| 1731 |
+
if isinstance(V.real_inputs, NullHandler)
|
| 1732 |
+
else V.real_inputs
|
| 1733 |
+
)
|
| 1734 |
+
]
|
| 1735 |
+
|
| 1736 |
+
if self.mutated_inputs:
|
| 1737 |
+
from .compile_fx import clone_preserve_strides
|
| 1738 |
+
|
| 1739 |
+
mutated_input_idxs = [
|
| 1740 |
+
idx
|
| 1741 |
+
for idx, name in enumerate(self.graph_inputs)
|
| 1742 |
+
if name in self.mutated_inputs
|
| 1743 |
+
and isinstance(real_inputs[idx], torch.Tensor)
|
| 1744 |
+
]
|
| 1745 |
+
for idx in mutated_input_idxs:
|
| 1746 |
+
# clone mutated Tensor inputs to avoid mutating them in
|
| 1747 |
+
# the first pass of the CPP wrapper-based compilation, as
|
| 1748 |
+
# this will lead to a side effect on the example inputs:
|
| 1749 |
+
# e.g. if torch.compile(f)(x) if called on input-mutating
|
| 1750 |
+
# f, the inputs x will be mutated twice in the process:
|
| 1751 |
+
# once here, and again when running the compiled model;
|
| 1752 |
+
# this will also lead to a numerically incorrect output
|
| 1753 |
+
mutated_inp = real_inputs[idx]
|
| 1754 |
+
assert isinstance(mutated_inp, torch.Tensor)
|
| 1755 |
+
real_inputs[idx] = clone_preserve_strides(mutated_inp)
|
| 1756 |
+
del mutated_inp
|
| 1757 |
+
|
| 1758 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 1759 |
+
compiled(real_inputs)
|
| 1760 |
+
del real_inputs
|
| 1761 |
+
|
| 1762 |
+
# second pass
|
| 1763 |
+
self.cpp_wrapper = True
|
| 1764 |
+
self.removed_buffers.clear()
|
| 1765 |
+
self.removed_operations.clear()
|
| 1766 |
+
self.inplaced_to_remove.clear()
|
| 1767 |
+
V.graph.sizevars.precomputed_replacements.clear()
|
| 1768 |
+
V.graph.sizevars.inv_precomputed_replacements.clear()
|
| 1769 |
+
with config.patch({"triton.autotune_at_compile_time": False}):
|
| 1770 |
+
return self.codegen()
|
| 1771 |
+
else:
|
| 1772 |
+
# cpu
|
| 1773 |
+
return self.codegen()
|
| 1774 |
+
|
| 1775 |
+
def codegen(self) -> Tuple[str, List[Tuple[int, Node]]]:
|
| 1776 |
+
from .scheduler import Scheduler
|
| 1777 |
+
|
| 1778 |
+
self.init_wrapper_code()
|
| 1779 |
+
|
| 1780 |
+
self.scheduler = Scheduler(self.operations)
|
| 1781 |
+
V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
|
| 1782 |
+
|
| 1783 |
+
self.wrapper_code.push_codegened_graph(self)
|
| 1784 |
+
self.scheduler.codegen()
|
| 1785 |
+
|
| 1786 |
+
log.debug(
|
| 1787 |
+
"Finished codegen for all nodes. The list of kernel names available: %s",
|
| 1788 |
+
V.graph.all_codegen_kernel_names,
|
| 1789 |
+
)
|
| 1790 |
+
|
| 1791 |
+
result = self.wrapper_code.generate(self.is_inference)
|
| 1792 |
+
self.wrapper_code.pop_codegened_graph()
|
| 1793 |
+
return result
|
| 1794 |
+
|
| 1795 |
+
def codegen_subgraph(self, parent_graph: "GraphLowering") -> None:
|
| 1796 |
+
"""
|
| 1797 |
+
This is a more compact version of the `codegen()` above
|
| 1798 |
+
where we codegen this graph as a subgraph of some parent
|
| 1799 |
+
graph. The parent graph is passed as an argument: the
|
| 1800 |
+
intention is to inline codegening of the subgraph in
|
| 1801 |
+
the parent graph's wrapper code (including the generated
|
| 1802 |
+
kerenls). The wrapper code is not finalized (via `.generate()`
|
| 1803 |
+
call), as this will be done in the parent graph's `codegen()`.
|
| 1804 |
+
"""
|
| 1805 |
+
from .scheduler import Scheduler
|
| 1806 |
+
|
| 1807 |
+
self.wrapper_code = parent_graph.wrapper_code
|
| 1808 |
+
self.device_ops = parent_graph.device_ops
|
| 1809 |
+
self.cpp_wrapper = parent_graph.cpp_wrapper
|
| 1810 |
+
|
| 1811 |
+
self.scheduler = Scheduler(self.operations)
|
| 1812 |
+
self.scheduler.codegen()
|
| 1813 |
+
|
| 1814 |
+
def count_bytes(
|
| 1815 |
+
self,
|
| 1816 |
+
) -> Tuple[
|
| 1817 |
+
int, List[Tuple[BaseSchedulerNode, int]], List[Tuple[BaseSchedulerNode, float]]
|
| 1818 |
+
]:
|
| 1819 |
+
total_bytes = 0
|
| 1820 |
+
node_counts = []
|
| 1821 |
+
node_runtimes = []
|
| 1822 |
+
for node in self.scheduler.nodes:
|
| 1823 |
+
num_bytes = node.get_read_write_buffers_sizes()
|
| 1824 |
+
total_bytes += num_bytes
|
| 1825 |
+
node_counts.append((node, num_bytes // 4))
|
| 1826 |
+
node_runtimes.append((node, node.get_estimated_runtime()))
|
| 1827 |
+
|
| 1828 |
+
return total_bytes, node_counts, node_runtimes
|
| 1829 |
+
|
| 1830 |
+
@staticmethod
|
| 1831 |
+
def save_output_code(code: str) -> None:
|
| 1832 |
+
# No-op to be patched for unit tests
|
| 1833 |
+
pass
|
| 1834 |
+
|
| 1835 |
+
def compile_to_module(self) -> ModuleType:
|
| 1836 |
+
with dynamo_timed(
|
| 1837 |
+
"GraphLowering.compile_to_module", phase_name="code_gen", fwd_only=False
|
| 1838 |
+
):
|
| 1839 |
+
return self._compile_to_module()
|
| 1840 |
+
|
| 1841 |
+
def _compile_to_module(self) -> ModuleType:
|
| 1842 |
+
from .codecache import PyCodeCache
|
| 1843 |
+
|
| 1844 |
+
code, linemap = (
|
| 1845 |
+
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
|
| 1846 |
+
)
|
| 1847 |
+
|
| 1848 |
+
GraphLowering.save_output_code(code)
|
| 1849 |
+
output_code_log.debug("Output code: \n%s", code)
|
| 1850 |
+
try:
|
| 1851 |
+
linemap = [(line_no, node.stack_trace) for line_no, node in linemap] # type: ignore[misc]
|
| 1852 |
+
key, path = PyCodeCache.write(code)
|
| 1853 |
+
except Exception:
|
| 1854 |
+
trace_structured(
|
| 1855 |
+
"inductor_output_code",
|
| 1856 |
+
# Just omit the filename, I still want the code though!
|
| 1857 |
+
payload_fn=lambda: code,
|
| 1858 |
+
)
|
| 1859 |
+
raise
|
| 1860 |
+
else:
|
| 1861 |
+
trace_structured(
|
| 1862 |
+
"inductor_output_code",
|
| 1863 |
+
lambda: {"filename": path},
|
| 1864 |
+
payload_fn=lambda: code,
|
| 1865 |
+
)
|
| 1866 |
+
|
| 1867 |
+
mod = PyCodeCache.load_by_key_path(
|
| 1868 |
+
key,
|
| 1869 |
+
path,
|
| 1870 |
+
linemap=linemap, # type: ignore[arg-type]
|
| 1871 |
+
attrs={**self.constants, **self.torchbind_constants},
|
| 1872 |
+
)
|
| 1873 |
+
self.cache_key = key
|
| 1874 |
+
self.cache_path = path
|
| 1875 |
+
self.cache_linemap = linemap # type: ignore[assignment]
|
| 1876 |
+
|
| 1877 |
+
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
|
| 1878 |
+
# TODO. Revisit this once the logging API is more mature
|
| 1879 |
+
assert mod.__file__ is not None
|
| 1880 |
+
|
| 1881 |
+
log_module_code(mod.__file__)
|
| 1882 |
+
log.debug("Output code written to: %s", mod.__file__)
|
| 1883 |
+
output_code_log.info("Output code written to: %s", mod.__file__)
|
| 1884 |
+
if config.benchmark_kernel:
|
| 1885 |
+
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
|
| 1886 |
+
V.debug.output_code(mod.__file__)
|
| 1887 |
+
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
|
| 1888 |
+
return mod
|
| 1889 |
+
|
| 1890 |
+
def compile_to_fn(self) -> Any:
|
| 1891 |
+
if self.aot_mode:
|
| 1892 |
+
from .codecache import AotCodeCompiler
|
| 1893 |
+
|
| 1894 |
+
assert self.cpp_wrapper, "AOT mode only supports C++ wrapper"
|
| 1895 |
+
code, linemap = self.codegen_with_cpp_wrapper()
|
| 1896 |
+
output_code_log.debug("Output code: \n%s", code)
|
| 1897 |
+
|
| 1898 |
+
serialized_extern_kernel_nodes = None
|
| 1899 |
+
if self.extern_kernel_nodes:
|
| 1900 |
+
serialized_extern_kernel_nodes = self.extern_node_serializer(
|
| 1901 |
+
self.extern_kernel_nodes
|
| 1902 |
+
)
|
| 1903 |
+
output_code_log.debug(
|
| 1904 |
+
"Serialized Extern Kernel Nodes: \n%s",
|
| 1905 |
+
serialized_extern_kernel_nodes,
|
| 1906 |
+
)
|
| 1907 |
+
|
| 1908 |
+
# Directly return the file path with the compiled code
|
| 1909 |
+
return AotCodeCompiler.compile(
|
| 1910 |
+
self, code, serialized_extern_kernel_nodes, cuda=self.cuda
|
| 1911 |
+
)
|
| 1912 |
+
else:
|
| 1913 |
+
return self.compile_to_module().call
|
| 1914 |
+
|
| 1915 |
+
def get_output_names(self) -> List[str]:
|
| 1916 |
+
return [
|
| 1917 |
+
node.get_name()
|
| 1918 |
+
for node in self.graph_outputs
|
| 1919 |
+
if not isinstance(node, ir.NoneAsConstantBuffer)
|
| 1920 |
+
and not isinstance(node, ir.ShapeAsConstantBuffer)
|
| 1921 |
+
]
|
| 1922 |
+
|
| 1923 |
+
def is_unspec_arg(self, name: str) -> bool:
|
| 1924 |
+
# dynamo wraps unspec variable as 0d CPU tensor,
|
| 1925 |
+
# need to convert to scalar during codegen (triton only)
|
| 1926 |
+
return (
|
| 1927 |
+
name in self.graph_inputs.keys()
|
| 1928 |
+
and self.graph_inputs[name].get_numel() == 1
|
| 1929 |
+
and self.graph_inputs[name].get_device().type == "cpu"
|
| 1930 |
+
) or name in self.zero_dim_cpu_tensor_list
|
.venv/lib/python3.11/site-packages/torch/_inductor/hooks.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
from typing import Callable, List, TYPE_CHECKING
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
if TYPE_CHECKING:
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
# Executed in the order they're registered
|
| 10 |
+
INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = []
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@contextlib.contextmanager
|
| 14 |
+
def intermediate_hook(fn):
|
| 15 |
+
INTERMEDIATE_HOOKS.append(fn)
|
| 16 |
+
try:
|
| 17 |
+
yield
|
| 18 |
+
finally:
|
| 19 |
+
INTERMEDIATE_HOOKS.pop()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def run_intermediate_hooks(name, val):
|
| 23 |
+
global INTERMEDIATE_HOOKS
|
| 24 |
+
hooks = INTERMEDIATE_HOOKS
|
| 25 |
+
INTERMEDIATE_HOOKS = []
|
| 26 |
+
try:
|
| 27 |
+
for hook in hooks:
|
| 28 |
+
hook(name, val)
|
| 29 |
+
finally:
|
| 30 |
+
INTERMEDIATE_HOOKS = hooks
|
.venv/lib/python3.11/site-packages/torch/_inductor/index_propagation.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
"""This file implements the IndexPropagation ops handler, which wraps an
|
| 3 |
+
underlying handler to add a limited form of constant propagation, as well as
|
| 4 |
+
propagation of sympy expressions downstream of ops.index_expr calls.
|
| 5 |
+
|
| 6 |
+
For example, say we have the IR:
|
| 7 |
+
|
| 8 |
+
tmp0 = ops.index_expr(x, torch.int32)
|
| 9 |
+
tmp1 = ops.constant(2, torch.int32)
|
| 10 |
+
tmp2 = ops.mul(tmp0, tmp1)
|
| 11 |
+
tmp3 = ops.indirect_indexing(tmp2, x_size)
|
| 12 |
+
tmp4 = ops.load("buf0", tmp3)
|
| 13 |
+
|
| 14 |
+
The underlying handler would just see:
|
| 15 |
+
|
| 16 |
+
ops.load("buf0", x * 2)
|
| 17 |
+
|
| 18 |
+
This is limited by the set of operators handled in the sympy expression
|
| 19 |
+
printers. So simple operations like minimum and maximum cannot be translated to
|
| 20 |
+
SymPy expressions yet, despite sympy.Min and sympy.Max existing.
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
import itertools
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from typing import Any, Callable, Dict, Literal, Optional, overload, Tuple, Union
|
| 26 |
+
from typing_extensions import TypeAlias
|
| 27 |
+
|
| 28 |
+
import sympy
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
from torch._prims_common import dtype_to_type, is_integer_dtype
|
| 32 |
+
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
|
| 33 |
+
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
| 34 |
+
|
| 35 |
+
from .sizevars import evaluate_expr
|
| 36 |
+
from .utils import generate_assert
|
| 37 |
+
from .virtualized import V
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
_ExprType = Union[sympy.Expr, float, int, bool]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _is_constant(val: _ExprType):
|
| 44 |
+
if isinstance(val, sympy.Basic):
|
| 45 |
+
return val.is_number
|
| 46 |
+
return isinstance(val, (int, float, bool))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def upper_bound(val: _ExprType):
|
| 50 |
+
return bound_sympy(val).upper if isinstance(val, sympy.Expr) else val
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class TypedExpr:
|
| 55 |
+
"""A SymPy expression with associated type"""
|
| 56 |
+
|
| 57 |
+
expr: _ExprType
|
| 58 |
+
dtype: torch.dtype
|
| 59 |
+
|
| 60 |
+
def is_constant(self):
|
| 61 |
+
return _is_constant(self.expr)
|
| 62 |
+
|
| 63 |
+
def __post_init__(self):
|
| 64 |
+
if _is_constant(self.expr):
|
| 65 |
+
self.expr = dtype_to_type(self.dtype)(self.expr)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class SymPyOps:
|
| 69 |
+
"""An ops handler where all IR values are SymPy expressions
|
| 70 |
+
|
| 71 |
+
When a value cannot be represented as a SymPy expression, the method is
|
| 72 |
+
either not defined, or returns NotImplemented
|
| 73 |
+
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def identity(value: Any) -> Any:
|
| 78 |
+
return value
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr:
|
| 82 |
+
return TypedExpr(value, dtype)
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def index_expr(value: Union[sympy.Expr, int], dtype: torch.dtype) -> TypedExpr:
|
| 86 |
+
return TypedExpr(value, dtype)
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def to_dtype(
|
| 90 |
+
value: TypedExpr,
|
| 91 |
+
dtype: torch.dtype,
|
| 92 |
+
src_dtype: Optional[torch.dtype] = None,
|
| 93 |
+
use_compute_types: bool = False,
|
| 94 |
+
) -> TypedExpr:
|
| 95 |
+
return TypedExpr(value.expr, dtype)
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def abs(x: TypedExpr) -> TypedExpr:
|
| 99 |
+
return TypedExpr(abs(x.expr), x.dtype) # type: ignore[arg-type]
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def square(x: TypedExpr) -> TypedExpr:
|
| 103 |
+
return TypedExpr(x.expr * x.expr, x.dtype)
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def add(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
| 107 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 108 |
+
return TypedExpr(x.expr + y.expr, result_type)
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
| 112 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 113 |
+
return TypedExpr(x.expr - y.expr, result_type)
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
| 117 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 118 |
+
return TypedExpr(x.expr * y.expr, result_type)
|
| 119 |
+
|
| 120 |
+
@staticmethod
|
| 121 |
+
def neg(x: TypedExpr) -> TypedExpr:
|
| 122 |
+
return TypedExpr(-x.expr, x.dtype)
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
| 126 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 127 |
+
if not is_integer_dtype(result_type):
|
| 128 |
+
return NotImplemented
|
| 129 |
+
|
| 130 |
+
return TypedExpr(FloorDiv(x.expr, y.expr), result_type)
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
|
| 134 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 135 |
+
if not is_integer_dtype(result_type):
|
| 136 |
+
return NotImplemented
|
| 137 |
+
|
| 138 |
+
result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
|
| 139 |
+
return TypedExpr(result_expr, result_type)
|
| 140 |
+
|
| 141 |
+
@staticmethod
|
| 142 |
+
def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
|
| 143 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 144 |
+
if not is_integer_dtype(result_type):
|
| 145 |
+
return NotImplemented
|
| 146 |
+
|
| 147 |
+
x_expr = sympy.sympify(x.expr)
|
| 148 |
+
y_expr = sympy.sympify(y.expr)
|
| 149 |
+
# In these cases, remainder in Python == remainder in C++, so this transformation
|
| 150 |
+
# is sound
|
| 151 |
+
if (
|
| 152 |
+
x_expr.is_nonnegative is not None
|
| 153 |
+
and x_expr.is_nonnegative == y_expr.is_positive
|
| 154 |
+
):
|
| 155 |
+
result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
|
| 156 |
+
return TypedExpr(result_expr, result_type)
|
| 157 |
+
return NotImplemented
|
| 158 |
+
|
| 159 |
+
@staticmethod
|
| 160 |
+
def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
| 161 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 162 |
+
return TypedExpr(sympy.Min(x.expr, y.expr), result_type)
|
| 163 |
+
|
| 164 |
+
@staticmethod
|
| 165 |
+
def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
| 166 |
+
result_type = torch.promote_types(x.dtype, y.dtype)
|
| 167 |
+
return TypedExpr(sympy.Max(x.expr, y.expr), result_type)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@dataclass
|
| 171 |
+
class IndexPropVar:
|
| 172 |
+
value: Any # Either an IR value, or TypedExpr if is_symbolic is true
|
| 173 |
+
is_symbolic: bool = False
|
| 174 |
+
|
| 175 |
+
@staticmethod
|
| 176 |
+
def new_symbolic(expr: TypedExpr) -> "IndexPropVar":
|
| 177 |
+
return IndexPropVar(expr, is_symbolic=True)
|
| 178 |
+
|
| 179 |
+
def __post_init__(self):
|
| 180 |
+
assert not self.is_symbolic or isinstance(
|
| 181 |
+
self.value, TypedExpr
|
| 182 |
+
), "Symbolic IndexPropVar must contain a TypedExpr"
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
IndexPropResult: TypeAlias = Union[IndexPropVar, Tuple["IndexPropResult", ...]]
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class IndexPropagation:
|
| 189 |
+
"""Ops wrapper that tries to propagate constant and index_expr values through the computation.
|
| 190 |
+
|
| 191 |
+
This aims to maximize the compile time simplification possible, and convert
|
| 192 |
+
indirect indexing from arange into normal static indexing.
|
| 193 |
+
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
def __init__(
|
| 197 |
+
self,
|
| 198 |
+
inner: Any,
|
| 199 |
+
iter_ranges: Dict[sympy.Symbol, sympy.Expr],
|
| 200 |
+
indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr],
|
| 201 |
+
) -> None:
|
| 202 |
+
self._inner = inner
|
| 203 |
+
self.shape_env = V.graph.sizevars.shape_env
|
| 204 |
+
|
| 205 |
+
var_to_range = {
|
| 206 |
+
k: ValueRanges(0, upper_bound(v) - 1) for k, v in iter_ranges.items()
|
| 207 |
+
}
|
| 208 |
+
self.var_to_range = tuple(
|
| 209 |
+
itertools.chain(self.shape_env.var_to_range.items(), var_to_range.items())
|
| 210 |
+
)
|
| 211 |
+
# NOTE: this is intentionally kept as a reference so the caller can
|
| 212 |
+
# update it in-place
|
| 213 |
+
self.indirect_var_ranges = indirect_var_ranges
|
| 214 |
+
|
| 215 |
+
axioms = []
|
| 216 |
+
for x, s in iter_ranges.items():
|
| 217 |
+
axioms.append(0 <= x)
|
| 218 |
+
axioms.append(x < s)
|
| 219 |
+
self.axioms = tuple(axioms) + self.shape_env.get_axioms()
|
| 220 |
+
|
| 221 |
+
def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any:
|
| 222 |
+
# Construct a new constant/index_expr from the SymPy expression
|
| 223 |
+
if _is_constant(expr):
|
| 224 |
+
val = dtype_to_type(dtype)(expr)
|
| 225 |
+
return self._inner.constant(val, dtype)
|
| 226 |
+
return self._inner.index_expr(expr, dtype)
|
| 227 |
+
|
| 228 |
+
def unwrap(self, a: Union[Any, IndexPropVar]) -> Any:
|
| 229 |
+
if isinstance(a, (list, tuple)):
|
| 230 |
+
return tuple(self.unwrap(v) for v in a)
|
| 231 |
+
|
| 232 |
+
if not isinstance(a, IndexPropVar):
|
| 233 |
+
return a
|
| 234 |
+
|
| 235 |
+
# Prefer the sympy representation if possible
|
| 236 |
+
if a.is_symbolic:
|
| 237 |
+
return self.materialize_expr(a.value.expr, a.value.dtype)
|
| 238 |
+
|
| 239 |
+
return a.value
|
| 240 |
+
|
| 241 |
+
def wrap(self, a) -> IndexPropResult:
|
| 242 |
+
if isinstance(a, (list, tuple)):
|
| 243 |
+
return tuple(self.wrap(v) for v in a)
|
| 244 |
+
return IndexPropVar(a)
|
| 245 |
+
|
| 246 |
+
@overload
|
| 247 |
+
def fallback(
|
| 248 |
+
self,
|
| 249 |
+
name: Literal["indirect_indexing"],
|
| 250 |
+
args: Tuple[Any, ...],
|
| 251 |
+
kwargs: Dict[str, Any],
|
| 252 |
+
) -> IndexPropVar:
|
| 253 |
+
...
|
| 254 |
+
|
| 255 |
+
@overload
|
| 256 |
+
def fallback(
|
| 257 |
+
self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
| 258 |
+
) -> IndexPropResult:
|
| 259 |
+
...
|
| 260 |
+
|
| 261 |
+
def fallback(
|
| 262 |
+
self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
| 263 |
+
) -> IndexPropResult:
|
| 264 |
+
# Fallback to the wrapped handler
|
| 265 |
+
new_args = [self.unwrap(a) for a in args]
|
| 266 |
+
new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()}
|
| 267 |
+
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
|
| 268 |
+
|
| 269 |
+
def propagate_sympy(
|
| 270 |
+
self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
| 271 |
+
) -> IndexPropResult:
|
| 272 |
+
# Build a new SymPy expression from this ops call
|
| 273 |
+
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
|
| 274 |
+
if not isinstance(a, IndexPropVar):
|
| 275 |
+
return a
|
| 276 |
+
return a.value
|
| 277 |
+
|
| 278 |
+
new_args = [unwrap(a) for a in args]
|
| 279 |
+
new_kwargs = {k: unwrap(v) for k, v in kwargs.items()}
|
| 280 |
+
new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs)
|
| 281 |
+
is_valid_expr = new_expr is not NotImplemented and (
|
| 282 |
+
# Inductor doesn't expect floating point in sympy expressions, but
|
| 283 |
+
# allow floating point constants to be propagated
|
| 284 |
+
new_expr.is_constant()
|
| 285 |
+
or new_expr.expr.is_integer
|
| 286 |
+
)
|
| 287 |
+
if not is_valid_expr:
|
| 288 |
+
return self.fallback(name, args, kwargs)
|
| 289 |
+
return IndexPropVar.new_symbolic(new_expr)
|
| 290 |
+
|
| 291 |
+
def __getattr__(self, name: str) -> Callable[..., IndexPropResult]:
|
| 292 |
+
def inner(*args: Any, **kwargs: Any) -> IndexPropResult:
|
| 293 |
+
if not hasattr(SymPyOps, name):
|
| 294 |
+
return self.fallback(name, args, kwargs)
|
| 295 |
+
|
| 296 |
+
var_arguments = [
|
| 297 |
+
a
|
| 298 |
+
for a in itertools.chain(args, kwargs.values())
|
| 299 |
+
if isinstance(a, IndexPropVar)
|
| 300 |
+
]
|
| 301 |
+
if not all(v.is_symbolic for v in var_arguments):
|
| 302 |
+
return self.fallback(name, args, kwargs)
|
| 303 |
+
|
| 304 |
+
return self.propagate_sympy(name, args, kwargs)
|
| 305 |
+
|
| 306 |
+
return inner
|
| 307 |
+
|
| 308 |
+
def statically_true(self, e):
|
| 309 |
+
"""
|
| 310 |
+
Given some iter_ranges, return a function that given an expression, returns whether
|
| 311 |
+
it is true or false using value ranges, guard knowledge and runtime_asserts.
|
| 312 |
+
|
| 313 |
+
FIXME I think this may not be entirely right, as we may not be able to use all runtime_asserts
|
| 314 |
+
If this is an issue, just use guards in `self.axioms`.
|
| 315 |
+
|
| 316 |
+
The proper way of handling this would be to have a global shape_env that adds
|
| 317 |
+
runtime_asserts as they happen in the code. Then, it shuld be used in SimplifyIndexing
|
| 318 |
+
to perform wrap_expr and in CSEProxy.check_bounds to elide upper / lower bounds also
|
| 319 |
+
for indirect_indexing
|
| 320 |
+
"""
|
| 321 |
+
var_to_range = (
|
| 322 |
+
*self.var_to_range,
|
| 323 |
+
*(
|
| 324 |
+
(k, ValueRanges(0, upper_bound(v) - 1))
|
| 325 |
+
for k, v in self.indirect_var_ranges.items()
|
| 326 |
+
),
|
| 327 |
+
)
|
| 328 |
+
return evaluate_expr(self.shape_env, e, self.axioms, var_to_range)
|
| 329 |
+
|
| 330 |
+
def indirect_indexing(
|
| 331 |
+
self,
|
| 332 |
+
index: Union[Any, IndexPropVar],
|
| 333 |
+
size: Any,
|
| 334 |
+
check: bool = True,
|
| 335 |
+
wrap_neg=True,
|
| 336 |
+
) -> Any:
|
| 337 |
+
if isinstance(index, IndexPropVar) and index.is_symbolic:
|
| 338 |
+
# If we find something we can convert into a direct indexing we do so
|
| 339 |
+
# We still need to (perhaps) wrap the expression and add bound checks
|
| 340 |
+
# We want to do this "constant folding", as we don't allow to fuse
|
| 341 |
+
# kernels into indirect indexing
|
| 342 |
+
|
| 343 |
+
expr = sympy.sympify(index.value.expr)
|
| 344 |
+
|
| 345 |
+
# TODO Perhaps move this logic to the simplify indexing pass
|
| 346 |
+
def wrap_expr(expr):
|
| 347 |
+
# Positive, negative, mixed
|
| 348 |
+
if self.statically_true(0 <= expr):
|
| 349 |
+
return expr
|
| 350 |
+
elif self.statically_true(expr < 0):
|
| 351 |
+
return expr + size
|
| 352 |
+
else:
|
| 353 |
+
return Where(expr < 0, expr + size, expr)
|
| 354 |
+
|
| 355 |
+
# Sometimes it's easier to prove 0 <= expr than the weaker -size <= expr
|
| 356 |
+
can_prove_lower = self.statically_true(0 <= expr) or self.statically_true(
|
| 357 |
+
-size <= expr
|
| 358 |
+
)
|
| 359 |
+
can_prove_upper = self.statically_true(expr < size)
|
| 360 |
+
if wrap_neg:
|
| 361 |
+
expr = wrap_expr(expr)
|
| 362 |
+
if generate_assert(check):
|
| 363 |
+
self.fallback(
|
| 364 |
+
"check_bounds",
|
| 365 |
+
(expr, size),
|
| 366 |
+
dict(lower=not can_prove_lower, upper=not can_prove_upper),
|
| 367 |
+
)
|
| 368 |
+
return expr
|
| 369 |
+
|
| 370 |
+
indirect_var = self.fallback(
|
| 371 |
+
"indirect_indexing", (index, size, check, wrap_neg), {}
|
| 372 |
+
).value
|
| 373 |
+
return indirect_var
|
.venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Optional, Sequence
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import _prims, Tensor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
log = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def make_prim(
|
| 15 |
+
schema: str,
|
| 16 |
+
impl_aten,
|
| 17 |
+
return_type=_prims.RETURN_TYPE.NEW,
|
| 18 |
+
doc: str = "",
|
| 19 |
+
tags: Optional[Sequence[torch.Tag]] = None,
|
| 20 |
+
):
|
| 21 |
+
if isinstance(return_type, tuple):
|
| 22 |
+
|
| 23 |
+
def meta(*args, **kwargs):
|
| 24 |
+
return tuple(_prims.TensorMeta(o) for o in impl_aten(*args, **kwargs))
|
| 25 |
+
|
| 26 |
+
else:
|
| 27 |
+
|
| 28 |
+
def meta(*args, **kwargs):
|
| 29 |
+
return _prims.TensorMeta(impl_aten(*args, **kwargs))
|
| 30 |
+
|
| 31 |
+
return _prims._make_prim(
|
| 32 |
+
schema=schema,
|
| 33 |
+
return_type=return_type,
|
| 34 |
+
meta=meta,
|
| 35 |
+
impl_aten=impl_aten,
|
| 36 |
+
doc=doc,
|
| 37 |
+
tags=tags,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def eager_force_stride(input_tensor: Tensor, stride) -> Tensor:
|
| 42 |
+
if input_tensor.stride() == stride:
|
| 43 |
+
return input_tensor
|
| 44 |
+
new_tensor = input_tensor.clone().as_strided(
|
| 45 |
+
input_tensor.shape,
|
| 46 |
+
stride,
|
| 47 |
+
)
|
| 48 |
+
new_tensor.copy_(input_tensor)
|
| 49 |
+
return new_tensor
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Custom prims used for handling randomness
|
| 53 |
+
seed = make_prim(
|
| 54 |
+
"inductor_seed(Device device) -> Tensor",
|
| 55 |
+
lambda device: torch.randint(2**63 - 1, [], device=device),
|
| 56 |
+
doc="create a fresh seed (one per call) for use with inductor_rand",
|
| 57 |
+
tags=(torch.Tag.nondeterministic_seeded,),
|
| 58 |
+
)
|
| 59 |
+
seeds = make_prim(
|
| 60 |
+
"inductor_seeds(int count, Device device) -> Tensor",
|
| 61 |
+
lambda count, device: torch.randint(2**63 - 1, [count], device=device),
|
| 62 |
+
doc="Horizontal fusion of many inductor_seed() calls",
|
| 63 |
+
tags=(torch.Tag.nondeterministic_seeded,),
|
| 64 |
+
)
|
| 65 |
+
lookup_seed = make_prim(
|
| 66 |
+
# if inductor_lookup_seed changes, update partitioners.py
|
| 67 |
+
"inductor_lookup_seed(Tensor seeds, int index) -> Tensor",
|
| 68 |
+
lambda seeds, index: seeds[index],
|
| 69 |
+
doc="Extract a single seed from the result of inductor_seeds()",
|
| 70 |
+
)
|
| 71 |
+
random = make_prim(
|
| 72 |
+
"inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor",
|
| 73 |
+
lambda size, seed, mode: getattr(torch, mode)(size, device=seed.device),
|
| 74 |
+
doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused",
|
| 75 |
+
)
|
| 76 |
+
randint = make_prim(
|
| 77 |
+
"inductor_randint(SymInt low, SymInt high, SymInt[] size, Tensor seed) -> Tensor",
|
| 78 |
+
lambda low, high, size, seed: torch.randint(low, high, size, device=seed.device),
|
| 79 |
+
doc="torch.randint() using backend-specific RNG that can be fused",
|
| 80 |
+
)
|
| 81 |
+
force_stride_order = make_prim(
|
| 82 |
+
"inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor",
|
| 83 |
+
eager_force_stride,
|
| 84 |
+
doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise",
|
| 85 |
+
)
|
| 86 |
+
_unsafe_index_put_ = make_prim(
|
| 87 |
+
"_unsafe_index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)",
|
| 88 |
+
lambda self, indices, values, accumulate=False: torch.ops.aten.index_put_(
|
| 89 |
+
self, indices, values, accumulate
|
| 90 |
+
),
|
| 91 |
+
doc="Unsafe index_put_ (doesn't issue device asserts)",
|
| 92 |
+
)
|
| 93 |
+
fma = make_prim(
|
| 94 |
+
"fma(Tensor a, Tensor b, Tensor c) -> Tensor",
|
| 95 |
+
lambda a, b, c: (a * b) + c,
|
| 96 |
+
doc="Fused multiply add: fma(a, b, c) -> (a * b) + c without rounding after the multiplication",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _low_memory_max_pool2d_with_offsets_aten(
|
| 101 |
+
self,
|
| 102 |
+
kernel_size,
|
| 103 |
+
stride,
|
| 104 |
+
padding,
|
| 105 |
+
dilation,
|
| 106 |
+
ceil_mode,
|
| 107 |
+
):
|
| 108 |
+
vals, indices = torch.ops.aten.max_pool2d_with_indices(
|
| 109 |
+
self, kernel_size, stride, padding, dilation, ceil_mode
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
input_width = self.shape[-1]
|
| 113 |
+
kernel_width = kernel_size[1]
|
| 114 |
+
|
| 115 |
+
bh_shape = [1] * self.ndim
|
| 116 |
+
bh_shape[-2] = -1
|
| 117 |
+
bh = torch.arange(indices.shape[-2], dtype=torch.int64, device=self.device).view(
|
| 118 |
+
bh_shape
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
bw_shape = [1] * self.ndim
|
| 122 |
+
bw_shape[-1] = -1
|
| 123 |
+
bw = torch.arange(indices.shape[-1], dtype=torch.int64, device=self.device).view(
|
| 124 |
+
bw_shape
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
hbase = bh * stride[0] - padding[0]
|
| 128 |
+
wbase = bw * stride[1] - padding[1]
|
| 129 |
+
|
| 130 |
+
ih = indices // input_width
|
| 131 |
+
iw = indices - (ih * input_width)
|
| 132 |
+
|
| 133 |
+
h_inc = ih - hbase
|
| 134 |
+
w_inc = iw - wbase
|
| 135 |
+
|
| 136 |
+
offsets = h_inc * kernel_width + w_inc
|
| 137 |
+
|
| 138 |
+
return vals, offsets.to(torch.int8)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _low_memory_max_pool2d_offsets_to_indices_aten(
|
| 142 |
+
offsets, kernel_width, input_width, stride, padding
|
| 143 |
+
):
|
| 144 |
+
offsets = offsets.to(torch.int64)
|
| 145 |
+
h_inc = offsets // kernel_width
|
| 146 |
+
w_inc = offsets - (h_inc * kernel_width)
|
| 147 |
+
|
| 148 |
+
bh_shape = [1] * offsets.ndim
|
| 149 |
+
bh_shape[-2] = -1
|
| 150 |
+
bh = torch.arange(offsets.shape[-2], dtype=torch.int64, device=offsets.device).view(
|
| 151 |
+
bh_shape
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
bw_shape = [1] * offsets.ndim
|
| 155 |
+
bw_shape[-1] = -1
|
| 156 |
+
bw = torch.arange(offsets.shape[-1], dtype=torch.int64, device=offsets.device).view(
|
| 157 |
+
bw_shape
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
hbase = bh * stride[0] - padding[0]
|
| 161 |
+
wbase = bw * stride[1] - padding[1]
|
| 162 |
+
|
| 163 |
+
ih = hbase + h_inc
|
| 164 |
+
iw = wbase + w_inc
|
| 165 |
+
return ih * input_width + iw
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
_low_memory_max_pool2d_with_offsets = make_prim(
|
| 169 |
+
"_low_memory_max_pool2d_with_offsets(Tensor self, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950
|
| 170 |
+
_low_memory_max_pool2d_with_offsets_aten,
|
| 171 |
+
return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW),
|
| 172 |
+
doc="Instead of returning indices, returns indices offsets.",
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
_low_memory_max_pool2d_offsets_to_indices = make_prim(
|
| 176 |
+
"_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding) -> Tensor", # noqa: B950
|
| 177 |
+
_low_memory_max_pool2d_offsets_to_indices_aten,
|
| 178 |
+
doc="Convert small int offsets to regular indices.",
|
| 179 |
+
)
|
.venv/lib/python3.11/site-packages/torch/_inductor/ir.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/jagged_lowerings.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
from typing import List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import sympy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .ir import Pointwise, TensorBox
|
| 10 |
+
from .lowering import fallback_handler, is_integer_type, register_lowering
|
| 11 |
+
from .virtualized import ops
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# pyre-ignore[2,3]
|
| 15 |
+
def dense_idx_to_jagged_idx(batch_idx, seq_idx, offsets_loader, jagged_len):
|
| 16 |
+
# jagged_len + 1 is used as the upper bound,
|
| 17 |
+
# because the last sequence length may be zero
|
| 18 |
+
begin_idx = ops.indirect_indexing(
|
| 19 |
+
offsets_loader([batch_idx]),
|
| 20 |
+
jagged_len + 1,
|
| 21 |
+
)
|
| 22 |
+
end_idx = offsets_loader([batch_idx + 1])
|
| 23 |
+
jagged_idx = begin_idx + seq_idx
|
| 24 |
+
return jagged_idx, end_idx
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_inverse_offsets(
|
| 28 |
+
offsets: TensorBox,
|
| 29 |
+
jagged_len: Union[int, sympy.Expr],
|
| 30 |
+
realize: bool = True,
|
| 31 |
+
) -> TensorBox:
|
| 32 |
+
"""
|
| 33 |
+
Returns "inverse_offsets" - the inverse of the offsets array.
|
| 34 |
+
offsets maps batch index (dense) to jagged index (i.e. offset into jagged tensor).
|
| 35 |
+
inverse_offsets maps jagged index to batch index.
|
| 36 |
+
|
| 37 |
+
e.g. for offsets [0, 3, 4, 9, 10] this will return
|
| 38 |
+
inverse_offsets = [0, 0, 0, 1, 2, 2, 2, 2, 2, 3]
|
| 39 |
+
|
| 40 |
+
For the given offsets, the computed inverse_offsets are cached
|
| 41 |
+
on the first call and reused in the further calls.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
if hasattr(offsets, "inverse_offsets"):
|
| 45 |
+
# inverse_offsets are already computed
|
| 46 |
+
# for these offsets: can reuse
|
| 47 |
+
return offsets.inverse_offsets
|
| 48 |
+
|
| 49 |
+
# ops.bucketize takes offsets.get_name() which doesn't exist on Pointwise
|
| 50 |
+
# kernels, i.e. we need to realize it before using. In other words, we need
|
| 51 |
+
# offsets to be in global memory so that we can binary search over the
|
| 52 |
+
# entire tensor
|
| 53 |
+
offsets.realize()
|
| 54 |
+
device: torch.device = offsets.get_device()
|
| 55 |
+
dtype: torch.dtype = offsets.get_dtype()
|
| 56 |
+
|
| 57 |
+
# pyre-ignore[2,3]
|
| 58 |
+
def inner_fn(index):
|
| 59 |
+
idx = index[0]
|
| 60 |
+
bucket = ops.bucketize(
|
| 61 |
+
values=ops.index_expr(idx, dtype),
|
| 62 |
+
offsets_name=offsets.get_name(),
|
| 63 |
+
offsets_size=offsets.get_size()[0],
|
| 64 |
+
indexing_dtype=dtype,
|
| 65 |
+
right=True,
|
| 66 |
+
)
|
| 67 |
+
# ops.bucketize above returns 1-based bucket indices,
|
| 68 |
+
# but we need 0-based, hence we subtract 1 from batch
|
| 69 |
+
return bucket - 1
|
| 70 |
+
|
| 71 |
+
inverse_offsets = Pointwise.create(
|
| 72 |
+
device=device,
|
| 73 |
+
dtype=dtype,
|
| 74 |
+
inner_fn=inner_fn,
|
| 75 |
+
ranges=[jagged_len],
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
if realize:
|
| 79 |
+
# "freeze" the node so that it doesn't get inlined downstream.
|
| 80 |
+
inverse_offsets.realize()
|
| 81 |
+
|
| 82 |
+
# cache inverse_offsets for further reuse
|
| 83 |
+
offsets.inverse_offsets = inverse_offsets # type: ignore[attr-defined]
|
| 84 |
+
|
| 85 |
+
return inverse_offsets
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def jagged_idx_to_dense_idx(
|
| 89 |
+
jagged_idx, # pyre-ignore[2]
|
| 90 |
+
inverse_offsets_loader, # pyre-ignore[2]
|
| 91 |
+
offsets_loader, # pyre-ignore[2]
|
| 92 |
+
batch_size: Union[int, sympy.Expr],
|
| 93 |
+
max_seq_len: Union[int, sympy.Expr],
|
| 94 |
+
offsets_dtype: torch.dtype,
|
| 95 |
+
) -> Tuple[sympy.Expr, sympy.Expr]:
|
| 96 |
+
batch_idx = ops.indirect_indexing(
|
| 97 |
+
inverse_offsets_loader([jagged_idx]),
|
| 98 |
+
batch_size + 1,
|
| 99 |
+
)
|
| 100 |
+
batch_start = offsets_loader([batch_idx])
|
| 101 |
+
seq = ops.index_expr(jagged_idx, offsets_dtype) - batch_start
|
| 102 |
+
# check=False because there may be sequences longer than max_seq_len
|
| 103 |
+
seq_idx = ops.indirect_indexing(seq, max_seq_len, check=False)
|
| 104 |
+
return batch_idx, seq_idx
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def register_jagged_ops():
|
| 108 |
+
# pyre-ignore[56]
|
| 109 |
+
@register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default)
|
| 110 |
+
def _jagged_to_padded_dense_forward(
|
| 111 |
+
jagged_values: TensorBox,
|
| 112 |
+
jagged_offsets: List[TensorBox],
|
| 113 |
+
max_lengths: List[int], # list of ints/SymInts
|
| 114 |
+
padding_value: float = 0.0,
|
| 115 |
+
) -> TensorBox:
|
| 116 |
+
device = jagged_values.get_device()
|
| 117 |
+
dtype = jagged_values.get_dtype()
|
| 118 |
+
|
| 119 |
+
jagged_values_size = jagged_values.get_size()
|
| 120 |
+
|
| 121 |
+
# only handle the common case of a single jagged dimension
|
| 122 |
+
if (
|
| 123 |
+
len(jagged_offsets) != 1
|
| 124 |
+
or device.type != "cuda"
|
| 125 |
+
or device != jagged_offsets[0].get_device()
|
| 126 |
+
or len(jagged_values_size) != 2
|
| 127 |
+
or len(jagged_offsets[0].get_size()) != 1
|
| 128 |
+
or len(max_lengths) != len(jagged_offsets)
|
| 129 |
+
or not is_integer_type(jagged_offsets[0])
|
| 130 |
+
):
|
| 131 |
+
return fallback_handler(
|
| 132 |
+
torch.ops.aten._jagged_to_padded_dense_forward.default,
|
| 133 |
+
add_to_fallback_set=False,
|
| 134 |
+
)(
|
| 135 |
+
jagged_values,
|
| 136 |
+
jagged_offsets,
|
| 137 |
+
max_lengths,
|
| 138 |
+
padding_value,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
offsets: TensorBox = jagged_offsets[0]
|
| 142 |
+
offsets_len = offsets.get_size()[0]
|
| 143 |
+
offsets_dtype = offsets.get_dtype()
|
| 144 |
+
batch_size = offsets_len - 1
|
| 145 |
+
max_seq_len = max_lengths[0]
|
| 146 |
+
embedding_len = jagged_values_size[1]
|
| 147 |
+
jagged_len = jagged_values_size[0]
|
| 148 |
+
|
| 149 |
+
output_size = [batch_size, max_seq_len, embedding_len]
|
| 150 |
+
|
| 151 |
+
values_loader = jagged_values.make_loader()
|
| 152 |
+
offsets_loader = offsets.make_loader()
|
| 153 |
+
|
| 154 |
+
# pyre-ignore[2,3,53]
|
| 155 |
+
def inner_fn(index):
|
| 156 |
+
# dense tensor size: [B, N, D]
|
| 157 |
+
batch_idx, seq_idx, emb_idx = index
|
| 158 |
+
jagged_idx, end_idx = dense_idx_to_jagged_idx(
|
| 159 |
+
batch_idx=batch_idx,
|
| 160 |
+
seq_idx=seq_idx,
|
| 161 |
+
offsets_loader=offsets_loader,
|
| 162 |
+
jagged_len=jagged_len,
|
| 163 |
+
)
|
| 164 |
+
return ops.masked(
|
| 165 |
+
ops.lt(
|
| 166 |
+
ops.index_expr(jagged_idx, offsets_dtype),
|
| 167 |
+
end_idx,
|
| 168 |
+
),
|
| 169 |
+
lambda: values_loader([jagged_idx, emb_idx]),
|
| 170 |
+
padding_value,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return Pointwise.create(
|
| 174 |
+
device=device,
|
| 175 |
+
dtype=dtype,
|
| 176 |
+
inner_fn=inner_fn,
|
| 177 |
+
ranges=output_size,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def _dense_to_jagged_forward_impl(
|
| 181 |
+
fallback_op, # pyre-ignore[2]
|
| 182 |
+
dense: TensorBox,
|
| 183 |
+
jagged_offsets: List[TensorBox],
|
| 184 |
+
jagged_len: Optional[int] = None,
|
| 185 |
+
) -> TensorBox:
|
| 186 |
+
device = dense.get_device()
|
| 187 |
+
dtype = dense.get_dtype()
|
| 188 |
+
|
| 189 |
+
dense_size = dense.get_size()
|
| 190 |
+
|
| 191 |
+
# only handle the common case of a single jagged dimension
|
| 192 |
+
if (
|
| 193 |
+
len(jagged_offsets) != 1
|
| 194 |
+
or device.type != "cuda"
|
| 195 |
+
or device != jagged_offsets[0].get_device()
|
| 196 |
+
or len(jagged_offsets[0].get_size()) != 1
|
| 197 |
+
or len(dense_size) != 3
|
| 198 |
+
or jagged_len is None
|
| 199 |
+
or not is_integer_type(jagged_offsets[0])
|
| 200 |
+
):
|
| 201 |
+
return fallback_handler(fallback_op, add_to_fallback_set=False)(
|
| 202 |
+
dense,
|
| 203 |
+
jagged_offsets,
|
| 204 |
+
jagged_len,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
offsets: TensorBox = jagged_offsets[0]
|
| 208 |
+
offsets_dtype = offsets.get_dtype()
|
| 209 |
+
batch_size = dense_size[0]
|
| 210 |
+
max_seq_len = dense_size[1]
|
| 211 |
+
embedding_len = dense_size[-1]
|
| 212 |
+
|
| 213 |
+
output_size = [jagged_len, embedding_len]
|
| 214 |
+
|
| 215 |
+
dense_loader = dense.make_loader()
|
| 216 |
+
offsets_loader = offsets.make_loader()
|
| 217 |
+
|
| 218 |
+
inverse_offsets = get_inverse_offsets(
|
| 219 |
+
offsets=offsets,
|
| 220 |
+
jagged_len=jagged_len,
|
| 221 |
+
)
|
| 222 |
+
inverse_offsets_loader = inverse_offsets.make_loader()
|
| 223 |
+
|
| 224 |
+
# pyre-ignore[2,3,53]
|
| 225 |
+
def inner_fn(index):
|
| 226 |
+
# jagged tensor size: [sum_B(N_B), D]
|
| 227 |
+
jagged_idx, emb_idx = index
|
| 228 |
+
batch_idx, seq_idx = jagged_idx_to_dense_idx(
|
| 229 |
+
jagged_idx=jagged_idx,
|
| 230 |
+
offsets_loader=offsets_loader,
|
| 231 |
+
inverse_offsets_loader=inverse_offsets_loader,
|
| 232 |
+
batch_size=batch_size,
|
| 233 |
+
max_seq_len=max_seq_len,
|
| 234 |
+
offsets_dtype=offsets_dtype,
|
| 235 |
+
)
|
| 236 |
+
return ops.masked(
|
| 237 |
+
ops.lt(
|
| 238 |
+
ops.index_expr(seq_idx, offsets_dtype),
|
| 239 |
+
ops.index_expr(max_seq_len, offsets_dtype),
|
| 240 |
+
),
|
| 241 |
+
lambda: dense_loader([batch_idx, seq_idx, emb_idx]),
|
| 242 |
+
0.0, # jagged sequence longer than max_seq_len
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
return Pointwise.create(
|
| 246 |
+
device=device,
|
| 247 |
+
dtype=dtype,
|
| 248 |
+
inner_fn=inner_fn,
|
| 249 |
+
ranges=output_size,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# pyre-ignore[56]
|
| 253 |
+
@register_lowering(torch.ops.aten._padded_dense_to_jagged_forward)
|
| 254 |
+
def _dense_to_jagged_forward(
|
| 255 |
+
dense: TensorBox,
|
| 256 |
+
jagged_offsets: List[TensorBox],
|
| 257 |
+
jagged_len: Optional[int] = None,
|
| 258 |
+
) -> TensorBox:
|
| 259 |
+
return _dense_to_jagged_forward_impl(
|
| 260 |
+
fallback_op=torch.ops.aten._padded_dense_to_jagged_forward.default,
|
| 261 |
+
dense=dense,
|
| 262 |
+
jagged_offsets=jagged_offsets,
|
| 263 |
+
jagged_len=jagged_len,
|
| 264 |
+
)
|
.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/metrics.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import csv
|
| 5 |
+
import dataclasses
|
| 6 |
+
import inspect
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from functools import lru_cache
|
| 11 |
+
from typing import Dict, List, Set, Tuple, TYPE_CHECKING
|
| 12 |
+
|
| 13 |
+
from torch._inductor import config
|
| 14 |
+
from torch._inductor.utils import get_benchmark_name
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Prevent circular import
|
| 18 |
+
if TYPE_CHECKING:
|
| 19 |
+
from torch._inductor.scheduler import BaseSchedulerNode
|
| 20 |
+
|
| 21 |
+
# counter for tracking how many kernels have been generated
|
| 22 |
+
generated_kernel_count = 0
|
| 23 |
+
generated_cpp_vec_kernel_count = 0
|
| 24 |
+
num_bytes_accessed = 0
|
| 25 |
+
nodes_num_elem: List[
|
| 26 |
+
Tuple[
|
| 27 |
+
BaseSchedulerNode,
|
| 28 |
+
int,
|
| 29 |
+
]
|
| 30 |
+
] = []
|
| 31 |
+
node_runtimes: List[Tuple[BaseSchedulerNode, float]] = []
|
| 32 |
+
|
| 33 |
+
# counters for tracking fusions
|
| 34 |
+
ir_nodes_pre_fusion = 0
|
| 35 |
+
|
| 36 |
+
# counters for tracking to_dtype inserted
|
| 37 |
+
cpp_to_dtype_count = 0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclasses.dataclass
|
| 41 |
+
class CppOuterLoopFusedCount:
|
| 42 |
+
inner_kernel_number: int
|
| 43 |
+
local_buffer_number: int = 0
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# The length counts the number of outer loop fusions.
|
| 47 |
+
cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = []
|
| 48 |
+
|
| 49 |
+
num_comprehensive_padding = 0
|
| 50 |
+
num_matches_for_scatter_upon_const_tensor = 0
|
| 51 |
+
|
| 52 |
+
num_loop_reordering = 0
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# reset all counters
|
| 56 |
+
def reset():
|
| 57 |
+
global generated_kernel_count
|
| 58 |
+
global generated_cpp_vec_kernel_count
|
| 59 |
+
global num_bytes_accessed, nodes_num_elem
|
| 60 |
+
global ir_nodes_pre_fusion
|
| 61 |
+
global cpp_to_dtype_count
|
| 62 |
+
global cpp_outer_loop_fused_inner_counts
|
| 63 |
+
global num_comprehensive_padding
|
| 64 |
+
global num_matches_for_scatter_upon_const_tensor
|
| 65 |
+
global num_loop_reordering
|
| 66 |
+
|
| 67 |
+
generated_kernel_count = 0
|
| 68 |
+
generated_cpp_vec_kernel_count = 0
|
| 69 |
+
num_bytes_accessed = 0
|
| 70 |
+
nodes_num_elem.clear()
|
| 71 |
+
node_runtimes.clear()
|
| 72 |
+
ir_nodes_pre_fusion = 0
|
| 73 |
+
cpp_to_dtype_count = 0
|
| 74 |
+
cpp_outer_loop_fused_inner_counts.clear()
|
| 75 |
+
num_comprehensive_padding = 0
|
| 76 |
+
num_matches_for_scatter_upon_const_tensor = 0
|
| 77 |
+
num_loop_reordering = 0
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@dataclass
|
| 81 |
+
class CachedMetricsDeltas:
|
| 82 |
+
"""
|
| 83 |
+
The subset of metrics we want update across cache hits, e.g., the
|
| 84 |
+
FxGraphCache.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
generated_kernel_count: int
|
| 88 |
+
generated_cpp_vec_kernel_count: int
|
| 89 |
+
ir_nodes_pre_fusion: int
|
| 90 |
+
cpp_to_dtype_count: int
|
| 91 |
+
num_bytes_accessed: int
|
| 92 |
+
num_matches_for_scatter_upon_const_tensor: int
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_metric_fields():
|
| 96 |
+
return [field.name for field in dataclasses.fields(CachedMetricsDeltas)]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class CachedMetricsHelper:
|
| 100 |
+
"""
|
| 101 |
+
A helper class to help calculate and apply counter deltas for those
|
| 102 |
+
metrics we want to save with cache entries (e.g., FxGraphCache) and
|
| 103 |
+
apply on a cache hit.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(self) -> None:
|
| 107 |
+
self.cached_metrics = {}
|
| 108 |
+
for metric in get_metric_fields():
|
| 109 |
+
self.cached_metrics[metric] = globals()[metric]
|
| 110 |
+
|
| 111 |
+
def get_deltas(self) -> CachedMetricsDeltas:
|
| 112 |
+
delta_metrics = {}
|
| 113 |
+
for metric in get_metric_fields():
|
| 114 |
+
delta_metrics[metric] = globals()[metric] - self.cached_metrics[metric]
|
| 115 |
+
|
| 116 |
+
return CachedMetricsDeltas(**delta_metrics)
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
def apply_deltas(delta: CachedMetricsDeltas):
|
| 120 |
+
for metric in get_metric_fields():
|
| 121 |
+
globals()[metric] += getattr(delta, metric)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@dataclass
|
| 128 |
+
class MetricTable:
|
| 129 |
+
table_name: str
|
| 130 |
+
column_names: List[str]
|
| 131 |
+
|
| 132 |
+
num_rows_added: int = 0
|
| 133 |
+
|
| 134 |
+
def add_row(self, row_fn):
|
| 135 |
+
if self.table_name not in enabled_metric_tables():
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
row_dict = row_fn()
|
| 139 |
+
assert len(self.column_names) == len(
|
| 140 |
+
row_dict
|
| 141 |
+
), f"{len(self.column_names)} v.s. {len(row_dict)}"
|
| 142 |
+
assert set(self.column_names) == set(
|
| 143 |
+
row_dict.keys()
|
| 144 |
+
), f"{set(self.column_names)} v.s. {set(row_dict.keys())}"
|
| 145 |
+
|
| 146 |
+
row = [
|
| 147 |
+
get_benchmark_name(),
|
| 148 |
+
]
|
| 149 |
+
row += [row_dict[column_name] for column_name in self.column_names]
|
| 150 |
+
self._write_row(row)
|
| 151 |
+
|
| 152 |
+
def output_filename(self):
|
| 153 |
+
return f"metric_table_{self.table_name}.csv"
|
| 154 |
+
|
| 155 |
+
def write_header(self):
|
| 156 |
+
filename = self.output_filename()
|
| 157 |
+
with open(filename, "w") as fd:
|
| 158 |
+
writer = csv.writer(fd, lineterminator="\n")
|
| 159 |
+
writer.writerow(["model_name"] + self.column_names)
|
| 160 |
+
|
| 161 |
+
def _write_row(self, row):
|
| 162 |
+
filename = self.output_filename()
|
| 163 |
+
if self.num_rows_added == 0 and not os.path.exists(filename):
|
| 164 |
+
self.write_header()
|
| 165 |
+
|
| 166 |
+
self.num_rows_added += 1
|
| 167 |
+
|
| 168 |
+
for idx, orig_val in enumerate(row):
|
| 169 |
+
if isinstance(orig_val, float):
|
| 170 |
+
new_val = f"{orig_val:.6f}"
|
| 171 |
+
elif orig_val is None:
|
| 172 |
+
new_val = ""
|
| 173 |
+
else:
|
| 174 |
+
new_val = orig_val
|
| 175 |
+
row[idx] = new_val
|
| 176 |
+
|
| 177 |
+
with open(filename, "a") as fd:
|
| 178 |
+
writer = csv.writer(fd, lineterminator="\n")
|
| 179 |
+
writer.writerow(row)
|
| 180 |
+
|
| 181 |
+
@staticmethod
|
| 182 |
+
def register_table(name, column_names):
|
| 183 |
+
table = MetricTable(name, column_names)
|
| 184 |
+
REGISTERED_METRIC_TABLES[name] = table
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
MetricTable.register_table(
|
| 188 |
+
"slow_fusion",
|
| 189 |
+
[
|
| 190 |
+
"kernel1_path",
|
| 191 |
+
"kernel1_latency",
|
| 192 |
+
"kernel2_path",
|
| 193 |
+
"kernel2_latency",
|
| 194 |
+
"fused_kernel_path",
|
| 195 |
+
"fused_kernel_latency",
|
| 196 |
+
"slow_down_ratio",
|
| 197 |
+
],
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# track the fusion statistics for each graph
|
| 201 |
+
MetricTable.register_table(
|
| 202 |
+
"graph_stats",
|
| 203 |
+
[
|
| 204 |
+
"graph_id",
|
| 205 |
+
"num_nodes_before_fusion",
|
| 206 |
+
"num_nodes_after_fusion",
|
| 207 |
+
],
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# track the perf difference between persistent reduction and non-persistent
|
| 211 |
+
# reductions
|
| 212 |
+
MetricTable.register_table(
|
| 213 |
+
"persistent_red_perf",
|
| 214 |
+
[
|
| 215 |
+
"kernel1_name",
|
| 216 |
+
"kernel2_name",
|
| 217 |
+
"kernel1_latency",
|
| 218 |
+
"kernel2_latency",
|
| 219 |
+
"size_hints",
|
| 220 |
+
"reduction_hint",
|
| 221 |
+
"speedup",
|
| 222 |
+
],
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Log the fusion failures due to indexing mismatch
|
| 226 |
+
MetricTable.register_table(
|
| 227 |
+
"fusion_failure_due_to_indexing_mismatch",
|
| 228 |
+
[
|
| 229 |
+
"pre_grad_graph_id",
|
| 230 |
+
"post_grad_graph_id",
|
| 231 |
+
"node1_name",
|
| 232 |
+
"node2_name",
|
| 233 |
+
"node1_debug_str",
|
| 234 |
+
"node2_debug_str",
|
| 235 |
+
"common_buffer_names",
|
| 236 |
+
"failure_reason",
|
| 237 |
+
],
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Log metadata for pointwise/reduction kernels. E.g., model name, kernel path, numel, rnumel, reduction hint
|
| 241 |
+
MetricTable.register_table(
|
| 242 |
+
"kernel_metadata",
|
| 243 |
+
[
|
| 244 |
+
"kernel_name",
|
| 245 |
+
"kernel_path",
|
| 246 |
+
"kernel_category", # pointwise/reduction/foreach etc.
|
| 247 |
+
"size_hints",
|
| 248 |
+
"reduction_hint",
|
| 249 |
+
"line_of_code",
|
| 250 |
+
"num_load",
|
| 251 |
+
"num_store",
|
| 252 |
+
"num_for_loop",
|
| 253 |
+
"num_atomic_add",
|
| 254 |
+
"num_args",
|
| 255 |
+
# xyz numel can be different to size_hints since size_hints are rounded
|
| 256 |
+
# up to the nearest power of 2.
|
| 257 |
+
# Inductor kernel will burn in the xyz numel in kernel code for static
|
| 258 |
+
# shape kernels.
|
| 259 |
+
# Logging them will be helpful to find unaligned shape for reduction
|
| 260 |
+
"xnumel",
|
| 261 |
+
"ynumel",
|
| 262 |
+
"rnumel",
|
| 263 |
+
"kernel_args_num_gb",
|
| 264 |
+
],
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _parse_kernel_fn_code(kernel_module_code):
|
| 269 |
+
"""
|
| 270 |
+
The kernel_module_code is the python module that contains kernel function code.
|
| 271 |
+
kernel function is the proper triton kernel function annotated with
|
| 272 |
+
@triton.jit
|
| 273 |
+
"""
|
| 274 |
+
from .codecache import PyCodeCache
|
| 275 |
+
from .wrapper_benchmark import get_triton_kernel
|
| 276 |
+
|
| 277 |
+
mod = PyCodeCache.load(kernel_module_code)
|
| 278 |
+
kernel = get_triton_kernel(mod)
|
| 279 |
+
# kernel is a CachingAutotune; kernel.fn is the JITFunction;
|
| 280 |
+
# kernel.fn.fn is the function being decorate by triton.jit
|
| 281 |
+
return inspect.getsource(kernel.fn.fn)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _parse_kernel_line_of_code(proper_kernel_fn_code):
|
| 285 |
+
"""
|
| 286 |
+
Return the line of code for the kernel excluding the decorators.
|
| 287 |
+
"""
|
| 288 |
+
return len(proper_kernel_fn_code.splitlines())
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def _parse_size_hints(kernel_module_code, kernel_category):
|
| 292 |
+
if kernel_category == "foreach":
|
| 293 |
+
# foreach kernel does not have size_hints
|
| 294 |
+
return None
|
| 295 |
+
m = re.search(r"size_hints=(\[[0-9, ]*\]),", kernel_module_code)
|
| 296 |
+
assert m, "size_hints missing!"
|
| 297 |
+
return m.group(1)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def _parse_reduction_hint(kernel_category, kernel_module_code):
|
| 301 |
+
if kernel_category not in ("reduction", "persistent_reduction"):
|
| 302 |
+
return None
|
| 303 |
+
m = re.search(r"reduction_hint=ReductionHint\.(\w*),", kernel_module_code)
|
| 304 |
+
assert m, "reduction_hint not found in kernel source code!"
|
| 305 |
+
return m.group(1)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _count_pattern(proper_kernel_fn_code, pattern):
|
| 309 |
+
return proper_kernel_fn_code.count(pattern)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def _count_args(proper_kernel_fn_code):
|
| 313 |
+
def_line = proper_kernel_fn_code.splitlines()[0]
|
| 314 |
+
assert def_line.startswith("def ")
|
| 315 |
+
start_idx = def_line.index("(")
|
| 316 |
+
end_idx = def_line.index("):")
|
| 317 |
+
decl_csv = def_line[start_idx + 1 : end_idx]
|
| 318 |
+
comps = decl_csv.split(",")
|
| 319 |
+
return len(comps)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def _parse_proper_kernel_fn_code(kernel_fn_code):
|
| 323 |
+
"""
|
| 324 |
+
Skip decorators.
|
| 325 |
+
"""
|
| 326 |
+
start_pos = kernel_fn_code.index("def ")
|
| 327 |
+
return kernel_fn_code[start_pos:]
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def _parse_numel(proper_kernel_fn_code, numel_arg_name):
|
| 331 |
+
m = re.search(f"{numel_arg_name} = ([\\d]+)", proper_kernel_fn_code)
|
| 332 |
+
if m:
|
| 333 |
+
return int(m.group(1))
|
| 334 |
+
else:
|
| 335 |
+
return None
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def _parse_kernel_args_num_gb(kernel_fn_code, kernel_category):
|
| 339 |
+
"""
|
| 340 |
+
inductor meta looks like:
|
| 341 |
+
inductor_meta={... 'mutated_arg_names': [], 'no_x_dim': False, 'kernel_num_gb': 2.0},
|
| 342 |
+
"""
|
| 343 |
+
m = re.search(r".kernel_num_gb.:\s*([0-9.]+)", kernel_fn_code)
|
| 344 |
+
if m:
|
| 345 |
+
return float(m.group(1))
|
| 346 |
+
else:
|
| 347 |
+
"""
|
| 348 |
+
There are a few cases that kernel_num_gdb field can be missing:
|
| 349 |
+
1. the field will be missing if config.benchmark_kernel and
|
| 350 |
+
config.profile_bandwidth are false
|
| 351 |
+
2. even if config.benchmark_kernel or config.profile_bandwidth is true.
|
| 352 |
+
foreach kernel does not have kernel_num_gb field in the metadata
|
| 353 |
+
"""
|
| 354 |
+
return None
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def log_kernel_metadata(kernel_name, kernel_path, kernel_module_code):
|
| 358 |
+
"""
|
| 359 |
+
An utility to log kernel metadata. We may parse metadata from kernel source code here.
|
| 360 |
+
|
| 361 |
+
It's fine to parse the generated kernel code here since the logging is
|
| 362 |
+
disabled by default. It would hurt compilation time.
|
| 363 |
+
"""
|
| 364 |
+
from .wrapper_benchmark import get_kernel_category_by_source_code
|
| 365 |
+
|
| 366 |
+
kernel_category = get_kernel_category_by_source_code(kernel_module_code)
|
| 367 |
+
reduction_hint = _parse_reduction_hint(kernel_category, kernel_module_code)
|
| 368 |
+
size_hints = _parse_size_hints(kernel_module_code, kernel_category)
|
| 369 |
+
kernel_fn_code = _parse_kernel_fn_code(kernel_module_code)
|
| 370 |
+
|
| 371 |
+
proper_kernel_fn_code = _parse_proper_kernel_fn_code(kernel_fn_code)
|
| 372 |
+
|
| 373 |
+
# the line of code excluding the decortors
|
| 374 |
+
kernel_line_of_code = _parse_kernel_line_of_code(proper_kernel_fn_code)
|
| 375 |
+
|
| 376 |
+
get_metric_table("kernel_metadata").add_row(
|
| 377 |
+
lambda: {
|
| 378 |
+
"kernel_name": kernel_name,
|
| 379 |
+
"kernel_path": kernel_path,
|
| 380 |
+
"kernel_category": kernel_category,
|
| 381 |
+
"size_hints": size_hints,
|
| 382 |
+
"reduction_hint": reduction_hint,
|
| 383 |
+
"line_of_code": kernel_line_of_code,
|
| 384 |
+
"num_load": _count_pattern(proper_kernel_fn_code, "tl.load"),
|
| 385 |
+
"num_store": _count_pattern(proper_kernel_fn_code, "tl.store"),
|
| 386 |
+
"num_for_loop": _count_pattern(proper_kernel_fn_code, "for "),
|
| 387 |
+
"num_atomic_add": _count_pattern(proper_kernel_fn_code, "tl.atomic_add"),
|
| 388 |
+
"num_args": _count_args(proper_kernel_fn_code),
|
| 389 |
+
"xnumel": _parse_numel(proper_kernel_fn_code, "xnumel"),
|
| 390 |
+
"ynumel": _parse_numel(proper_kernel_fn_code, "ynumel"),
|
| 391 |
+
"rnumel": _parse_numel(proper_kernel_fn_code, "rnumel"),
|
| 392 |
+
"kernel_args_num_gb": _parse_kernel_args_num_gb(
|
| 393 |
+
kernel_fn_code, kernel_category
|
| 394 |
+
),
|
| 395 |
+
}
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def purge_old_log_files():
|
| 400 |
+
"""
|
| 401 |
+
Purge the old log file at the beginning when the benchmark script runs.
|
| 402 |
+
Should do it in the parent process rather than the child processes running
|
| 403 |
+
each individual model.
|
| 404 |
+
"""
|
| 405 |
+
for name, table in REGISTERED_METRIC_TABLES.items():
|
| 406 |
+
if name in enabled_metric_tables():
|
| 407 |
+
filename = table.output_filename()
|
| 408 |
+
if os.path.exists(filename):
|
| 409 |
+
os.unlink(filename)
|
| 410 |
+
|
| 411 |
+
table.write_header()
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
@lru_cache
|
| 415 |
+
def enabled_metric_tables() -> Set[str]:
|
| 416 |
+
config_str = config.enabled_metric_tables
|
| 417 |
+
|
| 418 |
+
enabled = set()
|
| 419 |
+
for name in config_str.split(","):
|
| 420 |
+
name = name.strip()
|
| 421 |
+
if not name:
|
| 422 |
+
continue
|
| 423 |
+
assert (
|
| 424 |
+
name in REGISTERED_METRIC_TABLES
|
| 425 |
+
), f"Metric table name {name} is not registered"
|
| 426 |
+
enabled.add(name)
|
| 427 |
+
return enabled
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def is_metric_table_enabled(name):
|
| 431 |
+
return name in enabled_metric_tables()
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def get_metric_table(name):
|
| 435 |
+
assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined"
|
| 436 |
+
return REGISTERED_METRIC_TABLES[name]
|
.venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_ir.py
ADDED
|
@@ -0,0 +1,1881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import Any, List, Optional
|
| 3 |
+
|
| 4 |
+
import sympy
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._prims_common import make_channels_last_strides_for
|
| 8 |
+
from torch.utils._ordered_set import OrderedSet
|
| 9 |
+
|
| 10 |
+
from .ir import (
|
| 11 |
+
ExternKernelAlloc,
|
| 12 |
+
FixedLayout,
|
| 13 |
+
FlexibleLayout,
|
| 14 |
+
ir_node_to_tensor,
|
| 15 |
+
IRNode,
|
| 16 |
+
is_contiguous_storage_and_layout,
|
| 17 |
+
Layout,
|
| 18 |
+
may_convert_to_optional,
|
| 19 |
+
MultiOutput,
|
| 20 |
+
MultiOutputLayout,
|
| 21 |
+
MutationOutput,
|
| 22 |
+
NoneLayout,
|
| 23 |
+
TensorBox,
|
| 24 |
+
)
|
| 25 |
+
from .utils import convert_shape_to_inductor, pad_listlike
|
| 26 |
+
from .virtualized import V
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _prepare_convolution_fusion_create(
|
| 30 |
+
cls,
|
| 31 |
+
x: "TensorBox",
|
| 32 |
+
weight: "TensorBox",
|
| 33 |
+
bias: "TensorBox",
|
| 34 |
+
padding: List[int],
|
| 35 |
+
stride: List[int],
|
| 36 |
+
dilation: List[int],
|
| 37 |
+
groups: int,
|
| 38 |
+
transposed: bool = False,
|
| 39 |
+
output_padding: Optional[List[int]] = None,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
This function is a helper function to prepare inputs, layout and constant args
|
| 43 |
+
for convolution post-op fusion's create function, including deciding the output
|
| 44 |
+
layout (channels first or channels last), realizing inputs and make them etc. The
|
| 45 |
+
function only supports the CPU device since conv post-op fusion kernel is only
|
| 46 |
+
supported on CPU right now.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
# Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size
|
| 50 |
+
def _conv_input_size(
|
| 51 |
+
output_size, weight_size, padding, output_padding, stride, dilation, groups
|
| 52 |
+
):
|
| 53 |
+
assert len(output_size) == len(weight_size), "Expect input dim == weight dim"
|
| 54 |
+
dim = len(output_size)
|
| 55 |
+
assert dim > 2, "Expect input dim > 2"
|
| 56 |
+
|
| 57 |
+
BATCH_DIM = 0
|
| 58 |
+
WEIGHT_INPUT_CHANNELS_DIM = 1
|
| 59 |
+
input_size = []
|
| 60 |
+
input_size.append(output_size[BATCH_DIM])
|
| 61 |
+
input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups)
|
| 62 |
+
for d in range(2, dim):
|
| 63 |
+
kernel = (weight_size[d] - 1) * dilation[d - 2] + 1
|
| 64 |
+
input_size_d = (
|
| 65 |
+
(output_size[d] - 1) * stride[d - 2]
|
| 66 |
+
- (padding[d - 2] * 2)
|
| 67 |
+
+ kernel
|
| 68 |
+
+ output_padding[d - 2]
|
| 69 |
+
)
|
| 70 |
+
input_size.append(input_size_d)
|
| 71 |
+
return list(map(int, input_size))
|
| 72 |
+
|
| 73 |
+
# The size of prepacked_weight is the prepacked weight size of deconv:
|
| 74 |
+
# Groups > 1: [g*o, i/g, ...]
|
| 75 |
+
# Groups == 1: [o, i, ...]
|
| 76 |
+
# Returns original weight size in [i, o, ...]
|
| 77 |
+
def _original_deconv_weight_size(
|
| 78 |
+
prepacked_weight,
|
| 79 |
+
groups,
|
| 80 |
+
):
|
| 81 |
+
prepacked_weight_size = prepacked_weight.size()
|
| 82 |
+
dim = len(prepacked_weight_size)
|
| 83 |
+
assert dim > 2, "Expect weight dim > 2"
|
| 84 |
+
if groups > 1:
|
| 85 |
+
weight_size = []
|
| 86 |
+
weight_size.append(prepacked_weight_size[1] * groups)
|
| 87 |
+
weight_size.append(prepacked_weight_size[0] / groups)
|
| 88 |
+
for d in range(2, dim):
|
| 89 |
+
weight_size.append(prepacked_weight_size[d])
|
| 90 |
+
else:
|
| 91 |
+
weight_size = prepacked_weight.transpose(0, 1).size()
|
| 92 |
+
return weight_size
|
| 93 |
+
|
| 94 |
+
x.realize()
|
| 95 |
+
weight.realize()
|
| 96 |
+
if bias is not None:
|
| 97 |
+
bias.realize()
|
| 98 |
+
with V.graph.fake_mode:
|
| 99 |
+
# TODO <Leslie> cleaned up the fake_tensor trace as Linear implementation
|
| 100 |
+
x_fake = ir_node_to_tensor(x, guard_shape=True)
|
| 101 |
+
weight_fake = ir_node_to_tensor(weight, guard_shape=True)
|
| 102 |
+
dims = len(x_fake.size()) - 2
|
| 103 |
+
assert 0 < len(padding) <= dims
|
| 104 |
+
assert 0 < len(dilation) <= dims
|
| 105 |
+
assert 0 < len(stride) <= dims
|
| 106 |
+
padding = pad_listlike(padding, dims)
|
| 107 |
+
dilation = pad_listlike(dilation, dims)
|
| 108 |
+
stride = pad_listlike(stride, dims)
|
| 109 |
+
if output_padding is None:
|
| 110 |
+
output_padding = pad_listlike([0], dims)
|
| 111 |
+
else:
|
| 112 |
+
assert 0 < len(output_padding) <= dims
|
| 113 |
+
output_padding = pad_listlike(output_padding, dims)
|
| 114 |
+
assert isinstance(groups, (int, sympy.core.numbers.Integer))
|
| 115 |
+
if transposed:
|
| 116 |
+
# When transposed, the size of the prepacked oneDNN weight is different
|
| 117 |
+
# from the PyTorch weight. We're not able to run aten conv with such
|
| 118 |
+
# size. We infer the output size from the input params here:
|
| 119 |
+
weight_size = _original_deconv_weight_size(weight_fake, groups)
|
| 120 |
+
input_size = x_fake.size()
|
| 121 |
+
output_size = _conv_input_size(
|
| 122 |
+
input_size,
|
| 123 |
+
weight_size,
|
| 124 |
+
padding,
|
| 125 |
+
output_padding,
|
| 126 |
+
stride,
|
| 127 |
+
dilation,
|
| 128 |
+
groups,
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
bias_fake = (
|
| 132 |
+
ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias
|
| 133 |
+
)
|
| 134 |
+
output = torch.ops.aten.convolution(
|
| 135 |
+
x_fake,
|
| 136 |
+
weight_fake,
|
| 137 |
+
bias_fake,
|
| 138 |
+
stride,
|
| 139 |
+
padding,
|
| 140 |
+
dilation,
|
| 141 |
+
transposed,
|
| 142 |
+
output_padding,
|
| 143 |
+
groups,
|
| 144 |
+
)
|
| 145 |
+
output_size = output.size()
|
| 146 |
+
|
| 147 |
+
req_stride_order = [0] + list(reversed(range(1, len(stride) + 1)))
|
| 148 |
+
req_stride_order = [len(req_stride_order)] + req_stride_order
|
| 149 |
+
|
| 150 |
+
x = cls.require_stride_order(x, req_stride_order)
|
| 151 |
+
|
| 152 |
+
# We won't do weight prepack for Conv if dynamic_shapes.
|
| 153 |
+
# In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel.
|
| 154 |
+
# In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1),
|
| 155 |
+
# x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order
|
| 156 |
+
# won't change the stride of this tensor since stride for dimensions of size 1 is ignored. While in Conv kernel,
|
| 157 |
+
# this tensor is considered as channels first and the output will be in contiguous format.
|
| 158 |
+
# To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last.
|
| 159 |
+
dynamic_shapes = not all(isinstance(i, int) for i in (output_size))
|
| 160 |
+
if dynamic_shapes and is_contiguous_storage_and_layout(x):
|
| 161 |
+
output_stride = FlexibleLayout.contiguous_strides(output_size)
|
| 162 |
+
else:
|
| 163 |
+
output_stride = make_channels_last_strides_for(output_size)
|
| 164 |
+
|
| 165 |
+
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
|
| 166 |
+
inputs = [x, weight]
|
| 167 |
+
|
| 168 |
+
kernel_layout = FixedLayout(
|
| 169 |
+
x.get_device(),
|
| 170 |
+
x.get_dtype(),
|
| 171 |
+
convert_shape_to_inductor(output_size),
|
| 172 |
+
convert_shape_to_inductor(output_stride),
|
| 173 |
+
)
|
| 174 |
+
constant_args = [padding, stride, dilation, groups]
|
| 175 |
+
if transposed:
|
| 176 |
+
constant_args.insert(1, output_padding)
|
| 177 |
+
|
| 178 |
+
if bias is not None:
|
| 179 |
+
inputs.append(bias)
|
| 180 |
+
else:
|
| 181 |
+
constant_args.insert(0, bias)
|
| 182 |
+
return inputs, constant_args, kernel_layout, req_stride_order
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _prepare_linear_fusion_create(
|
| 186 |
+
cls,
|
| 187 |
+
x: "TensorBox",
|
| 188 |
+
weight: "TensorBox",
|
| 189 |
+
bias: "TensorBox",
|
| 190 |
+
):
|
| 191 |
+
"""
|
| 192 |
+
This function is a helper function to prepare inputs, layout and constant args
|
| 193 |
+
for linear post-op fusion's create function. The function only supports the CPU device
|
| 194 |
+
since linear post-op fusion kernel is only supported on CPU right now.
|
| 195 |
+
"""
|
| 196 |
+
x.realize()
|
| 197 |
+
weight.realize()
|
| 198 |
+
if bias is not None:
|
| 199 |
+
bias.realize()
|
| 200 |
+
|
| 201 |
+
*m, _ = x.get_size()
|
| 202 |
+
# The weight has been transposed during the qlinear weight prepack process.
|
| 203 |
+
# https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/
|
| 204 |
+
# aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291
|
| 205 |
+
_, oc = weight.get_size()
|
| 206 |
+
output_size = list(m) + [oc]
|
| 207 |
+
req_stride_order = list(reversed(range(len(x.get_size()))))
|
| 208 |
+
|
| 209 |
+
x = cls.require_stride_order(x, req_stride_order)
|
| 210 |
+
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
|
| 211 |
+
inputs = [x, weight]
|
| 212 |
+
|
| 213 |
+
output_stride = FlexibleLayout.contiguous_strides(output_size)
|
| 214 |
+
kernel_layout = FixedLayout(
|
| 215 |
+
x.get_device(),
|
| 216 |
+
x.get_dtype(),
|
| 217 |
+
output_size,
|
| 218 |
+
output_stride,
|
| 219 |
+
)
|
| 220 |
+
constant_args: List[Any] = []
|
| 221 |
+
|
| 222 |
+
if bias is not None:
|
| 223 |
+
inputs.append(bias)
|
| 224 |
+
else:
|
| 225 |
+
constant_args.insert(0, bias)
|
| 226 |
+
return inputs, constant_args, kernel_layout, req_stride_order
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class ConvolutionUnary(ExternKernelAlloc):
|
| 230 |
+
def __init__(
|
| 231 |
+
self,
|
| 232 |
+
layout,
|
| 233 |
+
inputs,
|
| 234 |
+
constant_args=(),
|
| 235 |
+
) -> None:
|
| 236 |
+
super().__init__(
|
| 237 |
+
layout,
|
| 238 |
+
inputs,
|
| 239 |
+
constant_args,
|
| 240 |
+
None,
|
| 241 |
+
op_overload=torch.ops.mkldnn._convolution_pointwise.default,
|
| 242 |
+
)
|
| 243 |
+
self.cpp_op_schema = """
|
| 244 |
+
at::Tensor(
|
| 245 |
+
const at::Tensor& input_t,
|
| 246 |
+
const at::Tensor& weight_t,
|
| 247 |
+
const std::optional<at::Tensor>& bias_opt,
|
| 248 |
+
at::IntArrayRef padding,
|
| 249 |
+
at::IntArrayRef stride,
|
| 250 |
+
at::IntArrayRef dilation,
|
| 251 |
+
int64_t groups,
|
| 252 |
+
c10::string_view attr,
|
| 253 |
+
torch::List<std::optional<at::Scalar>> scalars,
|
| 254 |
+
std::optional<c10::string_view> algorithm)"""
|
| 255 |
+
|
| 256 |
+
def codegen(self, wrapper):
|
| 257 |
+
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
|
| 258 |
+
self.get_name(),
|
| 259 |
+
self.python_kernel_name,
|
| 260 |
+
self.cpp_kernel_name,
|
| 261 |
+
self.codegen_args(),
|
| 262 |
+
self.cpp_op_schema,
|
| 263 |
+
self.cpp_kernel_key,
|
| 264 |
+
op_overload=self.op_overload,
|
| 265 |
+
raw_args=[*self.inputs, *self.constant_args],
|
| 266 |
+
)
|
| 267 |
+
if isinstance(self.layout, Layout):
|
| 268 |
+
self.codegen_size_asserts(wrapper)
|
| 269 |
+
|
| 270 |
+
@classmethod
|
| 271 |
+
def create(
|
| 272 |
+
cls,
|
| 273 |
+
x: "TensorBox",
|
| 274 |
+
weight: "TensorBox",
|
| 275 |
+
bias: "TensorBox",
|
| 276 |
+
padding_: List[int],
|
| 277 |
+
stride_: List[int],
|
| 278 |
+
dilation_: List[int],
|
| 279 |
+
groups: int,
|
| 280 |
+
attr,
|
| 281 |
+
scalars: Optional[List[Any]],
|
| 282 |
+
algorithm,
|
| 283 |
+
):
|
| 284 |
+
(inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
|
| 285 |
+
cls, x, weight, bias, padding_, stride_, dilation_, groups
|
| 286 |
+
)
|
| 287 |
+
constant_args = constant_args + [
|
| 288 |
+
attr,
|
| 289 |
+
may_convert_to_optional(scalars),
|
| 290 |
+
algorithm,
|
| 291 |
+
]
|
| 292 |
+
return ConvolutionUnary(
|
| 293 |
+
layout=kernel_layout,
|
| 294 |
+
inputs=inputs,
|
| 295 |
+
constant_args=constant_args,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class ConvolutionBinary(ExternKernelAlloc):
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
layout,
|
| 303 |
+
inputs,
|
| 304 |
+
constant_args=(),
|
| 305 |
+
cpp_constant_args=(),
|
| 306 |
+
) -> None:
|
| 307 |
+
super().__init__(
|
| 308 |
+
layout,
|
| 309 |
+
inputs,
|
| 310 |
+
constant_args,
|
| 311 |
+
None,
|
| 312 |
+
op_overload=torch.ops.mkldnn._convolution_pointwise.binary,
|
| 313 |
+
)
|
| 314 |
+
self.cpp_op_schema = """
|
| 315 |
+
at::Tensor(
|
| 316 |
+
const at::Tensor& input_t,
|
| 317 |
+
const at::Tensor& other_t,
|
| 318 |
+
const at::Tensor& weight_t,
|
| 319 |
+
const std::optional<at::Tensor>& bias_opt,
|
| 320 |
+
at::IntArrayRef padding,
|
| 321 |
+
at::IntArrayRef stride,
|
| 322 |
+
at::IntArrayRef dilation,
|
| 323 |
+
int64_t groups,
|
| 324 |
+
c10::string_view binary_attr,
|
| 325 |
+
std::optional<at::Scalar> alpha,
|
| 326 |
+
std::optional<c10::string_view> unary_attr,
|
| 327 |
+
torch::List<std::optional<at::Scalar>> unary_scalars,
|
| 328 |
+
std::optional<c10::string_view> unary_algorithm)"""
|
| 329 |
+
self.cpp_constant_args = cpp_constant_args
|
| 330 |
+
|
| 331 |
+
def codegen(self, wrapper):
|
| 332 |
+
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
|
| 333 |
+
self.get_name(),
|
| 334 |
+
self.python_kernel_name,
|
| 335 |
+
self.cpp_kernel_name,
|
| 336 |
+
self.codegen_args(),
|
| 337 |
+
self.cpp_op_schema,
|
| 338 |
+
self.cpp_kernel_key,
|
| 339 |
+
self.cpp_kernel_overload_name,
|
| 340 |
+
self.op_overload,
|
| 341 |
+
[*self.inputs, *self.constant_args],
|
| 342 |
+
)
|
| 343 |
+
if isinstance(self.layout, Layout):
|
| 344 |
+
self.codegen_size_asserts(wrapper)
|
| 345 |
+
|
| 346 |
+
@classmethod
|
| 347 |
+
def create(
|
| 348 |
+
cls,
|
| 349 |
+
x: "TensorBox",
|
| 350 |
+
other: "TensorBox",
|
| 351 |
+
weight: "TensorBox",
|
| 352 |
+
bias: "TensorBox",
|
| 353 |
+
padding_: List[int],
|
| 354 |
+
stride_: List[int],
|
| 355 |
+
dilation_: List[int],
|
| 356 |
+
groups: int,
|
| 357 |
+
binary_attr: str,
|
| 358 |
+
binary_alpha: Optional[float],
|
| 359 |
+
unary_attr: Optional[str],
|
| 360 |
+
unary_scalars: Optional[List[Any]],
|
| 361 |
+
unary_algorithm: Optional[str],
|
| 362 |
+
):
|
| 363 |
+
(
|
| 364 |
+
inputs,
|
| 365 |
+
constant_args,
|
| 366 |
+
kernel_layout,
|
| 367 |
+
req_stride_order,
|
| 368 |
+
) = _prepare_convolution_fusion_create(
|
| 369 |
+
cls, x, weight, bias, padding_, stride_, dilation_, groups
|
| 370 |
+
)
|
| 371 |
+
other = cls.require_stride_order(other, req_stride_order)
|
| 372 |
+
inputs.insert(1, other)
|
| 373 |
+
constant_args = constant_args + [
|
| 374 |
+
binary_attr,
|
| 375 |
+
binary_alpha,
|
| 376 |
+
unary_attr,
|
| 377 |
+
may_convert_to_optional(unary_scalars),
|
| 378 |
+
unary_algorithm,
|
| 379 |
+
]
|
| 380 |
+
return ConvolutionBinary(
|
| 381 |
+
layout=kernel_layout,
|
| 382 |
+
inputs=inputs,
|
| 383 |
+
constant_args=constant_args,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class ConvolutionBinaryInplace(ExternKernelAlloc):
|
| 388 |
+
def __init__(
|
| 389 |
+
self,
|
| 390 |
+
kernel_layout,
|
| 391 |
+
inputs,
|
| 392 |
+
constant_args=(),
|
| 393 |
+
) -> None:
|
| 394 |
+
# Due to constrain of op.call, other (Tensor&) should be at input[0]
|
| 395 |
+
reordered_inputs = [inputs[1], inputs[0]] + inputs[2:]
|
| 396 |
+
|
| 397 |
+
super().__init__(
|
| 398 |
+
kernel_layout,
|
| 399 |
+
reordered_inputs,
|
| 400 |
+
constant_args,
|
| 401 |
+
None,
|
| 402 |
+
op_overload=torch.ops.mkldnn._convolution_pointwise_.binary,
|
| 403 |
+
)
|
| 404 |
+
# TODO: op.call: input[0] should be at::Tensor&
|
| 405 |
+
self.cpp_op_schema = """
|
| 406 |
+
at::Tensor&(
|
| 407 |
+
at::Tensor& other_t,
|
| 408 |
+
const at::Tensor& input_t,
|
| 409 |
+
const at::Tensor& weight_t,
|
| 410 |
+
const std::optional<at::Tensor>& bias_opt,
|
| 411 |
+
at::IntArrayRef padding,
|
| 412 |
+
at::IntArrayRef stride,
|
| 413 |
+
at::IntArrayRef dilation,
|
| 414 |
+
int64_t groups,
|
| 415 |
+
c10::string_view binary_attr,
|
| 416 |
+
std::optional<at::Scalar> alpha,
|
| 417 |
+
std::optional<c10::string_view> unary_attr,
|
| 418 |
+
torch::List<std::optional<at::Scalar>> unary_scalars,
|
| 419 |
+
std::optional<c10::string_view> unary_algorithm)"""
|
| 420 |
+
|
| 421 |
+
self.mutation_outputs = [
|
| 422 |
+
MutationOutput(NoneLayout(inputs[0].get_device()), inputs[0], self),
|
| 423 |
+
MutationOutput(NoneLayout(inputs[1].get_device()), inputs[1], self),
|
| 424 |
+
]
|
| 425 |
+
|
| 426 |
+
def codegen(self, wrapper):
|
| 427 |
+
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
|
| 428 |
+
self.get_name(),
|
| 429 |
+
self.python_kernel_name,
|
| 430 |
+
self.cpp_kernel_name,
|
| 431 |
+
self.codegen_args(),
|
| 432 |
+
self.cpp_op_schema,
|
| 433 |
+
self.cpp_kernel_key,
|
| 434 |
+
self.cpp_kernel_overload_name,
|
| 435 |
+
self.op_overload,
|
| 436 |
+
[*self.inputs, *self.constant_args],
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
| 440 |
+
return OrderedSet()
|
| 441 |
+
|
| 442 |
+
@classmethod
|
| 443 |
+
def create(
|
| 444 |
+
cls,
|
| 445 |
+
x: "TensorBox",
|
| 446 |
+
other: "TensorBox",
|
| 447 |
+
weight: "TensorBox",
|
| 448 |
+
bias: "TensorBox",
|
| 449 |
+
padding_: List[int],
|
| 450 |
+
stride_: List[int],
|
| 451 |
+
dilation_: List[int],
|
| 452 |
+
groups: int,
|
| 453 |
+
binary_attr: str,
|
| 454 |
+
binary_alpha: Optional[float],
|
| 455 |
+
unary_attr: Optional[str],
|
| 456 |
+
unary_scalars: Optional[List[Any]],
|
| 457 |
+
unary_algorithm: Optional[str],
|
| 458 |
+
):
|
| 459 |
+
(
|
| 460 |
+
inputs,
|
| 461 |
+
constant_args,
|
| 462 |
+
_,
|
| 463 |
+
req_stride_order,
|
| 464 |
+
) = _prepare_convolution_fusion_create(
|
| 465 |
+
cls, x, weight, bias, padding_, stride_, dilation_, groups
|
| 466 |
+
)
|
| 467 |
+
other = cls.require_stride_order(other, req_stride_order)
|
| 468 |
+
inputs.insert(1, other)
|
| 469 |
+
constant_args = constant_args + [
|
| 470 |
+
binary_attr,
|
| 471 |
+
binary_alpha,
|
| 472 |
+
unary_attr,
|
| 473 |
+
may_convert_to_optional(unary_scalars),
|
| 474 |
+
unary_algorithm,
|
| 475 |
+
]
|
| 476 |
+
packed = ConvolutionBinaryInplace(
|
| 477 |
+
kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type]
|
| 478 |
+
inputs=inputs,
|
| 479 |
+
constant_args=constant_args,
|
| 480 |
+
)
|
| 481 |
+
# This op mutates in place which means that the result is not the
|
| 482 |
+
# target but rather the input that is being mutated
|
| 483 |
+
# init reorders the inputs, so inputs[1] becomes packed.inputs[0]
|
| 484 |
+
return packed.inputs[0]
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class ConvolutionTransposeUnary(ExternKernelAlloc):
|
| 488 |
+
def __init__(
|
| 489 |
+
self,
|
| 490 |
+
layout,
|
| 491 |
+
inputs,
|
| 492 |
+
constant_args=(),
|
| 493 |
+
) -> None:
|
| 494 |
+
super().__init__(
|
| 495 |
+
layout,
|
| 496 |
+
inputs,
|
| 497 |
+
constant_args,
|
| 498 |
+
None,
|
| 499 |
+
op_overload=torch.ops.mkldnn._convolution_transpose_pointwise.default,
|
| 500 |
+
)
|
| 501 |
+
self.cpp_op_schema = """
|
| 502 |
+
at::Tensor(
|
| 503 |
+
const at::Tensor& input_t,
|
| 504 |
+
const at::Tensor& weight_t,
|
| 505 |
+
const std::optional<at::Tensor>& bias_opt,
|
| 506 |
+
at::IntArrayRef padding,
|
| 507 |
+
at::IntArrayRef output_padding,
|
| 508 |
+
at::IntArrayRef stride,
|
| 509 |
+
at::IntArrayRef dilation,
|
| 510 |
+
int64_t groups,
|
| 511 |
+
c10::string_view attr,
|
| 512 |
+
torch::List<std::optional<at::Scalar>> scalars,
|
| 513 |
+
std::optional<c10::string_view> algorithm)"""
|
| 514 |
+
|
| 515 |
+
def codegen(self, wrapper):
|
| 516 |
+
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
|
| 517 |
+
self.get_name(),
|
| 518 |
+
self.python_kernel_name,
|
| 519 |
+
self.cpp_kernel_name,
|
| 520 |
+
self.codegen_args(),
|
| 521 |
+
self.cpp_op_schema,
|
| 522 |
+
self.cpp_kernel_key,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
@classmethod
|
| 526 |
+
def create(
|
| 527 |
+
cls,
|
| 528 |
+
x: "TensorBox",
|
| 529 |
+
weight: "TensorBox",
|
| 530 |
+
bias: "TensorBox",
|
| 531 |
+
padding_: List[int],
|
| 532 |
+
output_padding_: List[int],
|
| 533 |
+
stride_: List[int],
|
| 534 |
+
dilation_: List[int],
|
| 535 |
+
groups_: int,
|
| 536 |
+
attr,
|
| 537 |
+
scalars: Optional[List[Any]],
|
| 538 |
+
algorithm,
|
| 539 |
+
):
|
| 540 |
+
transposed = True
|
| 541 |
+
(
|
| 542 |
+
inputs,
|
| 543 |
+
constant_args,
|
| 544 |
+
kernel_layout,
|
| 545 |
+
_,
|
| 546 |
+
) = _prepare_convolution_fusion_create(
|
| 547 |
+
cls,
|
| 548 |
+
x,
|
| 549 |
+
weight,
|
| 550 |
+
bias,
|
| 551 |
+
padding_,
|
| 552 |
+
stride_,
|
| 553 |
+
dilation_,
|
| 554 |
+
groups_,
|
| 555 |
+
transposed,
|
| 556 |
+
output_padding_,
|
| 557 |
+
)
|
| 558 |
+
constant_args = constant_args + [
|
| 559 |
+
attr,
|
| 560 |
+
may_convert_to_optional(scalars),
|
| 561 |
+
algorithm,
|
| 562 |
+
]
|
| 563 |
+
return ConvolutionTransposeUnary(
|
| 564 |
+
layout=kernel_layout,
|
| 565 |
+
inputs=inputs,
|
| 566 |
+
constant_args=constant_args,
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
class QConvPointWisePT2E(ExternKernelAlloc):
|
| 571 |
+
def __init__(
|
| 572 |
+
self,
|
| 573 |
+
layout,
|
| 574 |
+
inputs,
|
| 575 |
+
constant_args=(),
|
| 576 |
+
) -> None:
|
| 577 |
+
"""
|
| 578 |
+
if bias is not None
|
| 579 |
+
- inputs = [x, w, b, weight_scale, weight_zp]
|
| 580 |
+
- const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp,
|
| 581 |
+
fp32_output, unary_attr, unary_scalars, unary_algorithm]
|
| 582 |
+
else
|
| 583 |
+
- inputs = [x, w, weight_scale, weight_zp]
|
| 584 |
+
- const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp,
|
| 585 |
+
fp32_output, unary_attr, unary_scalars, unary_algorithm]
|
| 586 |
+
"""
|
| 587 |
+
self.has_bias = len(inputs) == 5
|
| 588 |
+
super().__init__(
|
| 589 |
+
layout,
|
| 590 |
+
inputs,
|
| 591 |
+
constant_args,
|
| 592 |
+
None,
|
| 593 |
+
op_overload=torch.ops.onednn.qconv2d_pointwise.default,
|
| 594 |
+
)
|
| 595 |
+
self.cpp_op_schema = """
|
| 596 |
+
at::Tensor(
|
| 597 |
+
at::Tensor act,
|
| 598 |
+
double act_scale,
|
| 599 |
+
int64_t act_zero_point,
|
| 600 |
+
at::Tensor weight,
|
| 601 |
+
at::Tensor weight_scales,
|
| 602 |
+
at::Tensor weight_zero_points,
|
| 603 |
+
std::optional<at::Tensor> bias,
|
| 604 |
+
torch::List<int64_t> stride,
|
| 605 |
+
torch::List<int64_t> padding,
|
| 606 |
+
torch::List<int64_t> dilation,
|
| 607 |
+
int64_t groups,
|
| 608 |
+
double output_scale,
|
| 609 |
+
int64_t output_zero_point,
|
| 610 |
+
std::optional<c10::ScalarType> output_dtype,
|
| 611 |
+
c10::string_view attr,
|
| 612 |
+
torch::List<std::optional<at::Scalar>> scalars,
|
| 613 |
+
std::optional<c10::string_view> algorithm)"""
|
| 614 |
+
|
| 615 |
+
def codegen(self, wrapper):
|
| 616 |
+
# Parser the inputs and constant
|
| 617 |
+
# The raw_args setup can be skipped if there is a C shim implementation
|
| 618 |
+
args = [x.codegen_reference() for x in self.inputs]
|
| 619 |
+
const_arg_names = [
|
| 620 |
+
"x_scale",
|
| 621 |
+
"x_zero_point",
|
| 622 |
+
"stride",
|
| 623 |
+
"padding",
|
| 624 |
+
"dilation",
|
| 625 |
+
"groups",
|
| 626 |
+
"output_scale",
|
| 627 |
+
"output_zero_point",
|
| 628 |
+
"output_dtype",
|
| 629 |
+
"attr",
|
| 630 |
+
"scalars",
|
| 631 |
+
"algorithm",
|
| 632 |
+
]
|
| 633 |
+
if not self.has_bias:
|
| 634 |
+
const_arg_names.insert(2, "bias")
|
| 635 |
+
const_args = list(self.codegen_const_args(const_arg_names))
|
| 636 |
+
|
| 637 |
+
x = args[0]
|
| 638 |
+
x_raw = self.inputs[0]
|
| 639 |
+
packed_weight = args[1]
|
| 640 |
+
packed_weight_raw = self.inputs[1]
|
| 641 |
+
bias = args[2] if self.has_bias else const_args[2]
|
| 642 |
+
bias_raw = self.inputs[2] if self.has_bias else self.constant_args[2]
|
| 643 |
+
w_scale, w_zp = args[-2], args[-1]
|
| 644 |
+
w_scale_raw, w_zp_raw = self.inputs[-2], self.inputs[-1]
|
| 645 |
+
(
|
| 646 |
+
x_scale,
|
| 647 |
+
x_zp,
|
| 648 |
+
) = const_args[:2]
|
| 649 |
+
(
|
| 650 |
+
x_scale_raw,
|
| 651 |
+
x_zp_raw,
|
| 652 |
+
) = self.constant_args[:2]
|
| 653 |
+
(
|
| 654 |
+
stride,
|
| 655 |
+
padding,
|
| 656 |
+
dilation,
|
| 657 |
+
groups,
|
| 658 |
+
o_scale,
|
| 659 |
+
o_zp,
|
| 660 |
+
output_dtype,
|
| 661 |
+
unary_attr,
|
| 662 |
+
unary_scalars,
|
| 663 |
+
unary_algorithm,
|
| 664 |
+
) = const_args[-10:]
|
| 665 |
+
(
|
| 666 |
+
stride_raw,
|
| 667 |
+
padding_raw,
|
| 668 |
+
dilation_raw,
|
| 669 |
+
groups_raw,
|
| 670 |
+
o_scale_raw,
|
| 671 |
+
o_zp_raw,
|
| 672 |
+
output_dtype_raw,
|
| 673 |
+
unary_attr_raw,
|
| 674 |
+
unary_scalars_raw,
|
| 675 |
+
unary_algorithm_raw,
|
| 676 |
+
) = self.constant_args[-10:]
|
| 677 |
+
codegen_args = (
|
| 678 |
+
x,
|
| 679 |
+
x_scale,
|
| 680 |
+
x_zp,
|
| 681 |
+
packed_weight,
|
| 682 |
+
w_scale,
|
| 683 |
+
w_zp,
|
| 684 |
+
bias,
|
| 685 |
+
stride,
|
| 686 |
+
padding,
|
| 687 |
+
dilation,
|
| 688 |
+
groups,
|
| 689 |
+
o_scale,
|
| 690 |
+
o_zp,
|
| 691 |
+
output_dtype,
|
| 692 |
+
unary_attr,
|
| 693 |
+
unary_scalars,
|
| 694 |
+
unary_algorithm,
|
| 695 |
+
)
|
| 696 |
+
raw_args = (
|
| 697 |
+
x_raw,
|
| 698 |
+
x_scale_raw,
|
| 699 |
+
x_zp_raw,
|
| 700 |
+
packed_weight_raw,
|
| 701 |
+
w_scale_raw,
|
| 702 |
+
w_zp_raw,
|
| 703 |
+
bias_raw,
|
| 704 |
+
stride_raw,
|
| 705 |
+
padding_raw,
|
| 706 |
+
dilation_raw,
|
| 707 |
+
groups_raw,
|
| 708 |
+
o_scale_raw,
|
| 709 |
+
o_zp_raw,
|
| 710 |
+
output_dtype_raw,
|
| 711 |
+
unary_attr_raw,
|
| 712 |
+
unary_scalars_raw,
|
| 713 |
+
unary_algorithm_raw,
|
| 714 |
+
)
|
| 715 |
+
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
|
| 716 |
+
self.get_name(),
|
| 717 |
+
self.python_kernel_name,
|
| 718 |
+
self.cpp_kernel_name,
|
| 719 |
+
codegen_args,
|
| 720 |
+
self.cpp_op_schema,
|
| 721 |
+
self.cpp_kernel_key,
|
| 722 |
+
op_overload=self.op_overload,
|
| 723 |
+
raw_args=raw_args,
|
| 724 |
+
)
|
| 725 |
+
if isinstance(self.layout, Layout):
|
| 726 |
+
self.codegen_size_asserts(wrapper)
|
| 727 |
+
|
| 728 |
+
@classmethod
|
| 729 |
+
def create(
|
| 730 |
+
cls,
|
| 731 |
+
qx: "TensorBox",
|
| 732 |
+
x_scale: float,
|
| 733 |
+
x_zero_point: int,
|
| 734 |
+
qw: "TensorBox", # qw
|
| 735 |
+
w_scale: "TensorBox",
|
| 736 |
+
w_zero_point: "TensorBox",
|
| 737 |
+
bias: "TensorBox",
|
| 738 |
+
stride: List[int],
|
| 739 |
+
padding: List[int],
|
| 740 |
+
dilation: List[int],
|
| 741 |
+
groups: int,
|
| 742 |
+
output_scale: float,
|
| 743 |
+
output_zero_point: int,
|
| 744 |
+
output_dtype,
|
| 745 |
+
attr,
|
| 746 |
+
scalars,
|
| 747 |
+
algorithm,
|
| 748 |
+
):
|
| 749 |
+
transposed = False
|
| 750 |
+
output_padding = None
|
| 751 |
+
(inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
|
| 752 |
+
cls,
|
| 753 |
+
qx,
|
| 754 |
+
qw,
|
| 755 |
+
bias,
|
| 756 |
+
padding,
|
| 757 |
+
stride,
|
| 758 |
+
dilation,
|
| 759 |
+
groups,
|
| 760 |
+
transposed,
|
| 761 |
+
output_padding,
|
| 762 |
+
)
|
| 763 |
+
# swap padding and stride to align with functional conv arg order
|
| 764 |
+
if bias is None:
|
| 765 |
+
constant_args[1], constant_args[2] = constant_args[2], constant_args[1]
|
| 766 |
+
else:
|
| 767 |
+
constant_args[0], constant_args[1] = constant_args[1], constant_args[0]
|
| 768 |
+
|
| 769 |
+
w_scale.realize()
|
| 770 |
+
w_zero_point.realize()
|
| 771 |
+
inputs = inputs + [w_scale, w_zero_point]
|
| 772 |
+
|
| 773 |
+
constant_args = (
|
| 774 |
+
[
|
| 775 |
+
x_scale,
|
| 776 |
+
x_zero_point,
|
| 777 |
+
]
|
| 778 |
+
+ constant_args
|
| 779 |
+
+ [
|
| 780 |
+
output_scale,
|
| 781 |
+
output_zero_point,
|
| 782 |
+
output_dtype,
|
| 783 |
+
attr,
|
| 784 |
+
may_convert_to_optional(scalars),
|
| 785 |
+
algorithm,
|
| 786 |
+
]
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
assert output_dtype is not None
|
| 790 |
+
if output_dtype in [torch.float32, torch.bfloat16]:
|
| 791 |
+
# in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout
|
| 792 |
+
# if we set output_dtype is not None, the output buf should be output_dtype instead of uint8.
|
| 793 |
+
kernel_layout.dtype = output_dtype
|
| 794 |
+
|
| 795 |
+
return QConvPointWisePT2E(
|
| 796 |
+
layout=kernel_layout,
|
| 797 |
+
inputs=inputs,
|
| 798 |
+
constant_args=constant_args,
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
|
| 803 |
+
def __init__(
|
| 804 |
+
self,
|
| 805 |
+
layout,
|
| 806 |
+
inputs,
|
| 807 |
+
constant_args=(),
|
| 808 |
+
) -> None:
|
| 809 |
+
"""
|
| 810 |
+
Needs input/weight/output qparams
|
| 811 |
+
if bias is not None
|
| 812 |
+
- inputs = [x, w, b, accum, w_scale, w_zp]
|
| 813 |
+
- const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_scale, o_zp,
|
| 814 |
+
fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
|
| 815 |
+
else
|
| 816 |
+
- inputs = [x, w, accum, w_scale, w_zp]
|
| 817 |
+
- const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale,
|
| 818 |
+
accum_zp, o_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
|
| 819 |
+
"""
|
| 820 |
+
self.has_bias = len(inputs) == 6
|
| 821 |
+
self.idx_for_inplace_sum = 3 if self.has_bias else 2
|
| 822 |
+
super().__init__(
|
| 823 |
+
layout,
|
| 824 |
+
inputs,
|
| 825 |
+
constant_args,
|
| 826 |
+
None,
|
| 827 |
+
op_overload=torch.ops.onednn.qconv2d_pointwise.binary,
|
| 828 |
+
)
|
| 829 |
+
self.cpp_op_schema = """
|
| 830 |
+
at::Tensor(
|
| 831 |
+
at::Tensor act,
|
| 832 |
+
double act_scale,
|
| 833 |
+
int64_t act_zero_point,
|
| 834 |
+
at::Tensor accum,
|
| 835 |
+
double accum_scale,
|
| 836 |
+
int64_t accum_zero_point,
|
| 837 |
+
at::Tensor weight,
|
| 838 |
+
at::Tensor weight_scales,
|
| 839 |
+
at::Tensor weight_zero_points,
|
| 840 |
+
std::optional<at::Tensor> bias,
|
| 841 |
+
torch::List<int64_t> stride,
|
| 842 |
+
torch::List<int64_t> padding,
|
| 843 |
+
torch::List<int64_t> dilation,
|
| 844 |
+
int64_t groups,
|
| 845 |
+
double output_scale,
|
| 846 |
+
int64_t output_zero_point,
|
| 847 |
+
std::optional<c10::ScalarType> output_dtype,
|
| 848 |
+
c10::string_view binary_attr,
|
| 849 |
+
std::optional<at::Scalar> alpha,
|
| 850 |
+
std::optional<c10::string_view> attr,
|
| 851 |
+
torch::List<std::optional<at::Scalar>> scalars,
|
| 852 |
+
std::optional<c10::string_view> algorithm)"""
|
| 853 |
+
|
| 854 |
+
def codegen(self, wrapper):
|
| 855 |
+
# Parser the inputs and constant
|
| 856 |
+
# The raw_args setup can be skipped if there is a C shim implementation
|
| 857 |
+
args = [x.codegen_reference() for x in self.inputs]
|
| 858 |
+
const_arg_names = [
|
| 859 |
+
"x_scale",
|
| 860 |
+
"x_zero_point",
|
| 861 |
+
"accum_scale",
|
| 862 |
+
"accum_zero_point",
|
| 863 |
+
"stride",
|
| 864 |
+
"padding",
|
| 865 |
+
"dilation",
|
| 866 |
+
"groups",
|
| 867 |
+
"output_scale",
|
| 868 |
+
"output_zero_point",
|
| 869 |
+
"output_dtype",
|
| 870 |
+
"binary_attr",
|
| 871 |
+
"alpha",
|
| 872 |
+
"unary_attr",
|
| 873 |
+
"unary_scalars",
|
| 874 |
+
"unary_algorithm",
|
| 875 |
+
]
|
| 876 |
+
if not self.has_bias:
|
| 877 |
+
const_arg_names.insert(4, "bias")
|
| 878 |
+
const_args = list(self.codegen_const_args(const_arg_names))
|
| 879 |
+
|
| 880 |
+
x = args[0]
|
| 881 |
+
x_raw = self.inputs[0]
|
| 882 |
+
packed_weight = args[1]
|
| 883 |
+
packed_weight_raw = self.inputs[1]
|
| 884 |
+
bias = args[2] if self.has_bias else const_args[4]
|
| 885 |
+
bias_raw = self.inputs[2] if self.has_bias else self.constant_args[4]
|
| 886 |
+
accum, w_scale, w_zp = args[-3], args[-2], args[-1]
|
| 887 |
+
accum_raw, w_scale_raw, w_zp_raw = (
|
| 888 |
+
self.inputs[-3],
|
| 889 |
+
self.inputs[-2],
|
| 890 |
+
self.inputs[-1],
|
| 891 |
+
)
|
| 892 |
+
(
|
| 893 |
+
x_scale,
|
| 894 |
+
x_zp,
|
| 895 |
+
accum_scale,
|
| 896 |
+
accum_zp,
|
| 897 |
+
) = const_args[:4]
|
| 898 |
+
(
|
| 899 |
+
x_scale_raw,
|
| 900 |
+
x_zp_raw,
|
| 901 |
+
accum_scale_raw,
|
| 902 |
+
accum_zp_raw,
|
| 903 |
+
) = self.constant_args[:4]
|
| 904 |
+
(
|
| 905 |
+
stride,
|
| 906 |
+
padding,
|
| 907 |
+
dilation,
|
| 908 |
+
groups,
|
| 909 |
+
o_scale,
|
| 910 |
+
o_zp,
|
| 911 |
+
output_dtype,
|
| 912 |
+
binary_attr,
|
| 913 |
+
alpha,
|
| 914 |
+
unary_attr,
|
| 915 |
+
unary_scalars,
|
| 916 |
+
unary_algorithm,
|
| 917 |
+
) = const_args[-12:]
|
| 918 |
+
(
|
| 919 |
+
stride_raw,
|
| 920 |
+
padding_raw,
|
| 921 |
+
dilation_raw,
|
| 922 |
+
groups_raw,
|
| 923 |
+
o_scale_raw,
|
| 924 |
+
o_zp_raw,
|
| 925 |
+
output_dtype_raw,
|
| 926 |
+
binary_attr_raw,
|
| 927 |
+
alpha_raw,
|
| 928 |
+
unary_attr_raw,
|
| 929 |
+
unary_scalars_raw,
|
| 930 |
+
unary_algorithm_raw,
|
| 931 |
+
) = self.constant_args[-12:]
|
| 932 |
+
conv_args = (
|
| 933 |
+
x,
|
| 934 |
+
x_scale,
|
| 935 |
+
x_zp,
|
| 936 |
+
accum,
|
| 937 |
+
accum_scale,
|
| 938 |
+
accum_zp,
|
| 939 |
+
packed_weight,
|
| 940 |
+
w_scale,
|
| 941 |
+
w_zp,
|
| 942 |
+
bias,
|
| 943 |
+
stride,
|
| 944 |
+
padding,
|
| 945 |
+
dilation,
|
| 946 |
+
groups,
|
| 947 |
+
o_scale,
|
| 948 |
+
o_zp,
|
| 949 |
+
output_dtype,
|
| 950 |
+
binary_attr,
|
| 951 |
+
alpha,
|
| 952 |
+
unary_attr,
|
| 953 |
+
unary_scalars,
|
| 954 |
+
unary_algorithm,
|
| 955 |
+
)
|
| 956 |
+
raw_args = (
|
| 957 |
+
x_raw,
|
| 958 |
+
x_scale_raw,
|
| 959 |
+
x_zp_raw,
|
| 960 |
+
accum_raw,
|
| 961 |
+
accum_scale_raw,
|
| 962 |
+
accum_zp_raw,
|
| 963 |
+
packed_weight_raw,
|
| 964 |
+
w_scale_raw,
|
| 965 |
+
w_zp_raw,
|
| 966 |
+
bias_raw,
|
| 967 |
+
stride_raw,
|
| 968 |
+
padding_raw,
|
| 969 |
+
dilation_raw,
|
| 970 |
+
groups_raw,
|
| 971 |
+
o_scale_raw,
|
| 972 |
+
o_zp_raw,
|
| 973 |
+
output_dtype_raw,
|
| 974 |
+
binary_attr_raw,
|
| 975 |
+
alpha_raw,
|
| 976 |
+
unary_attr_raw,
|
| 977 |
+
unary_scalars_raw,
|
| 978 |
+
unary_algorithm_raw,
|
| 979 |
+
)
|
| 980 |
+
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
|
| 981 |
+
self.get_name(),
|
| 982 |
+
self.python_kernel_name,
|
| 983 |
+
self.cpp_kernel_name,
|
| 984 |
+
conv_args,
|
| 985 |
+
self.cpp_op_schema,
|
| 986 |
+
self.cpp_kernel_key,
|
| 987 |
+
self.cpp_kernel_overload_name,
|
| 988 |
+
op_overload=self.op_overload,
|
| 989 |
+
raw_args=raw_args,
|
| 990 |
+
)
|
| 991 |
+
if isinstance(self.layout, Layout):
|
| 992 |
+
self.codegen_size_asserts(wrapper)
|
| 993 |
+
|
| 994 |
+
def get_mutation_names(self):
|
| 995 |
+
return [self.inputs[self.idx_for_inplace_sum].get_name()]
|
| 996 |
+
|
| 997 |
+
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
| 998 |
+
return OrderedSet()
|
| 999 |
+
|
| 1000 |
+
@classmethod
|
| 1001 |
+
def create(
|
| 1002 |
+
cls,
|
| 1003 |
+
qx: "TensorBox",
|
| 1004 |
+
x_scale,
|
| 1005 |
+
x_zero_point,
|
| 1006 |
+
qaccum: "TensorBox",
|
| 1007 |
+
accum_scale,
|
| 1008 |
+
accum_zero_point,
|
| 1009 |
+
qw: "TensorBox", # packed_weight
|
| 1010 |
+
w_scale,
|
| 1011 |
+
w_zero_point,
|
| 1012 |
+
bias: "TensorBox",
|
| 1013 |
+
stride: List[int],
|
| 1014 |
+
padding: List[int],
|
| 1015 |
+
dilation: List[int],
|
| 1016 |
+
groups: int,
|
| 1017 |
+
output_scale: "TensorBox",
|
| 1018 |
+
output_zero_point: "TensorBox",
|
| 1019 |
+
output_dtype,
|
| 1020 |
+
binary_attr,
|
| 1021 |
+
alpha,
|
| 1022 |
+
unary_attr,
|
| 1023 |
+
unary_scalars,
|
| 1024 |
+
unary_algorithm,
|
| 1025 |
+
):
|
| 1026 |
+
transposed = False
|
| 1027 |
+
output_padding = None
|
| 1028 |
+
(
|
| 1029 |
+
inputs,
|
| 1030 |
+
constant_args,
|
| 1031 |
+
kernel_layout,
|
| 1032 |
+
req_stride_order,
|
| 1033 |
+
) = _prepare_convolution_fusion_create(
|
| 1034 |
+
cls,
|
| 1035 |
+
qx,
|
| 1036 |
+
qw,
|
| 1037 |
+
bias,
|
| 1038 |
+
padding,
|
| 1039 |
+
stride,
|
| 1040 |
+
dilation,
|
| 1041 |
+
groups,
|
| 1042 |
+
transposed,
|
| 1043 |
+
output_padding,
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
qaccum = cls.require_stride_order(qaccum, req_stride_order)
|
| 1047 |
+
inputs.append(qaccum)
|
| 1048 |
+
|
| 1049 |
+
# swap padding and stride to align with functional conv arg order
|
| 1050 |
+
if bias is None:
|
| 1051 |
+
constant_args[1], constant_args[2] = constant_args[2], constant_args[1]
|
| 1052 |
+
else:
|
| 1053 |
+
constant_args[0], constant_args[1] = constant_args[1], constant_args[0]
|
| 1054 |
+
|
| 1055 |
+
w_scale.realize()
|
| 1056 |
+
w_zero_point.realize()
|
| 1057 |
+
inputs = inputs + [w_scale, w_zero_point]
|
| 1058 |
+
constant_args = (
|
| 1059 |
+
[
|
| 1060 |
+
x_scale,
|
| 1061 |
+
x_zero_point,
|
| 1062 |
+
accum_scale,
|
| 1063 |
+
accum_zero_point,
|
| 1064 |
+
]
|
| 1065 |
+
+ constant_args
|
| 1066 |
+
+ [
|
| 1067 |
+
output_scale,
|
| 1068 |
+
output_zero_point,
|
| 1069 |
+
output_dtype,
|
| 1070 |
+
binary_attr,
|
| 1071 |
+
alpha,
|
| 1072 |
+
unary_attr,
|
| 1073 |
+
may_convert_to_optional(unary_scalars),
|
| 1074 |
+
unary_algorithm,
|
| 1075 |
+
]
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
assert (
|
| 1079 |
+
binary_attr == "sum"
|
| 1080 |
+
), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
|
| 1081 |
+
|
| 1082 |
+
V.graph.mark_buffer_mutated(qaccum.get_name())
|
| 1083 |
+
packed = QConvPointWiseBinaryPT2E(
|
| 1084 |
+
layout=NoneLayout(qaccum.get_device()),
|
| 1085 |
+
inputs=inputs,
|
| 1086 |
+
constant_args=constant_args,
|
| 1087 |
+
)
|
| 1088 |
+
|
| 1089 |
+
# Return accum since it has been inplace changed.
|
| 1090 |
+
return packed.inputs[packed.idx_for_inplace_sum]
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
class MKLPackedLinear(ExternKernelAlloc):
|
| 1094 |
+
def __init__(
|
| 1095 |
+
self,
|
| 1096 |
+
layout,
|
| 1097 |
+
inputs,
|
| 1098 |
+
constant_args=(),
|
| 1099 |
+
) -> None:
|
| 1100 |
+
super().__init__(
|
| 1101 |
+
layout,
|
| 1102 |
+
inputs,
|
| 1103 |
+
constant_args,
|
| 1104 |
+
None,
|
| 1105 |
+
op_overload=torch.ops.mkl._mkl_linear.default,
|
| 1106 |
+
)
|
| 1107 |
+
self.cpp_op_schema = """
|
| 1108 |
+
at::Tensor(
|
| 1109 |
+
const at::Tensor& self,
|
| 1110 |
+
const at::Tensor& mkl_weight_t,
|
| 1111 |
+
const at::Tensor& origin_weight_t,
|
| 1112 |
+
const std::optional<at::Tensor>& bias_opt,
|
| 1113 |
+
const int64_t prepack_batch_size)"""
|
| 1114 |
+
|
| 1115 |
+
def codegen(self, wrapper):
|
| 1116 |
+
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
|
| 1117 |
+
self.get_name(),
|
| 1118 |
+
self.python_kernel_name,
|
| 1119 |
+
self.cpp_kernel_name,
|
| 1120 |
+
self.codegen_args(),
|
| 1121 |
+
self.cpp_op_schema,
|
| 1122 |
+
self.cpp_kernel_key,
|
| 1123 |
+
)
|
| 1124 |
+
|
| 1125 |
+
@classmethod
|
| 1126 |
+
def create(cls, x, packed_w, orig_w, B, batch_size):
|
| 1127 |
+
x = cls.require_stride1(cls.realize_input(x))
|
| 1128 |
+
orig_w = cls.require_stride1(cls.realize_input(orig_w))
|
| 1129 |
+
*m, _ = x.get_size()
|
| 1130 |
+
oc, _ = orig_w.get_size()
|
| 1131 |
+
output_size = list(m) + [oc]
|
| 1132 |
+
output_stride = FlexibleLayout.contiguous_strides(output_size)
|
| 1133 |
+
inputs = [x, packed_w, orig_w]
|
| 1134 |
+
constant_args = [batch_size]
|
| 1135 |
+
if B is not None:
|
| 1136 |
+
inputs += [B]
|
| 1137 |
+
else:
|
| 1138 |
+
constant_args.insert(0, None)
|
| 1139 |
+
|
| 1140 |
+
return MKLPackedLinear(
|
| 1141 |
+
layout=FixedLayout(
|
| 1142 |
+
x.get_device(), x.get_dtype(), output_size, output_stride
|
| 1143 |
+
),
|
| 1144 |
+
inputs=inputs,
|
| 1145 |
+
constant_args=constant_args,
|
| 1146 |
+
)
|
| 1147 |
+
|
| 1148 |
+
|
| 1149 |
+
class LinearUnary(ExternKernelAlloc):
|
| 1150 |
+
def __init__(
|
| 1151 |
+
self,
|
| 1152 |
+
layout,
|
| 1153 |
+
inputs,
|
| 1154 |
+
constant_args=(),
|
| 1155 |
+
) -> None:
|
| 1156 |
+
super().__init__(
|
| 1157 |
+
layout,
|
| 1158 |
+
inputs,
|
| 1159 |
+
constant_args,
|
| 1160 |
+
None,
|
| 1161 |
+
op_overload=torch.ops.mkldnn._linear_pointwise.default,
|
| 1162 |
+
)
|
| 1163 |
+
self.cpp_kernel_key = "linear_pointwise"
|
| 1164 |
+
self.cpp_op_schema = """
|
| 1165 |
+
at::Tensor(
|
| 1166 |
+
const at::Tensor& input_t,
|
| 1167 |
+
const at::Tensor& weight_t,
|
| 1168 |
+
const std::optional<at::Tensor>& bias_opt,
|
| 1169 |
+
c10::string_view attr,
|
| 1170 |
+
torch::List<std::optional<at::Scalar>> scalars,
|
| 1171 |
+
std::optional<c10::string_view> algorithm)"""
|
| 1172 |
+
|
| 1173 |
+
def codegen(self, wrapper):
|
| 1174 |
+
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
|
| 1175 |
+
self.get_name(),
|
| 1176 |
+
self.python_kernel_name,
|
| 1177 |
+
self.cpp_kernel_name,
|
| 1178 |
+
self.codegen_args(),
|
| 1179 |
+
self.cpp_op_schema,
|
| 1180 |
+
self.cpp_kernel_key,
|
| 1181 |
+
)
|
| 1182 |
+
|
| 1183 |
+
@classmethod
|
| 1184 |
+
def create(cls, x, w, B, attr, scalars, algorithm):
|
| 1185 |
+
x = cls.require_contiguous(cls.realize_input(x))
|
| 1186 |
+
w = cls.require_contiguous(cls.realize_input(w))
|
| 1187 |
+
|
| 1188 |
+
*m, ic = x.get_size()
|
| 1189 |
+
oc, ic = w.get_size()
|
| 1190 |
+
inputs = [x, w]
|
| 1191 |
+
constant_args = [attr, scalars if scalars else [-1], algorithm]
|
| 1192 |
+
if B is not None:
|
| 1193 |
+
B = cls.require_contiguous(cls.realize_input(B))
|
| 1194 |
+
inputs.append(B)
|
| 1195 |
+
else:
|
| 1196 |
+
constant_args.insert(0, None)
|
| 1197 |
+
|
| 1198 |
+
return LinearUnary(
|
| 1199 |
+
layout=FlexibleLayout(
|
| 1200 |
+
device=x.get_device(),
|
| 1201 |
+
dtype=x.get_dtype(),
|
| 1202 |
+
size=list(m) + [oc],
|
| 1203 |
+
),
|
| 1204 |
+
inputs=inputs,
|
| 1205 |
+
constant_args=constant_args,
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
def apply_constraint(self):
|
| 1209 |
+
pass
|
| 1210 |
+
|
| 1211 |
+
|
| 1212 |
+
class LinearBinary(ExternKernelAlloc):
|
| 1213 |
+
kernel = "torch.ops.mkldnn._linear_pointwise.binary"
|
| 1214 |
+
|
| 1215 |
+
def __init__(
|
| 1216 |
+
self,
|
| 1217 |
+
layout,
|
| 1218 |
+
inputs,
|
| 1219 |
+
constant_args=(),
|
| 1220 |
+
) -> None:
|
| 1221 |
+
super().__init__(
|
| 1222 |
+
layout,
|
| 1223 |
+
inputs,
|
| 1224 |
+
constant_args,
|
| 1225 |
+
None,
|
| 1226 |
+
op_overload=torch.ops.mkldnn._linear_pointwise.binary,
|
| 1227 |
+
)
|
| 1228 |
+
self.cpp_op_schema = """
|
| 1229 |
+
at::Tensor(
|
| 1230 |
+
const at::Tensor& input_t,
|
| 1231 |
+
const at::Tensor& other_t,
|
| 1232 |
+
const at::Tensor& weight_t,
|
| 1233 |
+
const std::optional<at::Tensor>& bias_opt,
|
| 1234 |
+
c10::string_view attr)
|
| 1235 |
+
"""
|
| 1236 |
+
|
| 1237 |
+
def codegen(self, wrapper):
|
| 1238 |
+
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
|
| 1239 |
+
self.get_name(),
|
| 1240 |
+
self.python_kernel_name,
|
| 1241 |
+
self.cpp_kernel_name,
|
| 1242 |
+
self.codegen_args(),
|
| 1243 |
+
self.cpp_op_schema,
|
| 1244 |
+
self.cpp_kernel_key,
|
| 1245 |
+
self.cpp_kernel_overload_name,
|
| 1246 |
+
)
|
| 1247 |
+
|
| 1248 |
+
@classmethod
|
| 1249 |
+
def create(cls, x, y, w, B, attr):
|
| 1250 |
+
x = cls.require_contiguous(cls.realize_input(x))
|
| 1251 |
+
y = cls.require_contiguous(cls.realize_input(y))
|
| 1252 |
+
w = cls.require_contiguous(cls.realize_input(w))
|
| 1253 |
+
|
| 1254 |
+
*m, ic = x.get_size()
|
| 1255 |
+
oc, ic = w.get_size()
|
| 1256 |
+
|
| 1257 |
+
inputs = [x, y, w]
|
| 1258 |
+
constant_args = [attr]
|
| 1259 |
+
if B is not None:
|
| 1260 |
+
B = cls.require_contiguous(cls.realize_input(B))
|
| 1261 |
+
inputs.append(B)
|
| 1262 |
+
else:
|
| 1263 |
+
constant_args.insert(0, B)
|
| 1264 |
+
|
| 1265 |
+
return LinearBinary(
|
| 1266 |
+
layout=FlexibleLayout(
|
| 1267 |
+
device=x.get_device(),
|
| 1268 |
+
dtype=x.get_dtype(),
|
| 1269 |
+
size=list(m) + [oc],
|
| 1270 |
+
),
|
| 1271 |
+
inputs=inputs,
|
| 1272 |
+
constant_args=constant_args,
|
| 1273 |
+
)
|
| 1274 |
+
|
| 1275 |
+
def apply_constraint(self):
|
| 1276 |
+
pass
|
| 1277 |
+
|
| 1278 |
+
|
| 1279 |
+
class QLinearPointwisePT2E(ExternKernelAlloc):
|
| 1280 |
+
def __init__(
|
| 1281 |
+
self,
|
| 1282 |
+
layout,
|
| 1283 |
+
inputs,
|
| 1284 |
+
constant_args=(),
|
| 1285 |
+
has_bias=True,
|
| 1286 |
+
x_scale_zp_are_tensors=False,
|
| 1287 |
+
) -> None:
|
| 1288 |
+
"""
|
| 1289 |
+
if bias is not None
|
| 1290 |
+
- inputs = [x, w, b, weight_scale, weight_zp]
|
| 1291 |
+
- const_args is: [x_scale, x_zp, o_scale, o_zp,
|
| 1292 |
+
fp32_output, unary_attr, unary_scalars, unary_algorithm]
|
| 1293 |
+
else
|
| 1294 |
+
- inputs = [x, w, weight_scale, weight_zp]
|
| 1295 |
+
- const_args is: [bias, x_scale, x_zp, o_scale, o_zp,
|
| 1296 |
+
fp32_output, unary_attr, unary_scalars, unary_algorithm]
|
| 1297 |
+
"""
|
| 1298 |
+
self.has_bias = has_bias
|
| 1299 |
+
self.x_scale_zp_are_tensors = x_scale_zp_are_tensors
|
| 1300 |
+
super().__init__(
|
| 1301 |
+
layout,
|
| 1302 |
+
inputs,
|
| 1303 |
+
constant_args,
|
| 1304 |
+
None,
|
| 1305 |
+
op_overload=torch.ops.onednn.qlinear_pointwise.tensor
|
| 1306 |
+
if x_scale_zp_are_tensors
|
| 1307 |
+
else torch.ops.onednn.qlinear_pointwise.default,
|
| 1308 |
+
)
|
| 1309 |
+
x_scale_type_str, x_zp_type_str = (
|
| 1310 |
+
("at::Tensor", "at::Tensor")
|
| 1311 |
+
if x_scale_zp_are_tensors
|
| 1312 |
+
else ("double", "int64_t")
|
| 1313 |
+
)
|
| 1314 |
+
self.cpp_op_schema = f"""
|
| 1315 |
+
at::Tensor(
|
| 1316 |
+
at::Tensor act,
|
| 1317 |
+
{x_scale_type_str} act_scale,
|
| 1318 |
+
{x_zp_type_str} act_zero_point,
|
| 1319 |
+
at::Tensor weight,
|
| 1320 |
+
at::Tensor weight_scales,
|
| 1321 |
+
at::Tensor weight_zero_points,
|
| 1322 |
+
std::optional<at::Tensor> bias,
|
| 1323 |
+
double output_scale,
|
| 1324 |
+
int64_t output_zero_point,
|
| 1325 |
+
std::optional<c10::ScalarType> output_dtype,
|
| 1326 |
+
c10::string_view post_op_name,
|
| 1327 |
+
torch::List<std::optional<at::Scalar>> post_op_args,
|
| 1328 |
+
c10::string_view post_op_algorithm)"""
|
| 1329 |
+
|
| 1330 |
+
def codegen(self, wrapper):
|
| 1331 |
+
# Parser the inputs and constant
|
| 1332 |
+
# The raw_args setup can be skipped if there is a C shim implementation
|
| 1333 |
+
args = [x.codegen_reference() for x in self.inputs]
|
| 1334 |
+
const_args = []
|
| 1335 |
+
const_args.extend(self.codegen_const_args())
|
| 1336 |
+
|
| 1337 |
+
x = args[0]
|
| 1338 |
+
x_raw = self.inputs[0]
|
| 1339 |
+
packed_weight = args[1]
|
| 1340 |
+
packed_weight_raw = self.inputs[1]
|
| 1341 |
+
bias = args[2] if self.has_bias else const_args[0]
|
| 1342 |
+
bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0]
|
| 1343 |
+
w_scale, w_zp = args[-2], args[-1]
|
| 1344 |
+
w_scale_raw, w_zp_raw = self.inputs[-2], self.inputs[-1]
|
| 1345 |
+
if self.x_scale_zp_are_tensors:
|
| 1346 |
+
assert len(args) >= 4
|
| 1347 |
+
x_scale, x_zp = args[-4], args[-3]
|
| 1348 |
+
x_scale_raw, x_zp_raw = self.inputs[-4], self.inputs[-3]
|
| 1349 |
+
(
|
| 1350 |
+
o_scale,
|
| 1351 |
+
o_zp,
|
| 1352 |
+
output_dtype,
|
| 1353 |
+
unary_attr,
|
| 1354 |
+
unary_scalars,
|
| 1355 |
+
unary_algorithm,
|
| 1356 |
+
) = const_args[-6:]
|
| 1357 |
+
(
|
| 1358 |
+
o_scale_raw,
|
| 1359 |
+
o_zp_raw,
|
| 1360 |
+
output_dtype_raw,
|
| 1361 |
+
unary_attr_raw,
|
| 1362 |
+
unary_scalars_raw,
|
| 1363 |
+
unary_algorithm_raw,
|
| 1364 |
+
) = self.constant_args[-6:]
|
| 1365 |
+
else:
|
| 1366 |
+
assert len(const_args) >= 8
|
| 1367 |
+
(
|
| 1368 |
+
x_scale,
|
| 1369 |
+
x_zp,
|
| 1370 |
+
o_scale,
|
| 1371 |
+
o_zp,
|
| 1372 |
+
output_dtype,
|
| 1373 |
+
unary_attr,
|
| 1374 |
+
unary_scalars,
|
| 1375 |
+
unary_algorithm,
|
| 1376 |
+
) = const_args[-8:]
|
| 1377 |
+
(
|
| 1378 |
+
x_scale_raw,
|
| 1379 |
+
x_zp_raw,
|
| 1380 |
+
o_scale_raw,
|
| 1381 |
+
o_zp_raw,
|
| 1382 |
+
output_dtype_raw,
|
| 1383 |
+
unary_attr_raw,
|
| 1384 |
+
unary_scalars_raw,
|
| 1385 |
+
unary_algorithm_raw,
|
| 1386 |
+
) = self.constant_args[-8:]
|
| 1387 |
+
|
| 1388 |
+
codegen_args = (
|
| 1389 |
+
x,
|
| 1390 |
+
x_scale,
|
| 1391 |
+
x_zp,
|
| 1392 |
+
packed_weight,
|
| 1393 |
+
w_scale,
|
| 1394 |
+
w_zp,
|
| 1395 |
+
bias,
|
| 1396 |
+
o_scale,
|
| 1397 |
+
o_zp,
|
| 1398 |
+
output_dtype,
|
| 1399 |
+
unary_attr,
|
| 1400 |
+
unary_scalars,
|
| 1401 |
+
unary_algorithm,
|
| 1402 |
+
)
|
| 1403 |
+
raw_args = (
|
| 1404 |
+
x_raw,
|
| 1405 |
+
x_scale_raw,
|
| 1406 |
+
x_zp_raw,
|
| 1407 |
+
packed_weight_raw,
|
| 1408 |
+
w_scale_raw,
|
| 1409 |
+
w_zp_raw,
|
| 1410 |
+
bias_raw,
|
| 1411 |
+
o_scale_raw,
|
| 1412 |
+
o_zp_raw,
|
| 1413 |
+
output_dtype_raw,
|
| 1414 |
+
unary_attr_raw,
|
| 1415 |
+
unary_scalars_raw,
|
| 1416 |
+
unary_algorithm_raw,
|
| 1417 |
+
)
|
| 1418 |
+
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
|
| 1419 |
+
self.get_name(),
|
| 1420 |
+
self.python_kernel_name,
|
| 1421 |
+
self.cpp_kernel_name,
|
| 1422 |
+
codegen_args,
|
| 1423 |
+
self.cpp_op_schema,
|
| 1424 |
+
self.cpp_kernel_key,
|
| 1425 |
+
self.cpp_kernel_overload_name,
|
| 1426 |
+
self.op_overload,
|
| 1427 |
+
raw_args,
|
| 1428 |
+
)
|
| 1429 |
+
if isinstance(self.layout, Layout):
|
| 1430 |
+
self.codegen_size_asserts(wrapper)
|
| 1431 |
+
|
| 1432 |
+
@classmethod
|
| 1433 |
+
def create(
|
| 1434 |
+
cls,
|
| 1435 |
+
qx: "TensorBox",
|
| 1436 |
+
x_scale: float,
|
| 1437 |
+
x_zero_point: int,
|
| 1438 |
+
qw: "TensorBox", # packed_weight
|
| 1439 |
+
w_scale: "TensorBox",
|
| 1440 |
+
w_zero_point: "TensorBox",
|
| 1441 |
+
bias: "TensorBox",
|
| 1442 |
+
output_scale: float,
|
| 1443 |
+
output_zero_point: int,
|
| 1444 |
+
output_dtype,
|
| 1445 |
+
post_op_name,
|
| 1446 |
+
post_op_args,
|
| 1447 |
+
post_op_algorithm,
|
| 1448 |
+
):
|
| 1449 |
+
(inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create(
|
| 1450 |
+
cls,
|
| 1451 |
+
qx,
|
| 1452 |
+
qw,
|
| 1453 |
+
bias,
|
| 1454 |
+
)
|
| 1455 |
+
|
| 1456 |
+
if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox):
|
| 1457 |
+
x_scale.realize()
|
| 1458 |
+
x_zero_point.realize()
|
| 1459 |
+
inputs = inputs + [x_scale, x_zero_point]
|
| 1460 |
+
x_scale_zp_are_tensors = True
|
| 1461 |
+
else:
|
| 1462 |
+
assert isinstance(x_scale, float) and isinstance(x_zero_point, int)
|
| 1463 |
+
constant_args = constant_args + [x_scale, x_zero_point]
|
| 1464 |
+
x_scale_zp_are_tensors = False
|
| 1465 |
+
w_scale.realize()
|
| 1466 |
+
w_zero_point.realize()
|
| 1467 |
+
inputs = inputs + [w_scale, w_zero_point]
|
| 1468 |
+
constant_args = constant_args + [
|
| 1469 |
+
output_scale,
|
| 1470 |
+
output_zero_point,
|
| 1471 |
+
output_dtype,
|
| 1472 |
+
post_op_name,
|
| 1473 |
+
may_convert_to_optional(post_op_args),
|
| 1474 |
+
post_op_algorithm,
|
| 1475 |
+
]
|
| 1476 |
+
|
| 1477 |
+
assert output_dtype is not None
|
| 1478 |
+
if output_dtype in [torch.float32, torch.bfloat16]:
|
| 1479 |
+
# in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout
|
| 1480 |
+
# if we set fp32_output, the output buf should be dtype float32 instead of uint8.
|
| 1481 |
+
kernel_layout.dtype = output_dtype
|
| 1482 |
+
|
| 1483 |
+
return QLinearPointwisePT2E(
|
| 1484 |
+
layout=kernel_layout,
|
| 1485 |
+
inputs=inputs,
|
| 1486 |
+
constant_args=constant_args,
|
| 1487 |
+
has_bias=(bias is not None),
|
| 1488 |
+
x_scale_zp_are_tensors=x_scale_zp_are_tensors,
|
| 1489 |
+
)
|
| 1490 |
+
|
| 1491 |
+
|
| 1492 |
+
class QLinearPointwiseBinaryPT2E(ExternKernelAlloc):
|
| 1493 |
+
def __init__(
|
| 1494 |
+
self,
|
| 1495 |
+
layout,
|
| 1496 |
+
inputs,
|
| 1497 |
+
constant_args=(),
|
| 1498 |
+
has_bias=True,
|
| 1499 |
+
x_scale_zp_are_tensors=False,
|
| 1500 |
+
) -> None:
|
| 1501 |
+
"""
|
| 1502 |
+
if bias is not None
|
| 1503 |
+
- inputs = [x, w, b, weight_scale, weight_zp, x2]
|
| 1504 |
+
- const_args is: [x_scale, x_zp, o_scale, o_zp,
|
| 1505 |
+
fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
|
| 1506 |
+
else
|
| 1507 |
+
- inputs = [x, w, weight_scale, weight_zp, x2]
|
| 1508 |
+
- const_args is: [bias, x_scale, x_zp, o_scale, o_zp,
|
| 1509 |
+
fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
|
| 1510 |
+
"""
|
| 1511 |
+
self.has_bias = has_bias
|
| 1512 |
+
self.x_scale_zp_are_tensors = x_scale_zp_are_tensors
|
| 1513 |
+
super().__init__(
|
| 1514 |
+
layout,
|
| 1515 |
+
inputs,
|
| 1516 |
+
constant_args,
|
| 1517 |
+
None,
|
| 1518 |
+
op_overload=torch.ops.onednn.qlinear_pointwise.binary_tensor
|
| 1519 |
+
if x_scale_zp_are_tensors
|
| 1520 |
+
else torch.ops.onednn.qlinear_pointwise.binary,
|
| 1521 |
+
)
|
| 1522 |
+
x_scale_type_str, x_zp_type_str = (
|
| 1523 |
+
("at::Tensor", "at::Tensor")
|
| 1524 |
+
if x_scale_zp_are_tensors
|
| 1525 |
+
else ("double", "int64_t")
|
| 1526 |
+
)
|
| 1527 |
+
self.cpp_op_schema = f"""
|
| 1528 |
+
at::Tensor(
|
| 1529 |
+
at::Tensor act,
|
| 1530 |
+
{x_scale_type_str} act_scale,
|
| 1531 |
+
{x_zp_type_str} act_zero_point,
|
| 1532 |
+
at::Tensor weight,
|
| 1533 |
+
at::Tensor weight_scales,
|
| 1534 |
+
at::Tensor weight_zero_points,
|
| 1535 |
+
std::optional<at::Tensor> other,
|
| 1536 |
+
std::optional<at::Tensor> bias,
|
| 1537 |
+
double inv_output_scale,
|
| 1538 |
+
int64_t output_zero_point,
|
| 1539 |
+
std::optional<c10::ScalarType> output_dtype,
|
| 1540 |
+
double other_scale,
|
| 1541 |
+
int64_t other_zero_point,
|
| 1542 |
+
c10::string_view binary_post_op,
|
| 1543 |
+
double binary_alpha,
|
| 1544 |
+
c10::string_view unary_post_op,
|
| 1545 |
+
torch::List<std::optional<at::Scalar>> unary_post_op_args,
|
| 1546 |
+
c10::string_view unary_post_op_algorithm)"""
|
| 1547 |
+
|
| 1548 |
+
def codegen(self, wrapper):
|
| 1549 |
+
# Parser the inputs and constant
|
| 1550 |
+
# The raw_args setup can be skipped if there is a C shim implementation
|
| 1551 |
+
args = [x.codegen_reference() for x in self.inputs]
|
| 1552 |
+
const_args = []
|
| 1553 |
+
const_args.extend(self.codegen_const_args())
|
| 1554 |
+
|
| 1555 |
+
x = args[0]
|
| 1556 |
+
x_raw = self.inputs[0]
|
| 1557 |
+
packed_weight = args[1]
|
| 1558 |
+
packed_weight_raw = self.inputs[1]
|
| 1559 |
+
bias = args[2] if self.has_bias else const_args[0]
|
| 1560 |
+
bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0]
|
| 1561 |
+
w_scale, w_zp, other = args[-3], args[-2], args[-1]
|
| 1562 |
+
w_scale_raw, w_zp_raw, other_raw = (
|
| 1563 |
+
self.inputs[-3],
|
| 1564 |
+
self.inputs[-2],
|
| 1565 |
+
self.inputs[-1],
|
| 1566 |
+
)
|
| 1567 |
+
if self.x_scale_zp_are_tensors:
|
| 1568 |
+
assert len(args) >= 5
|
| 1569 |
+
x_scale, x_zp = args[-5], args[-4]
|
| 1570 |
+
x_scale_raw, x_zp_raw = self.inputs[-5], self.inputs[-4]
|
| 1571 |
+
(
|
| 1572 |
+
o_scale,
|
| 1573 |
+
o_zp,
|
| 1574 |
+
output_dtype,
|
| 1575 |
+
other_scale,
|
| 1576 |
+
other_zp,
|
| 1577 |
+
binary_attr,
|
| 1578 |
+
alpha,
|
| 1579 |
+
unary_attr,
|
| 1580 |
+
unary_scalars,
|
| 1581 |
+
unary_algorithm,
|
| 1582 |
+
) = const_args[-10:]
|
| 1583 |
+
(
|
| 1584 |
+
o_scale_raw,
|
| 1585 |
+
o_zp_raw,
|
| 1586 |
+
output_dtype_raw,
|
| 1587 |
+
other_scale_raw,
|
| 1588 |
+
other_zp_raw,
|
| 1589 |
+
binary_attr_raw,
|
| 1590 |
+
alpha_raw,
|
| 1591 |
+
unary_attr_raw,
|
| 1592 |
+
unary_scalars_raw,
|
| 1593 |
+
unary_algorithm_raw,
|
| 1594 |
+
) = self.constant_args[-10:]
|
| 1595 |
+
else:
|
| 1596 |
+
assert len(const_args) >= 8
|
| 1597 |
+
(
|
| 1598 |
+
x_scale,
|
| 1599 |
+
x_zp,
|
| 1600 |
+
o_scale,
|
| 1601 |
+
o_zp,
|
| 1602 |
+
output_dtype,
|
| 1603 |
+
other_scale,
|
| 1604 |
+
other_zp,
|
| 1605 |
+
binary_attr,
|
| 1606 |
+
alpha,
|
| 1607 |
+
unary_attr,
|
| 1608 |
+
unary_scalars,
|
| 1609 |
+
unary_algorithm,
|
| 1610 |
+
) = const_args[-12:]
|
| 1611 |
+
(
|
| 1612 |
+
x_scale_raw,
|
| 1613 |
+
x_zp_raw,
|
| 1614 |
+
o_scale_raw,
|
| 1615 |
+
o_zp_raw,
|
| 1616 |
+
output_dtype_raw,
|
| 1617 |
+
other_scale_raw,
|
| 1618 |
+
other_zp_raw,
|
| 1619 |
+
binary_attr_raw,
|
| 1620 |
+
alpha_raw,
|
| 1621 |
+
unary_attr_raw,
|
| 1622 |
+
unary_scalars_raw,
|
| 1623 |
+
unary_algorithm_raw,
|
| 1624 |
+
) = self.constant_args[-12:]
|
| 1625 |
+
|
| 1626 |
+
codegen_args = (
|
| 1627 |
+
x,
|
| 1628 |
+
x_scale,
|
| 1629 |
+
x_zp,
|
| 1630 |
+
packed_weight,
|
| 1631 |
+
w_scale,
|
| 1632 |
+
w_zp,
|
| 1633 |
+
other,
|
| 1634 |
+
bias,
|
| 1635 |
+
o_scale,
|
| 1636 |
+
o_zp,
|
| 1637 |
+
output_dtype,
|
| 1638 |
+
other_scale,
|
| 1639 |
+
other_zp,
|
| 1640 |
+
binary_attr,
|
| 1641 |
+
alpha,
|
| 1642 |
+
unary_attr,
|
| 1643 |
+
unary_scalars,
|
| 1644 |
+
unary_algorithm,
|
| 1645 |
+
)
|
| 1646 |
+
raw_args = (
|
| 1647 |
+
x_raw,
|
| 1648 |
+
x_scale_raw,
|
| 1649 |
+
x_zp_raw,
|
| 1650 |
+
packed_weight_raw,
|
| 1651 |
+
w_scale_raw,
|
| 1652 |
+
w_zp_raw,
|
| 1653 |
+
other_raw,
|
| 1654 |
+
bias_raw,
|
| 1655 |
+
o_scale_raw,
|
| 1656 |
+
o_zp_raw,
|
| 1657 |
+
output_dtype_raw,
|
| 1658 |
+
other_scale_raw,
|
| 1659 |
+
other_zp_raw,
|
| 1660 |
+
binary_attr_raw,
|
| 1661 |
+
alpha_raw,
|
| 1662 |
+
unary_attr_raw,
|
| 1663 |
+
unary_scalars_raw,
|
| 1664 |
+
unary_algorithm_raw,
|
| 1665 |
+
)
|
| 1666 |
+
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
|
| 1667 |
+
self.get_name(),
|
| 1668 |
+
self.python_kernel_name,
|
| 1669 |
+
self.cpp_kernel_name,
|
| 1670 |
+
codegen_args,
|
| 1671 |
+
self.cpp_op_schema,
|
| 1672 |
+
self.cpp_kernel_key,
|
| 1673 |
+
self.cpp_kernel_overload_name,
|
| 1674 |
+
self.op_overload,
|
| 1675 |
+
raw_args,
|
| 1676 |
+
)
|
| 1677 |
+
if isinstance(self.layout, Layout):
|
| 1678 |
+
self.codegen_size_asserts(wrapper)
|
| 1679 |
+
|
| 1680 |
+
def get_mutation_names(self):
|
| 1681 |
+
binary_post_op = self.constant_args[-5]
|
| 1682 |
+
if binary_post_op == "sum":
|
| 1683 |
+
return [self.inputs[-1].get_name()]
|
| 1684 |
+
else:
|
| 1685 |
+
return []
|
| 1686 |
+
|
| 1687 |
+
@classmethod
|
| 1688 |
+
def create(
|
| 1689 |
+
cls,
|
| 1690 |
+
qx: "TensorBox",
|
| 1691 |
+
x_scale: float,
|
| 1692 |
+
x_zero_point: int,
|
| 1693 |
+
qw: "TensorBox", # packed_weight
|
| 1694 |
+
w_scale: "TensorBox",
|
| 1695 |
+
w_zero_point: "TensorBox",
|
| 1696 |
+
other: "TensorBox",
|
| 1697 |
+
bias: "TensorBox",
|
| 1698 |
+
output_scale: float,
|
| 1699 |
+
output_zero_point: int,
|
| 1700 |
+
output_dtype,
|
| 1701 |
+
other_scale,
|
| 1702 |
+
other_zp,
|
| 1703 |
+
binary_post_op,
|
| 1704 |
+
binary_alpha,
|
| 1705 |
+
unary_post_op,
|
| 1706 |
+
unary_post_op_args,
|
| 1707 |
+
unary_post_op_algorithm,
|
| 1708 |
+
):
|
| 1709 |
+
(
|
| 1710 |
+
inputs,
|
| 1711 |
+
constant_args,
|
| 1712 |
+
kernel_layout,
|
| 1713 |
+
req_stride_order,
|
| 1714 |
+
) = _prepare_linear_fusion_create(
|
| 1715 |
+
cls,
|
| 1716 |
+
qx,
|
| 1717 |
+
qw,
|
| 1718 |
+
bias,
|
| 1719 |
+
)
|
| 1720 |
+
|
| 1721 |
+
if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox):
|
| 1722 |
+
x_scale.realize()
|
| 1723 |
+
x_zero_point.realize()
|
| 1724 |
+
inputs = inputs + [x_scale, x_zero_point]
|
| 1725 |
+
x_scale_zp_are_tensors = True
|
| 1726 |
+
else:
|
| 1727 |
+
assert isinstance(x_scale, float) and isinstance(x_zero_point, int)
|
| 1728 |
+
constant_args = constant_args + [x_scale, x_zero_point]
|
| 1729 |
+
x_scale_zp_are_tensors = False
|
| 1730 |
+
w_scale.realize()
|
| 1731 |
+
w_zero_point.realize()
|
| 1732 |
+
inputs = inputs + [w_scale, w_zero_point]
|
| 1733 |
+
if binary_post_op == "sum":
|
| 1734 |
+
other = cls.require_stride_order(other, req_stride_order)
|
| 1735 |
+
inputs.append(other)
|
| 1736 |
+
constant_args = constant_args + [
|
| 1737 |
+
output_scale,
|
| 1738 |
+
output_zero_point,
|
| 1739 |
+
output_dtype,
|
| 1740 |
+
other_scale,
|
| 1741 |
+
other_zp,
|
| 1742 |
+
binary_post_op,
|
| 1743 |
+
binary_alpha,
|
| 1744 |
+
unary_post_op,
|
| 1745 |
+
may_convert_to_optional(unary_post_op_args),
|
| 1746 |
+
unary_post_op_algorithm,
|
| 1747 |
+
]
|
| 1748 |
+
|
| 1749 |
+
if binary_post_op == "sum":
|
| 1750 |
+
V.graph.mark_buffer_mutated(other.get_name())
|
| 1751 |
+
packed = QLinearPointwiseBinaryPT2E(
|
| 1752 |
+
layout=NoneLayout(other.get_device()),
|
| 1753 |
+
inputs=inputs,
|
| 1754 |
+
constant_args=constant_args,
|
| 1755 |
+
has_bias=(bias is not None),
|
| 1756 |
+
x_scale_zp_are_tensors=x_scale_zp_are_tensors,
|
| 1757 |
+
)
|
| 1758 |
+
# Return other since it has been inplace changed.
|
| 1759 |
+
return packed.inputs[-1]
|
| 1760 |
+
|
| 1761 |
+
assert output_dtype is not None
|
| 1762 |
+
if output_dtype in [torch.float32, torch.bfloat16]:
|
| 1763 |
+
# in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout
|
| 1764 |
+
# if we set fp32_output, the output buf should be dtype float32 instead of uint8.
|
| 1765 |
+
kernel_layout.dtype = output_dtype
|
| 1766 |
+
|
| 1767 |
+
return QLinearPointwiseBinaryPT2E(
|
| 1768 |
+
layout=kernel_layout,
|
| 1769 |
+
inputs=inputs,
|
| 1770 |
+
constant_args=constant_args,
|
| 1771 |
+
has_bias=(bias is not None),
|
| 1772 |
+
x_scale_zp_are_tensors=x_scale_zp_are_tensors,
|
| 1773 |
+
)
|
| 1774 |
+
|
| 1775 |
+
|
| 1776 |
+
class MkldnnRnnLayer(ExternKernelAlloc):
|
| 1777 |
+
def __init__(
|
| 1778 |
+
self,
|
| 1779 |
+
layout,
|
| 1780 |
+
inputs,
|
| 1781 |
+
constant_args=(),
|
| 1782 |
+
) -> None:
|
| 1783 |
+
super().__init__(
|
| 1784 |
+
layout,
|
| 1785 |
+
inputs,
|
| 1786 |
+
constant_args,
|
| 1787 |
+
None,
|
| 1788 |
+
op_overload=torch.ops.aten.mkldnn_rnn_layer.default,
|
| 1789 |
+
)
|
| 1790 |
+
|
| 1791 |
+
@classmethod
|
| 1792 |
+
def create(
|
| 1793 |
+
cls,
|
| 1794 |
+
x: "TensorBox",
|
| 1795 |
+
w0: "TensorBox",
|
| 1796 |
+
w1: "TensorBox",
|
| 1797 |
+
w2: "TensorBox",
|
| 1798 |
+
w3: "TensorBox",
|
| 1799 |
+
hx: "TensorBox",
|
| 1800 |
+
cx: "TensorBox",
|
| 1801 |
+
reverse: bool,
|
| 1802 |
+
batch_sizes: List[int],
|
| 1803 |
+
mode: int,
|
| 1804 |
+
hidden_size: int,
|
| 1805 |
+
num_layers: int,
|
| 1806 |
+
has_biases: bool,
|
| 1807 |
+
bidirectional: bool,
|
| 1808 |
+
batch_first: bool,
|
| 1809 |
+
train: bool,
|
| 1810 |
+
):
|
| 1811 |
+
x = cls.require_stride1(cls.realize_input(x))
|
| 1812 |
+
# If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer.
|
| 1813 |
+
# Make sure x is contiguous in batch_first case.
|
| 1814 |
+
x.freeze_layout()
|
| 1815 |
+
w0 = cls.require_stride1(cls.realize_input(w0))
|
| 1816 |
+
w1 = cls.require_stride1(cls.realize_input(w1))
|
| 1817 |
+
w2 = cls.require_stride1(cls.realize_input(w2))
|
| 1818 |
+
w3 = cls.require_stride1(cls.realize_input(w3))
|
| 1819 |
+
hx = cls.require_stride1(cls.realize_input(hx))
|
| 1820 |
+
hx.freeze_layout()
|
| 1821 |
+
cx = cls.require_stride1(cls.realize_input(cx))
|
| 1822 |
+
cx.freeze_layout()
|
| 1823 |
+
|
| 1824 |
+
input_size = x.get_size()
|
| 1825 |
+
assert len(input_size) == 3, "Expect lstm input to be 3D"
|
| 1826 |
+
# batch_first is handled in the lstm OP. When entering
|
| 1827 |
+
# rnn_layer here, we'll always have batch_first = False
|
| 1828 |
+
seq_length, mini_batch, input_size = input_size
|
| 1829 |
+
output_shape = [seq_length, mini_batch, hidden_size]
|
| 1830 |
+
|
| 1831 |
+
hy_shape = hx.get_size()
|
| 1832 |
+
cy_shape = cx.get_size()
|
| 1833 |
+
|
| 1834 |
+
res: List[IRNode] = []
|
| 1835 |
+
|
| 1836 |
+
inputs = [x, w0, w1, w2, w3, hx, cx]
|
| 1837 |
+
constant_args = [
|
| 1838 |
+
reverse,
|
| 1839 |
+
batch_sizes,
|
| 1840 |
+
mode,
|
| 1841 |
+
hidden_size,
|
| 1842 |
+
num_layers,
|
| 1843 |
+
has_biases,
|
| 1844 |
+
bidirectional,
|
| 1845 |
+
batch_first,
|
| 1846 |
+
train,
|
| 1847 |
+
]
|
| 1848 |
+
|
| 1849 |
+
packed = MkldnnRnnLayer(
|
| 1850 |
+
MultiOutputLayout(x.get_device()),
|
| 1851 |
+
inputs=inputs,
|
| 1852 |
+
constant_args=constant_args,
|
| 1853 |
+
)
|
| 1854 |
+
|
| 1855 |
+
def get_strides_of_lstm_output(output_shape, batch_first):
|
| 1856 |
+
assert len(output_shape) == 3, "Expect output_shape to be 3D"
|
| 1857 |
+
return FlexibleLayout.contiguous_strides(output_shape)
|
| 1858 |
+
|
| 1859 |
+
output_sizes = [output_shape, hy_shape, cy_shape]
|
| 1860 |
+
output_strides = [
|
| 1861 |
+
get_strides_of_lstm_output(output_shape, batch_first),
|
| 1862 |
+
FlexibleLayout.contiguous_strides(hy_shape),
|
| 1863 |
+
FlexibleLayout.contiguous_strides(cy_shape),
|
| 1864 |
+
]
|
| 1865 |
+
output_ir = [
|
| 1866 |
+
MultiOutput(
|
| 1867 |
+
FixedLayout(
|
| 1868 |
+
x.get_device(),
|
| 1869 |
+
x.get_dtype(),
|
| 1870 |
+
output_size,
|
| 1871 |
+
output_stride,
|
| 1872 |
+
),
|
| 1873 |
+
packed,
|
| 1874 |
+
[(tuple, i)],
|
| 1875 |
+
)
|
| 1876 |
+
for i, (output_size, output_stride) in enumerate(
|
| 1877 |
+
zip(output_sizes, output_strides)
|
| 1878 |
+
)
|
| 1879 |
+
]
|
| 1880 |
+
|
| 1881 |
+
return output_ir
|
.venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_lowerings.py
ADDED
|
@@ -0,0 +1,1087 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
import functools
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.utils._pytree as pytree
|
| 8 |
+
from torch._inductor.kernel.mm_common import mm_args
|
| 9 |
+
|
| 10 |
+
from . import ir
|
| 11 |
+
from .codegen.cpp_gemm_template import CppPackedGemmTemplate
|
| 12 |
+
from .codegen.cpp_utils import create_epilogue_with_attr
|
| 13 |
+
from .ir import TensorBox
|
| 14 |
+
from .lowering import (
|
| 15 |
+
add,
|
| 16 |
+
add_needs_realized_inputs,
|
| 17 |
+
aten,
|
| 18 |
+
permute,
|
| 19 |
+
register_lowering,
|
| 20 |
+
to_dtype,
|
| 21 |
+
view,
|
| 22 |
+
)
|
| 23 |
+
from .select_algorithm import (
|
| 24 |
+
autotune_select_algorithm,
|
| 25 |
+
ChoiceCaller,
|
| 26 |
+
ExternKernelChoice,
|
| 27 |
+
)
|
| 28 |
+
from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune
|
| 29 |
+
from .virtualized import ops, V
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def register_onednn_fusion_ops():
|
| 33 |
+
if torch._C._has_mkldnn:
|
| 34 |
+
from . import mkldnn_ir
|
| 35 |
+
|
| 36 |
+
aten_mkldnn_linear_unary = ExternKernelChoice(
|
| 37 |
+
torch.ops.mkldnn._linear_pointwise,
|
| 38 |
+
"mkldnn::_linear_pointwise",
|
| 39 |
+
has_out_variant=False,
|
| 40 |
+
kernel_creator=mkldnn_ir.LinearUnary.create,
|
| 41 |
+
)
|
| 42 |
+
aten_mkldnn_linear_binary = ExternKernelChoice(
|
| 43 |
+
torch.ops.mkldnn._linear_pointwise.binary,
|
| 44 |
+
"mkldnn::_linear_pointwise",
|
| 45 |
+
has_out_variant=False,
|
| 46 |
+
kernel_creator=mkldnn_ir.LinearBinary.create,
|
| 47 |
+
)
|
| 48 |
+
aten_mkldnn_qlinear_unary = ExternKernelChoice(
|
| 49 |
+
torch.ops.onednn.qlinear_pointwise,
|
| 50 |
+
"onednn::qlinear_pointwise",
|
| 51 |
+
has_out_variant=False,
|
| 52 |
+
kernel_creator=mkldnn_ir.QLinearPointwisePT2E.create,
|
| 53 |
+
)
|
| 54 |
+
aten_mkldnn_qlinear_binary = ExternKernelChoice(
|
| 55 |
+
torch.ops.onednn.qlinear_pointwise.binary,
|
| 56 |
+
"onednn::qlinear_pointwise",
|
| 57 |
+
has_out_variant=False,
|
| 58 |
+
kernel_creator=mkldnn_ir.QLinearPointwiseBinaryPT2E.create,
|
| 59 |
+
)
|
| 60 |
+
cpu_needs_realized_inputs = [
|
| 61 |
+
torch.ops.mkldnn._convolution_pointwise,
|
| 62 |
+
torch.ops.mkldnn._convolution_pointwise_,
|
| 63 |
+
torch.ops.mkldnn._convolution_transpose_pointwise,
|
| 64 |
+
torch.ops.mkldnn._linear_pointwise,
|
| 65 |
+
aten.mkldnn_rnn_layer.default,
|
| 66 |
+
torch.ops.onednn.qconv2d_pointwise,
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
@register_lowering(torch.ops.mkldnn._convolution_pointwise)
|
| 70 |
+
def convolution_unary(
|
| 71 |
+
x: TensorBox,
|
| 72 |
+
weight: TensorBox,
|
| 73 |
+
bias: TensorBox,
|
| 74 |
+
padding,
|
| 75 |
+
stride,
|
| 76 |
+
dilation,
|
| 77 |
+
groups,
|
| 78 |
+
attr,
|
| 79 |
+
scalars,
|
| 80 |
+
algorithm,
|
| 81 |
+
):
|
| 82 |
+
return TensorBox.create(
|
| 83 |
+
mkldnn_ir.ConvolutionUnary.create(
|
| 84 |
+
x,
|
| 85 |
+
weight,
|
| 86 |
+
bias,
|
| 87 |
+
padding,
|
| 88 |
+
stride,
|
| 89 |
+
dilation,
|
| 90 |
+
groups,
|
| 91 |
+
attr,
|
| 92 |
+
scalars,
|
| 93 |
+
algorithm,
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
@register_lowering(torch.ops.mkldnn._convolution_pointwise.binary)
|
| 98 |
+
def convolution_binary(
|
| 99 |
+
x: TensorBox,
|
| 100 |
+
other: TensorBox,
|
| 101 |
+
weight: TensorBox,
|
| 102 |
+
bias: TensorBox,
|
| 103 |
+
padding,
|
| 104 |
+
stride,
|
| 105 |
+
dilation,
|
| 106 |
+
groups,
|
| 107 |
+
binary_attr,
|
| 108 |
+
binary_alpha,
|
| 109 |
+
unary_attr,
|
| 110 |
+
unary_scalars,
|
| 111 |
+
unary_algorithm,
|
| 112 |
+
):
|
| 113 |
+
return TensorBox.create(
|
| 114 |
+
mkldnn_ir.ConvolutionBinary.create(
|
| 115 |
+
x,
|
| 116 |
+
other,
|
| 117 |
+
weight,
|
| 118 |
+
bias,
|
| 119 |
+
padding,
|
| 120 |
+
stride,
|
| 121 |
+
dilation,
|
| 122 |
+
groups,
|
| 123 |
+
binary_attr,
|
| 124 |
+
binary_alpha,
|
| 125 |
+
unary_attr,
|
| 126 |
+
unary_scalars,
|
| 127 |
+
unary_algorithm,
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
@register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary)
|
| 132 |
+
def convolution_binary_inplace(
|
| 133 |
+
x: TensorBox,
|
| 134 |
+
other: TensorBox,
|
| 135 |
+
weight: TensorBox,
|
| 136 |
+
bias: TensorBox,
|
| 137 |
+
padding,
|
| 138 |
+
stride,
|
| 139 |
+
dilation,
|
| 140 |
+
groups,
|
| 141 |
+
binary_attr,
|
| 142 |
+
binary_alpha,
|
| 143 |
+
unary_attr,
|
| 144 |
+
unary_scalars,
|
| 145 |
+
unary_algorithm,
|
| 146 |
+
):
|
| 147 |
+
return TensorBox.create(
|
| 148 |
+
mkldnn_ir.ConvolutionBinaryInplace.create(
|
| 149 |
+
x,
|
| 150 |
+
other,
|
| 151 |
+
weight,
|
| 152 |
+
bias,
|
| 153 |
+
padding,
|
| 154 |
+
stride,
|
| 155 |
+
dilation,
|
| 156 |
+
groups,
|
| 157 |
+
binary_attr,
|
| 158 |
+
binary_alpha,
|
| 159 |
+
unary_attr,
|
| 160 |
+
unary_scalars,
|
| 161 |
+
unary_algorithm,
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
@register_lowering(torch.ops.mkldnn._linear_pointwise)
|
| 166 |
+
def linear_unary(
|
| 167 |
+
x: TensorBox,
|
| 168 |
+
w: TensorBox,
|
| 169 |
+
b: TensorBox,
|
| 170 |
+
attr,
|
| 171 |
+
scalars,
|
| 172 |
+
algorithm,
|
| 173 |
+
layout=None,
|
| 174 |
+
):
|
| 175 |
+
x_size = x.get_size()
|
| 176 |
+
if len(x_size) > 2:
|
| 177 |
+
# GEMM template needs 2D input, normalize input shape here
|
| 178 |
+
x = view(x, [-1, x_size[-1]])
|
| 179 |
+
if b is not None:
|
| 180 |
+
b = ir.ExternKernel.realize_input(b)
|
| 181 |
+
choices: List[ChoiceCaller] = []
|
| 182 |
+
if use_max_autotune():
|
| 183 |
+
transposed_w = permute(w, [1, 0])
|
| 184 |
+
*_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout)
|
| 185 |
+
if use_cpp_packed_gemm_template(layout, x, transposed_w):
|
| 186 |
+
|
| 187 |
+
def epilogue_creator(buf):
|
| 188 |
+
return create_epilogue_with_attr(
|
| 189 |
+
buf, attr, scalars=scalars, algorithm=algorithm
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
kwargs = dict(
|
| 193 |
+
has_bias=b is not None,
|
| 194 |
+
trans_w=True,
|
| 195 |
+
epilogue_creator=None if attr == "none" else epilogue_creator,
|
| 196 |
+
)
|
| 197 |
+
if b is not None:
|
| 198 |
+
kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment]
|
| 199 |
+
CppPackedGemmTemplate.add_choices(
|
| 200 |
+
choices,
|
| 201 |
+
layout,
|
| 202 |
+
[x, w] if b is None else [x, w, b],
|
| 203 |
+
**kwargs, # type: ignore[arg-type]
|
| 204 |
+
)
|
| 205 |
+
if len(choices) == 0 or use_aten_gemm_kernels():
|
| 206 |
+
kwargs = dict(attr=attr, scalars=scalars, algorithm=algorithm)
|
| 207 |
+
if b is None:
|
| 208 |
+
kwargs["B"] = None
|
| 209 |
+
choices.append(
|
| 210 |
+
aten_mkldnn_linear_unary.bind(
|
| 211 |
+
[x, w] if b is None else [x, w, b],
|
| 212 |
+
layout,
|
| 213 |
+
**kwargs,
|
| 214 |
+
)
|
| 215 |
+
)
|
| 216 |
+
assert w.get_name() in V.graph.constants
|
| 217 |
+
input_gen_fns = {
|
| 218 |
+
1: lambda x: V.graph.constants[x.get_name()],
|
| 219 |
+
}
|
| 220 |
+
result = autotune_select_algorithm(
|
| 221 |
+
"linear_unary",
|
| 222 |
+
choices,
|
| 223 |
+
[x, w] if b is None else [x, w, b],
|
| 224 |
+
layout,
|
| 225 |
+
input_gen_fns=input_gen_fns,
|
| 226 |
+
)
|
| 227 |
+
if len(x_size) > 2:
|
| 228 |
+
result = view(result, (*x_size[:-1], result.get_size()[-1]))
|
| 229 |
+
return result
|
| 230 |
+
|
| 231 |
+
@register_lowering(torch.ops.mkldnn._linear_pointwise.binary)
|
| 232 |
+
def linear_binary(
|
| 233 |
+
x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr, layout=None
|
| 234 |
+
):
|
| 235 |
+
x_size = x.get_size()
|
| 236 |
+
if len(x_size) > 2:
|
| 237 |
+
# GEMM template needs 2D input, normalize input shape here
|
| 238 |
+
x = view(x, [-1, x_size[-1]])
|
| 239 |
+
y_size = y.get_size()
|
| 240 |
+
if len(y_size) > 2:
|
| 241 |
+
y = view(y, [-1, y_size[-1]])
|
| 242 |
+
if b is not None:
|
| 243 |
+
b = ir.ExternKernel.realize_input(b)
|
| 244 |
+
choices: List[ChoiceCaller] = []
|
| 245 |
+
if use_max_autotune():
|
| 246 |
+
transposed_w = permute(w, [1, 0])
|
| 247 |
+
*_, layout, x, transposed_w, y = mm_args(
|
| 248 |
+
x, transposed_w, y, layout=layout
|
| 249 |
+
)
|
| 250 |
+
if use_cpp_packed_gemm_template(layout, x, transposed_w):
|
| 251 |
+
|
| 252 |
+
def epilogue_creator(buf):
|
| 253 |
+
return create_epilogue_with_attr(buf, attr, other=y)
|
| 254 |
+
|
| 255 |
+
kwargs = dict(
|
| 256 |
+
has_bias=b is not None,
|
| 257 |
+
trans_w=True,
|
| 258 |
+
epilogue_creator=epilogue_creator,
|
| 259 |
+
)
|
| 260 |
+
kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]
|
| 261 |
+
CppPackedGemmTemplate.add_choices(
|
| 262 |
+
choices,
|
| 263 |
+
layout,
|
| 264 |
+
[x, y, w] if b is None else [x, y, w, b],
|
| 265 |
+
**kwargs, # type: ignore[arg-type]
|
| 266 |
+
)
|
| 267 |
+
if len(choices) == 0 or use_aten_gemm_kernels():
|
| 268 |
+
kwargs = dict(attr=attr)
|
| 269 |
+
if b is None:
|
| 270 |
+
kwargs["B"] = None
|
| 271 |
+
choices.append(
|
| 272 |
+
aten_mkldnn_linear_binary.bind(
|
| 273 |
+
[x, y, w] if b is None else [x, y, w, b],
|
| 274 |
+
layout,
|
| 275 |
+
**kwargs,
|
| 276 |
+
)
|
| 277 |
+
)
|
| 278 |
+
assert w.get_name() in V.graph.constants
|
| 279 |
+
input_gen_fns = {
|
| 280 |
+
2: lambda x: V.graph.constants[x.get_name()],
|
| 281 |
+
}
|
| 282 |
+
result = autotune_select_algorithm(
|
| 283 |
+
"linear_binary",
|
| 284 |
+
choices,
|
| 285 |
+
[x, y, w] if b is None else [x, y, w, b],
|
| 286 |
+
layout,
|
| 287 |
+
input_gen_fns=input_gen_fns,
|
| 288 |
+
)
|
| 289 |
+
if len(x_size) > 2:
|
| 290 |
+
result = view(result, (*x_size[:-1], result.get_size()[-1]))
|
| 291 |
+
return result
|
| 292 |
+
|
| 293 |
+
@register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise)
|
| 294 |
+
def convolution_transpose_unary(
|
| 295 |
+
x: TensorBox,
|
| 296 |
+
weight: TensorBox,
|
| 297 |
+
bias: TensorBox,
|
| 298 |
+
padding,
|
| 299 |
+
output_padding,
|
| 300 |
+
stride,
|
| 301 |
+
dilation,
|
| 302 |
+
groups,
|
| 303 |
+
attr,
|
| 304 |
+
scalars,
|
| 305 |
+
algorithm,
|
| 306 |
+
):
|
| 307 |
+
return TensorBox.create(
|
| 308 |
+
mkldnn_ir.ConvolutionTransposeUnary.create(
|
| 309 |
+
x,
|
| 310 |
+
weight,
|
| 311 |
+
bias,
|
| 312 |
+
padding,
|
| 313 |
+
output_padding,
|
| 314 |
+
stride,
|
| 315 |
+
dilation,
|
| 316 |
+
groups,
|
| 317 |
+
attr,
|
| 318 |
+
scalars,
|
| 319 |
+
algorithm,
|
| 320 |
+
)
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
@register_lowering(aten.mkldnn_rnn_layer.default)
|
| 324 |
+
def mkldnn_rnn_layer(
|
| 325 |
+
x: TensorBox,
|
| 326 |
+
w0: TensorBox,
|
| 327 |
+
w1: TensorBox,
|
| 328 |
+
w2: TensorBox,
|
| 329 |
+
w3: TensorBox,
|
| 330 |
+
hx: TensorBox,
|
| 331 |
+
cx: TensorBox,
|
| 332 |
+
reverse: bool,
|
| 333 |
+
batch_sizes: List[int],
|
| 334 |
+
mode: int,
|
| 335 |
+
hidden_size: int,
|
| 336 |
+
num_layers: int,
|
| 337 |
+
has_biases: bool,
|
| 338 |
+
bidirectional: bool,
|
| 339 |
+
batch_first: bool,
|
| 340 |
+
train: bool,
|
| 341 |
+
):
|
| 342 |
+
return pytree.tree_map(
|
| 343 |
+
TensorBox.create,
|
| 344 |
+
mkldnn_ir.MkldnnRnnLayer.create(
|
| 345 |
+
x,
|
| 346 |
+
w0,
|
| 347 |
+
w1,
|
| 348 |
+
w2,
|
| 349 |
+
w3,
|
| 350 |
+
hx,
|
| 351 |
+
cx,
|
| 352 |
+
reverse,
|
| 353 |
+
batch_sizes,
|
| 354 |
+
mode,
|
| 355 |
+
hidden_size,
|
| 356 |
+
num_layers,
|
| 357 |
+
has_biases,
|
| 358 |
+
bidirectional,
|
| 359 |
+
batch_first,
|
| 360 |
+
train,
|
| 361 |
+
),
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
@register_lowering(torch.ops.onednn.qconv2d_pointwise, type_promotion_kind=None)
|
| 365 |
+
def qconvolution_unary(
|
| 366 |
+
x: TensorBox,
|
| 367 |
+
x_scale,
|
| 368 |
+
x_zp,
|
| 369 |
+
packed_weight: TensorBox,
|
| 370 |
+
w_scale: TensorBox,
|
| 371 |
+
w_zp: TensorBox,
|
| 372 |
+
bias: TensorBox,
|
| 373 |
+
stride,
|
| 374 |
+
padding,
|
| 375 |
+
dilation,
|
| 376 |
+
groups,
|
| 377 |
+
o_inv_scale,
|
| 378 |
+
o_zero_point,
|
| 379 |
+
output_dtype,
|
| 380 |
+
attr,
|
| 381 |
+
scalars,
|
| 382 |
+
algorithm,
|
| 383 |
+
):
|
| 384 |
+
return TensorBox.create(
|
| 385 |
+
mkldnn_ir.QConvPointWisePT2E.create(
|
| 386 |
+
x,
|
| 387 |
+
x_scale,
|
| 388 |
+
x_zp,
|
| 389 |
+
packed_weight,
|
| 390 |
+
w_scale,
|
| 391 |
+
w_zp,
|
| 392 |
+
bias,
|
| 393 |
+
stride,
|
| 394 |
+
padding,
|
| 395 |
+
dilation,
|
| 396 |
+
groups,
|
| 397 |
+
o_inv_scale,
|
| 398 |
+
o_zero_point,
|
| 399 |
+
output_dtype,
|
| 400 |
+
attr,
|
| 401 |
+
scalars,
|
| 402 |
+
algorithm,
|
| 403 |
+
)
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
@register_lowering(
|
| 407 |
+
torch.ops.onednn.qconv2d_pointwise.binary, type_promotion_kind=None
|
| 408 |
+
)
|
| 409 |
+
def qconvolution_binary(
|
| 410 |
+
x: TensorBox,
|
| 411 |
+
x_scale,
|
| 412 |
+
x_zp,
|
| 413 |
+
accum: TensorBox,
|
| 414 |
+
accum_scale,
|
| 415 |
+
accum_zp,
|
| 416 |
+
packed_weight: TensorBox,
|
| 417 |
+
w_scale: TensorBox,
|
| 418 |
+
w_zp: TensorBox,
|
| 419 |
+
bias: TensorBox,
|
| 420 |
+
stride,
|
| 421 |
+
padding,
|
| 422 |
+
dilation,
|
| 423 |
+
groups,
|
| 424 |
+
o_inv_scale,
|
| 425 |
+
o_zero_point,
|
| 426 |
+
output_dtype,
|
| 427 |
+
binary_attr,
|
| 428 |
+
alpha,
|
| 429 |
+
unary_attr,
|
| 430 |
+
unary_scalars,
|
| 431 |
+
unary_algorithmm,
|
| 432 |
+
):
|
| 433 |
+
if (
|
| 434 |
+
binary_attr == "sum"
|
| 435 |
+
and output_dtype in [torch.float32, torch.bfloat16]
|
| 436 |
+
and accum.get_dtype() in [torch.float32, torch.bfloat16]
|
| 437 |
+
and accum.get_dtype() != output_dtype
|
| 438 |
+
):
|
| 439 |
+
# For int8-mixed-bf16 quantization and inplace add,
|
| 440 |
+
# there is case when accum dtype is float32 but output dtype is bfloat16.
|
| 441 |
+
# Since the accum will be inplaced changed with post op sum,
|
| 442 |
+
# we will do accum dtype convertion here.
|
| 443 |
+
accum = to_dtype(accum, output_dtype)
|
| 444 |
+
return TensorBox.create(
|
| 445 |
+
mkldnn_ir.QConvPointWiseBinaryPT2E.create(
|
| 446 |
+
x,
|
| 447 |
+
x_scale,
|
| 448 |
+
x_zp,
|
| 449 |
+
accum,
|
| 450 |
+
accum_scale,
|
| 451 |
+
accum_zp,
|
| 452 |
+
packed_weight,
|
| 453 |
+
w_scale,
|
| 454 |
+
w_zp,
|
| 455 |
+
bias,
|
| 456 |
+
stride,
|
| 457 |
+
padding,
|
| 458 |
+
dilation,
|
| 459 |
+
groups,
|
| 460 |
+
o_inv_scale,
|
| 461 |
+
o_zero_point,
|
| 462 |
+
output_dtype,
|
| 463 |
+
binary_attr,
|
| 464 |
+
alpha,
|
| 465 |
+
unary_attr,
|
| 466 |
+
unary_scalars,
|
| 467 |
+
unary_algorithmm,
|
| 468 |
+
)
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
@register_lowering(torch.ops.onednn.qlinear_pointwise, type_promotion_kind=None)
|
| 472 |
+
def qlinear_unary(
|
| 473 |
+
x: TensorBox,
|
| 474 |
+
x_scale,
|
| 475 |
+
x_zp,
|
| 476 |
+
packed_weight: TensorBox,
|
| 477 |
+
w_scale: TensorBox,
|
| 478 |
+
w_zp: TensorBox,
|
| 479 |
+
bias: TensorBox,
|
| 480 |
+
o_scale,
|
| 481 |
+
o_zero_point,
|
| 482 |
+
output_dtype,
|
| 483 |
+
attr,
|
| 484 |
+
scalars,
|
| 485 |
+
algorithm,
|
| 486 |
+
layout=None,
|
| 487 |
+
):
|
| 488 |
+
x_size = x.get_size()
|
| 489 |
+
if len(x_size) > 2:
|
| 490 |
+
# GEMM template needs 2D input, normalize input shape here
|
| 491 |
+
x = view(x, [-1, x_size[-1]])
|
| 492 |
+
if not isinstance(x_scale, ir.TensorBox):
|
| 493 |
+
assert type(x_scale) == float
|
| 494 |
+
x_scale = V.graph.add_tensor_constant(
|
| 495 |
+
torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
|
| 496 |
+
)
|
| 497 |
+
else:
|
| 498 |
+
x_scale.realize()
|
| 499 |
+
if not isinstance(x_zp, ir.TensorBox):
|
| 500 |
+
assert type(x_zp) == int
|
| 501 |
+
x_zp = V.graph.add_tensor_constant(
|
| 502 |
+
torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
|
| 503 |
+
)
|
| 504 |
+
else:
|
| 505 |
+
x_zp.realize()
|
| 506 |
+
|
| 507 |
+
# When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer
|
| 508 |
+
# Refer to https://github.com/pytorch/pytorch/blob
|
| 509 |
+
# /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577
|
| 510 |
+
w_scale.realize()
|
| 511 |
+
w_zp.realize()
|
| 512 |
+
if w_zp.get_dtype() != torch.int32 and isinstance(
|
| 513 |
+
ir.InputsKernel.unwrap_storage_for_input(w_zp),
|
| 514 |
+
ir.ConstantBuffer,
|
| 515 |
+
):
|
| 516 |
+
# W_zp might be a ConstantBuffer with int64, convert it to int32
|
| 517 |
+
w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32)
|
| 518 |
+
w_zp = V.graph.add_tensor_constant(
|
| 519 |
+
torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name()
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
bias_dtype = None if bias is None else bias.get_dtype()
|
| 523 |
+
|
| 524 |
+
choices: List[ChoiceCaller] = []
|
| 525 |
+
if use_max_autotune():
|
| 526 |
+
*_, layout, x, packed_weight = mm_args(
|
| 527 |
+
x, packed_weight, layout=layout, out_dtype=output_dtype
|
| 528 |
+
)
|
| 529 |
+
if (
|
| 530 |
+
isinstance(
|
| 531 |
+
ir.InputsKernel.unwrap_storage_for_input(x_zp),
|
| 532 |
+
ir.ConstantBuffer,
|
| 533 |
+
)
|
| 534 |
+
and len(x_zp.get_layout().size) == 0 # Per tensor quant of act
|
| 535 |
+
and isinstance(
|
| 536 |
+
ir.InputsKernel.unwrap_storage_for_input(w_zp),
|
| 537 |
+
ir.ConstantBuffer,
|
| 538 |
+
)
|
| 539 |
+
and torch.equal(
|
| 540 |
+
torch.zeros_like(V.graph.constants[w_zp.get_name()]),
|
| 541 |
+
V.graph.constants[w_zp.get_name()],
|
| 542 |
+
) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA
|
| 543 |
+
and use_cpp_packed_gemm_template(layout, x, packed_weight)
|
| 544 |
+
):
|
| 545 |
+
W_tensor = V.graph.constants[packed_weight.get_name()].to_dense()
|
| 546 |
+
weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0)
|
| 547 |
+
weight_compens = V.graph.add_tensor_constant(
|
| 548 |
+
weight_compens_tensor,
|
| 549 |
+
name=packed_weight.get_name() + "_BMatrixCompens",
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
def epilogue_creator(input_buffer):
|
| 553 |
+
# Epilogue to convert from s32 to f32 for u8s8f32
|
| 554 |
+
assert output_dtype in [
|
| 555 |
+
torch.float32,
|
| 556 |
+
torch.bfloat16,
|
| 557 |
+
torch.uint8,
|
| 558 |
+
]
|
| 559 |
+
input_loader = input_buffer.make_loader()
|
| 560 |
+
weight_compens_loader = weight_compens.make_loader()
|
| 561 |
+
x_scale_loader = x_scale.make_loader()
|
| 562 |
+
w_scale_loader = w_scale.make_loader()
|
| 563 |
+
x_zp_loader = x_zp.make_loader()
|
| 564 |
+
nonlocal bias
|
| 565 |
+
bias_loader = None
|
| 566 |
+
if bias is not None:
|
| 567 |
+
bias_loader = bias.make_loader()
|
| 568 |
+
|
| 569 |
+
def inner_fn(index):
|
| 570 |
+
nonlocal bias
|
| 571 |
+
input = input_loader(index)
|
| 572 |
+
# MicroKernel Output is with int32
|
| 573 |
+
# cvt to FP32 before doing compensation
|
| 574 |
+
input = ops.to_dtype(input, torch.float32)
|
| 575 |
+
weight_compens_index = (index[-1],)
|
| 576 |
+
_x_scale = x_scale_loader(())
|
| 577 |
+
_x_zp = x_zp_loader(())
|
| 578 |
+
_w_scale = w_scale_loader(weight_compens_index)
|
| 579 |
+
_weight_compo = weight_compens_loader(weight_compens_index)
|
| 580 |
+
# Step 1: Doing compensation to cvt fp32
|
| 581 |
+
temp = ops.mul(
|
| 582 |
+
ops.mul(
|
| 583 |
+
input,
|
| 584 |
+
_x_scale,
|
| 585 |
+
),
|
| 586 |
+
_w_scale,
|
| 587 |
+
)
|
| 588 |
+
temp = ops.sub(
|
| 589 |
+
temp,
|
| 590 |
+
ops.mul(
|
| 591 |
+
ops.mul(
|
| 592 |
+
ops.mul(
|
| 593 |
+
_x_scale,
|
| 594 |
+
_w_scale,
|
| 595 |
+
),
|
| 596 |
+
_x_zp,
|
| 597 |
+
),
|
| 598 |
+
_weight_compo,
|
| 599 |
+
),
|
| 600 |
+
)
|
| 601 |
+
# Step 2: add Bias if applicable
|
| 602 |
+
if bias is not None:
|
| 603 |
+
_bias = bias_loader(weight_compens_index)
|
| 604 |
+
nonlocal bias_dtype
|
| 605 |
+
assert bias_dtype in [torch.float32, torch.bfloat16]
|
| 606 |
+
if bias_dtype == torch.bfloat16:
|
| 607 |
+
_bias = ops.to_dtype(_bias, torch.float32)
|
| 608 |
+
temp = ops.add(temp, _bias)
|
| 609 |
+
|
| 610 |
+
return temp
|
| 611 |
+
|
| 612 |
+
output_buf = ir.Pointwise(
|
| 613 |
+
device=input_buffer.get_device(),
|
| 614 |
+
dtype=torch.float32, # Hardcode to FP32 for u8s8f32
|
| 615 |
+
inner_fn=inner_fn,
|
| 616 |
+
ranges=input_buffer.get_size(),
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
# Step 3: Doing the unary post op fusion
|
| 620 |
+
if attr != "none":
|
| 621 |
+
output_buf = create_epilogue_with_attr(
|
| 622 |
+
output_buf, attr, scalars=scalars, algorithm=algorithm
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
# Step 4: Cast output to Target Dtype
|
| 626 |
+
if output_dtype == torch.bfloat16:
|
| 627 |
+
output_cast_loader = output_buf.make_loader()
|
| 628 |
+
|
| 629 |
+
def inner_fn_cast_output_to_bf16(index):
|
| 630 |
+
input = output_cast_loader(index)
|
| 631 |
+
return ops.to_dtype(input, output_dtype)
|
| 632 |
+
|
| 633 |
+
output_buf = ir.Pointwise(
|
| 634 |
+
device=output_buf.get_device(),
|
| 635 |
+
dtype=output_dtype,
|
| 636 |
+
inner_fn=inner_fn_cast_output_to_bf16,
|
| 637 |
+
ranges=output_buf.get_size(),
|
| 638 |
+
)
|
| 639 |
+
elif output_dtype == torch.uint8:
|
| 640 |
+
from .lowering import _create_constants
|
| 641 |
+
|
| 642 |
+
requant_input_loader = output_buf.make_loader()
|
| 643 |
+
|
| 644 |
+
def inner_fn_requant(index, scale, zero_point):
|
| 645 |
+
input = requant_input_loader(index)
|
| 646 |
+
inv_scale, zero_point = _create_constants(
|
| 647 |
+
1.0 / scale, zero_point, dtype=torch.float32
|
| 648 |
+
)
|
| 649 |
+
val = ops.round(input * inv_scale) + zero_point
|
| 650 |
+
qmin, qmax = _create_constants(
|
| 651 |
+
0, 255, dtype=torch.float32
|
| 652 |
+
)
|
| 653 |
+
clamped = ops.minimum(ops.maximum(val, qmin), qmax)
|
| 654 |
+
return ops.to_dtype(clamped, torch.uint8)
|
| 655 |
+
|
| 656 |
+
output_buf = ir.Pointwise(
|
| 657 |
+
device=output_buf.get_device(),
|
| 658 |
+
dtype=output_dtype,
|
| 659 |
+
inner_fn=functools.partial(
|
| 660 |
+
inner_fn_requant,
|
| 661 |
+
scale=float(o_scale),
|
| 662 |
+
zero_point=int(o_zero_point),
|
| 663 |
+
),
|
| 664 |
+
ranges=output_buf.get_size(),
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
return output_buf
|
| 668 |
+
|
| 669 |
+
assert x.get_dtype() == torch.uint8
|
| 670 |
+
CppPackedGemmTemplate.add_choices(
|
| 671 |
+
choices,
|
| 672 |
+
layout,
|
| 673 |
+
[x, x_scale, x_zp, packed_weight, w_scale, w_zp]
|
| 674 |
+
if bias is None
|
| 675 |
+
else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias],
|
| 676 |
+
has_bias=bias is not None,
|
| 677 |
+
epilogue_creator=epilogue_creator,
|
| 678 |
+
input_indices=[0, 3, 1, 2, 4, 5]
|
| 679 |
+
if bias is None
|
| 680 |
+
else [6, 0, 3, 1, 2, 4, 5],
|
| 681 |
+
)
|
| 682 |
+
if len(choices) == 0 or use_aten_gemm_kernels():
|
| 683 |
+
kwargs = dict(
|
| 684 |
+
output_scale=o_scale,
|
| 685 |
+
output_zero_point=o_zero_point,
|
| 686 |
+
output_dtype=output_dtype,
|
| 687 |
+
post_op_name=attr,
|
| 688 |
+
post_op_args=scalars,
|
| 689 |
+
post_op_algorithm=algorithm,
|
| 690 |
+
)
|
| 691 |
+
if bias is None:
|
| 692 |
+
kwargs["bias"] = None
|
| 693 |
+
choices.append(
|
| 694 |
+
aten_mkldnn_qlinear_unary.bind(
|
| 695 |
+
(x, x_scale, x_zp, packed_weight, w_scale, w_zp)
|
| 696 |
+
if bias is None
|
| 697 |
+
else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias),
|
| 698 |
+
layout,
|
| 699 |
+
**kwargs,
|
| 700 |
+
)
|
| 701 |
+
)
|
| 702 |
+
assert packed_weight.get_name() in V.graph.constants
|
| 703 |
+
input_gen_fns = {
|
| 704 |
+
3: lambda x: V.graph.constants[x.get_name()],
|
| 705 |
+
4: lambda x: V.graph.constants[x.get_name()],
|
| 706 |
+
5: lambda x: V.graph.constants[x.get_name()],
|
| 707 |
+
6: lambda x: V.graph.constants[x.get_name()], # For bias
|
| 708 |
+
}
|
| 709 |
+
result = autotune_select_algorithm(
|
| 710 |
+
"qlinear_unary",
|
| 711 |
+
choices,
|
| 712 |
+
[x, x_scale, x_zp, packed_weight, w_scale, w_zp]
|
| 713 |
+
if bias is None
|
| 714 |
+
else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias],
|
| 715 |
+
layout,
|
| 716 |
+
input_gen_fns=input_gen_fns,
|
| 717 |
+
)
|
| 718 |
+
if len(x_size) > 2:
|
| 719 |
+
result = view(result, (*x_size[:-1], result.get_size()[-1]))
|
| 720 |
+
return result
|
| 721 |
+
|
| 722 |
+
@register_lowering(
|
| 723 |
+
torch.ops.onednn.qlinear_pointwise.binary, type_promotion_kind=None
|
| 724 |
+
)
|
| 725 |
+
@register_lowering(
|
| 726 |
+
torch.ops.onednn.qlinear_pointwise.binary_tensor, type_promotion_kind=None
|
| 727 |
+
)
|
| 728 |
+
def qlinear_binary(
|
| 729 |
+
x: TensorBox,
|
| 730 |
+
x_scale,
|
| 731 |
+
x_zp,
|
| 732 |
+
packed_weight: TensorBox,
|
| 733 |
+
w_scale: TensorBox,
|
| 734 |
+
w_zp: TensorBox,
|
| 735 |
+
x2: TensorBox,
|
| 736 |
+
bias: TensorBox,
|
| 737 |
+
o_scale,
|
| 738 |
+
o_zero_point,
|
| 739 |
+
output_dtype,
|
| 740 |
+
x2_scale,
|
| 741 |
+
x2_zp,
|
| 742 |
+
binary_attr,
|
| 743 |
+
alpha,
|
| 744 |
+
unary_attr,
|
| 745 |
+
unary_scalars,
|
| 746 |
+
unary_algorithmm,
|
| 747 |
+
layout=None,
|
| 748 |
+
):
|
| 749 |
+
x_size = x.get_size()
|
| 750 |
+
x2_size = x2.get_size()
|
| 751 |
+
assert len(x_size) == len(x2_size)
|
| 752 |
+
if len(x_size) > 2 and binary_attr == "add":
|
| 753 |
+
# GEMM template needs 2D input, normalize input shape here
|
| 754 |
+
x = view(x, [-1, x_size[-1]])
|
| 755 |
+
x2 = view(x2, [-1, x2_size[-1]])
|
| 756 |
+
if not isinstance(x_scale, ir.TensorBox):
|
| 757 |
+
assert type(x_scale) == float
|
| 758 |
+
x_scale = V.graph.add_tensor_constant(
|
| 759 |
+
torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
|
| 760 |
+
)
|
| 761 |
+
else:
|
| 762 |
+
x_scale.realize()
|
| 763 |
+
if not isinstance(x_zp, ir.TensorBox):
|
| 764 |
+
assert type(x_zp) == int
|
| 765 |
+
x_zp = V.graph.add_tensor_constant(
|
| 766 |
+
torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
|
| 767 |
+
)
|
| 768 |
+
else:
|
| 769 |
+
x_zp.realize()
|
| 770 |
+
|
| 771 |
+
# When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer
|
| 772 |
+
# Refer to https://github.com/pytorch/pytorch/blob
|
| 773 |
+
# /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577
|
| 774 |
+
w_scale.realize()
|
| 775 |
+
w_zp.realize()
|
| 776 |
+
if w_zp.get_dtype() != torch.int32 and isinstance(
|
| 777 |
+
ir.InputsKernel.unwrap_storage_for_input(w_zp),
|
| 778 |
+
ir.ConstantBuffer,
|
| 779 |
+
):
|
| 780 |
+
w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32)
|
| 781 |
+
w_zp = V.graph.add_tensor_constant(
|
| 782 |
+
torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name()
|
| 783 |
+
)
|
| 784 |
+
if binary_attr == "sum":
|
| 785 |
+
if output_dtype in [
|
| 786 |
+
torch.float32,
|
| 787 |
+
torch.bfloat16,
|
| 788 |
+
] and x2.get_dtype() in [torch.float32, torch.bfloat16]:
|
| 789 |
+
if x2.get_dtype() != output_dtype:
|
| 790 |
+
# For int8-mixed-bf16 quantization and inplace add,
|
| 791 |
+
# there is case when accum dtype is float32 but output dtype is bfloat16.
|
| 792 |
+
# Since the accum will be inplaced changed with post op sum,
|
| 793 |
+
# we will do accum dtype convertion here.
|
| 794 |
+
x2 = to_dtype(x2, output_dtype)
|
| 795 |
+
else:
|
| 796 |
+
assert (
|
| 797 |
+
x2.get_dtype() == output_dtype
|
| 798 |
+
), "dtype of accum for qlinear post op sum should be the same as output"
|
| 799 |
+
x2_dtype = x2.get_dtype()
|
| 800 |
+
bias_dtype = bias.get_dtype() if bias is not None else None
|
| 801 |
+
choices: List[ChoiceCaller] = []
|
| 802 |
+
if (
|
| 803 |
+
use_max_autotune() and binary_attr == "add"
|
| 804 |
+
): # <TODO> Support inplace sum fusion
|
| 805 |
+
*_, layout, x, packed_weight, x2 = mm_args(
|
| 806 |
+
x, packed_weight, x2, layout=layout, out_dtype=output_dtype
|
| 807 |
+
)
|
| 808 |
+
if (
|
| 809 |
+
isinstance(
|
| 810 |
+
ir.InputsKernel.unwrap_storage_for_input(x_zp),
|
| 811 |
+
ir.ConstantBuffer,
|
| 812 |
+
)
|
| 813 |
+
and len(x_zp.get_layout().size) == 0 # Per tensor quant of act
|
| 814 |
+
and isinstance(
|
| 815 |
+
ir.InputsKernel.unwrap_storage_for_input(w_zp),
|
| 816 |
+
ir.ConstantBuffer,
|
| 817 |
+
)
|
| 818 |
+
and torch.equal(
|
| 819 |
+
torch.zeros_like(V.graph.constants[w_zp.get_name()]),
|
| 820 |
+
V.graph.constants[w_zp.get_name()],
|
| 821 |
+
) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA
|
| 822 |
+
and use_cpp_packed_gemm_template(layout, x, packed_weight)
|
| 823 |
+
):
|
| 824 |
+
W_tensor = V.graph.constants[packed_weight.get_name()]
|
| 825 |
+
W_tensor = W_tensor.to_dense()
|
| 826 |
+
weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0)
|
| 827 |
+
weight_compens = V.graph.add_tensor_constant(
|
| 828 |
+
weight_compens_tensor,
|
| 829 |
+
name=packed_weight.get_name() + "_BMatrixCompens",
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
def epilogue_creator(input_buffer):
|
| 833 |
+
# Epilogue to convert from s32 to f32 for u8s8f32
|
| 834 |
+
assert output_dtype in [
|
| 835 |
+
torch.float32,
|
| 836 |
+
torch.bfloat16,
|
| 837 |
+
torch.uint8,
|
| 838 |
+
]
|
| 839 |
+
|
| 840 |
+
input_loader = input_buffer.make_loader()
|
| 841 |
+
x2_loader = x2.make_loader()
|
| 842 |
+
weight_compens_loader = weight_compens.make_loader()
|
| 843 |
+
x_scale_loader = x_scale.make_loader()
|
| 844 |
+
w_scale_loader = w_scale.make_loader()
|
| 845 |
+
x_zp_loader = x_zp.make_loader()
|
| 846 |
+
nonlocal bias
|
| 847 |
+
bias_loader = None
|
| 848 |
+
if bias is not None:
|
| 849 |
+
bias_loader = bias.make_loader()
|
| 850 |
+
|
| 851 |
+
def inner_fn(index):
|
| 852 |
+
nonlocal bias
|
| 853 |
+
input = input_loader(index)
|
| 854 |
+
_x2 = x2_loader(index)
|
| 855 |
+
_x_scale = x_scale_loader(())
|
| 856 |
+
_x_zp = x_zp_loader(())
|
| 857 |
+
|
| 858 |
+
# MicroKernel Output is with int32
|
| 859 |
+
# cvt to FP32 before doing compensation
|
| 860 |
+
input = ops.to_dtype(input, torch.float32)
|
| 861 |
+
weight_compens_index = (index[-1],)
|
| 862 |
+
_w_scale = w_scale_loader(weight_compens_index)
|
| 863 |
+
_weight_compens = weight_compens_loader(
|
| 864 |
+
weight_compens_index
|
| 865 |
+
)
|
| 866 |
+
# Step 1: Doing compensation to cvt fp32
|
| 867 |
+
temp = ops.mul(
|
| 868 |
+
ops.mul(
|
| 869 |
+
input,
|
| 870 |
+
_x_scale,
|
| 871 |
+
),
|
| 872 |
+
_w_scale,
|
| 873 |
+
)
|
| 874 |
+
temp = ops.sub(
|
| 875 |
+
temp,
|
| 876 |
+
ops.mul(
|
| 877 |
+
ops.mul(
|
| 878 |
+
ops.mul(
|
| 879 |
+
_x_scale,
|
| 880 |
+
_w_scale,
|
| 881 |
+
),
|
| 882 |
+
_x_zp,
|
| 883 |
+
),
|
| 884 |
+
_weight_compens,
|
| 885 |
+
),
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
# Step 2: add Bias if applicable
|
| 889 |
+
if bias is not None:
|
| 890 |
+
_bias = bias_loader(weight_compens_index)
|
| 891 |
+
nonlocal bias_dtype
|
| 892 |
+
assert bias_dtype in [torch.float32, torch.bfloat16]
|
| 893 |
+
if bias_dtype == torch.bfloat16:
|
| 894 |
+
_bias = ops.to_dtype(_bias, torch.float32)
|
| 895 |
+
temp = ops.add(temp, _bias)
|
| 896 |
+
|
| 897 |
+
# Step 3: Binary add
|
| 898 |
+
nonlocal x2_dtype
|
| 899 |
+
assert x2_dtype in [torch.float32, torch.bfloat16]
|
| 900 |
+
if x2_dtype == torch.bfloat16:
|
| 901 |
+
_x2 = ops.to_dtype(_x2, torch.float32)
|
| 902 |
+
temp = ops.add(temp, _x2)
|
| 903 |
+
|
| 904 |
+
return temp
|
| 905 |
+
|
| 906 |
+
output_buf = ir.Pointwise(
|
| 907 |
+
device=input_buffer.get_device(),
|
| 908 |
+
dtype=torch.float32, # Hardcode to FP32 for u8s8f32
|
| 909 |
+
inner_fn=inner_fn,
|
| 910 |
+
ranges=input_buffer.get_size(),
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
# Step 4: Unary post op if has
|
| 914 |
+
if unary_attr != "none":
|
| 915 |
+
output_buf = create_epilogue_with_attr(
|
| 916 |
+
output_buf,
|
| 917 |
+
unary_attr,
|
| 918 |
+
scalars=unary_scalars,
|
| 919 |
+
algorithm=unary_algorithmm,
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
# Step 5: Cast output to Target Dtype
|
| 923 |
+
if output_dtype == torch.bfloat16:
|
| 924 |
+
output_cast_loader = output_buf.make_loader()
|
| 925 |
+
|
| 926 |
+
def inner_fn_cast_output_to_bf16(index):
|
| 927 |
+
input = output_cast_loader(index)
|
| 928 |
+
return ops.to_dtype(input, output_dtype)
|
| 929 |
+
|
| 930 |
+
output_buf = ir.Pointwise(
|
| 931 |
+
device=output_buf.get_device(),
|
| 932 |
+
dtype=output_dtype,
|
| 933 |
+
inner_fn=inner_fn_cast_output_to_bf16,
|
| 934 |
+
ranges=output_buf.get_size(),
|
| 935 |
+
)
|
| 936 |
+
elif output_dtype == torch.uint8:
|
| 937 |
+
from .lowering import _create_constants
|
| 938 |
+
|
| 939 |
+
requant_input_loader = output_buf.make_loader()
|
| 940 |
+
|
| 941 |
+
def inner_fn_requant(index, scale, zero_point):
|
| 942 |
+
input = requant_input_loader(index)
|
| 943 |
+
inv_scale, zero_point = _create_constants(
|
| 944 |
+
1.0 / scale, zero_point, dtype=torch.float32
|
| 945 |
+
)
|
| 946 |
+
val = ops.round(input * inv_scale) + zero_point
|
| 947 |
+
qmin, qmax = _create_constants(
|
| 948 |
+
0, 255, dtype=torch.float32
|
| 949 |
+
)
|
| 950 |
+
clamped = ops.minimum(ops.maximum(val, qmin), qmax)
|
| 951 |
+
return ops.to_dtype(clamped, torch.uint8)
|
| 952 |
+
|
| 953 |
+
output_buf = ir.Pointwise(
|
| 954 |
+
device=output_buf.get_device(),
|
| 955 |
+
dtype=torch.uint8,
|
| 956 |
+
inner_fn=functools.partial(
|
| 957 |
+
inner_fn_requant,
|
| 958 |
+
scale=float(o_scale),
|
| 959 |
+
zero_point=int(o_zero_point),
|
| 960 |
+
),
|
| 961 |
+
ranges=output_buf.get_size(),
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
return output_buf
|
| 965 |
+
|
| 966 |
+
CppPackedGemmTemplate.add_choices(
|
| 967 |
+
choices,
|
| 968 |
+
layout,
|
| 969 |
+
[x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2]
|
| 970 |
+
if bias is None
|
| 971 |
+
else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias],
|
| 972 |
+
has_bias=bias is not None,
|
| 973 |
+
epilogue_creator=epilogue_creator,
|
| 974 |
+
# Reorder bias and x2
|
| 975 |
+
input_indices=[0, 3, 1, 2, 4, 5, 6]
|
| 976 |
+
if bias is None
|
| 977 |
+
else [7, 0, 3, 1, 2, 4, 5, 6],
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
+
if len(choices) == 0 or use_aten_gemm_kernels():
|
| 981 |
+
kwargs = dict(
|
| 982 |
+
output_scale=o_scale,
|
| 983 |
+
output_zero_point=o_zero_point,
|
| 984 |
+
output_dtype=output_dtype,
|
| 985 |
+
other_scale=x2_scale,
|
| 986 |
+
other_zp=x2_zp,
|
| 987 |
+
binary_post_op=binary_attr,
|
| 988 |
+
binary_alpha=alpha,
|
| 989 |
+
unary_post_op=unary_attr,
|
| 990 |
+
unary_post_op_args=unary_scalars,
|
| 991 |
+
unary_post_op_algorithm=unary_algorithmm,
|
| 992 |
+
)
|
| 993 |
+
if bias is None:
|
| 994 |
+
kwargs["bias"] = None
|
| 995 |
+
choices.append(
|
| 996 |
+
aten_mkldnn_qlinear_binary.bind(
|
| 997 |
+
(x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2)
|
| 998 |
+
if bias is None
|
| 999 |
+
else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias),
|
| 1000 |
+
layout,
|
| 1001 |
+
**kwargs,
|
| 1002 |
+
)
|
| 1003 |
+
)
|
| 1004 |
+
assert packed_weight.get_name() in V.graph.constants
|
| 1005 |
+
input_gen_fns = {
|
| 1006 |
+
3: lambda x: V.graph.constants[x.get_name()],
|
| 1007 |
+
4: lambda x: V.graph.constants[x.get_name()],
|
| 1008 |
+
5: lambda x: V.graph.constants[x.get_name()],
|
| 1009 |
+
}
|
| 1010 |
+
if bias is not None:
|
| 1011 |
+
input_gen_fns[7] = lambda x: V.graph.constants[x.get_name()] # For bias
|
| 1012 |
+
result = autotune_select_algorithm(
|
| 1013 |
+
"qlinear_binary",
|
| 1014 |
+
choices,
|
| 1015 |
+
[x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2]
|
| 1016 |
+
if bias is None
|
| 1017 |
+
else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias],
|
| 1018 |
+
layout,
|
| 1019 |
+
input_gen_fns=input_gen_fns,
|
| 1020 |
+
)
|
| 1021 |
+
if len(x_size) > 2 and binary_attr == "add":
|
| 1022 |
+
result = view(result, (*x_size[:-1], result.get_size()[-1]))
|
| 1023 |
+
return result
|
| 1024 |
+
|
| 1025 |
+
if torch._C.has_mkl:
|
| 1026 |
+
aten_mkl_linear = ExternKernelChoice(
|
| 1027 |
+
torch.ops.mkl._mkl_linear,
|
| 1028 |
+
"mkl::_mkl_linear",
|
| 1029 |
+
has_out_variant=False,
|
| 1030 |
+
kernel_creator=mkldnn_ir.MKLPackedLinear.create,
|
| 1031 |
+
)
|
| 1032 |
+
cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear)
|
| 1033 |
+
|
| 1034 |
+
@register_lowering(torch.ops.mkl._mkl_linear)
|
| 1035 |
+
def mkl_packed_linear(
|
| 1036 |
+
x: TensorBox,
|
| 1037 |
+
packed_w: TensorBox,
|
| 1038 |
+
orig_w: TensorBox,
|
| 1039 |
+
b: Optional[TensorBox],
|
| 1040 |
+
batch_size,
|
| 1041 |
+
*,
|
| 1042 |
+
layout=None,
|
| 1043 |
+
):
|
| 1044 |
+
choices: List[ChoiceCaller] = []
|
| 1045 |
+
if use_max_autotune():
|
| 1046 |
+
transposed_w = permute(orig_w, [1, 0])
|
| 1047 |
+
*_, layout, x, transposed_w = mm_args(
|
| 1048 |
+
x, transposed_w, layout=layout
|
| 1049 |
+
)
|
| 1050 |
+
if use_cpp_packed_gemm_template(layout, x, transposed_w):
|
| 1051 |
+
CppPackedGemmTemplate.add_choices(
|
| 1052 |
+
choices,
|
| 1053 |
+
layout,
|
| 1054 |
+
[x, packed_w, orig_w],
|
| 1055 |
+
trans_w=True,
|
| 1056 |
+
input_indices=[0, 2],
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
if len(choices) == 0 or use_aten_gemm_kernels():
|
| 1060 |
+
choices.append(
|
| 1061 |
+
aten_mkl_linear.bind(
|
| 1062 |
+
(x, packed_w, orig_w), layout, B=None, batch_size=batch_size
|
| 1063 |
+
)
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
assert packed_w.get_name() in V.graph.constants
|
| 1067 |
+
assert orig_w.get_name() in V.graph.constants
|
| 1068 |
+
# packed_w is a mkldnn tensor which we can't generate directly
|
| 1069 |
+
# so we use the weights from the original tensor in autotune.
|
| 1070 |
+
input_gen_fns = {
|
| 1071 |
+
1: lambda x: V.graph.constants[x.get_name()],
|
| 1072 |
+
2: lambda x: V.graph.constants[x.get_name()],
|
| 1073 |
+
}
|
| 1074 |
+
result: TensorBox = autotune_select_algorithm(
|
| 1075 |
+
"packed_linear",
|
| 1076 |
+
choices,
|
| 1077 |
+
[x, packed_w, orig_w],
|
| 1078 |
+
layout,
|
| 1079 |
+
input_gen_fns=input_gen_fns,
|
| 1080 |
+
)
|
| 1081 |
+
if b is not None:
|
| 1082 |
+
result = add(result, b)
|
| 1083 |
+
return result
|
| 1084 |
+
|
| 1085 |
+
add_needs_realized_inputs(cpu_needs_realized_inputs)
|
| 1086 |
+
else:
|
| 1087 |
+
pass
|
.venv/lib/python3.11/site-packages/torch/_inductor/package/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .package import load_package, package_aoti
|
.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-311.pyc
ADDED
|
Binary file (532 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/package.cpython-311.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_inductor/package/build_package.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
build_package_contents = """
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from torch._inductor.package.package import compile_so
|
| 6 |
+
|
| 7 |
+
curr_dir = Path(__file__).parent
|
| 8 |
+
aoti_files = [
|
| 9 |
+
os.path.join(root, file)
|
| 10 |
+
for root, dirs, files in os.walk(curr_dir)
|
| 11 |
+
for file in files
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
output_so = compile_so(curr_dir, aoti_files, curr_dir)
|
| 15 |
+
"""
|
.venv/lib/python3.11/site-packages/torch/_inductor/package/package.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import shlex
|
| 5 |
+
import subprocess
|
| 6 |
+
import tempfile
|
| 7 |
+
import zipfile
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Callable, List, Optional, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch._inductor
|
| 13 |
+
import torch.utils._pytree as pytree
|
| 14 |
+
from torch._inductor import config, exc
|
| 15 |
+
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
|
| 16 |
+
from torch.export._tree_utils import reorder_kwargs
|
| 17 |
+
|
| 18 |
+
from .build_package import build_package_contents
|
| 19 |
+
from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class PT2ArchiveWriter:
|
| 23 |
+
def __init__(self, archive_path: str) -> None:
|
| 24 |
+
self.archive_path: str = archive_path
|
| 25 |
+
self.archive_file: Optional[zipfile.ZipFile] = None
|
| 26 |
+
|
| 27 |
+
def __enter__(self) -> "PT2ArchiveWriter":
|
| 28 |
+
assert self.archive_file is None
|
| 29 |
+
self.archive_file = zipfile.ZipFile(
|
| 30 |
+
self.archive_path, "w", compression=zipfile.ZIP_STORED
|
| 31 |
+
)
|
| 32 |
+
self.writestr("version", str(ARCHIVE_VERSION))
|
| 33 |
+
self.writestr("archive_format", "pt2")
|
| 34 |
+
return self
|
| 35 |
+
|
| 36 |
+
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
|
| 37 |
+
assert self.archive_file is not None
|
| 38 |
+
self.archive_file.close()
|
| 39 |
+
self.archive_file = None
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
def writestr(self, name: str, data: Union[bytes, str]) -> None:
|
| 43 |
+
assert self.archive_file is not None
|
| 44 |
+
self.archive_file.writestr(name, data)
|
| 45 |
+
|
| 46 |
+
def write_file(self, name: str, file_path: str) -> None:
|
| 47 |
+
"""
|
| 48 |
+
Copy a file into the archive.
|
| 49 |
+
name: The destination file inside the archive.
|
| 50 |
+
file_path: The source file on disk.
|
| 51 |
+
"""
|
| 52 |
+
assert Path(file_path).is_file(), f"{file_path} is not a valid file path"
|
| 53 |
+
assert self.archive_file is not None
|
| 54 |
+
self.archive_file.write(file_path, arcname=name)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class PT2ArchiveReader:
|
| 58 |
+
def __init__(self, archive_path: str) -> None:
|
| 59 |
+
self.archive_path: str = archive_path
|
| 60 |
+
self.archive_file: Optional[zipfile.ZipFile] = None
|
| 61 |
+
|
| 62 |
+
def __enter__(self) -> "PT2ArchiveReader":
|
| 63 |
+
self.archive_file = zipfile.ZipFile(
|
| 64 |
+
self.archive_path, "r", compression=zipfile.ZIP_STORED
|
| 65 |
+
)
|
| 66 |
+
return self
|
| 67 |
+
|
| 68 |
+
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
|
| 69 |
+
if self.archive_file is not None:
|
| 70 |
+
self.archive_file.close()
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
def read(self, name: str) -> bytes:
|
| 74 |
+
assert self.archive_file is not None
|
| 75 |
+
return self.archive_file.read(name)
|
| 76 |
+
|
| 77 |
+
def extract_to_path(self, member: str, path: str) -> str:
|
| 78 |
+
assert self.archive_file is not None
|
| 79 |
+
return self.archive_file.extract(member, path)
|
| 80 |
+
|
| 81 |
+
def extractall(self, path: str) -> None:
|
| 82 |
+
assert self.archive_file is not None
|
| 83 |
+
self.archive_file.extractall(path)
|
| 84 |
+
|
| 85 |
+
def get_file_names(self) -> List[str]:
|
| 86 |
+
assert self.archive_file is not None
|
| 87 |
+
return self.archive_file.namelist()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _run_command_and_check(cmd: str) -> None:
|
| 91 |
+
cmd = shlex.split(cmd)
|
| 92 |
+
try:
|
| 93 |
+
subprocess.run(cmd, check=True)
|
| 94 |
+
except subprocess.CalledProcessError as e:
|
| 95 |
+
raise exc.CppCompileError(cmd, e.output) from e
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str:
|
| 99 |
+
def get_aoti_file_with_suffix(suffix: str) -> str:
|
| 100 |
+
for file in aoti_files:
|
| 101 |
+
if file.endswith(suffix):
|
| 102 |
+
return file
|
| 103 |
+
raise RuntimeError(f"Unable to find file with suffix {suffix}")
|
| 104 |
+
|
| 105 |
+
# Compile all the files into a .so
|
| 106 |
+
cpp_file = os.path.join(aoti_dir, get_aoti_file_with_suffix(".cpp"))
|
| 107 |
+
consts_o = os.path.join(aoti_dir, get_aoti_file_with_suffix(".o"))
|
| 108 |
+
|
| 109 |
+
file_name = os.path.splitext(cpp_file)[0]
|
| 110 |
+
|
| 111 |
+
# Parse compile flags and build the .o file
|
| 112 |
+
with open(file_name + "_compile_flags.json") as f:
|
| 113 |
+
compile_flags = json.load(f)
|
| 114 |
+
|
| 115 |
+
compile_options = BuildOptionsBase(**compile_flags)
|
| 116 |
+
object_builder = CppBuilder(
|
| 117 |
+
name=file_name,
|
| 118 |
+
sources=cpp_file,
|
| 119 |
+
BuildOption=compile_options,
|
| 120 |
+
)
|
| 121 |
+
compile_cmd = object_builder.get_command_line()
|
| 122 |
+
output_o = object_builder.get_target_file_path()
|
| 123 |
+
|
| 124 |
+
_run_command_and_check(compile_cmd)
|
| 125 |
+
|
| 126 |
+
# Parse linker flags and build the .so file
|
| 127 |
+
with open(file_name + "_linker_flags.json") as f:
|
| 128 |
+
linker_flags = json.load(f)
|
| 129 |
+
|
| 130 |
+
linker_options = BuildOptionsBase(**linker_flags)
|
| 131 |
+
so_builder = CppBuilder(
|
| 132 |
+
name=os.path.split(so_path)[-1],
|
| 133 |
+
sources=[output_o, consts_o],
|
| 134 |
+
BuildOption=linker_options,
|
| 135 |
+
output_dir=so_path,
|
| 136 |
+
)
|
| 137 |
+
link_cmd = so_builder.get_command_line()
|
| 138 |
+
output_so = so_builder.get_target_file_path()
|
| 139 |
+
|
| 140 |
+
_run_command_and_check(link_cmd)
|
| 141 |
+
|
| 142 |
+
# mmapped weights
|
| 143 |
+
serialized_weights_filename = file_name + "_serialized_weights.bin"
|
| 144 |
+
if serialized_weights_filename in aoti_files:
|
| 145 |
+
with open(serialized_weights_filename, "rb") as f_weights:
|
| 146 |
+
serialized_weights = f_weights.read()
|
| 147 |
+
|
| 148 |
+
with open(output_so, "a+b") as f_so:
|
| 149 |
+
so_size = f_so.tell()
|
| 150 |
+
# Page align the weights
|
| 151 |
+
f_so.write(b" " * (16384 - so_size % 16384))
|
| 152 |
+
f_so.write(serialized_weights)
|
| 153 |
+
|
| 154 |
+
return output_so
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def package_aoti(aoti_output_dir: str) -> str:
|
| 158 |
+
"""
|
| 159 |
+
Saves the AOTInductor generated files to the PT2Archive format.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
# Add a makefile and python script
|
| 163 |
+
build_package_filename = "build_package.py"
|
| 164 |
+
with open(os.path.join(aoti_output_dir, build_package_filename), "w") as f:
|
| 165 |
+
f.write(build_package_contents)
|
| 166 |
+
|
| 167 |
+
with open(os.path.join(aoti_output_dir, "Makefile"), "w") as f:
|
| 168 |
+
f.write(f"all:\n\tpython3 {build_package_filename}\n")
|
| 169 |
+
|
| 170 |
+
if config.aot_inductor.output_path.endswith(".so"):
|
| 171 |
+
raise RuntimeError(
|
| 172 |
+
"Unable to save package as a .so. It should be a .pt2 format or a directory."
|
| 173 |
+
)
|
| 174 |
+
elif config.aot_inductor.output_path.endswith(".pt2"):
|
| 175 |
+
# Save using the PT2 packaging format
|
| 176 |
+
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
|
| 177 |
+
archive_path = config.aot_inductor.output_path
|
| 178 |
+
|
| 179 |
+
with PT2ArchiveWriter(archive_path) as archive_writer:
|
| 180 |
+
package_files = glob.glob(f"{aoti_output_dir}/*")
|
| 181 |
+
|
| 182 |
+
for path in package_files:
|
| 183 |
+
filename = os.path.basename(path)
|
| 184 |
+
archive_writer.write_file(f"{AOTINDUCTOR_DIR}{filename}", path)
|
| 185 |
+
|
| 186 |
+
return archive_path
|
| 187 |
+
|
| 188 |
+
else:
|
| 189 |
+
# Directly put the files into the directory, without any archiving
|
| 190 |
+
return aoti_output_dir
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def load_package(path: str, device: str) -> Callable: # type: ignore[type-arg]
|
| 194 |
+
if path.endswith(".so"):
|
| 195 |
+
raise RuntimeError(
|
| 196 |
+
"Unable to load .so. It should be a .pt2 format or a directory."
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
elif path.endswith(".pt2"):
|
| 200 |
+
so_path = os.path.splitext(path)[0]
|
| 201 |
+
with PT2ArchiveReader(path) as archive_reader:
|
| 202 |
+
file_names = archive_reader.get_file_names()
|
| 203 |
+
|
| 204 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 205 |
+
archive_reader.extractall(tmp_dir)
|
| 206 |
+
file_names = archive_reader.get_file_names()
|
| 207 |
+
aoti_files = [
|
| 208 |
+
file for file in file_names if file.startswith(AOTINDUCTOR_DIR)
|
| 209 |
+
]
|
| 210 |
+
|
| 211 |
+
so_path = compile_so(tmp_dir, aoti_files, so_path)
|
| 212 |
+
|
| 213 |
+
else:
|
| 214 |
+
assert os.path.isdir(path), "Must specify a directory or a .pt2 file"
|
| 215 |
+
aoti_files = [
|
| 216 |
+
os.path.join(root, file)
|
| 217 |
+
for root, dirs, files in os.walk(path)
|
| 218 |
+
for file in files
|
| 219 |
+
]
|
| 220 |
+
so_path = compile_so(path, aoti_files, path)
|
| 221 |
+
|
| 222 |
+
if device == "cpu":
|
| 223 |
+
runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
|
| 224 |
+
elif device == "cuda" or device.startswith("cuda:"):
|
| 225 |
+
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
|
| 226 |
+
else:
|
| 227 |
+
raise RuntimeError("Unsupported device " + device)
|
| 228 |
+
|
| 229 |
+
def optimized(*args, **kwargs): # type: ignore[no-untyped-def]
|
| 230 |
+
call_spec = runner.get_call_spec() # type: ignore[attr-defined]
|
| 231 |
+
in_spec = pytree.treespec_loads(call_spec[0])
|
| 232 |
+
out_spec = pytree.treespec_loads(call_spec[1])
|
| 233 |
+
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
|
| 234 |
+
flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
|
| 235 |
+
return pytree.tree_unflatten(flat_outputs, out_spec)
|
| 236 |
+
|
| 237 |
+
return optimized
|
.venv/lib/python3.11/site-packages/torch/_inductor/package/pt2_archive_constants.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ARCHIVE_ROOT_NAME = "package"
|
| 2 |
+
ARCHIVE_FORMAT_PATH = "archive_format"
|
| 3 |
+
MODELS_DIR = "models/"
|
| 4 |
+
MODELS_FILENAME_FORMAT = "models/{}.json"
|
| 5 |
+
AOTINDUCTOR_DIR = "data/aotinductor/"
|
| 6 |
+
WEIGHTS_DIR = "data/weights/"
|
| 7 |
+
WEIGHT_FILENAME_PREFIX = "weight_"
|
| 8 |
+
CONSTANTS_DIR = "data/constants/"
|
| 9 |
+
TENSOR_CONSTANT_FILENAME_PREFIX = "tensor_"
|
| 10 |
+
CUSTOM_OBJ_FILENAME_PREFIX = "custom_obj_"
|
| 11 |
+
SAMPLE_INPUTS_DIR = "data/sample_inputs/"
|
| 12 |
+
SAMPLE_INPUTS_FILENAME_FORMAT = "data/sample_inputs/{}.pt"
|
| 13 |
+
EXTRA_DIR = "extra/"
|
| 14 |
+
MODULE_INFO_PATH = "extra/module_info.json"
|
| 15 |
+
|
| 16 |
+
ARCHIVE_VERSION = 0
|
.venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py
ADDED
|
@@ -0,0 +1,2005 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
"""
|
| 3 |
+
# Inductor Pattern Matcher
|
| 4 |
+
|
| 5 |
+
The pattern matcher enables search/replace within an FX graph.
|
| 6 |
+
|
| 7 |
+
The main entrypoint to the pattern matcher is register_replacement(). Given a
|
| 8 |
+
search function and a replacement function this will register a replacement with
|
| 9 |
+
a pass (such as torch._inductor.fx_passes.joint_graph.patterns).
|
| 10 |
+
|
| 11 |
+
Internally the pattern matcher represents patterns as a graph (a DAG). Creating
|
| 12 |
+
new patterns manually as a graph is cumbersome and error-prone so the standard
|
| 13 |
+
way to create patterns (using register_replacement()) is to provide a search
|
| 14 |
+
function and a replacement function which is traced and converted into a graph.
|
| 15 |
+
|
| 16 |
+
Because the search functions are built somewhat generic (they tend to ignore
|
| 17 |
+
tensor sizes, for example) register_replacement() allows you to specify an
|
| 18 |
+
`extra_check` function which performs additional checks to verify that the
|
| 19 |
+
matched pattern fully matches before returning it.
|
| 20 |
+
|
| 21 |
+
## Precompiled Patterns
|
| 22 |
+
|
| 23 |
+
New patterns are added using register_replacement(). Patterns added in this way
|
| 24 |
+
can have a compile-time overhead because they need to be traced before
|
| 25 |
+
use. Patterns can be precompiled and added using gen_register_replacement()
|
| 26 |
+
instead. To do this you call gen_register_replacement() instead of
|
| 27 |
+
register_replacement(). The arguments are the same except for an additional
|
| 28 |
+
unique name which is used as a lookup key.
|
| 29 |
+
|
| 30 |
+
## Internals
|
| 31 |
+
|
| 32 |
+
The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr
|
| 33 |
+
implements a `_match` method which returns either a `Match` object for a
|
| 34 |
+
successful match or a `FailedMatch` object for a failure to match.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from __future__ import annotations
|
| 38 |
+
|
| 39 |
+
import contextlib
|
| 40 |
+
import dataclasses
|
| 41 |
+
import functools
|
| 42 |
+
import importlib
|
| 43 |
+
import inspect
|
| 44 |
+
import itertools
|
| 45 |
+
import logging
|
| 46 |
+
import operator
|
| 47 |
+
import os
|
| 48 |
+
import re
|
| 49 |
+
import textwrap
|
| 50 |
+
import typing
|
| 51 |
+
from abc import ABC, abstractmethod
|
| 52 |
+
from collections import defaultdict
|
| 53 |
+
from pathlib import Path
|
| 54 |
+
from typing import (
|
| 55 |
+
Any,
|
| 56 |
+
Callable,
|
| 57 |
+
DefaultDict,
|
| 58 |
+
Dict,
|
| 59 |
+
Generator,
|
| 60 |
+
Iterable,
|
| 61 |
+
List,
|
| 62 |
+
Mapping,
|
| 63 |
+
NoReturn,
|
| 64 |
+
Optional,
|
| 65 |
+
Protocol,
|
| 66 |
+
Sequence,
|
| 67 |
+
Set,
|
| 68 |
+
Tuple,
|
| 69 |
+
Type,
|
| 70 |
+
TypeVar,
|
| 71 |
+
Union,
|
| 72 |
+
)
|
| 73 |
+
from typing_extensions import Self, TypeGuard
|
| 74 |
+
|
| 75 |
+
import torch
|
| 76 |
+
import torch._guards
|
| 77 |
+
import torch.fx
|
| 78 |
+
import torch.utils._pytree as pytree
|
| 79 |
+
from torch._dispatch.python import enable_python_dispatcher
|
| 80 |
+
from torch._dynamo.utils import counters
|
| 81 |
+
from torch._inductor.config import trace as trace_config
|
| 82 |
+
from torch._prims_common import is_integer_dtype
|
| 83 |
+
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
| 84 |
+
from torch.fx.experimental.proxy_tensor import make_fx
|
| 85 |
+
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
| 86 |
+
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
| 87 |
+
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
| 88 |
+
|
| 89 |
+
from .._functorch import config as functorch_config
|
| 90 |
+
from .._functorch.aot_autograd import aot_function, make_boxed_func
|
| 91 |
+
from .._functorch.partitioners import default_partition
|
| 92 |
+
from .._subclasses import FakeTensor, FakeTensorMode
|
| 93 |
+
from ..fx import Transformer
|
| 94 |
+
from . import config
|
| 95 |
+
from .decomposition import select_decomp_table
|
| 96 |
+
from .lowering import fallback_node_due_to_unsupported_type
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
log = logging.getLogger(__name__)
|
| 100 |
+
aten = torch.ops.aten
|
| 101 |
+
prims = torch.ops.prims
|
| 102 |
+
|
| 103 |
+
Constant = Any
|
| 104 |
+
NodeOrConstant = Union[Constant, torch.fx.Node]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class SearchFn(Protocol):
|
| 108 |
+
__name__: str
|
| 109 |
+
|
| 110 |
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
| 111 |
+
...
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class ReplaceFn(Protocol):
|
| 115 |
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
| 116 |
+
...
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class TraceFn(Protocol):
|
| 120 |
+
def __call__(
|
| 121 |
+
self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any
|
| 122 |
+
) -> torch.fx.GraphModule:
|
| 123 |
+
...
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
T = TypeVar("T")
|
| 127 |
+
|
| 128 |
+
# What's a better name for this?
|
| 129 |
+
FnsType = Union[torch.fx.node.Target, str]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class Multiple:
|
| 133 |
+
def __init__(self) -> None:
|
| 134 |
+
# Ensure we're really a singleton.
|
| 135 |
+
assert "MULTIPLE" not in globals() or self is MULTIPLE
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# Sentinel indicating multiple quantities can be matched
|
| 139 |
+
MULTIPLE = Multiple()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class Match:
|
| 143 |
+
"""
|
| 144 |
+
Represents a successfully matched pattern.
|
| 145 |
+
|
| 146 |
+
The `Match` object is returned to represent a successfully matched
|
| 147 |
+
pattern. Included in the Match are the pattern that was matched, the graph
|
| 148 |
+
nodes matched, and any args that were used during the matching.
|
| 149 |
+
|
| 150 |
+
The args and kwargs are specific to the type of pattern that was matched and
|
| 151 |
+
provide hints about what was matched.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
pattern: PatternExpr
|
| 155 |
+
args: List[Any]
|
| 156 |
+
kwargs: Dict[str, Any]
|
| 157 |
+
nodes: List[torch.fx.Node]
|
| 158 |
+
targets: Dict[_TargetExpr, torch.fx.node.Target]
|
| 159 |
+
ctx: MatchContext
|
| 160 |
+
replacement_graph: Optional[torch.fx.Graph]
|
| 161 |
+
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
ctx: MatchContext,
|
| 165 |
+
pattern: PatternExpr,
|
| 166 |
+
args: Optional[Sequence[Any]] = None,
|
| 167 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 168 |
+
) -> None:
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.pattern = pattern
|
| 171 |
+
# The input nodes that must be passed in to the result
|
| 172 |
+
self.args = list(args or [])
|
| 173 |
+
self.kwargs = kwargs or {}
|
| 174 |
+
# The nodes matched in this expression
|
| 175 |
+
self.nodes = []
|
| 176 |
+
# Mapping CallFunction to the node.target
|
| 177 |
+
self.targets = {}
|
| 178 |
+
self.ctx = ctx
|
| 179 |
+
self.replacement_graph = None
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def graph(self) -> torch.fx.Graph:
|
| 183 |
+
return self.ctx.graph
|
| 184 |
+
|
| 185 |
+
def extend(self, other: Match) -> None:
|
| 186 |
+
if self.kwargs:
|
| 187 |
+
for key in set(self.kwargs.keys()) & set(other.kwargs.keys()):
|
| 188 |
+
if self.kwargs[key] != other.kwargs[key]:
|
| 189 |
+
raise FailedMatch("kwarg mismatch: {}", key)
|
| 190 |
+
self.args.extend(other.args)
|
| 191 |
+
self.nodes.extend(other.nodes)
|
| 192 |
+
self.kwargs.update(other.kwargs)
|
| 193 |
+
self.targets.update(other.targets)
|
| 194 |
+
|
| 195 |
+
def bundle(self) -> Match:
|
| 196 |
+
# Wrap args in an extra list
|
| 197 |
+
self.args = [tuple(self.args)] if self.args else []
|
| 198 |
+
return self
|
| 199 |
+
|
| 200 |
+
def __repr__(self) -> str:
|
| 201 |
+
return f"Match(..., {self.args}, {self.kwargs})"
|
| 202 |
+
|
| 203 |
+
def erase_nodes(self) -> None:
|
| 204 |
+
graph = self.graph
|
| 205 |
+
for n in reversed(self.nodes):
|
| 206 |
+
if not n._erased and not n.users:
|
| 207 |
+
graph.erase_node(n)
|
| 208 |
+
|
| 209 |
+
def output_nodes(self) -> List[Optional[torch.fx.Node]]:
|
| 210 |
+
return [
|
| 211 |
+
(self.ctx.pattern_to_node[p] if p is not None else None)
|
| 212 |
+
for p in self.ctx.outputs
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
def output_node(self) -> torch.fx.Node:
|
| 216 |
+
return next(p for p in self.output_nodes() if p)
|
| 217 |
+
|
| 218 |
+
def replace_with_graph(
|
| 219 |
+
self, replacement_graph: torch.fx.Graph, args: Sequence[Any]
|
| 220 |
+
) -> None:
|
| 221 |
+
ReplacementPatternEntry.replace_with_graph(
|
| 222 |
+
self, self.ctx.graph, replacement_graph, args
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def replace_by_example(
|
| 226 |
+
self,
|
| 227 |
+
replacement_fn: ReplaceFn,
|
| 228 |
+
args: Sequence[Any],
|
| 229 |
+
trace_fn: Optional[TraceFn] = None,
|
| 230 |
+
run_functional_passes: bool = True,
|
| 231 |
+
) -> None:
|
| 232 |
+
"""Replace with a graph generated by tracing the replacement_fn.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
run_functional_passes (bool). If we should run passes that
|
| 236 |
+
assume functional IR (like DCE, remove_noop_ops), on the
|
| 237 |
+
replacement graph.
|
| 238 |
+
|
| 239 |
+
"""
|
| 240 |
+
from torch._inductor.virtualized import NullHandler, V
|
| 241 |
+
|
| 242 |
+
context = (
|
| 243 |
+
V.fake_mode
|
| 244 |
+
if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None))
|
| 245 |
+
else contextlib.nullcontext()
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
with context:
|
| 249 |
+
if trace_fn is None:
|
| 250 |
+
trace_fn = functools.partial(
|
| 251 |
+
fwd_only, run_functional_passes=run_functional_passes
|
| 252 |
+
)
|
| 253 |
+
replacement = trace_fn(
|
| 254 |
+
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
|
| 255 |
+
)
|
| 256 |
+
ReplacementPatternEntry.replace_with_graph(
|
| 257 |
+
self,
|
| 258 |
+
self.ctx.graph,
|
| 259 |
+
replacement,
|
| 260 |
+
args,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class FailedMatch(RuntimeError):
|
| 265 |
+
"""
|
| 266 |
+
Represents a unsuccessful match.
|
| 267 |
+
|
| 268 |
+
The `FailedMatch` object is returned to represent a failure to match a
|
| 269 |
+
pattern.
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
format_string: str
|
| 273 |
+
|
| 274 |
+
def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None:
|
| 275 |
+
self.format_string = format_string
|
| 276 |
+
# We want to construct error messages lazily instead of eagerly, as
|
| 277 |
+
# constructing them eagerly can significantly worsen compile times.
|
| 278 |
+
if len(format_string) > 200:
|
| 279 |
+
raise RuntimeError(
|
| 280 |
+
f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}"
|
| 281 |
+
)
|
| 282 |
+
self.args = args
|
| 283 |
+
self.kwargs = kwargs
|
| 284 |
+
|
| 285 |
+
def __str__(self) -> str:
|
| 286 |
+
return self.format_string.format(*self.args, **self.kwargs)
|
| 287 |
+
|
| 288 |
+
def __bool__(self) -> bool:
|
| 289 |
+
return False
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
MatchResult = Union[Match, FailedMatch]
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def is_match(m: MatchResult) -> TypeGuard[Match]:
|
| 296 |
+
"""
|
| 297 |
+
TypeGuards cannot act on `self`. Thus this function exists to let mypy
|
| 298 |
+
recognize FailedMatch.__bool__ as a TypeGuard.
|
| 299 |
+
"""
|
| 300 |
+
return bool(m)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class MatchContext:
|
| 304 |
+
"""
|
| 305 |
+
Internal state needed while running PatternExpr._match().
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
outputs: List[Optional[PatternExpr]]
|
| 309 |
+
pattern_to_node: Dict[PatternExpr, Optional[torch.fx.Node]]
|
| 310 |
+
graph: torch.fx.Graph
|
| 311 |
+
exclusive_node_set: List[NodeOrConstant]
|
| 312 |
+
|
| 313 |
+
def __init__(
|
| 314 |
+
self,
|
| 315 |
+
outputs: List[Optional[PatternExpr]],
|
| 316 |
+
pattern_to_node: Optional[Dict[PatternExpr, torch.fx.Node]] = None,
|
| 317 |
+
*,
|
| 318 |
+
graph: torch.fx.Graph,
|
| 319 |
+
) -> None:
|
| 320 |
+
self.outputs = outputs
|
| 321 |
+
self.pattern_to_node = {} if pattern_to_node is None else dict(pattern_to_node)
|
| 322 |
+
self.graph = graph
|
| 323 |
+
self.exclusive_node_set = []
|
| 324 |
+
|
| 325 |
+
def match(self, pattern: PatternExpr, node: NodeOrConstant) -> MatchResult:
|
| 326 |
+
"""wrapper to check reused nodes in patterns"""
|
| 327 |
+
if pattern in self.pattern_to_node:
|
| 328 |
+
if self.pattern_to_node[pattern] == node:
|
| 329 |
+
return Match(self, pattern) # already checked this node
|
| 330 |
+
else:
|
| 331 |
+
return FailedMatch("repeated pattern differs")
|
| 332 |
+
m = pattern._match(node, self)
|
| 333 |
+
assert pattern not in self.pattern_to_node
|
| 334 |
+
self.pattern_to_node[pattern] = node if m else None
|
| 335 |
+
return m
|
| 336 |
+
|
| 337 |
+
def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]:
|
| 338 |
+
return {
|
| 339 |
+
pattern: node
|
| 340 |
+
for pattern, node in self.pattern_to_node.items()
|
| 341 |
+
if pattern.has_multiple_users() and node is not None
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class PatternExpr(ABC):
|
| 346 |
+
"""
|
| 347 |
+
Base class for types of patterns.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
@abstractmethod
|
| 351 |
+
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
|
| 352 |
+
...
|
| 353 |
+
|
| 354 |
+
def match(self, node: torch.fx.Node) -> MatchResult:
|
| 355 |
+
try:
|
| 356 |
+
return MatchContext([self], graph=node.graph).match(self, node)
|
| 357 |
+
except FailedMatch as e:
|
| 358 |
+
return e
|
| 359 |
+
|
| 360 |
+
def has_multiple_users(self) -> bool:
|
| 361 |
+
return False
|
| 362 |
+
|
| 363 |
+
def __repr__(self) -> str:
|
| 364 |
+
return self.__class__.__name__ + "()"
|
| 365 |
+
|
| 366 |
+
def find_anchor_nodes(
|
| 367 |
+
self, ctx: MatchContext, searched: Set[torch.fx.Node]
|
| 368 |
+
) -> Generator[Optional[torch.fx.Node], None, None]:
|
| 369 |
+
if self in ctx.pattern_to_node:
|
| 370 |
+
yield ctx.pattern_to_node[self]
|
| 371 |
+
|
| 372 |
+
def pattern_eq(self, other: Any) -> bool:
|
| 373 |
+
"""
|
| 374 |
+
Compare two `PatternExpr`s and return true if they are the
|
| 375 |
+
same. Note this is NOT matching a pattern - it is comparing the pattern
|
| 376 |
+
structures (for debugging).
|
| 377 |
+
"""
|
| 378 |
+
return isinstance(other, self.__class__)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class Arg(PatternExpr):
|
| 382 |
+
"""
|
| 383 |
+
Capture an arg which will become an input to the handler. Args are
|
| 384 |
+
passed in depth first order.
|
| 385 |
+
"""
|
| 386 |
+
|
| 387 |
+
def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
|
| 388 |
+
return Match(ctx, self, args=[node]) # matches anything
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class Ignored(PatternExpr):
|
| 392 |
+
"""
|
| 393 |
+
Match an arg, but don't pass it to handler
|
| 394 |
+
"""
|
| 395 |
+
|
| 396 |
+
def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
|
| 397 |
+
return Match(ctx, self) # matches anything
|
| 398 |
+
|
| 399 |
+
def __repr__(self) -> str:
|
| 400 |
+
return "*"
|
| 401 |
+
|
| 402 |
+
def pretty_print(self, pp: PatternPrettyPrinter) -> str:
|
| 403 |
+
return "Ignored()"
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
class KeywordArg(PatternExpr):
|
| 407 |
+
"""
|
| 408 |
+
Capture a kwarg which will become an input to the handler.
|
| 409 |
+
"""
|
| 410 |
+
|
| 411 |
+
def __init__(self, name: str) -> None:
|
| 412 |
+
super().__init__()
|
| 413 |
+
self.name = name
|
| 414 |
+
|
| 415 |
+
def __repr__(self) -> str:
|
| 416 |
+
return f"KeywordArg({self.name!r})"
|
| 417 |
+
|
| 418 |
+
def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
|
| 419 |
+
return Match(ctx, self, kwargs={self.name: node}) # matches anything
|
| 420 |
+
|
| 421 |
+
def pattern_eq(self, other: Any) -> bool:
|
| 422 |
+
other = typing.cast(Self, other) # super makes sure this is true
|
| 423 |
+
return super().pattern_eq(other) and self.name == other.name
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class ExclusiveKeywordArg(PatternExpr):
|
| 427 |
+
"""
|
| 428 |
+
Capture a kwarg which will become an input to the handler.
|
| 429 |
+
"""
|
| 430 |
+
|
| 431 |
+
name: str
|
| 432 |
+
|
| 433 |
+
def __init__(self, name: str) -> None:
|
| 434 |
+
super().__init__()
|
| 435 |
+
self.name = name
|
| 436 |
+
|
| 437 |
+
def __repr__(self) -> str:
|
| 438 |
+
return f"ExclusiveKeywordArg({self.name!r})"
|
| 439 |
+
|
| 440 |
+
def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
|
| 441 |
+
if node in ctx.exclusive_node_set:
|
| 442 |
+
return FailedMatch("exclusive arg appears twice")
|
| 443 |
+
|
| 444 |
+
ctx.exclusive_node_set.append(node)
|
| 445 |
+
return Match(ctx, self, kwargs={self.name: node}) # matches anything
|
| 446 |
+
|
| 447 |
+
def pattern_eq(self, other: Any) -> bool:
|
| 448 |
+
other = typing.cast(Self, other) # super makes sure this is true
|
| 449 |
+
return super().pattern_eq(other) and self.name == other.name
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
class _TargetExpr(PatternExpr):
|
| 453 |
+
"""
|
| 454 |
+
Base class for filtering match by node.target
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
fns: List[FnsType]
|
| 458 |
+
fns_set: Set[FnsType]
|
| 459 |
+
|
| 460 |
+
def __init__(
|
| 461 |
+
self, fns: Union[FnsType, Sequence[FnsType]], users: Union[Multiple, int] = 1
|
| 462 |
+
) -> None:
|
| 463 |
+
super().__init__()
|
| 464 |
+
fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns)
|
| 465 |
+
for fn in fns:
|
| 466 |
+
if isinstance(fn, torch._ops.OpOverloadPacket):
|
| 467 |
+
fns.extend(getattr(fn, overload) for overload in fn.overloads())
|
| 468 |
+
|
| 469 |
+
self.fns = fns
|
| 470 |
+
self.fns_set = set(fns)
|
| 471 |
+
self.users = users
|
| 472 |
+
|
| 473 |
+
@property
|
| 474 |
+
@abstractmethod
|
| 475 |
+
def op(self) -> str:
|
| 476 |
+
...
|
| 477 |
+
|
| 478 |
+
def fns_repr(self) -> str:
|
| 479 |
+
first_repr = self.fns[0]
|
| 480 |
+
if not isinstance(first_repr, str):
|
| 481 |
+
first_repr = first_repr.__name__
|
| 482 |
+
|
| 483 |
+
if len(self.fns) > 1:
|
| 484 |
+
return f"[{first_repr}, ...]"
|
| 485 |
+
elif self.fns[0] is getattr(torch, first_repr, None):
|
| 486 |
+
return f"torch.{first_repr}"
|
| 487 |
+
elif isinstance(self.fns[0], torch._ops.OpOverload):
|
| 488 |
+
return str(self.fns[0])
|
| 489 |
+
else:
|
| 490 |
+
return first_repr
|
| 491 |
+
|
| 492 |
+
def __repr__(self) -> str:
|
| 493 |
+
if self.users is MULTIPLE:
|
| 494 |
+
comma_users = ", MULTIPLE"
|
| 495 |
+
elif self.users != 1:
|
| 496 |
+
comma_users = f", {self.users})"
|
| 497 |
+
else:
|
| 498 |
+
comma_users = ""
|
| 499 |
+
return f"{self.__class__.__name__}({self.fns_repr()}{comma_users})"
|
| 500 |
+
|
| 501 |
+
def has_multiple_users(self) -> bool:
|
| 502 |
+
return isinstance(self.users, Multiple) or self.users > 1
|
| 503 |
+
|
| 504 |
+
def find_anchor_nodes(
|
| 505 |
+
self, ctx: MatchContext, searched: Set[torch.fx.Node]
|
| 506 |
+
) -> Generator[Optional[torch.fx.Node], None, None]:
|
| 507 |
+
raise NotImplementedError
|
| 508 |
+
|
| 509 |
+
def _match_fns(self, node: torch.fx.Node) -> bool:
|
| 510 |
+
return (
|
| 511 |
+
isinstance(node, torch.fx.Node)
|
| 512 |
+
and node.op == self.op
|
| 513 |
+
and extract_target(node) in self.fns_set
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
def _match_users(self, node: torch.fx.Node, ctx: MatchContext) -> bool:
|
| 517 |
+
return (
|
| 518 |
+
self in ctx.outputs
|
| 519 |
+
or self.users is MULTIPLE
|
| 520 |
+
or len(node.users) == self.users
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
def pattern_eq(self, other: Any) -> bool:
|
| 524 |
+
other = typing.cast(Self, other) # super makes sure this is true
|
| 525 |
+
return (
|
| 526 |
+
super().pattern_eq(other)
|
| 527 |
+
and self.op == other.op
|
| 528 |
+
and self.fns == other.fns
|
| 529 |
+
and self.users == other.users
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
_SimpleSpec = Tuple[Any, ...]
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class _TargetArgsExpr(_TargetExpr):
|
| 537 |
+
"""
|
| 538 |
+
Base class for filtering match by node.{target,args,kwargs}
|
| 539 |
+
"""
|
| 540 |
+
|
| 541 |
+
def __init__(
|
| 542 |
+
self,
|
| 543 |
+
fns: Union[torch.fx.node.Target, str, Sequence[Any]],
|
| 544 |
+
*args: Any,
|
| 545 |
+
_users: Union[int, Multiple] = 1,
|
| 546 |
+
**kwargs: Any,
|
| 547 |
+
) -> None:
|
| 548 |
+
super().__init__(fns, _users)
|
| 549 |
+
self.args = tuple(args)
|
| 550 |
+
self.kwargs = dict(kwargs)
|
| 551 |
+
if any(
|
| 552 |
+
isinstance(x, (dict, list, tuple))
|
| 553 |
+
for x in itertools.chain(args, kwargs.values())
|
| 554 |
+
):
|
| 555 |
+
self.flatten = self.pytree_flatten
|
| 556 |
+
else:
|
| 557 |
+
self.flatten = self.simple_flatten
|
| 558 |
+
self.flat_args_kwargs = self.flatten(self.args, self.kwargs)
|
| 559 |
+
|
| 560 |
+
@staticmethod
|
| 561 |
+
def simple_flatten(
|
| 562 |
+
args: Sequence[Any], kwargs: Mapping[Any, Any]
|
| 563 |
+
) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
|
| 564 |
+
values = (*args, *kwargs.values())
|
| 565 |
+
spec = (len(args), *kwargs.keys())
|
| 566 |
+
return values, spec
|
| 567 |
+
|
| 568 |
+
@staticmethod
|
| 569 |
+
def pytree_flatten(
|
| 570 |
+
args: Sequence[Any], kwargs: Mapping[Any, Any]
|
| 571 |
+
) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
|
| 572 |
+
def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec:
|
| 573 |
+
if s.type is None:
|
| 574 |
+
return s
|
| 575 |
+
mapping = {immutable_list: list, tuple: list, immutable_dict: dict}
|
| 576 |
+
return pytree.TreeSpec(
|
| 577 |
+
mapping.get(s.type, s.type),
|
| 578 |
+
s.context,
|
| 579 |
+
list(map(norm_spec, s.children_specs)),
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
flat, spec = pytree.tree_flatten([args, kwargs])
|
| 583 |
+
spec = norm_spec(spec)
|
| 584 |
+
return flat, spec
|
| 585 |
+
|
| 586 |
+
def __repr__(self) -> str:
|
| 587 |
+
args = [
|
| 588 |
+
self.fns_repr(),
|
| 589 |
+
*map(repr, self.args),
|
| 590 |
+
*[f"{k}={v}" for k, v in self.kwargs.items()],
|
| 591 |
+
]
|
| 592 |
+
if self.users is MULTIPLE:
|
| 593 |
+
args.append("_users=MULTIPLE")
|
| 594 |
+
elif self.users != 1:
|
| 595 |
+
args.append(f"_users={self.users}")
|
| 596 |
+
return f"{self.__class__.__name__}({', '.join(args)})"
|
| 597 |
+
|
| 598 |
+
def pretty_print(self, pp: PatternPrettyPrinter) -> str:
|
| 599 |
+
args = [
|
| 600 |
+
self.fns_repr(),
|
| 601 |
+
*(pp.pretty_print(x) for x in self.args),
|
| 602 |
+
*[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()],
|
| 603 |
+
]
|
| 604 |
+
if self.users is MULTIPLE:
|
| 605 |
+
args.append("_users=MULTIPLE")
|
| 606 |
+
elif self.users != 1:
|
| 607 |
+
args.append(f"_users={self.users}")
|
| 608 |
+
|
| 609 |
+
joiner_str = ", "
|
| 610 |
+
return f"{self.__class__.__name__}({joiner_str.join(args)})"
|
| 611 |
+
|
| 612 |
+
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
|
| 613 |
+
if not self._match_fns(node) or len(node.args) != len(self.args):
|
| 614 |
+
return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
|
| 615 |
+
|
| 616 |
+
if not self._match_users(node, ctx):
|
| 617 |
+
return FailedMatch("multiple_users {}", self)
|
| 618 |
+
|
| 619 |
+
_args = node.args
|
| 620 |
+
_kwargs = node.kwargs
|
| 621 |
+
if len(_kwargs) < len(self.kwargs):
|
| 622 |
+
from torch.fx.operator_schemas import normalize_function
|
| 623 |
+
|
| 624 |
+
normalized_args_and_kwargs = normalize_function(
|
| 625 |
+
node.target, node.args, node.kwargs # type: ignore[arg-type]
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
if normalized_args_and_kwargs is None:
|
| 629 |
+
return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
|
| 630 |
+
else:
|
| 631 |
+
_args, _kwargs = normalized_args_and_kwargs
|
| 632 |
+
if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs):
|
| 633 |
+
_kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
|
| 634 |
+
else:
|
| 635 |
+
return FailedMatch(
|
| 636 |
+
"function_mismatch: node={}, pattern={}", node, self
|
| 637 |
+
)
|
| 638 |
+
else:
|
| 639 |
+
_kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
|
| 640 |
+
|
| 641 |
+
node_items, node_spec = self.flatten(_args, _kwargs)
|
| 642 |
+
self_items, self_spec = self.flat_args_kwargs
|
| 643 |
+
if node_spec != self_spec:
|
| 644 |
+
return FailedMatch("args_structure {} {}", node_spec, self_spec)
|
| 645 |
+
assert len(node_items) == len(self_items)
|
| 646 |
+
|
| 647 |
+
m = Match(ctx, self)
|
| 648 |
+
for i, pattern, child_node in zip(itertools.count(), self_items, node_items):
|
| 649 |
+
if isinstance(pattern, PatternExpr):
|
| 650 |
+
child_match = ctx.match(pattern, child_node)
|
| 651 |
+
if not is_match(child_match):
|
| 652 |
+
return child_match
|
| 653 |
+
m.extend(child_match)
|
| 654 |
+
elif isinstance(child_node, torch.fx.Node) or child_node != pattern:
|
| 655 |
+
return FailedMatch(
|
| 656 |
+
"constant_args: {} {!r}!={pattern!r}", node, child_node
|
| 657 |
+
)
|
| 658 |
+
m.nodes.append(node)
|
| 659 |
+
m.targets[self] = node.target
|
| 660 |
+
return m
|
| 661 |
+
|
| 662 |
+
def find_anchor_nodes(
|
| 663 |
+
self, ctx: MatchContext, searched: Set[torch.fx.Node]
|
| 664 |
+
) -> Generator[Optional[torch.fx.Node], None, None]:
|
| 665 |
+
"""
|
| 666 |
+
This is used when we are matching a pattern with multiple outputs.
|
| 667 |
+
There is a partial match (stored in ctx) and we want to walk
|
| 668 |
+
this pattern to find a connection to an already-matched node.
|
| 669 |
+
|
| 670 |
+
Yields candidate nodes that `self._match` might like.
|
| 671 |
+
"""
|
| 672 |
+
if self in ctx.pattern_to_node:
|
| 673 |
+
yield ctx.pattern_to_node[self]
|
| 674 |
+
return
|
| 675 |
+
|
| 676 |
+
for pattern in self.flat_args_kwargs[0]:
|
| 677 |
+
if isinstance(pattern, PatternExpr):
|
| 678 |
+
for other_node in pattern.find_anchor_nodes(ctx, searched):
|
| 679 |
+
if not isinstance(other_node, torch.fx.Node):
|
| 680 |
+
continue
|
| 681 |
+
for node in other_node.users:
|
| 682 |
+
if node not in searched:
|
| 683 |
+
if self._match_fns(node):
|
| 684 |
+
yield node
|
| 685 |
+
searched.add(node)
|
| 686 |
+
|
| 687 |
+
def pattern_eq(self, other: Any) -> bool:
|
| 688 |
+
other = typing.cast(Self, other) # super makes sure this is true
|
| 689 |
+
return (
|
| 690 |
+
super().pattern_eq(other)
|
| 691 |
+
and self.flat_args_kwargs[1] == other.flat_args_kwargs[1]
|
| 692 |
+
and all(
|
| 693 |
+
a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b
|
| 694 |
+
for a, b in zip(self.flat_args_kwargs[0], other.flat_args_kwargs[0])
|
| 695 |
+
)
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
class CallFunction(_TargetArgsExpr):
|
| 700 |
+
"""
|
| 701 |
+
Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)`
|
| 702 |
+
"""
|
| 703 |
+
|
| 704 |
+
op = "call_function"
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
class CallMethod(_TargetArgsExpr):
|
| 708 |
+
"""
|
| 709 |
+
Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)`
|
| 710 |
+
"""
|
| 711 |
+
|
| 712 |
+
op = "call_method"
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
class CallModule(_TargetArgsExpr):
|
| 716 |
+
"""
|
| 717 |
+
Matches a call_module node in the FX graphs: `module(*args, **kwargs)`
|
| 718 |
+
"""
|
| 719 |
+
|
| 720 |
+
op = "call_module"
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
class _TargetExprVarArgs(_TargetExpr):
|
| 724 |
+
"""
|
| 725 |
+
Matches a call_function node with any arguments which are passed into the pattern
|
| 726 |
+
"""
|
| 727 |
+
|
| 728 |
+
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
|
| 729 |
+
if not self._match_fns(node):
|
| 730 |
+
return FailedMatch("function_mismatch")
|
| 731 |
+
|
| 732 |
+
if not self._match_users(node, ctx):
|
| 733 |
+
return FailedMatch("multiple_users")
|
| 734 |
+
|
| 735 |
+
m = Match(ctx, self)
|
| 736 |
+
m.nodes.append(node)
|
| 737 |
+
m.targets[self] = node.target
|
| 738 |
+
m.args.extend(node.args)
|
| 739 |
+
m.kwargs.update(node.kwargs)
|
| 740 |
+
return m
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class CallFunctionVarArgs(_TargetExprVarArgs):
|
| 744 |
+
op = "call_function"
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
class CallMethodVarArgs(_TargetExprVarArgs):
|
| 748 |
+
op = "call_method"
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
class CallModuleVarArgs(_TargetExprVarArgs):
|
| 752 |
+
op = "call_module"
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
class ListOf(PatternExpr):
|
| 756 |
+
"""
|
| 757 |
+
Matches a repeated pattern
|
| 758 |
+
"""
|
| 759 |
+
|
| 760 |
+
def __init__(self, pattern: PatternExpr, partial: bool = False) -> None:
|
| 761 |
+
super().__init__()
|
| 762 |
+
assert isinstance(pattern, PatternExpr)
|
| 763 |
+
self.pattern = pattern
|
| 764 |
+
self.partial = partial
|
| 765 |
+
|
| 766 |
+
def __repr__(self) -> str:
|
| 767 |
+
return f"{self.__class__.__name__}({self.pattern})"
|
| 768 |
+
|
| 769 |
+
def _match(self, node: List[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override]
|
| 770 |
+
if not isinstance(node, (list, tuple)) or len(node) == 0:
|
| 771 |
+
return FailedMatch("non_list")
|
| 772 |
+
m = Match(ctx, self)
|
| 773 |
+
# Propagating patterns with multiple users will ensure we don't revisit
|
| 774 |
+
# the same nodes
|
| 775 |
+
pattern_to_node = ctx.filter_multi_user_patterns()
|
| 776 |
+
matched = False
|
| 777 |
+
for i, child_node in enumerate(node):
|
| 778 |
+
child_ctx = MatchContext(
|
| 779 |
+
ctx.outputs, pattern_to_node, graph=child_node.graph
|
| 780 |
+
)
|
| 781 |
+
child_match = child_ctx.match(self.pattern, child_node)
|
| 782 |
+
pattern_to_node = child_ctx.filter_multi_user_patterns()
|
| 783 |
+
if not is_match(child_match):
|
| 784 |
+
if not self.partial:
|
| 785 |
+
return FailedMatch("list[{}]: {}", i, child_match)
|
| 786 |
+
continue
|
| 787 |
+
matched = True
|
| 788 |
+
m.extend(child_match.bundle())
|
| 789 |
+
if not matched:
|
| 790 |
+
return FailedMatch("list: no_match")
|
| 791 |
+
return m.bundle()
|
| 792 |
+
|
| 793 |
+
def pattern_eq(self, other: Any) -> bool:
|
| 794 |
+
other = typing.cast(Self, other) # super makes sure this is true
|
| 795 |
+
return (
|
| 796 |
+
super().pattern_eq(other)
|
| 797 |
+
and self.pattern.pattern_eq(other.pattern)
|
| 798 |
+
and self.partial == other.partial
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
class MultiOutputPattern(PatternExpr):
|
| 803 |
+
outputs: List[Optional[PatternExpr]]
|
| 804 |
+
|
| 805 |
+
def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None:
|
| 806 |
+
super().__init__()
|
| 807 |
+
assert isinstance(outputs[0], _TargetExpr)
|
| 808 |
+
assert all(x is None or isinstance(x, PatternExpr) for x in outputs), outputs
|
| 809 |
+
self.outputs = list(outputs)
|
| 810 |
+
self.op = outputs[0].op
|
| 811 |
+
|
| 812 |
+
@property
|
| 813 |
+
def fns(self) -> Union[Callable[..., Any], str, Sequence[Any]]:
|
| 814 |
+
# This cast is checked above in __init__()
|
| 815 |
+
output = typing.cast(_TargetExpr, self.outputs[0])
|
| 816 |
+
return output.fns
|
| 817 |
+
|
| 818 |
+
def __repr__(self) -> str:
|
| 819 |
+
return f"{self.__class__.__name__}({self.outputs})"
|
| 820 |
+
|
| 821 |
+
def pretty_print(self, pp: PatternPrettyPrinter) -> str:
|
| 822 |
+
args = [pp.pretty_print(x) for x in self.outputs]
|
| 823 |
+
joiner_str = f",\n{' '}"
|
| 824 |
+
str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}"
|
| 825 |
+
str_out = f"{str_out}\n])"
|
| 826 |
+
return str_out
|
| 827 |
+
|
| 828 |
+
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
|
| 829 |
+
output = typing.cast(_TargetExpr, self.outputs[0])
|
| 830 |
+
m = ctx.match(output, node)
|
| 831 |
+
if not is_match(m):
|
| 832 |
+
return m
|
| 833 |
+
|
| 834 |
+
for pattern in self.outputs[1:]:
|
| 835 |
+
if pattern is None:
|
| 836 |
+
continue
|
| 837 |
+
child_match = self._match_from_anchors(pattern, ctx)
|
| 838 |
+
if not is_match(child_match):
|
| 839 |
+
return child_match
|
| 840 |
+
m.extend(child_match)
|
| 841 |
+
|
| 842 |
+
return m
|
| 843 |
+
|
| 844 |
+
def _match_from_anchors(
|
| 845 |
+
self, pattern: PatternExpr, ctx: MatchContext
|
| 846 |
+
) -> MatchResult:
|
| 847 |
+
prior = dict(ctx.pattern_to_node)
|
| 848 |
+
m: MatchResult = FailedMatch("no anchor found")
|
| 849 |
+
for node in pattern.find_anchor_nodes(ctx, set()):
|
| 850 |
+
m = ctx.match(pattern, node)
|
| 851 |
+
if is_match(m):
|
| 852 |
+
return m
|
| 853 |
+
# revert any partial matches
|
| 854 |
+
ctx.pattern_to_node = dict(prior)
|
| 855 |
+
return m
|
| 856 |
+
|
| 857 |
+
def match(self, node: torch.fx.Node) -> MatchResult:
|
| 858 |
+
try:
|
| 859 |
+
return MatchContext(self.outputs, graph=node.graph).match(self, node)
|
| 860 |
+
except FailedMatch as e:
|
| 861 |
+
return e
|
| 862 |
+
|
| 863 |
+
def pattern_eq(self, other: Any) -> bool:
|
| 864 |
+
other = typing.cast(Self, other) # super makes sure this is true
|
| 865 |
+
return (
|
| 866 |
+
super().pattern_eq(other)
|
| 867 |
+
and len(self.outputs) == len(other.outputs)
|
| 868 |
+
and all(
|
| 869 |
+
a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b
|
| 870 |
+
for a, b in zip(self.outputs, other.outputs)
|
| 871 |
+
)
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
class RepeatedExpr(PatternExpr):
|
| 876 |
+
"""
|
| 877 |
+
Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind`
|
| 878 |
+
"""
|
| 879 |
+
|
| 880 |
+
def __init__(self, inner_pattern: _TargetExpr) -> None:
|
| 881 |
+
super().__init__()
|
| 882 |
+
self.inner_pattern = inner_pattern
|
| 883 |
+
self.op = inner_pattern.op
|
| 884 |
+
|
| 885 |
+
@property
|
| 886 |
+
def fns(self) -> Sequence[FnsType]:
|
| 887 |
+
return self.inner_pattern.fns
|
| 888 |
+
|
| 889 |
+
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
|
| 890 |
+
m = ctx.match(self.inner_pattern, node)
|
| 891 |
+
if not is_match(m):
|
| 892 |
+
return m
|
| 893 |
+
ctx.pattern_to_node.pop(
|
| 894 |
+
self.inner_pattern,
|
| 895 |
+
)
|
| 896 |
+
# Check all anchor nodes match the pattern
|
| 897 |
+
for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, set()):
|
| 898 |
+
anchor_m = MatchContext([self], graph=node.graph).match(
|
| 899 |
+
self.inner_pattern, anchor_node
|
| 900 |
+
)
|
| 901 |
+
if not is_match(anchor_m):
|
| 902 |
+
return anchor_m
|
| 903 |
+
m.extend(anchor_m)
|
| 904 |
+
return m
|
| 905 |
+
|
| 906 |
+
def pattern_eq(self, other: Any) -> bool:
|
| 907 |
+
other = typing.cast(Self, other) # super makes sure this is true
|
| 908 |
+
return super().pattern_eq(other) and self.inner_pattern.pattern_eq(
|
| 909 |
+
other.inner_pattern
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
|
| 913 |
+
class PatternPrettyPrinter:
|
| 914 |
+
"""
|
| 915 |
+
Serializes Patterns to executable python.
|
| 916 |
+
XXX: currently only used and tested for fuse attention patterns. May not cover
|
| 917 |
+
all patterns.
|
| 918 |
+
"""
|
| 919 |
+
|
| 920 |
+
def __init__(self) -> None:
|
| 921 |
+
self.namespace = torch.fx.graph._Namespace()
|
| 922 |
+
self.memoized_objs_names: Dict[PatternExpr, str] = {}
|
| 923 |
+
self.memoized_objs_pp: Dict[PatternExpr, str] = {}
|
| 924 |
+
|
| 925 |
+
@staticmethod
|
| 926 |
+
@functools.lru_cache(None)
|
| 927 |
+
def run(obj: PatternExpr, output_name: str = "output") -> str:
|
| 928 |
+
"""
|
| 929 |
+
Serializes obj to python code with obj written out to `output_name`
|
| 930 |
+
"""
|
| 931 |
+
|
| 932 |
+
pp = PatternPrettyPrinter()
|
| 933 |
+
assert hasattr(obj, "pretty_print")
|
| 934 |
+
out_str = obj.pretty_print(pp=pp)
|
| 935 |
+
|
| 936 |
+
output = []
|
| 937 |
+
for key in pp.memoized_objs_names:
|
| 938 |
+
output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}")
|
| 939 |
+
|
| 940 |
+
output.append(f"{output_name} = {out_str}")
|
| 941 |
+
|
| 942 |
+
return "\n".join(output)
|
| 943 |
+
|
| 944 |
+
def pretty_print(self, obj: Any) -> str:
|
| 945 |
+
if isinstance(obj, _TargetArgsExpr):
|
| 946 |
+
if memoized_name := self.memoized_objs_names.get(obj):
|
| 947 |
+
return memoized_name
|
| 948 |
+
else:
|
| 949 |
+
return self.memoize(obj)
|
| 950 |
+
if hasattr(obj, "pretty_print"):
|
| 951 |
+
return obj.pretty_print(self)
|
| 952 |
+
|
| 953 |
+
return repr(obj)
|
| 954 |
+
|
| 955 |
+
def memoize(self, obj: _TargetArgsExpr) -> str:
|
| 956 |
+
obj_str = obj.pretty_print(self)
|
| 957 |
+
obj_name = obj.fns_repr()
|
| 958 |
+
for prefix in ("aten.", "torch.", "prims."):
|
| 959 |
+
obj_name = obj_name.replace(prefix, "")
|
| 960 |
+
|
| 961 |
+
tmp_name = self.namespace.create_name(obj_name, None)
|
| 962 |
+
self.memoized_objs_names[obj] = tmp_name
|
| 963 |
+
self.memoized_objs_pp[obj] = obj_str
|
| 964 |
+
return tmp_name
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
class _PassDictsType(Protocol):
|
| 968 |
+
def __getitem__(self, k: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
|
| 969 |
+
...
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
@dataclasses.dataclass
|
| 973 |
+
class PatternEntry:
|
| 974 |
+
pattern: PatternExpr
|
| 975 |
+
extra_check: Callable[[Match], bool]
|
| 976 |
+
|
| 977 |
+
def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
|
| 978 |
+
raise NotImplementedError
|
| 979 |
+
|
| 980 |
+
def register(
|
| 981 |
+
self,
|
| 982 |
+
pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
|
| 983 |
+
target: Union[torch.fx.node.Target, None] = None,
|
| 984 |
+
prepend: bool = False,
|
| 985 |
+
) -> None:
|
| 986 |
+
if target is None:
|
| 987 |
+
assert hasattr(self.pattern, "fns")
|
| 988 |
+
for fn in self.pattern.fns:
|
| 989 |
+
self.register(pass_dicts, fn, prepend=prepend)
|
| 990 |
+
elif isinstance(pass_dicts, (dict, PatternMatcherPass)):
|
| 991 |
+
assert hasattr(self.pattern, "op")
|
| 992 |
+
if prepend:
|
| 993 |
+
pass_dicts[(self.pattern.op, target)].insert(0, self)
|
| 994 |
+
else:
|
| 995 |
+
pass_dicts[(self.pattern.op, target)].append(self)
|
| 996 |
+
else:
|
| 997 |
+
pass_dicts = typing.cast(Sequence[_PassDictsType], pass_dicts)
|
| 998 |
+
for x in pass_dicts:
|
| 999 |
+
self.register(x, target, prepend=prepend)
|
| 1000 |
+
|
| 1001 |
+
|
| 1002 |
+
@dataclasses.dataclass
|
| 1003 |
+
class LoweringPatternEntry(PatternEntry):
|
| 1004 |
+
handler: Callable[..., Any]
|
| 1005 |
+
|
| 1006 |
+
def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
|
| 1007 |
+
handler = functools.wraps(self.handler)(functools.partial(self.handler, match))
|
| 1008 |
+
with graph.inserting_before(node):
|
| 1009 |
+
replacement = graph.call_function(handler, tuple(match.args), match.kwargs)
|
| 1010 |
+
replacement.meta.update(node.meta)
|
| 1011 |
+
node.replace_all_uses_with(replacement)
|
| 1012 |
+
assert match.nodes[-1] is node
|
| 1013 |
+
match.erase_nodes()
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
@dataclasses.dataclass
|
| 1017 |
+
class GraphPatternEntry(PatternEntry):
|
| 1018 |
+
"""
|
| 1019 |
+
A pattern that runs a function on the FX graph
|
| 1020 |
+
"""
|
| 1021 |
+
|
| 1022 |
+
handler: Callable[..., Any]
|
| 1023 |
+
|
| 1024 |
+
def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
|
| 1025 |
+
with graph.inserting_before(node):
|
| 1026 |
+
self.handler(match, *match.args, **match.kwargs)
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
@dataclasses.dataclass
|
| 1030 |
+
class ReplacementPatternEntry(PatternEntry):
|
| 1031 |
+
normalize_args: Callable[..., List[Any]]
|
| 1032 |
+
|
| 1033 |
+
@staticmethod
|
| 1034 |
+
def replace_with_graph(
|
| 1035 |
+
match: Match,
|
| 1036 |
+
graph: torch.fx.Graph,
|
| 1037 |
+
replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule],
|
| 1038 |
+
args: Sequence[torch.fx.Node],
|
| 1039 |
+
) -> None:
|
| 1040 |
+
class Replacer(torch.fx.Interpreter):
|
| 1041 |
+
call_method = None # type: ignore[assignment]
|
| 1042 |
+
call_module = None # type: ignore[assignment]
|
| 1043 |
+
get_attr = None # type: ignore[assignment]
|
| 1044 |
+
|
| 1045 |
+
def run_node(self, node: torch.fx.Node) -> Any:
|
| 1046 |
+
if node.op in ("placeholder", "output"):
|
| 1047 |
+
return super().run_node(node)
|
| 1048 |
+
if node.op == "call_function":
|
| 1049 |
+
target = node.target
|
| 1050 |
+
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
| 1051 |
+
result = graph.call_function(target, args, kwargs) # type: ignore[arg-type]
|
| 1052 |
+
if "val" in node.meta and "val" not in result.meta:
|
| 1053 |
+
result.meta["val"] = node.meta["val"]
|
| 1054 |
+
if isinstance(node.meta["val"], torch.Tensor):
|
| 1055 |
+
assert "tensor_meta" in node.meta
|
| 1056 |
+
result.meta["tensor_meta"] = node.meta["tensor_meta"]
|
| 1057 |
+
return result
|
| 1058 |
+
raise NotImplementedError(f"unhandled {node}")
|
| 1059 |
+
|
| 1060 |
+
output_nodes = match.output_nodes()
|
| 1061 |
+
|
| 1062 |
+
if len(output_nodes) == 1:
|
| 1063 |
+
last_node = output_nodes[0]
|
| 1064 |
+
else:
|
| 1065 |
+
assert output_nodes[0]
|
| 1066 |
+
nodes = list(output_nodes[0].graph.nodes)
|
| 1067 |
+
indices = [
|
| 1068 |
+
(nodes.index(n), n)
|
| 1069 |
+
for n in output_nodes
|
| 1070 |
+
if isinstance(n, torch.fx.Node)
|
| 1071 |
+
]
|
| 1072 |
+
last_node = min(indices, key=operator.itemgetter(0))[1]
|
| 1073 |
+
|
| 1074 |
+
def percolate_tags(
|
| 1075 |
+
node: torch.fx.Node,
|
| 1076 |
+
tag_name: str,
|
| 1077 |
+
tag_value: str,
|
| 1078 |
+
input_stops: Set[torch.fx.Node],
|
| 1079 |
+
) -> None:
|
| 1080 |
+
queue = [node]
|
| 1081 |
+
visited = set()
|
| 1082 |
+
|
| 1083 |
+
while queue:
|
| 1084 |
+
arg = queue.pop()
|
| 1085 |
+
if (
|
| 1086 |
+
arg not in visited
|
| 1087 |
+
and arg not in input_stops
|
| 1088 |
+
and hasattr(arg, "meta")
|
| 1089 |
+
):
|
| 1090 |
+
visited.add(arg)
|
| 1091 |
+
arg.meta[tag_name] = tag_value
|
| 1092 |
+
queue.extend(arg.all_input_nodes)
|
| 1093 |
+
|
| 1094 |
+
with graph.inserting_before(last_node):
|
| 1095 |
+
replacement = Replacer(replacement_graph).run(*args) # type: ignore[arg-type]
|
| 1096 |
+
if isinstance(replacement, torch.fx.Node):
|
| 1097 |
+
replacement = [replacement]
|
| 1098 |
+
|
| 1099 |
+
def maybe_getitem(node: torch.fx.Node) -> Any:
|
| 1100 |
+
if node.op != "call_function":
|
| 1101 |
+
return None
|
| 1102 |
+
if node.target != operator.getitem:
|
| 1103 |
+
return None
|
| 1104 |
+
assert len(node.args) == 2
|
| 1105 |
+
return node.args[1]
|
| 1106 |
+
|
| 1107 |
+
def replace(
|
| 1108 |
+
old: Union[torch.fx.Node, None],
|
| 1109 |
+
new: Union[torch.fx.Node, Sequence[torch.fx.Node], None],
|
| 1110 |
+
) -> None:
|
| 1111 |
+
if old is None:
|
| 1112 |
+
assert new is None
|
| 1113 |
+
return
|
| 1114 |
+
assert isinstance(old, torch.fx.Node)
|
| 1115 |
+
if new is None:
|
| 1116 |
+
old.replace_all_uses_with(None) # type: ignore[arg-type]
|
| 1117 |
+
graph.erase_node(old)
|
| 1118 |
+
return
|
| 1119 |
+
if isinstance(new, torch.fx.Node):
|
| 1120 |
+
if "val" not in new.meta:
|
| 1121 |
+
new.meta.update(old.meta)
|
| 1122 |
+
|
| 1123 |
+
# Preserve the recompute tags in the replacement graph. We
|
| 1124 |
+
# look at the recompute tags of the original output node to
|
| 1125 |
+
# propagate the tag from the output all the way to the input
|
| 1126 |
+
# args (named as args in the replace_with_graph).
|
| 1127 |
+
# Note that this is best effort. Since patterns are from
|
| 1128 |
+
# many to many, there is no easy way to correctly map the
|
| 1129 |
+
# recomputable tags. It is possible in some scenarios that we
|
| 1130 |
+
# incorrectly tag some nodes as recomputables.
|
| 1131 |
+
for tag_name in ["recompute", "ac_graph_id"]:
|
| 1132 |
+
if tag_name in old.meta:
|
| 1133 |
+
percolate_tags(new, tag_name, old.meta[tag_name], set(args))
|
| 1134 |
+
|
| 1135 |
+
old.replace_all_uses_with(new)
|
| 1136 |
+
graph.erase_node(old)
|
| 1137 |
+
return
|
| 1138 |
+
|
| 1139 |
+
# `new` is not a node: it's a list of nodes.
|
| 1140 |
+
#
|
| 1141 |
+
# This happens when we want to replace a node that has a single
|
| 1142 |
+
# packed return with multiple unpacked returns. We need to do
|
| 1143 |
+
# some graph surgery here.
|
| 1144 |
+
#
|
| 1145 |
+
# Example:
|
| 1146 |
+
# def original_graph(x):
|
| 1147 |
+
# a = op(x)
|
| 1148 |
+
# b = a[0]
|
| 1149 |
+
# c = a[1]
|
| 1150 |
+
# ...
|
| 1151 |
+
#
|
| 1152 |
+
# Assume that we want to replace op(x) with the graph
|
| 1153 |
+
# def new_op(x):
|
| 1154 |
+
# w = x + 1
|
| 1155 |
+
# z = x + 2
|
| 1156 |
+
# return (w, z)
|
| 1157 |
+
#
|
| 1158 |
+
# We need to replace `op` with the contents of `new_op`,
|
| 1159 |
+
# and then rewrite a[0] to be w and a[1] to be z, as so:
|
| 1160 |
+
# def new_graph(x):
|
| 1161 |
+
# w = x + 1
|
| 1162 |
+
# z = x + 2
|
| 1163 |
+
# b = w
|
| 1164 |
+
# c = z
|
| 1165 |
+
# ...
|
| 1166 |
+
old_uses = list(old.users.keys())
|
| 1167 |
+
for user in old_uses:
|
| 1168 |
+
idx = maybe_getitem(user)
|
| 1169 |
+
if idx is None:
|
| 1170 |
+
raise AssertionError("can't handle")
|
| 1171 |
+
replace(user, new[idx]) # type: ignore[index]
|
| 1172 |
+
graph.erase_node(old)
|
| 1173 |
+
|
| 1174 |
+
if len(output_nodes) == len(replacement):
|
| 1175 |
+
for old, new in zip(output_nodes, replacement):
|
| 1176 |
+
replace(old, new)
|
| 1177 |
+
else:
|
| 1178 |
+
assert len(output_nodes) == 1
|
| 1179 |
+
replace(output_nodes[0], replacement)
|
| 1180 |
+
|
| 1181 |
+
match.erase_nodes()
|
| 1182 |
+
|
| 1183 |
+
def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
|
| 1184 |
+
assert match.replacement_graph is not None
|
| 1185 |
+
self.replace_with_graph(
|
| 1186 |
+
match,
|
| 1187 |
+
graph,
|
| 1188 |
+
match.replacement_graph,
|
| 1189 |
+
self.normalize_args(*match.args, **match.kwargs),
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
|
| 1193 |
+
def _return_true(match: Match) -> bool:
|
| 1194 |
+
return True
|
| 1195 |
+
|
| 1196 |
+
|
| 1197 |
+
def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None:
|
| 1198 |
+
log.info(
|
| 1199 |
+
"Replacement pattern %s failed to apply due to shape mismatch: %s",
|
| 1200 |
+
search_fn.__name__,
|
| 1201 |
+
e,
|
| 1202 |
+
)
|
| 1203 |
+
|
| 1204 |
+
|
| 1205 |
+
def register_replacement(
|
| 1206 |
+
search_fn: SearchFn,
|
| 1207 |
+
replace_fn: ReplaceFn,
|
| 1208 |
+
example_inputs: Iterable[Any],
|
| 1209 |
+
trace_fn: TraceFn,
|
| 1210 |
+
pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
|
| 1211 |
+
extra_check: Callable[[Match], bool] = _return_true,
|
| 1212 |
+
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
|
| 1213 |
+
exclusive_arg_names: Sequence[str] = (),
|
| 1214 |
+
search_fn_pattern: Union[PatternExpr, None] = None,
|
| 1215 |
+
) -> bool:
|
| 1216 |
+
"""
|
| 1217 |
+
Create a replacement rule based on example functions that get traced
|
| 1218 |
+
to create patterns. This supports both training and inference when
|
| 1219 |
+
run on a joint forward+backward graph.
|
| 1220 |
+
|
| 1221 |
+
Args:
|
| 1222 |
+
search_fn: traced to give original pattern
|
| 1223 |
+
replace_fn: traced to give replacement graph
|
| 1224 |
+
example_inputs: example inputs for initial trace
|
| 1225 |
+
trace_fn: fwd_only or joint_fwd_bwd
|
| 1226 |
+
pass_dict: dict of passes to register to
|
| 1227 |
+
extra_check: additional check to run on match(using real shapes)
|
| 1228 |
+
"""
|
| 1229 |
+
argnames_static = [*inspect.signature(search_fn).parameters.keys()]
|
| 1230 |
+
|
| 1231 |
+
def check_fn(match: Match) -> bool:
|
| 1232 |
+
"""
|
| 1233 |
+
Often shapes get burned into the pattern, so our initial match ran with
|
| 1234 |
+
`ignore_types=(int, ...)`.
|
| 1235 |
+
|
| 1236 |
+
Recheck the match with the correct shapes.
|
| 1237 |
+
"""
|
| 1238 |
+
argnames = list(argnames_static)
|
| 1239 |
+
for name in argnames:
|
| 1240 |
+
if name not in match.kwargs:
|
| 1241 |
+
raise RuntimeError(
|
| 1242 |
+
f"Not all inputs to pattern found in match.kwargs. Perhaps one "
|
| 1243 |
+
f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}"
|
| 1244 |
+
)
|
| 1245 |
+
|
| 1246 |
+
args = list(
|
| 1247 |
+
torch.fx.map_arg( # type: ignore[arg-type]
|
| 1248 |
+
[match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
|
| 1249 |
+
)
|
| 1250 |
+
)
|
| 1251 |
+
sym_args: List[torch.SymInt] = []
|
| 1252 |
+
with torch._dynamo.utils.detect_fake_mode(args):
|
| 1253 |
+
for i, grad in enumerate(requires_grad):
|
| 1254 |
+
if isinstance(args[i], torch.Tensor):
|
| 1255 |
+
if grad and is_integer_dtype(args[i].dtype):
|
| 1256 |
+
return False
|
| 1257 |
+
|
| 1258 |
+
args[i] = torch.empty_strided(
|
| 1259 |
+
args[i].size(),
|
| 1260 |
+
args[i].stride(),
|
| 1261 |
+
dtype=args[i].dtype,
|
| 1262 |
+
device=args[i].device,
|
| 1263 |
+
requires_grad=grad,
|
| 1264 |
+
)
|
| 1265 |
+
for v in itertools.chain(args[i].shape, args[i].stride()):
|
| 1266 |
+
if isinstance(v, torch.SymInt) and all(
|
| 1267 |
+
guard_size_oblivious(v != a) for a in sym_args
|
| 1268 |
+
):
|
| 1269 |
+
sym_args.append(v)
|
| 1270 |
+
|
| 1271 |
+
# If we were given a pre-traced pattern then use that instead of
|
| 1272 |
+
# retracing. Note that this means the pattern has to be independent
|
| 1273 |
+
# of its args.
|
| 1274 |
+
specific_pattern = search_fn_pattern
|
| 1275 |
+
|
| 1276 |
+
if not specific_pattern:
|
| 1277 |
+
if sym_args:
|
| 1278 |
+
# AOT Autograd and make fx will dedupe symbolic shape size
|
| 1279 |
+
# accesses of sym ints that appear as inputs
|
| 1280 |
+
# We don't want the sym_size uses to interfere with pattern matching
|
| 1281 |
+
# so we provide them as inputs.
|
| 1282 |
+
# Later, when we actually do the replacement, the symbolic shape
|
| 1283 |
+
# sizes will get re-traced and added to the graph.
|
| 1284 |
+
|
| 1285 |
+
def search_fn_new(*args_new: Any) -> Any:
|
| 1286 |
+
return search_fn(*args_new[len(args_new) - len(args) :])
|
| 1287 |
+
|
| 1288 |
+
try:
|
| 1289 |
+
specific_graph = trace_fn(search_fn_new, sym_args + args)
|
| 1290 |
+
except RuntimeError as e:
|
| 1291 |
+
log_trace_failure(search_fn, e)
|
| 1292 |
+
return False
|
| 1293 |
+
|
| 1294 |
+
# correct argnames in the graph
|
| 1295 |
+
sym_arg_names = []
|
| 1296 |
+
for i, placeholder in zip(
|
| 1297 |
+
range(len(sym_args) + len(args)),
|
| 1298 |
+
specific_graph.graph.nodes,
|
| 1299 |
+
):
|
| 1300 |
+
if i < len(sym_args):
|
| 1301 |
+
sym_arg_names.append(placeholder.target)
|
| 1302 |
+
continue
|
| 1303 |
+
|
| 1304 |
+
with specific_graph.graph.inserting_after(placeholder):
|
| 1305 |
+
new_node = specific_graph.graph.placeholder(
|
| 1306 |
+
argnames[i - len(sym_args)]
|
| 1307 |
+
)
|
| 1308 |
+
new_node.target = new_node.name
|
| 1309 |
+
placeholder.replace_all_uses_with(new_node)
|
| 1310 |
+
specific_graph.graph.erase_node(placeholder)
|
| 1311 |
+
|
| 1312 |
+
argnames = sym_arg_names + argnames
|
| 1313 |
+
else:
|
| 1314 |
+
try:
|
| 1315 |
+
specific_graph = trace_fn(search_fn, args)
|
| 1316 |
+
except RuntimeError as e:
|
| 1317 |
+
log_trace_failure(search_fn, e)
|
| 1318 |
+
return False
|
| 1319 |
+
|
| 1320 |
+
specific_pattern = fx_to_pattern(
|
| 1321 |
+
specific_graph,
|
| 1322 |
+
argnames=argnames,
|
| 1323 |
+
exclusive_arg_names=exclusive_arg_names,
|
| 1324 |
+
scalar_workaround=scalar_workaround,
|
| 1325 |
+
)
|
| 1326 |
+
|
| 1327 |
+
node = match.output_nodes()[0]
|
| 1328 |
+
assert node is not None
|
| 1329 |
+
specific_pattern_match = specific_pattern.match(node)
|
| 1330 |
+
|
| 1331 |
+
if is_match(specific_pattern_match) and extra_check(specific_pattern_match):
|
| 1332 |
+
# trace the pattern using the shapes from the user program
|
| 1333 |
+
match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment]
|
| 1334 |
+
return True
|
| 1335 |
+
return False
|
| 1336 |
+
|
| 1337 |
+
def normalize_args(**kwargs: Any) -> List[Any]:
|
| 1338 |
+
args = []
|
| 1339 |
+
for name in argnames_static:
|
| 1340 |
+
args.append(kwargs.pop(name))
|
| 1341 |
+
for i in range(1, len(kwargs) + 1):
|
| 1342 |
+
if f"tangents_{i}" not in kwargs:
|
| 1343 |
+
break
|
| 1344 |
+
args.append(kwargs.pop(f"tangents_{i}"))
|
| 1345 |
+
assert not kwargs, f"leftover kwargs: {kwargs!r}"
|
| 1346 |
+
return args
|
| 1347 |
+
|
| 1348 |
+
if trace_fn is joint_fwd_bwd:
|
| 1349 |
+
# If inference mode is enabled during compilation, assume that we don't
|
| 1350 |
+
# want to match on any training graph patterns
|
| 1351 |
+
if torch.is_inference_mode_enabled():
|
| 1352 |
+
return False
|
| 1353 |
+
|
| 1354 |
+
# TODO: Revisit the functionalize_rng_ops for lowmem dropout
|
| 1355 |
+
with functorch_config.patch(functionalize_rng_ops=False):
|
| 1356 |
+
requires_grad: List[bool] = [
|
| 1357 |
+
isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs
|
| 1358 |
+
]
|
| 1359 |
+
if search_fn_pattern is None:
|
| 1360 |
+
pattern = gen_pattern(
|
| 1361 |
+
search_fn,
|
| 1362 |
+
example_inputs,
|
| 1363 |
+
trace_fn,
|
| 1364 |
+
scalar_workaround,
|
| 1365 |
+
exclusive_arg_names,
|
| 1366 |
+
)
|
| 1367 |
+
else:
|
| 1368 |
+
pattern = search_fn_pattern
|
| 1369 |
+
|
| 1370 |
+
pattern_repr = PatternPrettyPrinter.run(pattern)
|
| 1371 |
+
assert pattern_repr not in _seen_patterns
|
| 1372 |
+
_seen_patterns.add(pattern_repr)
|
| 1373 |
+
pattern = ReplacementPatternEntry(
|
| 1374 |
+
pattern=pattern,
|
| 1375 |
+
extra_check=check_fn,
|
| 1376 |
+
normalize_args=normalize_args,
|
| 1377 |
+
)
|
| 1378 |
+
pattern.register(pass_dicts)
|
| 1379 |
+
return pattern.pattern
|
| 1380 |
+
|
| 1381 |
+
|
| 1382 |
+
_serialized_patterns: Set[str] = set()
|
| 1383 |
+
|
| 1384 |
+
|
| 1385 |
+
def _serialize_pattern(
|
| 1386 |
+
unique_name: str,
|
| 1387 |
+
search_fn: SearchFn,
|
| 1388 |
+
example_inputs: Iterable[Any],
|
| 1389 |
+
trace_fn: TraceFn,
|
| 1390 |
+
scalar_workaround: Union[Dict[str, Union[float, int]], None],
|
| 1391 |
+
) -> PatternExpr:
|
| 1392 |
+
def get_file_template() -> str:
|
| 1393 |
+
auto_generated_msg = textwrap.dedent(
|
| 1394 |
+
"""\
|
| 1395 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 1396 |
+
# To re-generate, run:
|
| 1397 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 1398 |
+
"""
|
| 1399 |
+
)
|
| 1400 |
+
|
| 1401 |
+
file_template = textwrap.dedent(
|
| 1402 |
+
"""\
|
| 1403 |
+
# mypy: ignore-errors
|
| 1404 |
+
|
| 1405 |
+
# noqa: F401, E501
|
| 1406 |
+
{msg}
|
| 1407 |
+
import torch
|
| 1408 |
+
import torch._inductor
|
| 1409 |
+
|
| 1410 |
+
aten = torch.ops.aten
|
| 1411 |
+
prims = torch.ops.prims
|
| 1412 |
+
|
| 1413 |
+
"""
|
| 1414 |
+
).format(msg=auto_generated_msg)
|
| 1415 |
+
|
| 1416 |
+
pattern_matcher_imports = []
|
| 1417 |
+
for name in dir(torch._inductor.pattern_matcher):
|
| 1418 |
+
attr = getattr(torch._inductor.pattern_matcher, name)
|
| 1419 |
+
if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)):
|
| 1420 |
+
pattern_matcher_imports.append(name)
|
| 1421 |
+
|
| 1422 |
+
formatted_imports = ",\n ".join(pattern_matcher_imports)
|
| 1423 |
+
formatted_imports = f"from torch._inductor.pattern_matcher import (\n {formatted_imports},\n)\n"
|
| 1424 |
+
return f"{file_template}{formatted_imports}"
|
| 1425 |
+
|
| 1426 |
+
if not SERIALIZED_PATTERN_PATH.is_dir():
|
| 1427 |
+
raise RuntimeError(
|
| 1428 |
+
f"Could not find serialized patterns directory at {SERIALIZED_PATTERN_PATH}"
|
| 1429 |
+
)
|
| 1430 |
+
|
| 1431 |
+
pattern_name = search_fn.__name__
|
| 1432 |
+
|
| 1433 |
+
from torch._functorch import config as functorch_config
|
| 1434 |
+
|
| 1435 |
+
with functorch_config.patch(functionalize_rng_ops=False):
|
| 1436 |
+
pattern = gen_pattern(search_fn, example_inputs, trace_fn, scalar_workaround)
|
| 1437 |
+
|
| 1438 |
+
serialized_pattern = PatternPrettyPrinter.run(pattern, output_name=unique_name)
|
| 1439 |
+
if pattern_name not in _serialized_patterns:
|
| 1440 |
+
write_mode = "w"
|
| 1441 |
+
_serialized_patterns.add(pattern_name)
|
| 1442 |
+
else:
|
| 1443 |
+
write_mode = "a"
|
| 1444 |
+
|
| 1445 |
+
file_template = get_file_template()
|
| 1446 |
+
|
| 1447 |
+
with open(SERIALIZED_PATTERN_PATH / f"{pattern_name}.py", write_mode) as f:
|
| 1448 |
+
if write_mode == "w":
|
| 1449 |
+
f.write(file_template)
|
| 1450 |
+
else:
|
| 1451 |
+
f.write("\n\n")
|
| 1452 |
+
f.write(serialized_pattern)
|
| 1453 |
+
f.write("\n")
|
| 1454 |
+
|
| 1455 |
+
return pattern
|
| 1456 |
+
|
| 1457 |
+
|
| 1458 |
+
SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patterns"
|
| 1459 |
+
|
| 1460 |
+
# This is the set of serialized patterns that we've registered. Used by
|
| 1461 |
+
# test_serialized_patterns_up_to_date() to ensure the patterns are up
|
| 1462 |
+
# to date.
|
| 1463 |
+
_known_precompiled_patterns: List[
|
| 1464 |
+
Tuple[
|
| 1465 |
+
Any,
|
| 1466 |
+
Iterable[Any],
|
| 1467 |
+
Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule],
|
| 1468 |
+
Any,
|
| 1469 |
+
PatternExpr,
|
| 1470 |
+
]
|
| 1471 |
+
] = []
|
| 1472 |
+
|
| 1473 |
+
|
| 1474 |
+
def gen_register_replacement(
|
| 1475 |
+
unique_name: str,
|
| 1476 |
+
search_fn: SearchFn,
|
| 1477 |
+
replace_fn: ReplaceFn,
|
| 1478 |
+
example_inputs: Iterable[Any],
|
| 1479 |
+
trace_fn: TraceFn,
|
| 1480 |
+
pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
|
| 1481 |
+
extra_check: Callable[[Match], bool] = _return_true,
|
| 1482 |
+
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
|
| 1483 |
+
exclusive_arg_names: Sequence[str] = (),
|
| 1484 |
+
skip_duplicates: bool = False,
|
| 1485 |
+
) -> None:
|
| 1486 |
+
# Make sure the example_inputs is materialized.
|
| 1487 |
+
example_inputs = tuple(example_inputs)
|
| 1488 |
+
|
| 1489 |
+
if "PYTORCH_GEN_PATTERNS" in os.environ:
|
| 1490 |
+
pat = _serialize_pattern(
|
| 1491 |
+
unique_name, search_fn, example_inputs, trace_fn, scalar_workaround
|
| 1492 |
+
)
|
| 1493 |
+
else:
|
| 1494 |
+
pattern_name = search_fn.__name__
|
| 1495 |
+
m = importlib.import_module(
|
| 1496 |
+
f"torch._inductor.fx_passes.serialized_patterns.{pattern_name}"
|
| 1497 |
+
)
|
| 1498 |
+
if not m or not hasattr(m, unique_name):
|
| 1499 |
+
log.warning(
|
| 1500 |
+
"Precompiled pattern %r not found. Run torchgen/fuse/gen_patterns.py.",
|
| 1501 |
+
unique_name,
|
| 1502 |
+
)
|
| 1503 |
+
pat = getattr(m, unique_name)
|
| 1504 |
+
|
| 1505 |
+
for arg in pytree.tree_iter(example_inputs):
|
| 1506 |
+
if isinstance(arg, FakeTensor) and arg.constant is not None:
|
| 1507 |
+
# This can be a problem - small fake tensors (e.g. `tensor(2)`) will
|
| 1508 |
+
# hold onto their original constant value - and by stashing it here
|
| 1509 |
+
# will cause a memory leak if the constant value is on GPU.
|
| 1510 |
+
# Since this is just an optimization we can clear it out.
|
| 1511 |
+
arg.constant = None
|
| 1512 |
+
|
| 1513 |
+
if PatternPrettyPrinter.run(pat) in _seen_patterns and skip_duplicates:
|
| 1514 |
+
return
|
| 1515 |
+
_known_precompiled_patterns.append(
|
| 1516 |
+
(search_fn, example_inputs, trace_fn, scalar_workaround, pat)
|
| 1517 |
+
)
|
| 1518 |
+
register_replacement(
|
| 1519 |
+
search_fn,
|
| 1520 |
+
replace_fn,
|
| 1521 |
+
example_inputs,
|
| 1522 |
+
trace_fn,
|
| 1523 |
+
pass_dicts,
|
| 1524 |
+
extra_check,
|
| 1525 |
+
scalar_workaround,
|
| 1526 |
+
exclusive_arg_names,
|
| 1527 |
+
search_fn_pattern=pat,
|
| 1528 |
+
)
|
| 1529 |
+
|
| 1530 |
+
|
| 1531 |
+
@functorch_config.patch(functionalize_rng_ops=False)
|
| 1532 |
+
def gen_pattern(
|
| 1533 |
+
search_fn: SearchFn,
|
| 1534 |
+
example_inputs: Sequence[Any],
|
| 1535 |
+
trace_fn: TraceFn,
|
| 1536 |
+
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
|
| 1537 |
+
exclusive_arg_names: Sequence[str] = (),
|
| 1538 |
+
) -> PatternExpr:
|
| 1539 |
+
argnames = [*inspect.signature(search_fn).parameters.keys()]
|
| 1540 |
+
|
| 1541 |
+
if scalar_workaround is None:
|
| 1542 |
+
scalar_workaround = {}
|
| 1543 |
+
flat_inputs = []
|
| 1544 |
+
input_idx = 0 # Positional arguments index
|
| 1545 |
+
|
| 1546 |
+
for argname in argnames:
|
| 1547 |
+
if argname in scalar_workaround:
|
| 1548 |
+
flat_inputs.append(scalar_workaround[argname])
|
| 1549 |
+
else:
|
| 1550 |
+
flat_inputs.append(example_inputs[input_idx])
|
| 1551 |
+
input_idx += 1
|
| 1552 |
+
|
| 1553 |
+
search_gm = trace_fn(search_fn, flat_inputs)
|
| 1554 |
+
return fx_to_pattern(
|
| 1555 |
+
search_gm,
|
| 1556 |
+
ignore_types=(int, float, list, torch.device, torch.dtype),
|
| 1557 |
+
argnames=argnames,
|
| 1558 |
+
scalar_workaround=scalar_workaround,
|
| 1559 |
+
exclusive_arg_names=exclusive_arg_names,
|
| 1560 |
+
)
|
| 1561 |
+
|
| 1562 |
+
|
| 1563 |
+
def register_lowering_pattern(
|
| 1564 |
+
pattern: PatternExpr,
|
| 1565 |
+
extra_check: Callable[[Match], bool] = _return_true,
|
| 1566 |
+
*,
|
| 1567 |
+
pass_dict: _PassDictsType,
|
| 1568 |
+
prepend: bool = False,
|
| 1569 |
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
| 1570 |
+
"""
|
| 1571 |
+
Register an aten to inductor IR replacement pattern. The decorated
|
| 1572 |
+
function is saved and then called a lowering time allowing direct
|
| 1573 |
+
pattern to inductor IR conversion.
|
| 1574 |
+
"""
|
| 1575 |
+
|
| 1576 |
+
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
| 1577 |
+
assert callable(handler)
|
| 1578 |
+
LoweringPatternEntry(
|
| 1579 |
+
pattern=pattern, extra_check=extra_check, handler=handler
|
| 1580 |
+
).register(pass_dict, prepend=prepend)
|
| 1581 |
+
handler._inductor_lowering_function = True # type: ignore[attr-defined]
|
| 1582 |
+
return handler
|
| 1583 |
+
|
| 1584 |
+
return decorator
|
| 1585 |
+
|
| 1586 |
+
|
| 1587 |
+
def register_graph_pattern(
|
| 1588 |
+
pattern: PatternExpr,
|
| 1589 |
+
extra_check: Callable[[Match], bool] = _return_true,
|
| 1590 |
+
*,
|
| 1591 |
+
pass_dict: _PassDictsType,
|
| 1592 |
+
prepend: bool = False,
|
| 1593 |
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
| 1594 |
+
"""
|
| 1595 |
+
Register a pattern that runs a function on the FX graph, allowing
|
| 1596 |
+
custom transformation code.
|
| 1597 |
+
"""
|
| 1598 |
+
|
| 1599 |
+
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
| 1600 |
+
assert callable(handler)
|
| 1601 |
+
GraphPatternEntry(
|
| 1602 |
+
pattern=pattern, extra_check=extra_check, handler=handler
|
| 1603 |
+
).register(pass_dict, prepend=prepend)
|
| 1604 |
+
return handler
|
| 1605 |
+
|
| 1606 |
+
return decorator
|
| 1607 |
+
|
| 1608 |
+
|
| 1609 |
+
def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
|
| 1610 |
+
# first node in the graph
|
| 1611 |
+
return node is next(iter(graph.nodes))
|
| 1612 |
+
|
| 1613 |
+
|
| 1614 |
+
# match: copy_, relu_, _set_grad_enabled, manual_seed, _enter_autocast, etc
|
| 1615 |
+
# doesn't match: __rshift__, etc
|
| 1616 |
+
_mutation_op_re = re.compile(r"(?<!_)(_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_))(?!_)")
|
| 1617 |
+
|
| 1618 |
+
|
| 1619 |
+
def is_mutation_op(node: torch.fx.Node) -> bool:
|
| 1620 |
+
if node.op == "call_function":
|
| 1621 |
+
if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
|
| 1622 |
+
return True
|
| 1623 |
+
elif node.op == "call_method":
|
| 1624 |
+
if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type]
|
| 1625 |
+
return True
|
| 1626 |
+
return node.kwargs.get("out") is not None
|
| 1627 |
+
|
| 1628 |
+
|
| 1629 |
+
def same_mutation_regions(a: torch.fx.Node, b: torch.fx.Node) -> bool:
|
| 1630 |
+
assert "mutation_region_id" in a.meta
|
| 1631 |
+
assert "mutation_region_id" in b.meta
|
| 1632 |
+
return a.meta["mutation_region_id"] == b.meta["mutation_region_id"]
|
| 1633 |
+
|
| 1634 |
+
|
| 1635 |
+
def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
|
| 1636 |
+
n = node
|
| 1637 |
+
while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):
|
| 1638 |
+
n = n.prev
|
| 1639 |
+
mutation_region_id = n.meta.get("mutation_region_id", 0)
|
| 1640 |
+
while n is not node:
|
| 1641 |
+
n = n.next
|
| 1642 |
+
if is_mutation_op(n):
|
| 1643 |
+
mutation_region_id += 1
|
| 1644 |
+
n.meta["mutation_region_id"] = mutation_region_id
|
| 1645 |
+
return mutation_region_id
|
| 1646 |
+
|
| 1647 |
+
|
| 1648 |
+
def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool:
|
| 1649 |
+
return "mutation_region_id" not in next(iter(graph.nodes)).meta
|
| 1650 |
+
|
| 1651 |
+
|
| 1652 |
+
def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None:
|
| 1653 |
+
mutation_region_id = 0
|
| 1654 |
+
for nd in graph.nodes:
|
| 1655 |
+
if is_mutation_op(nd):
|
| 1656 |
+
mutation_region_id += 1
|
| 1657 |
+
nd.meta["mutation_region_id"] = mutation_region_id
|
| 1658 |
+
|
| 1659 |
+
|
| 1660 |
+
class PatternMatcherPass:
|
| 1661 |
+
def __init__(
|
| 1662 |
+
self,
|
| 1663 |
+
pass_name: Optional[str] = None,
|
| 1664 |
+
) -> None:
|
| 1665 |
+
super().__init__()
|
| 1666 |
+
self.patterns: DefaultDict[
|
| 1667 |
+
Tuple[str, torch.fx.node.Target], List[PatternEntry]
|
| 1668 |
+
] = defaultdict(list)
|
| 1669 |
+
self.pass_name = pass_name
|
| 1670 |
+
|
| 1671 |
+
def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
|
| 1672 |
+
return self.patterns[item]
|
| 1673 |
+
|
| 1674 |
+
def apply(self, gm: torch.fx.GraphModule) -> int:
|
| 1675 |
+
if not self.patterns:
|
| 1676 |
+
return 0
|
| 1677 |
+
if isinstance(gm, torch.fx.GraphModule):
|
| 1678 |
+
graph = gm.graph
|
| 1679 |
+
elif isinstance(gm, torch.fx.Graph):
|
| 1680 |
+
graph = gm
|
| 1681 |
+
gm = graph.owning_module
|
| 1682 |
+
else:
|
| 1683 |
+
raise RuntimeError(
|
| 1684 |
+
f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}"
|
| 1685 |
+
)
|
| 1686 |
+
if should_compute_mutation_region_ids(graph): # type: ignore[arg-type]
|
| 1687 |
+
compute_mutation_region_ids(graph) # type: ignore[arg-type]
|
| 1688 |
+
get_mutation_region_id_partial = functools.partial(
|
| 1689 |
+
get_mutation_region_id, graph
|
| 1690 |
+
)
|
| 1691 |
+
count = 0
|
| 1692 |
+
nodes = []
|
| 1693 |
+
has_call_module = False
|
| 1694 |
+
for op, target in self.patterns:
|
| 1695 |
+
if op == "call_module":
|
| 1696 |
+
has_call_module = True
|
| 1697 |
+
else:
|
| 1698 |
+
nodes.append(graph.find_nodes(op=op, target=target, sort=False))
|
| 1699 |
+
if has_call_module:
|
| 1700 |
+
nodes.append(graph.find_nodes(op="call_module", sort=False))
|
| 1701 |
+
pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
|
| 1702 |
+
with GraphTransformObserver(
|
| 1703 |
+
gm, pass_name, trace_config.log_url_for_graph_xform
|
| 1704 |
+
):
|
| 1705 |
+
for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
|
| 1706 |
+
target = extract_target(node)
|
| 1707 |
+
if node.op == "call_module":
|
| 1708 |
+
if (node.op, target) not in self.patterns:
|
| 1709 |
+
continue
|
| 1710 |
+
|
| 1711 |
+
# conservatively not applying pattern for cpu input,
|
| 1712 |
+
# since some of the patterns induce codegen and split nodes.
|
| 1713 |
+
# Note: we will only skip cpu compute if disable_cpp_codegen=True
|
| 1714 |
+
if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False):
|
| 1715 |
+
continue
|
| 1716 |
+
|
| 1717 |
+
for entry in self.patterns[(node.op, target)]:
|
| 1718 |
+
if node._erased:
|
| 1719 |
+
break
|
| 1720 |
+
m = entry.pattern.match(node)
|
| 1721 |
+
# pattern match crosses mutation barrier - discard
|
| 1722 |
+
if (
|
| 1723 |
+
is_match(m)
|
| 1724 |
+
and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined]
|
| 1725 |
+
):
|
| 1726 |
+
continue
|
| 1727 |
+
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
|
| 1728 |
+
log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
|
| 1729 |
+
if is_match(m) and entry.extra_check(m):
|
| 1730 |
+
count += 1
|
| 1731 |
+
entry.apply(m, graph, node) # type: ignore[arg-type]
|
| 1732 |
+
counters["inductor"]["pattern_matcher_count"] += 1
|
| 1733 |
+
counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes)
|
| 1734 |
+
return count
|
| 1735 |
+
|
| 1736 |
+
def clear(self) -> None:
|
| 1737 |
+
self.patterns.clear()
|
| 1738 |
+
|
| 1739 |
+
|
| 1740 |
+
def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn:
|
| 1741 |
+
raise NotImplementedError
|
| 1742 |
+
|
| 1743 |
+
|
| 1744 |
+
def fx_to_pattern(
|
| 1745 |
+
gm: Union[torch.fx.GraphModule, torch.fx.Graph],
|
| 1746 |
+
ignore_types: Sequence[Type[Any]] = (),
|
| 1747 |
+
argnames: Sequence[str] = (),
|
| 1748 |
+
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
|
| 1749 |
+
exclusive_arg_names: Sequence[str] = (),
|
| 1750 |
+
) -> PatternExpr:
|
| 1751 |
+
"""
|
| 1752 |
+
Convert an FX graph into a PatternExpr. This is useful for simple
|
| 1753 |
+
patterns that can only match single functions and fixed-length lists.
|
| 1754 |
+
"""
|
| 1755 |
+
# scalar_workaround is a hack to capture dropout_p
|
| 1756 |
+
# see https://github.com/pytorch/pytorch/issues/97894
|
| 1757 |
+
scalar_workaround = scalar_workaround or {}
|
| 1758 |
+
inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()}
|
| 1759 |
+
assert len(inv_scalar_workaround) == len(scalar_workaround)
|
| 1760 |
+
|
| 1761 |
+
def process_arg(x: T) -> Union[T, KeywordArg, Ignored]:
|
| 1762 |
+
if isinstance(x, (float, int)) and x in inv_scalar_workaround:
|
| 1763 |
+
return KeywordArg(inv_scalar_workaround[x])
|
| 1764 |
+
if type(x) in ignore_types:
|
| 1765 |
+
return Ignored()
|
| 1766 |
+
if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x:
|
| 1767 |
+
return Ignored()
|
| 1768 |
+
return x
|
| 1769 |
+
|
| 1770 |
+
argnum = itertools.count()
|
| 1771 |
+
|
| 1772 |
+
class Converter(torch.fx.Interpreter):
|
| 1773 |
+
call_method = _not_implemented
|
| 1774 |
+
call_module = _not_implemented
|
| 1775 |
+
get_attr = _not_implemented
|
| 1776 |
+
|
| 1777 |
+
def placeholder(
|
| 1778 |
+
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
|
| 1779 |
+
) -> Union[ExclusiveKeywordArg, KeywordArg]:
|
| 1780 |
+
n = next(argnum)
|
| 1781 |
+
if n < len(argnames):
|
| 1782 |
+
name = argnames[n]
|
| 1783 |
+
elif argnames:
|
| 1784 |
+
assert target.startswith("tangent")
|
| 1785 |
+
name = target
|
| 1786 |
+
else:
|
| 1787 |
+
target = re.sub(r"_\d+$", "", target) # de-mangle arg name
|
| 1788 |
+
name = target
|
| 1789 |
+
if name in exclusive_arg_names:
|
| 1790 |
+
return ExclusiveKeywordArg(name)
|
| 1791 |
+
else:
|
| 1792 |
+
return KeywordArg(name)
|
| 1793 |
+
|
| 1794 |
+
def call_function(
|
| 1795 |
+
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
|
| 1796 |
+
) -> PatternExpr:
|
| 1797 |
+
args, kwargs = pytree.tree_map(process_arg, (args, kwargs))
|
| 1798 |
+
if list in ignore_types:
|
| 1799 |
+
# Handle a burned in tensor size which are now [Ignored(), Ignored(), ...]
|
| 1800 |
+
args = [process_arg(a) for a in args]
|
| 1801 |
+
kwargs = {k: process_arg(a) for k, a in kwargs.items()}
|
| 1802 |
+
return CallFunction(target, *args, **kwargs)
|
| 1803 |
+
|
| 1804 |
+
def run_node(self, n: torch.fx.Node) -> Any:
|
| 1805 |
+
rv = super().run_node(n)
|
| 1806 |
+
if n.op == "output" and isinstance(rv, tuple):
|
| 1807 |
+
assert len(rv) == len(n.args[0]) # type: ignore[arg-type]
|
| 1808 |
+
for r, arg in zip(rv, n.args[0]): # type: ignore[arg-type]
|
| 1809 |
+
r.users = len(arg.users)
|
| 1810 |
+
else:
|
| 1811 |
+
rv.users = len(n.users)
|
| 1812 |
+
return rv
|
| 1813 |
+
|
| 1814 |
+
pattern = Converter(gm).run() # type: ignore[arg-type]
|
| 1815 |
+
if not isinstance(pattern, PatternExpr):
|
| 1816 |
+
return MultiOutputPattern(pytree.tree_leaves(pattern))
|
| 1817 |
+
return pattern
|
| 1818 |
+
|
| 1819 |
+
|
| 1820 |
+
@torch.no_grad()
|
| 1821 |
+
def fwd_only(
|
| 1822 |
+
fn: Callable[..., Any],
|
| 1823 |
+
args: Sequence[Any],
|
| 1824 |
+
*,
|
| 1825 |
+
run_functional_passes: bool = True,
|
| 1826 |
+
get_decomp_fn: Optional[Callable[..., Any]] = None,
|
| 1827 |
+
) -> torch.fx.GraphModule:
|
| 1828 |
+
"""Build a normalized inference graph, for use with fx_to_pattern"""
|
| 1829 |
+
# TODO - look into using aot autograd, asserting no mutating ops here
|
| 1830 |
+
with enable_python_dispatcher():
|
| 1831 |
+
decompositions = (
|
| 1832 |
+
get_decomp_fn() if get_decomp_fn is not None else select_decomp_table()
|
| 1833 |
+
)
|
| 1834 |
+
gm = make_fx(fn, decompositions, tracing_mode="real")(*args)
|
| 1835 |
+
|
| 1836 |
+
from .fx_passes.post_grad import remove_noop_ops
|
| 1837 |
+
|
| 1838 |
+
if run_functional_passes:
|
| 1839 |
+
remove_noop_ops(gm.graph)
|
| 1840 |
+
gm.graph.eliminate_dead_code()
|
| 1841 |
+
|
| 1842 |
+
gm.recompile()
|
| 1843 |
+
return gm
|
| 1844 |
+
|
| 1845 |
+
|
| 1846 |
+
@torch.enable_grad()
|
| 1847 |
+
def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.GraphModule:
|
| 1848 |
+
"""Build a normalized training graph, for use with fx_to_pattern"""
|
| 1849 |
+
gm: Optional[torch.fx.GraphModule] = None
|
| 1850 |
+
|
| 1851 |
+
def record_joint_graph(
|
| 1852 |
+
joint_graph: torch.fx.GraphModule, inputs: Sequence[Any], **kwargs: Any
|
| 1853 |
+
) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
|
| 1854 |
+
nonlocal gm
|
| 1855 |
+
assert not gm
|
| 1856 |
+
gm = clone_graph(joint_graph)
|
| 1857 |
+
return default_partition(joint_graph, inputs, **kwargs)
|
| 1858 |
+
|
| 1859 |
+
with torch._guards.tracing(None):
|
| 1860 |
+
aot_function(
|
| 1861 |
+
fn,
|
| 1862 |
+
lambda g, i: make_boxed_func(g),
|
| 1863 |
+
partition_fn=record_joint_graph,
|
| 1864 |
+
decompositions=select_decomp_table(),
|
| 1865 |
+
keep_inference_input_mutations=True,
|
| 1866 |
+
enable_log=False,
|
| 1867 |
+
)(*args)
|
| 1868 |
+
assert gm
|
| 1869 |
+
|
| 1870 |
+
from .fx_passes.post_grad import remove_noop_ops
|
| 1871 |
+
|
| 1872 |
+
remove_noop_ops(gm.graph)
|
| 1873 |
+
|
| 1874 |
+
from .fx_passes.joint_graph import pointless_view
|
| 1875 |
+
|
| 1876 |
+
matcher_pass = PatternMatcherPass()
|
| 1877 |
+
|
| 1878 |
+
pattern = CallFunction(
|
| 1879 |
+
torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")
|
| 1880 |
+
)
|
| 1881 |
+
GraphPatternEntry(
|
| 1882 |
+
pattern=pattern, handler=pointless_view, extra_check=_return_true
|
| 1883 |
+
).register(matcher_pass.patterns)
|
| 1884 |
+
matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
| 1885 |
+
|
| 1886 |
+
# remove in/out specs
|
| 1887 |
+
gm.graph._codegen = torch.fx.graph.CodeGen()
|
| 1888 |
+
gm.graph.eliminate_dead_code()
|
| 1889 |
+
gm.recompile()
|
| 1890 |
+
return gm
|
| 1891 |
+
|
| 1892 |
+
|
| 1893 |
+
def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]:
|
| 1894 |
+
args: List[torch.fx.node.Argument] = []
|
| 1895 |
+
torch.fx.map_arg((n.args, n.kwargs), args.append)
|
| 1896 |
+
return args
|
| 1897 |
+
|
| 1898 |
+
|
| 1899 |
+
def stable_topological_sort(graph: torch.fx.Graph) -> None:
|
| 1900 |
+
# Nodes are in exactly one of these three collections:
|
| 1901 |
+
|
| 1902 |
+
# - Nodes in `pending` are waiting to be processed (in reverse order):
|
| 1903 |
+
pending = list(reversed(graph.nodes))
|
| 1904 |
+
|
| 1905 |
+
# - Nodes in `ready` have been processed and are already in the correct
|
| 1906 |
+
# order.
|
| 1907 |
+
ready = set()
|
| 1908 |
+
|
| 1909 |
+
# - `waiting` is a mapping from a dependency to nodes which depend on that
|
| 1910 |
+
# dependency.
|
| 1911 |
+
waiting = defaultdict(list)
|
| 1912 |
+
|
| 1913 |
+
# The cursor indicates the last processed node so we can add new nodes
|
| 1914 |
+
# after it.
|
| 1915 |
+
cursor = None
|
| 1916 |
+
while pending:
|
| 1917 |
+
node = pending.pop()
|
| 1918 |
+
waiting_for = [x for x in _args(node) if x not in ready]
|
| 1919 |
+
if waiting_for:
|
| 1920 |
+
# We have unprocessed input nodes. Might as well wait for the last
|
| 1921 |
+
# arg so an already sorted list will only recheck this node once.
|
| 1922 |
+
waiting[waiting_for[-1]].append(node)
|
| 1923 |
+
else:
|
| 1924 |
+
ready.add(node)
|
| 1925 |
+
if cursor and cursor.next is not node:
|
| 1926 |
+
cursor.append(node)
|
| 1927 |
+
cursor = node
|
| 1928 |
+
# Mark the nodes that have been waiting for this node to finish as
|
| 1929 |
+
# ready to check again.
|
| 1930 |
+
pending.extend(reversed(waiting.pop(node, ())))
|
| 1931 |
+
|
| 1932 |
+
assert not waiting and len(ready) == len(graph.nodes)
|
| 1933 |
+
|
| 1934 |
+
|
| 1935 |
+
def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]:
|
| 1936 |
+
"""Wrapper around lazy init functions in fx_passes/"""
|
| 1937 |
+
|
| 1938 |
+
@functools.lru_cache(None)
|
| 1939 |
+
@functools.wraps(fn)
|
| 1940 |
+
def lazy_init() -> Any:
|
| 1941 |
+
counters_ref = counters["inductor"].copy()
|
| 1942 |
+
|
| 1943 |
+
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
|
| 1944 |
+
result = fn()
|
| 1945 |
+
|
| 1946 |
+
# clear view matches encountered during tracing
|
| 1947 |
+
counters["inductor"] = counters_ref
|
| 1948 |
+
|
| 1949 |
+
return result
|
| 1950 |
+
|
| 1951 |
+
return lazy_init
|
| 1952 |
+
|
| 1953 |
+
|
| 1954 |
+
def config_flag(name: str) -> Callable[[Match], Any]:
|
| 1955 |
+
"""Function for extra_check to put pass behind a flag"""
|
| 1956 |
+
|
| 1957 |
+
def flag_check(match: Match) -> Any:
|
| 1958 |
+
return getattr(config, name)
|
| 1959 |
+
|
| 1960 |
+
return flag_check
|
| 1961 |
+
|
| 1962 |
+
|
| 1963 |
+
def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 1964 |
+
class CopyGraph(Transformer):
|
| 1965 |
+
def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node:
|
| 1966 |
+
new_node = super().run_node(old_node)
|
| 1967 |
+
if isinstance(new_node, torch.fx.Proxy):
|
| 1968 |
+
new_node.node.meta.update(old_node.meta)
|
| 1969 |
+
new_node.node.name = self.new_graph._graph_namespace.create_name(
|
| 1970 |
+
old_node.name, None
|
| 1971 |
+
)
|
| 1972 |
+
return new_node
|
| 1973 |
+
|
| 1974 |
+
return CopyGraph(input_graph).transform()
|
| 1975 |
+
|
| 1976 |
+
|
| 1977 |
+
_seen_patterns: Set[str] = set()
|
| 1978 |
+
|
| 1979 |
+
|
| 1980 |
+
def get_arg_value(
|
| 1981 |
+
node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None
|
| 1982 |
+
) -> Any:
|
| 1983 |
+
return (
|
| 1984 |
+
node.args[arg_number]
|
| 1985 |
+
if len(node.args) > arg_number
|
| 1986 |
+
else node.kwargs.get(kwarg_name) # type: ignore[arg-type]
|
| 1987 |
+
)
|
| 1988 |
+
|
| 1989 |
+
|
| 1990 |
+
def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> List[torch.fx.Node]:
|
| 1991 |
+
fns = [fn]
|
| 1992 |
+
if isinstance(fn, torch._ops.OpOverloadPacket):
|
| 1993 |
+
fns.extend([getattr(fn, overload) for overload in fn.overloads()])
|
| 1994 |
+
|
| 1995 |
+
return [node for node in nodes if node.target in fns]
|
| 1996 |
+
|
| 1997 |
+
|
| 1998 |
+
def extract_target(node: torch.fx.Node) -> torch.fx.node.Target:
|
| 1999 |
+
"""For call_function and call_method, we directly use the target function;
|
| 2000 |
+
For call_module, the target is string, and we treat the module class
|
| 2001 |
+
as a function.
|
| 2002 |
+
"""
|
| 2003 |
+
if node.op == "call_module":
|
| 2004 |
+
return getattr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type]
|
| 2005 |
+
return node.target
|
.venv/lib/python3.11/site-packages/torch/_inductor/quantized_lowerings.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._inductor.kernel.mm_common import mm_args
|
| 6 |
+
|
| 7 |
+
from . import config as inductor_config, lowering
|
| 8 |
+
from .codegen.cpp_gemm_template import CppPackedGemmTemplate
|
| 9 |
+
from .codegen.cpp_utils import create_epilogue_with_attr
|
| 10 |
+
from .lowering import expand, register_lowering
|
| 11 |
+
from .select_algorithm import (
|
| 12 |
+
autotune_select_algorithm,
|
| 13 |
+
ExternKernelChoice,
|
| 14 |
+
realize_inputs,
|
| 15 |
+
)
|
| 16 |
+
from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
log = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
aten__weight_int8pack_mm = ExternKernelChoice(
|
| 22 |
+
torch._weight_int8pack_mm, "at::_weight_int8pack_mm", has_out_variant=False
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
quantized = torch.ops.quantized
|
| 27 |
+
_quantized = torch.ops._quantized
|
| 28 |
+
aten = torch.ops.aten
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def register_quantized_ops():
|
| 32 |
+
lowering.add_needs_realized_inputs(
|
| 33 |
+
[
|
| 34 |
+
quantized.max_pool2d,
|
| 35 |
+
_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16,
|
| 36 |
+
_quantized.wrapped_fbgemm_linear_fp16_weight,
|
| 37 |
+
]
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
lowering.make_fallback(quantized.max_pool2d)
|
| 41 |
+
lowering.make_fallback(_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16)
|
| 42 |
+
lowering.make_fallback(_quantized.wrapped_fbgemm_linear_fp16_weight)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def register_woq_mm_ops():
|
| 46 |
+
@register_lowering(aten._weight_int8pack_mm, type_promotion_kind=None)
|
| 47 |
+
def int8pack_mm(input, weight, scale, *, layout=None):
|
| 48 |
+
_, _, _, layout, mat1, mat2 = mm_args(
|
| 49 |
+
input, weight, layout=layout, mat2_transposed=True
|
| 50 |
+
)
|
| 51 |
+
assert (
|
| 52 |
+
mat1.get_dtype() in [torch.bfloat16, torch.float16, torch.float]
|
| 53 |
+
and mat2.get_dtype() == torch.int8
|
| 54 |
+
)
|
| 55 |
+
aten_layout = layout
|
| 56 |
+
|
| 57 |
+
# options to tune from
|
| 58 |
+
choices = (
|
| 59 |
+
[aten__weight_int8pack_mm.bind((mat1, mat2, scale), aten_layout)]
|
| 60 |
+
if use_aten_gemm_kernels()
|
| 61 |
+
else []
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# scale is applied as an epilogue, and the scale tensor is expanded (with a view op)
|
| 65 |
+
# for broadcasting, as it's 1D.
|
| 66 |
+
def _mul_epilogue(buf):
|
| 67 |
+
return create_epilogue_with_attr(
|
| 68 |
+
buf, "mul", other=realize_inputs(expand(scale, layout.size))
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if use_cpp_packed_gemm_template(aten_layout, mat1, mat2, mat2_transposed=True):
|
| 72 |
+
CppPackedGemmTemplate.add_choices(
|
| 73 |
+
choices,
|
| 74 |
+
aten_layout,
|
| 75 |
+
[mat1, mat2, scale],
|
| 76 |
+
trans_w=True,
|
| 77 |
+
epilogue_creator=_mul_epilogue,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
if (
|
| 81 |
+
len(choices) == 0
|
| 82 |
+
and inductor_config.autotune_fallback_to_aten
|
| 83 |
+
and not use_aten_gemm_kernels()
|
| 84 |
+
):
|
| 85 |
+
log.warning("No choices for GEMM, using ATen backend as fallback")
|
| 86 |
+
return aten__weight_int8pack_mm.bind(
|
| 87 |
+
(mat1, mat2, scale), aten_layout
|
| 88 |
+
).output_node()
|
| 89 |
+
|
| 90 |
+
return autotune_select_algorithm(
|
| 91 |
+
"_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout
|
| 92 |
+
)
|
.venv/lib/python3.11/site-packages/torch/_inductor/remote_cache.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import typing
|
| 6 |
+
from abc import abstractmethod
|
| 7 |
+
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union
|
| 8 |
+
from typing_extensions import override, TypeAlias
|
| 9 |
+
|
| 10 |
+
from torch._inductor import config
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import redis
|
| 15 |
+
except ImportError:
|
| 16 |
+
redis = None # type: ignore[assignment]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if config.is_fbcode():
|
| 20 |
+
from rfe.scubadata.scubadata_py3 import ( # type: ignore[import-not-found]
|
| 21 |
+
Sample as Sample_,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
Sample: TypeAlias = Sample_
|
| 25 |
+
else:
|
| 26 |
+
Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
_T = TypeVar("_T")
|
| 30 |
+
_U = TypeVar("_U")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class RemoteCacheBackend(Generic[_T]):
|
| 34 |
+
"""
|
| 35 |
+
A backend implementation for accessing a remote/distributed cache. Only
|
| 36 |
+
works with bytes in/out. For structured data use a RemoteCache.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def get(self, key: str) -> Optional[_T]:
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
@abstractmethod
|
| 44 |
+
def put(self, key: str, data: _T) -> None:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Serde that encodes from _T to _U and decodes from _U to _T.
|
| 49 |
+
class RemoteCacheSerde(Generic[_T, _U]):
|
| 50 |
+
@abstractmethod
|
| 51 |
+
def encode(self, data: _T) -> _U:
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
@abstractmethod
|
| 55 |
+
def decode(self, data: _U) -> _T:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
JsonDataTy = Optional[
|
| 60 |
+
Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]]
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]):
|
| 65 |
+
def encode(self, data: JsonDataTy) -> bytes:
|
| 66 |
+
return bytes(json.dumps(data), "ascii")
|
| 67 |
+
|
| 68 |
+
def decode(self, data: bytes) -> JsonDataTy:
|
| 69 |
+
return json.loads(data)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]):
|
| 73 |
+
def encode(self, data: _T) -> _T:
|
| 74 |
+
return data
|
| 75 |
+
|
| 76 |
+
def decode(self, data: _T) -> _T:
|
| 77 |
+
return data
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class RemoteCache(Generic[_T]):
|
| 81 |
+
backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U]
|
| 85 |
+
) -> None:
|
| 86 |
+
# Support for testing.
|
| 87 |
+
if (override_cls := self.__class__.backend_override_cls) is not None:
|
| 88 |
+
self.backend = override_cls()
|
| 89 |
+
else:
|
| 90 |
+
self.backend = backend
|
| 91 |
+
self.serde = serde
|
| 92 |
+
|
| 93 |
+
def get(self, key: str) -> Optional[_T]:
|
| 94 |
+
sample = self._create_sample()
|
| 95 |
+
result = self._get(key, sample)
|
| 96 |
+
self._log_sample(sample)
|
| 97 |
+
return result
|
| 98 |
+
|
| 99 |
+
def put(self, key: str, value: _T) -> None:
|
| 100 |
+
sample = self._create_sample()
|
| 101 |
+
self._put(key, value, sample)
|
| 102 |
+
self._log_sample(sample)
|
| 103 |
+
|
| 104 |
+
def _decode(self, data: _U, sample: Optional[Sample]) -> _T:
|
| 105 |
+
return self.serde.decode(data)
|
| 106 |
+
|
| 107 |
+
def _encode(self, value: _T, sample: Optional[Sample]) -> Any: # returns _U
|
| 108 |
+
return self.serde.encode(value)
|
| 109 |
+
|
| 110 |
+
def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]:
|
| 111 |
+
if data := self.backend.get(key):
|
| 112 |
+
return self._decode(data, sample)
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None:
|
| 116 |
+
data = self._encode(value, sample)
|
| 117 |
+
self.backend.put(key, data)
|
| 118 |
+
|
| 119 |
+
def _create_sample(self) -> Optional[Sample]:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
def _log_sample(self, sample: Optional[Sample]) -> None:
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]):
|
| 127 |
+
"""
|
| 128 |
+
A Redis implementation of a remote/distributed cache.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
_key_fmt: str
|
| 132 |
+
_redis: Optional[redis.Redis] = None
|
| 133 |
+
|
| 134 |
+
def __init__(self, cache_id: str) -> None:
|
| 135 |
+
if not redis:
|
| 136 |
+
# We had trouble importing redis - just skip init.
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
self._key_fmt = f"pt2:{cache_id}:{{key}}"
|
| 140 |
+
self._redis = redis.Redis(
|
| 141 |
+
host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"),
|
| 142 |
+
port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def __get_key(self, key: str) -> str:
|
| 146 |
+
return self._key_fmt.format(key=key)
|
| 147 |
+
|
| 148 |
+
@override
|
| 149 |
+
def get(self, key: str) -> Optional[bytes]:
|
| 150 |
+
if not self._redis:
|
| 151 |
+
# Either redis wasn't found or we already had some trouble...
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
try:
|
| 155 |
+
value = self._redis.get(self.__get_key(key))
|
| 156 |
+
except redis.exceptions.ConnectionError:
|
| 157 |
+
# Redis is lazy and doesn't actually attempt to connect until the
|
| 158 |
+
# first use. Mark is as unavailable now.
|
| 159 |
+
self._redis = None
|
| 160 |
+
return None
|
| 161 |
+
|
| 162 |
+
# In theory redis.get() can return an Awaitable as well...
|
| 163 |
+
assert value is None or isinstance(value, bytes)
|
| 164 |
+
return value
|
| 165 |
+
|
| 166 |
+
@override
|
| 167 |
+
def put(self, key: str, data: bytes) -> None:
|
| 168 |
+
if not self._redis:
|
| 169 |
+
# Either redis wasn't found or we already had some trouble...
|
| 170 |
+
return
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
self._redis.set(self.__get_key(key), data)
|
| 174 |
+
except redis.exceptions.ConnectionError:
|
| 175 |
+
# Redis is lazy and doesn't actually attempt to connect until the
|
| 176 |
+
# first use. Mark is as unavailable now.
|
| 177 |
+
self._redis = None
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class RedisRemoteCache(RemoteCache[JsonDataTy]):
|
| 181 |
+
def __init__(self, key: str) -> None:
|
| 182 |
+
# Special test handling: If we're just going to override the backend
|
| 183 |
+
# anyway don't require redis
|
| 184 |
+
if self.__class__.backend_override_cls:
|
| 185 |
+
# This is totally bogus but it works for now...
|
| 186 |
+
backend = typing.cast(RemoteCacheBackend[bytes], None)
|
| 187 |
+
else:
|
| 188 |
+
backend = RedisRemoteCacheBackend(key)
|
| 189 |
+
serde = RemoteCacheJsonSerde()
|
| 190 |
+
super().__init__(backend, serde)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class RemoteAutotuneCache(RedisRemoteCache):
|
| 194 |
+
pass
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class RemoteFxGraphCache(RedisRemoteCache):
|
| 198 |
+
pass
|
.venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py
ADDED
|
@@ -0,0 +1,1743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import builtins
|
| 3 |
+
import contextlib
|
| 4 |
+
import functools
|
| 5 |
+
import inspect
|
| 6 |
+
import itertools
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import math
|
| 10 |
+
import operator
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import textwrap
|
| 14 |
+
import time
|
| 15 |
+
from collections import namedtuple
|
| 16 |
+
from concurrent.futures import as_completed, ThreadPoolExecutor
|
| 17 |
+
from io import StringIO
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
from unittest.mock import patch
|
| 20 |
+
|
| 21 |
+
import sympy
|
| 22 |
+
from filelock import FileLock
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
| 26 |
+
from torch._dynamo.testing import rand_strided
|
| 27 |
+
from torch._dynamo.utils import counters, identity, preserve_rng_state
|
| 28 |
+
|
| 29 |
+
from . import config, ir
|
| 30 |
+
from .autotune_process import TensorMeta, TritonBenchmarkRequest
|
| 31 |
+
from .codecache import code_hash, PersistentCache, PyCodeCache
|
| 32 |
+
from .codegen.common import IndentedBuffer, KernelTemplate
|
| 33 |
+
from .codegen.triton import (
|
| 34 |
+
gen_common_triton_imports,
|
| 35 |
+
texpr,
|
| 36 |
+
TritonKernel,
|
| 37 |
+
TritonPrinter,
|
| 38 |
+
TritonScheduling,
|
| 39 |
+
)
|
| 40 |
+
from .codegen.triton_utils import config_of, signature_to_meta
|
| 41 |
+
from .exc import CUDACompileError
|
| 42 |
+
from .ir import ChoiceCaller, PrimitiveInfoType
|
| 43 |
+
from .runtime.benchmarking import benchmarker
|
| 44 |
+
from .runtime.hints import DeviceProperties
|
| 45 |
+
from .utils import (
|
| 46 |
+
FakeIndentedBuffer,
|
| 47 |
+
get_dtype_size,
|
| 48 |
+
Placeholder,
|
| 49 |
+
restore_stdout_stderr,
|
| 50 |
+
sympy_dot,
|
| 51 |
+
sympy_index_symbol,
|
| 52 |
+
sympy_product,
|
| 53 |
+
unique,
|
| 54 |
+
)
|
| 55 |
+
from .virtualized import V
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
log = logging.getLogger(__name__)
|
| 59 |
+
|
| 60 |
+
# correctness checks struggle with fp16/tf32
|
| 61 |
+
VERIFY: Dict[str, Any] = {}
|
| 62 |
+
PRINT_AUTOTUNE = True
|
| 63 |
+
DEBUG = False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class KernelNamespace:
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# these objects are imported from the generated wrapper code
|
| 71 |
+
extern_kernels = KernelNamespace()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class PartialRender:
|
| 75 |
+
"""
|
| 76 |
+
Some parts of a template need to be generated at the end, but
|
| 77 |
+
inserted into the template at the start. This allows doing a bunch
|
| 78 |
+
of replacements after the initial render.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, code, replacement_hooks) -> None:
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.code = code
|
| 84 |
+
self.replacement_hooks = replacement_hooks
|
| 85 |
+
|
| 86 |
+
def finalize_hook(self, hook_key: str, strict=True) -> None:
|
| 87 |
+
if hook_key not in self.replacement_hooks:
|
| 88 |
+
if strict:
|
| 89 |
+
raise RuntimeError(
|
| 90 |
+
f"{hook_key} not registered in self.replacement_hooks"
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
return
|
| 94 |
+
assert (
|
| 95 |
+
self.replacement_hooks[hook_key] is not None
|
| 96 |
+
), "hook_key can only be called once"
|
| 97 |
+
self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]())
|
| 98 |
+
self.replacement_hooks[hook_key] = None
|
| 99 |
+
|
| 100 |
+
def finalize_all(self) -> str:
|
| 101 |
+
for key, fn in self.replacement_hooks.items():
|
| 102 |
+
self.code = self.code.replace(key, fn())
|
| 103 |
+
return self.code
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# This is used to store info needed for lowering each subgraph in triton
|
| 107 |
+
# templates
|
| 108 |
+
SubgraphInfo = namedtuple(
|
| 109 |
+
"SubgraphInfo",
|
| 110 |
+
[
|
| 111 |
+
"body",
|
| 112 |
+
"template_mask",
|
| 113 |
+
"template_out",
|
| 114 |
+
],
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class TritonTemplateKernel(TritonKernel):
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
kernel_name,
|
| 122 |
+
input_nodes,
|
| 123 |
+
output_node,
|
| 124 |
+
defines,
|
| 125 |
+
num_stages,
|
| 126 |
+
num_warps,
|
| 127 |
+
grid_fn,
|
| 128 |
+
meta,
|
| 129 |
+
call_sizes,
|
| 130 |
+
use_jit=False,
|
| 131 |
+
prefix_args=0,
|
| 132 |
+
suffix_args=0,
|
| 133 |
+
epilogue_fn=identity,
|
| 134 |
+
subgraphs: Optional[List[ir.ComputedBuffer]] = None,
|
| 135 |
+
*,
|
| 136 |
+
index_dtype,
|
| 137 |
+
) -> None:
|
| 138 |
+
super().__init__(
|
| 139 |
+
sympy_product(output_node.get_size()),
|
| 140 |
+
sympy.Integer(1),
|
| 141 |
+
index_dtype=index_dtype,
|
| 142 |
+
)
|
| 143 |
+
self.input_nodes = input_nodes
|
| 144 |
+
self.output_node = output_node
|
| 145 |
+
self.named_input_nodes = {} # type: ignore[var-annotated]
|
| 146 |
+
self.defines = defines
|
| 147 |
+
self.kernel_name = kernel_name
|
| 148 |
+
self.use_jit = use_jit
|
| 149 |
+
self.num_stages = num_stages
|
| 150 |
+
self.num_warps = num_warps
|
| 151 |
+
self.grid_fn = grid_fn
|
| 152 |
+
self.meta = meta
|
| 153 |
+
self.call_sizes = call_sizes
|
| 154 |
+
# for templates with fixed epilogues
|
| 155 |
+
self.prefix_args = prefix_args
|
| 156 |
+
self.suffix_args = suffix_args
|
| 157 |
+
self.epilogue_fn = epilogue_fn
|
| 158 |
+
self.render_hooks = {} # type: ignore[var-annotated]
|
| 159 |
+
self.triton_meta: Optional[Dict[str, object]] = None
|
| 160 |
+
# For Templated Attention this can be a list of ir.Subgraph
|
| 161 |
+
self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs
|
| 162 |
+
|
| 163 |
+
# The following attributes (body, template_mask, output_val) are all
|
| 164 |
+
# used for triton kernel codegen.
|
| 165 |
+
# They are swapped onto the TritonTemplateKernel object by
|
| 166 |
+
# `set_subgraph_body`
|
| 167 |
+
self.subgraph_bodies: Dict[str, SubgraphInfo] = {}
|
| 168 |
+
|
| 169 |
+
self.body: IndentedBuffer = FakeIndentedBuffer()
|
| 170 |
+
self.template_mask: Optional[str] = None
|
| 171 |
+
self.template_out: Optional[str] = None
|
| 172 |
+
|
| 173 |
+
@contextlib.contextmanager
|
| 174 |
+
def set_subgraph_body(self, body_name: str):
|
| 175 |
+
old_body, old_mask, old_out = self.body, self.template_mask, self.template_out
|
| 176 |
+
assert body_name in self.subgraph_bodies, body_name
|
| 177 |
+
self.body, self.template_mask, self.template_out = self.subgraph_bodies[
|
| 178 |
+
body_name
|
| 179 |
+
]
|
| 180 |
+
yield
|
| 181 |
+
self.subgraph_bodies[body_name] = SubgraphInfo(
|
| 182 |
+
self.body, self.template_mask, self.template_out
|
| 183 |
+
)
|
| 184 |
+
self.body, self.template_mask, self.template_out = old_body, old_mask, old_out
|
| 185 |
+
|
| 186 |
+
@contextlib.contextmanager
|
| 187 |
+
def create_subgraph_body(self, body_name: str):
|
| 188 |
+
assert body_name not in self.subgraph_bodies
|
| 189 |
+
self.subgraph_bodies[body_name] = SubgraphInfo(IndentedBuffer(), None, None)
|
| 190 |
+
with self.set_subgraph_body(body_name):
|
| 191 |
+
yield
|
| 192 |
+
|
| 193 |
+
def need_numel_args(self):
|
| 194 |
+
return False
|
| 195 |
+
|
| 196 |
+
def estimate_kernel_num_bytes(self):
|
| 197 |
+
"""
|
| 198 |
+
Estimate the total number of bytes this kernel takes.
|
| 199 |
+
For in/out nodes, sizes are counted twice: once for reading and
|
| 200 |
+
once for writing.
|
| 201 |
+
"""
|
| 202 |
+
ninplace_args = len(unique(self.args.inplace_buffers.values()))
|
| 203 |
+
num_bytes = []
|
| 204 |
+
for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))):
|
| 205 |
+
size = V.graph.sizevars.size_hints(inp.get_size())
|
| 206 |
+
numel = functools.reduce(operator.mul, size, 1)
|
| 207 |
+
dtype_size = get_dtype_size(inp.get_dtype())
|
| 208 |
+
num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
|
| 209 |
+
return sum(num_bytes)
|
| 210 |
+
|
| 211 |
+
def jit_lines(self):
|
| 212 |
+
if self.use_jit:
|
| 213 |
+
return "@triton.jit"
|
| 214 |
+
|
| 215 |
+
argdefs, _, signature, _ = self.args.python_argdefs()
|
| 216 |
+
triton_meta = {
|
| 217 |
+
"signature": signature_to_meta(signature, size_dtype=self.index_dtype),
|
| 218 |
+
"device": DeviceProperties.create(self.output_node.get_device()),
|
| 219 |
+
"constants": {},
|
| 220 |
+
}
|
| 221 |
+
triton_meta["configs"] = [config_of(signature)]
|
| 222 |
+
for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index]
|
| 223 |
+
triton_meta["constants"][arg_num] = 1 # type: ignore[index]
|
| 224 |
+
matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0)
|
| 225 |
+
if matrix_instr_nonkdim != 0:
|
| 226 |
+
triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim
|
| 227 |
+
|
| 228 |
+
self.triton_meta = triton_meta
|
| 229 |
+
|
| 230 |
+
inductor_meta = {
|
| 231 |
+
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
|
| 232 |
+
**TritonKernel.inductor_meta_common(),
|
| 233 |
+
}
|
| 234 |
+
if config.profile_bandwidth or config.benchmark_kernel:
|
| 235 |
+
num_gb = self.estimate_kernel_num_bytes() / 1e9
|
| 236 |
+
inductor_meta["kernel_num_gb"] = num_gb
|
| 237 |
+
return f"""
|
| 238 |
+
@triton_heuristics.template(
|
| 239 |
+
num_stages={self.num_stages},
|
| 240 |
+
num_warps={self.num_warps},
|
| 241 |
+
triton_meta={triton_meta!r},
|
| 242 |
+
inductor_meta={inductor_meta!r},
|
| 243 |
+
)
|
| 244 |
+
@triton.jit
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def gen_argdefs(self):
|
| 248 |
+
def hook():
|
| 249 |
+
# python_argdefs() cannot be run until after the rest of the template lazily adds more args
|
| 250 |
+
arg_defs, *_ = self.args.python_argdefs()
|
| 251 |
+
return f"{', '.join(arg_defs)}"
|
| 252 |
+
|
| 253 |
+
self.render_hooks["<ARGDEFS>"] = hook
|
| 254 |
+
return "<ARGDEFS>"
|
| 255 |
+
|
| 256 |
+
def gen_defines(self):
|
| 257 |
+
return self.defines
|
| 258 |
+
|
| 259 |
+
def def_kernel(self, *argnames):
|
| 260 |
+
"""
|
| 261 |
+
Hook called from template code to generate function def and
|
| 262 |
+
needed args.
|
| 263 |
+
"""
|
| 264 |
+
assert all(isinstance(x, str) for x in argnames)
|
| 265 |
+
renames = IndentedBuffer(initial_indent=1)
|
| 266 |
+
|
| 267 |
+
named_args = self.input_nodes[
|
| 268 |
+
self.prefix_args : len(self.input_nodes) - self.suffix_args
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
assert len(argnames) == len(named_args), (
|
| 272 |
+
len(argnames),
|
| 273 |
+
len(named_args),
|
| 274 |
+
self.prefix_args,
|
| 275 |
+
len(self.input_nodes),
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
for input_node in self.input_nodes[: self.prefix_args]:
|
| 279 |
+
# get args in correct order
|
| 280 |
+
self.args.input(input_node.get_name())
|
| 281 |
+
|
| 282 |
+
for name, input_node in zip(argnames, named_args):
|
| 283 |
+
arg_name = f"arg_{name}"
|
| 284 |
+
self.named_input_nodes[name] = input_node
|
| 285 |
+
self.args.input_buffers[input_node.get_name()] = arg_name
|
| 286 |
+
|
| 287 |
+
# The args may be duplicated, so renaming must be after args are de-duplicated.
|
| 288 |
+
for name in argnames:
|
| 289 |
+
input_node = self.named_input_nodes[name]
|
| 290 |
+
arg_name = self.args.input_buffers[input_node.get_name()]
|
| 291 |
+
if input_node.get_layout().offset == 0:
|
| 292 |
+
renames.writeline(f"{name} = {arg_name}")
|
| 293 |
+
else:
|
| 294 |
+
offset = texpr(self.rename_indexing(input_node.get_layout().offset))
|
| 295 |
+
renames.writeline(f"{name} = {arg_name} + {offset}")
|
| 296 |
+
|
| 297 |
+
for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]:
|
| 298 |
+
# get args in correct order
|
| 299 |
+
self.args.input(input_node.get_name())
|
| 300 |
+
|
| 301 |
+
def hook():
|
| 302 |
+
# python_argdefs() cannot be run until after the rest of the template lazily adds more args
|
| 303 |
+
arg_defs, *_ = self.args.python_argdefs()
|
| 304 |
+
code = IndentedBuffer()
|
| 305 |
+
code.splice(gen_common_triton_imports())
|
| 306 |
+
code.splice(self.jit_lines())
|
| 307 |
+
code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):")
|
| 308 |
+
with code.indent():
|
| 309 |
+
code.splice(self.defines)
|
| 310 |
+
code.splice(renames.getvalue())
|
| 311 |
+
return code.getvalue()
|
| 312 |
+
|
| 313 |
+
assert "<DEF_KERNEL>" not in self.render_hooks
|
| 314 |
+
self.render_hooks["<DEF_KERNEL>"] = hook
|
| 315 |
+
return "<DEF_KERNEL>"
|
| 316 |
+
|
| 317 |
+
def size(self, name: str, index: int):
|
| 318 |
+
"""
|
| 319 |
+
Hook called from template code to get the size of an arg.
|
| 320 |
+
Will add needed args to pass it in if it is dynamic.
|
| 321 |
+
"""
|
| 322 |
+
assert isinstance(index, int)
|
| 323 |
+
if name is None:
|
| 324 |
+
val = self.output_node.get_size()[index]
|
| 325 |
+
else:
|
| 326 |
+
assert isinstance(name, str)
|
| 327 |
+
val = self.named_input_nodes[name].get_size()[index]
|
| 328 |
+
return texpr(self.rename_indexing(val))
|
| 329 |
+
|
| 330 |
+
def stride(self, name, index=None):
|
| 331 |
+
"""
|
| 332 |
+
Hook called from template code to get the stride of an arg.
|
| 333 |
+
Will add needed args to pass it in if it is dynamic.
|
| 334 |
+
"""
|
| 335 |
+
if name is None:
|
| 336 |
+
val = self.output_node.get_stride()
|
| 337 |
+
else:
|
| 338 |
+
assert isinstance(name, str)
|
| 339 |
+
val = self.named_input_nodes[name].get_stride()
|
| 340 |
+
|
| 341 |
+
if isinstance(index, int):
|
| 342 |
+
return texpr(self.rename_indexing(val[index]))
|
| 343 |
+
else:
|
| 344 |
+
return ", ".join([texpr(self.rename_indexing(i)) for i in val])
|
| 345 |
+
|
| 346 |
+
def modification(
|
| 347 |
+
self, subgraph_number: int, output_name: str, **fixed_inputs
|
| 348 |
+
) -> str:
|
| 349 |
+
"""This creates a modification function for a subgraph.
|
| 350 |
+
To use this inside a template, the first argument should specify which subgraph to codegen for
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
subgraph_number (int): The index of the subgraph in self.subgraphs
|
| 354 |
+
"""
|
| 355 |
+
num = 0
|
| 356 |
+
while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies:
|
| 357 |
+
num += 1
|
| 358 |
+
with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"):
|
| 359 |
+
assert isinstance(subgraph_number, int)
|
| 360 |
+
assert isinstance(self.subgraphs, list)
|
| 361 |
+
assert (
|
| 362 |
+
self.body.getvalue() == ""
|
| 363 |
+
), "Body should be clear before adding a modification"
|
| 364 |
+
assert subgraph_number < len(
|
| 365 |
+
self.subgraphs
|
| 366 |
+
), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
|
| 367 |
+
|
| 368 |
+
subgraph = self.subgraphs[subgraph_number]
|
| 369 |
+
|
| 370 |
+
def add_input(name):
|
| 371 |
+
return self.args.input(name)
|
| 372 |
+
|
| 373 |
+
name = f"PlaceholderSubstitution_{subgraph_number}"
|
| 374 |
+
|
| 375 |
+
class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined]
|
| 376 |
+
self.name = name
|
| 377 |
+
|
| 378 |
+
def load(self, name: str, index: sympy.Expr):
|
| 379 |
+
if name not in fixed_inputs:
|
| 380 |
+
# If it's not a fixed input, it's a load from a captured
|
| 381 |
+
# tensor
|
| 382 |
+
var = add_input(name)
|
| 383 |
+
return f"tl.load({var} + {index})"
|
| 384 |
+
|
| 385 |
+
return f"({fixed_inputs[name]})"
|
| 386 |
+
|
| 387 |
+
def indirect_indexing(self, index_var, size, check, wrap_neg=True):
|
| 388 |
+
return sympy_index_symbol(str(index_var))
|
| 389 |
+
|
| 390 |
+
with V.set_ops_handler(PlaceholderSubstitution(V.ops)):
|
| 391 |
+
assert isinstance(
|
| 392 |
+
subgraph, ir.ComputedBuffer
|
| 393 |
+
), f"Expected the subgraph to be a ComputedBuffer, got {type(subgraph)}"
|
| 394 |
+
if isinstance(subgraph.data, ir.InputBuffer):
|
| 395 |
+
out = subgraph.data.make_loader()(())
|
| 396 |
+
else:
|
| 397 |
+
out = subgraph.data.inner_fn(())
|
| 398 |
+
|
| 399 |
+
self.codegen_body()
|
| 400 |
+
self.body.writeline(f"{output_name} = {out.value}")
|
| 401 |
+
|
| 402 |
+
body_val = self.body.getvalue()
|
| 403 |
+
self.cse.invalidate(set()) # type: ignore[arg-type]
|
| 404 |
+
return body_val
|
| 405 |
+
|
| 406 |
+
def store_output(
|
| 407 |
+
self,
|
| 408 |
+
indices: Union[List[Any], Tuple[Any]],
|
| 409 |
+
val: str,
|
| 410 |
+
mask: Optional[str] = None,
|
| 411 |
+
indent_width: int = 4,
|
| 412 |
+
):
|
| 413 |
+
"""Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
indices (Union[List, Tuple]): The index for each dimension of the output. The dot product of
|
| 417 |
+
these indices and output strides must match `val`.
|
| 418 |
+
val (str): The value to store.
|
| 419 |
+
mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask
|
| 420 |
+
will be applied to the store.
|
| 421 |
+
indent_width (int): The number of spaces to use for indentation. This is used when the call to
|
| 422 |
+
store_output is indented in the kernel definition.
|
| 423 |
+
"""
|
| 424 |
+
with self.create_subgraph_body("<STORE_OUTPUT>"):
|
| 425 |
+
assert isinstance(indices, (list, tuple))
|
| 426 |
+
assert isinstance(val, str)
|
| 427 |
+
assert isinstance(mask, (str, type(None)))
|
| 428 |
+
assert self.template_mask is None
|
| 429 |
+
indices = list(map(TritonPrinter.paren, indices))
|
| 430 |
+
index_symbols = [sympy.Symbol(x, integer=True) for x in indices]
|
| 431 |
+
lengths = [
|
| 432 |
+
V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
|
| 433 |
+
]
|
| 434 |
+
assert len(indices) == len(lengths)
|
| 435 |
+
|
| 436 |
+
# glue to make generated code use same indexing from template
|
| 437 |
+
for name, range_tree_entry in zip(
|
| 438 |
+
indices, self.range_trees[0].construct_entries(lengths)
|
| 439 |
+
):
|
| 440 |
+
range_tree_entry.set_name(name)
|
| 441 |
+
contiguous_index = sympy_dot(
|
| 442 |
+
ir.FlexibleLayout.contiguous_strides(lengths), index_symbols
|
| 443 |
+
)
|
| 444 |
+
contiguous_index = self.rename_indexing(contiguous_index)
|
| 445 |
+
self.body.writeline("xindex = " + texpr(contiguous_index))
|
| 446 |
+
self.range_trees[0].lookup(
|
| 447 |
+
sympy.Integer(1), sympy_product(lengths)
|
| 448 |
+
).set_name("xindex")
|
| 449 |
+
self.template_mask = mask
|
| 450 |
+
self.template_out = val
|
| 451 |
+
self.template_indices = indices
|
| 452 |
+
output_index = self.output_node.get_layout().make_indexer()(index_symbols)
|
| 453 |
+
output_index = self.rename_indexing(output_index)
|
| 454 |
+
if output_index == contiguous_index:
|
| 455 |
+
output_index = sympy.Symbol("xindex", integer=True)
|
| 456 |
+
|
| 457 |
+
epilogue_args = [val]
|
| 458 |
+
for input_node in itertools.chain(
|
| 459 |
+
self.input_nodes[: self.prefix_args],
|
| 460 |
+
self.input_nodes[len(self.input_nodes) - self.suffix_args :],
|
| 461 |
+
):
|
| 462 |
+
input_node.freeze_layout()
|
| 463 |
+
epilogue_args.append(input_node.make_loader()(index_symbols))
|
| 464 |
+
|
| 465 |
+
V.ops.store(
|
| 466 |
+
self.output_node.get_name(),
|
| 467 |
+
output_index,
|
| 468 |
+
self.epilogue_fn(*epilogue_args),
|
| 469 |
+
)
|
| 470 |
+
self.codegen_body()
|
| 471 |
+
|
| 472 |
+
def hook():
|
| 473 |
+
# more stuff might have been added since the codegen_body above
|
| 474 |
+
self.codegen_body()
|
| 475 |
+
|
| 476 |
+
return textwrap.indent(self.body.getvalue(), " " * indent_width).strip()
|
| 477 |
+
|
| 478 |
+
assert "<STORE_OUTPUT>" not in self.render_hooks
|
| 479 |
+
self.render_hooks["<STORE_OUTPUT>"] = hook
|
| 480 |
+
return "<STORE_OUTPUT>"
|
| 481 |
+
|
| 482 |
+
def render(self, template, kwargs):
|
| 483 |
+
return PartialRender(
|
| 484 |
+
template.render(**self.template_env(), **kwargs),
|
| 485 |
+
self.render_hooks,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
def make_load(self, name, indices, mask):
|
| 489 |
+
"""
|
| 490 |
+
Optional helper called from template code to generate the code
|
| 491 |
+
needed to load from an tensor.
|
| 492 |
+
"""
|
| 493 |
+
assert isinstance(indices, (list, tuple))
|
| 494 |
+
assert isinstance(name, str)
|
| 495 |
+
assert isinstance(mask, str)
|
| 496 |
+
stride = self.named_input_nodes[name].get_stride()
|
| 497 |
+
indices = list(map(TritonPrinter.paren, indices))
|
| 498 |
+
assert len(indices) == len(stride)
|
| 499 |
+
index = " + ".join(
|
| 500 |
+
f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)
|
| 501 |
+
)
|
| 502 |
+
return f"tl.load({name} + ({index}), {mask}, other=0.0)"
|
| 503 |
+
|
| 504 |
+
def template_env(self):
|
| 505 |
+
"""
|
| 506 |
+
Generate the namespace visible in the template.
|
| 507 |
+
"""
|
| 508 |
+
return {
|
| 509 |
+
fn.__name__: fn
|
| 510 |
+
for fn in [
|
| 511 |
+
self.def_kernel,
|
| 512 |
+
self.size,
|
| 513 |
+
self.stride,
|
| 514 |
+
self.store_output,
|
| 515 |
+
self.make_load,
|
| 516 |
+
self.modification,
|
| 517 |
+
self.gen_argdefs,
|
| 518 |
+
self.gen_defines,
|
| 519 |
+
]
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
def indexing(
|
| 523 |
+
self,
|
| 524 |
+
index: sympy.Expr,
|
| 525 |
+
*,
|
| 526 |
+
dense_indexing=False,
|
| 527 |
+
copy_shape=None,
|
| 528 |
+
override_mask=None,
|
| 529 |
+
block_ptr=False,
|
| 530 |
+
):
|
| 531 |
+
"""
|
| 532 |
+
Override the default indexing to use our custom mask and force
|
| 533 |
+
dense indexing.
|
| 534 |
+
"""
|
| 535 |
+
return super().indexing(
|
| 536 |
+
index,
|
| 537 |
+
dense_indexing=False,
|
| 538 |
+
# We pass template_out as the shape to broadcast the indexing to as
|
| 539 |
+
# the mask might be broadcast to the output shape
|
| 540 |
+
copy_shape=self.template_out,
|
| 541 |
+
override_mask=self.template_mask,
|
| 542 |
+
block_ptr=block_ptr,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
def codegen_range_tree(self):
|
| 546 |
+
pass # ignore default codegen
|
| 547 |
+
|
| 548 |
+
def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
|
| 549 |
+
wrapper = V.graph.wrapper_code
|
| 550 |
+
_, call_args, _, arg_types = self.args.python_argdefs()
|
| 551 |
+
if V.graph.cpp_wrapper:
|
| 552 |
+
# In the cpp_wrapper case, we have to compute CUDA launch grid at runtime
|
| 553 |
+
# if any dynamic dimension is involved. We rely on the Python version
|
| 554 |
+
# of the grid function to generate those grid configs, which may contain
|
| 555 |
+
# symbolic values. The wrapper will use cexpr to print out C++ code
|
| 556 |
+
# appropriately for the grid configs.
|
| 557 |
+
grid = self.call_sizes + [self.meta]
|
| 558 |
+
wrapper.generate_kernel_call(
|
| 559 |
+
name,
|
| 560 |
+
call_args,
|
| 561 |
+
grid=self.grid_fn(*grid),
|
| 562 |
+
arg_types=arg_types,
|
| 563 |
+
triton_meta=self.triton_meta,
|
| 564 |
+
)
|
| 565 |
+
else:
|
| 566 |
+
wrapper.add_import_once(f"import {self.grid_fn.__module__}")
|
| 567 |
+
meta = wrapper.add_meta_once(self.meta)
|
| 568 |
+
grid = self.call_sizes + [meta]
|
| 569 |
+
wrapper.generate_kernel_call(
|
| 570 |
+
name,
|
| 571 |
+
call_args,
|
| 572 |
+
grid=grid,
|
| 573 |
+
grid_fn=f"{self.grid_fn.__module__}.{self.grid_fn.__name__}",
|
| 574 |
+
arg_types=arg_types,
|
| 575 |
+
triton_meta=self.triton_meta,
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
@functools.lru_cache(None)
|
| 580 |
+
def _jinja2_env():
|
| 581 |
+
try:
|
| 582 |
+
import jinja2
|
| 583 |
+
|
| 584 |
+
return jinja2.Environment(
|
| 585 |
+
undefined=jinja2.StrictUndefined,
|
| 586 |
+
)
|
| 587 |
+
except ImportError:
|
| 588 |
+
return None
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
class TritonTemplate(KernelTemplate):
|
| 592 |
+
index_counter = itertools.count()
|
| 593 |
+
all_templates: Dict[str, "TritonTemplate"] = {}
|
| 594 |
+
|
| 595 |
+
def __init__(self, name: str, grid: Any, source: str, debug=False) -> None:
|
| 596 |
+
super().__init__(name)
|
| 597 |
+
self.grid = grid
|
| 598 |
+
self.template = self._template_from_string(source)
|
| 599 |
+
assert name not in self.all_templates, "duplicate template name"
|
| 600 |
+
self.all_templates[name] = self
|
| 601 |
+
self.debug = debug
|
| 602 |
+
|
| 603 |
+
def generate( # type: ignore[override]
|
| 604 |
+
self,
|
| 605 |
+
input_nodes,
|
| 606 |
+
layout,
|
| 607 |
+
num_stages,
|
| 608 |
+
num_warps,
|
| 609 |
+
prefix_args=0,
|
| 610 |
+
suffix_args=0,
|
| 611 |
+
epilogue_fn=identity,
|
| 612 |
+
subgraphs=None,
|
| 613 |
+
mutated_inputs=None,
|
| 614 |
+
call_sizes=None,
|
| 615 |
+
**kwargs,
|
| 616 |
+
):
|
| 617 |
+
"""This function generates a TritonTemplateCaller
|
| 618 |
+
|
| 619 |
+
Args:
|
| 620 |
+
input_nodes: List of input nodes
|
| 621 |
+
layout: Output layout
|
| 622 |
+
num_stages: Number of stages for triton launch
|
| 623 |
+
num_warps: Number of warps for triton launch
|
| 624 |
+
prefix_args: Number of input nodes to be passed as arguments
|
| 625 |
+
suffix_args: Number of input nodes to be passed as arguments
|
| 626 |
+
epilogue_fn: Optional epilogue function to be called on the output
|
| 627 |
+
subgraphs: Optional subgraphs to be passed as arguments, these will be inlined
|
| 628 |
+
into the triton template string
|
| 629 |
+
mutated_inputs: Optional list of input nodes that are mutated by the kernel, this is helpful
|
| 630 |
+
if you need to return multiple outputs. You can pass them as inputs and mark them as
|
| 631 |
+
being mutated by the kernel.
|
| 632 |
+
"""
|
| 633 |
+
assert self.template, "requires jinja2"
|
| 634 |
+
defines = StringIO()
|
| 635 |
+
for name, val in kwargs.items():
|
| 636 |
+
defines.write(f"{name} : tl.constexpr = {val}\n")
|
| 637 |
+
defines = defines.getvalue()
|
| 638 |
+
|
| 639 |
+
fake_out = ir.Buffer("buf_out", layout)
|
| 640 |
+
kernel_name = f"triton_{self.name}"
|
| 641 |
+
|
| 642 |
+
numel = sympy_product(layout.size)
|
| 643 |
+
buffers = itertools.chain(input_nodes, (fake_out,))
|
| 644 |
+
if not TritonScheduling.can_use_32bit_indexing(numel, buffers):
|
| 645 |
+
raise NotImplementedError(
|
| 646 |
+
"64-bit indexing is not yet implemented for triton templates"
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
if call_sizes is None:
|
| 650 |
+
call_sizes = layout.size
|
| 651 |
+
|
| 652 |
+
kernel_options = dict(
|
| 653 |
+
input_nodes=input_nodes,
|
| 654 |
+
defines=defines,
|
| 655 |
+
num_stages=num_stages,
|
| 656 |
+
num_warps=num_warps,
|
| 657 |
+
grid_fn=self.grid,
|
| 658 |
+
meta=kwargs,
|
| 659 |
+
call_sizes=call_sizes,
|
| 660 |
+
prefix_args=prefix_args,
|
| 661 |
+
suffix_args=suffix_args,
|
| 662 |
+
epilogue_fn=epilogue_fn,
|
| 663 |
+
index_dtype="tl.int32",
|
| 664 |
+
subgraphs=subgraphs,
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
with patch.object(
|
| 668 |
+
V.graph, "get_dtype", self._fake_get_dtype(fake_out)
|
| 669 |
+
), TritonTemplateKernel(
|
| 670 |
+
kernel_name=kernel_name,
|
| 671 |
+
output_node=fake_out,
|
| 672 |
+
use_jit=False,
|
| 673 |
+
**kernel_options,
|
| 674 |
+
) as kernel:
|
| 675 |
+
try:
|
| 676 |
+
template = kernel.render(self.template, kwargs)
|
| 677 |
+
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
|
| 678 |
+
code = template.finalize_all()
|
| 679 |
+
except ZeroDivisionError:
|
| 680 |
+
# TODO(nmacchioni): fix sympy division by zero
|
| 681 |
+
return None
|
| 682 |
+
if self.debug:
|
| 683 |
+
print("Generated Code:\n", code)
|
| 684 |
+
extra = (
|
| 685 |
+
"-".join(
|
| 686 |
+
[
|
| 687 |
+
*[
|
| 688 |
+
f"{kwarg}={repr(kwargs[kwarg])}"
|
| 689 |
+
for kwarg in sorted(kwargs.keys())
|
| 690 |
+
],
|
| 691 |
+
f"num_stages={num_stages}",
|
| 692 |
+
f"num_warps={num_warps}",
|
| 693 |
+
]
|
| 694 |
+
)
|
| 695 |
+
+ "-"
|
| 696 |
+
)
|
| 697 |
+
mod = PyCodeCache.load(code, extra)
|
| 698 |
+
|
| 699 |
+
input_call_args = tuple(kernel.args.input_buffers.keys())
|
| 700 |
+
output_call_args = tuple(kernel.args.output_buffers.keys())
|
| 701 |
+
|
| 702 |
+
# We expect the input_buffer order to be [*input_nodes, *captured_buffers]
|
| 703 |
+
expected_input_args = tuple(unique(x.get_name() for x in input_nodes))
|
| 704 |
+
expected_output_args = (fake_out.get_name(),)
|
| 705 |
+
assert input_call_args[: len(expected_input_args)] == expected_input_args, (
|
| 706 |
+
input_call_args,
|
| 707 |
+
expected_input_args,
|
| 708 |
+
)
|
| 709 |
+
assert output_call_args == expected_output_args, (
|
| 710 |
+
output_call_args,
|
| 711 |
+
expected_output_args,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args])
|
| 715 |
+
extra_args = V.graph.sizevars.size_hints(
|
| 716 |
+
map(sympy.expand, tuple(kernel.args.sizevars.keys())),
|
| 717 |
+
fallback=config.unbacked_symint_fallback,
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
|
| 721 |
+
|
| 722 |
+
def make_kernel_render(out_node):
|
| 723 |
+
kernel = TritonTemplateKernel(
|
| 724 |
+
kernel_name=str(Placeholder.KERNEL_NAME),
|
| 725 |
+
output_node=out_node,
|
| 726 |
+
use_jit=False,
|
| 727 |
+
**kernel_options,
|
| 728 |
+
)
|
| 729 |
+
render = functools.partial(
|
| 730 |
+
kernel.render,
|
| 731 |
+
self.template,
|
| 732 |
+
kwargs,
|
| 733 |
+
)
|
| 734 |
+
return kernel, render
|
| 735 |
+
|
| 736 |
+
# create the BenchmarkRequest
|
| 737 |
+
assert mod.__file__ is not None
|
| 738 |
+
grid = self.grid(
|
| 739 |
+
*V.graph.sizevars.size_hints(
|
| 740 |
+
call_sizes,
|
| 741 |
+
fallback=config.unbacked_symint_fallback,
|
| 742 |
+
),
|
| 743 |
+
kwargs,
|
| 744 |
+
)
|
| 745 |
+
bmreq = TritonBenchmarkRequest(
|
| 746 |
+
module_path=mod.__file__,
|
| 747 |
+
module_cache_key=mod.key,
|
| 748 |
+
kernel_name=kernel_name,
|
| 749 |
+
grid=grid,
|
| 750 |
+
extra_args=extra_args,
|
| 751 |
+
num_stages=num_stages,
|
| 752 |
+
num_warps=num_warps,
|
| 753 |
+
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
|
| 754 |
+
input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), # type: ignore[arg-type]
|
| 755 |
+
output_tensor_meta=TensorMeta.from_irnodes(layout),
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
return TritonTemplateCaller(
|
| 759 |
+
kernel_hash_name,
|
| 760 |
+
full_input_nodes,
|
| 761 |
+
layout,
|
| 762 |
+
make_kernel_render,
|
| 763 |
+
extra.strip("-").replace("-", ", "),
|
| 764 |
+
bmreq,
|
| 765 |
+
log_info={
|
| 766 |
+
"tile_shape": str(
|
| 767 |
+
(
|
| 768 |
+
kwargs.get("BLOCK_M", -1),
|
| 769 |
+
kwargs.get("BLOCK_K", -1),
|
| 770 |
+
kwargs.get("BLOCK_N", -1),
|
| 771 |
+
)
|
| 772 |
+
),
|
| 773 |
+
"num_stages": num_stages,
|
| 774 |
+
"num_warps": num_warps,
|
| 775 |
+
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
|
| 776 |
+
"acc_type": str(kwargs.get("ACC_TYPE", None)),
|
| 777 |
+
},
|
| 778 |
+
mutated_inputs=mutated_inputs,
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
class ExternKernelChoice:
|
| 783 |
+
def __init__(
|
| 784 |
+
self,
|
| 785 |
+
kernel,
|
| 786 |
+
cpp_kernel=None,
|
| 787 |
+
*,
|
| 788 |
+
name=None,
|
| 789 |
+
has_out_variant=True,
|
| 790 |
+
op_overload=None,
|
| 791 |
+
use_fallback_kernel=False,
|
| 792 |
+
kernel_creator=None,
|
| 793 |
+
) -> None:
|
| 794 |
+
super().__init__()
|
| 795 |
+
name = name or kernel.__name__
|
| 796 |
+
assert callable(kernel)
|
| 797 |
+
assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}"
|
| 798 |
+
self.name = name
|
| 799 |
+
self.cpp_kernel_name = cpp_kernel
|
| 800 |
+
self.has_out_variant = has_out_variant
|
| 801 |
+
setattr(extern_kernels, name, kernel)
|
| 802 |
+
self.op_overload = op_overload
|
| 803 |
+
self.use_fallback_kernel = use_fallback_kernel
|
| 804 |
+
self.kernel_creator = kernel_creator
|
| 805 |
+
|
| 806 |
+
def to_callable(self):
|
| 807 |
+
return getattr(extern_kernels, self.name)
|
| 808 |
+
|
| 809 |
+
def call_name(self):
|
| 810 |
+
return f"extern_kernels.{self.name}"
|
| 811 |
+
|
| 812 |
+
@functools.lru_cache(None) # noqa: B019
|
| 813 |
+
def hash_key(self):
|
| 814 |
+
fn = self.to_callable()
|
| 815 |
+
parts = [
|
| 816 |
+
self.name,
|
| 817 |
+
getattr(fn, "__name__", ""),
|
| 818 |
+
getattr(fn, "__module__", ""),
|
| 819 |
+
]
|
| 820 |
+
try:
|
| 821 |
+
parts.append(inspect.getsource(fn))
|
| 822 |
+
except Exception:
|
| 823 |
+
pass
|
| 824 |
+
return code_hash("-".join(parts))
|
| 825 |
+
|
| 826 |
+
def bind(
|
| 827 |
+
self,
|
| 828 |
+
input_nodes,
|
| 829 |
+
layout,
|
| 830 |
+
ordered_kwargs_for_cpp_kernel=(),
|
| 831 |
+
**kwargs,
|
| 832 |
+
):
|
| 833 |
+
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
|
| 834 |
+
return ExternKernelCaller(
|
| 835 |
+
self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
class TritonTemplateCaller(ir.TritonTemplateCallerBase):
|
| 840 |
+
def __init__(
|
| 841 |
+
self,
|
| 842 |
+
name,
|
| 843 |
+
input_nodes,
|
| 844 |
+
layout,
|
| 845 |
+
make_kernel_render,
|
| 846 |
+
debug_extra,
|
| 847 |
+
bmreq,
|
| 848 |
+
log_info: Optional[
|
| 849 |
+
Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]
|
| 850 |
+
] = None,
|
| 851 |
+
mutated_inputs=None,
|
| 852 |
+
) -> None:
|
| 853 |
+
super().__init__(name, input_nodes, layout)
|
| 854 |
+
self.make_kernel_render = make_kernel_render
|
| 855 |
+
self.debug_extra = debug_extra
|
| 856 |
+
self.bmreq: TritonBenchmarkRequest = bmreq
|
| 857 |
+
if log_info is None:
|
| 858 |
+
log_info = {}
|
| 859 |
+
self.log_info: Dict[str, Any] = log_info
|
| 860 |
+
self.log_info.update(
|
| 861 |
+
{
|
| 862 |
+
"backend": "Triton",
|
| 863 |
+
"grid": str(self.bmreq.grid),
|
| 864 |
+
"num_stages": self.bmreq.num_stages,
|
| 865 |
+
"num_warps": self.bmreq.num_warps,
|
| 866 |
+
}
|
| 867 |
+
)
|
| 868 |
+
self.mutated_inputs = mutated_inputs
|
| 869 |
+
|
| 870 |
+
def benchmark(self, *args, out):
|
| 871 |
+
assert self.bmreq is not None
|
| 872 |
+
return self.bmreq.benchmark(*args, output_tensor=out)
|
| 873 |
+
|
| 874 |
+
def precompile(self):
|
| 875 |
+
assert self.bmreq is not None
|
| 876 |
+
self.bmreq.precompile()
|
| 877 |
+
|
| 878 |
+
def __str__(self) -> str:
|
| 879 |
+
return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})"
|
| 880 |
+
|
| 881 |
+
def call_name(self):
|
| 882 |
+
return f"template_kernels.{self.name}"
|
| 883 |
+
|
| 884 |
+
def hash_key(self):
|
| 885 |
+
return "-".join(
|
| 886 |
+
[
|
| 887 |
+
self.name.rsplit("_", 1)[0],
|
| 888 |
+
self.bmreq.module_cache_key,
|
| 889 |
+
]
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
def output_node(self):
|
| 893 |
+
return ir.TensorBox.create(
|
| 894 |
+
ir.TritonTemplateBuffer(
|
| 895 |
+
layout=self.layout,
|
| 896 |
+
inputs=self.input_nodes,
|
| 897 |
+
make_kernel_render=self.make_kernel_render,
|
| 898 |
+
debug_extra=self.debug_extra,
|
| 899 |
+
mutated_inputs=self.mutated_inputs,
|
| 900 |
+
)
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
| 904 |
+
"""Information returned here is logged to the autotune log file when that is enabled."""
|
| 905 |
+
return self.log_info
|
| 906 |
+
|
| 907 |
+
def get_make_kernel_render(self):
|
| 908 |
+
return self.make_kernel_render
|
| 909 |
+
|
| 910 |
+
def autoheuristic_id(self):
|
| 911 |
+
type_name = "triton"
|
| 912 |
+
info = self.info_dict()
|
| 913 |
+
# TODO(AlnisM): Does tile_shape always exist?
|
| 914 |
+
tile = info["tile_shape"]
|
| 915 |
+
tile_vals = eval(tile) # type: ignore[arg-type]
|
| 916 |
+
BLOCK_M = tile_vals[0]
|
| 917 |
+
BLOCK_K = tile_vals[1]
|
| 918 |
+
BLOCK_N = tile_vals[2]
|
| 919 |
+
num_stages = info["num_stages"]
|
| 920 |
+
num_warps = info["num_warps"]
|
| 921 |
+
return f"type={type_name}_BLOCK-M={BLOCK_M}_BLOCK-K={BLOCK_K}_BLOCK-N={BLOCK_N}_numstages={num_stages}_numwarps={num_warps}"
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
class ExternKernelCaller(ChoiceCaller):
|
| 925 |
+
def __init__(
|
| 926 |
+
self,
|
| 927 |
+
choice: ExternKernelChoice,
|
| 928 |
+
input_nodes,
|
| 929 |
+
layout,
|
| 930 |
+
kwargs=None,
|
| 931 |
+
*,
|
| 932 |
+
has_out_variant=True,
|
| 933 |
+
) -> None:
|
| 934 |
+
super().__init__(choice.name, input_nodes, layout)
|
| 935 |
+
self.choice = choice
|
| 936 |
+
self.kwargs = kwargs or {}
|
| 937 |
+
self.has_out_variant = has_out_variant
|
| 938 |
+
|
| 939 |
+
def __str__(self) -> str:
|
| 940 |
+
return f"ExternKernelCaller({self.choice.call_name()})"
|
| 941 |
+
|
| 942 |
+
def benchmark(self, *args, out):
|
| 943 |
+
if out.numel() == 0:
|
| 944 |
+
# no need to run the kerrnel of do benchmarking
|
| 945 |
+
return 0.0
|
| 946 |
+
if self.has_out_variant:
|
| 947 |
+
return super().benchmark(*args, out=out)
|
| 948 |
+
else:
|
| 949 |
+
algo = self.to_callable()
|
| 950 |
+
out_new = algo(*args)
|
| 951 |
+
torch._C._dynamo.guards.assert_size_stride(
|
| 952 |
+
out_new, tuple(out.size()), tuple(out.stride())
|
| 953 |
+
)
|
| 954 |
+
out.copy_(out_new) # for correctness checking
|
| 955 |
+
return benchmarker.benchmark(algo, args, {})
|
| 956 |
+
|
| 957 |
+
def to_callable(self):
|
| 958 |
+
fn = self.choice.to_callable()
|
| 959 |
+
if self.kwargs:
|
| 960 |
+
return functools.partial(fn, **self.kwargs)
|
| 961 |
+
else:
|
| 962 |
+
return fn
|
| 963 |
+
|
| 964 |
+
def hash_key(self):
|
| 965 |
+
return "-".join(
|
| 966 |
+
[
|
| 967 |
+
self.choice.name,
|
| 968 |
+
*[
|
| 969 |
+
f"{kwarg}={repr(self.kwargs[kwarg])}"
|
| 970 |
+
for kwarg in sorted(self.kwargs.keys())
|
| 971 |
+
],
|
| 972 |
+
self.choice.hash_key(),
|
| 973 |
+
]
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
def output_node(self):
|
| 977 |
+
if config.abi_compatible and self.choice.use_fallback_kernel:
|
| 978 |
+
assert (
|
| 979 |
+
self.choice.op_overload is not None
|
| 980 |
+
), "Please provide an op_overload to use ir.FallbackKernel"
|
| 981 |
+
inner = ir.FallbackKernel.create(
|
| 982 |
+
self.choice.op_overload, *self.input_nodes, **self.kwargs
|
| 983 |
+
)
|
| 984 |
+
elif self.choice.kernel_creator is not None:
|
| 985 |
+
inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs)
|
| 986 |
+
else:
|
| 987 |
+
cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc
|
| 988 |
+
inner = cls(
|
| 989 |
+
layout=self.layout,
|
| 990 |
+
inputs=self.input_nodes,
|
| 991 |
+
python_kernel_name=self.choice.call_name(),
|
| 992 |
+
cpp_kernel_name=self.choice.cpp_kernel_name,
|
| 993 |
+
ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel,
|
| 994 |
+
op_overload=self.choice.op_overload,
|
| 995 |
+
kwargs=self.kwargs,
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
return ir.TensorBox.create(inner)
|
| 999 |
+
|
| 1000 |
+
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
| 1001 |
+
"""Information returned here is logged to the autotune log file when that is enabled."""
|
| 1002 |
+
return {
|
| 1003 |
+
"backend": "extern",
|
| 1004 |
+
"kernel_call_name": self.choice.call_name(),
|
| 1005 |
+
}
|
| 1006 |
+
|
| 1007 |
+
def autoheuristic_id(self):
|
| 1008 |
+
return f"extern_{self.choice.name}"
|
| 1009 |
+
|
| 1010 |
+
|
| 1011 |
+
@functools.lru_cache(None)
|
| 1012 |
+
def get_mm_log_filename() -> Optional[str]:
|
| 1013 |
+
mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None)
|
| 1014 |
+
if not mm_file_name:
|
| 1015 |
+
return None
|
| 1016 |
+
|
| 1017 |
+
if "json" not in mm_file_name:
|
| 1018 |
+
mm_file_name = f"{mm_file_name}.json"
|
| 1019 |
+
|
| 1020 |
+
return mm_file_name
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
def append_to_log(filename, data):
|
| 1024 |
+
lock_file = filename.replace(".json", ".lock")
|
| 1025 |
+
lock = FileLock(lock_file)
|
| 1026 |
+
with lock:
|
| 1027 |
+
try:
|
| 1028 |
+
with open(filename) as f:
|
| 1029 |
+
log_data = json.load(f)
|
| 1030 |
+
except (FileNotFoundError, json.JSONDecodeError):
|
| 1031 |
+
log_data = []
|
| 1032 |
+
|
| 1033 |
+
log_data.append(data)
|
| 1034 |
+
|
| 1035 |
+
with open(filename, "w") as f:
|
| 1036 |
+
json.dump(log_data, f, indent=4)
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
class DataProcessorChoiceCallerWrapper:
|
| 1040 |
+
def __init__(self, wrapped, preprocessor, postprocessor) -> None:
|
| 1041 |
+
self._wrapped = wrapped
|
| 1042 |
+
if preprocessor is not None:
|
| 1043 |
+
self._preprocessor = preprocessor
|
| 1044 |
+
else:
|
| 1045 |
+
self._preprocessor = lambda x, y: (x, y)
|
| 1046 |
+
if postprocessor is not None:
|
| 1047 |
+
self._postprocessor = postprocessor
|
| 1048 |
+
else:
|
| 1049 |
+
self._postprocessor = lambda x: x
|
| 1050 |
+
|
| 1051 |
+
def __getattr__(self, name):
|
| 1052 |
+
return getattr(self._wrapped, name)
|
| 1053 |
+
|
| 1054 |
+
def benchmark(self, *args, out) -> float:
|
| 1055 |
+
new_args, new_out = self._preprocessor(args, out)
|
| 1056 |
+
result = self._wrapped.benchmark(*new_args, out=new_out)
|
| 1057 |
+
new_out = self._postprocessor(new_out)
|
| 1058 |
+
if out is not new_out:
|
| 1059 |
+
out.copy_(new_out)
|
| 1060 |
+
return result
|
| 1061 |
+
|
| 1062 |
+
def output_node(self) -> ir.TensorBox:
|
| 1063 |
+
result = self._wrapped.output_node()
|
| 1064 |
+
return self._postprocessor(result)
|
| 1065 |
+
|
| 1066 |
+
def __repr__(self) -> str:
|
| 1067 |
+
return f"DataProcessorChoiceCallerWrapper({self._wrapped})"
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
class DataProcessorTemplateWrapper:
|
| 1071 |
+
"""
|
| 1072 |
+
A wrapper class for a kernel template.
|
| 1073 |
+
|
| 1074 |
+
This class together with `DataProcessorChoiceCallerWrapper` provides a convenient way to
|
| 1075 |
+
preprocess and postprocess data before and after using the wrapped template. A typical
|
| 1076 |
+
usage is to reorder or filter the input nodes in order to match the expected input of other
|
| 1077 |
+
kernel choices like a ATen kernel. A more complicated usage is to prepack the weights.
|
| 1078 |
+
See the example from :mod:`cpp_gemm_template` for more details.
|
| 1079 |
+
"""
|
| 1080 |
+
|
| 1081 |
+
def __init__(
|
| 1082 |
+
self,
|
| 1083 |
+
wrapped_template_cls,
|
| 1084 |
+
preprocessor,
|
| 1085 |
+
postprocessor,
|
| 1086 |
+
**kwargs,
|
| 1087 |
+
) -> None:
|
| 1088 |
+
if preprocessor is not None:
|
| 1089 |
+
self._preprocessor = preprocessor
|
| 1090 |
+
else:
|
| 1091 |
+
self._preprocessor = lambda x, y: (x, y)
|
| 1092 |
+
if postprocessor is not None:
|
| 1093 |
+
self._postprocessor = postprocessor
|
| 1094 |
+
else:
|
| 1095 |
+
self._postprocessor = lambda x: x
|
| 1096 |
+
assert "input_nodes" in kwargs
|
| 1097 |
+
assert "layout" in kwargs
|
| 1098 |
+
kwargs["input_nodes"], kwargs["layout"] = preprocessor(
|
| 1099 |
+
kwargs["input_nodes"], kwargs["layout"]
|
| 1100 |
+
)
|
| 1101 |
+
self._wrapped = wrapped_template_cls(**kwargs)
|
| 1102 |
+
|
| 1103 |
+
def __getattr__(self, name):
|
| 1104 |
+
return getattr(self._wrapped, name)
|
| 1105 |
+
|
| 1106 |
+
def maybe_append_choice(self, choices, **kwargs):
|
| 1107 |
+
return type(self._wrapped).maybe_append_choice(self, choices, **kwargs)
|
| 1108 |
+
|
| 1109 |
+
def generate(self, **kwargs):
|
| 1110 |
+
choice_caller = self._wrapped.generate(**kwargs)
|
| 1111 |
+
return DataProcessorChoiceCallerWrapper(
|
| 1112 |
+
choice_caller, self._preprocessor, self._postprocessor
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
def __repr__(self) -> str:
|
| 1116 |
+
return f"DataProcessorTemplateWrapper({self._wrapped})"
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
class ErrorFromChoice(RuntimeError):
|
| 1120 |
+
def __init__(self, msg, choice: ChoiceCaller, inputs_str) -> None:
|
| 1121 |
+
msg += f"\nFrom choice {choice}\n{inputs_str}"
|
| 1122 |
+
super().__init__(msg)
|
| 1123 |
+
self.choice = choice
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
class NoValidChoicesError(RuntimeError):
|
| 1127 |
+
pass
|
| 1128 |
+
|
| 1129 |
+
|
| 1130 |
+
@functools.lru_cache(None)
|
| 1131 |
+
def get_env_num_workers() -> Optional[int]:
|
| 1132 |
+
if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
|
| 1133 |
+
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
|
| 1134 |
+
return None
|
| 1135 |
+
|
| 1136 |
+
|
| 1137 |
+
def create_inputs_key(input_nodes) -> str:
|
| 1138 |
+
return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes])
|
| 1139 |
+
|
| 1140 |
+
|
| 1141 |
+
def create_precompile_key(
|
| 1142 |
+
name: str, inputs_key: str, choices: List[ChoiceCaller]
|
| 1143 |
+
) -> str:
|
| 1144 |
+
return ":".join(
|
| 1145 |
+
[
|
| 1146 |
+
name,
|
| 1147 |
+
inputs_key,
|
| 1148 |
+
torch.get_float32_matmul_precision(),
|
| 1149 |
+
]
|
| 1150 |
+
+ [choice.hash_key() for choice in choices]
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
|
| 1154 |
+
class AlgorithmSelectorCache(PersistentCache):
|
| 1155 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 1156 |
+
super().__init__(*args, **kwargs)
|
| 1157 |
+
|
| 1158 |
+
# the autotuning will get occur in the scheduler, so there is
|
| 1159 |
+
# no guarantee that the first lowering for a given key will also be the
|
| 1160 |
+
# first to benchmark it. share a single precompilation function for all lowerings
|
| 1161 |
+
# of a particular key
|
| 1162 |
+
self.precompile_cache: Dict[str, Callable[[], None]] = {}
|
| 1163 |
+
# list of callbacks that are called after benchmarking
|
| 1164 |
+
self.feedback_saver_fns: List[
|
| 1165 |
+
Callable[
|
| 1166 |
+
[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None
|
| 1167 |
+
]
|
| 1168 |
+
] = []
|
| 1169 |
+
|
| 1170 |
+
def __call__(
|
| 1171 |
+
self,
|
| 1172 |
+
name,
|
| 1173 |
+
choices: List[ChoiceCaller],
|
| 1174 |
+
input_nodes,
|
| 1175 |
+
layout,
|
| 1176 |
+
# optional dict mapping arg indices to the functions
|
| 1177 |
+
# generating a torch.Tensor for that input from the
|
| 1178 |
+
# corresponding ir.Buffer. if passed for a given
|
| 1179 |
+
# arg, the function will be called instead of
|
| 1180 |
+
# generating a random torch.Tensor for benchmarking.
|
| 1181 |
+
input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
|
| 1182 |
+
precompilation_timeout_seconds: int = 60 * 60,
|
| 1183 |
+
return_multi_template=False,
|
| 1184 |
+
):
|
| 1185 |
+
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
|
| 1186 |
+
|
| 1187 |
+
# Templates selected with input_gen_fns require specific input data to avoid IMA
|
| 1188 |
+
# Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection
|
| 1189 |
+
# TODO(jgong5): support multi-template on CPU
|
| 1190 |
+
if input_gen_fns is not None or layout.device.type == "cpu":
|
| 1191 |
+
return_multi_template = False
|
| 1192 |
+
|
| 1193 |
+
# TODO - assert that we have not mutating kernels here
|
| 1194 |
+
|
| 1195 |
+
# TODO(nmacchioni): remove once CI tests are fixed
|
| 1196 |
+
choices = [choice for choice in choices if choice is not None]
|
| 1197 |
+
|
| 1198 |
+
if mm_file_name := get_mm_log_filename():
|
| 1199 |
+
M, K = input_nodes[-2].get_size()[:2]
|
| 1200 |
+
N = input_nodes[-1].get_size()[-1]
|
| 1201 |
+
append_to_log(mm_file_name, {"invoke": str((M, K, N))})
|
| 1202 |
+
|
| 1203 |
+
if len(choices) == 0:
|
| 1204 |
+
backend_config = (
|
| 1205 |
+
"max_autotune_gemm_backends"
|
| 1206 |
+
if name != "convolution"
|
| 1207 |
+
else "max_autotune_conv_backends"
|
| 1208 |
+
)
|
| 1209 |
+
raise NoValidChoicesError(
|
| 1210 |
+
f"No choices to select, please consider adding ATEN into {backend_config} "
|
| 1211 |
+
"config (defined in torch/_inductor/config.py) to allow at least one choice. "
|
| 1212 |
+
)
|
| 1213 |
+
log.debug("Max autotune selects from %s choices.", str(len(choices)))
|
| 1214 |
+
|
| 1215 |
+
if len(choices) == 1:
|
| 1216 |
+
if not isinstance(choices[0], CUDATemplateCaller):
|
| 1217 |
+
# CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size.
|
| 1218 |
+
return choices[0].output_node()
|
| 1219 |
+
|
| 1220 |
+
@functools.lru_cache(None)
|
| 1221 |
+
def make_benchmark_fn():
|
| 1222 |
+
return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns)
|
| 1223 |
+
|
| 1224 |
+
inputs_key = create_inputs_key(input_nodes)
|
| 1225 |
+
|
| 1226 |
+
def precompile(choices) -> Callable[[], None]:
|
| 1227 |
+
def no_op(*args, **kwargs):
|
| 1228 |
+
return
|
| 1229 |
+
|
| 1230 |
+
if (
|
| 1231 |
+
precompilation_timeout_seconds is None
|
| 1232 |
+
or precompilation_timeout_seconds <= 0
|
| 1233 |
+
):
|
| 1234 |
+
return no_op
|
| 1235 |
+
|
| 1236 |
+
env_workers = get_env_num_workers()
|
| 1237 |
+
num_workers = env_workers if env_workers is not None else (len(choices))
|
| 1238 |
+
|
| 1239 |
+
if num_workers <= 0:
|
| 1240 |
+
return no_op
|
| 1241 |
+
|
| 1242 |
+
# https://github.com/python/cpython/issues/106905
|
| 1243 |
+
if (
|
| 1244 |
+
sys.version_info.major == 3
|
| 1245 |
+
and sys.version_info.minor == 11
|
| 1246 |
+
and sys.version_info.micro <= 8
|
| 1247 |
+
):
|
| 1248 |
+
return no_op
|
| 1249 |
+
|
| 1250 |
+
# check local and global cache before precompiling
|
| 1251 |
+
timings = self.lookup(
|
| 1252 |
+
choices,
|
| 1253 |
+
name,
|
| 1254 |
+
inputs_key,
|
| 1255 |
+
benchmark=None,
|
| 1256 |
+
)
|
| 1257 |
+
|
| 1258 |
+
if timings:
|
| 1259 |
+
return no_op
|
| 1260 |
+
|
| 1261 |
+
precompile_key = create_precompile_key(name, inputs_key, choices)
|
| 1262 |
+
if precompile_func := self.precompile_cache.get(precompile_key):
|
| 1263 |
+
return precompile_func
|
| 1264 |
+
|
| 1265 |
+
log.info(
|
| 1266 |
+
"Multithreaded precompilation for %d choices using %d worker threads",
|
| 1267 |
+
len(choices),
|
| 1268 |
+
num_workers,
|
| 1269 |
+
)
|
| 1270 |
+
|
| 1271 |
+
# In rare circumstances, because python threads inherit global state,
|
| 1272 |
+
# thread pool executor can race and leave stdout/stderr in a state
|
| 1273 |
+
# different than the original values. we explicitly restore the state
|
| 1274 |
+
# here to avoid this issue.
|
| 1275 |
+
|
| 1276 |
+
initial_stdout = sys.stdout
|
| 1277 |
+
initial_stderr = sys.stderr
|
| 1278 |
+
|
| 1279 |
+
def precompile_with_captured_stdout(choice):
|
| 1280 |
+
with restore_stdout_stderr(initial_stdout, initial_stderr):
|
| 1281 |
+
return choice.precompile()
|
| 1282 |
+
|
| 1283 |
+
executor = ThreadPoolExecutor(max_workers=num_workers)
|
| 1284 |
+
|
| 1285 |
+
futures = {}
|
| 1286 |
+
for c in choices:
|
| 1287 |
+
if hasattr(c, "precompile"):
|
| 1288 |
+
future = executor.submit(precompile_with_captured_stdout, c)
|
| 1289 |
+
futures[future] = c
|
| 1290 |
+
|
| 1291 |
+
@functools.lru_cache(None)
|
| 1292 |
+
@restore_stdout_stderr(initial_stdout, initial_stderr)
|
| 1293 |
+
def wait_on_futures():
|
| 1294 |
+
counters["inductor"]["select_algorithm_precompile"] += 1
|
| 1295 |
+
for future in as_completed(
|
| 1296 |
+
futures,
|
| 1297 |
+
timeout=precompilation_timeout_seconds,
|
| 1298 |
+
):
|
| 1299 |
+
if e := future.exception():
|
| 1300 |
+
log.error(
|
| 1301 |
+
"Exception %s for benchmark choice %s", e, futures[future]
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
executor.shutdown(wait=True)
|
| 1305 |
+
|
| 1306 |
+
self.precompile_cache[precompile_key] = wait_on_futures
|
| 1307 |
+
|
| 1308 |
+
return wait_on_futures
|
| 1309 |
+
|
| 1310 |
+
def autotune(choices):
|
| 1311 |
+
return make_benchmark_fn()(choices)
|
| 1312 |
+
|
| 1313 |
+
if config.autotune_in_subproc:
|
| 1314 |
+
from .autotune_process import tuning_pool
|
| 1315 |
+
|
| 1316 |
+
# do the optional warmup
|
| 1317 |
+
tuning_pool.initialize()
|
| 1318 |
+
|
| 1319 |
+
def do_autotuning(precompile_fn):
|
| 1320 |
+
precompile_start_ts = time.time()
|
| 1321 |
+
precompile_fn()
|
| 1322 |
+
precompile_elapse = time.time() - precompile_start_ts
|
| 1323 |
+
|
| 1324 |
+
autotune_start_ts = time.time()
|
| 1325 |
+
timings = self.lookup(
|
| 1326 |
+
choices,
|
| 1327 |
+
name,
|
| 1328 |
+
inputs_key,
|
| 1329 |
+
autotune,
|
| 1330 |
+
)
|
| 1331 |
+
autotune_elapse = time.time() - autotune_start_ts
|
| 1332 |
+
|
| 1333 |
+
if timings and all(
|
| 1334 |
+
not math.isfinite(timing) for timing in timings.values()
|
| 1335 |
+
):
|
| 1336 |
+
raise NoValidChoicesError
|
| 1337 |
+
|
| 1338 |
+
if make_benchmark_fn.cache_info().currsize:
|
| 1339 |
+
counters["inductor"]["select_algorithm_autotune"] += 1
|
| 1340 |
+
|
| 1341 |
+
if (
|
| 1342 |
+
make_benchmark_fn.cache_info().currsize
|
| 1343 |
+
or log.getEffectiveLevel() == logging.DEBUG
|
| 1344 |
+
or config.trace.log_autotuning_results
|
| 1345 |
+
):
|
| 1346 |
+
self.log_results(
|
| 1347 |
+
name, input_nodes, timings, autotune_elapse, precompile_elapse
|
| 1348 |
+
)
|
| 1349 |
+
|
| 1350 |
+
for feedback_fn in self.feedback_saver_fns:
|
| 1351 |
+
feedback_fn(timings, name, input_nodes, choices)
|
| 1352 |
+
|
| 1353 |
+
return timings
|
| 1354 |
+
|
| 1355 |
+
precompile_fn = precompile(choices)
|
| 1356 |
+
|
| 1357 |
+
if return_multi_template and (config.max_autotune or config.max_autotune_gemm):
|
| 1358 |
+
|
| 1359 |
+
def get_timings():
|
| 1360 |
+
timings = do_autotuning(precompile_fn)
|
| 1361 |
+
min_extern_choice = float("inf")
|
| 1362 |
+
for choice, timing in timings.items():
|
| 1363 |
+
if isinstance(choice, ExternKernelCaller):
|
| 1364 |
+
min_extern_choice = min(min_extern_choice, timing)
|
| 1365 |
+
|
| 1366 |
+
timings = {
|
| 1367 |
+
choice: time
|
| 1368 |
+
for choice, time in timings.items()
|
| 1369 |
+
if (
|
| 1370 |
+
time <= min_extern_choice
|
| 1371 |
+
or not isinstance(choice, ExternKernelCaller)
|
| 1372 |
+
)
|
| 1373 |
+
}
|
| 1374 |
+
|
| 1375 |
+
return timings
|
| 1376 |
+
|
| 1377 |
+
return torch._inductor.ir.TensorBox.create(
|
| 1378 |
+
torch._inductor.ir.MultiTemplateBuffer(
|
| 1379 |
+
layout,
|
| 1380 |
+
input_nodes,
|
| 1381 |
+
get_timings,
|
| 1382 |
+
)
|
| 1383 |
+
)
|
| 1384 |
+
|
| 1385 |
+
# TODO - dont want to precompile if we have a cache hit
|
| 1386 |
+
timings = do_autotuning(precompile_fn)
|
| 1387 |
+
if timings == {} or choices[0] not in timings:
|
| 1388 |
+
return choices[0].output_node()
|
| 1389 |
+
|
| 1390 |
+
selected_key = builtins.min(timings, key=timings.__getitem__)
|
| 1391 |
+
selected_time = timings[selected_key]
|
| 1392 |
+
selected_choice = selected_key.output_node()
|
| 1393 |
+
log.debug("selected choice: %s", str(selected_choice))
|
| 1394 |
+
return selected_choice
|
| 1395 |
+
|
| 1396 |
+
@classmethod
|
| 1397 |
+
def make_benchmark_fn(
|
| 1398 |
+
cls,
|
| 1399 |
+
choices,
|
| 1400 |
+
input_nodes,
|
| 1401 |
+
layout,
|
| 1402 |
+
input_gen_fns=None,
|
| 1403 |
+
):
|
| 1404 |
+
if input_gen_fns is None:
|
| 1405 |
+
input_gen_fns = {}
|
| 1406 |
+
|
| 1407 |
+
def get_inputs():
|
| 1408 |
+
# de-duplicate args
|
| 1409 |
+
unique_example_inputs = {
|
| 1410 |
+
x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x)
|
| 1411 |
+
for i, x in enumerate(input_nodes)
|
| 1412 |
+
}
|
| 1413 |
+
example_inputs = list(unique_example_inputs.values())
|
| 1414 |
+
example_inputs_extern = [
|
| 1415 |
+
unique_example_inputs[input_node.get_name()]
|
| 1416 |
+
if unique_example_inputs[input_node.get_name()].is_mkldnn
|
| 1417 |
+
else torch.as_strided(
|
| 1418 |
+
unique_example_inputs[input_node.get_name()],
|
| 1419 |
+
V.graph.sizevars.size_hints(
|
| 1420 |
+
input_node.get_size(),
|
| 1421 |
+
fallback=config.unbacked_symint_fallback,
|
| 1422 |
+
),
|
| 1423 |
+
V.graph.sizevars.size_hints(
|
| 1424 |
+
input_node.get_stride(),
|
| 1425 |
+
fallback=config.unbacked_symint_fallback,
|
| 1426 |
+
),
|
| 1427 |
+
V.graph.sizevars.size_hint(
|
| 1428 |
+
input_node.get_layout().offset,
|
| 1429 |
+
fallback=config.unbacked_symint_fallback,
|
| 1430 |
+
),
|
| 1431 |
+
)
|
| 1432 |
+
for input_node in input_nodes
|
| 1433 |
+
]
|
| 1434 |
+
|
| 1435 |
+
out = cls.benchmark_example_value(layout)
|
| 1436 |
+
out_extern = torch.as_strided(
|
| 1437 |
+
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
|
| 1438 |
+
)
|
| 1439 |
+
expected = None
|
| 1440 |
+
if VERIFY:
|
| 1441 |
+
choices[0].benchmark(*example_inputs_extern, out=out_extern)
|
| 1442 |
+
expected = out_extern.clone()
|
| 1443 |
+
|
| 1444 |
+
return example_inputs, example_inputs_extern, out, out_extern, expected
|
| 1445 |
+
|
| 1446 |
+
if DEBUG:
|
| 1447 |
+
print(f"{len(choices)} tuning requests:")
|
| 1448 |
+
|
| 1449 |
+
def debug_str(example_inputs, out):
|
| 1450 |
+
def tensor_repr(x):
|
| 1451 |
+
return (
|
| 1452 |
+
f"torch.empty_strided({tuple(x.size())!r}, {tuple(x.stride())!r}, "
|
| 1453 |
+
f"dtype={x.dtype!r}, device={x.device.type!r})"
|
| 1454 |
+
)
|
| 1455 |
+
|
| 1456 |
+
lines = [
|
| 1457 |
+
"inputs = [",
|
| 1458 |
+
]
|
| 1459 |
+
for x in example_inputs:
|
| 1460 |
+
lines.append(f" {tensor_repr(x)},")
|
| 1461 |
+
lines += ["]", f"out = {tensor_repr(out)}", ""]
|
| 1462 |
+
return "\n".join(lines)
|
| 1463 |
+
|
| 1464 |
+
def benchmark_choice_in_current_process(
|
| 1465 |
+
choice, example_inputs, example_inputs_extern, out, out_extern, expected
|
| 1466 |
+
):
|
| 1467 |
+
out.zero_()
|
| 1468 |
+
if isinstance(choice, ExternKernelCaller):
|
| 1469 |
+
# aten kernels want the offset baked in for sliced tensors
|
| 1470 |
+
result = choice.benchmark(*example_inputs_extern, out=out_extern)
|
| 1471 |
+
else:
|
| 1472 |
+
# triton templates want the base pointer for sliced tensors
|
| 1473 |
+
result = choice.benchmark(*example_inputs, out=out)
|
| 1474 |
+
if VERIFY and expected is not None:
|
| 1475 |
+
torch.testing.assert_close(out_extern, expected, **VERIFY)
|
| 1476 |
+
if torch.cuda.is_available():
|
| 1477 |
+
torch.cuda.synchronize() # shake out any CUDA errors
|
| 1478 |
+
return result
|
| 1479 |
+
|
| 1480 |
+
def benchmark_in_current_process(choices):
|
| 1481 |
+
inputs = get_inputs()
|
| 1482 |
+
example_inputs, _, out, _, _ = inputs
|
| 1483 |
+
timings = {}
|
| 1484 |
+
for choice in choices:
|
| 1485 |
+
try:
|
| 1486 |
+
timing = benchmark_choice_in_current_process(choice, *inputs)
|
| 1487 |
+
except CUDACompileError as e:
|
| 1488 |
+
log.error(
|
| 1489 |
+
"CUDA compilation error during autotuning: \n%s. \nIgnoring this choice.",
|
| 1490 |
+
str(e),
|
| 1491 |
+
)
|
| 1492 |
+
timing = float("inf")
|
| 1493 |
+
except NotImplementedError as e:
|
| 1494 |
+
log.warning("Not yet implemented: %s", e)
|
| 1495 |
+
timing = float("inf")
|
| 1496 |
+
except RuntimeError as e:
|
| 1497 |
+
msg = str(e)
|
| 1498 |
+
if "invalid argument" in msg:
|
| 1499 |
+
msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n"
|
| 1500 |
+
else:
|
| 1501 |
+
if "illegal memory access" in msg:
|
| 1502 |
+
msg += "\n\nEither error in template or triton bug.\n"
|
| 1503 |
+
log.error(
|
| 1504 |
+
"Runtime error during autotuning: \n%s. \nIgnoring this choice.",
|
| 1505 |
+
msg,
|
| 1506 |
+
)
|
| 1507 |
+
timing = float("inf")
|
| 1508 |
+
except AssertionError as e:
|
| 1509 |
+
raise AssertionError( # noqa: B904
|
| 1510 |
+
f"Incorrect result from choice {choice}\n\n{e}"
|
| 1511 |
+
)
|
| 1512 |
+
except Exception as e:
|
| 1513 |
+
try:
|
| 1514 |
+
from triton.runtime.autotuner import OutOfResources
|
| 1515 |
+
|
| 1516 |
+
if isinstance(e, OutOfResources):
|
| 1517 |
+
log.warning(e)
|
| 1518 |
+
timing = float("inf")
|
| 1519 |
+
else:
|
| 1520 |
+
raise e
|
| 1521 |
+
except ImportError:
|
| 1522 |
+
raise e from None
|
| 1523 |
+
|
| 1524 |
+
timings[choice] = timing
|
| 1525 |
+
|
| 1526 |
+
return timings
|
| 1527 |
+
|
| 1528 |
+
def benchmark_in_sub_process(choices):
|
| 1529 |
+
from . import autotune_process
|
| 1530 |
+
|
| 1531 |
+
# only benchmark triton kernel in sub process for now.
|
| 1532 |
+
# ATen/Extern kernel are still benchmarked in the current process.
|
| 1533 |
+
extern = [c for c in choices if isinstance(c, ExternKernelCaller)]
|
| 1534 |
+
triton = [c for c in choices if not isinstance(c, ExternKernelCaller)]
|
| 1535 |
+
|
| 1536 |
+
timings = benchmark_in_current_process(extern)
|
| 1537 |
+
timings.update(autotune_process.benchmark_in_sub_process(triton))
|
| 1538 |
+
return timings
|
| 1539 |
+
|
| 1540 |
+
benchmark = (
|
| 1541 |
+
benchmark_in_sub_process
|
| 1542 |
+
if config.autotune_in_subproc
|
| 1543 |
+
else benchmark_in_current_process
|
| 1544 |
+
)
|
| 1545 |
+
|
| 1546 |
+
return benchmark
|
| 1547 |
+
|
| 1548 |
+
@staticmethod
|
| 1549 |
+
def log_results(
|
| 1550 |
+
name: str,
|
| 1551 |
+
input_nodes: List[ir.IRNode],
|
| 1552 |
+
timings: Dict[ChoiceCaller, float],
|
| 1553 |
+
elapse: float,
|
| 1554 |
+
precompile_elapse: float,
|
| 1555 |
+
):
|
| 1556 |
+
V.debug.log_autotuning_results(
|
| 1557 |
+
name, input_nodes, timings, elapse, precompile_elapse
|
| 1558 |
+
)
|
| 1559 |
+
if not (config.max_autotune or config.max_autotune_gemm) or not PRINT_AUTOTUNE:
|
| 1560 |
+
return
|
| 1561 |
+
sizes = ", ".join(
|
| 1562 |
+
[
|
| 1563 |
+
"x".join(
|
| 1564 |
+
map(
|
| 1565 |
+
str,
|
| 1566 |
+
V.graph.sizevars.size_hints(
|
| 1567 |
+
n.get_size(), fallback=config.unbacked_symint_fallback
|
| 1568 |
+
),
|
| 1569 |
+
)
|
| 1570 |
+
)
|
| 1571 |
+
for n in input_nodes
|
| 1572 |
+
]
|
| 1573 |
+
)
|
| 1574 |
+
|
| 1575 |
+
n = None if log.getEffectiveLevel() == logging.DEBUG else 10
|
| 1576 |
+
top_k = sorted(timings, key=timings.__getitem__)[:n]
|
| 1577 |
+
best = top_k[0]
|
| 1578 |
+
|
| 1579 |
+
def get_choice_info(choice):
|
| 1580 |
+
if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller):
|
| 1581 |
+
return {"type": "cublas", "time": timings[choice]}
|
| 1582 |
+
|
| 1583 |
+
assert isinstance(
|
| 1584 |
+
choice, torch._inductor.select_algorithm.TritonTemplateCaller
|
| 1585 |
+
)
|
| 1586 |
+
|
| 1587 |
+
info = choice.info_dict()
|
| 1588 |
+
tile = info["tile_shape"]
|
| 1589 |
+
|
| 1590 |
+
tile_vals = eval(tile) # type: ignore[arg-type]
|
| 1591 |
+
BLOCK_M = tile_vals[0]
|
| 1592 |
+
BLOCK_K = tile_vals[1]
|
| 1593 |
+
BLOCK_N = tile_vals[2]
|
| 1594 |
+
|
| 1595 |
+
return {
|
| 1596 |
+
"type": "triton",
|
| 1597 |
+
"time": timings[choice],
|
| 1598 |
+
"BLOCK_M": BLOCK_M,
|
| 1599 |
+
"BLOCK_K": BLOCK_K,
|
| 1600 |
+
"BLOCK_N": BLOCK_N,
|
| 1601 |
+
"num_stages": info["num_stages"],
|
| 1602 |
+
"num_warps": info["num_warps"],
|
| 1603 |
+
}
|
| 1604 |
+
|
| 1605 |
+
mm_filename = get_mm_log_filename()
|
| 1606 |
+
if mm_filename and "mm" in name:
|
| 1607 |
+
M, K = input_nodes[-2].get_size()[:2]
|
| 1608 |
+
N = input_nodes[-1].get_size()[-1]
|
| 1609 |
+
|
| 1610 |
+
out_dict = {
|
| 1611 |
+
str((M, K, N)): [get_choice_info(choice) for choice in timings.keys()]
|
| 1612 |
+
}
|
| 1613 |
+
|
| 1614 |
+
append_to_log(mm_filename, out_dict)
|
| 1615 |
+
|
| 1616 |
+
best_time = timings[best]
|
| 1617 |
+
sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
|
| 1618 |
+
for choice in top_k:
|
| 1619 |
+
result = timings[choice]
|
| 1620 |
+
if result:
|
| 1621 |
+
kernel_info = (
|
| 1622 |
+
choice.debug_extra if hasattr(choice, "debug_extra") else ""
|
| 1623 |
+
)
|
| 1624 |
+
sys.stderr.write(
|
| 1625 |
+
f" {choice.name} {result:.4f} ms {best_time / result:.1%} {kernel_info}\n"
|
| 1626 |
+
)
|
| 1627 |
+
else:
|
| 1628 |
+
sys.stderr.write(
|
| 1629 |
+
f" {choice.name} {result:.4f} ms <DIVIDED BY ZERO ERROR>\n"
|
| 1630 |
+
)
|
| 1631 |
+
|
| 1632 |
+
autotune_type_str = (
|
| 1633 |
+
"SubProcess" if config.autotune_in_subproc else "SingleProcess"
|
| 1634 |
+
)
|
| 1635 |
+
sys.stderr.write(
|
| 1636 |
+
f"{autotune_type_str} AUTOTUNE benchmarking takes {elapse:.4f} seconds and {precompile_elapse:.4f}"
|
| 1637 |
+
" seconds precompiling\n"
|
| 1638 |
+
)
|
| 1639 |
+
|
| 1640 |
+
@staticmethod
|
| 1641 |
+
def benchmark_example_value(node):
|
| 1642 |
+
"""
|
| 1643 |
+
Convert an ir.Buffer into a concrete torch.Tensor we can use for
|
| 1644 |
+
benchmarking.
|
| 1645 |
+
"""
|
| 1646 |
+
if isinstance(node, ir.Layout):
|
| 1647 |
+
node = ir.Buffer("fake", node)
|
| 1648 |
+
# triton templates want the base tensor.
|
| 1649 |
+
if isinstance(node, ir.BaseView):
|
| 1650 |
+
node = node.unwrap_view()
|
| 1651 |
+
return AlgorithmSelectorCache.generate_example_value(
|
| 1652 |
+
V.graph.sizevars.size_hints(
|
| 1653 |
+
node.get_size(),
|
| 1654 |
+
fallback=config.unbacked_symint_fallback,
|
| 1655 |
+
),
|
| 1656 |
+
V.graph.sizevars.size_hints(
|
| 1657 |
+
node.get_stride(),
|
| 1658 |
+
fallback=config.unbacked_symint_fallback,
|
| 1659 |
+
),
|
| 1660 |
+
node.get_device(),
|
| 1661 |
+
node.get_dtype(),
|
| 1662 |
+
node.layout.offset,
|
| 1663 |
+
)
|
| 1664 |
+
|
| 1665 |
+
@staticmethod
|
| 1666 |
+
def generate_example_value(size, stride, device, dtype, extra_size):
|
| 1667 |
+
# preserve rng states to avoid the rand_strided call below changes
|
| 1668 |
+
# the rng states for the real model code.
|
| 1669 |
+
with preserve_rng_state():
|
| 1670 |
+
return rand_strided(
|
| 1671 |
+
size,
|
| 1672 |
+
stride,
|
| 1673 |
+
device=device,
|
| 1674 |
+
dtype=dtype,
|
| 1675 |
+
extra_size=extra_size,
|
| 1676 |
+
)
|
| 1677 |
+
|
| 1678 |
+
@staticmethod
|
| 1679 |
+
def key_of(node):
|
| 1680 |
+
"""
|
| 1681 |
+
Extract the pieces of an ir.Buffer that we should invalidate cached
|
| 1682 |
+
autotuning results on.
|
| 1683 |
+
"""
|
| 1684 |
+
sizevars = V.graph.sizevars
|
| 1685 |
+
return (
|
| 1686 |
+
node.get_device().type,
|
| 1687 |
+
str(node.get_dtype()),
|
| 1688 |
+
*sizevars.size_hints(
|
| 1689 |
+
node.get_size(),
|
| 1690 |
+
fallback=config.unbacked_symint_fallback,
|
| 1691 |
+
),
|
| 1692 |
+
*sizevars.size_hints(
|
| 1693 |
+
node.get_stride(),
|
| 1694 |
+
fallback=config.unbacked_symint_fallback,
|
| 1695 |
+
),
|
| 1696 |
+
sizevars.size_hint(
|
| 1697 |
+
node.get_layout().offset,
|
| 1698 |
+
fallback=config.unbacked_symint_fallback,
|
| 1699 |
+
),
|
| 1700 |
+
)
|
| 1701 |
+
|
| 1702 |
+
def add_feedback_saver(
|
| 1703 |
+
self,
|
| 1704 |
+
fn: Callable[
|
| 1705 |
+
[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None
|
| 1706 |
+
],
|
| 1707 |
+
):
|
| 1708 |
+
self.feedback_saver_fns.append(fn)
|
| 1709 |
+
|
| 1710 |
+
|
| 1711 |
+
_ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None
|
| 1712 |
+
|
| 1713 |
+
|
| 1714 |
+
def autotune_select_algorithm(*args, **kwargs):
|
| 1715 |
+
global _ALGORITHM_SELECTOR_CACHE
|
| 1716 |
+
if _ALGORITHM_SELECTOR_CACHE is None:
|
| 1717 |
+
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
|
| 1718 |
+
|
| 1719 |
+
if "return_multi_template" not in kwargs:
|
| 1720 |
+
kwargs[
|
| 1721 |
+
"return_multi_template"
|
| 1722 |
+
] = torch._inductor.config.benchmark_epilogue_fusion
|
| 1723 |
+
|
| 1724 |
+
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
|
| 1725 |
+
|
| 1726 |
+
|
| 1727 |
+
def add_feedback_saver(
|
| 1728 |
+
fn: Callable[[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None]
|
| 1729 |
+
):
|
| 1730 |
+
global _ALGORITHM_SELECTOR_CACHE
|
| 1731 |
+
if _ALGORITHM_SELECTOR_CACHE is None:
|
| 1732 |
+
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
|
| 1733 |
+
_ALGORITHM_SELECTOR_CACHE.add_feedback_saver(fn)
|
| 1734 |
+
|
| 1735 |
+
|
| 1736 |
+
def realize_inputs(*args):
|
| 1737 |
+
if len(args) == 1:
|
| 1738 |
+
return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0]))
|
| 1739 |
+
return [realize_inputs(x) for x in args]
|
| 1740 |
+
|
| 1741 |
+
|
| 1742 |
+
# ensure lowering is imported so that `extern_kernels.*` is populated
|
| 1743 |
+
from . import lowering # noqa: F401
|