koichi12 commited on
Commit
d9bcc7f
·
verified ·
1 Parent(s): 7e7bbc5

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torch/_inductor/async_compile.py +297 -0
  2. .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__init__.py +0 -0
  3. .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py +296 -0
  6. .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py +321 -0
  7. .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py +150 -0
  8. .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py +149 -0
  9. .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py +109 -0
  10. .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic.py +315 -0
  11. .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py +339 -0
  12. .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py +119 -0
  13. .venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py +92 -0
  14. .venv/lib/python3.11/site-packages/torch/_inductor/autotune_process.py +876 -0
  15. .venv/lib/python3.11/site-packages/torch/_inductor/codecache.py +0 -0
  16. .venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py +264 -0
  17. .venv/lib/python3.11/site-packages/torch/_inductor/comms.py +640 -0
  18. .venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py +1629 -0
  19. .venv/lib/python3.11/site-packages/torch/_inductor/config.py +1241 -0
  20. .venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py +348 -0
  21. .venv/lib/python3.11/site-packages/torch/_inductor/cpu_vec_isa.py +373 -0
  22. .venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py +0 -0
  23. .venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_utils.py +330 -0
  24. .venv/lib/python3.11/site-packages/torch/_inductor/debug.py +693 -0
  25. .venv/lib/python3.11/site-packages/torch/_inductor/decomposition.py +980 -0
  26. .venv/lib/python3.11/site-packages/torch/_inductor/dependencies.py +745 -0
  27. .venv/lib/python3.11/site-packages/torch/_inductor/exc.py +104 -0
  28. .venv/lib/python3.11/site-packages/torch/_inductor/extern_node_serializer.py +25 -0
  29. .venv/lib/python3.11/site-packages/torch/_inductor/freezing.py +269 -0
  30. .venv/lib/python3.11/site-packages/torch/_inductor/fx_utils.py +251 -0
  31. .venv/lib/python3.11/site-packages/torch/_inductor/graph.py +1930 -0
  32. .venv/lib/python3.11/site-packages/torch/_inductor/hooks.py +30 -0
  33. .venv/lib/python3.11/site-packages/torch/_inductor/index_propagation.py +373 -0
  34. .venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py +179 -0
  35. .venv/lib/python3.11/site-packages/torch/_inductor/ir.py +0 -0
  36. .venv/lib/python3.11/site-packages/torch/_inductor/jagged_lowerings.py +264 -0
  37. .venv/lib/python3.11/site-packages/torch/_inductor/lowering.py +0 -0
  38. .venv/lib/python3.11/site-packages/torch/_inductor/metrics.py +436 -0
  39. .venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_ir.py +1881 -0
  40. .venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_lowerings.py +1087 -0
  41. .venv/lib/python3.11/site-packages/torch/_inductor/package/__init__.py +1 -0
  42. .venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/package.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/torch/_inductor/package/build_package.py +15 -0
  45. .venv/lib/python3.11/site-packages/torch/_inductor/package/package.py +237 -0
  46. .venv/lib/python3.11/site-packages/torch/_inductor/package/pt2_archive_constants.py +16 -0
  47. .venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py +2005 -0
  48. .venv/lib/python3.11/site-packages/torch/_inductor/quantized_lowerings.py +92 -0
  49. .venv/lib/python3.11/site-packages/torch/_inductor/remote_cache.py +198 -0
  50. .venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py +1743 -0
.venv/lib/python3.11/site-packages/torch/_inductor/async_compile.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ import functools
5
+ import logging
6
+ import multiprocessing
7
+ import os
8
+ import sys
9
+ from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
10
+ from concurrent.futures.process import BrokenProcessPool
11
+ from functools import partial
12
+ from time import time
13
+ from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
14
+
15
+ import torch
16
+ from torch._dynamo.device_interface import get_registered_device_interfaces
17
+ from torch._inductor import config
18
+ from torch._inductor.codecache import (
19
+ CodeCacheFuture,
20
+ CppCodeCache,
21
+ CppPythonBindingsCodeCache,
22
+ CUDACodeCache,
23
+ HalideCodeCache,
24
+ LambdaFuture,
25
+ ROCmCodeCache,
26
+ TritonCodeCache,
27
+ TritonFuture,
28
+ )
29
+ from torch._inductor.compile_worker.subproc_pool import (
30
+ _warm_process_pool,
31
+ AnyPool,
32
+ SubprocPool,
33
+ )
34
+ from torch._inductor.compile_worker.watchdog import _async_compile_initializer
35
+ from torch._inductor.runtime.compile_tasks import (
36
+ _set_triton_ptxas_path,
37
+ _worker_compile_triton,
38
+ )
39
+ from torch.hub import _Faketqdm, tqdm
40
+ from torch.utils._triton import has_triton_package
41
+
42
+
43
+ if TYPE_CHECKING:
44
+ from torch._inductor.runtime.hints import HalideMeta
45
+
46
+ # timing metrics for time spent in the compilation
47
+ _cumulative_compile_time = 0.0
48
+ _t0: Optional[float] = None
49
+
50
+ kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
51
+
52
+
53
+ def pre_fork_setup():
54
+ """
55
+ Setup that must be done prior to forking with a process pool.
56
+ """
57
+ # ensure properties have been calculated before processes
58
+ # are forked
59
+ caching_device_properties()
60
+
61
+ # Computing the triton key can be slow. If we call it before fork,
62
+ # it will be cached for the forked subprocesses.
63
+ try:
64
+ from triton.compiler.compiler import triton_key
65
+
66
+ triton_key()
67
+ except ImportError:
68
+ # Triton might not be installed or might be an old version.
69
+ pass
70
+
71
+
72
+ def caching_device_properties():
73
+ for _, device_interface in get_registered_device_interfaces():
74
+ if device_interface.is_available():
75
+ device_interface.Worker.get_device_properties()
76
+
77
+
78
+ def _compile_start() -> None:
79
+ global _t0
80
+ if _t0 is None:
81
+ _t0 = time()
82
+
83
+
84
+ def _compile_end() -> None:
85
+ global _cumulative_compile_time, _t0
86
+ if _t0 is not None:
87
+ t1 = time()
88
+ _cumulative_compile_time += t1 - _t0
89
+ _t0 = None
90
+ # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
91
+
92
+
93
+ _IS_WINDOWS = sys.platform == "win32"
94
+
95
+ log = logging.getLogger(__name__)
96
+
97
+
98
+ # Used to keep track of all process pools invoked so far.
99
+ _pool_set: Set[AnyPool] = set()
100
+
101
+
102
+ def shutdown_compile_workers() -> None:
103
+ """Shut down all outstanding compile-worker pools."""
104
+ for pool in _pool_set:
105
+ pool.shutdown()
106
+ after_fork()
107
+
108
+
109
+ def after_fork():
110
+ """Reset pools to initial state without shutting them down"""
111
+ _pool_set.clear()
112
+ AsyncCompile.process_pool.cache_clear()
113
+
114
+
115
+ try:
116
+ os.register_at_fork(after_in_child=after_fork)
117
+ except AttributeError:
118
+ pass # register_at_fork does not exists on windows
119
+
120
+
121
+ class AsyncCompile:
122
+ def __init__(self) -> None:
123
+ pass
124
+
125
+ @staticmethod
126
+ @functools.lru_cache(1)
127
+ def pool() -> ThreadPoolExecutor:
128
+ assert config.compile_threads > 1
129
+ return ThreadPoolExecutor(config.compile_threads)
130
+
131
+ @staticmethod
132
+ def _get_ready():
133
+ """No-op function to help mark when the subprocess pool is ready."""
134
+ return "ready"
135
+
136
+ @staticmethod
137
+ @functools.lru_cache(1)
138
+ def process_pool() -> AnyPool:
139
+ assert config.compile_threads > 1
140
+ pool: AnyPool
141
+ if config.worker_start_method == "subprocess":
142
+ # Wrapper around ProcessPoolExecutor forks in a new process we control
143
+ pool = SubprocPool(config.compile_threads)
144
+ else:
145
+ pre_fork_setup()
146
+ ctx = multiprocessing.get_context(config.worker_start_method)
147
+ pool = ProcessPoolExecutor(
148
+ config.compile_threads,
149
+ mp_context=ctx,
150
+ initializer=partial(_async_compile_initializer, os.getpid()),
151
+ )
152
+ # when this pool is created in a subprocess object, the normal exit handler
153
+ # doesn't run, and we need to register our own handler.
154
+ # exitpriority has to be high, because another one of the finalizers will
155
+ # kill the worker thread that sends the shutdown message to the workers...
156
+ multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
157
+
158
+ # Set an attribute we can check to see if the pool is ready.
159
+ pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[union-attr]
160
+ _pool_set.add(pool)
161
+ return pool
162
+
163
+ @classmethod
164
+ def warm_pool(cls) -> None:
165
+ if config.compile_threads <= 1:
166
+ return
167
+ _compile_start()
168
+ _warm_process_pool(cls.process_pool(), config.compile_threads)
169
+ _compile_end()
170
+
171
+ @classmethod
172
+ def submit(cls, task: Callable[..., Any]) -> Any:
173
+ if config.compile_threads <= 1:
174
+ return task()
175
+ return cls.pool().submit(task)
176
+
177
+ def _use_process_pool(self):
178
+ return (
179
+ config.compile_threads > 1
180
+ and self.process_pool().ready_future.done() # type: ignore[union-attr]
181
+ )
182
+
183
+ def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
184
+ kernel_code_log.info("Triton Kernel:\n%s", source_code)
185
+ _compile_start()
186
+ _set_triton_ptxas_path()
187
+
188
+ kernel = TritonCodeCache.load(kernel_name, source_code)
189
+ if self._use_process_pool():
190
+ # We want to support changing these env vars after (and while) the
191
+ # process pool is running, so pass them to the subprocess to reset.
192
+ env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
193
+ extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
194
+ return TritonFuture(
195
+ kernel,
196
+ self.process_pool().submit(
197
+ _worker_compile_triton,
198
+ kernel._reload_in_subproc,
199
+ extra_env,
200
+ ),
201
+ )
202
+ else:
203
+ kernel.precompile()
204
+ return kernel
205
+
206
+ def multi_kernel(self, *args, **kwargs) -> Any:
207
+ from torch._inductor.codegen.multi_kernel import MultiKernelCall
208
+
209
+ # no need to call this in parallel since the sub-kernels are already parallel tasks
210
+ return MultiKernelCall(*args, **kwargs)
211
+
212
+ def cpp(self, source_code: str):
213
+ kernel_code_log.info("CPP Kernel:\n%s", source_code)
214
+ if config.compile_threads <= 1:
215
+ return CppCodeCache.load(source_code).kernel
216
+ else:
217
+ get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit)
218
+ return LambdaFuture(lambda: get_result().kernel)
219
+
220
+ def cpp_pybinding(self, argtypes: List[str], source_code: str):
221
+ kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
222
+ if config.compile_threads <= 1:
223
+ return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
224
+ else:
225
+ get_result = CppPythonBindingsCodeCache.load_pybinding_async(
226
+ argtypes, source_code, submit_fn=self.submit
227
+ )
228
+ return LambdaFuture(get_result)
229
+
230
+ def cuda(self, source_code, dst_file_ext):
231
+ kernel_code_log.info("CUDA Kernel:\n%s", source_code)
232
+
233
+ def task():
234
+ return CUDACodeCache.load(source_code, dst_file_ext)[0]
235
+
236
+ return self.submit(task)
237
+
238
+ def rocm(self, source_code, dst_file_ext):
239
+ kernel_code_log.info("ROCm Kernel:\n%s", source_code)
240
+
241
+ def task():
242
+ return ROCmCodeCache.load(source_code, dst_file_ext)[0]
243
+
244
+ return self.submit(task)
245
+
246
+ def halide(self, meta: HalideMeta, source_code: str):
247
+ kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code)
248
+ if config.compile_threads <= 1:
249
+ return HalideCodeCache.generate_halide(meta, source_code)
250
+ else:
251
+ get_result = HalideCodeCache.generate_halide_async(
252
+ meta, source_code, submit_fn=self.submit
253
+ )
254
+ return LambdaFuture(get_result)
255
+
256
+ def wait(self, scope: Dict[str, Any]) -> None:
257
+ num_kernels = len(
258
+ [
259
+ value
260
+ for key, value in scope.items()
261
+ if isinstance(value, (Future, CodeCacheFuture))
262
+ ]
263
+ )
264
+ pbar = tqdm(
265
+ total=num_kernels,
266
+ desc="Inductor Compilation",
267
+ disable=config.disable_progress,
268
+ delay=0,
269
+ )
270
+ if config.compile_threads > 1:
271
+ for key, result in scope.items():
272
+ if config.verbose_progress and not isinstance(pbar, _Faketqdm):
273
+ pbar.set_postfix_str(key)
274
+ if isinstance(result, (Future, CodeCacheFuture)):
275
+ try:
276
+ scope[key] = result.result()
277
+ except BrokenProcessPool as e:
278
+ raise RuntimeError(
279
+ "A compilation subprocess exited unexpectedly. This "
280
+ "is likely due to a crash. To facilitate debugging, "
281
+ "you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 "
282
+ "to cause compilation to occur in the main process."
283
+ ) from e
284
+ pbar.update(1)
285
+
286
+ _compile_end()
287
+
288
+
289
+ if (
290
+ os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
291
+ or os.environ.get("TORCH_WARM_POOL", "1") != "1"
292
+ # The subprocess pool is only used for the Triton backend
293
+ or not has_triton_package()
294
+ ):
295
+ pass
296
+ else:
297
+ AsyncCompile.warm_pool()
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (202 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-311.pyc ADDED
Binary file (19 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
5
+ from typing import List, Optional, Tuple
6
+
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ Choice,
11
+ )
12
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
13
+ LearnedHeuristicDecision,
14
+ )
15
+
16
+
17
+ class MMRankingA100(LearnedHeuristicDecision):
18
+
19
+ def __init__(self) -> None:
20
+ self.choices: List[Choice] = []
21
+ self.fill_choices()
22
+
23
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
24
+ return (
25
+ metadata.name == self.get_name()
26
+ and metadata.shared_memory == 166912
27
+ and str(metadata.device_capa) == "(8, 0)"
28
+ )
29
+
30
+ def get_confidence_threshold(self) -> float:
31
+ return 0.0
32
+
33
+ def get_choice(self, idx: int) -> Optional[str]:
34
+ if idx < len(self.choices):
35
+ return self.choices[idx]
36
+ return None
37
+
38
+ def fill_choices(self) -> None:
39
+ self.choices.append('extern_mm')
40
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
41
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
42
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
43
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
44
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
45
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
46
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
47
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
48
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
49
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
50
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
51
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
52
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
53
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
54
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
55
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
56
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
57
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
58
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
59
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
60
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
61
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
62
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
63
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
64
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
65
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
66
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
67
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
68
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
69
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
70
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
71
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
72
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
73
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
74
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
75
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
76
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
77
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
78
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
79
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
80
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
81
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
82
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
83
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
84
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
85
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
86
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
87
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
88
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
89
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
90
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
91
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
92
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
93
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
94
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
95
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
96
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
97
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
98
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
99
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
100
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
101
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
102
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
103
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
104
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
105
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
106
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
107
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
108
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
109
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
110
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
111
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
112
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
113
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
114
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
115
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
116
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
117
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
118
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
119
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
120
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
121
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
122
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
123
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
124
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
125
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
126
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
127
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
128
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
129
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
130
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
131
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
132
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
133
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
134
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
135
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
136
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
137
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
138
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
139
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
140
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
141
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
142
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
143
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
144
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
145
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
146
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
147
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
148
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
149
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=2')
150
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
151
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
152
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
153
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
154
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
155
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
156
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
157
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
158
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
159
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
160
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
161
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
162
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
163
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
164
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
165
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
166
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
167
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
168
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
169
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
170
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
171
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
172
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
173
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
174
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
175
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
176
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
177
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
178
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
179
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
180
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
181
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
182
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
183
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
184
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
185
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
186
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
187
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
188
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
189
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
190
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
191
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
192
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
193
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
194
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
195
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
196
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
197
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
198
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
199
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
200
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
201
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
202
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
203
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
204
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
205
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
206
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
207
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
208
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
209
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
210
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
211
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
212
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
213
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
214
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
215
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
216
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
217
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
218
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
219
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
220
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
221
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
222
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
223
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
224
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
225
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
226
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
227
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
228
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
229
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
230
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
231
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
232
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
233
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
234
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
235
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
236
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
237
+
238
+ def get_name(self) -> str:
239
+ return 'mm'
240
+
241
+ def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
242
+ if context.get_value('arith_intensity') <= 52.6245059967041:
243
+ if context.get_value('n') <= 34.0:
244
+ if context.get_value('n') <= 18.0:
245
+ if context.get_value('k*n') <= 312.0:
246
+ return [(0.093, 12), (0.081, 16), (0.081, 148), (0.070, 10), (0.070, 17), (0.070, 149), (0.070, 151), (0.070, 150), (0.070, 14), (0.058, 11), (0.058, 15), (0.058, 13), (0.058, 122), (0.047, 121), (0.035, 123), (0.012, 92)]
247
+ else:
248
+ if context.get_value('k') <= 40.0:
249
+ return [(0.083, 42), (0.083, 46), (0.083, 44), (0.083, 40), (0.083, 128), (0.067, 45), (0.067, 43), (0.067, 41), (0.067, 169), (0.067, 171), (0.067, 168), (0.067, 129), (0.067, 170), (0.033, 103), (0.017, 121)]
250
+ else:
251
+ return [(0.112, 137), (0.104, 136), (0.101, 0), (0.081, 1), (0.073, 135), (0.069, 67), (0.066, 187), (0.058, 41), (0.050, 71), (0.046, 68), (0.046, 70), (0.031, 44), (0.027, 43), (0.027, 170), (0.019, 189), (0.019, 188), (0.015, 169), (0.015, 171), (0.012, 115), (0.012, 168), (0.012, 69), (0.004, 103)]
252
+ else:
253
+ if context.get_value('mat1_stride_0') <= 20.0:
254
+ return [(0.069, 0), (0.059, 157), (0.059, 22), (0.059, 153), (0.059, 155), (0.059, 25), (0.059, 23), (0.059, 19), (0.044, 21), (0.044, 18), (0.044, 152), (0.044, 158), (0.044, 154), (0.044, 156), (0.044, 20), (0.044, 124), (0.044, 24), (0.030, 125), (0.029, 126), (0.015, 97), (0.015, 95), (0.015, 96), (0.010, 2), (0.010, 75)]
255
+ else:
256
+ if context.get_value('k') <= 68.0:
257
+ return [(0.087, 72), (0.087, 74), (0.087, 73), (0.086, 76), (0.077, 75), (0.067, 192), (0.058, 190), (0.048, 47), (0.048, 193), (0.048, 49), (0.048, 51), (0.048, 191), (0.038, 53), (0.019, 133), (0.019, 50), (0.019, 175), (0.019, 172), (0.019, 48), (0.019, 174), (0.010, 173), (0.010, 177), (0.010, 52), (0.010, 54), (0.010, 178), (0.010, 176)]
258
+ else:
259
+ return [(0.154, 52), (0.154, 72), (0.102, 75), (0.087, 49), (0.087, 73), (0.086, 51), (0.057, 176), (0.045, 2), (0.038, 191), (0.038, 178), (0.038, 190), (0.029, 173), (0.029, 76), (0.026, 138), (0.013, 139), (0.013, 140), (0.003, 0)]
260
+ else:
261
+ if context.get_value('k') <= 35.0:
262
+ if context.get_value('k') <= 18.0:
263
+ if context.get_value('m*n') <= 19505152.0:
264
+ return [(0.151, 159), (0.140, 160), (0.129, 164), (0.055, 127), (0.051, 29), (0.044, 161), (0.044, 147), (0.040, 146), (0.040, 31), (0.037, 145), (0.026, 28), (0.022, 90), (0.022, 93), (0.022, 94), (0.022, 100), (0.022, 125), (0.022, 158), (0.022, 157), (0.011, 87), (0.011, 88), (0.011, 89), (0.011, 91), (0.011, 95), (0.011, 96), (0.011, 98), (0.011, 99)]
265
+ else:
266
+ return [(0.069, 7), (0.069, 5), (0.067, 147), (0.066, 8), (0.061, 145), (0.058, 146), (0.052, 124), (0.049, 29), (0.049, 159), (0.046, 31), (0.043, 157), (0.041, 9), (0.041, 4), (0.040, 6), (0.035, 164), (0.035, 160), (0.026, 158), (0.017, 125), (0.017, 28), (0.017, 32), (0.017, 162), (0.017, 27), (0.017, 30), (0.017, 161), (0.009, 33), (0.009, 26), (0.009, 163), (0.006, 0)]
267
+ else:
268
+ if context.get_value('n') <= 68.0:
269
+ return [(0.101, 182), (0.101, 59), (0.088, 57), (0.076, 184), (0.076, 61), (0.076, 179), (0.076, 62), (0.076, 58), (0.063, 180), (0.063, 60), (0.051, 56), (0.050, 181), (0.025, 130), (0.025, 177), (0.025, 183), (0.013, 178), (0.013, 55)]
270
+ else:
271
+ return [(0.089, 180), (0.079, 60), (0.066, 35), (0.066, 181), (0.066, 38), (0.066, 58), (0.066, 179), (0.066, 57), (0.062, 184), (0.053, 37), (0.044, 166), (0.040, 55), (0.040, 39), (0.040, 36), (0.040, 165), (0.040, 167), (0.027, 177), (0.027, 34), (0.022, 159)]
272
+ else:
273
+ if context.get_value('m*n') <= 309760.0:
274
+ return [(0.298, 0), (0.097, 140), (0.080, 83), (0.072, 86), (0.044, 84), (0.036, 178), (0.036, 117), (0.036, 82), (0.032, 120), (0.032, 85), (0.028, 119), (0.024, 130), (0.024, 109), (0.020, 108), (0.020, 118), (0.012, 104), (0.012, 116), (0.012, 141), (0.012, 144), (0.008, 105), (0.008, 106), (0.008, 111), (0.008, 114), (0.008, 107), (0.008, 132), (0.004, 101), (0.004, 102), (0.004, 110), (0.004, 112), (0.004, 113), (0.004, 131)]
275
+ else:
276
+ if context.get_value('n') <= 72.0:
277
+ return [(0.227, 77), (0.118, 78), (0.102, 194), (0.086, 80), (0.059, 57), (0.054, 81), (0.049, 196), (0.048, 197), (0.048, 59), (0.043, 79), (0.032, 195), (0.027, 180), (0.022, 3), (0.021, 141), (0.016, 60), (0.016, 142), (0.011, 183), (0.011, 0), (0.011, 144)]
278
+ else:
279
+ return [(0.140, 186), (0.132, 185), (0.109, 63), (0.085, 65), (0.078, 37), (0.077, 35), (0.062, 197), (0.047, 194), (0.046, 165), (0.046, 57), (0.039, 78), (0.039, 79), (0.039, 66), (0.039, 64), (0.016, 195), (0.008, 159)]
280
+ else:
281
+ if str(context.get_value('using_tf32')) != 'False':
282
+ if context.get_value('m*n') <= 815360.0:
283
+ if context.get_value('k') <= 1184.0:
284
+ return [(0.218, 140), (0.205, 0), (0.154, 144), (0.115, 141), (0.051, 185), (0.051, 104), (0.039, 78), (0.038, 116), (0.026, 165), (0.026, 130), (0.026, 178), (0.013, 57), (0.013, 195), (0.013, 167), (0.013, 186)]
285
+ else:
286
+ return [(0.901, 0), (0.030, 144), (0.030, 134), (0.016, 3), (0.006, 78), (0.006, 77), (0.002, 57), (0.002, 194), (0.002, 59), (0.002, 60), (0.002, 143)]
287
+ else:
288
+ if context.get_value('arith_intensity') <= 187.23922729492188:
289
+ if context.get_value('mat1_stride_0') <= 198.0:
290
+ return [(0.273, 63), (0.158, 37), (0.152, 35), (0.127, 57), (0.097, 165), (0.053, 185), (0.031, 0), (0.028, 64), (0.014, 60), (0.014, 78), (0.009, 55), (0.008, 134), (0.005, 34), (0.005, 167), (0.005, 179), (0.005, 65), (0.005, 66), (0.005, 186), (0.005, 194), (0.002, 166)]
291
+ else:
292
+ return [(0.296, 63), (0.235, 0), (0.132, 64), (0.074, 37), (0.069, 78), (0.051, 185), (0.051, 35), (0.030, 57), (0.020, 77), (0.016, 194), (0.008, 66), (0.007, 65), (0.003, 3), (0.003, 165), (0.003, 141), (0.001, 134), (0.001, 166)]
293
+ else:
294
+ return [(0.405, 0), (0.246, 37), (0.177, 63), (0.145, 35), (0.005, 185), (0.005, 65), (0.005, 64), (0.004, 57), (0.003, 66), (0.002, 165), (0.001, 78), (0.001, 55)]
295
+ else:
296
+ return [(0.357, 0), (0.112, 165), (0.101, 57), (0.094, 179), (0.086, 64), (0.074, 167), (0.067, 60), (0.064, 159), (0.033, 35), (0.007, 195), (0.002, 180), (0.001, 34), (0.001, 166), (0.001, 78)]
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
5
+ from typing import List, Optional, Tuple
6
+
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ Choice,
11
+ )
12
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
13
+ LearnedHeuristicDecision,
14
+ )
15
+
16
+
17
+ class MMRankingH100(LearnedHeuristicDecision):
18
+
19
+ def __init__(self) -> None:
20
+ self.choices: List[Choice] = []
21
+ self.fill_choices()
22
+
23
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
24
+ return (
25
+ metadata.name == self.get_name()
26
+ and metadata.shared_memory == 232448
27
+ and str(metadata.device_capa) == "(9, 0)"
28
+ )
29
+
30
+ def get_confidence_threshold(self) -> float:
31
+ return 0.0
32
+
33
+ def get_choice(self, idx: int) -> Optional[str]:
34
+ if idx < len(self.choices):
35
+ return self.choices[idx]
36
+ return None
37
+
38
+ def fill_choices(self) -> None:
39
+ self.choices.append('extern_mm')
40
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
41
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
42
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
43
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
44
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
45
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
46
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
47
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
48
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
49
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
50
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
51
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
52
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
53
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
54
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
55
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
56
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
57
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
58
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
59
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
60
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
61
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
62
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
63
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
64
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
65
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
66
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
67
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
68
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
69
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
70
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
71
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
72
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
73
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
74
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
75
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
76
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
77
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
78
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
79
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
80
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
81
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
82
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
83
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
84
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
85
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
86
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
87
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
88
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
89
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
90
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
91
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
92
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
93
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
94
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
95
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
96
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
97
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
98
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
99
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
100
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
101
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
102
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
103
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
104
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
105
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
106
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
107
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
108
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
109
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
110
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
111
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
112
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
113
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
114
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
115
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
116
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
117
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
118
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
119
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
120
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
121
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
122
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
123
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
124
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
125
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
126
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
127
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
128
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
129
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
130
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
131
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
132
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
133
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=1')
134
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=1')
135
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
136
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
137
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
138
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
139
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
140
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
141
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
142
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
143
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
144
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
145
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
146
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
147
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
148
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
149
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
150
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=1')
151
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
152
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=2')
153
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
154
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
155
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
156
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
157
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
158
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
159
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
160
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
161
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
162
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
163
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
164
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
165
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
166
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
167
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
168
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
169
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
170
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
171
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
172
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
173
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
174
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
175
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
176
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
177
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
178
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=16_numstages=2_numwarps=2')
179
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
180
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
181
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
182
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
183
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
184
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
185
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
186
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
187
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
188
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
189
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
190
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
191
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
192
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
193
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
194
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
195
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
196
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
197
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
198
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
199
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
200
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
201
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
202
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
203
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
204
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
205
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
206
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
207
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
208
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
209
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
210
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
211
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
212
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
213
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
214
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
215
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
216
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
217
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
218
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
219
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
220
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
221
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
222
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
223
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
224
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
225
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
226
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
227
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
228
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
229
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
230
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
231
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
232
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
233
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
234
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
235
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
236
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
237
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
238
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
239
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
240
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
241
+
242
+ def get_name(self) -> str:
243
+ return 'mm'
244
+
245
+ def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
246
+ if context.get_value('arith_intensity') <= 29.89772129058838:
247
+ if context.get_value('n') <= 34.0:
248
+ if context.get_value('n') <= 18.0:
249
+ if context.get_value('k*n') <= 432.0:
250
+ if context.get_value('arith_intensity') <= 7.8700292110443115:
251
+ return [(0.098, 128), (0.098, 129), (0.098, 127), (0.073, 14), (0.073, 16), (0.073, 12), (0.073, 154), (0.073, 156), (0.073, 157), (0.073, 155), (0.049, 10), (0.049, 94), (0.049, 95), (0.048, 96)]
252
+ else:
253
+ return [(0.091, 154), (0.073, 10), (0.073, 15), (0.073, 13), (0.073, 11), (0.073, 17), (0.073, 16), (0.073, 14), (0.073, 12), (0.055, 127), (0.054, 157), (0.054, 156), (0.054, 155), (0.036, 129), (0.036, 128), (0.018, 41), (0.018, 43)]
254
+ else:
255
+ if context.get_value('k') <= 40.0:
256
+ return [(0.070, 39), (0.069, 45), (0.069, 41), (0.069, 43), (0.069, 111), (0.069, 112), (0.056, 38), (0.056, 40), (0.056, 42), (0.056, 44), (0.056, 174), (0.056, 173), (0.056, 175), (0.056, 134), (0.056, 172), (0.056, 135), (0.014, 154), (0.014, 127)]
257
+ else:
258
+ return [(0.147, 144), (0.119, 143), (0.087, 142), (0.083, 0), (0.073, 191), (0.059, 69), (0.050, 67), (0.046, 70), (0.041, 1), (0.036, 174), (0.032, 43), (0.032, 123), (0.028, 40), (0.027, 42), (0.027, 173), (0.023, 175), (0.018, 66), (0.014, 192), (0.014, 193), (0.014, 139), (0.014, 68), (0.014, 127)]
259
+ else:
260
+ if context.get_value('mat1_stride_0') <= 40.0:
261
+ if context.get_value('mat1_stride_0') <= 20.0:
262
+ return [(0.109, 23), (0.109, 21), (0.109, 20), (0.088, 0), (0.087, 131), (0.066, 18), (0.065, 130), (0.065, 132), (0.065, 159), (0.065, 160), (0.065, 161), (0.065, 158), (0.022, 22), (0.022, 19)]
263
+ else:
264
+ return [(0.065, 46), (0.064, 52), (0.064, 50), (0.064, 48), (0.064, 51), (0.064, 49), (0.064, 47), (0.064, 53), (0.064, 181), (0.064, 177), (0.064, 179), (0.064, 176), (0.038, 130), (0.038, 136), (0.026, 182), (0.026, 178), (0.026, 180), (0.026, 137), (0.025, 158), (0.013, 114), (0.013, 113)]
265
+ else:
266
+ if context.get_value('mat1_stride_0') <= 68.0:
267
+ return [(0.138, 140), (0.125, 195), (0.100, 71), (0.100, 74), (0.100, 196), (0.100, 194), (0.100, 197), (0.075, 75), (0.062, 72), (0.062, 73), (0.012, 180), (0.012, 51), (0.012, 182)]
268
+ else:
269
+ return [(0.124, 180), (0.124, 182), (0.114, 75), (0.103, 74), (0.093, 51), (0.093, 71), (0.072, 72), (0.062, 194), (0.052, 145), (0.052, 195), (0.021, 48), (0.021, 50), (0.021, 47), (0.020, 124), (0.010, 147), (0.010, 146), (0.010, 46)]
270
+ else:
271
+ if context.get_value('k') <= 18.0:
272
+ if context.get_value('m*k') <= 528.0:
273
+ return [(0.097, 88), (0.087, 92), (0.077, 90), (0.058, 105), (0.058, 103), (0.058, 104), (0.058, 99), (0.058, 100), (0.058, 106), (0.058, 93), (0.057, 91), (0.057, 97), (0.057, 98), (0.057, 101), (0.048, 102), (0.029, 87), (0.029, 89)]
274
+ else:
275
+ if context.get_value('n') <= 80.0:
276
+ return [(0.057, 161), (0.057, 130), (0.057, 24), (0.056, 164), (0.056, 163), (0.056, 166), (0.056, 168), (0.056, 30), (0.056, 28), (0.056, 26), (0.056, 25), (0.056, 27), (0.056, 29), (0.056, 31), (0.042, 131), (0.028, 99), (0.028, 101), (0.028, 100), (0.028, 167), (0.028, 165), (0.028, 133)]
277
+ else:
278
+ return [(0.110, 164), (0.108, 163), (0.106, 168), (0.069, 161), (0.066, 151), (0.060, 152), (0.055, 165), (0.050, 27), (0.050, 29), (0.048, 131), (0.043, 153), (0.037, 133), (0.037, 130), (0.028, 8), (0.028, 5), (0.027, 7), (0.026, 26), (0.016, 162), (0.012, 9), (0.007, 4), (0.005, 100), (0.005, 6), (0.005, 24)]
279
+ else:
280
+ if context.get_value('k') <= 36.0:
281
+ if context.get_value('n') <= 68.0:
282
+ return [(0.097, 184), (0.097, 56), (0.086, 186), (0.086, 183), (0.086, 188), (0.086, 58), (0.086, 60), (0.065, 54), (0.043, 187), (0.043, 185), (0.043, 57), (0.043, 61), (0.032, 55), (0.032, 130), (0.032, 59), (0.011, 181), (0.011, 163), (0.011, 136), (0.011, 138)]
283
+ else:
284
+ return [(0.117, 184), (0.117, 170), (0.117, 169), (0.107, 183), (0.106, 188), (0.075, 181), (0.064, 130), (0.064, 56), (0.053, 171), (0.032, 57), (0.032, 59), (0.032, 185), (0.011, 163), (0.011, 32), (0.011, 37), (0.011, 34), (0.011, 33), (0.011, 35), (0.011, 36), (0.011, 54)]
285
+ else:
286
+ if context.get_value('mat2_stride_0') <= 384.0:
287
+ return [(0.244, 0), (0.061, 76), (0.061, 79), (0.030, 3), (0.030, 183), (0.030, 189), (0.030, 187), (0.030, 64), (0.030, 190), (0.030, 62), (0.030, 198), (0.030, 201), (0.030, 77), (0.030, 200), (0.030, 80), (0.030, 199), (0.030, 78), (0.030, 184), (0.020, 86), (0.020, 84), (0.020, 120), (0.020, 81), (0.020, 121), (0.020, 85), (0.020, 122), (0.010, 83), (0.010, 118), (0.010, 119), (0.010, 82)]
288
+ else:
289
+ return [(0.274, 83), (0.171, 86), (0.152, 0), (0.071, 85), (0.061, 125), (0.050, 84), (0.020, 109), (0.020, 117), (0.020, 81), (0.020, 118), (0.020, 121), (0.020, 108), (0.020, 115), (0.020, 116), (0.010, 110), (0.010, 120), (0.010, 103), (0.010, 107), (0.010, 119), (0.010, 122)]
290
+ else:
291
+ if context.get_value('arith_intensity') <= 56.995582580566406:
292
+ if context.get_value('n') <= 68.0:
293
+ if context.get_value('k*n') <= 4448.0:
294
+ if context.get_value('m*n') <= 29626368.0:
295
+ return [(0.107, 198), (0.107, 200), (0.107, 201), (0.107, 199), (0.106, 76), (0.106, 79), (0.064, 197), (0.063, 56), (0.043, 184), (0.043, 187), (0.042, 80), (0.042, 77), (0.042, 183), (0.021, 78)]
296
+ else:
297
+ return [(0.073, 201), (0.073, 198), (0.073, 200), (0.073, 199), (0.073, 197), (0.073, 56), (0.073, 58), (0.073, 79), (0.073, 76), (0.072, 59), (0.072, 78), (0.072, 77), (0.072, 80), (0.018, 184), (0.018, 55), (0.018, 54)]
298
+ else:
299
+ if context.get_value('k') <= 348.0:
300
+ return [(0.206, 76), (0.183, 77), (0.169, 198), (0.160, 199), (0.053, 59), (0.046, 56), (0.038, 3), (0.030, 148), (0.030, 58), (0.030, 187), (0.023, 184), (0.015, 0), (0.008, 55), (0.008, 54)]
301
+ else:
302
+ return [(0.146, 198), (0.145, 199), (0.145, 148), (0.126, 0), (0.084, 76), (0.084, 77), (0.042, 80), (0.042, 79), (0.021, 149), (0.021, 150), (0.021, 3), (0.014, 46), (0.014, 74), (0.014, 75), (0.014, 124), (0.014, 194), (0.014, 195), (0.007, 145), (0.007, 146), (0.007, 2), (0.007, 72), (0.007, 147), (0.007, 71)]
303
+ else:
304
+ if context.get_value('m') <= 3264.0:
305
+ return [(0.247, 147), (0.115, 197), (0.066, 199), (0.066, 201), (0.066, 198), (0.049, 0), (0.049, 169), (0.049, 171), (0.033, 140), (0.033, 125), (0.033, 114), (0.016, 126), (0.016, 183), (0.016, 184), (0.016, 185), (0.016, 182), (0.016, 188), (0.016, 78), (0.016, 148), (0.016, 138), (0.016, 77), (0.016, 56), (0.016, 59)]
306
+ else:
307
+ if context.get_value('k') <= 62.5:
308
+ return [(0.226, 190), (0.226, 189), (0.122, 62), (0.122, 64), (0.055, 77), (0.055, 78), (0.037, 198), (0.036, 201), (0.036, 33), (0.024, 163), (0.018, 56), (0.018, 35), (0.018, 169), (0.006, 171)]
309
+ else:
310
+ return [(0.162, 35), (0.118, 33), (0.096, 189), (0.096, 190), (0.088, 169), (0.074, 62), (0.073, 56), (0.066, 171), (0.051, 198), (0.051, 201), (0.044, 59), (0.037, 64), (0.029, 63), (0.007, 0), (0.007, 77)]
311
+ else:
312
+ if context.get_value('m*n') <= 1097728.0:
313
+ return [(0.403, 0), (0.179, 141), (0.134, 150), (0.086, 147), (0.051, 148), (0.048, 3), (0.024, 189), (0.020, 199), (0.017, 64), (0.010, 65), (0.010, 77), (0.007, 114), (0.003, 138), (0.003, 59), (0.003, 182)]
314
+ else:
315
+ if context.get_value('m*n') <= 3244032.0:
316
+ return [(0.295, 189), (0.176, 64), (0.157, 65), (0.090, 0), (0.069, 62), (0.059, 63), (0.046, 77), (0.039, 169), (0.023, 199), (0.020, 35), (0.013, 33), (0.010, 171), (0.003, 141)]
317
+ else:
318
+ if context.get_value('n') <= 136.0:
319
+ return [(0.197, 189), (0.197, 63), (0.161, 77), (0.157, 62), (0.061, 33), (0.044, 65), (0.039, 35), (0.039, 64), (0.030, 169), (0.026, 0), (0.017, 199), (0.017, 148), (0.009, 56), (0.004, 3)]
320
+ else:
321
+ return [(0.460, 0), (0.145, 62), (0.138, 63), (0.081, 35), (0.047, 33), (0.043, 189), (0.023, 64), (0.018, 77), (0.013, 169), (0.009, 65), (0.009, 56), (0.005, 32), (0.005, 59), (0.002, 183), (0.002, 163)]
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
5
+ from typing import List, Optional, Tuple
6
+
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ Choice,
11
+ )
12
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
13
+ LearnedHeuristicDecision,
14
+ )
15
+
16
+
17
+ class MixedMMA100(LearnedHeuristicDecision):
18
+
19
+ def __init__(self) -> None:
20
+ self.choices: List[Choice] = []
21
+ self.fill_choices()
22
+
23
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
24
+ return (
25
+ metadata.name == self.get_name()
26
+ and metadata.shared_memory == 166912
27
+ and str(metadata.device_capa) == "(8, 0)"
28
+ )
29
+
30
+ def get_confidence_threshold(self) -> float:
31
+ return 0.0
32
+
33
+ def get_choice(self, idx: int) -> Optional[str]:
34
+ if idx < len(self.choices):
35
+ return self.choices[idx]
36
+ return None
37
+
38
+ def fill_choices(self) -> None:
39
+ self.choices.append('extern_fallback_mixed_mm')
40
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
41
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
42
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
43
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
44
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
45
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
46
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
47
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
48
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
49
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
50
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
51
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
52
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
53
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
54
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
55
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
56
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
57
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
58
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
59
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
60
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
61
+
62
+ def get_name(self) -> str:
63
+ return 'mixed_mm'
64
+
65
+ def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
66
+ if str(context.get_value('1LEQmLEQ16')) != 'True':
67
+ if context.get_value('m') <= 32.5:
68
+ if context.get_value('n') <= 6976.0:
69
+ if context.get_value('n') <= 3520.0:
70
+ if context.get_value('m*n') <= 37632.0:
71
+ return None
72
+ else:
73
+ return [(1.000, 13)]
74
+ else:
75
+ if context.get_value('m*k') <= 452352.0:
76
+ return [(0.590, 13), (0.256, 8), (0.103, 7), (0.051, 11)]
77
+ else:
78
+ return [(0.778, 8), (0.222, 13)]
79
+ else:
80
+ if context.get_value('k*n') <= 102776832.0:
81
+ if context.get_value('n') <= 14656.0:
82
+ return [(1.000, 11)]
83
+ else:
84
+ return [(0.889, 11), (0.111, 13)]
85
+ else:
86
+ return [(1.000, 11)]
87
+ else:
88
+ if context.get_value('m*n') <= 446464.0:
89
+ if context.get_value('m*n') <= 223424.0:
90
+ if context.get_value('mat1_stride_0') <= 3968.0:
91
+ return None
92
+ else:
93
+ return None
94
+ else:
95
+ if context.get_value('m*n') <= 346112.0:
96
+ return [(0.960, 16), (0.040, 7)]
97
+ else:
98
+ return [(0.750, 16), (0.136, 14), (0.114, 7)]
99
+ else:
100
+ if str(context.get_value('33LEQmLEQ64')) != 'True':
101
+ if context.get_value('n') <= 6976.0:
102
+ return [(1.000, 14)]
103
+ else:
104
+ return [(0.753, 2), (0.222, 1), (0.015, 7), (0.007, 16), (0.004, 12)]
105
+ else:
106
+ if context.get_value('n') <= 13888.0:
107
+ return [(0.710, 14), (0.275, 21), (0.014, 12)]
108
+ else:
109
+ return [(0.374, 19), (0.339, 20), (0.106, 21), (0.101, 16), (0.066, 17), (0.009, 14), (0.004, 18)]
110
+ else:
111
+ if context.get_value('n') <= 3520.0:
112
+ if context.get_value('arith_intensity') <= 3.994754433631897:
113
+ if str(context.get_value('mat2_dtype')) != 'torch.uint8':
114
+ if context.get_value('m*k') <= 18944.0:
115
+ return [(0.577, 5), (0.423, 6)]
116
+ else:
117
+ return [(0.988, 5), (0.012, 6)]
118
+ else:
119
+ if context.get_value('arith_intensity') <= 2.9899919033050537:
120
+ return None
121
+ else:
122
+ return None
123
+ else:
124
+ if context.get_value('arith_intensity') <= 7.956453561782837:
125
+ if context.get_value('k*n') <= 9244032.0:
126
+ return [(0.822, 5), (0.178, 6)]
127
+ else:
128
+ return [(0.977, 5), (0.023, 0)]
129
+ else:
130
+ if context.get_value('m*k') <= 978944.0:
131
+ return [(1.000, 5)]
132
+ else:
133
+ return [(0.971, 5), (0.029, 0)]
134
+ else:
135
+ if context.get_value('n') <= 13632.0:
136
+ if context.get_value('n') <= 6976.0:
137
+ return [(1.000, 6)]
138
+ else:
139
+ if context.get_value('k') <= 3968.0:
140
+ return [(0.617, 3), (0.111, 5), (0.099, 7), (0.086, 9), (0.062, 6), (0.025, 8)]
141
+ else:
142
+ return [(0.779, 8), (0.119, 5), (0.053, 7), (0.035, 6), (0.013, 3)]
143
+ else:
144
+ if context.get_value('k*n') <= 39518208.0:
145
+ return [(0.385, 4), (0.327, 3), (0.192, 6), (0.038, 7), (0.038, 10), (0.019, 5)]
146
+ else:
147
+ if context.get_value('n') <= 20800.0:
148
+ return [(0.821, 6), (0.121, 7), (0.029, 4), (0.014, 5), (0.007, 3), (0.007, 8)]
149
+ else:
150
+ return [(0.530, 7), (0.386, 6), (0.046, 8), (0.021, 3), (0.015, 4), (0.002, 5)]
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
5
+ from typing import List, Optional, Tuple
6
+
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ Choice,
11
+ )
12
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
13
+ LearnedHeuristicDecision,
14
+ )
15
+
16
+
17
+ class MixedMMH100(LearnedHeuristicDecision):
18
+
19
+ def __init__(self) -> None:
20
+ self.choices: List[Choice] = []
21
+ self.fill_choices()
22
+
23
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
24
+ return (
25
+ metadata.name == self.get_name()
26
+ and metadata.shared_memory == 232448
27
+ and str(metadata.device_capa) == "(9, 0)"
28
+ )
29
+
30
+ def get_confidence_threshold(self) -> float:
31
+ return 0.0
32
+
33
+ def get_choice(self, idx: int) -> Optional[str]:
34
+ if idx < len(self.choices):
35
+ return self.choices[idx]
36
+ return None
37
+
38
+ def fill_choices(self) -> None:
39
+ self.choices.append('extern_fallback_mixed_mm')
40
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
41
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
42
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
43
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
44
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
45
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
46
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
47
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
48
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
49
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
50
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
51
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
52
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
53
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
54
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
55
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
56
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
57
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
58
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
59
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
60
+
61
+ def get_name(self) -> str:
62
+ return 'mixed_mm'
63
+
64
+ def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
65
+ if context.get_value('arith_intensity') <= 15.988086223602295:
66
+ if context.get_value('n') <= 25280.0:
67
+ if context.get_value('n') <= 1344.0:
68
+ if context.get_value('mat1_stride_0') <= 7808.0:
69
+ return [(0.581, 7), (0.419, 6)]
70
+ else:
71
+ if context.get_value('m*n') <= 7680.0:
72
+ return [(0.875, 0), (0.125, 6)]
73
+ else:
74
+ return [(0.833, 0), (0.167, 7)]
75
+ else:
76
+ if context.get_value('n') <= 8512.0:
77
+ if str(context.get_value('mat2_dtype')) != 'torch.int8':
78
+ return [(0.763, 6), (0.237, 7)]
79
+ else:
80
+ return [(0.725, 7), (0.275, 6)]
81
+ else:
82
+ if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
83
+ return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)]
84
+ else:
85
+ return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)]
86
+ else:
87
+ if context.get_value('n') <= 42254.0:
88
+ if context.get_value('n') <= 33856.0:
89
+ if context.get_value('k*n') <= 68157440.0:
90
+ return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)]
91
+ else:
92
+ return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)]
93
+ else:
94
+ return [(0.659, 5), (0.341, 6)]
95
+ else:
96
+ if context.get_value('k*n') <= 326052992.0:
97
+ if context.get_value('n') <= 55232.0:
98
+ return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)]
99
+ else:
100
+ return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)]
101
+ else:
102
+ if context.get_value('n') <= 57024.0:
103
+ return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)]
104
+ else:
105
+ return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)]
106
+ else:
107
+ if context.get_value('m*n') <= 543936.0:
108
+ if str(context.get_value('17LEQmLEQ32')) != 'True':
109
+ if context.get_value('m*n') <= 262272.0:
110
+ if context.get_value('n') <= 1592.5:
111
+ return [(0.860, 0), (0.140, 9)]
112
+ else:
113
+ return None
114
+ else:
115
+ if context.get_value('m*k') <= 1294336.0:
116
+ return [(0.833, 17), (0.150, 18), (0.017, 15)]
117
+ else:
118
+ return [(0.917, 17), (0.083, 8)]
119
+ else:
120
+ if context.get_value('n') <= 12416.0:
121
+ if context.get_value('m*n') <= 43008.0:
122
+ return None
123
+ else:
124
+ return [(0.853, 14), (0.147, 9)]
125
+ else:
126
+ return [(0.625, 12), (0.375, 14)]
127
+ else:
128
+ if context.get_value('m') <= 32.5:
129
+ if context.get_value('mat2_stride_1') <= 6656.0:
130
+ if context.get_value('n') <= 69184.0:
131
+ return [(0.611, 12), (0.361, 14), (0.028, 13)]
132
+ else:
133
+ return [(1.000, 12)]
134
+ else:
135
+ if context.get_value('mat2_stride_1') <= 20864.0:
136
+ return [(1.000, 12)]
137
+ else:
138
+ return [(0.958, 12), (0.042, 9)]
139
+ else:
140
+ if context.get_value('m*n') <= 1085440.0:
141
+ if context.get_value('n') <= 9152.0:
142
+ return [(1.000, 18)]
143
+ else:
144
+ return [(0.780, 18), (0.160, 16), (0.060, 20)]
145
+ else:
146
+ if context.get_value('m') <= 67.0:
147
+ return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)]
148
+ else:
149
+ return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)]
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/
5
+ from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL
6
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
7
+ LearnedHeuristicRegression,
8
+ )
9
+
10
+
11
+ class PadMMA100(LearnedHeuristicRegression):
12
+
13
+ def __init__(self) -> None:
14
+ pass
15
+
16
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
17
+ return (
18
+ metadata.name == self.get_name()
19
+ and metadata.shared_memory == 166912
20
+ and str(metadata.device_capa) == "(8, 0)"
21
+ )
22
+
23
+ def get_feedback(self, context: AHContext, choice: Choice) -> float:
24
+ context.context_dict[CHOICE_COL] = choice
25
+ return self.predict(context)
26
+
27
+ def get_confidence_threshold(self) -> float:
28
+ return 1.7025303314066
29
+
30
+ def get_name(self) -> str:
31
+ return 'pad_mm'
32
+
33
+ def predict(self, context: AHContext) -> float:
34
+ if str(context.get_value('choice')) != 'pad':
35
+ if str(context.get_value('using_tf32')) != 'False':
36
+ if context.get_value('m*n') <= 4171264.0:
37
+ if context.get_value('m*k') <= 3999308.0:
38
+ return 1.8751469764071178
39
+ else:
40
+ if str(context.get_value('n_multiple_32')) != 'True':
41
+ return 0.9117231355626345
42
+ else:
43
+ return 1.1607689608873861
44
+ else:
45
+ if str(context.get_value('n_multiple_2')) != 'True':
46
+ if str(context.get_value('using_tf32')) != 'True':
47
+ return 0.7430382200435992
48
+ else:
49
+ return 0.8531269794448678
50
+ else:
51
+ if str(context.get_value('k_multiple_2')) != 'True':
52
+ return 0.7577181972719917
53
+ else:
54
+ return 0.8977349440424219
55
+ else:
56
+ if context.get_value('m*n') <= 1299712.0:
57
+ return 1.1669723418995592
58
+ else:
59
+ if context.get_value('mat2_stride_1') <= 45217.5:
60
+ if context.get_value('m*n') <= 55884158.0:
61
+ return 1.0262769936909601
62
+ else:
63
+ return 1.0022677428470845
64
+ else:
65
+ if context.get_value('m') <= 18478.0:
66
+ return 1.1127066261894312
67
+ else:
68
+ return 1.0337740659894263
69
+ else:
70
+ if str(context.get_value('mat1_dtype')) != 'torch.float32':
71
+ if str(context.get_value('n_multiple_2')) != 'False':
72
+ if str(context.get_value('k_multiple_2')) != 'True':
73
+ if context.get_value('mat1_stride_0') <= 561.0:
74
+ return 1.2900382135142956
75
+ else:
76
+ return 1.5761737616057887
77
+ else:
78
+ if context.get_value('num_dims_needs_padding') <= 1.5:
79
+ return 1.0472263310239422
80
+ else:
81
+ return 1.1727673465762514
82
+ else:
83
+ if context.get_value('k') <= 28238.5:
84
+ if context.get_value('k/(m*n)') <= 0.00026227018679492176:
85
+ return 1.6770542505397175
86
+ else:
87
+ return 1.3974785435105923
88
+ else:
89
+ if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
90
+ return 1.3952699800111992
91
+ else:
92
+ return 1.5759286511628336
93
+ else:
94
+ if str(context.get_value('using_tf32')) != 'False':
95
+ if context.get_value('m*n') <= 14119424.0:
96
+ return 0.8875772670422478
97
+ else:
98
+ if str(context.get_value('mat2_innermost_needs_padding')) != 'True':
99
+ return 1.1467728924377265
100
+ else:
101
+ return 1.215842963532998
102
+ else:
103
+ if context.get_value('arith_intensity') <= 396.8774871826172:
104
+ return 0.89940161869551
105
+ else:
106
+ if context.get_value('mat2_stride_1') <= 45217.5:
107
+ return 0.9964328169353532
108
+ else:
109
+ return 0.9493479238294826
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import partial
4
+ from typing import Any, Callable, Dict, List, Optional
5
+
6
+ import torch
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ AHOperation,
11
+ Choice,
12
+ CHOICE_COL,
13
+ Feedback,
14
+ FEEDBACK_COL,
15
+ get_metadata_str_from_log,
16
+ )
17
+ from torch._inductor.autoheuristic.learned_heuristic_controller import (
18
+ LearnedHeuristicController,
19
+ )
20
+ from torch._inductor.ir import ChoiceCaller
21
+ from torch._inductor.runtime.runtime_utils import cache_dir
22
+ from torch._inductor.utils import get_gpu_shared_memory
23
+
24
+
25
+ class LocalFeedback:
26
+ """
27
+ To be able to collect data for a choice, a function providing feedback given a choice has to be provided.
28
+ LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice
29
+ (see pad_mm.py, where the autotuning happens locally, for an example).
30
+ """
31
+
32
+ def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None:
33
+ self.feedback_fn = feedback_fn
34
+
35
+ def __call__(self, choice: Choice) -> Feedback:
36
+ return self.feedback_fn(choice)
37
+
38
+
39
+ class InconsistentMetadata(Exception):
40
+ """
41
+ Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does
42
+ not match the metadata it would store if the file didn't exist.
43
+ """
44
+
45
+
46
+ class AutoHeuristic:
47
+ """
48
+ AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and
49
+ generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train
50
+ a heuristic (see torchgen/autoheuristic/).
51
+ """
52
+
53
+ collected_feedback: Dict[Choice, Feedback]
54
+
55
+ def __init__(
56
+ self,
57
+ fallback: Callable[[], Choice],
58
+ choices: List[Choice],
59
+ feedback: Optional[LocalFeedback],
60
+ context: AHContext,
61
+ name: str,
62
+ augment_context: Optional[List[AHOperation]] = None,
63
+ precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
64
+ ) -> None:
65
+ """
66
+ Initializes an instance of the AutoHeuristic class.
67
+
68
+ Args:
69
+ fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or
70
+ AutoHeuristic is in data collection mode.
71
+ choices: A list of possible choices the heuristic can make.
72
+ feedback: An instance of LocalFeedback that provides feedback for a given choice.
73
+ context: Context to store with each choice and feedback.
74
+ name: A string that identifies the heuristic.
75
+ augment_context: An optional list of AHOperation instances that augment the context.
76
+ precondition: A callable that returns a boolean indicating whether AutoHeuristic should run.
77
+ """
78
+ self.fallback = fallback
79
+ self.choices = choices
80
+ self.feedback = feedback
81
+ self.context = context
82
+ self.name = name
83
+ self.collected_feedback = {}
84
+ self.augment_context = augment_context
85
+ self.metadata = AHMetadata(
86
+ get_gpu_shared_memory(),
87
+ torch.cuda.get_device_capability(),
88
+ self.choices,
89
+ self.name,
90
+ )
91
+ self.precondition = precondition
92
+
93
+ if not self.satisfies_precondition():
94
+ return
95
+
96
+ if torch._inductor.config.autoheuristic_log_path == "DEFAULT":
97
+ self.log_path = self.get_default_log_path()
98
+ else:
99
+ self.log_path = torch._inductor.config.autoheuristic_log_path
100
+
101
+ if torch._inductor.config.collect_autoheuristic(self.name):
102
+ if self.feedback is not None:
103
+ for choice in self.choices:
104
+ feedback_val = self.feedback(choice)
105
+ self.save_data(choice, feedback_val)
106
+
107
+ def satisfies_precondition(self) -> bool:
108
+ return self.precondition is None or self.precondition(
109
+ self.metadata, self.context
110
+ )
111
+
112
+ def get_choice(self) -> Choice:
113
+ """
114
+ Returns the chosen option based on the value of autoheuristic_use.
115
+ If self.name is one of the comma separated strings in autoheuristic_use,
116
+ it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option.
117
+ """
118
+
119
+ if not self.satisfies_precondition():
120
+ return self.fallback()
121
+
122
+ if torch._inductor.config.use_autoheuristic(self.name):
123
+ if self.augment_context is not None:
124
+ self.context.apply_operations(self.augment_context)
125
+ controller = LearnedHeuristicController(
126
+ self.metadata,
127
+ self.context,
128
+ )
129
+ decision = controller.get_decision()
130
+ if decision not in self.choices:
131
+ # TODO(AlnisM): We might want to allow this in the future
132
+ return self.fallback()
133
+ if decision is not None:
134
+ return decision
135
+ return self.fallback()
136
+
137
+ def get_top_k_choices(
138
+ self, top_k: int, always_included: Optional[List[str]] = None
139
+ ) -> Optional[List[Choice]]:
140
+ if not self.satisfies_precondition():
141
+ return None
142
+ if torch._inductor.config.use_autoheuristic(self.name):
143
+ if self.augment_context is not None:
144
+ self.context.apply_operations(self.augment_context)
145
+ controller = LearnedHeuristicController(
146
+ self.metadata,
147
+ self.context,
148
+ )
149
+ choices = controller.get_decisions_ranked(top_k)
150
+ if choices is None:
151
+ return None
152
+ if always_included is not None:
153
+ for choice in always_included:
154
+ if choice not in choices:
155
+ choices.append(choice)
156
+ return choices
157
+ return None
158
+
159
+ def get_collected_feedback(self, choice: Choice) -> Any:
160
+ return self.collected_feedback.get(choice, None)
161
+
162
+ @staticmethod
163
+ def get_device_identifier() -> str:
164
+ # a heuristic might work well for one GPU, but not for another
165
+ # we store the collected data per GPU model and learn a heuristic per GPU model
166
+
167
+ # TODO(AlnisM): just using the device name for now, but the same GPU model can have different names
168
+ device_name = torch.cuda.get_device_name().replace(" ", "_")
169
+ return device_name
170
+
171
+ def get_default_log_path(self) -> str:
172
+ device_name = self.get_device_identifier()
173
+ path = f"{cache_dir()}/autoheuristic/{device_name}/"
174
+ os.makedirs(path, exist_ok=True)
175
+ path += f"{self.name}.txt"
176
+ return path
177
+
178
+ def serialize_metadata(self) -> str:
179
+ metadata_dict = self.metadata.to_dict()
180
+ (
181
+ num_features,
182
+ cat_features,
183
+ ) = self.context.get_numerical_and_categorical_features()
184
+ metadata_dict["numerical_features"] = num_features
185
+ metadata_dict["categorical_features"] = cat_features
186
+ return json.dumps(metadata_dict)
187
+
188
+ def save_data(self, choice: Choice, feedback_val: Feedback) -> None:
189
+ self.collected_feedback[choice] = feedback_val
190
+ log_path = self.log_path
191
+
192
+ lines = []
193
+ log_exists = os.path.exists(log_path)
194
+ if log_exists:
195
+ # if log already exists, make sure it is consistent
196
+ metadata = self.serialize_metadata()
197
+ existing_metadata = get_metadata_str_from_log(self.log_path)
198
+ if existing_metadata != metadata:
199
+ raise InconsistentMetadata(
200
+ "Given metadata does not match existing metadata"
201
+ )
202
+ else:
203
+ lines.append(self.serialize_metadata())
204
+ feature_header = self.context.get_feature_names_csv()
205
+ header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL
206
+ lines.append(header)
207
+
208
+ line = ""
209
+ feature_values = self.context.get_feature_values_csv()
210
+ line += feature_values + "," + choice + "," + str(feedback_val)
211
+ lines.append(line)
212
+
213
+ with open(log_path, "a") as f:
214
+ f.write("\n".join(lines) + "\n")
215
+
216
+
217
+ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
218
+ """
219
+ AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic
220
+ when one wants to use AutoHeuristic for kernel choice selection.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ fallback: Callable[[], Optional[ChoiceCaller]],
226
+ choices: List[ChoiceCaller],
227
+ input_nodes: List[Any],
228
+ context: AHContext,
229
+ name: str,
230
+ augment_context: Optional[List[AHOperation]] = None,
231
+ precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
232
+ ) -> None:
233
+ """
234
+ The arguments choices, input_nodes and name have to match the ones used in the call to
235
+ autotune_select_algorithm(), e.g. if the following call is made
236
+ autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes
237
+ have to be used here.
238
+ """
239
+ self.input_nodes = input_nodes
240
+ self.choicestr2choice: Dict[str, ChoiceCaller] = {}
241
+ for choice in choices:
242
+ self.choicestr2choice[choice.autoheuristic_id()] = choice
243
+ choices_str = list(self.choicestr2choice.keys())
244
+
245
+ def fallback_str() -> str:
246
+ fallback_choice = fallback()
247
+ if fallback_choice is None:
248
+ # TODO: Find a nicer way to handle this
249
+ return "unsure"
250
+ return fallback_choice.autoheuristic_id()
251
+
252
+ super().__init__(
253
+ fallback_str,
254
+ choices_str,
255
+ None,
256
+ context,
257
+ name,
258
+ augment_context,
259
+ precondition,
260
+ )
261
+
262
+ if (
263
+ torch._inductor.config.collect_autoheuristic(self.name)
264
+ and self.satisfies_precondition()
265
+ ):
266
+ self.register_global_feedback(input_nodes, choices)
267
+
268
+ def register_global_feedback(
269
+ self, input_nodes: List[Any], choices: List[ChoiceCaller]
270
+ ) -> None:
271
+ """
272
+ Registers a callback in select_algorithm, which is called with the timing of each choice.
273
+ """
274
+
275
+ from torch._inductor.select_algorithm import (
276
+ add_feedback_saver,
277
+ create_inputs_key,
278
+ create_precompile_key,
279
+ )
280
+
281
+ def store_global_feedback(
282
+ ah_inputs_key: str,
283
+ ah_precompile_key: str,
284
+ timings: Dict[ChoiceCaller, float],
285
+ name: str,
286
+ input_nodes: List[Any],
287
+ choices: List[ChoiceCaller],
288
+ ) -> None:
289
+ current_inputs_key = create_inputs_key(input_nodes)
290
+ if current_inputs_key != ah_inputs_key:
291
+ return
292
+ current_precompile_key = create_precompile_key(
293
+ name, current_inputs_key, choices
294
+ )
295
+ if current_precompile_key != ah_precompile_key:
296
+ return
297
+ for choice, time in timings.items():
298
+ self.save_data(choice.autoheuristic_id(), time)
299
+
300
+ inputs_key = create_inputs_key(input_nodes)
301
+ precompile_key = create_precompile_key(self.name, inputs_key, choices)
302
+ feedback_saver = partial(store_global_feedback, inputs_key, precompile_key)
303
+ add_feedback_saver(feedback_saver)
304
+
305
+ def get_choice_caller(self) -> Optional[ChoiceCaller]:
306
+ choice = self.get_choice()
307
+ return self.choicestr2choice.get(choice, None)
308
+
309
+ def get_top_k_choices_caller(
310
+ self, top_k: int, always_included: Optional[List[str]] = None
311
+ ) -> Optional[List[ChoiceCaller]]:
312
+ choices = self.get_top_k_choices(top_k, always_included)
313
+ if choices is None:
314
+ return None
315
+ return [self.choicestr2choice[choice] for choice in choices]
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from typing import Any, Callable, Dict, List, Tuple
3
+
4
+ import torch
5
+
6
+
7
+ Feedback = float
8
+ Choice = str
9
+ Value = Any
10
+
11
+ CHOICE_COL = "choice"
12
+ FEEDBACK_COL = "feedback"
13
+
14
+
15
+ class AHFeature:
16
+ """
17
+ The context, that AutoHeuristic stores, is a list of features. AutoHeuristic needs to know whether a feature is
18
+ categorical (i.e., not a continuous variable) to learn a machine learning model.
19
+ """
20
+
21
+ def __init__(self, name: str, value: Value, is_categorical: bool = False) -> None:
22
+ self.name = name
23
+ self.value = value
24
+ self.is_categorical = is_categorical
25
+
26
+
27
+ class AHOperation:
28
+ """
29
+ AHOperation can be used to augment the data collected by AutoHeuristic.
30
+ One might for example store features like m, k, n, but also want to use
31
+ features like m*n, or k*n, to learn a heuristic. Instead of storing features
32
+ that can be created from the collected data, one can use AHOperation to
33
+ create new features from the collected data.
34
+ """
35
+
36
+ def __init__(
37
+ self, name: str, func: Callable[[Any], Value], is_categorical: bool = False
38
+ ) -> None:
39
+ self.name = name
40
+ self.func = func
41
+ self.is_categorical = is_categorical
42
+
43
+ def apply_operation(self, data: Any) -> None:
44
+ data[self.name] = self.func(data)
45
+
46
+
47
+ class AHContext:
48
+ """
49
+ This class is used to specify which information AutoHeuristic should store. For each choice, AutoHeursitic will
50
+ store the context and the collected feedback. The context could be something like the shape of a tensor, i.e.,
51
+ information that will help to learn a heuristic.
52
+ """
53
+
54
+ features: List[AHFeature]
55
+ context_dict: Dict[str, Value]
56
+
57
+ def __init__(self) -> None:
58
+ self.features = []
59
+ self.context_dict = {}
60
+
61
+ def add_feature(
62
+ self, name: str, value: Value, is_categorical: bool = False
63
+ ) -> None:
64
+ self.features.append(AHFeature(name, value, is_categorical=is_categorical))
65
+ self.context_dict[name] = value
66
+
67
+ def get_numerical_and_categorical_features(self) -> Tuple[List[str], List[str]]:
68
+ numerical_features = []
69
+ categorical_features = []
70
+ for feature in self.features:
71
+ if feature.is_categorical:
72
+ categorical_features.append(feature.name)
73
+ else:
74
+ numerical_features.append(feature.name)
75
+
76
+ return numerical_features, categorical_features
77
+
78
+ def get_feature_names_csv(self) -> str:
79
+ return ",".join(feature.name for feature in self.features)
80
+
81
+ def get_feature_values_csv(self) -> str:
82
+ return ",".join(str(feature.value) for feature in self.features)
83
+
84
+ def get_value(self, name: str) -> Value:
85
+ return self.context_dict[name]
86
+
87
+ def apply_operations(self, operations: List[AHOperation]) -> None:
88
+ for op in operations:
89
+ op.apply_operation(self.context_dict)
90
+
91
+
92
+ class AHMetadata:
93
+ def __init__(
94
+ self,
95
+ shared_memory: Any,
96
+ device_capa: Tuple[int, int],
97
+ choices: List[Choice],
98
+ name: str,
99
+ ) -> None:
100
+ # use amount of shared_memory and device_capability to identify GPU
101
+ # TODO(AlnisM): there might be a better way to do this
102
+ self.shared_memory = shared_memory
103
+ self.device_capa = device_capa
104
+ self.choices = choices
105
+ self.name = name
106
+
107
+ def to_dict(self) -> Dict[str, Value]:
108
+ return {
109
+ "shared_memory": self.shared_memory,
110
+ "device_capa": self.device_capa,
111
+ "name": self.name,
112
+ }
113
+
114
+
115
+ def get_metadata_str_from_log(log_path: str) -> str:
116
+ with open(log_path, newline="") as file:
117
+ json_string = file.readline().strip()
118
+ return json_string
119
+
120
+
121
+ def check_minsize(context: AHContext, minsize: int) -> bool:
122
+ return (
123
+ context.get_value("m") >= minsize
124
+ and context.get_value("k") >= minsize
125
+ and context.get_value("n") >= minsize
126
+ )
127
+
128
+
129
+ def pad_mm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
130
+ if metadata.shared_memory == 166912 and metadata.device_capa == (8, 0):
131
+ # A100 precondition
132
+ return check_minsize(context, 512)
133
+ elif metadata.shared_memory == 232448 and metadata.device_capa == (9, 0):
134
+ # H100 precondition
135
+ return check_minsize(context, 768)
136
+ return True
137
+
138
+
139
+ def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
140
+ m = context.get_value("m")
141
+ k = context.get_value("k")
142
+ n = context.get_value("n")
143
+ if m > 128 or k < 1024 or n < 1024:
144
+ return False
145
+ mat1_iscontig = context.get_value("mat1_iscontig")
146
+ mat2_iscontig = context.get_value("mat2_iscontig")
147
+ return mat1_iscontig and not mat2_iscontig
148
+
149
+
150
+ def get_mult_dims_ops() -> List[AHOperation]:
151
+ m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"])
152
+ m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"])
153
+ k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"])
154
+ return [m_times_k_op, m_times_n_op, k_times_n_op]
155
+
156
+
157
+ def get_arith_intensity(data: Any) -> float:
158
+ m = data["m"]
159
+ k = data["k"]
160
+ n = data["n"]
161
+ if m == 0 or k == 0 or n == 0:
162
+ return 0.0
163
+ return m * k * n / (m * k + k * n + m * n)
164
+
165
+
166
+ def pad_mm_operations() -> List[AHOperation]:
167
+ mult_dims_ops = get_mult_dims_ops()
168
+ k_div_m_times_n_op = AHOperation(
169
+ "k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"])
170
+ )
171
+
172
+ def bfloat_perf_hit(data: Any) -> bool:
173
+ m = data["m"]
174
+ k = data["k"]
175
+ n = data["n"]
176
+ is_bfloat = str(data["mat1_dtype"]) == "torch.bfloat16"
177
+ return k > (m * 1024) and k > (n * 1024) and is_bfloat
178
+
179
+ bfloat_perf_hit_op = AHOperation(
180
+ "bfloat_perf_hit", bfloat_perf_hit, is_categorical=True
181
+ )
182
+
183
+ arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
184
+ dims_need_padding_ops = get_dims_need_padding_ops()
185
+ dims_multiple_ops = get_dims_multiple_ops()
186
+ is_contig_ops = get_is_contig_ops()
187
+
188
+ ah_operations = mult_dims_ops + [
189
+ k_div_m_times_n_op,
190
+ bfloat_perf_hit_op,
191
+ arith_intensity_op,
192
+ ]
193
+ ah_operations.extend(dims_need_padding_ops)
194
+ ah_operations.extend(dims_multiple_ops)
195
+ ah_operations.extend(is_contig_ops)
196
+ return ah_operations
197
+
198
+
199
+ def between_op(data: Any, dim: str, lower: int, upper: int) -> bool:
200
+ return data[dim] >= lower and data[dim] <= upper
201
+
202
+
203
+ def between_ops() -> List[AHOperation]:
204
+ dims = ["m", "k", "n"]
205
+ limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)]
206
+ ah_operations = []
207
+ for dim in dims:
208
+ for lower, upper in limits:
209
+ between_op_fn = functools.partial(
210
+ between_op, dim=dim, lower=lower, upper=upper
211
+ )
212
+ # using 'LEQ' instead of '<=' because '<=' cannot be exported to dot
213
+ between_op_name = f"{lower}LEQ{dim}LEQ{upper}"
214
+ ah_operations.append(
215
+ AHOperation(between_op_name, between_op_fn, is_categorical=True)
216
+ )
217
+ return ah_operations
218
+
219
+
220
+ def pow2_op(data: Any, dim: str, exponent: int) -> bool:
221
+ return data[dim] == 2**exponent
222
+
223
+
224
+ def mm_operations() -> List[AHOperation]:
225
+ mult_dims_ops = get_mult_dims_ops()
226
+ arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
227
+ return mult_dims_ops + [arith_intensity_op]
228
+
229
+
230
+ def mixed_mm_operations() -> List[AHOperation]:
231
+ return mm_operations() + between_ops()
232
+
233
+
234
+ def is_multiple(data: Any, dim: str, mult: int) -> bool:
235
+ return data[dim] % mult == 0
236
+
237
+
238
+ def get_dims_multiple_ops() -> List[AHOperation]:
239
+ multiples = [2, 4, 8, 16, 32]
240
+ dims = ["m", "k", "n"]
241
+ dims_multiple_ops = []
242
+ for dim in dims:
243
+ for mult in multiples:
244
+ is_multiple_fn = functools.partial(is_multiple, dim=dim, mult=mult)
245
+ dims_multiple_op = AHOperation(
246
+ f"{dim}_multiple_{mult}", is_multiple_fn, is_categorical=True
247
+ )
248
+ dims_multiple_ops.append(dims_multiple_op)
249
+ return dims_multiple_ops
250
+
251
+
252
+ def get_dims_need_padding_ops() -> List[AHOperation]:
253
+ def mat1_innermost_needs_padding_fn(data: Any) -> bool:
254
+ mat1_stride_0 = data["mat1_stride_0"]
255
+ mat1_stride_1 = data["mat1_stride_1"]
256
+ m_padded_length = data["m_padded_length"]
257
+ k_padded_length = data["k_padded_length"]
258
+ mat1_innermost_needs_padding = False
259
+ if mat1_stride_0 == 1 and m_padded_length != 0:
260
+ mat1_innermost_needs_padding = True
261
+ if mat1_stride_1 == 1 and k_padded_length != 0:
262
+ mat1_innermost_needs_padding = True
263
+ return mat1_innermost_needs_padding
264
+
265
+ mat1_innermost_op = AHOperation(
266
+ "mat1_innermost_needs_padding",
267
+ mat1_innermost_needs_padding_fn,
268
+ is_categorical=True,
269
+ )
270
+
271
+ def mat2_innermost_needs_padding_fn(data: Any) -> bool:
272
+ mat2_stride_0 = data["mat2_stride_0"]
273
+ mat2_stride_1 = data["mat2_stride_1"]
274
+ k_padded_length = data["k_padded_length"]
275
+ n_padded_length = data["n_padded_length"]
276
+ mat2_innermost_needs_padding = False
277
+ if mat2_stride_0 == 1 and k_padded_length != 0:
278
+ mat2_innermost_needs_padding = True
279
+ if mat2_stride_1 == 1 and n_padded_length != 0:
280
+ mat2_innermost_needs_padding = True
281
+ return mat2_innermost_needs_padding
282
+
283
+ mat2_innermost_op = AHOperation(
284
+ "mat2_innermost_needs_padding",
285
+ mat2_innermost_needs_padding_fn,
286
+ is_categorical=True,
287
+ )
288
+
289
+ def num_dims_needs_padding_fn(data: Any) -> int:
290
+ m_padded_length = data["m_padded_length"]
291
+ k_padded_length = data["k_padded_length"]
292
+ n_padded_length = data["n_padded_length"]
293
+ num_dims_needs_padding = 0
294
+ if m_padded_length != 0:
295
+ num_dims_needs_padding += 1
296
+ if k_padded_length != 0:
297
+ num_dims_needs_padding += 1
298
+ if n_padded_length != 0:
299
+ num_dims_needs_padding += 1
300
+ return num_dims_needs_padding
301
+
302
+ num_dims_op = AHOperation("num_dims_needs_padding", num_dims_needs_padding_fn)
303
+ return [mat1_innermost_op, mat2_innermost_op, num_dims_op]
304
+
305
+
306
+ def get_is_contig_ops() -> List[AHOperation]:
307
+ def mat1_is_contig_fn(data: Any) -> bool:
308
+ stride_0 = data["mat1_stride_0"]
309
+ stride_1 = data["mat1_stride_1"]
310
+ k = data["k"]
311
+ return stride_0 == k and stride_1 == 1
312
+
313
+ mat1_is_contig_op = AHOperation(
314
+ "mat1_iscontig", mat1_is_contig_fn, is_categorical=True
315
+ )
316
+
317
+ def mat2_is_contig_fn(data: Any) -> bool:
318
+ stride_0 = data["mat2_stride_0"]
319
+ stride_1 = data["mat2_stride_1"]
320
+ n = data["n"]
321
+ return stride_0 == n and stride_1 == 1
322
+
323
+ mat2_is_contig_op = AHOperation(
324
+ "mat2_iscontig", mat2_is_contig_fn, is_categorical=True
325
+ )
326
+
327
+ return [mat1_is_contig_op, mat2_is_contig_op]
328
+
329
+
330
+ def context_add_strides(context: AHContext, name: str, stride: Tuple[int, ...]) -> None:
331
+ for i, s in enumerate(stride):
332
+ context.add_feature(f"{name}_stride_{i}", s)
333
+
334
+
335
+ def context_add_using_tf32(context: AHContext, dtype: torch.dtype) -> None:
336
+ using_tf32 = "not_float_32"
337
+ if dtype == torch.float32:
338
+ using_tf32 = torch.backends.cuda.matmul.allow_tf32
339
+ context.add_feature("using_tf32", using_tf32, is_categorical=True)
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import inspect
3
+ import pkgutil
4
+ from collections import defaultdict
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ Choice,
11
+ )
12
+ from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic
13
+
14
+
15
+ def find_and_instantiate_subclasses(
16
+ package_name: str, base_class: Any
17
+ ) -> List[LearnedHeuristic]:
18
+ instances = []
19
+
20
+ package = importlib.import_module(package_name)
21
+ for _, module_name, _ in pkgutil.walk_packages(
22
+ package.__path__, package.__name__ + "."
23
+ ):
24
+ try:
25
+ module_basename = module_name.split(".")[-1]
26
+ if not module_basename.startswith("_"):
27
+ # learned heuristics start with an underscore
28
+ continue
29
+ module = importlib.import_module(module_name)
30
+
31
+ # look for classes that are subclasses of base_class
32
+ for name, obj in inspect.getmembers(module):
33
+ if (
34
+ inspect.isclass(obj)
35
+ and issubclass(obj, base_class)
36
+ and obj != base_class
37
+ ):
38
+ instance = obj()
39
+ instances.append(instance)
40
+ except Exception as e:
41
+ print(f"Error processing module {module_name}: {e}")
42
+
43
+ return instances
44
+
45
+
46
+ class LearnedHeuristicController:
47
+ """
48
+ Class that finds and instantiates all learned heuristics. It also provides
49
+ a way to get the decision of a learned heuristic.
50
+ """
51
+
52
+ existing_heuristics: Dict[str, List[LearnedHeuristic]] = defaultdict(list)
53
+ """
54
+ A dictionary that stores all the learned heuristics for each optimization.
55
+ The key is the optimization name, and the value is a list of LearnedHeuristic objects.
56
+ """
57
+
58
+ heuristics_initialized: bool = False
59
+ """
60
+ A flag that indicates whether the learned heuristics have been initialized.
61
+ Set to true when the get_decision() function is called for the first time.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ metadata: AHMetadata,
67
+ context: AHContext,
68
+ ) -> None:
69
+ self.metadata = metadata
70
+ self.context = context
71
+
72
+ def get_heuristics(self, name: str) -> List[LearnedHeuristic]:
73
+ """
74
+ Returns a list of learned heuristics for the given optimization name.
75
+ """
76
+
77
+ if not LearnedHeuristicController.heuristics_initialized:
78
+ # learned heuristics are generated into the following package
79
+ learned_heuristics_package = "torch._inductor.autoheuristic.artifacts"
80
+
81
+ # learned heuristics have to be of type LearnedHeuristic
82
+ base_class = LearnedHeuristic
83
+ found_heuristics = find_and_instantiate_subclasses(
84
+ learned_heuristics_package, base_class
85
+ )
86
+
87
+ for learned_heuristic in found_heuristics:
88
+ opt_name = learned_heuristic.get_name()
89
+ LearnedHeuristicController.existing_heuristics[opt_name].append(
90
+ learned_heuristic
91
+ )
92
+ LearnedHeuristicController.heuristics_initialized = True
93
+
94
+ return LearnedHeuristicController.existing_heuristics[name]
95
+
96
+ def get_decision(self) -> Optional[Choice]:
97
+ """
98
+ Returns the decision made by the learned heuristic or None if no heuristic was found or the heuristic is unsure
99
+ which choice to make.
100
+ """
101
+
102
+ heuristics = self.get_heuristics(self.metadata.name)
103
+ for heuristic in heuristics:
104
+ if heuristic.check_precondition(self.metadata, self.context):
105
+ return heuristic.get_decision(self.context, self.metadata.choices)
106
+ return None
107
+
108
+ def get_decisions_ranked(self, top_k: int) -> Optional[List[Choice]]:
109
+ heuristics = self.get_heuristics(self.metadata.name)
110
+ for heuristic in heuristics:
111
+ if heuristic.check_precondition(self.metadata, self.context):
112
+ choices = heuristic.get_decisions_ranked(self.context)
113
+ if choices is None:
114
+ return None
115
+ avail_choices = [
116
+ choice for choice in choices if choice in self.metadata.choices
117
+ ]
118
+ return avail_choices[:top_k]
119
+ return None
.venv/lib/python3.11/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
4
+ AHContext,
5
+ AHMetadata,
6
+ Choice,
7
+ )
8
+
9
+
10
+ class LearnedHeuristic:
11
+ """
12
+ LearnedHeuristic is a base class for all learned heuristics.
13
+ """
14
+
15
+ def __init__(self) -> None:
16
+ pass
17
+
18
+ def check_precondition(
19
+ self,
20
+ metadata: AHMetadata,
21
+ context: AHContext,
22
+ ) -> bool:
23
+ return True
24
+
25
+ def get_decision(
26
+ self, context: AHContext, choices: List[Choice]
27
+ ) -> Optional[Choice]:
28
+ return None
29
+
30
+ def get_confidence_threshold(self) -> float:
31
+ return 1.0
32
+
33
+ def get_name(self) -> str:
34
+ return ""
35
+
36
+ def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
37
+ return None
38
+
39
+
40
+ class LearnedHeuristicRegression(LearnedHeuristic):
41
+ def __init__(self) -> None:
42
+ super().__init__()
43
+
44
+ def get_feedback(self, context: AHContext, choice: Choice) -> float:
45
+ return 1.0
46
+
47
+ def get_decision(
48
+ self, context: AHContext, choices: List[Choice]
49
+ ) -> Optional[Choice]:
50
+ choice2feedback = {}
51
+ for choice in choices:
52
+ predicted_feedback = self.get_feedback(context, choice)
53
+ choice2feedback[choice] = predicted_feedback
54
+ sorted_choices_feedback = sorted(choice2feedback.items(), key=lambda t: t[1])
55
+ highest_feedback = sorted_choices_feedback[-1][1]
56
+ second_highest_feedback = sorted_choices_feedback[-2][1]
57
+ if highest_feedback / second_highest_feedback > self.get_confidence_threshold():
58
+ return sorted_choices_feedback[-1][0]
59
+ # We are not sure which choice is the best one
60
+ return None
61
+
62
+
63
+ class LearnedHeuristicDecision(LearnedHeuristic):
64
+ def __init__(self) -> None:
65
+ super().__init__()
66
+
67
+ def get_choice(self, idx: int) -> Optional[str]:
68
+ return None
69
+
70
+ def get_decision(
71
+ self, context: AHContext, choices: List[Choice]
72
+ ) -> Optional[Choice]:
73
+ best_choices = self.get_best_choices(context)
74
+ if not best_choices:
75
+ return None
76
+ (best_choice_proba, best_choice_idx) = best_choices[0]
77
+ if best_choice_proba <= self.get_confidence_threshold():
78
+ return None
79
+ return self.get_choice(best_choice_idx)
80
+
81
+ def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
82
+ feedback_idx_list = self.get_best_choices(context)
83
+ if feedback_idx_list is None:
84
+ return None
85
+ choices = [
86
+ self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list
87
+ ]
88
+ choices = [choice for choice in choices if choice is not None]
89
+ return choices
90
+
91
+ def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
92
+ return []
.venv/lib/python3.11/site-packages/torch/_inductor/autotune_process.py ADDED
@@ -0,0 +1,876 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ import contextlib
5
+ import ctypes
6
+ import dataclasses
7
+ import functools
8
+ import logging
9
+ import os
10
+ import queue
11
+ import time
12
+ import warnings
13
+ from concurrent.futures import ThreadPoolExecutor
14
+ from ctypes import byref, c_size_t, c_void_p, CDLL
15
+ from typing import (
16
+ Any,
17
+ Callable,
18
+ Dict,
19
+ Iterable,
20
+ List,
21
+ Optional,
22
+ Sequence,
23
+ TYPE_CHECKING,
24
+ Union,
25
+ )
26
+
27
+ import torch
28
+ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
29
+ from torch import multiprocessing
30
+ from torch._dynamo.testing import rand_strided
31
+ from torch._inductor import ir
32
+ from torch._inductor.codecache import (
33
+ CppCodeCache,
34
+ CUDACodeCache,
35
+ DLLWrapper,
36
+ get_hash,
37
+ PyCodeCache,
38
+ )
39
+
40
+
41
+ if TYPE_CHECKING:
42
+ from multiprocessing.process import BaseProcess
43
+ from multiprocessing.queues import Queue
44
+ from types import ModuleType
45
+
46
+ from torch._inductor.select_algorithm import TritonTemplateCaller
47
+
48
+ from . import config
49
+ from .runtime.benchmarking import benchmarker
50
+ from .virtualized import V
51
+
52
+
53
+ CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
54
+ EXIT_HANDLER_REGISTERED = False
55
+
56
+ log = logging.getLogger(__name__)
57
+
58
+
59
+ # Used to synchronize between parent and child processes
60
+ class Ping:
61
+ pass
62
+
63
+
64
+ class Pong:
65
+ pass
66
+
67
+
68
+ class NonzeroWorkspaceNotSupportedError(Exception):
69
+ pass
70
+
71
+
72
+ @contextlib.contextmanager
73
+ def set_cuda_visible_device(device: Optional[int]):
74
+ """
75
+ Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the
76
+ specified single device. If device is None, don't manipulate the environment.
77
+ """
78
+ if device is None:
79
+ yield
80
+ return
81
+
82
+ current = os.environ.get(CUDA_VISIBLE_DEVICES)
83
+ os.environ[CUDA_VISIBLE_DEVICES] = str(device)
84
+ try:
85
+ yield
86
+ finally:
87
+ if current is None:
88
+ del os.environ[CUDA_VISIBLE_DEVICES]
89
+ else:
90
+ os.environ[CUDA_VISIBLE_DEVICES] = current
91
+
92
+
93
+ @dataclasses.dataclass
94
+ class TuningProcess:
95
+ """
96
+ Abstraction for launching a helper process to benchmark kernels. Spawns
97
+ the parent process and uses multiprocessing queues to send benchmark
98
+ requests and return results.
99
+ """
100
+
101
+ device: Optional[int] = None
102
+ process: Optional[BaseProcess] = None
103
+ request_queue: Optional[Queue[Any]] = None
104
+ response_queue: Optional[Queue[Any]] = None
105
+
106
+ @staticmethod
107
+ def process_main(
108
+ request_queue: Queue[Any],
109
+ response_queue: Queue[Any],
110
+ ) -> None:
111
+ """
112
+ Entry point for the child process.
113
+ """
114
+ log.debug(
115
+ "Entering TuningProcess child. Visible devices = %s",
116
+ os.environ.get(CUDA_VISIBLE_DEVICES),
117
+ )
118
+ try:
119
+ TuningProcess.workloop(request_queue, response_queue)
120
+ except Exception as ex:
121
+ log.exception("Exception in TuningProcess")
122
+
123
+ @staticmethod
124
+ def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
125
+ """
126
+ Work loop for the benchmarking subprocess.
127
+ """
128
+ while True:
129
+ obj = request_queue.get()
130
+
131
+ if obj is None:
132
+ break # None is a sentinel for the child to terminate
133
+ elif isinstance(obj, Ping):
134
+ response_queue.put(Pong())
135
+ elif isinstance(obj, BenchmarkRequest):
136
+ response_queue.put(obj.benchmark())
137
+ else:
138
+ raise RuntimeError(f"Invalid request type {type(obj)}")
139
+
140
+ def valid(self) -> bool:
141
+ """
142
+ True if the sub-process has been initialized.
143
+ """
144
+ return (
145
+ self.process is not None
146
+ and self.request_queue is not None
147
+ and self.response_queue is not None
148
+ )
149
+
150
+ def clear(self) -> None:
151
+ """
152
+ Reset to an uninitialized state.
153
+ """
154
+ self.process = self.request_queue = self.response_queue = None
155
+
156
+ def initialize(self) -> None:
157
+ """
158
+ Create child process, request/response queues, and do the warm up.
159
+ Set the environment to make only the provided GPU device visible
160
+ to the process.
161
+ """
162
+ if self.valid():
163
+ return
164
+
165
+ # cuda runtime does not work with "fork", use "spawn" to start processes.
166
+ ctx = multiprocessing.get_context("spawn")
167
+ self.request_queue = ctx.Queue()
168
+ self.response_queue = ctx.Queue()
169
+
170
+ self.process = ctx.Process(
171
+ target=self.process_main,
172
+ args=(
173
+ self.request_queue,
174
+ self.response_queue,
175
+ ),
176
+ )
177
+ assert self.process is not None
178
+ with set_cuda_visible_device(self.device):
179
+ self.process.start()
180
+
181
+ def put(self, obj: Any) -> None:
182
+ """
183
+ Push a work item to the child process.
184
+ """
185
+ # In case of a prior crash, ensure the subprocess is running
186
+ self.initialize()
187
+ assert self.request_queue is not None
188
+ self.request_queue.put(obj)
189
+
190
+ def get(
191
+ self, result_timeout=120.0, graceful_timeout=3.0, terminate_timeout=1.0
192
+ ) -> Any:
193
+ """
194
+ Get a response from the child process. Raises queue.Empty on timeout
195
+ or if the process dies.
196
+
197
+ This method is (so far) only used by TuningProcessPool, where torch._inductor.config entries are being used
198
+ to populate the timeouts:
199
+
200
+ Arguments:
201
+
202
+ @param result_timeout: Timeout in seconds, defaults to 120.0 or to
203
+ config.max_autotune_subproc_result_timeout_seconds when called by TuningProcessPool
204
+ @param graceful_timeout: Timeout in seconds to allow graceful shutdown (SIGTERM is sent after this time).
205
+ Defaults to 3.0 or to config.max_autotune_subproc_graceful_timeout_seconds
206
+ @param terminate_timeout: Timeout in seconds after SIGTERM, until we send SIGKILL if the process
207
+ remains alive. Defaults to 1.0 or to
208
+ config.max_autotune_subproc_terminate_timeout_seconds.
209
+ Returns:
210
+ A response from the child process (Any type)
211
+ """
212
+ assert self.process is not None
213
+ assert self.response_queue is not None
214
+ while True:
215
+ try:
216
+ remaining_timeout = result_timeout
217
+ res = None
218
+ while remaining_timeout is not None and remaining_timeout >= 1.0:
219
+ remaining_timeout -= 0.5
220
+ try:
221
+ res = self.response_queue.get(timeout=0.5)
222
+ break
223
+ except queue.Empty:
224
+ if not self.process.is_alive():
225
+ raise # is being caught a few lines below
226
+ if res is None:
227
+ res = self.response_queue.get(timeout=remaining_timeout)
228
+ return res
229
+ except queue.Empty:
230
+ status = self.process.exitcode
231
+ if status is None:
232
+ self.kill(
233
+ graceful_timeout=graceful_timeout,
234
+ terminate_timeout=terminate_timeout,
235
+ )
236
+ else:
237
+ # child process crashed
238
+ self.clear()
239
+ raise
240
+
241
+ def terminate(self) -> None:
242
+ """
243
+ Signal the child process to terminate.
244
+ """
245
+ if self.valid():
246
+ assert self.process is not None
247
+ assert self.request_queue is not None
248
+ self.request_queue.put(None)
249
+
250
+ def wait(self) -> None:
251
+ """
252
+ Wait for the child process to exit.
253
+ """
254
+ if self.process is not None:
255
+ self.process.join()
256
+ self.clear()
257
+
258
+ def kill(self, graceful_timeout=5.0, terminate_timeout=1.0) -> None:
259
+ # Tries to kill the process, using a graceful_timeout in which the process
260
+ # is allowed to exit gracefully. If the process is still alive,
261
+ # it will be terminated. If that is not sufficient to end it
262
+ # within terminate_timeout seconds, it will be killed.
263
+ if self.process is not None:
264
+ self.terminate()
265
+ self.process.join(timeout=graceful_timeout)
266
+ if self.process.is_alive():
267
+ log.warning(
268
+ "Sending SIGTERM to process with PID %d",
269
+ self.process.pid,
270
+ )
271
+ self.process.terminate()
272
+ self.process.join(timeout=terminate_timeout)
273
+ if self.process.is_alive():
274
+ log.error(
275
+ "Sending SIGKILL to process with PID %d",
276
+ self.process.pid,
277
+ )
278
+ self.process.kill() # This should definitely end the process
279
+ self.clear()
280
+
281
+
282
+ @dataclasses.dataclass
283
+ class TuningProcessPool:
284
+ """
285
+ Maintains a pool of TuningProcesses to benchmark kernels in parallel
286
+ across devices. By default, we create one TuningProcess per device and
287
+ set the sub-process environment to make only that device visible.
288
+ """
289
+
290
+ processes: Optional[queue.Queue[TuningProcess]] = None
291
+ executor: Optional[ThreadPoolExecutor] = None
292
+
293
+ def initialize(self) -> None:
294
+ """
295
+ Start the child processes.
296
+ """
297
+ assert (self.processes is None) == (self.executor is None)
298
+ if self.processes is not None:
299
+ return
300
+
301
+ devices = self.get_device_list()
302
+ log.debug("Sub-process autotune device list: %s", devices)
303
+
304
+ # Launch the child processes and push a msg to "warm up"
305
+ self.processes = queue.Queue()
306
+ for device in devices:
307
+ p = TuningProcess(device=device)
308
+ p.initialize()
309
+ p.put(Ping())
310
+ self.processes.put(p)
311
+
312
+ # Wait for the initialization to finish
313
+ for p in self.processes.queue:
314
+ assert isinstance(p.get(result_timeout=None), Pong)
315
+
316
+ # Use a thread pool to manage distributing work to the subprocesses.
317
+ # Threads block on an available process, so it makes sense to match
318
+ # the number of threads with the number of devices.
319
+ self.executor = ThreadPoolExecutor(max_workers=len(devices))
320
+
321
+ # Register the exit handler for the parent process so it will terminate
322
+ # the child processes.
323
+ global EXIT_HANDLER_REGISTERED
324
+ if not EXIT_HANDLER_REGISTERED:
325
+ EXIT_HANDLER_REGISTERED = True
326
+ import atexit
327
+
328
+ atexit.register(self.terminate)
329
+
330
+ def get_device_list(self) -> Sequence[Optional[int]]:
331
+ """
332
+ Gather the list of devices to be used in the pool.
333
+ """
334
+ if not config.autotune_multi_device:
335
+ # Don't use multiple devices
336
+ return [None]
337
+
338
+ count = torch.cuda.device_count()
339
+
340
+ # If the user specified the visible devices in the env, use those.
341
+ if CUDA_VISIBLE_DEVICES in os.environ:
342
+ devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")]
343
+ assert len(devices) <= count
344
+ return devices
345
+
346
+ return list(range(count))
347
+
348
+ def terminate(self) -> None:
349
+ """
350
+ Signal all child processes to terminate.
351
+ """
352
+ if self.executor is not None:
353
+ self.executor.shutdown()
354
+ self.executor = None
355
+
356
+ if self.processes is not None:
357
+ for p in self.processes.queue:
358
+ p.terminate()
359
+ for p in self.processes.queue:
360
+ p.wait()
361
+ self.processes = None
362
+
363
+ def target(self, choice: TritonTemplateCaller) -> float:
364
+ """
365
+ Entry point for the thread-pool helper threads: Wait for an open TuningProcess,
366
+ remove it from the queue, execute the benchmark in that subprocess, and return
367
+ the TuningProcess to the queue.
368
+ """
369
+ assert choice.bmreq is not None
370
+ assert self.processes is not None
371
+
372
+ process = self.processes.get()
373
+ process.put(choice.bmreq)
374
+ try:
375
+ return process.get(
376
+ config.max_autotune_subproc_result_timeout_seconds,
377
+ config.max_autotune_subproc_graceful_timeout_seconds,
378
+ config.max_autotune_subproc_terminate_timeout_seconds,
379
+ )
380
+ except queue.Empty:
381
+ warnings.warn(
382
+ f"Failed to benchmark choice '{choice}'. It will be ignored. "
383
+ "Please debug the root cause in case the choice can bring perf gains."
384
+ )
385
+ # set to INF so this choice will be ignored
386
+ return float("inf")
387
+ finally:
388
+ self.processes.put(process)
389
+
390
+ def benchmark(
391
+ self,
392
+ choices: List[TritonTemplateCaller],
393
+ ) -> Dict[TritonTemplateCaller, float]:
394
+ """
395
+ Benchmark each choice in a separate process.
396
+ """
397
+ assert self.processes is not None, "Tuning process pool is not initialized"
398
+ assert self.executor is not None
399
+
400
+ results = {}
401
+
402
+ # Use a ThreadExecutorPool to spread the work across the subprocesses and
403
+ # to grab subprocesses as soon as they're free.
404
+ for choice, result in zip(choices, self.executor.map(self.target, choices)):
405
+ results[choice] = result
406
+
407
+ return results
408
+
409
+
410
+ tuning_pool = TuningProcessPool()
411
+
412
+
413
+ LayoutOrBuffer = Union[ir.Layout, ir.Buffer]
414
+
415
+
416
+ @dataclasses.dataclass
417
+ class TensorMeta:
418
+ device: torch.device
419
+ dtype: torch.dtype
420
+ sizes: torch._prims_common.ShapeType
421
+ strides: torch._prims_common.StrideType
422
+ offset: int
423
+ name: Optional[str] = None
424
+
425
+ @classmethod
426
+ def from_irnodes(
427
+ cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
428
+ ) -> Union[TensorMeta, List[TensorMeta]]:
429
+ if isinstance(irnodes, Sequence):
430
+ result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
431
+ assert all(isinstance(x, TensorMeta) for x in result)
432
+ return result
433
+
434
+ node = irnodes
435
+ if isinstance(node, ir.Layout):
436
+ node = ir.Buffer("fake", node)
437
+
438
+ dtype = node.get_dtype()
439
+ assert dtype is not None
440
+
441
+ return TensorMeta(
442
+ device=node.get_device(),
443
+ dtype=dtype,
444
+ sizes=V.graph.sizevars.size_hints(
445
+ node.get_size(),
446
+ fallback=config.unbacked_symint_fallback,
447
+ ),
448
+ strides=V.graph.sizevars.size_hints(
449
+ node.get_stride(),
450
+ fallback=config.unbacked_symint_fallback,
451
+ ),
452
+ offset=V.graph.sizevars.size_hint(
453
+ node.get_layout().offset,
454
+ fallback=config.unbacked_symint_fallback,
455
+ ),
456
+ name=node.get_name(),
457
+ )
458
+
459
+ def to_tensor(self) -> torch.Tensor:
460
+ return rand_strided(
461
+ self.sizes,
462
+ self.strides,
463
+ device=self.device,
464
+ dtype=self.dtype,
465
+ extra_size=self.offset,
466
+ )
467
+
468
+
469
+ @dataclasses.dataclass
470
+ class BenchmarkRequest:
471
+ """
472
+ Only handle triton template benchmark for now. The extern kernel benchmark
473
+ can be done inside the same process since they usually don't cause crash.
474
+
475
+ Important: Instances of this class and subclasses have to be serializable
476
+ across process boundaries. Do not put CUDA Tensors in here!
477
+ """
478
+
479
+ def __init__(
480
+ self,
481
+ kernel_name: str,
482
+ input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
483
+ output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
484
+ extra_args: Iterable[Any],
485
+ ) -> None:
486
+ # the kernel name defined in the module
487
+ self.kernel_name = kernel_name
488
+
489
+ if isinstance(input_tensor_meta, TensorMeta):
490
+ input_tensor_meta = [input_tensor_meta]
491
+ self.input_tensor_meta = input_tensor_meta
492
+
493
+ if isinstance(output_tensor_meta, (tuple, list)):
494
+ assert len(output_tensor_meta) == 1
495
+ output_tensor_meta = output_tensor_meta[0]
496
+ self.output_tensor_meta = output_tensor_meta
497
+
498
+ self.extra_args = extra_args
499
+
500
+ def make_run_fn(
501
+ self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
502
+ ) -> Callable[[], None]:
503
+ raise NotImplementedError
504
+
505
+ def cleanup_run_fn(self) -> None:
506
+ pass
507
+
508
+ def do_bench(
509
+ self,
510
+ fn,
511
+ *input_tensors: torch.Tensor,
512
+ output_tensor: Optional[torch.Tensor] = None,
513
+ ) -> float:
514
+ raise NotImplementedError
515
+
516
+ def benchmark(
517
+ self,
518
+ *input_tensors: torch.Tensor,
519
+ output_tensor: Optional[torch.Tensor] = None,
520
+ ) -> float:
521
+ debug = log.isEnabledFor(logging.DEBUG)
522
+ if debug:
523
+ start_ts = time.time()
524
+
525
+ # create args and out tensor
526
+ if output_tensor is None:
527
+ assert len(input_tensors) == 0
528
+ input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta)
529
+ output_tensor = self.output_tensor_meta.to_tensor()
530
+
531
+ if debug:
532
+ create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
533
+ start_ts = time.time()
534
+ try:
535
+ fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
536
+ except NonzeroWorkspaceNotSupportedError:
537
+ # Skipping all ops with nonzero workspace requirements
538
+ log.info("Skipping op due to nonzero workspace requirement")
539
+ return float("inf")
540
+
541
+ if debug:
542
+ load_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
543
+ start_ts = time.time()
544
+
545
+ out = self.do_bench(fn, *input_tensors, output_tensor)
546
+
547
+ if debug:
548
+ bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
549
+ log.debug(
550
+ "InChildProcess %s: load %f, create tensor %f, bench %f",
551
+ str(self),
552
+ load_elapse, # type: ignore[possibly-undefined]
553
+ create_tensor_elapse, # type: ignore[possibly-undefined]
554
+ bench_elapse,
555
+ )
556
+ self.cleanup_run_fn()
557
+ return out
558
+
559
+
560
+ class TestBenchmarkRequest(BenchmarkRequest):
561
+ """
562
+ Supports unit testing. Defined in this file so that the TuningProcess
563
+ sub-process knows how to unpickle these objects.
564
+ """
565
+
566
+ def __init__(self, value: Optional[float] = None) -> None:
567
+ self.value = value
568
+
569
+ def benchmark(
570
+ self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
571
+ ) -> float:
572
+ if self.value is None:
573
+ raise Exception("Failed to run") # noqa: TRY002
574
+ return self.value
575
+
576
+
577
+ class GPUDeviceBenchmarkRequest(BenchmarkRequest):
578
+ def do_bench(
579
+ self,
580
+ fn,
581
+ *input_tensors: torch.Tensor,
582
+ output_tensor: Optional[torch.Tensor] = None,
583
+ ) -> float:
584
+ device_idx_set = {
585
+ tensor.device.index
586
+ for tensor in [*input_tensors, output_tensor]
587
+ if isinstance(tensor, torch.Tensor)
588
+ and tensor.is_cuda
589
+ and tensor.device.index is not None
590
+ }
591
+ assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}"
592
+ if len(device_idx_set) == 1:
593
+ device_idx = next(iter(device_idx_set))
594
+ else:
595
+ device_idx = torch.cuda.current_device()
596
+
597
+ with torch.cuda.device(device_idx):
598
+ out = benchmarker.benchmark_gpu(fn)
599
+ torch.cuda.synchronize() # shake out any CUDA errors
600
+
601
+ return out
602
+
603
+
604
+ class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest):
605
+ # Important: Instances of this class have to be serializable
606
+ # across process boundaries. Do not put CUDA Tensors in here!
607
+ def __init__(
608
+ self,
609
+ kernel_name: str,
610
+ input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
611
+ output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
612
+ extra_args: Iterable[Any],
613
+ module_path: str, # the path of the module defining the triton kernel
614
+ module_cache_key: str,
615
+ grid: List[int],
616
+ num_stages: int,
617
+ num_warps: int,
618
+ matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
619
+ ) -> None:
620
+ super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
621
+ self.module_path = module_path
622
+ self.module_cache_key = module_cache_key
623
+ self.grid = grid
624
+ self.num_stages = num_stages
625
+ self.num_warps = num_warps
626
+ self.matrix_instr_nonkdim = matrix_instr_nonkdim
627
+
628
+ def make_run_fn(
629
+ self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
630
+ ) -> Callable[[], None]:
631
+ mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
632
+ log.debug(
633
+ "benchmark module key: %s, path: %s",
634
+ self.module_cache_key,
635
+ self.module_path,
636
+ )
637
+
638
+ run_method = getattr(mod, self.kernel_name).run
639
+ extra_args = list(self.extra_args)
640
+
641
+ # Newer version of triton add warmup argument to JITFunction.run.
642
+ # This code handles backward-compatibility.
643
+ warmup_arg = {}
644
+ import inspect
645
+
646
+ if "warmup" in inspect.signature(run_method).parameters:
647
+ warmup_arg["warmup"] = False
648
+
649
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
650
+
651
+ if torch.version.hip and self.matrix_instr_nonkdim != 0:
652
+ return functools.partial(
653
+ run_method,
654
+ *input_tensors,
655
+ output_tensor,
656
+ *self.extra_args,
657
+ grid=self.grid,
658
+ **warmup_arg,
659
+ stream=get_raw_stream(self.output_tensor_meta.device.index),
660
+ )
661
+ else:
662
+ return functools.partial(
663
+ run_method,
664
+ *input_tensors,
665
+ output_tensor,
666
+ *self.extra_args,
667
+ grid=self.grid,
668
+ **warmup_arg,
669
+ stream=get_raw_stream(self.output_tensor_meta.device.index),
670
+ )
671
+
672
+ def precompile(self):
673
+ mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
674
+ getattr(mod, self.kernel_name).precompile()
675
+
676
+ def __str__(self) -> str:
677
+ return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"
678
+
679
+
680
+ class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest):
681
+ # Important: Instances of this class have to be serializable
682
+ # across process boundaries. Do not put CUDA Tensors in here!
683
+
684
+ def __init__(
685
+ self,
686
+ kernel_name: str,
687
+ input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
688
+ output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
689
+ extra_args: Iterable[Any],
690
+ source_code: str,
691
+ ) -> None:
692
+ super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
693
+ self.source_code = source_code
694
+ self.workspace_size: int = 0
695
+ self.workspace: Optional[torch.Tensor] = None
696
+ self.DLL: Optional[DLLWrapper] = None
697
+ self._workspace_size_updated = False
698
+ self.hash_key: str = ""
699
+ self.source_file: str = ""
700
+ self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
701
+
702
+ def precompile(self):
703
+ # Prepopulate CUDACodeCache
704
+ # may happen in separate Threadpool
705
+ log.debug("Precompiling %s", self)
706
+ CUDACodeCache.compile(self.source_code, "so")
707
+ log.debug("Done precompiling %s", self)
708
+
709
+ def make_run_fn(
710
+ self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
711
+ ) -> Callable[[], None]:
712
+ self.ensure_dll_loaded()
713
+ self.update_workspace_size()
714
+ args = [
715
+ c_void_p(tensor.data_ptr())
716
+ for tensor in list(input_tensors) + [output_tensor]
717
+ ]
718
+ log.debug(
719
+ "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
720
+ self.kernel_name,
721
+ self.source_file,
722
+ self.hash_key,
723
+ self.DLL,
724
+ args,
725
+ self.extra_args,
726
+ )
727
+ stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
728
+ run_method = getattr(self.DLL, self.kernel_name)
729
+ workspace_ptr = c_void_p(0)
730
+ if self.workspace_size > 0:
731
+ self.workspace = torch.zeros(
732
+ (self.workspace_size + 7) // 8,
733
+ dtype=torch.float64,
734
+ device=output_tensor.device,
735
+ )
736
+ workspace_ptr = c_void_p(self.workspace.data_ptr())
737
+
738
+ # Generate partial function.
739
+ return functools.partial(
740
+ run_method,
741
+ *args,
742
+ *self.extra_args,
743
+ None, # null workspace size ptr
744
+ workspace_ptr, # set workspace ptr,
745
+ stream_ptr,
746
+ )
747
+
748
+ def update_workspace_size(self) -> None:
749
+ if self._workspace_size_updated:
750
+ return
751
+ self.ensure_dll_loaded()
752
+ unique_input_count = len({meta.name for meta in self.input_tensor_meta})
753
+ args = [c_void_p(None) for _ in range(unique_input_count + 1)]
754
+ stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
755
+
756
+ run_method = getattr(self.DLL, self.kernel_name)
757
+ # Retrieve workspace_size and initialize workspace.
758
+ c_workspace_size = c_size_t()
759
+ run_method(
760
+ *args, # input ptrs and output ptrs
761
+ *self.extra_args,
762
+ byref(
763
+ c_workspace_size
764
+ ), # set workspace size ptr to retrieve workspace size
765
+ None, # null workspace ptr
766
+ stream_ptr,
767
+ )
768
+ torch.cuda.synchronize() # shake out any CUDA errors
769
+ self.workspace_size = c_workspace_size.value
770
+ log.debug(
771
+ "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950
772
+ self.workspace_size,
773
+ self.kernel_name,
774
+ self.source_file,
775
+ self.hash_key,
776
+ self.DLL,
777
+ args,
778
+ self.extra_args,
779
+ )
780
+ self._workspace_size_updated = True
781
+
782
+ def ensure_dll_loaded(self):
783
+ if self.DLL is None:
784
+ self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
785
+ self.source_code, "so"
786
+ )
787
+
788
+ def cleanup_run_fn(self) -> None:
789
+ if self.DLL is not None:
790
+ self.DLL.close()
791
+ self.workspace = None
792
+
793
+ def __str__(self) -> str:
794
+ return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
795
+
796
+
797
+ class CPUDeviceBenchmarkRequest(BenchmarkRequest):
798
+ def do_bench(
799
+ self,
800
+ fn,
801
+ *input_tensors: torch.Tensor,
802
+ output_tensor: Optional[torch.Tensor] = None,
803
+ ) -> float:
804
+ return benchmarker.benchmark_cpu(fn)
805
+
806
+
807
+ class CppBenchmarkRequest(CPUDeviceBenchmarkRequest):
808
+ # Important: Instances of this class have to be serializable
809
+ # across process boundaries. Do not put Tensors in here!
810
+
811
+ def __init__(
812
+ self,
813
+ kernel_name: str,
814
+ input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
815
+ output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
816
+ extra_args: Iterable[Any],
817
+ source_code: str,
818
+ ) -> None:
819
+ super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
820
+ self.source_code = source_code
821
+ self.hash_key = get_hash(source_code)
822
+ self.DLL: Optional[Union[CDLL, ModuleType]] = None
823
+
824
+ def precompile(self):
825
+ # Prepopulate CppCodeCache
826
+ # may happen in separate Threadpool
827
+ log.debug("Precompiling %s", self)
828
+ CppCodeCache.load(self.source_code, cuda=False)
829
+ log.debug("Done precompiling %s", self)
830
+
831
+ def make_run_fn(
832
+ self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
833
+ ) -> Callable[[], None]:
834
+ # TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf
835
+ self.DLL = CppCodeCache.load(self.source_code, cuda=False)
836
+ args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]]
837
+ log.debug(
838
+ "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s",
839
+ self.kernel_name,
840
+ self.DLL,
841
+ args,
842
+ self.extra_args,
843
+ )
844
+ run_method = getattr(self.DLL, self.kernel_name)
845
+ # Assume only size with type ctypes.c_ulonglong in extra_args
846
+ assert all(isinstance(arg, ctypes.c_ulonglong) for arg in self.extra_args)
847
+ run_method.argtypes = [ctypes.c_ulonglong] * (
848
+ len(args) + len(list(self.extra_args))
849
+ )
850
+
851
+ # Generate partial function.
852
+ return functools.partial(
853
+ run_method,
854
+ *args,
855
+ *self.extra_args,
856
+ )
857
+
858
+ def cleanup_run_fn(self) -> None:
859
+ if self.DLL is not None:
860
+ """
861
+ Check close attr due to it crash on Windows.
862
+ """
863
+ if hasattr(self.DLL, "close"):
864
+ self.DLL.close()
865
+
866
+ def __str__(self) -> str:
867
+ return f"{self.kernel_name=}"
868
+
869
+
870
+ def benchmark_in_sub_process(
871
+ choices: List[TritonTemplateCaller],
872
+ ) -> Dict[TritonTemplateCaller, float]:
873
+ """
874
+ Do benchmarking in a subprocess and return the perf number (latency).
875
+ """
876
+ return tuning_pool.benchmark(choices)
.venv/lib/python3.11/site-packages/torch/_inductor/codecache.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+ from enum import IntEnum
4
+
5
+ import sympy
6
+
7
+ import torch
8
+
9
+ from . import ir
10
+ from .utils import get_dtype_size, sympy_product
11
+ from .virtualized import V
12
+
13
+
14
+ class NCCL_COLL(IntEnum):
15
+ ALL_REDUCE = 0
16
+ ALL_GATHER = 1
17
+ REDUCE_SCATTER = 2
18
+
19
+
20
+ class NVIDIA_GPU_TYPE(IntEnum):
21
+ VOLTA = 0
22
+ AMPERE = 1
23
+ HOPPER = 2
24
+
25
+
26
+ @functools.lru_cache
27
+ def get_gpu_type() -> NVIDIA_GPU_TYPE:
28
+ gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or ""
29
+ if "V100" in gpu_info:
30
+ return NVIDIA_GPU_TYPE.VOLTA
31
+ elif "A100" in gpu_info:
32
+ return NVIDIA_GPU_TYPE.AMPERE
33
+ elif "H100" in gpu_info:
34
+ return NVIDIA_GPU_TYPE.HOPPER
35
+ else:
36
+ # for other gpu types, assume Ampere
37
+ return NVIDIA_GPU_TYPE.AMPERE
38
+
39
+
40
+ def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
41
+ if not isinstance(node, ir._CollectiveKernel):
42
+ raise ValueError(f"node is not a collective kernel: {node}")
43
+
44
+ kernel_name = node.python_kernel_name
45
+ assert kernel_name is not None
46
+ if "all_reduce" in kernel_name:
47
+ return NCCL_COLL.ALL_REDUCE
48
+ elif "all_gather" in kernel_name:
49
+ return NCCL_COLL.ALL_GATHER
50
+ elif "reduce_scatter" in kernel_name:
51
+ return NCCL_COLL.REDUCE_SCATTER
52
+ else:
53
+ raise ValueError(f"Unsupported collective kernel: {kernel_name}")
54
+
55
+
56
+ def get_collective_input_size_bytes(node: ir.IRNode) -> int:
57
+ sz_bytes = 0
58
+ for inp in node.inputs: # type: ignore[attr-defined]
59
+ numel = sympy_product(inp.layout.size)
60
+ if isinstance(numel, sympy.Integer):
61
+ # For ease of testing
62
+ numel = int(numel)
63
+ else:
64
+ numel = V.graph.sizevars.size_hint(numel, fallback=0)
65
+ sz_bytes += numel * get_dtype_size(inp.layout.dtype)
66
+ return sz_bytes
67
+
68
+
69
+ def get_collective_group_size(node: ir.IRNode) -> int:
70
+ if type(node) == ir._CollectiveKernel:
71
+ from torch.distributed.distributed_c10d import _get_group_size_by_name
72
+
73
+ return _get_group_size_by_name(node.constant_args[-1])
74
+ else:
75
+ raise TypeError(f"Unsupported collective type: {node}")
76
+
77
+
78
+ ####################################################################################################################
79
+ # The following code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
80
+ ####################################################################################################################
81
+
82
+
83
+ class NCCL_HW(IntEnum):
84
+ NVLINK = 0
85
+ PCI = 1
86
+ NET = 2
87
+
88
+
89
+ class NCCL_ALGO(IntEnum):
90
+ TREE = 0
91
+ RING = 1
92
+
93
+
94
+ class NCCL_PROTO(IntEnum):
95
+ # The ordering and enum values here matches original in
96
+ # https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28
97
+ # For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990
98
+ LL = 0 # Low-latency
99
+ # LL128 = 1 # Low-latency 128-byte
100
+ # SIMPLE = 2
101
+
102
+
103
+ # Latencies in us
104
+ # len(NCCL_ALGO) x len(NCCL_PROTO)
105
+ # NOTE: use array instead of tensor to prevent incompatibility with fake mode
106
+ baseLat = [
107
+ # Tree
108
+ [
109
+ 6.8, # LL
110
+ ],
111
+ # Ring
112
+ [
113
+ 6.6, # LL
114
+ ],
115
+ ]
116
+
117
+ # Latencies in us
118
+ # len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO)
119
+ hwLat = [
120
+ # NVLINK
121
+ [
122
+ [0.6], # Tree (LL)
123
+ [0.6], # Ring (LL)
124
+ ],
125
+ # PCI
126
+ [
127
+ [1.0], # Tree (LL)
128
+ [1.0], # Ring (LL)
129
+ ],
130
+ # NET
131
+ [
132
+ [5.0], # Tree (LL)
133
+ [2.7], # Ring (LL)
134
+ ],
135
+ ]
136
+
137
+
138
+ # LL128 max BW per channel
139
+ llMaxBws = [
140
+ # Volta-N1/Intel-N2/Intel-N4
141
+ [
142
+ 39.0,
143
+ 39.0,
144
+ 20.4,
145
+ ],
146
+ # Ampere-N1/AMD-N2/AMD-N4
147
+ [
148
+ 87.7,
149
+ 22.5, # avg of ring & tree
150
+ 19.0,
151
+ ],
152
+ # Hopper-N1/AMD-N2/AMD-N4
153
+ [
154
+ 87.7,
155
+ 22.5, # avg of ring & tree
156
+ 19.0,
157
+ ],
158
+ ]
159
+
160
+
161
+ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
162
+ """
163
+ Returns estimated NCCL collective runtime in nanoseconds (ns).
164
+
165
+ The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
166
+ We aim to estimate the runtime as accurately as possible.
167
+
168
+ Assumptions:
169
+ - only ring algorithm (NCCL_ALGO_RING) is used
170
+ - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used
171
+ - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
172
+ - collective is one of: allreduce, reducescatter, allgather
173
+ """
174
+ tensor_storage_size_bytes = get_collective_input_size_bytes(node)
175
+ # Convert bytes to GB
176
+ tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
177
+
178
+ # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
179
+ # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
180
+ num_gpus_per_node = 8
181
+ group_size = get_collective_group_size(node)
182
+ nNodes = math.ceil(group_size / num_gpus_per_node)
183
+ nRanks = group_size # this is total # of gpus globally that participate in this collective op
184
+
185
+ if nRanks <= 1:
186
+ return 0
187
+
188
+ # Assumes ring algorithm
189
+ nccl_algo = NCCL_ALGO.RING
190
+ nccl_proto = NCCL_PROTO.LL
191
+ coll = get_collective_type(node)
192
+
193
+ # =============== bandwidth computation ===============
194
+ # First compute bandwidth in GB/s; then at the end, convert it to GB/ns
195
+
196
+ bwIntra = torch._inductor.config.intra_node_bw
197
+ bwInter = torch._inductor.config.inter_node_bw
198
+
199
+ compCapIndex = get_gpu_type()
200
+ index2 = nNodes - 1 if nNodes <= 2 else 2
201
+ # LL: for single node, we look at GPU type; for multi-node, we look at CPU type
202
+ index1 = compCapIndex if nNodes == 1 else 0
203
+ llMaxBw = llMaxBws[index1][index2]
204
+
205
+ # NOTE: each step of ring algorithm is synchronized,
206
+ # and is bottlenecked by the slowest link which is the inter-node interconnect.
207
+ # hence when nNodes >= 2, bw is inter-node bandwidth.
208
+ # NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc
209
+ # have this as `if nNodes <= 2` which seems wrong. Corrected it here.
210
+ bw = bwIntra if nNodes == 1 else bwInter
211
+ nChannels = 2 # Assume # channels is 2
212
+ busBw = nChannels * bw
213
+
214
+ # Various model refinements
215
+ busBw = min(
216
+ llMaxBw,
217
+ busBw
218
+ * (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0),
219
+ )
220
+
221
+ if coll == NCCL_COLL.ALL_REDUCE:
222
+ nsteps = 2 * (nRanks - 1)
223
+ elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
224
+ nsteps = nRanks - 1
225
+
226
+ # Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time)
227
+ ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined]
228
+ bandwidth = busBw * ratio
229
+ # Convert GB/s to GB/ns
230
+ bandwidth_GB_per_ns = bandwidth / 1e9
231
+
232
+ # =============== latency computation ===============
233
+ intraHw = NCCL_HW.NVLINK
234
+
235
+ if coll == NCCL_COLL.ALL_REDUCE:
236
+ if nNodes > 1:
237
+ nInterSteps = 2 * nNodes
238
+ else:
239
+ nInterSteps = 0
240
+ elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
241
+ nInterSteps = nNodes - 1
242
+
243
+ # First compute latency in us; then at the end, convert it to ns
244
+ latency = baseLat[nccl_algo][nccl_proto]
245
+ intraLat = hwLat[intraHw][nccl_algo][nccl_proto]
246
+ interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto]
247
+
248
+ # Inter-node rings still have to launch nsteps * net overhead.
249
+ netOverhead = 0.0
250
+ if nNodes > 1:
251
+ netOverhead = 1.0 # getNetOverhead(comm);
252
+ intraLat = max(intraLat, netOverhead)
253
+ latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined]
254
+ # Convert us to ns
255
+ latency_ns = latency * 1e3
256
+
257
+ # =============== final result ===============
258
+ transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns
259
+ return transport_ns + latency_ns
260
+
261
+
262
+ ################################################################################################################
263
+ # The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
264
+ ################################################################################################################
.venv/lib/python3.11/site-packages/torch/_inductor/comms.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # pyre-strict
3
+ from __future__ import annotations
4
+
5
+ import heapq
6
+ import operator
7
+ import sys
8
+ from collections import defaultdict
9
+ from typing import Dict, List, Set, TYPE_CHECKING
10
+
11
+ import torch
12
+
13
+ from . import config, ir
14
+ from .dependencies import WeakDep
15
+ from .utils import (
16
+ contains_collective,
17
+ contains_wait,
18
+ find_recursive_deps_of_node,
19
+ find_recursive_users_of_node,
20
+ is_collective,
21
+ is_fallback_op,
22
+ is_wait,
23
+ )
24
+
25
+
26
+ overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
27
+
28
+ if TYPE_CHECKING:
29
+ from .scheduler import BaseSchedulerNode
30
+
31
+
32
+ def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
33
+ """
34
+ Greedily schedules waits as late as possible.
35
+ """
36
+ return _schedule_for_comm(
37
+ snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False
38
+ )
39
+
40
+
41
+ def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
42
+ """
43
+ Greedily schedules comms as early as possible.
44
+ """
45
+ return _schedule_for_comm(
46
+ snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False
47
+ )
48
+
49
+
50
+ def reorder_compute_for_overlap(
51
+ snodes: List[BaseSchedulerNode],
52
+ ) -> List[BaseSchedulerNode]:
53
+ """
54
+ This achieves the following overall scheduling procedure:
55
+ Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
56
+ that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N.
57
+ Step 2: If all those compute nodes are sufficient to overlap comm N, we're done.
58
+ Otherwise, we now need to look elsewhere to find compute that overlaps with comm N.
59
+ We prioritize compute nodes that are needed sooner.
60
+ Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1.
61
+ Step 4: We schedule comm N + 1.
62
+ Repeat this for subsequent comm nodes.
63
+ """
64
+ return _schedule_for_comm(
65
+ snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True
66
+ )
67
+
68
+
69
+ def _schedule_for_comm(
70
+ snodes: List[BaseSchedulerNode],
71
+ raise_comms: bool,
72
+ sink_waits: bool,
73
+ reorder_for_overlap: bool,
74
+ ) -> List[BaseSchedulerNode]:
75
+ """
76
+ Schedule `snodes` for various comm optimization objectives.
77
+
78
+ Args:
79
+ snodes: the nodes to be scheduled.
80
+ raise_comms: whether to greedily schedule collectives as early as possible
81
+ sink_wait: whether to greedily schedule waits as late as possible
82
+ reorder_compute_for_overlap: whether to reorder compute nodes to
83
+ optimize for compute/communication overlapping.
84
+
85
+ Returns:
86
+ The new schedule order.
87
+
88
+ Some notes on the synergy between different options:
89
+ - `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`.
90
+ - When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized.
91
+ """
92
+ # We assign each node a tuple of scores (score_0, score_1, score_2),
93
+ # decreasing in importance, with a lower value indicating a higher ranking:
94
+ #
95
+ # - score_0: the lowest comm_idx among the comm nodes that the node blocks.
96
+ # If a node doesn't block any comm nodes, its score_0 is set to
97
+ # sys.maxsize. This score ensures that comm nodes get scheduled as early as
98
+ # possible.
99
+ # - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures
100
+ # that wait nodes are deferred as late as possible.
101
+ # - score_2: the index of the node in the original topological order. This
102
+ # score provides stability in case of ties.
103
+ #
104
+ # When only raise_comms is True, only score_0 and score_2 are considered.
105
+ # When only sink_waits is True, only score_1 and score_2 are considered.
106
+ # When neither is True, the original order is yielded.
107
+ buf_name_to_snode = {}
108
+ name_to_fused_node = {}
109
+ scores_0, scores_1, scores_2 = {}, {}, {}
110
+ for idx, snode in enumerate(snodes):
111
+ for buf_name in snode.get_buffer_names():
112
+ buf_name_to_snode[buf_name] = snode
113
+
114
+ for op_name in snode.get_operation_names():
115
+ name_to_fused_node[op_name] = snode
116
+ name_to_fused_node[snode.get_name()] = snode
117
+
118
+ node_name = snode.get_name()
119
+ scores_0[node_name] = sys.maxsize
120
+ scores_1[node_name] = 0
121
+ scores_2[node_name] = idx
122
+
123
+ comm_idx = 0
124
+ for snode in snodes:
125
+ if raise_comms and contains_collective(snode):
126
+ scores_0[snode.get_name()] = comm_idx
127
+ for anc in snode.ancestors:
128
+ anc_fused_name = name_to_fused_node[anc].get_name()
129
+ scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx)
130
+ comm_idx += 1
131
+ elif sink_waits and contains_wait(snode):
132
+ scores_1[snode.get_name()] = 1
133
+
134
+ class Runnable:
135
+ def __init__(self, snode) -> None:
136
+ self.snode = snode
137
+ name = next(iter(snode.get_operation_names()))
138
+ fused_name = name_to_fused_node[name].get_name()
139
+ self.score = (
140
+ scores_0[fused_name],
141
+ scores_1[fused_name],
142
+ scores_2[fused_name],
143
+ )
144
+
145
+ def __lt__(self, other):
146
+ return self.score < other.score
147
+
148
+ unmet_deps: Dict[BaseSchedulerNode, Set[str]] = {
149
+ snode: {dep.name for dep in snode.unmet_dependencies} for snode in snodes
150
+ }
151
+
152
+ ready: List[Runnable] = []
153
+ buffer_users: Dict[str, Set[BaseSchedulerNode]] = defaultdict(set)
154
+ snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes}
155
+
156
+ for snode, deps in unmet_deps.items():
157
+ if len(deps) == 0:
158
+ heapq.heappush(ready, Runnable(snode))
159
+ for dep in deps:
160
+ buffer_users[dep].add(snode)
161
+
162
+ scheduled = []
163
+
164
+ def schedule(snode):
165
+ """
166
+ Schedules `snode` and put all unblocked nodes onto the ready queue.
167
+ """
168
+ scheduled.append(snode)
169
+ for buf_name in snode.get_buffer_names():
170
+ for snode in buffer_users[buf_name]:
171
+ unmet_deps[snode].remove(buf_name)
172
+ if len(unmet_deps[snode]) == 0:
173
+ heapq.heappush(ready, Runnable(snode))
174
+
175
+ def get_overlapping_candidate():
176
+ """
177
+ Return the next node in the ready queue that's neither a collective or
178
+ a wait.
179
+ """
180
+ candidates = [
181
+ x
182
+ for x in ready
183
+ if not contains_collective(x.snode) and not contains_wait(x.snode)
184
+ ]
185
+ if len(candidates) == 0:
186
+ return None
187
+ return min(candidates, key=lambda x: x.score)
188
+
189
+ def schedule_collective_for_overlap(snode):
190
+ """
191
+ Schedules collective node `snode`, along with one or more compute nodes
192
+ to overlap with it. The strategy is described in the comment of
193
+ `reorder_compute_for_overlap`.
194
+ """
195
+ assert contains_collective(snode)
196
+ schedule(snode)
197
+
198
+ collective_cost = snode_to_cost[snode]
199
+ while (
200
+ collective_cost > 0
201
+ and (candidate := get_overlapping_candidate()) is not None
202
+ ):
203
+ ready.remove(candidate)
204
+ schedule(candidate.snode)
205
+ collective_cost -= snode_to_cost[candidate.snode]
206
+ heapq.heapify(ready)
207
+
208
+ while len(ready):
209
+ snode = heapq.heappop(ready).snode
210
+ if reorder_for_overlap and contains_collective(snode):
211
+ schedule_collective_for_overlap(snode)
212
+ else:
213
+ schedule(snode)
214
+
215
+ for snode, deps in unmet_deps.items():
216
+ assert len(deps) == 0, (
217
+ "Detected unscheduled nodes. "
218
+ f"Nodes with unmet dependencies: {unmet_deps}"
219
+ )
220
+ return scheduled
221
+
222
+
223
+ def decide_global_ordering_of_comms(
224
+ nodes: List[BaseSchedulerNode], name_to_buf, name_to_fused_node
225
+ ) -> List[BaseSchedulerNode]:
226
+ """
227
+ Decide global ordering of comms, by just enforcing the ordering that's in the input graph
228
+ (might not be the same ordering as the eager mode program).
229
+ TODO: Come up with a better approach
230
+ """
231
+ # If FSDP2 is used, we apply FSDP-specific passes.
232
+ if any(
233
+ is_fallback_op(
234
+ x.node,
235
+ {
236
+ torch.ops.fsdp.all_gather_copy_in.default,
237
+ torch.ops.fsdp.chunk_cat.default,
238
+ },
239
+ )
240
+ for x in nodes
241
+ ):
242
+ nodes = enforce_comm_ordering_for_fsdp(nodes, name_to_buf, name_to_fused_node)
243
+
244
+ comm_nodes = [n for n in nodes if contains_collective(n)]
245
+
246
+ for i in range(1, len(comm_nodes)):
247
+ # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
248
+ mutating_buf = next(iter(comm_nodes[i].get_buffer_names()))
249
+ for buf in comm_nodes[i - 1].get_buffer_names():
250
+ comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf))
251
+
252
+ return nodes
253
+
254
+
255
+ def estimate_op_runtime(snode: BaseSchedulerNode) -> float:
256
+ """
257
+ Returns estimated op runtime in nanoseconds (ns)
258
+ """
259
+ if config.estimate_op_runtime == "default":
260
+ runtime = snode.get_estimated_runtime()
261
+ else:
262
+ assert callable(config.estimate_op_runtime)
263
+ runtime = config.estimate_op_runtime(snode)
264
+ return runtime
265
+
266
+
267
+ def node_summary(snode):
268
+ detail = ""
269
+ if isinstance(snode.node, ir.ExternKernelOut):
270
+ detail = f" ({snode.node.python_kernel_name})"
271
+ out_tensor_info = ""
272
+ if (
273
+ hasattr(snode.node, "layout")
274
+ and hasattr(snode.node.layout, "size")
275
+ and hasattr(snode.node.layout, "stride")
276
+ ):
277
+ out_tensor_info = (
278
+ f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})"
279
+ )
280
+ node_name = ""
281
+ if hasattr(snode.node, "name"):
282
+ node_name = snode.node.name
283
+ return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})"
284
+
285
+
286
+ def visualize_overlap(order):
287
+ total_est_runtime: float = 0.0
288
+ cur_comm_node = None
289
+ for snode in order:
290
+ if cur_comm_node is None:
291
+ if contains_collective(snode):
292
+ total_est_runtime += estimate_op_runtime(snode)
293
+ cur_comm_node = snode.node
294
+ elif is_wait(snode.node):
295
+ raise AssertionError(
296
+ "Wait is not expected when there is no collective running"
297
+ )
298
+ else: # exposed compute op
299
+ total_est_runtime += estimate_op_runtime(snode)
300
+ overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
301
+ else: # cur_comm_node is not None
302
+ if contains_collective(snode):
303
+ raise AssertionError(
304
+ "Found two collectives running at the same time. "
305
+ "`visualize_overlap` needs to be updated to handle this case"
306
+ )
307
+ elif is_wait(snode.node): # end of this comm op
308
+ overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
309
+ cur_comm_node = None
310
+ else: # overlapped compute op
311
+ overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004
312
+ overlap_log.debug(
313
+ f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004
314
+ )
315
+
316
+
317
+ def reorder_compute_and_comm_for_overlap(
318
+ snodes: List[BaseSchedulerNode],
319
+ ) -> List[BaseSchedulerNode]:
320
+ order = snodes
321
+
322
+ for p in config.reorder_for_compute_comm_overlap_passes:
323
+ if isinstance(p, str) and p in globals():
324
+ p = globals()[p] # it is a builtin pass
325
+ if torch.distributed.get_rank() == 0:
326
+ overlap_log.debug(
327
+ f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004
328
+ )
329
+ try:
330
+ visualize_overlap(order)
331
+ except Exception as e:
332
+ overlap_log.debug(str(e))
333
+ order = p(order) # type: ignore[operator]
334
+ if torch.distributed.get_rank() == 0:
335
+ overlap_log.debug(
336
+ f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004
337
+ )
338
+ try:
339
+ visualize_overlap(order)
340
+ except Exception as e:
341
+ overlap_log.debug(str(e))
342
+ return order
343
+
344
+
345
+ def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
346
+ try:
347
+ import torch.distributed._composable.fsdp._fsdp_collectives
348
+
349
+ assert torch.distributed.is_available()
350
+ # Assert existence of these ops
351
+ assert (
352
+ torch.ops._c10d_functional.all_gather_into_tensor
353
+ and torch.ops._c10d_functional.all_gather_into_tensor_out
354
+ )
355
+ except (ImportError, AttributeError, AssertionError):
356
+ return
357
+
358
+ from .pattern_matcher import (
359
+ CallFunction,
360
+ KeywordArg,
361
+ Match,
362
+ PatternMatcherPass,
363
+ register_graph_pattern,
364
+ )
365
+
366
+ """
367
+ all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
368
+ getitem = all_gather_copy_in[0];
369
+ (getitem_1 = all_gather_copy_in[1];) # optional
370
+
371
+ all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...);
372
+
373
+ ->
374
+
375
+ all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
376
+ getitem = all_gather_copy_in[0];
377
+ getitem_1 = all_gather_copy_in[1];
378
+
379
+ all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1);
380
+ """
381
+
382
+ def remove_unused_getitem(g):
383
+ # Remove `getitem_X = all_gather_copy_in[1]` which is never used.
384
+ node_list = list(g.nodes)
385
+ for n in node_list:
386
+ if (
387
+ n.target == operator.getitem
388
+ and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default
389
+ and n.args[1] == 1
390
+ ):
391
+ g.erase_node(n)
392
+
393
+ graph_pass = PatternMatcherPass()
394
+
395
+ @register_graph_pattern(
396
+ CallFunction(
397
+ torch.ops._c10d_functional.all_gather_into_tensor.default,
398
+ CallFunction(
399
+ operator.getitem,
400
+ CallFunction(
401
+ torch.ops.fsdp.all_gather_copy_in.default,
402
+ KeywordArg("all_gather_inputs"),
403
+ KeywordArg("inp_split_sizes"),
404
+ KeywordArg("all_gather_input_numel"),
405
+ KeywordArg("world_size"),
406
+ KeywordArg("rank"),
407
+ KeywordArg("dtype"),
408
+ KeywordArg("device"),
409
+ ),
410
+ KeywordArg("item_idx"),
411
+ ),
412
+ KeywordArg("group_size"),
413
+ KeywordArg("group_name"),
414
+ ),
415
+ pass_dict=graph_pass,
416
+ extra_check=lambda match: match.kwargs["item_idx"] == 0,
417
+ )
418
+ def reinplace_all_gather(match: Match, *args, **kwargs):
419
+ def repl(
420
+ *args,
421
+ ):
422
+ copy_in_args = args[:-2]
423
+ group_size = args[-2]
424
+ group_name = args[-1]
425
+ all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(
426
+ *copy_in_args
427
+ )
428
+ getitem = all_gather_copy_in[0]
429
+ getitem_1 = all_gather_copy_in[1]
430
+ all_gather_into_tensor = (
431
+ torch.ops._c10d_functional.all_gather_into_tensor_out.default(
432
+ getitem, group_size, group_name, out=getitem_1
433
+ )
434
+ )
435
+ return all_gather_into_tensor
436
+
437
+ match.replace_by_example(
438
+ repl,
439
+ [
440
+ kwargs["all_gather_inputs"],
441
+ kwargs["inp_split_sizes"],
442
+ kwargs["all_gather_input_numel"],
443
+ kwargs["world_size"],
444
+ kwargs["rank"],
445
+ kwargs["dtype"],
446
+ kwargs["device"],
447
+ kwargs["group_size"],
448
+ kwargs["group_name"],
449
+ ],
450
+ )
451
+
452
+ remove_unused_getitem(graph)
453
+ graph_pass.apply(graph) # type: ignore[arg-type]
454
+
455
+
456
+ def get_op_idx(snode):
457
+ assert not isinstance(
458
+ snode,
459
+ (
460
+ torch._inductor.scheduler.FusedSchedulerNode,
461
+ torch._inductor.scheduler.GroupedSchedulerNode,
462
+ ),
463
+ )
464
+ return int(snode.get_name()[2:])
465
+
466
+
467
+ def enforce_comm_ordering_for_fsdp(
468
+ snodes: List[torch._inductor.scheduler.BaseSchedulerNode],
469
+ name_to_buf: Dict[str, torch._inductor.scheduler.SchedulerBuffer],
470
+ name_to_fused_node: Dict[str, BaseSchedulerNode],
471
+ ) -> List[torch._inductor.scheduler.BaseSchedulerNode]:
472
+ from . import scheduler
473
+
474
+ new_order: list[BaseSchedulerNode] = []
475
+ scheduled = set()
476
+ ag_exists = False
477
+ rs_exists = False
478
+ ag_grouped_node_to_wait_grouped_node = {}
479
+ rs_grouped_node_to_wait_grouped_node = {}
480
+ snode_name_to_final_snode = {}
481
+
482
+ def _create_group_node(snodes_to_group):
483
+ group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group)
484
+ for snode in snodes_to_group:
485
+ snode_name_to_final_snode[snode.get_name()] = group_node
486
+ snode_name_to_final_snode[group_node.get_name()] = group_node
487
+ return group_node
488
+
489
+ # Create grouped nodes for specific sets of ops
490
+ for snode in snodes:
491
+ # Case 1: Handle AllGather
492
+ if is_collective(
493
+ snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default
494
+ ) and any(
495
+ is_fallback_op(
496
+ name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default
497
+ )
498
+ for x in snode.ancestors
499
+ ):
500
+ ag_exists = True
501
+ ag_snode = snode
502
+ ag_related_snode_set: set[scheduler.BaseSchedulerNode] = set()
503
+
504
+ # Find the "cast + copy_in + getitem + all_gather" code block
505
+ find_recursive_deps_of_node(
506
+ ag_snode,
507
+ ag_related_snode_set,
508
+ name_to_buf,
509
+ name_to_fused_node,
510
+ )
511
+
512
+ # Find the "all_gather + all_gather_wait_tensor + copy_out + set_" code block
513
+ allowed_ops = {
514
+ torch.ops._c10d_functional.all_gather_into_tensor_out.default,
515
+ torch.ops._c10d_functional.wait_tensor.default,
516
+ torch.ops.fsdp.split_with_sizes_copy.default,
517
+ torch.ops.aten.set_.source_Tensor,
518
+ }
519
+ find_recursive_users_of_node(
520
+ ag_snode,
521
+ ag_related_snode_set,
522
+ name_to_buf,
523
+ name_to_fused_node,
524
+ criteria_cb=lambda x: not (
525
+ isinstance(x, scheduler.NopKernelSchedulerNode)
526
+ or (
527
+ isinstance(x, scheduler.ExternKernelSchedulerNode)
528
+ and x.node.op_overload in allowed_ops # type: ignore[union-attr]
529
+ )
530
+ ),
531
+ )
532
+
533
+ # sort nodes by original operation order
534
+ ag_related_snodes = sorted(
535
+ ag_related_snode_set, key=lambda x: get_op_idx(x)
536
+ )
537
+
538
+ # In the "reuse layer" case, some ops in the 2nd all-gather code block could also
539
+ # depend on ops in the 1st all-gather code block, and we don't want to group them together.
540
+ end_idx_of_current_ag_block = len(ag_related_snodes)
541
+ copy_out_count = 0
542
+ for i in range(len(ag_related_snodes)):
543
+ cur_snode = ag_related_snodes[i]
544
+ if is_fallback_op(
545
+ cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default
546
+ ):
547
+ copy_out_count += 1
548
+ if copy_out_count > 1:
549
+ end_idx_of_current_ag_block = i
550
+ break
551
+
552
+ ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block]
553
+
554
+ # Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode
555
+ wait_node_idx = None
556
+ for i in range(len(ag_related_snodes) - 1):
557
+ if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel):
558
+ wait_node_idx = i + 1
559
+ break
560
+ assert wait_node_idx is not None
561
+ ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx])
562
+
563
+ # Group "all_gather_wait_tensor + copy_out + set_" into one GroupedSchedulerNode
564
+ ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:])
565
+
566
+ ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node
567
+
568
+ # Case 2: Handle ReduceScatter
569
+ elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default):
570
+ rs_exists = True
571
+ rs_snode = snode
572
+
573
+ # Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block
574
+ rs_related_snode_set: set[scheduler.BaseSchedulerNode] = set()
575
+ find_recursive_users_of_node(
576
+ rs_snode,
577
+ rs_related_snode_set,
578
+ name_to_buf,
579
+ name_to_fused_node,
580
+ )
581
+
582
+ # sort nodes by original operation order
583
+ rs_related_snodes = sorted(
584
+ rs_related_snode_set, key=lambda x: get_op_idx(x)
585
+ )
586
+
587
+ # Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode
588
+ wait_node_idx = None
589
+ for i in range(len(rs_related_snodes) - 1):
590
+ if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel):
591
+ wait_node_idx = i + 1
592
+ break
593
+ assert wait_node_idx is not None
594
+ rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx])
595
+
596
+ # Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode
597
+ rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:])
598
+
599
+ rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node
600
+
601
+ assert len(snode_name_to_final_snode) > 0
602
+ if ag_exists:
603
+ assert len(ag_grouped_node_to_wait_grouped_node) > 0
604
+ if rs_exists:
605
+ assert len(rs_grouped_node_to_wait_grouped_node) > 0
606
+
607
+ # Build the new node schedule, taking GroupedSchedulerNode into account
608
+ for snode in snodes:
609
+ if snode.get_name() in snode_name_to_final_snode:
610
+ snode = snode_name_to_final_snode[snode.get_name()]
611
+ if snode in scheduled:
612
+ continue
613
+ new_order.append(snode)
614
+ scheduled.add(snode)
615
+
616
+ # Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run
617
+ # before next AllGather's "copy_in then AG" group node
618
+ prev_ag_wait = None
619
+ for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items():
620
+ if prev_ag_wait is not None:
621
+ mutating_buf = next(iter(ag_group_node.get_buffer_names()))
622
+ for o in prev_ag_wait.get_outputs():
623
+ ag_group_node.add_fake_dep(
624
+ WeakDep(o.get_name(), mutating_buf=mutating_buf)
625
+ )
626
+ prev_ag_wait = wait_group_node
627
+
628
+ # Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run
629
+ # before next ReduceScatter's "copy_in then RS" group node
630
+ prev_rs_wait = None
631
+ for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items():
632
+ if prev_rs_wait is not None:
633
+ mutating_buf = next(iter(rs_group_node.get_buffer_names()))
634
+ for o in prev_rs_wait.get_outputs():
635
+ rs_group_node.add_fake_dep(
636
+ WeakDep(o.get_name(), mutating_buf=mutating_buf)
637
+ )
638
+ prev_rs_wait = wait_group_node
639
+
640
+ return new_order # type: ignore[return-value]
.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py ADDED
@@ -0,0 +1,1629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ import contextlib
4
+ import functools
5
+ import io
6
+ import itertools
7
+ import logging
8
+ import os
9
+ import sys
10
+ import time
11
+ import warnings
12
+ from itertools import count
13
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
14
+ from unittest import mock
15
+
16
+ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
17
+ import torch.fx
18
+ import torch.utils._pytree as pytree
19
+ from functorch.compile import min_cut_rematerialization_partition
20
+ from torch._dynamo import (
21
+ compiled_autograd,
22
+ config as dynamo_config,
23
+ logging as dynamo_logging,
24
+ utils as dynamo_utils,
25
+ )
26
+ from torch._dynamo.device_interface import get_interface_for_device
27
+ from torch._dynamo.repro.after_aot import wrap_compiler_debug
28
+ from torch._dynamo.utils import (
29
+ counters,
30
+ detect_fake_mode,
31
+ flatten_graph_inputs,
32
+ lazy_format_graph_code,
33
+ )
34
+ from torch._functorch import config as functorch_config
35
+ from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
36
+ from torch._inductor.codecache import (
37
+ _StrideExprStr,
38
+ code_hash,
39
+ CompiledFxGraph,
40
+ FxGraphCache,
41
+ )
42
+ from torch._inductor.cudagraph_utils import (
43
+ BoxedDeviceIndex,
44
+ CudagraphCachedInfo,
45
+ get_placeholder_info,
46
+ log_cudagraph_skip_and_bump_counter,
47
+ PlaceholderInfo,
48
+ )
49
+ from torch._inductor.debug import save_args_for_compile_fx_inner
50
+ from torch._inductor.runtime.runtime_utils import cache_dir
51
+ from torch._inductor.utils import (
52
+ BoxedBool,
53
+ count_tangents,
54
+ fresh_inductor_cache,
55
+ InputType,
56
+ is_gpu,
57
+ should_assume_input_aligned,
58
+ tensor_is_aligned,
59
+ )
60
+ from torch._logging import trace_structured
61
+ from torch._ops import OpOverload
62
+ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter
63
+ from torch.fx.passes.fake_tensor_prop import FakeTensorProp
64
+ from torch.monitor import _WaitCounter
65
+
66
+ from .._dynamo.backends.common import aot_autograd
67
+ from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined]
68
+ from ..fx.graph import _PyTreeCodeGen
69
+ from . import config, metrics
70
+ from .debug import DebugContext
71
+ from .decomposition import select_decomp_table
72
+ from .fx_passes.joint_graph import joint_graph_passes
73
+ from .fx_passes.post_grad import post_grad_passes, view_to_reshape
74
+ from .fx_passes.pre_grad import pre_grad_passes
75
+ from .graph import GraphLowering
76
+ from .ir import ExternKernelNode
77
+ from .utils import (
78
+ align_inputs_from_check_idxs,
79
+ clone_preserve_strides,
80
+ copy_misaligned_inputs,
81
+ get_cloned_parameter_buffer_name,
82
+ has_incompatible_cudagraph_ops,
83
+ maybe_get_suppress_shape_guards_ctx,
84
+ output_node,
85
+ remove_unaligned_input_idxs,
86
+ shape_env_from_inputs,
87
+ )
88
+ from .virtualized import V
89
+
90
+
91
+ if config.is_fbcode():
92
+ from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log
93
+ else:
94
+ # no-op decorator
95
+ def time_and_log(attr: str):
96
+ return dynamo_utils.identity
97
+
98
+
99
+ log = logging.getLogger(__name__)
100
+ perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
101
+ post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs")
102
+ static_inputs_log = torch._logging.getArtifactLogger(
103
+ __name__, "cudagraph_static_inputs"
104
+ )
105
+
106
+
107
+ # copy_ fails when trying to write to tensors with memory overlap,
108
+ # for expanded dimensions (a dimension which used to have size 1 -> ?)
109
+ # we can select one element from that dimension and write to it
110
+ # to achieve writing to all values of that dimension of the input tensor
111
+ def get_expanded_dims(t):
112
+ if not isinstance(t, torch.Tensor):
113
+ return None
114
+ return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
115
+
116
+
117
+ def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor:
118
+ for expanded_dim in expanded_dims:
119
+ t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
120
+ return t
121
+
122
+
123
+ def complex_memory_overlap(t: torch.Tensor) -> bool:
124
+ # if torch._debug_has_internal_overlap thinks this tensor potentially has
125
+ # memory overlap internally, let's dig deeper to find out whether it's true.
126
+ #
127
+ # Call squeeze() so that dimension with size 1 does not cause false positive.
128
+ t = index_expanded_dims(t, get_expanded_dims(t)).squeeze()
129
+ if torch._debug_has_internal_overlap(t) != 0:
130
+ strides = t.stride()
131
+ sizes = t.shape
132
+ indices = list(range(len(strides)))
133
+ indices = [x for _, x in sorted(zip(strides, indices))]
134
+ for i in range(len(strides)):
135
+ prev_stride = 1 if i == 0 else strides[indices[i - 1]]
136
+ prev_size = 1 if i == 0 else sizes[indices[i - 1]]
137
+ if strides[indices[i]] < prev_stride * prev_size:
138
+ return True
139
+ return False
140
+
141
+
142
+ def get_static_input_idxs(num_fixed):
143
+ # If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes
144
+ # of cudagraphs. Rather than copying these into cudagraph-owned memory
145
+ # like we do for normal inputs on each run, we will re-record a cudagraph if these
146
+ # parameter locations change.
147
+ context = torch._guards.TracingContext.try_get()
148
+ fixed = list(range(num_fixed))
149
+ if not context or not context.fw_metadata:
150
+ return fixed
151
+
152
+ return fixed + context.fw_metadata.static_input_indices
153
+
154
+
155
+ @functools.lru_cache(None)
156
+ def _step_logger():
157
+ return dynamo_logging.get_step_logger(log)
158
+
159
+
160
+ @functools.lru_cache(None)
161
+ def _warn_tf32_disabled():
162
+ if (
163
+ torch.cuda.is_available()
164
+ and not torch.backends.cuda.matmul.allow_tf32
165
+ and torch.cuda.get_device_capability() >= (8, 0)
166
+ ):
167
+ warnings.warn(
168
+ "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. "
169
+ "Consider setting `torch.set_float32_matmul_precision('high')` for better performance."
170
+ )
171
+
172
+
173
+ def _unlift_graph(mod, gm, graph_signature):
174
+ from torch.export.unflatten import _assign_attr, _AttrKind
175
+
176
+ state_dict = {}
177
+ for name, param in mod.named_parameters(remove_duplicate=False):
178
+ state_dict[name] = param
179
+ _assign_attr(
180
+ param,
181
+ gm,
182
+ name,
183
+ attr_kind=_AttrKind.PARAMETER,
184
+ )
185
+ for name, buffer in mod.named_buffers(remove_duplicate=False):
186
+ state_dict[name] = buffer
187
+ _assign_attr(
188
+ buffer,
189
+ gm,
190
+ name,
191
+ attr_kind=_AttrKind.BUFFER,
192
+ )
193
+
194
+ placeholder_nodes = gm.graph.find_nodes(op="placeholder")
195
+ lifted_inputs = []
196
+
197
+ # In AOTI, module parameters and buffers are not lifted as graph inputs.
198
+ # As a result, mutation to buffers has side effect which makes their initial
199
+ # values different from Eager. So we clone them here as a copy.
200
+ # We are not cloning for parameters, although it will be needed if we want to
201
+ # support training.
202
+ for node in placeholder_nodes:
203
+ node_name = node.name
204
+ if node_name in graph_signature.inputs_to_parameters:
205
+ parameter_name = graph_signature.inputs_to_parameters[node_name]
206
+ lifted_inputs.append(parameter_name)
207
+ elif node_name in graph_signature.inputs_to_buffers:
208
+ buffer_name = graph_signature.inputs_to_buffers[node_name]
209
+ lifted_inputs.append(buffer_name)
210
+ gm.meta[
211
+ get_cloned_parameter_buffer_name(buffer_name)
212
+ ] = clone_preserve_strides(state_dict[buffer_name])
213
+ else:
214
+ assert node_name in graph_signature.user_inputs
215
+ lifted_inputs.append(None)
216
+
217
+ from torch.export._unlift import _unlift
218
+
219
+ outputs = list(gm.graph.nodes)[-1].args[0]
220
+ mutated_outputs = []
221
+ buffer_mutations = graph_signature.buffers_to_mutate
222
+ user_input_mutations = graph_signature.user_inputs_to_mutate
223
+ output_tokens = graph_signature.output_tokens
224
+ for idx, out in enumerate(outputs):
225
+ value = None
226
+
227
+ if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
228
+ if out.name in buffer_mutations:
229
+ value = buffer_mutations[out.name]
230
+ elif out.name in user_input_mutations:
231
+ value = user_input_mutations[out.name]
232
+
233
+ mutated_outputs.append(value)
234
+
235
+ unlifted_gm = _unlift(
236
+ gm,
237
+ lifted_inputs,
238
+ mutated_outputs,
239
+ pytree.LeafSpec(),
240
+ None,
241
+ state_dict,
242
+ {},
243
+ )
244
+ return unlifted_gm
245
+
246
+
247
+ def _get_subgraph_names(gm):
248
+ for node in sorted(
249
+ itertools.chain(
250
+ gm.graph.find_nodes(op="call_function", target=torch.ops.higher_order.cond),
251
+ gm.graph.find_nodes(
252
+ op="call_function", target=torch.ops.higher_order.while_loop
253
+ ),
254
+ )
255
+ ):
256
+ if node.target == torch.ops.higher_order.cond:
257
+ true_subgraph_name = node.args[1].name
258
+ false_subgraph_name = node.args[2].name
259
+ yield true_subgraph_name
260
+ yield false_subgraph_name
261
+ elif node.target == torch.ops.higher_order.while_loop:
262
+ cond_subgraph_name = node.args[0].name
263
+ body_subgraph_name = node.args[1].name
264
+ yield cond_subgraph_name
265
+ yield body_subgraph_name
266
+
267
+
268
+ def _recursive_pre_grad_passes(gm, example_inputs):
269
+ for subgraph_name in _get_subgraph_names(gm):
270
+ subgraph = getattr(gm, subgraph_name)
271
+ # as we don't have recursive example inputs, passing None here
272
+ new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None)
273
+ setattr(gm, subgraph_name, new_subgraph)
274
+ return pre_grad_passes(gm, example_inputs)
275
+
276
+
277
+ def _recursive_joint_graph_passes(gm):
278
+ for subgraph_name in _get_subgraph_names(gm):
279
+ subgraph = getattr(gm, subgraph_name)
280
+ _recursive_joint_graph_passes(subgraph)
281
+ joint_graph_passes(gm)
282
+
283
+
284
+ def _recursive_post_grad_passes(gm, is_inference: bool = False):
285
+ for subgraph_name in _get_subgraph_names(gm):
286
+ subgraph = getattr(gm, subgraph_name)
287
+ _recursive_post_grad_passes(subgraph, is_inference)
288
+ post_grad_passes(gm, is_inference)
289
+
290
+
291
+ def split_const_gm(
292
+ gm: torch.fx.GraphModule,
293
+ lifted_constants: Optional[Dict[str, Any]] = None,
294
+ skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
295
+ ) -> Tuple[torch.fx.GraphModule, Dict[str, int]]:
296
+ """
297
+ This function takes an GraphModule input "gm".
298
+ The gm will be split into 2 components,
299
+ 1) const_gm, which consists the subgraph of gm that can be constant folded.
300
+ 2) gm (being inplace modified,) which returns the graph after constant folding.
301
+
302
+ If an additional "lifted_constants" argument is passed in, we will assume the gm has
303
+ been lifted and run the transformation accordingly.
304
+
305
+ When a "skip_folding_node_fn" callback is passed, we will skip constant folding on
306
+ the nodes for which the callback returns True.
307
+
308
+ const_output_index is a mapping of corresponding node name from gm to the
309
+ output index of const_gm.
310
+ Returns (const_gm, const_output_index)
311
+ """
312
+ from torch._inductor.constant_folding import (
313
+ CONST_MODULE_TAG,
314
+ META_TAG,
315
+ MODULE_TAG,
316
+ replace_node_with_constant,
317
+ run_and_get_constant_graph,
318
+ )
319
+
320
+ const_gm, const_result = run_and_get_constant_graph(
321
+ gm, lifted_constants, skip_folding_node_fn
322
+ )
323
+
324
+ const_outputs = {
325
+ x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0])
326
+ }
327
+
328
+ to_erase_node = []
329
+ to_replace_node = []
330
+ const_output_index = {}
331
+ for node in gm.graph.nodes:
332
+ if node.name in const_outputs:
333
+ to_replace_node.append(node)
334
+ elif node.meta[META_TAG] == CONST_MODULE_TAG and node.op != "placeholder":
335
+ to_erase_node.append(node)
336
+
337
+ for node in to_replace_node:
338
+ new_const_name = "_FOLDED_CONST_" + node.name
339
+ replace_node_with_constant(
340
+ gm,
341
+ node,
342
+ const_result[const_outputs[node.name]],
343
+ new_const_name,
344
+ )
345
+ const_output_index[new_const_name] = const_outputs[node.name]
346
+ for node in to_erase_node[::-1]:
347
+ if node.users:
348
+ for n in node.users:
349
+ assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty."
350
+ else:
351
+ gm.graph.erase_node(node)
352
+ gm.recompile()
353
+
354
+ return const_gm, const_output_index
355
+
356
+
357
+ def is_tf32_warning_applicable(gm: torch.fx.GraphModule):
358
+ aten = torch.ops.aten
359
+ tf32_ops = {
360
+ aten.mm.default,
361
+ aten.addmm.default,
362
+ aten.bmm.default,
363
+ aten.baddbmm.default,
364
+ }
365
+ for target in tf32_ops:
366
+ for node in gm.graph.find_nodes(op="call_function", target=target):
367
+ if (
368
+ isinstance(node.meta.get("val", None), torch.Tensor)
369
+ and node.meta["val"].dtype == torch.float32
370
+ and node.meta["val"].device.type == "cuda"
371
+ ):
372
+ return True
373
+ return False
374
+
375
+
376
+ def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]):
377
+ """
378
+ For CPU backend, enable comprehensive padding causes some unit tests
379
+ fail due to changing number of generated kernels. Skip for now.
380
+ """
381
+ has_gpu = any(
382
+ is_gpu(t.device.type) for t in example_inputs if isinstance(t, torch.Tensor)
383
+ )
384
+
385
+ if config.disable_padding_cpu and config.comprehensive_padding and not has_gpu:
386
+ perf_hint_log.info("Skip comprehensive padding on CPU")
387
+ return config.patch(comprehensive_padding=False)
388
+ else:
389
+ return contextlib.nullcontext()
390
+
391
+
392
+ def fake_tensor_prop(
393
+ gm: torch.fx.GraphModule,
394
+ example_inputs: List[torch.Tensor],
395
+ force_allow_non_fake_inputs: bool = False,
396
+ ):
397
+ """
398
+ If we can not detect fake mode from the context of inputs, create one.
399
+
400
+ The created fake mode will be returned.
401
+ """
402
+ fake_mode = detect_fake_mode(example_inputs)
403
+ if not fake_mode:
404
+ fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
405
+ FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
406
+ else:
407
+ ctx = (
408
+ contextlib.nullcontext()
409
+ if not force_allow_non_fake_inputs
410
+ else mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
411
+ )
412
+ with ctx: # type: ignore[attr-defined]
413
+ FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
414
+ *example_inputs
415
+ )
416
+
417
+ return fake_mode
418
+
419
+
420
+ def should_use_remote_fx_graph_cache():
421
+ if config.fx_graph_remote_cache is not None:
422
+ return config.fx_graph_remote_cache
423
+ if not config.is_fbcode():
424
+ return False
425
+
426
+ if torch._utils_internal.is_fb_unit_test():
427
+ return False
428
+
429
+ try:
430
+ from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
431
+ except ModuleNotFoundError:
432
+ return False
433
+
434
+ jk_name = "pytorch/remote_cache:fx_graph_memcache_version"
435
+ if torch.version.hip is not None:
436
+ jk_name = "pytorch/remote_cache:fx_graph_memcache_version_amd"
437
+
438
+ return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(jk_name)
439
+
440
+
441
+ # pass config dict back to user
442
+ def get_patched_config_dict(config_patches=None) -> Dict[str, Any]:
443
+ with config.patch(config_patches):
444
+ return config.get_config_copy()
445
+
446
+
447
+ @contextlib.contextmanager
448
+ def with_fresh_cache_if_config():
449
+ if config.force_disable_caches:
450
+ # Don't delete the cache dir because it has to survive beyond the
451
+ # compile_fx call. Let's put the temp dirs under the default cache
452
+ # dir so they're easier to locate.
453
+ with fresh_inductor_cache(dir=cache_dir(), delete=False):
454
+ yield
455
+ else:
456
+ yield
457
+
458
+
459
+ def compile_fx_inner(*args, **kwargs):
460
+ # Need with_fresh_cache_if_config for compile_fx_inner even if we already have one for
461
+ # compile_fx. The reason is the compilation for backward graph may happen after
462
+ # compile_fx return and we may want to use the _LazyGraphModule for compiling
463
+ # the backward graph as well.
464
+ with contextlib.ExitStack() as stack:
465
+ stack.enter_context(torch.utils._python_dispatch._disable_current_modes())
466
+ stack.enter_context(_use_lazy_graph_module(dynamo_config.use_lazy_graph_module))
467
+ stack.enter_context(
468
+ dynamo_utils.dynamo_timed(
469
+ "compile_fx_inner", phase_name="inductor_compile", fwd_only=False
470
+ )
471
+ )
472
+ stack.enter_context(with_fresh_cache_if_config())
473
+ stack.enter_context(DebugContext())
474
+
475
+ return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
476
+ *args, **kwargs
477
+ )
478
+
479
+
480
+ @time_and_log(attr="compilation time (in seconds)")
481
+ def _compile_fx_inner(
482
+ gm: torch.fx.GraphModule,
483
+ example_inputs: List[torch.Tensor],
484
+ cudagraphs: Optional[BoxedBool] = None,
485
+ static_input_idxs: Optional[List[int]] = None,
486
+ is_backward: bool = False,
487
+ graph_id: Optional[int] = None,
488
+ cpp_wrapper: bool = False,
489
+ aot_mode: bool = False,
490
+ is_inference: bool = False,
491
+ boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
492
+ user_visible_outputs: Optional[Dict[str, None]] = None,
493
+ layout_opt: Optional[bool] = None,
494
+ extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
495
+ ) -> Union[CompiledFxGraph, str]:
496
+ """
497
+ Inductor API that compiles a single graph.
498
+
499
+ If you change the argument list for this function, make sure you
500
+ also update the call to save_args_for_compile_fx_inner below accordingly.
501
+ """
502
+ if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode:
503
+ # trigger the real recompilation for _LazyGraphModule before returning
504
+ # the forward method.
505
+ from torch.fx._lazy_graph_module import _LazyGraphModule
506
+
507
+ _LazyGraphModule.force_recompile(gm)
508
+ return make_boxed_func(gm.forward)
509
+
510
+ if static_input_idxs is None:
511
+ static_input_idxs = []
512
+
513
+ static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs)
514
+
515
+ assert isinstance(
516
+ next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)
517
+ ), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
518
+
519
+ if config.save_args:
520
+ save_args_for_compile_fx_inner(
521
+ gm,
522
+ example_inputs,
523
+ cudagraphs=cudagraphs,
524
+ static_input_idxs=static_input_idxs,
525
+ is_backward=is_backward,
526
+ graph_id=graph_id,
527
+ cpp_wrapper=cpp_wrapper,
528
+ aot_mode=aot_mode,
529
+ is_inference=is_inference,
530
+ boxed_forward_device_index=boxed_forward_device_index,
531
+ user_visible_outputs=user_visible_outputs,
532
+ layout_opt=layout_opt,
533
+ )
534
+
535
+ if cudagraphs is None:
536
+ cudagraphs = BoxedBool(config.triton.cudagraphs)
537
+
538
+ # Inputs to fx_codegen_and_compile
539
+ # Anything that affects codegen should go here, so if the signature
540
+ # of fx_codegen_and_compile changes, the dict should be updated accordingly
541
+ graph_kwargs = {
542
+ "cudagraphs": cudagraphs,
543
+ "static_input_idxs": static_input_idxs,
544
+ "is_backward": is_backward,
545
+ "graph_id": graph_id,
546
+ "cpp_wrapper": cpp_wrapper,
547
+ "aot_mode": aot_mode,
548
+ "is_inference": is_inference,
549
+ "user_visible_outputs": user_visible_outputs,
550
+ "layout_opt": layout_opt,
551
+ "extern_node_serializer": extern_node_serializer,
552
+ }
553
+
554
+ start = time.time()
555
+
556
+ fx_graph_remote_cache = should_use_remote_fx_graph_cache()
557
+
558
+ inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) # type: ignore[arg-type]
559
+
560
+ def codegen_and_compile(
561
+ gm,
562
+ example_inputs,
563
+ inputs_to_check,
564
+ fx_kwargs,
565
+ ):
566
+ """
567
+ This function calls fx_codegen_and_compile and also adds some extra metadata to the resulting
568
+ compiled fx graph. The metadata is saved to FXGraphCache.
569
+ """
570
+ compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
571
+ if isinstance(compiled_graph, str):
572
+ # We only return a string in aot mode, in which case we don't
573
+ # need to do any post-compilation steps: we just return the string,
574
+ # which is the filename of the compiled code.
575
+ return compiled_graph
576
+ cudagraph_info = None
577
+ if cudagraphs:
578
+ # check cudagraph disabling reasons from inductor lowering
579
+ if compiled_graph.disabled_cudagraphs_reason:
580
+ if "cuda" in compiled_graph.device_types:
581
+ log_cudagraph_skip_and_bump_counter(
582
+ f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}"
583
+ )
584
+ else:
585
+ counters["inductor"]["cudagraph_skips"] += 1
586
+ BoxedBool.disable(cudagraphs)
587
+ else:
588
+ complex_memory_overlap_inputs = any(
589
+ complex_memory_overlap(t)
590
+ for t in example_inputs
591
+ if isinstance(t, torch.Tensor)
592
+ )
593
+
594
+ if not config.triton.cudagraph_support_input_mutation:
595
+ # Skip supports for cudagraph-managed tensors
596
+ from torch._inductor.cudagraph_utils import (
597
+ check_for_mutation_ignore_cuda_graph_managed_tensor,
598
+ )
599
+
600
+ has_mutation_str = (
601
+ check_for_mutation_ignore_cuda_graph_managed_tensor(
602
+ gm,
603
+ compiled_graph,
604
+ static_input_idxs, # type:ignore[arg-type]
605
+ )
606
+ )
607
+ has_mutation = has_mutation_str is not None
608
+
609
+ if has_mutation:
610
+ compiled_graph.disabled_cudagraphs_reason = has_mutation_str
611
+ else:
612
+ # Check mutation later to support cudagraph-managed tensors
613
+ has_mutation = None
614
+
615
+ cudagraph_tests = [
616
+ (not has_mutation, "mutated inputs"),
617
+ (not has_incompatible_cudagraph_ops(gm), "incompatible ops"),
618
+ (not complex_memory_overlap_inputs, "complex memory overlap"),
619
+ (
620
+ all(
621
+ isinstance(t, (torch.Tensor, torch.SymInt))
622
+ for t in example_inputs
623
+ ),
624
+ "non-Tensor inputs",
625
+ ),
626
+ ]
627
+ output = output_node(gm)
628
+ # output args are tuple of first argument
629
+ assert len(output.args) == 1
630
+ stack_traces = [
631
+ (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
632
+ for arg in output.args[0]
633
+ ]
634
+ cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
635
+ placeholders = tuple(get_placeholder_info(gm.graph))
636
+ cudagraph_info = CudagraphCachedInfo(
637
+ placeholders, stack_traces, cudagraph_fail_reasons
638
+ )
639
+
640
+ compiled_graph.cudagraph_info = cudagraph_info
641
+ compiled_graph.inputs_to_check = inputs_to_check
642
+ compiled_graph.fx_kwargs = fx_kwargs
643
+ # TODO: should this be part of fx_kwargs
644
+ compiled_graph.boxed_forward_device_index = boxed_forward_device_index
645
+ return compiled_graph
646
+
647
+ with _WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _:
648
+ if (
649
+ not config.force_disable_caches
650
+ and (config.fx_graph_cache or fx_graph_remote_cache)
651
+ and not aot_mode
652
+ ):
653
+ for i, input in enumerate(example_inputs):
654
+ if (
655
+ isinstance(input, torch.Tensor)
656
+ and input.device.type == "cuda"
657
+ and i in static_input_idxs
658
+ ):
659
+ input._is_inductor_static = True # type: ignore[attr-defined]
660
+
661
+ compiled_graph = FxGraphCache.load(
662
+ codegen_and_compile,
663
+ gm,
664
+ example_inputs,
665
+ graph_kwargs,
666
+ inputs_to_check,
667
+ local=config.fx_graph_cache,
668
+ remote=fx_graph_remote_cache,
669
+ )
670
+ else:
671
+ compiled_graph = codegen_and_compile(
672
+ gm, example_inputs, inputs_to_check, graph_kwargs # type: ignore[arg-type]
673
+ )
674
+ if aot_mode:
675
+ # AOT mode is special because codegen_and_compile returns a string.
676
+ # In that case, we don't need to run all post compilation steps, we just need
677
+ # to return the string directly.
678
+ return compiled_graph
679
+ compiled_graph = FxGraphCache.post_compile(
680
+ compiled_graph, example_inputs, cudagraphs
681
+ )
682
+
683
+ log.debug("FX codegen and compilation took %.3fs", time.time() - start)
684
+
685
+ _step_logger()(
686
+ logging.INFO,
687
+ "torchinductor done compiling "
688
+ f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
689
+ f"graph {graph_id}",
690
+ )
691
+ # aot autograd needs to know to pass in inputs as a list
692
+ compiled_graph._boxed_call = True
693
+ return compiled_graph
694
+
695
+
696
+ def fx_codegen_and_compile(
697
+ gm: torch.fx.GraphModule,
698
+ example_inputs: List[torch.Tensor],
699
+ cudagraphs: Optional[BoxedBool] = None,
700
+ static_input_idxs: Optional[List[int]] = None,
701
+ is_backward: bool = False,
702
+ graph_id: Optional[int] = None,
703
+ cpp_wrapper: bool = False,
704
+ aot_mode: bool = False,
705
+ is_inference: bool = False,
706
+ # Use a dict with None value rather than a set for deterministic
707
+ # iteration order just in case.
708
+ user_visible_outputs: Optional[Dict[str, None]] = None,
709
+ layout_opt: Optional[bool] = None,
710
+ extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
711
+ ) -> Union[CompiledFxGraph, str]:
712
+ if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None:
713
+ import time
714
+
715
+ log.warning("Sleeping for %s since sleep_sec_TESTING_ONLY is set", sleep_sec)
716
+ time.sleep(sleep_sec)
717
+
718
+ with dynamo_utils.preserve_rng_state():
719
+ if is_tf32_warning_applicable(gm):
720
+ _warn_tf32_disabled()
721
+
722
+ inductor_counters = counters["inductor"].copy()
723
+
724
+ # lift the maximum depth of the Python interpreter stack
725
+ # to adapt large/deep models
726
+ sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
727
+
728
+ _step_logger()(
729
+ logging.INFO,
730
+ "torchinductor compiling "
731
+ f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
732
+ f"graph {graph_id}",
733
+ )
734
+
735
+ def log_graph_runnable():
736
+ fd = io.StringIO()
737
+ torch._dynamo.repro.after_aot.save_graph_repro(
738
+ fd, gm, example_inputs, "inductor", save_dir=None
739
+ )
740
+ return fd.getvalue()
741
+
742
+ torch._logging.trace_structured(
743
+ "artifact",
744
+ metadata_fn=lambda: {
745
+ "name": "fx_graph_runnable",
746
+ "encoding": "string",
747
+ },
748
+ payload_fn=lambda: log_graph_runnable(),
749
+ )
750
+
751
+ V.debug.fx_graph(gm, example_inputs)
752
+ # TODO: Should we actually dump this? It should be redundant with the aot
753
+ # structured logs...
754
+ # trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False))
755
+
756
+ shape_env = shape_env_from_inputs(example_inputs)
757
+
758
+ # Convert view to reshape in the graph. This is necessary primarily for
759
+ # layout optimization. Do it unconditionally for uniformity.
760
+ #
761
+ # It's needed because when we do layout optimization, an contiguous tensor
762
+ # in eager mode may becomes a channels last tensor. A view op previously
763
+ # can be applied to the contiguous tensor may not be able to be applied
764
+ # on the channels tensor any more. An error like
765
+ # RuntimeError: view size is not compatible with input tensor's size and stride
766
+ # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
767
+ # will be printed.
768
+ #
769
+ # Replace view op to reshape op in this case.
770
+ # As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this.
771
+ #
772
+ # Also this has to be done before FakeTensorProp below to avoid the failed
773
+ # .view() call.
774
+ view_to_reshape(gm)
775
+
776
+ # It is safe to run FakeTensorProp under no_grad because by the time
777
+ # we're in inductor, we assume that AOTAutograd has already "taken care"
778
+ # of autograd, so there should be no more autograd-related API's in the
779
+ # graph.
780
+ with torch.no_grad():
781
+ fake_mode = fake_tensor_prop(gm, example_inputs)
782
+
783
+ # pattern matcher passes might not preserve striding information
784
+ # on node.meta["val"]. if in the future we rely on these being
785
+ # correct we will need to fix.
786
+
787
+ with V.set_fake_mode(fake_mode):
788
+ # has some issues with memory in training
789
+ _recursive_post_grad_passes(gm, is_inference=is_inference)
790
+ V.debug.fx_graph_transformed(gm, example_inputs)
791
+ post_grad_graphs_log.debug(
792
+ "%s",
793
+ lazy_format_graph_code(
794
+ "AFTER POST GRAD",
795
+ gm,
796
+ include_stride=True,
797
+ include_device=True,
798
+ colored=True,
799
+ ),
800
+ )
801
+ trace_structured(
802
+ "inductor_post_grad_graph",
803
+ payload_fn=lambda: gm.print_readable(
804
+ print_output=False, include_stride=True, include_device=True
805
+ ),
806
+ )
807
+ if config.is_fbcode():
808
+ log_optimus_to_scuba(
809
+ extra_logging={"pt2_configs": str(get_patched_config_dict())}
810
+ )
811
+
812
+ with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding(
813
+ example_inputs
814
+ ):
815
+ const_output_index = None
816
+ const_graph = None
817
+ const_code = None
818
+
819
+ if aot_mode and config.aot_inductor.use_runtime_constant_folding:
820
+ const_gm, const_output_index = split_const_gm(gm)
821
+
822
+ const_graph = GraphLowering(
823
+ const_gm,
824
+ example_inputs=[],
825
+ shape_env=shape_env,
826
+ graph_id=graph_id,
827
+ cpp_wrapper=cpp_wrapper,
828
+ aot_mode=aot_mode,
829
+ user_visible_outputs=user_visible_outputs,
830
+ extern_node_serializer=extern_node_serializer,
831
+ is_inference=is_inference,
832
+ is_const_graph=True,
833
+ )
834
+ with V.set_graph_handler(const_graph):
835
+ assert cpp_wrapper, "AOT mode only supports C++ wrapper"
836
+ const_graph.run()
837
+
838
+ const_code, _ = const_graph.codegen_with_cpp_wrapper()
839
+
840
+ graph = GraphLowering(
841
+ gm,
842
+ # example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
843
+ # For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
844
+ # we currently use fake tensors and defake them later.
845
+ example_inputs=example_inputs,
846
+ shape_env=shape_env,
847
+ graph_id=graph_id,
848
+ cpp_wrapper=cpp_wrapper,
849
+ aot_mode=aot_mode,
850
+ user_visible_outputs=user_visible_outputs,
851
+ extern_node_serializer=extern_node_serializer,
852
+ is_inference=is_inference,
853
+ const_output_index=const_output_index,
854
+ const_code=const_code,
855
+ const_module=const_graph,
856
+ )
857
+ metrics_helper = metrics.CachedMetricsHelper()
858
+ with V.set_graph_handler(graph):
859
+ graph.run(*example_inputs)
860
+ output_strides: List[Optional[Tuple[_StrideExprStr, ...]]] = []
861
+ if graph.graph_outputs is not None:
862
+ # We'll put the output strides in the compiled graph so we
863
+ # can later return them to the caller via TracingContext
864
+ p = SymExprPrinter()
865
+ for out in graph.graph_outputs:
866
+ if (
867
+ hasattr(out, "layout")
868
+ and len(free_unbacked_symbols(out.layout.stride)) == 0
869
+ ):
870
+ # Convert to string for eval on the load path
871
+ output_strides.append(
872
+ tuple(p.doprint(s) for s in out.layout.stride)
873
+ )
874
+ else:
875
+ output_strides.append(None)
876
+
877
+ _check_triton_bf16_support(graph)
878
+ compiled_fn = graph.compile_to_fn()
879
+ num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
880
+ metrics.num_bytes_accessed += num_bytes
881
+ metrics.node_runtimes += node_runtimes
882
+ metrics.nodes_num_elem += nodes_num_elem
883
+
884
+ if (
885
+ cudagraphs
886
+ and config.triton.cudagraph_skip_dynamic_graphs
887
+ and not V.graph.disable_cudagraphs_reason
888
+ and torch._inductor.utils.any_is_symbolic(*example_inputs)
889
+ ):
890
+ stack_trace = None
891
+ for node in gm.graph.nodes:
892
+ meta_val = node.meta.get("val", None)
893
+ if (
894
+ node.op == "placeholder"
895
+ or not isinstance(meta_val, torch.Tensor)
896
+ or not torch._inductor.utils.any_is_symbolic(meta_val)
897
+ ):
898
+ continue
899
+
900
+ if stack_trace := node.meta.get("stack_trace", None):
901
+ break
902
+ disable = "graph with symbolic shapes inputs and config.triton.cudagraph_skip_dynamic_graphs=True."
903
+ if stack_trace:
904
+ disable = f"{disable} Found from {stack_trace}\n"
905
+ else:
906
+ disable = f"{disable}\n"
907
+ V.graph.disable_cudagraphs_reason = disable
908
+
909
+ if V.aot_compilation is True:
910
+ return compiled_fn
911
+
912
+ if cudagraphs and not V.graph.disable_cudagraphs_reason:
913
+ from torch._inductor.cudagraph_utils import (
914
+ check_lowering_disable_cudagraph,
915
+ )
916
+
917
+ V.graph.disable_cudagraphs_reason = (
918
+ check_lowering_disable_cudagraph(V.graph.device_node_mapping)
919
+ )
920
+
921
+ compiled_graph = CompiledFxGraph(
922
+ compiled_fn,
923
+ graph,
924
+ output_strides,
925
+ V.graph.disable_cudagraphs_reason,
926
+ metrics_helper.get_deltas(),
927
+ counters["inductor"] - inductor_counters,
928
+ )
929
+
930
+ return compiled_graph
931
+
932
+
933
+ def get_input_idxs_to_check(
934
+ inputs: List[InputType],
935
+ static_input_idxs: Sequence[int],
936
+ ) -> Sequence[int]:
937
+ """
938
+ This function runs at compile time, and generates a list of indices for which we
939
+ might need to do a copy to preserve alignment requirements.
940
+ """
941
+ ids_to_check = []
942
+
943
+ for i, input in enumerate(inputs):
944
+ if not isinstance(input, torch.Tensor):
945
+ # non-tensors don't need alignment
946
+ continue
947
+ if not is_gpu(input.device.type):
948
+ # right now we only care for gpu tensors
949
+ continue
950
+ with maybe_get_suppress_shape_guards_ctx():
951
+ # suppress guards so that tensor_is_aligned and should_assume_input_aligned
952
+ # do not add guards on input's storage offset
953
+ if i in static_input_idxs and tensor_is_aligned(input):
954
+ continue
955
+ if not should_assume_input_aligned(input):
956
+ continue
957
+
958
+ # if we get here, then
959
+ # (a) our triton code assumes that the input is aligned
960
+ # (b) we can't be sure ahead of time that the input will actually be aligned.
961
+ # therefore, at runtime, we'll need to check that the input is aligned
962
+ # (and if not, clone it to make it aligned.)
963
+ ids_to_check.append(i)
964
+
965
+ return ids_to_check
966
+
967
+
968
+ def cudagraphify(
969
+ model: Callable[..., Any],
970
+ static_input_idxs: Sequence[int] = (),
971
+ *,
972
+ device_index: int,
973
+ stack_traces: List[Optional[str]],
974
+ is_backward: bool,
975
+ is_inference: bool,
976
+ constants: Tuple[torch.Tensor, ...] = (),
977
+ placeholders: Sequence[PlaceholderInfo] = (),
978
+ mutated_input_idxs: Tuple[int, ...] = (),
979
+ ) -> Callable[..., Any]:
980
+ from torch._inductor.cudagraph_trees import (
981
+ cudagraphify_impl as new_cudagraphify_impl,
982
+ )
983
+
984
+ cudagraphify_fn: Callable[..., Any]
985
+ if config.triton.cudagraph_trees:
986
+ cudagraphify_fn = functools.partial(
987
+ new_cudagraphify_impl,
988
+ device_index=device_index,
989
+ stack_traces=stack_traces,
990
+ is_backward=is_backward,
991
+ is_inference=is_inference,
992
+ constants=constants,
993
+ placeholders=placeholders,
994
+ mutated_input_idxs=mutated_input_idxs,
995
+ )
996
+ else:
997
+ cudagraphify_fn = cudagraphify_impl
998
+
999
+ compiled_fn = None
1000
+
1001
+ def run(new_inputs):
1002
+ nonlocal compiled_fn
1003
+ if compiled_fn is None:
1004
+ with dynamo_utils.dynamo_timed(
1005
+ "cudagraphify"
1006
+ ), dynamo_utils.preserve_rng_state():
1007
+ compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
1008
+ return compiled_fn(new_inputs)
1009
+
1010
+ return run
1011
+
1012
+
1013
+ def static_input(x: torch.Tensor) -> torch.Tensor:
1014
+ """
1015
+ Copy and input while preserving strides
1016
+ """
1017
+ return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
1018
+
1019
+
1020
+ def index_expanded_dims_and_copy_(
1021
+ dst: torch.Tensor,
1022
+ src: torch.Tensor,
1023
+ expanded_dims: List[int],
1024
+ ):
1025
+ "Index into expanded dimensions of both dst and src then copy_"
1026
+ dst = index_expanded_dims(dst, expanded_dims)
1027
+ src = index_expanded_dims(src, expanded_dims)
1028
+ dst.copy_(src)
1029
+
1030
+
1031
+ def cudagraphify_impl(
1032
+ model: Callable[..., Any],
1033
+ inputs: List[torch.Tensor],
1034
+ static_input_idxs: Sequence[int] = (),
1035
+ ):
1036
+ """
1037
+ Assumes inputs[static_input_idxs[i]] are always the same memory address
1038
+ """
1039
+ check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type]
1040
+ static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type]
1041
+ copy_misaligned_inputs(inputs, check_input_idxs) # type: ignore[arg-type]
1042
+
1043
+ assert isinstance(inputs, list)
1044
+
1045
+ inps_expanded_dims = [
1046
+ get_expanded_dims(x) if idx not in static_input_idxs else []
1047
+ for idx, x in enumerate(inputs)
1048
+ ]
1049
+
1050
+ # allocate static tensor inputs
1051
+ static_inputs = [
1052
+ x
1053
+ if not isinstance(x, torch.Tensor)
1054
+ else static_input(x)
1055
+ if idx not in static_input_idxs
1056
+ else x.detach()
1057
+ for idx, x in enumerate(inputs)
1058
+ ]
1059
+
1060
+ # copy over input values for fresh allocations
1061
+ for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)):
1062
+ if isinstance(x, torch.Tensor) and idx not in static_input_idxs:
1063
+ index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims)
1064
+
1065
+ # warmup
1066
+ torch.cuda.synchronize()
1067
+ stream = torch.cuda.Stream()
1068
+ stream.wait_stream(torch.cuda.current_stream())
1069
+ # copy static_inputs because it will be cleared in model
1070
+ with torch.cuda.stream(stream):
1071
+ model(list(static_inputs))
1072
+ stream.synchronize()
1073
+ torch.cuda.current_stream().wait_stream(stream)
1074
+ torch.cuda.synchronize()
1075
+
1076
+ # record
1077
+ graph = torch.cuda.CUDAGraph()
1078
+ with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):
1079
+ static_outputs = model(list(static_inputs))
1080
+ if not isinstance(static_outputs, (list, tuple)):
1081
+ static_outputs = (static_outputs,)
1082
+
1083
+ if config.size_asserts:
1084
+
1085
+ def run(new_inputs):
1086
+ assert len(static_inputs) == len(new_inputs)
1087
+ for idx, (dst, src, expanded_dims) in enumerate(
1088
+ zip(static_inputs, new_inputs, inps_expanded_dims)
1089
+ ):
1090
+ if not isinstance(dst, torch.Tensor):
1091
+ pass
1092
+ elif idx in static_input_idxs:
1093
+ assert dst.data_ptr() == src.data_ptr()
1094
+ else:
1095
+ # TODO - could make one single op of multiple slices
1096
+ # and avoid dispatch.
1097
+ # Could also pre-index the `dst` tensors
1098
+ index_expanded_dims_and_copy_(dst, src, expanded_dims)
1099
+ new_inputs.clear()
1100
+ graph.replay()
1101
+ return static_outputs
1102
+
1103
+ else:
1104
+ copy_indices = [
1105
+ idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
1106
+ ]
1107
+
1108
+ def run(new_inputs):
1109
+ for idx in copy_indices:
1110
+ expanded_dims = inps_expanded_dims[idx]
1111
+ index_expanded_dims_and_copy_(
1112
+ static_inputs[idx], new_inputs[idx], expanded_dims
1113
+ )
1114
+ new_inputs.clear()
1115
+ graph.replay()
1116
+ return static_outputs
1117
+
1118
+ return align_inputs_from_check_idxs(run, check_input_idxs)
1119
+
1120
+
1121
+ def compile_fx_aot(
1122
+ model_: torch.fx.GraphModule,
1123
+ example_inputs_: List[torch.Tensor],
1124
+ inner_compile: Callable[..., Any] = compile_fx_inner,
1125
+ config_patches: Optional[Dict[str, Any]] = None,
1126
+ ):
1127
+ config_patches: Dict[str, Any] = (
1128
+ {"cpp_wrapper": True}
1129
+ if config_patches is None
1130
+ else {**config_patches, "cpp_wrapper": True}
1131
+ )
1132
+ if (
1133
+ "aot_inductor.output_path" not in config_patches
1134
+ and not config.aot_inductor.output_path
1135
+ ):
1136
+ config_patches = {
1137
+ **config_patches,
1138
+ "aot_inductor.output_path": code_hash(model_.code),
1139
+ }
1140
+
1141
+ extern_node_serializer = config_patches.pop("extern_node_serializer", None)
1142
+ with V.set_aot_compilation(True):
1143
+ compiled_lib_path = compile_fx(
1144
+ model_,
1145
+ example_inputs_,
1146
+ inner_compile=functools.partial(
1147
+ inner_compile,
1148
+ aot_mode=True,
1149
+ extern_node_serializer=extern_node_serializer,
1150
+ ),
1151
+ config_patches=config_patches,
1152
+ )
1153
+ assert os.path.exists(
1154
+ compiled_lib_path
1155
+ ), f"AOTInductor compiled library does not exist at {compiled_lib_path}"
1156
+ return compiled_lib_path
1157
+
1158
+
1159
+ _graph_counter = count(0)
1160
+
1161
+
1162
+ def fw_compiler_freezing(
1163
+ aot_autograd_model: torch.fx.GraphModule,
1164
+ aot_example_inputs: List[torch.Tensor],
1165
+ dynamo_model: torch.fx.GraphModule,
1166
+ num_example_inputs: int,
1167
+ inner_compile: Callable[..., Any],
1168
+ cudagraphs: BoxedBool,
1169
+ graph_id: int,
1170
+ forward_device: BoxedDeviceIndex,
1171
+ ):
1172
+ from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
1173
+
1174
+ # partition_fn won't be called
1175
+ _recursive_joint_graph_passes(aot_autograd_model)
1176
+
1177
+ layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True)
1178
+ if layout_opt:
1179
+ # make sure meta['val'] is properly setup
1180
+ fake_tensor_prop(aot_autograd_model, aot_example_inputs, True)
1181
+ convert_conv_weights_to_channels_last(aot_autograd_model)
1182
+
1183
+ opt_model, preserved_arg_indices = freeze(
1184
+ dynamo_model,
1185
+ aot_autograd_model,
1186
+ aot_example_inputs, # type: ignore[arg-type]
1187
+ )
1188
+
1189
+ aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
1190
+ num_fixed = len(preserved_arg_indices) - num_example_inputs
1191
+
1192
+ fake_mode = detect_fake_mode(aot_example_inputs)
1193
+
1194
+ # for freezing, all graph outputs should be user visible
1195
+ *_, model_outputs_node = opt_model.graph.nodes
1196
+ model_outputs = model_outputs_node.args[0]
1197
+ user_visible_outputs = dict.fromkeys(
1198
+ n.name for n in model_outputs if isinstance(n, torch.fx.Node)
1199
+ )
1200
+
1201
+ static_input_idxs = list(range(num_fixed))
1202
+ # constant params will be real tensors, not fake
1203
+ tracing_context = torch._guards.TracingContext.try_get()
1204
+ if tracing_context is not None:
1205
+ params_flat = tracing_context.params_flat
1206
+ assert params_flat is not None
1207
+ for i in range(len(params_flat)):
1208
+ if i not in preserved_arg_indices:
1209
+ params_flat[i] = None
1210
+
1211
+ if tracing_context.fw_metadata:
1212
+ static_input_idxs += tracing_context.fw_metadata.static_input_indices
1213
+
1214
+ with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
1215
+ optimized_function = inner_compile(
1216
+ opt_model,
1217
+ aot_example_inputs,
1218
+ static_input_idxs=static_input_idxs,
1219
+ cudagraphs=cudagraphs,
1220
+ graph_id=graph_id,
1221
+ is_inference=True,
1222
+ boxed_forward_device_index=forward_device,
1223
+ layout_opt=layout_opt,
1224
+ user_visible_outputs=user_visible_outputs,
1225
+ )
1226
+
1227
+ # aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper
1228
+ # that drops constant-ified params
1229
+ if V.aot_compilation is True:
1230
+ return optimized_function
1231
+
1232
+ def wrapper(args):
1233
+ args_new = [args[i] for i in preserved_arg_indices]
1234
+ args.clear()
1235
+ return optimized_function(args_new)
1236
+
1237
+ wrapper._boxed_call = True # type: ignore[attr-defined]
1238
+
1239
+ return wrapper
1240
+
1241
+
1242
+ def compile_fx(
1243
+ model_: torch.fx.GraphModule,
1244
+ example_inputs_: List[torch.Tensor],
1245
+ inner_compile: Callable[..., Any] = compile_fx_inner,
1246
+ config_patches: Optional[Dict[str, Any]] = None,
1247
+ decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
1248
+ ):
1249
+ with _use_lazy_graph_module(dynamo_config.use_lazy_graph_module):
1250
+ """Main entrypoint to a compile given FX graph"""
1251
+ if config_patches:
1252
+ with config.patch(config_patches):
1253
+ return compile_fx(
1254
+ model_,
1255
+ example_inputs_,
1256
+ # need extra layer of patching as backwards is compiled out of scope
1257
+ inner_compile=config.patch(config_patches)(inner_compile),
1258
+ decompositions=decompositions,
1259
+ )
1260
+
1261
+ if config.cpp_wrapper:
1262
+ with config.patch(
1263
+ {
1264
+ "cpp_wrapper": False,
1265
+ # For triton.autotune_at_compile_time, disable by default for
1266
+ # FBCode, but enabled by default for OSS.
1267
+ "triton.autotune_at_compile_time": config.triton.autotune_at_compile_time
1268
+ if config.is_fbcode()
1269
+ else os.environ.get(
1270
+ "TORCHINDUCTOR_TRITON_AUTOTUNE_AT_COMPILE_TIME", "1"
1271
+ )
1272
+ == "1",
1273
+ "triton.autotune_cublasLt": False,
1274
+ "triton.cudagraphs": False,
1275
+ "triton.store_cubin": True,
1276
+ }
1277
+ ), V.set_real_inputs(example_inputs_):
1278
+ inputs_ = example_inputs_
1279
+ if isinstance(model_, torch.fx.GraphModule):
1280
+ fake_inputs = [
1281
+ node.meta.get("val")
1282
+ for node in model_.graph.nodes
1283
+ if node.op == "placeholder"
1284
+ ]
1285
+ if all(v is not None for v in fake_inputs):
1286
+ # Validate devices before switching to fake tensors.
1287
+ for idx, fi, i in zip(count(), fake_inputs, inputs_):
1288
+ if fi.device != i.device:
1289
+ raise ValueError(
1290
+ f"Device mismatch between fake input and example input at position #{idx}: "
1291
+ f"{fi.device} vs {i.device}. If the model was exported via torch.export(), "
1292
+ "make sure torch.export() and torch.aot_compile() run on the same device."
1293
+ )
1294
+ inputs_ = fake_inputs
1295
+ return compile_fx(
1296
+ model_,
1297
+ inputs_,
1298
+ inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
1299
+ decompositions=decompositions,
1300
+ )
1301
+
1302
+ recursive_compile_fx = functools.partial(
1303
+ compile_fx,
1304
+ inner_compile=inner_compile,
1305
+ decompositions=decompositions,
1306
+ )
1307
+
1308
+ if not graph_returns_tuple(model_):
1309
+ return make_graph_return_tuple(
1310
+ model_,
1311
+ example_inputs_,
1312
+ recursive_compile_fx,
1313
+ )
1314
+
1315
+ if isinstance(model_, torch.fx.GraphModule):
1316
+ if isinstance(model_.graph._codegen, _PyTreeCodeGen):
1317
+ # this graph is the result of dynamo.export()
1318
+ return handle_dynamo_export_graph(
1319
+ model_,
1320
+ example_inputs_,
1321
+ recursive_compile_fx,
1322
+ )
1323
+
1324
+ model_ = _recursive_pre_grad_passes(model_, example_inputs_)
1325
+
1326
+ if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
1327
+ return flatten_graph_inputs(
1328
+ model_,
1329
+ example_inputs_,
1330
+ recursive_compile_fx,
1331
+ )
1332
+
1333
+ assert not config._raise_error_for_testing
1334
+ num_example_inputs = len(example_inputs_)
1335
+ cudagraphs = BoxedBool(config.triton.cudagraphs)
1336
+ forward_device = BoxedDeviceIndex(None)
1337
+
1338
+ graph_id = next(_graph_counter)
1339
+
1340
+ decompositions = (
1341
+ decompositions if decompositions is not None else select_decomp_table()
1342
+ )
1343
+
1344
+ def fw_compiler_base(
1345
+ model: torch.fx.GraphModule,
1346
+ example_inputs: List[torch.Tensor],
1347
+ is_inference: bool,
1348
+ ):
1349
+ with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
1350
+ return _fw_compiler_base(model, example_inputs, is_inference)
1351
+
1352
+ def _fw_compiler_base(
1353
+ model: torch.fx.GraphModule,
1354
+ example_inputs: List[torch.Tensor],
1355
+ is_inference: bool,
1356
+ ):
1357
+ if is_inference:
1358
+ # partition_fn won't be called
1359
+ _recursive_joint_graph_passes(model)
1360
+
1361
+ fixed = torch._inductor.utils.num_fw_fixed_arguments(
1362
+ num_example_inputs, len(example_inputs)
1363
+ )
1364
+
1365
+ user_visible_outputs = {}
1366
+
1367
+ if config.keep_output_stride:
1368
+ model_outputs_node = output_node(model)
1369
+ model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
1370
+ num_model_outputs = len(model_outputs)
1371
+
1372
+ context = torch._guards.TracingContext.try_get()
1373
+ # See Note [User Outputs in the inductor graph]
1374
+ if context is not None and context.fw_metadata and not is_inference:
1375
+ original_output_start_index = (
1376
+ context.fw_metadata.num_mutated_inp_runtime_indices
1377
+ )
1378
+ else:
1379
+ original_output_start_index = 0
1380
+
1381
+ if isinstance(model_, torch.fx.GraphModule):
1382
+ *_, orig_model_outputs_node = model_.graph.nodes
1383
+ assert orig_model_outputs_node.op == "output"
1384
+ orig_model_outputs, _ = pytree.tree_flatten(
1385
+ orig_model_outputs_node.args
1386
+ )
1387
+ num_orig_model_outputs = len(orig_model_outputs)
1388
+ else:
1389
+ num_orig_model_outputs = num_model_outputs
1390
+
1391
+ assert num_orig_model_outputs <= num_model_outputs
1392
+
1393
+ # Note [User Outputs in the inductor graph]
1394
+ # We makes the following assumption
1395
+ # For inference
1396
+ # len(orig_model_outputs) == len(model_outputs)
1397
+ # For training
1398
+ # len(orig_model_outputs) <= len(model_outputs)
1399
+ # During training, most of the time the model_outputs starts with
1400
+ # original module's outputs followed by saved activations.
1401
+ # But this can be not true if the model have inplace updated tensors.
1402
+ # AOTAutograd will make those tensors being returned before the original
1403
+ # module's output.
1404
+ # To make things safe, we'll use original_output_start_index field
1405
+ # set by AOTAutograd to decide where the original module outputs start.
1406
+ orig_output_end_idx = (
1407
+ original_output_start_index + num_orig_model_outputs
1408
+ )
1409
+ # Sanity chec: we are about to splice out the "user" outputs from the full set
1410
+ # of "graph" outputs. Make sure we're within bounds.
1411
+ assert orig_output_end_idx <= num_model_outputs
1412
+
1413
+ user_visible_outputs = dict.fromkeys(
1414
+ n.name
1415
+ for n in model_outputs[
1416
+ original_output_start_index:orig_output_end_idx
1417
+ ]
1418
+ if isinstance(n, torch.fx.Node)
1419
+ )
1420
+
1421
+ return inner_compile(
1422
+ model,
1423
+ example_inputs,
1424
+ static_input_idxs=get_static_input_idxs(fixed),
1425
+ cudagraphs=cudagraphs,
1426
+ graph_id=graph_id,
1427
+ is_inference=is_inference,
1428
+ boxed_forward_device_index=forward_device,
1429
+ user_visible_outputs=user_visible_outputs,
1430
+ )
1431
+
1432
+ fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
1433
+
1434
+ if config.freezing and not torch.is_grad_enabled():
1435
+ inference_compiler = functools.partial(
1436
+ fw_compiler_freezing,
1437
+ dynamo_model=model_,
1438
+ num_example_inputs=num_example_inputs,
1439
+ inner_compile=inner_compile,
1440
+ cudagraphs=cudagraphs,
1441
+ graph_id=graph_id,
1442
+ forward_device=forward_device,
1443
+ )
1444
+ else:
1445
+ inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
1446
+
1447
+ def partition_fn(graph, joint_inputs, **kwargs):
1448
+ _recursive_joint_graph_passes(graph)
1449
+ return min_cut_rematerialization_partition(
1450
+ graph, joint_inputs, **kwargs, compiler="inductor"
1451
+ )
1452
+
1453
+ def bw_compiler(
1454
+ model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
1455
+ ):
1456
+ with dynamo_utils.dynamo_timed("compile_fx.<locals>.bw_compiler"):
1457
+ user_visible_outputs = {}
1458
+
1459
+ if config.bw_outputs_user_visible:
1460
+ model_outputs_node = output_node(model)
1461
+ model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
1462
+ user_visible_outputs = dict.fromkeys(
1463
+ n.name for n in model_outputs if isinstance(n, torch.fx.Node)
1464
+ )
1465
+ fixed = count_tangents(model)
1466
+ return inner_compile(
1467
+ model,
1468
+ example_inputs,
1469
+ static_input_idxs=list(range(fixed)),
1470
+ cudagraphs=cudagraphs,
1471
+ is_backward=True,
1472
+ graph_id=graph_id,
1473
+ boxed_forward_device_index=forward_device,
1474
+ user_visible_outputs=user_visible_outputs,
1475
+ )
1476
+
1477
+ # TODO: can add logging before/after the call to create_aot_dispatcher_function
1478
+ # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
1479
+ # once torchdynamo is merged into pytorch
1480
+
1481
+ fake_mode = detect_fake_mode(
1482
+ example_inputs_
1483
+ ) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
1484
+ tracing_context = (
1485
+ torch._guards.TracingContext.try_get()
1486
+ or torch._guards.TracingContext(fake_mode)
1487
+ )
1488
+
1489
+ if V.aot_compilation is True:
1490
+ with functorch_config.patch(unlift_effect_tokens=True):
1491
+ gm, graph_signature = aot_export_module(
1492
+ model_,
1493
+ example_inputs_,
1494
+ trace_joint=False,
1495
+ decompositions=decompositions,
1496
+ )
1497
+ unlifted_gm = _unlift_graph(model_, gm, graph_signature)
1498
+ if "dynamo_flat_name_to_original_fqn" in model_.meta:
1499
+ unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[
1500
+ "dynamo_flat_name_to_original_fqn"
1501
+ ]
1502
+
1503
+ # Disable amp as in aot_dispatch_autograd (https://github.com/pytorch/pytorch/pull/86515)
1504
+ # In inference_compiler (fw_compiler_base), _recursive_joint_graph_passes will call into
1505
+ # _sfdp_init() to register patterns.
1506
+ # When fallback_random is set to True, the sdpa patterns will be traced during runtime.
1507
+ # If amp is turned on, the traced FP32 patterns will have prims.convert_element_type which
1508
+ # will be the same as the generated FP16 patterns.
1509
+ disable_amp = torch._C._is_any_autocast_enabled()
1510
+ context = (
1511
+ torch._C._DisableAutocast if disable_amp else contextlib.nullcontext
1512
+ )
1513
+ with V.set_fake_mode(fake_mode), compiled_autograd.disable(), context():
1514
+ return inference_compiler(unlifted_gm, example_inputs_)
1515
+
1516
+ with V.set_fake_mode(fake_mode), torch._guards.tracing(
1517
+ tracing_context
1518
+ ), compiled_autograd.disable(), functorch_config.patch(
1519
+ unlift_effect_tokens=True
1520
+ ):
1521
+ return aot_autograd(
1522
+ fw_compiler=fw_compiler,
1523
+ bw_compiler=bw_compiler,
1524
+ inference_compiler=inference_compiler,
1525
+ decompositions=decompositions,
1526
+ partition_fn=partition_fn,
1527
+ keep_inference_input_mutations=True,
1528
+ cudagraphs=cudagraphs,
1529
+ )(model_, example_inputs_)
1530
+
1531
+
1532
+ def graph_returns_tuple(gm: torch.fx.GraphModule):
1533
+ """True if a FX graph returns a tuple"""
1534
+ if not isinstance(gm, torch.fx.GraphModule):
1535
+ return True # can't check this, assume true
1536
+ (rv,) = output_node(gm).args
1537
+ if isinstance(rv, (list, tuple)):
1538
+ return True
1539
+ if (
1540
+ isinstance(rv, torch.fx.node.Node)
1541
+ and hasattr(rv.target, "_schema")
1542
+ and len(rv.target._schema.returns) > 1
1543
+ and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns)
1544
+ ):
1545
+ # for graphs whose result is one node with multiple outputs
1546
+ return True
1547
+ return False
1548
+
1549
+
1550
+ def make_graph_return_tuple(
1551
+ gm: torch.fx.GraphModule,
1552
+ inputs: List[torch.Tensor],
1553
+ compile_gm: Callable[..., Any],
1554
+ ):
1555
+ """
1556
+ Mutate gm so it returns a tuple. This is only needed for graphs
1557
+ not created by torchdynamo that return non-tuples.
1558
+ """
1559
+ node = output_node(gm)
1560
+ (rv,) = node.args
1561
+ rv, spec = pytree.tree_flatten(rv)
1562
+ with gm.graph.inserting_before(node):
1563
+ gm.graph.output(rv)
1564
+ gm.graph.erase_node(node)
1565
+ assert graph_returns_tuple(gm)
1566
+
1567
+ compiled_fn = compile_gm(gm, inputs)
1568
+
1569
+ @functools.wraps(compiled_fn)
1570
+ def wrapper(*args, **kwargs):
1571
+ return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec)
1572
+
1573
+ return wrapper
1574
+
1575
+
1576
+ def handle_dynamo_export_graph(
1577
+ gm: torch.fx.GraphModule,
1578
+ inputs: List[torch.Tensor],
1579
+ compile_gm: Callable[..., Any],
1580
+ ):
1581
+ """
1582
+ `torch._dynamo.export` embeds pytrees in the FX graph codegen object,
1583
+ convert that to a normal FX graph so inductor can compile it.
1584
+ """
1585
+ codegen = gm.graph._codegen
1586
+ gm.graph._codegen = torch.fx.graph.CodeGen()
1587
+ gm.recompile()
1588
+
1589
+ compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs))
1590
+
1591
+ @functools.wraps(compiled_fn)
1592
+ def wrapper(*args):
1593
+ return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args)))
1594
+
1595
+ return wrapper
1596
+
1597
+
1598
+ def _check_triton_bf16_support(graph: GraphLowering) -> None:
1599
+ def warn_and_skip(device) -> None:
1600
+ from torch._dynamo.exc import SkipFrame
1601
+
1602
+ device_interface = get_interface_for_device(device.type)
1603
+ device_props = device_interface.get_device_properties(device)
1604
+ warnings.warn(
1605
+ f"{device_props.name} does not support bfloat16 compilation natively, skipping"
1606
+ )
1607
+ raise SkipFrame("BF16 is not supported")
1608
+
1609
+ for inp in graph.graph_inputs.values():
1610
+ device = getattr(inp, "get_device", lambda: torch.device("meta"))()
1611
+ if (not is_gpu(device.type)) or inp.get_dtype() != torch.bfloat16:
1612
+ continue
1613
+ # Print warning and skip frame if attempting to compile for bfloat16
1614
+ # on device without hardware support for dtype
1615
+ device_interface = get_interface_for_device(device.type)
1616
+ if device_interface.is_bf16_supported(including_emulation=False):
1617
+ return
1618
+ warn_and_skip(device)
1619
+
1620
+ for out in graph.graph_outputs:
1621
+ device = getattr(out, "get_device", lambda: torch.device("meta"))()
1622
+ if (not is_gpu(device.type)) or out.get_dtype() != torch.bfloat16:
1623
+ continue
1624
+ # Print warning and skip frame if attempting to compile for bfloat16
1625
+ # on device without hardware support for dtype
1626
+ device_interface = get_interface_for_device(device.type)
1627
+ if device_interface.is_bf16_supported(including_emulation=False):
1628
+ return
1629
+ warn_and_skip(device)
.venv/lib/python3.11/site-packages/torch/_inductor/config.py ADDED
@@ -0,0 +1,1241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os # noqa: C101
2
+ import sys
3
+ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
4
+
5
+ import torch
6
+
7
+
8
+ def is_fbcode() -> bool:
9
+ return not hasattr(torch.version, "git_version")
10
+
11
+
12
+ def fx_graph_remote_cache_default() -> Optional[bool]:
13
+ if os.environ.get("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") == "1":
14
+ return True
15
+ if os.environ.get("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") == "0":
16
+ return False
17
+ return None
18
+
19
+
20
+ def autotune_remote_cache_default() -> Optional[bool]:
21
+ if os.environ.get("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") == "1":
22
+ return True
23
+ if os.environ.get("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") == "0":
24
+ return False
25
+ return None
26
+
27
+
28
+ # Enable auto_functionalized_v2 (enabled by default)
29
+ enable_auto_functionalized_v2 = (
30
+ os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "0") == "1"
31
+ )
32
+
33
+ # add some debug printouts
34
+ debug = False
35
+
36
+ # Whether to disable a progress bar for autotuning
37
+ disable_progress = True
38
+
39
+ # Whether to enable printing the source code for each future
40
+ verbose_progress = False
41
+
42
+ # use fx aot graph codegen cache
43
+ fx_graph_cache = (
44
+ os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE", "0" if is_fbcode() else "1") == "1"
45
+ )
46
+
47
+ # use remote fx aot graph codegen cache
48
+ # False: Disables the cache
49
+ # True: Enables the cache
50
+ # None: Not set -- Off for OSS, JustKnobs based for internal
51
+ fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default()
52
+
53
+ # enable autotune local cache
54
+ autotune_local_cache = True
55
+
56
+ # enable autotune remote cache
57
+ # False: Disables the cache
58
+ # True: Enables the cache
59
+ # None: Not set -- Off for OSS, JustKnobs based for internal
60
+ autotune_remote_cache: Optional[bool] = autotune_remote_cache_default()
61
+
62
+ # Force disabled all inductor level caching -- This will override any other caching flag
63
+ force_disable_caches = os.environ.get("TORCHINDUCTOR_FORCE_DISABLE_CACHES") == "1"
64
+
65
+ # sleep in inductor for testing
66
+ sleep_sec_TESTING_ONLY: Optional[int] = None
67
+
68
+ # The default layout constraint for custom operators.
69
+ # This must be the name of one of the layout constraint tags
70
+ # (that is, one of {"needs_fixed_stride_order", "flexible_layout"}),
71
+ # If the custom op does not have a layout constraint tag already
72
+ # then we assume the following applies.
73
+ custom_op_default_layout_constraint = "flexible_layout"
74
+
75
+ # use cpp wrapper instead of python wrapper
76
+ cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1"
77
+
78
+ # codegen cpp wrapper code in an ABI compatible mode
79
+ abi_compatible = (
80
+ os.environ.get("TORCHINDUCTOR_ABI_COMPATIBLE", "1" if is_fbcode() else "0") == "1"
81
+ )
82
+
83
+ c_shim_version = os.environ.get("TORCHINDUCTOR_C_SHIM_VERSION", "2")
84
+
85
+ # dead code elimination
86
+ dce = False
87
+
88
+ # assume weight tensors are fixed size
89
+ static_weight_shapes = True
90
+
91
+ # put correctness assertions in generated code
92
+ size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
93
+ nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
94
+
95
+ # enable loop reordering based on input orders
96
+ pick_loop_orders = True
97
+
98
+ # reuse a kernel input as the output
99
+ inplace_buffers = True
100
+
101
+ # reuse a buffer for an unrelated purpose
102
+ allow_buffer_reuse = True
103
+
104
+ # Enable pooled allocations for non-output tensors
105
+ memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1"
106
+
107
+ # How to organize memory under memory_planning=True:
108
+ # - "none": do not try to pool storage, just reuse
109
+ # - "intermediates": all non-outputs share storage, outputs each get unique storage
110
+ # - "outputs": two pools, one for intermediates (freed on return) and one for outputs
111
+ # - "combined": a single pool for both intermediates and outputs
112
+ memory_pool = os.environ.get("TORCHINDUCTOR_MEMORY_POOL", "intermediates")
113
+
114
+ # codegen benchmark harness
115
+ benchmark_harness = True
116
+
117
+ # fuse pointwise into templates
118
+ epilogue_fusion = True
119
+
120
+ # do epilogue fusions before other fusions
121
+ epilogue_fusion_first = False
122
+
123
+ # enable pattern match+replace optimizations
124
+ pattern_matcher = True
125
+
126
+ # set to True to enable the back-to-back GEMM pass
127
+ b2b_gemm_pass = False
128
+
129
+ # register custom graph optimization pass hook. so far, pre/post passes are
130
+ # only applied before/after pattern_matcher in post_grad_passes.
131
+ #
132
+ # def my_custom_pre_pass(graph: torch.fx.graph.Graph):
133
+ # # my custom graph optimization pass
134
+ # ...
135
+ #
136
+ # def my_custom_post_pass(graph: torch.fx.graph.Graph):
137
+ # # my custom graph optimization pass
138
+ # ...
139
+ #
140
+ # torch._inductor.config.post_grad_custom_pre_pass = my_custom_pre_pass
141
+ # torch._inductor.config.post_grad_custom_post_pass = my_custom_post_pass
142
+ post_grad_custom_pre_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
143
+ post_grad_custom_post_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
144
+
145
+ # Registers a custom joint graph pass.
146
+ joint_custom_pre_pass: Optional[Callable[[torch.fx.Graph], None]] = None
147
+ joint_custom_post_pass: Optional[Callable[[torch.fx.Graph], None]] = None
148
+
149
+ # Registers a custom pregrad pass. Note that the pre-grad IR is 1.
150
+ # non-functional, 2. non-normalized, and 3. prone to change. Ideally we should
151
+ # use post-grad passes.
152
+ pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
153
+
154
+ # Registers a custom pass to be run right before fusion in Inductor scheduler.
155
+ # WARNING: Inductor scheduler IR is at prototype stage and subject to change,
156
+ # hence custom IR passes built on top of it might break in the future.
157
+ _pre_fusion_custom_pass: Optional[
158
+ Callable[
159
+ [List["torch._inductor.scheduler.BaseSchedulerNode"]],
160
+ List["torch._inductor.scheduler.BaseSchedulerNode"],
161
+ ]
162
+ ] = None
163
+
164
+ # Deprecated
165
+ split_cat_fx_passes = True
166
+
167
+ # Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability.
168
+ efficient_conv_bn_eval_fx_passes = False
169
+
170
+ # Enable predispatch aten IR for export
171
+ is_predispatch = False
172
+
173
+ # Deprecated
174
+ group_fusion = False
175
+
176
+ # Deprecated
177
+ batch_fusion = True
178
+
179
+ # Pre grad fusion and options in order, set to empty dict to disable fusion.
180
+ # Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions.
181
+ # batch fusion options:
182
+ # batch_linear
183
+ # batch_linear_lhs
184
+ # batch_layernorm
185
+ # batch_tanh
186
+ # batch_relu
187
+ # batch_sigmoid
188
+
189
+ # split cat fusion options:
190
+ # normalization_pass
191
+ # remove_split_with_size_one_pass
192
+ # merge_getitem_cat_pass
193
+ # merge_stack_tahn_unbind
194
+ # merge_splits_pass
195
+ # mutate_cat_pass
196
+ # split_cat_pass
197
+ pre_grad_fusion_options: Dict[str, Dict[str, Any]] = {
198
+ "batch_linear": {},
199
+ "batch_linear_lhs": {},
200
+ "batch_layernorm": {},
201
+ "batch_tanh": {},
202
+ "batch_relu": {},
203
+ "batch_sigmoid": {},
204
+ }
205
+
206
+ # Post grad fusion and options, set to empty dict to disable fusion.
207
+ # Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
208
+ post_grad_fusion_options: Dict[str, Dict[str, Any]] = {}
209
+
210
+ # enable reordering pass for improving memory locality
211
+ reorder_for_locality = True
212
+
213
+ # Scale down RBLOCK for better occupancy
214
+ dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1"
215
+
216
+ # this forces fusion for int_mm with mul. Needed when you want to avoid realizing the int32
217
+ # but the mul gets fused with other pointwise ops instead.
218
+ force_fuse_int_mm_with_mul = False
219
+
220
+ # for pattern torch.mm(a, b.to(dtype)) with cuda tensors,
221
+ # enable torch._inductor.kernel.mm.tuned_mixed_mm fused kernel.
222
+ # Autotune will compare perf with normal cast->then->mm option
223
+ use_mixed_mm = True
224
+
225
+ # enable runtime numeric check for pre/post grad fx passes
226
+ # floating point provides limited accuracy (about 7 decimal digits for single precision
227
+ # floating point numbers,about 16 decimal digits for double precision floating point numbers)
228
+ # according to PyTorch documentation.
229
+ # https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations
230
+ fx_passes_numeric_check: Dict[str, Any] = {
231
+ "pre_grad": False,
232
+ "precision": 1e-4,
233
+ "num_iterations": 1,
234
+ "requires_optimizer": True,
235
+ }
236
+
237
+ # mixed_mm_choice can be used to control the behaviour for pattern torch.mm(a, b.to(dtype)) with cuda tensors.
238
+ # The fallback aten implementation is normal cast->then->mm option.
239
+ # If mixed_mm_choice is "default": this flag will be ignored.
240
+ # If mixed_mm_choice is "triton":
241
+ # - Always use torch._inductor.kernel.mm.tuned_mixed_mm's fused kernel.
242
+ # - Autotune will not compare with fallback.
243
+ # If mixed_mm_choice is "aten": always use the fallback aten implementation.
244
+ # If mixed_mm_choice is "heuristic":
245
+ # - Enables the heuristic.
246
+ # - If the heuristic decides to add a config, it will add the config as the first choice.
247
+ # - If autotune is disabled, this config will always be chosen.
248
+ # - If autotune is enabled, it will also compare with fallback aten implementation and fused kernel.
249
+ # The use_mixed_mm flag will be ignored if mixed_mm_choice != "default".
250
+ mixed_mm_choice = "heuristic"
251
+
252
+ # enable reordering pass for increasing overlap between compute and communication
253
+ reorder_for_compute_comm_overlap = False
254
+
255
+ # passes (in execution order) for increasing overlap between compute and communication
256
+ # for built-in passes, use string name; for user-defined passes, pass in the function handle
257
+ # WARNING: Inductor scheduler IR is at prototype stage and subject to change,
258
+ # hence custom IR passes built on top of it might break in the future.
259
+ reorder_for_compute_comm_overlap_passes = [
260
+ "reorder_compute_for_overlap",
261
+ "sink_waits",
262
+ "raise_comms",
263
+ ]
264
+
265
+ # runtime estimation function for ops
266
+ # for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
267
+ estimate_op_runtime = "default"
268
+
269
+ # unit: GB/s, uni-directional P2P bandwidth per card
270
+ # default value is NVLink
271
+ intra_node_bw = 300
272
+
273
+ # unit: GB/s, uni-directional P2P bandwidth per node
274
+ # default value is InfiniBand
275
+ inter_node_bw = 25
276
+
277
+ # enable slow autotuning passes to select algorithms
278
+ max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
279
+
280
+ # enable slow autotuning passes to select pointwise/reductions algorithms
281
+ max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1"
282
+
283
+ # enable slow autotuning passes to select gemm algorithms
284
+ max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1"
285
+
286
+ # force cublas and triton to use the same precision; cublas supports TF32 for matmul operations
287
+ # when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations
288
+ # for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure
289
+ # that triton does not use TF32 wherever cublas would not use TF32
290
+ force_same_precision = (
291
+ True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1"
292
+ )
293
+
294
+ # Specify candidate backends for gemm autotune.
295
+ # Possible choices are combinations of: ATen, Triton, CUTLASS, CK, CPP.
296
+ # ATen: default Pytorch ATen kernels.
297
+ # Triton: Triton templates defined in torch inductor (AMD and NVidia GPUs).
298
+ # CUTLASS: Cutlass templates and kernels (NVidia GPUs only).
299
+ # CK: Composable Kernel templates and kernels (AMD Instinct GPUs only).
300
+ # CPP: CPP templates and kernels for CPU.
301
+ max_autotune_gemm_backends = os.environ.get(
302
+ "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP"
303
+ ).upper()
304
+
305
+ # As above, specify candidate backends for conv autotune.
306
+ # NB: in some cases for 1x1 convs we emit as matmul,
307
+ # which will use the backends of `max_autotune_gemm_backends`
308
+ max_autotune_conv_backends = os.environ.get(
309
+ "TORCHINDUCTOR_MAX_AUTOTUNE_CONV_BACKENDS", "ATEN,TRITON"
310
+ ).upper()
311
+
312
+
313
+ # Specify the size of the search space for GEMM autotuning.
314
+ # DEFAULT - balance between compile time overhead and performance
315
+ # EXHAUSTIVE - maximize performance
316
+ max_autotune_gemm_search_space = os.environ.get(
317
+ "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT"
318
+ ).upper()
319
+
320
+ # Whether we fall back to ATen or hard error when no matches are found during autotuning
321
+ autotune_fallback_to_aten = (
322
+ os.environ.get("TORCHINDUCTOR_AUTOTUNE_FALLBACK_TO_ATEN", "1") == "1"
323
+ )
324
+
325
+ # the value used as a fallback for the unbacked SymInts
326
+ # that can appear in the input shapes (e.g., in autotuning)
327
+ unbacked_symint_fallback = 8192
328
+
329
+ # DEPRECATED, DO NOT USE
330
+ search_autotune_cache = False
331
+
332
+ save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1"
333
+
334
+ # We will disable creating subprocess for autotuning if this is False
335
+ autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
336
+
337
+ # The following three timeouts are applicable if autotune_in_subproc is True:
338
+
339
+ # Max time that a a valid benchmark result may take during autotuning
340
+ max_autotune_subproc_result_timeout_seconds = 60.0
341
+ # Additional time we allow subprocesses to terminate gracefully after the timeout until we send a SIGTERM
342
+ max_autotune_subproc_graceful_timeout_seconds = 1.0
343
+ # Additional time that we grant after a SIGTERM until we do a hard SIGKILL of subprocesses
344
+ max_autotune_subproc_terminate_timeout_seconds = 2.0
345
+
346
+ # If autotuning in subprocess, whether to use multiple devices
347
+ autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1"
348
+
349
+ coordinate_descent_tuning = (
350
+ os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1"
351
+ )
352
+ coordinate_descent_check_all_directions = (
353
+ os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1"
354
+ )
355
+ coordinate_descent_search_radius = int(
356
+ os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1")
357
+ )
358
+
359
+ # AutoHeuristic is a framework that allows one to collect data from autotuning, use the data to learn a heuristic, and
360
+ # generate the learned heursitic to code which is shipped with the compiler
361
+ # Specify a list of comma separated optimizations to collect data for
362
+ autoheuristic_collect = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "")
363
+ # Specify a list of comma separated optimizations to use learned heuristics for
364
+ autoheuristic_use = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "mixed_mm")
365
+
366
+
367
+ def run_autoheuristic(name: str) -> bool:
368
+ return collect_autoheuristic(name) or use_autoheuristic(name)
369
+
370
+
371
+ def collect_autoheuristic(name: str) -> bool:
372
+ return name in torch._inductor.config.autoheuristic_collect.split(",")
373
+
374
+
375
+ def use_autoheuristic(name: str) -> bool:
376
+ return name in torch._inductor.config.autoheuristic_use.split(",")
377
+
378
+
379
+ # If set to "DEFAULT", this will use the default log path specified in autoheuristic.py.
380
+ # If set to another path, autoheuristic will instead log results to the given path.
381
+ autoheuristic_log_path = os.environ.get(
382
+ "TORCHINDUCTOR_AUTOHEURISTIC_LOG_PATH", "DEFAULT"
383
+ )
384
+
385
+ # Disabled by default on ROCm, opt-in if model utilises NHWC convolutions
386
+ layout_opt_default = "1" if not torch.version.hip else "0"
387
+ layout_optimization = (
388
+ os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", layout_opt_default) == "1"
389
+ )
390
+
391
+ force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1"
392
+
393
+
394
+ # Whether to keep the output strides the same as eager after layout optimization.
395
+ keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1"
396
+
397
+ # Enabling this will let compiler print warning messages if a generated triton
398
+ # kernel has inputs with mixed layouts. This is helpful for perf debugging
399
+ # since kernel with mixed layout inputs may run much slower then one whose inputs
400
+ # have uniform layouts.
401
+ warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1"
402
+
403
+ # control store vs recompute heuristic
404
+ # For fanouts, rematerialization can lead to exponential blowup. So, have
405
+ # smaller threshold
406
+ realize_reads_threshold = 4
407
+ realize_opcount_threshold = 30
408
+
409
+ # Threshold to prevent excessive accumulation of ops in one buffer during lowering
410
+ realize_acc_reads_threshold = 8
411
+
412
+ # fallback to eager for random/dropout, this is slow but useful for debugging
413
+ fallback_random = False
414
+
415
+ # automatically create fallbacks when encountering an unhandled op
416
+ implicit_fallbacks = True
417
+
418
+ # fuse even in cases without common reads
419
+ aggressive_fusion = False
420
+
421
+ # For each fused kernel in the wrapper, comment with the nodes that get fused.
422
+ # Useful for debugging fusion.
423
+ debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1"
424
+ benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
425
+ enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "")
426
+ loop_ordering_after_fusion = (
427
+ os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1"
428
+ )
429
+
430
+ # For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel
431
+ benchmark_epilogue_fusion = (
432
+ os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1"
433
+ )
434
+
435
+ # Take how many of the top triton kernels to benchmark epilogue
436
+ max_epilogue_benchmarked_choices = 1
437
+
438
+ # how many nodes to allow into a single fusion
439
+ max_fusion_size = 64
440
+
441
+ # max number of inputs to generate cat as a pointwise op with masked laods
442
+ max_pointwise_cat_inputs = 8
443
+
444
+ # replace small reductions with pointwise, disable with `= 1`
445
+ unroll_reductions_threshold = 8
446
+
447
+ # Add extra comments to output code (causes compile cache misses)
448
+ comment_origin = False
449
+
450
+ # Convert 1x1 convs into matmuls
451
+ conv_1x1_as_mm = False
452
+
453
+ # Enable split reductions for better utilization when the dimension
454
+ # being reduced over is large (by splitting it)
455
+ split_reductions = True
456
+
457
+ benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1"
458
+
459
+ # Enable constant and index_expr folding
460
+ constant_and_index_propagation = True
461
+
462
+ # we always add constants into graph.constants without
463
+ # performing any constant-inlining optimization
464
+ always_keep_tensor_constants = False
465
+
466
+ # assert that indirect indexing does not read / write out of bounds
467
+ assert_indirect_indexing = True
468
+
469
+ # compute CSE bounds on variables that do not appear in the FX graph
470
+ compute_all_bounds = False
471
+
472
+ # enable the combo kernel that combines data-independent kernels (additional
473
+ # to foreach kernels) into a single one (Experimental)
474
+ combo_kernels = False
475
+ # benchmark combo kernels and only allow ones with perf gains
476
+ benchmark_combo_kernel = False
477
+ # combo_kernel autotuning options: 0 - disable, 1 - enable except for foreach,
478
+ # 2 - enable for all
479
+ combo_kernels_autotune = 1
480
+ # Enable masking for combining kernels of mixed sizes: 0 - disable, 1 - enable
481
+ # for all except for foreach, 2 - enable for all
482
+ combo_kernel_allow_mixed_sizes = 1
483
+ # Enable dynamic shapes for foreach kernels
484
+ combo_kernel_foreach_dynamic_shapes = False
485
+
486
+ # constant folding on the joint graph
487
+ joint_graph_constant_folding = True
488
+
489
+ # Enable indirect_indexing asserts for decompositions and lowerings
490
+ debug_index_asserts = False
491
+
492
+ # Mode to emulate pytorch eager numerics for lower precision (fp16, bf16)
493
+ # Pytorch eager computes bf16/fp16 by upcasting inputs to fp32 and downcasting after
494
+ # For multiple, fused pointwise nodes, inductor will elide the intermediary upcasts and downcasts
495
+ # Typically this should be closer to fp64 ref numerics. However, it can be useful for debugging
496
+ # to emulate the eager numerics.
497
+ emulate_precision_casts = False
498
+
499
+ # warnings intended for PyTorch developers, disable for point releases
500
+ is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__
501
+ developer_warnings = is_fbcode() or is_nightly_or_source
502
+
503
+ # This pattern matches a special usage of scatter
504
+ # 1. It's applied to a constant tensor
505
+ # 2. The index tensor has size 1 in the scatter dimension
506
+ # Such pattern generates a sparse matrix when the const tensor is all-zero.
507
+ # We can lower this pattern to a pointwise kernel for more fusion opportunities
508
+ # and saving memory footprint.
509
+ optimize_scatter_upon_const_tensor = (
510
+ os.environ.get("TORCHINDUCTOR_OPTIMIZE_SCATTER_UPON_CONST_TENSOR", "1") == "1"
511
+ )
512
+
513
+
514
+ # The multiprocessing start method to use for inductor workers in the codecache.
515
+ # Can be "subprocess" or "fork".
516
+ def decide_worker_start_method() -> str:
517
+ start_method = os.environ.get(
518
+ "TORCHINDUCTOR_WORKER_START", "fork" if is_fbcode() else "subprocess"
519
+ )
520
+ assert start_method in (
521
+ "subprocess",
522
+ "fork",
523
+ ), f"Invalid start method: {start_method}"
524
+ return start_method
525
+
526
+
527
+ worker_start_method = decide_worker_start_method()
528
+
529
+ # Flags to turn on all_reduce fusion. These 2 flags should be automaticaly turned
530
+ # on by DDP and should not be set by the users.
531
+ _fuse_ddp_communication = False
532
+ _fuse_ddp_bucket_size = 25
533
+
534
+ # Flag to control which fusion passes to apply. Functions in the list will
535
+ # be applied in order. There are two different different fusion passes
536
+ # --"fuse_ddp_with_concat_op" and "fuse_ddp_with_coalesced_op". The default
537
+ # one is "fuse_ddp_with_concat_op". Users can also change this to a customized
538
+ # fusion function.
539
+ #
540
+ # The fusion currently does not support multiple DDP with different PG or
541
+ # data type. This feature will be added in the future PRs.
542
+ #
543
+ # "schedule_comm_wait" is used to delay the wait ops to maximize comm/comp
544
+ # overlapping. At this moment, this pass performs better than
545
+ # reorder_for_compute_comm_overlap_passes but we will add the logic of
546
+ # "schedule_comm_wait" in the future and remove the one here.
547
+ _fuse_ddp_communication_passes: List[Union[Callable[..., None], str]] = [
548
+ "fuse_ddp_with_concat_op",
549
+ "schedule_comm_wait",
550
+ ]
551
+
552
+ _micro_pipeline_tp: bool = False
553
+
554
+
555
+ def decide_compile_threads() -> int:
556
+ """
557
+ Here are the precedence to decide compile_threads
558
+ 1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by
559
+ setting this to 1 to make pdb happy.
560
+ 2. Set to 1 if it's win32 platform
561
+ 3. decide by the number of CPU cores
562
+ """
563
+ if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
564
+ return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
565
+ elif sys.platform == "win32":
566
+ return 1
567
+ elif is_fbcode():
568
+ return 1
569
+ else:
570
+ cpu_count = (
571
+ len(os.sched_getaffinity(0))
572
+ if hasattr(os, "sched_getaffinity")
573
+ else os.cpu_count()
574
+ )
575
+ assert cpu_count
576
+ return min(32, cpu_count)
577
+
578
+
579
+ compile_threads = decide_compile_threads()
580
+
581
+ # gemm autotuning global cache dir
582
+ if is_fbcode():
583
+ try:
584
+ from libfb.py import parutil
585
+
586
+ if __package__:
587
+ global_cache_dir = parutil.get_dir_path(
588
+ os.path.join(__package__.replace(".", os.sep), "fb/cache")
589
+ )
590
+ else:
591
+ global_cache_dir = parutil.get_dir_path("fb/cache")
592
+ except (ValueError, ModuleNotFoundError):
593
+ global_cache_dir = None
594
+
595
+ else:
596
+ global_cache_dir = None
597
+
598
+ # If kernel is fused, the name is generated from the origin node op names
599
+ # for larger kernels limit this
600
+ kernel_name_max_ops = 10
601
+
602
+ # Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs
603
+ shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1"
604
+
605
+ # Control if we will do padding for pointwise/reductions
606
+ comprehensive_padding = (
607
+ os.environ.get("TORCHINDUCTOR_COMPREHENSIVE_PADDING", "1") == "1"
608
+ )
609
+ pad_channels_last = False
610
+
611
+ # Disable comprehensive padding on the CPU
612
+ disable_padding_cpu = True
613
+
614
+ # The width of comprehensive padding, in bytes.
615
+ # CUDA max memory transaction size is 128 bytes for a warp.
616
+ padding_alignment_bytes = 128
617
+
618
+ # Threshold on the minimum stride that will be padded.
619
+ #
620
+ # Don't align a too small stride since that causes too much memory increase.
621
+ # Pad too small stride may also cause perf loss. We may result in many tiny data blocks
622
+ # with gaps in between. That causes less coalesced GPU memory access!
623
+ #
624
+ # Initially we pick 320 as the threshold since for alignement=16,
625
+ # that results in at most 5% memory cost.
626
+ #
627
+ # But later on we raise the threshold to 1024 to avoid interfere with persistent reduction.
628
+ # Let's say an inner reduction has a row size 513. Inductor will generate
629
+ # persistent reduction code.
630
+ # If we do padding, the strides are not contiguous any more. Inductor
631
+ # uses a much smaller threshold for persistent reduction in this case and
632
+ # generates potentially worse non-persistent reduction code.
633
+ #
634
+ # This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x.
635
+ # (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms)
636
+ padding_stride_threshold = 1024
637
+
638
+ # Enable padding outputs, even if they would not be padded in eager mode.
639
+ # By default, we use the same strides as eager mode.
640
+ pad_outputs = False
641
+
642
+ # Whether to treat output of the backward graph as user visible.
643
+ # For user visible outputs, inductor will make sure the stride matches with eager.
644
+ bw_outputs_user_visible = True
645
+
646
+ # Whether to always use shape padding if it is enabled and possible
647
+ force_shape_pad: bool = False
648
+
649
+ # Fx-based linear/matmul/bmm + permute/transpose vertical fusion
650
+ permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"
651
+
652
+ # Mark the wrapper call in PyTorch profiler
653
+ profiler_mark_wrapper_call = False
654
+
655
+ # Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for
656
+ # every intermediate for which we can correlate it with an intermediate
657
+ # from the original FX graph
658
+ generate_intermediate_hooks = False
659
+
660
+ # Populate traceback field on IRNode; good for debugging why origin_node is
661
+ # not populated, or finding out where an IRNode was constructed
662
+ debug_ir_traceback = False
663
+
664
+ # used for debugging to make sure config is properly set
665
+ _raise_error_for_testing = False
666
+
667
+ _profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "")
668
+ profile_bandwidth = _profile_var != ""
669
+ profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var
670
+ # Specify a file where we print out the profiling results.
671
+ # None means we do not dump results to a file.
672
+ profile_bandwidth_output = os.environ.get("TORCHINDUCTOR_PROFILE_OUTPUT", None)
673
+ # Switch to do_bench_using_profiling to exclude the CPU overheads
674
+ profile_bandwidth_with_do_bench_using_profiling = (
675
+ os.environ.get("TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING") == "1"
676
+ )
677
+
678
+
679
+ # TODO: remove later
680
+ disable_cpp_codegen = False
681
+
682
+
683
+ # Freezing will attempt to inline weights as constants in optimization
684
+ # and run constant folding and other optimizations on them. After freezing, weights
685
+ # can no longer be updated.
686
+ freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1"
687
+
688
+ # Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead
689
+ # of potentially keeping multiple copies of weights.
690
+ freezing_discard_parameters: bool = False
691
+
692
+ # Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests
693
+ # should be run with this flag both on and off to make sure we have coverage.
694
+ allow_stack_allocation: bool = (
695
+ os.environ.get("TORCHINDUCTOR_STACK_ALLOCATION", "1" if is_fbcode() else "0") == "1"
696
+ )
697
+
698
+ # Enables an alternate DSO interface (the "minimal ArrayRef interface") intended
699
+ # to maximize performance for use cases that it can accommodate at the expense of
700
+ # generality. In brief:
701
+ # - inputs and outputs are ArrayRefTensor<T> (note that strides are required, but the
702
+ # tensor must be contiguous)
703
+ # - constant handling is unchanged because it is not a per-inference-iteration bottleneck
704
+ #
705
+ # When the DSO is generated in this mode, the usual interface will also be supported,
706
+ # but performance for that interface may be degraded.
707
+ use_minimal_arrayref_interface: bool = False
708
+
709
+ # decompose some memory bound matmul/bmm to mul
710
+ decompose_mem_bound_mm: bool = False
711
+
712
+ # assume_aligned_inputs means that we assume that inputs will be aligned; we generate
713
+ # code using this assumption, and clone tensors before use if they aren't aligned.
714
+ # In the common case, most inputs will be aligned.
715
+ assume_aligned_inputs: bool = False
716
+
717
+ # For the user-written Triton kernels compiled with the model, ignore the unsupported
718
+ # arguments passed to the @triton.autotune in the user's code; this is unsafe, as
719
+ # ignoring the unsupported args may lead to unexpected autotuning behavior: don't
720
+ # set unless you know what you're doing.
721
+ unsafe_ignore_unsupported_triton_autotune_args: bool = False
722
+
723
+ # When True, we will check in scheduler.py _codegen that there are no "loops"
724
+ # in the call stack; that is to say, the same frame multiple times. This
725
+ # ensures that a cProfile trace to this frame will be a straight line without
726
+ # any cycles.
727
+ check_stack_no_cycles_TESTING_ONLY: bool = False
728
+
729
+
730
+ # config specific to codegen/cpp.py
731
+ class cpp:
732
+ # set to torch.get_num_threads()
733
+ threads = -1
734
+
735
+ # Do not generate loops when the condition doesn't hold, like:
736
+ # for(long i0=4096; i0<4096; i0+=1)
737
+ no_redundant_loops = (
738
+ os.environ.get("TORCHINDUCTOR_CPP_NO_REDUNDANT_LOOPS", "1") == "1"
739
+ )
740
+
741
+ # Assume number of threads is dynamic, don't specialize thread number.
742
+ # Kernels don't recompile on thread number changes with this flag on.
743
+ # For single-threaded workload, turning it on would incur a slight
744
+ # performance degradation.
745
+ dynamic_threads = os.environ.get("TORCHINDUCTOR_CPP_DYNAMIC_THREADS", "0") == "1"
746
+
747
+ simdlen: Optional[int] = None
748
+ min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096"))
749
+ cxx = (
750
+ None, # download gcc12 from conda-forge if conda is installed
751
+ # "g++-12",
752
+ # "g++-11",
753
+ # "g++-10",
754
+ # "clang++",
755
+ os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"),
756
+ # "g++.par",
757
+ )
758
+ # Allow kernel performance profiling via PyTorch profiler
759
+ enable_kernel_profile = (
760
+ os.environ.get("TORCHINDUCTOR_CPP_ENABLE_KERNEL_PROFILE", "0") == "1"
761
+ )
762
+
763
+ # enable weight prepacking to get a better performance; may lead to large memory footprint
764
+ weight_prepack = os.environ.get("TORCHINDUCTOR_CPP_WEIGHT_PREPACK", "1") == "1"
765
+
766
+ # Inject a bug into our relu implementation; useful for testing our repro
767
+ # extraction and minification functionality.
768
+ # Valid values: "compile_error", "runtime_error", "accuracy"
769
+ inject_relu_bug_TESTING_ONLY: Optional[str] = None
770
+ inject_log1p_bug_TESTING_ONLY: Optional[str] = None
771
+
772
+ # If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise,
773
+ # force usage as specified, without testing.
774
+ vec_isa_ok: Optional[bool] = None
775
+
776
+ # similar to config.triton.descriptive_names
777
+ descriptive_names = "original_aten"
778
+
779
+ # how many nodes to allow into a single horizontal fusion
780
+ max_horizontal_fusion_size = int(
781
+ os.environ.get("TORCHINDUCTOR_CPP_MAX_HORIZONTAL_FUSION_SIZE", "16")
782
+ )
783
+
784
+ # Make scatter_reduce fallback when reduce is sum to avoid performance regression
785
+ # using atomic_add.
786
+ fallback_scatter_reduce_sum = (
787
+ os.environ.get("TORCHINDUCTOR_CPP_FALLBACK_SCATTER_REDUCE_SUM", "1") == "1"
788
+ )
789
+
790
+ # Use funsafe-math-optimizations when compiling
791
+ enable_unsafe_math_opt_flag = (
792
+ os.environ.get("TORCHINDUCTOR_CPP_ENABLE_UNSAFE_MATH_OPT_FLAG", "0") == "1"
793
+ )
794
+
795
+ # Use ffp-contract when compiling
796
+ enable_floating_point_contract_flag = (
797
+ os.environ.get("TORCHINDUCTOR_CPP_ENABLE_FLOATING_POINT_CONTRACT_FLAG", "0")
798
+ == "1"
799
+ )
800
+
801
+ # Disable the tiling select heuristic
802
+ enable_tiling_heuristics = (
803
+ os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1"
804
+ )
805
+
806
+ # Maximal allowed number of slices on K-dim for a GEMM kernel. This controls
807
+ # the maximal parallelism of K-slicing. Since K-slicing requires extra thread
808
+ # synchronization and buffers, the maximal number of slices is limited to
809
+ # mitigate the sync overhead and memory usage.
810
+ # When set to 0, the number of slices is unlimited.
811
+ gemm_max_k_slices = int(os.environ.get("TORCHINDUCTOR_CPP_GEMM_MAX_K_SLICES", "1"))
812
+
813
+ # For perf tuning and debugging purpose, configure the pre-defined cache blocking for
814
+ # MxNxK dims respectively. The blockings are separated by comma and the unit is
815
+ # the number of register blocks.
816
+ # For example, "4,1,10" means 4 register blocks on M, 1 on N and 10 on K respectively.
817
+ gemm_cache_blocking = os.environ.get("TORCHINDUCTOR_CPP_GEMM_CACHE_BLOCKING", None)
818
+
819
+ # For perf tuning and debugging purpose, configure the pre-defined thread blocking factors for
820
+ # MxNxK dims respectively. The factors are separated by comma and their product
821
+ # should be the same as the total number of threads.
822
+ # For example, if the total number of threads is 56, "7,4,2" means the work is
823
+ # decomposed into 7x4x2 thread blocks along MxNxK of a GEMM.
824
+ gemm_thread_factors = os.environ.get("TORCHINDUCTOR_CPP_GEMM_THREAD_FACTORS", None)
825
+
826
+ # Whether to enable masked vectorization for the tail_loop.
827
+ enable_loop_tail_vec = True
828
+
829
+
830
+ # config specific to codegen/triton.py
831
+ class triton:
832
+ # Use cudagraphs on output code
833
+ cudagraphs = os.environ.get("TORCHINDUCTOR_CUDAGRAPHS") == "1"
834
+
835
+ # Use cudagraph trees for memory pooling if `cudagraphs` is True
836
+ cudagraph_trees = True
837
+
838
+ # Should we skip cudagraphing graphs with dynamic shape inputs
839
+ # If False, we will re-record a graph for each unique set of shape inputs
840
+ cudagraph_skip_dynamic_graphs = False
841
+
842
+ # assertions not on the fast path, steady state
843
+ slow_path_cudagraph_asserts = True
844
+
845
+ # TODO - need to debug why this prevents cleanup
846
+ cudagraph_trees_history_recording = False
847
+
848
+ # Enable cudagraph support for mutated inputs from prior cudagraph pool
849
+ cudagraph_support_input_mutation = False if is_fbcode() else True
850
+
851
+ # Maximal number of allowed cudagraph re-record for a function and
852
+ # a cudagraph node due to static input tensor address changes or
853
+ # cudagraph managed tensor data pointer changed.
854
+ # i.e., allow num_recording <= cudagraph_unexpected_rerecord_limit
855
+ # note: we are conservative here and choose a large limit.
856
+ cudagraph_unexpected_rerecord_limit = 128
857
+
858
+ # Warn loudly when the number of cudagraphs due to dynamic shape
859
+ # exceeds this limit
860
+ cudagraph_dynamic_shape_warn_limit: Optional[int] = 50
861
+
862
+ # synchronize after cudagraph invocation
863
+ force_cudagraph_sync = False
864
+
865
+ # always run cudagraphs in the eager warmup stage
866
+ # instead of recording and executing cudagraphs
867
+ force_cudagraphs_warmup = False
868
+
869
+ # assertions on the fast path
870
+ fast_path_cudagraph_asserts = False
871
+
872
+ # skip warmup for cudagraph trees
873
+ skip_cudagraph_warmup = False
874
+
875
+ # Synchronize before and after every compiled graph.
876
+ debug_sync_graph = False
877
+
878
+ # Synchronize after every kernel launch, to help pinpoint bugs
879
+ debug_sync_kernel = False
880
+
881
+ # Always load full blocks (rather than broadcasting inside the block)
882
+ dense_indexing = False
883
+
884
+ # limit tiling dimensions
885
+ max_tiles = 2
886
+
887
+ # Prefer higher dimensional tilings. This simplifies indexing expressions, making
888
+ # it easier to identify block pointers.
889
+ prefer_nd_tiling: bool = False
890
+
891
+ # use triton.autotune for pointwise ops with complex layouts
892
+ # this should only be disabled for debugging/testing
893
+ autotune_pointwise = True
894
+
895
+ # max autotune gemm with cublasLt
896
+ autotune_cublasLt = True
897
+
898
+ # Tune the generated Triton kernels at compile time instead of first time they run
899
+ autotune_at_compile_time = False
900
+
901
+ # should we stop a fusion to allow better tiling?
902
+ tiling_prevents_pointwise_fusion = True
903
+ tiling_prevents_reduction_fusion = True
904
+
905
+ # should we give different names to kernels
906
+ # Note: This is orthogonal to descriptive_names - this is deciding whether
907
+ # our triton kernel names should all be `triton_` (to maximize caching) or
908
+ # whether they should be unique.
909
+ unique_kernel_names = os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1"
910
+
911
+ # should we put op names in kernel names
912
+ # False: No special names (just triton__1, triton__2, etc.)
913
+ # "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.)
914
+ # "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions)
915
+ # "inductor_node": Maps to the node name in the FX graph passed to Inductor
916
+ descriptive_names = "original_aten"
917
+
918
+ # use alternate codegen for smaller reductions
919
+ persistent_reductions = (
920
+ os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1"
921
+ )
922
+
923
+ # 0/False: disable
924
+ # 1/True: enable, use tuning to pick between different subkernels
925
+ # 2: enable, force using persistent reduction (for debugging)
926
+ # 3: enable, force using non-persistent reduction (for debugging)
927
+ multi_kernel = int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0"))
928
+
929
+ # hint to Triton when arguments are divisible by 16
930
+ divisible_by_16 = True
931
+
932
+ # Minimum RBLOCK to be used for a TritonSplitScanKernel
933
+ # NOTE: This also indirectly controls the size of workspace buffer required
934
+ min_split_scan_rblock = 256
935
+
936
+ # Store the generated cubin files for cpp wrapper code to load
937
+ store_cubin = False
938
+
939
+ # the max number of spills we allow for the configs we benchmark.
940
+ # Setting this to 0 means we skip a config if it spills even a single
941
+ # register.
942
+ # Setting it to a larger value allows a config spilling a small amount
943
+ # of registers being benchmarked.
944
+ #
945
+ # NOTE: triton will always report >0 register spills for kernels using sin/cos.
946
+ # (check this issue https://github.com/openai/triton/issues/1756 )
947
+ # So far we see a fixed 8 spilled registers for kernels using sin/cos.
948
+ # Raise the threshold to 16 to be safe.
949
+ # We should revisit this once we understand more of the source of register spills.
950
+ spill_threshold: int = 16
951
+
952
+ # Generate code containing the newer tl.make_block_ptr() API for loads/store
953
+ use_block_ptr = False
954
+
955
+ # Inject a bug into our relu implementation; useful for testing our repro
956
+ # extraction and minification functionality.
957
+ # Valid values: "compile_error", "runtime_error", "accuracy"
958
+ inject_relu_bug_TESTING_ONLY: Optional[str] = None
959
+
960
+ # Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental)
961
+ codegen_upcast_to_fp32 = True
962
+
963
+
964
+ class aot_inductor:
965
+ # AOTInductor output path
966
+ # If an absolute path is specified, the generated lib files will be stored under the directory;
967
+ # If a relative path is specified, it will be used as a subdirectory under the default caching path;
968
+ # If not specified, a temp directory will be created under the default caching path.
969
+ # If the specified path contains something like "model.so", the sub-string will be used
970
+ # to name the generated library.
971
+ output_path = ""
972
+
973
+ debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1"
974
+
975
+ debug_dump_consts_bin: bool = (
976
+ os.environ.get("AOT_INDUCTOR_DEBUG_DUMP_CONSTS_BIN", "0") == "1"
977
+ )
978
+
979
+ # option for debug printing/saving for intermediate tensor values for aot inductor
980
+ # 0: disable debug dumping
981
+ # 1: enable saving intermediate tensor values
982
+ # 2: enable printing intermediate tensor values
983
+ debug_intermediate_value_printer = os.environ.get(
984
+ "AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0"
985
+ )
986
+
987
+ # filtered nodes to be printed for debug values. Specify this option when debug_intermediate_value_printer is set to 2
988
+ filtered_kernel_names = os.environ.get(
989
+ "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", None
990
+ )
991
+
992
+ # Serialized tree spec for flattening inputs
993
+ serialized_in_spec = ""
994
+
995
+ # Serialized tree spec for flattening outputs
996
+ serialized_out_spec = ""
997
+
998
+ # flag to decide whether to create a submodule for constant graph.
999
+ use_runtime_constant_folding: bool = False
1000
+
1001
+ # flag to force weight to be appened to the shared library and mmaped by the runtime
1002
+ # rather than embedded into the data section. Needed to support 1B+ parameter models
1003
+ force_mmap_weights: bool = False
1004
+
1005
+ package: bool = False
1006
+
1007
+
1008
+ class cuda:
1009
+ # CUDA arch to use for CUDA template kernel compilation.
1010
+ # e.g. "70", "75", "80", "90", etc.
1011
+ # When arch is None, Inductor uses torch.cuda.get_device_capability(0).
1012
+ arch: Optional[str] = None
1013
+
1014
+ # CUDA version to use for CUDA template kernel compilation.
1015
+ # e.g. "11.4", "12.1", etc.
1016
+ # When version is None, Inductor uses torch.version.cuda.
1017
+ version: Optional[str] = None
1018
+
1019
+ # Optimization level for the host compiler.
1020
+ compile_opt_level = "-O1"
1021
+
1022
+ # Whether to enable device LTO (link-time-optimization).
1023
+ enable_cuda_lto = False
1024
+
1025
+ # Whether to keep intermediate files dring compilation.
1026
+ enable_ptxas_info = False
1027
+
1028
+ # Whether to enable debug info, e.g. line number, cutlass debug info.
1029
+ enable_debug_info = False
1030
+
1031
+ # Whether to use fast math.
1032
+ use_fast_math = False
1033
+
1034
+ # Path to the CUTLASS repo root directory.
1035
+ # The default path only works under PyTorch local development environment.
1036
+ cutlass_dir = os.environ.get(
1037
+ "TORCHINDUCTOR_CUTLASS_DIR",
1038
+ os.path.abspath(
1039
+ os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/")
1040
+ ),
1041
+ )
1042
+
1043
+ # Configures the maximum number of CUTLASS configs to profile in max_autotune.
1044
+ # By default it's None, so that all CUTLASS configs are tuned.
1045
+ # This is mainly used to reduce test time in CI.
1046
+ cutlass_max_profiling_configs: Optional[int] = None
1047
+
1048
+ # Path to CUDA NVCC.
1049
+ # NVCC search order:
1050
+ # 1) cuda_cxx set in this config
1051
+ # 2) CUDACXX environment variable
1052
+ # 3) CUDA_HOME environment variable
1053
+ # 4) default system search PATH.
1054
+ cuda_cxx: Optional[str] = None
1055
+
1056
+ # Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops.
1057
+ cutlass_backend_min_gemm_size: int = 1
1058
+
1059
+ # enable generation of inline standalone runner in CUDA CPP generated code
1060
+ # which allows to compile the generated code into a standalone executable.
1061
+ generate_test_runner: bool = (
1062
+ os.environ.get("INDUCTOR_CUDA_BACKEND_GENERATE_TEST_RUNNER_CODE", "1") == "1"
1063
+ )
1064
+
1065
+ # Keep only Cutlass op configs which contain this regular expression pattern
1066
+ # Set this to "warpspecialized_cooperative_epi_tma" to enable only SM90 TMA Cutlass Kernels for large GEMMs
1067
+ cutlass_op_allowlist_regex: Optional[str] = None
1068
+
1069
+ # Note: Names of Cutlass ops names can be obtained by calling
1070
+ # op.configuration_name() on a Cutlass op instance, for example those
1071
+ # returned from cutlass_utils.gen_ops() or the op argument passed to
1072
+ # CUTLASSGemmTemplate.render(...)
1073
+
1074
+ # Filter Cutlass configs which contain this regular expression pattern
1075
+ # Set this to "pingpong" to avoid numerical issues
1076
+ # caused by the op ordering of the "pingpong" memory access
1077
+ # pattern used by some Cutlass Kernels.
1078
+ cutlass_op_denylist_regex: Optional[str] = "pingpong"
1079
+
1080
+
1081
+ class rocm:
1082
+ # Offload arch list for device code compilation, e.g. ["gfx941", "gfx942"].
1083
+ # If empty, the `native` arch is used
1084
+ arch: List[str] = []
1085
+
1086
+ # Enable the CK backend for CDNA2 and CDNA3 only (for now)
1087
+ # Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors
1088
+ ck_supported_arch: List[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"]
1089
+
1090
+ # Optimization level, use to balance compilation speed and runtime performance
1091
+ compile_opt_level = "-O2"
1092
+
1093
+ # Flag to keep debug information in compiled objects
1094
+ is_debug = False
1095
+
1096
+ # Flag to keep intermediate files (assembly listings, preprocessed sources, etc.)
1097
+ save_temps = False
1098
+
1099
+ # Flag to add `-ffast-math`` to compile flags
1100
+ use_fast_math = True
1101
+
1102
+ # Flag to add `-fgpu-flush-denormals-to-zero` to compile flags
1103
+ flush_denormals = True
1104
+
1105
+ # Flag to print register and LDS usage during compilation
1106
+ print_kernel_resource_usage = False
1107
+
1108
+ # Path to ROCm installation, if None, use env variable ROCM_HOME
1109
+ rocm_home: Optional[str] = None
1110
+
1111
+ # Path to Composable Kernel library.
1112
+ # Install with `pip install git+https://github.com/rocm/composable_kernel@develop`.
1113
+ ck_dir = os.environ.get("TORCHINDUCTOR_CK_DIR")
1114
+
1115
+ # Number of op instance choices to trade off between runtime perf and compilation time
1116
+ n_max_profiling_configs: Optional[int] = None
1117
+
1118
+ # Flag to use a short list of CK instances which perform well across a variety of shapes.
1119
+ # Currently RCR and F16 only
1120
+ use_preselected_instances: bool = False
1121
+
1122
+
1123
+ # Backend to use for CPU codegen either "cpp" or "halide" (experimental)
1124
+ cpu_backend = "cpp"
1125
+
1126
+ # Backend to use for CUDA codegen either "triton" or "halide" (experimental)
1127
+ cuda_backend = "triton"
1128
+
1129
+
1130
+ class halide:
1131
+ # Base halide target to use for CPU devices
1132
+ cpu_target = "host"
1133
+
1134
+ # Base halide target to use for CUDA devices
1135
+ gpu_target = "host-cuda"
1136
+
1137
+ # Halide autoscheduler to use, choices are:
1138
+ # "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only)
1139
+ scheduler_cuda = "Anderson2021"
1140
+ scheduler_cpu = "Adams2019"
1141
+
1142
+ # Controls `no_asserts` flag passed to Halide target (warning: can false positive)
1143
+ asserts = False
1144
+
1145
+ # Controls `debug` flag passed to Halide target
1146
+ debug = False
1147
+
1148
+ # Enable (or fallback on) scan kernels such as cumsum
1149
+ # Halide autoschedulers struggle with these kernels
1150
+ scan_kernels = False
1151
+
1152
+
1153
+ # create a directory containing lots of debug information
1154
+ class trace:
1155
+ # master switch for all debugging flags below
1156
+ enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
1157
+
1158
+ # Save debug information to a temporary directory
1159
+ # If not specified, a temp directory will be created by system
1160
+ debug_dir: Optional[str] = None
1161
+
1162
+ # Save python logger call >=logging.DEBUG
1163
+ debug_log = False
1164
+
1165
+ # Save python logger call >=logging.INFO
1166
+ info_log = False
1167
+
1168
+ # Save input FX graph (post decomps, pre optimization)
1169
+ fx_graph = True
1170
+
1171
+ # Save FX graph after transformations
1172
+ fx_graph_transformed = True
1173
+
1174
+ # Save TorchInductor IR before fusion pass
1175
+ ir_pre_fusion = True
1176
+
1177
+ # Save TorchInductor IR after fusion pass
1178
+ ir_post_fusion = True
1179
+
1180
+ # Copy generated code to trace dir
1181
+ output_code = True
1182
+
1183
+ # SVG figure showing post-fusion graph
1184
+ graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1"
1185
+
1186
+ # SVG figure showing fx with fusion
1187
+ draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1"
1188
+
1189
+ # We draw our fx graphs with the "record" shape attribute by default.
1190
+ # Sometimes, when the graph is very complex, we may hit dot errors like below:
1191
+ # "flat edge between adjacent nodes one of which has a record shape -
1192
+ # replace records with HTML-like labels"
1193
+ # and thus fail to generate a graph. So, let's give the user an option
1194
+ # to specify the shape attribute for the dot graph. For example, passing
1195
+ # INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like lables
1196
+ # to workaround the above failure.
1197
+ dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None)
1198
+
1199
+ # If not None, this is the URL that saves the SVG files of the input/output
1200
+ # graph of each pass that changed the graph
1201
+ # The nodes that are being transformed in each pass will be colored in yellow
1202
+ # URL only supports local directory for now
1203
+ log_url_for_graph_xform = os.environ.get("INDUCTOR_LOG_URL_FOR_GRAPH_XFORM", None)
1204
+
1205
+ # Store cProfile (see snakeviz to view)
1206
+ compile_profile = False
1207
+
1208
+ # Upload the .tar.gz file
1209
+ # Needs to be overriden based on specific environment needs
1210
+ upload_tar: Optional[Callable[[str], None]] = None
1211
+
1212
+ log_autotuning_results: bool = False
1213
+
1214
+
1215
+ _save_config_ignore = [
1216
+ # workaround: "Can't pickle <function ...>"
1217
+ "trace.upload_tar",
1218
+ "post_grad_custom_post_pass",
1219
+ "post_grad_custom_pre_pass",
1220
+ "joint_custom_pre_pass",
1221
+ "joint_custom_post_pass",
1222
+ "pre_grad_custom_pass",
1223
+ ]
1224
+
1225
+ _cache_config_ignore_prefix = [
1226
+ # trace functions are not relevant to config caching
1227
+ "trace",
1228
+ # uses absolute path
1229
+ "cuda.cutlass_dir",
1230
+ # not relevant
1231
+ "compile_threads",
1232
+ ]
1233
+
1234
+ if TYPE_CHECKING:
1235
+ from torch.utils._config_typing import * # noqa: F401, F403
1236
+
1237
+ from torch.utils._config_module import install_config_module
1238
+
1239
+
1240
+ # adds patch, save_config, etc
1241
+ install_config_module(sys.modules[__name__])
.venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.utils._pytree as pytree
6
+
7
+
8
+ aten = torch.ops.aten
9
+
10
+ # We would like to split modules into two subgraphs for runtime weight updates to work correctly.
11
+ # The use case and more information could be found at:
12
+ # https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing
13
+ META_TAG = "MODULE_TYPE"
14
+ MODULE_TAG = "_MAIN_MODULE"
15
+ CONST_MODULE_TAG = "_CONST_MODULE"
16
+
17
+
18
+ def replace_node_with_constant(
19
+ gm: torch.fx.GraphModule,
20
+ node: torch.fx.Node,
21
+ constant: torch.Tensor,
22
+ name: Optional[str] = None,
23
+ ) -> None:
24
+ g = gm.graph
25
+
26
+ if name:
27
+ qualname = name
28
+ else:
29
+ if not hasattr(gm, "_frozen_param_count"):
30
+ gm._frozen_param_count = 0 # type: ignore[assignment]
31
+ i = gm._frozen_param_count
32
+
33
+ while True:
34
+ qualname = f"_frozen_param{i}"
35
+ if not hasattr(gm, qualname):
36
+ break
37
+ i += 1
38
+
39
+ gm._frozen_param_count = i + 1
40
+
41
+ with g.inserting_before(node):
42
+ new_input_node = g.create_node("get_attr", qualname, (), {})
43
+ node.replace_all_uses_with(new_input_node)
44
+ new_input_node.meta.update(node.meta)
45
+ g.erase_node(node)
46
+
47
+ # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
48
+ gm.register_buffer(qualname, constant)
49
+ setattr(gm, qualname, constant)
50
+
51
+
52
+ def is_const_source(
53
+ node: torch.fx.Node, lifted_constants: Optional[Dict[str, Any]]
54
+ ) -> bool:
55
+ return node.op == "get_attr" or (
56
+ node.op == "placeholder"
57
+ and lifted_constants is not None
58
+ and node.name in lifted_constants
59
+ )
60
+
61
+
62
+ class ConstantFolder(torch.fx.Interpreter):
63
+ def __init__(
64
+ self,
65
+ gm: torch.fx.GraphModule,
66
+ skip_constructors: bool = False,
67
+ lifted_constants: Optional[Dict[str, torch.Tensor]] = None,
68
+ skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
69
+ ) -> None:
70
+ super().__init__(gm)
71
+ self.node_replacements: Dict[torch.fx.Node, Any] = {}
72
+ self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
73
+ self.unknown_value = object()
74
+ self.skip_constructors: bool = skip_constructors
75
+
76
+ # overwrite this to deallocate env values if their only remaining use
77
+ # is the output
78
+ self.user_to_last_uses = self.node_to_last_non_output_use()
79
+ self.lifted_constants = lifted_constants
80
+
81
+ def _support_dynamic_shape(self) -> bool:
82
+ # ConstantFolder not support dynamic shape now
83
+ return False
84
+
85
+ def _deduce_value(self, node: torch.fx.Node) -> Any:
86
+ return super().run_node(node)
87
+
88
+ def is_impure(self, node: torch.fx.node.Node) -> bool:
89
+ if (
90
+ node.target == torch.ops.prims.convert_element_type.default
91
+ and is_const_source(node.args[0], self.lifted_constants) # type: ignore[arg-type]
92
+ and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
93
+ and node.args[1] == torch.bfloat16
94
+ ):
95
+ # For int8_weight -> dq -> bf16_weight
96
+ return True
97
+ if node.target in [
98
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
99
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
100
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
101
+ ]:
102
+ # For the pattern fp32_weight -> q -> dq
103
+ # We only folding fp32_weight -> q
104
+ # int8_weight and leave dq in graph to be fused
105
+ return True
106
+ return False
107
+
108
+ def node_to_last_non_output_use(self) -> Dict[torch.fx.Node, List[torch.fx.Node]]:
109
+ last_non_output_use = collections.defaultdict(list)
110
+ seen_uses = set()
111
+ output_node = next(iter(reversed(self.module.graph.nodes)))
112
+
113
+ for node in reversed(self.module.graph.nodes):
114
+ if node.target == "output":
115
+ continue
116
+
117
+ def add_use(inp: torch.fx.Node) -> None:
118
+ if inp in seen_uses:
119
+ return
120
+
121
+ seen_uses.add(inp)
122
+ last_non_output_use[node].append(inp)
123
+
124
+ # In-place is fine since we don't mutate
125
+ pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))
126
+
127
+ # if this node is only used in output, we want to gc it right away
128
+ if len(node.users) == 1 and output_node in node.users:
129
+ last_non_output_use[node].append(node)
130
+
131
+ return last_non_output_use
132
+
133
+ def run_node(self, node: torch.fx.Node) -> Any:
134
+ if node.target == "output":
135
+ # because we remove nodes from env on last non output use,
136
+ # re-define them now or we'll get error in interpreter
137
+ def set_env(arg: torch.fx.Node) -> None:
138
+ self.env[arg] = self.unknown_value
139
+
140
+ # In-place is fine since we don't mutate
141
+ pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
142
+ return super().run_node(node)
143
+
144
+ args, kwargs = self.fetch_args_kwargs_from_env(node)
145
+ flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
146
+
147
+ # We need to do this weird thing because in cases where flattened_inputs
148
+ # contains a ScriptObject, equality checking results in a type error if
149
+ # the types are different.
150
+ if any(
151
+ type(self.unknown_value) == type(input_) and self.unknown_value == input_
152
+ for input_ in flattened_inputs
153
+ ):
154
+ return self.unknown_value
155
+
156
+ # TODO - fix errors with this
157
+ if (
158
+ node.op == "call_function"
159
+ and node.target == aten._efficientzerotensor.default
160
+ ):
161
+ return self.unknown_value
162
+
163
+ # TODO - constant folding triton kernel returns the inputs -- fix this
164
+ if (
165
+ node.op == "call_function"
166
+ and node.name == "triton_kernel_wrapper_functional_proxy"
167
+ ):
168
+ return self.unknown_value
169
+
170
+ # skip constructors, since inductor generates optimal code for them already
171
+ # and turning into tensor would result in an additional global memory read
172
+ # TODO - more complicated strategy
173
+ if (
174
+ self.skip_constructors
175
+ and not is_const_source(node, self.lifted_constants)
176
+ and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
177
+ ):
178
+ return self.unknown_value
179
+
180
+ # All mutations should either be removed or on inputs which we did not make constant
181
+ if (
182
+ isinstance(node.target, torch._ops.OpOverload)
183
+ and torch.Tag.nondeterministic_seeded in node.target.tags
184
+ ):
185
+ return self.unknown_value
186
+
187
+ out = self._deduce_value(node)
188
+ if out == self.unknown_value:
189
+ return self.unknown_value
190
+
191
+ if not is_const_source(node, self.lifted_constants) and isinstance(
192
+ out, torch.Tensor
193
+ ):
194
+ if out.device.type == "meta":
195
+ return out
196
+
197
+ if not self.insertable_tensor_check(out):
198
+ return out
199
+
200
+ if self.is_impure(node):
201
+ return self.unknown_value
202
+
203
+ self.add_node_replacement(node, out)
204
+
205
+ flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
206
+
207
+ for n in flattened_node_inps:
208
+ if not isinstance(n, torch.fx.Node):
209
+ continue
210
+
211
+ self.replaced_uses[n] += 1
212
+
213
+ for to_delete in self.user_to_last_uses.get(node, []):
214
+ if self.replaced_uses[to_delete] == len(to_delete.users):
215
+ self.node_replacements.pop(to_delete, None)
216
+
217
+ return out
218
+
219
+ def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
220
+ return True
221
+
222
+ def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
223
+ self.node_replacements[node] = tensor
224
+
225
+ def run(self) -> Any: # type: ignore[override]
226
+ env: Dict[torch.fx.Node, Any] = {}
227
+ self.insert_placerholder_values(env)
228
+ return super().run(initial_env=env)
229
+
230
+ def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
231
+ for n in self.module.graph.find_nodes(op="placeholder"):
232
+ if self.lifted_constants is not None and n.name in self.lifted_constants:
233
+ env[n] = self.lifted_constants[n.name]
234
+ else:
235
+ env[n] = self.unknown_value # type: ignore[assignment]
236
+
237
+
238
+ def constant_fold(
239
+ gm: torch.fx.GraphModule,
240
+ constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
241
+ ) -> None:
242
+ with torch.utils._python_dispatch._disable_current_modes():
243
+ cf = ConstantFolder(gm, skip_constructors=True)
244
+ cf.run()
245
+
246
+ for node, constant in cf.node_replacements.items():
247
+ if constraint_fn is not None and not constraint_fn(node):
248
+ continue
249
+ replace_node_with_constant(gm, node, constant)
250
+
251
+ erased_params = []
252
+ for node in gm.graph.find_nodes(op="get_attr"):
253
+ if len(node.users) == 0:
254
+ if hasattr(gm, node.target):
255
+ delattr(gm, node.target)
256
+ erased_params.append(node)
257
+
258
+ for node in erased_params:
259
+ gm.graph.erase_node(node)
260
+
261
+ gm.graph.eliminate_dead_code()
262
+ gm.graph.lint()
263
+ gm.recompile()
264
+
265
+
266
+ def constant_graph_tag(
267
+ gm: torch.fx.GraphModule,
268
+ lifted_constants: Optional[Dict[str, Any]],
269
+ skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]],
270
+ ) -> None:
271
+ with torch.utils._python_dispatch._disable_current_modes():
272
+ cf = ConstantFolder(
273
+ gm, skip_constructors=True, lifted_constants=lifted_constants
274
+ )
275
+ cf.run()
276
+
277
+ for node in gm.graph.nodes:
278
+ if skip_folding_node_fn is not None and skip_folding_node_fn(node):
279
+ node.meta[META_TAG] = MODULE_TAG
280
+ continue
281
+ if (
282
+ is_const_source(node, lifted_constants)
283
+ or node in cf.node_replacements
284
+ or node in cf.replaced_uses
285
+ ):
286
+ node.meta[META_TAG] = CONST_MODULE_TAG
287
+ else:
288
+ node.meta[META_TAG] = MODULE_TAG
289
+
290
+
291
+ def run_and_get_constant_graph(
292
+ gm: torch.fx.GraphModule,
293
+ lifted_constants: Optional[Dict[str, Any]],
294
+ skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]],
295
+ ) -> Tuple[torch.fx.GraphModule, Tuple[torch.Tensor, ...]]:
296
+ """
297
+ Construct a GraphModule which corresponds to the part which could be
298
+ constant folded in provided gm.
299
+ """
300
+
301
+ constant_graph_tag(gm, lifted_constants, skip_folding_node_fn)
302
+
303
+ def untag(node: torch.fx.Node) -> bool:
304
+ used_to_fold = False
305
+ for u in node.users:
306
+ if u.meta[META_TAG] == CONST_MODULE_TAG:
307
+ used_to_fold = True
308
+ break
309
+ if not used_to_fold:
310
+ node.meta[META_TAG] = MODULE_TAG
311
+ return used_to_fold
312
+
313
+ const_args = []
314
+ if lifted_constants is not None:
315
+ placeholders = list(gm.graph.find_nodes(op="placeholder"))
316
+ for node in placeholders:
317
+ if node.meta[META_TAG] == MODULE_TAG:
318
+ continue
319
+ if untag(node):
320
+ const_args.append(lifted_constants[node.name])
321
+
322
+ # We rewrite the tags, if it's a constant being directly consumed, without
323
+ # any folding opportunity, we keep it in main gm.
324
+ for node in gm.graph.find_nodes(op="get_attr"):
325
+ untag(node)
326
+
327
+ new_graph = torch.fx.Graph()
328
+
329
+ node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
330
+ output_nodes = []
331
+ for node in gm.graph.nodes:
332
+ if node.meta[META_TAG] == MODULE_TAG:
333
+ continue
334
+
335
+ new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
336
+ node_remapping[node] = new_node
337
+
338
+ for user in node.users:
339
+ if user.meta[META_TAG] == MODULE_TAG:
340
+ output_nodes.append(new_node)
341
+ break
342
+
343
+ new_graph.output(tuple(output_nodes))
344
+ new_graph.lint()
345
+ new_gm = torch.fx.GraphModule(gm, new_graph)
346
+
347
+ const_result = new_gm(*const_args)
348
+ return new_gm, const_result
.venv/lib/python3.11/site-packages/torch/_inductor/cpu_vec_isa.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import dataclasses
3
+ import functools
4
+ import os
5
+ import platform
6
+ import re
7
+ import subprocess
8
+ import sys
9
+ from typing import Any, Callable, Dict, List
10
+
11
+ import torch
12
+ from torch._inductor import config
13
+
14
+
15
+ _IS_WINDOWS = sys.platform == "win32"
16
+
17
+
18
+ def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str:
19
+ # ISA dry compile will cost about 1 sec time each startup time.
20
+ # Please check the issue: https://github.com/pytorch/pytorch/issues/100378
21
+ # Actually, dry compile is checking compile capability for ISA.
22
+ # We just record the compiler version, isa options and pytorch version info,
23
+ # and generated them to output binary hash path.
24
+ # It would optimize and skip compile existing binary.
25
+ from torch._inductor.cpp_builder import get_compiler_version_info, get_cpp_compiler
26
+
27
+ compiler_info = get_compiler_version_info(get_cpp_compiler())
28
+ torch_version = torch.__version__
29
+ fingerprint = f"{compiler_info}={isa_flags}={torch_version}"
30
+ return fingerprint
31
+
32
+
33
+ class VecISA:
34
+ _bit_width: int
35
+ _macro: List[str]
36
+ _arch_flags: str
37
+ _dtype_nelements: Dict[torch.dtype, int]
38
+
39
+ # Note [Checking for Vectorized Support in Inductor]
40
+ # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
41
+ # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions
42
+ # like exp, pow, sin, cos and etc.
43
+ # But PyTorch and TorchInductor might use different compilers to build code. If
44
+ # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so
45
+ # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass
46
+ # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest
47
+ # gcc/g++ compiler by default while it could support the AVX512 compilation.
48
+ # Therefore, there would be a conflict sleef version between PyTorch and
49
+ # TorchInductor. Hence, we dry-compile the following code to check whether current
50
+ # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM
51
+ # also needs the logic
52
+ # In fbcode however, we are using the same compiler for pytorch and for inductor codegen,
53
+ # making the runtime check unnecessary.
54
+ _avx_code = """
55
+ #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX)
56
+ #include <ATen/cpu/vec/functional.h>
57
+ #include <ATen/cpu/vec/vec.h>
58
+ #endif
59
+
60
+ alignas(64) float in_out_ptr0[16] = {0.0};
61
+
62
+ extern "C" void __avx_chk_kernel() {
63
+ auto tmp0 = at::vec::Vectorized<float>(1);
64
+ auto tmp1 = tmp0.exp();
65
+ tmp1.store(in_out_ptr0);
66
+ }
67
+ """ # noqa: B950
68
+
69
+ _avx_py_load = """
70
+ import torch
71
+ from ctypes import cdll
72
+ cdll.LoadLibrary("__lib_path__")
73
+ """
74
+
75
+ def bit_width(self) -> int:
76
+ return self._bit_width
77
+
78
+ def nelements(self, dtype: torch.dtype = torch.float) -> int:
79
+ return self._dtype_nelements[dtype]
80
+
81
+ def build_macro(self) -> List[str]:
82
+ return self._macro
83
+
84
+ def build_arch_flags(self) -> str:
85
+ return self._arch_flags
86
+
87
+ def __hash__(self) -> int:
88
+ return hash(str(self))
89
+
90
+ def check_build(self, code: str) -> bool:
91
+ from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT, write
92
+ from torch._inductor.cpp_builder import (
93
+ CppBuilder,
94
+ CppTorchOptions,
95
+ normalize_path_separator,
96
+ )
97
+
98
+ key, input_path = write(
99
+ code,
100
+ "cpp",
101
+ extra=_get_isa_dry_compile_fingerprint(self._arch_flags),
102
+ )
103
+ from filelock import FileLock
104
+
105
+ lock_dir = get_lock_dir()
106
+ lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
107
+ with lock:
108
+ output_dir = os.path.dirname(input_path)
109
+ buid_options = CppTorchOptions(vec_isa=self, warning_all=False)
110
+ x86_isa_help_builder = CppBuilder(
111
+ key,
112
+ [input_path],
113
+ buid_options,
114
+ output_dir,
115
+ )
116
+ try:
117
+ # Check if the output file exist, and compile when not.
118
+ output_path = normalize_path_separator(
119
+ x86_isa_help_builder.get_target_file_path()
120
+ )
121
+ if not os.path.isfile(output_path):
122
+ status, target_file = x86_isa_help_builder.build()
123
+
124
+ # Check build result
125
+ subprocess.check_call(
126
+ [
127
+ sys.executable,
128
+ "-c",
129
+ VecISA._avx_py_load.replace("__lib_path__", output_path),
130
+ ],
131
+ cwd=output_dir,
132
+ stderr=subprocess.DEVNULL,
133
+ env={**os.environ, "PYTHONPATH": ":".join(sys.path)},
134
+ )
135
+ except Exception as e:
136
+ return False
137
+
138
+ return True
139
+
140
+ @functools.lru_cache(None) # noqa: B019
141
+ def __bool__(self) -> bool:
142
+ if config.cpp.vec_isa_ok is not None:
143
+ return config.cpp.vec_isa_ok
144
+
145
+ if config.is_fbcode():
146
+ return True
147
+
148
+ return self.check_build(VecISA._avx_code)
149
+
150
+
151
+ @dataclasses.dataclass
152
+ class VecNEON(VecISA):
153
+ _bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
154
+ _macro = ["CPU_CAPABILITY_NEON"]
155
+ if sys.platform == "darwin" and platform.processor() == "arm":
156
+ _macro.append("AT_BUILD_ARM_VEC256_WITH_SLEEF")
157
+ _arch_flags = "" # Unused
158
+ _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
159
+
160
+ def __str__(self) -> str:
161
+ return "asimd" # detects the presence of advanced SIMD on armv8-a kernels
162
+
163
+ __hash__: Callable[[VecISA], Any] = VecISA.__hash__
164
+
165
+
166
+ @dataclasses.dataclass
167
+ class VecAVX512(VecISA):
168
+ _bit_width = 512
169
+ _macro = ["CPU_CAPABILITY_AVX512"]
170
+ _arch_flags = (
171
+ "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
172
+ if not _IS_WINDOWS
173
+ else "/arch:AVX512"
174
+ ) # TODO: use cflags
175
+ _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32}
176
+
177
+ def __str__(self) -> str:
178
+ return "avx512"
179
+
180
+ __hash__: Callable[[VecISA], Any] = VecISA.__hash__
181
+
182
+
183
+ @dataclasses.dataclass
184
+ class VecAMX(VecAVX512):
185
+ _arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8"
186
+
187
+ def __str__(self) -> str:
188
+ return super().__str__() + " amx_tile"
189
+
190
+ __hash__: Callable[[VecISA], Any] = VecISA.__hash__
191
+
192
+ _amx_code = """
193
+ #include <cstdint>
194
+ #include <immintrin.h>
195
+
196
+ struct amx_tilecfg {
197
+ uint8_t palette_id;
198
+ uint8_t start_row;
199
+ uint8_t reserved_0[14];
200
+ uint16_t colsb[16];
201
+ uint8_t rows[16];
202
+ };
203
+
204
+ extern "C" void __amx_chk_kernel() {
205
+ amx_tilecfg cfg = {0};
206
+ _tile_loadconfig(&cfg);
207
+ _tile_zero(0);
208
+ _tile_dpbf16ps(0, 1, 2);
209
+ _tile_dpbusd(0, 1, 2);
210
+ }
211
+ """
212
+
213
+ @functools.lru_cache(None) # noqa: B019
214
+ def __bool__(self) -> bool:
215
+ if super().__bool__():
216
+ if config.is_fbcode():
217
+ return False
218
+ if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx():
219
+ return True
220
+ return False
221
+
222
+
223
+ @dataclasses.dataclass
224
+ class VecAVX2(VecISA):
225
+ _bit_width = 256
226
+ _macro = ["CPU_CAPABILITY_AVX2"]
227
+ _arch_flags = (
228
+ "-mavx2 -mfma -mf16c" if not _IS_WINDOWS else "/arch:AVX2"
229
+ ) # TODO: use cflags
230
+ _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
231
+
232
+ def __str__(self) -> str:
233
+ return "avx2"
234
+
235
+ __hash__: Callable[[VecISA], Any] = VecISA.__hash__
236
+
237
+
238
+ @dataclasses.dataclass
239
+ class VecZVECTOR(VecISA):
240
+ _bit_width = 256
241
+ _macro = [
242
+ "CPU_CAPABILITY_ZVECTOR",
243
+ "CPU_CAPABILITY=ZVECTOR",
244
+ "HAVE_ZVECTOR_CPU_DEFINITION",
245
+ ]
246
+ _arch_flags = "-mvx -mzvector"
247
+ _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
248
+
249
+ def __str__(self) -> str:
250
+ return "zvector"
251
+
252
+ __hash__: Callable[[VecISA], Any] = VecISA.__hash__
253
+
254
+
255
+ @dataclasses.dataclass
256
+ class VecVSX(VecISA):
257
+ _bit_width = 256 # VSX simd supports 128 bit_width, but aten is emulating it as 256
258
+ _macro = ["CPU_CAPABILITY_VSX"]
259
+ _arch_flags = "-mvsx"
260
+ _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
261
+
262
+ def __str__(self) -> str:
263
+ return "vsx"
264
+
265
+ __hash__: Callable[[VecISA], Any] = VecISA.__hash__
266
+
267
+
268
+ class InvalidVecISA(VecISA):
269
+ _bit_width = 0
270
+ _macro = [""]
271
+ _arch_flags = ""
272
+ _dtype_nelements = {}
273
+
274
+ def __str__(self) -> str:
275
+ return "INVALID_VEC_ISA"
276
+
277
+ def __bool__(self) -> bool: # type: ignore[override]
278
+ return False
279
+
280
+ __hash__: Callable[[VecISA], Any] = VecISA.__hash__
281
+
282
+
283
+ def x86_isa_checker() -> List[str]:
284
+ supported_isa: List[str] = []
285
+
286
+ def _check_and_append_supported_isa(
287
+ dest: List[str], isa_supported: bool, isa_name: str
288
+ ) -> None:
289
+ if isa_supported:
290
+ dest.append(isa_name)
291
+
292
+ Arch = platform.machine()
293
+ """
294
+ Arch value is x86_64 on Linux, and the value is AMD64 on Windows.
295
+ """
296
+ if Arch != "x86_64" and Arch != "AMD64":
297
+ return supported_isa
298
+
299
+ avx2 = torch.cpu._is_avx2_supported()
300
+ avx512 = torch.cpu._is_avx512_supported()
301
+ amx_tile = torch.cpu._is_amx_tile_supported()
302
+
303
+ _check_and_append_supported_isa(supported_isa, avx2, "avx2")
304
+ _check_and_append_supported_isa(supported_isa, avx512, "avx512")
305
+ _check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile")
306
+
307
+ return supported_isa
308
+
309
+
310
+ invalid_vec_isa = InvalidVecISA()
311
+ supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()]
312
+
313
+
314
+ # Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
315
+ # might have too much redundant content that is useless for ISA check. Hence,
316
+ # we only cache some key isa information.
317
+ @functools.lru_cache(None)
318
+ def valid_vec_isa_list() -> List[VecISA]:
319
+ isa_list: List[VecISA] = []
320
+ if sys.platform == "darwin" and platform.processor() == "arm":
321
+ isa_list.append(VecNEON())
322
+
323
+ if sys.platform not in ["linux", "win32"]:
324
+ return isa_list
325
+
326
+ arch = platform.machine()
327
+ if arch == "s390x":
328
+ with open("/proc/cpuinfo") as _cpu_info:
329
+ while True:
330
+ line = _cpu_info.readline()
331
+ if not line:
332
+ break
333
+ # process line
334
+ featuresmatch = re.match(r"^features\s*:\s*(.*)$", line)
335
+ if featuresmatch:
336
+ for group in featuresmatch.groups():
337
+ if re.search(r"[\^ ]+vxe[\$ ]+", group):
338
+ isa_list.append(VecZVECTOR())
339
+ break
340
+ elif arch == "ppc64le":
341
+ isa_list.append(VecVSX())
342
+ elif arch == "aarch64":
343
+ isa_list.append(VecNEON())
344
+ elif arch in ["x86_64", "AMD64"]:
345
+ """
346
+ arch value is x86_64 on Linux, and the value is AMD64 on Windows.
347
+ """
348
+ _cpu_supported_x86_isa = x86_isa_checker()
349
+ for isa in supported_vec_isa_list:
350
+ if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa:
351
+ isa_list.append(isa)
352
+
353
+ return isa_list
354
+
355
+
356
+ def pick_vec_isa() -> VecISA:
357
+ if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]):
358
+ return VecAVX2()
359
+
360
+ _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
361
+ if not _valid_vec_isa_list:
362
+ return invalid_vec_isa
363
+
364
+ # If the simdlen is None, it indicates determine the vectorization length automatically
365
+ if config.cpp.simdlen is None:
366
+ assert _valid_vec_isa_list
367
+ return _valid_vec_isa_list[0]
368
+
369
+ for isa in _valid_vec_isa_list:
370
+ if config.cpp.simdlen == isa.bit_width():
371
+ return isa
372
+
373
+ return invalid_vec_isa
.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_utils.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ import dataclasses
5
+ from enum import Enum
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
7
+
8
+ import torch
9
+ from torch._dynamo.utils import counters
10
+ from torch._inductor.utils import InputType
11
+
12
+
13
+ perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
14
+ static_inputs_log = torch._logging.getArtifactLogger(
15
+ __name__, "cudagraph_static_inputs"
16
+ )
17
+
18
+
19
+ OutputType = List[Optional[Union[int, torch.Tensor]]]
20
+ ModelType = Callable[[List[InputType]], OutputType]
21
+
22
+
23
+ @dataclasses.dataclass(frozen=True)
24
+ class FunctionID:
25
+ "Unique counter of a function wrapped in cudagraphify_impl"
26
+ id: int
27
+
28
+
29
+ @dataclasses.dataclass(frozen=True)
30
+ class PlaceholderInfo:
31
+ """
32
+ A serializable version of torch.fx.Node that contains information
33
+ pertinent to placeholder stack traces. We use these in logging and error messages
34
+ related to cudagraphs, and will cache these results.
35
+ """
36
+
37
+ name: str
38
+ stack_trace: Optional[str]
39
+ # This field is recursive, but never cyclic (since a node never uses itself)
40
+ users: List[PlaceholderInfo]
41
+ mutating_use_stack_trace: Optional[str]
42
+
43
+
44
+ @dataclasses.dataclass(frozen=True)
45
+ class WrappedFunction:
46
+ """
47
+ Represents a function that you want to record for CUDA graph replay,
48
+ with a little more metadata so we can identify if we have an applicable
49
+ CUDA graph in our CUDA graph tree for it.
50
+ """
51
+
52
+ model: Callable[..., Any]
53
+ static_input_idxs: Sequence[int]
54
+ id: FunctionID
55
+ constants: Tuple[torch.Tensor, ...]
56
+ placeholders: Sequence[PlaceholderInfo]
57
+ mutated_input_idxs: Sequence[int]
58
+
59
+
60
+ def get_mutating_use_stack_trace_from_node(
61
+ placeholder_node: torch.fx.Node,
62
+ ) -> Optional[str]:
63
+ # reinplaced uses might have a single, non-copy_ use
64
+ if len(placeholder_node.users) == 1:
65
+ return next(iter(placeholder_node.users)).meta.get("stack_trace", None)
66
+
67
+ for use in placeholder_node.users:
68
+ if use.target == torch.ops.aten.copy_.default:
69
+ if stack_trace := use.meta.get("stack_trace", None):
70
+ return stack_trace
71
+
72
+ return None
73
+
74
+
75
+ def get_mutating_use_stack_trace(placeholder_info: PlaceholderInfo) -> Optional[str]:
76
+ return placeholder_info.mutating_use_stack_trace
77
+
78
+
79
+ def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo:
80
+ name = placeholder_node.name
81
+ stack_trace = placeholder_node.meta.get("stack_trace", None)
82
+ users = []
83
+ mutating_use_stack_trace = None
84
+ # Only recurse to users once, since we only care about user's stack traces
85
+ if placeholder_node.op == "placeholder":
86
+ users = [to_placeholder_info(i) for i in placeholder_node.users]
87
+ mutating_use_stack_trace = get_mutating_use_stack_trace_from_node(
88
+ placeholder_node
89
+ )
90
+
91
+ return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace)
92
+
93
+
94
+ def get_placeholder_info(graph: torch.fx.Graph) -> List[PlaceholderInfo]:
95
+ return [
96
+ to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder"
97
+ ]
98
+
99
+
100
+ def format_default_skip_message(reason: str) -> str:
101
+ return f"skipping cudagraphs due to {reason}"
102
+
103
+
104
+ def get_mutation_stack_trace(
105
+ placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int]
106
+ ) -> str:
107
+ stack_trace: Optional[str] = ""
108
+
109
+ for idx in mutation_indices:
110
+ placeholder = placeholders[idx]
111
+ if stack_trace := get_mutating_use_stack_trace(placeholder):
112
+ break
113
+
114
+ msg = format_default_skip_message(
115
+ f"mutated inputs ({len(mutation_indices)} instances)"
116
+ )
117
+ if stack_trace:
118
+ return f"{msg}. Found from : \n {stack_trace}"
119
+
120
+ return msg
121
+
122
+
123
+ def check_for_mutation(
124
+ func: WrappedFunction,
125
+ inputs: List[InputType],
126
+ is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool],
127
+ ) -> Optional[str]:
128
+ # doesnt work for non-trees because the warmup run would apply mutation twice
129
+ if torch._inductor.config.triton.cudagraph_trees:
130
+ # checking if mutation is only on parameters/static inputs
131
+ mutation_indices: Sequence[int] = [
132
+ idx
133
+ for idx in func.mutated_input_idxs
134
+ if not (
135
+ idx in func.static_input_idxs
136
+ or is_cuda_graph_recorded_tensor(inputs[idx]) # type: ignore[arg-type]
137
+ )
138
+ ]
139
+ else:
140
+ mutation_indices = func.mutated_input_idxs
141
+
142
+ static_inputs_log.debug(
143
+ "check mutation static input indices: %s", func.static_input_idxs
144
+ )
145
+ static_inputs_log.debug("check mutation mutation indices: %s", mutation_indices)
146
+
147
+ return (
148
+ get_mutation_stack_trace(func.placeholders, mutation_indices)
149
+ if mutation_indices
150
+ else None
151
+ )
152
+
153
+
154
+ def _get_use_stack_trace(node) -> Optional[str]:
155
+ for use in node.users:
156
+ if stack_trace := use.meta.get("stack_trace", None):
157
+ return stack_trace
158
+ return None
159
+
160
+
161
+ def check_multiple_devices_or_any_cpu_nodes(
162
+ device_node_mapping: Dict[torch.device, torch.fx.Node]
163
+ ) -> Optional[str]:
164
+ if cpu_node := device_node_mapping.get(torch.device("cpu")):
165
+ msg = f"cpu device ({cpu_node.name})"
166
+ if stack_trace := _get_use_stack_trace(cpu_node):
167
+ return format_default_skip_message(f"{msg}. Found from : \n {stack_trace}")
168
+
169
+ return format_default_skip_message(msg)
170
+
171
+ if (
172
+ len(device_node_mapping) == 1
173
+ and next(iter(device_node_mapping.keys())).type == "cuda"
174
+ ):
175
+ return None
176
+
177
+ keys_repr = (repr(key) for key in device_node_mapping.keys())
178
+ return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}")
179
+
180
+
181
+ def check_lowering_disable_cudagraph(
182
+ device_node_mapping: Dict[torch.device, torch.fx.Node]
183
+ ):
184
+ return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
185
+
186
+
187
+ def log_cudagraph_skip_and_bump_counter(msg):
188
+ perf_hint_log.warning(msg)
189
+ counters["inductor"]["cudagraph_skips"] += 1
190
+
191
+
192
+ @dataclasses.dataclass
193
+ class BoxedDeviceIndex:
194
+ value: Optional[int]
195
+
196
+ def set(self, device_idx: Optional[int]):
197
+ assert device_idx is None or isinstance(device_idx, int)
198
+ self.value = device_idx
199
+
200
+
201
+ def check_for_mutation_ignore_cuda_graph_managed_tensor(
202
+ gm: torch.fx.GraphModule, compiled_graph, static_input_idxs: Sequence[int]
203
+ ) -> Optional[str]:
204
+ default_msg = format_default_skip_message("mutated inputs")
205
+
206
+ # doesnt work for non-trees because the warmup run would apply mutation twice
207
+ if torch._inductor.config.triton.cudagraph_trees:
208
+ unique_idxs = set(static_input_idxs)
209
+ # checking if mutation is only on parameters/static inputs
210
+ mutation_indices = [
211
+ idx for idx in compiled_graph.mutated_input_idxs if idx not in unique_idxs
212
+ ]
213
+ has_mutation = len(mutation_indices) != 0
214
+ if not has_mutation:
215
+ return None
216
+ placeholders = get_placeholder_info(gm.graph)
217
+ return get_mutation_stack_trace(placeholders, mutation_indices)
218
+
219
+ else:
220
+ has_mutation = len(compiled_graph.mutated_inputs) != 0
221
+ return None if not has_mutation else default_msg
222
+
223
+
224
+ def get_placeholder_stack_trace(placeholder: PlaceholderInfo) -> Optional[str]:
225
+ """
226
+ Gets the first non-empty stack trace of a placeholder or its users.
227
+ """
228
+ if placeholder.stack_trace:
229
+ return placeholder.stack_trace
230
+
231
+ for user in placeholder.users:
232
+ if user.stack_trace:
233
+ return user.stack_trace
234
+
235
+ return None
236
+
237
+
238
+ class CheckInvariantStatus(Enum):
239
+ # Check invariant succeeded
240
+ SUCCESS = 1
241
+
242
+ # Previously managed data pointers are not stable
243
+ CudagraphManagedIdxMismatch = 2
244
+
245
+ # Static tensor input addresses are not stable
246
+ StaticInputIdxMismatch = 3
247
+
248
+ # Expected dead indices before graph are live
249
+ ExpectedDeadIndicesBeforeGraphMismatch = 4
250
+
251
+ def __str__(self) -> str:
252
+ if self.name == "CudagraphManagedIdxMismatch":
253
+ return "cudagraph managed tensor data pointer changed"
254
+ elif self.name == "StaticInputIdxMismatch":
255
+ return "static input data pointer changed"
256
+ elif self.name == "ExpectedDeadIndicesBeforeGraphMismatch":
257
+ return "expected dead indices before graph are live"
258
+ else:
259
+ return f"{self.name}: {self.value}"
260
+
261
+
262
+ def log_data_ptr_mismatch(
263
+ placeholders: Sequence[PlaceholderInfo],
264
+ inputs: List[InputType],
265
+ recorded_data_ptr: Sequence[Optional[int]],
266
+ target_idxs: Sequence[int],
267
+ mismatch: CheckInvariantStatus,
268
+ ) -> str:
269
+ """
270
+ Logs the mismatch between input data pointers and recorded data pointers.
271
+ This checks only idxs in target_idxs.
272
+ """
273
+ assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(
274
+ placeholders
275
+ ), "length mismatch between inputs, recorded_data_ptr, and placeholders"
276
+
277
+ t_tensors = [inputs[i] for i in target_idxs]
278
+ t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs]
279
+ error_msg = f"{mismatch}.\n"
280
+ for i, (tensor, data_ptr) in enumerate(zip(t_tensors, t_data_ptrs)):
281
+ assert isinstance(tensor, torch.Tensor)
282
+ index = target_idxs[i]
283
+ if tensor.data_ptr() != data_ptr:
284
+ placeholder = placeholders[index]
285
+ error_msg = (
286
+ f"{error_msg}input name: {placeholder.name}. "
287
+ f"data pointer changed from {data_ptr} to {tensor.data_ptr()}. "
288
+ f"input stack trace: {get_placeholder_stack_trace(placeholder)}\n"
289
+ )
290
+ return error_msg
291
+
292
+
293
+ def maybe_warning_due_to_dynamic_shape(
294
+ fn_cache: Dict[Tuple[int, ...], Callable[..., Any]],
295
+ new_int_key: Any,
296
+ ) -> bool:
297
+ num_cudagraphs = len(fn_cache.keys()) + 1
298
+
299
+ def warn_msg():
300
+ return (
301
+ "CUDAGraph supports dynamic shapes by recording a new graph for each "
302
+ "distinct input size. Recording too many CUDAGraphs may lead to "
303
+ f"extra overhead. We have observed {num_cudagraphs} distinct sizes. "
304
+ "Please consider the following options for better performance: "
305
+ "a) padding inputs to a few fixed number of shapes; or b) set "
306
+ "torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. "
307
+ "Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None "
308
+ "to silence this warning."
309
+ )
310
+
311
+ if (
312
+ torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit
313
+ and num_cudagraphs
314
+ > torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit
315
+ ):
316
+ perf_hint_log.warning(warn_msg())
317
+ return True
318
+
319
+ return False
320
+
321
+
322
+ @dataclasses.dataclass(frozen=True)
323
+ class CudagraphCachedInfo:
324
+ """
325
+ Info needed to realign inputs
326
+ """
327
+
328
+ placeholders: Sequence[PlaceholderInfo]
329
+ stack_traces: List[Optional[str]]
330
+ cudagraph_fail_reasons: List[str]
.venv/lib/python3.11/site-packages/torch/_inductor/debug.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import contextlib
3
+ import dataclasses
4
+ import functools
5
+ import itertools
6
+ import logging
7
+ import os
8
+ import os.path
9
+ import pickle
10
+ import pstats
11
+ import shutil
12
+ import subprocess
13
+ from typing import Any, Callable, Dict, IO, Iterator, List, Optional, Type, Union
14
+ from unittest.mock import patch
15
+
16
+ import torch
17
+ from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
18
+ from torch import fx as fx
19
+ from torch._dynamo.repro.after_aot import save_graph_repro
20
+ from torch._dynamo.utils import get_debug_dir
21
+ from torch.fx.graph_module import GraphModule
22
+ from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
23
+ from torch.fx.passes.tools_common import legalize_graph
24
+ from torch.utils._pytree import tree_map
25
+
26
+ from . import config, ir # noqa: F811, this is needed
27
+ from .scheduler import (
28
+ BaseSchedulerNode,
29
+ FusedSchedulerNode,
30
+ NopKernelSchedulerNode,
31
+ OutputNode,
32
+ SchedulerNode,
33
+ )
34
+ from .virtualized import V
35
+
36
+
37
+ log = logging.getLogger(__name__)
38
+
39
+ SchedulerNodeList = List[Any]
40
+ BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
41
+ GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
42
+
43
+
44
+ @functools.lru_cache(None)
45
+ def has_dot() -> bool:
46
+ try:
47
+ subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
48
+ return True
49
+ except subprocess.SubprocessError:
50
+ return False
51
+
52
+
53
+ def draw_buffers(
54
+ nodes: List[BaseSchedulerNode],
55
+ print_graph: bool = False,
56
+ fname: Optional[str] = None,
57
+ ) -> None:
58
+ """
59
+ Draw a graph in fname.svg.
60
+ """
61
+ if not has_dot():
62
+ log.warning("draw_buffers() requires `graphviz` package")
63
+ return
64
+
65
+ if fname is None:
66
+ fname = get_graph_being_compiled()
67
+
68
+ graph = create_fx_from_snodes(nodes)
69
+
70
+ for node in graph.nodes:
71
+ if "fusion_meta" not in node.meta:
72
+ continue
73
+ group = node.meta["fusion_meta"].group
74
+ if isinstance(group, tuple):
75
+ if isinstance(group[1], int):
76
+ group = (group[1],)
77
+ else:
78
+ group = group[1]
79
+
80
+ # gather meta data
81
+ dtype = None
82
+ if isinstance(node, ir.ComputedBuffer):
83
+ dtype = node.data.dtype
84
+
85
+ metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type]
86
+ node.meta["tensor_meta"] = metadata
87
+
88
+ if print_graph:
89
+ print(graph)
90
+
91
+ gm = GraphModule({}, graph)
92
+ legalize_graph(gm)
93
+ gm.graph.lint()
94
+ draw_graph(
95
+ gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape
96
+ )
97
+
98
+
99
+ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
100
+ """
101
+ Creates a FX Graph from a list of SchedulerNode objects.
102
+ """
103
+
104
+ def get_fake_func(name: str) -> Callable[..., int]:
105
+ def func1(*args: Any) -> int:
106
+ return 0
107
+
108
+ func1.__name__ = name
109
+ return func1
110
+
111
+ FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])
112
+
113
+ buf_to_fx_node = {}
114
+ node_to_fx_node = {}
115
+ graph = torch.fx.Graph()
116
+ first_node = None
117
+
118
+ outputs = []
119
+ group: Any = None
120
+ # create call_function node for each Buffer and Kernel
121
+ for snode in snodes:
122
+ if snode.is_extern():
123
+ node_type = "extern"
124
+ group = node_type
125
+ elif snode.is_template():
126
+ node_type = "template"
127
+ group = node_type
128
+ elif isinstance(snode, NopKernelSchedulerNode):
129
+ node_type = "nop"
130
+ group = node_type
131
+ elif isinstance(snode, SchedulerNode):
132
+ node_type = "compute"
133
+ group = snode.group
134
+ elif isinstance(snode, FusedSchedulerNode):
135
+ node_type = "fused"
136
+ group = snode.group
137
+ else:
138
+ raise RuntimeError("Unknown node type")
139
+
140
+ fused_name = torch._inductor.utils.get_fused_kernel_name(
141
+ snode.get_nodes(), "original_aten"
142
+ )
143
+ func_name = f"{node_type}: {fused_name}"
144
+ node_func = get_fake_func(func_name)
145
+ kwargs = {}
146
+ if hasattr(snode, "get_device"):
147
+ kwargs = {"device": snode.get_device()}
148
+ fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type]
149
+
150
+ def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool:
151
+ if isinstance(snode, FusedSchedulerNode):
152
+ return any(in_output(x) for x in snode.snodes)
153
+ return any(
154
+ isinstance(user.node, OutputNode)
155
+ for buf in snode.get_outputs()
156
+ for user in buf.users
157
+ )
158
+
159
+ if in_output(snode):
160
+ outputs.append(fx_node)
161
+ name = snode.get_name()
162
+ fx_node.name = name
163
+
164
+ fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type)
165
+
166
+ node_to_fx_node[name] = fx_node
167
+ for buf in snode.get_outputs():
168
+ buf_to_fx_node[buf.get_name()] = fx_node
169
+
170
+ if first_node is None:
171
+ first_node = fx_node
172
+
173
+ # create edges between nodes
174
+ for snode in snodes:
175
+ name = snode.get_name()
176
+ deps = snode.read_writes.reads
177
+
178
+ fx_node = node_to_fx_node[name]
179
+ new_args = []
180
+ for dep in deps:
181
+ if dep.name in buf_to_fx_node:
182
+ dep_node = buf_to_fx_node[dep.name]
183
+ else:
184
+ with graph.inserting_before(first_node):
185
+ dep_node = graph.placeholder(dep.name)
186
+ buf_to_fx_node[dep.name] = dep_node
187
+ if dep_node == fx_node: # to avoid cycles
188
+ continue
189
+ new_args.append(dep_node)
190
+
191
+ fx_node.args = tuple(new_args)
192
+
193
+ graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
194
+ return graph
195
+
196
+
197
+ def update_orig_fx_node_name_to_buf_name(
198
+ nodes: Optional[SchedulerNodeList],
199
+ node_name_to_buf_name: Dict[str, str],
200
+ parent_buf_name: Optional[str] = None,
201
+ n_origins: int = 0,
202
+ ) -> None:
203
+ if nodes is None:
204
+ return
205
+ for node in nodes:
206
+ # for FusedSchedulerNode, traverse recursively into get_nodes()
207
+ buf_name = node.get_name()
208
+ children_nodes = node.get_nodes()
209
+ if children_nodes is not None and len(children_nodes) > 1:
210
+ update_orig_fx_node_name_to_buf_name(
211
+ children_nodes,
212
+ node_name_to_buf_name,
213
+ buf_name if parent_buf_name is None else parent_buf_name,
214
+ )
215
+ continue
216
+ else:
217
+ assert len(children_nodes) == 1 and children_nodes[0] == node
218
+
219
+ ir_node = node.node
220
+ if ir_node is None or ir_node.origins is None:
221
+ continue
222
+ for origin in ir_node.origins:
223
+ node_name = origin.name
224
+ # when buf1 and buf2 both have origin=node1
225
+ # we draw node1 according to buf1
226
+ if node_name not in node_name_to_buf_name:
227
+ node_name_to_buf_name[node_name] = (
228
+ buf_name if parent_buf_name is None else parent_buf_name
229
+ )
230
+
231
+
232
+ def get_node_name_to_buf_meta(
233
+ node_name_to_buf_name: Dict[str, str]
234
+ ) -> Dict[str, BufMeta]:
235
+ buf_name_to_n_node = {}
236
+ for node_name, buf_name in node_name_to_buf_name.items():
237
+ if buf_name not in buf_name_to_n_node:
238
+ buf_name_to_n_node[buf_name] = {node_name}
239
+ else:
240
+ buf_name_to_n_node[buf_name].add(node_name)
241
+
242
+ node_name_to_buf_meta = {}
243
+ for node_name, buf_name in node_name_to_buf_name.items():
244
+ n_node = len(buf_name_to_n_node[buf_name])
245
+ node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node)
246
+ return node_name_to_buf_meta
247
+
248
+
249
+ def annotate_orig_fx_with_snodes(
250
+ gm: torch.fx.GraphModule,
251
+ snodes: SchedulerNodeList,
252
+ ) -> None:
253
+ """
254
+ Creates a FX Graph from a list of SchedulerNode objects.
255
+ """
256
+ node_name_to_buf_name: Dict[str, str] = {}
257
+ update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
258
+ if node_name_to_buf_name is None:
259
+ return
260
+ node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name)
261
+ for node in gm.graph.nodes:
262
+ if node.name in node_name_to_buf_meta:
263
+ node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name)
264
+
265
+
266
+ @contextlib.contextmanager
267
+ def enable_aot_logging() -> Iterator[None]:
268
+ compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
269
+
270
+ import torch._functorch.aot_autograd
271
+
272
+ log = logging.getLogger(torch._functorch.aot_autograd.__name__)
273
+
274
+ stack = contextlib.ExitStack()
275
+ if not compile_debug:
276
+ try:
277
+ yield
278
+ finally:
279
+ stack.close()
280
+ return
281
+
282
+ # Enable all graphs to be logged to a file by setting the flags to True
283
+ # and the log level of the file logger to DEBUG
284
+ stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
285
+
286
+ path = os.path.join(get_debug_dir(), "torchinductor")
287
+ os.makedirs(path, exist_ok=True)
288
+
289
+ fh = logging.FileHandler(
290
+ os.path.join(
291
+ path,
292
+ f"aot_{get_aot_graph_name()}_debug.log",
293
+ )
294
+ )
295
+ fh.setLevel(logging.DEBUG)
296
+ fh.setFormatter(
297
+ logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
298
+ )
299
+ log.addHandler(fh)
300
+ try:
301
+ yield
302
+ finally:
303
+ log.removeHandler(fh)
304
+ stack.close()
305
+
306
+
307
+ class DebugContext:
308
+ _counter = itertools.count()
309
+
310
+ @staticmethod
311
+ def create_debug_dir(folder_name: str) -> Optional[str]:
312
+ debug_dir = config.trace.debug_dir or get_debug_dir()
313
+ for n in DebugContext._counter:
314
+ dirname = os.path.join(
315
+ debug_dir,
316
+ "torchinductor",
317
+ f"{folder_name}.{n}",
318
+ )
319
+ if not os.path.exists(dirname):
320
+ os.makedirs(dirname)
321
+ return dirname
322
+ return None
323
+
324
+ def __init__(self) -> None:
325
+ self._prof = None
326
+ self._path = None
327
+ self._stack = contextlib.ExitStack()
328
+
329
+ def copy(self, new_path: str) -> None:
330
+ if not self._path:
331
+ return
332
+ assert new_path.endswith(".debug"), new_path
333
+ from filelock import FileLock
334
+
335
+ try:
336
+ with FileLock(f"{new_path}.lock"):
337
+ if os.path.exists(new_path):
338
+ shutil.rmtree(new_path)
339
+ shutil.copytree(self._path, new_path)
340
+ except OSError:
341
+ log.warning(
342
+ "Failed to copy debug files from %s to %s", self._path, new_path
343
+ )
344
+
345
+ def fopen(
346
+ self,
347
+ filename: str,
348
+ write_mode: str = "w",
349
+ *args: Any,
350
+ **kwargs: Any,
351
+ ) -> IO[Any]:
352
+ assert self._path
353
+ return open(os.path.join(self._path, filename), write_mode, *args, **kwargs)
354
+
355
+ @contextlib.contextmanager
356
+ def fopen_context(
357
+ self,
358
+ filename: str,
359
+ write_mode: str = "w",
360
+ *args: Any,
361
+ **kwargs: Any,
362
+ ) -> Iterator[IO[Any]]:
363
+ assert self._path
364
+ with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f:
365
+ yield f
366
+
367
+ def filename(self, suffix: str) -> str:
368
+ assert self._path
369
+ return os.path.join(self._path, suffix)
370
+
371
+ def upload_tar(self) -> None:
372
+ if config.trace.upload_tar is not None:
373
+ import tarfile
374
+
375
+ assert self._path
376
+ tar_file = os.path.join(
377
+ self._path, f"{os.path.basename(self._path)}.tar.gz"
378
+ )
379
+ with tarfile.open(tar_file, "w:gz") as tar:
380
+ tar.add(self._path, arcname=os.path.basename(self._path))
381
+ config.trace.upload_tar(tar_file)
382
+
383
+ def __enter__(self) -> None:
384
+ if config.debug:
385
+ log = logging.getLogger("torch._dynamo")
386
+ prev_level = log.level
387
+ log.setLevel(logging.DEBUG)
388
+
389
+ def reset_log_level(level: Any) -> None:
390
+ log.setLevel(level)
391
+
392
+ self._stack.callback(reset_log_level, prev_level)
393
+
394
+ self._stack.enter_context(V.set_debug_handler(self))
395
+
396
+ if not config.trace.enabled:
397
+ return
398
+
399
+ self._path = self.create_debug_dir(get_aot_graph_name()) # type: ignore[assignment]
400
+
401
+ if config.trace.debug_log:
402
+ self._setup_log_capture("debug.log", logging.DEBUG)
403
+ if config.trace.info_log:
404
+ self._setup_log_capture("info.log", logging.INFO)
405
+
406
+ def _setup_log_capture(
407
+ self,
408
+ filename: str,
409
+ level: int,
410
+ ) -> None:
411
+ log = logging.getLogger("torch._inductor")
412
+ fd = self._stack.enter_context(self.fopen(filename))
413
+ ch = logging.StreamHandler(fd)
414
+ ch.setLevel(level)
415
+ ch.setFormatter(
416
+ logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
417
+ )
418
+ log.addHandler(ch)
419
+ log.setLevel(min(log.level, level))
420
+ self._stack.callback(log.removeHandler, ch)
421
+
422
+ def __exit__(
423
+ self,
424
+ exc_type: Optional[Type[BaseException]],
425
+ exc_val: Optional[BaseException],
426
+ exc_tb: Optional[Any],
427
+ ) -> None:
428
+ if self._prof:
429
+ self._prof.disable()
430
+ self._save_profile_data()
431
+
432
+ if self._path:
433
+ self.upload_tar()
434
+ log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
435
+ self._stack.close()
436
+
437
+ def _save_profile_data(self) -> None:
438
+ assert self._prof
439
+ self._prof.dump_stats(self.filename("compile.prof"))
440
+ with self.fopen("compile.stats") as fd:
441
+ stats = pstats.Stats(self._prof, stream=fd)
442
+ stats.strip_dirs()
443
+ stats.sort_stats("cumtime")
444
+ stats.print_stats(100)
445
+ stats.sort_stats("tottime")
446
+ stats.print_stats(100)
447
+
448
+ def __getattr__(self, name: str) -> Optional[Callable[..., None]]:
449
+ if config.trace.enabled and getattr(config.trace, name):
450
+ try:
451
+ return getattr(DebugFormatter(self), name)
452
+ except Exception:
453
+ log.warning("Ignoring exception in debug code", exc_info=True)
454
+ return None
455
+ else:
456
+
457
+ def ignored(*args: Any, **kwargs: Any) -> None:
458
+ pass
459
+
460
+ return ignored
461
+
462
+
463
+ class DebugFormatter:
464
+ def __init__(self, handler: DebugContext) -> None:
465
+ self.fopen = handler.fopen
466
+ self.fopen_context = handler.fopen_context
467
+ self.filename = handler.filename
468
+ self.handler = handler
469
+
470
+ def fx_graph(
471
+ self,
472
+ gm: torch.fx.GraphModule,
473
+ inputs: List[torch.Tensor],
474
+ ) -> None:
475
+ with self.fopen("fx_graph_runnable.py") as fd:
476
+ save_graph_repro(fd, gm, inputs, "inductor")
477
+
478
+ with self.fopen("fx_graph_readable.py") as fd:
479
+ fd.write(gm.print_readable(print_output=False))
480
+
481
+ def fx_graph_transformed(
482
+ self,
483
+ gm: torch.fx.GraphModule,
484
+ inputs: List[torch.Tensor],
485
+ ) -> None:
486
+ with self.fopen("fx_graph_transformed.py") as fd:
487
+ fd.write(gm.print_readable(print_output=False))
488
+
489
+ def ir_pre_fusion(self, nodes: SchedulerNodeList) -> None:
490
+ self._write_ir("ir_pre_fusion.txt", nodes)
491
+
492
+ def ir_post_fusion(self, nodes: SchedulerNodeList) -> None:
493
+ self._write_ir("ir_post_fusion.txt", nodes)
494
+
495
+ def _write_ir(
496
+ self,
497
+ filename: str,
498
+ nodes: SchedulerNodeList,
499
+ ) -> None:
500
+ with self.fopen(filename) as fd:
501
+ log.info("Writing debug ir to %s", fd.name)
502
+ for node in nodes:
503
+ fd.write(node.debug_str())
504
+ fd.write("\n\n\n")
505
+
506
+ def graph_diagram(self, nodes: SchedulerNodeList) -> None:
507
+ draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
508
+
509
+ def draw_orig_fx_graph(
510
+ self,
511
+ gm: torch.fx.GraphModule,
512
+ nodes: SchedulerNodeList,
513
+ ) -> None:
514
+ annotate_orig_fx_with_snodes(gm, nodes)
515
+ draw_graph(
516
+ gm,
517
+ fname=self.filename("orig_fx_graph_diagram.svg"),
518
+ clear_meta=False,
519
+ prog=GRAPHVIZ_COMMAND_SCALABLE,
520
+ parse_stack_trace=True,
521
+ dot_graph_shape=config.trace.dot_graph_shape,
522
+ )
523
+
524
+ def output_code(self, filename: str) -> None:
525
+ shutil.copy(filename, self.filename("output_code.py"))
526
+
527
+ def log_autotuning_results(
528
+ self,
529
+ name: str,
530
+ input_nodes: List[ir.IRNode],
531
+ timings: Dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
532
+ elapse: float,
533
+ precompile_elapse: float,
534
+ ) -> None:
535
+ import json
536
+
537
+ from .ir import FixedLayout
538
+
539
+ def build_node_info(node: ir.IRNode) -> Dict[str, str]:
540
+ if hasattr(node, "name"):
541
+ node_name = node.name
542
+ else:
543
+ node_name = ""
544
+ node_info = {
545
+ "name": node_name,
546
+ "type": type(node).__name__,
547
+ }
548
+ try:
549
+ layout = node.get_layout()
550
+ if isinstance(layout, FixedLayout):
551
+ offset = 0
552
+ try:
553
+ offset = int(layout.offset)
554
+ except Exception:
555
+ try:
556
+ offset = V.graph.sizevars.size_hint(
557
+ layout.offset, fallback=0
558
+ )
559
+ except Exception:
560
+ pass
561
+ static_layout = FixedLayout(
562
+ layout.device,
563
+ dtype=layout.dtype,
564
+ size=list(V.graph.sizevars.size_hints(layout.size)),
565
+ stride=list(V.graph.sizevars.size_hints(layout.stride)),
566
+ offset=offset,
567
+ )
568
+ node_info["layout"] = str(static_layout)
569
+ else:
570
+ node_info["layout"] = str(node.get_layout())
571
+ except Exception as e:
572
+ pass
573
+ try:
574
+ node_info["dtype"] = str(node.get_dtype())
575
+ except Exception as e:
576
+ pass
577
+ try:
578
+ node_info["device"] = str(node.get_device())
579
+ except Exception as e:
580
+ pass
581
+ try:
582
+ node_info["stride"] = str(
583
+ V.graph.sizevars.size_hints(node.get_stride())
584
+ )
585
+ except Exception as e:
586
+ pass
587
+ try:
588
+ node_info["size"] = str(V.graph.sizevars.size_hints(node.get_size()))
589
+ except Exception as e:
590
+ pass
591
+ try:
592
+ node_info["numel"] = str(V.graph.sizevars.size_hint(node.get_numel()))
593
+ except Exception as e:
594
+ pass
595
+ if hasattr(node, "data") and isinstance(node.data, ir.IRNode):
596
+ node_info["data"] = build_node_info(node.data)
597
+ return node_info
598
+
599
+ general_properties = {
600
+ "op_name": name,
601
+ "cuda_device_name": torch.cuda.get_device_name(),
602
+ "cuda_device_count": torch.cuda.device_count(),
603
+ "input_nodes": [build_node_info(node) for node in input_nodes],
604
+ "autotuning_time": elapse,
605
+ "precompile_time": precompile_elapse,
606
+ }
607
+ with self.fopen_context(
608
+ "autotuning_result_json_list.txt", "at", encoding="utf-8"
609
+ ) as fd:
610
+ for caller, time in timings.items():
611
+ info_dict = dict(caller.info_dict())
612
+ info_dict.update(general_properties)
613
+ info_dict["benchmark_result"] = time
614
+ json.dump(info_dict, fd)
615
+ fd.write("\n")
616
+
617
+
618
+ @dataclasses.dataclass
619
+ class TensorMetadataHolder:
620
+ tensor_metadata: TensorMetadata
621
+ device: torch.device
622
+
623
+
624
+ save_args_cnt = itertools.count()
625
+
626
+
627
+ def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None:
628
+ """
629
+ This function is used to save arguments for a compile_fx_inner function call
630
+ to the file system. Later on one can replay the compile_fx_inner call
631
+ with the saved arguments using load_args_and_run_compile_fx_inner.
632
+ """
633
+
634
+ folder = "/tmp/inductor_saved_args"
635
+ if not os.path.exists(folder):
636
+ os.mkdir(folder)
637
+
638
+ def handle_tensor(x: Any) -> Any:
639
+ """
640
+ Pickle FakeTensor will result in error:
641
+ AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'
642
+
643
+ Convert all Tensor to metadata. This may also makes pickle faster.
644
+ """
645
+ if isinstance(x, torch.Tensor):
646
+ return TensorMetadataHolder(_extract_tensor_metadata(x), x.device)
647
+ else:
648
+ return x
649
+
650
+ args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs))
651
+
652
+ fn_name = "compile_fx_inner"
653
+ path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl"
654
+ with open(path, "wb") as f:
655
+ pickle.dump((args_to_save, kwargs_to_save), f)
656
+
657
+ if log.isEnabledFor(logging.DEBUG):
658
+ message = f"""
659
+ Arguments for a compile_fx_inner call is saved to {path}. To replay the call,
660
+ run the following:
661
+
662
+ from torch._inductor.debug import load_args_and_run_compile_fx_inner
663
+ load_args_and_run_compile_fx_inner({path!r})
664
+ """
665
+ # call print rather than log.debug. log.debug will print message
666
+ # prefix for each line which makes the code snippet harder to be
667
+ # copied.
668
+ # Not a big deal since the code is already been guarded by checking
669
+ # the log level.
670
+ print(message)
671
+
672
+
673
+ def load_args_and_run_compile_fx_inner(path: str) -> Any:
674
+ from torch._inductor.compile_fx import compile_fx_inner
675
+
676
+ with open(path, "rb") as f:
677
+ args, kwargs = pickle.load(f)
678
+
679
+ def handle_tensor(x: Any) -> Any:
680
+ if isinstance(x, TensorMetadataHolder):
681
+ return torch._dynamo.testing.rand_strided(
682
+ x.tensor_metadata.shape,
683
+ x.tensor_metadata.stride,
684
+ x.tensor_metadata.dtype,
685
+ x.device,
686
+ )
687
+ else:
688
+ return x
689
+
690
+ fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
691
+ with fake_mode, config.patch("save_args", False):
692
+ args, kwargs = tree_map(handle_tensor, (args, kwargs))
693
+ return compile_fx_inner(*args, **kwargs)
.venv/lib/python3.11/site-packages/torch/_inductor/decomposition.py ADDED
@@ -0,0 +1,980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ import functools
3
+ import logging
4
+ import math
5
+ import sys
6
+ import typing
7
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch._decomp as decomp
11
+ import torch._prims_common as utils
12
+ import torch.ao.quantization.fx._decomposed
13
+ from torch._decomp import (
14
+ core_aten_decompositions,
15
+ get_decompositions,
16
+ remove_decompositions,
17
+ )
18
+ from torch._decomp.decompositions import (
19
+ _grid_sampler_2d as decomp_grid_sampler_2d,
20
+ pw_cast_for_opmath,
21
+ )
22
+ from torch._decomp.decompositions_for_rng import extra_random_decomps
23
+ from torch._dynamo.utils import counters
24
+ from torch._higher_order_ops.out_dtype import out_dtype
25
+ from torch._inductor.utils import pad_listlike
26
+ from torch._prims_common import (
27
+ elementwise_dtypes,
28
+ ELEMENTWISE_TYPE_PROMOTION_KIND,
29
+ type_to_dtype,
30
+ )
31
+ from torch.fx.experimental.symbolic_shapes import definitely_true, guard_size_oblivious
32
+
33
+ from . import config, inductor_prims
34
+ from .utils import (
35
+ is_gpu,
36
+ needs_fallback_due_to_atomic_add_limitations,
37
+ use_scatter_fallback,
38
+ )
39
+
40
+
41
+ log = logging.getLogger(__name__)
42
+ aten = torch.ops.aten
43
+ prims = torch.ops.prims
44
+ quantized = torch.ops.quantized
45
+ _quantized = torch.ops._quantized
46
+ quantized_decomposed = torch.ops.quantized_decomposed
47
+
48
+ inductor_decompositions = get_decompositions(
49
+ [
50
+ aten._adaptive_avg_pool2d_backward,
51
+ aten.addmv,
52
+ aten.arange,
53
+ aten.bitwise_and_,
54
+ aten.bitwise_or_,
55
+ aten.clamp_min_,
56
+ aten.dist,
57
+ aten.empty_like,
58
+ aten.flip,
59
+ aten.gelu,
60
+ aten.hardtanh,
61
+ aten.index_select,
62
+ aten.lcm,
63
+ aten.leaky_relu,
64
+ aten.linalg_vector_norm,
65
+ aten._log_softmax,
66
+ aten.max_pool2d_with_indices_backward,
67
+ aten._native_batch_norm_legit,
68
+ aten._native_batch_norm_legit_functional,
69
+ aten._native_batch_norm_legit_no_training,
70
+ aten._batch_norm_with_update,
71
+ aten._batch_norm_with_update_functional,
72
+ aten._batch_norm_no_update,
73
+ aten.batch_norm_backward,
74
+ aten.native_batch_norm,
75
+ aten.native_group_norm,
76
+ aten.native_layer_norm,
77
+ aten.nll_loss2d_backward,
78
+ aten._softmax,
79
+ aten.sin_,
80
+ aten.sqrt_,
81
+ out_dtype,
82
+ aten._to_copy,
83
+ aten.tril_indices,
84
+ aten.triu_indices,
85
+ aten.upsample_bilinear2d.vec,
86
+ quantized.linear_dynamic_fp16_unpacked_weight,
87
+ _quantized.wrapped_quantized_linear,
88
+ ]
89
+ )
90
+ decompositions = {**core_aten_decompositions(), **inductor_decompositions}
91
+
92
+ # Remove unwanted decompositions included via the core ATen decompositions from
93
+ # the Inductor decomp table.
94
+ decomps_to_exclude = [
95
+ aten._unsafe_index,
96
+ aten._unsafe_masked_index,
97
+ aten._unsafe_masked_index_put_accumulate,
98
+ aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py
99
+ aten._softmax_backward_data,
100
+ aten.clamp_max,
101
+ aten.clamp_min,
102
+ aten.glu, # inductor lowers this directly
103
+ aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
104
+ aten.slice_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
105
+ aten.split.Tensor, # inductor lowers this directly
106
+ aten.squeeze, # inductor lowers this directly
107
+ aten.sum, # inductor lowers this directly
108
+ aten.unbind, # inductor lowers this directly
109
+ ]
110
+
111
+ remove_decompositions(decompositions, decomps_to_exclude)
112
+
113
+
114
+ def register_decomposition(
115
+ ops: List[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
116
+ ) -> Callable[..., Any]:
117
+ for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
118
+ if op in decompositions:
119
+ log.warning("duplicate decomp: %s", ops)
120
+ return decomp.register_decomposition(ops, decompositions)
121
+
122
+
123
+ # TODO: for now, inductor doesn't handle asserts
124
+ # because the condition is symbol -> tensor in the graph.
125
+ @register_decomposition([aten._assert_async.msg])
126
+ def assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
127
+ return
128
+
129
+
130
+ # Following `assert_async_msg_decomp` and implement as non-op.
131
+ @register_decomposition([aten._functional_assert_async.msg])
132
+ def functional_assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
133
+ return
134
+
135
+
136
+ @register_decomposition([aten.sym_constrain_range_for_size.default])
137
+ def sym_constrain_range_for_size(
138
+ symbol: torch.SymInt,
139
+ *,
140
+ min: Optional[torch.types.Number] = None,
141
+ max: Optional[torch.types.Number] = None,
142
+ ) -> None:
143
+ return
144
+
145
+
146
+ @register_decomposition([aten.clamp])
147
+ @pw_cast_for_opmath
148
+ def clamp(
149
+ x: torch.Tensor,
150
+ min: Optional[torch.types.Number] = None,
151
+ max: Optional[torch.types.Number] = None,
152
+ ) -> torch.Tensor:
153
+ if min is not None:
154
+ x = x.clamp_min(min)
155
+ if max is not None:
156
+ x = x.clamp_max(max)
157
+ return x
158
+
159
+
160
+ @register_decomposition([aten.full])
161
+ def full(
162
+ size: List[Union[int, torch.SymInt]],
163
+ fill_value: torch.types.Number,
164
+ **kwargs: Any,
165
+ ) -> torch.Tensor:
166
+ dtype = kwargs.get("dtype")
167
+ if dtype is None:
168
+ kwargs["dtype"] = type_to_dtype(type(fill_value))
169
+ return torch.full(size, fill_value, **kwargs)
170
+ return NotImplemented
171
+
172
+
173
+ # Not really sure how to put this into the main library. PrimTorch wants
174
+ # empty_permuted to go to the prim, and typically users don't really want
175
+ # to decompose to empty_strided (but inductor is OK with it, because we are
176
+ # cool with strides and everything goes to empty_strided)
177
+ @register_decomposition([aten.empty_permuted.default])
178
+ def empty_permuted(
179
+ size: List[Union[int, torch.SymInt]],
180
+ physical_layout: List[int],
181
+ **kwargs: Any,
182
+ ) -> torch.Tensor:
183
+ perm = [0] * len(size)
184
+ for p, l in enumerate(physical_layout):
185
+ perm[l] = p
186
+ return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm)
187
+
188
+
189
+ @register_decomposition([aten.convolution_backward])
190
+ def convolution_backward(
191
+ grad_output: torch.Tensor,
192
+ input: torch.Tensor,
193
+ weight: torch.Tensor,
194
+ bias_sizes: List[int],
195
+ stride: Union[int, List[int]],
196
+ padding: Union[int, List[int]],
197
+ dilation: Union[int, List[int]],
198
+ transposed: bool,
199
+ output_padding: List[int],
200
+ groups: int,
201
+ output_mask: List[bool],
202
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
203
+ if not output_mask[2] or not is_gpu(grad_output.device.type):
204
+ return NotImplemented
205
+ grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
206
+ grad_inp, grad_weight, _ = aten.convolution_backward(
207
+ grad_output,
208
+ input,
209
+ weight,
210
+ bias_sizes,
211
+ stride,
212
+ padding,
213
+ dilation,
214
+ transposed,
215
+ output_padding,
216
+ groups,
217
+ [output_mask[0], output_mask[1], False],
218
+ )
219
+ return (grad_inp, grad_weight, grad_bias)
220
+
221
+
222
+ @register_decomposition([aten.round.decimals])
223
+ def round_dec(x: torch.Tensor, decimals: int = 0) -> torch.Tensor:
224
+ ten_pow_decimals = 10.0**decimals
225
+ return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
226
+
227
+
228
+ @register_decomposition([aten.bmm])
229
+ @pw_cast_for_opmath
230
+ def bmm(
231
+ self: torch.Tensor,
232
+ batch2: torch.Tensor,
233
+ ) -> torch.Tensor:
234
+ if config.coordinate_descent_tuning:
235
+ if guard_size_oblivious(self.shape[1] == 1) or guard_size_oblivious(
236
+ batch2.shape[2] == 1
237
+ ):
238
+ out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2)
239
+ return out
240
+ if self.device.type == "cpu":
241
+ if guard_size_oblivious(self.size(1) == 1) and guard_size_oblivious(
242
+ batch2.size(-1) == 1
243
+ ):
244
+ counters["inductor"]["decompose_bmm"] += 1
245
+ return torch.sum(
246
+ self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True
247
+ ).unsqueeze(1)
248
+ return NotImplemented
249
+
250
+
251
+ @register_decomposition([aten.addmm])
252
+ @pw_cast_for_opmath
253
+ def addmm(
254
+ self: torch.Tensor,
255
+ mat1: torch.Tensor,
256
+ mat2: torch.Tensor,
257
+ beta: torch.types.Number = 1,
258
+ alpha: torch.types.Number = 1,
259
+ ) -> torch.Tensor:
260
+ if self.device.type == "cpu":
261
+ if guard_size_oblivious(mat1.size(0) == 1) and guard_size_oblivious(
262
+ mat2.size(-1) == 1
263
+ ):
264
+ counters["inductor"]["decompose_addmm"] += 1
265
+ out = torch.sum(
266
+ mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True
267
+ ).unsqueeze(0)
268
+ return alpha * out + beta * self
269
+ if (
270
+ guard_size_oblivious(mat1.size(0) == 1)
271
+ and definitely_true(mat2.size(0) <= 16)
272
+ and definitely_true(mat2.size(1) <= 16)
273
+ ):
274
+ counters["inductor"]["decompose_addmm"] += 1
275
+ out = (mat1.T * mat2).sum(dim=0, keepdim=True)
276
+ return alpha * out + beta * self
277
+ return NotImplemented
278
+
279
+
280
+ @register_decomposition([aten.mm])
281
+ @pw_cast_for_opmath
282
+ def mm(
283
+ self: torch.Tensor,
284
+ input2: torch.Tensor,
285
+ ) -> torch.Tensor:
286
+ # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
287
+ # todo: Look into why and fix it (hopefully)
288
+ if config.coordinate_descent_tuning:
289
+ if guard_size_oblivious(self.shape[0] == 1) or guard_size_oblivious(
290
+ input2.shape[1] == 1
291
+ ):
292
+ return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1)
293
+ if self.device.type == "cpu":
294
+ if (
295
+ guard_size_oblivious(self.size(-1) == 1)
296
+ and guard_size_oblivious(self.size(0) > 0)
297
+ and guard_size_oblivious(input2.size(0) == 1)
298
+ and (self.dtype == input2.dtype)
299
+ and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32)
300
+ ):
301
+ counters["inductor"]["decompose_mm"] += 1
302
+ return torch.cat([self[i, :] * input2 for i in range(self.size(0))])
303
+ if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious(
304
+ input2.size(-1) == 1
305
+ ):
306
+ counters["inductor"]["decompose_mm"] += 1
307
+ return torch.sum(
308
+ self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True
309
+ ).unsqueeze(0)
310
+ return NotImplemented
311
+
312
+
313
+ # This pass does two things:
314
+ # - Eliminate cat when there is only one tensor input
315
+ # - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we
316
+ # don't remove ALL empty tensors, only the naughty ones)
317
+ @register_decomposition([aten.cat.default])
318
+ def cat(
319
+ tensors: List[torch.Tensor],
320
+ dim: int = 0,
321
+ ) -> torch.Tensor:
322
+ from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
323
+
324
+ def non_empty_tensor(x: torch.Tensor) -> bool:
325
+ # For better or worse, this is a valid cat:
326
+ #
327
+ # torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)])
328
+ #
329
+ # We'd like to eliminate naughtiness like this for downstream passes
330
+ # like split_cat. The easiest way is to just drop such inputs
331
+ # (guarding that they are non-zero).
332
+ #
333
+ # Is it permissible for this filtering to be size-oblivious? A case
334
+ # where this could matter is cat([(2, 2), (u0,)], dim=0); if u0
335
+ # happened to be zero, we would have liked to have filtered it out.
336
+ # But actually, the ONLY way this could have passed is if u0 == 0,
337
+ # so by the time we get here we have already installed a deferred
338
+ # runtime assert forcing u0 to be zero. So if this hasn't happened,
339
+ # we know that the unbacked SymInt has appropriate size and there are
340
+ # no problems.
341
+ if len(x.shape) == 1 and guard_size_oblivious(x.shape[0] == 0):
342
+ return False
343
+
344
+ if dim < len(x.shape) and guard_size_oblivious(x.shape[dim] == 0):
345
+ return False
346
+
347
+ return True
348
+
349
+ filtered_tensors = list(filter(non_empty_tensor, tensors))
350
+
351
+ if len(filtered_tensors) == 1:
352
+ return filtered_tensors[0].clone()
353
+ elif 1 < len(filtered_tensors) < len(tensors):
354
+ # on the first call, when we remove empty tensors, we redispatch recursively
355
+ return aten.cat.default(filtered_tensors, dim)
356
+
357
+ # optimization, avoid concat for single, repeated input
358
+ if len(filtered_tensors) > 1 and all(
359
+ t is filtered_tensors[0] for t in filtered_tensors
360
+ ):
361
+ inp = filtered_tensors[0]
362
+ shape = list(inp.shape)
363
+ dim = dim + len(inp.shape) if dim < 0 else dim
364
+ shape.insert(dim, len(filtered_tensors))
365
+ return inp.unsqueeze(dim).expand(*shape).flatten(dim, dim + 1).clone()
366
+
367
+ # when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed)
368
+ return NotImplemented
369
+
370
+
371
+ @register_decomposition([aten.angle])
372
+ def angle(x: torch.Tensor) -> torch.Tensor:
373
+ if x.is_complex():
374
+ return torch.where(
375
+ torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real)
376
+ )
377
+
378
+ # when x is real number
379
+ # if x >= 0, return 0
380
+ # if x < 0, return pi
381
+ # if x is nan, return nan
382
+ _, dtype = elementwise_dtypes(
383
+ x,
384
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
385
+ )
386
+ pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device)
387
+ ret = torch.where(x < 0, pi, 0.0)
388
+ return torch.where(torch.isnan(x), float("nan"), ret)
389
+
390
+
391
+ @register_decomposition([aten.add])
392
+ def add(
393
+ x: torch.Tensor,
394
+ y: torch.Tensor,
395
+ *,
396
+ alpha: Optional[torch.types.Number] = None,
397
+ ) -> torch.Tensor:
398
+ # Require both x and y to be complex tensors.
399
+ x_is_complex_tensor = torch.is_tensor(x) and x.is_complex()
400
+ y_is_complex_tensor = torch.is_tensor(y) and y.is_complex()
401
+ if not x_is_complex_tensor or not y_is_complex_tensor:
402
+ return NotImplemented
403
+ z = y
404
+ if alpha is not None:
405
+ z = alpha * y
406
+ complex_type = torch.promote_types(x.dtype, y.dtype)
407
+
408
+ # For complex typed `x`, `x.view(x.real.dtype)` doubles the last dimension and can cause problem
409
+ # when broadcasting the add.
410
+ def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor:
411
+ """Reshape tensor from [*initial_dims, last_dim] to *initial_dims, last_dim/2, 2]"""
412
+ # Get the current shape of the tensor
413
+ *initial_dims, last_dim = tensor.shape
414
+
415
+ # Check if the last dimension is even. We should never reach here since `x.view(x.real.dtype)`
416
+ # doubles the last dimension for complex numbers.
417
+ if last_dim % 2 != 0:
418
+ raise AssertionError(
419
+ "The size of the last dimension must be even to reshape it to [..., last_dim/2, 2]"
420
+ )
421
+
422
+ # Reshape the tensor
423
+ new_shape = (*initial_dims, last_dim // 2, 2)
424
+ reshaped_tensor = tensor.view(new_shape)
425
+ return reshaped_tensor
426
+
427
+ x_reshaped = reshape_tensor_complex(x.view(x.real.dtype))
428
+ z_reshaped = reshape_tensor_complex(z.view(y.real.dtype))
429
+ result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type)
430
+ return result
431
+
432
+
433
+ @register_decomposition([aten.conj_physical])
434
+ def conj_physical(self: torch.Tensor) -> torch.Tensor:
435
+ assert not self.is_complex(), "TODO: implement this"
436
+ return self
437
+
438
+
439
+ @register_decomposition([aten.lift, aten.detach_])
440
+ def lift(self: torch.Tensor) -> torch.Tensor:
441
+ return self
442
+
443
+
444
+ @register_decomposition([aten.bernoulli.default])
445
+ def bernoulli(
446
+ self: torch.Tensor,
447
+ *,
448
+ generator: Optional[torch.Generator] = None,
449
+ ) -> torch.Tensor:
450
+ assert generator is None
451
+ return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
452
+
453
+
454
+ @register_decomposition([aten.fmin, prims.fmin])
455
+ def fmin(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
456
+ return torch.where(torch.isnan(other) | (other > self), self, other)
457
+
458
+
459
+ @register_decomposition([aten.fmax, prims.fmax])
460
+ def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
461
+ return torch.where(torch.isnan(other) | (other < self), self, other)
462
+
463
+
464
+ @register_decomposition(aten.amax)
465
+ def amax(
466
+ self: torch.Tensor,
467
+ dim: Optional[int] = None,
468
+ keepdim: bool = False,
469
+ ) -> torch.Tensor:
470
+ if self.dtype == torch.bool:
471
+ return torch.any(self, dim=dim, keepdim=keepdim)
472
+ return NotImplemented
473
+
474
+
475
+ @register_decomposition(aten.amin)
476
+ def amin(
477
+ self: torch.Tensor,
478
+ dim: Optional[int] = None,
479
+ keepdim: bool = False,
480
+ ) -> torch.Tensor:
481
+ if self.dtype == torch.bool:
482
+ return torch.all(self, dim=dim, keepdim=keepdim)
483
+ return NotImplemented
484
+
485
+
486
+ @register_decomposition([aten.narrow_copy])
487
+ def narrow_copy(
488
+ self: torch.Tensor,
489
+ dim: int,
490
+ start: int,
491
+ length: int,
492
+ ) -> torch.Tensor:
493
+ return torch.narrow(self, dim, start, length).clone()
494
+
495
+
496
+ @register_decomposition([aten.view_copy.default])
497
+ def view_copy_default(
498
+ self: torch.Tensor,
499
+ size: List[Union[int, torch.SymInt]],
500
+ ) -> torch.Tensor:
501
+ return aten.view(self, size).clone()
502
+
503
+
504
+ @register_decomposition([aten.view_copy.dtype])
505
+ def view_copy_dtype(
506
+ self: torch.Tensor,
507
+ dtype: torch.dtype,
508
+ ) -> torch.Tensor:
509
+ return self.to(dtype).clone()
510
+
511
+
512
+ def get_like_layout(
513
+ tensor: torch.Tensor,
514
+ memory_format: Optional[torch.memory_format] = None,
515
+ ) -> torch.memory_format:
516
+ # TODO: _to_copy tensor to stride permutation
517
+ if memory_format is torch.preserve_format or memory_format is None:
518
+ return utils.suggest_memory_format(tensor)
519
+ else:
520
+ return memory_format
521
+
522
+
523
+ @register_decomposition(aten.rand_like)
524
+ def rand_like(
525
+ self: torch.Tensor,
526
+ *,
527
+ dtype: Optional[torch.dtype] = None,
528
+ device: Optional[torch.device] = None,
529
+ memory_format: Optional[torch.memory_format] = None,
530
+ **kwargs: Any,
531
+ ) -> torch.Tensor:
532
+ return torch.rand(
533
+ [*self.size()],
534
+ dtype=dtype or self.dtype,
535
+ device=device or self.device,
536
+ **kwargs,
537
+ ).to(memory_format=get_like_layout(self, memory_format))
538
+
539
+
540
+ @register_decomposition(aten.randn_like)
541
+ def randn_like(
542
+ self: torch.Tensor,
543
+ *,
544
+ dtype: Optional[torch.dtype] = None,
545
+ device: Optional[torch.device] = None,
546
+ memory_format: Optional[torch.memory_format] = None,
547
+ **kwargs: Any,
548
+ ) -> torch.Tensor:
549
+ return torch.randn(
550
+ [*self.size()],
551
+ dtype=dtype or self.dtype,
552
+ device=device or self.device,
553
+ **kwargs,
554
+ ).to(memory_format=get_like_layout(self, memory_format))
555
+
556
+
557
+ @register_decomposition(aten.full_like)
558
+ def full_like(
559
+ self: torch.Tensor,
560
+ fill_value: Union[int, float],
561
+ *,
562
+ dtype: Optional[torch.dtype] = None,
563
+ layout: Optional[torch.layout] = None,
564
+ device: Optional[torch.device] = None,
565
+ pin_memory: bool = False,
566
+ requires_grad: bool = False,
567
+ memory_format: torch.memory_format = torch.preserve_format,
568
+ ) -> torch.Tensor:
569
+ return torch.full(
570
+ [*self.size()],
571
+ fill_value,
572
+ dtype=dtype or self.dtype,
573
+ layout=layout or self.layout,
574
+ device=device or self.device,
575
+ requires_grad=requires_grad,
576
+ ).to(memory_format=get_like_layout(self, memory_format))
577
+
578
+
579
+ @register_decomposition(aten.randint_like.default)
580
+ def randint_like(
581
+ self: torch.Tensor,
582
+ high: int,
583
+ *,
584
+ dtype: Optional[torch.dtype] = None,
585
+ device: Optional[torch.device] = None,
586
+ memory_format: Optional[torch.memory_format] = None,
587
+ **kwargs: Any,
588
+ ) -> torch.Tensor:
589
+ return aten.randint.low(
590
+ 0,
591
+ high,
592
+ [*self.size()],
593
+ dtype=dtype or self.dtype,
594
+ device=device or self.device,
595
+ **kwargs,
596
+ ).to(memory_format=get_like_layout(self, memory_format))
597
+
598
+
599
+ @register_decomposition(aten.randint_like.low_dtype)
600
+ def randint_like_low(
601
+ self: torch.Tensor,
602
+ low: int,
603
+ high: int,
604
+ *,
605
+ dtype: Optional[torch.dtype] = None,
606
+ device: Optional[torch.device] = None,
607
+ memory_format: Optional[torch.memory_format] = None,
608
+ **kwargs: Any,
609
+ ) -> torch.Tensor:
610
+ return aten.randint.low(
611
+ low,
612
+ high,
613
+ [*self.size()],
614
+ dtype=dtype or self.dtype,
615
+ device=device or self.device,
616
+ **kwargs,
617
+ ).to(memory_format=get_like_layout(self, memory_format))
618
+
619
+
620
+ @register_decomposition(aten.randint.default)
621
+ def randint(
622
+ high: int,
623
+ size: List[Union[int, torch.SymInt]],
624
+ **kwargs: Any,
625
+ ) -> torch.Tensor:
626
+ return aten.randint.low(0, high, size, **kwargs)
627
+
628
+
629
+ @register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default)
630
+ def linear_dynamic_fp16_unpacked_weight(
631
+ input: torch.Tensor,
632
+ weight: torch.Tensor,
633
+ bias: torch.Tensor,
634
+ ) -> torch.Tensor:
635
+ packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight)
636
+ return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(
637
+ input, packed_weight, bias, weight.size()[0]
638
+ )
639
+
640
+
641
+ @register_decomposition(_quantized.wrapped_quantized_linear.default)
642
+ def wrapped_quantized_linear(
643
+ input: torch.Tensor,
644
+ input_scale: torch.Tensor,
645
+ input_zero_point: torch.Tensor,
646
+ weight: torch.Tensor,
647
+ weight_scale: torch.Tensor,
648
+ weight_zero_point: torch.Tensor,
649
+ bias: torch.Tensor,
650
+ out_scale: torch.Tensor,
651
+ out_zero_point: torch.Tensor,
652
+ out_channel: int,
653
+ ) -> torch.Tensor:
654
+ packed_weight = torch.ops._quantized._wrapped_linear_prepack(
655
+ weight, weight_scale, weight_zero_point, bias
656
+ )
657
+ return torch.ops._quantized._wrapped_quantized_linear_prepacked(
658
+ input,
659
+ input_scale,
660
+ input_zero_point,
661
+ packed_weight,
662
+ out_scale,
663
+ out_zero_point,
664
+ out_channel,
665
+ )
666
+
667
+
668
+ @register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
669
+ def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor:
670
+ def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor:
671
+ x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
672
+ if sys.byteorder == "little":
673
+ return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None]
674
+ else:
675
+ return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None]
676
+
677
+ scales = bitcast_u8_to_f32(packed[..., -8:-4])
678
+ offsets = bitcast_u8_to_f32(packed[..., -4:])
679
+ return packed[..., :-8].to(torch.float32) * scales + offsets
680
+
681
+
682
+ @register_decomposition([aten.grid_sampler_2d])
683
+ @pw_cast_for_opmath
684
+ def grid_sampler_2d(
685
+ a: torch.Tensor,
686
+ grid: torch.Tensor,
687
+ interpolation_mode: int = 0,
688
+ padding_mode: int = 0,
689
+ align_corners: bool = False,
690
+ ) -> torch.Tensor:
691
+ # We do not expand the grid (_expand_grid=False) on cpu for performance reasons
692
+ # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x
693
+ # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2)
694
+ # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first.
695
+ # Thus we apply this hack to not expand the grid for this case.
696
+ _expand_grid = not (
697
+ a.device == torch.device("cpu")
698
+ and interpolation_mode == 0
699
+ and a.is_contiguous(memory_format=torch.contiguous_format)
700
+ )
701
+
702
+ output = decomp_grid_sampler_2d(
703
+ a,
704
+ grid=grid,
705
+ interpolation_mode=interpolation_mode,
706
+ padding_mode=padding_mode,
707
+ align_corners=align_corners,
708
+ _expand_grid=_expand_grid,
709
+ )
710
+ return output
711
+
712
+
713
+ @register_decomposition(aten._foreach_addcmul.Scalar)
714
+ def _foreach_addcmul_scalar(
715
+ self: List[torch.Tensor],
716
+ left_tensors: List[torch.Tensor],
717
+ right_tensors: List[torch.Tensor],
718
+ scalar: float = 1,
719
+ ) -> List[torch.Tensor]:
720
+ return aten._foreach_add.List(
721
+ self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
722
+ )
723
+
724
+
725
+ @register_decomposition(aten._foreach_addcdiv.Scalar)
726
+ def _foreach_addcdiv_scalar(
727
+ self: List[torch.Tensor],
728
+ left_tensors: List[torch.Tensor],
729
+ right_tensors: List[torch.Tensor],
730
+ scalar: float = 1,
731
+ ) -> List[torch.Tensor]:
732
+ return aten._foreach_add.List(
733
+ self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
734
+ )
735
+
736
+
737
+ @register_decomposition(aten._foreach_lerp.Scalar)
738
+ def _foreach_lerp_scalar(
739
+ start_tensors: List[torch.Tensor],
740
+ end_tensors: List[torch.Tensor],
741
+ weight: torch.types.Number,
742
+ ) -> List[torch.Tensor]:
743
+ return aten._foreach_add.List(
744
+ start_tensors,
745
+ aten._foreach_mul.Scalar(
746
+ aten._foreach_sub.List(end_tensors, start_tensors), weight
747
+ ),
748
+ )
749
+
750
+
751
+ @aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
752
+ @register_decomposition(aten.miopen_batch_norm)
753
+ def miopen_batch_norm(
754
+ input: torch.Tensor,
755
+ weight: torch.Tensor,
756
+ bias: typing.Optional[torch.Tensor],
757
+ running_mean: typing.Optional[torch.Tensor],
758
+ running_var: typing.Optional[torch.Tensor],
759
+ training: bool,
760
+ exponential_average_factor: float,
761
+ epsilon: float,
762
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
763
+ a, b, c = aten.native_batch_norm(
764
+ input,
765
+ weight,
766
+ bias,
767
+ running_mean,
768
+ running_var,
769
+ training,
770
+ exponential_average_factor,
771
+ epsilon,
772
+ )
773
+
774
+ if training:
775
+ return (a, b, c)
776
+ return (
777
+ a,
778
+ weight.new_zeros((0,)),
779
+ weight.new_zeros((0,)),
780
+ )
781
+
782
+
783
+ @functools.lru_cache(None)
784
+ def fast_random_decomps() -> Dict[Any, Callable[..., Any]]:
785
+ return {**decompositions, **extra_random_decomps}
786
+
787
+
788
+ # TODO(aakhundov): replace this (and the above) Any by more
789
+ # specific type and fix all the cascading mypy errors
790
+ def select_decomp_table() -> Dict[Any, Callable[..., Any]]:
791
+ """decomps can change based on config"""
792
+ if config.fallback_random:
793
+ return decompositions
794
+ return fast_random_decomps()
795
+
796
+
797
+ @register_decomposition(aten.masked_scatter)
798
+ def masked_scatter(
799
+ self: torch.Tensor,
800
+ mask: torch.Tensor,
801
+ source: torch.Tensor,
802
+ ) -> torch.Tensor:
803
+ from .codegen.common import BackendFeature, has_backend_feature
804
+
805
+ if has_backend_feature(self.device, BackendFeature.MASKED_SCATTER_WITH_INDEX):
806
+ # This two-step algorithm is the same as eager CUDA, for eager CPU we
807
+ # use a 1-shot serial iteration.
808
+ self, mask = aten.broadcast_tensors([self, mask])
809
+ source_idx = mask.reshape(-1).cumsum(0) - 1
810
+ self_flat, mask_flat, source_flat = (x.flatten() for x in (self, mask, source))
811
+ result = aten._unsafe_masked_index(source_flat, mask_flat, [source_idx], 0)
812
+ return torch.where(mask_flat, result, self_flat).view(self.shape)
813
+ return NotImplemented
814
+
815
+
816
+ @register_decomposition(quantized_decomposed.choose_qparams.tensor)
817
+ def choose_qparams_tensor(
818
+ input: torch.Tensor,
819
+ quant_min: int,
820
+ quant_max: int,
821
+ eps: float,
822
+ dtype: torch.dtype,
823
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
824
+ min_val, max_val = torch.aminmax(input)
825
+ scale = (max_val - min_val) / float(quant_max - quant_min)
826
+ scale = torch.max(scale, torch.Tensor([eps]))
827
+ zero_point = quant_min - torch.round(min_val / scale).to(torch.int)
828
+ zero_point = torch.clamp(zero_point, quant_min, quant_max)
829
+ return scale.to(torch.float64), zero_point.to(torch.int64)
830
+
831
+
832
+ @register_decomposition(aten.put)
833
+ def put(
834
+ self: torch.Tensor,
835
+ index: torch.Tensor,
836
+ source: torch.Tensor,
837
+ accumulate: bool = False,
838
+ ) -> torch.Tensor:
839
+ flattened = self.flatten()
840
+ flattened = torch.index_put(
841
+ flattened, [index], source.reshape(index.shape), accumulate
842
+ )
843
+ return flattened.reshape(self.shape)
844
+
845
+
846
+ @register_decomposition(aten.put_)
847
+ def put_(
848
+ self: torch.Tensor,
849
+ index: torch.Tensor,
850
+ source: torch.Tensor,
851
+ accumulate: bool = False,
852
+ ) -> torch.Tensor:
853
+ out = aten.put(self, index, source, accumulate=accumulate)
854
+ return self.copy_(out)
855
+
856
+
857
+ @register_decomposition(aten._softmax_backward_data.default)
858
+ @pw_cast_for_opmath
859
+ def _softmax_backward_data(
860
+ grad_output: torch.Tensor,
861
+ output: torch.Tensor,
862
+ dim: int,
863
+ input_dtype: torch.dtype,
864
+ ) -> torch.Tensor:
865
+ new_grad_output = grad_output * output
866
+ sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True)
867
+ # grad_input = new_grad_output - output * sum_new_grad
868
+ grad_input = inductor_prims.fma(-output, sum_new_grad, new_grad_output)
869
+
870
+ # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor
871
+ # if grad_output.device == torch.device("cpu"):
872
+ # return grad_input.contiguous()
873
+
874
+ if grad_output.dtype != input_dtype:
875
+ grad_input = grad_input.to(input_dtype)
876
+ return grad_input.contiguous()
877
+
878
+
879
+ @register_decomposition(aten.index_reduce)
880
+ def index_reduce(
881
+ self: torch.Tensor,
882
+ dim: int,
883
+ index: torch.Tensor,
884
+ src: torch.Tensor,
885
+ reduction_type: str,
886
+ *,
887
+ include_self: bool = True,
888
+ ) -> torch.Tensor:
889
+ if reduction_type == "mean" and not needs_fallback_due_to_atomic_add_limitations(
890
+ self.dtype
891
+ ):
892
+ true_division = self.dtype.is_floating_point or self.dtype.is_complex
893
+ ones = torch.ones_like(src)
894
+ if include_self:
895
+ out = self
896
+ counts = torch.ones_like(self).index_add(dim, index, ones)
897
+ else:
898
+ out = self.index_fill(dim, index, 0)
899
+ counts = torch.zeros_like(self).index_add(dim, index, ones)
900
+ counts = counts.masked_fill(counts < 1, 1)
901
+ out = out.index_add(dim, index, src)
902
+ return out / counts if true_division else out // counts
903
+
904
+ if use_scatter_fallback(
905
+ aten.scatter_reduce_.two,
906
+ reduction_type,
907
+ self.dtype,
908
+ src.dtype,
909
+ src.device.type,
910
+ True,
911
+ ):
912
+ return NotImplemented
913
+
914
+ repeats = self.shape[dim + 1 :].numel() * self.shape[:dim].numel()
915
+ index_shape = (index.numel(), *self.shape[dim + 1 :], *self.shape[:dim])
916
+ perm = (*range(self.ndim - dim, self.ndim), 0, *range(1, self.ndim - dim))
917
+ scatter_index = (
918
+ index.to(torch.int64)
919
+ .repeat_interleave(repeats)
920
+ .reshape(index_shape)
921
+ .permute(perm)
922
+ )
923
+ return self.scatter_reduce(
924
+ dim,
925
+ scatter_index,
926
+ src,
927
+ reduction_type,
928
+ include_self=include_self,
929
+ )
930
+
931
+
932
+ @register_decomposition(aten.max_pool2d_with_indices)
933
+ def max_pool2d_with_indices(
934
+ x: torch.Tensor,
935
+ kernel_size: List[int],
936
+ stride: Optional[Union[int, List[int]]] = None,
937
+ padding: Union[int, List[int]] = 0,
938
+ dilation: Union[int, List[int]] = 1,
939
+ ceil_mode: bool = False,
940
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
941
+ if dilation == 1:
942
+ dilation = [1, 1]
943
+
944
+ if padding == 0:
945
+ padding = [0, 0]
946
+
947
+ if not stride:
948
+ stride = kernel_size
949
+
950
+ kernel_size = pad_listlike(kernel_size, 2)
951
+ dilation = pad_listlike(dilation, 2)
952
+ padding = pad_listlike(padding, 2)
953
+ stride = pad_listlike(stride, 2)
954
+
955
+ window_size = kernel_size[0] * kernel_size[1]
956
+ # We fallback when using non-default dilation or when the window size is too large
957
+ if (
958
+ torch._inductor.lowering.should_fallback_max_pool2d_with_indices(
959
+ kernel_size, dilation
960
+ )
961
+ or window_size > torch.iinfo(torch.int8).max
962
+ ):
963
+ return NotImplemented
964
+
965
+ vals, offsets = prims._low_memory_max_pool2d_with_offsets(
966
+ x,
967
+ kernel_size,
968
+ stride,
969
+ padding,
970
+ dilation,
971
+ ceil_mode,
972
+ )
973
+ indices = prims._low_memory_max_pool2d_offsets_to_indices(
974
+ offsets,
975
+ kernel_size[1],
976
+ x.size(-1),
977
+ stride,
978
+ padding,
979
+ )
980
+ return vals, indices
.venv/lib/python3.11/site-packages/torch/_inductor/dependencies.py ADDED
@@ -0,0 +1,745 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import abc
3
+ import dataclasses
4
+ import itertools
5
+ import logging
6
+ import re
7
+ import typing
8
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
9
+ from unittest.mock import patch
10
+
11
+ import sympy
12
+
13
+ import torch
14
+ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
15
+ from torch.utils._ordered_set import OrderedSet
16
+
17
+ from .codegen.common import index_prevent_reordering
18
+ from .utils import (
19
+ get_dtype_size,
20
+ reduction_num_outputs,
21
+ sympy_index_symbol,
22
+ sympy_str,
23
+ sympy_subs,
24
+ VarRanges,
25
+ )
26
+ from .virtualized import OpsHandler, ReductionType, V
27
+
28
+
29
+ log = logging.getLogger(__name__)
30
+ is_indirect = re.compile(r"indirect|tmp").search
31
+
32
+
33
+ class Dep(abc.ABC):
34
+ name: str
35
+ index: sympy.Expr
36
+
37
+ @abc.abstractmethod
38
+ def rename(self, renames: Dict[str, str]) -> "Dep":
39
+ pass
40
+
41
+ @abc.abstractmethod
42
+ def get_numel(self) -> sympy.Expr:
43
+ pass
44
+
45
+ @abc.abstractmethod
46
+ def numbytes_hint(self):
47
+ pass
48
+
49
+ @abc.abstractmethod
50
+ def has_unbacked_symbols(self) -> bool:
51
+ pass
52
+
53
+ @abc.abstractmethod
54
+ def is_contiguous(self) -> bool:
55
+ pass
56
+
57
+ def normalize_with_stride_order(self, prefix="t"):
58
+ return self
59
+
60
+
61
+ @dataclasses.dataclass(frozen=True)
62
+ class MemoryDep(Dep):
63
+ name: str
64
+ index: sympy.Expr
65
+ var_names: Tuple[sympy.Symbol, ...]
66
+ size: Tuple[sympy.Expr, ...]
67
+ mode: Optional[str] = None
68
+
69
+ def __repr__(self) -> str:
70
+ return f"MemoryDep({self.name!r}, {self.index}, {self.ranges}, {self.mode})"
71
+
72
+ @property
73
+ def num_vars(self):
74
+ return len(self.var_names)
75
+
76
+ def decide_loop_order_to_match(self, other):
77
+ """
78
+ Can return None if not able to decide loop orders.
79
+ """
80
+ assert self.num_vars == other.num_vars
81
+
82
+ # ignore broadcast for now since broadcast causes extra 0 strides
83
+ # which makes it hard to decide the correct loop orders.
84
+ if self.num_vars != len(self.index.free_symbols):
85
+ return None
86
+ if other.num_vars != len(other.index.free_symbols):
87
+ return None
88
+
89
+ # bail out if any size is 0 or 1
90
+ # For size == 0, it's an empty tensor, any strides for that dimension
91
+ # are equivalent. Skip for simplicity and it may not matter that much.
92
+ #
93
+ # For size == 1, it cause cause tie for strides of different dimensions.
94
+ # Also when we first time create LoopBody in ComputedBuffer.simplify_and_reorder
95
+ # we can dependencies.index_vars_squeeze which should already sqeeuze
96
+ # the size == 1 dimensions.
97
+ if any(s == 0 or s == 1 for s in itertools.chain(self.size, other.size)):
98
+ return None
99
+
100
+ # Extract strides for both expression
101
+ self_strides = V.graph.sizevars.stride_hints(self.index, self.var_names)
102
+ other_strides = V.graph.sizevars.stride_hints(other.index, other.var_names)
103
+
104
+ # Even if the shape contains no 0/1, some complex index expression may
105
+ # still have duplicate stride values. Here is an example:
106
+ # https://gist.github.com/shunting314/511a7e1ec88aa2e1a8ec85d8445ab129
107
+ # We don't reorder the loop for these cases for now, but in theory
108
+ # we could improve the algorithm to detect the correct loop orders.
109
+ if len(set(self_strides)) != len(self_strides) or len(
110
+ set(other_strides)
111
+ ) != len(other_strides):
112
+ log.debug(
113
+ "unable to decide loop order. self_dep=%s v.s. other_dep=%s, self_strides=%s v.s. other_strides=%s",
114
+ self,
115
+ other,
116
+ self_strides,
117
+ other_strides,
118
+ )
119
+ return None
120
+
121
+ # May hanppen if self and other are as follows
122
+ # MemoryDep('addmm_6', 393216*d0 + 768*d1 + d2, {d0: 16, d1: 512, d2: 768}, None)
123
+ # MemoryDep('addmm_6', 98304*d0 + d1 + 768*d2, {d0: 64, d1: 768, d2: 128}, None)
124
+ if set(self_strides) != set(other_strides):
125
+ return None
126
+
127
+ stride_to_index = {s: i for i, s in enumerate(self_strides)}
128
+ order = []
129
+ for s in other_strides:
130
+ order.append(stride_to_index[s])
131
+
132
+ assert set(order) == set(range(0, self.num_vars))
133
+ return order
134
+
135
+ def get_offset(self):
136
+ """
137
+ Return the offset by setting every variable to be 0.
138
+ """
139
+ return sympy_subs(self.index, dict.fromkeys(self.var_names, 0))
140
+
141
+ def normalize(self) -> "MemoryDep":
142
+ """
143
+ Normalize by merging loops. The different to normalize_with_stride_order is,
144
+ this method does not reorder loops while normalize_with_stride_order reorder
145
+ loops based on stride order.
146
+ """
147
+ return MemoryDep(
148
+ self.name,
149
+ *_RecordLoadStoreInner._normalize(self.index, self.ranges), # type: ignore[arg-type]
150
+ self.mode,
151
+ )
152
+
153
+ def normalize_with_stride_order(self, prefix="t"):
154
+ r"""
155
+ Used to decide if two MemoryDep does not equal due to different loop orders.
156
+ More specifically, when dep1 and dep2 are not equal, we can normalize
157
+ both and check if they are equal after that. If yes, then the mismatch is
158
+ caused by different loop orders.
159
+ """
160
+ # import here to avoid circular import
161
+ from torch._inductor import ir
162
+
163
+ strides = V.graph.sizevars.stride_hints(self.index, self.var_names)
164
+
165
+ # pick a loop order with stride ordered decreasingly
166
+ order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
167
+ stride_reorder = ir.same_reorder(order)
168
+ sizes = self.size
169
+ var_names = self.var_names
170
+
171
+ new_reordered_sizes = stride_reorder(sizes)
172
+ new_reordered_var_names = stride_reorder(var_names)
173
+
174
+ new_simplified_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
175
+ new_reordered_var_names,
176
+ new_reordered_sizes,
177
+ index_prevent_reordering(
178
+ [self.index], new_reordered_var_names, new_reordered_sizes
179
+ ),
180
+ )
181
+
182
+ # now let's create new symbols with the passed in prefix
183
+ var_ranges, add_var = var_builder(prefix)
184
+ replacement = dict(
185
+ zip(
186
+ new_reordered_var_names,
187
+ reindex([add_var(x) for x in new_simplified_sizes]),
188
+ )
189
+ )
190
+ new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR
191
+
192
+ out = MemoryDep(self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())) # type: ignore[arg-type]
193
+ return out
194
+
195
+ @property
196
+ def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]:
197
+ """{c0: 128, c1: 512, ...}"""
198
+ return dict(zip(self.var_names, self.size))
199
+
200
+ def get_numel(self) -> sympy.Expr:
201
+ if self.is_indirect():
202
+ numel = V.graph.get_numel(self.name)
203
+ else:
204
+ vars: OrderedSet[sympy.Basic] = OrderedSet(self.index.free_symbols)
205
+ numel = sympy.Integer(1)
206
+ for var, size in zip(self.var_names, self.size):
207
+ if var in vars:
208
+ numel = numel * size
209
+ return numel # type: ignore[return-value]
210
+
211
+ def rename(self, renames: Dict[str, str]) -> "MemoryDep":
212
+ if self.name in renames:
213
+ return MemoryDep(
214
+ renames[self.name],
215
+ self.index,
216
+ var_names=self.var_names,
217
+ size=self.size,
218
+ mode=self.mode,
219
+ )
220
+ return self
221
+
222
+ def numbytes_hint(self):
223
+ return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
224
+ V.graph.get_dtype(self.name)
225
+ )
226
+
227
+ def has_unbacked_symbols(self):
228
+ return len(free_unbacked_symbols(self.get_numel())) > 0
229
+
230
+ def is_contiguous(self) -> bool:
231
+ return isinstance(self.index, sympy.Symbol) and self.index in self.var_names
232
+
233
+ def stride1_for_last_dim(self, result_for_complex_expression=True) -> bool:
234
+ """
235
+ Whether the stride for the last dimension is 1.
236
+ """
237
+ # python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_masked_scatter_cuda_float16
238
+ # will exercise thru this corner case.
239
+ if len(self.var_names) == 0:
240
+ return True
241
+
242
+ terms = self.index.args if isinstance(self.index, sympy.Add) else [self.index]
243
+
244
+ last_sym = self.var_names[-1]
245
+ for term in terms:
246
+ if term is last_sym:
247
+ return True
248
+
249
+ # Having a >1 stride for the last dimension is bad for perf
250
+ # return False.
251
+ if (
252
+ isinstance(term, sympy.Mul)
253
+ and len(term.args) == 2
254
+ and term.args[1] is last_sym
255
+ and isinstance(term.args[0], (int, sympy.Integer))
256
+ and term.args[0] > 1
257
+ ):
258
+ return False
259
+
260
+ return result_for_complex_expression
261
+
262
+ def is_scalar(self) -> bool:
263
+ if isinstance(self.index, sympy.Symbol):
264
+ return self.index not in self.var_names and not self.is_indirect()
265
+ return isinstance(self.index, (int, sympy.Integer))
266
+
267
+ def is_indirect(self) -> bool:
268
+ return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined]
269
+
270
+
271
+ @dataclasses.dataclass(frozen=True)
272
+ class StarDep(Dep):
273
+ name: str
274
+ mode: Optional[str] = None
275
+
276
+ # depends on the entire buffer
277
+ @property
278
+ def index(self):
279
+ raise NotImplementedError("StarDep does not have an index")
280
+
281
+ def get_numel(self) -> sympy.Expr:
282
+ return V.graph.get_numel(self.name) # type: ignore[return-value]
283
+
284
+ def rename(self, renames: Dict[str, str]) -> "StarDep":
285
+ if self.name in renames:
286
+ return StarDep(renames[self.name], self.mode)
287
+ return self
288
+
289
+ def numbytes_hint(self):
290
+ return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
291
+ V.graph.get_dtype(self.name)
292
+ )
293
+
294
+ def has_unbacked_symbols(self):
295
+ return len(free_unbacked_symbols(self.get_numel())) > 0
296
+
297
+ def is_contiguous(self) -> bool:
298
+ return False
299
+
300
+ def is_scalar(self) -> bool:
301
+ return False
302
+
303
+ def is_indirect(self) -> bool:
304
+ return False
305
+
306
+
307
+ # Used for tracking mutation ordering
308
+ # if A reads a buffer and B mutates it
309
+ # B must be ordered after A
310
+ #
311
+ # This is useful for a variety of reasons.
312
+ # For example, if A's read is never actually used, we can eliminate it.
313
+ # Another case is if A's buffer ends up being fused away, we never need to
314
+ # materialize that buffer
315
+ @dataclasses.dataclass(frozen=True)
316
+ class WeakDep(Dep):
317
+ # Fake dependency on unused buffer
318
+ name: str
319
+ # Buffer that is doing the mutation
320
+ mutating_buf: str
321
+
322
+ @property
323
+ def index(self):
324
+ raise NotImplementedError("WeakDep does not have an index")
325
+
326
+ def get_numel(self) -> sympy.Expr:
327
+ return sympy.Integer(1)
328
+
329
+ def rename(self, renames: Dict[str, str]) -> "WeakDep":
330
+ if self.name in renames:
331
+ return WeakDep(renames[self.name], self.mutating_buf)
332
+ return self
333
+
334
+ def numbytes_hint(self):
335
+ return 1 # Purely inserted for ordering, not an actual dep
336
+
337
+ def has_unbacked_symbols(self):
338
+ return False
339
+
340
+ def is_contiguous(self) -> bool:
341
+ return False
342
+
343
+
344
+ @dataclasses.dataclass(frozen=True)
345
+ class IndexExprDep:
346
+ index: sympy.Expr # type: ignore[assignment]
347
+ var_names: Tuple[sympy.Symbol, ...]
348
+ size: Tuple[sympy.Expr, ...]
349
+
350
+
351
+ @dataclasses.dataclass
352
+ class ReadWrites:
353
+ reads: OrderedSet[Dep]
354
+ writes: OrderedSet[Dep]
355
+ index_exprs: OrderedSet[IndexExprDep]
356
+ range_vars: Optional[List[sympy.Expr]] = None
357
+ var_ranges: Optional[VarRanges] = None
358
+
359
+ def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
360
+ return ReadWrites(
361
+ OrderedSet(dep.rename(renames) for dep in self.reads),
362
+ OrderedSet(dep.rename(renames) for dep in self.writes),
363
+ self.index_exprs,
364
+ self.range_vars,
365
+ self.var_ranges,
366
+ )
367
+
368
+ def with_read(self, dep: Union[Dep, Set[Dep]]) -> "ReadWrites":
369
+ assert isinstance(dep, (WeakDep, StarDep, set))
370
+ if not isinstance(dep, set):
371
+ dep = {dep}
372
+ return ReadWrites(
373
+ OrderedSet.union(self.reads, dep),
374
+ self.writes,
375
+ self.index_exprs,
376
+ self.range_vars,
377
+ self.var_ranges,
378
+ )
379
+
380
+ def merge(self, other: "ReadWrites"):
381
+ reads = OrderedSet.union(self.reads, other.reads)
382
+ writes = OrderedSet.union(self.writes, other.writes)
383
+ index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs)
384
+ return ReadWrites(reads - writes, writes, index_exprs)
385
+
386
+ @staticmethod
387
+ def merge_list(read_writes: List["ReadWrites"]):
388
+ all_writes = OrderedSet.union(*[rw.writes for rw in read_writes])
389
+ all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes
390
+ all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes])
391
+ return ReadWrites(all_reads, all_writes, all_index_exprs)
392
+
393
+ def remove_reads(self, rem_reads):
394
+ return ReadWrites(
395
+ self.reads - rem_reads,
396
+ self.writes,
397
+ self.index_exprs,
398
+ self.range_vars,
399
+ self.var_ranges,
400
+ )
401
+
402
+ def reads_and_writes(self):
403
+ return itertools.chain(self.reads, self.writes)
404
+
405
+ def buffer_names(self, ignore_integer_index=True):
406
+ """
407
+ Integer index is used for load_seed.
408
+ """
409
+ names: OrderedSet[str] = OrderedSet()
410
+ for dep in self.reads_and_writes():
411
+ if not isinstance(dep, MemoryDep):
412
+ continue
413
+ if not ignore_integer_index or not isinstance(
414
+ dep.index, (int, sympy.Integer)
415
+ ):
416
+ names.add(dep.name)
417
+ return names
418
+
419
+
420
+ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
421
+ def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:
422
+ super().__init__()
423
+ self._reads: OrderedSet[Dep] = OrderedSet()
424
+ self._writes: OrderedSet[MemoryDep] = OrderedSet()
425
+ self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet()
426
+ self._var_ranges: VarRanges = var_ranges
427
+ self._should_normalize: bool = normalize
428
+
429
+ @staticmethod
430
+ def drop_unused_symbols(index, var_names, sizes):
431
+ """
432
+ Reduction has last (reduced) dim in its sizes, but
433
+ downstream users won't. Normalize this away.
434
+ """
435
+ if not isinstance(index, sympy.Expr):
436
+ # index can be an int
437
+ return
438
+ free_symbols = index.free_symbols
439
+ while var_names and var_names[-1] not in free_symbols:
440
+ var_names.pop()
441
+ sizes.pop()
442
+
443
+ @classmethod
444
+ def _normalize(
445
+ cls, index: sympy.Expr, var_ranges: VarRanges
446
+ ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]:
447
+ # Try to further simplify the indexes even if simplify_loops didn't
448
+ # convert it to the simplest form because of the interference from
449
+ # different indexing formulas.
450
+ index_vars = [*var_ranges.keys()]
451
+ sizes = tuple(var_ranges.values()) # type: ignore[assignment]
452
+ new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
453
+ index_vars,
454
+ sizes,
455
+ index_prevent_reordering([index], index_vars, sizes),
456
+ )
457
+
458
+ # assign new variables each dimension to deal with numbering mismatches
459
+ # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
460
+ new_vars, add_var = var_builder(canonicalization_prefix())
461
+ replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
462
+ index = sympy_subs(sympy.expand(index), replacement)
463
+
464
+ new_vars = [*new_vars.keys()]
465
+ new_sizes = [*new_sizes]
466
+ cls.drop_unused_symbols(index, new_vars, new_sizes)
467
+ return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type]
468
+
469
+ def canonicalize(
470
+ self, index: sympy.Expr
471
+ ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]:
472
+ if not self._should_normalize:
473
+ sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()]
474
+ var_names = [k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1]
475
+ sizes = [v for v in sizes if v != 1]
476
+
477
+ self.drop_unused_symbols(index, var_names, sizes)
478
+
479
+ return index, tuple(var_names), tuple(sizes) # type: ignore[return-value, arg-type]
480
+ var_ranges = {
481
+ k: V.graph.sizevars.simplify(v)
482
+ for k, v in self._var_ranges.items()
483
+ # TODO(jansel): explore this further normalization
484
+ # if k in free_symbols
485
+ }
486
+ return self._normalize(index, var_ranges)
487
+
488
+ def load(self, name: str, index: sympy.Expr) -> str:
489
+ self._reads.add(MemoryDep(name, *self.canonicalize(index)))
490
+ return f"load({name}, {sympy_str(index)})"
491
+
492
+ def load_seed(self, name: str, index: int):
493
+ assert isinstance(index, int)
494
+ return self.load(name, sympy.Integer(index))
495
+
496
+ def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str:
497
+ self._writes.add(MemoryDep(name, *self.canonicalize(index), mode=mode))
498
+ return f"store({name}, {sympy_str(index)}, {value}, {mode})"
499
+
500
+ def store_reduction(self, name: str, index, value) -> str:
501
+ return self.store(name, index, f"store_reduction({value})")
502
+
503
+ def index_expr(self, index: sympy.Expr, dtype) -> str:
504
+ self._index_exprs.add(IndexExprDep(*self.canonicalize(index)))
505
+ return f"index_expr({sympy_str(index)}, {dtype})"
506
+
507
+ def bucketize(
508
+ self,
509
+ values,
510
+ offsets_name: str,
511
+ offsets_size: sympy.Expr,
512
+ indexing_dtype: torch.dtype,
513
+ right: bool,
514
+ ):
515
+ self._reads.add(StarDep(offsets_name))
516
+ return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})"
517
+
518
+
519
+ class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined]
520
+ def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:
521
+ parent_handler = _RecordLoadStoreInner(
522
+ var_ranges=var_ranges, normalize=normalize
523
+ )
524
+ super().__init__(parent_handler=parent_handler)
525
+
526
+
527
+ # TODO: check call sites
528
+ def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]:
529
+ cnt = itertools.count()
530
+ var_ranges: VarRanges = {}
531
+
532
+ def add_var(length: sympy.Expr) -> sympy.Symbol:
533
+ v = sympy_index_symbol(f"{prefix}{next(cnt)}")
534
+ var_ranges[v] = length
535
+ return v
536
+
537
+ return var_ranges, add_var
538
+
539
+
540
+ def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
541
+ var_ranges, add_var = var_builder(prefix)
542
+ args: List[List[sympy.Symbol]] = []
543
+ for size in argsizes:
544
+ args.append(list(map(add_var, size)))
545
+ return args, var_ranges
546
+
547
+
548
+ def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"):
549
+ from .ir import SqueezeView
550
+
551
+ var_ranges, add_var = var_builder(prefix)
552
+ args: List[List[sympy.Expr]] = []
553
+ new_sizes: List[List[sympy.Expr]] = []
554
+ for size in argsizes:
555
+ new_size, reindex = SqueezeView.squeezer(size)
556
+ new_sizes.append(new_size)
557
+ args.append(reindex(list(map(add_var, new_size))))
558
+ return args, var_ranges
559
+
560
+
561
+ def extract_read_writes(
562
+ fn: Callable[..., Any],
563
+ *argsizes: Tuple[sympy.Expr, ...],
564
+ normalize: bool = False,
565
+ prefix: str = "d",
566
+ hidden_args=(),
567
+ ):
568
+ args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
569
+
570
+ from .loop_body import LoopBody, MemoryUsageType
571
+
572
+ if isinstance(fn, LoopBody):
573
+ # Fast path to avoid tracing when we already have a LoopBody
574
+ inner = _RecordLoadStoreInner(var_ranges=var_ranges, normalize=normalize)
575
+ name_to_index = fn.indexing_from_args([*args, *hidden_args])
576
+ if fn.indirect_vars:
577
+ # mimic the `tmpX` naming tracing gives us
578
+ repl = {v: sympy.Symbol(f"tmp{i}") for i, v in enumerate(fn.indirect_vars)}
579
+ name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()}
580
+ for entry in fn.memory_usage[MemoryUsageType.LOAD]:
581
+ inner.load(entry.buffer_name, name_to_index[entry.index_name])
582
+ for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]:
583
+ inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name]))
584
+ for entry in fn.memory_usage[MemoryUsageType.STORE]:
585
+ inner.store(
586
+ entry.buffer_name, name_to_index[entry.index_name], None, entry.mode
587
+ )
588
+ for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]:
589
+ inner.store_reduction(
590
+ entry.buffer_name, name_to_index[entry.index_name], None
591
+ )
592
+ for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
593
+ inner.index_expr(name_to_index[entry.index_name], None)
594
+ for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]:
595
+ inner.bucketize(
596
+ None, entry.buffer_name, name_to_index[entry.index_name], None, None
597
+ )
598
+ # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
599
+ else:
600
+ # Slow path tracing the function
601
+ rw = RecordLoadStore(var_ranges, normalize=normalize)
602
+ with V.set_ops_handler(rw):
603
+ fn(*args, *hidden_args)
604
+ inner = rw.parent_handler
605
+
606
+ if normalize:
607
+ range_vars = [] # Number of vars could differ due to normalization
608
+ else:
609
+ range_vars = [*itertools.chain.from_iterable(args)]
610
+
611
+ return ReadWrites(
612
+ OrderedSet(inner._reads),
613
+ OrderedSet(inner._writes),
614
+ inner._index_exprs,
615
+ range_vars,
616
+ var_ranges,
617
+ )
618
+
619
+
620
+ def extract_input_node_reduction_ranges(
621
+ input_node: "torch._inductor.ir.TensorBox",
622
+ ) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
623
+ """
624
+ Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
625
+ It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
626
+ In this case, reduction_sizes of the Reduction nodes need to be the same.
627
+ Otherwise returns (None, None).
628
+ """
629
+
630
+ from .ir import ComputedBuffer, Loops
631
+
632
+ if isinstance(input_node.data, ComputedBuffer):
633
+ # Input node has already been realized. Return its size and reduction_size.
634
+ size = input_node.get_size()
635
+ reduction_size = input_node.get_reduction_size()
636
+ if len(reduction_size) > 0:
637
+ return (size, reduction_size)
638
+ else:
639
+ return (None, None)
640
+
641
+ if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined]
642
+ # Other IRNodes do not have reduction_ranges.
643
+ return (None, None)
644
+
645
+ # There is one issue: what if there are views / permutations between the input node and its dependent realized nodes?
646
+ # The current method still uses reduction ranges from the dependent realized node, which is not ideal.
647
+ # Is there a way to check whether there are permutations inbetween?
648
+ reads = input_node.get_reads()
649
+ reduction_size = None
650
+ size = None
651
+ while reduction_size is None and len(reads) > 0:
652
+ seen: OrderedSet[str] = OrderedSet()
653
+ new_reads = []
654
+ for read in reads:
655
+ if not isinstance(read, MemoryDep):
656
+ continue
657
+ if read.name in seen:
658
+ continue
659
+ seen.add(read.name)
660
+ buffer = V.graph.try_get_buffer(read.name)
661
+ if buffer is None:
662
+ continue
663
+ op = buffer.get_defining_op()
664
+ if op is None:
665
+ continue
666
+
667
+ if isinstance(op, ComputedBuffer) and len(op.get_reduction_size()) > 0:
668
+ if reduction_size is None:
669
+ reduction_size = op.get_reduction_size()
670
+ size = op.get_size()
671
+ elif reduction_size != op.get_reduction_size() or size != op.get_size():
672
+ return (None, None)
673
+ else:
674
+ new_reads.extend(op.get_reads())
675
+ if reads == new_reads:
676
+ return (size, reduction_size)
677
+ else:
678
+ reads = new_reads
679
+ return (size, reduction_size)
680
+
681
+
682
+ def canonicalization_prefix():
683
+ return "c"
684
+
685
+
686
+ # ops handler which computes all the free unbacked symbols for an IR
687
+ class FreeUnbackedSymbolsOpsHandler:
688
+ symbols: OrderedSet[sympy.Symbol]
689
+
690
+ def __init__(self) -> None:
691
+ self.symbols = OrderedSet()
692
+
693
+ def __getattr__(self, name: str) -> Callable[..., Any]:
694
+ def inner(*args, **kwargs):
695
+ for a in itertools.chain(args, kwargs.values()):
696
+ if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
697
+ self.symbols |= free_unbacked_symbols(a)
698
+
699
+ return inner
700
+
701
+ def indirect_indexing(
702
+ self, index_var, size, check=True, wrap_neg=True
703
+ ) -> sympy.Symbol:
704
+ assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean))
705
+ self.symbols |= free_unbacked_symbols(size)
706
+ return sympy_index_symbol(f"({str(index_var)})")
707
+
708
+ def frexp(self, x):
709
+ return (None,) * 2
710
+
711
+ def scan(self, dtypes, combine_fn, values):
712
+ return (None,) * len(values)
713
+
714
+ def sort(self, dtypes, values, stable, descending):
715
+ return (None,) * len(values)
716
+
717
+ def reduction(
718
+ self,
719
+ dtype: torch.dtype,
720
+ src_dtype: torch.dtype,
721
+ reduction_type: ReductionType,
722
+ value: Union[None, Tuple[None, ...]],
723
+ ) -> Union[None, Tuple[None, ...]]:
724
+ num_values = reduction_num_outputs(reduction_type)
725
+ return (None,) * num_values if num_values > 1 else None
726
+
727
+
728
+ def _typecheck_FreeUnbackedSymbolsOpsHandler(
729
+ h: FreeUnbackedSymbolsOpsHandler,
730
+ ) -> OpsHandler[None]:
731
+ return h
732
+
733
+
734
+ def extract_free_unbacked_symbols(fn: Callable[..., Any], index, rindex=None):
735
+ from .ir import FlexibleLayout
736
+
737
+ args = [index, rindex] if rindex is not None else [index]
738
+ handler = FreeUnbackedSymbolsOpsHandler()
739
+ # NB: I cargo culted the allow_indexing patch here, I don't understand why
740
+ # people do this all over
741
+ with V.set_ops_handler(handler), patch.object(
742
+ FlexibleLayout, "allow_indexing", True
743
+ ):
744
+ fn(*args)
745
+ return handler.symbols
.venv/lib/python3.11/site-packages/torch/_inductor/exc.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ import tempfile
6
+ import textwrap
7
+ from functools import lru_cache
8
+
9
+
10
+ if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1":
11
+
12
+ @lru_cache(None)
13
+ def _record_missing_op(target):
14
+ with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd:
15
+ fd.write(str(target) + "\n")
16
+
17
+ else:
18
+
19
+ def _record_missing_op(target): # type: ignore[misc]
20
+ pass
21
+
22
+
23
+ class OperatorIssue(RuntimeError):
24
+ @staticmethod
25
+ def operator_str(target, args, kwargs):
26
+ lines = [f"target: {target}"] + [
27
+ f"args[{i}]: {arg}" for i, arg in enumerate(args)
28
+ ]
29
+ if kwargs:
30
+ lines.append(f"kwargs: {kwargs}")
31
+ return textwrap.indent("\n".join(lines), " ")
32
+
33
+
34
+ class MissingOperatorWithoutDecomp(OperatorIssue):
35
+ def __init__(self, target, args, kwargs) -> None:
36
+ _record_missing_op(target)
37
+ super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
38
+
39
+
40
+ class MissingOperatorWithDecomp(OperatorIssue):
41
+ def __init__(self, target, args, kwargs) -> None:
42
+ _record_missing_op(target)
43
+ super().__init__(
44
+ f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
45
+ + textwrap.dedent(
46
+ f"""
47
+
48
+ There is a decomposition available for {target} in
49
+ torch._decomp.get_decompositions(). Please add this operator to the
50
+ `decompositions` list in torch._inductor.decomposition
51
+ """
52
+ )
53
+ )
54
+
55
+
56
+ class LoweringException(OperatorIssue):
57
+ def __init__(self, exc: Exception, target, args, kwargs) -> None:
58
+ super().__init__(
59
+ f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"
60
+ )
61
+
62
+
63
+ class SubgraphLoweringException(RuntimeError):
64
+ pass
65
+
66
+
67
+ class InvalidCxxCompiler(RuntimeError):
68
+ def __init__(self) -> None:
69
+ from . import config
70
+
71
+ super().__init__(
72
+ f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}"
73
+ )
74
+
75
+
76
+ class CppWrapperCodeGenError(RuntimeError):
77
+ def __init__(self, msg: str) -> None:
78
+ super().__init__(f"C++ wrapper codegen error: {msg}")
79
+
80
+
81
+ class CppCompileError(RuntimeError):
82
+ def __init__(self, cmd: list[str], output: str) -> None:
83
+ if isinstance(output, bytes):
84
+ output = output.decode("utf-8")
85
+
86
+ super().__init__(
87
+ textwrap.dedent(
88
+ """
89
+ C++ compile error
90
+
91
+ Command:
92
+ {cmd}
93
+
94
+ Output:
95
+ {output}
96
+ """
97
+ )
98
+ .strip()
99
+ .format(cmd=" ".join(cmd), output=output)
100
+ )
101
+
102
+
103
+ class CUDACompileError(CppCompileError):
104
+ pass
.venv/lib/python3.11/site-packages/torch/_inductor/extern_node_serializer.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List
3
+
4
+ from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node
5
+ from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder
6
+ from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode
7
+
8
+
9
+ def serialize_extern_kernel_node(
10
+ extern_kernel_node: inductor_ExternKernelNode,
11
+ ) -> ExternKernelNode:
12
+ assert isinstance(extern_kernel_node.node, Node)
13
+ return ExternKernelNode(
14
+ name=extern_kernel_node.name,
15
+ node=extern_kernel_node.node,
16
+ )
17
+
18
+
19
+ def extern_node_json_serializer(
20
+ extern_kernel_nodes: List[inductor_ExternKernelNode],
21
+ ) -> str:
22
+ serialized_nodes = ExternKernelNodes(
23
+ nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes]
24
+ )
25
+ return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder)
.venv/lib/python3.11/site-packages/torch/_inductor/freezing.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ import itertools
5
+ import logging
6
+ import weakref
7
+ from typing import Any, List, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.utils._pytree as pytree
11
+ from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code
12
+ from torch._functorch.aot_autograd import MutationType
13
+ from torch._functorch.compile_utils import fx_graph_cse
14
+ from torch._inductor.constant_folding import constant_fold, replace_node_with_constant
15
+ from torch._inductor.fx_passes.freezing_patterns import freezing_passes
16
+ from torch._inductor.fx_passes.post_grad import view_to_reshape
17
+
18
+ from . import config
19
+
20
+
21
+ aten = torch.ops.aten
22
+ prims = torch.ops.prims
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+
27
+ def replace_params_with_constants(
28
+ gm: torch.fx.GraphModule,
29
+ flat_params: list[Any],
30
+ fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta,
31
+ ) -> List[int]:
32
+ """
33
+ Replaces the parameters of a PyTorch GraphModule with constants wherever possible.
34
+ Returns a list of indices representing the input parameters that were not converted to constants.
35
+ """
36
+ params = gm.graph.find_nodes(op="placeholder")
37
+ fake_inp_nodes = params[: len(params)]
38
+ preserved_arg_indices = []
39
+ aliased_input_args = [
40
+ out_info.base_idx
41
+ for out_info in fw_metadata.output_info
42
+ if out_info.base_idx is not None
43
+ ]
44
+
45
+ # TODO (tmanlaibaatar) figure out why this is different
46
+ # from mutated_inp_runtime_indices
47
+ mutated_inps = [
48
+ i
49
+ for i, m in enumerate(fw_metadata.input_info)
50
+ if m.mutation_type
51
+ in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH)
52
+ ]
53
+
54
+ for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)):
55
+ if i in mutated_inps or i in aliased_input_args:
56
+ preserved_arg_indices.append(i)
57
+ continue
58
+ replace_node_with_constant(gm, node, real_input)
59
+ # add on non param inputs
60
+ preserved_arg_indices.extend(range(len(flat_params), len(params)))
61
+ # is this necessary ?
62
+ gm.recompile()
63
+ return preserved_arg_indices
64
+
65
+
66
+ def freeze(
67
+ dynamo_gm: torch.fx.GraphModule,
68
+ aot_autograd_gm: torch.fx.GraphModule,
69
+ example_inputs: List[torch._subclasses.FakeTensor],
70
+ ) -> Tuple[torch.fx.GraphModule, List[int]]:
71
+ """
72
+ Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation
73
+ and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency.
74
+
75
+ Assumes that this function is run in dynamo tracing post aot_autograd.
76
+
77
+ Args:
78
+ dynamo_gm (torch.fx.GraphModule): The Dynamo constructed GraphModule.
79
+ aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen.
80
+ example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process.
81
+
82
+ Returns:
83
+ Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices
84
+ of the inputs that were preserved (not turned into constants).
85
+ """
86
+ # We have convert conv's weight to channels last which may meet error for .view
87
+ # when doing fake_tensor_prop. So we need to convert view to reshape first.
88
+ # See the details in fx_codegen_and_compile of compile_fx.py.
89
+ view_to_reshape(aot_autograd_gm)
90
+
91
+ if tracing_context := torch._guards.TracingContext.try_get():
92
+ fw_metadata = tracing_context.fw_metadata
93
+ params_flat = tracing_context.params_flat
94
+ assert fw_metadata is not None and params_flat is not None
95
+
96
+ preserved_arg_indices = replace_params_with_constants(
97
+ aot_autograd_gm, params_flat, fw_metadata
98
+ )
99
+ else:
100
+ inputs = aot_autograd_gm.graph.find_nodes(op="placeholder")
101
+ preserved_arg_indices = list(range(len(inputs)))
102
+
103
+ # TODO - further restrict cse ? right now needed to dedup aliasing ops
104
+ cse_graph = fx_graph_cse(aot_autograd_gm.graph)
105
+ aot_autograd_gm.graph = cse_graph
106
+ aot_autograd_gm.recompile()
107
+
108
+ aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices]
109
+ freezing_passes(aot_autograd_gm, aot_example_inputs)
110
+
111
+ constant_fold(aot_autograd_gm)
112
+ # invalidate nn Modules
113
+ if config.freezing_discard_parameters:
114
+ invalidate_eager_modules()
115
+ discard_traced_gm_params(dynamo_gm)
116
+
117
+ log.debug(
118
+ "%s", lazy_format_graph_code("FROZEN GRAPH", aot_autograd_gm, colored=True)
119
+ )
120
+
121
+ return aot_autograd_gm, preserved_arg_indices
122
+
123
+
124
+ class ErasedTensor(torch.Tensor):
125
+ @staticmethod
126
+ def __new__(cls, elem, name, owning_mod):
127
+ return super().__new__(cls, elem.to(device="meta"))
128
+
129
+ def __init__(self, elem, name: Optional[str], mod) -> None:
130
+ self.erased_name = name
131
+ self.owning_mod_ref = weakref.ref(mod)
132
+
133
+ @classmethod
134
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
135
+ erased_tensors = [
136
+ e
137
+ for e in pytree.arg_tree_leaves(*args, **kwargs)
138
+ if isinstance(e, ErasedTensor)
139
+ ]
140
+ assert len(erased_tensors) > 0
141
+ e = erased_tensors[0]
142
+
143
+ raise RuntimeError(
144
+ f"Trying to run Pytorch Eager Module after Dynamo Freezing. "
145
+ "The original parameters have been discarded for memory efficiency. "
146
+ f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}"
147
+ )
148
+
149
+
150
+ def invalidate_eager_modules():
151
+ with torch.utils._python_dispatch._disable_current_modes():
152
+ for (
153
+ mod
154
+ ) in torch._guards.TracingContext.get().module_context.nn_modules.values():
155
+ if not isinstance(mod, torch.nn.Module):
156
+ continue
157
+
158
+ for attr_name, tensor in list(
159
+ itertools.chain(
160
+ mod.named_parameters(recurse=False),
161
+ mod.named_buffers(recurse=False),
162
+ )
163
+ ):
164
+ with torch._dispatch.python.no_python_dispatcher():
165
+ e_t = ErasedTensor(tensor, attr_name, mod)
166
+ if isinstance(tensor, torch.nn.Parameter):
167
+ e_t.requires_grad_(True)
168
+ e_t._is_param = True # type: ignore[attr-defined]
169
+ setattr(mod, attr_name, e_t)
170
+
171
+
172
+ def discard_traced_gm_params(mod: torch.fx.GraphModule):
173
+ with torch.utils._python_dispatch._disable_current_modes():
174
+ for attr_name, tensor in list(
175
+ itertools.chain(
176
+ mod.named_parameters(recurse=False), mod.named_buffers(recurse=False)
177
+ )
178
+ ):
179
+ with torch._dispatch.python.no_python_dispatcher():
180
+ e_t = ErasedTensor(tensor, attr_name, mod)
181
+ if isinstance(tensor, torch.nn.Parameter):
182
+ e_t.requires_grad_(True)
183
+ e_t._is_param = True # type: ignore[attr-defined]
184
+ setattr(mod, attr_name, e_t)
185
+
186
+
187
+ def enforce_output_layout(gm: torch.fx.GraphModule):
188
+ """
189
+ Make sure the output node's layout does not change due to compiler optimizations
190
+ by adding aten.as_strided nodes with the expected strides.
191
+
192
+ Only used for inference so we can assume all graph outputs are model outputs.
193
+ """
194
+ *_, output_node = gm.graph.nodes
195
+ out_list = output_node.args[0]
196
+ with gm.graph.inserting_before(output_node):
197
+ for n in out_list:
198
+ if not isinstance(
199
+ n.meta["val"], torch.Tensor
200
+ ) or not torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]):
201
+ continue
202
+
203
+ # add a node to enforce eager layout
204
+ ft = n.meta["val"]
205
+ new_node = gm.graph.call_function(
206
+ prims.inductor_force_stride_order.default, (n, ft.stride())
207
+ )
208
+
209
+ # can not call
210
+ # n.replace_all_uses_with(new_node)
211
+ # since it will replace the usage of n in new_node itself.
212
+ output_node.replace_input_with(n, new_node)
213
+
214
+ gm.graph.lint()
215
+ gm.recompile()
216
+
217
+
218
+ def enforce_as_strided_input_layout(gm: torch.fx.GraphModule):
219
+ """
220
+ Make sure the as_strided node's input's layout does not change due to compiler
221
+ optimizations, because the as_strided strides info depends on input tensor stride info.
222
+ """
223
+
224
+ as_strided_ops = [
225
+ torch.ops.aten.as_strided.default,
226
+ torch.ops.aten.as_strided_.default,
227
+ torch.ops.aten.as_strided_scatter.default,
228
+ ]
229
+ strided_nodes = [n for n in gm.graph.nodes if n.target in as_strided_ops]
230
+ for n in strided_nodes:
231
+ with gm.graph.inserting_before(n):
232
+ # add a node to enforce eager layout
233
+ ft = n.args[0].meta["val"]
234
+ new_node = gm.graph.call_function(
235
+ prims.inductor_force_stride_order.default, (n.args[0], ft.stride())
236
+ )
237
+ n.replace_input_with(n.args[0], new_node)
238
+
239
+ gm.graph.lint()
240
+ gm.recompile()
241
+
242
+
243
+ def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule):
244
+ """
245
+ Convert 4d convolution weight tensor to channels last format.
246
+
247
+ This pass is performed before freezing so the added nodes can be constant
248
+ folded by freezing.
249
+ """
250
+ with dynamo_timed("convert_conv_weights_to_channels_last"):
251
+ convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
252
+ for conv in convs:
253
+ weight_node = conv.args[1]
254
+ if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[
255
+ "val"
256
+ ].is_contiguous(memory_format=torch.channels_last):
257
+ # not a 4d tensor or already channels last, skip
258
+ continue
259
+
260
+ with gm.graph.inserting_before(conv):
261
+ new_node = gm.graph.call_function(
262
+ aten.clone.default,
263
+ (weight_node,),
264
+ {"memory_format": torch.channels_last},
265
+ )
266
+ conv.replace_input_with(weight_node, new_node)
267
+
268
+ enforce_as_strided_input_layout(gm)
269
+ enforce_output_layout(gm)
.venv/lib/python3.11/site-packages/torch/_inductor/fx_utils.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import operator
3
+ from collections import defaultdict
4
+ from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
5
+
6
+ import sympy
7
+
8
+ import torch
9
+ import torch.fx
10
+ from torch.fx.experimental.symbolic_shapes import (
11
+ compute_unbacked_bindings,
12
+ rebind_unbacked,
13
+ statically_known_true,
14
+ sym_eq,
15
+ )
16
+ from torch.utils import _pytree as pytree
17
+ from torch.utils._pytree import tree_map
18
+
19
+ from .virtualized import V
20
+
21
+
22
+ # Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
23
+ # Works for length 2 patterns with 1 module and 1 function/method.
24
+ def matches_module_function_pattern(
25
+ pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
26
+ node: torch.fx.node.Node,
27
+ modules: Dict[str, torch.nn.modules.Module],
28
+ ) -> bool:
29
+ if len(node.args) == 0:
30
+ return False
31
+ if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
32
+ node, torch.fx.Node
33
+ ):
34
+ return False
35
+ # the first node is call_module
36
+ if node.args[0].op != "call_module":
37
+ return False
38
+ if not isinstance(node.args[0].target, str):
39
+ return False
40
+ if node.args[0].target not in modules:
41
+ return False
42
+ if type(modules[node.args[0].target]) is not pattern[0]:
43
+ return False
44
+ # the second node is call_function or call_method
45
+ if node.op != "call_function" and node.op != "call_method":
46
+ return False
47
+ if node.target != pattern[1]:
48
+ return False
49
+ # make sure node.args[0] output is only used by current node.
50
+ if len(node.args[0].users) > 1:
51
+ return False
52
+ return True
53
+
54
+
55
+ class FakeTensorUpdater:
56
+ """
57
+ The main idea here is that it's difficult to maintain accurate fake
58
+ tensors (our primary form of metadata) for each node in our graph as we
59
+ transform it.
60
+
61
+ The most reliable way to obtain this information is by rerunning
62
+ faketensor propagation. However, in general, faketensor propagation is
63
+ fairly expensive. So, instead we'd like to only rerun faketensor
64
+ propagation on nodes that have changed.
65
+
66
+ In order to detect which nodes have changed, we first hash its node,
67
+ target, and argument lists (which are immutable in FX).
68
+
69
+ Then, whenever we call incremental_update, we check which FX nodes have a
70
+ new hash, and recompute the faketensor metadata for that node. Then, we
71
+ continue to recursively compute the faketensors for all users until the
72
+ fake tensors stop changing.
73
+ """
74
+
75
+ def __init__(self, graph: torch.fx.Graph) -> None:
76
+ self.processed_hashes = set()
77
+ self.graph = graph
78
+
79
+ for node in self.graph.nodes:
80
+ self.processed_hashes.add(self.hash_node(node))
81
+
82
+ def hash_node(self, node: torch.fx.Node):
83
+ # todo(chilli): Not a great hash function
84
+ return (node, node.target, id(node.args), id(node.kwargs))
85
+
86
+ def incremental_update(self):
87
+ processed = set()
88
+ existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
89
+ for node in self.graph.nodes:
90
+ existing_storages[get_node_storage(node)] += 1
91
+
92
+ def is_intlist_same(new, old):
93
+ return statically_known_true(sym_eq(new, old))
94
+
95
+ def is_fake_tensor_same(new, old):
96
+ if type(new) != type(old):
97
+ return False
98
+ if isinstance(new, (list, tuple)):
99
+ if len(new) != len(old):
100
+ return False
101
+ return all(
102
+ is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
103
+ )
104
+ if new is None:
105
+ return old is None
106
+ if not isinstance(new, torch.Tensor):
107
+ assert isinstance(
108
+ new, (torch.SymInt, torch.SymBool, torch.SymFloat)
109
+ ), f"Unknown type {type(new)} in {self.graph}"
110
+ return (
111
+ new.node.shape_env._maybe_evaluate_static(
112
+ sympy.Eq(new.node.expr, old.node.expr)
113
+ )
114
+ == sympy.true
115
+ )
116
+ if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
117
+ return False
118
+ if new.layout == torch.strided and (
119
+ not is_intlist_same(new.stride(), old.stride())
120
+ or not statically_known_true(
121
+ new.storage_offset() == old.storage_offset()
122
+ )
123
+ ):
124
+ return False
125
+
126
+ if new.device != old.device:
127
+ return False
128
+
129
+ if get_storage(new) == get_storage(old):
130
+ return True
131
+
132
+ # This is the case where it returns a completely fresh storage that's used nowhere else.
133
+ if (
134
+ existing_storages[get_storage(old)] == 1
135
+ and get_storage(new) not in existing_storages
136
+ ):
137
+ return True
138
+ return False
139
+
140
+ def should_process_node(node):
141
+ # node.target for nodes returning true from this function
142
+ # are called under fake mode and does not work for inductor
143
+ # lowerings. We check if the node.target is an aten operator
144
+ # or operator.getitem which is used when returning multiple
145
+ # tensors from an op.
146
+ return node.op == "call_function" and (
147
+ isinstance(node.target, torch._ops.OpOverload)
148
+ or node.target == operator.getitem
149
+ )
150
+
151
+ to_process = set()
152
+ for node in self.graph.nodes:
153
+ if (
154
+ self.hash_node(node) in self.processed_hashes
155
+ and id(node) not in to_process
156
+ ):
157
+ continue
158
+
159
+ if not should_process_node(node):
160
+ continue
161
+
162
+ is_valid, args, kwargs = get_fake_args_kwargs(node)
163
+ if not is_valid:
164
+ continue
165
+ with V.fake_mode:
166
+ new_fake_tensor = node.target(*args, **kwargs)
167
+ if "val" in node.meta and is_fake_tensor_same(
168
+ new_fake_tensor, node.meta["val"]
169
+ ):
170
+ continue
171
+
172
+ rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor)
173
+
174
+ node.meta["val"] = new_fake_tensor
175
+ if (shape_env := V.fake_mode.shape_env) and (
176
+ symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
177
+ ):
178
+ # Refresh the bindings to the new symbols
179
+ node.meta["unbacked_bindings"] = symbol_to_path
180
+
181
+ existing_storages[get_node_storage(node)] += 1
182
+
183
+ to_process.update([id(user) for user in node.users])
184
+
185
+ self.processed_hashes.add(self.hash_node(node))
186
+
187
+
188
+ def get_storage(t: torch.Tensor) -> int:
189
+ return t.untyped_storage()._cdata
190
+
191
+
192
+ def get_node_storage(node: torch.fx.Node) -> Optional[int]:
193
+ if "val" not in node.meta:
194
+ return None
195
+ if not isinstance(node.meta["val"], torch.Tensor):
196
+ return None
197
+ if not torch._C._has_storage(node.meta["val"]):
198
+ return None
199
+ return get_storage(node.meta["val"])
200
+
201
+
202
+ def get_fake(x):
203
+ if isinstance(x, torch.fx.Node):
204
+ if "val" not in x.meta:
205
+ return x
206
+ return x.meta["val"]
207
+ return x
208
+
209
+
210
+ def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]:
211
+ """
212
+ First value returns a boolean if any of the input nodes don't have a faketensor.
213
+ """
214
+ args, kwargs = tree_map(get_fake, (x.args, x.kwargs))
215
+ if any(
216
+ isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs)
217
+ ):
218
+ return False, args, kwargs
219
+ return True, args, kwargs
220
+
221
+
222
+ def is_node_realized(node: torch.fx.Node) -> bool:
223
+ """Returns true if a node is always realized when lowered to inductor IR.
224
+
225
+ NOTE: This may return some false negatives. e.g. it doesn't
226
+ handle buffers realized heuristically during lowering, or
227
+ buffers realized indirectly through view ops.
228
+ """
229
+ from torch._inductor.lowering import fallbacks, needs_realized_inputs
230
+
231
+ def is_buffer(node: torch.fx.Node) -> bool:
232
+ if node.op == "call_function" and node.target is operator.getitem:
233
+ # For nodes with multiple outputs, we get the fx graph:
234
+ # foo = torch.ops.aten.foo(...)
235
+ # getitem = foo[0]
236
+ # getitem_1 = foo[1]
237
+ # where we need to check if foo is a fallback kernel
238
+ return is_buffer(node.args[0]) # type: ignore[arg-type]
239
+ return node.op in ("placeholder", "output") or node.target in fallbacks
240
+
241
+ if is_buffer(node):
242
+ return True
243
+
244
+ def realizes_inputs(node: torch.fx.Node) -> bool:
245
+ return node.op == "output" or node.target in needs_realized_inputs
246
+
247
+ if any(realizes_inputs(user) for user in node.users):
248
+ return True
249
+
250
+ # Otherwise, assume node isn't realized
251
+ return False
.venv/lib/python3.11/site-packages/torch/_inductor/graph.py ADDED
@@ -0,0 +1,1930 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import itertools
3
+ import logging
4
+ import operator
5
+ import os
6
+ import re
7
+ import sys
8
+ import time
9
+ from collections import defaultdict
10
+ from contextlib import contextmanager
11
+ from types import ModuleType
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ DefaultDict,
16
+ Dict,
17
+ Iterable,
18
+ List,
19
+ NoReturn,
20
+ Optional,
21
+ Sequence,
22
+ Tuple,
23
+ TYPE_CHECKING,
24
+ Union,
25
+ )
26
+
27
+ import sympy
28
+ from sympy import Expr
29
+
30
+ import torch
31
+ import torch._logging
32
+ import torch.fx
33
+ from torch import device, Tensor
34
+ from torch._decomp import get_decompositions
35
+ from torch._dynamo.utils import defake, dynamo_timed
36
+ from torch._logging import LazyString, trace_structured
37
+ from torch._prims_common import make_channels_last_strides_for
38
+ from torch._subclasses.fake_tensor import FakeTensor
39
+ from torch.fx import GraphModule
40
+ from torch.fx.experimental._backward_state import BackwardState
41
+ from torch.fx.experimental.sym_node import magic_methods, method_to_operator
42
+ from torch.fx.experimental.symbolic_shapes import (
43
+ free_unbacked_symbols,
44
+ has_free_symbols,
45
+ resolve_unbacked_bindings,
46
+ RuntimeAssert,
47
+ ShapeEnv,
48
+ SymTypes,
49
+ )
50
+ from torch.fx.graph import Graph
51
+ from torch.fx.node import Node
52
+ from torch.utils._mode_utils import no_dispatch
53
+ from torch.utils._ordered_set import OrderedSet
54
+ from torch.utils._sympy.numbers import int_oo
55
+
56
+ from . import config, ir
57
+ from .codegen.common import (
58
+ BackendFeature,
59
+ DeviceOpOverrides,
60
+ get_backend_features,
61
+ get_device_op_overrides,
62
+ get_wrapper_codegen_for_device,
63
+ init_backend_registration,
64
+ )
65
+ from .exc import (
66
+ CppWrapperCodeGenError,
67
+ LoweringException,
68
+ MissingOperatorWithDecomp,
69
+ MissingOperatorWithoutDecomp,
70
+ )
71
+ from .ir import (
72
+ Constant,
73
+ FixedLayout,
74
+ get_device_type,
75
+ InputBuffer,
76
+ Pointwise,
77
+ Reduction,
78
+ StorageBox,
79
+ TensorBox,
80
+ TorchBindObject,
81
+ )
82
+ from .lowering import (
83
+ FALLBACK_ALLOW_LIST,
84
+ fallback_handler,
85
+ fallback_node_due_to_unsupported_type,
86
+ lowerings,
87
+ make_fallback,
88
+ maybe_layout_constraints,
89
+ needs_realized_inputs,
90
+ unsupported_output_tensor,
91
+ )
92
+ from .scheduler import BaseSchedulerNode
93
+ from .sizevars import SizeVarAllocator
94
+ from .utils import (
95
+ convert_shape_to_inductor,
96
+ gather_origins,
97
+ get_cloned_parameter_buffer_name,
98
+ get_sympy_Expr_dtype,
99
+ maybe_get_suppress_shape_guards_ctx,
100
+ should_assume_input_aligned,
101
+ )
102
+ from .virtualized import NullHandler, V
103
+
104
+
105
+ if TYPE_CHECKING:
106
+ from torch._higher_order_ops.effects import _EffectType
107
+ from .codegen.wrapper import WrapperCodeGen
108
+
109
+ from torch._inductor.codecache import output_code_log
110
+
111
+
112
+ log = logging.getLogger(__name__)
113
+ perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
114
+
115
+ aten = torch.ops.aten
116
+
117
+ _post_grad_graph_counter = itertools.count()
118
+
119
+ if config.is_fbcode():
120
+ from torch._inductor.fb.utils import log_module_code
121
+ else:
122
+
123
+ def log_module_code(*args: Any, **kwargs: Any) -> None:
124
+ pass
125
+
126
+
127
+ def supported_dtype_of_cpp_wrapper(dtype: torch.device, cuda: bool) -> bool:
128
+ supported_dtype = {
129
+ torch.float32,
130
+ torch.float64,
131
+ torch.int64,
132
+ torch.int32,
133
+ torch.int16,
134
+ torch.int8,
135
+ torch.uint8,
136
+ torch.bool,
137
+ torch.bfloat16,
138
+ torch.complex32,
139
+ torch.complex64,
140
+ torch.complex128,
141
+ torch.float16,
142
+ }
143
+ if cuda:
144
+ supported_dtype.add(torch.float8_e4m3fn)
145
+ supported_dtype.add(torch.float8_e5m2)
146
+ supported_dtype.add(torch.float8_e4m3fnuz)
147
+ supported_dtype.add(torch.float8_e5m2fnuz)
148
+
149
+ return dtype in supported_dtype
150
+
151
+
152
+ def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]:
153
+ assert isinstance(
154
+ constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
155
+ ), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
156
+ if isinstance(constant_buffer, sympy.core.numbers.Integer):
157
+ return torch.int64
158
+
159
+ if isinstance(constant_buffer, sympy.Expr):
160
+ return get_sympy_Expr_dtype(constant_buffer)
161
+
162
+ if constant_buffer.is_integer:
163
+ return torch.int64
164
+ elif constant_buffer.is_float:
165
+ return torch.float32
166
+ else:
167
+ return None
168
+
169
+
170
+ def is_magic_method(op: Any) -> bool:
171
+ magic_ops = {method_to_operator(m) for m in magic_methods}
172
+ return op in magic_ops
173
+
174
+
175
+ def getattr_recursive(
176
+ obj: GraphModule, target: str
177
+ ) -> Union[Tensor, torch._C.ScriptObject, GraphModule]:
178
+ target_atoms = target.split(".")
179
+ attr_itr = obj
180
+ for i, atom in enumerate(target_atoms):
181
+ if not hasattr(attr_itr, atom):
182
+ raise RuntimeError(
183
+ f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
184
+ )
185
+ attr_itr = getattr(attr_itr, atom)
186
+ return attr_itr
187
+
188
+
189
+ def mark_nodes_dislike_padding(
190
+ g: Graph, user_visible_outputs: Optional[Dict[str, None]]
191
+ ) -> None:
192
+ """
193
+ Nodes like convolution/convolution_backward want its input to be dense.
194
+ If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction.
195
+
196
+ The pass finds nodes that dislike padding. These are nodes that can be reached
197
+ from a convolution/convolution_backward in the backward direction without
198
+ going thru a reduction.
199
+ """
200
+ if not config.comprehensive_padding:
201
+ return
202
+ ops_dislike_padding = {
203
+ aten.convolution,
204
+ aten.convolution_backward,
205
+ }
206
+ # what's a better way to collect the reduction ops?
207
+ ops_like_padding = {
208
+ aten.var_mean,
209
+ aten.sum,
210
+ aten.mean,
211
+ aten.prod,
212
+ aten.any,
213
+ aten.amin,
214
+ aten.amax,
215
+ aten.min,
216
+ aten.max,
217
+ aten.argmin,
218
+ aten.argmax,
219
+ aten.scatter_reduce,
220
+ }
221
+
222
+ def _get_overload_packet(
223
+ node: torch.fx.Node,
224
+ ) -> Optional[torch._ops.OpOverloadPacket]:
225
+ return (
226
+ node.target._overloadpacket
227
+ if node.op == "call_function"
228
+ # hasattr on OpOverloadPacket is slow, do isinstance first
229
+ and isinstance(node.target, torch._ops.OpOverload)
230
+ and hasattr(node.target, "_overloadpacket")
231
+ else None
232
+ )
233
+
234
+ for cur in reversed(g.nodes):
235
+ op = _get_overload_packet(cur)
236
+ if not op:
237
+ continue
238
+ if op in ops_dislike_padding:
239
+ cur.meta["dislike_padding"] = True
240
+
241
+ if cur.meta.get("dislike_padding", False):
242
+ # propagate
243
+ for prior in cur.all_input_nodes:
244
+ prior_op = _get_overload_packet(prior)
245
+ if not prior_op:
246
+ continue
247
+ if prior_op not in ops_like_padding:
248
+ prior.meta["dislike_padding"] = True
249
+ # We only want to mark output nodes. So, move it after the above prior nodes process.
250
+ if (
251
+ not config.pad_outputs
252
+ and user_visible_outputs
253
+ and cur.name in user_visible_outputs
254
+ ):
255
+ cur.meta["dislike_padding"] = True
256
+
257
+
258
+ class GraphLowering(torch.fx.Interpreter):
259
+ graph_outputs: List[ir.IRNode]
260
+
261
+ def symbolic_sizes_strides(
262
+ self, ex: torch.Tensor
263
+ ) -> Tuple[Union[List[int], List[Expr]], Union[List[int], List[Expr]]]:
264
+ """
265
+ Support dynamic shapes and dynamic strides by assigning variables
266
+ to each dimension. We duck-shape tensors, so if two tensors
267
+ have the same size they get assigned the same symbolic variable.
268
+ """
269
+ if self.reuse_shape_env:
270
+ return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor(
271
+ ex.stride()
272
+ )
273
+ else:
274
+ from torch._dynamo.source import ConstantSource
275
+
276
+ # TODO: this should not be needed once #93059 lands
277
+ # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816
278
+ # TODO: make a dedicated UnknownSource for this?
279
+ # NB: This is using the legacy default behavior from
280
+ # create_symbolic_sizes_strides_storage_offset but we hope we can
281
+ # just delete this entirely
282
+ source = ConstantSource(
283
+ f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}"
284
+ )
285
+ (
286
+ size,
287
+ stride,
288
+ _,
289
+ ) = self._shape_env.create_symbolic_sizes_strides_storage_offset(
290
+ ex,
291
+ source,
292
+ )
293
+
294
+ size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size]
295
+ stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride]
296
+ return size, stride
297
+
298
+ def static_sizes_strides(
299
+ self, ex: torch.Tensor
300
+ ) -> Tuple[List[sympy.Expr], List[sympy.Expr]]:
301
+ """
302
+ Primarily used to weights
303
+ """
304
+ size = [sympy.Integer(i) for i in ex.size()]
305
+ stride = [sympy.Integer(i) for i in ex.stride()]
306
+ return size, stride
307
+
308
+ def __init__(
309
+ self,
310
+ gm: torch.fx.GraphModule,
311
+ example_inputs: Optional[List[torch.Tensor]] = None,
312
+ shape_env: Optional[ShapeEnv] = None,
313
+ graph_id: Optional[int] = None,
314
+ cpp_wrapper: bool = False,
315
+ aot_mode: bool = False,
316
+ user_visible_outputs: Optional[Dict[str, None]] = None,
317
+ layout_opt: Optional[bool] = None,
318
+ extern_node_serializer: Optional[
319
+ Callable[[List[ir.ExternKernelNode]], Any]
320
+ ] = None,
321
+ is_inference: bool = False,
322
+ is_const_graph: bool = False,
323
+ const_output_index: Optional[Dict[str, int]] = None,
324
+ const_code: Optional[str] = None,
325
+ const_module: Optional["GraphLowering"] = None,
326
+ name: Optional[str] = None,
327
+ ) -> None:
328
+ super().__init__(gm)
329
+ self.example_inputs = example_inputs
330
+ self.layout_opt = (
331
+ layout_opt
332
+ if layout_opt is not None
333
+ else self.decide_layout_opt(gm, is_inference=is_inference)
334
+ )
335
+ self.num_channels_last_conv = 0
336
+ self.is_inference = is_inference
337
+ self.is_const_graph = is_const_graph
338
+ self.const_code = const_code
339
+ self.const_module = const_module
340
+
341
+ self.extra_traceback = False # we do our own error wrapping
342
+ if shape_env is None:
343
+ shape_env = ShapeEnv()
344
+ self.reuse_shape_env = False
345
+ else:
346
+ self._shape_env = shape_env
347
+ self.reuse_shape_env = True
348
+ self._shape_env = shape_env
349
+ # We are going to start code generating runtime asserts, so make sure
350
+ # you don't start adding new ones in the lowering process
351
+ shape_env.freeze_runtime_asserts()
352
+ # We're going to mutate ras_by_symbol as we finish generating them
353
+ self.ras_by_symbol: Dict[
354
+ sympy.Symbol, List[RuntimeAssert]
355
+ ] = shape_env.deferred_runtime_asserts.copy()
356
+ self.bound_unbacked_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
357
+ self.sizevars = SizeVarAllocator(shape_env)
358
+ self.graph_input_names: List[str] = []
359
+ self.graph_inputs: Dict[str, TensorBox] = {}
360
+ self.graph_inputs_original: Dict[str, InputBuffer] = {}
361
+ self.zero_dim_cpu_tensor_list: OrderedSet[str] = OrderedSet()
362
+ self.device_types: OrderedSet[str] = (
363
+ const_module.device_types if const_module else OrderedSet()
364
+ )
365
+ self.device_idxs: OrderedSet[int] = (
366
+ const_module.device_idxs if const_module else OrderedSet()
367
+ )
368
+ self.cuda = False
369
+ self.buffers: List[ir.Buffer] = []
370
+ self.operations: List[ir.Operation] = []
371
+ self.const_output_index: Dict[str, int] = (
372
+ const_output_index if const_output_index else {}
373
+ )
374
+ self.folded_constants: OrderedSet[str] = (
375
+ OrderedSet(const_output_index.keys())
376
+ if const_output_index
377
+ else OrderedSet()
378
+ )
379
+ self.constants: Dict[str, torch.Tensor] = (
380
+ const_module.constants if const_module else {}
381
+ )
382
+ self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {}
383
+ self.constant_reprs: Dict[str, str] = {}
384
+ self.removed_operations: OrderedSet[str] = OrderedSet()
385
+ self.removed_buffers: OrderedSet[str] = OrderedSet()
386
+ self.removed_inplace_buffers: OrderedSet[str] = OrderedSet()
387
+ self.mutated_buffers: OrderedSet[str] = OrderedSet()
388
+ self.never_reuse_buffers: OrderedSet[str] = OrderedSet()
389
+ self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
390
+ self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
391
+ self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment]
392
+ # See `ProxyExecutor Design Note` in ir.py for more details
393
+ self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
394
+
395
+ from torch._inductor.extern_node_serializer import extern_node_json_serializer
396
+
397
+ self.extern_node_serializer: Callable[[List[ir.ExternKernelNode]], Any] = (
398
+ extern_node_serializer
399
+ if config.is_fbcode() and extern_node_serializer
400
+ else extern_node_json_serializer
401
+ )
402
+
403
+ self.current_node: torch.fx.Node = None # type: ignore[assignment]
404
+ self.lists: Dict[str, List[str]] = {}
405
+ self.mutated_inputs: OrderedSet[str] = OrderedSet()
406
+ self.mutated_input_idxs: List[int] = []
407
+ self.name_to_buffer: Dict[str, ir.Buffer] = {}
408
+ self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
409
+ self.name_to_op: Dict[str, ir.Operation] = {}
410
+ self.creation_time = time.time()
411
+ self.name = name # type: ignore[assignment]
412
+ self.cpp_wrapper = cpp_wrapper
413
+
414
+ # record multi_kernel choice for cpp_wrapper so the second pass knows
415
+ # which sub-kernel is picked. Copy cpp_wrapper to another variable
416
+ # since cpp_wrapper flag is OrderedSet to false for the first pass of codegen.
417
+ self.record_multi_kernel_choice = cpp_wrapper
418
+ self.multi_kernel_to_choice: Dict[str, int] = {}
419
+
420
+ self.aot_mode = aot_mode
421
+ self.graph_id = graph_id
422
+ self.post_grad_graph_id = next(_post_grad_graph_counter)
423
+ self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment]
424
+ self.nodes_prefer_channels_last = (
425
+ self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet()
426
+ )
427
+ self._warned_fallback = {"aten.convolution_backward"}
428
+ self.user_visible_outputs = (
429
+ user_visible_outputs if user_visible_outputs is not None else {}
430
+ )
431
+ mark_nodes_dislike_padding(gm.graph, user_visible_outputs)
432
+ self.cache_key: str = "" # This is the cache key for the compiled artifact
433
+ self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
434
+ self.cache_linemap: List[
435
+ Tuple[int, str]
436
+ ] = (
437
+ []
438
+ ) # This is the linemap used by the profiler to mark custom compiled kernels getting run
439
+ # Used if lowering encounters cases where cudagraphs are not supported
440
+ self.disable_cudagraphs_reason: Optional[str] = None
441
+
442
+ # only keeping one node per device for stack trace purposes
443
+ self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
444
+ self.orig_gm: torch.fx.GraphModule = gm.__copy__()
445
+ self.dynamo_flat_name_to_original_fqn = self.module.meta.get(
446
+ "dynamo_flat_name_to_original_fqn", {}
447
+ )
448
+ self.allocated_constant_name: Dict[str, str] = (
449
+ const_module.allocated_constant_name if const_module is not None else {}
450
+ )
451
+ init_backend_registration()
452
+ self.get_backend_features = functools.lru_cache(None)(get_backend_features)
453
+
454
+ self.effectful_ops: Dict[_EffectType, ir.Buffer] = {}
455
+ self.aligned_inputs: OrderedSet[str] = OrderedSet()
456
+ self.no_fuse_buffer_names: OrderedSet[str] = OrderedSet()
457
+
458
+ # Below field is related to printing debug intermediate tensor values info for debugging
459
+ self.all_codegen_kernel_names: OrderedSet[str] = OrderedSet()
460
+
461
+ def has_feature(
462
+ self, device: Union[torch._inductor.ir.IRNode, device], feature: BackendFeature
463
+ ) -> bool:
464
+ assert isinstance(feature, BackendFeature), feature
465
+ return feature in self.get_backend_features(get_device_type(device))
466
+
467
+ @staticmethod
468
+ def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool:
469
+ """
470
+ Decide if we should enable layout optimization for this graph based on
471
+ heuristics.
472
+ """
473
+ if not config.layout_optimization:
474
+ return False
475
+
476
+ if config.force_layout_optimization:
477
+ return True
478
+
479
+ conv_nodes = [
480
+ n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default
481
+ ]
482
+ nconv = len(conv_nodes)
483
+
484
+ if nconv == 0:
485
+ return False
486
+
487
+ # For cpu backend and mkldnn enabled, we always use channels_last for better performance.
488
+ if (
489
+ torch.backends.mkldnn.enabled
490
+ and torch.backends.mkldnn.is_available()
491
+ and all(
492
+ n.args[idx].meta["val"].device == torch.device("cpu")
493
+ for n in conv_nodes
494
+ for idx in [0, 1]
495
+ )
496
+ ):
497
+ return True
498
+
499
+ # Following models are skipped due to this:
500
+ # jx_nest_base
501
+ # volo_d1_224
502
+ if len(list(gm.graph.nodes)) >= 300 * nconv:
503
+ log.debug("Skipped layout opt because only a few conv")
504
+ return False
505
+
506
+ if any(
507
+ has_free_symbols(n.args[idx].meta["val"])
508
+ for n in conv_nodes
509
+ for idx in [0, 1]
510
+ ):
511
+ log.debug(
512
+ "See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670"
513
+ )
514
+ return False
515
+
516
+ def is_grouped(n: Any) -> bool:
517
+ meta_val = n.args[1].meta["val"] # type: ignore[union-attr, operator]
518
+ assert isinstance(meta_val, torch.Tensor)
519
+ return n.args[-1] > 1 and meta_val.size(1) > 1 # type: ignore[union-attr, operator]
520
+
521
+ def is_in_out_channel(n: torch.fx.Node) -> bool:
522
+ return (
523
+ n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) # type: ignore[union-attr, operator]
524
+ and n.args[1].meta["val"].size(2) > 1 # type: ignore[union-attr, operator]
525
+ )
526
+
527
+ def is_small_channel(n: torch.fx.Node) -> bool:
528
+ return (
529
+ n.args[1].meta["val"].size(0) <= 64 # type: ignore[union-attr, operator]
530
+ and n.args[1].meta["val"].size(1) <= 64 # type: ignore[union-attr, operator]
531
+ )
532
+
533
+ # only grouped convolutions benchmarked as slower in conv samples for inference only
534
+ if is_inference:
535
+ from torch.utils.flop_counter import FlopCounterMode
536
+
537
+ flop_counts: Dict[str, float] = defaultdict(float)
538
+ for node in conv_nodes:
539
+ success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(
540
+ node
541
+ )
542
+
543
+ if success:
544
+ with FlopCounterMode(display=False) as flop_counter_mode:
545
+ with V.fake_mode:
546
+ node.target(*args, **kwargs)
547
+
548
+ counted_flops = flop_counter_mode.get_total_flops()
549
+ if is_grouped(node):
550
+ node_type = "grouped"
551
+ elif is_small_channel(node):
552
+ node_type = "small"
553
+ elif is_in_out_channel(node):
554
+ node_type = "in_out"
555
+ else:
556
+ node_type = "default"
557
+
558
+ flop_counts[node_type] += counted_flops
559
+ else:
560
+ log.debug("Conv inputs meta not found")
561
+
562
+ # average benchmarked channels last speedup / slowdown, < 1 is speedup.
563
+ # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/
564
+ # To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb
565
+ GROUPED_MULTIPLIER = 1.358
566
+ DEFAULT_MULTIPLIER = 0.823
567
+ IN_OUT_MULTIPLIER = 0.725
568
+ SMALL_MULTIPLIER = 0.783
569
+
570
+ total_flops = sum(flop_counts.values())
571
+ # TODO - get different values per hardware
572
+ weighted_flops = (
573
+ flop_counts["grouped"] * GROUPED_MULTIPLIER
574
+ + flop_counts["small"] * SMALL_MULTIPLIER
575
+ + flop_counts["in_out"] * IN_OUT_MULTIPLIER
576
+ + flop_counts["default"] * DEFAULT_MULTIPLIER
577
+ )
578
+ do_layout_opt = weighted_flops <= total_flops
579
+ if not do_layout_opt:
580
+ log.debug(
581
+ "Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d",
582
+ total_flops,
583
+ weighted_flops,
584
+ )
585
+ return do_layout_opt
586
+
587
+ # Channels last layout can dramatically hurt grouped conv perf. E.g.
588
+ # Conv with arguments like
589
+ # {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3],
590
+ # "stride": [2, 2], "padding": [1, 1], "groups": 2}
591
+ # slows down 31x using channels last..
592
+
593
+ # But a lot of timm models use depthwise separable convolution which will
594
+ # result in grouped convolution with in-channel size == 1.
595
+ # For those grouped convolution, channels last still helps a lot.
596
+ # E.g.
597
+ # Conv with arguments
598
+ # {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3],
599
+ # "stride": [2, 2], "padding": [1, 1], "groups": 58}
600
+ # get 1.86x speedup with channels last layout.
601
+ #
602
+ # The following heuristics skip using channels-last if the model contains
603
+ # grouped convolution with in-channels > 1.
604
+ if any(map(is_grouped, conv_nodes)):
605
+ log.debug(
606
+ "Skip layout opt because found grouped convolution with >1 in_channels!"
607
+ )
608
+ return False
609
+
610
+ # For some models that contain convolution with larger in-channel than out-channel, applying
611
+ # channels last hurts performance.
612
+ # Following models are skipped due to this:
613
+ # - pytorch_unet
614
+ # - phlippe_densenet (slightly worse)
615
+ # - Background_Matting (1.22x -> 0.821x)
616
+ # - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x)
617
+ if any(map(is_in_out_channel, conv_nodes)):
618
+ log.debug(
619
+ "Skip layout opt because some convolutions have smaller out_channel"
620
+ )
621
+ return False
622
+
623
+ # Following models are skipped due to this:
624
+ # - functorch_maml_omniglot
625
+ if all(map(is_small_channel, conv_nodes)):
626
+ log.debug("Skip layout opt because all convolution channels are too small")
627
+ return False
628
+
629
+ return True
630
+
631
+ def qualify_name(self, name: str) -> str:
632
+ """Prepend the given name with the graph name if any."""
633
+ if self.name is not None:
634
+ return f"{self.name}_{name}"
635
+ return name
636
+
637
+ def make_subgraph(
638
+ self,
639
+ gm: torch.fx.GraphModule,
640
+ example_inputs: List[torch.Tensor],
641
+ subgraph_name: str,
642
+ ) -> "GraphLowering":
643
+ """
644
+ Make a subgraph of the current graph with all inherited
645
+ parts, except the graph module (`gm`) and `example_inputs`.
646
+ The subgraphs are lowered separately, but intended to be
647
+ inlined in the parent graph's codegening. Hence the need
648
+ for maintaining the same `shape_env` and other properties.
649
+ The subgraph name is qualified by the parent graph's name.
650
+ """
651
+ return GraphLowering(
652
+ gm=gm,
653
+ example_inputs=example_inputs,
654
+ shape_env=self._shape_env,
655
+ cpp_wrapper=self.cpp_wrapper,
656
+ aot_mode=self.aot_mode,
657
+ extern_node_serializer=self.extern_node_serializer,
658
+ is_inference=self.is_inference,
659
+ name=self.qualify_name(subgraph_name),
660
+ )
661
+
662
+ def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]:
663
+ """
664
+ The rule to decide if an node prefer channels last is simple.
665
+ 1. if it's input/output of a convolution
666
+ 2. if one of its user prefers channels last
667
+
668
+ We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs;
669
+ Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers
670
+ channels last.
671
+
672
+ Consider the scenario: conv -> batch-norm -> relu -> conv
673
+ Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies:
674
+ 1. the output of batch-norm should be channels last initially since its input is a conv's output.
675
+ Forcing the batch-norm's output to be contiguous results in the first copy
676
+ 2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output.
677
+ We need convert it to channels last layout which results in the second copy.
678
+ With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies
679
+ can be saved.
680
+ """
681
+ output_set: OrderedSet[Node] = OrderedSet()
682
+ for n in reversed(self.module.graph.nodes):
683
+ if n.target == torch.ops.aten.convolution.default:
684
+ output_set.add(n)
685
+ continue
686
+
687
+ for user in n.users:
688
+ if user in output_set:
689
+ output_set.add(n)
690
+ break
691
+
692
+ # need a second pass to add downstream nodes of those channel last nodes to the sets.
693
+ # This pass is especially needed to avoid mix-layout kernel inputs in backward pass.
694
+ #
695
+ # Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned
696
+ # from the fwd graph. Without this second pass, we will force relu's output to be contiguous.
697
+ # Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last
698
+ # tensors and passed to a kernel.
699
+ #
700
+ # This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x.
701
+ # It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x .
702
+ # This also helps the following models:
703
+ # - res2net101_26w_4s
704
+ # - res2net50_14w_8s
705
+ # - sebotnet33ts_256
706
+ for n in self.module.graph.nodes:
707
+ if n in output_set:
708
+ output_set.update(n.users)
709
+
710
+ return output_set
711
+
712
+ def warn_fallback(self, name: str) -> None:
713
+ if name not in self._warned_fallback:
714
+ self._warned_fallback.add(name)
715
+ perf_hint_log.info("Using FallbackKernel: %s", name)
716
+
717
+ def add_device_info(self, device: torch.device) -> None:
718
+ self.device_types.add(device.type)
719
+ if device.index is not None:
720
+ self.device_idxs.add(device.index)
721
+ if V.graph.current_node and device not in self.device_node_mapping:
722
+ self.device_node_mapping[device] = V.graph.current_node
723
+
724
+ @property
725
+ def fake_mode(self) -> torch._subclasses.fake_tensor.FakeTensorMode:
726
+ return V.fake_mode
727
+
728
+ def try_get_buffer(
729
+ self, buffer_name: str
730
+ ) -> Optional[Union[ir.TensorBox, ir.Buffer]]:
731
+ if buffer_name in self.name_to_buffer:
732
+ return self.name_to_buffer[buffer_name]
733
+ if buffer_name in self.graph_inputs:
734
+ return self.graph_inputs[buffer_name]
735
+ if buffer_name in self.constants:
736
+ data = V.graph.constants[buffer_name]
737
+ return ir.ConstantBuffer(
738
+ buffer_name,
739
+ ir.FixedLayout(
740
+ data.device, data.dtype, *V.graph.static_sizes_strides(data)
741
+ ),
742
+ )
743
+
744
+ return None
745
+
746
+ def get_buffer(self, buffer_name: str) -> Union[ir.TensorBox, ir.Buffer]:
747
+ buf = self.try_get_buffer(buffer_name)
748
+ if buf is not None:
749
+ return buf
750
+ raise RuntimeError(f"Failed to find buffer matching name {buffer_name}")
751
+
752
+ def get_dtype(self, buffer_name: str) -> torch.dtype:
753
+ if buffer_name in self.constants:
754
+ return self.constants[buffer_name].dtype
755
+ if buffer_name in self.name_to_buffer:
756
+ return self.name_to_buffer[buffer_name].get_dtype()
757
+ if buffer_name in self.graph_inputs:
758
+ return self.graph_inputs[buffer_name].get_dtype()
759
+ m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name)
760
+ if m:
761
+ return self.get_dtype(m.group(1))
762
+ raise KeyError(f"could not find {buffer_name}")
763
+
764
+ def get_numel(self, buffer_name: str) -> Union[int, Expr]:
765
+ from .ir import MultiOutputLayout
766
+
767
+ if buffer_name in self.constants:
768
+ return self.constants[buffer_name].numel()
769
+ if buffer_name in self.name_to_buffer:
770
+ buf = self.name_to_buffer[buffer_name]
771
+ if isinstance(getattr(buf, "layout", None), MultiOutputLayout):
772
+ return 1
773
+ return buf.get_numel()
774
+ if buffer_name in self.graph_inputs:
775
+ return self.graph_inputs[buffer_name].get_numel()
776
+ raise KeyError(f"could not find {buffer_name}")
777
+
778
+ def run(self, *args: Any) -> Any: # type: ignore[override]
779
+ with dynamo_timed("GraphLowering.run"):
780
+ return super().run(*args)
781
+
782
+ def register_operation(self, op: ir.Operation) -> str:
783
+ assert op.operation_name is None, f"Operation registered twice: {op}"
784
+ assert isinstance(op, ir.Operation)
785
+ name = self.qualify_name(f"op{len(self.operations)}")
786
+ self.operations.append(op)
787
+ self.name_to_op[name] = op
788
+ op.operation_name = name
789
+ return name
790
+
791
+ def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str:
792
+ name = self.qualify_name(f"buf{len(self.buffers)}")
793
+ self.buffers.append(buffer)
794
+ self.name_to_buffer[name] = buffer
795
+ if (
796
+ # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
797
+ not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements())
798
+ and buffer.get_device() is not None
799
+ ):
800
+ self.add_device_info(buffer.get_device())
801
+
802
+ if set_name:
803
+ buffer.name = name
804
+ return name
805
+
806
+ def register_operation_list(self, operation_names: List[str]) -> str:
807
+ name = self.qualify_name("list_" + "_".join(operation_names))
808
+ self.lists[name] = operation_names
809
+ return name
810
+
811
+ def register_users_of(
812
+ self, node_output: Union[Iterable[ir.IRNode], ir.IRNode]
813
+ ) -> None:
814
+ def register(value: Union[Iterable[ir.IRNode], ir.IRNode]) -> None:
815
+ if isinstance(value, (list, tuple)):
816
+ for x in value:
817
+ register(x)
818
+ if isinstance(value, ir.TensorBox):
819
+ for read_name in value.get_read_names():
820
+ self.name_to_users[read_name].append(value)
821
+
822
+ register(node_output)
823
+
824
+ def mark_buffer_mutated(self, name: str) -> None:
825
+ """
826
+ When a buffer is mutated we need to make sure all the reads to
827
+ the old version are realized before the mutation happens.
828
+ """
829
+ assert isinstance(name, str)
830
+ self.mutated_buffers.add(name)
831
+
832
+ if name not in self.name_to_users:
833
+ return
834
+
835
+ for user in self.name_to_users[name]:
836
+ user.realize()
837
+
838
+ def get_original_value_of_constant(self, name: str) -> torch.Tensor:
839
+ """
840
+ In AOTI, module buffers may have been mutated during the tracing and compilation.
841
+ Thus we need to read from previously stored original buffers, to make sure the
842
+ generated model.so uses correct initial values.
843
+ """
844
+ assert name in self.allocated_constant_name and name in self.constants, (
845
+ "Can not find the original value for " + name
846
+ )
847
+ orig_name = get_cloned_parameter_buffer_name(self.allocated_constant_name[name])
848
+ return (
849
+ self.module.meta[orig_name]
850
+ if orig_name in self.module.meta
851
+ else self.constants[name]
852
+ )
853
+
854
+ def allocate_non_dup_const_name(
855
+ self, name: Optional[str], data: Union[Tensor]
856
+ ) -> str:
857
+ orig_name = name
858
+ if not config.aot_inductor.use_runtime_constant_folding:
859
+ for constant_name, value in self.constants.items():
860
+ if (
861
+ not data.is_mkldnn
862
+ and data.size() == value.size()
863
+ and data.stride() == value.stride()
864
+ and data.dtype == value.dtype
865
+ and data.device == value.device
866
+ and data.untyped_storage().data_ptr()
867
+ == value.untyped_storage().data_ptr()
868
+ and data.storage_offset() == value.storage_offset()
869
+ ):
870
+ return constant_name
871
+
872
+ if name is None:
873
+ name = f"constant{len(self.constants)}"
874
+ assert name is not None
875
+ if name[0].isdigit():
876
+ name = f"constant_{name}"
877
+ name = self.qualify_name(name)
878
+ # We may generate a var name for each constant in the codegen.
879
+ # Let's only keep sane characters.
880
+ prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name)
881
+ name = prefix
882
+ cnt = 0
883
+ while name in self.constants:
884
+ name = f"{prefix}_{cnt}"
885
+ cnt += 1
886
+ self.constants[name] = data
887
+ self.constant_reprs[name] = (
888
+ f"{data.device!r} {data.dtype!r} "
889
+ f"{tuple(data.size())!r} {tuple(data.stride())!r} "
890
+ f"{hash(data):x}"
891
+ )
892
+ self.allocated_constant_name[name] = orig_name # type: ignore[assignment]
893
+ return name
894
+
895
+ def add_tensor_constant(
896
+ self, data: Tensor, name: Optional[str] = None
897
+ ) -> TensorBox:
898
+ new_name = self.allocate_non_dup_const_name(name, data)
899
+ return TensorBox.create(
900
+ ir.ConstantBuffer(
901
+ new_name,
902
+ FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)),
903
+ )
904
+ )
905
+
906
+ def constant_name(self, name: str, device_override: Optional[torch.device]) -> str:
907
+ """
908
+ We AOT copy constants to the devices they are needed on.
909
+ If device_override doesn't match the constant's device, then
910
+ copy it and return a different name.
911
+ """
912
+ if self.constants[name].device == device_override or device_override is None:
913
+ return name
914
+ with torch.utils._python_dispatch._disable_current_modes():
915
+ # caller might have OrderedSet fake tensor mode which will create a fake tensor
916
+ # when calling .to, so unset modes here
917
+ return self.allocate_non_dup_const_name(
918
+ f"{name}_{device_override.type}{device_override.index or 0}",
919
+ self.constants[name].to(device_override),
920
+ )
921
+
922
+ def placeholder(
923
+ self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
924
+ ) -> Union[Expr, TensorBox, None]:
925
+ example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
926
+ self.graph_input_names.append(target)
927
+ if isinstance(example, SymTypes):
928
+ expr = example.node.expr
929
+ self.graph_inputs[target] = expr
930
+ return expr
931
+ elif isinstance(example, (int, bool, float)):
932
+ expr = sympy.sympify(example)
933
+ self.graph_inputs[target] = expr
934
+ return expr
935
+ elif example is None:
936
+ return None
937
+ if isinstance(example, BackwardState):
938
+ # Ignored arg, must be unused
939
+ # Alternately we could filter this out in AotAutograd
940
+ return None
941
+ assert isinstance(example, torch.Tensor), example
942
+ # todo(chilli): We can remove the last check once we turn buffers into
943
+ # static shape tensors. That's a hack to workaround Inductor believing
944
+ # the buffer should be static but us passing in a fake tensor with
945
+ # symbolic shapes.
946
+ if not example._has_symbolic_sizes_strides:
947
+ # the first N inputs are weights
948
+ sizes, strides = self.static_sizes_strides(example)
949
+ else:
950
+ sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment]
951
+ # TODO(jansel): handle input aliasing
952
+ target = self.qualify_name(target)
953
+ tensor = TensorBox.create(
954
+ InputBuffer(
955
+ target,
956
+ FixedLayout(example.device, example.dtype, sizes, strides),
957
+ )
958
+ )
959
+ self.graph_inputs[target] = tensor
960
+ self.graph_inputs_original[target] = tensor.data.data
961
+ if self.current_node.users: # cudagraphs should work with an unused CPU input
962
+ self.add_device_info(example.device)
963
+
964
+ # Note: [Input Alignment handling in Inductor]
965
+ # Alignment matters for generating efficient code. Some operations,
966
+ # e.g. vectorized loads, can only be performed on aligned inputs.
967
+ #
968
+ # But if we codegen assuming aligned inputs and then get unaligned
969
+ # inputs at runtime, then we are forced to clone - which is bad for
970
+ # both perf and memory usage.
971
+ #
972
+ # One option would be to guard on storage_offset%ALIGNMENT, and then
973
+ # codegen based on this. But storage_offset guards turned out to be
974
+ # expensive and cause recompiles; Instead, we're generating code
975
+ # based on the alignment of the example input without guarding.
976
+ with maybe_get_suppress_shape_guards_ctx():
977
+ if should_assume_input_aligned(example):
978
+ self.aligned_inputs.add(target)
979
+ return tensor
980
+
981
+ def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg, override]
982
+ if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
983
+ return super().call_function(target, args, kwargs)
984
+
985
+ # hasattr on OpOverloadPacket is slow, check isinstance first
986
+ if not isinstance(target, torch._ops.OpOverloadPacket) and hasattr(
987
+ target, "_inductor_lowering_function"
988
+ ):
989
+ # passthrough lowerings from .pattern_matcher
990
+ return target(*args, **kwargs)
991
+
992
+ if target not in lowerings:
993
+ assert isinstance(
994
+ target, torch._ops.OpOverload
995
+ ), f"{target} is not an OpOverload"
996
+ base_name = target.name().split(".")[0]
997
+ if base_name in FALLBACK_ALLOW_LIST:
998
+ make_fallback(target)
999
+ elif config.implicit_fallbacks:
1000
+ error = (
1001
+ MissingOperatorWithDecomp
1002
+ if get_decompositions([target])
1003
+ else MissingOperatorWithoutDecomp
1004
+ )
1005
+ log.info(
1006
+ "Creating implicit fallback for:\n%s",
1007
+ error.operator_str(target, args, kwargs),
1008
+ )
1009
+ make_fallback(target)
1010
+
1011
+ elif get_decompositions([target]):
1012
+ # There isn't a good way to dynamically patch this in
1013
+ # since AOT Autograd already ran. The error message tells
1014
+ # the user how to fix it.
1015
+ raise MissingOperatorWithDecomp(target, args, kwargs)
1016
+ else:
1017
+ raise MissingOperatorWithoutDecomp(target, args, kwargs)
1018
+
1019
+ try:
1020
+ log.debug(" via %s", lowerings[target]) # type: ignore[index]
1021
+ out = lowerings[target](*args, **kwargs) # type: ignore[index]
1022
+ return out
1023
+ except Exception as e:
1024
+ raise LoweringException(e, target, args, kwargs).with_traceback(
1025
+ e.__traceback__
1026
+ ) from None
1027
+
1028
+ @staticmethod
1029
+ def can_inline_constant(t: torch.Tensor) -> bool:
1030
+ """
1031
+ True if this is a small constant attr that will be inlined.
1032
+ """
1033
+ return len(t.shape) == 1 and t.shape[0] <= 8
1034
+
1035
+ def get_attr(
1036
+ self, target: str, args: Tuple[()], kwargs: Dict[str, object] # type: ignore[override]
1037
+ ) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
1038
+ # this is a constant
1039
+ value = getattr_recursive(self.module, target) # type: ignore[arg-type]
1040
+
1041
+ if isinstance(value, torch.fx.GraphModule):
1042
+ return ir.Subgraph(name=target, graph_module=value)
1043
+
1044
+ if isinstance(value, torch._C.ScriptObject):
1045
+ self.torchbind_constants[target] = value
1046
+ self.constant_reprs[target] = ""
1047
+ return TorchBindObject(target, value)
1048
+
1049
+ assert isinstance(value, torch.Tensor)
1050
+ if (
1051
+ config.aot_inductor.use_runtime_constant_folding
1052
+ or config.always_keep_tensor_constants
1053
+ or unsupported_output_tensor(value)
1054
+ ):
1055
+ return self.add_tensor_constant(value, target)
1056
+
1057
+ with no_dispatch():
1058
+ if value.shape == ():
1059
+ return Constant(value.item(), value.dtype, value.device)
1060
+ if self.can_inline_constant(value):
1061
+ log.debug("Inlining constant: %s ", str(target))
1062
+ # tensor lowering has constant inlining logic
1063
+ from .lowering import tensor
1064
+
1065
+ return tensor(value.tolist(), dtype=value.dtype, device=value.device)
1066
+
1067
+ return self.add_tensor_constant(value, target)
1068
+
1069
+ def call_module(self, target: Any, args: Any, kwargs: Any) -> NoReturn:
1070
+ raise AssertionError
1071
+
1072
+ def call_method(self, target: Any, args: Any, kwargs: Any) -> NoReturn:
1073
+ raise AssertionError
1074
+
1075
+ def output(
1076
+ self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
1077
+ ) -> None:
1078
+ result = super().output(target, args, kwargs) # type: ignore[arg-type]
1079
+ if not isinstance(result, (tuple, list)):
1080
+ # nested subgraphs can have singleton outputs
1081
+ result = (result,)
1082
+ assert isinstance(result, (tuple, list)), type(result)
1083
+ assert all(
1084
+ isinstance(
1085
+ x,
1086
+ (
1087
+ TensorBox,
1088
+ ir.Constant,
1089
+ type(None),
1090
+ ir.ConstantBuffer,
1091
+ sympy.Expr,
1092
+ sympy.logic.boolalg.Boolean,
1093
+ int,
1094
+ ir.EffectfulKernel,
1095
+ ),
1096
+ )
1097
+ for x in result
1098
+ ), result
1099
+
1100
+ fx_node_args = V.graph.current_node.args[0] # type: ignore[arg-type]
1101
+ if not isinstance(fx_node_args, (tuple, list)):
1102
+ # nested subgraphs can have singleton outputs
1103
+ fx_node_args = (fx_node_args,)
1104
+ result = [ir.ExternKernel.realize_input(x) for x in result]
1105
+ result_correct_strides = []
1106
+
1107
+ assert len(fx_node_args) == len(result)
1108
+ for r, fx_node in zip(result, fx_node_args):
1109
+ if not isinstance(r, (ir.TensorBox, ir.BaseView)):
1110
+ result_correct_strides.append(r)
1111
+ else:
1112
+ # AOT Autograd tries to detect stride divergence of inductor from output metadata.
1113
+ # Here, we try to avoid spurious divergence by matching insignificant strides such as
1114
+ result_correct_strides.append(
1115
+ self.try_match_insignificant_strides(
1116
+ r, fx_node.meta["val"].stride()
1117
+ )
1118
+ )
1119
+
1120
+ self.graph_outputs = result_correct_strides
1121
+ value: ir.IRNode
1122
+ for name, value in self.graph_inputs.items():
1123
+ assert isinstance(
1124
+ value, (TensorBox, sympy.Expr)
1125
+ ), f"Unsupported inductor graph input type: {type(value)}"
1126
+ if not isinstance(value, TensorBox):
1127
+ continue
1128
+ value.realize()
1129
+ assert isinstance(value, TensorBox)
1130
+ value = value.data
1131
+ assert isinstance(value, ir.StorageBox)
1132
+ value_storage_box = value
1133
+ value = value.data
1134
+ if not isinstance(value, InputBuffer) or value.get_name() != name:
1135
+ # one of our inputs was mutated, need to turn that into a copy
1136
+ ir.MutationLayoutSHOULDREMOVE.realize_into(
1137
+ value, self.graph_inputs_original[name]
1138
+ )
1139
+ # replace output with mutated input
1140
+ try:
1141
+ ind = self.graph_outputs.index(value_storage_box)
1142
+ self.graph_outputs[ind] = self.graph_inputs_original[name]
1143
+ except ValueError:
1144
+ pass
1145
+
1146
+ self.finalize()
1147
+ log.debug(
1148
+ "Force channels last inputs for %d conv for the current graph with id %d",
1149
+ self.num_channels_last_conv,
1150
+ self.graph_id if self.graph_id is not None else -1,
1151
+ )
1152
+
1153
+ def finalize(self) -> None:
1154
+ for buf in self.buffers:
1155
+ buf.decide_layout()
1156
+
1157
+ @contextmanager
1158
+ def set_current_node(self, node: torch.fx.Node): # type: ignore[no-untyped-def]
1159
+ old = self.current_node
1160
+ try:
1161
+ self.current_node = node
1162
+ yield
1163
+ finally:
1164
+ self.current_node = old
1165
+
1166
+ def try_match_insignificant_strides(
1167
+ self,
1168
+ tensor: Union[ir.TensorBox, ir.BaseView],
1169
+ meta_strides_inp: Tuple[Union[int, torch.SymInt], ...],
1170
+ ) -> Union[ir.TensorBox, ir.BaseView]:
1171
+ """
1172
+ Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant
1173
+ dimensions - size 0 or 1 - will be updated.
1174
+
1175
+ If there are real stride differences (NHWC vs NCHW) then the input will be returned.
1176
+ """
1177
+
1178
+ # should have already been realized
1179
+ assert torch._inductor.ir.is_storage_and_layout(tensor)
1180
+
1181
+ meta_strides = [
1182
+ s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_strides_inp
1183
+ ]
1184
+
1185
+ if all(
1186
+ self.sizevars.statically_known_equals(s1, s2)
1187
+ for s1, s2 in zip(meta_strides, tensor.get_stride())
1188
+ ):
1189
+ return tensor # type: ignore[arg-type]
1190
+
1191
+ def significant_strides_equal(
1192
+ shape: Sequence[Union[Expr, int]],
1193
+ meta_strides: Sequence[Union[Expr, int]],
1194
+ tensor_strides: Sequence[Union[Expr, int]],
1195
+ ) -> bool:
1196
+ for dim, s1, s2 in zip(shape, meta_strides, tensor_strides):
1197
+ if self.sizevars.statically_known_leq(dim, 1): # type: ignore[arg-type]
1198
+ continue
1199
+
1200
+ if not self.sizevars.statically_known_equals(s1, s2):
1201
+ return False
1202
+
1203
+ return True
1204
+
1205
+ if not significant_strides_equal(
1206
+ tensor.get_size(), meta_strides, tensor.get_stride()
1207
+ ):
1208
+ return tensor
1209
+
1210
+ storage, old_layout = torch._inductor.ir.as_storage_and_layout(tensor)
1211
+ new_stride = list(old_layout.stride)
1212
+ for i, s in enumerate(tensor.get_size()):
1213
+ if self.sizevars.statically_known_leq(s, 1): # type: ignore[arg-type]
1214
+ new_stride[i] = meta_strides[i]
1215
+
1216
+ new_layout = torch._inductor.ir.FixedLayout(
1217
+ old_layout.device,
1218
+ old_layout.dtype,
1219
+ old_layout.size,
1220
+ new_stride,
1221
+ old_layout.offset,
1222
+ )
1223
+ return ir.TensorBox(torch._inductor.ir.ReinterpretView(storage, new_layout))
1224
+
1225
+ def propagate_mutation(
1226
+ self,
1227
+ fx_node: torch.fx.Node,
1228
+ old_args: Tuple[Any],
1229
+ old_kwargs: Dict[str, Any],
1230
+ new_args: Tuple[Any],
1231
+ new_kwargs: Dict[str, Any],
1232
+ ) -> None:
1233
+ """Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs.
1234
+
1235
+ Assumes we may have cloned old_args/old_kwargs into new_args/new_kwargs
1236
+ and then called fx_node(*new_args, **new_kwargs).
1237
+
1238
+ If fx_node mutates any of new_args/new_kwargs, and they are different from
1239
+ old_args/old_kwargs, then we need to update the original tensor.
1240
+ """
1241
+ assert isinstance(fx_node.target, torch._ops.OpOverload)
1242
+ assert len(old_args) == len(new_args)
1243
+ assert len(old_kwargs) == len(new_kwargs)
1244
+
1245
+ def maybe_propagate(
1246
+ schema_arg: torch._C.Argument, old_arg: ir.IRNode, new_arg: ir.IRNode
1247
+ ) -> None:
1248
+ if old_arg is new_arg:
1249
+ return
1250
+ if schema_arg.alias_info is not None and schema_arg.alias_info.is_write:
1251
+ # The lowering for copy_ is smart enough to "replace" old_arg with
1252
+ # new_arg in all future uses so a copy_ kernel never gets emitted.
1253
+ self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {})
1254
+
1255
+ schema = fx_node.target._schema
1256
+ for idx, (old_arg, new_arg) in enumerate(zip(old_args, new_args)):
1257
+ schema_arg = schema.arguments[idx]
1258
+ maybe_propagate(schema_arg, old_arg, new_arg)
1259
+
1260
+ schema_kwargs = {arg.name: arg for arg in schema.arguments}
1261
+
1262
+ for key in old_kwargs.keys():
1263
+ old_arg = old_kwargs[key]
1264
+ new_arg = new_kwargs[key]
1265
+ schema_arg = schema_kwargs[key]
1266
+ maybe_propagate(schema_arg, old_arg, new_arg)
1267
+
1268
+ def run_node(self, n: torch.fx.Node) -> object:
1269
+ def debug(msg: str) -> None:
1270
+ log.debug("lowering %s %s", LazyString(n.format_node), msg)
1271
+
1272
+ buffer_watermark = len(self.buffers)
1273
+ operation_watermark = len(self.operations)
1274
+
1275
+ origins = {n}
1276
+ is_call_function = n.op == "call_function"
1277
+ if is_call_function:
1278
+ args, kwargs = self.fetch_args_kwargs_from_env(n)
1279
+ origins |= gather_origins(args, kwargs)
1280
+ with ir.IRNode.current_origins(origins), self.set_current_node( # type: ignore[arg-type]
1281
+ n
1282
+ ), V.set_current_node(
1283
+ n
1284
+ ):
1285
+ if (
1286
+ n.op == "call_function"
1287
+ and n.target is not operator.getitem
1288
+ and fallback_node_due_to_unsupported_type(n)
1289
+ ):
1290
+ debug("fallback_handler")
1291
+ result = fallback_handler(n.target, add_to_fallback_set=False)(
1292
+ *args, **kwargs # type: ignore[possibly-undefined]
1293
+ )
1294
+ elif n.op == "call_function" and (
1295
+ layout_constraints := maybe_layout_constraints(n.target) # type: ignore[arg-type]
1296
+ ):
1297
+ debug("layout_constraints")
1298
+ old_args = args # type: ignore[possibly-undefined]
1299
+ old_kwargs = kwargs # type: ignore[possibly-undefined]
1300
+ args, kwargs = layout_constraints(n, *args, **kwargs) # type: ignore[index]
1301
+ result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type]
1302
+ # layout_constraints are allowed to make new copies of the inputs.
1303
+ # if they do, and if the target is mutable, then we need to
1304
+ # write the new values back into the original inputs.
1305
+ self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined]
1306
+ elif is_magic_method(n.target):
1307
+ # TODO: this is sus, it probably should be handled in the
1308
+ # lowerings themselves similarly to sym_size/sym-stride
1309
+ # https://github.com/pytorch/pytorch/issues/127789
1310
+ debug("is_magic_method")
1311
+ if isinstance(
1312
+ n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool)
1313
+ ):
1314
+ result = n.meta["val"].node.expr
1315
+ else:
1316
+ result = super().run_node(n)
1317
+ else:
1318
+ debug("")
1319
+ result = super().run_node(n)
1320
+
1321
+ # require the same stride order for dense outputs,
1322
+ # 1. user-land view() will not throw because inductor
1323
+ # output different strides than eager
1324
+ # long term the solution is to make view() always succeed
1325
+ # with infallible strides.
1326
+ # 2: as_strided ops, we need make sure its input has same size/stride with
1327
+ # eager model to align with eager behavior.
1328
+ as_strided_ops = [
1329
+ torch.ops.aten.as_strided.default,
1330
+ torch.ops.aten.as_strided_.default,
1331
+ torch.ops.aten.as_strided_scatter.default,
1332
+ torch.ops.aten.resize.default,
1333
+ torch.ops.aten.resize_as.default,
1334
+ ]
1335
+ is_output = any(user.op == "output" for user in n.users)
1336
+ is_input_for_as_strided = any(
1337
+ user.target in as_strided_ops for user in n.users
1338
+ )
1339
+
1340
+ if n.meta.get("inductor_realize_to_strides", False) and isinstance(
1341
+ result, TensorBox
1342
+ ):
1343
+ result.realize()
1344
+ strides = n.meta["val"].stride()
1345
+ sym_strides = torch._inductor.utils.any_is_symbolic(*strides)
1346
+ if (
1347
+ not hasattr(result, "get_stride")
1348
+ or result.get_stride() != strides
1349
+ and not sym_strides
1350
+ ):
1351
+ stride_order = ir.get_stride_order(strides)
1352
+ result = ir.ExternKernel.require_stride_order(result, stride_order)
1353
+ if (
1354
+ is_output
1355
+ and isinstance(result, TensorBox)
1356
+ and isinstance(result.data, ir.BaseView)
1357
+ ):
1358
+ # Realize so that outputs are correctly aliased
1359
+ result.realize()
1360
+
1361
+ if (is_output or is_input_for_as_strided) and isinstance(
1362
+ n.meta["val"], torch.Tensor
1363
+ ):
1364
+ strides = n.meta["val"].stride()
1365
+ if len(strides):
1366
+ allow_padding = (
1367
+ config.pad_outputs or n.name not in self.user_visible_outputs
1368
+ ) and not is_input_for_as_strided
1369
+ dense = torch._prims_common.is_non_overlapping_and_dense(
1370
+ n.meta["val"]
1371
+ )
1372
+ unbacked_symbols_in_strides = (
1373
+ len(free_unbacked_symbols(strides)) > 0
1374
+ )
1375
+ if (
1376
+ not unbacked_symbols_in_strides
1377
+ and dense
1378
+ and len(result.get_size()) == 4
1379
+ and n in self.nodes_prefer_channels_last
1380
+ and n.name not in self.user_visible_outputs
1381
+ and not is_input_for_as_strided
1382
+ ):
1383
+ strides = ir.FlexibleLayout.stride_ordered_for_memory_format(
1384
+ result.get_size(), torch.channels_last
1385
+ )
1386
+ if not unbacked_symbols_in_strides and len(strides):
1387
+ # To avoid converting possible view ops to a copy kernel, we use the previous
1388
+ # require_exact_strides to handle views. But ultimately it's better to require
1389
+ # the right strides at the tensor definition.
1390
+ if n.meta["val"]._is_view() or isinstance(
1391
+ result.data, ir.BaseView
1392
+ ):
1393
+ result = ir.ExternKernel.require_stride_order(
1394
+ result,
1395
+ ir.get_stride_order(strides),
1396
+ allow_padding=allow_padding,
1397
+ )
1398
+ else:
1399
+ strides = [
1400
+ s.node.expr if isinstance(s, torch.SymInt) else s
1401
+ for s in strides
1402
+ ]
1403
+ result = ir.ExternKernel.require_exact_strides(
1404
+ result, strides, allow_padding=allow_padding
1405
+ )
1406
+
1407
+ # Realize if (1) any user need inputs realized, or (2) there is
1408
+ # already too many reads and rematerializing can be bad.
1409
+ num_users = len(OrderedSet(n.users))
1410
+ if num_users > 1 and isinstance(result, TensorBox):
1411
+ for user in n.users:
1412
+ if user.target in needs_realized_inputs:
1413
+ result.realize_hint()
1414
+ # This inclusion is somewhat controversial (from
1415
+ # discussion between Horace, Natalia, and Elias).
1416
+ # Currently, it's not very clear why this is helpful.
1417
+ # The general idea here is that even though a node may
1418
+ # have FlexibleLayout, we still often *treat* it as if
1419
+ # it was contiguous. This appears to sometimes result in
1420
+ # suboptimal behavior.
1421
+ #
1422
+ # When we do a better job selecting layout, we should
1423
+ # revisit this.
1424
+ need_fixed_layout = [
1425
+ torch.ops.aten.convolution_backward.default,
1426
+ torch.ops.aten.mm.default,
1427
+ torch.ops.aten._int_mm.default,
1428
+ ]
1429
+ need_fixed_channels_last_layout = []
1430
+ if not self.layout_opt:
1431
+ need_fixed_layout.append(torch.ops.aten.convolution.default)
1432
+ if torch._C._has_mkldnn:
1433
+ need_fixed_layout += [
1434
+ torch.ops.mkldnn._linear_pointwise.default,
1435
+ torch.ops.mkldnn._linear_pointwise.binary,
1436
+ torch.ops.aten.mkldnn_rnn_layer.default,
1437
+ torch.ops.onednn.qlinear_pointwise.default,
1438
+ torch.ops.onednn.qlinear_pointwise.tensor,
1439
+ torch.ops.onednn.qlinear_pointwise.binary,
1440
+ torch.ops.onednn.qlinear_pointwise.binary_tensor,
1441
+ ]
1442
+ need_fixed_channels_last_layout += [
1443
+ torch.ops.mkldnn._convolution_pointwise.default,
1444
+ torch.ops.mkldnn._convolution_pointwise.binary,
1445
+ torch.ops.mkldnn._convolution_pointwise_.binary,
1446
+ torch.ops.mkldnn._convolution_transpose_pointwise.default,
1447
+ torch.ops.onednn.qconv2d_pointwise.default,
1448
+ torch.ops.onednn.qconv2d_pointwise.binary,
1449
+ ]
1450
+ if torch._C.has_mkl:
1451
+ need_fixed_layout += [torch.ops.mkl._mkl_linear.default]
1452
+ if user.target in need_fixed_layout:
1453
+ result = ir.ExternKernel.require_stride_order(
1454
+ result,
1455
+ ir.get_stride_order(n.meta["val"].stride()),
1456
+ allow_padding=True,
1457
+ )
1458
+ if (
1459
+ user.target in need_fixed_channels_last_layout
1460
+ and n is user.args[0]
1461
+ ):
1462
+ result = ir.ExternKernel.require_stride_order(
1463
+ result,
1464
+ ir.get_stride_order(
1465
+ make_channels_last_strides_for(n.meta["val"].shape)
1466
+ ),
1467
+ )
1468
+ if user.op == "output":
1469
+ if isinstance(result.data.data, (Pointwise, Reduction)):
1470
+ result.realize()
1471
+
1472
+ # TODO(jansel): introduce a store vs inline choice
1473
+ result.mark_reuse(len(n.users))
1474
+
1475
+ # Realize if the IRNode already has accumulated lots of reads
1476
+ if isinstance(result, TensorBox) and result.has_exceeded_max_reads():
1477
+ # Prevent excessive accumulation in a computed buffer, when
1478
+ # there are multiple branches each with small number of memory
1479
+ # reads, but they converge to a user.
1480
+ result.realize_hint()
1481
+
1482
+ # Realize if a Pointwise has too much stuff to be inlined.
1483
+ # As this may cause RecursionError during Inductor's evaluation.
1484
+ if isinstance(result, TensorBox) and isinstance(result.data, StorageBox):
1485
+ curr = result.data.data
1486
+ if isinstance(curr, Pointwise):
1487
+ # Use inner fn as a rough proxy. Good enough.
1488
+ if curr.has_large_inner_fn():
1489
+ result.realize()
1490
+
1491
+ # This is not complete, but it doesn't have to be: origin_node
1492
+ # tracking is best effort. The logic here critically relies on direct
1493
+ # TensorBox -> StorageBox denoting a non-view; we don't bother trying
1494
+ # to get views to work. Feel free to add any extra cases as needed.
1495
+ #
1496
+ # Note: we can't YOLO tree_map over this result, because if there are
1497
+ # buffers or a view involved, we might not be able to validly assign
1498
+ # the origin_node here.
1499
+ if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox):
1500
+ if isinstance(result.data.data, ir.Loops):
1501
+ result.data.data.origin_node = n
1502
+ elif isinstance(result.data.data, ir.Buffer):
1503
+ result.data.data.origin_node = n
1504
+ if isinstance(result.data.data, ir.ComputedBuffer) and isinstance(
1505
+ result.data.data.data, ir.Loops
1506
+ ):
1507
+ result.data.data.data.origin_node = n
1508
+ # Not really multi-output, can straightforwardly recurse in
1509
+ elif (
1510
+ isinstance(result.data.data, ir.MultiOutput)
1511
+ and not result.data.data.indices
1512
+ ):
1513
+ if isinstance(result.data.data.inputs[0], ir.Buffer):
1514
+ result.data.data.inputs[0].origin_node = n
1515
+
1516
+ self.register_users_of(result)
1517
+
1518
+ new_unbacked_defs: OrderedSet[sympy.Symbol] = OrderedSet()
1519
+ for buf in self.buffers[buffer_watermark:]:
1520
+ new_unbacked_defs |= buf.get_unbacked_symbol_defs()
1521
+ for op in self.operations[operation_watermark:]:
1522
+ new_unbacked_defs |= op.get_unbacked_symbol_defs()
1523
+
1524
+ def format_new_defs() -> str:
1525
+ r = []
1526
+ for buf in self.buffers[buffer_watermark:]:
1527
+ r.append(
1528
+ f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n"
1529
+ )
1530
+ for op in self.operations[operation_watermark:]:
1531
+ r.append(
1532
+ f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n"
1533
+ )
1534
+ return "***\n".join(r)
1535
+
1536
+ if n.op != "placeholder":
1537
+ # Note [Backwards runtime asserts]
1538
+ # Backwards poses an interesting problem for deferred runtime
1539
+ # asserts. In the easy case, we may solely close over data
1540
+ # dependent sized tensors, and there are no binding sites for
1541
+ # unbacked SymInts. In this case, we can just drop all the
1542
+ # runtime asserts on the floor: no non-placeholder bindings, no
1543
+ # problem.
1544
+ #
1545
+ # However, it is *possible* for a fresh runtime assert to show up
1546
+ # between forwards and backwards. Right now, the freezing process
1547
+ # that happens when we lower forwards means that we will freeze
1548
+ # runtime asserts, and then the moment the backwards lowering
1549
+ # process attempts to add a new deferred runtime assert, we will
1550
+ # fail. Let's say you remove that assert. Now when we get here,
1551
+ # we need to make sure we actually emit these asserts (because we
1552
+ # can't emit them in forwards, we already compiled it). So we
1553
+ # have to do something here. But we don't want to reemit ALL
1554
+ # deferred runtime asserts, we only want to emit the NEW ones.
1555
+ # Therefore needing some sort of stratification in the ShapeEnv.
1556
+ # This is all doable, it just hasn't been done yet.
1557
+ shape_env = V.graph.sizevars.shape_env
1558
+
1559
+ def make_assert(expr: Expr, msg: str) -> None:
1560
+ assert_op = ir.AssertScalar(expr, msg)
1561
+ self.register_buffer(assert_op, set_name=True)
1562
+ self.register_operation(assert_op)
1563
+
1564
+ for i0 in new_unbacked_defs:
1565
+ ras = self.ras_by_symbol.pop(i0, [])
1566
+ # NB: size-like not needed, we won't retrace
1567
+ vr = shape_env.var_to_range[i0]
1568
+ if not shape_env._default_unspecified_value_range().issubset(vr):
1569
+
1570
+ def is_convertible(s: Expr) -> bool:
1571
+ if s in (int_oo, -int_oo):
1572
+ return False
1573
+ try:
1574
+ int(s)
1575
+ return True
1576
+ except TypeError:
1577
+ return False
1578
+
1579
+ if is_convertible(vr.lower):
1580
+ make_assert(i0 >= vr.lower, f"{i0} >= {vr.lower}")
1581
+ if is_convertible(vr.upper):
1582
+ make_assert(i0 <= vr.upper, f"{i0} <= {vr.upper}")
1583
+
1584
+ for ra in ras:
1585
+ fvs = free_unbacked_symbols(ra.expr)
1586
+ missing = fvs - self.bound_unbacked_symbols
1587
+ if missing:
1588
+ i1 = min(missing, key=str)
1589
+ self.ras_by_symbol.setdefault(i1, []).append(ra)
1590
+ else:
1591
+ make_assert(ra.expr, f"{ra.expr}")
1592
+
1593
+ self.bound_unbacked_symbols |= new_unbacked_defs
1594
+
1595
+ unbacked_bindings = resolve_unbacked_bindings(
1596
+ V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {})
1597
+ )
1598
+ # When we do lowering, it is possible we reallocate unbacked SymInts.
1599
+ # So we need to line up the unbacked SymInts when performing the test
1600
+ # here
1601
+ #
1602
+ # In principle, we could permit lowering to introduce MORE unbacked
1603
+ # SymInts: as long as all the old unbacked ones are accounted for,
1604
+ # it's fine for inductor to introduce extra calls to item()/unbacked()
1605
+ # whatever. This actually happens in practice when an unbacked SymInt
1606
+ # gets memoized away; naively, when Inductor reprocesses a kernel, it
1607
+ # doesn't know that the memo still applies, and ends up allocating a
1608
+ # new symbol. However, this is generally a bad thing: we may still
1609
+ # end up needing to test equalities on the symbols, and a fresh
1610
+ # symbol is likely to hit lots of GuardOnDataDependent errors that
1611
+ # we already know facts for.
1612
+ renamed_unbacked_bindings = OrderedSet(
1613
+ V.fake_mode.shape_env.unbacked_renamings.get(s, s)
1614
+ for s in unbacked_bindings.keys()
1615
+ )
1616
+ assert new_unbacked_defs >= renamed_unbacked_bindings, (
1617
+ f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n"
1618
+ f"fx node is: {n.format_node()}\n"
1619
+ f"new operations are:\n\n{format_new_defs()}"
1620
+ )
1621
+
1622
+ return result
1623
+
1624
+ def validate_can_generate_cpp_wrapper(self) -> None:
1625
+ if config.disable_cpp_codegen:
1626
+ raise CppWrapperCodeGenError("C++ codegen is disabled")
1627
+
1628
+ if sys.platform not in ["linux", "darwin", "win32"]:
1629
+ raise CppWrapperCodeGenError(f"Unsupported platform {sys.platform}")
1630
+
1631
+ for value in self.graph_inputs.values():
1632
+ dtype = None
1633
+ if isinstance(value, TensorBox):
1634
+ dtype = value.get_dtype()
1635
+ elif isinstance(
1636
+ value, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
1637
+ ):
1638
+ dtype = may_get_constant_buffer_dtype(value)
1639
+
1640
+ if not supported_dtype_of_cpp_wrapper(dtype, self.cuda):
1641
+ raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}")
1642
+
1643
+ def init_wrapper_code(self) -> None:
1644
+ self.cuda = "cuda" in self.device_types
1645
+ if self.cpp_wrapper:
1646
+ self.validate_can_generate_cpp_wrapper()
1647
+
1648
+ device_types = self.device_types.copy()
1649
+ device_types.discard("cpu")
1650
+ device_types.discard("meta")
1651
+ # TODO(Eikan): Only support mixing cpu and other device now.
1652
+ assert len(device_types) <= 1, "Does not support mixing {}".format(
1653
+ "+".join(device_types)
1654
+ )
1655
+ only_cpu = len(device_types) == 0
1656
+ device_type = "cpu" if only_cpu else device_types.pop()
1657
+
1658
+ self.device_ops = get_device_op_overrides(device_type)
1659
+ wrapper_code_gen_cls = get_wrapper_codegen_for_device(
1660
+ device_type, self.cpp_wrapper
1661
+ )
1662
+ assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported"
1663
+ self.wrapper_code = wrapper_code_gen_cls()
1664
+
1665
+ if self.const_module:
1666
+ # If we have const module, we could reuse the kernels
1667
+ # This could avoid duplication and save time on doing recompilation (if Triton.)
1668
+ self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter
1669
+ self.wrapper_code.src_to_kernel = (
1670
+ self.const_module.wrapper_code.src_to_kernel
1671
+ )
1672
+
1673
+ def codegen_with_cpp_wrapper(self) -> Tuple[str, List[Tuple[int, Node]]]:
1674
+ """
1675
+ For CPU, the cpp wrapper codegen is done in one pass.
1676
+ For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python
1677
+ wrapper code and run it to generate autotuned kernel binaries in the first pass; and then
1678
+ generate cpp wrapper code and compile it to a dynamic library in the second pass.
1679
+ """
1680
+ if "cuda" in self.device_types:
1681
+ # first pass
1682
+ self.cpp_wrapper = False
1683
+ # Although triton.store_cubin was OrderedSet in compile_fx, the backward pass didn't pick
1684
+ # that up. In theory it should work by only setting triton.store_cubin to True here,
1685
+ # but that will cause a problem when use_runtime_constant_folding is OrderedSet.
1686
+ with config.patch({"triton.store_cubin": True}):
1687
+ compiled = self.compile_to_module().call
1688
+
1689
+ if not config.triton.autotune_at_compile_time:
1690
+
1691
+ def materialize(
1692
+ x: Union[torch.SymInt, torch.SymFloat, torch.Tensor]
1693
+ ) -> Union[int, float, torch.Tensor]:
1694
+ if x is None:
1695
+ return None
1696
+ elif isinstance(x, (torch.SymInt, torch.SymFloat)):
1697
+ # Need concrete value to run dynamic shapes and tune the result
1698
+ return x.node.hint
1699
+ elif isinstance(x, FakeTensor):
1700
+ return defake(x)
1701
+ else:
1702
+ assert isinstance(
1703
+ x, torch.Tensor
1704
+ ), "Unknown type when creating real inputs" + str(type(x))
1705
+ return x
1706
+
1707
+ tracing_context = torch._guards.TracingContext.try_get()
1708
+ if tracing_context is not None and not isinstance(
1709
+ V.real_inputs, NullHandler
1710
+ ):
1711
+ if tracing_context.output_strides:
1712
+ tracing_context.output_strides.clear()
1713
+
1714
+ params_flat = [
1715
+ param
1716
+ for param in tracing_context.params_flat # type: ignore[union-attr]
1717
+ if param is not None
1718
+ ]
1719
+ real_inputs = [
1720
+ materialize(x)
1721
+ for x in itertools.chain(params_flat, V.real_inputs)
1722
+ ]
1723
+ else:
1724
+ # In the backward pass, V.real_inputs is not OrderedSet.
1725
+ # Generating random inputs based on self.example_inputs sometimes can be problematic,
1726
+ # e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
1727
+ real_inputs = [
1728
+ materialize(x)
1729
+ for x in (
1730
+ self.example_inputs
1731
+ if isinstance(V.real_inputs, NullHandler)
1732
+ else V.real_inputs
1733
+ )
1734
+ ]
1735
+
1736
+ if self.mutated_inputs:
1737
+ from .compile_fx import clone_preserve_strides
1738
+
1739
+ mutated_input_idxs = [
1740
+ idx
1741
+ for idx, name in enumerate(self.graph_inputs)
1742
+ if name in self.mutated_inputs
1743
+ and isinstance(real_inputs[idx], torch.Tensor)
1744
+ ]
1745
+ for idx in mutated_input_idxs:
1746
+ # clone mutated Tensor inputs to avoid mutating them in
1747
+ # the first pass of the CPP wrapper-based compilation, as
1748
+ # this will lead to a side effect on the example inputs:
1749
+ # e.g. if torch.compile(f)(x) if called on input-mutating
1750
+ # f, the inputs x will be mutated twice in the process:
1751
+ # once here, and again when running the compiled model;
1752
+ # this will also lead to a numerically incorrect output
1753
+ mutated_inp = real_inputs[idx]
1754
+ assert isinstance(mutated_inp, torch.Tensor)
1755
+ real_inputs[idx] = clone_preserve_strides(mutated_inp)
1756
+ del mutated_inp
1757
+
1758
+ with torch.utils._python_dispatch._disable_current_modes():
1759
+ compiled(real_inputs)
1760
+ del real_inputs
1761
+
1762
+ # second pass
1763
+ self.cpp_wrapper = True
1764
+ self.removed_buffers.clear()
1765
+ self.removed_operations.clear()
1766
+ self.inplaced_to_remove.clear()
1767
+ V.graph.sizevars.precomputed_replacements.clear()
1768
+ V.graph.sizevars.inv_precomputed_replacements.clear()
1769
+ with config.patch({"triton.autotune_at_compile_time": False}):
1770
+ return self.codegen()
1771
+ else:
1772
+ # cpu
1773
+ return self.codegen()
1774
+
1775
+ def codegen(self) -> Tuple[str, List[Tuple[int, Node]]]:
1776
+ from .scheduler import Scheduler
1777
+
1778
+ self.init_wrapper_code()
1779
+
1780
+ self.scheduler = Scheduler(self.operations)
1781
+ V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
1782
+
1783
+ self.wrapper_code.push_codegened_graph(self)
1784
+ self.scheduler.codegen()
1785
+
1786
+ log.debug(
1787
+ "Finished codegen for all nodes. The list of kernel names available: %s",
1788
+ V.graph.all_codegen_kernel_names,
1789
+ )
1790
+
1791
+ result = self.wrapper_code.generate(self.is_inference)
1792
+ self.wrapper_code.pop_codegened_graph()
1793
+ return result
1794
+
1795
+ def codegen_subgraph(self, parent_graph: "GraphLowering") -> None:
1796
+ """
1797
+ This is a more compact version of the `codegen()` above
1798
+ where we codegen this graph as a subgraph of some parent
1799
+ graph. The parent graph is passed as an argument: the
1800
+ intention is to inline codegening of the subgraph in
1801
+ the parent graph's wrapper code (including the generated
1802
+ kerenls). The wrapper code is not finalized (via `.generate()`
1803
+ call), as this will be done in the parent graph's `codegen()`.
1804
+ """
1805
+ from .scheduler import Scheduler
1806
+
1807
+ self.wrapper_code = parent_graph.wrapper_code
1808
+ self.device_ops = parent_graph.device_ops
1809
+ self.cpp_wrapper = parent_graph.cpp_wrapper
1810
+
1811
+ self.scheduler = Scheduler(self.operations)
1812
+ self.scheduler.codegen()
1813
+
1814
+ def count_bytes(
1815
+ self,
1816
+ ) -> Tuple[
1817
+ int, List[Tuple[BaseSchedulerNode, int]], List[Tuple[BaseSchedulerNode, float]]
1818
+ ]:
1819
+ total_bytes = 0
1820
+ node_counts = []
1821
+ node_runtimes = []
1822
+ for node in self.scheduler.nodes:
1823
+ num_bytes = node.get_read_write_buffers_sizes()
1824
+ total_bytes += num_bytes
1825
+ node_counts.append((node, num_bytes // 4))
1826
+ node_runtimes.append((node, node.get_estimated_runtime()))
1827
+
1828
+ return total_bytes, node_counts, node_runtimes
1829
+
1830
+ @staticmethod
1831
+ def save_output_code(code: str) -> None:
1832
+ # No-op to be patched for unit tests
1833
+ pass
1834
+
1835
+ def compile_to_module(self) -> ModuleType:
1836
+ with dynamo_timed(
1837
+ "GraphLowering.compile_to_module", phase_name="code_gen", fwd_only=False
1838
+ ):
1839
+ return self._compile_to_module()
1840
+
1841
+ def _compile_to_module(self) -> ModuleType:
1842
+ from .codecache import PyCodeCache
1843
+
1844
+ code, linemap = (
1845
+ self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
1846
+ )
1847
+
1848
+ GraphLowering.save_output_code(code)
1849
+ output_code_log.debug("Output code: \n%s", code)
1850
+ try:
1851
+ linemap = [(line_no, node.stack_trace) for line_no, node in linemap] # type: ignore[misc]
1852
+ key, path = PyCodeCache.write(code)
1853
+ except Exception:
1854
+ trace_structured(
1855
+ "inductor_output_code",
1856
+ # Just omit the filename, I still want the code though!
1857
+ payload_fn=lambda: code,
1858
+ )
1859
+ raise
1860
+ else:
1861
+ trace_structured(
1862
+ "inductor_output_code",
1863
+ lambda: {"filename": path},
1864
+ payload_fn=lambda: code,
1865
+ )
1866
+
1867
+ mod = PyCodeCache.load_by_key_path(
1868
+ key,
1869
+ path,
1870
+ linemap=linemap, # type: ignore[arg-type]
1871
+ attrs={**self.constants, **self.torchbind_constants},
1872
+ )
1873
+ self.cache_key = key
1874
+ self.cache_path = path
1875
+ self.cache_linemap = linemap # type: ignore[assignment]
1876
+
1877
+ # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
1878
+ # TODO. Revisit this once the logging API is more mature
1879
+ assert mod.__file__ is not None
1880
+
1881
+ log_module_code(mod.__file__)
1882
+ log.debug("Output code written to: %s", mod.__file__)
1883
+ output_code_log.info("Output code written to: %s", mod.__file__)
1884
+ if config.benchmark_kernel:
1885
+ print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
1886
+ V.debug.output_code(mod.__file__)
1887
+ V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
1888
+ return mod
1889
+
1890
+ def compile_to_fn(self) -> Any:
1891
+ if self.aot_mode:
1892
+ from .codecache import AotCodeCompiler
1893
+
1894
+ assert self.cpp_wrapper, "AOT mode only supports C++ wrapper"
1895
+ code, linemap = self.codegen_with_cpp_wrapper()
1896
+ output_code_log.debug("Output code: \n%s", code)
1897
+
1898
+ serialized_extern_kernel_nodes = None
1899
+ if self.extern_kernel_nodes:
1900
+ serialized_extern_kernel_nodes = self.extern_node_serializer(
1901
+ self.extern_kernel_nodes
1902
+ )
1903
+ output_code_log.debug(
1904
+ "Serialized Extern Kernel Nodes: \n%s",
1905
+ serialized_extern_kernel_nodes,
1906
+ )
1907
+
1908
+ # Directly return the file path with the compiled code
1909
+ return AotCodeCompiler.compile(
1910
+ self, code, serialized_extern_kernel_nodes, cuda=self.cuda
1911
+ )
1912
+ else:
1913
+ return self.compile_to_module().call
1914
+
1915
+ def get_output_names(self) -> List[str]:
1916
+ return [
1917
+ node.get_name()
1918
+ for node in self.graph_outputs
1919
+ if not isinstance(node, ir.NoneAsConstantBuffer)
1920
+ and not isinstance(node, ir.ShapeAsConstantBuffer)
1921
+ ]
1922
+
1923
+ def is_unspec_arg(self, name: str) -> bool:
1924
+ # dynamo wraps unspec variable as 0d CPU tensor,
1925
+ # need to convert to scalar during codegen (triton only)
1926
+ return (
1927
+ name in self.graph_inputs.keys()
1928
+ and self.graph_inputs[name].get_numel() == 1
1929
+ and self.graph_inputs[name].get_device().type == "cpu"
1930
+ ) or name in self.zero_dim_cpu_tensor_list
.venv/lib/python3.11/site-packages/torch/_inductor/hooks.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ from typing import Callable, List, TYPE_CHECKING
4
+
5
+
6
+ if TYPE_CHECKING:
7
+ import torch
8
+
9
+ # Executed in the order they're registered
10
+ INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = []
11
+
12
+
13
+ @contextlib.contextmanager
14
+ def intermediate_hook(fn):
15
+ INTERMEDIATE_HOOKS.append(fn)
16
+ try:
17
+ yield
18
+ finally:
19
+ INTERMEDIATE_HOOKS.pop()
20
+
21
+
22
+ def run_intermediate_hooks(name, val):
23
+ global INTERMEDIATE_HOOKS
24
+ hooks = INTERMEDIATE_HOOKS
25
+ INTERMEDIATE_HOOKS = []
26
+ try:
27
+ for hook in hooks:
28
+ hook(name, val)
29
+ finally:
30
+ INTERMEDIATE_HOOKS = hooks
.venv/lib/python3.11/site-packages/torch/_inductor/index_propagation.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """This file implements the IndexPropagation ops handler, which wraps an
3
+ underlying handler to add a limited form of constant propagation, as well as
4
+ propagation of sympy expressions downstream of ops.index_expr calls.
5
+
6
+ For example, say we have the IR:
7
+
8
+ tmp0 = ops.index_expr(x, torch.int32)
9
+ tmp1 = ops.constant(2, torch.int32)
10
+ tmp2 = ops.mul(tmp0, tmp1)
11
+ tmp3 = ops.indirect_indexing(tmp2, x_size)
12
+ tmp4 = ops.load("buf0", tmp3)
13
+
14
+ The underlying handler would just see:
15
+
16
+ ops.load("buf0", x * 2)
17
+
18
+ This is limited by the set of operators handled in the sympy expression
19
+ printers. So simple operations like minimum and maximum cannot be translated to
20
+ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
21
+
22
+ """
23
+ import itertools
24
+ from dataclasses import dataclass
25
+ from typing import Any, Callable, Dict, Literal, Optional, overload, Tuple, Union
26
+ from typing_extensions import TypeAlias
27
+
28
+ import sympy
29
+
30
+ import torch
31
+ from torch._prims_common import dtype_to_type, is_integer_dtype
32
+ from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
33
+ from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
34
+
35
+ from .sizevars import evaluate_expr
36
+ from .utils import generate_assert
37
+ from .virtualized import V
38
+
39
+
40
+ _ExprType = Union[sympy.Expr, float, int, bool]
41
+
42
+
43
+ def _is_constant(val: _ExprType):
44
+ if isinstance(val, sympy.Basic):
45
+ return val.is_number
46
+ return isinstance(val, (int, float, bool))
47
+
48
+
49
+ def upper_bound(val: _ExprType):
50
+ return bound_sympy(val).upper if isinstance(val, sympy.Expr) else val
51
+
52
+
53
+ @dataclass
54
+ class TypedExpr:
55
+ """A SymPy expression with associated type"""
56
+
57
+ expr: _ExprType
58
+ dtype: torch.dtype
59
+
60
+ def is_constant(self):
61
+ return _is_constant(self.expr)
62
+
63
+ def __post_init__(self):
64
+ if _is_constant(self.expr):
65
+ self.expr = dtype_to_type(self.dtype)(self.expr)
66
+
67
+
68
+ class SymPyOps:
69
+ """An ops handler where all IR values are SymPy expressions
70
+
71
+ When a value cannot be represented as a SymPy expression, the method is
72
+ either not defined, or returns NotImplemented
73
+
74
+ """
75
+
76
+ @staticmethod
77
+ def identity(value: Any) -> Any:
78
+ return value
79
+
80
+ @staticmethod
81
+ def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr:
82
+ return TypedExpr(value, dtype)
83
+
84
+ @staticmethod
85
+ def index_expr(value: Union[sympy.Expr, int], dtype: torch.dtype) -> TypedExpr:
86
+ return TypedExpr(value, dtype)
87
+
88
+ @staticmethod
89
+ def to_dtype(
90
+ value: TypedExpr,
91
+ dtype: torch.dtype,
92
+ src_dtype: Optional[torch.dtype] = None,
93
+ use_compute_types: bool = False,
94
+ ) -> TypedExpr:
95
+ return TypedExpr(value.expr, dtype)
96
+
97
+ @staticmethod
98
+ def abs(x: TypedExpr) -> TypedExpr:
99
+ return TypedExpr(abs(x.expr), x.dtype) # type: ignore[arg-type]
100
+
101
+ @staticmethod
102
+ def square(x: TypedExpr) -> TypedExpr:
103
+ return TypedExpr(x.expr * x.expr, x.dtype)
104
+
105
+ @staticmethod
106
+ def add(x: TypedExpr, y: TypedExpr) -> TypedExpr:
107
+ result_type = torch.promote_types(x.dtype, y.dtype)
108
+ return TypedExpr(x.expr + y.expr, result_type)
109
+
110
+ @staticmethod
111
+ def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr:
112
+ result_type = torch.promote_types(x.dtype, y.dtype)
113
+ return TypedExpr(x.expr - y.expr, result_type)
114
+
115
+ @staticmethod
116
+ def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr:
117
+ result_type = torch.promote_types(x.dtype, y.dtype)
118
+ return TypedExpr(x.expr * y.expr, result_type)
119
+
120
+ @staticmethod
121
+ def neg(x: TypedExpr) -> TypedExpr:
122
+ return TypedExpr(-x.expr, x.dtype)
123
+
124
+ @staticmethod
125
+ def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr:
126
+ result_type = torch.promote_types(x.dtype, y.dtype)
127
+ if not is_integer_dtype(result_type):
128
+ return NotImplemented
129
+
130
+ return TypedExpr(FloorDiv(x.expr, y.expr), result_type)
131
+
132
+ @staticmethod
133
+ def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
134
+ result_type = torch.promote_types(x.dtype, y.dtype)
135
+ if not is_integer_dtype(result_type):
136
+ return NotImplemented
137
+
138
+ result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
139
+ return TypedExpr(result_expr, result_type)
140
+
141
+ @staticmethod
142
+ def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
143
+ result_type = torch.promote_types(x.dtype, y.dtype)
144
+ if not is_integer_dtype(result_type):
145
+ return NotImplemented
146
+
147
+ x_expr = sympy.sympify(x.expr)
148
+ y_expr = sympy.sympify(y.expr)
149
+ # In these cases, remainder in Python == remainder in C++, so this transformation
150
+ # is sound
151
+ if (
152
+ x_expr.is_nonnegative is not None
153
+ and x_expr.is_nonnegative == y_expr.is_positive
154
+ ):
155
+ result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
156
+ return TypedExpr(result_expr, result_type)
157
+ return NotImplemented
158
+
159
+ @staticmethod
160
+ def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
161
+ result_type = torch.promote_types(x.dtype, y.dtype)
162
+ return TypedExpr(sympy.Min(x.expr, y.expr), result_type)
163
+
164
+ @staticmethod
165
+ def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
166
+ result_type = torch.promote_types(x.dtype, y.dtype)
167
+ return TypedExpr(sympy.Max(x.expr, y.expr), result_type)
168
+
169
+
170
+ @dataclass
171
+ class IndexPropVar:
172
+ value: Any # Either an IR value, or TypedExpr if is_symbolic is true
173
+ is_symbolic: bool = False
174
+
175
+ @staticmethod
176
+ def new_symbolic(expr: TypedExpr) -> "IndexPropVar":
177
+ return IndexPropVar(expr, is_symbolic=True)
178
+
179
+ def __post_init__(self):
180
+ assert not self.is_symbolic or isinstance(
181
+ self.value, TypedExpr
182
+ ), "Symbolic IndexPropVar must contain a TypedExpr"
183
+
184
+
185
+ IndexPropResult: TypeAlias = Union[IndexPropVar, Tuple["IndexPropResult", ...]]
186
+
187
+
188
+ class IndexPropagation:
189
+ """Ops wrapper that tries to propagate constant and index_expr values through the computation.
190
+
191
+ This aims to maximize the compile time simplification possible, and convert
192
+ indirect indexing from arange into normal static indexing.
193
+
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ inner: Any,
199
+ iter_ranges: Dict[sympy.Symbol, sympy.Expr],
200
+ indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr],
201
+ ) -> None:
202
+ self._inner = inner
203
+ self.shape_env = V.graph.sizevars.shape_env
204
+
205
+ var_to_range = {
206
+ k: ValueRanges(0, upper_bound(v) - 1) for k, v in iter_ranges.items()
207
+ }
208
+ self.var_to_range = tuple(
209
+ itertools.chain(self.shape_env.var_to_range.items(), var_to_range.items())
210
+ )
211
+ # NOTE: this is intentionally kept as a reference so the caller can
212
+ # update it in-place
213
+ self.indirect_var_ranges = indirect_var_ranges
214
+
215
+ axioms = []
216
+ for x, s in iter_ranges.items():
217
+ axioms.append(0 <= x)
218
+ axioms.append(x < s)
219
+ self.axioms = tuple(axioms) + self.shape_env.get_axioms()
220
+
221
+ def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any:
222
+ # Construct a new constant/index_expr from the SymPy expression
223
+ if _is_constant(expr):
224
+ val = dtype_to_type(dtype)(expr)
225
+ return self._inner.constant(val, dtype)
226
+ return self._inner.index_expr(expr, dtype)
227
+
228
+ def unwrap(self, a: Union[Any, IndexPropVar]) -> Any:
229
+ if isinstance(a, (list, tuple)):
230
+ return tuple(self.unwrap(v) for v in a)
231
+
232
+ if not isinstance(a, IndexPropVar):
233
+ return a
234
+
235
+ # Prefer the sympy representation if possible
236
+ if a.is_symbolic:
237
+ return self.materialize_expr(a.value.expr, a.value.dtype)
238
+
239
+ return a.value
240
+
241
+ def wrap(self, a) -> IndexPropResult:
242
+ if isinstance(a, (list, tuple)):
243
+ return tuple(self.wrap(v) for v in a)
244
+ return IndexPropVar(a)
245
+
246
+ @overload
247
+ def fallback(
248
+ self,
249
+ name: Literal["indirect_indexing"],
250
+ args: Tuple[Any, ...],
251
+ kwargs: Dict[str, Any],
252
+ ) -> IndexPropVar:
253
+ ...
254
+
255
+ @overload
256
+ def fallback(
257
+ self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
258
+ ) -> IndexPropResult:
259
+ ...
260
+
261
+ def fallback(
262
+ self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
263
+ ) -> IndexPropResult:
264
+ # Fallback to the wrapped handler
265
+ new_args = [self.unwrap(a) for a in args]
266
+ new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()}
267
+ return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
268
+
269
+ def propagate_sympy(
270
+ self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
271
+ ) -> IndexPropResult:
272
+ # Build a new SymPy expression from this ops call
273
+ def unwrap(a: Union[Any, IndexPropVar]) -> Any:
274
+ if not isinstance(a, IndexPropVar):
275
+ return a
276
+ return a.value
277
+
278
+ new_args = [unwrap(a) for a in args]
279
+ new_kwargs = {k: unwrap(v) for k, v in kwargs.items()}
280
+ new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs)
281
+ is_valid_expr = new_expr is not NotImplemented and (
282
+ # Inductor doesn't expect floating point in sympy expressions, but
283
+ # allow floating point constants to be propagated
284
+ new_expr.is_constant()
285
+ or new_expr.expr.is_integer
286
+ )
287
+ if not is_valid_expr:
288
+ return self.fallback(name, args, kwargs)
289
+ return IndexPropVar.new_symbolic(new_expr)
290
+
291
+ def __getattr__(self, name: str) -> Callable[..., IndexPropResult]:
292
+ def inner(*args: Any, **kwargs: Any) -> IndexPropResult:
293
+ if not hasattr(SymPyOps, name):
294
+ return self.fallback(name, args, kwargs)
295
+
296
+ var_arguments = [
297
+ a
298
+ for a in itertools.chain(args, kwargs.values())
299
+ if isinstance(a, IndexPropVar)
300
+ ]
301
+ if not all(v.is_symbolic for v in var_arguments):
302
+ return self.fallback(name, args, kwargs)
303
+
304
+ return self.propagate_sympy(name, args, kwargs)
305
+
306
+ return inner
307
+
308
+ def statically_true(self, e):
309
+ """
310
+ Given some iter_ranges, return a function that given an expression, returns whether
311
+ it is true or false using value ranges, guard knowledge and runtime_asserts.
312
+
313
+ FIXME I think this may not be entirely right, as we may not be able to use all runtime_asserts
314
+ If this is an issue, just use guards in `self.axioms`.
315
+
316
+ The proper way of handling this would be to have a global shape_env that adds
317
+ runtime_asserts as they happen in the code. Then, it shuld be used in SimplifyIndexing
318
+ to perform wrap_expr and in CSEProxy.check_bounds to elide upper / lower bounds also
319
+ for indirect_indexing
320
+ """
321
+ var_to_range = (
322
+ *self.var_to_range,
323
+ *(
324
+ (k, ValueRanges(0, upper_bound(v) - 1))
325
+ for k, v in self.indirect_var_ranges.items()
326
+ ),
327
+ )
328
+ return evaluate_expr(self.shape_env, e, self.axioms, var_to_range)
329
+
330
+ def indirect_indexing(
331
+ self,
332
+ index: Union[Any, IndexPropVar],
333
+ size: Any,
334
+ check: bool = True,
335
+ wrap_neg=True,
336
+ ) -> Any:
337
+ if isinstance(index, IndexPropVar) and index.is_symbolic:
338
+ # If we find something we can convert into a direct indexing we do so
339
+ # We still need to (perhaps) wrap the expression and add bound checks
340
+ # We want to do this "constant folding", as we don't allow to fuse
341
+ # kernels into indirect indexing
342
+
343
+ expr = sympy.sympify(index.value.expr)
344
+
345
+ # TODO Perhaps move this logic to the simplify indexing pass
346
+ def wrap_expr(expr):
347
+ # Positive, negative, mixed
348
+ if self.statically_true(0 <= expr):
349
+ return expr
350
+ elif self.statically_true(expr < 0):
351
+ return expr + size
352
+ else:
353
+ return Where(expr < 0, expr + size, expr)
354
+
355
+ # Sometimes it's easier to prove 0 <= expr than the weaker -size <= expr
356
+ can_prove_lower = self.statically_true(0 <= expr) or self.statically_true(
357
+ -size <= expr
358
+ )
359
+ can_prove_upper = self.statically_true(expr < size)
360
+ if wrap_neg:
361
+ expr = wrap_expr(expr)
362
+ if generate_assert(check):
363
+ self.fallback(
364
+ "check_bounds",
365
+ (expr, size),
366
+ dict(lower=not can_prove_lower, upper=not can_prove_upper),
367
+ )
368
+ return expr
369
+
370
+ indirect_var = self.fallback(
371
+ "indirect_indexing", (index, size, check, wrap_neg), {}
372
+ ).value
373
+ return indirect_var
.venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ from typing import Optional, Sequence
6
+
7
+ import torch
8
+ from torch import _prims, Tensor
9
+
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ def make_prim(
15
+ schema: str,
16
+ impl_aten,
17
+ return_type=_prims.RETURN_TYPE.NEW,
18
+ doc: str = "",
19
+ tags: Optional[Sequence[torch.Tag]] = None,
20
+ ):
21
+ if isinstance(return_type, tuple):
22
+
23
+ def meta(*args, **kwargs):
24
+ return tuple(_prims.TensorMeta(o) for o in impl_aten(*args, **kwargs))
25
+
26
+ else:
27
+
28
+ def meta(*args, **kwargs):
29
+ return _prims.TensorMeta(impl_aten(*args, **kwargs))
30
+
31
+ return _prims._make_prim(
32
+ schema=schema,
33
+ return_type=return_type,
34
+ meta=meta,
35
+ impl_aten=impl_aten,
36
+ doc=doc,
37
+ tags=tags,
38
+ )
39
+
40
+
41
+ def eager_force_stride(input_tensor: Tensor, stride) -> Tensor:
42
+ if input_tensor.stride() == stride:
43
+ return input_tensor
44
+ new_tensor = input_tensor.clone().as_strided(
45
+ input_tensor.shape,
46
+ stride,
47
+ )
48
+ new_tensor.copy_(input_tensor)
49
+ return new_tensor
50
+
51
+
52
+ # Custom prims used for handling randomness
53
+ seed = make_prim(
54
+ "inductor_seed(Device device) -> Tensor",
55
+ lambda device: torch.randint(2**63 - 1, [], device=device),
56
+ doc="create a fresh seed (one per call) for use with inductor_rand",
57
+ tags=(torch.Tag.nondeterministic_seeded,),
58
+ )
59
+ seeds = make_prim(
60
+ "inductor_seeds(int count, Device device) -> Tensor",
61
+ lambda count, device: torch.randint(2**63 - 1, [count], device=device),
62
+ doc="Horizontal fusion of many inductor_seed() calls",
63
+ tags=(torch.Tag.nondeterministic_seeded,),
64
+ )
65
+ lookup_seed = make_prim(
66
+ # if inductor_lookup_seed changes, update partitioners.py
67
+ "inductor_lookup_seed(Tensor seeds, int index) -> Tensor",
68
+ lambda seeds, index: seeds[index],
69
+ doc="Extract a single seed from the result of inductor_seeds()",
70
+ )
71
+ random = make_prim(
72
+ "inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor",
73
+ lambda size, seed, mode: getattr(torch, mode)(size, device=seed.device),
74
+ doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused",
75
+ )
76
+ randint = make_prim(
77
+ "inductor_randint(SymInt low, SymInt high, SymInt[] size, Tensor seed) -> Tensor",
78
+ lambda low, high, size, seed: torch.randint(low, high, size, device=seed.device),
79
+ doc="torch.randint() using backend-specific RNG that can be fused",
80
+ )
81
+ force_stride_order = make_prim(
82
+ "inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor",
83
+ eager_force_stride,
84
+ doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise",
85
+ )
86
+ _unsafe_index_put_ = make_prim(
87
+ "_unsafe_index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)",
88
+ lambda self, indices, values, accumulate=False: torch.ops.aten.index_put_(
89
+ self, indices, values, accumulate
90
+ ),
91
+ doc="Unsafe index_put_ (doesn't issue device asserts)",
92
+ )
93
+ fma = make_prim(
94
+ "fma(Tensor a, Tensor b, Tensor c) -> Tensor",
95
+ lambda a, b, c: (a * b) + c,
96
+ doc="Fused multiply add: fma(a, b, c) -> (a * b) + c without rounding after the multiplication",
97
+ )
98
+
99
+
100
+ def _low_memory_max_pool2d_with_offsets_aten(
101
+ self,
102
+ kernel_size,
103
+ stride,
104
+ padding,
105
+ dilation,
106
+ ceil_mode,
107
+ ):
108
+ vals, indices = torch.ops.aten.max_pool2d_with_indices(
109
+ self, kernel_size, stride, padding, dilation, ceil_mode
110
+ )
111
+
112
+ input_width = self.shape[-1]
113
+ kernel_width = kernel_size[1]
114
+
115
+ bh_shape = [1] * self.ndim
116
+ bh_shape[-2] = -1
117
+ bh = torch.arange(indices.shape[-2], dtype=torch.int64, device=self.device).view(
118
+ bh_shape
119
+ )
120
+
121
+ bw_shape = [1] * self.ndim
122
+ bw_shape[-1] = -1
123
+ bw = torch.arange(indices.shape[-1], dtype=torch.int64, device=self.device).view(
124
+ bw_shape
125
+ )
126
+
127
+ hbase = bh * stride[0] - padding[0]
128
+ wbase = bw * stride[1] - padding[1]
129
+
130
+ ih = indices // input_width
131
+ iw = indices - (ih * input_width)
132
+
133
+ h_inc = ih - hbase
134
+ w_inc = iw - wbase
135
+
136
+ offsets = h_inc * kernel_width + w_inc
137
+
138
+ return vals, offsets.to(torch.int8)
139
+
140
+
141
+ def _low_memory_max_pool2d_offsets_to_indices_aten(
142
+ offsets, kernel_width, input_width, stride, padding
143
+ ):
144
+ offsets = offsets.to(torch.int64)
145
+ h_inc = offsets // kernel_width
146
+ w_inc = offsets - (h_inc * kernel_width)
147
+
148
+ bh_shape = [1] * offsets.ndim
149
+ bh_shape[-2] = -1
150
+ bh = torch.arange(offsets.shape[-2], dtype=torch.int64, device=offsets.device).view(
151
+ bh_shape
152
+ )
153
+
154
+ bw_shape = [1] * offsets.ndim
155
+ bw_shape[-1] = -1
156
+ bw = torch.arange(offsets.shape[-1], dtype=torch.int64, device=offsets.device).view(
157
+ bw_shape
158
+ )
159
+
160
+ hbase = bh * stride[0] - padding[0]
161
+ wbase = bw * stride[1] - padding[1]
162
+
163
+ ih = hbase + h_inc
164
+ iw = wbase + w_inc
165
+ return ih * input_width + iw
166
+
167
+
168
+ _low_memory_max_pool2d_with_offsets = make_prim(
169
+ "_low_memory_max_pool2d_with_offsets(Tensor self, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950
170
+ _low_memory_max_pool2d_with_offsets_aten,
171
+ return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW),
172
+ doc="Instead of returning indices, returns indices offsets.",
173
+ )
174
+
175
+ _low_memory_max_pool2d_offsets_to_indices = make_prim(
176
+ "_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding) -> Tensor", # noqa: B950
177
+ _low_memory_max_pool2d_offsets_to_indices_aten,
178
+ doc="Convert small int offsets to regular indices.",
179
+ )
.venv/lib/python3.11/site-packages/torch/_inductor/ir.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/torch/_inductor/jagged_lowerings.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import sympy
6
+
7
+ import torch
8
+
9
+ from .ir import Pointwise, TensorBox
10
+ from .lowering import fallback_handler, is_integer_type, register_lowering
11
+ from .virtualized import ops
12
+
13
+
14
+ # pyre-ignore[2,3]
15
+ def dense_idx_to_jagged_idx(batch_idx, seq_idx, offsets_loader, jagged_len):
16
+ # jagged_len + 1 is used as the upper bound,
17
+ # because the last sequence length may be zero
18
+ begin_idx = ops.indirect_indexing(
19
+ offsets_loader([batch_idx]),
20
+ jagged_len + 1,
21
+ )
22
+ end_idx = offsets_loader([batch_idx + 1])
23
+ jagged_idx = begin_idx + seq_idx
24
+ return jagged_idx, end_idx
25
+
26
+
27
+ def get_inverse_offsets(
28
+ offsets: TensorBox,
29
+ jagged_len: Union[int, sympy.Expr],
30
+ realize: bool = True,
31
+ ) -> TensorBox:
32
+ """
33
+ Returns "inverse_offsets" - the inverse of the offsets array.
34
+ offsets maps batch index (dense) to jagged index (i.e. offset into jagged tensor).
35
+ inverse_offsets maps jagged index to batch index.
36
+
37
+ e.g. for offsets [0, 3, 4, 9, 10] this will return
38
+ inverse_offsets = [0, 0, 0, 1, 2, 2, 2, 2, 2, 3]
39
+
40
+ For the given offsets, the computed inverse_offsets are cached
41
+ on the first call and reused in the further calls.
42
+ """
43
+
44
+ if hasattr(offsets, "inverse_offsets"):
45
+ # inverse_offsets are already computed
46
+ # for these offsets: can reuse
47
+ return offsets.inverse_offsets
48
+
49
+ # ops.bucketize takes offsets.get_name() which doesn't exist on Pointwise
50
+ # kernels, i.e. we need to realize it before using. In other words, we need
51
+ # offsets to be in global memory so that we can binary search over the
52
+ # entire tensor
53
+ offsets.realize()
54
+ device: torch.device = offsets.get_device()
55
+ dtype: torch.dtype = offsets.get_dtype()
56
+
57
+ # pyre-ignore[2,3]
58
+ def inner_fn(index):
59
+ idx = index[0]
60
+ bucket = ops.bucketize(
61
+ values=ops.index_expr(idx, dtype),
62
+ offsets_name=offsets.get_name(),
63
+ offsets_size=offsets.get_size()[0],
64
+ indexing_dtype=dtype,
65
+ right=True,
66
+ )
67
+ # ops.bucketize above returns 1-based bucket indices,
68
+ # but we need 0-based, hence we subtract 1 from batch
69
+ return bucket - 1
70
+
71
+ inverse_offsets = Pointwise.create(
72
+ device=device,
73
+ dtype=dtype,
74
+ inner_fn=inner_fn,
75
+ ranges=[jagged_len],
76
+ )
77
+
78
+ if realize:
79
+ # "freeze" the node so that it doesn't get inlined downstream.
80
+ inverse_offsets.realize()
81
+
82
+ # cache inverse_offsets for further reuse
83
+ offsets.inverse_offsets = inverse_offsets # type: ignore[attr-defined]
84
+
85
+ return inverse_offsets
86
+
87
+
88
+ def jagged_idx_to_dense_idx(
89
+ jagged_idx, # pyre-ignore[2]
90
+ inverse_offsets_loader, # pyre-ignore[2]
91
+ offsets_loader, # pyre-ignore[2]
92
+ batch_size: Union[int, sympy.Expr],
93
+ max_seq_len: Union[int, sympy.Expr],
94
+ offsets_dtype: torch.dtype,
95
+ ) -> Tuple[sympy.Expr, sympy.Expr]:
96
+ batch_idx = ops.indirect_indexing(
97
+ inverse_offsets_loader([jagged_idx]),
98
+ batch_size + 1,
99
+ )
100
+ batch_start = offsets_loader([batch_idx])
101
+ seq = ops.index_expr(jagged_idx, offsets_dtype) - batch_start
102
+ # check=False because there may be sequences longer than max_seq_len
103
+ seq_idx = ops.indirect_indexing(seq, max_seq_len, check=False)
104
+ return batch_idx, seq_idx
105
+
106
+
107
+ def register_jagged_ops():
108
+ # pyre-ignore[56]
109
+ @register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default)
110
+ def _jagged_to_padded_dense_forward(
111
+ jagged_values: TensorBox,
112
+ jagged_offsets: List[TensorBox],
113
+ max_lengths: List[int], # list of ints/SymInts
114
+ padding_value: float = 0.0,
115
+ ) -> TensorBox:
116
+ device = jagged_values.get_device()
117
+ dtype = jagged_values.get_dtype()
118
+
119
+ jagged_values_size = jagged_values.get_size()
120
+
121
+ # only handle the common case of a single jagged dimension
122
+ if (
123
+ len(jagged_offsets) != 1
124
+ or device.type != "cuda"
125
+ or device != jagged_offsets[0].get_device()
126
+ or len(jagged_values_size) != 2
127
+ or len(jagged_offsets[0].get_size()) != 1
128
+ or len(max_lengths) != len(jagged_offsets)
129
+ or not is_integer_type(jagged_offsets[0])
130
+ ):
131
+ return fallback_handler(
132
+ torch.ops.aten._jagged_to_padded_dense_forward.default,
133
+ add_to_fallback_set=False,
134
+ )(
135
+ jagged_values,
136
+ jagged_offsets,
137
+ max_lengths,
138
+ padding_value,
139
+ )
140
+
141
+ offsets: TensorBox = jagged_offsets[0]
142
+ offsets_len = offsets.get_size()[0]
143
+ offsets_dtype = offsets.get_dtype()
144
+ batch_size = offsets_len - 1
145
+ max_seq_len = max_lengths[0]
146
+ embedding_len = jagged_values_size[1]
147
+ jagged_len = jagged_values_size[0]
148
+
149
+ output_size = [batch_size, max_seq_len, embedding_len]
150
+
151
+ values_loader = jagged_values.make_loader()
152
+ offsets_loader = offsets.make_loader()
153
+
154
+ # pyre-ignore[2,3,53]
155
+ def inner_fn(index):
156
+ # dense tensor size: [B, N, D]
157
+ batch_idx, seq_idx, emb_idx = index
158
+ jagged_idx, end_idx = dense_idx_to_jagged_idx(
159
+ batch_idx=batch_idx,
160
+ seq_idx=seq_idx,
161
+ offsets_loader=offsets_loader,
162
+ jagged_len=jagged_len,
163
+ )
164
+ return ops.masked(
165
+ ops.lt(
166
+ ops.index_expr(jagged_idx, offsets_dtype),
167
+ end_idx,
168
+ ),
169
+ lambda: values_loader([jagged_idx, emb_idx]),
170
+ padding_value,
171
+ )
172
+
173
+ return Pointwise.create(
174
+ device=device,
175
+ dtype=dtype,
176
+ inner_fn=inner_fn,
177
+ ranges=output_size,
178
+ )
179
+
180
+ def _dense_to_jagged_forward_impl(
181
+ fallback_op, # pyre-ignore[2]
182
+ dense: TensorBox,
183
+ jagged_offsets: List[TensorBox],
184
+ jagged_len: Optional[int] = None,
185
+ ) -> TensorBox:
186
+ device = dense.get_device()
187
+ dtype = dense.get_dtype()
188
+
189
+ dense_size = dense.get_size()
190
+
191
+ # only handle the common case of a single jagged dimension
192
+ if (
193
+ len(jagged_offsets) != 1
194
+ or device.type != "cuda"
195
+ or device != jagged_offsets[0].get_device()
196
+ or len(jagged_offsets[0].get_size()) != 1
197
+ or len(dense_size) != 3
198
+ or jagged_len is None
199
+ or not is_integer_type(jagged_offsets[0])
200
+ ):
201
+ return fallback_handler(fallback_op, add_to_fallback_set=False)(
202
+ dense,
203
+ jagged_offsets,
204
+ jagged_len,
205
+ )
206
+
207
+ offsets: TensorBox = jagged_offsets[0]
208
+ offsets_dtype = offsets.get_dtype()
209
+ batch_size = dense_size[0]
210
+ max_seq_len = dense_size[1]
211
+ embedding_len = dense_size[-1]
212
+
213
+ output_size = [jagged_len, embedding_len]
214
+
215
+ dense_loader = dense.make_loader()
216
+ offsets_loader = offsets.make_loader()
217
+
218
+ inverse_offsets = get_inverse_offsets(
219
+ offsets=offsets,
220
+ jagged_len=jagged_len,
221
+ )
222
+ inverse_offsets_loader = inverse_offsets.make_loader()
223
+
224
+ # pyre-ignore[2,3,53]
225
+ def inner_fn(index):
226
+ # jagged tensor size: [sum_B(N_B), D]
227
+ jagged_idx, emb_idx = index
228
+ batch_idx, seq_idx = jagged_idx_to_dense_idx(
229
+ jagged_idx=jagged_idx,
230
+ offsets_loader=offsets_loader,
231
+ inverse_offsets_loader=inverse_offsets_loader,
232
+ batch_size=batch_size,
233
+ max_seq_len=max_seq_len,
234
+ offsets_dtype=offsets_dtype,
235
+ )
236
+ return ops.masked(
237
+ ops.lt(
238
+ ops.index_expr(seq_idx, offsets_dtype),
239
+ ops.index_expr(max_seq_len, offsets_dtype),
240
+ ),
241
+ lambda: dense_loader([batch_idx, seq_idx, emb_idx]),
242
+ 0.0, # jagged sequence longer than max_seq_len
243
+ )
244
+
245
+ return Pointwise.create(
246
+ device=device,
247
+ dtype=dtype,
248
+ inner_fn=inner_fn,
249
+ ranges=output_size,
250
+ )
251
+
252
+ # pyre-ignore[56]
253
+ @register_lowering(torch.ops.aten._padded_dense_to_jagged_forward)
254
+ def _dense_to_jagged_forward(
255
+ dense: TensorBox,
256
+ jagged_offsets: List[TensorBox],
257
+ jagged_len: Optional[int] = None,
258
+ ) -> TensorBox:
259
+ return _dense_to_jagged_forward_impl(
260
+ fallback_op=torch.ops.aten._padded_dense_to_jagged_forward.default,
261
+ dense=dense,
262
+ jagged_offsets=jagged_offsets,
263
+ jagged_len=jagged_len,
264
+ )
.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/torch/_inductor/metrics.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ import csv
5
+ import dataclasses
6
+ import inspect
7
+ import os
8
+ import re
9
+ from dataclasses import dataclass
10
+ from functools import lru_cache
11
+ from typing import Dict, List, Set, Tuple, TYPE_CHECKING
12
+
13
+ from torch._inductor import config
14
+ from torch._inductor.utils import get_benchmark_name
15
+
16
+
17
+ # Prevent circular import
18
+ if TYPE_CHECKING:
19
+ from torch._inductor.scheduler import BaseSchedulerNode
20
+
21
+ # counter for tracking how many kernels have been generated
22
+ generated_kernel_count = 0
23
+ generated_cpp_vec_kernel_count = 0
24
+ num_bytes_accessed = 0
25
+ nodes_num_elem: List[
26
+ Tuple[
27
+ BaseSchedulerNode,
28
+ int,
29
+ ]
30
+ ] = []
31
+ node_runtimes: List[Tuple[BaseSchedulerNode, float]] = []
32
+
33
+ # counters for tracking fusions
34
+ ir_nodes_pre_fusion = 0
35
+
36
+ # counters for tracking to_dtype inserted
37
+ cpp_to_dtype_count = 0
38
+
39
+
40
+ @dataclasses.dataclass
41
+ class CppOuterLoopFusedCount:
42
+ inner_kernel_number: int
43
+ local_buffer_number: int = 0
44
+
45
+
46
+ # The length counts the number of outer loop fusions.
47
+ cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = []
48
+
49
+ num_comprehensive_padding = 0
50
+ num_matches_for_scatter_upon_const_tensor = 0
51
+
52
+ num_loop_reordering = 0
53
+
54
+
55
+ # reset all counters
56
+ def reset():
57
+ global generated_kernel_count
58
+ global generated_cpp_vec_kernel_count
59
+ global num_bytes_accessed, nodes_num_elem
60
+ global ir_nodes_pre_fusion
61
+ global cpp_to_dtype_count
62
+ global cpp_outer_loop_fused_inner_counts
63
+ global num_comprehensive_padding
64
+ global num_matches_for_scatter_upon_const_tensor
65
+ global num_loop_reordering
66
+
67
+ generated_kernel_count = 0
68
+ generated_cpp_vec_kernel_count = 0
69
+ num_bytes_accessed = 0
70
+ nodes_num_elem.clear()
71
+ node_runtimes.clear()
72
+ ir_nodes_pre_fusion = 0
73
+ cpp_to_dtype_count = 0
74
+ cpp_outer_loop_fused_inner_counts.clear()
75
+ num_comprehensive_padding = 0
76
+ num_matches_for_scatter_upon_const_tensor = 0
77
+ num_loop_reordering = 0
78
+
79
+
80
+ @dataclass
81
+ class CachedMetricsDeltas:
82
+ """
83
+ The subset of metrics we want update across cache hits, e.g., the
84
+ FxGraphCache.
85
+ """
86
+
87
+ generated_kernel_count: int
88
+ generated_cpp_vec_kernel_count: int
89
+ ir_nodes_pre_fusion: int
90
+ cpp_to_dtype_count: int
91
+ num_bytes_accessed: int
92
+ num_matches_for_scatter_upon_const_tensor: int
93
+
94
+
95
+ def get_metric_fields():
96
+ return [field.name for field in dataclasses.fields(CachedMetricsDeltas)]
97
+
98
+
99
+ class CachedMetricsHelper:
100
+ """
101
+ A helper class to help calculate and apply counter deltas for those
102
+ metrics we want to save with cache entries (e.g., FxGraphCache) and
103
+ apply on a cache hit.
104
+ """
105
+
106
+ def __init__(self) -> None:
107
+ self.cached_metrics = {}
108
+ for metric in get_metric_fields():
109
+ self.cached_metrics[metric] = globals()[metric]
110
+
111
+ def get_deltas(self) -> CachedMetricsDeltas:
112
+ delta_metrics = {}
113
+ for metric in get_metric_fields():
114
+ delta_metrics[metric] = globals()[metric] - self.cached_metrics[metric]
115
+
116
+ return CachedMetricsDeltas(**delta_metrics)
117
+
118
+ @staticmethod
119
+ def apply_deltas(delta: CachedMetricsDeltas):
120
+ for metric in get_metric_fields():
121
+ globals()[metric] += getattr(delta, metric)
122
+
123
+
124
+ REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {}
125
+
126
+
127
+ @dataclass
128
+ class MetricTable:
129
+ table_name: str
130
+ column_names: List[str]
131
+
132
+ num_rows_added: int = 0
133
+
134
+ def add_row(self, row_fn):
135
+ if self.table_name not in enabled_metric_tables():
136
+ return
137
+
138
+ row_dict = row_fn()
139
+ assert len(self.column_names) == len(
140
+ row_dict
141
+ ), f"{len(self.column_names)} v.s. {len(row_dict)}"
142
+ assert set(self.column_names) == set(
143
+ row_dict.keys()
144
+ ), f"{set(self.column_names)} v.s. {set(row_dict.keys())}"
145
+
146
+ row = [
147
+ get_benchmark_name(),
148
+ ]
149
+ row += [row_dict[column_name] for column_name in self.column_names]
150
+ self._write_row(row)
151
+
152
+ def output_filename(self):
153
+ return f"metric_table_{self.table_name}.csv"
154
+
155
+ def write_header(self):
156
+ filename = self.output_filename()
157
+ with open(filename, "w") as fd:
158
+ writer = csv.writer(fd, lineterminator="\n")
159
+ writer.writerow(["model_name"] + self.column_names)
160
+
161
+ def _write_row(self, row):
162
+ filename = self.output_filename()
163
+ if self.num_rows_added == 0 and not os.path.exists(filename):
164
+ self.write_header()
165
+
166
+ self.num_rows_added += 1
167
+
168
+ for idx, orig_val in enumerate(row):
169
+ if isinstance(orig_val, float):
170
+ new_val = f"{orig_val:.6f}"
171
+ elif orig_val is None:
172
+ new_val = ""
173
+ else:
174
+ new_val = orig_val
175
+ row[idx] = new_val
176
+
177
+ with open(filename, "a") as fd:
178
+ writer = csv.writer(fd, lineterminator="\n")
179
+ writer.writerow(row)
180
+
181
+ @staticmethod
182
+ def register_table(name, column_names):
183
+ table = MetricTable(name, column_names)
184
+ REGISTERED_METRIC_TABLES[name] = table
185
+
186
+
187
+ MetricTable.register_table(
188
+ "slow_fusion",
189
+ [
190
+ "kernel1_path",
191
+ "kernel1_latency",
192
+ "kernel2_path",
193
+ "kernel2_latency",
194
+ "fused_kernel_path",
195
+ "fused_kernel_latency",
196
+ "slow_down_ratio",
197
+ ],
198
+ )
199
+
200
+ # track the fusion statistics for each graph
201
+ MetricTable.register_table(
202
+ "graph_stats",
203
+ [
204
+ "graph_id",
205
+ "num_nodes_before_fusion",
206
+ "num_nodes_after_fusion",
207
+ ],
208
+ )
209
+
210
+ # track the perf difference between persistent reduction and non-persistent
211
+ # reductions
212
+ MetricTable.register_table(
213
+ "persistent_red_perf",
214
+ [
215
+ "kernel1_name",
216
+ "kernel2_name",
217
+ "kernel1_latency",
218
+ "kernel2_latency",
219
+ "size_hints",
220
+ "reduction_hint",
221
+ "speedup",
222
+ ],
223
+ )
224
+
225
+ # Log the fusion failures due to indexing mismatch
226
+ MetricTable.register_table(
227
+ "fusion_failure_due_to_indexing_mismatch",
228
+ [
229
+ "pre_grad_graph_id",
230
+ "post_grad_graph_id",
231
+ "node1_name",
232
+ "node2_name",
233
+ "node1_debug_str",
234
+ "node2_debug_str",
235
+ "common_buffer_names",
236
+ "failure_reason",
237
+ ],
238
+ )
239
+
240
+ # Log metadata for pointwise/reduction kernels. E.g., model name, kernel path, numel, rnumel, reduction hint
241
+ MetricTable.register_table(
242
+ "kernel_metadata",
243
+ [
244
+ "kernel_name",
245
+ "kernel_path",
246
+ "kernel_category", # pointwise/reduction/foreach etc.
247
+ "size_hints",
248
+ "reduction_hint",
249
+ "line_of_code",
250
+ "num_load",
251
+ "num_store",
252
+ "num_for_loop",
253
+ "num_atomic_add",
254
+ "num_args",
255
+ # xyz numel can be different to size_hints since size_hints are rounded
256
+ # up to the nearest power of 2.
257
+ # Inductor kernel will burn in the xyz numel in kernel code for static
258
+ # shape kernels.
259
+ # Logging them will be helpful to find unaligned shape for reduction
260
+ "xnumel",
261
+ "ynumel",
262
+ "rnumel",
263
+ "kernel_args_num_gb",
264
+ ],
265
+ )
266
+
267
+
268
+ def _parse_kernel_fn_code(kernel_module_code):
269
+ """
270
+ The kernel_module_code is the python module that contains kernel function code.
271
+ kernel function is the proper triton kernel function annotated with
272
+ @triton.jit
273
+ """
274
+ from .codecache import PyCodeCache
275
+ from .wrapper_benchmark import get_triton_kernel
276
+
277
+ mod = PyCodeCache.load(kernel_module_code)
278
+ kernel = get_triton_kernel(mod)
279
+ # kernel is a CachingAutotune; kernel.fn is the JITFunction;
280
+ # kernel.fn.fn is the function being decorate by triton.jit
281
+ return inspect.getsource(kernel.fn.fn)
282
+
283
+
284
+ def _parse_kernel_line_of_code(proper_kernel_fn_code):
285
+ """
286
+ Return the line of code for the kernel excluding the decorators.
287
+ """
288
+ return len(proper_kernel_fn_code.splitlines())
289
+
290
+
291
+ def _parse_size_hints(kernel_module_code, kernel_category):
292
+ if kernel_category == "foreach":
293
+ # foreach kernel does not have size_hints
294
+ return None
295
+ m = re.search(r"size_hints=(\[[0-9, ]*\]),", kernel_module_code)
296
+ assert m, "size_hints missing!"
297
+ return m.group(1)
298
+
299
+
300
+ def _parse_reduction_hint(kernel_category, kernel_module_code):
301
+ if kernel_category not in ("reduction", "persistent_reduction"):
302
+ return None
303
+ m = re.search(r"reduction_hint=ReductionHint\.(\w*),", kernel_module_code)
304
+ assert m, "reduction_hint not found in kernel source code!"
305
+ return m.group(1)
306
+
307
+
308
+ def _count_pattern(proper_kernel_fn_code, pattern):
309
+ return proper_kernel_fn_code.count(pattern)
310
+
311
+
312
+ def _count_args(proper_kernel_fn_code):
313
+ def_line = proper_kernel_fn_code.splitlines()[0]
314
+ assert def_line.startswith("def ")
315
+ start_idx = def_line.index("(")
316
+ end_idx = def_line.index("):")
317
+ decl_csv = def_line[start_idx + 1 : end_idx]
318
+ comps = decl_csv.split(",")
319
+ return len(comps)
320
+
321
+
322
+ def _parse_proper_kernel_fn_code(kernel_fn_code):
323
+ """
324
+ Skip decorators.
325
+ """
326
+ start_pos = kernel_fn_code.index("def ")
327
+ return kernel_fn_code[start_pos:]
328
+
329
+
330
+ def _parse_numel(proper_kernel_fn_code, numel_arg_name):
331
+ m = re.search(f"{numel_arg_name} = ([\\d]+)", proper_kernel_fn_code)
332
+ if m:
333
+ return int(m.group(1))
334
+ else:
335
+ return None
336
+
337
+
338
+ def _parse_kernel_args_num_gb(kernel_fn_code, kernel_category):
339
+ """
340
+ inductor meta looks like:
341
+ inductor_meta={... 'mutated_arg_names': [], 'no_x_dim': False, 'kernel_num_gb': 2.0},
342
+ """
343
+ m = re.search(r".kernel_num_gb.:\s*([0-9.]+)", kernel_fn_code)
344
+ if m:
345
+ return float(m.group(1))
346
+ else:
347
+ """
348
+ There are a few cases that kernel_num_gdb field can be missing:
349
+ 1. the field will be missing if config.benchmark_kernel and
350
+ config.profile_bandwidth are false
351
+ 2. even if config.benchmark_kernel or config.profile_bandwidth is true.
352
+ foreach kernel does not have kernel_num_gb field in the metadata
353
+ """
354
+ return None
355
+
356
+
357
+ def log_kernel_metadata(kernel_name, kernel_path, kernel_module_code):
358
+ """
359
+ An utility to log kernel metadata. We may parse metadata from kernel source code here.
360
+
361
+ It's fine to parse the generated kernel code here since the logging is
362
+ disabled by default. It would hurt compilation time.
363
+ """
364
+ from .wrapper_benchmark import get_kernel_category_by_source_code
365
+
366
+ kernel_category = get_kernel_category_by_source_code(kernel_module_code)
367
+ reduction_hint = _parse_reduction_hint(kernel_category, kernel_module_code)
368
+ size_hints = _parse_size_hints(kernel_module_code, kernel_category)
369
+ kernel_fn_code = _parse_kernel_fn_code(kernel_module_code)
370
+
371
+ proper_kernel_fn_code = _parse_proper_kernel_fn_code(kernel_fn_code)
372
+
373
+ # the line of code excluding the decortors
374
+ kernel_line_of_code = _parse_kernel_line_of_code(proper_kernel_fn_code)
375
+
376
+ get_metric_table("kernel_metadata").add_row(
377
+ lambda: {
378
+ "kernel_name": kernel_name,
379
+ "kernel_path": kernel_path,
380
+ "kernel_category": kernel_category,
381
+ "size_hints": size_hints,
382
+ "reduction_hint": reduction_hint,
383
+ "line_of_code": kernel_line_of_code,
384
+ "num_load": _count_pattern(proper_kernel_fn_code, "tl.load"),
385
+ "num_store": _count_pattern(proper_kernel_fn_code, "tl.store"),
386
+ "num_for_loop": _count_pattern(proper_kernel_fn_code, "for "),
387
+ "num_atomic_add": _count_pattern(proper_kernel_fn_code, "tl.atomic_add"),
388
+ "num_args": _count_args(proper_kernel_fn_code),
389
+ "xnumel": _parse_numel(proper_kernel_fn_code, "xnumel"),
390
+ "ynumel": _parse_numel(proper_kernel_fn_code, "ynumel"),
391
+ "rnumel": _parse_numel(proper_kernel_fn_code, "rnumel"),
392
+ "kernel_args_num_gb": _parse_kernel_args_num_gb(
393
+ kernel_fn_code, kernel_category
394
+ ),
395
+ }
396
+ )
397
+
398
+
399
+ def purge_old_log_files():
400
+ """
401
+ Purge the old log file at the beginning when the benchmark script runs.
402
+ Should do it in the parent process rather than the child processes running
403
+ each individual model.
404
+ """
405
+ for name, table in REGISTERED_METRIC_TABLES.items():
406
+ if name in enabled_metric_tables():
407
+ filename = table.output_filename()
408
+ if os.path.exists(filename):
409
+ os.unlink(filename)
410
+
411
+ table.write_header()
412
+
413
+
414
+ @lru_cache
415
+ def enabled_metric_tables() -> Set[str]:
416
+ config_str = config.enabled_metric_tables
417
+
418
+ enabled = set()
419
+ for name in config_str.split(","):
420
+ name = name.strip()
421
+ if not name:
422
+ continue
423
+ assert (
424
+ name in REGISTERED_METRIC_TABLES
425
+ ), f"Metric table name {name} is not registered"
426
+ enabled.add(name)
427
+ return enabled
428
+
429
+
430
+ def is_metric_table_enabled(name):
431
+ return name in enabled_metric_tables()
432
+
433
+
434
+ def get_metric_table(name):
435
+ assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined"
436
+ return REGISTERED_METRIC_TABLES[name]
.venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_ir.py ADDED
@@ -0,0 +1,1881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Any, List, Optional
3
+
4
+ import sympy
5
+
6
+ import torch
7
+ from torch._prims_common import make_channels_last_strides_for
8
+ from torch.utils._ordered_set import OrderedSet
9
+
10
+ from .ir import (
11
+ ExternKernelAlloc,
12
+ FixedLayout,
13
+ FlexibleLayout,
14
+ ir_node_to_tensor,
15
+ IRNode,
16
+ is_contiguous_storage_and_layout,
17
+ Layout,
18
+ may_convert_to_optional,
19
+ MultiOutput,
20
+ MultiOutputLayout,
21
+ MutationOutput,
22
+ NoneLayout,
23
+ TensorBox,
24
+ )
25
+ from .utils import convert_shape_to_inductor, pad_listlike
26
+ from .virtualized import V
27
+
28
+
29
+ def _prepare_convolution_fusion_create(
30
+ cls,
31
+ x: "TensorBox",
32
+ weight: "TensorBox",
33
+ bias: "TensorBox",
34
+ padding: List[int],
35
+ stride: List[int],
36
+ dilation: List[int],
37
+ groups: int,
38
+ transposed: bool = False,
39
+ output_padding: Optional[List[int]] = None,
40
+ ):
41
+ """
42
+ This function is a helper function to prepare inputs, layout and constant args
43
+ for convolution post-op fusion's create function, including deciding the output
44
+ layout (channels first or channels last), realizing inputs and make them etc. The
45
+ function only supports the CPU device since conv post-op fusion kernel is only
46
+ supported on CPU right now.
47
+ """
48
+
49
+ # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size
50
+ def _conv_input_size(
51
+ output_size, weight_size, padding, output_padding, stride, dilation, groups
52
+ ):
53
+ assert len(output_size) == len(weight_size), "Expect input dim == weight dim"
54
+ dim = len(output_size)
55
+ assert dim > 2, "Expect input dim > 2"
56
+
57
+ BATCH_DIM = 0
58
+ WEIGHT_INPUT_CHANNELS_DIM = 1
59
+ input_size = []
60
+ input_size.append(output_size[BATCH_DIM])
61
+ input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups)
62
+ for d in range(2, dim):
63
+ kernel = (weight_size[d] - 1) * dilation[d - 2] + 1
64
+ input_size_d = (
65
+ (output_size[d] - 1) * stride[d - 2]
66
+ - (padding[d - 2] * 2)
67
+ + kernel
68
+ + output_padding[d - 2]
69
+ )
70
+ input_size.append(input_size_d)
71
+ return list(map(int, input_size))
72
+
73
+ # The size of prepacked_weight is the prepacked weight size of deconv:
74
+ # Groups > 1: [g*o, i/g, ...]
75
+ # Groups == 1: [o, i, ...]
76
+ # Returns original weight size in [i, o, ...]
77
+ def _original_deconv_weight_size(
78
+ prepacked_weight,
79
+ groups,
80
+ ):
81
+ prepacked_weight_size = prepacked_weight.size()
82
+ dim = len(prepacked_weight_size)
83
+ assert dim > 2, "Expect weight dim > 2"
84
+ if groups > 1:
85
+ weight_size = []
86
+ weight_size.append(prepacked_weight_size[1] * groups)
87
+ weight_size.append(prepacked_weight_size[0] / groups)
88
+ for d in range(2, dim):
89
+ weight_size.append(prepacked_weight_size[d])
90
+ else:
91
+ weight_size = prepacked_weight.transpose(0, 1).size()
92
+ return weight_size
93
+
94
+ x.realize()
95
+ weight.realize()
96
+ if bias is not None:
97
+ bias.realize()
98
+ with V.graph.fake_mode:
99
+ # TODO <Leslie> cleaned up the fake_tensor trace as Linear implementation
100
+ x_fake = ir_node_to_tensor(x, guard_shape=True)
101
+ weight_fake = ir_node_to_tensor(weight, guard_shape=True)
102
+ dims = len(x_fake.size()) - 2
103
+ assert 0 < len(padding) <= dims
104
+ assert 0 < len(dilation) <= dims
105
+ assert 0 < len(stride) <= dims
106
+ padding = pad_listlike(padding, dims)
107
+ dilation = pad_listlike(dilation, dims)
108
+ stride = pad_listlike(stride, dims)
109
+ if output_padding is None:
110
+ output_padding = pad_listlike([0], dims)
111
+ else:
112
+ assert 0 < len(output_padding) <= dims
113
+ output_padding = pad_listlike(output_padding, dims)
114
+ assert isinstance(groups, (int, sympy.core.numbers.Integer))
115
+ if transposed:
116
+ # When transposed, the size of the prepacked oneDNN weight is different
117
+ # from the PyTorch weight. We're not able to run aten conv with such
118
+ # size. We infer the output size from the input params here:
119
+ weight_size = _original_deconv_weight_size(weight_fake, groups)
120
+ input_size = x_fake.size()
121
+ output_size = _conv_input_size(
122
+ input_size,
123
+ weight_size,
124
+ padding,
125
+ output_padding,
126
+ stride,
127
+ dilation,
128
+ groups,
129
+ )
130
+ else:
131
+ bias_fake = (
132
+ ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias
133
+ )
134
+ output = torch.ops.aten.convolution(
135
+ x_fake,
136
+ weight_fake,
137
+ bias_fake,
138
+ stride,
139
+ padding,
140
+ dilation,
141
+ transposed,
142
+ output_padding,
143
+ groups,
144
+ )
145
+ output_size = output.size()
146
+
147
+ req_stride_order = [0] + list(reversed(range(1, len(stride) + 1)))
148
+ req_stride_order = [len(req_stride_order)] + req_stride_order
149
+
150
+ x = cls.require_stride_order(x, req_stride_order)
151
+
152
+ # We won't do weight prepack for Conv if dynamic_shapes.
153
+ # In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel.
154
+ # In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1),
155
+ # x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order
156
+ # won't change the stride of this tensor since stride for dimensions of size 1 is ignored. While in Conv kernel,
157
+ # this tensor is considered as channels first and the output will be in contiguous format.
158
+ # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last.
159
+ dynamic_shapes = not all(isinstance(i, int) for i in (output_size))
160
+ if dynamic_shapes and is_contiguous_storage_and_layout(x):
161
+ output_stride = FlexibleLayout.contiguous_strides(output_size)
162
+ else:
163
+ output_stride = make_channels_last_strides_for(output_size)
164
+
165
+ assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
166
+ inputs = [x, weight]
167
+
168
+ kernel_layout = FixedLayout(
169
+ x.get_device(),
170
+ x.get_dtype(),
171
+ convert_shape_to_inductor(output_size),
172
+ convert_shape_to_inductor(output_stride),
173
+ )
174
+ constant_args = [padding, stride, dilation, groups]
175
+ if transposed:
176
+ constant_args.insert(1, output_padding)
177
+
178
+ if bias is not None:
179
+ inputs.append(bias)
180
+ else:
181
+ constant_args.insert(0, bias)
182
+ return inputs, constant_args, kernel_layout, req_stride_order
183
+
184
+
185
+ def _prepare_linear_fusion_create(
186
+ cls,
187
+ x: "TensorBox",
188
+ weight: "TensorBox",
189
+ bias: "TensorBox",
190
+ ):
191
+ """
192
+ This function is a helper function to prepare inputs, layout and constant args
193
+ for linear post-op fusion's create function. The function only supports the CPU device
194
+ since linear post-op fusion kernel is only supported on CPU right now.
195
+ """
196
+ x.realize()
197
+ weight.realize()
198
+ if bias is not None:
199
+ bias.realize()
200
+
201
+ *m, _ = x.get_size()
202
+ # The weight has been transposed during the qlinear weight prepack process.
203
+ # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/
204
+ # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291
205
+ _, oc = weight.get_size()
206
+ output_size = list(m) + [oc]
207
+ req_stride_order = list(reversed(range(len(x.get_size()))))
208
+
209
+ x = cls.require_stride_order(x, req_stride_order)
210
+ assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
211
+ inputs = [x, weight]
212
+
213
+ output_stride = FlexibleLayout.contiguous_strides(output_size)
214
+ kernel_layout = FixedLayout(
215
+ x.get_device(),
216
+ x.get_dtype(),
217
+ output_size,
218
+ output_stride,
219
+ )
220
+ constant_args: List[Any] = []
221
+
222
+ if bias is not None:
223
+ inputs.append(bias)
224
+ else:
225
+ constant_args.insert(0, bias)
226
+ return inputs, constant_args, kernel_layout, req_stride_order
227
+
228
+
229
+ class ConvolutionUnary(ExternKernelAlloc):
230
+ def __init__(
231
+ self,
232
+ layout,
233
+ inputs,
234
+ constant_args=(),
235
+ ) -> None:
236
+ super().__init__(
237
+ layout,
238
+ inputs,
239
+ constant_args,
240
+ None,
241
+ op_overload=torch.ops.mkldnn._convolution_pointwise.default,
242
+ )
243
+ self.cpp_op_schema = """
244
+ at::Tensor(
245
+ const at::Tensor& input_t,
246
+ const at::Tensor& weight_t,
247
+ const std::optional<at::Tensor>& bias_opt,
248
+ at::IntArrayRef padding,
249
+ at::IntArrayRef stride,
250
+ at::IntArrayRef dilation,
251
+ int64_t groups,
252
+ c10::string_view attr,
253
+ torch::List<std::optional<at::Scalar>> scalars,
254
+ std::optional<c10::string_view> algorithm)"""
255
+
256
+ def codegen(self, wrapper):
257
+ wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
258
+ self.get_name(),
259
+ self.python_kernel_name,
260
+ self.cpp_kernel_name,
261
+ self.codegen_args(),
262
+ self.cpp_op_schema,
263
+ self.cpp_kernel_key,
264
+ op_overload=self.op_overload,
265
+ raw_args=[*self.inputs, *self.constant_args],
266
+ )
267
+ if isinstance(self.layout, Layout):
268
+ self.codegen_size_asserts(wrapper)
269
+
270
+ @classmethod
271
+ def create(
272
+ cls,
273
+ x: "TensorBox",
274
+ weight: "TensorBox",
275
+ bias: "TensorBox",
276
+ padding_: List[int],
277
+ stride_: List[int],
278
+ dilation_: List[int],
279
+ groups: int,
280
+ attr,
281
+ scalars: Optional[List[Any]],
282
+ algorithm,
283
+ ):
284
+ (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
285
+ cls, x, weight, bias, padding_, stride_, dilation_, groups
286
+ )
287
+ constant_args = constant_args + [
288
+ attr,
289
+ may_convert_to_optional(scalars),
290
+ algorithm,
291
+ ]
292
+ return ConvolutionUnary(
293
+ layout=kernel_layout,
294
+ inputs=inputs,
295
+ constant_args=constant_args,
296
+ )
297
+
298
+
299
+ class ConvolutionBinary(ExternKernelAlloc):
300
+ def __init__(
301
+ self,
302
+ layout,
303
+ inputs,
304
+ constant_args=(),
305
+ cpp_constant_args=(),
306
+ ) -> None:
307
+ super().__init__(
308
+ layout,
309
+ inputs,
310
+ constant_args,
311
+ None,
312
+ op_overload=torch.ops.mkldnn._convolution_pointwise.binary,
313
+ )
314
+ self.cpp_op_schema = """
315
+ at::Tensor(
316
+ const at::Tensor& input_t,
317
+ const at::Tensor& other_t,
318
+ const at::Tensor& weight_t,
319
+ const std::optional<at::Tensor>& bias_opt,
320
+ at::IntArrayRef padding,
321
+ at::IntArrayRef stride,
322
+ at::IntArrayRef dilation,
323
+ int64_t groups,
324
+ c10::string_view binary_attr,
325
+ std::optional<at::Scalar> alpha,
326
+ std::optional<c10::string_view> unary_attr,
327
+ torch::List<std::optional<at::Scalar>> unary_scalars,
328
+ std::optional<c10::string_view> unary_algorithm)"""
329
+ self.cpp_constant_args = cpp_constant_args
330
+
331
+ def codegen(self, wrapper):
332
+ wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
333
+ self.get_name(),
334
+ self.python_kernel_name,
335
+ self.cpp_kernel_name,
336
+ self.codegen_args(),
337
+ self.cpp_op_schema,
338
+ self.cpp_kernel_key,
339
+ self.cpp_kernel_overload_name,
340
+ self.op_overload,
341
+ [*self.inputs, *self.constant_args],
342
+ )
343
+ if isinstance(self.layout, Layout):
344
+ self.codegen_size_asserts(wrapper)
345
+
346
+ @classmethod
347
+ def create(
348
+ cls,
349
+ x: "TensorBox",
350
+ other: "TensorBox",
351
+ weight: "TensorBox",
352
+ bias: "TensorBox",
353
+ padding_: List[int],
354
+ stride_: List[int],
355
+ dilation_: List[int],
356
+ groups: int,
357
+ binary_attr: str,
358
+ binary_alpha: Optional[float],
359
+ unary_attr: Optional[str],
360
+ unary_scalars: Optional[List[Any]],
361
+ unary_algorithm: Optional[str],
362
+ ):
363
+ (
364
+ inputs,
365
+ constant_args,
366
+ kernel_layout,
367
+ req_stride_order,
368
+ ) = _prepare_convolution_fusion_create(
369
+ cls, x, weight, bias, padding_, stride_, dilation_, groups
370
+ )
371
+ other = cls.require_stride_order(other, req_stride_order)
372
+ inputs.insert(1, other)
373
+ constant_args = constant_args + [
374
+ binary_attr,
375
+ binary_alpha,
376
+ unary_attr,
377
+ may_convert_to_optional(unary_scalars),
378
+ unary_algorithm,
379
+ ]
380
+ return ConvolutionBinary(
381
+ layout=kernel_layout,
382
+ inputs=inputs,
383
+ constant_args=constant_args,
384
+ )
385
+
386
+
387
+ class ConvolutionBinaryInplace(ExternKernelAlloc):
388
+ def __init__(
389
+ self,
390
+ kernel_layout,
391
+ inputs,
392
+ constant_args=(),
393
+ ) -> None:
394
+ # Due to constrain of op.call, other (Tensor&) should be at input[0]
395
+ reordered_inputs = [inputs[1], inputs[0]] + inputs[2:]
396
+
397
+ super().__init__(
398
+ kernel_layout,
399
+ reordered_inputs,
400
+ constant_args,
401
+ None,
402
+ op_overload=torch.ops.mkldnn._convolution_pointwise_.binary,
403
+ )
404
+ # TODO: op.call: input[0] should be at::Tensor&
405
+ self.cpp_op_schema = """
406
+ at::Tensor&(
407
+ at::Tensor& other_t,
408
+ const at::Tensor& input_t,
409
+ const at::Tensor& weight_t,
410
+ const std::optional<at::Tensor>& bias_opt,
411
+ at::IntArrayRef padding,
412
+ at::IntArrayRef stride,
413
+ at::IntArrayRef dilation,
414
+ int64_t groups,
415
+ c10::string_view binary_attr,
416
+ std::optional<at::Scalar> alpha,
417
+ std::optional<c10::string_view> unary_attr,
418
+ torch::List<std::optional<at::Scalar>> unary_scalars,
419
+ std::optional<c10::string_view> unary_algorithm)"""
420
+
421
+ self.mutation_outputs = [
422
+ MutationOutput(NoneLayout(inputs[0].get_device()), inputs[0], self),
423
+ MutationOutput(NoneLayout(inputs[1].get_device()), inputs[1], self),
424
+ ]
425
+
426
+ def codegen(self, wrapper):
427
+ wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
428
+ self.get_name(),
429
+ self.python_kernel_name,
430
+ self.cpp_kernel_name,
431
+ self.codegen_args(),
432
+ self.cpp_op_schema,
433
+ self.cpp_kernel_key,
434
+ self.cpp_kernel_overload_name,
435
+ self.op_overload,
436
+ [*self.inputs, *self.constant_args],
437
+ )
438
+
439
+ def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
440
+ return OrderedSet()
441
+
442
+ @classmethod
443
+ def create(
444
+ cls,
445
+ x: "TensorBox",
446
+ other: "TensorBox",
447
+ weight: "TensorBox",
448
+ bias: "TensorBox",
449
+ padding_: List[int],
450
+ stride_: List[int],
451
+ dilation_: List[int],
452
+ groups: int,
453
+ binary_attr: str,
454
+ binary_alpha: Optional[float],
455
+ unary_attr: Optional[str],
456
+ unary_scalars: Optional[List[Any]],
457
+ unary_algorithm: Optional[str],
458
+ ):
459
+ (
460
+ inputs,
461
+ constant_args,
462
+ _,
463
+ req_stride_order,
464
+ ) = _prepare_convolution_fusion_create(
465
+ cls, x, weight, bias, padding_, stride_, dilation_, groups
466
+ )
467
+ other = cls.require_stride_order(other, req_stride_order)
468
+ inputs.insert(1, other)
469
+ constant_args = constant_args + [
470
+ binary_attr,
471
+ binary_alpha,
472
+ unary_attr,
473
+ may_convert_to_optional(unary_scalars),
474
+ unary_algorithm,
475
+ ]
476
+ packed = ConvolutionBinaryInplace(
477
+ kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type]
478
+ inputs=inputs,
479
+ constant_args=constant_args,
480
+ )
481
+ # This op mutates in place which means that the result is not the
482
+ # target but rather the input that is being mutated
483
+ # init reorders the inputs, so inputs[1] becomes packed.inputs[0]
484
+ return packed.inputs[0]
485
+
486
+
487
+ class ConvolutionTransposeUnary(ExternKernelAlloc):
488
+ def __init__(
489
+ self,
490
+ layout,
491
+ inputs,
492
+ constant_args=(),
493
+ ) -> None:
494
+ super().__init__(
495
+ layout,
496
+ inputs,
497
+ constant_args,
498
+ None,
499
+ op_overload=torch.ops.mkldnn._convolution_transpose_pointwise.default,
500
+ )
501
+ self.cpp_op_schema = """
502
+ at::Tensor(
503
+ const at::Tensor& input_t,
504
+ const at::Tensor& weight_t,
505
+ const std::optional<at::Tensor>& bias_opt,
506
+ at::IntArrayRef padding,
507
+ at::IntArrayRef output_padding,
508
+ at::IntArrayRef stride,
509
+ at::IntArrayRef dilation,
510
+ int64_t groups,
511
+ c10::string_view attr,
512
+ torch::List<std::optional<at::Scalar>> scalars,
513
+ std::optional<c10::string_view> algorithm)"""
514
+
515
+ def codegen(self, wrapper):
516
+ wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
517
+ self.get_name(),
518
+ self.python_kernel_name,
519
+ self.cpp_kernel_name,
520
+ self.codegen_args(),
521
+ self.cpp_op_schema,
522
+ self.cpp_kernel_key,
523
+ )
524
+
525
+ @classmethod
526
+ def create(
527
+ cls,
528
+ x: "TensorBox",
529
+ weight: "TensorBox",
530
+ bias: "TensorBox",
531
+ padding_: List[int],
532
+ output_padding_: List[int],
533
+ stride_: List[int],
534
+ dilation_: List[int],
535
+ groups_: int,
536
+ attr,
537
+ scalars: Optional[List[Any]],
538
+ algorithm,
539
+ ):
540
+ transposed = True
541
+ (
542
+ inputs,
543
+ constant_args,
544
+ kernel_layout,
545
+ _,
546
+ ) = _prepare_convolution_fusion_create(
547
+ cls,
548
+ x,
549
+ weight,
550
+ bias,
551
+ padding_,
552
+ stride_,
553
+ dilation_,
554
+ groups_,
555
+ transposed,
556
+ output_padding_,
557
+ )
558
+ constant_args = constant_args + [
559
+ attr,
560
+ may_convert_to_optional(scalars),
561
+ algorithm,
562
+ ]
563
+ return ConvolutionTransposeUnary(
564
+ layout=kernel_layout,
565
+ inputs=inputs,
566
+ constant_args=constant_args,
567
+ )
568
+
569
+
570
+ class QConvPointWisePT2E(ExternKernelAlloc):
571
+ def __init__(
572
+ self,
573
+ layout,
574
+ inputs,
575
+ constant_args=(),
576
+ ) -> None:
577
+ """
578
+ if bias is not None
579
+ - inputs = [x, w, b, weight_scale, weight_zp]
580
+ - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp,
581
+ fp32_output, unary_attr, unary_scalars, unary_algorithm]
582
+ else
583
+ - inputs = [x, w, weight_scale, weight_zp]
584
+ - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp,
585
+ fp32_output, unary_attr, unary_scalars, unary_algorithm]
586
+ """
587
+ self.has_bias = len(inputs) == 5
588
+ super().__init__(
589
+ layout,
590
+ inputs,
591
+ constant_args,
592
+ None,
593
+ op_overload=torch.ops.onednn.qconv2d_pointwise.default,
594
+ )
595
+ self.cpp_op_schema = """
596
+ at::Tensor(
597
+ at::Tensor act,
598
+ double act_scale,
599
+ int64_t act_zero_point,
600
+ at::Tensor weight,
601
+ at::Tensor weight_scales,
602
+ at::Tensor weight_zero_points,
603
+ std::optional<at::Tensor> bias,
604
+ torch::List<int64_t> stride,
605
+ torch::List<int64_t> padding,
606
+ torch::List<int64_t> dilation,
607
+ int64_t groups,
608
+ double output_scale,
609
+ int64_t output_zero_point,
610
+ std::optional<c10::ScalarType> output_dtype,
611
+ c10::string_view attr,
612
+ torch::List<std::optional<at::Scalar>> scalars,
613
+ std::optional<c10::string_view> algorithm)"""
614
+
615
+ def codegen(self, wrapper):
616
+ # Parser the inputs and constant
617
+ # The raw_args setup can be skipped if there is a C shim implementation
618
+ args = [x.codegen_reference() for x in self.inputs]
619
+ const_arg_names = [
620
+ "x_scale",
621
+ "x_zero_point",
622
+ "stride",
623
+ "padding",
624
+ "dilation",
625
+ "groups",
626
+ "output_scale",
627
+ "output_zero_point",
628
+ "output_dtype",
629
+ "attr",
630
+ "scalars",
631
+ "algorithm",
632
+ ]
633
+ if not self.has_bias:
634
+ const_arg_names.insert(2, "bias")
635
+ const_args = list(self.codegen_const_args(const_arg_names))
636
+
637
+ x = args[0]
638
+ x_raw = self.inputs[0]
639
+ packed_weight = args[1]
640
+ packed_weight_raw = self.inputs[1]
641
+ bias = args[2] if self.has_bias else const_args[2]
642
+ bias_raw = self.inputs[2] if self.has_bias else self.constant_args[2]
643
+ w_scale, w_zp = args[-2], args[-1]
644
+ w_scale_raw, w_zp_raw = self.inputs[-2], self.inputs[-1]
645
+ (
646
+ x_scale,
647
+ x_zp,
648
+ ) = const_args[:2]
649
+ (
650
+ x_scale_raw,
651
+ x_zp_raw,
652
+ ) = self.constant_args[:2]
653
+ (
654
+ stride,
655
+ padding,
656
+ dilation,
657
+ groups,
658
+ o_scale,
659
+ o_zp,
660
+ output_dtype,
661
+ unary_attr,
662
+ unary_scalars,
663
+ unary_algorithm,
664
+ ) = const_args[-10:]
665
+ (
666
+ stride_raw,
667
+ padding_raw,
668
+ dilation_raw,
669
+ groups_raw,
670
+ o_scale_raw,
671
+ o_zp_raw,
672
+ output_dtype_raw,
673
+ unary_attr_raw,
674
+ unary_scalars_raw,
675
+ unary_algorithm_raw,
676
+ ) = self.constant_args[-10:]
677
+ codegen_args = (
678
+ x,
679
+ x_scale,
680
+ x_zp,
681
+ packed_weight,
682
+ w_scale,
683
+ w_zp,
684
+ bias,
685
+ stride,
686
+ padding,
687
+ dilation,
688
+ groups,
689
+ o_scale,
690
+ o_zp,
691
+ output_dtype,
692
+ unary_attr,
693
+ unary_scalars,
694
+ unary_algorithm,
695
+ )
696
+ raw_args = (
697
+ x_raw,
698
+ x_scale_raw,
699
+ x_zp_raw,
700
+ packed_weight_raw,
701
+ w_scale_raw,
702
+ w_zp_raw,
703
+ bias_raw,
704
+ stride_raw,
705
+ padding_raw,
706
+ dilation_raw,
707
+ groups_raw,
708
+ o_scale_raw,
709
+ o_zp_raw,
710
+ output_dtype_raw,
711
+ unary_attr_raw,
712
+ unary_scalars_raw,
713
+ unary_algorithm_raw,
714
+ )
715
+ wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
716
+ self.get_name(),
717
+ self.python_kernel_name,
718
+ self.cpp_kernel_name,
719
+ codegen_args,
720
+ self.cpp_op_schema,
721
+ self.cpp_kernel_key,
722
+ op_overload=self.op_overload,
723
+ raw_args=raw_args,
724
+ )
725
+ if isinstance(self.layout, Layout):
726
+ self.codegen_size_asserts(wrapper)
727
+
728
+ @classmethod
729
+ def create(
730
+ cls,
731
+ qx: "TensorBox",
732
+ x_scale: float,
733
+ x_zero_point: int,
734
+ qw: "TensorBox", # qw
735
+ w_scale: "TensorBox",
736
+ w_zero_point: "TensorBox",
737
+ bias: "TensorBox",
738
+ stride: List[int],
739
+ padding: List[int],
740
+ dilation: List[int],
741
+ groups: int,
742
+ output_scale: float,
743
+ output_zero_point: int,
744
+ output_dtype,
745
+ attr,
746
+ scalars,
747
+ algorithm,
748
+ ):
749
+ transposed = False
750
+ output_padding = None
751
+ (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
752
+ cls,
753
+ qx,
754
+ qw,
755
+ bias,
756
+ padding,
757
+ stride,
758
+ dilation,
759
+ groups,
760
+ transposed,
761
+ output_padding,
762
+ )
763
+ # swap padding and stride to align with functional conv arg order
764
+ if bias is None:
765
+ constant_args[1], constant_args[2] = constant_args[2], constant_args[1]
766
+ else:
767
+ constant_args[0], constant_args[1] = constant_args[1], constant_args[0]
768
+
769
+ w_scale.realize()
770
+ w_zero_point.realize()
771
+ inputs = inputs + [w_scale, w_zero_point]
772
+
773
+ constant_args = (
774
+ [
775
+ x_scale,
776
+ x_zero_point,
777
+ ]
778
+ + constant_args
779
+ + [
780
+ output_scale,
781
+ output_zero_point,
782
+ output_dtype,
783
+ attr,
784
+ may_convert_to_optional(scalars),
785
+ algorithm,
786
+ ]
787
+ )
788
+
789
+ assert output_dtype is not None
790
+ if output_dtype in [torch.float32, torch.bfloat16]:
791
+ # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout
792
+ # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8.
793
+ kernel_layout.dtype = output_dtype
794
+
795
+ return QConvPointWisePT2E(
796
+ layout=kernel_layout,
797
+ inputs=inputs,
798
+ constant_args=constant_args,
799
+ )
800
+
801
+
802
+ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
803
+ def __init__(
804
+ self,
805
+ layout,
806
+ inputs,
807
+ constant_args=(),
808
+ ) -> None:
809
+ """
810
+ Needs input/weight/output qparams
811
+ if bias is not None
812
+ - inputs = [x, w, b, accum, w_scale, w_zp]
813
+ - const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_scale, o_zp,
814
+ fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
815
+ else
816
+ - inputs = [x, w, accum, w_scale, w_zp]
817
+ - const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale,
818
+ accum_zp, o_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
819
+ """
820
+ self.has_bias = len(inputs) == 6
821
+ self.idx_for_inplace_sum = 3 if self.has_bias else 2
822
+ super().__init__(
823
+ layout,
824
+ inputs,
825
+ constant_args,
826
+ None,
827
+ op_overload=torch.ops.onednn.qconv2d_pointwise.binary,
828
+ )
829
+ self.cpp_op_schema = """
830
+ at::Tensor(
831
+ at::Tensor act,
832
+ double act_scale,
833
+ int64_t act_zero_point,
834
+ at::Tensor accum,
835
+ double accum_scale,
836
+ int64_t accum_zero_point,
837
+ at::Tensor weight,
838
+ at::Tensor weight_scales,
839
+ at::Tensor weight_zero_points,
840
+ std::optional<at::Tensor> bias,
841
+ torch::List<int64_t> stride,
842
+ torch::List<int64_t> padding,
843
+ torch::List<int64_t> dilation,
844
+ int64_t groups,
845
+ double output_scale,
846
+ int64_t output_zero_point,
847
+ std::optional<c10::ScalarType> output_dtype,
848
+ c10::string_view binary_attr,
849
+ std::optional<at::Scalar> alpha,
850
+ std::optional<c10::string_view> attr,
851
+ torch::List<std::optional<at::Scalar>> scalars,
852
+ std::optional<c10::string_view> algorithm)"""
853
+
854
+ def codegen(self, wrapper):
855
+ # Parser the inputs and constant
856
+ # The raw_args setup can be skipped if there is a C shim implementation
857
+ args = [x.codegen_reference() for x in self.inputs]
858
+ const_arg_names = [
859
+ "x_scale",
860
+ "x_zero_point",
861
+ "accum_scale",
862
+ "accum_zero_point",
863
+ "stride",
864
+ "padding",
865
+ "dilation",
866
+ "groups",
867
+ "output_scale",
868
+ "output_zero_point",
869
+ "output_dtype",
870
+ "binary_attr",
871
+ "alpha",
872
+ "unary_attr",
873
+ "unary_scalars",
874
+ "unary_algorithm",
875
+ ]
876
+ if not self.has_bias:
877
+ const_arg_names.insert(4, "bias")
878
+ const_args = list(self.codegen_const_args(const_arg_names))
879
+
880
+ x = args[0]
881
+ x_raw = self.inputs[0]
882
+ packed_weight = args[1]
883
+ packed_weight_raw = self.inputs[1]
884
+ bias = args[2] if self.has_bias else const_args[4]
885
+ bias_raw = self.inputs[2] if self.has_bias else self.constant_args[4]
886
+ accum, w_scale, w_zp = args[-3], args[-2], args[-1]
887
+ accum_raw, w_scale_raw, w_zp_raw = (
888
+ self.inputs[-3],
889
+ self.inputs[-2],
890
+ self.inputs[-1],
891
+ )
892
+ (
893
+ x_scale,
894
+ x_zp,
895
+ accum_scale,
896
+ accum_zp,
897
+ ) = const_args[:4]
898
+ (
899
+ x_scale_raw,
900
+ x_zp_raw,
901
+ accum_scale_raw,
902
+ accum_zp_raw,
903
+ ) = self.constant_args[:4]
904
+ (
905
+ stride,
906
+ padding,
907
+ dilation,
908
+ groups,
909
+ o_scale,
910
+ o_zp,
911
+ output_dtype,
912
+ binary_attr,
913
+ alpha,
914
+ unary_attr,
915
+ unary_scalars,
916
+ unary_algorithm,
917
+ ) = const_args[-12:]
918
+ (
919
+ stride_raw,
920
+ padding_raw,
921
+ dilation_raw,
922
+ groups_raw,
923
+ o_scale_raw,
924
+ o_zp_raw,
925
+ output_dtype_raw,
926
+ binary_attr_raw,
927
+ alpha_raw,
928
+ unary_attr_raw,
929
+ unary_scalars_raw,
930
+ unary_algorithm_raw,
931
+ ) = self.constant_args[-12:]
932
+ conv_args = (
933
+ x,
934
+ x_scale,
935
+ x_zp,
936
+ accum,
937
+ accum_scale,
938
+ accum_zp,
939
+ packed_weight,
940
+ w_scale,
941
+ w_zp,
942
+ bias,
943
+ stride,
944
+ padding,
945
+ dilation,
946
+ groups,
947
+ o_scale,
948
+ o_zp,
949
+ output_dtype,
950
+ binary_attr,
951
+ alpha,
952
+ unary_attr,
953
+ unary_scalars,
954
+ unary_algorithm,
955
+ )
956
+ raw_args = (
957
+ x_raw,
958
+ x_scale_raw,
959
+ x_zp_raw,
960
+ accum_raw,
961
+ accum_scale_raw,
962
+ accum_zp_raw,
963
+ packed_weight_raw,
964
+ w_scale_raw,
965
+ w_zp_raw,
966
+ bias_raw,
967
+ stride_raw,
968
+ padding_raw,
969
+ dilation_raw,
970
+ groups_raw,
971
+ o_scale_raw,
972
+ o_zp_raw,
973
+ output_dtype_raw,
974
+ binary_attr_raw,
975
+ alpha_raw,
976
+ unary_attr_raw,
977
+ unary_scalars_raw,
978
+ unary_algorithm_raw,
979
+ )
980
+ wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
981
+ self.get_name(),
982
+ self.python_kernel_name,
983
+ self.cpp_kernel_name,
984
+ conv_args,
985
+ self.cpp_op_schema,
986
+ self.cpp_kernel_key,
987
+ self.cpp_kernel_overload_name,
988
+ op_overload=self.op_overload,
989
+ raw_args=raw_args,
990
+ )
991
+ if isinstance(self.layout, Layout):
992
+ self.codegen_size_asserts(wrapper)
993
+
994
+ def get_mutation_names(self):
995
+ return [self.inputs[self.idx_for_inplace_sum].get_name()]
996
+
997
+ def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
998
+ return OrderedSet()
999
+
1000
+ @classmethod
1001
+ def create(
1002
+ cls,
1003
+ qx: "TensorBox",
1004
+ x_scale,
1005
+ x_zero_point,
1006
+ qaccum: "TensorBox",
1007
+ accum_scale,
1008
+ accum_zero_point,
1009
+ qw: "TensorBox", # packed_weight
1010
+ w_scale,
1011
+ w_zero_point,
1012
+ bias: "TensorBox",
1013
+ stride: List[int],
1014
+ padding: List[int],
1015
+ dilation: List[int],
1016
+ groups: int,
1017
+ output_scale: "TensorBox",
1018
+ output_zero_point: "TensorBox",
1019
+ output_dtype,
1020
+ binary_attr,
1021
+ alpha,
1022
+ unary_attr,
1023
+ unary_scalars,
1024
+ unary_algorithm,
1025
+ ):
1026
+ transposed = False
1027
+ output_padding = None
1028
+ (
1029
+ inputs,
1030
+ constant_args,
1031
+ kernel_layout,
1032
+ req_stride_order,
1033
+ ) = _prepare_convolution_fusion_create(
1034
+ cls,
1035
+ qx,
1036
+ qw,
1037
+ bias,
1038
+ padding,
1039
+ stride,
1040
+ dilation,
1041
+ groups,
1042
+ transposed,
1043
+ output_padding,
1044
+ )
1045
+
1046
+ qaccum = cls.require_stride_order(qaccum, req_stride_order)
1047
+ inputs.append(qaccum)
1048
+
1049
+ # swap padding and stride to align with functional conv arg order
1050
+ if bias is None:
1051
+ constant_args[1], constant_args[2] = constant_args[2], constant_args[1]
1052
+ else:
1053
+ constant_args[0], constant_args[1] = constant_args[1], constant_args[0]
1054
+
1055
+ w_scale.realize()
1056
+ w_zero_point.realize()
1057
+ inputs = inputs + [w_scale, w_zero_point]
1058
+ constant_args = (
1059
+ [
1060
+ x_scale,
1061
+ x_zero_point,
1062
+ accum_scale,
1063
+ accum_zero_point,
1064
+ ]
1065
+ + constant_args
1066
+ + [
1067
+ output_scale,
1068
+ output_zero_point,
1069
+ output_dtype,
1070
+ binary_attr,
1071
+ alpha,
1072
+ unary_attr,
1073
+ may_convert_to_optional(unary_scalars),
1074
+ unary_algorithm,
1075
+ ]
1076
+ )
1077
+
1078
+ assert (
1079
+ binary_attr == "sum"
1080
+ ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
1081
+
1082
+ V.graph.mark_buffer_mutated(qaccum.get_name())
1083
+ packed = QConvPointWiseBinaryPT2E(
1084
+ layout=NoneLayout(qaccum.get_device()),
1085
+ inputs=inputs,
1086
+ constant_args=constant_args,
1087
+ )
1088
+
1089
+ # Return accum since it has been inplace changed.
1090
+ return packed.inputs[packed.idx_for_inplace_sum]
1091
+
1092
+
1093
+ class MKLPackedLinear(ExternKernelAlloc):
1094
+ def __init__(
1095
+ self,
1096
+ layout,
1097
+ inputs,
1098
+ constant_args=(),
1099
+ ) -> None:
1100
+ super().__init__(
1101
+ layout,
1102
+ inputs,
1103
+ constant_args,
1104
+ None,
1105
+ op_overload=torch.ops.mkl._mkl_linear.default,
1106
+ )
1107
+ self.cpp_op_schema = """
1108
+ at::Tensor(
1109
+ const at::Tensor& self,
1110
+ const at::Tensor& mkl_weight_t,
1111
+ const at::Tensor& origin_weight_t,
1112
+ const std::optional<at::Tensor>& bias_opt,
1113
+ const int64_t prepack_batch_size)"""
1114
+
1115
+ def codegen(self, wrapper):
1116
+ wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
1117
+ self.get_name(),
1118
+ self.python_kernel_name,
1119
+ self.cpp_kernel_name,
1120
+ self.codegen_args(),
1121
+ self.cpp_op_schema,
1122
+ self.cpp_kernel_key,
1123
+ )
1124
+
1125
+ @classmethod
1126
+ def create(cls, x, packed_w, orig_w, B, batch_size):
1127
+ x = cls.require_stride1(cls.realize_input(x))
1128
+ orig_w = cls.require_stride1(cls.realize_input(orig_w))
1129
+ *m, _ = x.get_size()
1130
+ oc, _ = orig_w.get_size()
1131
+ output_size = list(m) + [oc]
1132
+ output_stride = FlexibleLayout.contiguous_strides(output_size)
1133
+ inputs = [x, packed_w, orig_w]
1134
+ constant_args = [batch_size]
1135
+ if B is not None:
1136
+ inputs += [B]
1137
+ else:
1138
+ constant_args.insert(0, None)
1139
+
1140
+ return MKLPackedLinear(
1141
+ layout=FixedLayout(
1142
+ x.get_device(), x.get_dtype(), output_size, output_stride
1143
+ ),
1144
+ inputs=inputs,
1145
+ constant_args=constant_args,
1146
+ )
1147
+
1148
+
1149
+ class LinearUnary(ExternKernelAlloc):
1150
+ def __init__(
1151
+ self,
1152
+ layout,
1153
+ inputs,
1154
+ constant_args=(),
1155
+ ) -> None:
1156
+ super().__init__(
1157
+ layout,
1158
+ inputs,
1159
+ constant_args,
1160
+ None,
1161
+ op_overload=torch.ops.mkldnn._linear_pointwise.default,
1162
+ )
1163
+ self.cpp_kernel_key = "linear_pointwise"
1164
+ self.cpp_op_schema = """
1165
+ at::Tensor(
1166
+ const at::Tensor& input_t,
1167
+ const at::Tensor& weight_t,
1168
+ const std::optional<at::Tensor>& bias_opt,
1169
+ c10::string_view attr,
1170
+ torch::List<std::optional<at::Scalar>> scalars,
1171
+ std::optional<c10::string_view> algorithm)"""
1172
+
1173
+ def codegen(self, wrapper):
1174
+ wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
1175
+ self.get_name(),
1176
+ self.python_kernel_name,
1177
+ self.cpp_kernel_name,
1178
+ self.codegen_args(),
1179
+ self.cpp_op_schema,
1180
+ self.cpp_kernel_key,
1181
+ )
1182
+
1183
+ @classmethod
1184
+ def create(cls, x, w, B, attr, scalars, algorithm):
1185
+ x = cls.require_contiguous(cls.realize_input(x))
1186
+ w = cls.require_contiguous(cls.realize_input(w))
1187
+
1188
+ *m, ic = x.get_size()
1189
+ oc, ic = w.get_size()
1190
+ inputs = [x, w]
1191
+ constant_args = [attr, scalars if scalars else [-1], algorithm]
1192
+ if B is not None:
1193
+ B = cls.require_contiguous(cls.realize_input(B))
1194
+ inputs.append(B)
1195
+ else:
1196
+ constant_args.insert(0, None)
1197
+
1198
+ return LinearUnary(
1199
+ layout=FlexibleLayout(
1200
+ device=x.get_device(),
1201
+ dtype=x.get_dtype(),
1202
+ size=list(m) + [oc],
1203
+ ),
1204
+ inputs=inputs,
1205
+ constant_args=constant_args,
1206
+ )
1207
+
1208
+ def apply_constraint(self):
1209
+ pass
1210
+
1211
+
1212
+ class LinearBinary(ExternKernelAlloc):
1213
+ kernel = "torch.ops.mkldnn._linear_pointwise.binary"
1214
+
1215
+ def __init__(
1216
+ self,
1217
+ layout,
1218
+ inputs,
1219
+ constant_args=(),
1220
+ ) -> None:
1221
+ super().__init__(
1222
+ layout,
1223
+ inputs,
1224
+ constant_args,
1225
+ None,
1226
+ op_overload=torch.ops.mkldnn._linear_pointwise.binary,
1227
+ )
1228
+ self.cpp_op_schema = """
1229
+ at::Tensor(
1230
+ const at::Tensor& input_t,
1231
+ const at::Tensor& other_t,
1232
+ const at::Tensor& weight_t,
1233
+ const std::optional<at::Tensor>& bias_opt,
1234
+ c10::string_view attr)
1235
+ """
1236
+
1237
+ def codegen(self, wrapper):
1238
+ wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
1239
+ self.get_name(),
1240
+ self.python_kernel_name,
1241
+ self.cpp_kernel_name,
1242
+ self.codegen_args(),
1243
+ self.cpp_op_schema,
1244
+ self.cpp_kernel_key,
1245
+ self.cpp_kernel_overload_name,
1246
+ )
1247
+
1248
+ @classmethod
1249
+ def create(cls, x, y, w, B, attr):
1250
+ x = cls.require_contiguous(cls.realize_input(x))
1251
+ y = cls.require_contiguous(cls.realize_input(y))
1252
+ w = cls.require_contiguous(cls.realize_input(w))
1253
+
1254
+ *m, ic = x.get_size()
1255
+ oc, ic = w.get_size()
1256
+
1257
+ inputs = [x, y, w]
1258
+ constant_args = [attr]
1259
+ if B is not None:
1260
+ B = cls.require_contiguous(cls.realize_input(B))
1261
+ inputs.append(B)
1262
+ else:
1263
+ constant_args.insert(0, B)
1264
+
1265
+ return LinearBinary(
1266
+ layout=FlexibleLayout(
1267
+ device=x.get_device(),
1268
+ dtype=x.get_dtype(),
1269
+ size=list(m) + [oc],
1270
+ ),
1271
+ inputs=inputs,
1272
+ constant_args=constant_args,
1273
+ )
1274
+
1275
+ def apply_constraint(self):
1276
+ pass
1277
+
1278
+
1279
+ class QLinearPointwisePT2E(ExternKernelAlloc):
1280
+ def __init__(
1281
+ self,
1282
+ layout,
1283
+ inputs,
1284
+ constant_args=(),
1285
+ has_bias=True,
1286
+ x_scale_zp_are_tensors=False,
1287
+ ) -> None:
1288
+ """
1289
+ if bias is not None
1290
+ - inputs = [x, w, b, weight_scale, weight_zp]
1291
+ - const_args is: [x_scale, x_zp, o_scale, o_zp,
1292
+ fp32_output, unary_attr, unary_scalars, unary_algorithm]
1293
+ else
1294
+ - inputs = [x, w, weight_scale, weight_zp]
1295
+ - const_args is: [bias, x_scale, x_zp, o_scale, o_zp,
1296
+ fp32_output, unary_attr, unary_scalars, unary_algorithm]
1297
+ """
1298
+ self.has_bias = has_bias
1299
+ self.x_scale_zp_are_tensors = x_scale_zp_are_tensors
1300
+ super().__init__(
1301
+ layout,
1302
+ inputs,
1303
+ constant_args,
1304
+ None,
1305
+ op_overload=torch.ops.onednn.qlinear_pointwise.tensor
1306
+ if x_scale_zp_are_tensors
1307
+ else torch.ops.onednn.qlinear_pointwise.default,
1308
+ )
1309
+ x_scale_type_str, x_zp_type_str = (
1310
+ ("at::Tensor", "at::Tensor")
1311
+ if x_scale_zp_are_tensors
1312
+ else ("double", "int64_t")
1313
+ )
1314
+ self.cpp_op_schema = f"""
1315
+ at::Tensor(
1316
+ at::Tensor act,
1317
+ {x_scale_type_str} act_scale,
1318
+ {x_zp_type_str} act_zero_point,
1319
+ at::Tensor weight,
1320
+ at::Tensor weight_scales,
1321
+ at::Tensor weight_zero_points,
1322
+ std::optional<at::Tensor> bias,
1323
+ double output_scale,
1324
+ int64_t output_zero_point,
1325
+ std::optional<c10::ScalarType> output_dtype,
1326
+ c10::string_view post_op_name,
1327
+ torch::List<std::optional<at::Scalar>> post_op_args,
1328
+ c10::string_view post_op_algorithm)"""
1329
+
1330
+ def codegen(self, wrapper):
1331
+ # Parser the inputs and constant
1332
+ # The raw_args setup can be skipped if there is a C shim implementation
1333
+ args = [x.codegen_reference() for x in self.inputs]
1334
+ const_args = []
1335
+ const_args.extend(self.codegen_const_args())
1336
+
1337
+ x = args[0]
1338
+ x_raw = self.inputs[0]
1339
+ packed_weight = args[1]
1340
+ packed_weight_raw = self.inputs[1]
1341
+ bias = args[2] if self.has_bias else const_args[0]
1342
+ bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0]
1343
+ w_scale, w_zp = args[-2], args[-1]
1344
+ w_scale_raw, w_zp_raw = self.inputs[-2], self.inputs[-1]
1345
+ if self.x_scale_zp_are_tensors:
1346
+ assert len(args) >= 4
1347
+ x_scale, x_zp = args[-4], args[-3]
1348
+ x_scale_raw, x_zp_raw = self.inputs[-4], self.inputs[-3]
1349
+ (
1350
+ o_scale,
1351
+ o_zp,
1352
+ output_dtype,
1353
+ unary_attr,
1354
+ unary_scalars,
1355
+ unary_algorithm,
1356
+ ) = const_args[-6:]
1357
+ (
1358
+ o_scale_raw,
1359
+ o_zp_raw,
1360
+ output_dtype_raw,
1361
+ unary_attr_raw,
1362
+ unary_scalars_raw,
1363
+ unary_algorithm_raw,
1364
+ ) = self.constant_args[-6:]
1365
+ else:
1366
+ assert len(const_args) >= 8
1367
+ (
1368
+ x_scale,
1369
+ x_zp,
1370
+ o_scale,
1371
+ o_zp,
1372
+ output_dtype,
1373
+ unary_attr,
1374
+ unary_scalars,
1375
+ unary_algorithm,
1376
+ ) = const_args[-8:]
1377
+ (
1378
+ x_scale_raw,
1379
+ x_zp_raw,
1380
+ o_scale_raw,
1381
+ o_zp_raw,
1382
+ output_dtype_raw,
1383
+ unary_attr_raw,
1384
+ unary_scalars_raw,
1385
+ unary_algorithm_raw,
1386
+ ) = self.constant_args[-8:]
1387
+
1388
+ codegen_args = (
1389
+ x,
1390
+ x_scale,
1391
+ x_zp,
1392
+ packed_weight,
1393
+ w_scale,
1394
+ w_zp,
1395
+ bias,
1396
+ o_scale,
1397
+ o_zp,
1398
+ output_dtype,
1399
+ unary_attr,
1400
+ unary_scalars,
1401
+ unary_algorithm,
1402
+ )
1403
+ raw_args = (
1404
+ x_raw,
1405
+ x_scale_raw,
1406
+ x_zp_raw,
1407
+ packed_weight_raw,
1408
+ w_scale_raw,
1409
+ w_zp_raw,
1410
+ bias_raw,
1411
+ o_scale_raw,
1412
+ o_zp_raw,
1413
+ output_dtype_raw,
1414
+ unary_attr_raw,
1415
+ unary_scalars_raw,
1416
+ unary_algorithm_raw,
1417
+ )
1418
+ wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
1419
+ self.get_name(),
1420
+ self.python_kernel_name,
1421
+ self.cpp_kernel_name,
1422
+ codegen_args,
1423
+ self.cpp_op_schema,
1424
+ self.cpp_kernel_key,
1425
+ self.cpp_kernel_overload_name,
1426
+ self.op_overload,
1427
+ raw_args,
1428
+ )
1429
+ if isinstance(self.layout, Layout):
1430
+ self.codegen_size_asserts(wrapper)
1431
+
1432
+ @classmethod
1433
+ def create(
1434
+ cls,
1435
+ qx: "TensorBox",
1436
+ x_scale: float,
1437
+ x_zero_point: int,
1438
+ qw: "TensorBox", # packed_weight
1439
+ w_scale: "TensorBox",
1440
+ w_zero_point: "TensorBox",
1441
+ bias: "TensorBox",
1442
+ output_scale: float,
1443
+ output_zero_point: int,
1444
+ output_dtype,
1445
+ post_op_name,
1446
+ post_op_args,
1447
+ post_op_algorithm,
1448
+ ):
1449
+ (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create(
1450
+ cls,
1451
+ qx,
1452
+ qw,
1453
+ bias,
1454
+ )
1455
+
1456
+ if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox):
1457
+ x_scale.realize()
1458
+ x_zero_point.realize()
1459
+ inputs = inputs + [x_scale, x_zero_point]
1460
+ x_scale_zp_are_tensors = True
1461
+ else:
1462
+ assert isinstance(x_scale, float) and isinstance(x_zero_point, int)
1463
+ constant_args = constant_args + [x_scale, x_zero_point]
1464
+ x_scale_zp_are_tensors = False
1465
+ w_scale.realize()
1466
+ w_zero_point.realize()
1467
+ inputs = inputs + [w_scale, w_zero_point]
1468
+ constant_args = constant_args + [
1469
+ output_scale,
1470
+ output_zero_point,
1471
+ output_dtype,
1472
+ post_op_name,
1473
+ may_convert_to_optional(post_op_args),
1474
+ post_op_algorithm,
1475
+ ]
1476
+
1477
+ assert output_dtype is not None
1478
+ if output_dtype in [torch.float32, torch.bfloat16]:
1479
+ # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout
1480
+ # if we set fp32_output, the output buf should be dtype float32 instead of uint8.
1481
+ kernel_layout.dtype = output_dtype
1482
+
1483
+ return QLinearPointwisePT2E(
1484
+ layout=kernel_layout,
1485
+ inputs=inputs,
1486
+ constant_args=constant_args,
1487
+ has_bias=(bias is not None),
1488
+ x_scale_zp_are_tensors=x_scale_zp_are_tensors,
1489
+ )
1490
+
1491
+
1492
+ class QLinearPointwiseBinaryPT2E(ExternKernelAlloc):
1493
+ def __init__(
1494
+ self,
1495
+ layout,
1496
+ inputs,
1497
+ constant_args=(),
1498
+ has_bias=True,
1499
+ x_scale_zp_are_tensors=False,
1500
+ ) -> None:
1501
+ """
1502
+ if bias is not None
1503
+ - inputs = [x, w, b, weight_scale, weight_zp, x2]
1504
+ - const_args is: [x_scale, x_zp, o_scale, o_zp,
1505
+ fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
1506
+ else
1507
+ - inputs = [x, w, weight_scale, weight_zp, x2]
1508
+ - const_args is: [bias, x_scale, x_zp, o_scale, o_zp,
1509
+ fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
1510
+ """
1511
+ self.has_bias = has_bias
1512
+ self.x_scale_zp_are_tensors = x_scale_zp_are_tensors
1513
+ super().__init__(
1514
+ layout,
1515
+ inputs,
1516
+ constant_args,
1517
+ None,
1518
+ op_overload=torch.ops.onednn.qlinear_pointwise.binary_tensor
1519
+ if x_scale_zp_are_tensors
1520
+ else torch.ops.onednn.qlinear_pointwise.binary,
1521
+ )
1522
+ x_scale_type_str, x_zp_type_str = (
1523
+ ("at::Tensor", "at::Tensor")
1524
+ if x_scale_zp_are_tensors
1525
+ else ("double", "int64_t")
1526
+ )
1527
+ self.cpp_op_schema = f"""
1528
+ at::Tensor(
1529
+ at::Tensor act,
1530
+ {x_scale_type_str} act_scale,
1531
+ {x_zp_type_str} act_zero_point,
1532
+ at::Tensor weight,
1533
+ at::Tensor weight_scales,
1534
+ at::Tensor weight_zero_points,
1535
+ std::optional<at::Tensor> other,
1536
+ std::optional<at::Tensor> bias,
1537
+ double inv_output_scale,
1538
+ int64_t output_zero_point,
1539
+ std::optional<c10::ScalarType> output_dtype,
1540
+ double other_scale,
1541
+ int64_t other_zero_point,
1542
+ c10::string_view binary_post_op,
1543
+ double binary_alpha,
1544
+ c10::string_view unary_post_op,
1545
+ torch::List<std::optional<at::Scalar>> unary_post_op_args,
1546
+ c10::string_view unary_post_op_algorithm)"""
1547
+
1548
+ def codegen(self, wrapper):
1549
+ # Parser the inputs and constant
1550
+ # The raw_args setup can be skipped if there is a C shim implementation
1551
+ args = [x.codegen_reference() for x in self.inputs]
1552
+ const_args = []
1553
+ const_args.extend(self.codegen_const_args())
1554
+
1555
+ x = args[0]
1556
+ x_raw = self.inputs[0]
1557
+ packed_weight = args[1]
1558
+ packed_weight_raw = self.inputs[1]
1559
+ bias = args[2] if self.has_bias else const_args[0]
1560
+ bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0]
1561
+ w_scale, w_zp, other = args[-3], args[-2], args[-1]
1562
+ w_scale_raw, w_zp_raw, other_raw = (
1563
+ self.inputs[-3],
1564
+ self.inputs[-2],
1565
+ self.inputs[-1],
1566
+ )
1567
+ if self.x_scale_zp_are_tensors:
1568
+ assert len(args) >= 5
1569
+ x_scale, x_zp = args[-5], args[-4]
1570
+ x_scale_raw, x_zp_raw = self.inputs[-5], self.inputs[-4]
1571
+ (
1572
+ o_scale,
1573
+ o_zp,
1574
+ output_dtype,
1575
+ other_scale,
1576
+ other_zp,
1577
+ binary_attr,
1578
+ alpha,
1579
+ unary_attr,
1580
+ unary_scalars,
1581
+ unary_algorithm,
1582
+ ) = const_args[-10:]
1583
+ (
1584
+ o_scale_raw,
1585
+ o_zp_raw,
1586
+ output_dtype_raw,
1587
+ other_scale_raw,
1588
+ other_zp_raw,
1589
+ binary_attr_raw,
1590
+ alpha_raw,
1591
+ unary_attr_raw,
1592
+ unary_scalars_raw,
1593
+ unary_algorithm_raw,
1594
+ ) = self.constant_args[-10:]
1595
+ else:
1596
+ assert len(const_args) >= 8
1597
+ (
1598
+ x_scale,
1599
+ x_zp,
1600
+ o_scale,
1601
+ o_zp,
1602
+ output_dtype,
1603
+ other_scale,
1604
+ other_zp,
1605
+ binary_attr,
1606
+ alpha,
1607
+ unary_attr,
1608
+ unary_scalars,
1609
+ unary_algorithm,
1610
+ ) = const_args[-12:]
1611
+ (
1612
+ x_scale_raw,
1613
+ x_zp_raw,
1614
+ o_scale_raw,
1615
+ o_zp_raw,
1616
+ output_dtype_raw,
1617
+ other_scale_raw,
1618
+ other_zp_raw,
1619
+ binary_attr_raw,
1620
+ alpha_raw,
1621
+ unary_attr_raw,
1622
+ unary_scalars_raw,
1623
+ unary_algorithm_raw,
1624
+ ) = self.constant_args[-12:]
1625
+
1626
+ codegen_args = (
1627
+ x,
1628
+ x_scale,
1629
+ x_zp,
1630
+ packed_weight,
1631
+ w_scale,
1632
+ w_zp,
1633
+ other,
1634
+ bias,
1635
+ o_scale,
1636
+ o_zp,
1637
+ output_dtype,
1638
+ other_scale,
1639
+ other_zp,
1640
+ binary_attr,
1641
+ alpha,
1642
+ unary_attr,
1643
+ unary_scalars,
1644
+ unary_algorithm,
1645
+ )
1646
+ raw_args = (
1647
+ x_raw,
1648
+ x_scale_raw,
1649
+ x_zp_raw,
1650
+ packed_weight_raw,
1651
+ w_scale_raw,
1652
+ w_zp_raw,
1653
+ other_raw,
1654
+ bias_raw,
1655
+ o_scale_raw,
1656
+ o_zp_raw,
1657
+ output_dtype_raw,
1658
+ other_scale_raw,
1659
+ other_zp_raw,
1660
+ binary_attr_raw,
1661
+ alpha_raw,
1662
+ unary_attr_raw,
1663
+ unary_scalars_raw,
1664
+ unary_algorithm_raw,
1665
+ )
1666
+ wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
1667
+ self.get_name(),
1668
+ self.python_kernel_name,
1669
+ self.cpp_kernel_name,
1670
+ codegen_args,
1671
+ self.cpp_op_schema,
1672
+ self.cpp_kernel_key,
1673
+ self.cpp_kernel_overload_name,
1674
+ self.op_overload,
1675
+ raw_args,
1676
+ )
1677
+ if isinstance(self.layout, Layout):
1678
+ self.codegen_size_asserts(wrapper)
1679
+
1680
+ def get_mutation_names(self):
1681
+ binary_post_op = self.constant_args[-5]
1682
+ if binary_post_op == "sum":
1683
+ return [self.inputs[-1].get_name()]
1684
+ else:
1685
+ return []
1686
+
1687
+ @classmethod
1688
+ def create(
1689
+ cls,
1690
+ qx: "TensorBox",
1691
+ x_scale: float,
1692
+ x_zero_point: int,
1693
+ qw: "TensorBox", # packed_weight
1694
+ w_scale: "TensorBox",
1695
+ w_zero_point: "TensorBox",
1696
+ other: "TensorBox",
1697
+ bias: "TensorBox",
1698
+ output_scale: float,
1699
+ output_zero_point: int,
1700
+ output_dtype,
1701
+ other_scale,
1702
+ other_zp,
1703
+ binary_post_op,
1704
+ binary_alpha,
1705
+ unary_post_op,
1706
+ unary_post_op_args,
1707
+ unary_post_op_algorithm,
1708
+ ):
1709
+ (
1710
+ inputs,
1711
+ constant_args,
1712
+ kernel_layout,
1713
+ req_stride_order,
1714
+ ) = _prepare_linear_fusion_create(
1715
+ cls,
1716
+ qx,
1717
+ qw,
1718
+ bias,
1719
+ )
1720
+
1721
+ if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox):
1722
+ x_scale.realize()
1723
+ x_zero_point.realize()
1724
+ inputs = inputs + [x_scale, x_zero_point]
1725
+ x_scale_zp_are_tensors = True
1726
+ else:
1727
+ assert isinstance(x_scale, float) and isinstance(x_zero_point, int)
1728
+ constant_args = constant_args + [x_scale, x_zero_point]
1729
+ x_scale_zp_are_tensors = False
1730
+ w_scale.realize()
1731
+ w_zero_point.realize()
1732
+ inputs = inputs + [w_scale, w_zero_point]
1733
+ if binary_post_op == "sum":
1734
+ other = cls.require_stride_order(other, req_stride_order)
1735
+ inputs.append(other)
1736
+ constant_args = constant_args + [
1737
+ output_scale,
1738
+ output_zero_point,
1739
+ output_dtype,
1740
+ other_scale,
1741
+ other_zp,
1742
+ binary_post_op,
1743
+ binary_alpha,
1744
+ unary_post_op,
1745
+ may_convert_to_optional(unary_post_op_args),
1746
+ unary_post_op_algorithm,
1747
+ ]
1748
+
1749
+ if binary_post_op == "sum":
1750
+ V.graph.mark_buffer_mutated(other.get_name())
1751
+ packed = QLinearPointwiseBinaryPT2E(
1752
+ layout=NoneLayout(other.get_device()),
1753
+ inputs=inputs,
1754
+ constant_args=constant_args,
1755
+ has_bias=(bias is not None),
1756
+ x_scale_zp_are_tensors=x_scale_zp_are_tensors,
1757
+ )
1758
+ # Return other since it has been inplace changed.
1759
+ return packed.inputs[-1]
1760
+
1761
+ assert output_dtype is not None
1762
+ if output_dtype in [torch.float32, torch.bfloat16]:
1763
+ # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout
1764
+ # if we set fp32_output, the output buf should be dtype float32 instead of uint8.
1765
+ kernel_layout.dtype = output_dtype
1766
+
1767
+ return QLinearPointwiseBinaryPT2E(
1768
+ layout=kernel_layout,
1769
+ inputs=inputs,
1770
+ constant_args=constant_args,
1771
+ has_bias=(bias is not None),
1772
+ x_scale_zp_are_tensors=x_scale_zp_are_tensors,
1773
+ )
1774
+
1775
+
1776
+ class MkldnnRnnLayer(ExternKernelAlloc):
1777
+ def __init__(
1778
+ self,
1779
+ layout,
1780
+ inputs,
1781
+ constant_args=(),
1782
+ ) -> None:
1783
+ super().__init__(
1784
+ layout,
1785
+ inputs,
1786
+ constant_args,
1787
+ None,
1788
+ op_overload=torch.ops.aten.mkldnn_rnn_layer.default,
1789
+ )
1790
+
1791
+ @classmethod
1792
+ def create(
1793
+ cls,
1794
+ x: "TensorBox",
1795
+ w0: "TensorBox",
1796
+ w1: "TensorBox",
1797
+ w2: "TensorBox",
1798
+ w3: "TensorBox",
1799
+ hx: "TensorBox",
1800
+ cx: "TensorBox",
1801
+ reverse: bool,
1802
+ batch_sizes: List[int],
1803
+ mode: int,
1804
+ hidden_size: int,
1805
+ num_layers: int,
1806
+ has_biases: bool,
1807
+ bidirectional: bool,
1808
+ batch_first: bool,
1809
+ train: bool,
1810
+ ):
1811
+ x = cls.require_stride1(cls.realize_input(x))
1812
+ # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer.
1813
+ # Make sure x is contiguous in batch_first case.
1814
+ x.freeze_layout()
1815
+ w0 = cls.require_stride1(cls.realize_input(w0))
1816
+ w1 = cls.require_stride1(cls.realize_input(w1))
1817
+ w2 = cls.require_stride1(cls.realize_input(w2))
1818
+ w3 = cls.require_stride1(cls.realize_input(w3))
1819
+ hx = cls.require_stride1(cls.realize_input(hx))
1820
+ hx.freeze_layout()
1821
+ cx = cls.require_stride1(cls.realize_input(cx))
1822
+ cx.freeze_layout()
1823
+
1824
+ input_size = x.get_size()
1825
+ assert len(input_size) == 3, "Expect lstm input to be 3D"
1826
+ # batch_first is handled in the lstm OP. When entering
1827
+ # rnn_layer here, we'll always have batch_first = False
1828
+ seq_length, mini_batch, input_size = input_size
1829
+ output_shape = [seq_length, mini_batch, hidden_size]
1830
+
1831
+ hy_shape = hx.get_size()
1832
+ cy_shape = cx.get_size()
1833
+
1834
+ res: List[IRNode] = []
1835
+
1836
+ inputs = [x, w0, w1, w2, w3, hx, cx]
1837
+ constant_args = [
1838
+ reverse,
1839
+ batch_sizes,
1840
+ mode,
1841
+ hidden_size,
1842
+ num_layers,
1843
+ has_biases,
1844
+ bidirectional,
1845
+ batch_first,
1846
+ train,
1847
+ ]
1848
+
1849
+ packed = MkldnnRnnLayer(
1850
+ MultiOutputLayout(x.get_device()),
1851
+ inputs=inputs,
1852
+ constant_args=constant_args,
1853
+ )
1854
+
1855
+ def get_strides_of_lstm_output(output_shape, batch_first):
1856
+ assert len(output_shape) == 3, "Expect output_shape to be 3D"
1857
+ return FlexibleLayout.contiguous_strides(output_shape)
1858
+
1859
+ output_sizes = [output_shape, hy_shape, cy_shape]
1860
+ output_strides = [
1861
+ get_strides_of_lstm_output(output_shape, batch_first),
1862
+ FlexibleLayout.contiguous_strides(hy_shape),
1863
+ FlexibleLayout.contiguous_strides(cy_shape),
1864
+ ]
1865
+ output_ir = [
1866
+ MultiOutput(
1867
+ FixedLayout(
1868
+ x.get_device(),
1869
+ x.get_dtype(),
1870
+ output_size,
1871
+ output_stride,
1872
+ ),
1873
+ packed,
1874
+ [(tuple, i)],
1875
+ )
1876
+ for i, (output_size, output_stride) in enumerate(
1877
+ zip(output_sizes, output_strides)
1878
+ )
1879
+ ]
1880
+
1881
+ return output_ir
.venv/lib/python3.11/site-packages/torch/_inductor/mkldnn_lowerings.py ADDED
@@ -0,0 +1,1087 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ import functools
4
+ from typing import List, Optional
5
+
6
+ import torch
7
+ import torch.utils._pytree as pytree
8
+ from torch._inductor.kernel.mm_common import mm_args
9
+
10
+ from . import ir
11
+ from .codegen.cpp_gemm_template import CppPackedGemmTemplate
12
+ from .codegen.cpp_utils import create_epilogue_with_attr
13
+ from .ir import TensorBox
14
+ from .lowering import (
15
+ add,
16
+ add_needs_realized_inputs,
17
+ aten,
18
+ permute,
19
+ register_lowering,
20
+ to_dtype,
21
+ view,
22
+ )
23
+ from .select_algorithm import (
24
+ autotune_select_algorithm,
25
+ ChoiceCaller,
26
+ ExternKernelChoice,
27
+ )
28
+ from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune
29
+ from .virtualized import ops, V
30
+
31
+
32
+ def register_onednn_fusion_ops():
33
+ if torch._C._has_mkldnn:
34
+ from . import mkldnn_ir
35
+
36
+ aten_mkldnn_linear_unary = ExternKernelChoice(
37
+ torch.ops.mkldnn._linear_pointwise,
38
+ "mkldnn::_linear_pointwise",
39
+ has_out_variant=False,
40
+ kernel_creator=mkldnn_ir.LinearUnary.create,
41
+ )
42
+ aten_mkldnn_linear_binary = ExternKernelChoice(
43
+ torch.ops.mkldnn._linear_pointwise.binary,
44
+ "mkldnn::_linear_pointwise",
45
+ has_out_variant=False,
46
+ kernel_creator=mkldnn_ir.LinearBinary.create,
47
+ )
48
+ aten_mkldnn_qlinear_unary = ExternKernelChoice(
49
+ torch.ops.onednn.qlinear_pointwise,
50
+ "onednn::qlinear_pointwise",
51
+ has_out_variant=False,
52
+ kernel_creator=mkldnn_ir.QLinearPointwisePT2E.create,
53
+ )
54
+ aten_mkldnn_qlinear_binary = ExternKernelChoice(
55
+ torch.ops.onednn.qlinear_pointwise.binary,
56
+ "onednn::qlinear_pointwise",
57
+ has_out_variant=False,
58
+ kernel_creator=mkldnn_ir.QLinearPointwiseBinaryPT2E.create,
59
+ )
60
+ cpu_needs_realized_inputs = [
61
+ torch.ops.mkldnn._convolution_pointwise,
62
+ torch.ops.mkldnn._convolution_pointwise_,
63
+ torch.ops.mkldnn._convolution_transpose_pointwise,
64
+ torch.ops.mkldnn._linear_pointwise,
65
+ aten.mkldnn_rnn_layer.default,
66
+ torch.ops.onednn.qconv2d_pointwise,
67
+ ]
68
+
69
+ @register_lowering(torch.ops.mkldnn._convolution_pointwise)
70
+ def convolution_unary(
71
+ x: TensorBox,
72
+ weight: TensorBox,
73
+ bias: TensorBox,
74
+ padding,
75
+ stride,
76
+ dilation,
77
+ groups,
78
+ attr,
79
+ scalars,
80
+ algorithm,
81
+ ):
82
+ return TensorBox.create(
83
+ mkldnn_ir.ConvolutionUnary.create(
84
+ x,
85
+ weight,
86
+ bias,
87
+ padding,
88
+ stride,
89
+ dilation,
90
+ groups,
91
+ attr,
92
+ scalars,
93
+ algorithm,
94
+ )
95
+ )
96
+
97
+ @register_lowering(torch.ops.mkldnn._convolution_pointwise.binary)
98
+ def convolution_binary(
99
+ x: TensorBox,
100
+ other: TensorBox,
101
+ weight: TensorBox,
102
+ bias: TensorBox,
103
+ padding,
104
+ stride,
105
+ dilation,
106
+ groups,
107
+ binary_attr,
108
+ binary_alpha,
109
+ unary_attr,
110
+ unary_scalars,
111
+ unary_algorithm,
112
+ ):
113
+ return TensorBox.create(
114
+ mkldnn_ir.ConvolutionBinary.create(
115
+ x,
116
+ other,
117
+ weight,
118
+ bias,
119
+ padding,
120
+ stride,
121
+ dilation,
122
+ groups,
123
+ binary_attr,
124
+ binary_alpha,
125
+ unary_attr,
126
+ unary_scalars,
127
+ unary_algorithm,
128
+ )
129
+ )
130
+
131
+ @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary)
132
+ def convolution_binary_inplace(
133
+ x: TensorBox,
134
+ other: TensorBox,
135
+ weight: TensorBox,
136
+ bias: TensorBox,
137
+ padding,
138
+ stride,
139
+ dilation,
140
+ groups,
141
+ binary_attr,
142
+ binary_alpha,
143
+ unary_attr,
144
+ unary_scalars,
145
+ unary_algorithm,
146
+ ):
147
+ return TensorBox.create(
148
+ mkldnn_ir.ConvolutionBinaryInplace.create(
149
+ x,
150
+ other,
151
+ weight,
152
+ bias,
153
+ padding,
154
+ stride,
155
+ dilation,
156
+ groups,
157
+ binary_attr,
158
+ binary_alpha,
159
+ unary_attr,
160
+ unary_scalars,
161
+ unary_algorithm,
162
+ )
163
+ )
164
+
165
+ @register_lowering(torch.ops.mkldnn._linear_pointwise)
166
+ def linear_unary(
167
+ x: TensorBox,
168
+ w: TensorBox,
169
+ b: TensorBox,
170
+ attr,
171
+ scalars,
172
+ algorithm,
173
+ layout=None,
174
+ ):
175
+ x_size = x.get_size()
176
+ if len(x_size) > 2:
177
+ # GEMM template needs 2D input, normalize input shape here
178
+ x = view(x, [-1, x_size[-1]])
179
+ if b is not None:
180
+ b = ir.ExternKernel.realize_input(b)
181
+ choices: List[ChoiceCaller] = []
182
+ if use_max_autotune():
183
+ transposed_w = permute(w, [1, 0])
184
+ *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout)
185
+ if use_cpp_packed_gemm_template(layout, x, transposed_w):
186
+
187
+ def epilogue_creator(buf):
188
+ return create_epilogue_with_attr(
189
+ buf, attr, scalars=scalars, algorithm=algorithm
190
+ )
191
+
192
+ kwargs = dict(
193
+ has_bias=b is not None,
194
+ trans_w=True,
195
+ epilogue_creator=None if attr == "none" else epilogue_creator,
196
+ )
197
+ if b is not None:
198
+ kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment]
199
+ CppPackedGemmTemplate.add_choices(
200
+ choices,
201
+ layout,
202
+ [x, w] if b is None else [x, w, b],
203
+ **kwargs, # type: ignore[arg-type]
204
+ )
205
+ if len(choices) == 0 or use_aten_gemm_kernels():
206
+ kwargs = dict(attr=attr, scalars=scalars, algorithm=algorithm)
207
+ if b is None:
208
+ kwargs["B"] = None
209
+ choices.append(
210
+ aten_mkldnn_linear_unary.bind(
211
+ [x, w] if b is None else [x, w, b],
212
+ layout,
213
+ **kwargs,
214
+ )
215
+ )
216
+ assert w.get_name() in V.graph.constants
217
+ input_gen_fns = {
218
+ 1: lambda x: V.graph.constants[x.get_name()],
219
+ }
220
+ result = autotune_select_algorithm(
221
+ "linear_unary",
222
+ choices,
223
+ [x, w] if b is None else [x, w, b],
224
+ layout,
225
+ input_gen_fns=input_gen_fns,
226
+ )
227
+ if len(x_size) > 2:
228
+ result = view(result, (*x_size[:-1], result.get_size()[-1]))
229
+ return result
230
+
231
+ @register_lowering(torch.ops.mkldnn._linear_pointwise.binary)
232
+ def linear_binary(
233
+ x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr, layout=None
234
+ ):
235
+ x_size = x.get_size()
236
+ if len(x_size) > 2:
237
+ # GEMM template needs 2D input, normalize input shape here
238
+ x = view(x, [-1, x_size[-1]])
239
+ y_size = y.get_size()
240
+ if len(y_size) > 2:
241
+ y = view(y, [-1, y_size[-1]])
242
+ if b is not None:
243
+ b = ir.ExternKernel.realize_input(b)
244
+ choices: List[ChoiceCaller] = []
245
+ if use_max_autotune():
246
+ transposed_w = permute(w, [1, 0])
247
+ *_, layout, x, transposed_w, y = mm_args(
248
+ x, transposed_w, y, layout=layout
249
+ )
250
+ if use_cpp_packed_gemm_template(layout, x, transposed_w):
251
+
252
+ def epilogue_creator(buf):
253
+ return create_epilogue_with_attr(buf, attr, other=y)
254
+
255
+ kwargs = dict(
256
+ has_bias=b is not None,
257
+ trans_w=True,
258
+ epilogue_creator=epilogue_creator,
259
+ )
260
+ kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]
261
+ CppPackedGemmTemplate.add_choices(
262
+ choices,
263
+ layout,
264
+ [x, y, w] if b is None else [x, y, w, b],
265
+ **kwargs, # type: ignore[arg-type]
266
+ )
267
+ if len(choices) == 0 or use_aten_gemm_kernels():
268
+ kwargs = dict(attr=attr)
269
+ if b is None:
270
+ kwargs["B"] = None
271
+ choices.append(
272
+ aten_mkldnn_linear_binary.bind(
273
+ [x, y, w] if b is None else [x, y, w, b],
274
+ layout,
275
+ **kwargs,
276
+ )
277
+ )
278
+ assert w.get_name() in V.graph.constants
279
+ input_gen_fns = {
280
+ 2: lambda x: V.graph.constants[x.get_name()],
281
+ }
282
+ result = autotune_select_algorithm(
283
+ "linear_binary",
284
+ choices,
285
+ [x, y, w] if b is None else [x, y, w, b],
286
+ layout,
287
+ input_gen_fns=input_gen_fns,
288
+ )
289
+ if len(x_size) > 2:
290
+ result = view(result, (*x_size[:-1], result.get_size()[-1]))
291
+ return result
292
+
293
+ @register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise)
294
+ def convolution_transpose_unary(
295
+ x: TensorBox,
296
+ weight: TensorBox,
297
+ bias: TensorBox,
298
+ padding,
299
+ output_padding,
300
+ stride,
301
+ dilation,
302
+ groups,
303
+ attr,
304
+ scalars,
305
+ algorithm,
306
+ ):
307
+ return TensorBox.create(
308
+ mkldnn_ir.ConvolutionTransposeUnary.create(
309
+ x,
310
+ weight,
311
+ bias,
312
+ padding,
313
+ output_padding,
314
+ stride,
315
+ dilation,
316
+ groups,
317
+ attr,
318
+ scalars,
319
+ algorithm,
320
+ )
321
+ )
322
+
323
+ @register_lowering(aten.mkldnn_rnn_layer.default)
324
+ def mkldnn_rnn_layer(
325
+ x: TensorBox,
326
+ w0: TensorBox,
327
+ w1: TensorBox,
328
+ w2: TensorBox,
329
+ w3: TensorBox,
330
+ hx: TensorBox,
331
+ cx: TensorBox,
332
+ reverse: bool,
333
+ batch_sizes: List[int],
334
+ mode: int,
335
+ hidden_size: int,
336
+ num_layers: int,
337
+ has_biases: bool,
338
+ bidirectional: bool,
339
+ batch_first: bool,
340
+ train: bool,
341
+ ):
342
+ return pytree.tree_map(
343
+ TensorBox.create,
344
+ mkldnn_ir.MkldnnRnnLayer.create(
345
+ x,
346
+ w0,
347
+ w1,
348
+ w2,
349
+ w3,
350
+ hx,
351
+ cx,
352
+ reverse,
353
+ batch_sizes,
354
+ mode,
355
+ hidden_size,
356
+ num_layers,
357
+ has_biases,
358
+ bidirectional,
359
+ batch_first,
360
+ train,
361
+ ),
362
+ )
363
+
364
+ @register_lowering(torch.ops.onednn.qconv2d_pointwise, type_promotion_kind=None)
365
+ def qconvolution_unary(
366
+ x: TensorBox,
367
+ x_scale,
368
+ x_zp,
369
+ packed_weight: TensorBox,
370
+ w_scale: TensorBox,
371
+ w_zp: TensorBox,
372
+ bias: TensorBox,
373
+ stride,
374
+ padding,
375
+ dilation,
376
+ groups,
377
+ o_inv_scale,
378
+ o_zero_point,
379
+ output_dtype,
380
+ attr,
381
+ scalars,
382
+ algorithm,
383
+ ):
384
+ return TensorBox.create(
385
+ mkldnn_ir.QConvPointWisePT2E.create(
386
+ x,
387
+ x_scale,
388
+ x_zp,
389
+ packed_weight,
390
+ w_scale,
391
+ w_zp,
392
+ bias,
393
+ stride,
394
+ padding,
395
+ dilation,
396
+ groups,
397
+ o_inv_scale,
398
+ o_zero_point,
399
+ output_dtype,
400
+ attr,
401
+ scalars,
402
+ algorithm,
403
+ )
404
+ )
405
+
406
+ @register_lowering(
407
+ torch.ops.onednn.qconv2d_pointwise.binary, type_promotion_kind=None
408
+ )
409
+ def qconvolution_binary(
410
+ x: TensorBox,
411
+ x_scale,
412
+ x_zp,
413
+ accum: TensorBox,
414
+ accum_scale,
415
+ accum_zp,
416
+ packed_weight: TensorBox,
417
+ w_scale: TensorBox,
418
+ w_zp: TensorBox,
419
+ bias: TensorBox,
420
+ stride,
421
+ padding,
422
+ dilation,
423
+ groups,
424
+ o_inv_scale,
425
+ o_zero_point,
426
+ output_dtype,
427
+ binary_attr,
428
+ alpha,
429
+ unary_attr,
430
+ unary_scalars,
431
+ unary_algorithmm,
432
+ ):
433
+ if (
434
+ binary_attr == "sum"
435
+ and output_dtype in [torch.float32, torch.bfloat16]
436
+ and accum.get_dtype() in [torch.float32, torch.bfloat16]
437
+ and accum.get_dtype() != output_dtype
438
+ ):
439
+ # For int8-mixed-bf16 quantization and inplace add,
440
+ # there is case when accum dtype is float32 but output dtype is bfloat16.
441
+ # Since the accum will be inplaced changed with post op sum,
442
+ # we will do accum dtype convertion here.
443
+ accum = to_dtype(accum, output_dtype)
444
+ return TensorBox.create(
445
+ mkldnn_ir.QConvPointWiseBinaryPT2E.create(
446
+ x,
447
+ x_scale,
448
+ x_zp,
449
+ accum,
450
+ accum_scale,
451
+ accum_zp,
452
+ packed_weight,
453
+ w_scale,
454
+ w_zp,
455
+ bias,
456
+ stride,
457
+ padding,
458
+ dilation,
459
+ groups,
460
+ o_inv_scale,
461
+ o_zero_point,
462
+ output_dtype,
463
+ binary_attr,
464
+ alpha,
465
+ unary_attr,
466
+ unary_scalars,
467
+ unary_algorithmm,
468
+ )
469
+ )
470
+
471
+ @register_lowering(torch.ops.onednn.qlinear_pointwise, type_promotion_kind=None)
472
+ def qlinear_unary(
473
+ x: TensorBox,
474
+ x_scale,
475
+ x_zp,
476
+ packed_weight: TensorBox,
477
+ w_scale: TensorBox,
478
+ w_zp: TensorBox,
479
+ bias: TensorBox,
480
+ o_scale,
481
+ o_zero_point,
482
+ output_dtype,
483
+ attr,
484
+ scalars,
485
+ algorithm,
486
+ layout=None,
487
+ ):
488
+ x_size = x.get_size()
489
+ if len(x_size) > 2:
490
+ # GEMM template needs 2D input, normalize input shape here
491
+ x = view(x, [-1, x_size[-1]])
492
+ if not isinstance(x_scale, ir.TensorBox):
493
+ assert type(x_scale) == float
494
+ x_scale = V.graph.add_tensor_constant(
495
+ torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
496
+ )
497
+ else:
498
+ x_scale.realize()
499
+ if not isinstance(x_zp, ir.TensorBox):
500
+ assert type(x_zp) == int
501
+ x_zp = V.graph.add_tensor_constant(
502
+ torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
503
+ )
504
+ else:
505
+ x_zp.realize()
506
+
507
+ # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer
508
+ # Refer to https://github.com/pytorch/pytorch/blob
509
+ # /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577
510
+ w_scale.realize()
511
+ w_zp.realize()
512
+ if w_zp.get_dtype() != torch.int32 and isinstance(
513
+ ir.InputsKernel.unwrap_storage_for_input(w_zp),
514
+ ir.ConstantBuffer,
515
+ ):
516
+ # W_zp might be a ConstantBuffer with int64, convert it to int32
517
+ w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32)
518
+ w_zp = V.graph.add_tensor_constant(
519
+ torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name()
520
+ )
521
+
522
+ bias_dtype = None if bias is None else bias.get_dtype()
523
+
524
+ choices: List[ChoiceCaller] = []
525
+ if use_max_autotune():
526
+ *_, layout, x, packed_weight = mm_args(
527
+ x, packed_weight, layout=layout, out_dtype=output_dtype
528
+ )
529
+ if (
530
+ isinstance(
531
+ ir.InputsKernel.unwrap_storage_for_input(x_zp),
532
+ ir.ConstantBuffer,
533
+ )
534
+ and len(x_zp.get_layout().size) == 0 # Per tensor quant of act
535
+ and isinstance(
536
+ ir.InputsKernel.unwrap_storage_for_input(w_zp),
537
+ ir.ConstantBuffer,
538
+ )
539
+ and torch.equal(
540
+ torch.zeros_like(V.graph.constants[w_zp.get_name()]),
541
+ V.graph.constants[w_zp.get_name()],
542
+ ) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA
543
+ and use_cpp_packed_gemm_template(layout, x, packed_weight)
544
+ ):
545
+ W_tensor = V.graph.constants[packed_weight.get_name()].to_dense()
546
+ weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0)
547
+ weight_compens = V.graph.add_tensor_constant(
548
+ weight_compens_tensor,
549
+ name=packed_weight.get_name() + "_BMatrixCompens",
550
+ )
551
+
552
+ def epilogue_creator(input_buffer):
553
+ # Epilogue to convert from s32 to f32 for u8s8f32
554
+ assert output_dtype in [
555
+ torch.float32,
556
+ torch.bfloat16,
557
+ torch.uint8,
558
+ ]
559
+ input_loader = input_buffer.make_loader()
560
+ weight_compens_loader = weight_compens.make_loader()
561
+ x_scale_loader = x_scale.make_loader()
562
+ w_scale_loader = w_scale.make_loader()
563
+ x_zp_loader = x_zp.make_loader()
564
+ nonlocal bias
565
+ bias_loader = None
566
+ if bias is not None:
567
+ bias_loader = bias.make_loader()
568
+
569
+ def inner_fn(index):
570
+ nonlocal bias
571
+ input = input_loader(index)
572
+ # MicroKernel Output is with int32
573
+ # cvt to FP32 before doing compensation
574
+ input = ops.to_dtype(input, torch.float32)
575
+ weight_compens_index = (index[-1],)
576
+ _x_scale = x_scale_loader(())
577
+ _x_zp = x_zp_loader(())
578
+ _w_scale = w_scale_loader(weight_compens_index)
579
+ _weight_compo = weight_compens_loader(weight_compens_index)
580
+ # Step 1: Doing compensation to cvt fp32
581
+ temp = ops.mul(
582
+ ops.mul(
583
+ input,
584
+ _x_scale,
585
+ ),
586
+ _w_scale,
587
+ )
588
+ temp = ops.sub(
589
+ temp,
590
+ ops.mul(
591
+ ops.mul(
592
+ ops.mul(
593
+ _x_scale,
594
+ _w_scale,
595
+ ),
596
+ _x_zp,
597
+ ),
598
+ _weight_compo,
599
+ ),
600
+ )
601
+ # Step 2: add Bias if applicable
602
+ if bias is not None:
603
+ _bias = bias_loader(weight_compens_index)
604
+ nonlocal bias_dtype
605
+ assert bias_dtype in [torch.float32, torch.bfloat16]
606
+ if bias_dtype == torch.bfloat16:
607
+ _bias = ops.to_dtype(_bias, torch.float32)
608
+ temp = ops.add(temp, _bias)
609
+
610
+ return temp
611
+
612
+ output_buf = ir.Pointwise(
613
+ device=input_buffer.get_device(),
614
+ dtype=torch.float32, # Hardcode to FP32 for u8s8f32
615
+ inner_fn=inner_fn,
616
+ ranges=input_buffer.get_size(),
617
+ )
618
+
619
+ # Step 3: Doing the unary post op fusion
620
+ if attr != "none":
621
+ output_buf = create_epilogue_with_attr(
622
+ output_buf, attr, scalars=scalars, algorithm=algorithm
623
+ )
624
+
625
+ # Step 4: Cast output to Target Dtype
626
+ if output_dtype == torch.bfloat16:
627
+ output_cast_loader = output_buf.make_loader()
628
+
629
+ def inner_fn_cast_output_to_bf16(index):
630
+ input = output_cast_loader(index)
631
+ return ops.to_dtype(input, output_dtype)
632
+
633
+ output_buf = ir.Pointwise(
634
+ device=output_buf.get_device(),
635
+ dtype=output_dtype,
636
+ inner_fn=inner_fn_cast_output_to_bf16,
637
+ ranges=output_buf.get_size(),
638
+ )
639
+ elif output_dtype == torch.uint8:
640
+ from .lowering import _create_constants
641
+
642
+ requant_input_loader = output_buf.make_loader()
643
+
644
+ def inner_fn_requant(index, scale, zero_point):
645
+ input = requant_input_loader(index)
646
+ inv_scale, zero_point = _create_constants(
647
+ 1.0 / scale, zero_point, dtype=torch.float32
648
+ )
649
+ val = ops.round(input * inv_scale) + zero_point
650
+ qmin, qmax = _create_constants(
651
+ 0, 255, dtype=torch.float32
652
+ )
653
+ clamped = ops.minimum(ops.maximum(val, qmin), qmax)
654
+ return ops.to_dtype(clamped, torch.uint8)
655
+
656
+ output_buf = ir.Pointwise(
657
+ device=output_buf.get_device(),
658
+ dtype=output_dtype,
659
+ inner_fn=functools.partial(
660
+ inner_fn_requant,
661
+ scale=float(o_scale),
662
+ zero_point=int(o_zero_point),
663
+ ),
664
+ ranges=output_buf.get_size(),
665
+ )
666
+
667
+ return output_buf
668
+
669
+ assert x.get_dtype() == torch.uint8
670
+ CppPackedGemmTemplate.add_choices(
671
+ choices,
672
+ layout,
673
+ [x, x_scale, x_zp, packed_weight, w_scale, w_zp]
674
+ if bias is None
675
+ else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias],
676
+ has_bias=bias is not None,
677
+ epilogue_creator=epilogue_creator,
678
+ input_indices=[0, 3, 1, 2, 4, 5]
679
+ if bias is None
680
+ else [6, 0, 3, 1, 2, 4, 5],
681
+ )
682
+ if len(choices) == 0 or use_aten_gemm_kernels():
683
+ kwargs = dict(
684
+ output_scale=o_scale,
685
+ output_zero_point=o_zero_point,
686
+ output_dtype=output_dtype,
687
+ post_op_name=attr,
688
+ post_op_args=scalars,
689
+ post_op_algorithm=algorithm,
690
+ )
691
+ if bias is None:
692
+ kwargs["bias"] = None
693
+ choices.append(
694
+ aten_mkldnn_qlinear_unary.bind(
695
+ (x, x_scale, x_zp, packed_weight, w_scale, w_zp)
696
+ if bias is None
697
+ else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias),
698
+ layout,
699
+ **kwargs,
700
+ )
701
+ )
702
+ assert packed_weight.get_name() in V.graph.constants
703
+ input_gen_fns = {
704
+ 3: lambda x: V.graph.constants[x.get_name()],
705
+ 4: lambda x: V.graph.constants[x.get_name()],
706
+ 5: lambda x: V.graph.constants[x.get_name()],
707
+ 6: lambda x: V.graph.constants[x.get_name()], # For bias
708
+ }
709
+ result = autotune_select_algorithm(
710
+ "qlinear_unary",
711
+ choices,
712
+ [x, x_scale, x_zp, packed_weight, w_scale, w_zp]
713
+ if bias is None
714
+ else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias],
715
+ layout,
716
+ input_gen_fns=input_gen_fns,
717
+ )
718
+ if len(x_size) > 2:
719
+ result = view(result, (*x_size[:-1], result.get_size()[-1]))
720
+ return result
721
+
722
+ @register_lowering(
723
+ torch.ops.onednn.qlinear_pointwise.binary, type_promotion_kind=None
724
+ )
725
+ @register_lowering(
726
+ torch.ops.onednn.qlinear_pointwise.binary_tensor, type_promotion_kind=None
727
+ )
728
+ def qlinear_binary(
729
+ x: TensorBox,
730
+ x_scale,
731
+ x_zp,
732
+ packed_weight: TensorBox,
733
+ w_scale: TensorBox,
734
+ w_zp: TensorBox,
735
+ x2: TensorBox,
736
+ bias: TensorBox,
737
+ o_scale,
738
+ o_zero_point,
739
+ output_dtype,
740
+ x2_scale,
741
+ x2_zp,
742
+ binary_attr,
743
+ alpha,
744
+ unary_attr,
745
+ unary_scalars,
746
+ unary_algorithmm,
747
+ layout=None,
748
+ ):
749
+ x_size = x.get_size()
750
+ x2_size = x2.get_size()
751
+ assert len(x_size) == len(x2_size)
752
+ if len(x_size) > 2 and binary_attr == "add":
753
+ # GEMM template needs 2D input, normalize input shape here
754
+ x = view(x, [-1, x_size[-1]])
755
+ x2 = view(x2, [-1, x2_size[-1]])
756
+ if not isinstance(x_scale, ir.TensorBox):
757
+ assert type(x_scale) == float
758
+ x_scale = V.graph.add_tensor_constant(
759
+ torch.tensor(x_scale, dtype=torch.float32), name="x_scale"
760
+ )
761
+ else:
762
+ x_scale.realize()
763
+ if not isinstance(x_zp, ir.TensorBox):
764
+ assert type(x_zp) == int
765
+ x_zp = V.graph.add_tensor_constant(
766
+ torch.tensor(x_zp, dtype=torch.int32), name="x_zp"
767
+ )
768
+ else:
769
+ x_zp.realize()
770
+
771
+ # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer
772
+ # Refer to https://github.com/pytorch/pytorch/blob
773
+ # /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577
774
+ w_scale.realize()
775
+ w_zp.realize()
776
+ if w_zp.get_dtype() != torch.int32 and isinstance(
777
+ ir.InputsKernel.unwrap_storage_for_input(w_zp),
778
+ ir.ConstantBuffer,
779
+ ):
780
+ w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32)
781
+ w_zp = V.graph.add_tensor_constant(
782
+ torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name()
783
+ )
784
+ if binary_attr == "sum":
785
+ if output_dtype in [
786
+ torch.float32,
787
+ torch.bfloat16,
788
+ ] and x2.get_dtype() in [torch.float32, torch.bfloat16]:
789
+ if x2.get_dtype() != output_dtype:
790
+ # For int8-mixed-bf16 quantization and inplace add,
791
+ # there is case when accum dtype is float32 but output dtype is bfloat16.
792
+ # Since the accum will be inplaced changed with post op sum,
793
+ # we will do accum dtype convertion here.
794
+ x2 = to_dtype(x2, output_dtype)
795
+ else:
796
+ assert (
797
+ x2.get_dtype() == output_dtype
798
+ ), "dtype of accum for qlinear post op sum should be the same as output"
799
+ x2_dtype = x2.get_dtype()
800
+ bias_dtype = bias.get_dtype() if bias is not None else None
801
+ choices: List[ChoiceCaller] = []
802
+ if (
803
+ use_max_autotune() and binary_attr == "add"
804
+ ): # <TODO> Support inplace sum fusion
805
+ *_, layout, x, packed_weight, x2 = mm_args(
806
+ x, packed_weight, x2, layout=layout, out_dtype=output_dtype
807
+ )
808
+ if (
809
+ isinstance(
810
+ ir.InputsKernel.unwrap_storage_for_input(x_zp),
811
+ ir.ConstantBuffer,
812
+ )
813
+ and len(x_zp.get_layout().size) == 0 # Per tensor quant of act
814
+ and isinstance(
815
+ ir.InputsKernel.unwrap_storage_for_input(w_zp),
816
+ ir.ConstantBuffer,
817
+ )
818
+ and torch.equal(
819
+ torch.zeros_like(V.graph.constants[w_zp.get_name()]),
820
+ V.graph.constants[w_zp.get_name()],
821
+ ) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA
822
+ and use_cpp_packed_gemm_template(layout, x, packed_weight)
823
+ ):
824
+ W_tensor = V.graph.constants[packed_weight.get_name()]
825
+ W_tensor = W_tensor.to_dense()
826
+ weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0)
827
+ weight_compens = V.graph.add_tensor_constant(
828
+ weight_compens_tensor,
829
+ name=packed_weight.get_name() + "_BMatrixCompens",
830
+ )
831
+
832
+ def epilogue_creator(input_buffer):
833
+ # Epilogue to convert from s32 to f32 for u8s8f32
834
+ assert output_dtype in [
835
+ torch.float32,
836
+ torch.bfloat16,
837
+ torch.uint8,
838
+ ]
839
+
840
+ input_loader = input_buffer.make_loader()
841
+ x2_loader = x2.make_loader()
842
+ weight_compens_loader = weight_compens.make_loader()
843
+ x_scale_loader = x_scale.make_loader()
844
+ w_scale_loader = w_scale.make_loader()
845
+ x_zp_loader = x_zp.make_loader()
846
+ nonlocal bias
847
+ bias_loader = None
848
+ if bias is not None:
849
+ bias_loader = bias.make_loader()
850
+
851
+ def inner_fn(index):
852
+ nonlocal bias
853
+ input = input_loader(index)
854
+ _x2 = x2_loader(index)
855
+ _x_scale = x_scale_loader(())
856
+ _x_zp = x_zp_loader(())
857
+
858
+ # MicroKernel Output is with int32
859
+ # cvt to FP32 before doing compensation
860
+ input = ops.to_dtype(input, torch.float32)
861
+ weight_compens_index = (index[-1],)
862
+ _w_scale = w_scale_loader(weight_compens_index)
863
+ _weight_compens = weight_compens_loader(
864
+ weight_compens_index
865
+ )
866
+ # Step 1: Doing compensation to cvt fp32
867
+ temp = ops.mul(
868
+ ops.mul(
869
+ input,
870
+ _x_scale,
871
+ ),
872
+ _w_scale,
873
+ )
874
+ temp = ops.sub(
875
+ temp,
876
+ ops.mul(
877
+ ops.mul(
878
+ ops.mul(
879
+ _x_scale,
880
+ _w_scale,
881
+ ),
882
+ _x_zp,
883
+ ),
884
+ _weight_compens,
885
+ ),
886
+ )
887
+
888
+ # Step 2: add Bias if applicable
889
+ if bias is not None:
890
+ _bias = bias_loader(weight_compens_index)
891
+ nonlocal bias_dtype
892
+ assert bias_dtype in [torch.float32, torch.bfloat16]
893
+ if bias_dtype == torch.bfloat16:
894
+ _bias = ops.to_dtype(_bias, torch.float32)
895
+ temp = ops.add(temp, _bias)
896
+
897
+ # Step 3: Binary add
898
+ nonlocal x2_dtype
899
+ assert x2_dtype in [torch.float32, torch.bfloat16]
900
+ if x2_dtype == torch.bfloat16:
901
+ _x2 = ops.to_dtype(_x2, torch.float32)
902
+ temp = ops.add(temp, _x2)
903
+
904
+ return temp
905
+
906
+ output_buf = ir.Pointwise(
907
+ device=input_buffer.get_device(),
908
+ dtype=torch.float32, # Hardcode to FP32 for u8s8f32
909
+ inner_fn=inner_fn,
910
+ ranges=input_buffer.get_size(),
911
+ )
912
+
913
+ # Step 4: Unary post op if has
914
+ if unary_attr != "none":
915
+ output_buf = create_epilogue_with_attr(
916
+ output_buf,
917
+ unary_attr,
918
+ scalars=unary_scalars,
919
+ algorithm=unary_algorithmm,
920
+ )
921
+
922
+ # Step 5: Cast output to Target Dtype
923
+ if output_dtype == torch.bfloat16:
924
+ output_cast_loader = output_buf.make_loader()
925
+
926
+ def inner_fn_cast_output_to_bf16(index):
927
+ input = output_cast_loader(index)
928
+ return ops.to_dtype(input, output_dtype)
929
+
930
+ output_buf = ir.Pointwise(
931
+ device=output_buf.get_device(),
932
+ dtype=output_dtype,
933
+ inner_fn=inner_fn_cast_output_to_bf16,
934
+ ranges=output_buf.get_size(),
935
+ )
936
+ elif output_dtype == torch.uint8:
937
+ from .lowering import _create_constants
938
+
939
+ requant_input_loader = output_buf.make_loader()
940
+
941
+ def inner_fn_requant(index, scale, zero_point):
942
+ input = requant_input_loader(index)
943
+ inv_scale, zero_point = _create_constants(
944
+ 1.0 / scale, zero_point, dtype=torch.float32
945
+ )
946
+ val = ops.round(input * inv_scale) + zero_point
947
+ qmin, qmax = _create_constants(
948
+ 0, 255, dtype=torch.float32
949
+ )
950
+ clamped = ops.minimum(ops.maximum(val, qmin), qmax)
951
+ return ops.to_dtype(clamped, torch.uint8)
952
+
953
+ output_buf = ir.Pointwise(
954
+ device=output_buf.get_device(),
955
+ dtype=torch.uint8,
956
+ inner_fn=functools.partial(
957
+ inner_fn_requant,
958
+ scale=float(o_scale),
959
+ zero_point=int(o_zero_point),
960
+ ),
961
+ ranges=output_buf.get_size(),
962
+ )
963
+
964
+ return output_buf
965
+
966
+ CppPackedGemmTemplate.add_choices(
967
+ choices,
968
+ layout,
969
+ [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2]
970
+ if bias is None
971
+ else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias],
972
+ has_bias=bias is not None,
973
+ epilogue_creator=epilogue_creator,
974
+ # Reorder bias and x2
975
+ input_indices=[0, 3, 1, 2, 4, 5, 6]
976
+ if bias is None
977
+ else [7, 0, 3, 1, 2, 4, 5, 6],
978
+ )
979
+
980
+ if len(choices) == 0 or use_aten_gemm_kernels():
981
+ kwargs = dict(
982
+ output_scale=o_scale,
983
+ output_zero_point=o_zero_point,
984
+ output_dtype=output_dtype,
985
+ other_scale=x2_scale,
986
+ other_zp=x2_zp,
987
+ binary_post_op=binary_attr,
988
+ binary_alpha=alpha,
989
+ unary_post_op=unary_attr,
990
+ unary_post_op_args=unary_scalars,
991
+ unary_post_op_algorithm=unary_algorithmm,
992
+ )
993
+ if bias is None:
994
+ kwargs["bias"] = None
995
+ choices.append(
996
+ aten_mkldnn_qlinear_binary.bind(
997
+ (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2)
998
+ if bias is None
999
+ else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias),
1000
+ layout,
1001
+ **kwargs,
1002
+ )
1003
+ )
1004
+ assert packed_weight.get_name() in V.graph.constants
1005
+ input_gen_fns = {
1006
+ 3: lambda x: V.graph.constants[x.get_name()],
1007
+ 4: lambda x: V.graph.constants[x.get_name()],
1008
+ 5: lambda x: V.graph.constants[x.get_name()],
1009
+ }
1010
+ if bias is not None:
1011
+ input_gen_fns[7] = lambda x: V.graph.constants[x.get_name()] # For bias
1012
+ result = autotune_select_algorithm(
1013
+ "qlinear_binary",
1014
+ choices,
1015
+ [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2]
1016
+ if bias is None
1017
+ else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias],
1018
+ layout,
1019
+ input_gen_fns=input_gen_fns,
1020
+ )
1021
+ if len(x_size) > 2 and binary_attr == "add":
1022
+ result = view(result, (*x_size[:-1], result.get_size()[-1]))
1023
+ return result
1024
+
1025
+ if torch._C.has_mkl:
1026
+ aten_mkl_linear = ExternKernelChoice(
1027
+ torch.ops.mkl._mkl_linear,
1028
+ "mkl::_mkl_linear",
1029
+ has_out_variant=False,
1030
+ kernel_creator=mkldnn_ir.MKLPackedLinear.create,
1031
+ )
1032
+ cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear)
1033
+
1034
+ @register_lowering(torch.ops.mkl._mkl_linear)
1035
+ def mkl_packed_linear(
1036
+ x: TensorBox,
1037
+ packed_w: TensorBox,
1038
+ orig_w: TensorBox,
1039
+ b: Optional[TensorBox],
1040
+ batch_size,
1041
+ *,
1042
+ layout=None,
1043
+ ):
1044
+ choices: List[ChoiceCaller] = []
1045
+ if use_max_autotune():
1046
+ transposed_w = permute(orig_w, [1, 0])
1047
+ *_, layout, x, transposed_w = mm_args(
1048
+ x, transposed_w, layout=layout
1049
+ )
1050
+ if use_cpp_packed_gemm_template(layout, x, transposed_w):
1051
+ CppPackedGemmTemplate.add_choices(
1052
+ choices,
1053
+ layout,
1054
+ [x, packed_w, orig_w],
1055
+ trans_w=True,
1056
+ input_indices=[0, 2],
1057
+ )
1058
+
1059
+ if len(choices) == 0 or use_aten_gemm_kernels():
1060
+ choices.append(
1061
+ aten_mkl_linear.bind(
1062
+ (x, packed_w, orig_w), layout, B=None, batch_size=batch_size
1063
+ )
1064
+ )
1065
+
1066
+ assert packed_w.get_name() in V.graph.constants
1067
+ assert orig_w.get_name() in V.graph.constants
1068
+ # packed_w is a mkldnn tensor which we can't generate directly
1069
+ # so we use the weights from the original tensor in autotune.
1070
+ input_gen_fns = {
1071
+ 1: lambda x: V.graph.constants[x.get_name()],
1072
+ 2: lambda x: V.graph.constants[x.get_name()],
1073
+ }
1074
+ result: TensorBox = autotune_select_algorithm(
1075
+ "packed_linear",
1076
+ choices,
1077
+ [x, packed_w, orig_w],
1078
+ layout,
1079
+ input_gen_fns=input_gen_fns,
1080
+ )
1081
+ if b is not None:
1082
+ result = add(result, b)
1083
+ return result
1084
+
1085
+ add_needs_realized_inputs(cpu_needs_realized_inputs)
1086
+ else:
1087
+ pass
.venv/lib/python3.11/site-packages/torch/_inductor/package/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .package import load_package, package_aoti
.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-311.pyc ADDED
Binary file (532 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/package/__pycache__/package.cpython-311.pyc ADDED
Binary file (15.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_inductor/package/build_package.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build_package_contents = """
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from torch._inductor.package.package import compile_so
6
+
7
+ curr_dir = Path(__file__).parent
8
+ aoti_files = [
9
+ os.path.join(root, file)
10
+ for root, dirs, files in os.walk(curr_dir)
11
+ for file in files
12
+ ]
13
+
14
+ output_so = compile_so(curr_dir, aoti_files, curr_dir)
15
+ """
.venv/lib/python3.11/site-packages/torch/_inductor/package/package.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import os
4
+ import shlex
5
+ import subprocess
6
+ import tempfile
7
+ import zipfile
8
+ from pathlib import Path
9
+ from typing import Callable, List, Optional, Union
10
+
11
+ import torch
12
+ import torch._inductor
13
+ import torch.utils._pytree as pytree
14
+ from torch._inductor import config, exc
15
+ from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
16
+ from torch.export._tree_utils import reorder_kwargs
17
+
18
+ from .build_package import build_package_contents
19
+ from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION
20
+
21
+
22
+ class PT2ArchiveWriter:
23
+ def __init__(self, archive_path: str) -> None:
24
+ self.archive_path: str = archive_path
25
+ self.archive_file: Optional[zipfile.ZipFile] = None
26
+
27
+ def __enter__(self) -> "PT2ArchiveWriter":
28
+ assert self.archive_file is None
29
+ self.archive_file = zipfile.ZipFile(
30
+ self.archive_path, "w", compression=zipfile.ZIP_STORED
31
+ )
32
+ self.writestr("version", str(ARCHIVE_VERSION))
33
+ self.writestr("archive_format", "pt2")
34
+ return self
35
+
36
+ def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
37
+ assert self.archive_file is not None
38
+ self.archive_file.close()
39
+ self.archive_file = None
40
+ return None
41
+
42
+ def writestr(self, name: str, data: Union[bytes, str]) -> None:
43
+ assert self.archive_file is not None
44
+ self.archive_file.writestr(name, data)
45
+
46
+ def write_file(self, name: str, file_path: str) -> None:
47
+ """
48
+ Copy a file into the archive.
49
+ name: The destination file inside the archive.
50
+ file_path: The source file on disk.
51
+ """
52
+ assert Path(file_path).is_file(), f"{file_path} is not a valid file path"
53
+ assert self.archive_file is not None
54
+ self.archive_file.write(file_path, arcname=name)
55
+
56
+
57
+ class PT2ArchiveReader:
58
+ def __init__(self, archive_path: str) -> None:
59
+ self.archive_path: str = archive_path
60
+ self.archive_file: Optional[zipfile.ZipFile] = None
61
+
62
+ def __enter__(self) -> "PT2ArchiveReader":
63
+ self.archive_file = zipfile.ZipFile(
64
+ self.archive_path, "r", compression=zipfile.ZIP_STORED
65
+ )
66
+ return self
67
+
68
+ def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
69
+ if self.archive_file is not None:
70
+ self.archive_file.close()
71
+ return None
72
+
73
+ def read(self, name: str) -> bytes:
74
+ assert self.archive_file is not None
75
+ return self.archive_file.read(name)
76
+
77
+ def extract_to_path(self, member: str, path: str) -> str:
78
+ assert self.archive_file is not None
79
+ return self.archive_file.extract(member, path)
80
+
81
+ def extractall(self, path: str) -> None:
82
+ assert self.archive_file is not None
83
+ self.archive_file.extractall(path)
84
+
85
+ def get_file_names(self) -> List[str]:
86
+ assert self.archive_file is not None
87
+ return self.archive_file.namelist()
88
+
89
+
90
+ def _run_command_and_check(cmd: str) -> None:
91
+ cmd = shlex.split(cmd)
92
+ try:
93
+ subprocess.run(cmd, check=True)
94
+ except subprocess.CalledProcessError as e:
95
+ raise exc.CppCompileError(cmd, e.output) from e
96
+
97
+
98
+ def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str:
99
+ def get_aoti_file_with_suffix(suffix: str) -> str:
100
+ for file in aoti_files:
101
+ if file.endswith(suffix):
102
+ return file
103
+ raise RuntimeError(f"Unable to find file with suffix {suffix}")
104
+
105
+ # Compile all the files into a .so
106
+ cpp_file = os.path.join(aoti_dir, get_aoti_file_with_suffix(".cpp"))
107
+ consts_o = os.path.join(aoti_dir, get_aoti_file_with_suffix(".o"))
108
+
109
+ file_name = os.path.splitext(cpp_file)[0]
110
+
111
+ # Parse compile flags and build the .o file
112
+ with open(file_name + "_compile_flags.json") as f:
113
+ compile_flags = json.load(f)
114
+
115
+ compile_options = BuildOptionsBase(**compile_flags)
116
+ object_builder = CppBuilder(
117
+ name=file_name,
118
+ sources=cpp_file,
119
+ BuildOption=compile_options,
120
+ )
121
+ compile_cmd = object_builder.get_command_line()
122
+ output_o = object_builder.get_target_file_path()
123
+
124
+ _run_command_and_check(compile_cmd)
125
+
126
+ # Parse linker flags and build the .so file
127
+ with open(file_name + "_linker_flags.json") as f:
128
+ linker_flags = json.load(f)
129
+
130
+ linker_options = BuildOptionsBase(**linker_flags)
131
+ so_builder = CppBuilder(
132
+ name=os.path.split(so_path)[-1],
133
+ sources=[output_o, consts_o],
134
+ BuildOption=linker_options,
135
+ output_dir=so_path,
136
+ )
137
+ link_cmd = so_builder.get_command_line()
138
+ output_so = so_builder.get_target_file_path()
139
+
140
+ _run_command_and_check(link_cmd)
141
+
142
+ # mmapped weights
143
+ serialized_weights_filename = file_name + "_serialized_weights.bin"
144
+ if serialized_weights_filename in aoti_files:
145
+ with open(serialized_weights_filename, "rb") as f_weights:
146
+ serialized_weights = f_weights.read()
147
+
148
+ with open(output_so, "a+b") as f_so:
149
+ so_size = f_so.tell()
150
+ # Page align the weights
151
+ f_so.write(b" " * (16384 - so_size % 16384))
152
+ f_so.write(serialized_weights)
153
+
154
+ return output_so
155
+
156
+
157
+ def package_aoti(aoti_output_dir: str) -> str:
158
+ """
159
+ Saves the AOTInductor generated files to the PT2Archive format.
160
+ """
161
+
162
+ # Add a makefile and python script
163
+ build_package_filename = "build_package.py"
164
+ with open(os.path.join(aoti_output_dir, build_package_filename), "w") as f:
165
+ f.write(build_package_contents)
166
+
167
+ with open(os.path.join(aoti_output_dir, "Makefile"), "w") as f:
168
+ f.write(f"all:\n\tpython3 {build_package_filename}\n")
169
+
170
+ if config.aot_inductor.output_path.endswith(".so"):
171
+ raise RuntimeError(
172
+ "Unable to save package as a .so. It should be a .pt2 format or a directory."
173
+ )
174
+ elif config.aot_inductor.output_path.endswith(".pt2"):
175
+ # Save using the PT2 packaging format
176
+ # (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
177
+ archive_path = config.aot_inductor.output_path
178
+
179
+ with PT2ArchiveWriter(archive_path) as archive_writer:
180
+ package_files = glob.glob(f"{aoti_output_dir}/*")
181
+
182
+ for path in package_files:
183
+ filename = os.path.basename(path)
184
+ archive_writer.write_file(f"{AOTINDUCTOR_DIR}{filename}", path)
185
+
186
+ return archive_path
187
+
188
+ else:
189
+ # Directly put the files into the directory, without any archiving
190
+ return aoti_output_dir
191
+
192
+
193
+ def load_package(path: str, device: str) -> Callable: # type: ignore[type-arg]
194
+ if path.endswith(".so"):
195
+ raise RuntimeError(
196
+ "Unable to load .so. It should be a .pt2 format or a directory."
197
+ )
198
+
199
+ elif path.endswith(".pt2"):
200
+ so_path = os.path.splitext(path)[0]
201
+ with PT2ArchiveReader(path) as archive_reader:
202
+ file_names = archive_reader.get_file_names()
203
+
204
+ with tempfile.TemporaryDirectory() as tmp_dir:
205
+ archive_reader.extractall(tmp_dir)
206
+ file_names = archive_reader.get_file_names()
207
+ aoti_files = [
208
+ file for file in file_names if file.startswith(AOTINDUCTOR_DIR)
209
+ ]
210
+
211
+ so_path = compile_so(tmp_dir, aoti_files, so_path)
212
+
213
+ else:
214
+ assert os.path.isdir(path), "Must specify a directory or a .pt2 file"
215
+ aoti_files = [
216
+ os.path.join(root, file)
217
+ for root, dirs, files in os.walk(path)
218
+ for file in files
219
+ ]
220
+ so_path = compile_so(path, aoti_files, path)
221
+
222
+ if device == "cpu":
223
+ runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
224
+ elif device == "cuda" or device.startswith("cuda:"):
225
+ runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
226
+ else:
227
+ raise RuntimeError("Unsupported device " + device)
228
+
229
+ def optimized(*args, **kwargs): # type: ignore[no-untyped-def]
230
+ call_spec = runner.get_call_spec() # type: ignore[attr-defined]
231
+ in_spec = pytree.treespec_loads(call_spec[0])
232
+ out_spec = pytree.treespec_loads(call_spec[1])
233
+ flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
234
+ flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
235
+ return pytree.tree_unflatten(flat_outputs, out_spec)
236
+
237
+ return optimized
.venv/lib/python3.11/site-packages/torch/_inductor/package/pt2_archive_constants.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARCHIVE_ROOT_NAME = "package"
2
+ ARCHIVE_FORMAT_PATH = "archive_format"
3
+ MODELS_DIR = "models/"
4
+ MODELS_FILENAME_FORMAT = "models/{}.json"
5
+ AOTINDUCTOR_DIR = "data/aotinductor/"
6
+ WEIGHTS_DIR = "data/weights/"
7
+ WEIGHT_FILENAME_PREFIX = "weight_"
8
+ CONSTANTS_DIR = "data/constants/"
9
+ TENSOR_CONSTANT_FILENAME_PREFIX = "tensor_"
10
+ CUSTOM_OBJ_FILENAME_PREFIX = "custom_obj_"
11
+ SAMPLE_INPUTS_DIR = "data/sample_inputs/"
12
+ SAMPLE_INPUTS_FILENAME_FORMAT = "data/sample_inputs/{}.pt"
13
+ EXTRA_DIR = "extra/"
14
+ MODULE_INFO_PATH = "extra/module_info.json"
15
+
16
+ ARCHIVE_VERSION = 0
.venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py ADDED
@@ -0,0 +1,2005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ """
3
+ # Inductor Pattern Matcher
4
+
5
+ The pattern matcher enables search/replace within an FX graph.
6
+
7
+ The main entrypoint to the pattern matcher is register_replacement(). Given a
8
+ search function and a replacement function this will register a replacement with
9
+ a pass (such as torch._inductor.fx_passes.joint_graph.patterns).
10
+
11
+ Internally the pattern matcher represents patterns as a graph (a DAG). Creating
12
+ new patterns manually as a graph is cumbersome and error-prone so the standard
13
+ way to create patterns (using register_replacement()) is to provide a search
14
+ function and a replacement function which is traced and converted into a graph.
15
+
16
+ Because the search functions are built somewhat generic (they tend to ignore
17
+ tensor sizes, for example) register_replacement() allows you to specify an
18
+ `extra_check` function which performs additional checks to verify that the
19
+ matched pattern fully matches before returning it.
20
+
21
+ ## Precompiled Patterns
22
+
23
+ New patterns are added using register_replacement(). Patterns added in this way
24
+ can have a compile-time overhead because they need to be traced before
25
+ use. Patterns can be precompiled and added using gen_register_replacement()
26
+ instead. To do this you call gen_register_replacement() instead of
27
+ register_replacement(). The arguments are the same except for an additional
28
+ unique name which is used as a lookup key.
29
+
30
+ ## Internals
31
+
32
+ The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr
33
+ implements a `_match` method which returns either a `Match` object for a
34
+ successful match or a `FailedMatch` object for a failure to match.
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ import contextlib
40
+ import dataclasses
41
+ import functools
42
+ import importlib
43
+ import inspect
44
+ import itertools
45
+ import logging
46
+ import operator
47
+ import os
48
+ import re
49
+ import textwrap
50
+ import typing
51
+ from abc import ABC, abstractmethod
52
+ from collections import defaultdict
53
+ from pathlib import Path
54
+ from typing import (
55
+ Any,
56
+ Callable,
57
+ DefaultDict,
58
+ Dict,
59
+ Generator,
60
+ Iterable,
61
+ List,
62
+ Mapping,
63
+ NoReturn,
64
+ Optional,
65
+ Protocol,
66
+ Sequence,
67
+ Set,
68
+ Tuple,
69
+ Type,
70
+ TypeVar,
71
+ Union,
72
+ )
73
+ from typing_extensions import Self, TypeGuard
74
+
75
+ import torch
76
+ import torch._guards
77
+ import torch.fx
78
+ import torch.utils._pytree as pytree
79
+ from torch._dispatch.python import enable_python_dispatcher
80
+ from torch._dynamo.utils import counters
81
+ from torch._inductor.config import trace as trace_config
82
+ from torch._prims_common import is_integer_dtype
83
+ from torch._subclasses.fake_tensor import unset_fake_temporarily
84
+ from torch.fx.experimental.proxy_tensor import make_fx
85
+ from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
86
+ from torch.fx.immutable_collections import immutable_dict, immutable_list
87
+ from torch.fx.passes.graph_transform_observer import GraphTransformObserver
88
+
89
+ from .._functorch import config as functorch_config
90
+ from .._functorch.aot_autograd import aot_function, make_boxed_func
91
+ from .._functorch.partitioners import default_partition
92
+ from .._subclasses import FakeTensor, FakeTensorMode
93
+ from ..fx import Transformer
94
+ from . import config
95
+ from .decomposition import select_decomp_table
96
+ from .lowering import fallback_node_due_to_unsupported_type
97
+
98
+
99
+ log = logging.getLogger(__name__)
100
+ aten = torch.ops.aten
101
+ prims = torch.ops.prims
102
+
103
+ Constant = Any
104
+ NodeOrConstant = Union[Constant, torch.fx.Node]
105
+
106
+
107
+ class SearchFn(Protocol):
108
+ __name__: str
109
+
110
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
111
+ ...
112
+
113
+
114
+ class ReplaceFn(Protocol):
115
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
116
+ ...
117
+
118
+
119
+ class TraceFn(Protocol):
120
+ def __call__(
121
+ self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any
122
+ ) -> torch.fx.GraphModule:
123
+ ...
124
+
125
+
126
+ T = TypeVar("T")
127
+
128
+ # What's a better name for this?
129
+ FnsType = Union[torch.fx.node.Target, str]
130
+
131
+
132
+ class Multiple:
133
+ def __init__(self) -> None:
134
+ # Ensure we're really a singleton.
135
+ assert "MULTIPLE" not in globals() or self is MULTIPLE
136
+
137
+
138
+ # Sentinel indicating multiple quantities can be matched
139
+ MULTIPLE = Multiple()
140
+
141
+
142
+ class Match:
143
+ """
144
+ Represents a successfully matched pattern.
145
+
146
+ The `Match` object is returned to represent a successfully matched
147
+ pattern. Included in the Match are the pattern that was matched, the graph
148
+ nodes matched, and any args that were used during the matching.
149
+
150
+ The args and kwargs are specific to the type of pattern that was matched and
151
+ provide hints about what was matched.
152
+ """
153
+
154
+ pattern: PatternExpr
155
+ args: List[Any]
156
+ kwargs: Dict[str, Any]
157
+ nodes: List[torch.fx.Node]
158
+ targets: Dict[_TargetExpr, torch.fx.node.Target]
159
+ ctx: MatchContext
160
+ replacement_graph: Optional[torch.fx.Graph]
161
+
162
+ def __init__(
163
+ self,
164
+ ctx: MatchContext,
165
+ pattern: PatternExpr,
166
+ args: Optional[Sequence[Any]] = None,
167
+ kwargs: Optional[Dict[str, Any]] = None,
168
+ ) -> None:
169
+ super().__init__()
170
+ self.pattern = pattern
171
+ # The input nodes that must be passed in to the result
172
+ self.args = list(args or [])
173
+ self.kwargs = kwargs or {}
174
+ # The nodes matched in this expression
175
+ self.nodes = []
176
+ # Mapping CallFunction to the node.target
177
+ self.targets = {}
178
+ self.ctx = ctx
179
+ self.replacement_graph = None
180
+
181
+ @property
182
+ def graph(self) -> torch.fx.Graph:
183
+ return self.ctx.graph
184
+
185
+ def extend(self, other: Match) -> None:
186
+ if self.kwargs:
187
+ for key in set(self.kwargs.keys()) & set(other.kwargs.keys()):
188
+ if self.kwargs[key] != other.kwargs[key]:
189
+ raise FailedMatch("kwarg mismatch: {}", key)
190
+ self.args.extend(other.args)
191
+ self.nodes.extend(other.nodes)
192
+ self.kwargs.update(other.kwargs)
193
+ self.targets.update(other.targets)
194
+
195
+ def bundle(self) -> Match:
196
+ # Wrap args in an extra list
197
+ self.args = [tuple(self.args)] if self.args else []
198
+ return self
199
+
200
+ def __repr__(self) -> str:
201
+ return f"Match(..., {self.args}, {self.kwargs})"
202
+
203
+ def erase_nodes(self) -> None:
204
+ graph = self.graph
205
+ for n in reversed(self.nodes):
206
+ if not n._erased and not n.users:
207
+ graph.erase_node(n)
208
+
209
+ def output_nodes(self) -> List[Optional[torch.fx.Node]]:
210
+ return [
211
+ (self.ctx.pattern_to_node[p] if p is not None else None)
212
+ for p in self.ctx.outputs
213
+ ]
214
+
215
+ def output_node(self) -> torch.fx.Node:
216
+ return next(p for p in self.output_nodes() if p)
217
+
218
+ def replace_with_graph(
219
+ self, replacement_graph: torch.fx.Graph, args: Sequence[Any]
220
+ ) -> None:
221
+ ReplacementPatternEntry.replace_with_graph(
222
+ self, self.ctx.graph, replacement_graph, args
223
+ )
224
+
225
+ def replace_by_example(
226
+ self,
227
+ replacement_fn: ReplaceFn,
228
+ args: Sequence[Any],
229
+ trace_fn: Optional[TraceFn] = None,
230
+ run_functional_passes: bool = True,
231
+ ) -> None:
232
+ """Replace with a graph generated by tracing the replacement_fn.
233
+
234
+ Args:
235
+ run_functional_passes (bool). If we should run passes that
236
+ assume functional IR (like DCE, remove_noop_ops), on the
237
+ replacement graph.
238
+
239
+ """
240
+ from torch._inductor.virtualized import NullHandler, V
241
+
242
+ context = (
243
+ V.fake_mode
244
+ if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None))
245
+ else contextlib.nullcontext()
246
+ )
247
+
248
+ with context:
249
+ if trace_fn is None:
250
+ trace_fn = functools.partial(
251
+ fwd_only, run_functional_passes=run_functional_passes
252
+ )
253
+ replacement = trace_fn(
254
+ replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
255
+ )
256
+ ReplacementPatternEntry.replace_with_graph(
257
+ self,
258
+ self.ctx.graph,
259
+ replacement,
260
+ args,
261
+ )
262
+
263
+
264
+ class FailedMatch(RuntimeError):
265
+ """
266
+ Represents a unsuccessful match.
267
+
268
+ The `FailedMatch` object is returned to represent a failure to match a
269
+ pattern.
270
+ """
271
+
272
+ format_string: str
273
+
274
+ def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None:
275
+ self.format_string = format_string
276
+ # We want to construct error messages lazily instead of eagerly, as
277
+ # constructing them eagerly can significantly worsen compile times.
278
+ if len(format_string) > 200:
279
+ raise RuntimeError(
280
+ f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}"
281
+ )
282
+ self.args = args
283
+ self.kwargs = kwargs
284
+
285
+ def __str__(self) -> str:
286
+ return self.format_string.format(*self.args, **self.kwargs)
287
+
288
+ def __bool__(self) -> bool:
289
+ return False
290
+
291
+
292
+ MatchResult = Union[Match, FailedMatch]
293
+
294
+
295
+ def is_match(m: MatchResult) -> TypeGuard[Match]:
296
+ """
297
+ TypeGuards cannot act on `self`. Thus this function exists to let mypy
298
+ recognize FailedMatch.__bool__ as a TypeGuard.
299
+ """
300
+ return bool(m)
301
+
302
+
303
+ class MatchContext:
304
+ """
305
+ Internal state needed while running PatternExpr._match().
306
+ """
307
+
308
+ outputs: List[Optional[PatternExpr]]
309
+ pattern_to_node: Dict[PatternExpr, Optional[torch.fx.Node]]
310
+ graph: torch.fx.Graph
311
+ exclusive_node_set: List[NodeOrConstant]
312
+
313
+ def __init__(
314
+ self,
315
+ outputs: List[Optional[PatternExpr]],
316
+ pattern_to_node: Optional[Dict[PatternExpr, torch.fx.Node]] = None,
317
+ *,
318
+ graph: torch.fx.Graph,
319
+ ) -> None:
320
+ self.outputs = outputs
321
+ self.pattern_to_node = {} if pattern_to_node is None else dict(pattern_to_node)
322
+ self.graph = graph
323
+ self.exclusive_node_set = []
324
+
325
+ def match(self, pattern: PatternExpr, node: NodeOrConstant) -> MatchResult:
326
+ """wrapper to check reused nodes in patterns"""
327
+ if pattern in self.pattern_to_node:
328
+ if self.pattern_to_node[pattern] == node:
329
+ return Match(self, pattern) # already checked this node
330
+ else:
331
+ return FailedMatch("repeated pattern differs")
332
+ m = pattern._match(node, self)
333
+ assert pattern not in self.pattern_to_node
334
+ self.pattern_to_node[pattern] = node if m else None
335
+ return m
336
+
337
+ def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]:
338
+ return {
339
+ pattern: node
340
+ for pattern, node in self.pattern_to_node.items()
341
+ if pattern.has_multiple_users() and node is not None
342
+ }
343
+
344
+
345
+ class PatternExpr(ABC):
346
+ """
347
+ Base class for types of patterns.
348
+ """
349
+
350
+ @abstractmethod
351
+ def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
352
+ ...
353
+
354
+ def match(self, node: torch.fx.Node) -> MatchResult:
355
+ try:
356
+ return MatchContext([self], graph=node.graph).match(self, node)
357
+ except FailedMatch as e:
358
+ return e
359
+
360
+ def has_multiple_users(self) -> bool:
361
+ return False
362
+
363
+ def __repr__(self) -> str:
364
+ return self.__class__.__name__ + "()"
365
+
366
+ def find_anchor_nodes(
367
+ self, ctx: MatchContext, searched: Set[torch.fx.Node]
368
+ ) -> Generator[Optional[torch.fx.Node], None, None]:
369
+ if self in ctx.pattern_to_node:
370
+ yield ctx.pattern_to_node[self]
371
+
372
+ def pattern_eq(self, other: Any) -> bool:
373
+ """
374
+ Compare two `PatternExpr`s and return true if they are the
375
+ same. Note this is NOT matching a pattern - it is comparing the pattern
376
+ structures (for debugging).
377
+ """
378
+ return isinstance(other, self.__class__)
379
+
380
+
381
+ class Arg(PatternExpr):
382
+ """
383
+ Capture an arg which will become an input to the handler. Args are
384
+ passed in depth first order.
385
+ """
386
+
387
+ def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
388
+ return Match(ctx, self, args=[node]) # matches anything
389
+
390
+
391
+ class Ignored(PatternExpr):
392
+ """
393
+ Match an arg, but don't pass it to handler
394
+ """
395
+
396
+ def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
397
+ return Match(ctx, self) # matches anything
398
+
399
+ def __repr__(self) -> str:
400
+ return "*"
401
+
402
+ def pretty_print(self, pp: PatternPrettyPrinter) -> str:
403
+ return "Ignored()"
404
+
405
+
406
+ class KeywordArg(PatternExpr):
407
+ """
408
+ Capture a kwarg which will become an input to the handler.
409
+ """
410
+
411
+ def __init__(self, name: str) -> None:
412
+ super().__init__()
413
+ self.name = name
414
+
415
+ def __repr__(self) -> str:
416
+ return f"KeywordArg({self.name!r})"
417
+
418
+ def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
419
+ return Match(ctx, self, kwargs={self.name: node}) # matches anything
420
+
421
+ def pattern_eq(self, other: Any) -> bool:
422
+ other = typing.cast(Self, other) # super makes sure this is true
423
+ return super().pattern_eq(other) and self.name == other.name
424
+
425
+
426
+ class ExclusiveKeywordArg(PatternExpr):
427
+ """
428
+ Capture a kwarg which will become an input to the handler.
429
+ """
430
+
431
+ name: str
432
+
433
+ def __init__(self, name: str) -> None:
434
+ super().__init__()
435
+ self.name = name
436
+
437
+ def __repr__(self) -> str:
438
+ return f"ExclusiveKeywordArg({self.name!r})"
439
+
440
+ def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
441
+ if node in ctx.exclusive_node_set:
442
+ return FailedMatch("exclusive arg appears twice")
443
+
444
+ ctx.exclusive_node_set.append(node)
445
+ return Match(ctx, self, kwargs={self.name: node}) # matches anything
446
+
447
+ def pattern_eq(self, other: Any) -> bool:
448
+ other = typing.cast(Self, other) # super makes sure this is true
449
+ return super().pattern_eq(other) and self.name == other.name
450
+
451
+
452
+ class _TargetExpr(PatternExpr):
453
+ """
454
+ Base class for filtering match by node.target
455
+ """
456
+
457
+ fns: List[FnsType]
458
+ fns_set: Set[FnsType]
459
+
460
+ def __init__(
461
+ self, fns: Union[FnsType, Sequence[FnsType]], users: Union[Multiple, int] = 1
462
+ ) -> None:
463
+ super().__init__()
464
+ fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns)
465
+ for fn in fns:
466
+ if isinstance(fn, torch._ops.OpOverloadPacket):
467
+ fns.extend(getattr(fn, overload) for overload in fn.overloads())
468
+
469
+ self.fns = fns
470
+ self.fns_set = set(fns)
471
+ self.users = users
472
+
473
+ @property
474
+ @abstractmethod
475
+ def op(self) -> str:
476
+ ...
477
+
478
+ def fns_repr(self) -> str:
479
+ first_repr = self.fns[0]
480
+ if not isinstance(first_repr, str):
481
+ first_repr = first_repr.__name__
482
+
483
+ if len(self.fns) > 1:
484
+ return f"[{first_repr}, ...]"
485
+ elif self.fns[0] is getattr(torch, first_repr, None):
486
+ return f"torch.{first_repr}"
487
+ elif isinstance(self.fns[0], torch._ops.OpOverload):
488
+ return str(self.fns[0])
489
+ else:
490
+ return first_repr
491
+
492
+ def __repr__(self) -> str:
493
+ if self.users is MULTIPLE:
494
+ comma_users = ", MULTIPLE"
495
+ elif self.users != 1:
496
+ comma_users = f", {self.users})"
497
+ else:
498
+ comma_users = ""
499
+ return f"{self.__class__.__name__}({self.fns_repr()}{comma_users})"
500
+
501
+ def has_multiple_users(self) -> bool:
502
+ return isinstance(self.users, Multiple) or self.users > 1
503
+
504
+ def find_anchor_nodes(
505
+ self, ctx: MatchContext, searched: Set[torch.fx.Node]
506
+ ) -> Generator[Optional[torch.fx.Node], None, None]:
507
+ raise NotImplementedError
508
+
509
+ def _match_fns(self, node: torch.fx.Node) -> bool:
510
+ return (
511
+ isinstance(node, torch.fx.Node)
512
+ and node.op == self.op
513
+ and extract_target(node) in self.fns_set
514
+ )
515
+
516
+ def _match_users(self, node: torch.fx.Node, ctx: MatchContext) -> bool:
517
+ return (
518
+ self in ctx.outputs
519
+ or self.users is MULTIPLE
520
+ or len(node.users) == self.users
521
+ )
522
+
523
+ def pattern_eq(self, other: Any) -> bool:
524
+ other = typing.cast(Self, other) # super makes sure this is true
525
+ return (
526
+ super().pattern_eq(other)
527
+ and self.op == other.op
528
+ and self.fns == other.fns
529
+ and self.users == other.users
530
+ )
531
+
532
+
533
+ _SimpleSpec = Tuple[Any, ...]
534
+
535
+
536
+ class _TargetArgsExpr(_TargetExpr):
537
+ """
538
+ Base class for filtering match by node.{target,args,kwargs}
539
+ """
540
+
541
+ def __init__(
542
+ self,
543
+ fns: Union[torch.fx.node.Target, str, Sequence[Any]],
544
+ *args: Any,
545
+ _users: Union[int, Multiple] = 1,
546
+ **kwargs: Any,
547
+ ) -> None:
548
+ super().__init__(fns, _users)
549
+ self.args = tuple(args)
550
+ self.kwargs = dict(kwargs)
551
+ if any(
552
+ isinstance(x, (dict, list, tuple))
553
+ for x in itertools.chain(args, kwargs.values())
554
+ ):
555
+ self.flatten = self.pytree_flatten
556
+ else:
557
+ self.flatten = self.simple_flatten
558
+ self.flat_args_kwargs = self.flatten(self.args, self.kwargs)
559
+
560
+ @staticmethod
561
+ def simple_flatten(
562
+ args: Sequence[Any], kwargs: Mapping[Any, Any]
563
+ ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
564
+ values = (*args, *kwargs.values())
565
+ spec = (len(args), *kwargs.keys())
566
+ return values, spec
567
+
568
+ @staticmethod
569
+ def pytree_flatten(
570
+ args: Sequence[Any], kwargs: Mapping[Any, Any]
571
+ ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
572
+ def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec:
573
+ if s.type is None:
574
+ return s
575
+ mapping = {immutable_list: list, tuple: list, immutable_dict: dict}
576
+ return pytree.TreeSpec(
577
+ mapping.get(s.type, s.type),
578
+ s.context,
579
+ list(map(norm_spec, s.children_specs)),
580
+ )
581
+
582
+ flat, spec = pytree.tree_flatten([args, kwargs])
583
+ spec = norm_spec(spec)
584
+ return flat, spec
585
+
586
+ def __repr__(self) -> str:
587
+ args = [
588
+ self.fns_repr(),
589
+ *map(repr, self.args),
590
+ *[f"{k}={v}" for k, v in self.kwargs.items()],
591
+ ]
592
+ if self.users is MULTIPLE:
593
+ args.append("_users=MULTIPLE")
594
+ elif self.users != 1:
595
+ args.append(f"_users={self.users}")
596
+ return f"{self.__class__.__name__}({', '.join(args)})"
597
+
598
+ def pretty_print(self, pp: PatternPrettyPrinter) -> str:
599
+ args = [
600
+ self.fns_repr(),
601
+ *(pp.pretty_print(x) for x in self.args),
602
+ *[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()],
603
+ ]
604
+ if self.users is MULTIPLE:
605
+ args.append("_users=MULTIPLE")
606
+ elif self.users != 1:
607
+ args.append(f"_users={self.users}")
608
+
609
+ joiner_str = ", "
610
+ return f"{self.__class__.__name__}({joiner_str.join(args)})"
611
+
612
+ def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
613
+ if not self._match_fns(node) or len(node.args) != len(self.args):
614
+ return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
615
+
616
+ if not self._match_users(node, ctx):
617
+ return FailedMatch("multiple_users {}", self)
618
+
619
+ _args = node.args
620
+ _kwargs = node.kwargs
621
+ if len(_kwargs) < len(self.kwargs):
622
+ from torch.fx.operator_schemas import normalize_function
623
+
624
+ normalized_args_and_kwargs = normalize_function(
625
+ node.target, node.args, node.kwargs # type: ignore[arg-type]
626
+ )
627
+
628
+ if normalized_args_and_kwargs is None:
629
+ return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
630
+ else:
631
+ _args, _kwargs = normalized_args_and_kwargs
632
+ if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs):
633
+ _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
634
+ else:
635
+ return FailedMatch(
636
+ "function_mismatch: node={}, pattern={}", node, self
637
+ )
638
+ else:
639
+ _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
640
+
641
+ node_items, node_spec = self.flatten(_args, _kwargs)
642
+ self_items, self_spec = self.flat_args_kwargs
643
+ if node_spec != self_spec:
644
+ return FailedMatch("args_structure {} {}", node_spec, self_spec)
645
+ assert len(node_items) == len(self_items)
646
+
647
+ m = Match(ctx, self)
648
+ for i, pattern, child_node in zip(itertools.count(), self_items, node_items):
649
+ if isinstance(pattern, PatternExpr):
650
+ child_match = ctx.match(pattern, child_node)
651
+ if not is_match(child_match):
652
+ return child_match
653
+ m.extend(child_match)
654
+ elif isinstance(child_node, torch.fx.Node) or child_node != pattern:
655
+ return FailedMatch(
656
+ "constant_args: {} {!r}!={pattern!r}", node, child_node
657
+ )
658
+ m.nodes.append(node)
659
+ m.targets[self] = node.target
660
+ return m
661
+
662
+ def find_anchor_nodes(
663
+ self, ctx: MatchContext, searched: Set[torch.fx.Node]
664
+ ) -> Generator[Optional[torch.fx.Node], None, None]:
665
+ """
666
+ This is used when we are matching a pattern with multiple outputs.
667
+ There is a partial match (stored in ctx) and we want to walk
668
+ this pattern to find a connection to an already-matched node.
669
+
670
+ Yields candidate nodes that `self._match` might like.
671
+ """
672
+ if self in ctx.pattern_to_node:
673
+ yield ctx.pattern_to_node[self]
674
+ return
675
+
676
+ for pattern in self.flat_args_kwargs[0]:
677
+ if isinstance(pattern, PatternExpr):
678
+ for other_node in pattern.find_anchor_nodes(ctx, searched):
679
+ if not isinstance(other_node, torch.fx.Node):
680
+ continue
681
+ for node in other_node.users:
682
+ if node not in searched:
683
+ if self._match_fns(node):
684
+ yield node
685
+ searched.add(node)
686
+
687
+ def pattern_eq(self, other: Any) -> bool:
688
+ other = typing.cast(Self, other) # super makes sure this is true
689
+ return (
690
+ super().pattern_eq(other)
691
+ and self.flat_args_kwargs[1] == other.flat_args_kwargs[1]
692
+ and all(
693
+ a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b
694
+ for a, b in zip(self.flat_args_kwargs[0], other.flat_args_kwargs[0])
695
+ )
696
+ )
697
+
698
+
699
+ class CallFunction(_TargetArgsExpr):
700
+ """
701
+ Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)`
702
+ """
703
+
704
+ op = "call_function"
705
+
706
+
707
+ class CallMethod(_TargetArgsExpr):
708
+ """
709
+ Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)`
710
+ """
711
+
712
+ op = "call_method"
713
+
714
+
715
+ class CallModule(_TargetArgsExpr):
716
+ """
717
+ Matches a call_module node in the FX graphs: `module(*args, **kwargs)`
718
+ """
719
+
720
+ op = "call_module"
721
+
722
+
723
+ class _TargetExprVarArgs(_TargetExpr):
724
+ """
725
+ Matches a call_function node with any arguments which are passed into the pattern
726
+ """
727
+
728
+ def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
729
+ if not self._match_fns(node):
730
+ return FailedMatch("function_mismatch")
731
+
732
+ if not self._match_users(node, ctx):
733
+ return FailedMatch("multiple_users")
734
+
735
+ m = Match(ctx, self)
736
+ m.nodes.append(node)
737
+ m.targets[self] = node.target
738
+ m.args.extend(node.args)
739
+ m.kwargs.update(node.kwargs)
740
+ return m
741
+
742
+
743
+ class CallFunctionVarArgs(_TargetExprVarArgs):
744
+ op = "call_function"
745
+
746
+
747
+ class CallMethodVarArgs(_TargetExprVarArgs):
748
+ op = "call_method"
749
+
750
+
751
+ class CallModuleVarArgs(_TargetExprVarArgs):
752
+ op = "call_module"
753
+
754
+
755
+ class ListOf(PatternExpr):
756
+ """
757
+ Matches a repeated pattern
758
+ """
759
+
760
+ def __init__(self, pattern: PatternExpr, partial: bool = False) -> None:
761
+ super().__init__()
762
+ assert isinstance(pattern, PatternExpr)
763
+ self.pattern = pattern
764
+ self.partial = partial
765
+
766
+ def __repr__(self) -> str:
767
+ return f"{self.__class__.__name__}({self.pattern})"
768
+
769
+ def _match(self, node: List[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override]
770
+ if not isinstance(node, (list, tuple)) or len(node) == 0:
771
+ return FailedMatch("non_list")
772
+ m = Match(ctx, self)
773
+ # Propagating patterns with multiple users will ensure we don't revisit
774
+ # the same nodes
775
+ pattern_to_node = ctx.filter_multi_user_patterns()
776
+ matched = False
777
+ for i, child_node in enumerate(node):
778
+ child_ctx = MatchContext(
779
+ ctx.outputs, pattern_to_node, graph=child_node.graph
780
+ )
781
+ child_match = child_ctx.match(self.pattern, child_node)
782
+ pattern_to_node = child_ctx.filter_multi_user_patterns()
783
+ if not is_match(child_match):
784
+ if not self.partial:
785
+ return FailedMatch("list[{}]: {}", i, child_match)
786
+ continue
787
+ matched = True
788
+ m.extend(child_match.bundle())
789
+ if not matched:
790
+ return FailedMatch("list: no_match")
791
+ return m.bundle()
792
+
793
+ def pattern_eq(self, other: Any) -> bool:
794
+ other = typing.cast(Self, other) # super makes sure this is true
795
+ return (
796
+ super().pattern_eq(other)
797
+ and self.pattern.pattern_eq(other.pattern)
798
+ and self.partial == other.partial
799
+ )
800
+
801
+
802
+ class MultiOutputPattern(PatternExpr):
803
+ outputs: List[Optional[PatternExpr]]
804
+
805
+ def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None:
806
+ super().__init__()
807
+ assert isinstance(outputs[0], _TargetExpr)
808
+ assert all(x is None or isinstance(x, PatternExpr) for x in outputs), outputs
809
+ self.outputs = list(outputs)
810
+ self.op = outputs[0].op
811
+
812
+ @property
813
+ def fns(self) -> Union[Callable[..., Any], str, Sequence[Any]]:
814
+ # This cast is checked above in __init__()
815
+ output = typing.cast(_TargetExpr, self.outputs[0])
816
+ return output.fns
817
+
818
+ def __repr__(self) -> str:
819
+ return f"{self.__class__.__name__}({self.outputs})"
820
+
821
+ def pretty_print(self, pp: PatternPrettyPrinter) -> str:
822
+ args = [pp.pretty_print(x) for x in self.outputs]
823
+ joiner_str = f",\n{' '}"
824
+ str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}"
825
+ str_out = f"{str_out}\n])"
826
+ return str_out
827
+
828
+ def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
829
+ output = typing.cast(_TargetExpr, self.outputs[0])
830
+ m = ctx.match(output, node)
831
+ if not is_match(m):
832
+ return m
833
+
834
+ for pattern in self.outputs[1:]:
835
+ if pattern is None:
836
+ continue
837
+ child_match = self._match_from_anchors(pattern, ctx)
838
+ if not is_match(child_match):
839
+ return child_match
840
+ m.extend(child_match)
841
+
842
+ return m
843
+
844
+ def _match_from_anchors(
845
+ self, pattern: PatternExpr, ctx: MatchContext
846
+ ) -> MatchResult:
847
+ prior = dict(ctx.pattern_to_node)
848
+ m: MatchResult = FailedMatch("no anchor found")
849
+ for node in pattern.find_anchor_nodes(ctx, set()):
850
+ m = ctx.match(pattern, node)
851
+ if is_match(m):
852
+ return m
853
+ # revert any partial matches
854
+ ctx.pattern_to_node = dict(prior)
855
+ return m
856
+
857
+ def match(self, node: torch.fx.Node) -> MatchResult:
858
+ try:
859
+ return MatchContext(self.outputs, graph=node.graph).match(self, node)
860
+ except FailedMatch as e:
861
+ return e
862
+
863
+ def pattern_eq(self, other: Any) -> bool:
864
+ other = typing.cast(Self, other) # super makes sure this is true
865
+ return (
866
+ super().pattern_eq(other)
867
+ and len(self.outputs) == len(other.outputs)
868
+ and all(
869
+ a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b
870
+ for a, b in zip(self.outputs, other.outputs)
871
+ )
872
+ )
873
+
874
+
875
+ class RepeatedExpr(PatternExpr):
876
+ """
877
+ Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind`
878
+ """
879
+
880
+ def __init__(self, inner_pattern: _TargetExpr) -> None:
881
+ super().__init__()
882
+ self.inner_pattern = inner_pattern
883
+ self.op = inner_pattern.op
884
+
885
+ @property
886
+ def fns(self) -> Sequence[FnsType]:
887
+ return self.inner_pattern.fns
888
+
889
+ def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
890
+ m = ctx.match(self.inner_pattern, node)
891
+ if not is_match(m):
892
+ return m
893
+ ctx.pattern_to_node.pop(
894
+ self.inner_pattern,
895
+ )
896
+ # Check all anchor nodes match the pattern
897
+ for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, set()):
898
+ anchor_m = MatchContext([self], graph=node.graph).match(
899
+ self.inner_pattern, anchor_node
900
+ )
901
+ if not is_match(anchor_m):
902
+ return anchor_m
903
+ m.extend(anchor_m)
904
+ return m
905
+
906
+ def pattern_eq(self, other: Any) -> bool:
907
+ other = typing.cast(Self, other) # super makes sure this is true
908
+ return super().pattern_eq(other) and self.inner_pattern.pattern_eq(
909
+ other.inner_pattern
910
+ )
911
+
912
+
913
+ class PatternPrettyPrinter:
914
+ """
915
+ Serializes Patterns to executable python.
916
+ XXX: currently only used and tested for fuse attention patterns. May not cover
917
+ all patterns.
918
+ """
919
+
920
+ def __init__(self) -> None:
921
+ self.namespace = torch.fx.graph._Namespace()
922
+ self.memoized_objs_names: Dict[PatternExpr, str] = {}
923
+ self.memoized_objs_pp: Dict[PatternExpr, str] = {}
924
+
925
+ @staticmethod
926
+ @functools.lru_cache(None)
927
+ def run(obj: PatternExpr, output_name: str = "output") -> str:
928
+ """
929
+ Serializes obj to python code with obj written out to `output_name`
930
+ """
931
+
932
+ pp = PatternPrettyPrinter()
933
+ assert hasattr(obj, "pretty_print")
934
+ out_str = obj.pretty_print(pp=pp)
935
+
936
+ output = []
937
+ for key in pp.memoized_objs_names:
938
+ output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}")
939
+
940
+ output.append(f"{output_name} = {out_str}")
941
+
942
+ return "\n".join(output)
943
+
944
+ def pretty_print(self, obj: Any) -> str:
945
+ if isinstance(obj, _TargetArgsExpr):
946
+ if memoized_name := self.memoized_objs_names.get(obj):
947
+ return memoized_name
948
+ else:
949
+ return self.memoize(obj)
950
+ if hasattr(obj, "pretty_print"):
951
+ return obj.pretty_print(self)
952
+
953
+ return repr(obj)
954
+
955
+ def memoize(self, obj: _TargetArgsExpr) -> str:
956
+ obj_str = obj.pretty_print(self)
957
+ obj_name = obj.fns_repr()
958
+ for prefix in ("aten.", "torch.", "prims."):
959
+ obj_name = obj_name.replace(prefix, "")
960
+
961
+ tmp_name = self.namespace.create_name(obj_name, None)
962
+ self.memoized_objs_names[obj] = tmp_name
963
+ self.memoized_objs_pp[obj] = obj_str
964
+ return tmp_name
965
+
966
+
967
+ class _PassDictsType(Protocol):
968
+ def __getitem__(self, k: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
969
+ ...
970
+
971
+
972
+ @dataclasses.dataclass
973
+ class PatternEntry:
974
+ pattern: PatternExpr
975
+ extra_check: Callable[[Match], bool]
976
+
977
+ def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
978
+ raise NotImplementedError
979
+
980
+ def register(
981
+ self,
982
+ pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
983
+ target: Union[torch.fx.node.Target, None] = None,
984
+ prepend: bool = False,
985
+ ) -> None:
986
+ if target is None:
987
+ assert hasattr(self.pattern, "fns")
988
+ for fn in self.pattern.fns:
989
+ self.register(pass_dicts, fn, prepend=prepend)
990
+ elif isinstance(pass_dicts, (dict, PatternMatcherPass)):
991
+ assert hasattr(self.pattern, "op")
992
+ if prepend:
993
+ pass_dicts[(self.pattern.op, target)].insert(0, self)
994
+ else:
995
+ pass_dicts[(self.pattern.op, target)].append(self)
996
+ else:
997
+ pass_dicts = typing.cast(Sequence[_PassDictsType], pass_dicts)
998
+ for x in pass_dicts:
999
+ self.register(x, target, prepend=prepend)
1000
+
1001
+
1002
+ @dataclasses.dataclass
1003
+ class LoweringPatternEntry(PatternEntry):
1004
+ handler: Callable[..., Any]
1005
+
1006
+ def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
1007
+ handler = functools.wraps(self.handler)(functools.partial(self.handler, match))
1008
+ with graph.inserting_before(node):
1009
+ replacement = graph.call_function(handler, tuple(match.args), match.kwargs)
1010
+ replacement.meta.update(node.meta)
1011
+ node.replace_all_uses_with(replacement)
1012
+ assert match.nodes[-1] is node
1013
+ match.erase_nodes()
1014
+
1015
+
1016
+ @dataclasses.dataclass
1017
+ class GraphPatternEntry(PatternEntry):
1018
+ """
1019
+ A pattern that runs a function on the FX graph
1020
+ """
1021
+
1022
+ handler: Callable[..., Any]
1023
+
1024
+ def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
1025
+ with graph.inserting_before(node):
1026
+ self.handler(match, *match.args, **match.kwargs)
1027
+
1028
+
1029
+ @dataclasses.dataclass
1030
+ class ReplacementPatternEntry(PatternEntry):
1031
+ normalize_args: Callable[..., List[Any]]
1032
+
1033
+ @staticmethod
1034
+ def replace_with_graph(
1035
+ match: Match,
1036
+ graph: torch.fx.Graph,
1037
+ replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule],
1038
+ args: Sequence[torch.fx.Node],
1039
+ ) -> None:
1040
+ class Replacer(torch.fx.Interpreter):
1041
+ call_method = None # type: ignore[assignment]
1042
+ call_module = None # type: ignore[assignment]
1043
+ get_attr = None # type: ignore[assignment]
1044
+
1045
+ def run_node(self, node: torch.fx.Node) -> Any:
1046
+ if node.op in ("placeholder", "output"):
1047
+ return super().run_node(node)
1048
+ if node.op == "call_function":
1049
+ target = node.target
1050
+ args, kwargs = self.fetch_args_kwargs_from_env(node)
1051
+ result = graph.call_function(target, args, kwargs) # type: ignore[arg-type]
1052
+ if "val" in node.meta and "val" not in result.meta:
1053
+ result.meta["val"] = node.meta["val"]
1054
+ if isinstance(node.meta["val"], torch.Tensor):
1055
+ assert "tensor_meta" in node.meta
1056
+ result.meta["tensor_meta"] = node.meta["tensor_meta"]
1057
+ return result
1058
+ raise NotImplementedError(f"unhandled {node}")
1059
+
1060
+ output_nodes = match.output_nodes()
1061
+
1062
+ if len(output_nodes) == 1:
1063
+ last_node = output_nodes[0]
1064
+ else:
1065
+ assert output_nodes[0]
1066
+ nodes = list(output_nodes[0].graph.nodes)
1067
+ indices = [
1068
+ (nodes.index(n), n)
1069
+ for n in output_nodes
1070
+ if isinstance(n, torch.fx.Node)
1071
+ ]
1072
+ last_node = min(indices, key=operator.itemgetter(0))[1]
1073
+
1074
+ def percolate_tags(
1075
+ node: torch.fx.Node,
1076
+ tag_name: str,
1077
+ tag_value: str,
1078
+ input_stops: Set[torch.fx.Node],
1079
+ ) -> None:
1080
+ queue = [node]
1081
+ visited = set()
1082
+
1083
+ while queue:
1084
+ arg = queue.pop()
1085
+ if (
1086
+ arg not in visited
1087
+ and arg not in input_stops
1088
+ and hasattr(arg, "meta")
1089
+ ):
1090
+ visited.add(arg)
1091
+ arg.meta[tag_name] = tag_value
1092
+ queue.extend(arg.all_input_nodes)
1093
+
1094
+ with graph.inserting_before(last_node):
1095
+ replacement = Replacer(replacement_graph).run(*args) # type: ignore[arg-type]
1096
+ if isinstance(replacement, torch.fx.Node):
1097
+ replacement = [replacement]
1098
+
1099
+ def maybe_getitem(node: torch.fx.Node) -> Any:
1100
+ if node.op != "call_function":
1101
+ return None
1102
+ if node.target != operator.getitem:
1103
+ return None
1104
+ assert len(node.args) == 2
1105
+ return node.args[1]
1106
+
1107
+ def replace(
1108
+ old: Union[torch.fx.Node, None],
1109
+ new: Union[torch.fx.Node, Sequence[torch.fx.Node], None],
1110
+ ) -> None:
1111
+ if old is None:
1112
+ assert new is None
1113
+ return
1114
+ assert isinstance(old, torch.fx.Node)
1115
+ if new is None:
1116
+ old.replace_all_uses_with(None) # type: ignore[arg-type]
1117
+ graph.erase_node(old)
1118
+ return
1119
+ if isinstance(new, torch.fx.Node):
1120
+ if "val" not in new.meta:
1121
+ new.meta.update(old.meta)
1122
+
1123
+ # Preserve the recompute tags in the replacement graph. We
1124
+ # look at the recompute tags of the original output node to
1125
+ # propagate the tag from the output all the way to the input
1126
+ # args (named as args in the replace_with_graph).
1127
+ # Note that this is best effort. Since patterns are from
1128
+ # many to many, there is no easy way to correctly map the
1129
+ # recomputable tags. It is possible in some scenarios that we
1130
+ # incorrectly tag some nodes as recomputables.
1131
+ for tag_name in ["recompute", "ac_graph_id"]:
1132
+ if tag_name in old.meta:
1133
+ percolate_tags(new, tag_name, old.meta[tag_name], set(args))
1134
+
1135
+ old.replace_all_uses_with(new)
1136
+ graph.erase_node(old)
1137
+ return
1138
+
1139
+ # `new` is not a node: it's a list of nodes.
1140
+ #
1141
+ # This happens when we want to replace a node that has a single
1142
+ # packed return with multiple unpacked returns. We need to do
1143
+ # some graph surgery here.
1144
+ #
1145
+ # Example:
1146
+ # def original_graph(x):
1147
+ # a = op(x)
1148
+ # b = a[0]
1149
+ # c = a[1]
1150
+ # ...
1151
+ #
1152
+ # Assume that we want to replace op(x) with the graph
1153
+ # def new_op(x):
1154
+ # w = x + 1
1155
+ # z = x + 2
1156
+ # return (w, z)
1157
+ #
1158
+ # We need to replace `op` with the contents of `new_op`,
1159
+ # and then rewrite a[0] to be w and a[1] to be z, as so:
1160
+ # def new_graph(x):
1161
+ # w = x + 1
1162
+ # z = x + 2
1163
+ # b = w
1164
+ # c = z
1165
+ # ...
1166
+ old_uses = list(old.users.keys())
1167
+ for user in old_uses:
1168
+ idx = maybe_getitem(user)
1169
+ if idx is None:
1170
+ raise AssertionError("can't handle")
1171
+ replace(user, new[idx]) # type: ignore[index]
1172
+ graph.erase_node(old)
1173
+
1174
+ if len(output_nodes) == len(replacement):
1175
+ for old, new in zip(output_nodes, replacement):
1176
+ replace(old, new)
1177
+ else:
1178
+ assert len(output_nodes) == 1
1179
+ replace(output_nodes[0], replacement)
1180
+
1181
+ match.erase_nodes()
1182
+
1183
+ def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
1184
+ assert match.replacement_graph is not None
1185
+ self.replace_with_graph(
1186
+ match,
1187
+ graph,
1188
+ match.replacement_graph,
1189
+ self.normalize_args(*match.args, **match.kwargs),
1190
+ )
1191
+
1192
+
1193
+ def _return_true(match: Match) -> bool:
1194
+ return True
1195
+
1196
+
1197
+ def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None:
1198
+ log.info(
1199
+ "Replacement pattern %s failed to apply due to shape mismatch: %s",
1200
+ search_fn.__name__,
1201
+ e,
1202
+ )
1203
+
1204
+
1205
+ def register_replacement(
1206
+ search_fn: SearchFn,
1207
+ replace_fn: ReplaceFn,
1208
+ example_inputs: Iterable[Any],
1209
+ trace_fn: TraceFn,
1210
+ pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
1211
+ extra_check: Callable[[Match], bool] = _return_true,
1212
+ scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
1213
+ exclusive_arg_names: Sequence[str] = (),
1214
+ search_fn_pattern: Union[PatternExpr, None] = None,
1215
+ ) -> bool:
1216
+ """
1217
+ Create a replacement rule based on example functions that get traced
1218
+ to create patterns. This supports both training and inference when
1219
+ run on a joint forward+backward graph.
1220
+
1221
+ Args:
1222
+ search_fn: traced to give original pattern
1223
+ replace_fn: traced to give replacement graph
1224
+ example_inputs: example inputs for initial trace
1225
+ trace_fn: fwd_only or joint_fwd_bwd
1226
+ pass_dict: dict of passes to register to
1227
+ extra_check: additional check to run on match(using real shapes)
1228
+ """
1229
+ argnames_static = [*inspect.signature(search_fn).parameters.keys()]
1230
+
1231
+ def check_fn(match: Match) -> bool:
1232
+ """
1233
+ Often shapes get burned into the pattern, so our initial match ran with
1234
+ `ignore_types=(int, ...)`.
1235
+
1236
+ Recheck the match with the correct shapes.
1237
+ """
1238
+ argnames = list(argnames_static)
1239
+ for name in argnames:
1240
+ if name not in match.kwargs:
1241
+ raise RuntimeError(
1242
+ f"Not all inputs to pattern found in match.kwargs. Perhaps one "
1243
+ f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}"
1244
+ )
1245
+
1246
+ args = list(
1247
+ torch.fx.map_arg( # type: ignore[arg-type]
1248
+ [match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
1249
+ )
1250
+ )
1251
+ sym_args: List[torch.SymInt] = []
1252
+ with torch._dynamo.utils.detect_fake_mode(args):
1253
+ for i, grad in enumerate(requires_grad):
1254
+ if isinstance(args[i], torch.Tensor):
1255
+ if grad and is_integer_dtype(args[i].dtype):
1256
+ return False
1257
+
1258
+ args[i] = torch.empty_strided(
1259
+ args[i].size(),
1260
+ args[i].stride(),
1261
+ dtype=args[i].dtype,
1262
+ device=args[i].device,
1263
+ requires_grad=grad,
1264
+ )
1265
+ for v in itertools.chain(args[i].shape, args[i].stride()):
1266
+ if isinstance(v, torch.SymInt) and all(
1267
+ guard_size_oblivious(v != a) for a in sym_args
1268
+ ):
1269
+ sym_args.append(v)
1270
+
1271
+ # If we were given a pre-traced pattern then use that instead of
1272
+ # retracing. Note that this means the pattern has to be independent
1273
+ # of its args.
1274
+ specific_pattern = search_fn_pattern
1275
+
1276
+ if not specific_pattern:
1277
+ if sym_args:
1278
+ # AOT Autograd and make fx will dedupe symbolic shape size
1279
+ # accesses of sym ints that appear as inputs
1280
+ # We don't want the sym_size uses to interfere with pattern matching
1281
+ # so we provide them as inputs.
1282
+ # Later, when we actually do the replacement, the symbolic shape
1283
+ # sizes will get re-traced and added to the graph.
1284
+
1285
+ def search_fn_new(*args_new: Any) -> Any:
1286
+ return search_fn(*args_new[len(args_new) - len(args) :])
1287
+
1288
+ try:
1289
+ specific_graph = trace_fn(search_fn_new, sym_args + args)
1290
+ except RuntimeError as e:
1291
+ log_trace_failure(search_fn, e)
1292
+ return False
1293
+
1294
+ # correct argnames in the graph
1295
+ sym_arg_names = []
1296
+ for i, placeholder in zip(
1297
+ range(len(sym_args) + len(args)),
1298
+ specific_graph.graph.nodes,
1299
+ ):
1300
+ if i < len(sym_args):
1301
+ sym_arg_names.append(placeholder.target)
1302
+ continue
1303
+
1304
+ with specific_graph.graph.inserting_after(placeholder):
1305
+ new_node = specific_graph.graph.placeholder(
1306
+ argnames[i - len(sym_args)]
1307
+ )
1308
+ new_node.target = new_node.name
1309
+ placeholder.replace_all_uses_with(new_node)
1310
+ specific_graph.graph.erase_node(placeholder)
1311
+
1312
+ argnames = sym_arg_names + argnames
1313
+ else:
1314
+ try:
1315
+ specific_graph = trace_fn(search_fn, args)
1316
+ except RuntimeError as e:
1317
+ log_trace_failure(search_fn, e)
1318
+ return False
1319
+
1320
+ specific_pattern = fx_to_pattern(
1321
+ specific_graph,
1322
+ argnames=argnames,
1323
+ exclusive_arg_names=exclusive_arg_names,
1324
+ scalar_workaround=scalar_workaround,
1325
+ )
1326
+
1327
+ node = match.output_nodes()[0]
1328
+ assert node is not None
1329
+ specific_pattern_match = specific_pattern.match(node)
1330
+
1331
+ if is_match(specific_pattern_match) and extra_check(specific_pattern_match):
1332
+ # trace the pattern using the shapes from the user program
1333
+ match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment]
1334
+ return True
1335
+ return False
1336
+
1337
+ def normalize_args(**kwargs: Any) -> List[Any]:
1338
+ args = []
1339
+ for name in argnames_static:
1340
+ args.append(kwargs.pop(name))
1341
+ for i in range(1, len(kwargs) + 1):
1342
+ if f"tangents_{i}" not in kwargs:
1343
+ break
1344
+ args.append(kwargs.pop(f"tangents_{i}"))
1345
+ assert not kwargs, f"leftover kwargs: {kwargs!r}"
1346
+ return args
1347
+
1348
+ if trace_fn is joint_fwd_bwd:
1349
+ # If inference mode is enabled during compilation, assume that we don't
1350
+ # want to match on any training graph patterns
1351
+ if torch.is_inference_mode_enabled():
1352
+ return False
1353
+
1354
+ # TODO: Revisit the functionalize_rng_ops for lowmem dropout
1355
+ with functorch_config.patch(functionalize_rng_ops=False):
1356
+ requires_grad: List[bool] = [
1357
+ isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs
1358
+ ]
1359
+ if search_fn_pattern is None:
1360
+ pattern = gen_pattern(
1361
+ search_fn,
1362
+ example_inputs,
1363
+ trace_fn,
1364
+ scalar_workaround,
1365
+ exclusive_arg_names,
1366
+ )
1367
+ else:
1368
+ pattern = search_fn_pattern
1369
+
1370
+ pattern_repr = PatternPrettyPrinter.run(pattern)
1371
+ assert pattern_repr not in _seen_patterns
1372
+ _seen_patterns.add(pattern_repr)
1373
+ pattern = ReplacementPatternEntry(
1374
+ pattern=pattern,
1375
+ extra_check=check_fn,
1376
+ normalize_args=normalize_args,
1377
+ )
1378
+ pattern.register(pass_dicts)
1379
+ return pattern.pattern
1380
+
1381
+
1382
+ _serialized_patterns: Set[str] = set()
1383
+
1384
+
1385
+ def _serialize_pattern(
1386
+ unique_name: str,
1387
+ search_fn: SearchFn,
1388
+ example_inputs: Iterable[Any],
1389
+ trace_fn: TraceFn,
1390
+ scalar_workaround: Union[Dict[str, Union[float, int]], None],
1391
+ ) -> PatternExpr:
1392
+ def get_file_template() -> str:
1393
+ auto_generated_msg = textwrap.dedent(
1394
+ """\
1395
+ # This is an auto-generated file. Please do not modify it by hand.
1396
+ # To re-generate, run:
1397
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
1398
+ """
1399
+ )
1400
+
1401
+ file_template = textwrap.dedent(
1402
+ """\
1403
+ # mypy: ignore-errors
1404
+
1405
+ # noqa: F401, E501
1406
+ {msg}
1407
+ import torch
1408
+ import torch._inductor
1409
+
1410
+ aten = torch.ops.aten
1411
+ prims = torch.ops.prims
1412
+
1413
+ """
1414
+ ).format(msg=auto_generated_msg)
1415
+
1416
+ pattern_matcher_imports = []
1417
+ for name in dir(torch._inductor.pattern_matcher):
1418
+ attr = getattr(torch._inductor.pattern_matcher, name)
1419
+ if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)):
1420
+ pattern_matcher_imports.append(name)
1421
+
1422
+ formatted_imports = ",\n ".join(pattern_matcher_imports)
1423
+ formatted_imports = f"from torch._inductor.pattern_matcher import (\n {formatted_imports},\n)\n"
1424
+ return f"{file_template}{formatted_imports}"
1425
+
1426
+ if not SERIALIZED_PATTERN_PATH.is_dir():
1427
+ raise RuntimeError(
1428
+ f"Could not find serialized patterns directory at {SERIALIZED_PATTERN_PATH}"
1429
+ )
1430
+
1431
+ pattern_name = search_fn.__name__
1432
+
1433
+ from torch._functorch import config as functorch_config
1434
+
1435
+ with functorch_config.patch(functionalize_rng_ops=False):
1436
+ pattern = gen_pattern(search_fn, example_inputs, trace_fn, scalar_workaround)
1437
+
1438
+ serialized_pattern = PatternPrettyPrinter.run(pattern, output_name=unique_name)
1439
+ if pattern_name not in _serialized_patterns:
1440
+ write_mode = "w"
1441
+ _serialized_patterns.add(pattern_name)
1442
+ else:
1443
+ write_mode = "a"
1444
+
1445
+ file_template = get_file_template()
1446
+
1447
+ with open(SERIALIZED_PATTERN_PATH / f"{pattern_name}.py", write_mode) as f:
1448
+ if write_mode == "w":
1449
+ f.write(file_template)
1450
+ else:
1451
+ f.write("\n\n")
1452
+ f.write(serialized_pattern)
1453
+ f.write("\n")
1454
+
1455
+ return pattern
1456
+
1457
+
1458
+ SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patterns"
1459
+
1460
+ # This is the set of serialized patterns that we've registered. Used by
1461
+ # test_serialized_patterns_up_to_date() to ensure the patterns are up
1462
+ # to date.
1463
+ _known_precompiled_patterns: List[
1464
+ Tuple[
1465
+ Any,
1466
+ Iterable[Any],
1467
+ Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule],
1468
+ Any,
1469
+ PatternExpr,
1470
+ ]
1471
+ ] = []
1472
+
1473
+
1474
+ def gen_register_replacement(
1475
+ unique_name: str,
1476
+ search_fn: SearchFn,
1477
+ replace_fn: ReplaceFn,
1478
+ example_inputs: Iterable[Any],
1479
+ trace_fn: TraceFn,
1480
+ pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
1481
+ extra_check: Callable[[Match], bool] = _return_true,
1482
+ scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
1483
+ exclusive_arg_names: Sequence[str] = (),
1484
+ skip_duplicates: bool = False,
1485
+ ) -> None:
1486
+ # Make sure the example_inputs is materialized.
1487
+ example_inputs = tuple(example_inputs)
1488
+
1489
+ if "PYTORCH_GEN_PATTERNS" in os.environ:
1490
+ pat = _serialize_pattern(
1491
+ unique_name, search_fn, example_inputs, trace_fn, scalar_workaround
1492
+ )
1493
+ else:
1494
+ pattern_name = search_fn.__name__
1495
+ m = importlib.import_module(
1496
+ f"torch._inductor.fx_passes.serialized_patterns.{pattern_name}"
1497
+ )
1498
+ if not m or not hasattr(m, unique_name):
1499
+ log.warning(
1500
+ "Precompiled pattern %r not found. Run torchgen/fuse/gen_patterns.py.",
1501
+ unique_name,
1502
+ )
1503
+ pat = getattr(m, unique_name)
1504
+
1505
+ for arg in pytree.tree_iter(example_inputs):
1506
+ if isinstance(arg, FakeTensor) and arg.constant is not None:
1507
+ # This can be a problem - small fake tensors (e.g. `tensor(2)`) will
1508
+ # hold onto their original constant value - and by stashing it here
1509
+ # will cause a memory leak if the constant value is on GPU.
1510
+ # Since this is just an optimization we can clear it out.
1511
+ arg.constant = None
1512
+
1513
+ if PatternPrettyPrinter.run(pat) in _seen_patterns and skip_duplicates:
1514
+ return
1515
+ _known_precompiled_patterns.append(
1516
+ (search_fn, example_inputs, trace_fn, scalar_workaround, pat)
1517
+ )
1518
+ register_replacement(
1519
+ search_fn,
1520
+ replace_fn,
1521
+ example_inputs,
1522
+ trace_fn,
1523
+ pass_dicts,
1524
+ extra_check,
1525
+ scalar_workaround,
1526
+ exclusive_arg_names,
1527
+ search_fn_pattern=pat,
1528
+ )
1529
+
1530
+
1531
+ @functorch_config.patch(functionalize_rng_ops=False)
1532
+ def gen_pattern(
1533
+ search_fn: SearchFn,
1534
+ example_inputs: Sequence[Any],
1535
+ trace_fn: TraceFn,
1536
+ scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
1537
+ exclusive_arg_names: Sequence[str] = (),
1538
+ ) -> PatternExpr:
1539
+ argnames = [*inspect.signature(search_fn).parameters.keys()]
1540
+
1541
+ if scalar_workaround is None:
1542
+ scalar_workaround = {}
1543
+ flat_inputs = []
1544
+ input_idx = 0 # Positional arguments index
1545
+
1546
+ for argname in argnames:
1547
+ if argname in scalar_workaround:
1548
+ flat_inputs.append(scalar_workaround[argname])
1549
+ else:
1550
+ flat_inputs.append(example_inputs[input_idx])
1551
+ input_idx += 1
1552
+
1553
+ search_gm = trace_fn(search_fn, flat_inputs)
1554
+ return fx_to_pattern(
1555
+ search_gm,
1556
+ ignore_types=(int, float, list, torch.device, torch.dtype),
1557
+ argnames=argnames,
1558
+ scalar_workaround=scalar_workaround,
1559
+ exclusive_arg_names=exclusive_arg_names,
1560
+ )
1561
+
1562
+
1563
+ def register_lowering_pattern(
1564
+ pattern: PatternExpr,
1565
+ extra_check: Callable[[Match], bool] = _return_true,
1566
+ *,
1567
+ pass_dict: _PassDictsType,
1568
+ prepend: bool = False,
1569
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
1570
+ """
1571
+ Register an aten to inductor IR replacement pattern. The decorated
1572
+ function is saved and then called a lowering time allowing direct
1573
+ pattern to inductor IR conversion.
1574
+ """
1575
+
1576
+ def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
1577
+ assert callable(handler)
1578
+ LoweringPatternEntry(
1579
+ pattern=pattern, extra_check=extra_check, handler=handler
1580
+ ).register(pass_dict, prepend=prepend)
1581
+ handler._inductor_lowering_function = True # type: ignore[attr-defined]
1582
+ return handler
1583
+
1584
+ return decorator
1585
+
1586
+
1587
+ def register_graph_pattern(
1588
+ pattern: PatternExpr,
1589
+ extra_check: Callable[[Match], bool] = _return_true,
1590
+ *,
1591
+ pass_dict: _PassDictsType,
1592
+ prepend: bool = False,
1593
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
1594
+ """
1595
+ Register a pattern that runs a function on the FX graph, allowing
1596
+ custom transformation code.
1597
+ """
1598
+
1599
+ def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
1600
+ assert callable(handler)
1601
+ GraphPatternEntry(
1602
+ pattern=pattern, extra_check=extra_check, handler=handler
1603
+ ).register(pass_dict, prepend=prepend)
1604
+ return handler
1605
+
1606
+ return decorator
1607
+
1608
+
1609
+ def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
1610
+ # first node in the graph
1611
+ return node is next(iter(graph.nodes))
1612
+
1613
+
1614
+ # match: copy_, relu_, _set_grad_enabled, manual_seed, _enter_autocast, etc
1615
+ # doesn't match: __rshift__, etc
1616
+ _mutation_op_re = re.compile(r"(?<!_)(_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_))(?!_)")
1617
+
1618
+
1619
+ def is_mutation_op(node: torch.fx.Node) -> bool:
1620
+ if node.op == "call_function":
1621
+ if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
1622
+ return True
1623
+ elif node.op == "call_method":
1624
+ if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type]
1625
+ return True
1626
+ return node.kwargs.get("out") is not None
1627
+
1628
+
1629
+ def same_mutation_regions(a: torch.fx.Node, b: torch.fx.Node) -> bool:
1630
+ assert "mutation_region_id" in a.meta
1631
+ assert "mutation_region_id" in b.meta
1632
+ return a.meta["mutation_region_id"] == b.meta["mutation_region_id"]
1633
+
1634
+
1635
+ def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
1636
+ n = node
1637
+ while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):
1638
+ n = n.prev
1639
+ mutation_region_id = n.meta.get("mutation_region_id", 0)
1640
+ while n is not node:
1641
+ n = n.next
1642
+ if is_mutation_op(n):
1643
+ mutation_region_id += 1
1644
+ n.meta["mutation_region_id"] = mutation_region_id
1645
+ return mutation_region_id
1646
+
1647
+
1648
+ def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool:
1649
+ return "mutation_region_id" not in next(iter(graph.nodes)).meta
1650
+
1651
+
1652
+ def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None:
1653
+ mutation_region_id = 0
1654
+ for nd in graph.nodes:
1655
+ if is_mutation_op(nd):
1656
+ mutation_region_id += 1
1657
+ nd.meta["mutation_region_id"] = mutation_region_id
1658
+
1659
+
1660
+ class PatternMatcherPass:
1661
+ def __init__(
1662
+ self,
1663
+ pass_name: Optional[str] = None,
1664
+ ) -> None:
1665
+ super().__init__()
1666
+ self.patterns: DefaultDict[
1667
+ Tuple[str, torch.fx.node.Target], List[PatternEntry]
1668
+ ] = defaultdict(list)
1669
+ self.pass_name = pass_name
1670
+
1671
+ def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
1672
+ return self.patterns[item]
1673
+
1674
+ def apply(self, gm: torch.fx.GraphModule) -> int:
1675
+ if not self.patterns:
1676
+ return 0
1677
+ if isinstance(gm, torch.fx.GraphModule):
1678
+ graph = gm.graph
1679
+ elif isinstance(gm, torch.fx.Graph):
1680
+ graph = gm
1681
+ gm = graph.owning_module
1682
+ else:
1683
+ raise RuntimeError(
1684
+ f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}"
1685
+ )
1686
+ if should_compute_mutation_region_ids(graph): # type: ignore[arg-type]
1687
+ compute_mutation_region_ids(graph) # type: ignore[arg-type]
1688
+ get_mutation_region_id_partial = functools.partial(
1689
+ get_mutation_region_id, graph
1690
+ )
1691
+ count = 0
1692
+ nodes = []
1693
+ has_call_module = False
1694
+ for op, target in self.patterns:
1695
+ if op == "call_module":
1696
+ has_call_module = True
1697
+ else:
1698
+ nodes.append(graph.find_nodes(op=op, target=target, sort=False))
1699
+ if has_call_module:
1700
+ nodes.append(graph.find_nodes(op="call_module", sort=False))
1701
+ pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
1702
+ with GraphTransformObserver(
1703
+ gm, pass_name, trace_config.log_url_for_graph_xform
1704
+ ):
1705
+ for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
1706
+ target = extract_target(node)
1707
+ if node.op == "call_module":
1708
+ if (node.op, target) not in self.patterns:
1709
+ continue
1710
+
1711
+ # conservatively not applying pattern for cpu input,
1712
+ # since some of the patterns induce codegen and split nodes.
1713
+ # Note: we will only skip cpu compute if disable_cpp_codegen=True
1714
+ if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False):
1715
+ continue
1716
+
1717
+ for entry in self.patterns[(node.op, target)]:
1718
+ if node._erased:
1719
+ break
1720
+ m = entry.pattern.match(node)
1721
+ # pattern match crosses mutation barrier - discard
1722
+ if (
1723
+ is_match(m)
1724
+ and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined]
1725
+ ):
1726
+ continue
1727
+ if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
1728
+ log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
1729
+ if is_match(m) and entry.extra_check(m):
1730
+ count += 1
1731
+ entry.apply(m, graph, node) # type: ignore[arg-type]
1732
+ counters["inductor"]["pattern_matcher_count"] += 1
1733
+ counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes)
1734
+ return count
1735
+
1736
+ def clear(self) -> None:
1737
+ self.patterns.clear()
1738
+
1739
+
1740
+ def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn:
1741
+ raise NotImplementedError
1742
+
1743
+
1744
+ def fx_to_pattern(
1745
+ gm: Union[torch.fx.GraphModule, torch.fx.Graph],
1746
+ ignore_types: Sequence[Type[Any]] = (),
1747
+ argnames: Sequence[str] = (),
1748
+ scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
1749
+ exclusive_arg_names: Sequence[str] = (),
1750
+ ) -> PatternExpr:
1751
+ """
1752
+ Convert an FX graph into a PatternExpr. This is useful for simple
1753
+ patterns that can only match single functions and fixed-length lists.
1754
+ """
1755
+ # scalar_workaround is a hack to capture dropout_p
1756
+ # see https://github.com/pytorch/pytorch/issues/97894
1757
+ scalar_workaround = scalar_workaround or {}
1758
+ inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()}
1759
+ assert len(inv_scalar_workaround) == len(scalar_workaround)
1760
+
1761
+ def process_arg(x: T) -> Union[T, KeywordArg, Ignored]:
1762
+ if isinstance(x, (float, int)) and x in inv_scalar_workaround:
1763
+ return KeywordArg(inv_scalar_workaround[x])
1764
+ if type(x) in ignore_types:
1765
+ return Ignored()
1766
+ if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x:
1767
+ return Ignored()
1768
+ return x
1769
+
1770
+ argnum = itertools.count()
1771
+
1772
+ class Converter(torch.fx.Interpreter):
1773
+ call_method = _not_implemented
1774
+ call_module = _not_implemented
1775
+ get_attr = _not_implemented
1776
+
1777
+ def placeholder(
1778
+ self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
1779
+ ) -> Union[ExclusiveKeywordArg, KeywordArg]:
1780
+ n = next(argnum)
1781
+ if n < len(argnames):
1782
+ name = argnames[n]
1783
+ elif argnames:
1784
+ assert target.startswith("tangent")
1785
+ name = target
1786
+ else:
1787
+ target = re.sub(r"_\d+$", "", target) # de-mangle arg name
1788
+ name = target
1789
+ if name in exclusive_arg_names:
1790
+ return ExclusiveKeywordArg(name)
1791
+ else:
1792
+ return KeywordArg(name)
1793
+
1794
+ def call_function(
1795
+ self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
1796
+ ) -> PatternExpr:
1797
+ args, kwargs = pytree.tree_map(process_arg, (args, kwargs))
1798
+ if list in ignore_types:
1799
+ # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...]
1800
+ args = [process_arg(a) for a in args]
1801
+ kwargs = {k: process_arg(a) for k, a in kwargs.items()}
1802
+ return CallFunction(target, *args, **kwargs)
1803
+
1804
+ def run_node(self, n: torch.fx.Node) -> Any:
1805
+ rv = super().run_node(n)
1806
+ if n.op == "output" and isinstance(rv, tuple):
1807
+ assert len(rv) == len(n.args[0]) # type: ignore[arg-type]
1808
+ for r, arg in zip(rv, n.args[0]): # type: ignore[arg-type]
1809
+ r.users = len(arg.users)
1810
+ else:
1811
+ rv.users = len(n.users)
1812
+ return rv
1813
+
1814
+ pattern = Converter(gm).run() # type: ignore[arg-type]
1815
+ if not isinstance(pattern, PatternExpr):
1816
+ return MultiOutputPattern(pytree.tree_leaves(pattern))
1817
+ return pattern
1818
+
1819
+
1820
+ @torch.no_grad()
1821
+ def fwd_only(
1822
+ fn: Callable[..., Any],
1823
+ args: Sequence[Any],
1824
+ *,
1825
+ run_functional_passes: bool = True,
1826
+ get_decomp_fn: Optional[Callable[..., Any]] = None,
1827
+ ) -> torch.fx.GraphModule:
1828
+ """Build a normalized inference graph, for use with fx_to_pattern"""
1829
+ # TODO - look into using aot autograd, asserting no mutating ops here
1830
+ with enable_python_dispatcher():
1831
+ decompositions = (
1832
+ get_decomp_fn() if get_decomp_fn is not None else select_decomp_table()
1833
+ )
1834
+ gm = make_fx(fn, decompositions, tracing_mode="real")(*args)
1835
+
1836
+ from .fx_passes.post_grad import remove_noop_ops
1837
+
1838
+ if run_functional_passes:
1839
+ remove_noop_ops(gm.graph)
1840
+ gm.graph.eliminate_dead_code()
1841
+
1842
+ gm.recompile()
1843
+ return gm
1844
+
1845
+
1846
+ @torch.enable_grad()
1847
+ def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.GraphModule:
1848
+ """Build a normalized training graph, for use with fx_to_pattern"""
1849
+ gm: Optional[torch.fx.GraphModule] = None
1850
+
1851
+ def record_joint_graph(
1852
+ joint_graph: torch.fx.GraphModule, inputs: Sequence[Any], **kwargs: Any
1853
+ ) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
1854
+ nonlocal gm
1855
+ assert not gm
1856
+ gm = clone_graph(joint_graph)
1857
+ return default_partition(joint_graph, inputs, **kwargs)
1858
+
1859
+ with torch._guards.tracing(None):
1860
+ aot_function(
1861
+ fn,
1862
+ lambda g, i: make_boxed_func(g),
1863
+ partition_fn=record_joint_graph,
1864
+ decompositions=select_decomp_table(),
1865
+ keep_inference_input_mutations=True,
1866
+ enable_log=False,
1867
+ )(*args)
1868
+ assert gm
1869
+
1870
+ from .fx_passes.post_grad import remove_noop_ops
1871
+
1872
+ remove_noop_ops(gm.graph)
1873
+
1874
+ from .fx_passes.joint_graph import pointless_view
1875
+
1876
+ matcher_pass = PatternMatcherPass()
1877
+
1878
+ pattern = CallFunction(
1879
+ torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")
1880
+ )
1881
+ GraphPatternEntry(
1882
+ pattern=pattern, handler=pointless_view, extra_check=_return_true
1883
+ ).register(matcher_pass.patterns)
1884
+ matcher_pass.apply(gm.graph) # type: ignore[arg-type]
1885
+
1886
+ # remove in/out specs
1887
+ gm.graph._codegen = torch.fx.graph.CodeGen()
1888
+ gm.graph.eliminate_dead_code()
1889
+ gm.recompile()
1890
+ return gm
1891
+
1892
+
1893
+ def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]:
1894
+ args: List[torch.fx.node.Argument] = []
1895
+ torch.fx.map_arg((n.args, n.kwargs), args.append)
1896
+ return args
1897
+
1898
+
1899
+ def stable_topological_sort(graph: torch.fx.Graph) -> None:
1900
+ # Nodes are in exactly one of these three collections:
1901
+
1902
+ # - Nodes in `pending` are waiting to be processed (in reverse order):
1903
+ pending = list(reversed(graph.nodes))
1904
+
1905
+ # - Nodes in `ready` have been processed and are already in the correct
1906
+ # order.
1907
+ ready = set()
1908
+
1909
+ # - `waiting` is a mapping from a dependency to nodes which depend on that
1910
+ # dependency.
1911
+ waiting = defaultdict(list)
1912
+
1913
+ # The cursor indicates the last processed node so we can add new nodes
1914
+ # after it.
1915
+ cursor = None
1916
+ while pending:
1917
+ node = pending.pop()
1918
+ waiting_for = [x for x in _args(node) if x not in ready]
1919
+ if waiting_for:
1920
+ # We have unprocessed input nodes. Might as well wait for the last
1921
+ # arg so an already sorted list will only recheck this node once.
1922
+ waiting[waiting_for[-1]].append(node)
1923
+ else:
1924
+ ready.add(node)
1925
+ if cursor and cursor.next is not node:
1926
+ cursor.append(node)
1927
+ cursor = node
1928
+ # Mark the nodes that have been waiting for this node to finish as
1929
+ # ready to check again.
1930
+ pending.extend(reversed(waiting.pop(node, ())))
1931
+
1932
+ assert not waiting and len(ready) == len(graph.nodes)
1933
+
1934
+
1935
+ def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]:
1936
+ """Wrapper around lazy init functions in fx_passes/"""
1937
+
1938
+ @functools.lru_cache(None)
1939
+ @functools.wraps(fn)
1940
+ def lazy_init() -> Any:
1941
+ counters_ref = counters["inductor"].copy()
1942
+
1943
+ with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
1944
+ result = fn()
1945
+
1946
+ # clear view matches encountered during tracing
1947
+ counters["inductor"] = counters_ref
1948
+
1949
+ return result
1950
+
1951
+ return lazy_init
1952
+
1953
+
1954
+ def config_flag(name: str) -> Callable[[Match], Any]:
1955
+ """Function for extra_check to put pass behind a flag"""
1956
+
1957
+ def flag_check(match: Match) -> Any:
1958
+ return getattr(config, name)
1959
+
1960
+ return flag_check
1961
+
1962
+
1963
+ def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
1964
+ class CopyGraph(Transformer):
1965
+ def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node:
1966
+ new_node = super().run_node(old_node)
1967
+ if isinstance(new_node, torch.fx.Proxy):
1968
+ new_node.node.meta.update(old_node.meta)
1969
+ new_node.node.name = self.new_graph._graph_namespace.create_name(
1970
+ old_node.name, None
1971
+ )
1972
+ return new_node
1973
+
1974
+ return CopyGraph(input_graph).transform()
1975
+
1976
+
1977
+ _seen_patterns: Set[str] = set()
1978
+
1979
+
1980
+ def get_arg_value(
1981
+ node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None
1982
+ ) -> Any:
1983
+ return (
1984
+ node.args[arg_number]
1985
+ if len(node.args) > arg_number
1986
+ else node.kwargs.get(kwarg_name) # type: ignore[arg-type]
1987
+ )
1988
+
1989
+
1990
+ def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> List[torch.fx.Node]:
1991
+ fns = [fn]
1992
+ if isinstance(fn, torch._ops.OpOverloadPacket):
1993
+ fns.extend([getattr(fn, overload) for overload in fn.overloads()])
1994
+
1995
+ return [node for node in nodes if node.target in fns]
1996
+
1997
+
1998
+ def extract_target(node: torch.fx.Node) -> torch.fx.node.Target:
1999
+ """For call_function and call_method, we directly use the target function;
2000
+ For call_module, the target is string, and we treat the module class
2001
+ as a function.
2002
+ """
2003
+ if node.op == "call_module":
2004
+ return getattr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type]
2005
+ return node.target
.venv/lib/python3.11/site-packages/torch/_inductor/quantized_lowerings.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+
4
+ import torch
5
+ from torch._inductor.kernel.mm_common import mm_args
6
+
7
+ from . import config as inductor_config, lowering
8
+ from .codegen.cpp_gemm_template import CppPackedGemmTemplate
9
+ from .codegen.cpp_utils import create_epilogue_with_attr
10
+ from .lowering import expand, register_lowering
11
+ from .select_algorithm import (
12
+ autotune_select_algorithm,
13
+ ExternKernelChoice,
14
+ realize_inputs,
15
+ )
16
+ from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template
17
+
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+ aten__weight_int8pack_mm = ExternKernelChoice(
22
+ torch._weight_int8pack_mm, "at::_weight_int8pack_mm", has_out_variant=False
23
+ )
24
+
25
+
26
+ quantized = torch.ops.quantized
27
+ _quantized = torch.ops._quantized
28
+ aten = torch.ops.aten
29
+
30
+
31
+ def register_quantized_ops():
32
+ lowering.add_needs_realized_inputs(
33
+ [
34
+ quantized.max_pool2d,
35
+ _quantized.wrapped_fbgemm_pack_gemm_matrix_fp16,
36
+ _quantized.wrapped_fbgemm_linear_fp16_weight,
37
+ ]
38
+ )
39
+
40
+ lowering.make_fallback(quantized.max_pool2d)
41
+ lowering.make_fallback(_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16)
42
+ lowering.make_fallback(_quantized.wrapped_fbgemm_linear_fp16_weight)
43
+
44
+
45
+ def register_woq_mm_ops():
46
+ @register_lowering(aten._weight_int8pack_mm, type_promotion_kind=None)
47
+ def int8pack_mm(input, weight, scale, *, layout=None):
48
+ _, _, _, layout, mat1, mat2 = mm_args(
49
+ input, weight, layout=layout, mat2_transposed=True
50
+ )
51
+ assert (
52
+ mat1.get_dtype() in [torch.bfloat16, torch.float16, torch.float]
53
+ and mat2.get_dtype() == torch.int8
54
+ )
55
+ aten_layout = layout
56
+
57
+ # options to tune from
58
+ choices = (
59
+ [aten__weight_int8pack_mm.bind((mat1, mat2, scale), aten_layout)]
60
+ if use_aten_gemm_kernels()
61
+ else []
62
+ )
63
+
64
+ # scale is applied as an epilogue, and the scale tensor is expanded (with a view op)
65
+ # for broadcasting, as it's 1D.
66
+ def _mul_epilogue(buf):
67
+ return create_epilogue_with_attr(
68
+ buf, "mul", other=realize_inputs(expand(scale, layout.size))
69
+ )
70
+
71
+ if use_cpp_packed_gemm_template(aten_layout, mat1, mat2, mat2_transposed=True):
72
+ CppPackedGemmTemplate.add_choices(
73
+ choices,
74
+ aten_layout,
75
+ [mat1, mat2, scale],
76
+ trans_w=True,
77
+ epilogue_creator=_mul_epilogue,
78
+ )
79
+
80
+ if (
81
+ len(choices) == 0
82
+ and inductor_config.autotune_fallback_to_aten
83
+ and not use_aten_gemm_kernels()
84
+ ):
85
+ log.warning("No choices for GEMM, using ATen backend as fallback")
86
+ return aten__weight_int8pack_mm.bind(
87
+ (mat1, mat2, scale), aten_layout
88
+ ).output_node()
89
+
90
+ return autotune_select_algorithm(
91
+ "_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout
92
+ )
.venv/lib/python3.11/site-packages/torch/_inductor/remote_cache.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import typing
6
+ from abc import abstractmethod
7
+ from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union
8
+ from typing_extensions import override, TypeAlias
9
+
10
+ from torch._inductor import config
11
+
12
+
13
+ try:
14
+ import redis
15
+ except ImportError:
16
+ redis = None # type: ignore[assignment]
17
+
18
+
19
+ if config.is_fbcode():
20
+ from rfe.scubadata.scubadata_py3 import ( # type: ignore[import-not-found]
21
+ Sample as Sample_,
22
+ )
23
+
24
+ Sample: TypeAlias = Sample_
25
+ else:
26
+ Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef]
27
+
28
+
29
+ _T = TypeVar("_T")
30
+ _U = TypeVar("_U")
31
+
32
+
33
+ class RemoteCacheBackend(Generic[_T]):
34
+ """
35
+ A backend implementation for accessing a remote/distributed cache. Only
36
+ works with bytes in/out. For structured data use a RemoteCache.
37
+ """
38
+
39
+ @abstractmethod
40
+ def get(self, key: str) -> Optional[_T]:
41
+ pass
42
+
43
+ @abstractmethod
44
+ def put(self, key: str, data: _T) -> None:
45
+ pass
46
+
47
+
48
+ # Serde that encodes from _T to _U and decodes from _U to _T.
49
+ class RemoteCacheSerde(Generic[_T, _U]):
50
+ @abstractmethod
51
+ def encode(self, data: _T) -> _U:
52
+ pass
53
+
54
+ @abstractmethod
55
+ def decode(self, data: _U) -> _T:
56
+ pass
57
+
58
+
59
+ JsonDataTy = Optional[
60
+ Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]]
61
+ ]
62
+
63
+
64
+ class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]):
65
+ def encode(self, data: JsonDataTy) -> bytes:
66
+ return bytes(json.dumps(data), "ascii")
67
+
68
+ def decode(self, data: bytes) -> JsonDataTy:
69
+ return json.loads(data)
70
+
71
+
72
+ class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]):
73
+ def encode(self, data: _T) -> _T:
74
+ return data
75
+
76
+ def decode(self, data: _T) -> _T:
77
+ return data
78
+
79
+
80
+ class RemoteCache(Generic[_T]):
81
+ backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None
82
+
83
+ def __init__(
84
+ self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U]
85
+ ) -> None:
86
+ # Support for testing.
87
+ if (override_cls := self.__class__.backend_override_cls) is not None:
88
+ self.backend = override_cls()
89
+ else:
90
+ self.backend = backend
91
+ self.serde = serde
92
+
93
+ def get(self, key: str) -> Optional[_T]:
94
+ sample = self._create_sample()
95
+ result = self._get(key, sample)
96
+ self._log_sample(sample)
97
+ return result
98
+
99
+ def put(self, key: str, value: _T) -> None:
100
+ sample = self._create_sample()
101
+ self._put(key, value, sample)
102
+ self._log_sample(sample)
103
+
104
+ def _decode(self, data: _U, sample: Optional[Sample]) -> _T:
105
+ return self.serde.decode(data)
106
+
107
+ def _encode(self, value: _T, sample: Optional[Sample]) -> Any: # returns _U
108
+ return self.serde.encode(value)
109
+
110
+ def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]:
111
+ if data := self.backend.get(key):
112
+ return self._decode(data, sample)
113
+ return None
114
+
115
+ def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None:
116
+ data = self._encode(value, sample)
117
+ self.backend.put(key, data)
118
+
119
+ def _create_sample(self) -> Optional[Sample]:
120
+ return None
121
+
122
+ def _log_sample(self, sample: Optional[Sample]) -> None:
123
+ pass
124
+
125
+
126
+ class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]):
127
+ """
128
+ A Redis implementation of a remote/distributed cache.
129
+ """
130
+
131
+ _key_fmt: str
132
+ _redis: Optional[redis.Redis] = None
133
+
134
+ def __init__(self, cache_id: str) -> None:
135
+ if not redis:
136
+ # We had trouble importing redis - just skip init.
137
+ return
138
+
139
+ self._key_fmt = f"pt2:{cache_id}:{{key}}"
140
+ self._redis = redis.Redis(
141
+ host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"),
142
+ port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)),
143
+ )
144
+
145
+ def __get_key(self, key: str) -> str:
146
+ return self._key_fmt.format(key=key)
147
+
148
+ @override
149
+ def get(self, key: str) -> Optional[bytes]:
150
+ if not self._redis:
151
+ # Either redis wasn't found or we already had some trouble...
152
+ return None
153
+
154
+ try:
155
+ value = self._redis.get(self.__get_key(key))
156
+ except redis.exceptions.ConnectionError:
157
+ # Redis is lazy and doesn't actually attempt to connect until the
158
+ # first use. Mark is as unavailable now.
159
+ self._redis = None
160
+ return None
161
+
162
+ # In theory redis.get() can return an Awaitable as well...
163
+ assert value is None or isinstance(value, bytes)
164
+ return value
165
+
166
+ @override
167
+ def put(self, key: str, data: bytes) -> None:
168
+ if not self._redis:
169
+ # Either redis wasn't found or we already had some trouble...
170
+ return
171
+
172
+ try:
173
+ self._redis.set(self.__get_key(key), data)
174
+ except redis.exceptions.ConnectionError:
175
+ # Redis is lazy and doesn't actually attempt to connect until the
176
+ # first use. Mark is as unavailable now.
177
+ self._redis = None
178
+
179
+
180
+ class RedisRemoteCache(RemoteCache[JsonDataTy]):
181
+ def __init__(self, key: str) -> None:
182
+ # Special test handling: If we're just going to override the backend
183
+ # anyway don't require redis
184
+ if self.__class__.backend_override_cls:
185
+ # This is totally bogus but it works for now...
186
+ backend = typing.cast(RemoteCacheBackend[bytes], None)
187
+ else:
188
+ backend = RedisRemoteCacheBackend(key)
189
+ serde = RemoteCacheJsonSerde()
190
+ super().__init__(backend, serde)
191
+
192
+
193
+ class RemoteAutotuneCache(RedisRemoteCache):
194
+ pass
195
+
196
+
197
+ class RemoteFxGraphCache(RedisRemoteCache):
198
+ pass
.venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py ADDED
@@ -0,0 +1,1743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import builtins
3
+ import contextlib
4
+ import functools
5
+ import inspect
6
+ import itertools
7
+ import json
8
+ import logging
9
+ import math
10
+ import operator
11
+ import os
12
+ import sys
13
+ import textwrap
14
+ import time
15
+ from collections import namedtuple
16
+ from concurrent.futures import as_completed, ThreadPoolExecutor
17
+ from io import StringIO
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+ from unittest.mock import patch
20
+
21
+ import sympy
22
+ from filelock import FileLock
23
+
24
+ import torch
25
+ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
26
+ from torch._dynamo.testing import rand_strided
27
+ from torch._dynamo.utils import counters, identity, preserve_rng_state
28
+
29
+ from . import config, ir
30
+ from .autotune_process import TensorMeta, TritonBenchmarkRequest
31
+ from .codecache import code_hash, PersistentCache, PyCodeCache
32
+ from .codegen.common import IndentedBuffer, KernelTemplate
33
+ from .codegen.triton import (
34
+ gen_common_triton_imports,
35
+ texpr,
36
+ TritonKernel,
37
+ TritonPrinter,
38
+ TritonScheduling,
39
+ )
40
+ from .codegen.triton_utils import config_of, signature_to_meta
41
+ from .exc import CUDACompileError
42
+ from .ir import ChoiceCaller, PrimitiveInfoType
43
+ from .runtime.benchmarking import benchmarker
44
+ from .runtime.hints import DeviceProperties
45
+ from .utils import (
46
+ FakeIndentedBuffer,
47
+ get_dtype_size,
48
+ Placeholder,
49
+ restore_stdout_stderr,
50
+ sympy_dot,
51
+ sympy_index_symbol,
52
+ sympy_product,
53
+ unique,
54
+ )
55
+ from .virtualized import V
56
+
57
+
58
+ log = logging.getLogger(__name__)
59
+
60
+ # correctness checks struggle with fp16/tf32
61
+ VERIFY: Dict[str, Any] = {}
62
+ PRINT_AUTOTUNE = True
63
+ DEBUG = False
64
+
65
+
66
+ class KernelNamespace:
67
+ pass
68
+
69
+
70
+ # these objects are imported from the generated wrapper code
71
+ extern_kernels = KernelNamespace()
72
+
73
+
74
+ class PartialRender:
75
+ """
76
+ Some parts of a template need to be generated at the end, but
77
+ inserted into the template at the start. This allows doing a bunch
78
+ of replacements after the initial render.
79
+ """
80
+
81
+ def __init__(self, code, replacement_hooks) -> None:
82
+ super().__init__()
83
+ self.code = code
84
+ self.replacement_hooks = replacement_hooks
85
+
86
+ def finalize_hook(self, hook_key: str, strict=True) -> None:
87
+ if hook_key not in self.replacement_hooks:
88
+ if strict:
89
+ raise RuntimeError(
90
+ f"{hook_key} not registered in self.replacement_hooks"
91
+ )
92
+ else:
93
+ return
94
+ assert (
95
+ self.replacement_hooks[hook_key] is not None
96
+ ), "hook_key can only be called once"
97
+ self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]())
98
+ self.replacement_hooks[hook_key] = None
99
+
100
+ def finalize_all(self) -> str:
101
+ for key, fn in self.replacement_hooks.items():
102
+ self.code = self.code.replace(key, fn())
103
+ return self.code
104
+
105
+
106
+ # This is used to store info needed for lowering each subgraph in triton
107
+ # templates
108
+ SubgraphInfo = namedtuple(
109
+ "SubgraphInfo",
110
+ [
111
+ "body",
112
+ "template_mask",
113
+ "template_out",
114
+ ],
115
+ )
116
+
117
+
118
+ class TritonTemplateKernel(TritonKernel):
119
+ def __init__(
120
+ self,
121
+ kernel_name,
122
+ input_nodes,
123
+ output_node,
124
+ defines,
125
+ num_stages,
126
+ num_warps,
127
+ grid_fn,
128
+ meta,
129
+ call_sizes,
130
+ use_jit=False,
131
+ prefix_args=0,
132
+ suffix_args=0,
133
+ epilogue_fn=identity,
134
+ subgraphs: Optional[List[ir.ComputedBuffer]] = None,
135
+ *,
136
+ index_dtype,
137
+ ) -> None:
138
+ super().__init__(
139
+ sympy_product(output_node.get_size()),
140
+ sympy.Integer(1),
141
+ index_dtype=index_dtype,
142
+ )
143
+ self.input_nodes = input_nodes
144
+ self.output_node = output_node
145
+ self.named_input_nodes = {} # type: ignore[var-annotated]
146
+ self.defines = defines
147
+ self.kernel_name = kernel_name
148
+ self.use_jit = use_jit
149
+ self.num_stages = num_stages
150
+ self.num_warps = num_warps
151
+ self.grid_fn = grid_fn
152
+ self.meta = meta
153
+ self.call_sizes = call_sizes
154
+ # for templates with fixed epilogues
155
+ self.prefix_args = prefix_args
156
+ self.suffix_args = suffix_args
157
+ self.epilogue_fn = epilogue_fn
158
+ self.render_hooks = {} # type: ignore[var-annotated]
159
+ self.triton_meta: Optional[Dict[str, object]] = None
160
+ # For Templated Attention this can be a list of ir.Subgraph
161
+ self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs
162
+
163
+ # The following attributes (body, template_mask, output_val) are all
164
+ # used for triton kernel codegen.
165
+ # They are swapped onto the TritonTemplateKernel object by
166
+ # `set_subgraph_body`
167
+ self.subgraph_bodies: Dict[str, SubgraphInfo] = {}
168
+
169
+ self.body: IndentedBuffer = FakeIndentedBuffer()
170
+ self.template_mask: Optional[str] = None
171
+ self.template_out: Optional[str] = None
172
+
173
+ @contextlib.contextmanager
174
+ def set_subgraph_body(self, body_name: str):
175
+ old_body, old_mask, old_out = self.body, self.template_mask, self.template_out
176
+ assert body_name in self.subgraph_bodies, body_name
177
+ self.body, self.template_mask, self.template_out = self.subgraph_bodies[
178
+ body_name
179
+ ]
180
+ yield
181
+ self.subgraph_bodies[body_name] = SubgraphInfo(
182
+ self.body, self.template_mask, self.template_out
183
+ )
184
+ self.body, self.template_mask, self.template_out = old_body, old_mask, old_out
185
+
186
+ @contextlib.contextmanager
187
+ def create_subgraph_body(self, body_name: str):
188
+ assert body_name not in self.subgraph_bodies
189
+ self.subgraph_bodies[body_name] = SubgraphInfo(IndentedBuffer(), None, None)
190
+ with self.set_subgraph_body(body_name):
191
+ yield
192
+
193
+ def need_numel_args(self):
194
+ return False
195
+
196
+ def estimate_kernel_num_bytes(self):
197
+ """
198
+ Estimate the total number of bytes this kernel takes.
199
+ For in/out nodes, sizes are counted twice: once for reading and
200
+ once for writing.
201
+ """
202
+ ninplace_args = len(unique(self.args.inplace_buffers.values()))
203
+ num_bytes = []
204
+ for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))):
205
+ size = V.graph.sizevars.size_hints(inp.get_size())
206
+ numel = functools.reduce(operator.mul, size, 1)
207
+ dtype_size = get_dtype_size(inp.get_dtype())
208
+ num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
209
+ return sum(num_bytes)
210
+
211
+ def jit_lines(self):
212
+ if self.use_jit:
213
+ return "@triton.jit"
214
+
215
+ argdefs, _, signature, _ = self.args.python_argdefs()
216
+ triton_meta = {
217
+ "signature": signature_to_meta(signature, size_dtype=self.index_dtype),
218
+ "device": DeviceProperties.create(self.output_node.get_device()),
219
+ "constants": {},
220
+ }
221
+ triton_meta["configs"] = [config_of(signature)]
222
+ for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index]
223
+ triton_meta["constants"][arg_num] = 1 # type: ignore[index]
224
+ matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0)
225
+ if matrix_instr_nonkdim != 0:
226
+ triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim
227
+
228
+ self.triton_meta = triton_meta
229
+
230
+ inductor_meta = {
231
+ "kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
232
+ **TritonKernel.inductor_meta_common(),
233
+ }
234
+ if config.profile_bandwidth or config.benchmark_kernel:
235
+ num_gb = self.estimate_kernel_num_bytes() / 1e9
236
+ inductor_meta["kernel_num_gb"] = num_gb
237
+ return f"""
238
+ @triton_heuristics.template(
239
+ num_stages={self.num_stages},
240
+ num_warps={self.num_warps},
241
+ triton_meta={triton_meta!r},
242
+ inductor_meta={inductor_meta!r},
243
+ )
244
+ @triton.jit
245
+ """
246
+
247
+ def gen_argdefs(self):
248
+ def hook():
249
+ # python_argdefs() cannot be run until after the rest of the template lazily adds more args
250
+ arg_defs, *_ = self.args.python_argdefs()
251
+ return f"{', '.join(arg_defs)}"
252
+
253
+ self.render_hooks["<ARGDEFS>"] = hook
254
+ return "<ARGDEFS>"
255
+
256
+ def gen_defines(self):
257
+ return self.defines
258
+
259
+ def def_kernel(self, *argnames):
260
+ """
261
+ Hook called from template code to generate function def and
262
+ needed args.
263
+ """
264
+ assert all(isinstance(x, str) for x in argnames)
265
+ renames = IndentedBuffer(initial_indent=1)
266
+
267
+ named_args = self.input_nodes[
268
+ self.prefix_args : len(self.input_nodes) - self.suffix_args
269
+ ]
270
+
271
+ assert len(argnames) == len(named_args), (
272
+ len(argnames),
273
+ len(named_args),
274
+ self.prefix_args,
275
+ len(self.input_nodes),
276
+ )
277
+
278
+ for input_node in self.input_nodes[: self.prefix_args]:
279
+ # get args in correct order
280
+ self.args.input(input_node.get_name())
281
+
282
+ for name, input_node in zip(argnames, named_args):
283
+ arg_name = f"arg_{name}"
284
+ self.named_input_nodes[name] = input_node
285
+ self.args.input_buffers[input_node.get_name()] = arg_name
286
+
287
+ # The args may be duplicated, so renaming must be after args are de-duplicated.
288
+ for name in argnames:
289
+ input_node = self.named_input_nodes[name]
290
+ arg_name = self.args.input_buffers[input_node.get_name()]
291
+ if input_node.get_layout().offset == 0:
292
+ renames.writeline(f"{name} = {arg_name}")
293
+ else:
294
+ offset = texpr(self.rename_indexing(input_node.get_layout().offset))
295
+ renames.writeline(f"{name} = {arg_name} + {offset}")
296
+
297
+ for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]:
298
+ # get args in correct order
299
+ self.args.input(input_node.get_name())
300
+
301
+ def hook():
302
+ # python_argdefs() cannot be run until after the rest of the template lazily adds more args
303
+ arg_defs, *_ = self.args.python_argdefs()
304
+ code = IndentedBuffer()
305
+ code.splice(gen_common_triton_imports())
306
+ code.splice(self.jit_lines())
307
+ code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):")
308
+ with code.indent():
309
+ code.splice(self.defines)
310
+ code.splice(renames.getvalue())
311
+ return code.getvalue()
312
+
313
+ assert "<DEF_KERNEL>" not in self.render_hooks
314
+ self.render_hooks["<DEF_KERNEL>"] = hook
315
+ return "<DEF_KERNEL>"
316
+
317
+ def size(self, name: str, index: int):
318
+ """
319
+ Hook called from template code to get the size of an arg.
320
+ Will add needed args to pass it in if it is dynamic.
321
+ """
322
+ assert isinstance(index, int)
323
+ if name is None:
324
+ val = self.output_node.get_size()[index]
325
+ else:
326
+ assert isinstance(name, str)
327
+ val = self.named_input_nodes[name].get_size()[index]
328
+ return texpr(self.rename_indexing(val))
329
+
330
+ def stride(self, name, index=None):
331
+ """
332
+ Hook called from template code to get the stride of an arg.
333
+ Will add needed args to pass it in if it is dynamic.
334
+ """
335
+ if name is None:
336
+ val = self.output_node.get_stride()
337
+ else:
338
+ assert isinstance(name, str)
339
+ val = self.named_input_nodes[name].get_stride()
340
+
341
+ if isinstance(index, int):
342
+ return texpr(self.rename_indexing(val[index]))
343
+ else:
344
+ return ", ".join([texpr(self.rename_indexing(i)) for i in val])
345
+
346
+ def modification(
347
+ self, subgraph_number: int, output_name: str, **fixed_inputs
348
+ ) -> str:
349
+ """This creates a modification function for a subgraph.
350
+ To use this inside a template, the first argument should specify which subgraph to codegen for
351
+
352
+ Args:
353
+ subgraph_number (int): The index of the subgraph in self.subgraphs
354
+ """
355
+ num = 0
356
+ while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies:
357
+ num += 1
358
+ with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"):
359
+ assert isinstance(subgraph_number, int)
360
+ assert isinstance(self.subgraphs, list)
361
+ assert (
362
+ self.body.getvalue() == ""
363
+ ), "Body should be clear before adding a modification"
364
+ assert subgraph_number < len(
365
+ self.subgraphs
366
+ ), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
367
+
368
+ subgraph = self.subgraphs[subgraph_number]
369
+
370
+ def add_input(name):
371
+ return self.args.input(name)
372
+
373
+ name = f"PlaceholderSubstitution_{subgraph_number}"
374
+
375
+ class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined]
376
+ self.name = name
377
+
378
+ def load(self, name: str, index: sympy.Expr):
379
+ if name not in fixed_inputs:
380
+ # If it's not a fixed input, it's a load from a captured
381
+ # tensor
382
+ var = add_input(name)
383
+ return f"tl.load({var} + {index})"
384
+
385
+ return f"({fixed_inputs[name]})"
386
+
387
+ def indirect_indexing(self, index_var, size, check, wrap_neg=True):
388
+ return sympy_index_symbol(str(index_var))
389
+
390
+ with V.set_ops_handler(PlaceholderSubstitution(V.ops)):
391
+ assert isinstance(
392
+ subgraph, ir.ComputedBuffer
393
+ ), f"Expected the subgraph to be a ComputedBuffer, got {type(subgraph)}"
394
+ if isinstance(subgraph.data, ir.InputBuffer):
395
+ out = subgraph.data.make_loader()(())
396
+ else:
397
+ out = subgraph.data.inner_fn(())
398
+
399
+ self.codegen_body()
400
+ self.body.writeline(f"{output_name} = {out.value}")
401
+
402
+ body_val = self.body.getvalue()
403
+ self.cse.invalidate(set()) # type: ignore[arg-type]
404
+ return body_val
405
+
406
+ def store_output(
407
+ self,
408
+ indices: Union[List[Any], Tuple[Any]],
409
+ val: str,
410
+ mask: Optional[str] = None,
411
+ indent_width: int = 4,
412
+ ):
413
+ """Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away.
414
+
415
+ Args:
416
+ indices (Union[List, Tuple]): The index for each dimension of the output. The dot product of
417
+ these indices and output strides must match `val`.
418
+ val (str): The value to store.
419
+ mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask
420
+ will be applied to the store.
421
+ indent_width (int): The number of spaces to use for indentation. This is used when the call to
422
+ store_output is indented in the kernel definition.
423
+ """
424
+ with self.create_subgraph_body("<STORE_OUTPUT>"):
425
+ assert isinstance(indices, (list, tuple))
426
+ assert isinstance(val, str)
427
+ assert isinstance(mask, (str, type(None)))
428
+ assert self.template_mask is None
429
+ indices = list(map(TritonPrinter.paren, indices))
430
+ index_symbols = [sympy.Symbol(x, integer=True) for x in indices]
431
+ lengths = [
432
+ V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
433
+ ]
434
+ assert len(indices) == len(lengths)
435
+
436
+ # glue to make generated code use same indexing from template
437
+ for name, range_tree_entry in zip(
438
+ indices, self.range_trees[0].construct_entries(lengths)
439
+ ):
440
+ range_tree_entry.set_name(name)
441
+ contiguous_index = sympy_dot(
442
+ ir.FlexibleLayout.contiguous_strides(lengths), index_symbols
443
+ )
444
+ contiguous_index = self.rename_indexing(contiguous_index)
445
+ self.body.writeline("xindex = " + texpr(contiguous_index))
446
+ self.range_trees[0].lookup(
447
+ sympy.Integer(1), sympy_product(lengths)
448
+ ).set_name("xindex")
449
+ self.template_mask = mask
450
+ self.template_out = val
451
+ self.template_indices = indices
452
+ output_index = self.output_node.get_layout().make_indexer()(index_symbols)
453
+ output_index = self.rename_indexing(output_index)
454
+ if output_index == contiguous_index:
455
+ output_index = sympy.Symbol("xindex", integer=True)
456
+
457
+ epilogue_args = [val]
458
+ for input_node in itertools.chain(
459
+ self.input_nodes[: self.prefix_args],
460
+ self.input_nodes[len(self.input_nodes) - self.suffix_args :],
461
+ ):
462
+ input_node.freeze_layout()
463
+ epilogue_args.append(input_node.make_loader()(index_symbols))
464
+
465
+ V.ops.store(
466
+ self.output_node.get_name(),
467
+ output_index,
468
+ self.epilogue_fn(*epilogue_args),
469
+ )
470
+ self.codegen_body()
471
+
472
+ def hook():
473
+ # more stuff might have been added since the codegen_body above
474
+ self.codegen_body()
475
+
476
+ return textwrap.indent(self.body.getvalue(), " " * indent_width).strip()
477
+
478
+ assert "<STORE_OUTPUT>" not in self.render_hooks
479
+ self.render_hooks["<STORE_OUTPUT>"] = hook
480
+ return "<STORE_OUTPUT>"
481
+
482
+ def render(self, template, kwargs):
483
+ return PartialRender(
484
+ template.render(**self.template_env(), **kwargs),
485
+ self.render_hooks,
486
+ )
487
+
488
+ def make_load(self, name, indices, mask):
489
+ """
490
+ Optional helper called from template code to generate the code
491
+ needed to load from an tensor.
492
+ """
493
+ assert isinstance(indices, (list, tuple))
494
+ assert isinstance(name, str)
495
+ assert isinstance(mask, str)
496
+ stride = self.named_input_nodes[name].get_stride()
497
+ indices = list(map(TritonPrinter.paren, indices))
498
+ assert len(indices) == len(stride)
499
+ index = " + ".join(
500
+ f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)
501
+ )
502
+ return f"tl.load({name} + ({index}), {mask}, other=0.0)"
503
+
504
+ def template_env(self):
505
+ """
506
+ Generate the namespace visible in the template.
507
+ """
508
+ return {
509
+ fn.__name__: fn
510
+ for fn in [
511
+ self.def_kernel,
512
+ self.size,
513
+ self.stride,
514
+ self.store_output,
515
+ self.make_load,
516
+ self.modification,
517
+ self.gen_argdefs,
518
+ self.gen_defines,
519
+ ]
520
+ }
521
+
522
+ def indexing(
523
+ self,
524
+ index: sympy.Expr,
525
+ *,
526
+ dense_indexing=False,
527
+ copy_shape=None,
528
+ override_mask=None,
529
+ block_ptr=False,
530
+ ):
531
+ """
532
+ Override the default indexing to use our custom mask and force
533
+ dense indexing.
534
+ """
535
+ return super().indexing(
536
+ index,
537
+ dense_indexing=False,
538
+ # We pass template_out as the shape to broadcast the indexing to as
539
+ # the mask might be broadcast to the output shape
540
+ copy_shape=self.template_out,
541
+ override_mask=self.template_mask,
542
+ block_ptr=block_ptr,
543
+ )
544
+
545
+ def codegen_range_tree(self):
546
+ pass # ignore default codegen
547
+
548
+ def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
549
+ wrapper = V.graph.wrapper_code
550
+ _, call_args, _, arg_types = self.args.python_argdefs()
551
+ if V.graph.cpp_wrapper:
552
+ # In the cpp_wrapper case, we have to compute CUDA launch grid at runtime
553
+ # if any dynamic dimension is involved. We rely on the Python version
554
+ # of the grid function to generate those grid configs, which may contain
555
+ # symbolic values. The wrapper will use cexpr to print out C++ code
556
+ # appropriately for the grid configs.
557
+ grid = self.call_sizes + [self.meta]
558
+ wrapper.generate_kernel_call(
559
+ name,
560
+ call_args,
561
+ grid=self.grid_fn(*grid),
562
+ arg_types=arg_types,
563
+ triton_meta=self.triton_meta,
564
+ )
565
+ else:
566
+ wrapper.add_import_once(f"import {self.grid_fn.__module__}")
567
+ meta = wrapper.add_meta_once(self.meta)
568
+ grid = self.call_sizes + [meta]
569
+ wrapper.generate_kernel_call(
570
+ name,
571
+ call_args,
572
+ grid=grid,
573
+ grid_fn=f"{self.grid_fn.__module__}.{self.grid_fn.__name__}",
574
+ arg_types=arg_types,
575
+ triton_meta=self.triton_meta,
576
+ )
577
+
578
+
579
+ @functools.lru_cache(None)
580
+ def _jinja2_env():
581
+ try:
582
+ import jinja2
583
+
584
+ return jinja2.Environment(
585
+ undefined=jinja2.StrictUndefined,
586
+ )
587
+ except ImportError:
588
+ return None
589
+
590
+
591
+ class TritonTemplate(KernelTemplate):
592
+ index_counter = itertools.count()
593
+ all_templates: Dict[str, "TritonTemplate"] = {}
594
+
595
+ def __init__(self, name: str, grid: Any, source: str, debug=False) -> None:
596
+ super().__init__(name)
597
+ self.grid = grid
598
+ self.template = self._template_from_string(source)
599
+ assert name not in self.all_templates, "duplicate template name"
600
+ self.all_templates[name] = self
601
+ self.debug = debug
602
+
603
+ def generate( # type: ignore[override]
604
+ self,
605
+ input_nodes,
606
+ layout,
607
+ num_stages,
608
+ num_warps,
609
+ prefix_args=0,
610
+ suffix_args=0,
611
+ epilogue_fn=identity,
612
+ subgraphs=None,
613
+ mutated_inputs=None,
614
+ call_sizes=None,
615
+ **kwargs,
616
+ ):
617
+ """This function generates a TritonTemplateCaller
618
+
619
+ Args:
620
+ input_nodes: List of input nodes
621
+ layout: Output layout
622
+ num_stages: Number of stages for triton launch
623
+ num_warps: Number of warps for triton launch
624
+ prefix_args: Number of input nodes to be passed as arguments
625
+ suffix_args: Number of input nodes to be passed as arguments
626
+ epilogue_fn: Optional epilogue function to be called on the output
627
+ subgraphs: Optional subgraphs to be passed as arguments, these will be inlined
628
+ into the triton template string
629
+ mutated_inputs: Optional list of input nodes that are mutated by the kernel, this is helpful
630
+ if you need to return multiple outputs. You can pass them as inputs and mark them as
631
+ being mutated by the kernel.
632
+ """
633
+ assert self.template, "requires jinja2"
634
+ defines = StringIO()
635
+ for name, val in kwargs.items():
636
+ defines.write(f"{name} : tl.constexpr = {val}\n")
637
+ defines = defines.getvalue()
638
+
639
+ fake_out = ir.Buffer("buf_out", layout)
640
+ kernel_name = f"triton_{self.name}"
641
+
642
+ numel = sympy_product(layout.size)
643
+ buffers = itertools.chain(input_nodes, (fake_out,))
644
+ if not TritonScheduling.can_use_32bit_indexing(numel, buffers):
645
+ raise NotImplementedError(
646
+ "64-bit indexing is not yet implemented for triton templates"
647
+ )
648
+
649
+ if call_sizes is None:
650
+ call_sizes = layout.size
651
+
652
+ kernel_options = dict(
653
+ input_nodes=input_nodes,
654
+ defines=defines,
655
+ num_stages=num_stages,
656
+ num_warps=num_warps,
657
+ grid_fn=self.grid,
658
+ meta=kwargs,
659
+ call_sizes=call_sizes,
660
+ prefix_args=prefix_args,
661
+ suffix_args=suffix_args,
662
+ epilogue_fn=epilogue_fn,
663
+ index_dtype="tl.int32",
664
+ subgraphs=subgraphs,
665
+ )
666
+
667
+ with patch.object(
668
+ V.graph, "get_dtype", self._fake_get_dtype(fake_out)
669
+ ), TritonTemplateKernel(
670
+ kernel_name=kernel_name,
671
+ output_node=fake_out,
672
+ use_jit=False,
673
+ **kernel_options,
674
+ ) as kernel:
675
+ try:
676
+ template = kernel.render(self.template, kwargs)
677
+ with kernel.set_subgraph_body("<STORE_OUTPUT>"):
678
+ code = template.finalize_all()
679
+ except ZeroDivisionError:
680
+ # TODO(nmacchioni): fix sympy division by zero
681
+ return None
682
+ if self.debug:
683
+ print("Generated Code:\n", code)
684
+ extra = (
685
+ "-".join(
686
+ [
687
+ *[
688
+ f"{kwarg}={repr(kwargs[kwarg])}"
689
+ for kwarg in sorted(kwargs.keys())
690
+ ],
691
+ f"num_stages={num_stages}",
692
+ f"num_warps={num_warps}",
693
+ ]
694
+ )
695
+ + "-"
696
+ )
697
+ mod = PyCodeCache.load(code, extra)
698
+
699
+ input_call_args = tuple(kernel.args.input_buffers.keys())
700
+ output_call_args = tuple(kernel.args.output_buffers.keys())
701
+
702
+ # We expect the input_buffer order to be [*input_nodes, *captured_buffers]
703
+ expected_input_args = tuple(unique(x.get_name() for x in input_nodes))
704
+ expected_output_args = (fake_out.get_name(),)
705
+ assert input_call_args[: len(expected_input_args)] == expected_input_args, (
706
+ input_call_args,
707
+ expected_input_args,
708
+ )
709
+ assert output_call_args == expected_output_args, (
710
+ output_call_args,
711
+ expected_output_args,
712
+ )
713
+
714
+ full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args])
715
+ extra_args = V.graph.sizevars.size_hints(
716
+ map(sympy.expand, tuple(kernel.args.sizevars.keys())),
717
+ fallback=config.unbacked_symint_fallback,
718
+ )
719
+
720
+ kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
721
+
722
+ def make_kernel_render(out_node):
723
+ kernel = TritonTemplateKernel(
724
+ kernel_name=str(Placeholder.KERNEL_NAME),
725
+ output_node=out_node,
726
+ use_jit=False,
727
+ **kernel_options,
728
+ )
729
+ render = functools.partial(
730
+ kernel.render,
731
+ self.template,
732
+ kwargs,
733
+ )
734
+ return kernel, render
735
+
736
+ # create the BenchmarkRequest
737
+ assert mod.__file__ is not None
738
+ grid = self.grid(
739
+ *V.graph.sizevars.size_hints(
740
+ call_sizes,
741
+ fallback=config.unbacked_symint_fallback,
742
+ ),
743
+ kwargs,
744
+ )
745
+ bmreq = TritonBenchmarkRequest(
746
+ module_path=mod.__file__,
747
+ module_cache_key=mod.key,
748
+ kernel_name=kernel_name,
749
+ grid=grid,
750
+ extra_args=extra_args,
751
+ num_stages=num_stages,
752
+ num_warps=num_warps,
753
+ matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
754
+ input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), # type: ignore[arg-type]
755
+ output_tensor_meta=TensorMeta.from_irnodes(layout),
756
+ )
757
+
758
+ return TritonTemplateCaller(
759
+ kernel_hash_name,
760
+ full_input_nodes,
761
+ layout,
762
+ make_kernel_render,
763
+ extra.strip("-").replace("-", ", "),
764
+ bmreq,
765
+ log_info={
766
+ "tile_shape": str(
767
+ (
768
+ kwargs.get("BLOCK_M", -1),
769
+ kwargs.get("BLOCK_K", -1),
770
+ kwargs.get("BLOCK_N", -1),
771
+ )
772
+ ),
773
+ "num_stages": num_stages,
774
+ "num_warps": num_warps,
775
+ "allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
776
+ "acc_type": str(kwargs.get("ACC_TYPE", None)),
777
+ },
778
+ mutated_inputs=mutated_inputs,
779
+ )
780
+
781
+
782
+ class ExternKernelChoice:
783
+ def __init__(
784
+ self,
785
+ kernel,
786
+ cpp_kernel=None,
787
+ *,
788
+ name=None,
789
+ has_out_variant=True,
790
+ op_overload=None,
791
+ use_fallback_kernel=False,
792
+ kernel_creator=None,
793
+ ) -> None:
794
+ super().__init__()
795
+ name = name or kernel.__name__
796
+ assert callable(kernel)
797
+ assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}"
798
+ self.name = name
799
+ self.cpp_kernel_name = cpp_kernel
800
+ self.has_out_variant = has_out_variant
801
+ setattr(extern_kernels, name, kernel)
802
+ self.op_overload = op_overload
803
+ self.use_fallback_kernel = use_fallback_kernel
804
+ self.kernel_creator = kernel_creator
805
+
806
+ def to_callable(self):
807
+ return getattr(extern_kernels, self.name)
808
+
809
+ def call_name(self):
810
+ return f"extern_kernels.{self.name}"
811
+
812
+ @functools.lru_cache(None) # noqa: B019
813
+ def hash_key(self):
814
+ fn = self.to_callable()
815
+ parts = [
816
+ self.name,
817
+ getattr(fn, "__name__", ""),
818
+ getattr(fn, "__module__", ""),
819
+ ]
820
+ try:
821
+ parts.append(inspect.getsource(fn))
822
+ except Exception:
823
+ pass
824
+ return code_hash("-".join(parts))
825
+
826
+ def bind(
827
+ self,
828
+ input_nodes,
829
+ layout,
830
+ ordered_kwargs_for_cpp_kernel=(),
831
+ **kwargs,
832
+ ):
833
+ self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
834
+ return ExternKernelCaller(
835
+ self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant
836
+ )
837
+
838
+
839
+ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
840
+ def __init__(
841
+ self,
842
+ name,
843
+ input_nodes,
844
+ layout,
845
+ make_kernel_render,
846
+ debug_extra,
847
+ bmreq,
848
+ log_info: Optional[
849
+ Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]
850
+ ] = None,
851
+ mutated_inputs=None,
852
+ ) -> None:
853
+ super().__init__(name, input_nodes, layout)
854
+ self.make_kernel_render = make_kernel_render
855
+ self.debug_extra = debug_extra
856
+ self.bmreq: TritonBenchmarkRequest = bmreq
857
+ if log_info is None:
858
+ log_info = {}
859
+ self.log_info: Dict[str, Any] = log_info
860
+ self.log_info.update(
861
+ {
862
+ "backend": "Triton",
863
+ "grid": str(self.bmreq.grid),
864
+ "num_stages": self.bmreq.num_stages,
865
+ "num_warps": self.bmreq.num_warps,
866
+ }
867
+ )
868
+ self.mutated_inputs = mutated_inputs
869
+
870
+ def benchmark(self, *args, out):
871
+ assert self.bmreq is not None
872
+ return self.bmreq.benchmark(*args, output_tensor=out)
873
+
874
+ def precompile(self):
875
+ assert self.bmreq is not None
876
+ self.bmreq.precompile()
877
+
878
+ def __str__(self) -> str:
879
+ return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})"
880
+
881
+ def call_name(self):
882
+ return f"template_kernels.{self.name}"
883
+
884
+ def hash_key(self):
885
+ return "-".join(
886
+ [
887
+ self.name.rsplit("_", 1)[0],
888
+ self.bmreq.module_cache_key,
889
+ ]
890
+ )
891
+
892
+ def output_node(self):
893
+ return ir.TensorBox.create(
894
+ ir.TritonTemplateBuffer(
895
+ layout=self.layout,
896
+ inputs=self.input_nodes,
897
+ make_kernel_render=self.make_kernel_render,
898
+ debug_extra=self.debug_extra,
899
+ mutated_inputs=self.mutated_inputs,
900
+ )
901
+ )
902
+
903
+ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
904
+ """Information returned here is logged to the autotune log file when that is enabled."""
905
+ return self.log_info
906
+
907
+ def get_make_kernel_render(self):
908
+ return self.make_kernel_render
909
+
910
+ def autoheuristic_id(self):
911
+ type_name = "triton"
912
+ info = self.info_dict()
913
+ # TODO(AlnisM): Does tile_shape always exist?
914
+ tile = info["tile_shape"]
915
+ tile_vals = eval(tile) # type: ignore[arg-type]
916
+ BLOCK_M = tile_vals[0]
917
+ BLOCK_K = tile_vals[1]
918
+ BLOCK_N = tile_vals[2]
919
+ num_stages = info["num_stages"]
920
+ num_warps = info["num_warps"]
921
+ return f"type={type_name}_BLOCK-M={BLOCK_M}_BLOCK-K={BLOCK_K}_BLOCK-N={BLOCK_N}_numstages={num_stages}_numwarps={num_warps}"
922
+
923
+
924
+ class ExternKernelCaller(ChoiceCaller):
925
+ def __init__(
926
+ self,
927
+ choice: ExternKernelChoice,
928
+ input_nodes,
929
+ layout,
930
+ kwargs=None,
931
+ *,
932
+ has_out_variant=True,
933
+ ) -> None:
934
+ super().__init__(choice.name, input_nodes, layout)
935
+ self.choice = choice
936
+ self.kwargs = kwargs or {}
937
+ self.has_out_variant = has_out_variant
938
+
939
+ def __str__(self) -> str:
940
+ return f"ExternKernelCaller({self.choice.call_name()})"
941
+
942
+ def benchmark(self, *args, out):
943
+ if out.numel() == 0:
944
+ # no need to run the kerrnel of do benchmarking
945
+ return 0.0
946
+ if self.has_out_variant:
947
+ return super().benchmark(*args, out=out)
948
+ else:
949
+ algo = self.to_callable()
950
+ out_new = algo(*args)
951
+ torch._C._dynamo.guards.assert_size_stride(
952
+ out_new, tuple(out.size()), tuple(out.stride())
953
+ )
954
+ out.copy_(out_new) # for correctness checking
955
+ return benchmarker.benchmark(algo, args, {})
956
+
957
+ def to_callable(self):
958
+ fn = self.choice.to_callable()
959
+ if self.kwargs:
960
+ return functools.partial(fn, **self.kwargs)
961
+ else:
962
+ return fn
963
+
964
+ def hash_key(self):
965
+ return "-".join(
966
+ [
967
+ self.choice.name,
968
+ *[
969
+ f"{kwarg}={repr(self.kwargs[kwarg])}"
970
+ for kwarg in sorted(self.kwargs.keys())
971
+ ],
972
+ self.choice.hash_key(),
973
+ ]
974
+ )
975
+
976
+ def output_node(self):
977
+ if config.abi_compatible and self.choice.use_fallback_kernel:
978
+ assert (
979
+ self.choice.op_overload is not None
980
+ ), "Please provide an op_overload to use ir.FallbackKernel"
981
+ inner = ir.FallbackKernel.create(
982
+ self.choice.op_overload, *self.input_nodes, **self.kwargs
983
+ )
984
+ elif self.choice.kernel_creator is not None:
985
+ inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs)
986
+ else:
987
+ cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc
988
+ inner = cls(
989
+ layout=self.layout,
990
+ inputs=self.input_nodes,
991
+ python_kernel_name=self.choice.call_name(),
992
+ cpp_kernel_name=self.choice.cpp_kernel_name,
993
+ ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel,
994
+ op_overload=self.choice.op_overload,
995
+ kwargs=self.kwargs,
996
+ )
997
+
998
+ return ir.TensorBox.create(inner)
999
+
1000
+ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
1001
+ """Information returned here is logged to the autotune log file when that is enabled."""
1002
+ return {
1003
+ "backend": "extern",
1004
+ "kernel_call_name": self.choice.call_name(),
1005
+ }
1006
+
1007
+ def autoheuristic_id(self):
1008
+ return f"extern_{self.choice.name}"
1009
+
1010
+
1011
+ @functools.lru_cache(None)
1012
+ def get_mm_log_filename() -> Optional[str]:
1013
+ mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None)
1014
+ if not mm_file_name:
1015
+ return None
1016
+
1017
+ if "json" not in mm_file_name:
1018
+ mm_file_name = f"{mm_file_name}.json"
1019
+
1020
+ return mm_file_name
1021
+
1022
+
1023
+ def append_to_log(filename, data):
1024
+ lock_file = filename.replace(".json", ".lock")
1025
+ lock = FileLock(lock_file)
1026
+ with lock:
1027
+ try:
1028
+ with open(filename) as f:
1029
+ log_data = json.load(f)
1030
+ except (FileNotFoundError, json.JSONDecodeError):
1031
+ log_data = []
1032
+
1033
+ log_data.append(data)
1034
+
1035
+ with open(filename, "w") as f:
1036
+ json.dump(log_data, f, indent=4)
1037
+
1038
+
1039
+ class DataProcessorChoiceCallerWrapper:
1040
+ def __init__(self, wrapped, preprocessor, postprocessor) -> None:
1041
+ self._wrapped = wrapped
1042
+ if preprocessor is not None:
1043
+ self._preprocessor = preprocessor
1044
+ else:
1045
+ self._preprocessor = lambda x, y: (x, y)
1046
+ if postprocessor is not None:
1047
+ self._postprocessor = postprocessor
1048
+ else:
1049
+ self._postprocessor = lambda x: x
1050
+
1051
+ def __getattr__(self, name):
1052
+ return getattr(self._wrapped, name)
1053
+
1054
+ def benchmark(self, *args, out) -> float:
1055
+ new_args, new_out = self._preprocessor(args, out)
1056
+ result = self._wrapped.benchmark(*new_args, out=new_out)
1057
+ new_out = self._postprocessor(new_out)
1058
+ if out is not new_out:
1059
+ out.copy_(new_out)
1060
+ return result
1061
+
1062
+ def output_node(self) -> ir.TensorBox:
1063
+ result = self._wrapped.output_node()
1064
+ return self._postprocessor(result)
1065
+
1066
+ def __repr__(self) -> str:
1067
+ return f"DataProcessorChoiceCallerWrapper({self._wrapped})"
1068
+
1069
+
1070
+ class DataProcessorTemplateWrapper:
1071
+ """
1072
+ A wrapper class for a kernel template.
1073
+
1074
+ This class together with `DataProcessorChoiceCallerWrapper` provides a convenient way to
1075
+ preprocess and postprocess data before and after using the wrapped template. A typical
1076
+ usage is to reorder or filter the input nodes in order to match the expected input of other
1077
+ kernel choices like a ATen kernel. A more complicated usage is to prepack the weights.
1078
+ See the example from :mod:`cpp_gemm_template` for more details.
1079
+ """
1080
+
1081
+ def __init__(
1082
+ self,
1083
+ wrapped_template_cls,
1084
+ preprocessor,
1085
+ postprocessor,
1086
+ **kwargs,
1087
+ ) -> None:
1088
+ if preprocessor is not None:
1089
+ self._preprocessor = preprocessor
1090
+ else:
1091
+ self._preprocessor = lambda x, y: (x, y)
1092
+ if postprocessor is not None:
1093
+ self._postprocessor = postprocessor
1094
+ else:
1095
+ self._postprocessor = lambda x: x
1096
+ assert "input_nodes" in kwargs
1097
+ assert "layout" in kwargs
1098
+ kwargs["input_nodes"], kwargs["layout"] = preprocessor(
1099
+ kwargs["input_nodes"], kwargs["layout"]
1100
+ )
1101
+ self._wrapped = wrapped_template_cls(**kwargs)
1102
+
1103
+ def __getattr__(self, name):
1104
+ return getattr(self._wrapped, name)
1105
+
1106
+ def maybe_append_choice(self, choices, **kwargs):
1107
+ return type(self._wrapped).maybe_append_choice(self, choices, **kwargs)
1108
+
1109
+ def generate(self, **kwargs):
1110
+ choice_caller = self._wrapped.generate(**kwargs)
1111
+ return DataProcessorChoiceCallerWrapper(
1112
+ choice_caller, self._preprocessor, self._postprocessor
1113
+ )
1114
+
1115
+ def __repr__(self) -> str:
1116
+ return f"DataProcessorTemplateWrapper({self._wrapped})"
1117
+
1118
+
1119
+ class ErrorFromChoice(RuntimeError):
1120
+ def __init__(self, msg, choice: ChoiceCaller, inputs_str) -> None:
1121
+ msg += f"\nFrom choice {choice}\n{inputs_str}"
1122
+ super().__init__(msg)
1123
+ self.choice = choice
1124
+
1125
+
1126
+ class NoValidChoicesError(RuntimeError):
1127
+ pass
1128
+
1129
+
1130
+ @functools.lru_cache(None)
1131
+ def get_env_num_workers() -> Optional[int]:
1132
+ if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
1133
+ return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
1134
+ return None
1135
+
1136
+
1137
+ def create_inputs_key(input_nodes) -> str:
1138
+ return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes])
1139
+
1140
+
1141
+ def create_precompile_key(
1142
+ name: str, inputs_key: str, choices: List[ChoiceCaller]
1143
+ ) -> str:
1144
+ return ":".join(
1145
+ [
1146
+ name,
1147
+ inputs_key,
1148
+ torch.get_float32_matmul_precision(),
1149
+ ]
1150
+ + [choice.hash_key() for choice in choices]
1151
+ )
1152
+
1153
+
1154
+ class AlgorithmSelectorCache(PersistentCache):
1155
+ def __init__(self, *args, **kwargs) -> None:
1156
+ super().__init__(*args, **kwargs)
1157
+
1158
+ # the autotuning will get occur in the scheduler, so there is
1159
+ # no guarantee that the first lowering for a given key will also be the
1160
+ # first to benchmark it. share a single precompilation function for all lowerings
1161
+ # of a particular key
1162
+ self.precompile_cache: Dict[str, Callable[[], None]] = {}
1163
+ # list of callbacks that are called after benchmarking
1164
+ self.feedback_saver_fns: List[
1165
+ Callable[
1166
+ [Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None
1167
+ ]
1168
+ ] = []
1169
+
1170
+ def __call__(
1171
+ self,
1172
+ name,
1173
+ choices: List[ChoiceCaller],
1174
+ input_nodes,
1175
+ layout,
1176
+ # optional dict mapping arg indices to the functions
1177
+ # generating a torch.Tensor for that input from the
1178
+ # corresponding ir.Buffer. if passed for a given
1179
+ # arg, the function will be called instead of
1180
+ # generating a random torch.Tensor for benchmarking.
1181
+ input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
1182
+ precompilation_timeout_seconds: int = 60 * 60,
1183
+ return_multi_template=False,
1184
+ ):
1185
+ from .codegen.cuda.cuda_kernel import CUDATemplateCaller
1186
+
1187
+ # Templates selected with input_gen_fns require specific input data to avoid IMA
1188
+ # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection
1189
+ # TODO(jgong5): support multi-template on CPU
1190
+ if input_gen_fns is not None or layout.device.type == "cpu":
1191
+ return_multi_template = False
1192
+
1193
+ # TODO - assert that we have not mutating kernels here
1194
+
1195
+ # TODO(nmacchioni): remove once CI tests are fixed
1196
+ choices = [choice for choice in choices if choice is not None]
1197
+
1198
+ if mm_file_name := get_mm_log_filename():
1199
+ M, K = input_nodes[-2].get_size()[:2]
1200
+ N = input_nodes[-1].get_size()[-1]
1201
+ append_to_log(mm_file_name, {"invoke": str((M, K, N))})
1202
+
1203
+ if len(choices) == 0:
1204
+ backend_config = (
1205
+ "max_autotune_gemm_backends"
1206
+ if name != "convolution"
1207
+ else "max_autotune_conv_backends"
1208
+ )
1209
+ raise NoValidChoicesError(
1210
+ f"No choices to select, please consider adding ATEN into {backend_config} "
1211
+ "config (defined in torch/_inductor/config.py) to allow at least one choice. "
1212
+ )
1213
+ log.debug("Max autotune selects from %s choices.", str(len(choices)))
1214
+
1215
+ if len(choices) == 1:
1216
+ if not isinstance(choices[0], CUDATemplateCaller):
1217
+ # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size.
1218
+ return choices[0].output_node()
1219
+
1220
+ @functools.lru_cache(None)
1221
+ def make_benchmark_fn():
1222
+ return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns)
1223
+
1224
+ inputs_key = create_inputs_key(input_nodes)
1225
+
1226
+ def precompile(choices) -> Callable[[], None]:
1227
+ def no_op(*args, **kwargs):
1228
+ return
1229
+
1230
+ if (
1231
+ precompilation_timeout_seconds is None
1232
+ or precompilation_timeout_seconds <= 0
1233
+ ):
1234
+ return no_op
1235
+
1236
+ env_workers = get_env_num_workers()
1237
+ num_workers = env_workers if env_workers is not None else (len(choices))
1238
+
1239
+ if num_workers <= 0:
1240
+ return no_op
1241
+
1242
+ # https://github.com/python/cpython/issues/106905
1243
+ if (
1244
+ sys.version_info.major == 3
1245
+ and sys.version_info.minor == 11
1246
+ and sys.version_info.micro <= 8
1247
+ ):
1248
+ return no_op
1249
+
1250
+ # check local and global cache before precompiling
1251
+ timings = self.lookup(
1252
+ choices,
1253
+ name,
1254
+ inputs_key,
1255
+ benchmark=None,
1256
+ )
1257
+
1258
+ if timings:
1259
+ return no_op
1260
+
1261
+ precompile_key = create_precompile_key(name, inputs_key, choices)
1262
+ if precompile_func := self.precompile_cache.get(precompile_key):
1263
+ return precompile_func
1264
+
1265
+ log.info(
1266
+ "Multithreaded precompilation for %d choices using %d worker threads",
1267
+ len(choices),
1268
+ num_workers,
1269
+ )
1270
+
1271
+ # In rare circumstances, because python threads inherit global state,
1272
+ # thread pool executor can race and leave stdout/stderr in a state
1273
+ # different than the original values. we explicitly restore the state
1274
+ # here to avoid this issue.
1275
+
1276
+ initial_stdout = sys.stdout
1277
+ initial_stderr = sys.stderr
1278
+
1279
+ def precompile_with_captured_stdout(choice):
1280
+ with restore_stdout_stderr(initial_stdout, initial_stderr):
1281
+ return choice.precompile()
1282
+
1283
+ executor = ThreadPoolExecutor(max_workers=num_workers)
1284
+
1285
+ futures = {}
1286
+ for c in choices:
1287
+ if hasattr(c, "precompile"):
1288
+ future = executor.submit(precompile_with_captured_stdout, c)
1289
+ futures[future] = c
1290
+
1291
+ @functools.lru_cache(None)
1292
+ @restore_stdout_stderr(initial_stdout, initial_stderr)
1293
+ def wait_on_futures():
1294
+ counters["inductor"]["select_algorithm_precompile"] += 1
1295
+ for future in as_completed(
1296
+ futures,
1297
+ timeout=precompilation_timeout_seconds,
1298
+ ):
1299
+ if e := future.exception():
1300
+ log.error(
1301
+ "Exception %s for benchmark choice %s", e, futures[future]
1302
+ )
1303
+
1304
+ executor.shutdown(wait=True)
1305
+
1306
+ self.precompile_cache[precompile_key] = wait_on_futures
1307
+
1308
+ return wait_on_futures
1309
+
1310
+ def autotune(choices):
1311
+ return make_benchmark_fn()(choices)
1312
+
1313
+ if config.autotune_in_subproc:
1314
+ from .autotune_process import tuning_pool
1315
+
1316
+ # do the optional warmup
1317
+ tuning_pool.initialize()
1318
+
1319
+ def do_autotuning(precompile_fn):
1320
+ precompile_start_ts = time.time()
1321
+ precompile_fn()
1322
+ precompile_elapse = time.time() - precompile_start_ts
1323
+
1324
+ autotune_start_ts = time.time()
1325
+ timings = self.lookup(
1326
+ choices,
1327
+ name,
1328
+ inputs_key,
1329
+ autotune,
1330
+ )
1331
+ autotune_elapse = time.time() - autotune_start_ts
1332
+
1333
+ if timings and all(
1334
+ not math.isfinite(timing) for timing in timings.values()
1335
+ ):
1336
+ raise NoValidChoicesError
1337
+
1338
+ if make_benchmark_fn.cache_info().currsize:
1339
+ counters["inductor"]["select_algorithm_autotune"] += 1
1340
+
1341
+ if (
1342
+ make_benchmark_fn.cache_info().currsize
1343
+ or log.getEffectiveLevel() == logging.DEBUG
1344
+ or config.trace.log_autotuning_results
1345
+ ):
1346
+ self.log_results(
1347
+ name, input_nodes, timings, autotune_elapse, precompile_elapse
1348
+ )
1349
+
1350
+ for feedback_fn in self.feedback_saver_fns:
1351
+ feedback_fn(timings, name, input_nodes, choices)
1352
+
1353
+ return timings
1354
+
1355
+ precompile_fn = precompile(choices)
1356
+
1357
+ if return_multi_template and (config.max_autotune or config.max_autotune_gemm):
1358
+
1359
+ def get_timings():
1360
+ timings = do_autotuning(precompile_fn)
1361
+ min_extern_choice = float("inf")
1362
+ for choice, timing in timings.items():
1363
+ if isinstance(choice, ExternKernelCaller):
1364
+ min_extern_choice = min(min_extern_choice, timing)
1365
+
1366
+ timings = {
1367
+ choice: time
1368
+ for choice, time in timings.items()
1369
+ if (
1370
+ time <= min_extern_choice
1371
+ or not isinstance(choice, ExternKernelCaller)
1372
+ )
1373
+ }
1374
+
1375
+ return timings
1376
+
1377
+ return torch._inductor.ir.TensorBox.create(
1378
+ torch._inductor.ir.MultiTemplateBuffer(
1379
+ layout,
1380
+ input_nodes,
1381
+ get_timings,
1382
+ )
1383
+ )
1384
+
1385
+ # TODO - dont want to precompile if we have a cache hit
1386
+ timings = do_autotuning(precompile_fn)
1387
+ if timings == {} or choices[0] not in timings:
1388
+ return choices[0].output_node()
1389
+
1390
+ selected_key = builtins.min(timings, key=timings.__getitem__)
1391
+ selected_time = timings[selected_key]
1392
+ selected_choice = selected_key.output_node()
1393
+ log.debug("selected choice: %s", str(selected_choice))
1394
+ return selected_choice
1395
+
1396
+ @classmethod
1397
+ def make_benchmark_fn(
1398
+ cls,
1399
+ choices,
1400
+ input_nodes,
1401
+ layout,
1402
+ input_gen_fns=None,
1403
+ ):
1404
+ if input_gen_fns is None:
1405
+ input_gen_fns = {}
1406
+
1407
+ def get_inputs():
1408
+ # de-duplicate args
1409
+ unique_example_inputs = {
1410
+ x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x)
1411
+ for i, x in enumerate(input_nodes)
1412
+ }
1413
+ example_inputs = list(unique_example_inputs.values())
1414
+ example_inputs_extern = [
1415
+ unique_example_inputs[input_node.get_name()]
1416
+ if unique_example_inputs[input_node.get_name()].is_mkldnn
1417
+ else torch.as_strided(
1418
+ unique_example_inputs[input_node.get_name()],
1419
+ V.graph.sizevars.size_hints(
1420
+ input_node.get_size(),
1421
+ fallback=config.unbacked_symint_fallback,
1422
+ ),
1423
+ V.graph.sizevars.size_hints(
1424
+ input_node.get_stride(),
1425
+ fallback=config.unbacked_symint_fallback,
1426
+ ),
1427
+ V.graph.sizevars.size_hint(
1428
+ input_node.get_layout().offset,
1429
+ fallback=config.unbacked_symint_fallback,
1430
+ ),
1431
+ )
1432
+ for input_node in input_nodes
1433
+ ]
1434
+
1435
+ out = cls.benchmark_example_value(layout)
1436
+ out_extern = torch.as_strided(
1437
+ out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
1438
+ )
1439
+ expected = None
1440
+ if VERIFY:
1441
+ choices[0].benchmark(*example_inputs_extern, out=out_extern)
1442
+ expected = out_extern.clone()
1443
+
1444
+ return example_inputs, example_inputs_extern, out, out_extern, expected
1445
+
1446
+ if DEBUG:
1447
+ print(f"{len(choices)} tuning requests:")
1448
+
1449
+ def debug_str(example_inputs, out):
1450
+ def tensor_repr(x):
1451
+ return (
1452
+ f"torch.empty_strided({tuple(x.size())!r}, {tuple(x.stride())!r}, "
1453
+ f"dtype={x.dtype!r}, device={x.device.type!r})"
1454
+ )
1455
+
1456
+ lines = [
1457
+ "inputs = [",
1458
+ ]
1459
+ for x in example_inputs:
1460
+ lines.append(f" {tensor_repr(x)},")
1461
+ lines += ["]", f"out = {tensor_repr(out)}", ""]
1462
+ return "\n".join(lines)
1463
+
1464
+ def benchmark_choice_in_current_process(
1465
+ choice, example_inputs, example_inputs_extern, out, out_extern, expected
1466
+ ):
1467
+ out.zero_()
1468
+ if isinstance(choice, ExternKernelCaller):
1469
+ # aten kernels want the offset baked in for sliced tensors
1470
+ result = choice.benchmark(*example_inputs_extern, out=out_extern)
1471
+ else:
1472
+ # triton templates want the base pointer for sliced tensors
1473
+ result = choice.benchmark(*example_inputs, out=out)
1474
+ if VERIFY and expected is not None:
1475
+ torch.testing.assert_close(out_extern, expected, **VERIFY)
1476
+ if torch.cuda.is_available():
1477
+ torch.cuda.synchronize() # shake out any CUDA errors
1478
+ return result
1479
+
1480
+ def benchmark_in_current_process(choices):
1481
+ inputs = get_inputs()
1482
+ example_inputs, _, out, _, _ = inputs
1483
+ timings = {}
1484
+ for choice in choices:
1485
+ try:
1486
+ timing = benchmark_choice_in_current_process(choice, *inputs)
1487
+ except CUDACompileError as e:
1488
+ log.error(
1489
+ "CUDA compilation error during autotuning: \n%s. \nIgnoring this choice.",
1490
+ str(e),
1491
+ )
1492
+ timing = float("inf")
1493
+ except NotImplementedError as e:
1494
+ log.warning("Not yet implemented: %s", e)
1495
+ timing = float("inf")
1496
+ except RuntimeError as e:
1497
+ msg = str(e)
1498
+ if "invalid argument" in msg:
1499
+ msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n"
1500
+ else:
1501
+ if "illegal memory access" in msg:
1502
+ msg += "\n\nEither error in template or triton bug.\n"
1503
+ log.error(
1504
+ "Runtime error during autotuning: \n%s. \nIgnoring this choice.",
1505
+ msg,
1506
+ )
1507
+ timing = float("inf")
1508
+ except AssertionError as e:
1509
+ raise AssertionError( # noqa: B904
1510
+ f"Incorrect result from choice {choice}\n\n{e}"
1511
+ )
1512
+ except Exception as e:
1513
+ try:
1514
+ from triton.runtime.autotuner import OutOfResources
1515
+
1516
+ if isinstance(e, OutOfResources):
1517
+ log.warning(e)
1518
+ timing = float("inf")
1519
+ else:
1520
+ raise e
1521
+ except ImportError:
1522
+ raise e from None
1523
+
1524
+ timings[choice] = timing
1525
+
1526
+ return timings
1527
+
1528
+ def benchmark_in_sub_process(choices):
1529
+ from . import autotune_process
1530
+
1531
+ # only benchmark triton kernel in sub process for now.
1532
+ # ATen/Extern kernel are still benchmarked in the current process.
1533
+ extern = [c for c in choices if isinstance(c, ExternKernelCaller)]
1534
+ triton = [c for c in choices if not isinstance(c, ExternKernelCaller)]
1535
+
1536
+ timings = benchmark_in_current_process(extern)
1537
+ timings.update(autotune_process.benchmark_in_sub_process(triton))
1538
+ return timings
1539
+
1540
+ benchmark = (
1541
+ benchmark_in_sub_process
1542
+ if config.autotune_in_subproc
1543
+ else benchmark_in_current_process
1544
+ )
1545
+
1546
+ return benchmark
1547
+
1548
+ @staticmethod
1549
+ def log_results(
1550
+ name: str,
1551
+ input_nodes: List[ir.IRNode],
1552
+ timings: Dict[ChoiceCaller, float],
1553
+ elapse: float,
1554
+ precompile_elapse: float,
1555
+ ):
1556
+ V.debug.log_autotuning_results(
1557
+ name, input_nodes, timings, elapse, precompile_elapse
1558
+ )
1559
+ if not (config.max_autotune or config.max_autotune_gemm) or not PRINT_AUTOTUNE:
1560
+ return
1561
+ sizes = ", ".join(
1562
+ [
1563
+ "x".join(
1564
+ map(
1565
+ str,
1566
+ V.graph.sizevars.size_hints(
1567
+ n.get_size(), fallback=config.unbacked_symint_fallback
1568
+ ),
1569
+ )
1570
+ )
1571
+ for n in input_nodes
1572
+ ]
1573
+ )
1574
+
1575
+ n = None if log.getEffectiveLevel() == logging.DEBUG else 10
1576
+ top_k = sorted(timings, key=timings.__getitem__)[:n]
1577
+ best = top_k[0]
1578
+
1579
+ def get_choice_info(choice):
1580
+ if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller):
1581
+ return {"type": "cublas", "time": timings[choice]}
1582
+
1583
+ assert isinstance(
1584
+ choice, torch._inductor.select_algorithm.TritonTemplateCaller
1585
+ )
1586
+
1587
+ info = choice.info_dict()
1588
+ tile = info["tile_shape"]
1589
+
1590
+ tile_vals = eval(tile) # type: ignore[arg-type]
1591
+ BLOCK_M = tile_vals[0]
1592
+ BLOCK_K = tile_vals[1]
1593
+ BLOCK_N = tile_vals[2]
1594
+
1595
+ return {
1596
+ "type": "triton",
1597
+ "time": timings[choice],
1598
+ "BLOCK_M": BLOCK_M,
1599
+ "BLOCK_K": BLOCK_K,
1600
+ "BLOCK_N": BLOCK_N,
1601
+ "num_stages": info["num_stages"],
1602
+ "num_warps": info["num_warps"],
1603
+ }
1604
+
1605
+ mm_filename = get_mm_log_filename()
1606
+ if mm_filename and "mm" in name:
1607
+ M, K = input_nodes[-2].get_size()[:2]
1608
+ N = input_nodes[-1].get_size()[-1]
1609
+
1610
+ out_dict = {
1611
+ str((M, K, N)): [get_choice_info(choice) for choice in timings.keys()]
1612
+ }
1613
+
1614
+ append_to_log(mm_filename, out_dict)
1615
+
1616
+ best_time = timings[best]
1617
+ sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
1618
+ for choice in top_k:
1619
+ result = timings[choice]
1620
+ if result:
1621
+ kernel_info = (
1622
+ choice.debug_extra if hasattr(choice, "debug_extra") else ""
1623
+ )
1624
+ sys.stderr.write(
1625
+ f" {choice.name} {result:.4f} ms {best_time / result:.1%} {kernel_info}\n"
1626
+ )
1627
+ else:
1628
+ sys.stderr.write(
1629
+ f" {choice.name} {result:.4f} ms <DIVIDED BY ZERO ERROR>\n"
1630
+ )
1631
+
1632
+ autotune_type_str = (
1633
+ "SubProcess" if config.autotune_in_subproc else "SingleProcess"
1634
+ )
1635
+ sys.stderr.write(
1636
+ f"{autotune_type_str} AUTOTUNE benchmarking takes {elapse:.4f} seconds and {precompile_elapse:.4f}"
1637
+ " seconds precompiling\n"
1638
+ )
1639
+
1640
+ @staticmethod
1641
+ def benchmark_example_value(node):
1642
+ """
1643
+ Convert an ir.Buffer into a concrete torch.Tensor we can use for
1644
+ benchmarking.
1645
+ """
1646
+ if isinstance(node, ir.Layout):
1647
+ node = ir.Buffer("fake", node)
1648
+ # triton templates want the base tensor.
1649
+ if isinstance(node, ir.BaseView):
1650
+ node = node.unwrap_view()
1651
+ return AlgorithmSelectorCache.generate_example_value(
1652
+ V.graph.sizevars.size_hints(
1653
+ node.get_size(),
1654
+ fallback=config.unbacked_symint_fallback,
1655
+ ),
1656
+ V.graph.sizevars.size_hints(
1657
+ node.get_stride(),
1658
+ fallback=config.unbacked_symint_fallback,
1659
+ ),
1660
+ node.get_device(),
1661
+ node.get_dtype(),
1662
+ node.layout.offset,
1663
+ )
1664
+
1665
+ @staticmethod
1666
+ def generate_example_value(size, stride, device, dtype, extra_size):
1667
+ # preserve rng states to avoid the rand_strided call below changes
1668
+ # the rng states for the real model code.
1669
+ with preserve_rng_state():
1670
+ return rand_strided(
1671
+ size,
1672
+ stride,
1673
+ device=device,
1674
+ dtype=dtype,
1675
+ extra_size=extra_size,
1676
+ )
1677
+
1678
+ @staticmethod
1679
+ def key_of(node):
1680
+ """
1681
+ Extract the pieces of an ir.Buffer that we should invalidate cached
1682
+ autotuning results on.
1683
+ """
1684
+ sizevars = V.graph.sizevars
1685
+ return (
1686
+ node.get_device().type,
1687
+ str(node.get_dtype()),
1688
+ *sizevars.size_hints(
1689
+ node.get_size(),
1690
+ fallback=config.unbacked_symint_fallback,
1691
+ ),
1692
+ *sizevars.size_hints(
1693
+ node.get_stride(),
1694
+ fallback=config.unbacked_symint_fallback,
1695
+ ),
1696
+ sizevars.size_hint(
1697
+ node.get_layout().offset,
1698
+ fallback=config.unbacked_symint_fallback,
1699
+ ),
1700
+ )
1701
+
1702
+ def add_feedback_saver(
1703
+ self,
1704
+ fn: Callable[
1705
+ [Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None
1706
+ ],
1707
+ ):
1708
+ self.feedback_saver_fns.append(fn)
1709
+
1710
+
1711
+ _ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None
1712
+
1713
+
1714
+ def autotune_select_algorithm(*args, **kwargs):
1715
+ global _ALGORITHM_SELECTOR_CACHE
1716
+ if _ALGORITHM_SELECTOR_CACHE is None:
1717
+ _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
1718
+
1719
+ if "return_multi_template" not in kwargs:
1720
+ kwargs[
1721
+ "return_multi_template"
1722
+ ] = torch._inductor.config.benchmark_epilogue_fusion
1723
+
1724
+ return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
1725
+
1726
+
1727
+ def add_feedback_saver(
1728
+ fn: Callable[[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None]
1729
+ ):
1730
+ global _ALGORITHM_SELECTOR_CACHE
1731
+ if _ALGORITHM_SELECTOR_CACHE is None:
1732
+ _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
1733
+ _ALGORITHM_SELECTOR_CACHE.add_feedback_saver(fn)
1734
+
1735
+
1736
+ def realize_inputs(*args):
1737
+ if len(args) == 1:
1738
+ return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0]))
1739
+ return [realize_inputs(x) for x in args]
1740
+
1741
+
1742
+ # ensure lowering is imported so that `extern_kernels.*` is populated
1743
+ from . import lowering # noqa: F401