Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/bounds.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/decomposition.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_helpers.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_heuristics.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/bounds.py +124 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py +1755 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +212 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py +45 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py +242 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py +360 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py +186 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py +18 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py +75 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py +130 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py +1204 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py +1100 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py +182 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py +202 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py +186 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/central_index.py +114 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/split_cat.py +1537 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py +90 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/test_case.py +53 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/jiterator.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/nccl.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/random.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/streams.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py +626 -0
.gitattributes
CHANGED
|
@@ -74,3 +74,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/V
|
|
| 74 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 75 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 filter=lfs diff=lfs merge=lfs -text
|
| 76 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 74 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 75 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 filter=lfs diff=lfs merge=lfs -text
|
| 76 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 77 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:647373d0020a53c70bd44d2950f81f6c5edec206899855800a76aabe1ae27e02
|
| 3 |
+
size 745240
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-311.pyc
ADDED
|
Binary file (29.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/bounds.cpython-311.pyc
ADDED
|
Binary file (7.75 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/decomposition.cpython-311.pyc
ADDED
|
Binary file (34.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-311.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-311.pyc
ADDED
|
Binary file (35.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc
ADDED
|
Binary file (86.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_helpers.cpython-311.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_heuristics.cpython-311.pyc
ADDED
|
Binary file (64.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (79.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/bounds.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import operator
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Any, Callable, Dict
|
| 4 |
+
|
| 5 |
+
from sympy import Expr
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
| 9 |
+
from .ir import InterpreterShim, LoopBody, LoopBodyBlock
|
| 10 |
+
from .utils import cache_on_self, dominated_nodes
|
| 11 |
+
from .virtualized import V
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BoundVars:
|
| 15 |
+
"""
|
| 16 |
+
Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run()
|
| 17 |
+
It exposes the ranges of the nodes in the `bounds` variable
|
| 18 |
+
|
| 19 |
+
Note. A current limitation of this analysis is that it just works on a per-loop basis.
|
| 20 |
+
We should be able to propagate the bounds between across the whole graph. This may benefit
|
| 21 |
+
the case a bounded variable is returned by a kernel and fed into another.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, loop_body: LoopBody) -> None:
|
| 25 |
+
self.loop_body = loop_body
|
| 26 |
+
self.replacement_vals = {
|
| 27 |
+
k: ValueRanges[Expr](0, v - 1)
|
| 28 |
+
if (isinstance(v, int) or v.is_number)
|
| 29 |
+
else bound_sympy(v)
|
| 30 |
+
for k, v in loop_body.var_ranges.items()
|
| 31 |
+
}
|
| 32 |
+
# avoid computing these values, pessimistically assume that they are unbounded
|
| 33 |
+
self.unbounded_vars = dominated_nodes(
|
| 34 |
+
node
|
| 35 |
+
for node in self.loop_body.get_nodes()
|
| 36 |
+
if node.target in ["load", "reduction", operator.getitem]
|
| 37 |
+
or "masked_subblock" in node.target
|
| 38 |
+
)
|
| 39 |
+
# To access this variable call `get_bounds()`
|
| 40 |
+
self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {}
|
| 41 |
+
|
| 42 |
+
@cache_on_self
|
| 43 |
+
def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]:
|
| 44 |
+
submodules = self.swap_submodules(self.loop_body.submodules)
|
| 45 |
+
|
| 46 |
+
# Initialize the environment with the unbounded variables
|
| 47 |
+
for node in self.unbounded_vars:
|
| 48 |
+
# we need to evaluate masked_subblock to recurse, and we need to set indirect values
|
| 49 |
+
if not isinstance(node.target, str) or (
|
| 50 |
+
"masked_subblock" not in node.target
|
| 51 |
+
and "set_indirect" not in node.target
|
| 52 |
+
):
|
| 53 |
+
self._bounds[node] = ValueRanges[Expr].unknown()
|
| 54 |
+
|
| 55 |
+
with V.set_ops_handler(ValueRangeAnalysis()):
|
| 56 |
+
interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules)
|
| 57 |
+
interpreter.run(V.get_ops_handler(), initial_env=self._bounds)
|
| 58 |
+
return self._bounds
|
| 59 |
+
|
| 60 |
+
def swap_submodules(
|
| 61 |
+
self, submodules: Dict[str, Callable[..., Any]]
|
| 62 |
+
) -> Dict[str, Callable[..., ValueRanges[Expr]]]:
|
| 63 |
+
result: Dict[str, Callable[..., ValueRanges[Expr]]] = {}
|
| 64 |
+
for key in submodules.keys():
|
| 65 |
+
if key == "get_index":
|
| 66 |
+
result[key] = self.get_index
|
| 67 |
+
elif "masked_subblock" in key:
|
| 68 |
+
subblock = self.loop_body.subblocks[key]
|
| 69 |
+
# The result within the lambda will reference to the final
|
| 70 |
+
# set of modules at the end of the for-loop as it stores a reference to it
|
| 71 |
+
|
| 72 |
+
# bind subblock in a function because python lambdas close over by reference
|
| 73 |
+
# moving the lambda out of make_fn would close over the reference to subblock,
|
| 74 |
+
# so all lambdas would have the same subblock reference that is the final
|
| 75 |
+
# subblock in the loop
|
| 76 |
+
def make_fn(subblock):
|
| 77 |
+
return lambda mask, value: self.masked_subblock(
|
| 78 |
+
subblock, self._bounds, mask, value, result
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
result[key] = make_fn(subblock)
|
| 82 |
+
|
| 83 |
+
elif "set_indirect" in key:
|
| 84 |
+
idx = int(key[len("set_indirect") :])
|
| 85 |
+
var = self.loop_body.indirect_vars[idx]
|
| 86 |
+
indirect = partial(self.set_indirect, var)
|
| 87 |
+
result[key] = indirect
|
| 88 |
+
else:
|
| 89 |
+
assert "scan" in key
|
| 90 |
+
result[key] = submodules[key]
|
| 91 |
+
|
| 92 |
+
return result
|
| 93 |
+
|
| 94 |
+
def masked_subblock(
|
| 95 |
+
self,
|
| 96 |
+
subblock: LoopBodyBlock,
|
| 97 |
+
env: Dict[torch.fx.Node, ValueRanges[Expr]],
|
| 98 |
+
mask: Any,
|
| 99 |
+
value: Any,
|
| 100 |
+
submodules: Dict[str, Callable[..., Any]],
|
| 101 |
+
) -> ValueRanges[Expr]:
|
| 102 |
+
interp = InterpreterShim(subblock.graph, submodules)
|
| 103 |
+
interp.run(V.get_ops_handler(), initial_env=env)
|
| 104 |
+
output = [node for node in subblock.graph.nodes if node.target == "output"]
|
| 105 |
+
assert len(output) == 1
|
| 106 |
+
# dont bother unioning with value since the load from buffer will be
|
| 107 |
+
# pessimistically assumed to be inf anyway
|
| 108 |
+
return interp.env[output[0]]
|
| 109 |
+
|
| 110 |
+
def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]:
|
| 111 |
+
assert isinstance(new, ValueRanges)
|
| 112 |
+
self.replacement_vals[old] = new
|
| 113 |
+
return new
|
| 114 |
+
|
| 115 |
+
def get_index(self, name: Expr) -> ValueRanges[Expr]:
|
| 116 |
+
expr = self.loop_body.indexing_exprs[name]
|
| 117 |
+
bound = self.replacement_vals.get(expr)
|
| 118 |
+
if bound is None:
|
| 119 |
+
bound = bound_sympy(expr, self.replacement_vals)
|
| 120 |
+
# The following assertion is true at the time of this writing
|
| 121 |
+
# We don't assert is as to not execute bound_sympy when bound is not None
|
| 122 |
+
# assert bound is None or bound == bound_sympy(expr, self.replacement_vals)
|
| 123 |
+
self.replacement_vals[name] = bound
|
| 124 |
+
return bound
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (224 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc
ADDED
|
Binary file (46.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc
ADDED
|
Binary file (22.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc
ADDED
|
Binary file (9.21 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc
ADDED
|
Binary file (94.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py
ADDED
|
@@ -0,0 +1,1755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import dataclasses
|
| 3 |
+
import functools
|
| 4 |
+
import itertools
|
| 5 |
+
import logging
|
| 6 |
+
import operator
|
| 7 |
+
import re
|
| 8 |
+
from itertools import chain
|
| 9 |
+
from typing import (
|
| 10 |
+
Any,
|
| 11 |
+
Callable,
|
| 12 |
+
ClassVar,
|
| 13 |
+
Dict,
|
| 14 |
+
List,
|
| 15 |
+
NamedTuple,
|
| 16 |
+
Optional,
|
| 17 |
+
Set,
|
| 18 |
+
Tuple,
|
| 19 |
+
TYPE_CHECKING,
|
| 20 |
+
Union,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
import sympy
|
| 24 |
+
from sympy.printing.printer import Printer
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.fx
|
| 28 |
+
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
| 29 |
+
from torch.utils import _pytree as pytree
|
| 30 |
+
from torch.utils._sympy.value_ranges import ValueRanges
|
| 31 |
+
|
| 32 |
+
from .. import config, metrics
|
| 33 |
+
from ..utils import (
|
| 34 |
+
DeferredLineBase,
|
| 35 |
+
do_bench,
|
| 36 |
+
free_symbol_startswith,
|
| 37 |
+
IndentedBuffer,
|
| 38 |
+
sympy_dot,
|
| 39 |
+
sympy_index_symbol,
|
| 40 |
+
sympy_subs,
|
| 41 |
+
unique,
|
| 42 |
+
)
|
| 43 |
+
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
|
| 44 |
+
|
| 45 |
+
if TYPE_CHECKING:
|
| 46 |
+
from ..ir import TensorBox
|
| 47 |
+
|
| 48 |
+
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def data_type_logger(msg):
|
| 52 |
+
if schedule_log.isEnabledFor(logging.DEBUG):
|
| 53 |
+
schedule_log.debug("Data type propagation: %s", msg)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclasses.dataclass
|
| 57 |
+
class WorkspaceArg:
|
| 58 |
+
"""A temporary buffer used for a single kernel, then discarded.
|
| 59 |
+
|
| 60 |
+
Not registered as a traditional buffer since there are no users,
|
| 61 |
+
so it would be dead code eliminated.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
nbytes: sympy.Expr
|
| 65 |
+
zero_fill: bool
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclasses.dataclass
|
| 69 |
+
class TensorArg:
|
| 70 |
+
name: str
|
| 71 |
+
buffer: str
|
| 72 |
+
dtype: torch.dtype
|
| 73 |
+
offset: sympy.Expr = sympy.Integer(0)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclasses.dataclass
|
| 77 |
+
class SizeArg:
|
| 78 |
+
name: str
|
| 79 |
+
expr: sympy.Expr
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclasses.dataclass
|
| 83 |
+
class DeviceCodegen:
|
| 84 |
+
scheduling: type
|
| 85 |
+
wrapper_codegen: type
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
|
| 89 |
+
|
| 90 |
+
device_codegens: Dict[str, DeviceCodegen] = {}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class DeviceOpOverrides:
|
| 94 |
+
def import_get_raw_stream_as(self, name):
|
| 95 |
+
raise NotImplementedError()
|
| 96 |
+
|
| 97 |
+
def set_device(self, device_idx):
|
| 98 |
+
raise NotImplementedError()
|
| 99 |
+
|
| 100 |
+
def synchronize(self):
|
| 101 |
+
raise NotImplementedError()
|
| 102 |
+
|
| 103 |
+
def device_guard(self, device_idx):
|
| 104 |
+
raise NotImplementedError()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
|
| 111 |
+
# For any new backend looking to integrate with Inductor, customization of these two main
|
| 112 |
+
# parts are necessary to generate its specific code.
|
| 113 |
+
#
|
| 114 |
+
# Kernel code generation is determined by different Scheduling. Consequently, a new
|
| 115 |
+
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
|
| 116 |
+
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
|
| 117 |
+
#
|
| 118 |
+
# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
|
| 119 |
+
# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
|
| 120 |
+
# and override specific member functions to create backend-specific Python wrapper code.
|
| 121 |
+
#
|
| 122 |
+
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
|
| 123 |
+
# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
|
| 124 |
+
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
|
| 125 |
+
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
|
| 126 |
+
# register_backend_for_device, to equip a new backend at runtime.
|
| 127 |
+
#
|
| 128 |
+
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
|
| 129 |
+
# This backend can be used as a reference:
|
| 130 |
+
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
|
| 131 |
+
def register_backend_for_device(
|
| 132 |
+
device: str, device_scheduling: type, device_wrapper_codegen: type
|
| 133 |
+
):
|
| 134 |
+
device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_scheduling_for_device(device: str):
|
| 138 |
+
return device_codegens[device].scheduling if device in device_codegens else None
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_wrapper_codegen_for_device(device: str):
|
| 142 |
+
return (
|
| 143 |
+
device_codegens[device].wrapper_codegen if device in device_codegens else None
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
|
| 148 |
+
from ..ir import FlexibleLayout
|
| 149 |
+
|
| 150 |
+
# added contiguous index prevents reordering
|
| 151 |
+
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
|
| 155 |
+
device_op_overrides_dict[device] = device_op_overrides
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_device_op_overrides(device: str):
|
| 159 |
+
assert isinstance(device, str)
|
| 160 |
+
|
| 161 |
+
if not device_op_overrides_dict.keys():
|
| 162 |
+
from .cuda import device_op_overrides # noqa: F401
|
| 163 |
+
|
| 164 |
+
if device in device_op_overrides_dict.keys():
|
| 165 |
+
return device_op_overrides_dict[device]
|
| 166 |
+
|
| 167 |
+
return DeviceOpOverrides()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@functools.lru_cache(None)
|
| 171 |
+
def boolean_ops():
|
| 172 |
+
return (
|
| 173 |
+
"is_inf",
|
| 174 |
+
"is_nan",
|
| 175 |
+
"bitwise_xor",
|
| 176 |
+
"logical_not",
|
| 177 |
+
"signbit",
|
| 178 |
+
"le",
|
| 179 |
+
"lt",
|
| 180 |
+
"ge",
|
| 181 |
+
"gt",
|
| 182 |
+
"eq",
|
| 183 |
+
"ne",
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
DTYPE_TO_COMPUTATION_DTYPE = {
|
| 188 |
+
torch.bfloat16: torch.float,
|
| 189 |
+
torch.float16: torch.float,
|
| 190 |
+
**{
|
| 191 |
+
dtype: dtype
|
| 192 |
+
for dtype in [
|
| 193 |
+
torch.bool,
|
| 194 |
+
torch.float32,
|
| 195 |
+
torch.float64,
|
| 196 |
+
torch.int8,
|
| 197 |
+
torch.int16,
|
| 198 |
+
torch.int32,
|
| 199 |
+
torch.int64,
|
| 200 |
+
torch.uint8,
|
| 201 |
+
torch.uint16,
|
| 202 |
+
torch.uint32,
|
| 203 |
+
torch.uint64,
|
| 204 |
+
]
|
| 205 |
+
},
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class DataTypePropagation:
|
| 210 |
+
def __init__(self, body) -> None:
|
| 211 |
+
self.body = body
|
| 212 |
+
self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
|
| 213 |
+
"root": body.root_block.graph
|
| 214 |
+
}
|
| 215 |
+
for k, v in body.subblocks.items():
|
| 216 |
+
self.graphs[k] = v.graph
|
| 217 |
+
|
| 218 |
+
def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
|
| 219 |
+
inputs = node.all_input_nodes
|
| 220 |
+
input_nodes = [
|
| 221 |
+
n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
|
| 222 |
+
]
|
| 223 |
+
if len(input_nodes) == 0:
|
| 224 |
+
return None
|
| 225 |
+
|
| 226 |
+
all_input_nodes_propogated = all(
|
| 227 |
+
OptimizationContext.key in n.meta
|
| 228 |
+
and n.meta[OptimizationContext.key].dtype is not None
|
| 229 |
+
for n in input_nodes
|
| 230 |
+
)
|
| 231 |
+
if not all_input_nodes_propogated:
|
| 232 |
+
return None
|
| 233 |
+
|
| 234 |
+
return functools.reduce(
|
| 235 |
+
torch.promote_types,
|
| 236 |
+
[n.meta[OptimizationContext.key].dtype for n in input_nodes],
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
|
| 240 |
+
sub_graph = self.graphs[node.target]
|
| 241 |
+
dtype = self.propagate_graph(sub_graph)
|
| 242 |
+
assert dtype
|
| 243 |
+
return dtype
|
| 244 |
+
|
| 245 |
+
def deduce_node_dtype(self, node: torch.fx.Node):
|
| 246 |
+
if node.target in boolean_ops():
|
| 247 |
+
return torch.bool
|
| 248 |
+
|
| 249 |
+
if node.op == "placeholder":
|
| 250 |
+
return None
|
| 251 |
+
|
| 252 |
+
if node.target == "output":
|
| 253 |
+
# we can infer output node if it only have 1 arg
|
| 254 |
+
if len(node.args) != 1:
|
| 255 |
+
return None
|
| 256 |
+
|
| 257 |
+
if node.target in (
|
| 258 |
+
"to_dtype",
|
| 259 |
+
"index_expr",
|
| 260 |
+
):
|
| 261 |
+
return node.args[-1]
|
| 262 |
+
|
| 263 |
+
if node.target in (
|
| 264 |
+
"rand",
|
| 265 |
+
"randn",
|
| 266 |
+
):
|
| 267 |
+
return torch.float
|
| 268 |
+
|
| 269 |
+
if node.target in (
|
| 270 |
+
"get_index",
|
| 271 |
+
"index_expr",
|
| 272 |
+
):
|
| 273 |
+
return torch.int64
|
| 274 |
+
|
| 275 |
+
if node.target in (
|
| 276 |
+
"load",
|
| 277 |
+
"store",
|
| 278 |
+
"store_reduction",
|
| 279 |
+
):
|
| 280 |
+
buf_name = node.args[1]
|
| 281 |
+
return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
|
| 282 |
+
|
| 283 |
+
if node.target == operator.getitem:
|
| 284 |
+
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
|
| 285 |
+
|
| 286 |
+
assert isinstance(node.target, str)
|
| 287 |
+
|
| 288 |
+
if node.target == "reduction":
|
| 289 |
+
return node.args[1]
|
| 290 |
+
|
| 291 |
+
if node.target == "constant":
|
| 292 |
+
return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index]
|
| 293 |
+
|
| 294 |
+
if node.target.startswith("masked_subblock"):
|
| 295 |
+
return self.deduce_node_dtype_by_subgraph(node)
|
| 296 |
+
|
| 297 |
+
return self.deduce_node_dtype_by_inputs(node)
|
| 298 |
+
|
| 299 |
+
def propagate_graph(self, graph: torch.fx.Graph):
|
| 300 |
+
assert graph.nodes
|
| 301 |
+
graph_dtype = None
|
| 302 |
+
# For masked_subblock, we use output's dtype to represent
|
| 303 |
+
# the dtype of this subgraph. For other cases, graph_dtype
|
| 304 |
+
# might be None
|
| 305 |
+
for node in graph.nodes:
|
| 306 |
+
if OptimizationContext.key in node.meta:
|
| 307 |
+
opt_ctx = node.meta[OptimizationContext.key]
|
| 308 |
+
else:
|
| 309 |
+
opt_ctx = OptimizationContext()
|
| 310 |
+
|
| 311 |
+
opt_ctx.dtype = self.deduce_node_dtype(node)
|
| 312 |
+
node.meta[OptimizationContext.key] = opt_ctx
|
| 313 |
+
if node.target == "output":
|
| 314 |
+
graph_dtype = opt_ctx.dtype
|
| 315 |
+
return graph_dtype
|
| 316 |
+
|
| 317 |
+
def propagate(self):
|
| 318 |
+
self.propagate_graph(self.graphs["root"])
|
| 319 |
+
|
| 320 |
+
@classmethod
|
| 321 |
+
def propagate_loopbody(cls, body):
|
| 322 |
+
return cls(body).propagate()
|
| 323 |
+
|
| 324 |
+
@classmethod
|
| 325 |
+
def propagate_scheduler_node(cls, node):
|
| 326 |
+
from ..ir import LoopBody
|
| 327 |
+
from ..scheduler import SchedulerNode
|
| 328 |
+
|
| 329 |
+
assert isinstance(node, SchedulerNode)
|
| 330 |
+
assert isinstance(node._body, LoopBody)
|
| 331 |
+
DataTypePropagation.propagate_loopbody(node._body)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class ExprPrinter(Printer):
|
| 335 |
+
@staticmethod
|
| 336 |
+
def paren(string):
|
| 337 |
+
def all_in_parens(string):
|
| 338 |
+
if string[0] != "(" or len(string) < 2:
|
| 339 |
+
return False
|
| 340 |
+
count = 1
|
| 341 |
+
for i, char in enumerate(string[1:]):
|
| 342 |
+
if char == "(":
|
| 343 |
+
count += 1
|
| 344 |
+
elif char == ")":
|
| 345 |
+
count -= 1
|
| 346 |
+
if count == 0 and i != len(string) - 2:
|
| 347 |
+
return False
|
| 348 |
+
assert count == 0
|
| 349 |
+
return True
|
| 350 |
+
|
| 351 |
+
if (
|
| 352 |
+
isinstance(string, CSEVariable)
|
| 353 |
+
or re.match(r"^[a-z0-9_.]+$", string, re.I)
|
| 354 |
+
or re.match(r"^\([^)]*\)$", string, re.I)
|
| 355 |
+
or string == ""
|
| 356 |
+
):
|
| 357 |
+
return string
|
| 358 |
+
# don't put extra parens for strings that are already wrapped in parens
|
| 359 |
+
if all_in_parens(string):
|
| 360 |
+
return string
|
| 361 |
+
return f"({string})"
|
| 362 |
+
|
| 363 |
+
def _print_Infinity(self, expr):
|
| 364 |
+
return "math.inf"
|
| 365 |
+
|
| 366 |
+
def _print_NegativeInfinity(self, expr):
|
| 367 |
+
return "-math.inf"
|
| 368 |
+
|
| 369 |
+
def _print_Relational(self, expr):
|
| 370 |
+
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
|
| 371 |
+
|
| 372 |
+
def _print_Mul(self, expr):
|
| 373 |
+
return "*".join(map(self.paren, map(self._print, expr.args)))
|
| 374 |
+
|
| 375 |
+
def _print_Add(self, expr):
|
| 376 |
+
return " + ".join(map(self.paren, map(self._print, expr.args)))
|
| 377 |
+
|
| 378 |
+
def _print_Mod(self, expr):
|
| 379 |
+
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
| 380 |
+
|
| 381 |
+
def _print_FloorDiv(self, expr):
|
| 382 |
+
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
|
| 383 |
+
|
| 384 |
+
def _print_CleanDiv(self, expr):
|
| 385 |
+
return self._print_FloorDiv(expr)
|
| 386 |
+
|
| 387 |
+
def _print_GreaterThan(self, expr):
|
| 388 |
+
# GreaterThan: >=
|
| 389 |
+
# StrictlyGreaterThan: >
|
| 390 |
+
# Go figure...
|
| 391 |
+
return " >= ".join(map(self.paren, map(self._print, expr.args)))
|
| 392 |
+
|
| 393 |
+
def _print_align(self, expr):
|
| 394 |
+
assert len(expr.args) == 1
|
| 395 |
+
return f"align({self._print(expr.args[0])})"
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class PythonPrinter(ExprPrinter):
|
| 399 |
+
def _print_ModularIndexing(self, expr):
|
| 400 |
+
x, div, mod = expr.args
|
| 401 |
+
x = self.paren(self.doprint(x))
|
| 402 |
+
div = self.paren(self.doprint(div))
|
| 403 |
+
mod = self.paren(self.doprint(mod))
|
| 404 |
+
if div != "1":
|
| 405 |
+
x = f"({x} // {div})"
|
| 406 |
+
return f"{x} % {mod}"
|
| 407 |
+
|
| 408 |
+
def _print_FloorDiv(self, expr):
|
| 409 |
+
x, div = expr.args
|
| 410 |
+
x = self.paren(self.doprint(x))
|
| 411 |
+
div = self.paren(self.doprint(div))
|
| 412 |
+
return f"({x} // {div})"
|
| 413 |
+
|
| 414 |
+
def _helper_sqrt(self, expr):
|
| 415 |
+
return f"math.sqrt({self._print(expr)})"
|
| 416 |
+
|
| 417 |
+
def _print_Pow(self, expr):
|
| 418 |
+
# Pow() confuses triton
|
| 419 |
+
base, exp = expr.args
|
| 420 |
+
# NB: Remember this is sizevar computation! You don't typically
|
| 421 |
+
# expect to have to do floating point computation including exponents
|
| 422 |
+
# in sizevar compute. Instead of adding support for floating
|
| 423 |
+
# point pow, you should make upstream retranslate the Sympy expression
|
| 424 |
+
# into Tensor expressions earlier and do that instead.
|
| 425 |
+
if exp == 0.5:
|
| 426 |
+
return self._helper_sqrt(base)
|
| 427 |
+
elif exp == -0.5:
|
| 428 |
+
return "1/" + self._helper_sqrt(base)
|
| 429 |
+
base = self._print(base)
|
| 430 |
+
assert exp == int(exp), exp
|
| 431 |
+
exp = int(exp)
|
| 432 |
+
if exp > 0:
|
| 433 |
+
return "*".join([self.paren(base)] * exp)
|
| 434 |
+
elif exp < 0:
|
| 435 |
+
return "1/" + self.paren("*".join([self.paren(base)] * abs(exp)))
|
| 436 |
+
else: # exp == 0
|
| 437 |
+
return "1"
|
| 438 |
+
|
| 439 |
+
def _print_floor(self, expr):
|
| 440 |
+
assert len(expr.args) == 1
|
| 441 |
+
return f"math.floor({self._print(expr.args[0])})"
|
| 442 |
+
|
| 443 |
+
def _print_ceiling(self, expr):
|
| 444 |
+
assert len(expr.args) == 1
|
| 445 |
+
return f"math.ceil({self._print(expr.args[0])})"
|
| 446 |
+
|
| 447 |
+
def _print_Abs(self, expr):
|
| 448 |
+
assert len(expr.args) == 1
|
| 449 |
+
return f"abs({self._print(expr.args[0])})"
|
| 450 |
+
|
| 451 |
+
def _print_Max(self, expr):
|
| 452 |
+
assert len(expr.args) >= 2
|
| 453 |
+
return f"max({', '.join(map(self._print, expr.args))})"
|
| 454 |
+
|
| 455 |
+
def _print_Min(self, expr):
|
| 456 |
+
assert len(expr.args) >= 2
|
| 457 |
+
return f"min({', '.join(map(self._print, expr.args))})"
|
| 458 |
+
|
| 459 |
+
def _print_cos(self, expr):
|
| 460 |
+
assert len(expr.args) == 1
|
| 461 |
+
return f"math.cos({self._print(expr.args[0])})"
|
| 462 |
+
|
| 463 |
+
def _print_cosh(self, expr):
|
| 464 |
+
assert len(expr.args) == 1
|
| 465 |
+
return f"math.cosh({self._print(expr.args[0])})"
|
| 466 |
+
|
| 467 |
+
def _print_acos(self, expr):
|
| 468 |
+
assert len(expr.args) == 1
|
| 469 |
+
return f"math.acos({self._print(expr.args[0])})"
|
| 470 |
+
|
| 471 |
+
def _print_sin(self, expr):
|
| 472 |
+
assert len(expr.args) == 1
|
| 473 |
+
return f"math.sin({self._print(expr.args[0])})"
|
| 474 |
+
|
| 475 |
+
def _print_sinh(self, expr):
|
| 476 |
+
assert len(expr.args) == 1
|
| 477 |
+
return f"math.sinh({self._print(expr.args[0])})"
|
| 478 |
+
|
| 479 |
+
def _print_asin(self, expr):
|
| 480 |
+
assert len(expr.args) == 1
|
| 481 |
+
return f"math.asin({self._print(expr.args[0])})"
|
| 482 |
+
|
| 483 |
+
def _print_tan(self, expr):
|
| 484 |
+
assert len(expr.args) == 1
|
| 485 |
+
return f"math.tan({self._print(expr.args[0])})"
|
| 486 |
+
|
| 487 |
+
def _print_tanh(self, expr):
|
| 488 |
+
assert len(expr.args) == 1
|
| 489 |
+
return f"math.tanh({self._print(expr.args[0])})"
|
| 490 |
+
|
| 491 |
+
def _print_atan(self, expr):
|
| 492 |
+
assert len(expr.args) == 1
|
| 493 |
+
return f"math.atan({self._print(expr.args[0])})"
|
| 494 |
+
|
| 495 |
+
def _print_Round(self, expr):
|
| 496 |
+
assert len(expr.args) == 1
|
| 497 |
+
return f"round({self._print(expr.args[0])})"
|
| 498 |
+
|
| 499 |
+
def _print_RoundDecimal(self, expr):
|
| 500 |
+
assert len(expr.args) == 2
|
| 501 |
+
number, ndigits = expr.args
|
| 502 |
+
assert isinstance(ndigits, sympy.Integer)
|
| 503 |
+
return f"round({self._print(number)}, {ndigits})"
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
class OpOverrides:
|
| 507 |
+
def __init__(self, parent):
|
| 508 |
+
super().__init__()
|
| 509 |
+
self._parent = parent
|
| 510 |
+
|
| 511 |
+
def __getattr__(self, item):
|
| 512 |
+
return getattr(self._parent, item)
|
| 513 |
+
|
| 514 |
+
@staticmethod
|
| 515 |
+
def identity(value):
|
| 516 |
+
# used to trigger cse
|
| 517 |
+
return value
|
| 518 |
+
|
| 519 |
+
@staticmethod
|
| 520 |
+
def constant(value, dtype):
|
| 521 |
+
return repr(value)
|
| 522 |
+
|
| 523 |
+
@staticmethod
|
| 524 |
+
def reciprocal(x):
|
| 525 |
+
return ops.truediv("1", x)
|
| 526 |
+
|
| 527 |
+
@staticmethod
|
| 528 |
+
def square(x):
|
| 529 |
+
return ops.mul(x, x)
|
| 530 |
+
|
| 531 |
+
@staticmethod
|
| 532 |
+
def bitwise_not(x):
|
| 533 |
+
return f"~{ExprPrinter.paren(x)}"
|
| 534 |
+
|
| 535 |
+
@staticmethod
|
| 536 |
+
def logical_not(a):
|
| 537 |
+
return f"{ExprPrinter.paren(a)} == 0"
|
| 538 |
+
|
| 539 |
+
@staticmethod
|
| 540 |
+
def bitwise_and(x, y):
|
| 541 |
+
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
|
| 542 |
+
|
| 543 |
+
@staticmethod
|
| 544 |
+
def bitwise_or(x, y):
|
| 545 |
+
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
|
| 546 |
+
|
| 547 |
+
@staticmethod
|
| 548 |
+
def bitwise_xor(x, y):
|
| 549 |
+
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
|
| 550 |
+
|
| 551 |
+
@staticmethod
|
| 552 |
+
def bitwise_left_shift(x, y):
|
| 553 |
+
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
|
| 554 |
+
|
| 555 |
+
@staticmethod
|
| 556 |
+
def bitwise_right_shift(x, y):
|
| 557 |
+
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
|
| 558 |
+
|
| 559 |
+
@staticmethod
|
| 560 |
+
def remainder(a, b):
|
| 561 |
+
r = ops.mod(a, b)
|
| 562 |
+
return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)
|
| 563 |
+
|
| 564 |
+
@staticmethod
|
| 565 |
+
def load_seed(name, offset):
|
| 566 |
+
return ops.load(name, sympy.Integer(offset))
|
| 567 |
+
|
| 568 |
+
@classmethod
|
| 569 |
+
def _initialize_pointwise_overrides(cls, target):
|
| 570 |
+
assert target in {"triton", "cpp", "cppvec"}, target
|
| 571 |
+
|
| 572 |
+
def pointwise_factory_1(impl):
|
| 573 |
+
def func(x):
|
| 574 |
+
return impl.format(x=x)
|
| 575 |
+
|
| 576 |
+
return func
|
| 577 |
+
|
| 578 |
+
def pointwise_factory_2(impl):
|
| 579 |
+
def func(x, y):
|
| 580 |
+
return impl.format(x=x, y=y)
|
| 581 |
+
|
| 582 |
+
return func
|
| 583 |
+
|
| 584 |
+
for funcname, data in pointwise_overrides_data.items():
|
| 585 |
+
impl = getattr(data, target)
|
| 586 |
+
if isinstance(impl, str):
|
| 587 |
+
nof_args = 2 if "{y}" in impl else 1
|
| 588 |
+
# extend the following dictionary with factory
|
| 589 |
+
# functions for a specific number of arguments as
|
| 590 |
+
# needed:
|
| 591 |
+
factory = {1: pointwise_factory_1, 2: pointwise_factory_2}[nof_args]
|
| 592 |
+
setattr(cls, funcname, staticmethod(factory(impl)))
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
@dataclasses.dataclass
|
| 596 |
+
class OverridesData:
|
| 597 |
+
name: str
|
| 598 |
+
cpp: str
|
| 599 |
+
triton: Optional[str] = None # None when not impl in libdevice/triton
|
| 600 |
+
cppvec: Optional[str] = None # None when not impl in aten/.../vec
|
| 601 |
+
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
|
| 602 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
pointwise_overrides_data: Dict[str, OverridesData] = dict(
|
| 607 |
+
airy_ai=OverridesData(
|
| 608 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 609 |
+
cpp="airy_ai_forward({x})",
|
| 610 |
+
name="special_airy_ai",
|
| 611 |
+
),
|
| 612 |
+
bessel_j0=OverridesData(
|
| 613 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 614 |
+
cpp="bessel_j0_forward({x})",
|
| 615 |
+
triton="libdevice.j0({x})",
|
| 616 |
+
name="special_bessel_j0",
|
| 617 |
+
),
|
| 618 |
+
bessel_j1=OverridesData(
|
| 619 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 620 |
+
cpp="bessel_j1_forward({x})",
|
| 621 |
+
triton="libdevice.j1({x})",
|
| 622 |
+
name="special_bessel_j1",
|
| 623 |
+
),
|
| 624 |
+
bessel_y0=OverridesData(
|
| 625 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 626 |
+
cpp="bessel_y0_forward({x})",
|
| 627 |
+
triton="libdevice.y0({x})",
|
| 628 |
+
name="special_bessel_y0",
|
| 629 |
+
),
|
| 630 |
+
bessel_y1=OverridesData(
|
| 631 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 632 |
+
cpp="bessel_y1_forward({x})",
|
| 633 |
+
triton="libdevice.y1({x})",
|
| 634 |
+
name="special_bessel_y1",
|
| 635 |
+
),
|
| 636 |
+
digamma=OverridesData(
|
| 637 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 638 |
+
cpp="calc_digamma({x})",
|
| 639 |
+
cppvec="{x}.digamma()",
|
| 640 |
+
name="digamma",
|
| 641 |
+
),
|
| 642 |
+
# no cpp nor triton implementation for entr, it is defined as decomposition
|
| 643 |
+
# erf, erfc
|
| 644 |
+
erfcx=OverridesData(
|
| 645 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 646 |
+
cpp="calc_erfcx({x})",
|
| 647 |
+
triton="libdevice.erfcx({x})",
|
| 648 |
+
name="special_erfcx",
|
| 649 |
+
),
|
| 650 |
+
# erfinv, exp2, expit, gammaln
|
| 651 |
+
igamma=OverridesData(
|
| 652 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 653 |
+
cpp="calc_igamma({x}, {y})",
|
| 654 |
+
name="igamma",
|
| 655 |
+
),
|
| 656 |
+
igammac=OverridesData(
|
| 657 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 658 |
+
cpp="calc_igammac({x}, {y})",
|
| 659 |
+
name="igammac",
|
| 660 |
+
),
|
| 661 |
+
gammainc=OverridesData(
|
| 662 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 663 |
+
cpp="calc_igamma({x}, {y})",
|
| 664 |
+
name="special_gammainc",
|
| 665 |
+
),
|
| 666 |
+
gammaincc=OverridesData(
|
| 667 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 668 |
+
cpp="calc_igammac({x}, {y})",
|
| 669 |
+
name="special_gammaincc",
|
| 670 |
+
),
|
| 671 |
+
i0=OverridesData(
|
| 672 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 673 |
+
cpp="calc_i0({x})",
|
| 674 |
+
triton="libdevice.cyl_bessel_i0({x})",
|
| 675 |
+
cppvec="{x}.i0()",
|
| 676 |
+
name="i0",
|
| 677 |
+
),
|
| 678 |
+
i0e=OverridesData(
|
| 679 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 680 |
+
cpp="calc_i0e({x})",
|
| 681 |
+
cppvec="{x}.i0e()",
|
| 682 |
+
name="special_i0e",
|
| 683 |
+
),
|
| 684 |
+
i1=OverridesData(
|
| 685 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 686 |
+
cpp="calc_i1({x})",
|
| 687 |
+
triton="libdevice.cyl_bessel_i1({x})",
|
| 688 |
+
name="special_i1",
|
| 689 |
+
),
|
| 690 |
+
i1e=OverridesData(
|
| 691 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 692 |
+
cpp="calc_i1e({x})",
|
| 693 |
+
name="special_i1e",
|
| 694 |
+
),
|
| 695 |
+
log_ndtr=OverridesData(
|
| 696 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 697 |
+
cpp="calc_log_ndtr({x})",
|
| 698 |
+
name="special_log_ndtr",
|
| 699 |
+
),
|
| 700 |
+
# logit
|
| 701 |
+
modified_bessel_i0=OverridesData(
|
| 702 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 703 |
+
cpp="modified_bessel_i0_forward({x})",
|
| 704 |
+
triton="libdevice.cyl_bessel_i0({x})",
|
| 705 |
+
name="special_modified_bessel_i0",
|
| 706 |
+
),
|
| 707 |
+
modified_bessel_i1=OverridesData(
|
| 708 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 709 |
+
cpp="modified_bessel_i1_forward({x})",
|
| 710 |
+
triton="libdevice.cyl_bessel_i1({x})",
|
| 711 |
+
name="special_modified_bessel_i1",
|
| 712 |
+
),
|
| 713 |
+
modified_bessel_k0=OverridesData(
|
| 714 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 715 |
+
cpp="modified_bessel_k0_forward({x})",
|
| 716 |
+
name="special_modified_bessel_k0",
|
| 717 |
+
),
|
| 718 |
+
modified_bessel_k1=OverridesData(
|
| 719 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 720 |
+
cpp="modified_bessel_k1_forward({x})",
|
| 721 |
+
name="special_modified_bessel_k1",
|
| 722 |
+
),
|
| 723 |
+
# multigamma
|
| 724 |
+
ndtr=OverridesData(
|
| 725 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 726 |
+
cpp="calc_ndtr({x})",
|
| 727 |
+
name="special_ndtr",
|
| 728 |
+
),
|
| 729 |
+
ndtri=OverridesData(
|
| 730 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 731 |
+
cpp="calc_ndtri({x})",
|
| 732 |
+
name="special_ndtri",
|
| 733 |
+
),
|
| 734 |
+
polygamma=OverridesData(
|
| 735 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 736 |
+
cpp="calc_polygamma({y}, {x})",
|
| 737 |
+
name="polygamma",
|
| 738 |
+
),
|
| 739 |
+
# psi - alias to digamma
|
| 740 |
+
# round
|
| 741 |
+
scaled_modified_bessel_k0=OverridesData(
|
| 742 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 743 |
+
cpp="scaled_modified_bessel_k0_forward({x})",
|
| 744 |
+
name="special_scaled_modified_bessel_k0",
|
| 745 |
+
),
|
| 746 |
+
scaled_modified_bessel_k1=OverridesData(
|
| 747 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 748 |
+
cpp="scaled_modified_bessel_k1_forward({x})",
|
| 749 |
+
name="special_scaled_modified_bessel_k1",
|
| 750 |
+
),
|
| 751 |
+
# sinc
|
| 752 |
+
spherical_bessel_j0=OverridesData(
|
| 753 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 754 |
+
cpp="spherical_bessel_j0_forward({x})",
|
| 755 |
+
name="special_spherical_bessel_j0",
|
| 756 |
+
),
|
| 757 |
+
zeta=OverridesData(
|
| 758 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 759 |
+
cpp="zeta({x}, {y})",
|
| 760 |
+
name="special_zeta",
|
| 761 |
+
),
|
| 762 |
+
chebyshev_polynomial_t=OverridesData(
|
| 763 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 764 |
+
cpp="chebyshev_polynomial_t_forward({x}, {y})",
|
| 765 |
+
name="special_chebyshev_polynomial_t",
|
| 766 |
+
),
|
| 767 |
+
chebyshev_polynomial_u=OverridesData(
|
| 768 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 769 |
+
cpp="chebyshev_polynomial_u_forward({x}, {y})",
|
| 770 |
+
name="special_chebyshev_polynomial_u",
|
| 771 |
+
),
|
| 772 |
+
chebyshev_polynomial_v=OverridesData(
|
| 773 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 774 |
+
cpp="chebyshev_polynomial_v_forward({x}, {y})",
|
| 775 |
+
name="special_chebyshev_polynomial_v",
|
| 776 |
+
),
|
| 777 |
+
chebyshev_polynomial_w=OverridesData(
|
| 778 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 779 |
+
cpp="chebyshev_polynomial_w_forward({x}, {y})",
|
| 780 |
+
name="special_chebyshev_polynomial_w",
|
| 781 |
+
),
|
| 782 |
+
legendre_polynomial_p=OverridesData(
|
| 783 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 784 |
+
cpp="legendre_polynomial_p_forward({x}, {y})",
|
| 785 |
+
name="special_legendre_polynomial_p",
|
| 786 |
+
),
|
| 787 |
+
shifted_chebyshev_polynomial_t=OverridesData(
|
| 788 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 789 |
+
cpp="shifted_chebyshev_polynomial_t_forward({x}, {y})",
|
| 790 |
+
name="special_shifted_chebyshev_polynomial_t",
|
| 791 |
+
),
|
| 792 |
+
shifted_chebyshev_polynomial_u=OverridesData(
|
| 793 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 794 |
+
cpp="shifted_chebyshev_polynomial_u_forward({x}, {y})",
|
| 795 |
+
name="special_shifted_chebyshev_polynomial_u",
|
| 796 |
+
),
|
| 797 |
+
shifted_chebyshev_polynomial_v=OverridesData(
|
| 798 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 799 |
+
cpp="shifted_chebyshev_polynomial_v_forward({x}, {y})",
|
| 800 |
+
name="special_shifted_chebyshev_polynomial_v",
|
| 801 |
+
),
|
| 802 |
+
shifted_chebyshev_polynomial_w=OverridesData(
|
| 803 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 804 |
+
cpp="shifted_chebyshev_polynomial_w_forward({x}, {y})",
|
| 805 |
+
name="special_shifted_chebyshev_polynomial_w",
|
| 806 |
+
),
|
| 807 |
+
hermite_polynomial_h=OverridesData(
|
| 808 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 809 |
+
cpp="hermite_polynomial_h_forward({x}, {y})",
|
| 810 |
+
name="special_hermite_polynomial_h",
|
| 811 |
+
),
|
| 812 |
+
hermite_polynomial_he=OverridesData(
|
| 813 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 814 |
+
cpp="hermite_polynomial_he_forward({x}, {y})",
|
| 815 |
+
name="special_hermite_polynomial_he",
|
| 816 |
+
),
|
| 817 |
+
laguerre_polynomial_l=OverridesData(
|
| 818 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 819 |
+
cpp="laguerre_polynomial_l_forward({x}, {y})",
|
| 820 |
+
name="special_laguerre_polynomial_l",
|
| 821 |
+
),
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
# Use mypy to check protocol implemented correctly
|
| 826 |
+
def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
|
| 827 |
+
return h
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
class DeferredLine(DeferredLineBase):
|
| 831 |
+
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
|
| 832 |
+
|
| 833 |
+
def __init__(self, name, line):
|
| 834 |
+
super().__init__(line)
|
| 835 |
+
self.name = name
|
| 836 |
+
assert not isinstance(line, DeferredLineBase)
|
| 837 |
+
|
| 838 |
+
def __call__(self):
|
| 839 |
+
if all(
|
| 840 |
+
self.name not in x
|
| 841 |
+
for x in (
|
| 842 |
+
V.graph.removed_buffers,
|
| 843 |
+
V.kernel.removed_buffers,
|
| 844 |
+
V.graph.inplaced_to_remove,
|
| 845 |
+
V.kernel.inplaced_to_remove,
|
| 846 |
+
)
|
| 847 |
+
):
|
| 848 |
+
return self.line
|
| 849 |
+
return None
|
| 850 |
+
|
| 851 |
+
def _new_line(self, line):
|
| 852 |
+
return DeferredLine(self.name, line)
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
class BracesBuffer(IndentedBuffer):
|
| 856 |
+
def indent(self, offset=1):
|
| 857 |
+
@contextlib.contextmanager
|
| 858 |
+
def ctx():
|
| 859 |
+
for _ in range(offset):
|
| 860 |
+
self.writeline("{")
|
| 861 |
+
self._indent += 1
|
| 862 |
+
for _ in range(-offset):
|
| 863 |
+
self._indent -= 1
|
| 864 |
+
self.writeline("}")
|
| 865 |
+
yield
|
| 866 |
+
for _ in range(-offset):
|
| 867 |
+
self.writeline("{")
|
| 868 |
+
self._indent += 1
|
| 869 |
+
for _ in range(offset):
|
| 870 |
+
self._indent -= 1
|
| 871 |
+
self.writeline("}")
|
| 872 |
+
|
| 873 |
+
return ctx()
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
class InplacedBuffer(NamedTuple):
|
| 877 |
+
inner_name: str
|
| 878 |
+
other_names: List[str]
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
class KernelArgs:
|
| 882 |
+
@staticmethod
|
| 883 |
+
def _lookup(prefix, odict, name):
|
| 884 |
+
assert isinstance(name, (str, sympy.Symbol))
|
| 885 |
+
if name not in odict:
|
| 886 |
+
odict[name] = f"{prefix}{len(odict)}"
|
| 887 |
+
return odict[name]
|
| 888 |
+
|
| 889 |
+
def __init__(self, sizevars=None):
|
| 890 |
+
self.input_buffers = dict()
|
| 891 |
+
self.output_buffers = dict()
|
| 892 |
+
self.inplace_buffers = dict()
|
| 893 |
+
self.sizevars = sizevars or dict()
|
| 894 |
+
self.workspace_arg = None
|
| 895 |
+
|
| 896 |
+
def __repr__(self):
|
| 897 |
+
return "KernelArgs({})".format(
|
| 898 |
+
", ".join(
|
| 899 |
+
map(
|
| 900 |
+
repr,
|
| 901 |
+
[
|
| 902 |
+
self.input_buffers,
|
| 903 |
+
self.output_buffers,
|
| 904 |
+
self.inplace_buffers,
|
| 905 |
+
self.sizevars,
|
| 906 |
+
],
|
| 907 |
+
)
|
| 908 |
+
)
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
def _buffer_is_marked_removed(self, name):
|
| 912 |
+
return isinstance(name, str) and name.startswith("REMOVED")
|
| 913 |
+
|
| 914 |
+
def input(self, name):
|
| 915 |
+
if V.graph.scheduler:
|
| 916 |
+
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
| 917 |
+
assert name not in V.graph.removed_buffers, name
|
| 918 |
+
if name in self.output_buffers:
|
| 919 |
+
return self.output_buffers[name]
|
| 920 |
+
if name in self.inplace_buffers:
|
| 921 |
+
return self.inplace_buffers[name].inner_name
|
| 922 |
+
if name.startswith("seed"):
|
| 923 |
+
return self._lookup("seed", self.input_buffers, name)
|
| 924 |
+
return self._lookup("in_ptr", self.input_buffers, name)
|
| 925 |
+
|
| 926 |
+
def output(self, name):
|
| 927 |
+
if V.graph.scheduler:
|
| 928 |
+
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
| 929 |
+
assert name not in V.graph.removed_buffers, name
|
| 930 |
+
if name in self.inplace_buffers:
|
| 931 |
+
return self.inplace_buffers[name].inner_name
|
| 932 |
+
return self._lookup("out_ptr", self.output_buffers, name)
|
| 933 |
+
|
| 934 |
+
def make_inplace(self, input_name, output_name):
|
| 935 |
+
assert output_name not in self.inplace_buffers
|
| 936 |
+
if input_name in self.inplace_buffers:
|
| 937 |
+
buf = self.inplace_buffers[input_name]
|
| 938 |
+
buf.other_names.append(output_name)
|
| 939 |
+
self.inplace_buffers[output_name] = buf
|
| 940 |
+
else:
|
| 941 |
+
buf = InplacedBuffer(
|
| 942 |
+
f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
|
| 943 |
+
[input_name, output_name],
|
| 944 |
+
)
|
| 945 |
+
self.inplace_buffers[input_name] = buf
|
| 946 |
+
self.inplace_buffers[output_name] = buf
|
| 947 |
+
|
| 948 |
+
def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
|
| 949 |
+
if self.workspace_arg is None:
|
| 950 |
+
self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
|
| 951 |
+
return "ws_ptr", 0
|
| 952 |
+
|
| 953 |
+
offset = self.workspace_arg.nbytes
|
| 954 |
+
zero_fill = zero_fill or self.workspace_arg.zero_fill
|
| 955 |
+
self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
|
| 956 |
+
return "ws_ptr", offset
|
| 957 |
+
|
| 958 |
+
def seed_offset(self, name, value):
|
| 959 |
+
if value in self.sizevars:
|
| 960 |
+
return self.sizevars[value]
|
| 961 |
+
if name in self.sizevars.values():
|
| 962 |
+
name = (
|
| 963 |
+
f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
|
| 964 |
+
)
|
| 965 |
+
self.sizevars[value] = name
|
| 966 |
+
return name
|
| 967 |
+
|
| 968 |
+
def size(self, name):
|
| 969 |
+
if str(name) == "seed":
|
| 970 |
+
self.sizevars["seed"] = "seed"
|
| 971 |
+
return "seed"
|
| 972 |
+
return self._lookup("ks", self.sizevars, name)
|
| 973 |
+
|
| 974 |
+
def call_names(self):
|
| 975 |
+
return chain(
|
| 976 |
+
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
def wrap_ptr_arg(self, buf, dtype):
|
| 980 |
+
return buf
|
| 981 |
+
|
| 982 |
+
def wrap_size_arg(self, size):
|
| 983 |
+
return str(size)
|
| 984 |
+
|
| 985 |
+
def cpp_argdefs(self):
|
| 986 |
+
from .cpp import DTYPE_TO_CPP, INDEX_TYPE
|
| 987 |
+
|
| 988 |
+
call_args = []
|
| 989 |
+
arg_defs = []
|
| 990 |
+
arg_types = []
|
| 991 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 992 |
+
if self._buffer_is_marked_removed(inplaced):
|
| 993 |
+
continue
|
| 994 |
+
outer = inplaced.other_names[-1]
|
| 995 |
+
inner = inplaced.inner_name
|
| 996 |
+
dtype = V.graph.get_dtype(outer)
|
| 997 |
+
cpp_dtype = DTYPE_TO_CPP[dtype]
|
| 998 |
+
arg_defs.append(f"{cpp_dtype}* {inner}")
|
| 999 |
+
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
| 1000 |
+
arg_types.append(f"{cpp_dtype}*")
|
| 1001 |
+
for outer, inner in self.input_buffers.items():
|
| 1002 |
+
if outer in self.inplace_buffers:
|
| 1003 |
+
continue
|
| 1004 |
+
dtype = V.graph.get_dtype(outer)
|
| 1005 |
+
cpp_dtype = DTYPE_TO_CPP[dtype]
|
| 1006 |
+
arg_defs.append(f"const {cpp_dtype}* {inner}")
|
| 1007 |
+
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
| 1008 |
+
arg_types.append(f"const {cpp_dtype}*")
|
| 1009 |
+
for outer, inner in self.output_buffers.items():
|
| 1010 |
+
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
| 1011 |
+
continue
|
| 1012 |
+
dtype = V.graph.get_dtype(outer)
|
| 1013 |
+
cpp_dtype = DTYPE_TO_CPP[dtype]
|
| 1014 |
+
arg_defs.append(f"{cpp_dtype}* {inner}")
|
| 1015 |
+
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
| 1016 |
+
arg_types.append(f"{cpp_dtype}*")
|
| 1017 |
+
for outer, inner in self.sizevars.items():
|
| 1018 |
+
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
| 1019 |
+
call_args.append(self.wrap_size_arg(outer))
|
| 1020 |
+
arg_types.append(f"const {INDEX_TYPE}")
|
| 1021 |
+
if V.graph.wrapper_code:
|
| 1022 |
+
V.graph.wrapper_code.ensure_size_computed(outer)
|
| 1023 |
+
assert self.workspace_arg is None, "Workspace not supported on CPU "
|
| 1024 |
+
return arg_defs, call_args, arg_types
|
| 1025 |
+
|
| 1026 |
+
def python_argdefs(self):
|
| 1027 |
+
arg_defs = []
|
| 1028 |
+
call_args = []
|
| 1029 |
+
precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
|
| 1030 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1031 |
+
if self._buffer_is_marked_removed(inplaced):
|
| 1032 |
+
continue
|
| 1033 |
+
arg_defs.append(inplaced.inner_name)
|
| 1034 |
+
call_args.append(inplaced.other_names[-1])
|
| 1035 |
+
precompile_args.append(
|
| 1036 |
+
TensorArg(
|
| 1037 |
+
name=inplaced.inner_name,
|
| 1038 |
+
buffer=inplaced.other_names[-1],
|
| 1039 |
+
dtype=V.graph.get_dtype(inplaced.other_names[-1]),
|
| 1040 |
+
)
|
| 1041 |
+
)
|
| 1042 |
+
for outer, inner in chain(
|
| 1043 |
+
self.input_buffers.items(), self.output_buffers.items()
|
| 1044 |
+
):
|
| 1045 |
+
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
| 1046 |
+
continue
|
| 1047 |
+
arg_defs.append(inner)
|
| 1048 |
+
call_args.append(outer)
|
| 1049 |
+
precompile_args.append(
|
| 1050 |
+
TensorArg(
|
| 1051 |
+
name=inner,
|
| 1052 |
+
buffer=outer,
|
| 1053 |
+
dtype=V.graph.get_dtype(outer),
|
| 1054 |
+
)
|
| 1055 |
+
)
|
| 1056 |
+
for outer, inner in self.sizevars.items():
|
| 1057 |
+
arg_defs.append(inner)
|
| 1058 |
+
call_args.append(outer)
|
| 1059 |
+
precompile_args.append(SizeArg(inner, outer))
|
| 1060 |
+
if V.graph.wrapper_code:
|
| 1061 |
+
V.graph.wrapper_code.ensure_size_computed(outer)
|
| 1062 |
+
if self.workspace_arg is not None:
|
| 1063 |
+
arg_defs.append("ws_ptr")
|
| 1064 |
+
call_args.append("workspace")
|
| 1065 |
+
precompile_args.append(self.workspace_arg)
|
| 1066 |
+
|
| 1067 |
+
return arg_defs, call_args, precompile_args
|
| 1068 |
+
|
| 1069 |
+
def aliases(self):
|
| 1070 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1071 |
+
if self._buffer_is_marked_removed(inplaced):
|
| 1072 |
+
continue
|
| 1073 |
+
for other in inplaced.other_names:
|
| 1074 |
+
if (
|
| 1075 |
+
other in V.graph.inplaced_to_remove
|
| 1076 |
+
or other in V.kernel.inplaced_to_remove
|
| 1077 |
+
):
|
| 1078 |
+
continue
|
| 1079 |
+
if other in self.input_buffers:
|
| 1080 |
+
yield self.input_buffers[other], inplaced.inner_name
|
| 1081 |
+
if other in self.output_buffers:
|
| 1082 |
+
yield self.output_buffers[other], inplaced.inner_name
|
| 1083 |
+
|
| 1084 |
+
def is_removed(self, name):
|
| 1085 |
+
def _is_removed(name, buffers):
|
| 1086 |
+
return name not in buffers or self._buffer_is_marked_removed(buffers[name])
|
| 1087 |
+
|
| 1088 |
+
return _is_removed(name, self.output_buffers) and _is_removed(
|
| 1089 |
+
name, self.inplace_buffers
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
# Includes inplace buffers, excludes removed buffers. Essentially,
|
| 1093 |
+
# after you do a call into this kernel, which buffers actually contain
|
| 1094 |
+
# updated data? Modeled off of python_argdefs.
|
| 1095 |
+
def live_output_buffers(self):
|
| 1096 |
+
live_outs = set()
|
| 1097 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1098 |
+
if self._buffer_is_marked_removed(inplaced):
|
| 1099 |
+
continue
|
| 1100 |
+
live_outs.add(inplaced.other_names[-1])
|
| 1101 |
+
for outer, inner in self.output_buffers.items():
|
| 1102 |
+
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
| 1103 |
+
continue
|
| 1104 |
+
live_outs.add(outer)
|
| 1105 |
+
return live_outs
|
| 1106 |
+
|
| 1107 |
+
|
| 1108 |
+
class CSEVariable:
|
| 1109 |
+
"""A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
|
| 1110 |
+
To do so, the backends can simply overload `Kernel.create_cse_var`
|
| 1111 |
+
The "CSEVariable.update_on_args" method gives you a hook for annotations
|
| 1112 |
+
See example of TritonCSEVariable in triton.py
|
| 1113 |
+
"""
|
| 1114 |
+
|
| 1115 |
+
def __init__(self, name, bounds: ValueRanges[Any]):
|
| 1116 |
+
assert isinstance(bounds, ValueRanges)
|
| 1117 |
+
self.name = name
|
| 1118 |
+
self.bounds = bounds
|
| 1119 |
+
|
| 1120 |
+
def __str__(self):
|
| 1121 |
+
return self.name
|
| 1122 |
+
|
| 1123 |
+
def __hash__(self) -> int:
|
| 1124 |
+
return hash(self.name)
|
| 1125 |
+
|
| 1126 |
+
def __eq__(self, other) -> bool:
|
| 1127 |
+
return type(other) == type(self) and other.name == self.name
|
| 1128 |
+
|
| 1129 |
+
def update_on_args(self, name, args, kwargs):
|
| 1130 |
+
pass
|
| 1131 |
+
|
| 1132 |
+
|
| 1133 |
+
class CppWrapperKernelArgs(KernelArgs):
|
| 1134 |
+
def wrap_ptr_arg(self, buf, dtype):
|
| 1135 |
+
from .cpp import DTYPE_TO_CPP
|
| 1136 |
+
|
| 1137 |
+
if config.abi_compatible:
|
| 1138 |
+
# In the abi_compatible model, we just return the buf here.
|
| 1139 |
+
# We will form correct call args later in wrapper.generate_kernel_all.
|
| 1140 |
+
return buf
|
| 1141 |
+
else:
|
| 1142 |
+
return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
|
| 1143 |
+
|
| 1144 |
+
def wrap_size_arg(self, size):
|
| 1145 |
+
return f"{size}"
|
| 1146 |
+
|
| 1147 |
+
|
| 1148 |
+
class CSE:
|
| 1149 |
+
"""Common subexpression elimination"""
|
| 1150 |
+
|
| 1151 |
+
def __init__(
|
| 1152 |
+
self,
|
| 1153 |
+
prefix="",
|
| 1154 |
+
suffix="",
|
| 1155 |
+
name_prefix="tmp",
|
| 1156 |
+
iter_buffers=None,
|
| 1157 |
+
store_cache=None,
|
| 1158 |
+
reduction_cache=None,
|
| 1159 |
+
varname_map=None,
|
| 1160 |
+
):
|
| 1161 |
+
self.prefix = prefix
|
| 1162 |
+
self.suffix = suffix
|
| 1163 |
+
self.cache = {}
|
| 1164 |
+
self.name_prefix = name_prefix
|
| 1165 |
+
self.store_cache = store_cache or {}
|
| 1166 |
+
self.reduction_cache = reduction_cache or {}
|
| 1167 |
+
self.iter_buffer_ids = iter_buffers or itertools.count()
|
| 1168 |
+
self.invalidated_stores = set()
|
| 1169 |
+
self.varname_map = varname_map or {}
|
| 1170 |
+
|
| 1171 |
+
def invalidate(self, keep_vars: Set[str]):
|
| 1172 |
+
for name, tmp in list(self.store_cache.items()):
|
| 1173 |
+
if tmp not in keep_vars:
|
| 1174 |
+
del self.store_cache[name]
|
| 1175 |
+
self.invalidated_stores.add(name)
|
| 1176 |
+
self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
|
| 1177 |
+
|
| 1178 |
+
def clone(self):
|
| 1179 |
+
# Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
|
| 1180 |
+
return CSE(
|
| 1181 |
+
prefix=self.prefix,
|
| 1182 |
+
suffix=self.suffix,
|
| 1183 |
+
name_prefix=self.name_prefix,
|
| 1184 |
+
iter_buffers=self.iter_buffer_ids,
|
| 1185 |
+
store_cache=self.store_cache,
|
| 1186 |
+
varname_map=self.varname_map,
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
def generate(
|
| 1190 |
+
self,
|
| 1191 |
+
buffer: IndentedBuffer,
|
| 1192 |
+
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
|
| 1193 |
+
*,
|
| 1194 |
+
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
| 1195 |
+
write=True,
|
| 1196 |
+
assignment=True,
|
| 1197 |
+
) -> CSEVariable:
|
| 1198 |
+
if isinstance(expr, OpsValue):
|
| 1199 |
+
expr = expr.value
|
| 1200 |
+
|
| 1201 |
+
assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
|
| 1202 |
+
assert write or assignment
|
| 1203 |
+
if isinstance(expr, CSEVariable):
|
| 1204 |
+
# If the expressions were always created with all the information, we could
|
| 1205 |
+
# assert expr.bounds == bounds, but sometimes the expression is created
|
| 1206 |
+
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
|
| 1207 |
+
expr.bounds = expr.bounds.tighten(bounds)
|
| 1208 |
+
return expr
|
| 1209 |
+
cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
|
| 1210 |
+
var = self.cache.get(cache_key, None)
|
| 1211 |
+
if not var:
|
| 1212 |
+
var = self.newvar(bounds) if assignment else None
|
| 1213 |
+
self.cache[cache_key] = var
|
| 1214 |
+
if write:
|
| 1215 |
+
if V.kernel.current_node:
|
| 1216 |
+
V.kernel.current_node.codegen_originating_info(
|
| 1217 |
+
buffer, only_once=True
|
| 1218 |
+
)
|
| 1219 |
+
if isinstance(expr, IndentedBuffer):
|
| 1220 |
+
if assignment:
|
| 1221 |
+
buffer.writeline(f"{self.prefix}{var} =")
|
| 1222 |
+
buffer.splice(expr)
|
| 1223 |
+
buffer.writeline(self.suffix)
|
| 1224 |
+
else:
|
| 1225 |
+
if assignment:
|
| 1226 |
+
line = f"{self.prefix}{var} = {expr}{self.suffix}"
|
| 1227 |
+
else:
|
| 1228 |
+
line = f"{expr}{self.suffix}"
|
| 1229 |
+
buffer.writeline(line)
|
| 1230 |
+
else:
|
| 1231 |
+
var.bounds = var.bounds.tighten(bounds)
|
| 1232 |
+
|
| 1233 |
+
return var
|
| 1234 |
+
|
| 1235 |
+
def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
|
| 1236 |
+
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
|
| 1237 |
+
var = V.kernel.create_cse_var(var_name, bounds)
|
| 1238 |
+
self.varname_map[var_name] = var
|
| 1239 |
+
return var
|
| 1240 |
+
|
| 1241 |
+
|
| 1242 |
+
class IndirectAssertLine(DeferredLineBase):
|
| 1243 |
+
def __init__(self, line, assert_fn, var, mask, size_map):
|
| 1244 |
+
self.var = var
|
| 1245 |
+
self.mask = mask
|
| 1246 |
+
self.line = line
|
| 1247 |
+
self.assert_fn = assert_fn
|
| 1248 |
+
self.size_map = size_map
|
| 1249 |
+
|
| 1250 |
+
def __call__(self):
|
| 1251 |
+
size, size_str = self.size_map[(self.var, self.mask)]
|
| 1252 |
+
|
| 1253 |
+
# We assert if we've not been able to prove the bound
|
| 1254 |
+
assert_min = (self.var.bounds.lower >= 0) != sympy.true
|
| 1255 |
+
assert_max = (self.var.bounds.upper < size) != sympy.true
|
| 1256 |
+
|
| 1257 |
+
# FooBar interview question
|
| 1258 |
+
if not (assert_min or assert_max):
|
| 1259 |
+
return None
|
| 1260 |
+
elif assert_min and assert_max:
|
| 1261 |
+
# The conditions need to be in parens because of Python's operator precedence.
|
| 1262 |
+
# It'd be less error-prone to use and/or/not, which is suported by triton
|
| 1263 |
+
cond = f"(0 <= {self.var}) & ({self.var} < {size_str})"
|
| 1264 |
+
cond_print = f"0 <= {self.var} < {size_str}"
|
| 1265 |
+
elif assert_min:
|
| 1266 |
+
cond = f"0 <= {self.var}"
|
| 1267 |
+
cond_print = cond
|
| 1268 |
+
else:
|
| 1269 |
+
assert assert_max
|
| 1270 |
+
cond = f"{self.var} < {size_str}"
|
| 1271 |
+
cond_print = cond
|
| 1272 |
+
|
| 1273 |
+
if self.mask:
|
| 1274 |
+
cond = f"({cond}) | ~{self.mask}"
|
| 1275 |
+
return self.line.format(
|
| 1276 |
+
assert_fn=self.assert_fn, cond=cond, cond_print=cond_print
|
| 1277 |
+
)
|
| 1278 |
+
|
| 1279 |
+
def _new_line(self, line):
|
| 1280 |
+
return IndirectAssertLine(
|
| 1281 |
+
line, self.assert_fn, self.var, self.mask, self.size_map
|
| 1282 |
+
)
|
| 1283 |
+
|
| 1284 |
+
|
| 1285 |
+
class CodeGen:
|
| 1286 |
+
def __init__(self):
|
| 1287 |
+
super().__init__()
|
| 1288 |
+
self.exit_stack = contextlib.ExitStack()
|
| 1289 |
+
|
| 1290 |
+
def __enter__(self):
|
| 1291 |
+
self.exit_stack.__enter__()
|
| 1292 |
+
return self
|
| 1293 |
+
|
| 1294 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 1295 |
+
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
| 1296 |
+
|
| 1297 |
+
|
| 1298 |
+
class Kernel(CodeGen):
|
| 1299 |
+
newvar_prefix = ""
|
| 1300 |
+
suffix = ""
|
| 1301 |
+
overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
|
| 1302 |
+
# TODO: these look dead, but with all the getattr it's hard to tell...
|
| 1303 |
+
load_format: None = None
|
| 1304 |
+
store_format: None = None
|
| 1305 |
+
|
| 1306 |
+
def __init__(self, args=None, increase_kernel_count=True):
|
| 1307 |
+
super().__init__()
|
| 1308 |
+
if increase_kernel_count:
|
| 1309 |
+
metrics.generated_kernel_count += 1
|
| 1310 |
+
self.args = args or KernelArgs()
|
| 1311 |
+
self.loads = IndentedBuffer()
|
| 1312 |
+
self.compute = IndentedBuffer()
|
| 1313 |
+
self.stores = IndentedBuffer()
|
| 1314 |
+
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
|
| 1315 |
+
self.must_keep_buffers = set()
|
| 1316 |
+
self.store_buffer_names = set()
|
| 1317 |
+
self._load_mask = None
|
| 1318 |
+
# set in set_current_node
|
| 1319 |
+
self.current_node = None
|
| 1320 |
+
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
|
| 1321 |
+
# Upper bounds for indirect_indexing and their str representation
|
| 1322 |
+
# NB: None, None is never stored in map, but it is the assumed
|
| 1323 |
+
# "not set" value for the dict
|
| 1324 |
+
self.indirect_max_sizes: Dict[
|
| 1325 |
+
Tuple[CSEVariable, str], Union[Tuple[sympy.Expr, str], Tuple[None, None]]
|
| 1326 |
+
] = {}
|
| 1327 |
+
|
| 1328 |
+
self.removed_buffers = set()
|
| 1329 |
+
self.inplaced_to_remove = set()
|
| 1330 |
+
|
| 1331 |
+
# key: the buffer to write
|
| 1332 |
+
# value: the buffer to read and whose memory can be reused for
|
| 1333 |
+
# the buffer specified by key
|
| 1334 |
+
self.inplace_update_buffers = dict()
|
| 1335 |
+
# Set minimum number of elements processed per thread.
|
| 1336 |
+
self.min_elem_per_thread = 1
|
| 1337 |
+
self.kernel_name = None
|
| 1338 |
+
|
| 1339 |
+
@contextlib.contextmanager
|
| 1340 |
+
def set_current_node(self, node):
|
| 1341 |
+
prior = self.current_node
|
| 1342 |
+
self.current_node = node
|
| 1343 |
+
self.node_to_bounds = node._body.bounds().get_bounds()
|
| 1344 |
+
try:
|
| 1345 |
+
yield
|
| 1346 |
+
finally:
|
| 1347 |
+
self.current_node = prior
|
| 1348 |
+
|
| 1349 |
+
@contextlib.contextmanager
|
| 1350 |
+
def swap_buffers(self, lb, cb=None, sb=None):
|
| 1351 |
+
if cb is None:
|
| 1352 |
+
cb = lb
|
| 1353 |
+
loads = self.loads
|
| 1354 |
+
compute = self.compute
|
| 1355 |
+
stores = self.stores
|
| 1356 |
+
cse = self.cse
|
| 1357 |
+
self.loads = lb
|
| 1358 |
+
self.compute = cb
|
| 1359 |
+
self.stores = sb
|
| 1360 |
+
self.cse = cse.clone()
|
| 1361 |
+
try:
|
| 1362 |
+
yield
|
| 1363 |
+
finally:
|
| 1364 |
+
self.loads = loads
|
| 1365 |
+
self.compute = compute
|
| 1366 |
+
self.stores = stores
|
| 1367 |
+
self.cse = cse
|
| 1368 |
+
|
| 1369 |
+
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
| 1370 |
+
raise NotImplementedError()
|
| 1371 |
+
|
| 1372 |
+
def indirect_load(self, name: str, index: sympy.Expr):
|
| 1373 |
+
"""A load the depends on an index we have read"""
|
| 1374 |
+
prior = self.loads
|
| 1375 |
+
try:
|
| 1376 |
+
# put the load in the compute section as it might have deps
|
| 1377 |
+
self.loads = self.compute
|
| 1378 |
+
return self.load(name, index)
|
| 1379 |
+
finally:
|
| 1380 |
+
self.loads = prior
|
| 1381 |
+
|
| 1382 |
+
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
|
| 1383 |
+
raise NotImplementedError()
|
| 1384 |
+
|
| 1385 |
+
def store(
|
| 1386 |
+
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
| 1387 |
+
) -> None:
|
| 1388 |
+
raise NotImplementedError()
|
| 1389 |
+
|
| 1390 |
+
def reduction(
|
| 1391 |
+
self,
|
| 1392 |
+
dtype: torch.dtype,
|
| 1393 |
+
src_dtype: torch.dtype,
|
| 1394 |
+
reduction_type: ReductionType,
|
| 1395 |
+
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
| 1396 |
+
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
| 1397 |
+
raise NotImplementedError()
|
| 1398 |
+
|
| 1399 |
+
def scan(
|
| 1400 |
+
self,
|
| 1401 |
+
dtype: torch.dtype,
|
| 1402 |
+
combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable],
|
| 1403 |
+
value: CSEVariable,
|
| 1404 |
+
init: int,
|
| 1405 |
+
) -> CSEVariable:
|
| 1406 |
+
raise NotImplementedError()
|
| 1407 |
+
|
| 1408 |
+
def bucketize(
|
| 1409 |
+
self,
|
| 1410 |
+
values: CSEVariable,
|
| 1411 |
+
offsets_name: str,
|
| 1412 |
+
offsets_size: sympy.Expr,
|
| 1413 |
+
indexing_dtype: torch.dtype,
|
| 1414 |
+
right: bool,
|
| 1415 |
+
) -> CSEVariable:
|
| 1416 |
+
"""
|
| 1417 |
+
See [Note: Inductor bucketize op]
|
| 1418 |
+
"""
|
| 1419 |
+
raise NotImplementedError()
|
| 1420 |
+
|
| 1421 |
+
@property
|
| 1422 |
+
def assert_function(self) -> str:
|
| 1423 |
+
raise NotImplementedError()
|
| 1424 |
+
|
| 1425 |
+
def index_to_str(self, index: sympy.Expr) -> str:
|
| 1426 |
+
raise NotImplementedError()
|
| 1427 |
+
|
| 1428 |
+
def __enter__(self):
|
| 1429 |
+
# TODO: hoist this to top level
|
| 1430 |
+
class CSEProxy:
|
| 1431 |
+
self.name = "CSEProxy"
|
| 1432 |
+
|
| 1433 |
+
@staticmethod
|
| 1434 |
+
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
|
| 1435 |
+
def inner(*args, **kwargs):
|
| 1436 |
+
# TritonTemplateKernel has no current_node
|
| 1437 |
+
buf_bounds = ValueRanges.unknown()
|
| 1438 |
+
if hasattr(V.interpreter, "current_node"):
|
| 1439 |
+
fx_node = V.interpreter.current_node
|
| 1440 |
+
assert isinstance(self.node_to_bounds, dict)
|
| 1441 |
+
buf_bounds = self.node_to_bounds.get(
|
| 1442 |
+
fx_node, ValueRanges.unknown()
|
| 1443 |
+
)
|
| 1444 |
+
|
| 1445 |
+
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
| 1446 |
+
|
| 1447 |
+
def do_cse(v):
|
| 1448 |
+
csevar = self.cse.generate(self.compute, v, bounds=buf_bounds)
|
| 1449 |
+
csevar.update_on_args(name, args, kwargs)
|
| 1450 |
+
return csevar
|
| 1451 |
+
|
| 1452 |
+
return pytree.tree_map(do_cse, value)
|
| 1453 |
+
|
| 1454 |
+
return inner
|
| 1455 |
+
|
| 1456 |
+
@staticmethod
|
| 1457 |
+
def indirect_indexing(
|
| 1458 |
+
var: CSEVariable, size: sympy.Expr, check: bool = True
|
| 1459 |
+
):
|
| 1460 |
+
# Skip CSE since this doesn't return an expression
|
| 1461 |
+
|
| 1462 |
+
if var.bounds.lower < 0: # type: ignore[operator]
|
| 1463 |
+
new_bounds = ValueRanges.unknown()
|
| 1464 |
+
if var.bounds != ValueRanges.unknown() and isinstance(
|
| 1465 |
+
size, sympy.Number
|
| 1466 |
+
):
|
| 1467 |
+
# Take the negative part of the bound and add size to it
|
| 1468 |
+
# Then take union of that and the positive part
|
| 1469 |
+
# This is a tighter bound than that of a generic ops.where, as we have info on the cond
|
| 1470 |
+
neg = var.bounds & ValueRanges(-sympy.oo, -1)
|
| 1471 |
+
new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
|
| 1472 |
+
# We don't have a good way of representing the empty range
|
| 1473 |
+
if var.bounds.upper >= 0: # type: ignore[operator]
|
| 1474 |
+
pos = var.bounds & ValueRanges(0, sympy.oo)
|
| 1475 |
+
new_bounds = new_bounds | pos
|
| 1476 |
+
|
| 1477 |
+
stm = ops.add(var, self.rename_indexing(size))
|
| 1478 |
+
# Mixed negative and non-negative
|
| 1479 |
+
if var.bounds.upper >= 0: # type: ignore[operator]
|
| 1480 |
+
lt = ops.lt(var, "0")
|
| 1481 |
+
stm = ops.where(lt, stm, var)
|
| 1482 |
+
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
|
| 1483 |
+
|
| 1484 |
+
new_var.update_on_args("index_wrap", (var,), {})
|
| 1485 |
+
var = new_var
|
| 1486 |
+
|
| 1487 |
+
if self.generate_assert(check):
|
| 1488 |
+
mask = self.load_mask(var)
|
| 1489 |
+
|
| 1490 |
+
# An assertion line may have been written already, if so just
|
| 1491 |
+
# update the max size.
|
| 1492 |
+
map_key = (var, mask)
|
| 1493 |
+
existing_size, _ = self.indirect_max_sizes.get(
|
| 1494 |
+
map_key, (None, None)
|
| 1495 |
+
)
|
| 1496 |
+
if existing_size is not None:
|
| 1497 |
+
size = sympy.Min(size, existing_size)
|
| 1498 |
+
else:
|
| 1499 |
+
line = (
|
| 1500 |
+
'{assert_fn}({cond}, "index out of bounds: {cond_print}")'
|
| 1501 |
+
)
|
| 1502 |
+
self.compute.writeline(
|
| 1503 |
+
IndirectAssertLine(
|
| 1504 |
+
line,
|
| 1505 |
+
self.assert_function,
|
| 1506 |
+
var,
|
| 1507 |
+
mask,
|
| 1508 |
+
self.indirect_max_sizes,
|
| 1509 |
+
)
|
| 1510 |
+
)
|
| 1511 |
+
|
| 1512 |
+
self.indirect_max_sizes[map_key] = (size, self.index_to_str(size))
|
| 1513 |
+
return sympy_index_symbol(str(var))
|
| 1514 |
+
|
| 1515 |
+
@staticmethod
|
| 1516 |
+
def load(name: str, index: sympy.Expr) -> CSEVariable:
|
| 1517 |
+
if name in self.cse.invalidated_stores:
|
| 1518 |
+
# A load from an invalidated store requires us to
|
| 1519 |
+
# keep the actual buffer around
|
| 1520 |
+
V.kernel.must_keep_buffers.add(name)
|
| 1521 |
+
if free_symbol_startswith(index, "tmp"):
|
| 1522 |
+
return self.indirect_load(name, index)
|
| 1523 |
+
store_cache = self.cse.store_cache
|
| 1524 |
+
if name in store_cache:
|
| 1525 |
+
return store_cache[name]
|
| 1526 |
+
return self.load(name, index)
|
| 1527 |
+
|
| 1528 |
+
@staticmethod
|
| 1529 |
+
def store(
|
| 1530 |
+
name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
| 1531 |
+
) -> None:
|
| 1532 |
+
self.store_buffer_names.add(name)
|
| 1533 |
+
if mode is None:
|
| 1534 |
+
self.cse.store_cache[name] = value
|
| 1535 |
+
if self.current_node:
|
| 1536 |
+
for other_name in self.current_node.get_mutations():
|
| 1537 |
+
self.cse.store_cache[other_name] = value
|
| 1538 |
+
if name not in V.graph.removed_buffers:
|
| 1539 |
+
return self.store(name, index, value, mode=mode)
|
| 1540 |
+
else:
|
| 1541 |
+
return None # type: ignore[return-value]
|
| 1542 |
+
|
| 1543 |
+
@staticmethod
|
| 1544 |
+
def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
|
| 1545 |
+
self.store_buffer_names.add(name)
|
| 1546 |
+
self.cse.store_cache[name] = value
|
| 1547 |
+
if self.current_node:
|
| 1548 |
+
for other_name in self.current_node.get_mutations():
|
| 1549 |
+
self.cse.store_cache[other_name] = value
|
| 1550 |
+
|
| 1551 |
+
if name not in V.graph.removed_buffers:
|
| 1552 |
+
return self.store_reduction(name, index, value)
|
| 1553 |
+
|
| 1554 |
+
@staticmethod
|
| 1555 |
+
def reduction(
|
| 1556 |
+
dtype: torch.dtype,
|
| 1557 |
+
src_dtype: torch.dtype,
|
| 1558 |
+
reduction_type: ReductionType,
|
| 1559 |
+
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
| 1560 |
+
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
| 1561 |
+
return self.reduction(dtype, src_dtype, reduction_type, value)
|
| 1562 |
+
|
| 1563 |
+
@staticmethod
|
| 1564 |
+
def scan(
|
| 1565 |
+
dtype: torch.dtype,
|
| 1566 |
+
combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable],
|
| 1567 |
+
value: CSEVariable,
|
| 1568 |
+
init: int,
|
| 1569 |
+
) -> CSEVariable:
|
| 1570 |
+
return self.scan(dtype, combine_fn, value, init)
|
| 1571 |
+
|
| 1572 |
+
@staticmethod
|
| 1573 |
+
def bucketize(
|
| 1574 |
+
values: CSEVariable,
|
| 1575 |
+
offsets_name: str,
|
| 1576 |
+
offsets_size: sympy.Expr,
|
| 1577 |
+
indexing_dtype: torch.dtype,
|
| 1578 |
+
right: bool,
|
| 1579 |
+
) -> CSEVariable:
|
| 1580 |
+
"""
|
| 1581 |
+
[Note: Inductor bucketize op]
|
| 1582 |
+
|
| 1583 |
+
Given values (tensor) and offsets_name (reference to the name of a 1D
|
| 1584 |
+
tensor), calculate the bucket that each value belongs to.
|
| 1585 |
+
|
| 1586 |
+
e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
|
| 1587 |
+
return = [ 0, 1, 1, 1, 1, 3, 3, 4].
|
| 1588 |
+
|
| 1589 |
+
When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
|
| 1590 |
+
When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
|
| 1591 |
+
|
| 1592 |
+
Offsets must be non-decreasing or the result is undefined.
|
| 1593 |
+
"""
|
| 1594 |
+
return self.bucketize(
|
| 1595 |
+
values, offsets_name, offsets_size, indexing_dtype, right
|
| 1596 |
+
)
|
| 1597 |
+
|
| 1598 |
+
# Use mypy to check protocol implemented correctly
|
| 1599 |
+
def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
|
| 1600 |
+
return h
|
| 1601 |
+
|
| 1602 |
+
super().__enter__()
|
| 1603 |
+
assert self.overrides
|
| 1604 |
+
parent_handler = self.overrides(V.get_ops_handler())
|
| 1605 |
+
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
|
| 1606 |
+
self.exit_stack.enter_context(V.set_kernel_handler(self))
|
| 1607 |
+
return self
|
| 1608 |
+
|
| 1609 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 1610 |
+
"""
|
| 1611 |
+
Note that V.graph.scheduler can be None when codegening triton template
|
| 1612 |
+
kernels.
|
| 1613 |
+
"""
|
| 1614 |
+
if V.graph.scheduler:
|
| 1615 |
+
V.graph.scheduler.remove_kernel_local_buffers()
|
| 1616 |
+
super().__exit__(exc_type, exc_val, exc_tb)
|
| 1617 |
+
|
| 1618 |
+
def generate_assert(self, check):
|
| 1619 |
+
return (check or config.debug_index_asserts) and config.assert_indirect_indexing
|
| 1620 |
+
|
| 1621 |
+
def load_mask(self, var) -> str:
|
| 1622 |
+
# only the triton kernel requires mask
|
| 1623 |
+
return ""
|
| 1624 |
+
|
| 1625 |
+
def rename_indexing(self, index) -> sympy.Expr:
|
| 1626 |
+
# adds the necessary kernel args for index expressions
|
| 1627 |
+
# and renames variables in index expressions to kernel arg names
|
| 1628 |
+
if isinstance(index, (list, tuple)):
|
| 1629 |
+
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
|
| 1630 |
+
index = V.graph.sizevars.simplify(index)
|
| 1631 |
+
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
|
| 1632 |
+
replacements = {
|
| 1633 |
+
x: self.args.size(x)
|
| 1634 |
+
for x in sorted_symbols
|
| 1635 |
+
if x.name.startswith(("s", "u", "ps"))
|
| 1636 |
+
or (x.name.startswith("i") and not x.name.startswith("idx"))
|
| 1637 |
+
}
|
| 1638 |
+
return sympy_subs(index, replacements)
|
| 1639 |
+
|
| 1640 |
+
def create_cse_var(self, *args, **kwargs):
|
| 1641 |
+
return CSEVariable(*args, **kwargs)
|
| 1642 |
+
|
| 1643 |
+
|
| 1644 |
+
@dataclasses.dataclass
|
| 1645 |
+
class OptimizationContext:
|
| 1646 |
+
key: ClassVar[str] = "opt_ctx"
|
| 1647 |
+
|
| 1648 |
+
# Load value as mask
|
| 1649 |
+
is_load_as_mask: bool = False
|
| 1650 |
+
|
| 1651 |
+
dtype: Optional[torch.dtype] = None
|
| 1652 |
+
ops_name: str = ""
|
| 1653 |
+
|
| 1654 |
+
# Load uint8/int8 value as float32
|
| 1655 |
+
is_load_int8_as_float: bool = False
|
| 1656 |
+
|
| 1657 |
+
|
| 1658 |
+
@functools.lru_cache(None)
|
| 1659 |
+
def jinja2_env():
|
| 1660 |
+
try:
|
| 1661 |
+
import jinja2
|
| 1662 |
+
|
| 1663 |
+
return jinja2.Environment(
|
| 1664 |
+
undefined=jinja2.StrictUndefined,
|
| 1665 |
+
)
|
| 1666 |
+
except ImportError:
|
| 1667 |
+
return None
|
| 1668 |
+
|
| 1669 |
+
|
| 1670 |
+
PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]]
|
| 1671 |
+
|
| 1672 |
+
|
| 1673 |
+
class ChoiceCaller:
|
| 1674 |
+
"""
|
| 1675 |
+
Represents a possible choice used in autotune_process.py.
|
| 1676 |
+
During autotuning, self.benchmark() is first called to get benchmark result,
|
| 1677 |
+
and if this choice is selected, self.output_node() is called to get the output_node.
|
| 1678 |
+
|
| 1679 |
+
Children classes: TritonTemplateCaller, CUDATemplateCaller.
|
| 1680 |
+
"""
|
| 1681 |
+
|
| 1682 |
+
def __init__(self, name, input_nodes, layout):
|
| 1683 |
+
super().__init__()
|
| 1684 |
+
self.name = name
|
| 1685 |
+
self.layout = layout
|
| 1686 |
+
self.input_nodes = input_nodes
|
| 1687 |
+
|
| 1688 |
+
def benchmark(self, *args, out) -> float:
|
| 1689 |
+
algo = self.to_callable()
|
| 1690 |
+
return do_bench(lambda: algo(*args, out=out))
|
| 1691 |
+
|
| 1692 |
+
def call_name(self) -> str:
|
| 1693 |
+
raise NotImplementedError()
|
| 1694 |
+
|
| 1695 |
+
def to_callable(self):
|
| 1696 |
+
raise NotImplementedError()
|
| 1697 |
+
|
| 1698 |
+
def hash_key(self) -> str:
|
| 1699 |
+
raise NotImplementedError()
|
| 1700 |
+
|
| 1701 |
+
def output_node(self) -> "TensorBox":
|
| 1702 |
+
raise NotImplementedError()
|
| 1703 |
+
|
| 1704 |
+
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
| 1705 |
+
"""Information returned here is logged to the autotune log file when that is enabled."""
|
| 1706 |
+
return {}
|
| 1707 |
+
|
| 1708 |
+
|
| 1709 |
+
class KernelTemplate:
|
| 1710 |
+
"""
|
| 1711 |
+
Base class for defining kernel templates.
|
| 1712 |
+
|
| 1713 |
+
Children classes: TritonTemplate, CUDATemplate
|
| 1714 |
+
"""
|
| 1715 |
+
|
| 1716 |
+
@staticmethod
|
| 1717 |
+
def _template_from_string(source):
|
| 1718 |
+
env = jinja2_env()
|
| 1719 |
+
if env is not None:
|
| 1720 |
+
return env.from_string(source)
|
| 1721 |
+
return None
|
| 1722 |
+
|
| 1723 |
+
@staticmethod
|
| 1724 |
+
def _fake_get_dtype(fake_out):
|
| 1725 |
+
_get_dtype_real = V.graph.get_dtype
|
| 1726 |
+
|
| 1727 |
+
def get_dtype(name):
|
| 1728 |
+
if name == fake_out.get_name():
|
| 1729 |
+
return fake_out.get_dtype()
|
| 1730 |
+
return _get_dtype_real(name)
|
| 1731 |
+
|
| 1732 |
+
return get_dtype
|
| 1733 |
+
|
| 1734 |
+
def __init__(self, name: str):
|
| 1735 |
+
self.name = name
|
| 1736 |
+
|
| 1737 |
+
def maybe_append_choice(self, choices, **kwargs):
|
| 1738 |
+
"""
|
| 1739 |
+
Maybe generates a new ChoiceCaller and appends it into existing choices.
|
| 1740 |
+
|
| 1741 |
+
choices: A list of ChoiceCallers.
|
| 1742 |
+
kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
|
| 1743 |
+
"""
|
| 1744 |
+
|
| 1745 |
+
try:
|
| 1746 |
+
choices.append(self.generate(**kwargs))
|
| 1747 |
+
except NotImplementedError:
|
| 1748 |
+
pass
|
| 1749 |
+
|
| 1750 |
+
def generate(self, **kwargs) -> ChoiceCaller:
|
| 1751 |
+
"""
|
| 1752 |
+
Generates a ChoiceCaller instance from the given arguments.
|
| 1753 |
+
"""
|
| 1754 |
+
|
| 1755 |
+
raise NotImplementedError()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc
ADDED
|
Binary file (2.29 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc
ADDED
|
Binary file (19.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc
ADDED
|
Binary file (20.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc
ADDED
|
Binary file (30.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import cast, List
|
| 3 |
+
|
| 4 |
+
from ...._dynamo.utils import counters
|
| 5 |
+
|
| 6 |
+
from ... import config, ir
|
| 7 |
+
from ...codecache import code_hash, get_path
|
| 8 |
+
from ...ir import ComputedBuffer, CUDATemplateBuffer, Pointwise
|
| 9 |
+
from ...scheduler import (
|
| 10 |
+
BaseSchedulerNode,
|
| 11 |
+
BaseScheduling,
|
| 12 |
+
FusedSchedulerNode,
|
| 13 |
+
Scheduler,
|
| 14 |
+
SchedulerNode,
|
| 15 |
+
)
|
| 16 |
+
from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
|
| 17 |
+
from ...virtualized import V
|
| 18 |
+
from ..common import IndentedBuffer
|
| 19 |
+
|
| 20 |
+
from .cutlass_epilogue_gen import CUTLASSEVTOpNotImplementedError
|
| 21 |
+
|
| 22 |
+
log = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class CUDACPPScheduling(BaseScheduling):
|
| 26 |
+
"""
|
| 27 |
+
Partial Scheduling implementation for CUDA C++ Kernels.
|
| 28 |
+
This class is intended to be used in combination with TritonScheduling,
|
| 29 |
+
and delegated to by CUDACombinedScheduling.
|
| 30 |
+
|
| 31 |
+
It handles fusion decisions and CUDA C++ specific template code generation.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, scheduler: Scheduler):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.scheduler = scheduler
|
| 37 |
+
|
| 38 |
+
def group_fn(self, sizes):
|
| 39 |
+
return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
|
| 40 |
+
|
| 41 |
+
def is_cuda_cpp_template(self, node: BaseSchedulerNode) -> bool:
|
| 42 |
+
return isinstance(node, SchedulerNode) and isinstance(
|
| 43 |
+
node.node, CUDATemplateBuffer
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool:
|
| 47 |
+
return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(
|
| 48 |
+
node.get_template_node()
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def _can_fuse_epilogue_impl(
|
| 52 |
+
self,
|
| 53 |
+
cuda_template_buffer: CUDATemplateBuffer,
|
| 54 |
+
epilogue_nodes: List[ir.IRNode],
|
| 55 |
+
additional_node: ir.IRNode,
|
| 56 |
+
) -> bool:
|
| 57 |
+
"""
|
| 58 |
+
Check if the given node can be fused with the epilogue. At the moment, Kernels
|
| 59 |
+
support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
|
| 63 |
+
epilogue_nodes : List[ir.Buffer]: The list of already fused epilogue nodes.
|
| 64 |
+
additional_node: The ir.Buffer node to be checked if it can be fused with the epilogue.
|
| 65 |
+
Returns:
|
| 66 |
+
- bool: True if the given node can be fused with the epilogue, False otherwise.
|
| 67 |
+
|
| 68 |
+
"""
|
| 69 |
+
if not isinstance(cuda_template_buffer, CUDATemplateBuffer):
|
| 70 |
+
return False
|
| 71 |
+
if not cuda_template_buffer.template.can_fuse_epilogue:
|
| 72 |
+
# The used GEMM op does not support fusing epilogues
|
| 73 |
+
return False
|
| 74 |
+
if not isinstance(additional_node, ComputedBuffer):
|
| 75 |
+
return False
|
| 76 |
+
if not isinstance(additional_node.data, Pointwise):
|
| 77 |
+
return False
|
| 78 |
+
# We can fuse a Pointwise op that depends on the last fused epilogue node
|
| 79 |
+
# if any. If there is no epilogue node yet, it needs to depend on the template
|
| 80 |
+
# node
|
| 81 |
+
node_name = additional_node.get_computed_buffer_name()
|
| 82 |
+
if node_name is None:
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
if len(epilogue_nodes) == 0:
|
| 86 |
+
if cuda_template_buffer.name not in additional_node.get_read_names():
|
| 87 |
+
return False
|
| 88 |
+
else:
|
| 89 |
+
last_epilogue_node = epilogue_nodes[-1]
|
| 90 |
+
assert isinstance(last_epilogue_node, ir.ComputedBuffer) # for mypy
|
| 91 |
+
last_epilogue_name = (
|
| 92 |
+
last_epilogue_node.name
|
| 93 |
+
if last_epilogue_node.name is not None
|
| 94 |
+
else last_epilogue_node.data.name # type: ignore[attr-defined]
|
| 95 |
+
)
|
| 96 |
+
if last_epilogue_name not in additional_node.get_read_names():
|
| 97 |
+
return False
|
| 98 |
+
if additional_node.layout != cuda_template_buffer.layout:
|
| 99 |
+
return False
|
| 100 |
+
try:
|
| 101 |
+
from torch._inductor.codegen.cuda.cutlass_epilogue_gen import (
|
| 102 |
+
CutlassEVTEpilogueArgumentFormatter,
|
| 103 |
+
CutlassEVTEpilogueTypeFormatter,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(
|
| 107 |
+
cast(str, cuda_template_buffer.name), "anything", [additional_node]
|
| 108 |
+
)
|
| 109 |
+
CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(
|
| 110 |
+
cast(str, cuda_template_buffer.name), [additional_node]
|
| 111 |
+
)
|
| 112 |
+
except CUTLASSEVTOpNotImplementedError as e:
|
| 113 |
+
not_implemented_op = str(e)
|
| 114 |
+
if not_implemented_op.startswith("_op_"):
|
| 115 |
+
not_implemented_op = not_implemented_op[4:]
|
| 116 |
+
log.warning(
|
| 117 |
+
f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}, likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950
|
| 118 |
+
)
|
| 119 |
+
return False
|
| 120 |
+
else:
|
| 121 |
+
# Likely due to unsupported dtype.
|
| 122 |
+
log.warning(
|
| 123 |
+
f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}. Reason: {not_implemented_op}" # noqa: G004, B950
|
| 124 |
+
)
|
| 125 |
+
return False
|
| 126 |
+
return True
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def _unwrap_epilogue_nodes(fused_node: FusedSchedulerNode) -> List[ir.IRNode]:
|
| 130 |
+
nodes = fused_node.get_nodes()
|
| 131 |
+
template_node = fused_node.get_template_node()
|
| 132 |
+
nodes.remove(template_node)
|
| 133 |
+
return [n.node for n in nodes]
|
| 134 |
+
|
| 135 |
+
def can_fuse_vertical(
|
| 136 |
+
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
| 137 |
+
) -> bool:
|
| 138 |
+
if self.is_cuda_cpp_template(node1) and isinstance(node2, SchedulerNode):
|
| 139 |
+
return self._can_fuse_epilogue_impl(
|
| 140 |
+
cast(CUDATemplateBuffer, node1.node), [], node2.node
|
| 141 |
+
)
|
| 142 |
+
elif self.is_cuda_cpp_fused_template(node1) and isinstance(
|
| 143 |
+
node2, SchedulerNode
|
| 144 |
+
):
|
| 145 |
+
fnode1 = cast(FusedSchedulerNode, node1)
|
| 146 |
+
return self._can_fuse_epilogue_impl(
|
| 147 |
+
fnode1.get_template_node().node,
|
| 148 |
+
self._unwrap_epilogue_nodes(fnode1),
|
| 149 |
+
node2.node,
|
| 150 |
+
)
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
def define_kernel(self, src_code: str, node_schedule) -> str:
|
| 154 |
+
wrapper = V.graph.wrapper_code
|
| 155 |
+
if src_code in wrapper.src_to_kernel:
|
| 156 |
+
kernel_name = wrapper.src_to_kernel[src_code]
|
| 157 |
+
else:
|
| 158 |
+
fused_name = (
|
| 159 |
+
get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
|
| 160 |
+
if config.triton.descriptive_names
|
| 161 |
+
else ""
|
| 162 |
+
)
|
| 163 |
+
kernel_name = "_".join(["cuda", fused_name, wrapper.next_kernel_suffix()])
|
| 164 |
+
# use the original src_code as the key
|
| 165 |
+
wrapper.src_to_kernel[src_code] = kernel_name
|
| 166 |
+
src_code = src_code.replace("KERNEL_NAME", kernel_name)
|
| 167 |
+
|
| 168 |
+
_, _, kernel_path = get_path(code_hash(src_code), "py")
|
| 169 |
+
|
| 170 |
+
compile_wrapper = IndentedBuffer()
|
| 171 |
+
compile_wrapper.writeline("async_compile.cuda(r'''")
|
| 172 |
+
compile_wrapper.splice(src_code, strip=True)
|
| 173 |
+
compile_wrapper.writeline("''', 'so')")
|
| 174 |
+
|
| 175 |
+
metadata_comment = f"# kernel path: {kernel_path}"
|
| 176 |
+
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
| 177 |
+
metadata_comment += "\n" + origins + "\n" + detailed_origins
|
| 178 |
+
wrapper.define_kernel(
|
| 179 |
+
kernel_name, compile_wrapper.getvalue(), metadata_comment
|
| 180 |
+
)
|
| 181 |
+
return kernel_name
|
| 182 |
+
|
| 183 |
+
def codegen_template(
|
| 184 |
+
self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode]
|
| 185 |
+
):
|
| 186 |
+
"""
|
| 187 |
+
Codegen a CUDA template, possibly with fused epilogues
|
| 188 |
+
"""
|
| 189 |
+
counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
|
| 190 |
+
assert self.is_cuda_cpp_template(
|
| 191 |
+
template_node
|
| 192 |
+
), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
|
| 193 |
+
template_node = cast(SchedulerNode, template_node)
|
| 194 |
+
_, (numel, rnumel) = template_node.group
|
| 195 |
+
assert rnumel == 1
|
| 196 |
+
ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node)
|
| 197 |
+
epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes]
|
| 198 |
+
assert all(
|
| 199 |
+
isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes
|
| 200 |
+
), "Epilogue nodes must all be instances of ir.ComputedBuffer"
|
| 201 |
+
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes)
|
| 202 |
+
with kernel:
|
| 203 |
+
for node in [template_node, *epilogue_nodes]:
|
| 204 |
+
node.mark_run()
|
| 205 |
+
src_code = render()
|
| 206 |
+
|
| 207 |
+
with V.set_kernel_handler(kernel):
|
| 208 |
+
node_schedule = [template_node, *epilogue_nodes]
|
| 209 |
+
kernel_name = self.define_kernel(src_code, node_schedule)
|
| 210 |
+
kernel.call_kernel(kernel_name, ctb, epilogue_ir_nodes)
|
| 211 |
+
V.graph.removed_buffers |= kernel.removed_buffers
|
| 212 |
+
self.scheduler.free_buffers()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ... import config
|
| 8 |
+
|
| 9 |
+
log = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_cuda_arch() -> Optional[str]:
|
| 13 |
+
try:
|
| 14 |
+
cuda_arch = config.cuda.arch
|
| 15 |
+
if cuda_arch is None:
|
| 16 |
+
# Get Compute Capability of the first Visible device
|
| 17 |
+
major, minor = torch.cuda.get_device_capability(0)
|
| 18 |
+
return str(major * 10 + minor)
|
| 19 |
+
return str(cuda_arch)
|
| 20 |
+
except Exception as e:
|
| 21 |
+
log.error("Error getting cuda arch: %s", e)
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_cuda_version() -> Optional[str]:
|
| 26 |
+
try:
|
| 27 |
+
cuda_version = config.cuda.version
|
| 28 |
+
if cuda_version is None:
|
| 29 |
+
cuda_version = torch.version.cuda
|
| 30 |
+
return cuda_version
|
| 31 |
+
except Exception as e:
|
| 32 |
+
log.error("Error getting cuda version: %s", e)
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@functools.lru_cache(None)
|
| 37 |
+
def nvcc_exist(nvcc_path: str = "nvcc") -> bool:
|
| 38 |
+
if nvcc_path is None:
|
| 39 |
+
return False
|
| 40 |
+
import subprocess
|
| 41 |
+
|
| 42 |
+
res = subprocess.call(
|
| 43 |
+
["which", nvcc_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
| 44 |
+
)
|
| 45 |
+
return res == 0
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import itertools
|
| 3 |
+
import logging
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
from unittest.mock import patch
|
| 6 |
+
|
| 7 |
+
import sympy
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from ...autotune_process import CUDABenchmarkRequest, TensorMeta
|
| 11 |
+
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout
|
| 12 |
+
|
| 13 |
+
from ...utils import IndentedBuffer, unique
|
| 14 |
+
from ...virtualized import V
|
| 15 |
+
from ..common import KernelTemplate
|
| 16 |
+
from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
|
| 17 |
+
|
| 18 |
+
log = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class CUDATemplate(KernelTemplate):
|
| 22 |
+
index_counter = itertools.count()
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
name: str,
|
| 27 |
+
input_nodes: List[Buffer],
|
| 28 |
+
layout: Layout,
|
| 29 |
+
input_reorder: Optional[List[int]] = None,
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
name (str): The name of the CUDATemplate object.
|
| 37 |
+
input_nodes (List[IRNode]): A list of input IRNodes.
|
| 38 |
+
layout (Layout): The layout of the output buffer / tensor.
|
| 39 |
+
input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes.
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
super().__init__(name)
|
| 43 |
+
self.input_nodes = input_nodes
|
| 44 |
+
self.output_node: Buffer = Buffer("buf_out", layout)
|
| 45 |
+
self.input_reorder = input_reorder
|
| 46 |
+
self.layout = layout
|
| 47 |
+
|
| 48 |
+
def generate( # type: ignore[override]
|
| 49 |
+
self,
|
| 50 |
+
**kwargs,
|
| 51 |
+
) -> CUDATemplateCaller:
|
| 52 |
+
"""
|
| 53 |
+
Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller
|
| 54 |
+
may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
kwargs: Additional keyword arguments.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
A CUDATemplateCaller object representing the generated CUDA template caller.
|
| 61 |
+
"""
|
| 62 |
+
kernel_name = f"cuda_{self.name}"
|
| 63 |
+
with patch.object(
|
| 64 |
+
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
|
| 65 |
+
), CUDATemplateKernel(
|
| 66 |
+
kernel_name=kernel_name,
|
| 67 |
+
) as kernel:
|
| 68 |
+
code = self.render(kernel=kernel, **kwargs)
|
| 69 |
+
_, call_args, _ = kernel.args.python_argdefs()
|
| 70 |
+
log.debug("Generated Code:\n%s", code)
|
| 71 |
+
log.debug(
|
| 72 |
+
"Args: cpp_argdefs: %s, python_argdefs: %s",
|
| 73 |
+
kernel.args.cpp_argdefs(),
|
| 74 |
+
kernel.args.python_argdefs(),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
input_reorder = (
|
| 78 |
+
self.input_reorder
|
| 79 |
+
if self.input_reorder is not None
|
| 80 |
+
else list(range(len(self.input_nodes)))
|
| 81 |
+
)
|
| 82 |
+
expected_args = list(
|
| 83 |
+
unique(self.input_nodes[idx].get_name() for idx in input_reorder)
|
| 84 |
+
)
|
| 85 |
+
expected_args.extend([self.output_node.get_name()])
|
| 86 |
+
assert list(call_args)[: len(expected_args)] == expected_args, (
|
| 87 |
+
call_args,
|
| 88 |
+
expected_args,
|
| 89 |
+
)
|
| 90 |
+
extra_args = V.graph.sizevars.size_hints(
|
| 91 |
+
map(sympy.expand, call_args[len(expected_args) :])
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
kernel_hash_name = f"cuda_{self.name}_{next(self.index_counter)}"
|
| 95 |
+
|
| 96 |
+
# create the BenchmarkRequest
|
| 97 |
+
bmreq = CUDABenchmarkRequest(
|
| 98 |
+
kernel_name=kernel_name,
|
| 99 |
+
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
|
| 100 |
+
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
|
| 101 |
+
extra_args=extra_args,
|
| 102 |
+
source_code=code,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def make_kernel_render(
|
| 106 |
+
template_node: CUDATemplateBuffer,
|
| 107 |
+
epilogue_nodes: Optional[List[IRNode]] = None,
|
| 108 |
+
):
|
| 109 |
+
kernel = CUDATemplateKernel(
|
| 110 |
+
kernel_name="KERNEL_NAME",
|
| 111 |
+
)
|
| 112 |
+
render = functools.partial(
|
| 113 |
+
self.render,
|
| 114 |
+
kernel=kernel,
|
| 115 |
+
template_buffer_node=template_node,
|
| 116 |
+
epilogue_nodes=epilogue_nodes,
|
| 117 |
+
**kwargs, # includes "op" argument in case of CUTLASSGemmTemplate
|
| 118 |
+
)
|
| 119 |
+
return kernel, render
|
| 120 |
+
|
| 121 |
+
return CUDATemplateCaller(
|
| 122 |
+
kernel_hash_name,
|
| 123 |
+
self.name,
|
| 124 |
+
self.input_nodes,
|
| 125 |
+
self.output_node.get_layout(),
|
| 126 |
+
make_kernel_render,
|
| 127 |
+
bmreq,
|
| 128 |
+
self,
|
| 129 |
+
kwargs,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def header(self) -> IndentedBuffer:
|
| 133 |
+
res = IndentedBuffer()
|
| 134 |
+
res.splice(
|
| 135 |
+
"""
|
| 136 |
+
#include <exception>
|
| 137 |
+
#include <iostream>
|
| 138 |
+
#include <memory>
|
| 139 |
+
#include <random>
|
| 140 |
+
#include <vector>
|
| 141 |
+
"""
|
| 142 |
+
)
|
| 143 |
+
return res
|
| 144 |
+
|
| 145 |
+
def globals(self) -> IndentedBuffer:
|
| 146 |
+
res = IndentedBuffer()
|
| 147 |
+
res.splice(
|
| 148 |
+
"""
|
| 149 |
+
// We compile all models with -fvisibility=hidden. Any symbols that need to be
|
| 150 |
+
// exposed in the final shared library must be declared with PT_EXPORT to make
|
| 151 |
+
// them visible.
|
| 152 |
+
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
|
| 153 |
+
#define PT_EXPORT __attribute__((__visibility__("default")))
|
| 154 |
+
#else
|
| 155 |
+
#ifdef _WIN32
|
| 156 |
+
#define PT_EXPORT __declspec(dllexport)
|
| 157 |
+
#else
|
| 158 |
+
#define PT_EXPORT
|
| 159 |
+
#endif
|
| 160 |
+
#endif
|
| 161 |
+
using bfloat16 = nv_bfloat16;
|
| 162 |
+
"""
|
| 163 |
+
)
|
| 164 |
+
return res
|
| 165 |
+
|
| 166 |
+
def render(self, **kwargs) -> str:
|
| 167 |
+
raise NotImplementedError
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class CUTLASSTemplate(CUDATemplate):
|
| 171 |
+
"""
|
| 172 |
+
CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the
|
| 173 |
+
CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
def header(self) -> IndentedBuffer:
|
| 177 |
+
res = super().header()
|
| 178 |
+
res.splice(
|
| 179 |
+
"""
|
| 180 |
+
#include "cute/tensor.hpp"
|
| 181 |
+
#include "cutlass/cutlass.h"
|
| 182 |
+
#include "cutlass/numeric_types.h"
|
| 183 |
+
#include "cutlass/tensor_ref.h"
|
| 184 |
+
#include "cutlass/util/host_tensor.h"
|
| 185 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 186 |
+
#include "cutlass/util/reference/device/tensor_fill.h"
|
| 187 |
+
#include "cutlass/util/device_memory.h"
|
| 188 |
+
"""
|
| 189 |
+
)
|
| 190 |
+
return res
|
| 191 |
+
|
| 192 |
+
def globals(self) -> IndentedBuffer:
|
| 193 |
+
res = super().globals()
|
| 194 |
+
res.splice(
|
| 195 |
+
"""
|
| 196 |
+
using namespace cute;
|
| 197 |
+
#define CUTLASS_CHECK(status) \\
|
| 198 |
+
{ \\
|
| 199 |
+
cutlass::Status error = status; \\
|
| 200 |
+
if (error != cutlass::Status::kSuccess) { \\
|
| 201 |
+
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\
|
| 202 |
+
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\
|
| 203 |
+
throw std::runtime_error(msg); \\
|
| 204 |
+
} \\
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
// Used as pass-through functor in EVT just for type casting / rounding
|
| 208 |
+
template <typename T>
|
| 209 |
+
struct identity_op {
|
| 210 |
+
CUTLASS_HOST_DEVICE
|
| 211 |
+
T operator()(T val) const { return val; }
|
| 212 |
+
};
|
| 213 |
+
|
| 214 |
+
"""
|
| 215 |
+
)
|
| 216 |
+
return res
|
| 217 |
+
|
| 218 |
+
def cute_int(self, int_str: str, var_name: str) -> str:
|
| 219 |
+
res = ""
|
| 220 |
+
if int_str in {"1", "1L"}:
|
| 221 |
+
res = "cute::Int<1>{}"
|
| 222 |
+
else:
|
| 223 |
+
res = int_str
|
| 224 |
+
|
| 225 |
+
return f"{res} /* {var_name} */"
|
| 226 |
+
|
| 227 |
+
_DTYPE_TO_CUTLASS = {
|
| 228 |
+
torch.float32: "float",
|
| 229 |
+
torch.float64: "double",
|
| 230 |
+
torch.float16: "cutlass::half_t",
|
| 231 |
+
torch.int32: "int",
|
| 232 |
+
torch.int8: "int8_t",
|
| 233 |
+
torch.uint8: "uint8_t",
|
| 234 |
+
torch.bool: "bool",
|
| 235 |
+
torch.bfloat16: "cutlass::bfloat16_t",
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
def cutlass_type_cast(self, node: IRNode, ptr: str) -> str:
|
| 239 |
+
if node is None:
|
| 240 |
+
return ptr
|
| 241 |
+
else:
|
| 242 |
+
return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})"
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
from unittest.mock import patch
|
| 3 |
+
|
| 4 |
+
import sympy
|
| 5 |
+
|
| 6 |
+
import torch._inductor.virtualized as virtualized
|
| 7 |
+
from torch._inductor.ir import ComputedBuffer, FlexibleLayout, IRNode, Pointwise
|
| 8 |
+
from torch._inductor.utils import IndentedBuffer, sympy_str
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Used as a magic string to indicate an unsupported sympy expression
|
| 12 |
+
# became part of generated C++ code.
|
| 13 |
+
_MAGIC_SYMPY_ERROR_STRING = "[!sympy: unsupported expr!]"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _arg_str(a):
|
| 17 |
+
if isinstance(a, sympy.Expr):
|
| 18 |
+
# If this return value containting the _MAGIC_SYMPY_ERROR_STRING
|
| 19 |
+
# is used as part of the final generated C++ code,
|
| 20 |
+
# a CUTLASSEVTOpNotImplementedError is raised to indicate that
|
| 21 |
+
# the op could not be converted to a valid EVT expression.
|
| 22 |
+
return f"{_MAGIC_SYMPY_ERROR_STRING}('{sympy_str(a)}')"
|
| 23 |
+
return str(a)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CUTLASSEVTOpNotImplementedError(NotImplementedError):
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CutlassEVTEpilogueTypeFormatter:
|
| 31 |
+
"""
|
| 32 |
+
Codegen class, which provides an entry point to generate
|
| 33 |
+
Cutlass "Epilogue Visitor Tree" (EVT) functor declarations.
|
| 34 |
+
|
| 35 |
+
See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder
|
| 36 |
+
for more about EVTs and how they are declared and used to generate.
|
| 37 |
+
|
| 38 |
+
Notes:
|
| 39 |
+
* Used by CUTLASSGemmTemplate.
|
| 40 |
+
* This class should not be instantiated by users, it is intended to be used
|
| 41 |
+
by calling CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(...)
|
| 42 |
+
which instantiates this class as an ops handler for virtualized.V.ops.[op-name]
|
| 43 |
+
* Extend this with more _op_<whatever> nodes to add support for new pointwise operations.
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, accumulator_node_name, evt_type_name):
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
Initialize an instance of CutlassEVTEpilogueTypeFormatter.
|
| 52 |
+
|
| 53 |
+
Parameters:
|
| 54 |
+
- accumulator_node_name (str): The name of the output Buffer for the GEMM operation in the original (unfused)
|
| 55 |
+
IR graph.
|
| 56 |
+
- evt_type_name (str): The output name of the EVT type we are generating.
|
| 57 |
+
|
| 58 |
+
"""
|
| 59 |
+
self.accumulator_node_name = accumulator_node_name
|
| 60 |
+
self.output = IndentedBuffer(0)
|
| 61 |
+
self.var_counter = 0
|
| 62 |
+
self.evt_type_name = evt_type_name
|
| 63 |
+
self.aliases = dict()
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def ir_to_evt_string(
|
| 67 |
+
template_output_node_name: str,
|
| 68 |
+
evt_type_name: str,
|
| 69 |
+
epilogue_nodes: List[IRNode],
|
| 70 |
+
):
|
| 71 |
+
"""
|
| 72 |
+
Formats IR nodes into a string representation compatible with Cutlass EVT format.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
template_output_node_name (str): The name of the template output node.
|
| 76 |
+
evt_type_name (str): The name of the EVT type.
|
| 77 |
+
epilogue_nodes (List[IRNode]): A list of IR nodes representing the epilogue nodes. As of now, these must be
|
| 78 |
+
ComputedBuffer nodes wrapping Pointwise nodes.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
A string representation of the IR nodes formatted according to the Cutlass EVT format.
|
| 82 |
+
"""
|
| 83 |
+
formatter = CutlassEVTEpilogueTypeFormatter(
|
| 84 |
+
template_output_node_name, evt_type_name
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
with virtualized.V.set_ops_handler(formatter), patch.object(
|
| 88 |
+
FlexibleLayout, "allow_indexing", True
|
| 89 |
+
):
|
| 90 |
+
for node in epilogue_nodes:
|
| 91 |
+
if isinstance(node, ComputedBuffer):
|
| 92 |
+
pnode = node.data
|
| 93 |
+
else:
|
| 94 |
+
raise RuntimeError(
|
| 95 |
+
"Epilogue nodes must be Pointwise nodes, wrapped in a named ComputedBuffer"
|
| 96 |
+
)
|
| 97 |
+
assert isinstance(pnode, Pointwise)
|
| 98 |
+
index = pnode._index(pnode.ranges)
|
| 99 |
+
result = pnode.inner_fn(index)
|
| 100 |
+
# each epilogue node results in a single "using" statement and may refer to the previous steps by name
|
| 101 |
+
formatter.aliases[node.name] = result
|
| 102 |
+
res = formatter.getvalue(result) # type: ignore[possibly-undefined]
|
| 103 |
+
if _MAGIC_SYMPY_ERROR_STRING in res:
|
| 104 |
+
raise CUTLASSEVTOpNotImplementedError(
|
| 105 |
+
"sympy / indexing expressions not yet supported in EVT fusion"
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
return res
|
| 109 |
+
|
| 110 |
+
def __getattr__(self, name):
|
| 111 |
+
"""
|
| 112 |
+
Resolve V.ops.<whatever> calls, after this instance has been installed as V.ops handler.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def inner(*args, **kwargs):
|
| 116 |
+
fargs = [_arg_str(a) for a in args]
|
| 117 |
+
fkwargs = {key: _arg_str(a) for key, a in kwargs.items()}
|
| 118 |
+
fn = getattr(self, f"_op_{name}")
|
| 119 |
+
line = fn(*fargs, **fkwargs)
|
| 120 |
+
self.var_counter += 1
|
| 121 |
+
varname = f"EVT_expr_{self.var_counter}"
|
| 122 |
+
# replace line with a new variable name
|
| 123 |
+
self.output.writeline(f"using {varname} = {line};")
|
| 124 |
+
return varname
|
| 125 |
+
|
| 126 |
+
if name.startswith("_"):
|
| 127 |
+
raise CUTLASSEVTOpNotImplementedError(name)
|
| 128 |
+
if hasattr(self, f"_op_{name}"):
|
| 129 |
+
return inner
|
| 130 |
+
else:
|
| 131 |
+
raise CUTLASSEVTOpNotImplementedError(name)
|
| 132 |
+
|
| 133 |
+
def _op_load(self, name, index_expr):
|
| 134 |
+
# Load an input to an operation. Might be the output of the matmul, the result
|
| 135 |
+
# of a previous epilogue node, a constant or (TODO) an auxiliary input.
|
| 136 |
+
if name == self.accumulator_node_name:
|
| 137 |
+
return f"cutlass::epilogue::fusion::Sm90AccFetch /* :={name} (matmul output in accumulator) */"
|
| 138 |
+
elif name in self.aliases:
|
| 139 |
+
return self.aliases[name]
|
| 140 |
+
else:
|
| 141 |
+
# return f"cutlass::epilogue::fusion::Sm90SrcFetch /* :={name} */"
|
| 142 |
+
raise CUTLASSEVTOpNotImplementedError(
|
| 143 |
+
f"Operand {name} not found. Auxiliary inputs not supported yet."
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def _op_constant(self, value, dtype):
|
| 147 |
+
# Load a constant
|
| 148 |
+
if str(dtype) in ("torch.float16", "torch.float32"):
|
| 149 |
+
return f"cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc> /* value={value}, dtype={dtype} */"
|
| 150 |
+
else:
|
| 151 |
+
raise CUTLASSEVTOpNotImplementedError(
|
| 152 |
+
f"Unsupported dtype for constant: {dtype}"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def _cutlass_binary_functional_op(self, op, a, b):
|
| 156 |
+
# Perform a named operation on two inputs
|
| 157 |
+
# see https://github.com/NVIDIA/cutlass/blob/6407bcdf0a24097b7b016ee105937693c62f9923/include/cutlass/functional.h for ops
|
| 158 |
+
return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::{op}, ElementAcc, ElementAcc, RoundStyle>,{a},{b}>" # noqa: B950
|
| 159 |
+
|
| 160 |
+
def _convert_to_output_dtype(self, a):
|
| 161 |
+
# Convert the final output to the dtype of the output buffer
|
| 162 |
+
return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,{a}>" # noqa: B950
|
| 163 |
+
|
| 164 |
+
def _op_to_dtype(self, a, *args, **kwargs):
|
| 165 |
+
# no-op in our case, since we convert to the output dtype at the end and convert everything to the accumulator
|
| 166 |
+
# dtype.
|
| 167 |
+
# Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible
|
| 168 |
+
# throughout the fusion chain.
|
| 169 |
+
return a # noqa: B950
|
| 170 |
+
|
| 171 |
+
def _op_mul(self, a, b):
|
| 172 |
+
return self._cutlass_binary_functional_op("multiplies", a, b)
|
| 173 |
+
|
| 174 |
+
def _op_div(self, a, b):
|
| 175 |
+
return self._cutlass_binary_functional_op("divides", a, b)
|
| 176 |
+
|
| 177 |
+
def _op_truediv(self, a, b):
|
| 178 |
+
return self._cutlass_binary_functional_op("divides", a, b)
|
| 179 |
+
|
| 180 |
+
def _op_ge(self, a, b):
|
| 181 |
+
return self._cutlass_binary_functional_op("greater_equal", a, b)
|
| 182 |
+
|
| 183 |
+
def _op_add(self, a, b):
|
| 184 |
+
return self._cutlass_binary_functional_op("plus", a, b)
|
| 185 |
+
|
| 186 |
+
def _op_sub(self, a, b):
|
| 187 |
+
return self._cutlass_binary_functional_op("minus", a, b)
|
| 188 |
+
|
| 189 |
+
def _op_minimum(self, a, b):
|
| 190 |
+
return self._cutlass_binary_functional_op("minimum", a, b)
|
| 191 |
+
|
| 192 |
+
def _op_maximum(self, a, b):
|
| 193 |
+
return self._cutlass_binary_functional_op("maximum", a, b)
|
| 194 |
+
|
| 195 |
+
def _op_relu(self, a):
|
| 196 |
+
const_zero = self._op_constant(0.0, "torch.float32")
|
| 197 |
+
return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, ElementAcc, ElementAcc, RoundStyle>,{a}, {const_zero}>" # noqa: B950
|
| 198 |
+
|
| 199 |
+
def reduction(self, dtype, src_dtype, reduction_type, value):
|
| 200 |
+
raise CUTLASSEVTOpNotImplementedError()
|
| 201 |
+
|
| 202 |
+
# Add more ops here...
|
| 203 |
+
def getvalue(self, result) -> str:
|
| 204 |
+
# Return final result
|
| 205 |
+
dtype_converted_expr = self._convert_to_output_dtype(
|
| 206 |
+
f"EVT_expr_{self.var_counter}"
|
| 207 |
+
)
|
| 208 |
+
self.output.writeline(f"using {self.evt_type_name} = {dtype_converted_expr};")
|
| 209 |
+
return self.output.getvalue()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class CutlassEVTEpilogueArgumentFormatter:
|
| 213 |
+
"""
|
| 214 |
+
Codegen class, which provides an entry point to generate
|
| 215 |
+
Cutlass "Epilogue Visitor Tree" (EVT) Argument initializers
|
| 216 |
+
|
| 217 |
+
See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder
|
| 218 |
+
for more about EVTs and how they are declared and used to generate.
|
| 219 |
+
|
| 220 |
+
Notes:
|
| 221 |
+
* Used by CUTLASSGemmTemplate.
|
| 222 |
+
* This class should not be instantiated by users, it is intended to be used
|
| 223 |
+
by calling CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(...)
|
| 224 |
+
which instantiates this class as an ops handler for virtualized.V.ops.[op-name]
|
| 225 |
+
* Extend this with more _op_<whatever> nodes to add support for new pointwise operations.
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
def __init__(self, accumulator_node_name: str):
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly.
|
| 234 |
+
Use the CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string static method.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
accumulator_node_name (str): The name of the accumulator node which should contain
|
| 238 |
+
the Matmul result before fusion according to the IR graph.
|
| 239 |
+
"""
|
| 240 |
+
self.accumulator_node_name: str = accumulator_node_name #
|
| 241 |
+
self.output: IndentedBuffer = IndentedBuffer(0) # The output buffer for codegen
|
| 242 |
+
self.var_counter: int = (
|
| 243 |
+
0 # used to generate variable names, incremented for each new variable
|
| 244 |
+
)
|
| 245 |
+
self.aliases: Dict[str, str] = dict() # Aliases for subexpression functors
|
| 246 |
+
|
| 247 |
+
@staticmethod
|
| 248 |
+
def ir_to_evt_argument_string(
|
| 249 |
+
template_output_node_name: str,
|
| 250 |
+
epilogue_nodes: List[IRNode],
|
| 251 |
+
) -> str:
|
| 252 |
+
formatter = CutlassEVTEpilogueArgumentFormatter(
|
| 253 |
+
template_output_node_name,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
with virtualized.V.set_ops_handler(formatter), patch.object(
|
| 257 |
+
FlexibleLayout, "allow_indexing", True
|
| 258 |
+
):
|
| 259 |
+
for node in epilogue_nodes:
|
| 260 |
+
assert isinstance(node, ComputedBuffer)
|
| 261 |
+
pnode = node.data
|
| 262 |
+
assert isinstance(pnode, Pointwise)
|
| 263 |
+
index = pnode._index(pnode.ranges)
|
| 264 |
+
result = pnode.inner_fn(index)
|
| 265 |
+
# each epilogue node results in a single "using" statement and may refer to the previous steps by name
|
| 266 |
+
if node.name is not None:
|
| 267 |
+
formatter.aliases[node.name] = result
|
| 268 |
+
|
| 269 |
+
res: str = formatter.getvalue(result) # type: ignore[possibly-undefined]
|
| 270 |
+
if _MAGIC_SYMPY_ERROR_STRING in res:
|
| 271 |
+
raise CUTLASSEVTOpNotImplementedError(
|
| 272 |
+
"sympy / indexing expressions not yet supported in EVT fusion"
|
| 273 |
+
)
|
| 274 |
+
else:
|
| 275 |
+
return res
|
| 276 |
+
|
| 277 |
+
def __getattr__(self, name):
|
| 278 |
+
def inner(*args, **kwargs):
|
| 279 |
+
fargs = [_arg_str(a) for a in args]
|
| 280 |
+
fkwargs = {key: _arg_str(a) for key, a in kwargs.items()}
|
| 281 |
+
fn = getattr(self, f"_op_{name}")
|
| 282 |
+
line = fn(*fargs, **fkwargs)
|
| 283 |
+
return line
|
| 284 |
+
|
| 285 |
+
if name.startswith("_"):
|
| 286 |
+
raise CUTLASSEVTOpNotImplementedError(name)
|
| 287 |
+
|
| 288 |
+
if hasattr(self, f"_op_{name}"):
|
| 289 |
+
return inner
|
| 290 |
+
else:
|
| 291 |
+
raise CUTLASSEVTOpNotImplementedError(name)
|
| 292 |
+
|
| 293 |
+
def _op_load(self, name, index_expr):
|
| 294 |
+
if name == self.accumulator_node_name:
|
| 295 |
+
return "{}"
|
| 296 |
+
elif name in self.aliases:
|
| 297 |
+
return self.aliases[name]
|
| 298 |
+
else:
|
| 299 |
+
raise CUTLASSEVTOpNotImplementedError(
|
| 300 |
+
f"Operand {name} not found. Auxiliary inputs not supported yet."
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
def _op_constant(self, value, dtype):
|
| 304 |
+
if str(dtype) in ("torch.float16", "torch.float32"):
|
| 305 |
+
return "{ static_cast<ElementAcc>(" + str(value) + ") }"
|
| 306 |
+
else:
|
| 307 |
+
raise CUTLASSEVTOpNotImplementedError(
|
| 308 |
+
f"Unsupported dtype for constant: {dtype}"
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def _cutlass_binary_functional_op(self, op, a, b):
|
| 312 |
+
return f"{{ /*{op}: */ {a}, {b} }}"
|
| 313 |
+
|
| 314 |
+
def _op_mul(self, a, b):
|
| 315 |
+
return self._cutlass_binary_functional_op("multiplies", a, b)
|
| 316 |
+
|
| 317 |
+
def _op_div(self, a, b):
|
| 318 |
+
return self._cutlass_binary_functional_op("divides", a, b)
|
| 319 |
+
|
| 320 |
+
def _op_truediv(self, a, b):
|
| 321 |
+
return self._cutlass_binary_functional_op("divides", a, b)
|
| 322 |
+
|
| 323 |
+
def _op_ge(self, a, b):
|
| 324 |
+
return self._cutlass_binary_functional_op("greater_equal", a, b)
|
| 325 |
+
|
| 326 |
+
def _op_add(self, a, b):
|
| 327 |
+
return self._cutlass_binary_functional_op("plus", a, b)
|
| 328 |
+
|
| 329 |
+
def _op_sub(self, a, b):
|
| 330 |
+
return self._cutlass_binary_functional_op("minus", a, b)
|
| 331 |
+
|
| 332 |
+
def _op_minimum(self, a, b):
|
| 333 |
+
return self._cutlass_binary_functional_op("minimum", a, b)
|
| 334 |
+
|
| 335 |
+
def _op_maximum(self, a, b):
|
| 336 |
+
return self._cutlass_binary_functional_op("maximum", a, b)
|
| 337 |
+
|
| 338 |
+
def _op_relu(self, a):
|
| 339 |
+
const_zero = self._op_constant(0.0, "torch.float32")
|
| 340 |
+
return "{" + str(a) + ", " + const_zero + "}"
|
| 341 |
+
|
| 342 |
+
def _op_to_dtype(self, a, dtype, src_dtype=None):
|
| 343 |
+
# Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible
|
| 344 |
+
# throughout the fusion chain.
|
| 345 |
+
assert dtype in (
|
| 346 |
+
"torch.float32",
|
| 347 |
+
"torch.float16",
|
| 348 |
+
), f"Unsupported dtype: {dtype}"
|
| 349 |
+
assert src_dtype in (
|
| 350 |
+
None,
|
| 351 |
+
"torch.float32",
|
| 352 |
+
"torch.float16",
|
| 353 |
+
), f"Unsupported source dtype: {src_dtype}"
|
| 354 |
+
return a
|
| 355 |
+
|
| 356 |
+
def reduction(self, dtype, src_dtype, reduction_type, value):
|
| 357 |
+
raise CUTLASSEVTOpNotImplementedError()
|
| 358 |
+
|
| 359 |
+
def getvalue(self, result) -> str:
|
| 360 |
+
return "{" + str(result) + "}"
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..cutlass_utils import try_import_cutlass
|
| 2 |
+
|
| 3 |
+
if try_import_cutlass():
|
| 4 |
+
import enum
|
| 5 |
+
|
| 6 |
+
from cutlass_library.library import * # noqa: F401, F403
|
| 7 |
+
from cutlass_library.gemm_operation import * # noqa: F401, F403
|
| 8 |
+
|
| 9 |
+
# copied / modified from original at
|
| 10 |
+
# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658
|
| 11 |
+
# to support EVT similar to
|
| 12 |
+
# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L315C69-L315C69 # noqa: B950
|
| 13 |
+
class EmitGemmUniversal3xInstanceWithEVT:
|
| 14 |
+
"""Responsible for emitting a CUTLASS 3.x template definition"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, operation_suffix=""):
|
| 17 |
+
self.operation_suffix = operation_suffix
|
| 18 |
+
self.includes = [
|
| 19 |
+
"cutlass/cutlass.h",
|
| 20 |
+
"cutlass/gemm/gemm.h",
|
| 21 |
+
"cutlass/numeric_types.h",
|
| 22 |
+
"cutlass/gemm/kernel/gemm_universal.hpp",
|
| 23 |
+
"cutlass/gemm/collective/collective_builder.hpp",
|
| 24 |
+
"cutlass/epilogue/collective/collective_builder.hpp",
|
| 25 |
+
]
|
| 26 |
+
self.builtin_epilogue_functor_template = """
|
| 27 |
+
${epilogue_functor}<
|
| 28 |
+
${element_c},
|
| 29 |
+
${epilogue_vector_length},
|
| 30 |
+
${element_accumulator},
|
| 31 |
+
${element_epilogue}
|
| 32 |
+
>
|
| 33 |
+
"""
|
| 34 |
+
self.gemm_template = """
|
| 35 |
+
using EpilogueScheduleType = ${epilogue_schedule};
|
| 36 |
+
static_assert(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> ||
|
| 37 |
+
cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>,
|
| 38 |
+
"Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue");
|
| 39 |
+
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
| 40 |
+
using ElementAcc = ${element_accumulator};
|
| 41 |
+
using ElementD = ${element_d};
|
| 42 |
+
${epilogue_functor};
|
| 43 |
+
using ${operation_name}_epilogue =
|
| 44 |
+
typename cutlass::epilogue::collective::CollectiveBuilder<
|
| 45 |
+
${arch}, ${opcode_class},
|
| 46 |
+
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
|
| 47 |
+
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
|
| 48 |
+
cutlass::epilogue::collective::EpilogueTileAuto,
|
| 49 |
+
${element_accumulator}, ${element_epilogue},
|
| 50 |
+
${element_c}, ${layout_c}, ${align_c},
|
| 51 |
+
${element_d}, ${layout_d}, ${align_d},
|
| 52 |
+
EpilogueScheduleType,
|
| 53 |
+
${operation_name}_epilogue_functor
|
| 54 |
+
>::CollectiveOp;
|
| 55 |
+
|
| 56 |
+
using ${operation_name}_mainloop =
|
| 57 |
+
typename cutlass::gemm::collective::CollectiveBuilder<
|
| 58 |
+
${arch}, ${opcode_class},
|
| 59 |
+
${element_a}, ${layout_a}, ${align_a},
|
| 60 |
+
${element_b}, ${layout_b}, ${align_b},
|
| 61 |
+
${element_accumulator},
|
| 62 |
+
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
|
| 63 |
+
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
|
| 64 |
+
${stages},
|
| 65 |
+
${kernel_schedule}
|
| 66 |
+
>::CollectiveOp;
|
| 67 |
+
|
| 68 |
+
// Gemm operator ${operation_name}
|
| 69 |
+
using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
|
| 70 |
+
cute::Shape<int,int,int,int>,
|
| 71 |
+
${operation_name}_mainloop,
|
| 72 |
+
${operation_name}_epilogue,
|
| 73 |
+
${tile_scheduler}>;
|
| 74 |
+
|
| 75 |
+
// Define named type
|
| 76 |
+
struct ${operation_name} :
|
| 77 |
+
public ${operation_name}_base { };
|
| 78 |
+
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
#
|
| 82 |
+
def instance_template(self):
|
| 83 |
+
return """
|
| 84 |
+
${compile_guard_start}
|
| 85 |
+
using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
|
| 86 |
+
manifest.append(
|
| 87 |
+
new ${gemm_kind}<GemmKernel>("${operation_name}"));
|
| 88 |
+
${compile_guard_end}
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
#
|
| 92 |
+
def emit(self, operation):
|
| 93 |
+
tile_shape = operation.tile_description.tile_shape
|
| 94 |
+
warp_count = operation.tile_description.warp_count
|
| 95 |
+
# stage count set to zero indicates builder automatic stage selection
|
| 96 |
+
if operation.tile_description.stages > 0:
|
| 97 |
+
stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>"
|
| 98 |
+
else:
|
| 99 |
+
stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage)>" # noqa: B950
|
| 100 |
+
warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)]
|
| 101 |
+
|
| 102 |
+
(
|
| 103 |
+
instance_layout_A,
|
| 104 |
+
instance_layout_B,
|
| 105 |
+
instance_layout_C,
|
| 106 |
+
instance_layout_D,
|
| 107 |
+
) = (
|
| 108 |
+
operation.A.layout,
|
| 109 |
+
operation.B.layout,
|
| 110 |
+
operation.C.layout,
|
| 111 |
+
operation.D.layout,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# 3.0 profiler integration only supports trivial epilogues for now
|
| 115 |
+
epilogue_vector_length = 1
|
| 116 |
+
|
| 117 |
+
# Support built-in epilogue functors or user-defined functions
|
| 118 |
+
if isinstance(operation.epilogue_functor, enum.Enum):
|
| 119 |
+
values = {
|
| 120 |
+
"epilogue_vector_length": str(epilogue_vector_length),
|
| 121 |
+
"element_epilogue": str(DataTypeTag[operation.element_epilogue]), # type: ignore[name-defined]
|
| 122 |
+
"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], # type: ignore[name-defined]
|
| 123 |
+
}
|
| 124 |
+
epilogue_functor = SubstituteTemplate( # type: ignore[name-defined]
|
| 125 |
+
self.builtin_epilogue_functor_template, values
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
elif callable(operation.epilogue_functor):
|
| 129 |
+
epilogue_functor = operation.epilogue_functor(
|
| 130 |
+
operation.procedural_name() + "_epilogue_functor"
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
epilogue_functor = str(operation.epilogue_functor)
|
| 134 |
+
#
|
| 135 |
+
|
| 136 |
+
values = {
|
| 137 |
+
"operation_name": operation.procedural_name(),
|
| 138 |
+
"operation_suffix": self.operation_suffix,
|
| 139 |
+
"element_a": DataTypeTag[operation.A.element], # type: ignore[name-defined]
|
| 140 |
+
"layout_a": LayoutTag[instance_layout_A], # type: ignore[name-defined]
|
| 141 |
+
"element_b": DataTypeTag[operation.B.element], # type: ignore[name-defined]
|
| 142 |
+
"layout_b": LayoutTag[instance_layout_B], # type: ignore[name-defined]
|
| 143 |
+
"element_c": DataTypeTag[operation.C.element], # type: ignore[name-defined]
|
| 144 |
+
"layout_c": LayoutTag[instance_layout_C], # type: ignore[name-defined]
|
| 145 |
+
"element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined]
|
| 146 |
+
"layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined]
|
| 147 |
+
"element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined]
|
| 148 |
+
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], # type: ignore[name-defined] # noqa: B950
|
| 149 |
+
"arch": "cutlass::arch::Sm%d" % operation.arch,
|
| 150 |
+
"tile_shape_m": str(operation.tile_description.tile_shape[0]),
|
| 151 |
+
"tile_shape_n": str(operation.tile_description.tile_shape[1]),
|
| 152 |
+
"tile_shape_k": str(operation.tile_description.tile_shape[2]),
|
| 153 |
+
"cluster_m": str(operation.tile_description.cluster_shape[0]),
|
| 154 |
+
"cluster_n": str(operation.tile_description.cluster_shape[1]),
|
| 155 |
+
"cluster_k": str(operation.tile_description.cluster_shape[2]),
|
| 156 |
+
"warp_shape_m": str(warp_shape[0]),
|
| 157 |
+
"warp_shape_n": str(warp_shape[1]),
|
| 158 |
+
"warp_shape_k": str(warp_shape[2]),
|
| 159 |
+
"instruction_shape_m": str(
|
| 160 |
+
operation.tile_description.math_instruction.instruction_shape[0]
|
| 161 |
+
),
|
| 162 |
+
"instruction_shape_n": str(
|
| 163 |
+
operation.tile_description.math_instruction.instruction_shape[1]
|
| 164 |
+
),
|
| 165 |
+
"instruction_shape_k": str(
|
| 166 |
+
operation.tile_description.math_instruction.instruction_shape[2]
|
| 167 |
+
),
|
| 168 |
+
"kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined]
|
| 169 |
+
"epilogue_schedule": str(EpilogueScheduleTag[operation.epilogue_schedule]), # type: ignore[name-defined]
|
| 170 |
+
"epilogue_functor": epilogue_functor,
|
| 171 |
+
"stages": stage_count_string,
|
| 172 |
+
"align_a": str(operation.A.alignment),
|
| 173 |
+
"align_b": str(operation.B.alignment),
|
| 174 |
+
"align_c": str(operation.C.alignment),
|
| 175 |
+
"align_d": str(operation.C.alignment),
|
| 176 |
+
"transform_a": ComplexTransformTag[operation.A.complex_transform], # type: ignore[name-defined]
|
| 177 |
+
"transform_b": ComplexTransformTag[operation.B.complex_transform], # type: ignore[name-defined]
|
| 178 |
+
"math_operation": MathOperationTag[ # type: ignore[name-defined]
|
| 179 |
+
operation.tile_description.math_instruction.math_operation
|
| 180 |
+
],
|
| 181 |
+
"epilogue_vector_length": str(epilogue_vector_length),
|
| 182 |
+
"element_epilogue": str(DataTypeTag[operation.element_epilogue]), # type: ignore[name-defined]
|
| 183 |
+
"tile_scheduler": str(TileSchedulerTag[operation.tile_scheduler]), # type: ignore[name-defined]
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
return SubstituteTemplate(self.gemm_template, values) # type: ignore[name-defined]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..common import DeviceOpOverrides, register_device_op_overrides
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CUDADeviceOpOverrides(DeviceOpOverrides):
|
| 5 |
+
def import_get_raw_stream_as(self, name):
|
| 6 |
+
return f"from torch._C import _cuda_getCurrentRawStream as {name}"
|
| 7 |
+
|
| 8 |
+
def set_device(self, device_idx):
|
| 9 |
+
return f"torch.cuda.set_device({device_idx})"
|
| 10 |
+
|
| 11 |
+
def synchronize(self):
|
| 12 |
+
return "torch.cuda.synchronize()"
|
| 13 |
+
|
| 14 |
+
def device_guard(self, device_idx):
|
| 15 |
+
return f"torch.cuda._DeviceGuard({device_idx})"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
register_device_op_overrides("cuda", CUDADeviceOpOverrides())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from ..scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode
|
| 4 |
+
from .cuda.cuda_cpp_scheduling import CUDACPPScheduling
|
| 5 |
+
|
| 6 |
+
from .triton import TritonScheduling
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CUDACombinedScheduling(BaseScheduling):
|
| 10 |
+
"""
|
| 11 |
+
Scheduler for CUDA Kernels, which delegates calls as appropriate
|
| 12 |
+
to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices
|
| 13 |
+
and use a unified-wrapper for codegen.
|
| 14 |
+
|
| 15 |
+
If Scheduling code needs to be specialized for the case of mixed Triton / CUDA C++ code,
|
| 16 |
+
this would also be the place to do it.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, scheduler: Scheduler):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self._scheduler = scheduler
|
| 22 |
+
self._triton_scheduling = TritonScheduling(scheduler)
|
| 23 |
+
self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler)
|
| 24 |
+
|
| 25 |
+
def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling:
|
| 26 |
+
if self._cuda_cpp_scheduling.is_cuda_cpp_template(
|
| 27 |
+
node
|
| 28 |
+
) or self._cuda_cpp_scheduling.is_cuda_cpp_fused_template(node):
|
| 29 |
+
return self._cuda_cpp_scheduling
|
| 30 |
+
return self._triton_scheduling
|
| 31 |
+
|
| 32 |
+
def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
|
| 33 |
+
if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2):
|
| 34 |
+
return True
|
| 35 |
+
return self._triton_scheduling.can_fuse_vertical(node1, node2)
|
| 36 |
+
|
| 37 |
+
def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
|
| 38 |
+
for node in (node1, node2):
|
| 39 |
+
if self._cuda_cpp_scheduling.is_cuda_cpp_template(
|
| 40 |
+
node
|
| 41 |
+
) or self._cuda_cpp_scheduling.is_cuda_cpp_fused_template(node):
|
| 42 |
+
return self._cuda_cpp_scheduling.can_fuse_horizontal(
|
| 43 |
+
node1, node2
|
| 44 |
+
) # always False at the moment
|
| 45 |
+
return self._triton_scheduling.can_fuse_horizontal(node1, node2)
|
| 46 |
+
|
| 47 |
+
def group_fn(self, sizes):
|
| 48 |
+
return self._triton_scheduling.group_fn(sizes)
|
| 49 |
+
|
| 50 |
+
def codegen_template(
|
| 51 |
+
self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode]
|
| 52 |
+
):
|
| 53 |
+
if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
|
| 54 |
+
return self._cuda_cpp_scheduling.codegen_template(
|
| 55 |
+
template_node, epilogue_nodes
|
| 56 |
+
)
|
| 57 |
+
else:
|
| 58 |
+
return self._triton_scheduling.codegen_template(
|
| 59 |
+
template_node, epilogue_nodes
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def codegen_nodes(self, nodes: List[SchedulerNode]):
|
| 63 |
+
return self._triton_scheduling.codegen_nodes(nodes)
|
| 64 |
+
|
| 65 |
+
def codegen_sync(self):
|
| 66 |
+
return self._triton_scheduling.codegen_sync()
|
| 67 |
+
|
| 68 |
+
def flush(self):
|
| 69 |
+
return self._triton_scheduling.flush()
|
| 70 |
+
|
| 71 |
+
def codegen_foreach(self, *args, **kwargs):
|
| 72 |
+
return self._triton_scheduling.codegen_foreach(*args, **kwargs)
|
| 73 |
+
|
| 74 |
+
def benchmark_fused_nodes(self, nodes):
|
| 75 |
+
return self._triton_scheduling.benchmark_fused_nodes(nodes)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Set, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch._dynamo.utils import counters
|
| 7 |
+
|
| 8 |
+
from torch._ops import OpOverload, OpOverloadPacket
|
| 9 |
+
from ..pattern_matcher import fwd_only, register_replacement
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@functools.lru_cache(None)
|
| 15 |
+
def _misc_patterns_init():
|
| 16 |
+
from .joint_graph import patterns as joint_graph_patterns
|
| 17 |
+
from .post_grad import pass_patterns as post_grad_patterns_all
|
| 18 |
+
|
| 19 |
+
post_grad_patterns = post_grad_patterns_all[1] # medium priority
|
| 20 |
+
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
# workaround https://github.com/pytorch/pytorch/issues/97894
|
| 23 |
+
device = "cuda"
|
| 24 |
+
else:
|
| 25 |
+
device = "cpu"
|
| 26 |
+
|
| 27 |
+
# These patterns do 2 things
|
| 28 |
+
# 1. Since we know that index is completely unique, we can codegen it using
|
| 29 |
+
# stores instead of atomic adds, which is quite a bit faster.
|
| 30 |
+
# 2. Also, since we are guaranteed that they are completely within bounds,
|
| 31 |
+
# we can use unsafe indexing and skip debug asserts
|
| 32 |
+
def randperm_index_add_pattern(x, y):
|
| 33 |
+
index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
|
| 34 |
+
return torch.index_add(x, dim=0, source=y, index=index), index
|
| 35 |
+
|
| 36 |
+
def randperm_index_add_replacement(x, y):
|
| 37 |
+
index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
|
| 38 |
+
return (
|
| 39 |
+
torch.ops.aten._unsafe_index_put(
|
| 40 |
+
x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False
|
| 41 |
+
),
|
| 42 |
+
index,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
register_replacement(
|
| 46 |
+
randperm_index_add_pattern,
|
| 47 |
+
randperm_index_add_replacement,
|
| 48 |
+
[torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)],
|
| 49 |
+
fwd_only,
|
| 50 |
+
[post_grad_patterns, joint_graph_patterns],
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def randperm_index_pattern(x, slice_shape):
|
| 54 |
+
index = torch.randperm(x.shape[0], device=x.device)[:slice_shape]
|
| 55 |
+
return torch.ops.aten.index(x, (index,)), index
|
| 56 |
+
|
| 57 |
+
def randperm_index_replacement(x, slice_shape):
|
| 58 |
+
index = torch.randperm(x.shape[0], device=x.device)[:slice_shape]
|
| 59 |
+
return torch.ops.aten._unsafe_index(x, (index,)), index
|
| 60 |
+
|
| 61 |
+
pattern = register_replacement(
|
| 62 |
+
randperm_index_pattern,
|
| 63 |
+
randperm_index_replacement,
|
| 64 |
+
[torch.empty(4, 8, device=device)],
|
| 65 |
+
fwd_only,
|
| 66 |
+
[post_grad_patterns, joint_graph_patterns],
|
| 67 |
+
scalar_workaround={"slice_shape": 42},
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class NumpyCompatNormalization:
|
| 72 |
+
numpy_compat: Dict[str, Tuple[str, ...]] = {
|
| 73 |
+
"dim": ("axis",),
|
| 74 |
+
"keepdim": ("keepdims",),
|
| 75 |
+
"input": ("x", "a", "x1"),
|
| 76 |
+
"other": ("x2",),
|
| 77 |
+
}
|
| 78 |
+
inverse_mapping: Dict[str, str]
|
| 79 |
+
cache: Dict["torch.fx.graph.Target", Set[str]]
|
| 80 |
+
|
| 81 |
+
def __init__(self):
|
| 82 |
+
self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"]
|
| 83 |
+
self.inverse_mapping = {}
|
| 84 |
+
for actual_kwarg, numpy_kwargs in self.numpy_compat.items():
|
| 85 |
+
for numpy_kwarg in numpy_kwargs:
|
| 86 |
+
assert numpy_kwarg not in self.inverse_mapping
|
| 87 |
+
self.inverse_mapping[numpy_kwarg] = actual_kwarg
|
| 88 |
+
|
| 89 |
+
def __call__(self, graph: torch.fx.Graph):
|
| 90 |
+
for node in graph.nodes:
|
| 91 |
+
if node.op != "call_function":
|
| 92 |
+
continue
|
| 93 |
+
if isinstance(node.target, (OpOverload, OpOverloadPacket)):
|
| 94 |
+
# only applies to torch ops; e.g. torch.stack(axis=1) works, torch.ops.aten.stack(axis=1) doesn't.
|
| 95 |
+
continue
|
| 96 |
+
kwargs = node.kwargs
|
| 97 |
+
|
| 98 |
+
if node.target in self.cache:
|
| 99 |
+
replaceable_kwargs = self.cache[node.target]
|
| 100 |
+
else:
|
| 101 |
+
signatures = torch.fx.operator_schemas.get_signature_for_torch_op(
|
| 102 |
+
node.target
|
| 103 |
+
)
|
| 104 |
+
signatures = () if signatures is None else signatures
|
| 105 |
+
replaceable_kwargs = set()
|
| 106 |
+
for sig in signatures:
|
| 107 |
+
for param_name in sig.parameters.keys():
|
| 108 |
+
if param_name in self.numpy_compat:
|
| 109 |
+
replaceable_kwargs.update(self.numpy_compat[param_name])
|
| 110 |
+
|
| 111 |
+
self.cache[node.target] = replaceable_kwargs
|
| 112 |
+
|
| 113 |
+
if not replaceable_kwargs:
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
new_kwargs = {}
|
| 117 |
+
kwargs_changed = False
|
| 118 |
+
for k, v in kwargs.items():
|
| 119 |
+
if k in replaceable_kwargs:
|
| 120 |
+
kwargs_changed = True
|
| 121 |
+
new_kwargs[self.inverse_mapping[k]] = v
|
| 122 |
+
else:
|
| 123 |
+
new_kwargs[k] = v
|
| 124 |
+
|
| 125 |
+
if kwargs_changed:
|
| 126 |
+
node.kwargs = torch.fx.immutable_collections.immutable_dict(new_kwargs)
|
| 127 |
+
counters["inductor"]["numpy_compat_normalization"] += 1
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
numpy_compat_normalization = NumpyCompatNormalization()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py
ADDED
|
@@ -0,0 +1,1204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import operator
|
| 3 |
+
from functools import reduce
|
| 4 |
+
from typing import Any, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
| 9 |
+
|
| 10 |
+
from .. import ir
|
| 11 |
+
|
| 12 |
+
from ..lowering import lowerings as L
|
| 13 |
+
from ..pattern_matcher import (
|
| 14 |
+
Arg,
|
| 15 |
+
CallFunction,
|
| 16 |
+
filter_nodes,
|
| 17 |
+
get_arg_value,
|
| 18 |
+
KeywordArg,
|
| 19 |
+
MULTIPLE,
|
| 20 |
+
)
|
| 21 |
+
from ..virtualized import ops
|
| 22 |
+
from .freezing_patterns import register_freezing_graph_pattern
|
| 23 |
+
from .post_grad import register_lowering_pattern
|
| 24 |
+
from .quantization import (
|
| 25 |
+
_register_quantization_lowerings,
|
| 26 |
+
_register_quantization_weight_pack_pass,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
if torch._C._has_mkldnn:
|
| 30 |
+
aten = torch.ops.aten
|
| 31 |
+
mkldnn = torch.ops.mkldnn
|
| 32 |
+
prims = torch.ops.prims
|
| 33 |
+
|
| 34 |
+
_conv_args = [Arg() for _ in range(10)]
|
| 35 |
+
_linear_args = [Arg() for _ in range(6)]
|
| 36 |
+
_conv_transpose_args = [Arg() for _ in range(11)]
|
| 37 |
+
|
| 38 |
+
def _conv_call(users=1):
|
| 39 |
+
return CallFunction(
|
| 40 |
+
mkldnn._convolution_pointwise.default, *_conv_args, _users=users
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def _linear_call(users=1):
|
| 44 |
+
return CallFunction(
|
| 45 |
+
mkldnn._linear_pointwise.default, *_linear_args, _users=users
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def _conv_transpose_call(users=1):
|
| 49 |
+
return CallFunction(
|
| 50 |
+
mkldnn._convolution_transpose_pointwise.default,
|
| 51 |
+
*_conv_transpose_args,
|
| 52 |
+
_users=users,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def _to_float(input_call, users=1):
|
| 56 |
+
return CallFunction(
|
| 57 |
+
prims.convert_element_type.default,
|
| 58 |
+
input_call,
|
| 59 |
+
KeywordArg("to_float"),
|
| 60 |
+
_users=users,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def _to_bf16(input_call):
|
| 64 |
+
return CallFunction(
|
| 65 |
+
prims.convert_element_type.default,
|
| 66 |
+
input_call,
|
| 67 |
+
KeywordArg("to_bf16"),
|
| 68 |
+
_users=1,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def _to_fp16(input_call):
|
| 72 |
+
return CallFunction(
|
| 73 |
+
prims.convert_element_type.default,
|
| 74 |
+
input_call,
|
| 75 |
+
KeywordArg("to_fp16"),
|
| 76 |
+
_users=1,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def _unary_fusion_pattern(unary_fusion, call_fn, users, lowp_dtype):
|
| 80 |
+
# only insert to_dtype if lowp_dtype is True
|
| 81 |
+
computation_call = (
|
| 82 |
+
_to_float(call_fn(), users=users) if lowp_dtype else call_fn(users=users)
|
| 83 |
+
)
|
| 84 |
+
out = unary_fusion(computation_call)
|
| 85 |
+
if lowp_dtype == torch.bfloat16:
|
| 86 |
+
return _to_bf16(out)
|
| 87 |
+
elif lowp_dtype == torch.float16:
|
| 88 |
+
return _to_fp16(out)
|
| 89 |
+
else:
|
| 90 |
+
return out
|
| 91 |
+
|
| 92 |
+
def _gelu_fusion_1(computation_call):
|
| 93 |
+
return CallFunction(
|
| 94 |
+
aten.mul,
|
| 95 |
+
CallFunction(aten.mul, computation_call, 0.5),
|
| 96 |
+
CallFunction(
|
| 97 |
+
aten.add,
|
| 98 |
+
CallFunction(
|
| 99 |
+
aten.erf,
|
| 100 |
+
CallFunction(aten.mul, computation_call, 0.7071067811865476),
|
| 101 |
+
),
|
| 102 |
+
1,
|
| 103 |
+
),
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def _gelu_fusion_2(computation_call):
|
| 107 |
+
return CallFunction(
|
| 108 |
+
aten.mul,
|
| 109 |
+
CallFunction(aten.mul, computation_call, 0.5),
|
| 110 |
+
CallFunction(
|
| 111 |
+
aten.add,
|
| 112 |
+
CallFunction(
|
| 113 |
+
aten.tanh,
|
| 114 |
+
CallFunction(
|
| 115 |
+
aten.mul,
|
| 116 |
+
CallFunction(
|
| 117 |
+
aten.add,
|
| 118 |
+
computation_call,
|
| 119 |
+
CallFunction(
|
| 120 |
+
aten.mul,
|
| 121 |
+
CallFunction(
|
| 122 |
+
aten.mul,
|
| 123 |
+
CallFunction(
|
| 124 |
+
aten.mul, computation_call, computation_call
|
| 125 |
+
),
|
| 126 |
+
computation_call,
|
| 127 |
+
),
|
| 128 |
+
0.044715,
|
| 129 |
+
),
|
| 130 |
+
),
|
| 131 |
+
0.7978845608028654,
|
| 132 |
+
),
|
| 133 |
+
),
|
| 134 |
+
1,
|
| 135 |
+
),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def _hardswish_fusion(computation_call):
|
| 139 |
+
return CallFunction(
|
| 140 |
+
aten.div,
|
| 141 |
+
CallFunction(
|
| 142 |
+
aten.mul,
|
| 143 |
+
computation_call,
|
| 144 |
+
CallFunction(
|
| 145 |
+
aten.clamp_max,
|
| 146 |
+
CallFunction(
|
| 147 |
+
aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0
|
| 148 |
+
),
|
| 149 |
+
6,
|
| 150 |
+
),
|
| 151 |
+
),
|
| 152 |
+
6,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def _silu_fusion(computation_call):
|
| 156 |
+
return CallFunction(
|
| 157 |
+
aten.mul, computation_call, CallFunction(aten.sigmoid, computation_call)
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
def _hardsigmoid_fusion(computation_call):
|
| 161 |
+
return CallFunction(
|
| 162 |
+
aten.div,
|
| 163 |
+
CallFunction(
|
| 164 |
+
aten.clamp_max,
|
| 165 |
+
CallFunction(
|
| 166 |
+
aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0
|
| 167 |
+
),
|
| 168 |
+
6,
|
| 169 |
+
),
|
| 170 |
+
6,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def _leaky_relu_fusion(computation_call):
|
| 174 |
+
return CallFunction(
|
| 175 |
+
aten.where,
|
| 176 |
+
CallFunction(aten.gt, computation_call, 0),
|
| 177 |
+
computation_call,
|
| 178 |
+
CallFunction(aten.mul, computation_call, KeywordArg("negative_slope")),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def _hardtanh_fusion(computation_call):
|
| 182 |
+
return CallFunction(
|
| 183 |
+
aten.clamp_max,
|
| 184 |
+
CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")),
|
| 185 |
+
KeywordArg("max_value"),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def _combined_fusion(computation_call, elementwise_op):
|
| 189 |
+
return CallFunction(elementwise_op, computation_call)
|
| 190 |
+
|
| 191 |
+
# binary_op(other, computation_op)
|
| 192 |
+
def _binary_fusion_v1(computation_call, binary_fn):
|
| 193 |
+
return CallFunction(binary_fn, KeywordArg("other"), computation_call)
|
| 194 |
+
|
| 195 |
+
# binary_op(computation_op, other)
|
| 196 |
+
def _binary_fusion_v2(computation_call, binary_fn):
|
| 197 |
+
return CallFunction(binary_fn, computation_call, KeywordArg("other"))
|
| 198 |
+
|
| 199 |
+
def _is_single_computation_op(computation_op):
|
| 200 |
+
def fn(match):
|
| 201 |
+
computation_nodes = filter_nodes(match.nodes, computation_op)
|
| 202 |
+
if len(computation_nodes) < 1:
|
| 203 |
+
return False
|
| 204 |
+
if any(n.args[-3] != "none" for n in computation_nodes):
|
| 205 |
+
return False
|
| 206 |
+
return True
|
| 207 |
+
|
| 208 |
+
return fn
|
| 209 |
+
|
| 210 |
+
def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None):
|
| 211 |
+
def fn(match):
|
| 212 |
+
matched = _is_single_computation_op(computation_op)(match)
|
| 213 |
+
computation_node = filter_nodes(match.nodes, computation_op)[0]
|
| 214 |
+
if lowp_dtype:
|
| 215 |
+
conversion_dtype_nodes = filter_nodes(
|
| 216 |
+
match.nodes, prims.convert_element_type.default
|
| 217 |
+
)
|
| 218 |
+
if len(conversion_dtype_nodes) != 2:
|
| 219 |
+
return False
|
| 220 |
+
# fusion pattern is always in the form of computation_op + to_float32 + unary_op + to_bfloat16
|
| 221 |
+
if computation_node == conversion_dtype_nodes[0].args[0]:
|
| 222 |
+
to_float = conversion_dtype_nodes[0].args[1]
|
| 223 |
+
to_lp = conversion_dtype_nodes[1].args[1]
|
| 224 |
+
else:
|
| 225 |
+
to_float = conversion_dtype_nodes[1].args[1]
|
| 226 |
+
to_lp = conversion_dtype_nodes[0].args[1]
|
| 227 |
+
matched = matched and to_float == torch.float and to_lp == lowp_dtype
|
| 228 |
+
return matched
|
| 229 |
+
|
| 230 |
+
return fn
|
| 231 |
+
|
| 232 |
+
def _register_unary_fusion_lowering(
|
| 233 |
+
pattern, unary_attr, computation_op, lowp_dtype=None
|
| 234 |
+
):
|
| 235 |
+
@register_lowering_pattern(
|
| 236 |
+
pattern,
|
| 237 |
+
extra_check=_is_valid_computation_unary_fusion(computation_op, lowp_dtype),
|
| 238 |
+
)
|
| 239 |
+
def fn(match, *args, **kwargs):
|
| 240 |
+
computation_args = list(args)[:-3] + [
|
| 241 |
+
unary_attr.op_name,
|
| 242 |
+
unary_attr.scalars_attr,
|
| 243 |
+
unary_attr.algorithm_attr,
|
| 244 |
+
]
|
| 245 |
+
return L[computation_op](*computation_args)
|
| 246 |
+
|
| 247 |
+
return fn
|
| 248 |
+
|
| 249 |
+
def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None):
|
| 250 |
+
@register_lowering_pattern(
|
| 251 |
+
pattern, extra_check=_is_single_computation_op(computation_op)
|
| 252 |
+
)
|
| 253 |
+
def fn(match, *args, **kwargs):
|
| 254 |
+
negative_slope = kwargs.get("negative_slope")
|
| 255 |
+
if isinstance(negative_slope, ir.TensorBox):
|
| 256 |
+
matched = False
|
| 257 |
+
else: # inp is a Number
|
| 258 |
+
matched = True
|
| 259 |
+
if lowp_dtype:
|
| 260 |
+
dtype1 = kwargs.get("to_float")
|
| 261 |
+
dtype2 = (
|
| 262 |
+
kwargs.get("to_bf16")
|
| 263 |
+
if lowp_dtype == torch.bfloat16
|
| 264 |
+
else kwargs.get("to_fp16")
|
| 265 |
+
)
|
| 266 |
+
matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
|
| 267 |
+
computation_args = list(args)
|
| 268 |
+
if matched:
|
| 269 |
+
computation_args = computation_args[:-3] + [
|
| 270 |
+
"leaky_relu",
|
| 271 |
+
[negative_slope],
|
| 272 |
+
"",
|
| 273 |
+
]
|
| 274 |
+
return L[computation_op](*computation_args)
|
| 275 |
+
else:
|
| 276 |
+
# computation_args += ["none", [], ""]
|
| 277 |
+
out = L[computation_op](*computation_args)
|
| 278 |
+
if lowp_dtype:
|
| 279 |
+
out = L[prims.convert_element_type.default](out, dtype=torch.float)
|
| 280 |
+
out = L[aten.where](
|
| 281 |
+
L[aten.gt](out, 0),
|
| 282 |
+
out,
|
| 283 |
+
L[aten.mul](out, negative_slope),
|
| 284 |
+
)
|
| 285 |
+
if lowp_dtype:
|
| 286 |
+
out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined]
|
| 287 |
+
return out
|
| 288 |
+
|
| 289 |
+
return fn
|
| 290 |
+
|
| 291 |
+
def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None):
|
| 292 |
+
@register_lowering_pattern(
|
| 293 |
+
pattern, extra_check=_is_single_computation_op(computation_op)
|
| 294 |
+
)
|
| 295 |
+
def fn(match, *args, **kwargs):
|
| 296 |
+
min_value = kwargs.get("min_value")
|
| 297 |
+
max_value = kwargs.get("max_value")
|
| 298 |
+
if isinstance(min_value, ir.TensorBox) or isinstance(
|
| 299 |
+
max_value, ir.TensorBox
|
| 300 |
+
):
|
| 301 |
+
matched = False
|
| 302 |
+
else: # inp is a Number
|
| 303 |
+
assert max_value is not None
|
| 304 |
+
matched = min_value <= max_value
|
| 305 |
+
if lowp_dtype:
|
| 306 |
+
dtype1 = kwargs.get("to_float")
|
| 307 |
+
dtype2 = (
|
| 308 |
+
kwargs.get("to_bf16")
|
| 309 |
+
if lowp_dtype == torch.bfloat16
|
| 310 |
+
else kwargs.get("to_fp16")
|
| 311 |
+
)
|
| 312 |
+
matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
|
| 313 |
+
computation_args = list(args)
|
| 314 |
+
if matched:
|
| 315 |
+
computation_args = computation_args[:-3] + [
|
| 316 |
+
"hardtanh",
|
| 317 |
+
[min_value, max_value],
|
| 318 |
+
"",
|
| 319 |
+
]
|
| 320 |
+
return L[computation_op](*computation_args)
|
| 321 |
+
else:
|
| 322 |
+
out = L[computation_op](*computation_args)
|
| 323 |
+
if lowp_dtype:
|
| 324 |
+
out = L[prims.convert_element_type.default](out, dtype=torch.float)
|
| 325 |
+
out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value)
|
| 326 |
+
if lowp_dtype:
|
| 327 |
+
out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined]
|
| 328 |
+
return out
|
| 329 |
+
|
| 330 |
+
return fn
|
| 331 |
+
|
| 332 |
+
_binary_attr = {
|
| 333 |
+
aten.add: "add",
|
| 334 |
+
ops.add: "add",
|
| 335 |
+
aten.sub: "sub",
|
| 336 |
+
ops.sub: "sub",
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
def _is_valid_binary(match, fn):
|
| 340 |
+
binary_nodes = filter_nodes(match.nodes, fn)
|
| 341 |
+
if len(binary_nodes) < 1:
|
| 342 |
+
return False
|
| 343 |
+
|
| 344 |
+
def get_meta_value(argument: torch.fx.node.Argument):
|
| 345 |
+
# Only torch.fx.Node is expected to have meta.
|
| 346 |
+
if isinstance(argument, torch.fx.Node):
|
| 347 |
+
return argument.meta.get("val", None)
|
| 348 |
+
return None
|
| 349 |
+
|
| 350 |
+
if any(
|
| 351 |
+
not isinstance(get_meta_value(n.args[0]), torch.Tensor)
|
| 352 |
+
or not isinstance(get_meta_value(n.args[1]), torch.Tensor)
|
| 353 |
+
for n in binary_nodes
|
| 354 |
+
):
|
| 355 |
+
return False
|
| 356 |
+
# check alpha is one.
|
| 357 |
+
if any(
|
| 358 |
+
get_arg_value(n, 2, kwarg_name="alpha") != 1.0
|
| 359 |
+
and get_arg_value(n, 2, kwarg_name="alpha") is not None
|
| 360 |
+
for n in binary_nodes
|
| 361 |
+
):
|
| 362 |
+
return False
|
| 363 |
+
if any(
|
| 364 |
+
get_meta_value(n.args[0]).size() != get_meta_value(n.args[1]).size()
|
| 365 |
+
or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device
|
| 366 |
+
or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype
|
| 367 |
+
for n in binary_nodes
|
| 368 |
+
):
|
| 369 |
+
return False
|
| 370 |
+
# check args[0] and args[1] is not same
|
| 371 |
+
if any(n.args[0] == n.args[1] for n in binary_nodes):
|
| 372 |
+
return False
|
| 373 |
+
return True
|
| 374 |
+
|
| 375 |
+
def _is_valid_computation_binary(computation_op, binary_op, other_index=None):
|
| 376 |
+
def fn(match):
|
| 377 |
+
if not _is_single_computation_op(computation_op)(match):
|
| 378 |
+
return False
|
| 379 |
+
if not _is_valid_binary(match, binary_op):
|
| 380 |
+
return False
|
| 381 |
+
return True
|
| 382 |
+
|
| 383 |
+
return fn
|
| 384 |
+
|
| 385 |
+
def _get_remaining_users(extra_input_node, compute_node):
|
| 386 |
+
# Think about this pattern:
|
| 387 |
+
# ReLU
|
| 388 |
+
# / \
|
| 389 |
+
# Conv1
|
| 390 |
+
# / \
|
| 391 |
+
# Conv2
|
| 392 |
+
# \ /
|
| 393 |
+
# Add
|
| 394 |
+
# Although, the extra input node (ReLU) has more than 1 users: Conv1 and Add.
|
| 395 |
+
# The Conv1 is the ancestor node of the current compute node (Conv2).
|
| 396 |
+
# This indicates that the buffer of ReLU has completed all its usage,
|
| 397 |
+
# So we can safely make changes to it now by doing Conv2->Add inplace fusion.
|
| 398 |
+
# Take above case as example:
|
| 399 |
+
# * extra_input_node: ReLU
|
| 400 |
+
# * compute_node: Conv2
|
| 401 |
+
# _get_remaining_users will return the users of extra_input_node which are not
|
| 402 |
+
# ancestor node of compute_node.
|
| 403 |
+
def _is_ancestor_node(_current_node, _ancestor_node):
|
| 404 |
+
# Check whether _ancestor_node is the ancestor node of _current_node
|
| 405 |
+
_node_list = [_current_node]
|
| 406 |
+
_visited_nodes = set()
|
| 407 |
+
while len(_node_list) != 0:
|
| 408 |
+
_current_node = _node_list.pop(0)
|
| 409 |
+
if _current_node not in _visited_nodes:
|
| 410 |
+
_visited_nodes.add(_current_node)
|
| 411 |
+
if _current_node == _ancestor_node:
|
| 412 |
+
return True
|
| 413 |
+
elif isinstance(
|
| 414 |
+
_current_node, torch.fx.Node
|
| 415 |
+
) and _current_node.op not in ["placeholder", "output", "get_attr"]:
|
| 416 |
+
for input in _current_node.all_input_nodes:
|
| 417 |
+
_node_list.append(input) # noqa: PERF402
|
| 418 |
+
return False
|
| 419 |
+
|
| 420 |
+
return [
|
| 421 |
+
user
|
| 422 |
+
for user in list(extra_input_node.users)
|
| 423 |
+
if not _is_ancestor_node(compute_node, user)
|
| 424 |
+
]
|
| 425 |
+
|
| 426 |
+
def _is_valid_computation_binary_inplace(computation_op, binary_op, other_index):
|
| 427 |
+
def fn(match):
|
| 428 |
+
if not _is_valid_computation_binary(computation_op, binary_op)(match):
|
| 429 |
+
return False
|
| 430 |
+
binary_nodes = filter_nodes(match.nodes, binary_op)
|
| 431 |
+
|
| 432 |
+
def _get_compute_node(_binary_node, _other_index):
|
| 433 |
+
assert (
|
| 434 |
+
len(_binary_node.all_input_nodes) == 2
|
| 435 |
+
), "Binary node should have 2 input nodes."
|
| 436 |
+
_compute_index = 1 if (_other_index == 0) else 0
|
| 437 |
+
return _binary_node.args[_compute_index]
|
| 438 |
+
|
| 439 |
+
def _other_input_not_inplaceable(_binary_node, _other_index):
|
| 440 |
+
_compute_node = _get_compute_node(_binary_node, _other_index)
|
| 441 |
+
return (
|
| 442 |
+
len(
|
| 443 |
+
_get_remaining_users(
|
| 444 |
+
_binary_node.args[_other_index], _compute_node
|
| 445 |
+
)
|
| 446 |
+
)
|
| 447 |
+
> 1
|
| 448 |
+
or _binary_node.args[_other_index] == _compute_node.args[0]
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes):
|
| 452 |
+
return False
|
| 453 |
+
if any(
|
| 454 |
+
n.args[other_index].op in ["placeholder", "output"]
|
| 455 |
+
for n in binary_nodes
|
| 456 |
+
):
|
| 457 |
+
return False
|
| 458 |
+
return True
|
| 459 |
+
|
| 460 |
+
return fn
|
| 461 |
+
|
| 462 |
+
def _register_binary_unary_fusion_lowering(
|
| 463 |
+
pattern,
|
| 464 |
+
computation_op,
|
| 465 |
+
binary_op,
|
| 466 |
+
fusion_op,
|
| 467 |
+
unary_attr=None,
|
| 468 |
+
):
|
| 469 |
+
@register_lowering_pattern(
|
| 470 |
+
pattern, extra_check=_is_valid_computation_binary(computation_op, binary_op)
|
| 471 |
+
)
|
| 472 |
+
def fn(match, *args, **kwargs):
|
| 473 |
+
other = kwargs.get("other")
|
| 474 |
+
assert isinstance(other, ir.TensorBox)
|
| 475 |
+
binary_attr = _binary_attr[binary_op]
|
| 476 |
+
args_list = list(args)
|
| 477 |
+
computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr]
|
| 478 |
+
if len(args_list) > 6:
|
| 479 |
+
if unary_attr is not None:
|
| 480 |
+
computation_args += [
|
| 481 |
+
1.0,
|
| 482 |
+
unary_attr.op_name,
|
| 483 |
+
unary_attr.scalars_attr,
|
| 484 |
+
unary_attr.algorithm_attr,
|
| 485 |
+
]
|
| 486 |
+
else:
|
| 487 |
+
computation_args += [1.0, None, [], None]
|
| 488 |
+
return L[fusion_op](*computation_args)
|
| 489 |
+
|
| 490 |
+
return fn
|
| 491 |
+
|
| 492 |
+
def _can_be_inplace(_other):
|
| 493 |
+
if isinstance(_other.data, ir.View):
|
| 494 |
+
return _can_be_inplace(_other.data)
|
| 495 |
+
else:
|
| 496 |
+
return not (
|
| 497 |
+
isinstance(_other.data, ir.ReinterpretView)
|
| 498 |
+
or isinstance(
|
| 499 |
+
_other.get_layout(), (ir.MutationLayout, ir.AliasedLayout)
|
| 500 |
+
)
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
def _register_binary_unary_maybe_inplace_fusion_lowering(
|
| 504 |
+
pattern,
|
| 505 |
+
computation_op,
|
| 506 |
+
binary_op,
|
| 507 |
+
inplace_fusion_op,
|
| 508 |
+
outplace_fusion_op,
|
| 509 |
+
unary_attr=None,
|
| 510 |
+
other_index=None,
|
| 511 |
+
):
|
| 512 |
+
@register_lowering_pattern(
|
| 513 |
+
pattern,
|
| 514 |
+
extra_check=_is_valid_computation_binary_inplace(
|
| 515 |
+
computation_op, binary_op, other_index
|
| 516 |
+
),
|
| 517 |
+
)
|
| 518 |
+
def fn(match, *args, **kwargs):
|
| 519 |
+
other = kwargs.get("other")
|
| 520 |
+
assert isinstance(other, ir.TensorBox)
|
| 521 |
+
binary_attr = _binary_attr[binary_op]
|
| 522 |
+
args_list = list(args)
|
| 523 |
+
computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr]
|
| 524 |
+
if len(args_list) > 6:
|
| 525 |
+
if unary_attr is not None:
|
| 526 |
+
computation_args += [
|
| 527 |
+
1.0,
|
| 528 |
+
unary_attr.op_name,
|
| 529 |
+
unary_attr.scalars_attr,
|
| 530 |
+
unary_attr.algorithm_attr,
|
| 531 |
+
]
|
| 532 |
+
else:
|
| 533 |
+
computation_args += [1.0, None, [], None]
|
| 534 |
+
# Make sure the other is not an alias or mutation(fx side doesn't has such info).
|
| 535 |
+
other.realize()
|
| 536 |
+
if not _can_be_inplace(other):
|
| 537 |
+
return L[outplace_fusion_op](*computation_args)
|
| 538 |
+
return L[inplace_fusion_op](*computation_args)
|
| 539 |
+
|
| 540 |
+
return fn
|
| 541 |
+
|
| 542 |
+
computation_ops = [
|
| 543 |
+
mkldnn._convolution_pointwise.default,
|
| 544 |
+
mkldnn._linear_pointwise.default,
|
| 545 |
+
mkldnn._convolution_transpose_pointwise.default,
|
| 546 |
+
]
|
| 547 |
+
|
| 548 |
+
class UnaryAttr:
|
| 549 |
+
def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
|
| 550 |
+
self.op_name = op_name
|
| 551 |
+
self.scalars_attr = scalars_attr if scalars_attr else []
|
| 552 |
+
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
|
| 553 |
+
|
| 554 |
+
def _register_unary_fusion():
|
| 555 |
+
computation_call_fns = [_conv_call, _linear_call, _conv_transpose_call]
|
| 556 |
+
|
| 557 |
+
def _unary_fusion_patterns(lowp_dtype):
|
| 558 |
+
replacement_unary_fusion_patterns = {
|
| 559 |
+
UnaryAttr("gelu", algorithm_attr="tanh"): [
|
| 560 |
+
_unary_fusion_pattern(_gelu_fusion_2, call_fn, 4, lowp_dtype)
|
| 561 |
+
for call_fn in computation_call_fns
|
| 562 |
+
],
|
| 563 |
+
UnaryAttr("gelu", algorithm_attr="none"): [
|
| 564 |
+
_unary_fusion_pattern(_gelu_fusion_1, call_fn, 2, lowp_dtype)
|
| 565 |
+
for call_fn in computation_call_fns
|
| 566 |
+
],
|
| 567 |
+
UnaryAttr("hardswish"): [
|
| 568 |
+
_unary_fusion_pattern(_hardswish_fusion, call_fn, 2, lowp_dtype)
|
| 569 |
+
for call_fn in computation_call_fns
|
| 570 |
+
],
|
| 571 |
+
UnaryAttr("hardsigmoid"): [
|
| 572 |
+
_unary_fusion_pattern(_hardsigmoid_fusion, call_fn, 1, lowp_dtype)
|
| 573 |
+
for call_fn in computation_call_fns
|
| 574 |
+
],
|
| 575 |
+
UnaryAttr("swish"): [
|
| 576 |
+
_unary_fusion_pattern(_silu_fusion, call_fn, 2, lowp_dtype)
|
| 577 |
+
for call_fn in computation_call_fns
|
| 578 |
+
],
|
| 579 |
+
}
|
| 580 |
+
if not lowp_dtype:
|
| 581 |
+
call_user1 = [call_fn(users=1) for call_fn in computation_call_fns]
|
| 582 |
+
replacement_unary_fusion_patterns.update(
|
| 583 |
+
{
|
| 584 |
+
UnaryAttr("relu"): [
|
| 585 |
+
_combined_fusion(u, aten.relu) for u in call_user1
|
| 586 |
+
],
|
| 587 |
+
UnaryAttr("sigmoid"): [
|
| 588 |
+
_combined_fusion(u, aten.sigmoid) for u in call_user1
|
| 589 |
+
],
|
| 590 |
+
UnaryAttr("tanh"): [
|
| 591 |
+
_combined_fusion(u, aten.tanh) for u in call_user1
|
| 592 |
+
],
|
| 593 |
+
}
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
return replacement_unary_fusion_patterns
|
| 597 |
+
|
| 598 |
+
for lowp_dtype in [torch.bfloat16, torch.float16, None]:
|
| 599 |
+
replace_patterns = _unary_fusion_patterns(lowp_dtype)
|
| 600 |
+
for unary_attr, patterns in replace_patterns.items():
|
| 601 |
+
_register_unary_fusion_lowering(
|
| 602 |
+
patterns[0], unary_attr, computation_ops[0], lowp_dtype
|
| 603 |
+
)
|
| 604 |
+
_register_unary_fusion_lowering(
|
| 605 |
+
patterns[1], unary_attr, computation_ops[1], lowp_dtype
|
| 606 |
+
)
|
| 607 |
+
_register_unary_fusion_lowering(
|
| 608 |
+
patterns[2], unary_attr, computation_ops[2], lowp_dtype
|
| 609 |
+
)
|
| 610 |
+
_leaky_relu_patterns = [
|
| 611 |
+
_unary_fusion_pattern(_leaky_relu_fusion, call_fn, 3, lowp_dtype)
|
| 612 |
+
for call_fn in computation_call_fns
|
| 613 |
+
]
|
| 614 |
+
for pattern, computation_op in zip(_leaky_relu_patterns, computation_ops):
|
| 615 |
+
_register_leaky_relu_fusion_lowering(
|
| 616 |
+
pattern, computation_op, lowp_dtype
|
| 617 |
+
)
|
| 618 |
+
hardtanh_patterns = [
|
| 619 |
+
_unary_fusion_pattern(_hardtanh_fusion, call_fn, 1, lowp_dtype)
|
| 620 |
+
for call_fn in computation_call_fns
|
| 621 |
+
]
|
| 622 |
+
for pattern, computation_op in zip(hardtanh_patterns, computation_ops):
|
| 623 |
+
_register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype)
|
| 624 |
+
|
| 625 |
+
def _register_inplace_fusion():
|
| 626 |
+
binary_ops = [aten.add, ops.add]
|
| 627 |
+
inplace_fusion_op = mkldnn._convolution_pointwise_.binary
|
| 628 |
+
outplace_fusion_op = mkldnn._convolution_pointwise.binary
|
| 629 |
+
conv_call = _conv_call(users=1)
|
| 630 |
+
conv_op = computation_ops[0]
|
| 631 |
+
for binary_op in binary_ops:
|
| 632 |
+
binary_v1 = _binary_fusion_v1(conv_call, binary_op)
|
| 633 |
+
binary_unary_v1 = _combined_fusion(binary_v1, aten.relu)
|
| 634 |
+
_register_binary_unary_maybe_inplace_fusion_lowering(
|
| 635 |
+
binary_unary_v1,
|
| 636 |
+
conv_op,
|
| 637 |
+
binary_op,
|
| 638 |
+
inplace_fusion_op,
|
| 639 |
+
outplace_fusion_op,
|
| 640 |
+
other_index=0,
|
| 641 |
+
unary_attr=UnaryAttr("relu"),
|
| 642 |
+
)
|
| 643 |
+
_register_binary_unary_maybe_inplace_fusion_lowering(
|
| 644 |
+
binary_v1,
|
| 645 |
+
conv_op,
|
| 646 |
+
binary_op,
|
| 647 |
+
inplace_fusion_op,
|
| 648 |
+
outplace_fusion_op,
|
| 649 |
+
other_index=0,
|
| 650 |
+
)
|
| 651 |
+
binary_v2 = _binary_fusion_v2(conv_call, binary_op)
|
| 652 |
+
binary_unary_v2 = _combined_fusion(binary_v2, aten.relu)
|
| 653 |
+
_register_binary_unary_maybe_inplace_fusion_lowering(
|
| 654 |
+
binary_unary_v2,
|
| 655 |
+
conv_op,
|
| 656 |
+
binary_op,
|
| 657 |
+
inplace_fusion_op,
|
| 658 |
+
outplace_fusion_op,
|
| 659 |
+
other_index=1,
|
| 660 |
+
unary_attr=UnaryAttr("relu"),
|
| 661 |
+
)
|
| 662 |
+
_register_binary_unary_maybe_inplace_fusion_lowering(
|
| 663 |
+
binary_v2,
|
| 664 |
+
conv_op,
|
| 665 |
+
binary_op,
|
| 666 |
+
inplace_fusion_op,
|
| 667 |
+
outplace_fusion_op,
|
| 668 |
+
other_index=1,
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
def _register_binary_fusion():
|
| 672 |
+
binary_ops = [aten.add, ops.add, aten.sub, ops.sub]
|
| 673 |
+
fusion_ops = [
|
| 674 |
+
mkldnn._convolution_pointwise.binary,
|
| 675 |
+
mkldnn._linear_pointwise.binary,
|
| 676 |
+
]
|
| 677 |
+
_computation_user_1 = [_conv_call(users=1), _linear_call(users=1)]
|
| 678 |
+
for computation_call, computation_op, fusion_op in zip(
|
| 679 |
+
_computation_user_1, computation_ops[:-1], fusion_ops
|
| 680 |
+
):
|
| 681 |
+
for binary_op in binary_ops:
|
| 682 |
+
pattern = _binary_fusion_v2(computation_call, binary_op)
|
| 683 |
+
_register_binary_unary_fusion_lowering(
|
| 684 |
+
pattern, computation_op, binary_op, fusion_op
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
for binary_op in [aten.add, ops.add]:
|
| 688 |
+
pattern = _binary_fusion_v1(computation_call, binary_op)
|
| 689 |
+
_register_binary_unary_fusion_lowering(
|
| 690 |
+
pattern, computation_op, binary_op, fusion_op
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
def _register_binary_unary_fusion():
|
| 694 |
+
binary_ops = [aten.add, ops.add, aten.sub, ops.sub]
|
| 695 |
+
fusion_ops = [mkldnn._convolution_pointwise.binary]
|
| 696 |
+
_computation_user_1 = [_conv_call(users=1)]
|
| 697 |
+
for computation_call, computation_op, fusion_op in zip(
|
| 698 |
+
_computation_user_1, computation_ops[:-1], fusion_ops
|
| 699 |
+
):
|
| 700 |
+
for binary_op in binary_ops:
|
| 701 |
+
pattern_v1 = _combined_fusion(
|
| 702 |
+
_binary_fusion_v2(computation_call, binary_op), aten.relu
|
| 703 |
+
)
|
| 704 |
+
_register_binary_unary_fusion_lowering(
|
| 705 |
+
pattern_v1,
|
| 706 |
+
computation_op,
|
| 707 |
+
binary_op,
|
| 708 |
+
fusion_op,
|
| 709 |
+
unary_attr=UnaryAttr("relu"),
|
| 710 |
+
)
|
| 711 |
+
for binary_op in [aten.add, ops.add]:
|
| 712 |
+
pattern_v2 = _combined_fusion(
|
| 713 |
+
_binary_fusion_v1(computation_call, binary_op), aten.relu
|
| 714 |
+
)
|
| 715 |
+
_register_binary_unary_fusion_lowering(
|
| 716 |
+
pattern_v2,
|
| 717 |
+
computation_op,
|
| 718 |
+
binary_op,
|
| 719 |
+
fusion_op,
|
| 720 |
+
unary_attr=UnaryAttr("relu"),
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
def _recover_linear():
|
| 724 |
+
# convert reshape+linear+reshape to a single linear for applying fusion path.
|
| 725 |
+
@register_freezing_graph_pattern(
|
| 726 |
+
CallFunction(
|
| 727 |
+
aten.reshape.default,
|
| 728 |
+
CallFunction(
|
| 729 |
+
mkldnn._linear_pointwise.default,
|
| 730 |
+
CallFunction(
|
| 731 |
+
aten.reshape.default,
|
| 732 |
+
Arg(),
|
| 733 |
+
KeywordArg("reshape_1"),
|
| 734 |
+
_users=MULTIPLE,
|
| 735 |
+
),
|
| 736 |
+
Arg(),
|
| 737 |
+
Arg(),
|
| 738 |
+
Arg(),
|
| 739 |
+
Arg(),
|
| 740 |
+
Arg(),
|
| 741 |
+
),
|
| 742 |
+
KeywordArg("reshape_2"),
|
| 743 |
+
),
|
| 744 |
+
pass_number=1,
|
| 745 |
+
)
|
| 746 |
+
def reshape_linear_reshape_pattern(match, *args, **kwargs):
|
| 747 |
+
reshape_1 = kwargs.get("reshape_1")
|
| 748 |
+
reshape_2 = kwargs.get("reshape_2")
|
| 749 |
+
assert isinstance(reshape_1, list)
|
| 750 |
+
assert isinstance(reshape_2, list)
|
| 751 |
+
assert len(reshape_1) == 2
|
| 752 |
+
dynamic_shapes = not all(
|
| 753 |
+
isinstance(x, int) for x in ([reshape_1[0]] + reshape_2[:-1])
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
graph = match.graph
|
| 757 |
+
reshape_2_node = match.output_node()
|
| 758 |
+
linear_input_node = reshape_2_node.args[0].args[0].args[0]
|
| 759 |
+
# check linear's input's shape[:-1] == reshape_2[:-1]
|
| 760 |
+
# and check product(reshape_2[:-1]) == reshape_1[0]
|
| 761 |
+
if dynamic_shapes:
|
| 762 |
+
# TODO: Haozhe investigate how add guard here
|
| 763 |
+
return
|
| 764 |
+
else:
|
| 765 |
+
can_remove_reshape = linear_input_node.meta.get("val").shape[
|
| 766 |
+
:-1
|
| 767 |
+
] == torch.Size(reshape_2[:-1])
|
| 768 |
+
can_remove_reshape = can_remove_reshape and (
|
| 769 |
+
reduce(operator.mul, reshape_2[:-1]) == reshape_1[0]
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
if can_remove_reshape:
|
| 773 |
+
repl = graph.call_function(mkldnn._linear_pointwise.default, args)
|
| 774 |
+
repl.meta.update(reshape_2_node.meta)
|
| 775 |
+
reshape_2_node.replace_all_uses_with(repl)
|
| 776 |
+
old_linear_node = reshape_2_node.args[0]
|
| 777 |
+
reshape_1_node = old_linear_node.args[0]
|
| 778 |
+
graph.erase_node(reshape_2_node)
|
| 779 |
+
graph.erase_node(old_linear_node)
|
| 780 |
+
if len(reshape_1_node.users) == 0:
|
| 781 |
+
graph.erase_node(reshape_1_node)
|
| 782 |
+
|
| 783 |
+
def is_linear_add_bias(match):
|
| 784 |
+
add_node = match.output_node()
|
| 785 |
+
linear_node = add_node.args[0]
|
| 786 |
+
weight_meta = linear_node.args[1].meta.get("val")
|
| 787 |
+
bias_meta = add_node.args[1].meta.get("val")
|
| 788 |
+
if weight_meta is None or bias_meta is None:
|
| 789 |
+
return False
|
| 790 |
+
return (
|
| 791 |
+
linear_node.args[2] is None
|
| 792 |
+
and bias_meta.dim() == 1
|
| 793 |
+
and bias_meta.size(0) == weight_meta.size(0)
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
# convert linear+bias to a single linear for applying fusion path.
|
| 797 |
+
@register_freezing_graph_pattern(
|
| 798 |
+
CallFunction(
|
| 799 |
+
aten.add.Tensor,
|
| 800 |
+
CallFunction(mkldnn._linear_pointwise.default, *_linear_args),
|
| 801 |
+
Arg(),
|
| 802 |
+
),
|
| 803 |
+
pass_number=1,
|
| 804 |
+
extra_check=is_linear_add_bias,
|
| 805 |
+
)
|
| 806 |
+
def linear_bias_pattern(match, *args):
|
| 807 |
+
graph = match.graph
|
| 808 |
+
add_node = match.output_node()
|
| 809 |
+
linear_node = add_node.args[0]
|
| 810 |
+
new_args = list(linear_node.args)
|
| 811 |
+
new_args[2] = add_node.args[1]
|
| 812 |
+
repl = graph.call_function(
|
| 813 |
+
mkldnn._linear_pointwise.default, tuple(new_args)
|
| 814 |
+
)
|
| 815 |
+
repl.meta.update(add_node.meta)
|
| 816 |
+
add_node.replace_all_uses_with(repl)
|
| 817 |
+
match.erase_nodes(graph)
|
| 818 |
+
|
| 819 |
+
def _is_packable_mkldnn_rnn_layer(match):
|
| 820 |
+
lstm_node = match.output_node()
|
| 821 |
+
POS_WEIGHTS = [1, 2]
|
| 822 |
+
POS_INPUTS = [0, 5, 6]
|
| 823 |
+
POS_ARGS = POS_WEIGHTS + POS_INPUTS
|
| 824 |
+
# Weights should be Constant
|
| 825 |
+
if any(
|
| 826 |
+
lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS
|
| 827 |
+
):
|
| 828 |
+
return False
|
| 829 |
+
|
| 830 |
+
# Meta info for weights and inputs should be available
|
| 831 |
+
if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS):
|
| 832 |
+
return False
|
| 833 |
+
|
| 834 |
+
# Check device
|
| 835 |
+
if any(
|
| 836 |
+
lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu"
|
| 837 |
+
for POS_ARG in POS_ARGS
|
| 838 |
+
):
|
| 839 |
+
return False
|
| 840 |
+
|
| 841 |
+
# Check dtype
|
| 842 |
+
if any(
|
| 843 |
+
lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16
|
| 844 |
+
and not mkldnn._is_mkldnn_bf16_supported()
|
| 845 |
+
for POS_ARG in POS_ARGS
|
| 846 |
+
):
|
| 847 |
+
return False
|
| 848 |
+
if any(
|
| 849 |
+
lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16
|
| 850 |
+
and not mkldnn._is_mkldnn_fp16_supported()
|
| 851 |
+
for POS_ARG in POS_ARGS
|
| 852 |
+
):
|
| 853 |
+
return False
|
| 854 |
+
|
| 855 |
+
return True
|
| 856 |
+
|
| 857 |
+
def _is_packable_convolution(match):
|
| 858 |
+
"""
|
| 859 |
+
Check if the node is supported for MKLDNN convolution.
|
| 860 |
+
"""
|
| 861 |
+
conv_node = match.output_node()
|
| 862 |
+
input_meta_value = conv_node.args[0].meta.get("val")
|
| 863 |
+
weight_meta_value = conv_node.args[1].meta.get("val")
|
| 864 |
+
if input_meta_value is None or weight_meta_value is None:
|
| 865 |
+
return False
|
| 866 |
+
input_size = input_meta_value.shape
|
| 867 |
+
if conv_node.args[1].op != "get_attr":
|
| 868 |
+
return False
|
| 869 |
+
for meta_value in [input_meta_value, weight_meta_value]:
|
| 870 |
+
if (
|
| 871 |
+
meta_value is None
|
| 872 |
+
or meta_value.device.type != "cpu"
|
| 873 |
+
or meta_value.dim() != 4
|
| 874 |
+
):
|
| 875 |
+
return False
|
| 876 |
+
if (
|
| 877 |
+
input_meta_value.dtype == torch.bfloat16
|
| 878 |
+
or weight_meta_value.dtype == torch.bfloat16
|
| 879 |
+
):
|
| 880 |
+
if not mkldnn._is_mkldnn_bf16_supported():
|
| 881 |
+
return False
|
| 882 |
+
if (
|
| 883 |
+
input_meta_value.dtype == torch.float16
|
| 884 |
+
or weight_meta_value.dtype == torch.float16
|
| 885 |
+
):
|
| 886 |
+
if not mkldnn._is_mkldnn_fp16_supported():
|
| 887 |
+
return False
|
| 888 |
+
is_transposed = conv_node.args[-3]
|
| 889 |
+
if is_transposed:
|
| 890 |
+
# TODO: Support dynamic shape case for MKLDNN conv transpose.
|
| 891 |
+
if has_free_symbols(input_size):
|
| 892 |
+
return False
|
| 893 |
+
groups = conv_node.args[-1]
|
| 894 |
+
in_channels = weight_meta_value.size(0)
|
| 895 |
+
# doesn't support group_depthwise_conv_transpose.
|
| 896 |
+
if groups > 1 and groups == in_channels:
|
| 897 |
+
return False
|
| 898 |
+
# Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big
|
| 899 |
+
output_paddings = conv_node.args[-2]
|
| 900 |
+
strides = conv_node.args[3]
|
| 901 |
+
if any(
|
| 902 |
+
output_padding >= stride
|
| 903 |
+
for output_padding, stride in zip(output_paddings, strides)
|
| 904 |
+
):
|
| 905 |
+
return False
|
| 906 |
+
return True
|
| 907 |
+
|
| 908 |
+
def _is_packable_linear(match):
|
| 909 |
+
"""
|
| 910 |
+
Check if the node is supported for MKLDNN linear.
|
| 911 |
+
"""
|
| 912 |
+
linear_node = match.output_node()
|
| 913 |
+
# weight_idx is 1 for aten.mm and is 2 for aten.addmm
|
| 914 |
+
weight_idx = 2 if linear_node.target == aten.addmm.default else 1
|
| 915 |
+
if linear_node.args[weight_idx].op != "get_attr":
|
| 916 |
+
return False
|
| 917 |
+
input_meta_value = linear_node.args[weight_idx - 1].meta.get("val")
|
| 918 |
+
weight_meta_value = linear_node.args[weight_idx].meta.get("val")
|
| 919 |
+
if input_meta_value is None or weight_meta_value is None:
|
| 920 |
+
return False
|
| 921 |
+
batch_size = input_meta_value.shape[0]
|
| 922 |
+
is_lp_weight = weight_meta_value.dtype in (
|
| 923 |
+
torch.bfloat16,
|
| 924 |
+
torch.float16,
|
| 925 |
+
)
|
| 926 |
+
# on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol.
|
| 927 |
+
# on aarch64, use mkldnn op for fp32 as well if acl is enabled
|
| 928 |
+
if (
|
| 929 |
+
not is_lp_weight
|
| 930 |
+
and not mkldnn._is_mkldnn_acl_supported()
|
| 931 |
+
and ((not torch._C.has_mkl) or has_free_symbols(batch_size))
|
| 932 |
+
):
|
| 933 |
+
return False
|
| 934 |
+
for meta_value in [input_meta_value, weight_meta_value]:
|
| 935 |
+
if (
|
| 936 |
+
meta_value is None
|
| 937 |
+
or meta_value.device.type != "cpu"
|
| 938 |
+
or meta_value.dim() != 2
|
| 939 |
+
):
|
| 940 |
+
return False
|
| 941 |
+
if weight_idx == 2:
|
| 942 |
+
bias_meta_value = linear_node.args[0].meta.get("val")
|
| 943 |
+
if (
|
| 944 |
+
bias_meta_value is None
|
| 945 |
+
or meta_value.device.type != "cpu"
|
| 946 |
+
or bias_meta_value.dim() != 1
|
| 947 |
+
or bias_meta_value.size(0) != weight_meta_value.size(1)
|
| 948 |
+
):
|
| 949 |
+
return False
|
| 950 |
+
|
| 951 |
+
if (
|
| 952 |
+
input_meta_value.dtype == torch.bfloat16
|
| 953 |
+
or weight_meta_value.dtype == torch.bfloat16
|
| 954 |
+
):
|
| 955 |
+
if not mkldnn._is_mkldnn_bf16_supported():
|
| 956 |
+
return False
|
| 957 |
+
if (
|
| 958 |
+
input_meta_value.dtype == torch.float16
|
| 959 |
+
or weight_meta_value.dtype == torch.float16
|
| 960 |
+
):
|
| 961 |
+
if not mkldnn._is_mkldnn_fp16_supported():
|
| 962 |
+
return False
|
| 963 |
+
return True
|
| 964 |
+
|
| 965 |
+
_aten_conv_args = (
|
| 966 |
+
Arg(),
|
| 967 |
+
Arg(),
|
| 968 |
+
Arg(),
|
| 969 |
+
Arg(),
|
| 970 |
+
Arg(),
|
| 971 |
+
Arg(),
|
| 972 |
+
KeywordArg("is_transposed"),
|
| 973 |
+
Arg(),
|
| 974 |
+
Arg(),
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
_aten_mkldnn_rnn_layer_args = (
|
| 978 |
+
Arg(), # input
|
| 979 |
+
Arg(), # weight0
|
| 980 |
+
Arg(), # weight1
|
| 981 |
+
Arg(), # weight2
|
| 982 |
+
Arg(), # weight3
|
| 983 |
+
Arg(), # hx_
|
| 984 |
+
Arg(), # cx_
|
| 985 |
+
KeywordArg("reverse"), # reverse
|
| 986 |
+
Arg(), # batch_sizes
|
| 987 |
+
Arg(), # mode
|
| 988 |
+
Arg(), # hidden_size
|
| 989 |
+
Arg(), # num_layers
|
| 990 |
+
Arg(), # has_biases
|
| 991 |
+
Arg(), # bidirectional
|
| 992 |
+
Arg(), # batch_first
|
| 993 |
+
Arg(), # train
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
def _register_weight_pack_pass():
|
| 997 |
+
@register_freezing_graph_pattern(
|
| 998 |
+
CallFunction(aten.convolution.default, *_aten_conv_args),
|
| 999 |
+
extra_check=_is_packable_convolution,
|
| 1000 |
+
)
|
| 1001 |
+
def convolution(match, *args, **kwargs):
|
| 1002 |
+
is_transposed = kwargs.get("is_transposed")
|
| 1003 |
+
assert isinstance(is_transposed, bool)
|
| 1004 |
+
graph = match.graph
|
| 1005 |
+
conv_node = match.output_node()
|
| 1006 |
+
input_size = conv_node.args[0].meta.get("val").shape
|
| 1007 |
+
with graph.inserting_before(conv_node):
|
| 1008 |
+
constant_args = [args[4], args[3], args[5], args[-1]]
|
| 1009 |
+
packed_weight_op = mkldnn._reorder_convolution_weight
|
| 1010 |
+
packed_conv_op = mkldnn._convolution_pointwise.default
|
| 1011 |
+
if is_transposed:
|
| 1012 |
+
constant_args.insert(1, args[-2]) # output_padding
|
| 1013 |
+
packed_weight_op = mkldnn._reorder_convolution_transpose_weight
|
| 1014 |
+
packed_conv_op = mkldnn._convolution_transpose_pointwise.default
|
| 1015 |
+
if not has_free_symbols(input_size):
|
| 1016 |
+
packed_weight_inputs = (
|
| 1017 |
+
(args[1],) + tuple(constant_args) + (input_size,)
|
| 1018 |
+
)
|
| 1019 |
+
packed_weight_node = graph.create_node(
|
| 1020 |
+
"call_function", packed_weight_op, args=packed_weight_inputs
|
| 1021 |
+
)
|
| 1022 |
+
else:
|
| 1023 |
+
assert not is_transposed
|
| 1024 |
+
# For dynamic shape case, we need to pack weight in runtime.
|
| 1025 |
+
packed_weight_node = args[1]
|
| 1026 |
+
packed_conv_inputs = (
|
| 1027 |
+
(args[0], packed_weight_node, args[2])
|
| 1028 |
+
+ tuple(constant_args)
|
| 1029 |
+
+ ("none", [], "")
|
| 1030 |
+
)
|
| 1031 |
+
packed_conv_node = graph.create_node(
|
| 1032 |
+
"call_function", packed_conv_op, tuple(packed_conv_inputs)
|
| 1033 |
+
)
|
| 1034 |
+
conv_node.replace_all_uses_with(packed_conv_node)
|
| 1035 |
+
packed_conv_node.meta.update(conv_node.meta)
|
| 1036 |
+
graph.erase_node(conv_node)
|
| 1037 |
+
|
| 1038 |
+
@register_freezing_graph_pattern(
|
| 1039 |
+
CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args),
|
| 1040 |
+
extra_check=_is_packable_mkldnn_rnn_layer,
|
| 1041 |
+
)
|
| 1042 |
+
def mkldnn_rnn_layer(match, *args, **kwargs):
|
| 1043 |
+
def get_item(graph, node, index):
|
| 1044 |
+
return graph.call_function(operator.getitem, (node, index))
|
| 1045 |
+
|
| 1046 |
+
graph = match.graph
|
| 1047 |
+
lstm_node = match.output_node()
|
| 1048 |
+
input = args[0]
|
| 1049 |
+
weight0, weight1 = args[1:3]
|
| 1050 |
+
reverse = kwargs.get("reverse")
|
| 1051 |
+
packed_lstm_op = aten.mkldnn_rnn_layer.default
|
| 1052 |
+
hidden_size = args[9]
|
| 1053 |
+
has_biases = args[11]
|
| 1054 |
+
batch_first = args[13]
|
| 1055 |
+
with graph.inserting_before(lstm_node):
|
| 1056 |
+
packed_weight_op = mkldnn._reorder_mkldnn_rnn_layer_weight.default
|
| 1057 |
+
packed_weight_inputs = (
|
| 1058 |
+
weight0,
|
| 1059 |
+
weight1,
|
| 1060 |
+
hidden_size,
|
| 1061 |
+
reverse,
|
| 1062 |
+
has_biases,
|
| 1063 |
+
batch_first,
|
| 1064 |
+
)
|
| 1065 |
+
packed_weight_node = graph.create_node(
|
| 1066 |
+
"call_function", packed_weight_op, packed_weight_inputs, {}, "name"
|
| 1067 |
+
)
|
| 1068 |
+
packed_weight_items = [
|
| 1069 |
+
get_item(graph, packed_weight_node, i) for i in range(2)
|
| 1070 |
+
]
|
| 1071 |
+
pack_lstm_inputs = (
|
| 1072 |
+
args[0],
|
| 1073 |
+
*packed_weight_items,
|
| 1074 |
+
args[3],
|
| 1075 |
+
args[4],
|
| 1076 |
+
args[5],
|
| 1077 |
+
args[6],
|
| 1078 |
+
reverse,
|
| 1079 |
+
*args[7:],
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
packed_lstm_node = graph.create_node(
|
| 1083 |
+
"call_function", packed_lstm_op, args=pack_lstm_inputs
|
| 1084 |
+
)
|
| 1085 |
+
lstm_node.replace_all_uses_with(packed_lstm_node)
|
| 1086 |
+
packed_lstm_node.meta.update(lstm_node.meta)
|
| 1087 |
+
graph.erase_node(lstm_node)
|
| 1088 |
+
|
| 1089 |
+
@register_freezing_graph_pattern(
|
| 1090 |
+
CallFunction(aten.addmm.default, Arg(), Arg(), Arg()),
|
| 1091 |
+
extra_check=_is_packable_linear,
|
| 1092 |
+
)
|
| 1093 |
+
@register_freezing_graph_pattern(
|
| 1094 |
+
CallFunction(aten.mm.default, Arg(), Arg()),
|
| 1095 |
+
extra_check=_is_packable_linear,
|
| 1096 |
+
)
|
| 1097 |
+
def linear(match, *args, **kwargs):
|
| 1098 |
+
graph = match.graph
|
| 1099 |
+
linear_node = match.output_node()
|
| 1100 |
+
input = args[0] if linear_node.target == aten.mm.default else args[1]
|
| 1101 |
+
bias = None if linear_node.target == aten.mm.default else args[0]
|
| 1102 |
+
weight = args[1] if linear_node.target == aten.mm.default else args[2]
|
| 1103 |
+
with graph.inserting_before(linear_node):
|
| 1104 |
+
transpose_weight_node = graph.create_node(
|
| 1105 |
+
"call_function", aten.permute.default, (weight, (1, 0))
|
| 1106 |
+
)
|
| 1107 |
+
weight_dtype = weight.meta.get("val").dtype
|
| 1108 |
+
is_lp_weight = weight_dtype in (
|
| 1109 |
+
torch.bfloat16,
|
| 1110 |
+
torch.float16,
|
| 1111 |
+
)
|
| 1112 |
+
batch_size = input.meta.get("val").shape[0]
|
| 1113 |
+
if has_free_symbols(batch_size):
|
| 1114 |
+
assert (
|
| 1115 |
+
is_lp_weight or mkldnn._is_mkldnn_acl_supported()
|
| 1116 |
+
), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
|
| 1117 |
+
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
|
| 1118 |
+
packed_weight_inputs = (
|
| 1119 |
+
transpose_weight_node,
|
| 1120 |
+
batch_size.node.shape_env.size_hint(batch_size.node.expr)
|
| 1121 |
+
if has_free_symbols(batch_size)
|
| 1122 |
+
else batch_size,
|
| 1123 |
+
)
|
| 1124 |
+
packed_weight_op = (
|
| 1125 |
+
mkldnn._reorder_linear_weight
|
| 1126 |
+
if (is_lp_weight or mkldnn._is_mkldnn_acl_supported())
|
| 1127 |
+
else torch.ops.mkl._mkl_reorder_linear_weight
|
| 1128 |
+
)
|
| 1129 |
+
packed_weight_node = graph.create_node(
|
| 1130 |
+
"call_function", packed_weight_op, args=packed_weight_inputs
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node)
|
| 1134 |
+
if is_lp_weight or mkldnn._is_mkldnn_acl_supported():
|
| 1135 |
+
packed_linear_inputs += (bias, "none", [], "")
|
| 1136 |
+
packed_linear_op = mkldnn._linear_pointwise.default
|
| 1137 |
+
else:
|
| 1138 |
+
packed_linear_inputs += (transpose_weight_node, bias, batch_size)
|
| 1139 |
+
packed_linear_op = torch.ops.mkl._mkl_linear
|
| 1140 |
+
packed_linear_node = graph.create_node(
|
| 1141 |
+
"call_function", packed_linear_op, packed_linear_inputs
|
| 1142 |
+
)
|
| 1143 |
+
linear_node.replace_all_uses_with(packed_linear_node)
|
| 1144 |
+
packed_linear_node.meta.update(linear_node.meta)
|
| 1145 |
+
graph.erase_node(linear_node)
|
| 1146 |
+
|
| 1147 |
+
def _eliminate_duplicate_packed_nodes(gm):
|
| 1148 |
+
"""
|
| 1149 |
+
Combine packed weight nodes with the same inputs to reduce memory usage.
|
| 1150 |
+
for example:
|
| 1151 |
+
class Model(nn.Module):
|
| 1152 |
+
def __init__(self):
|
| 1153 |
+
super().__init__()
|
| 1154 |
+
self.linear = nn.Linear(32, 32, bias=True)
|
| 1155 |
+
|
| 1156 |
+
def forward(self, x):
|
| 1157 |
+
return self.linear(self.linear(x))
|
| 1158 |
+
|
| 1159 |
+
the above's packed weight nodes are duplicate if two linear calls have same input size.
|
| 1160 |
+
"""
|
| 1161 |
+
if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
|
| 1162 |
+
return gm
|
| 1163 |
+
|
| 1164 |
+
packed_weight_ops = [
|
| 1165 |
+
torch._C._nn.mkldnn_reorder_conv2d_weight,
|
| 1166 |
+
mkldnn._reorder_convolution_transpose_weight,
|
| 1167 |
+
mkldnn._reorder_linear_weight,
|
| 1168 |
+
mkldnn._reorder_mkldnn_rnn_layer_weight,
|
| 1169 |
+
]
|
| 1170 |
+
if torch._C.has_mkl:
|
| 1171 |
+
packed_weight_ops.append(torch.ops.mkl._mkl_reorder_linear_weight)
|
| 1172 |
+
|
| 1173 |
+
for node in gm.graph.nodes:
|
| 1174 |
+
if node.target in packed_weight_ops and len(node.args[0].users) > 1:
|
| 1175 |
+
for user_node in list(node.args[0].users.keys()):
|
| 1176 |
+
if (
|
| 1177 |
+
user_node.target == node.target
|
| 1178 |
+
and user_node != node
|
| 1179 |
+
and user_node.args == node.args
|
| 1180 |
+
):
|
| 1181 |
+
user_node.replace_all_uses_with(node)
|
| 1182 |
+
gm.graph.erase_node(user_node)
|
| 1183 |
+
|
| 1184 |
+
@functools.lru_cache(None)
|
| 1185 |
+
def _mkldnn_fusion_init():
|
| 1186 |
+
# TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now.
|
| 1187 |
+
# Otherwise even the matmul or innerproduct can not be accelerated with acl
|
| 1188 |
+
if (
|
| 1189 |
+
torch.backends.mkldnn.enabled
|
| 1190 |
+
and torch.backends.mkldnn.is_available()
|
| 1191 |
+
and not torch.ops.mkldnn._is_mkldnn_acl_supported()
|
| 1192 |
+
):
|
| 1193 |
+
_register_unary_fusion()
|
| 1194 |
+
_register_inplace_fusion()
|
| 1195 |
+
_register_binary_unary_fusion()
|
| 1196 |
+
_register_binary_fusion()
|
| 1197 |
+
_register_quantization_lowerings()
|
| 1198 |
+
|
| 1199 |
+
@functools.lru_cache(None)
|
| 1200 |
+
def _mkldnn_weight_pack_init():
|
| 1201 |
+
if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available():
|
| 1202 |
+
_register_weight_pack_pass()
|
| 1203 |
+
_recover_linear()
|
| 1204 |
+
_register_quantization_weight_pack_pass()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py
ADDED
|
@@ -0,0 +1,1100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import functools
|
| 3 |
+
import itertools
|
| 4 |
+
import logging
|
| 5 |
+
import operator
|
| 6 |
+
from collections import Counter, defaultdict
|
| 7 |
+
from typing import Any, Dict, List, Optional, Set, Union
|
| 8 |
+
|
| 9 |
+
from sympy import Expr
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch._inductor as inductor
|
| 13 |
+
import torch.utils._pytree as pytree
|
| 14 |
+
from torch import fx
|
| 15 |
+
from torch._decomp import register_decomposition
|
| 16 |
+
from torch._dynamo.utils import counters, optimus_scuba_log
|
| 17 |
+
|
| 18 |
+
from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
|
| 19 |
+
|
| 20 |
+
from torch._utils_internal import upload_graph
|
| 21 |
+
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
|
| 22 |
+
|
| 23 |
+
from .. import config, ir, pattern_matcher
|
| 24 |
+
from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
|
| 25 |
+
|
| 26 |
+
from ..lowering import lowerings as L
|
| 27 |
+
from ..pattern_matcher import (
|
| 28 |
+
_return_true,
|
| 29 |
+
Arg,
|
| 30 |
+
CallFunction,
|
| 31 |
+
CallFunctionVarArgs,
|
| 32 |
+
filter_nodes,
|
| 33 |
+
get_arg_value,
|
| 34 |
+
get_mutation_region_id,
|
| 35 |
+
Ignored,
|
| 36 |
+
init_once_fakemode,
|
| 37 |
+
KeywordArg,
|
| 38 |
+
ListOf,
|
| 39 |
+
Match,
|
| 40 |
+
MULTIPLE,
|
| 41 |
+
PatternMatcherPass,
|
| 42 |
+
register_graph_pattern,
|
| 43 |
+
stable_topological_sort,
|
| 44 |
+
)
|
| 45 |
+
from ..utils import decode_device, is_pointwise_use
|
| 46 |
+
from ..virtualized import V
|
| 47 |
+
from .group_batch_fusion import group_batch_fusion_passes
|
| 48 |
+
from .reinplace import reinplace_inplaceable_ops
|
| 49 |
+
|
| 50 |
+
log = logging.getLogger(__name__)
|
| 51 |
+
aten = torch.ops.aten
|
| 52 |
+
prims = torch.ops.prims
|
| 53 |
+
|
| 54 |
+
# First pass_patterns[0] are applied, then [1], then [2]
|
| 55 |
+
pass_patterns = [
|
| 56 |
+
PatternMatcherPass(),
|
| 57 |
+
PatternMatcherPass(),
|
| 58 |
+
PatternMatcherPass(),
|
| 59 |
+
]
|
| 60 |
+
# patterns applied only in inference
|
| 61 |
+
inference_patterns = PatternMatcherPass()
|
| 62 |
+
decompose_mm_pass = PatternMatcherPass()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
| 66 |
+
"""
|
| 67 |
+
Passes that run on after grad. This is called once on the forwards
|
| 68 |
+
graph and once on the backwards graph.
|
| 69 |
+
|
| 70 |
+
The IR here has been normalized and functionalized.
|
| 71 |
+
"""
|
| 72 |
+
if config.dce:
|
| 73 |
+
# has some issues with mutation in inference mode
|
| 74 |
+
gm.graph.eliminate_dead_code()
|
| 75 |
+
|
| 76 |
+
if is_inference and config.reorder_for_locality:
|
| 77 |
+
reorder_for_locality(gm.graph)
|
| 78 |
+
|
| 79 |
+
fake_tensor_updater = FakeTensorUpdater(gm.graph)
|
| 80 |
+
|
| 81 |
+
if config.post_grad_custom_pre_pass is not None:
|
| 82 |
+
config.post_grad_custom_pre_pass(gm.graph)
|
| 83 |
+
|
| 84 |
+
if config.pattern_matcher:
|
| 85 |
+
lazy_init()
|
| 86 |
+
inductor_before_change = copy.deepcopy(counters["inductor"])
|
| 87 |
+
group_batch_fusion_passes(gm.graph, pre_grad=False)
|
| 88 |
+
if counters["inductor"] != inductor_before_change:
|
| 89 |
+
optimus_scuba_log["group_batch_fusion_post_grad"] = upload_graph(gm.graph)
|
| 90 |
+
remove_noop_ops(gm.graph)
|
| 91 |
+
for patterns in pass_patterns:
|
| 92 |
+
patterns.apply(gm.graph) # type: ignore[arg-type]
|
| 93 |
+
if is_inference:
|
| 94 |
+
inference_patterns.apply(gm.graph) # type: ignore[arg-type]
|
| 95 |
+
decompose_mm_pass.apply(gm.graph) # type: ignore[arg-type]
|
| 96 |
+
|
| 97 |
+
if config.post_grad_custom_post_pass is not None:
|
| 98 |
+
config.post_grad_custom_post_pass(gm.graph)
|
| 99 |
+
|
| 100 |
+
stable_topological_sort(gm.graph)
|
| 101 |
+
|
| 102 |
+
move_constructors_to_cuda(gm.graph)
|
| 103 |
+
|
| 104 |
+
fake_tensor_updater.incremental_update()
|
| 105 |
+
|
| 106 |
+
# Keep these last, since they introduces mutation. Look at
|
| 107 |
+
# ./fx_passes/README.md for a discussion of mutation invariants.
|
| 108 |
+
reinplace_inplaceable_ops(gm.graph)
|
| 109 |
+
decompose_auto_functionalized(gm.graph)
|
| 110 |
+
|
| 111 |
+
gm.recompile()
|
| 112 |
+
gm.graph.lint()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@init_once_fakemode
|
| 116 |
+
def lazy_init():
|
| 117 |
+
if torch._C._has_mkldnn:
|
| 118 |
+
from . import decompose_mem_bound_mm # noqa: F401
|
| 119 |
+
from .mkldnn_fusion import _mkldnn_fusion_init
|
| 120 |
+
|
| 121 |
+
_mkldnn_fusion_init()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def reorder_for_locality(graph: torch.fx.Graph):
|
| 125 |
+
def visit(other_node):
|
| 126 |
+
if (
|
| 127 |
+
other_node.op == "call_function"
|
| 128 |
+
and other_node.target != operator.getitem
|
| 129 |
+
and all((n in seen_nodes) for n in other_node.users)
|
| 130 |
+
and get_mutation_region_id(graph, node)
|
| 131 |
+
== get_mutation_region_id(graph, other_node)
|
| 132 |
+
):
|
| 133 |
+
# move node's producers right before it
|
| 134 |
+
node.prepend(other_node)
|
| 135 |
+
|
| 136 |
+
seen_nodes = set()
|
| 137 |
+
|
| 138 |
+
# only reorder nodes before the first copy_ in the graph.
|
| 139 |
+
# copy_ will appear at the end of functionalized graphs when there is mutation on inputs,
|
| 140 |
+
# and this reordering doesnt work well with mutation
|
| 141 |
+
first_copy = next(
|
| 142 |
+
(
|
| 143 |
+
node
|
| 144 |
+
for node in graph.nodes
|
| 145 |
+
if node.op == "call_function"
|
| 146 |
+
and node.target == torch.ops.aten.copy_.default
|
| 147 |
+
),
|
| 148 |
+
None,
|
| 149 |
+
)
|
| 150 |
+
past_mutating_epilogue = True if first_copy is None else False
|
| 151 |
+
|
| 152 |
+
for node in reversed(graph.nodes):
|
| 153 |
+
seen_nodes.add(node)
|
| 154 |
+
if not past_mutating_epilogue:
|
| 155 |
+
past_mutating_epilogue = node is first_copy
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
torch.fx.map_arg((node.args, node.kwargs), visit)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def register_lowering_pattern(pattern, extra_check=_return_true, pass_number=1):
|
| 162 |
+
"""
|
| 163 |
+
Register an aten to inductor IR replacement pattern
|
| 164 |
+
"""
|
| 165 |
+
return pattern_matcher.register_lowering_pattern(
|
| 166 |
+
pattern, extra_check, pass_dict=pass_patterns[pass_number]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
################################################################################
|
| 171 |
+
# Actual patterns below this point.
|
| 172 |
+
# Priority of patterns is:
|
| 173 |
+
# - later output nodes first
|
| 174 |
+
# - order patterns are defined in
|
| 175 |
+
################################################################################
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def is_valid_mm_plus_mm(match: Match):
|
| 179 |
+
*b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape
|
| 180 |
+
*b2, k2, n1 = match.kwargs["mat2"].meta.get("tensor_meta").shape
|
| 181 |
+
if k1 != k2:
|
| 182 |
+
return False
|
| 183 |
+
|
| 184 |
+
*b1, m2, k3 = match.kwargs["mat3"].meta.get("tensor_meta").shape
|
| 185 |
+
*b2, k4, n2 = match.kwargs["mat4"].meta.get("tensor_meta").shape
|
| 186 |
+
if k3 != k4:
|
| 187 |
+
return False
|
| 188 |
+
|
| 189 |
+
if m1 != m2 or n1 != n2:
|
| 190 |
+
return False
|
| 191 |
+
|
| 192 |
+
return True
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@register_lowering_pattern(
|
| 196 |
+
CallFunction(
|
| 197 |
+
aten.add,
|
| 198 |
+
CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")),
|
| 199 |
+
CallFunction(aten.mm, KeywordArg("mat3"), KeywordArg("mat4")),
|
| 200 |
+
),
|
| 201 |
+
extra_check=is_valid_mm_plus_mm,
|
| 202 |
+
)
|
| 203 |
+
def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4):
|
| 204 |
+
return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def cuda_and_enabled_mixed_mm(match):
|
| 208 |
+
return (config.use_mixed_mm or config.force_mixed_mm) and getattr(
|
| 209 |
+
match.kwargs["mat1"].meta.get("val"), "is_cuda", False
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def cuda_and_enabled_mixed_mm_and_not_int8(match):
|
| 214 |
+
return (
|
| 215 |
+
cuda_and_enabled_mixed_mm(match)
|
| 216 |
+
and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False)
|
| 217 |
+
and getattr(match.kwargs["mat2"].meta.get("val"), "dtype", torch.int8)
|
| 218 |
+
!= torch.int8
|
| 219 |
+
) # bitshift numerics in triton and pytorch don't match for torch.int8
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
"""
|
| 223 |
+
this is intended to be used to unpack a [K,N] int4 tensor from a [K/2, N] uint4x2 tensor
|
| 224 |
+
(where the int4 and uint4x2 are represented with int8 and uint8 respectively)
|
| 225 |
+
where every other row of the int4 is packed with the row above it as:
|
| 226 |
+
uint4x2[k,n] = (8+int4[2*k,n])+(8+int4[2*k+1,n])<<4
|
| 227 |
+
|
| 228 |
+
unpack formulas:
|
| 229 |
+
int4[2*k,n]=(uint4x2[k,n] & 0xF) - 8
|
| 230 |
+
int4[2*k+1,n]=(uint4x2[k,n] >> 4) - 8
|
| 231 |
+
|
| 232 |
+
thus matching on unpack formula:
|
| 233 |
+
torch.mm(mat1, torch.cat((mat2 & 0xF, mat2>>4),1).reshape(mat2_mm_shape).to(mat2_dtype).sub(8))
|
| 234 |
+
|
| 235 |
+
note: although the unpack formula in pytorch and the triton kernel is designed for a uint8 mat2, the behavior
|
| 236 |
+
of the kernel matches the pytorch formula for all dtypes except torch.int8
|
| 237 |
+
where the bitwise numerics in triton do not match those in pytorch.
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
@register_lowering_pattern(
|
| 242 |
+
CallFunction(
|
| 243 |
+
aten.mm.default,
|
| 244 |
+
KeywordArg("mat1"),
|
| 245 |
+
CallFunction(
|
| 246 |
+
aten.sub.Tensor,
|
| 247 |
+
CallFunction(
|
| 248 |
+
prims.convert_element_type.default,
|
| 249 |
+
CallFunction(
|
| 250 |
+
aten.reshape.default,
|
| 251 |
+
CallFunction(
|
| 252 |
+
aten.cat.default,
|
| 253 |
+
ListOf(
|
| 254 |
+
CallFunction(
|
| 255 |
+
aten.bitwise_and.Scalar,
|
| 256 |
+
KeywordArg("mat2"),
|
| 257 |
+
0xF,
|
| 258 |
+
),
|
| 259 |
+
CallFunction(
|
| 260 |
+
aten.__rshift__.Scalar,
|
| 261 |
+
KeywordArg("mat2"),
|
| 262 |
+
4,
|
| 263 |
+
),
|
| 264 |
+
),
|
| 265 |
+
1,
|
| 266 |
+
),
|
| 267 |
+
KeywordArg("mat2_mm_shape"),
|
| 268 |
+
),
|
| 269 |
+
KeywordArg("mat2_dtype"),
|
| 270 |
+
),
|
| 271 |
+
8,
|
| 272 |
+
),
|
| 273 |
+
),
|
| 274 |
+
extra_check=cuda_and_enabled_mixed_mm_and_not_int8,
|
| 275 |
+
)
|
| 276 |
+
def uint4x2_mixed_mm(match: Match, mat1, mat2, mat2_mm_shape, mat2_dtype):
|
| 277 |
+
return inductor.kernel.unpack_mixed_mm.tuned_uint4x2_mixed_mm(
|
| 278 |
+
mat1, mat2, mat2_mm_shape, mat2_dtype
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
"""
|
| 283 |
+
torch.mm(mat1, mat2.to(mat2_dtype))
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
@register_lowering_pattern(
|
| 288 |
+
CallFunction(
|
| 289 |
+
aten.mm,
|
| 290 |
+
KeywordArg("mat1"),
|
| 291 |
+
CallFunction(
|
| 292 |
+
prims.convert_element_type.default,
|
| 293 |
+
KeywordArg("mat2"),
|
| 294 |
+
KeywordArg("mat2_dtype"),
|
| 295 |
+
),
|
| 296 |
+
),
|
| 297 |
+
extra_check=cuda_and_enabled_mixed_mm,
|
| 298 |
+
)
|
| 299 |
+
def mixed_mm(match: Match, mat1, mat2, mat2_dtype):
|
| 300 |
+
return inductor.kernel.mm.tuned_mixed_mm(mat1, mat2, mat2_dtype)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
@register_graph_pattern(
|
| 304 |
+
CallFunction(
|
| 305 |
+
aten.cumsum.default,
|
| 306 |
+
CallFunction(
|
| 307 |
+
torch.ops.aten.full.default,
|
| 308 |
+
KeywordArg("shape"),
|
| 309 |
+
KeywordArg("fill_value"),
|
| 310 |
+
dtype=KeywordArg("dtype"),
|
| 311 |
+
layout=Ignored(),
|
| 312 |
+
device=KeywordArg("device"),
|
| 313 |
+
pin_memory=False,
|
| 314 |
+
_users=MULTIPLE,
|
| 315 |
+
),
|
| 316 |
+
KeywordArg("dim"),
|
| 317 |
+
_users=MULTIPLE,
|
| 318 |
+
),
|
| 319 |
+
pass_dict=pass_patterns[1],
|
| 320 |
+
)
|
| 321 |
+
def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim):
|
| 322 |
+
"""Based on a pattern in OPTForCausalLM"""
|
| 323 |
+
|
| 324 |
+
if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
|
| 325 |
+
# cumsum promotes all integral types to int64
|
| 326 |
+
dtype = torch.int64
|
| 327 |
+
|
| 328 |
+
def repl(*shape):
|
| 329 |
+
dim_size = shape[dim]
|
| 330 |
+
idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype)
|
| 331 |
+
|
| 332 |
+
inter_shape = [1] * len(shape)
|
| 333 |
+
inter_shape[dim] = dim_size
|
| 334 |
+
return (idx * fill_value).view(inter_shape).expand(shape)
|
| 335 |
+
|
| 336 |
+
# only replace the output node, not all nodes
|
| 337 |
+
match.nodes = [match.output_node()]
|
| 338 |
+
with V.fake_mode:
|
| 339 |
+
match.replace_by_example(repl, list(shape))
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def shape_of_mm(a, b):
|
| 343 |
+
m, _ = a.get_size()
|
| 344 |
+
_, n = b.get_size()
|
| 345 |
+
return [m, n]
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
@register_lowering_pattern(
|
| 349 |
+
CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()),
|
| 350 |
+
)
|
| 351 |
+
def cat_mm(match, inputs, dim):
|
| 352 |
+
return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of_mm)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
@register_lowering_pattern(
|
| 356 |
+
CallFunction(
|
| 357 |
+
aten.cat, ListOf(CallFunction(aten.addmm, Arg(), Arg(), Arg())), Arg()
|
| 358 |
+
),
|
| 359 |
+
)
|
| 360 |
+
def cat_addmm(match, inputs, dim):
|
| 361 |
+
def shape_of(bias, a, b):
|
| 362 |
+
m, _ = a.get_size()
|
| 363 |
+
_, n = b.get_size()
|
| 364 |
+
return [m, n]
|
| 365 |
+
|
| 366 |
+
return cat_tuned_op(match, inputs, dim, op=L[aten.addmm], shape_of=shape_of)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def cat_tuned_op(match, inputs, dim, *, op, shape_of):
|
| 370 |
+
"""
|
| 371 |
+
Memory planning to remove cat. We can't use the stock memory
|
| 372 |
+
planner since autotuning matmuls needs to know the output layout.
|
| 373 |
+
"""
|
| 374 |
+
if len(inputs) == 1:
|
| 375 |
+
return op(*inputs[0])
|
| 376 |
+
|
| 377 |
+
# TODO(jansel): rewrite this as a bmm?
|
| 378 |
+
if dim < 0:
|
| 379 |
+
dim += len(shape_of(*inputs[0]))
|
| 380 |
+
assert dim in (0, 1)
|
| 381 |
+
notdim = 1 - dim
|
| 382 |
+
|
| 383 |
+
new_size: Optional[Union[List[Expr], List[int]]] = None
|
| 384 |
+
offsets_start = []
|
| 385 |
+
offsets_end = []
|
| 386 |
+
|
| 387 |
+
# compute output sizes
|
| 388 |
+
for i in range(len(inputs)):
|
| 389 |
+
shape = shape_of(*inputs[i])
|
| 390 |
+
if new_size is None:
|
| 391 |
+
new_size = shape
|
| 392 |
+
else:
|
| 393 |
+
new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload]
|
| 394 |
+
shape[notdim], new_size[notdim]
|
| 395 |
+
)
|
| 396 |
+
new_size[dim] += shape[dim]
|
| 397 |
+
offsets_start.append(new_size[dim] - shape[dim])
|
| 398 |
+
offsets_end.append(new_size[dim])
|
| 399 |
+
|
| 400 |
+
assert new_size is not None
|
| 401 |
+
dtype = functools.reduce(
|
| 402 |
+
torch.promote_types,
|
| 403 |
+
[x.get_dtype() for x in itertools.chain.from_iterable(inputs)],
|
| 404 |
+
)
|
| 405 |
+
device = inputs[0][0].get_device()
|
| 406 |
+
kernel = ir.ConcatKernel(
|
| 407 |
+
name=None,
|
| 408 |
+
layout=ir.FixedLayout(device, dtype, new_size),
|
| 409 |
+
inputs=[],
|
| 410 |
+
)
|
| 411 |
+
kernel_tensor = ir.TensorBox.create(kernel)
|
| 412 |
+
|
| 413 |
+
for i in range(len(inputs)):
|
| 414 |
+
dst = ir.SliceView.create(kernel_tensor, dim, offsets_start[i], offsets_end[i])
|
| 415 |
+
src = op(*inputs[i], layout=dst.get_layout()).data.data
|
| 416 |
+
assert isinstance(src, (ir.ExternKernelOut, ir.TemplateBuffer))
|
| 417 |
+
src.layout = ir.AliasedLayout(dst)
|
| 418 |
+
kernel.inputs.append(src)
|
| 419 |
+
|
| 420 |
+
kernel.name = V.graph.register_buffer(kernel)
|
| 421 |
+
kernel.inputs = ir.ConcatKernel.unwrap_storage(kernel.inputs)
|
| 422 |
+
return kernel_tensor
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
_cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
@register_lowering_pattern(
|
| 429 |
+
CallFunction(
|
| 430 |
+
aten.cat,
|
| 431 |
+
[
|
| 432 |
+
_cat_1,
|
| 433 |
+
CallFunction(
|
| 434 |
+
aten.slice,
|
| 435 |
+
_cat_1,
|
| 436 |
+
1,
|
| 437 |
+
0,
|
| 438 |
+
KeywordArg("size"),
|
| 439 |
+
),
|
| 440 |
+
],
|
| 441 |
+
1,
|
| 442 |
+
)
|
| 443 |
+
)
|
| 444 |
+
def cat_slice_cat(match, cat_input, size, dim=1):
|
| 445 |
+
"""
|
| 446 |
+
This is an example of a more complex pattern where cat_1 is used
|
| 447 |
+
multiple times inside the pattern. We fold 2 calls to cat into one.
|
| 448 |
+
|
| 449 |
+
Matches:
|
| 450 |
+
cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1)
|
| 451 |
+
slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
|
| 452 |
+
slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
|
| 453 |
+
cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
Rewrite to:
|
| 457 |
+
slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19)
|
| 458 |
+
cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1)
|
| 459 |
+
"""
|
| 460 |
+
first, *rest = cat_input
|
| 461 |
+
# Optimization is optional, because we can just not fold the cat
|
| 462 |
+
# size should be within first.get_size()[dim] such that the optimization is valid.
|
| 463 |
+
# For negative `end`, we currently fallback to not optimizing.
|
| 464 |
+
if size >= 0 and V.graph.sizevars.statically_known_leq(size, first.get_size()[dim]):
|
| 465 |
+
# fold 2 cats into 1 cat
|
| 466 |
+
return L[aten.cat](
|
| 467 |
+
[
|
| 468 |
+
first,
|
| 469 |
+
*rest,
|
| 470 |
+
L[aten.slice](first, dim, 0, size),
|
| 471 |
+
],
|
| 472 |
+
dim,
|
| 473 |
+
)
|
| 474 |
+
else:
|
| 475 |
+
# don't expect to hit this case, just fall back
|
| 476 |
+
tmp = L[aten.cat](cat_input, dim)
|
| 477 |
+
return L[aten.cat](
|
| 478 |
+
[
|
| 479 |
+
tmp,
|
| 480 |
+
L[aten.slice](tmp, dim, 0, size),
|
| 481 |
+
],
|
| 482 |
+
dim,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def is_valid_splitwithsizes_cat(match):
|
| 487 |
+
split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
|
| 488 |
+
cat_nodes = filter_nodes(match.nodes, aten.cat)
|
| 489 |
+
get_item_nodes = filter_nodes(match.nodes, operator.getitem)
|
| 490 |
+
if len(split_nodes) != 1 or len(cat_nodes) != 1:
|
| 491 |
+
return False
|
| 492 |
+
split_node, cat_node = split_nodes[0], cat_nodes[0]
|
| 493 |
+
# The dim of split and cat should match for passthrough
|
| 494 |
+
if get_arg_value(split_node, 2, "dim") != get_arg_value(cat_node, 1, "dim"):
|
| 495 |
+
return False
|
| 496 |
+
get_item_args = {
|
| 497 |
+
get_arg_value(get_item_node, 1) for get_item_node in get_item_nodes
|
| 498 |
+
}
|
| 499 |
+
assert None not in get_item_args
|
| 500 |
+
split_sizes = get_arg_value(split_node, 1, "split_sizes")
|
| 501 |
+
# All parts of split should be included in the cat
|
| 502 |
+
if get_item_args != set(range(len(split_sizes))):
|
| 503 |
+
return False
|
| 504 |
+
# The order of get_item_args should same with cat_node used.
|
| 505 |
+
# For example, if the split_node like split_with_sizes(input, [2, 2, 3], 1),
|
| 506 |
+
# the cat node should be like cat([get_item(0), get_item(1), get_item(2)], 1).
|
| 507 |
+
cat_items_args_order = [
|
| 508 |
+
get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0)
|
| 509 |
+
]
|
| 510 |
+
if cat_items_args_order != list(range(len(split_sizes))):
|
| 511 |
+
return False
|
| 512 |
+
|
| 513 |
+
return True
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def same_meta(node1: torch.fx.Node, node2: torch.fx.Node):
|
| 517 |
+
"""True if two nodes have the same metadata"""
|
| 518 |
+
val1 = node1.meta.get("val")
|
| 519 |
+
val2 = node2.meta.get("val")
|
| 520 |
+
return (
|
| 521 |
+
val1 is not None
|
| 522 |
+
and val2 is not None
|
| 523 |
+
and statically_known_true(sym_eq(val1.size(), val2.size()))
|
| 524 |
+
and val1.layout == val2.layout
|
| 525 |
+
and val1.dtype == val2.dtype
|
| 526 |
+
and val1.device == val2.device
|
| 527 |
+
and (
|
| 528 |
+
val1.layout != torch.strided
|
| 529 |
+
or statically_known_true(sym_eq(val1.stride(), val2.stride()))
|
| 530 |
+
)
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
noop_registry: Dict[Any, Any] = {}
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def register_noop_decomp(targets, nop_arg=0):
|
| 538 |
+
def register_fun(cond):
|
| 539 |
+
register_decomposition(targets, registry=noop_registry, unsafe=True)(
|
| 540 |
+
(cond, nop_arg)
|
| 541 |
+
)
|
| 542 |
+
return cond
|
| 543 |
+
|
| 544 |
+
return register_fun
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
@register_noop_decomp(aten.slice)
|
| 548 |
+
def slice_noop(self, dim=0, start=None, end=None, step=1):
|
| 549 |
+
if start is None or end is None:
|
| 550 |
+
return False
|
| 551 |
+
if start == 0 and end >= 2**63 - 1 and step == 1:
|
| 552 |
+
return True
|
| 553 |
+
return False
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
@register_noop_decomp(aten.slice_scatter, 1)
|
| 557 |
+
def slice_scatter_noop(self, src, dim=0, start=None, end=None, step=1):
|
| 558 |
+
if start is None:
|
| 559 |
+
start = 0
|
| 560 |
+
if end is None:
|
| 561 |
+
end = 2**63 - 1
|
| 562 |
+
if start == 0 and end >= 2**63 - 1 and step == 1:
|
| 563 |
+
return True
|
| 564 |
+
return False
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
@register_noop_decomp(aten.repeat)
|
| 568 |
+
def repeat_noop(self, repeats):
|
| 569 |
+
return all(r == 1 for r in repeats)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
@register_noop_decomp(aten.constant_pad_nd)
|
| 573 |
+
def constant_pad_nd(x, padding, fill_value=0):
|
| 574 |
+
return all(p == 0 for p in padding)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
@register_noop_decomp(torch.ops.prims.convert_element_type)
|
| 578 |
+
def convert_element_type_noop(x, dtype: torch.dtype):
|
| 579 |
+
return x.dtype == dtype
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
@register_noop_decomp(torch.ops.prims.device_put)
|
| 583 |
+
def device_put_noop(x, device):
|
| 584 |
+
return x.device == decode_device(device)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
@register_noop_decomp([aten.ceil, aten.floor, aten.round, aten.trunc])
|
| 588 |
+
def int_noop(x):
|
| 589 |
+
return is_integer_dtype(x.dtype)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
@register_noop_decomp([aten.pow])
|
| 593 |
+
def pow_noop(a, b):
|
| 594 |
+
return isinstance(b, int) and b == 1
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
@register_noop_decomp([aten.cat], lambda args: args[0][0])
|
| 598 |
+
def cat_noop(inputs, dim=0):
|
| 599 |
+
return len(inputs) == 1
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
@register_noop_decomp(aten.view)
|
| 603 |
+
def view_noop(arg, size):
|
| 604 |
+
return arg.shape == size
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
# Note, we also always have a check for identical metadata, which is why these
|
| 608 |
+
# are safe
|
| 609 |
+
@register_noop_decomp([aten.copy], nop_arg=1)
|
| 610 |
+
@register_noop_decomp([aten.alias, aten.clone])
|
| 611 |
+
def true_noop(*args, **kwargs):
|
| 612 |
+
return True
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def remove_noop_ops(graph: torch.fx.Graph):
|
| 616 |
+
"""
|
| 617 |
+
Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph.
|
| 618 |
+
"""
|
| 619 |
+
inputs = set()
|
| 620 |
+
input_storages = set()
|
| 621 |
+
output_storages = set()
|
| 622 |
+
|
| 623 |
+
for node in graph.nodes:
|
| 624 |
+
if node.op == "placeholder":
|
| 625 |
+
inputs.add(node)
|
| 626 |
+
input_storages.add(get_node_storage(node))
|
| 627 |
+
else:
|
| 628 |
+
break
|
| 629 |
+
|
| 630 |
+
output_node = next(iter(reversed(graph.nodes)))
|
| 631 |
+
assert output_node.op == "output"
|
| 632 |
+
for out in output_node.args[0]:
|
| 633 |
+
if isinstance(out, torch.fx.Node):
|
| 634 |
+
output_storages.add(get_node_storage(out))
|
| 635 |
+
|
| 636 |
+
for node in graph.nodes:
|
| 637 |
+
if node.target in noop_registry:
|
| 638 |
+
cond, src_index = noop_registry[node.target]
|
| 639 |
+
if isinstance(src_index, int):
|
| 640 |
+
src = node.args[src_index]
|
| 641 |
+
else:
|
| 642 |
+
src = src_index(node.args)
|
| 643 |
+
if not isinstance(src, torch.fx.Node):
|
| 644 |
+
continue
|
| 645 |
+
# Don't introduce new aliasing between inputs and outputs.
|
| 646 |
+
# See fx_passes/README.md for a discussion of why this is
|
| 647 |
+
# necessary.
|
| 648 |
+
node_storage = get_node_storage(node)
|
| 649 |
+
src_storage = get_node_storage(src)
|
| 650 |
+
node_is_view = node_storage == src_storage
|
| 651 |
+
if (
|
| 652 |
+
not node_is_view
|
| 653 |
+
and node_storage in output_storages
|
| 654 |
+
and (src_storage in input_storages or src_storage in output_storages)
|
| 655 |
+
):
|
| 656 |
+
continue
|
| 657 |
+
|
| 658 |
+
# Even if input and outputs are expected to alias,
|
| 659 |
+
# don't make "node is src" True
|
| 660 |
+
if (
|
| 661 |
+
node_is_view
|
| 662 |
+
and node in output_node.args
|
| 663 |
+
and (src in inputs or src in output_node.args)
|
| 664 |
+
):
|
| 665 |
+
continue
|
| 666 |
+
|
| 667 |
+
is_valid, args, kwargs = get_fake_args_kwargs(node)
|
| 668 |
+
if not is_valid:
|
| 669 |
+
continue
|
| 670 |
+
if same_meta(node, src) and cond(*args, **kwargs):
|
| 671 |
+
node.replace_all_uses_with(src)
|
| 672 |
+
graph.erase_node(node)
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def decompose_auto_functionalized(graph):
|
| 676 |
+
graph_pass = PatternMatcherPass()
|
| 677 |
+
|
| 678 |
+
@register_graph_pattern(
|
| 679 |
+
CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized),
|
| 680 |
+
pass_dict=graph_pass,
|
| 681 |
+
)
|
| 682 |
+
def replacement(match: Match, *args, **kwargs):
|
| 683 |
+
from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense
|
| 684 |
+
|
| 685 |
+
only_clone_these_tensors = tuple(
|
| 686 |
+
match.nodes[0].meta.get("only_clone_these_tensors", [])
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
flat_args, spec = pytree.tree_flatten((args, kwargs))
|
| 690 |
+
|
| 691 |
+
# NB: we combine (args, kwargs) into flat args for replacing.
|
| 692 |
+
# This is replace_by_example uses make_fx which does not support
|
| 693 |
+
# tracing a function with kwargs.
|
| 694 |
+
def decomp(*flat_args):
|
| 695 |
+
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
| 696 |
+
return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs)
|
| 697 |
+
|
| 698 |
+
with V.fake_mode:
|
| 699 |
+
match.replace_by_example(decomp, flat_args, run_dce=False)
|
| 700 |
+
|
| 701 |
+
graph_pass.apply(graph)
|
| 702 |
+
for node in graph.nodes:
|
| 703 |
+
if node.target is torch.ops.higher_order.auto_functionalized:
|
| 704 |
+
raise AssertionError("auto_functionalized was not removed")
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
@register_lowering_pattern(
|
| 708 |
+
CallFunction(
|
| 709 |
+
aten.cat,
|
| 710 |
+
ListOf(
|
| 711 |
+
CallFunction(
|
| 712 |
+
operator.getitem,
|
| 713 |
+
CallFunction(
|
| 714 |
+
aten.split_with_sizes,
|
| 715 |
+
KeywordArg("input_"),
|
| 716 |
+
Ignored(),
|
| 717 |
+
Ignored(),
|
| 718 |
+
_users=MULTIPLE,
|
| 719 |
+
),
|
| 720 |
+
Ignored(),
|
| 721 |
+
),
|
| 722 |
+
),
|
| 723 |
+
Ignored(),
|
| 724 |
+
),
|
| 725 |
+
pass_number=2,
|
| 726 |
+
extra_check=is_valid_splitwithsizes_cat,
|
| 727 |
+
)
|
| 728 |
+
def splitwithsizes_cat_replace(match, input_):
|
| 729 |
+
return input_
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
def is_valid_cat_splitwithsizes(match):
|
| 733 |
+
cat_nodes = filter_nodes(match.nodes, aten.cat)
|
| 734 |
+
split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
|
| 735 |
+
if len(split_nodes) != 1 or len(cat_nodes) != 1:
|
| 736 |
+
return False
|
| 737 |
+
split_node, cat_node = split_nodes[0], cat_nodes[0]
|
| 738 |
+
|
| 739 |
+
# the cat node has other users: can't eliminate
|
| 740 |
+
if len(cat_node.users) > 1:
|
| 741 |
+
return False
|
| 742 |
+
|
| 743 |
+
# the dim of the cat and split should match
|
| 744 |
+
dim = get_arg_value(split_node, 2, "dim")
|
| 745 |
+
if dim != get_arg_value(cat_node, 1, "dim"):
|
| 746 |
+
return False
|
| 747 |
+
|
| 748 |
+
cat_inputs = list(get_arg_value(cat_node, 0))
|
| 749 |
+
split_sizes = get_arg_value(split_node, 1, "split_sizes")
|
| 750 |
+
# the number of input tensors in cat and the
|
| 751 |
+
# length of the split sizes should match
|
| 752 |
+
if len(cat_inputs) != len(split_sizes):
|
| 753 |
+
return False
|
| 754 |
+
|
| 755 |
+
for cat_input, split_size in zip(cat_inputs, split_sizes):
|
| 756 |
+
# each cat input tensor's size along dim
|
| 757 |
+
# should match the corresponding split size
|
| 758 |
+
if "val" not in cat_input.meta:
|
| 759 |
+
return False
|
| 760 |
+
cat_input_size = cat_input.meta["val"].size(dim)
|
| 761 |
+
if cat_input_size != split_size:
|
| 762 |
+
return False
|
| 763 |
+
|
| 764 |
+
return True
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
@register_lowering_pattern(
|
| 768 |
+
CallFunction(
|
| 769 |
+
aten.split_with_sizes,
|
| 770 |
+
CallFunction(
|
| 771 |
+
aten.cat,
|
| 772 |
+
KeywordArg("input_"),
|
| 773 |
+
Ignored(),
|
| 774 |
+
_users=MULTIPLE,
|
| 775 |
+
),
|
| 776 |
+
Ignored(),
|
| 777 |
+
Ignored(),
|
| 778 |
+
),
|
| 779 |
+
pass_number=2,
|
| 780 |
+
extra_check=is_valid_cat_splitwithsizes,
|
| 781 |
+
)
|
| 782 |
+
def cat_splitwithsizes_replace(match, input_):
|
| 783 |
+
return input_
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
def view_to_reshape(gm):
|
| 787 |
+
"""
|
| 788 |
+
Replace view ops in the GraphModule to reshape ops.
|
| 789 |
+
"""
|
| 790 |
+
for nd in gm.graph.nodes:
|
| 791 |
+
if nd.target == torch.ops.aten.view.default:
|
| 792 |
+
nd.target = torch.ops.aten.reshape.default
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
def should_prefer_unfused_addmm(match):
|
| 796 |
+
inp = match.kwargs["inp"]
|
| 797 |
+
if not inp.meta["val"].is_cuda:
|
| 798 |
+
return False
|
| 799 |
+
|
| 800 |
+
output = match.output_node()
|
| 801 |
+
return all(is_pointwise_use(use) for use in output.users)
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
@register_graph_pattern(
|
| 805 |
+
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
|
| 806 |
+
pass_dict=pass_patterns[2],
|
| 807 |
+
extra_check=should_prefer_unfused_addmm,
|
| 808 |
+
)
|
| 809 |
+
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
|
| 810 |
+
def repl(inp, x1, x2):
|
| 811 |
+
return x1 @ x2 + inp
|
| 812 |
+
|
| 813 |
+
with V.fake_mode:
|
| 814 |
+
match.replace_by_example(repl, [inp, mat1, mat2])
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
def is_valid_addmm_fusion(match):
|
| 818 |
+
mat1, mat2 = match.args
|
| 819 |
+
inp = match.kwargs["inp"]
|
| 820 |
+
|
| 821 |
+
if not (
|
| 822 |
+
isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor)
|
| 823 |
+
):
|
| 824 |
+
return False # Input is a number
|
| 825 |
+
|
| 826 |
+
in_shape = inp.meta["val"].shape
|
| 827 |
+
mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1]
|
| 828 |
+
matched = is_expandable_to(in_shape, mm_shape)
|
| 829 |
+
if not matched:
|
| 830 |
+
return False # Shape mismatch
|
| 831 |
+
|
| 832 |
+
return not should_prefer_unfused_addmm(match)
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
@register_graph_pattern(
|
| 836 |
+
CallFunction(
|
| 837 |
+
aten.add,
|
| 838 |
+
CallFunction(aten.mm, Arg(), Arg()),
|
| 839 |
+
KeywordArg("inp"),
|
| 840 |
+
),
|
| 841 |
+
pass_dict=pass_patterns[2],
|
| 842 |
+
extra_check=is_valid_addmm_fusion,
|
| 843 |
+
)
|
| 844 |
+
@register_graph_pattern(
|
| 845 |
+
CallFunction(
|
| 846 |
+
aten.add,
|
| 847 |
+
KeywordArg("inp"),
|
| 848 |
+
CallFunction(aten.mm, Arg(), Arg()),
|
| 849 |
+
),
|
| 850 |
+
pass_dict=pass_patterns[2],
|
| 851 |
+
extra_check=is_valid_addmm_fusion,
|
| 852 |
+
)
|
| 853 |
+
def addmm(match, mat1, mat2, *, inp):
|
| 854 |
+
def repl(inp, mat1, mat2):
|
| 855 |
+
return aten.addmm(inp, mat1, mat2)
|
| 856 |
+
|
| 857 |
+
with V.fake_mode:
|
| 858 |
+
match.replace_by_example(repl, [inp, mat1, mat2])
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
def check_shape_cuda_and_fused_int_mm_mul_enabled(match):
|
| 862 |
+
return (
|
| 863 |
+
config.force_fuse_int_mm_with_mul
|
| 864 |
+
and len(getattr(match.args[2].meta.get("val"), "shape", [])) == 2
|
| 865 |
+
and getattr(match.args[2].meta.get("val"), "is_cuda", False)
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
@register_lowering_pattern(
|
| 870 |
+
CallFunction(
|
| 871 |
+
prims.convert_element_type.default,
|
| 872 |
+
CallFunction(
|
| 873 |
+
aten.mul,
|
| 874 |
+
CallFunction(
|
| 875 |
+
aten._int_mm,
|
| 876 |
+
Arg(),
|
| 877 |
+
Arg(),
|
| 878 |
+
),
|
| 879 |
+
Arg(),
|
| 880 |
+
),
|
| 881 |
+
Arg(),
|
| 882 |
+
),
|
| 883 |
+
check_shape_cuda_and_fused_int_mm_mul_enabled,
|
| 884 |
+
)
|
| 885 |
+
@register_lowering_pattern(
|
| 886 |
+
CallFunction(
|
| 887 |
+
aten.mul,
|
| 888 |
+
CallFunction(
|
| 889 |
+
aten._int_mm,
|
| 890 |
+
Arg(),
|
| 891 |
+
Arg(),
|
| 892 |
+
),
|
| 893 |
+
Arg(),
|
| 894 |
+
),
|
| 895 |
+
check_shape_cuda_and_fused_int_mm_mul_enabled,
|
| 896 |
+
)
|
| 897 |
+
def fused_int_mm_mul(match: Match, mat1, mat2, mat3, out_dtype=None):
|
| 898 |
+
return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype)
|
| 899 |
+
|
| 900 |
+
|
| 901 |
+
class ConstructorMoverPass:
|
| 902 |
+
def __init__(self, target: str, allow_outputs: bool = False) -> None:
|
| 903 |
+
"""
|
| 904 |
+
Move constructors from cpu to the target_device.
|
| 905 |
+
|
| 906 |
+
Sweeps through the module, looking for constructor nodes that can be moved
|
| 907 |
+
to the target_device.
|
| 908 |
+
|
| 909 |
+
A constructor node can be moved to the target_device iff all of its users
|
| 910 |
+
can also be moved (tested by cannot_be_moved). Otherwise, all dependent
|
| 911 |
+
constructor nodes won't be moved.
|
| 912 |
+
|
| 913 |
+
- target: target device type
|
| 914 |
+
- allow_outputs: allow outputs to be moved
|
| 915 |
+
"""
|
| 916 |
+
|
| 917 |
+
self.target = target
|
| 918 |
+
self.allow_outputs = allow_outputs
|
| 919 |
+
|
| 920 |
+
assert isinstance(target, str), (
|
| 921 |
+
"target should be a string representing the device type. "
|
| 922 |
+
f"Got: {type(target).__name__}"
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
def allow_cpu_device(self, node: fx.Node) -> bool:
|
| 926 |
+
"""
|
| 927 |
+
Returns whether a node that returns a tensor on the target device may have
|
| 928 |
+
cpu tensors as input.
|
| 929 |
+
"""
|
| 930 |
+
return node.target in (
|
| 931 |
+
torch.ops.aten.index.Tensor,
|
| 932 |
+
torch.ops.aten.index_put.default,
|
| 933 |
+
torch.ops.aten.index_put_.default,
|
| 934 |
+
torch.ops.aten.copy.default,
|
| 935 |
+
torch.ops.aten.copy_.default,
|
| 936 |
+
torch.ops.aten.slice_scatter.default,
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
def cannot_be_moved(self, node: fx.Node) -> bool:
|
| 940 |
+
"""
|
| 941 |
+
Returns whether a node can be moved to the target device.
|
| 942 |
+
|
| 943 |
+
If this function returns False, it means that this node and all of its users
|
| 944 |
+
won't be moved into the target device.
|
| 945 |
+
"""
|
| 946 |
+
if node.target == "output":
|
| 947 |
+
return not self.allow_outputs
|
| 948 |
+
|
| 949 |
+
if not (
|
| 950 |
+
isinstance(node.target, torch._ops.OpOverload)
|
| 951 |
+
and node.target.namespace in ("prims", "aten")
|
| 952 |
+
):
|
| 953 |
+
return True
|
| 954 |
+
|
| 955 |
+
return False
|
| 956 |
+
|
| 957 |
+
def get_node_device(self, node: fx.Node) -> Optional[torch.device]:
|
| 958 |
+
"""
|
| 959 |
+
Get the device of a node.
|
| 960 |
+
"""
|
| 961 |
+
ten = node.meta.get("val")
|
| 962 |
+
return None if not isinstance(ten, torch.Tensor) else ten.device
|
| 963 |
+
|
| 964 |
+
def get_cpu_indeg_count(self, graph: fx.Graph) -> Dict[fx.Node, int]:
|
| 965 |
+
"""
|
| 966 |
+
Get the number of cpu inputs to a node
|
| 967 |
+
"""
|
| 968 |
+
cpu_indeg: Dict[fx.Node, int] = Counter()
|
| 969 |
+
|
| 970 |
+
for node in graph.nodes:
|
| 971 |
+
cpu_count = 0
|
| 972 |
+
|
| 973 |
+
def add_cpu_inp(node):
|
| 974 |
+
nonlocal cpu_count
|
| 975 |
+
device = self.get_node_device(node)
|
| 976 |
+
cpu_count += device is not None and device.type == "cpu"
|
| 977 |
+
|
| 978 |
+
pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs))
|
| 979 |
+
|
| 980 |
+
if cpu_count:
|
| 981 |
+
cpu_indeg[node] = cpu_count
|
| 982 |
+
|
| 983 |
+
return cpu_indeg
|
| 984 |
+
|
| 985 |
+
def __call__(self, graph: fx.Graph) -> None:
|
| 986 |
+
target_devices = set()
|
| 987 |
+
constructors = []
|
| 988 |
+
|
| 989 |
+
for node in graph.nodes:
|
| 990 |
+
device = self.get_node_device(node)
|
| 991 |
+
if device and device.type == self.target:
|
| 992 |
+
target_devices.add(device)
|
| 993 |
+
|
| 994 |
+
if not (
|
| 995 |
+
isinstance(node.target, torch._ops.OpOverload)
|
| 996 |
+
and node.target.namespace in ("prims", "aten")
|
| 997 |
+
):
|
| 998 |
+
continue
|
| 999 |
+
|
| 1000 |
+
if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target):
|
| 1001 |
+
continue
|
| 1002 |
+
|
| 1003 |
+
if not node.kwargs.get("device") == torch.device("cpu"):
|
| 1004 |
+
continue
|
| 1005 |
+
|
| 1006 |
+
constructors.append(node)
|
| 1007 |
+
|
| 1008 |
+
# not handling multiple target devices initially
|
| 1009 |
+
if not constructors or len(target_devices) != 1:
|
| 1010 |
+
return
|
| 1011 |
+
|
| 1012 |
+
movable_constructors = self.find_movable_constructors(graph, constructors)
|
| 1013 |
+
|
| 1014 |
+
for node in movable_constructors:
|
| 1015 |
+
kwargs = node.kwargs.copy()
|
| 1016 |
+
kwargs["device"] = next(iter(target_devices))
|
| 1017 |
+
node.kwargs = kwargs
|
| 1018 |
+
|
| 1019 |
+
def find_movable_constructors(
|
| 1020 |
+
self, graph: fx.Graph, constructors: List[fx.Node]
|
| 1021 |
+
) -> Set[fx.Node]:
|
| 1022 |
+
"""
|
| 1023 |
+
Starting from the cpu constructors, iterate through the graph and test that all of their
|
| 1024 |
+
downstream uses can safely be moved to cpu.
|
| 1025 |
+
"""
|
| 1026 |
+
cpu_indeg: Dict[fx.Node, int] = self.get_cpu_indeg_count(graph)
|
| 1027 |
+
|
| 1028 |
+
# which constructors cannot be moved to cuda
|
| 1029 |
+
cannot_move_to_cuda: Set[fx.Node] = set()
|
| 1030 |
+
|
| 1031 |
+
# For any node in the graph, which constructors does it have a dependency on
|
| 1032 |
+
constructor_dependencies: Dict[fx.Node, Set[fx.Node]] = defaultdict(set)
|
| 1033 |
+
|
| 1034 |
+
# if a cpu node has a dependency on two different cpu constructors,
|
| 1035 |
+
# then if either constructor cannot be moved to cuda, the other cannot as well.
|
| 1036 |
+
# In this case any node with a dependency on one will have a dependency on the other
|
| 1037 |
+
equal_constructor_sets: Dict[fx.Node, Set[fx.Node]] = {
|
| 1038 |
+
c: {c} for c in constructors
|
| 1039 |
+
}
|
| 1040 |
+
|
| 1041 |
+
def make_dependencies_equivalent(
|
| 1042 |
+
set1: Set[fx.Node], set2: Set[fx.Node]
|
| 1043 |
+
) -> Set[fx.Node]:
|
| 1044 |
+
# could use union find but not worth complexity here
|
| 1045 |
+
set1.update(set2)
|
| 1046 |
+
for obj in set1:
|
| 1047 |
+
equal_constructor_sets[obj] = set1
|
| 1048 |
+
return set1
|
| 1049 |
+
|
| 1050 |
+
queue: List[fx.Node] = list(constructors)
|
| 1051 |
+
|
| 1052 |
+
for c in queue:
|
| 1053 |
+
constructor_dependencies[c].add(c)
|
| 1054 |
+
|
| 1055 |
+
while queue:
|
| 1056 |
+
node = queue.pop()
|
| 1057 |
+
dependencies = constructor_dependencies[node]
|
| 1058 |
+
|
| 1059 |
+
for user in node.users:
|
| 1060 |
+
if self.cannot_be_moved(user):
|
| 1061 |
+
cannot_move_to_cuda.update(dependencies)
|
| 1062 |
+
break
|
| 1063 |
+
|
| 1064 |
+
# this node was used on a op which takes in multiple devices and output a cuda
|
| 1065 |
+
# tensor. we can convert its cpu input to cuda without making further changes
|
| 1066 |
+
node_device = self.get_node_device(user)
|
| 1067 |
+
if (
|
| 1068 |
+
self.allow_cpu_device(user)
|
| 1069 |
+
and node_device
|
| 1070 |
+
and node_device.type == self.target
|
| 1071 |
+
):
|
| 1072 |
+
del cpu_indeg[user]
|
| 1073 |
+
else:
|
| 1074 |
+
# otherwise, we should continue look at its downstream uses
|
| 1075 |
+
cpu_indeg[user] -= 1
|
| 1076 |
+
if cpu_indeg[user] == 0:
|
| 1077 |
+
del cpu_indeg[user]
|
| 1078 |
+
queue.append(user)
|
| 1079 |
+
|
| 1080 |
+
unioned_set = make_dependencies_equivalent(
|
| 1081 |
+
dependencies, constructor_dependencies[user]
|
| 1082 |
+
)
|
| 1083 |
+
constructor_dependencies[user] = unioned_set
|
| 1084 |
+
|
| 1085 |
+
for node in cpu_indeg:
|
| 1086 |
+
if constructor_dependencies[node]:
|
| 1087 |
+
cannot_move_to_cuda.update(constructor_dependencies[node])
|
| 1088 |
+
|
| 1089 |
+
all_cannot_move_to_cuda = cannot_move_to_cuda.copy()
|
| 1090 |
+
for constructor in cannot_move_to_cuda:
|
| 1091 |
+
all_cannot_move_to_cuda.update(equal_constructor_sets[constructor])
|
| 1092 |
+
|
| 1093 |
+
return set(constructors) - all_cannot_move_to_cuda
|
| 1094 |
+
|
| 1095 |
+
|
| 1096 |
+
def move_constructors_to_cuda(graph: fx.Graph) -> None:
|
| 1097 |
+
"""
|
| 1098 |
+
Moves intermediary tensors which are constructed on the cpu to cuda when safe
|
| 1099 |
+
"""
|
| 1100 |
+
ConstructorMoverPass("cuda")(graph)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 35 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 36 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 37 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 38 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 39 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 40 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 41 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
|
| 42 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 43 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 44 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 45 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 46 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 47 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 48 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 49 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 50 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 51 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 52 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 53 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 54 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 55 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 56 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 57 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 58 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 59 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 60 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 61 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
|
| 62 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 63 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 64 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 65 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale'))
|
| 66 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 67 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 68 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 69 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 70 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 71 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 72 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 73 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 74 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 75 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 76 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 77 |
+
_sfdp_pattern_1_training = MultiOutputPattern([view_default_5,
|
| 78 |
+
view_default_9,
|
| 79 |
+
permute_default_4,
|
| 80 |
+
view_default_11,
|
| 81 |
+
None
|
| 82 |
+
])
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 86 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 87 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 88 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 89 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 90 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 91 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 92 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
|
| 93 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 94 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 95 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 96 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 97 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 98 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 99 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 100 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 101 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 102 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 103 |
+
_sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 107 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 108 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 109 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 110 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 111 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 112 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 113 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 114 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 115 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 116 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 117 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 118 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 119 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 120 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 121 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 122 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 123 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 124 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 125 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 126 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 127 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 128 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 129 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 130 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 131 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 132 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 133 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 134 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 135 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 136 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 137 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
|
| 138 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 139 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
|
| 140 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 141 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 142 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
|
| 143 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 144 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 145 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 146 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 147 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 148 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 149 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 150 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 151 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 152 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 153 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 154 |
+
_sfdp_pattern_1_half_training = MultiOutputPattern([view_default_5,
|
| 155 |
+
view_default_9,
|
| 156 |
+
permute_default_4,
|
| 157 |
+
view_default_11,
|
| 158 |
+
None
|
| 159 |
+
])
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 163 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 164 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 165 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 166 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 167 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 168 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 169 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 170 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 171 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 172 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 173 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 174 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 175 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 176 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 177 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 178 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 179 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 180 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 181 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 182 |
+
_sfdp_pattern_1_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 35 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 36 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 37 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 38 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 39 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 40 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 41 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 42 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 43 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
|
| 44 |
+
amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
|
| 45 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
|
| 46 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 47 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 48 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 49 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor)
|
| 50 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored())
|
| 51 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored())
|
| 52 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 53 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 54 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 55 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 56 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 57 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 58 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 59 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 60 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 61 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 62 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 63 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
|
| 64 |
+
clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format)
|
| 65 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor)
|
| 66 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 67 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 68 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 69 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
| 70 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
|
| 71 |
+
mul_Tensor_6 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 72 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_5, mul_Tensor_6)
|
| 73 |
+
mul_Tensor_7 = CallFunction(aten.mul.Tensor, sub_Tensor_1, KeywordArg('scale_factor'))
|
| 74 |
+
view_default_8 = CallFunction(aten.view.default, mul_Tensor_7, Ignored(), _users=2)
|
| 75 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 76 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 77 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 78 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 79 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 80 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 81 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 82 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 83 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 84 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 85 |
+
_sfdp_pattern_4_training = MultiOutputPattern([view_default_5,
|
| 86 |
+
view_default_9,
|
| 87 |
+
permute_default_4,
|
| 88 |
+
view_default_11,
|
| 89 |
+
None,
|
| 90 |
+
None
|
| 91 |
+
])
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 95 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 96 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 97 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 98 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 99 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 100 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 101 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
|
| 102 |
+
amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
|
| 103 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
|
| 104 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 105 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 106 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 107 |
+
clone_default = CallFunction(aten.clone.default, div_Tensor)
|
| 108 |
+
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
| 109 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 110 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 111 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 112 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 113 |
+
_sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 117 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 118 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 119 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 120 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 121 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 122 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 123 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 124 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 125 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
|
| 126 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
|
| 127 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 128 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 129 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 130 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 131 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 132 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 133 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 134 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored())
|
| 135 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored())
|
| 136 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 137 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 138 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 139 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 140 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 141 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 142 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 143 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 144 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 145 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 146 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
| 147 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
|
| 148 |
+
clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format)
|
| 149 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
| 150 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 151 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 152 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 153 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 154 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 155 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
|
| 156 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
|
| 157 |
+
mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
|
| 158 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_5, mul_Tensor_6)
|
| 159 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 160 |
+
mul_Tensor_7 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor'))
|
| 161 |
+
view_default_8 = CallFunction(aten.view.default, mul_Tensor_7, Ignored(), _users=2)
|
| 162 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 163 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 164 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 165 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 166 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 167 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 168 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 169 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 170 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 171 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 172 |
+
_sfdp_pattern_4_half_training = MultiOutputPattern([view_default_5,
|
| 173 |
+
view_default_9,
|
| 174 |
+
permute_default_4,
|
| 175 |
+
view_default_11,
|
| 176 |
+
None,
|
| 177 |
+
None
|
| 178 |
+
])
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 182 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 183 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 184 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 185 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 186 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 187 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 188 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
|
| 189 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
|
| 190 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 191 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 192 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 193 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 194 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 195 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
| 196 |
+
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
|
| 197 |
+
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
| 198 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 199 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 200 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 201 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 202 |
+
_sfdp_pattern_4_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 35 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 36 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 37 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 38 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 39 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 40 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 41 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 42 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 43 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 44 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 45 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 46 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 47 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 48 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 49 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 50 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 51 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 52 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 53 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 54 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 55 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 56 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 57 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 58 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 59 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 60 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 61 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 62 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
|
| 63 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 64 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 65 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 66 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored())
|
| 67 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 68 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 69 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 70 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 71 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 72 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 73 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 74 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 75 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 76 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 77 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 78 |
+
_sfdp_pattern_5_training = MultiOutputPattern([view_default_5,
|
| 79 |
+
view_default_9,
|
| 80 |
+
permute_default_4,
|
| 81 |
+
view_default_11,
|
| 82 |
+
None
|
| 83 |
+
])
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 87 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 88 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 89 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 90 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 91 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 92 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 93 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 94 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 95 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 96 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 97 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 98 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 99 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 100 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 101 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 102 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 103 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 104 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 105 |
+
_sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 109 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 110 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 111 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 112 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 113 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 114 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 115 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 116 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 117 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 118 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 119 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 120 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 121 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 122 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 123 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 124 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 125 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 126 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 127 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 128 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 129 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 130 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 131 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 132 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 133 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 134 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 135 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 136 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 137 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 138 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 139 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 140 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
|
| 141 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 142 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
|
| 143 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 144 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 145 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored())
|
| 146 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 147 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 148 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 149 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 150 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 151 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 152 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 153 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 154 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 155 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 156 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 157 |
+
_sfdp_pattern_5_half_training = MultiOutputPattern([view_default_5,
|
| 158 |
+
view_default_9,
|
| 159 |
+
permute_default_4,
|
| 160 |
+
view_default_11,
|
| 161 |
+
None
|
| 162 |
+
])
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 166 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 167 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 168 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 169 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 170 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 171 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 172 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 173 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 174 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 175 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 176 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 177 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 178 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 179 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 180 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 181 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 182 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 183 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 184 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 185 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 186 |
+
_sfdp_pattern_5_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/central_index.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 4 |
+
# To re-generate, run:
|
| 5 |
+
# cd ~/pytorch && python
|
| 6 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 7 |
+
from ._sfdp_pattern_1 import (_sfdp_pattern_1_training, _sfdp_pattern_1_inference, _sfdp_pattern_1_half_training, _sfdp_pattern_1_half_inference)
|
| 8 |
+
from ._sfdp_pattern_2 import (_sfdp_pattern_2_training, _sfdp_pattern_2_inference, _sfdp_pattern_2_half_training, _sfdp_pattern_2_half_inference)
|
| 9 |
+
from ._sfdp_pattern_3 import (_sfdp_pattern_3_training, _sfdp_pattern_3_inference, _sfdp_pattern_3_half_training, _sfdp_pattern_3_half_inference)
|
| 10 |
+
from ._sfdp_pattern_4 import (_sfdp_pattern_4_training, _sfdp_pattern_4_inference, _sfdp_pattern_4_half_training, _sfdp_pattern_4_half_inference)
|
| 11 |
+
from ._sfdp_pattern_5 import (_sfdp_pattern_5_training, _sfdp_pattern_5_inference, _sfdp_pattern_5_half_training, _sfdp_pattern_5_half_inference)
|
| 12 |
+
from ._sfdp_pattern_6 import (_sfdp_pattern_6_training, _sfdp_pattern_6_inference, _sfdp_pattern_6_half_training, _sfdp_pattern_6_half_inference)
|
| 13 |
+
from ._sfdp_pattern_7 import (_sfdp_pattern_7_training, _sfdp_pattern_7_inference, _sfdp_pattern_7_half_training, _sfdp_pattern_7_half_inference)
|
| 14 |
+
from ._sfdp_pattern_8 import (_sfdp_pattern_8_training, _sfdp_pattern_8_inference, _sfdp_pattern_8_half_training, _sfdp_pattern_8_half_inference)
|
| 15 |
+
from ._sfdp_pattern_9 import (_sfdp_pattern_9_training, _sfdp_pattern_9_inference, _sfdp_pattern_9_half_training, _sfdp_pattern_9_half_inference)
|
| 16 |
+
from ._sfdp_pattern_10 import (_sfdp_pattern_10_training, _sfdp_pattern_10_inference, _sfdp_pattern_10_half_training, _sfdp_pattern_10_half_inference)
|
| 17 |
+
from ._sfdp_pattern_11 import (_sfdp_pattern_11_training, _sfdp_pattern_11_inference, _sfdp_pattern_11_half_training, _sfdp_pattern_11_half_inference)
|
| 18 |
+
from ._sfdp_pattern_12 import (_sfdp_pattern_12_training, _sfdp_pattern_12_inference, _sfdp_pattern_12_half_training, _sfdp_pattern_12_half_inference)
|
| 19 |
+
from ._sfdp_pattern_13 import (_sfdp_pattern_13_training, _sfdp_pattern_13_inference, _sfdp_pattern_13_half_training, _sfdp_pattern_13_half_inference)
|
| 20 |
+
from ._sfdp_pattern_14 import (_sfdp_pattern_14_training, _sfdp_pattern_14_inference, _sfdp_pattern_14_half_training, _sfdp_pattern_14_half_inference)
|
| 21 |
+
from ._sfdp_pattern_15 import (_sfdp_pattern_15_training, _sfdp_pattern_15_inference, _sfdp_pattern_15_half_training, _sfdp_pattern_15_half_inference)
|
| 22 |
+
from ._sfdp_pattern_16 import (_sfdp_pattern_16_training, _sfdp_pattern_16_inference, _sfdp_pattern_16_bs1_training, _sfdp_pattern_16_bs1_inference, _sfdp_pattern_16_half_training, _sfdp_pattern_16_half_inference, _sfdp_pattern_16_half_bs1_training, _sfdp_pattern_16_half_bs1_inference, _sfdp_pattern_16_half_mask_fp32_training, _sfdp_pattern_16_half_mask_fp32_inference, _sfdp_pattern_16_half_mask_fp32_bs1_training, _sfdp_pattern_16_half_mask_fp32_bs1_inference)
|
| 23 |
+
from ._sfdp_pattern_17 import (_sfdp_pattern_17_training, _sfdp_pattern_17_inference, _sfdp_pattern_17_half_training, _sfdp_pattern_17_half_inference)
|
| 24 |
+
|
| 25 |
+
central_index = {
|
| 26 |
+
'_sfdp_pattern_1_training': _sfdp_pattern_1_training,
|
| 27 |
+
'_sfdp_pattern_1_inference': _sfdp_pattern_1_inference,
|
| 28 |
+
'_sfdp_pattern_2_training': _sfdp_pattern_2_training,
|
| 29 |
+
'_sfdp_pattern_2_inference': _sfdp_pattern_2_inference,
|
| 30 |
+
'_sfdp_pattern_3_training': _sfdp_pattern_3_training,
|
| 31 |
+
'_sfdp_pattern_3_inference': _sfdp_pattern_3_inference,
|
| 32 |
+
'_sfdp_pattern_4_training': _sfdp_pattern_4_training,
|
| 33 |
+
'_sfdp_pattern_4_inference': _sfdp_pattern_4_inference,
|
| 34 |
+
'_sfdp_pattern_5_training': _sfdp_pattern_5_training,
|
| 35 |
+
'_sfdp_pattern_5_inference': _sfdp_pattern_5_inference,
|
| 36 |
+
'_sfdp_pattern_6_training': _sfdp_pattern_6_training,
|
| 37 |
+
'_sfdp_pattern_6_inference': _sfdp_pattern_6_inference,
|
| 38 |
+
'_sfdp_pattern_7_training': _sfdp_pattern_7_training,
|
| 39 |
+
'_sfdp_pattern_7_inference': _sfdp_pattern_7_inference,
|
| 40 |
+
'_sfdp_pattern_8_training': _sfdp_pattern_8_training,
|
| 41 |
+
'_sfdp_pattern_8_inference': _sfdp_pattern_8_inference,
|
| 42 |
+
'_sfdp_pattern_9_training': _sfdp_pattern_9_training,
|
| 43 |
+
'_sfdp_pattern_9_inference': _sfdp_pattern_9_inference,
|
| 44 |
+
'_sfdp_pattern_10_training': _sfdp_pattern_10_training,
|
| 45 |
+
'_sfdp_pattern_10_inference': _sfdp_pattern_10_inference,
|
| 46 |
+
'_sfdp_pattern_11_training': _sfdp_pattern_11_training,
|
| 47 |
+
'_sfdp_pattern_11_inference': _sfdp_pattern_11_inference,
|
| 48 |
+
'_sfdp_pattern_12_training': _sfdp_pattern_12_training,
|
| 49 |
+
'_sfdp_pattern_12_inference': _sfdp_pattern_12_inference,
|
| 50 |
+
'_sfdp_pattern_13_training': _sfdp_pattern_13_training,
|
| 51 |
+
'_sfdp_pattern_13_inference': _sfdp_pattern_13_inference,
|
| 52 |
+
'_sfdp_pattern_14_training': _sfdp_pattern_14_training,
|
| 53 |
+
'_sfdp_pattern_14_inference': _sfdp_pattern_14_inference,
|
| 54 |
+
'_sfdp_pattern_15_training': _sfdp_pattern_15_training,
|
| 55 |
+
'_sfdp_pattern_15_inference': _sfdp_pattern_15_inference,
|
| 56 |
+
'_sfdp_pattern_16_training': _sfdp_pattern_16_training,
|
| 57 |
+
'_sfdp_pattern_16_inference': _sfdp_pattern_16_inference,
|
| 58 |
+
'_sfdp_pattern_16_bs1_training': _sfdp_pattern_16_bs1_training,
|
| 59 |
+
'_sfdp_pattern_16_bs1_inference': _sfdp_pattern_16_bs1_inference,
|
| 60 |
+
'_sfdp_pattern_17_training': _sfdp_pattern_17_training,
|
| 61 |
+
'_sfdp_pattern_17_inference': _sfdp_pattern_17_inference,
|
| 62 |
+
'_sfdp_pattern_1_half_training': _sfdp_pattern_1_half_training,
|
| 63 |
+
'_sfdp_pattern_1_half_inference': _sfdp_pattern_1_half_inference,
|
| 64 |
+
'_sfdp_pattern_2_half_training': _sfdp_pattern_2_half_training,
|
| 65 |
+
'_sfdp_pattern_2_half_inference': _sfdp_pattern_2_half_inference,
|
| 66 |
+
'_sfdp_pattern_3_half_training': _sfdp_pattern_3_half_training,
|
| 67 |
+
'_sfdp_pattern_3_half_inference': _sfdp_pattern_3_half_inference,
|
| 68 |
+
'_sfdp_pattern_4_half_training': _sfdp_pattern_4_half_training,
|
| 69 |
+
'_sfdp_pattern_4_half_inference': _sfdp_pattern_4_half_inference,
|
| 70 |
+
'_sfdp_pattern_5_half_training': _sfdp_pattern_5_half_training,
|
| 71 |
+
'_sfdp_pattern_5_half_inference': _sfdp_pattern_5_half_inference,
|
| 72 |
+
'_sfdp_pattern_6_half_training': _sfdp_pattern_6_half_training,
|
| 73 |
+
'_sfdp_pattern_6_half_inference': _sfdp_pattern_6_half_inference,
|
| 74 |
+
'_sfdp_pattern_7_half_training': _sfdp_pattern_7_half_training,
|
| 75 |
+
'_sfdp_pattern_7_half_inference': _sfdp_pattern_7_half_inference,
|
| 76 |
+
'_sfdp_pattern_8_half_training': _sfdp_pattern_8_half_training,
|
| 77 |
+
'_sfdp_pattern_8_half_inference': _sfdp_pattern_8_half_inference,
|
| 78 |
+
'_sfdp_pattern_9_half_training': _sfdp_pattern_9_half_training,
|
| 79 |
+
'_sfdp_pattern_9_half_inference': _sfdp_pattern_9_half_inference,
|
| 80 |
+
'_sfdp_pattern_10_half_training': _sfdp_pattern_10_half_training,
|
| 81 |
+
'_sfdp_pattern_10_half_inference': _sfdp_pattern_10_half_inference,
|
| 82 |
+
'_sfdp_pattern_11_half_training': _sfdp_pattern_11_half_training,
|
| 83 |
+
'_sfdp_pattern_11_half_inference': _sfdp_pattern_11_half_inference,
|
| 84 |
+
'_sfdp_pattern_12_half_training': _sfdp_pattern_12_half_training,
|
| 85 |
+
'_sfdp_pattern_12_half_inference': _sfdp_pattern_12_half_inference,
|
| 86 |
+
'_sfdp_pattern_13_half_training': _sfdp_pattern_13_half_training,
|
| 87 |
+
'_sfdp_pattern_13_half_inference': _sfdp_pattern_13_half_inference,
|
| 88 |
+
'_sfdp_pattern_14_half_training': _sfdp_pattern_14_half_training,
|
| 89 |
+
'_sfdp_pattern_14_half_inference': _sfdp_pattern_14_half_inference,
|
| 90 |
+
'_sfdp_pattern_15_half_training': _sfdp_pattern_15_half_training,
|
| 91 |
+
'_sfdp_pattern_15_half_inference': _sfdp_pattern_15_half_inference,
|
| 92 |
+
'_sfdp_pattern_16_half_training': _sfdp_pattern_16_half_training,
|
| 93 |
+
'_sfdp_pattern_16_half_inference': _sfdp_pattern_16_half_inference,
|
| 94 |
+
'_sfdp_pattern_16_half_bs1_training': _sfdp_pattern_16_half_bs1_training,
|
| 95 |
+
'_sfdp_pattern_16_half_bs1_inference': _sfdp_pattern_16_half_bs1_inference,
|
| 96 |
+
'_sfdp_pattern_17_half_training': _sfdp_pattern_17_half_training,
|
| 97 |
+
'_sfdp_pattern_17_half_inference': _sfdp_pattern_17_half_inference,
|
| 98 |
+
'_sfdp_pattern_16_half_mask_fp32_training': _sfdp_pattern_16_half_mask_fp32_training,
|
| 99 |
+
'_sfdp_pattern_16_half_mask_fp32_inference': _sfdp_pattern_16_half_mask_fp32_inference,
|
| 100 |
+
'_sfdp_pattern_16_half_mask_fp32_bs1_training': _sfdp_pattern_16_half_mask_fp32_bs1_training,
|
| 101 |
+
'_sfdp_pattern_16_half_mask_fp32_bs1_inference': _sfdp_pattern_16_half_mask_fp32_bs1_inference,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_serialized_pattern(key):
|
| 106 |
+
import torch._inductor # noqa: F401
|
| 107 |
+
from torch._inductor import config
|
| 108 |
+
if config.fallback_random:
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
# TODO - could add more validation that the same set of decomps used when
|
| 112 |
+
# tracing SDPA are also used in current context. softmax, dropout, etc
|
| 113 |
+
# decomp use is stable so not an issue in practice.
|
| 114 |
+
return central_index.get(key)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/split_cat.py
ADDED
|
@@ -0,0 +1,1537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import logging
|
| 3 |
+
import operator
|
| 4 |
+
from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, Union
|
| 5 |
+
|
| 6 |
+
from typing_extensions import TypeAlias
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch._dynamo.utils import counters
|
| 10 |
+
|
| 11 |
+
from ..pattern_matcher import (
|
| 12 |
+
Arg,
|
| 13 |
+
CallFunction,
|
| 14 |
+
CallFunctionVarArgs,
|
| 15 |
+
CallMethodVarArgs,
|
| 16 |
+
config_flag,
|
| 17 |
+
FailedMatch,
|
| 18 |
+
get_arg_value,
|
| 19 |
+
Ignored,
|
| 20 |
+
KeywordArg,
|
| 21 |
+
ListOf,
|
| 22 |
+
Match,
|
| 23 |
+
MatchContext,
|
| 24 |
+
MULTIPLE,
|
| 25 |
+
PatternExpr,
|
| 26 |
+
register_graph_pattern,
|
| 27 |
+
RepeatedExpr,
|
| 28 |
+
)
|
| 29 |
+
from .group_batch_fusion import is_node_meta_valid
|
| 30 |
+
from .pre_grad import (
|
| 31 |
+
merge_getitem_cat_pass,
|
| 32 |
+
merge_splits_pass,
|
| 33 |
+
normalization_pass,
|
| 34 |
+
split_cat_pass,
|
| 35 |
+
unbind_stack_pass,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
log = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
_Arguments: TypeAlias = Tuple[torch.fx.node.Argument, ...]
|
| 41 |
+
_TransformParam: TypeAlias = Tuple[
|
| 42 |
+
Optional[_Arguments],
|
| 43 |
+
Optional[_Arguments],
|
| 44 |
+
Optional[_Arguments],
|
| 45 |
+
Optional[_Arguments],
|
| 46 |
+
]
|
| 47 |
+
_Range: TypeAlias = Tuple[int, int]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _get_split_args_default(split_node):
|
| 51 |
+
input_kwarg = "tensor"
|
| 52 |
+
split_size_kwarg = "split_size_or_sections"
|
| 53 |
+
dim_kwarg = "dim"
|
| 54 |
+
default_dim_value = 0
|
| 55 |
+
if split_node.op == "call_method":
|
| 56 |
+
split_size_kwarg = "split_size"
|
| 57 |
+
return (
|
| 58 |
+
get_arg_value(split_node, 0, input_kwarg),
|
| 59 |
+
get_arg_value(split_node, 1, split_size_kwarg),
|
| 60 |
+
get_arg_value(split_node, 2, dim_kwarg) or default_dim_value,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# noqa: W605
|
| 65 |
+
# ############The pattern to be optimized is#########
|
| 66 |
+
# unbind (dim=0)
|
| 67 |
+
# / ... \
|
| 68 |
+
# getitem getitem -> user=1
|
| 69 |
+
# | |
|
| 70 |
+
# split split -> dim=1, user=1, split_section_size=1
|
| 71 |
+
# | |
|
| 72 |
+
# getitem getitem -> user=1
|
| 73 |
+
# \ /
|
| 74 |
+
# cat (dim=1) -> user=1
|
| 75 |
+
# |
|
| 76 |
+
|
| 77 |
+
# ################After transformation#############
|
| 78 |
+
# unbind (dim=0)
|
| 79 |
+
# / ... \
|
| 80 |
+
# getitem getitem -> user=1
|
| 81 |
+
# \ /
|
| 82 |
+
# cat (dim=1) -> user=1
|
| 83 |
+
# |
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def remove_split_with_size_one(
|
| 87 |
+
graph: torch.fx.Graph,
|
| 88 |
+
node: torch.fx.Node,
|
| 89 |
+
input: torch.fx.Node,
|
| 90 |
+
):
|
| 91 |
+
# find the grand children of the split_node
|
| 92 |
+
next_users = find_next_users(node)
|
| 93 |
+
user = next(iter(node.users.keys()))
|
| 94 |
+
# replace the users of grand child node with the input node
|
| 95 |
+
for next_user in next_users:
|
| 96 |
+
next_user.replace_input_with(user, input)
|
| 97 |
+
# erase the split node and its child
|
| 98 |
+
graph.erase_node(user)
|
| 99 |
+
graph.erase_node(node)
|
| 100 |
+
|
| 101 |
+
counters["inductor"]["remove_split_with_size_one"] += 1
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def normalize_split_base(
|
| 105 |
+
match: Match,
|
| 106 |
+
_get_split_args: Callable[
|
| 107 |
+
[torch.fx.Node], Tuple[Optional[torch.fx.Node], Optional[Any], Optional[int]]
|
| 108 |
+
],
|
| 109 |
+
):
|
| 110 |
+
"""
|
| 111 |
+
Normalize split with split_size into split_with_sizes, so that we only deal with one type of split in
|
| 112 |
+
subsequent optimizations
|
| 113 |
+
"""
|
| 114 |
+
split_node = match.nodes[0]
|
| 115 |
+
graph = match.graph
|
| 116 |
+
split_input, split_size, split_dim = _get_split_args(split_node)
|
| 117 |
+
if split_input is None or split_dim is None or split_size is None:
|
| 118 |
+
log.debug("couldn't find split args")
|
| 119 |
+
return
|
| 120 |
+
if "example_value" not in split_node.meta:
|
| 121 |
+
log.debug("example value absent for node: %s", split_node)
|
| 122 |
+
return
|
| 123 |
+
assert isinstance(split_node.meta["example_value"], (list, tuple))
|
| 124 |
+
split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]]
|
| 125 |
+
|
| 126 |
+
if any(isinstance(section, torch.SymInt) for section in split_sections):
|
| 127 |
+
# TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing.
|
| 128 |
+
return
|
| 129 |
+
# remove the dummy split whose split sections size is one
|
| 130 |
+
if len(split_sections) == 1:
|
| 131 |
+
remove_split_with_size_one(graph, split_node, split_input)
|
| 132 |
+
return
|
| 133 |
+
if split_dim < 0: # Normalize split dim
|
| 134 |
+
split_dim += split_input.meta["example_value"].dim()
|
| 135 |
+
with graph.inserting_after(split_node):
|
| 136 |
+
new_split_node = graph.call_function(
|
| 137 |
+
torch.split,
|
| 138 |
+
args=(split_input, split_sections),
|
| 139 |
+
kwargs={"dim": split_dim},
|
| 140 |
+
)
|
| 141 |
+
split_node.replace_all_uses_with(new_split_node)
|
| 142 |
+
new_split_node.meta.update(split_node.meta)
|
| 143 |
+
graph.erase_node(split_node)
|
| 144 |
+
counters["inductor"]["split_cat_norm"] += 1
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@register_graph_pattern(
|
| 148 |
+
CallFunctionVarArgs(torch.split, users=MULTIPLE),
|
| 149 |
+
pass_dict=normalization_pass,
|
| 150 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 151 |
+
)
|
| 152 |
+
@register_graph_pattern(
|
| 153 |
+
CallMethodVarArgs("split", users=MULTIPLE),
|
| 154 |
+
pass_dict=normalization_pass,
|
| 155 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 156 |
+
)
|
| 157 |
+
def normalize_split_default(match: Match, *args, **kwargs):
|
| 158 |
+
return normalize_split_base(match, _get_split_args_default)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@register_graph_pattern(
|
| 162 |
+
CallFunctionVarArgs(torch.unbind, users=MULTIPLE),
|
| 163 |
+
pass_dict=normalization_pass,
|
| 164 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 165 |
+
)
|
| 166 |
+
@register_graph_pattern(
|
| 167 |
+
CallMethodVarArgs("unbind", users=MULTIPLE),
|
| 168 |
+
pass_dict=normalization_pass,
|
| 169 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 170 |
+
)
|
| 171 |
+
def normalize_unbind_default(match: Match, *args, **kwargs):
|
| 172 |
+
node = match.nodes[0]
|
| 173 |
+
graph = match.graph
|
| 174 |
+
input = get_arg_value(node, 0, "input")
|
| 175 |
+
dim = get_arg_value(node, 1, "dim")
|
| 176 |
+
if dim is None:
|
| 177 |
+
axis = node.kwargs.get("axis")
|
| 178 |
+
if axis is not None:
|
| 179 |
+
dim = axis
|
| 180 |
+
else:
|
| 181 |
+
dim = 0
|
| 182 |
+
if input is None:
|
| 183 |
+
log.debug("couldn't find unbind args")
|
| 184 |
+
return
|
| 185 |
+
if "example_value" not in input.meta:
|
| 186 |
+
log.debug("example value absent for node: %s", input)
|
| 187 |
+
return
|
| 188 |
+
ndim = input.meta["example_value"].ndim
|
| 189 |
+
if dim < 0: # Normalize unbind dim
|
| 190 |
+
dim += ndim
|
| 191 |
+
with graph.inserting_after(node):
|
| 192 |
+
new_node = graph.call_function(
|
| 193 |
+
torch.unbind,
|
| 194 |
+
args=(input,),
|
| 195 |
+
kwargs={"dim": dim},
|
| 196 |
+
)
|
| 197 |
+
node.replace_all_uses_with(new_node)
|
| 198 |
+
new_node.meta.update(node.meta)
|
| 199 |
+
graph.erase_node(node)
|
| 200 |
+
counters["inductor"]["split_cat_norm"] += 1
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@register_graph_pattern(
|
| 204 |
+
CallFunctionVarArgs(torch.cat, users=MULTIPLE),
|
| 205 |
+
pass_dict=normalization_pass,
|
| 206 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 207 |
+
)
|
| 208 |
+
def normalize_cat_default(match: Match, *args, **kwargs):
|
| 209 |
+
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
| 210 |
+
|
| 211 |
+
cat_node = match.nodes[0]
|
| 212 |
+
graph = match.graph
|
| 213 |
+
tensors = get_arg_value(cat_node, 0, "tensors")
|
| 214 |
+
cat_dim = get_arg_value(cat_node, 1, "dim")
|
| 215 |
+
if cat_dim is None:
|
| 216 |
+
cat_axis = cat_node.kwargs.get("axis")
|
| 217 |
+
if cat_axis is not None:
|
| 218 |
+
cat_dim = cat_axis
|
| 219 |
+
else:
|
| 220 |
+
cat_dim = 0
|
| 221 |
+
if tensors is None or cat_dim is None:
|
| 222 |
+
log.debug("couldn't find cat args")
|
| 223 |
+
return
|
| 224 |
+
assert isinstance(tensors, (list, tuple))
|
| 225 |
+
for tensor in itertools.chain([cat_node], tensors):
|
| 226 |
+
if "example_value" not in tensor.meta:
|
| 227 |
+
log.debug("example value absent for node: %s", tensor)
|
| 228 |
+
return
|
| 229 |
+
|
| 230 |
+
ndim = cat_node.meta["example_value"].dim()
|
| 231 |
+
|
| 232 |
+
def is_empty_tensor(x):
|
| 233 |
+
# special case where torch.cat supports cat'ing with an empty tensor
|
| 234 |
+
x_shape = x.meta["example_value"].shape
|
| 235 |
+
return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0)
|
| 236 |
+
|
| 237 |
+
assert all(
|
| 238 |
+
ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if cat_dim < 0: # Normalize cat dim
|
| 242 |
+
cat_dim += ndim
|
| 243 |
+
|
| 244 |
+
with graph.inserting_after(cat_node):
|
| 245 |
+
new_cat_node = graph.call_function(
|
| 246 |
+
torch.cat,
|
| 247 |
+
args=(tensors,),
|
| 248 |
+
kwargs={"dim": cat_dim},
|
| 249 |
+
)
|
| 250 |
+
cat_node.replace_all_uses_with(new_cat_node)
|
| 251 |
+
new_cat_node.meta.update(cat_node.meta)
|
| 252 |
+
graph.erase_node(cat_node)
|
| 253 |
+
counters["inductor"]["split_cat_norm"] += 1
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@register_graph_pattern(
|
| 257 |
+
CallFunctionVarArgs(torch.stack, users=MULTIPLE),
|
| 258 |
+
pass_dict=normalization_pass,
|
| 259 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 260 |
+
)
|
| 261 |
+
def normalize_stack_default(match: Match, *args, **kwargs):
|
| 262 |
+
node = match.nodes[0]
|
| 263 |
+
graph = match.graph
|
| 264 |
+
tensors = get_arg_value(node, 0, "tensors")
|
| 265 |
+
dim = get_arg_value(node, 1, "dim") or 0
|
| 266 |
+
if tensors is None or dim is None:
|
| 267 |
+
log.debug("couldn't find stack args")
|
| 268 |
+
return
|
| 269 |
+
assert isinstance(tensors, (list, tuple))
|
| 270 |
+
|
| 271 |
+
# A bug in pytorch, some nodes miss the example_value metadata
|
| 272 |
+
for tensor in itertools.chain([node], tensors):
|
| 273 |
+
if "example_value" not in tensor.meta:
|
| 274 |
+
log.debug("example value absent for node: %s", tensor)
|
| 275 |
+
return
|
| 276 |
+
|
| 277 |
+
ndim = node.meta["example_value"].dim()
|
| 278 |
+
if dim < 0: # Normalize dim
|
| 279 |
+
dim += ndim
|
| 280 |
+
|
| 281 |
+
with graph.inserting_after(node):
|
| 282 |
+
new_node = graph.call_function(
|
| 283 |
+
node.target,
|
| 284 |
+
args=(tensors,),
|
| 285 |
+
kwargs={"dim": dim},
|
| 286 |
+
)
|
| 287 |
+
node.replace_all_uses_with(new_node)
|
| 288 |
+
new_node.meta.update(node.meta)
|
| 289 |
+
graph.erase_node(node)
|
| 290 |
+
counters["inductor"]["split_cat_norm"] += 1
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def find_next_users(split_node: torch.fx.Node) -> List[torch.fx.Node]:
|
| 294 |
+
next_users = []
|
| 295 |
+
for getitem_node in split_node.users.keys():
|
| 296 |
+
for getitem_user in getitem_node.users.keys():
|
| 297 |
+
if getitem_user not in next_users:
|
| 298 |
+
next_users.append(getitem_user)
|
| 299 |
+
return next_users
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@register_graph_pattern(
|
| 303 |
+
CallMethodVarArgs("squeeze", users=MULTIPLE),
|
| 304 |
+
pass_dict=normalization_pass,
|
| 305 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 306 |
+
)
|
| 307 |
+
def normalize_squeeze_default(match: Match, *args, **kwargs):
|
| 308 |
+
squeeze_node = match.nodes[0]
|
| 309 |
+
squeeze_input = get_arg_value(squeeze_node, 0)
|
| 310 |
+
|
| 311 |
+
if "dim" in squeeze_node.kwargs:
|
| 312 |
+
assert len(squeeze_node.args) == 1
|
| 313 |
+
dim = squeeze_node.kwargs["dim"]
|
| 314 |
+
elif len(squeeze_node.args) == 1:
|
| 315 |
+
# squeeze(Tensor)
|
| 316 |
+
dim = None
|
| 317 |
+
elif len(squeeze_node.args) == 2:
|
| 318 |
+
# squeeze(Tensor self, int dim)
|
| 319 |
+
# squeeze(Tensor self, int[] dim)
|
| 320 |
+
dim = squeeze_node.args[1]
|
| 321 |
+
else:
|
| 322 |
+
# squeeze(Tensor self, int[] dim) (called with varargs)
|
| 323 |
+
dim = squeeze_node.args[1:]
|
| 324 |
+
|
| 325 |
+
if isinstance(dim, Sequence) and len(dim) == 1:
|
| 326 |
+
dim = dim[0]
|
| 327 |
+
|
| 328 |
+
with match.graph.inserting_after(squeeze_node):
|
| 329 |
+
if dim is None:
|
| 330 |
+
new_squeeze_node = match.graph.call_function(
|
| 331 |
+
torch.squeeze, args=(squeeze_input,)
|
| 332 |
+
)
|
| 333 |
+
else:
|
| 334 |
+
new_squeeze_node = match.graph.call_function(
|
| 335 |
+
torch.squeeze, args=(squeeze_input,), kwargs={"dim": dim}
|
| 336 |
+
)
|
| 337 |
+
squeeze_node.replace_all_uses_with(new_squeeze_node)
|
| 338 |
+
match.graph.erase_node(squeeze_node)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class TorchSplit(CallFunction):
|
| 342 |
+
"""
|
| 343 |
+
Matches a call to torch.split if it is in a normalized form. Ensures that all users of
|
| 344 |
+
splits are unique getitems.
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
def __init__(self, arg, sizes, func=torch.split):
|
| 348 |
+
# using KeywordArg("dim") for `dim` checks they all match
|
| 349 |
+
super().__init__(func, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim"))
|
| 350 |
+
|
| 351 |
+
def _match(self, node: torch.fx.Node, ctx: MatchContext):
|
| 352 |
+
m = super()._match(node, ctx)
|
| 353 |
+
if not m:
|
| 354 |
+
return m
|
| 355 |
+
split_sections = node.args[1]
|
| 356 |
+
if not isinstance(split_sections, (list, tuple)):
|
| 357 |
+
return FailedMatch("split not normalized")
|
| 358 |
+
# check users are all unique getitems
|
| 359 |
+
seen_idxs = set()
|
| 360 |
+
for user in node.users:
|
| 361 |
+
if not CallFunction(operator.getitem, Arg(), Arg()).match(user):
|
| 362 |
+
# This should ideally never happen. Split user should always be a getitem
|
| 363 |
+
return FailedMatch(f"user of split not a getitem: {user}")
|
| 364 |
+
if not isinstance(user.args[1], int):
|
| 365 |
+
return FailedMatch("only integer getitems are handled")
|
| 366 |
+
if user.args[1] in seen_idxs:
|
| 367 |
+
return FailedMatch(f"duplicate getitem {user.args[1]}")
|
| 368 |
+
if user.args[-1] < 0: # type: ignore[operator]
|
| 369 |
+
# This shouldn't ideally happen as dynamo normalizes indexes to positive
|
| 370 |
+
return FailedMatch("negative index")
|
| 371 |
+
seen_idxs.add(user.args[1])
|
| 372 |
+
return m
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
@register_graph_pattern(
|
| 376 |
+
TorchSplit(
|
| 377 |
+
CallFunction(
|
| 378 |
+
operator.getitem,
|
| 379 |
+
TorchSplit(
|
| 380 |
+
KeywordArg("first_split_input"),
|
| 381 |
+
KeywordArg("first_split_sections"),
|
| 382 |
+
),
|
| 383 |
+
Ignored(),
|
| 384 |
+
),
|
| 385 |
+
KeywordArg("next_split_sections"),
|
| 386 |
+
),
|
| 387 |
+
pass_dict=merge_splits_pass,
|
| 388 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 389 |
+
)
|
| 390 |
+
def merge_splits(
|
| 391 |
+
match: Match,
|
| 392 |
+
first_split_input: torch.fx.Node,
|
| 393 |
+
first_split_sections: List[int],
|
| 394 |
+
next_split_sections: List[int],
|
| 395 |
+
# Note: dim is implicitly passed by TorchSplit, as it internally uses a pattern with dim
|
| 396 |
+
dim: int,
|
| 397 |
+
):
|
| 398 |
+
node = match.output_node()
|
| 399 |
+
# it is possible that the split has no users,
|
| 400 |
+
# we check the corner case and skip the pattern
|
| 401 |
+
if len(node.users.keys()) == 0:
|
| 402 |
+
return
|
| 403 |
+
graph = match.graph
|
| 404 |
+
first_split = node.args[0].args[0] # type: ignore[union-attr]
|
| 405 |
+
next_split_index = node.args[0].args[1] # type: ignore[union-attr]
|
| 406 |
+
|
| 407 |
+
new_split_sections = list(first_split_sections)
|
| 408 |
+
new_split_sections[next_split_index : next_split_index + 1] = next_split_sections # type: ignore[operator, misc]
|
| 409 |
+
|
| 410 |
+
first_split_dim = first_split.kwargs["dim"] # type: ignore[union-attr]
|
| 411 |
+
|
| 412 |
+
to_remove = []
|
| 413 |
+
|
| 414 |
+
with graph.inserting_before(first_split):
|
| 415 |
+
# Add the new split node
|
| 416 |
+
new_split = graph.call_function(
|
| 417 |
+
torch.split,
|
| 418 |
+
args=(first_split_input, new_split_sections),
|
| 419 |
+
kwargs={"dim": first_split_dim},
|
| 420 |
+
)
|
| 421 |
+
first_split_num_to_user = {
|
| 422 |
+
user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr]
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
new_split_num = 0
|
| 426 |
+
for split_num in range(len(first_split_sections)):
|
| 427 |
+
if split_num not in first_split_num_to_user:
|
| 428 |
+
new_split_num += 1
|
| 429 |
+
continue
|
| 430 |
+
old_getitem = first_split_num_to_user[split_num]
|
| 431 |
+
if split_num != next_split_index:
|
| 432 |
+
old_getitem.update_arg(0, new_split)
|
| 433 |
+
old_getitem.update_arg(1, new_split_num)
|
| 434 |
+
new_split_num += 1
|
| 435 |
+
else:
|
| 436 |
+
next_split_num_to_user = {
|
| 437 |
+
user.args[1]: user for user in node.users.keys()
|
| 438 |
+
}
|
| 439 |
+
# It is not necessary all getitems from the split node are used.
|
| 440 |
+
# We use the num of users to check the getitems to be merged.
|
| 441 |
+
for next_split_num in range(len(node.users.keys())):
|
| 442 |
+
with graph.inserting_after(new_split):
|
| 443 |
+
new_getitem = graph.call_function(
|
| 444 |
+
operator.getitem, args=(new_split, new_split_num)
|
| 445 |
+
)
|
| 446 |
+
new_split_num += 1
|
| 447 |
+
next_getitem = next_split_num_to_user[next_split_num]
|
| 448 |
+
new_getitem.meta.update(next_getitem.meta)
|
| 449 |
+
next_getitem.replace_all_uses_with(new_getitem)
|
| 450 |
+
to_remove.append(next_getitem)
|
| 451 |
+
to_remove.append(node)
|
| 452 |
+
to_remove.append(old_getitem)
|
| 453 |
+
|
| 454 |
+
to_remove.append(first_split) # type: ignore[arg-type]
|
| 455 |
+
for node in to_remove:
|
| 456 |
+
graph.erase_node(node)
|
| 457 |
+
|
| 458 |
+
counters["inductor"]["consecutive_split_merged"] += 1
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
class SplitCatSimplifier:
|
| 462 |
+
"""
|
| 463 |
+
Helper class to simplify split-cat pattern. In simple cases, both split and cat node can be removed in a "split->cat"
|
| 464 |
+
pattern. However, there are various cases where they can't and we need to simplify split/ add transforms before cat.
|
| 465 |
+
Some such cases are:
|
| 466 |
+
1. Final node has additional args (not coming from the initial split)
|
| 467 |
+
2. Shuffling of args between split/cat
|
| 468 |
+
3. Some final nodes are non-(cat/stack)
|
| 469 |
+
4. Split-dim != cat-dim (but equal split)
|
| 470 |
+
|
| 471 |
+
Note that any combination of the above cases can happen.
|
| 472 |
+
|
| 473 |
+
To deal with 1, 2, & 3 - we iterate over all users of split. And figure out common "ranges" that can be merged.
|
| 474 |
+
Then, we simplify the split accordingly. In the best case, split can be entirely removed.
|
| 475 |
+
|
| 476 |
+
To deal with 4, we add some transformations (unflatten + movedim) (See `get_transform_params`).
|
| 477 |
+
|
| 478 |
+
Finally, depending on final node being cat or stack, unsqueeze/flatten needs to be added.
|
| 479 |
+
|
| 480 |
+
"""
|
| 481 |
+
|
| 482 |
+
def simplify(
|
| 483 |
+
self,
|
| 484 |
+
graph: torch.fx.Graph,
|
| 485 |
+
split_node: torch.fx.Node,
|
| 486 |
+
split_sections: List[int],
|
| 487 |
+
):
|
| 488 |
+
# Find the next users (i.e. users after the getitem)
|
| 489 |
+
next_users = find_next_users(split_node)
|
| 490 |
+
# Gather inputs of the next users. When inputs come from `split_node`, they are instead represented by
|
| 491 |
+
# a tuple indicating the split ranges. See `get_user_input_list` for more details
|
| 492 |
+
user_inputs_list = self.get_user_input_list(split_node, next_users)
|
| 493 |
+
# Simplify the split_sections based on user_inputs_list. In simpler cases, len(simplified_split_ranges) == 1 and
|
| 494 |
+
# we can simply replace the split node. Otherwise, we simplify it.
|
| 495 |
+
simplified_split_ranges = self.get_simplified_split_ranges(
|
| 496 |
+
split_sections, next_users, user_inputs_list
|
| 497 |
+
)
|
| 498 |
+
if not simplified_split_ranges: # Simplification not possible
|
| 499 |
+
return
|
| 500 |
+
transform_params_list = self.get_transform_params(
|
| 501 |
+
split_node, next_users, user_inputs_list
|
| 502 |
+
)
|
| 503 |
+
if not transform_params_list:
|
| 504 |
+
return
|
| 505 |
+
|
| 506 |
+
# Start actual replacement
|
| 507 |
+
user_inputs_list_new = self.replace_split(
|
| 508 |
+
graph, split_node, split_sections, user_inputs_list, simplified_split_ranges
|
| 509 |
+
)
|
| 510 |
+
self.replace_cat(
|
| 511 |
+
graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type]
|
| 512 |
+
)
|
| 513 |
+
self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type]
|
| 514 |
+
|
| 515 |
+
def get_user_input_list(
|
| 516 |
+
self, split_node: torch.fx.Node, next_users: List[torch.fx.Node]
|
| 517 |
+
) -> List[List[Union[torch.fx.Node, _Range]]]:
|
| 518 |
+
"""
|
| 519 |
+
Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner
|
| 520 |
+
list represents the inputs to that particular node. This list can either contain
|
| 521 |
+
- a tuple representing the ranges of get_items that should go into the cat (closed interval)
|
| 522 |
+
- torch.fx.Node representing "other" inputs (which are not coming from our split)
|
| 523 |
+
"""
|
| 524 |
+
user_inputs_list: List[List[Union[torch.fx.Node, _Range]]] = []
|
| 525 |
+
for user in next_users:
|
| 526 |
+
if user.target in {torch.cat, torch.stack}:
|
| 527 |
+
user_inputs_list.append(self.get_merged_user_inputs(split_node, user))
|
| 528 |
+
else:
|
| 529 |
+
user_inputs_list.append(self.get_non_cat_node_input(split_node, user)) # type: ignore[arg-type]
|
| 530 |
+
return user_inputs_list
|
| 531 |
+
|
| 532 |
+
def get_merged_user_inputs(
|
| 533 |
+
self, split_node: torch.fx.Node, cat_node: torch.fx.Node
|
| 534 |
+
) -> List[Union[torch.fx.Node, _Range]]:
|
| 535 |
+
user_inputs = get_arg_value(cat_node, 0, "tensors")
|
| 536 |
+
simplified_user_inputs = []
|
| 537 |
+
split_users = set(split_node.users.keys())
|
| 538 |
+
for user_input in user_inputs:
|
| 539 |
+
if user_input not in split_users:
|
| 540 |
+
simplified_user_inputs.append(user_input)
|
| 541 |
+
else:
|
| 542 |
+
# Add which "getitem" cat depends on
|
| 543 |
+
simplified_user_inputs.append(user_input.args[1])
|
| 544 |
+
return self.merge_consecutive_inputs(simplified_user_inputs)
|
| 545 |
+
|
| 546 |
+
def get_non_cat_node_input(
|
| 547 |
+
self, split_node: torch.fx.Node, node: torch.fx.Node
|
| 548 |
+
) -> List[_Range]:
|
| 549 |
+
"""
|
| 550 |
+
Get input for a non cat node in the same format as `get_merged_user_inputs`
|
| 551 |
+
"""
|
| 552 |
+
node_input = []
|
| 553 |
+
split_users = set(split_node.users.keys())
|
| 554 |
+
for node_arg in node.all_input_nodes:
|
| 555 |
+
if node_arg in split_users:
|
| 556 |
+
getitem_num = get_arg_value(node_arg, 1)
|
| 557 |
+
node_input.append((getitem_num, getitem_num))
|
| 558 |
+
return node_input
|
| 559 |
+
|
| 560 |
+
def merge_consecutive_inputs(
|
| 561 |
+
self, inputs: List[Union[torch.fx.Node, int]]
|
| 562 |
+
) -> List[Union[torch.fx.Node, _Range]]:
|
| 563 |
+
"""
|
| 564 |
+
Merge consecutive inputs going into a user node.
|
| 565 |
+
|
| 566 |
+
For e.g.
|
| 567 |
+
[arg0, 0, 1, 2, arg1] -> [arg0, (0, 2), arg1]
|
| 568 |
+
"""
|
| 569 |
+
merged_ranges = []
|
| 570 |
+
cur_range = None
|
| 571 |
+
for input_ in inputs:
|
| 572 |
+
if isinstance(input_, int):
|
| 573 |
+
if not cur_range:
|
| 574 |
+
cur_range = [input_, input_]
|
| 575 |
+
elif input_ == cur_range[1] + 1:
|
| 576 |
+
cur_range[1] += 1
|
| 577 |
+
else:
|
| 578 |
+
merged_ranges.append(tuple(cur_range))
|
| 579 |
+
cur_range = [input_, input_]
|
| 580 |
+
else:
|
| 581 |
+
if cur_range:
|
| 582 |
+
merged_ranges.append(tuple(cur_range))
|
| 583 |
+
cur_range = None
|
| 584 |
+
merged_ranges.append(input_) # type: ignore[arg-type]
|
| 585 |
+
if cur_range:
|
| 586 |
+
merged_ranges.append(tuple(cur_range))
|
| 587 |
+
return merged_ranges # type: ignore[return-value]
|
| 588 |
+
|
| 589 |
+
def get_simplified_split_ranges(
|
| 590 |
+
self,
|
| 591 |
+
split_sections,
|
| 592 |
+
next_users,
|
| 593 |
+
user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
|
| 594 |
+
) -> Optional[List[_Range]]:
|
| 595 |
+
ranges = set()
|
| 596 |
+
for user_node, user_inputs in zip(next_users, user_inputs_list):
|
| 597 |
+
ranges |= {
|
| 598 |
+
user_input
|
| 599 |
+
for user_input in user_inputs
|
| 600 |
+
if isinstance(user_input, tuple)
|
| 601 |
+
}
|
| 602 |
+
cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist()
|
| 603 |
+
split_ranges = sorted(
|
| 604 |
+
[(cumulative_sizes[r[0]], cumulative_sizes[r[1] + 1]) for r in ranges]
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
if not self.has_non_overlapping_ranges(
|
| 608 |
+
split_ranges,
|
| 609 |
+
): # This need not be a strict condition
|
| 610 |
+
# However, we keep it now for simplicity.
|
| 611 |
+
return None
|
| 612 |
+
split_ranges = self.fill_gaps(split_ranges, 0, cumulative_sizes[-1])
|
| 613 |
+
if len(split_sections) == len(split_ranges): # Simplification not possible
|
| 614 |
+
return None
|
| 615 |
+
counters["inductor"]["scmerge_split_sections_removed"] = len(
|
| 616 |
+
split_sections
|
| 617 |
+
) - len(split_ranges)
|
| 618 |
+
return split_ranges
|
| 619 |
+
|
| 620 |
+
def has_non_overlapping_ranges(self, ranges: List[_Range]) -> bool:
|
| 621 |
+
for range_, next_range in zip(ranges, ranges[1:]):
|
| 622 |
+
if range_[1] > next_range[0]:
|
| 623 |
+
return False
|
| 624 |
+
return True
|
| 625 |
+
|
| 626 |
+
def fill_gaps(self, ranges: List[_Range], min_: int, max_: int) -> List[_Range]:
|
| 627 |
+
cur = min_
|
| 628 |
+
filled_ranges = []
|
| 629 |
+
for a, b in ranges:
|
| 630 |
+
if cur < a:
|
| 631 |
+
filled_ranges.append((cur, a))
|
| 632 |
+
filled_ranges.append((a, b))
|
| 633 |
+
cur = b
|
| 634 |
+
if filled_ranges[-1][1] < max_:
|
| 635 |
+
filled_ranges.append((filled_ranges[-1][1], max_))
|
| 636 |
+
return filled_ranges
|
| 637 |
+
|
| 638 |
+
def get_transform_params(
|
| 639 |
+
self,
|
| 640 |
+
split_node: torch.fx.Node,
|
| 641 |
+
next_users: List[torch.fx.Node],
|
| 642 |
+
user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
|
| 643 |
+
) -> Optional[List[List[_TransformParam]]]:
|
| 644 |
+
"""
|
| 645 |
+
Figure out what transforms are needed for each input to each cat node.
|
| 646 |
+
|
| 647 |
+
We replace a split node with an unflatten followed by a movedim
|
| 648 |
+
"""
|
| 649 |
+
split_dim = split_node.kwargs["dim"]
|
| 650 |
+
split_sections = split_node.args[1]
|
| 651 |
+
transform_params_list: List[List[_TransformParam]] = []
|
| 652 |
+
|
| 653 |
+
for user_node, user_inputs in zip(next_users, user_inputs_list):
|
| 654 |
+
if user_node.target not in {torch.cat, torch.stack}:
|
| 655 |
+
transform_params_list.append([])
|
| 656 |
+
continue
|
| 657 |
+
|
| 658 |
+
cat_dim = get_arg_value(user_node, 1, "dim")
|
| 659 |
+
transform_params: List[_TransformParam] = []
|
| 660 |
+
for user_input in user_inputs:
|
| 661 |
+
if split_dim == cat_dim and user_node.target == torch.cat:
|
| 662 |
+
# No transform needed
|
| 663 |
+
transform_params.append((None, None, None, None))
|
| 664 |
+
elif isinstance(user_input, tuple): # Split being simplified
|
| 665 |
+
# Verify equal split
|
| 666 |
+
subset_split_sections = split_sections[ # type: ignore[index]
|
| 667 |
+
user_input[0] : user_input[1] + 1
|
| 668 |
+
]
|
| 669 |
+
# All sections should be equal
|
| 670 |
+
if len(set(subset_split_sections)) != 1:
|
| 671 |
+
return None
|
| 672 |
+
|
| 673 |
+
num_splits = len(subset_split_sections)
|
| 674 |
+
unflatten_params = (split_dim, (num_splits, -1))
|
| 675 |
+
movedim_params = (
|
| 676 |
+
(split_dim, cat_dim) if split_dim != cat_dim else None
|
| 677 |
+
)
|
| 678 |
+
transform_params.append(
|
| 679 |
+
(unflatten_params, movedim_params, None, None)
|
| 680 |
+
)
|
| 681 |
+
elif (
|
| 682 |
+
user_node.target == torch.stack or split_dim != cat_dim
|
| 683 |
+
): # We need to unsqueeze inputs not coming through split
|
| 684 |
+
transform_params.append((None, None, (cat_dim,), None))
|
| 685 |
+
else: # Non-split inputs
|
| 686 |
+
transform_params.append((None, None, None, None))
|
| 687 |
+
transform_params_list.append(transform_params)
|
| 688 |
+
return transform_params_list
|
| 689 |
+
|
| 690 |
+
def replace_split(
|
| 691 |
+
self,
|
| 692 |
+
graph: torch.fx.Graph,
|
| 693 |
+
split_node: torch.fx.Node,
|
| 694 |
+
split_sections: List[int],
|
| 695 |
+
user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
|
| 696 |
+
split_ranges: List[_Range],
|
| 697 |
+
) -> List[List[torch.fx.Node]]:
|
| 698 |
+
"""
|
| 699 |
+
Replace the split node. It can either remove the split node if len(split_ranges) == 1, or simplify it
|
| 700 |
+
into a split with lesser sections if len(split_ranges) > 1.
|
| 701 |
+
|
| 702 |
+
Returns the new `user_inputs_list`, with tuples replaced with new getitems from the newer split node.
|
| 703 |
+
"""
|
| 704 |
+
split_input = split_node.args[0]
|
| 705 |
+
split_dim = split_node.kwargs["dim"]
|
| 706 |
+
if len(split_ranges) == 1: # We can completely eliminate the split node
|
| 707 |
+
split_items = [split_input]
|
| 708 |
+
else:
|
| 709 |
+
with graph.inserting_after(split_node):
|
| 710 |
+
new_split = graph.call_function(
|
| 711 |
+
torch.split,
|
| 712 |
+
args=(
|
| 713 |
+
split_input,
|
| 714 |
+
[r[1] - r[0] for r in split_ranges],
|
| 715 |
+
),
|
| 716 |
+
kwargs={"dim": split_dim},
|
| 717 |
+
)
|
| 718 |
+
new_split.meta.update(split_node.meta)
|
| 719 |
+
counters["inductor"]["scmerge_split_added"] += 1
|
| 720 |
+
with graph.inserting_after(new_split):
|
| 721 |
+
split_items = [
|
| 722 |
+
graph.call_function(operator.getitem, args=(new_split, i))
|
| 723 |
+
for i in range(len(split_ranges))
|
| 724 |
+
]
|
| 725 |
+
# Now assign the right getitem to the right input
|
| 726 |
+
cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist()
|
| 727 |
+
new_user_inputs_list = []
|
| 728 |
+
for user_inputs in user_inputs_list:
|
| 729 |
+
new_user_inputs = []
|
| 730 |
+
for user_input in user_inputs:
|
| 731 |
+
if isinstance(user_input, tuple):
|
| 732 |
+
# Find the correct new getitem (present in split_items)
|
| 733 |
+
new_user_inputs.append(
|
| 734 |
+
split_items[
|
| 735 |
+
split_ranges.index(
|
| 736 |
+
(
|
| 737 |
+
cumulative_sizes[user_input[0]],
|
| 738 |
+
cumulative_sizes[user_input[1] + 1],
|
| 739 |
+
)
|
| 740 |
+
)
|
| 741 |
+
]
|
| 742 |
+
)
|
| 743 |
+
else:
|
| 744 |
+
new_user_inputs.append(user_input)
|
| 745 |
+
new_user_inputs_list.append(new_user_inputs)
|
| 746 |
+
return new_user_inputs_list # type: ignore[return-value]
|
| 747 |
+
|
| 748 |
+
def replace_cat(
|
| 749 |
+
self,
|
| 750 |
+
graph: torch.fx.GraphModule,
|
| 751 |
+
split_node: torch.fx.Node,
|
| 752 |
+
next_users: List[torch.fx.Node],
|
| 753 |
+
user_inputs_list_new,
|
| 754 |
+
transform_params_list: List[List[_TransformParam]],
|
| 755 |
+
):
|
| 756 |
+
split_dim = split_node.kwargs["dim"]
|
| 757 |
+
|
| 758 |
+
split_users = split_node.users.keys()
|
| 759 |
+
new_cats = []
|
| 760 |
+
for user_node, user_inputs_new, transform_params in zip(
|
| 761 |
+
next_users, user_inputs_list_new, transform_params_list
|
| 762 |
+
):
|
| 763 |
+
if user_node.target not in {torch.cat, torch.stack}:
|
| 764 |
+
# Change the args and kwargs of non-cat/stack nodes. Replace old getitems (belonging to
|
| 765 |
+
# the original split node) with the newer getitems
|
| 766 |
+
next_cat_input = 0
|
| 767 |
+
for input_node in user_node.all_input_nodes:
|
| 768 |
+
if input_node in split_users:
|
| 769 |
+
user_node.replace_input_with(
|
| 770 |
+
input_node, user_inputs_new[next_cat_input]
|
| 771 |
+
)
|
| 772 |
+
next_cat_input += 1
|
| 773 |
+
continue
|
| 774 |
+
|
| 775 |
+
# Handle cat/stack user nodes
|
| 776 |
+
cat_dim = get_arg_value(user_node, 1, "dim")
|
| 777 |
+
user_inputs_new_transformed = []
|
| 778 |
+
# For `unsqueeze` transform, we will combine consecutive inputs with the same unsqueeze params, and stack them
|
| 779 |
+
to_stack = []
|
| 780 |
+
stack_dim = None
|
| 781 |
+
with graph.inserting_before(user_node):
|
| 782 |
+
for user_input_new, transform_param in zip(
|
| 783 |
+
user_inputs_new, transform_params
|
| 784 |
+
):
|
| 785 |
+
# Apply transforms
|
| 786 |
+
(
|
| 787 |
+
unflatten_params,
|
| 788 |
+
movedim_params,
|
| 789 |
+
unsqueeze_params,
|
| 790 |
+
flatten_params,
|
| 791 |
+
) = transform_param
|
| 792 |
+
if unsqueeze_params and (
|
| 793 |
+
stack_dim is None or stack_dim == unsqueeze_params[0]
|
| 794 |
+
):
|
| 795 |
+
to_stack.append(user_input_new)
|
| 796 |
+
stack_dim = unsqueeze_params[0]
|
| 797 |
+
continue
|
| 798 |
+
elif to_stack:
|
| 799 |
+
stacked_input = graph.call_function(
|
| 800 |
+
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
|
| 801 |
+
)
|
| 802 |
+
to_stack = []
|
| 803 |
+
stack_dim = None
|
| 804 |
+
user_inputs_new_transformed.append(stacked_input)
|
| 805 |
+
if unsqueeze_params:
|
| 806 |
+
to_stack.append(user_input_new)
|
| 807 |
+
stack_dim = unsqueeze_params[0]
|
| 808 |
+
continue
|
| 809 |
+
|
| 810 |
+
if unflatten_params:
|
| 811 |
+
user_input_new = graph.call_function(
|
| 812 |
+
torch.unflatten, args=(user_input_new, *unflatten_params)
|
| 813 |
+
)
|
| 814 |
+
if movedim_params:
|
| 815 |
+
user_input_new = graph.call_function(
|
| 816 |
+
torch.movedim, args=(user_input_new, *movedim_params)
|
| 817 |
+
)
|
| 818 |
+
if flatten_params:
|
| 819 |
+
user_input_new = graph.call_function(
|
| 820 |
+
torch.flatten, args=(user_input_new, *flatten_params)
|
| 821 |
+
)
|
| 822 |
+
user_inputs_new_transformed.append(user_input_new)
|
| 823 |
+
if to_stack:
|
| 824 |
+
stacked_input = graph.call_function(
|
| 825 |
+
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
|
| 826 |
+
)
|
| 827 |
+
user_inputs_new_transformed.append(stacked_input)
|
| 828 |
+
|
| 829 |
+
with graph.inserting_after(user_node):
|
| 830 |
+
if len(user_inputs_new_transformed) > 1:
|
| 831 |
+
new_cat_node = graph.call_function(
|
| 832 |
+
torch.cat,
|
| 833 |
+
args=(user_inputs_new_transformed,),
|
| 834 |
+
kwargs={"dim": cat_dim},
|
| 835 |
+
)
|
| 836 |
+
new_cat_node.meta.update(user_node.meta)
|
| 837 |
+
counters["inductor"]["scmerge_cat_added"] += 1
|
| 838 |
+
else:
|
| 839 |
+
new_cat_node = user_inputs_new_transformed[-1]
|
| 840 |
+
|
| 841 |
+
if (
|
| 842 |
+
user_node.target == torch.cat
|
| 843 |
+
and split_dim != cat_dim
|
| 844 |
+
and split_node.target == torch.split
|
| 845 |
+
):
|
| 846 |
+
with graph.inserting_after(new_cat_node):
|
| 847 |
+
new_cat_node = graph.call_function(
|
| 848 |
+
torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1)
|
| 849 |
+
)
|
| 850 |
+
user_node.replace_all_uses_with(new_cat_node)
|
| 851 |
+
new_cats.append(new_cat_node)
|
| 852 |
+
|
| 853 |
+
def erase_old_nodes(
|
| 854 |
+
self,
|
| 855 |
+
graph: torch.fx.GraphModule,
|
| 856 |
+
split_node: torch.fx.Node,
|
| 857 |
+
next_users: List[torch.fx.Node],
|
| 858 |
+
):
|
| 859 |
+
to_remove = [split_node]
|
| 860 |
+
counters["inductor"]["scmerge_split_removed"] += 1
|
| 861 |
+
to_remove.extend(split_node.users.keys())
|
| 862 |
+
for next_user in next_users:
|
| 863 |
+
if next_user.target not in {torch.cat, torch.stack}:
|
| 864 |
+
continue
|
| 865 |
+
counters["inductor"]["scmerge_cat_removed"] += 1
|
| 866 |
+
to_remove.append(next_user)
|
| 867 |
+
for node in reversed(to_remove):
|
| 868 |
+
graph.erase_node(node)
|
| 869 |
+
|
| 870 |
+
|
| 871 |
+
class UnbindCatRemover(SplitCatSimplifier):
|
| 872 |
+
"""
|
| 873 |
+
Helper class to merge Unbind->Cat/Stack. Many of the cases are similar to SplitCatSimplifier.
|
| 874 |
+
|
| 875 |
+
Unbind can't be simplified like splits. So, we can only remove the unbind node. Other than this,
|
| 876 |
+
other cases like multiple users, additional args, dim mismatch are similar to `SplitCatSimplifier`,
|
| 877 |
+
hence we extend that class.
|
| 878 |
+
"""
|
| 879 |
+
|
| 880 |
+
def remove_unbind(
|
| 881 |
+
self,
|
| 882 |
+
graph: torch.fx.Graph,
|
| 883 |
+
unbind_node: torch.fx.Node,
|
| 884 |
+
):
|
| 885 |
+
num_unbind = ( # type: ignore[operator]
|
| 886 |
+
max(getitem_node.args[1] for getitem_node in unbind_node.users.keys()) + 1 # type: ignore[operator, union-attr, type-var]
|
| 887 |
+
)
|
| 888 |
+
split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type]
|
| 889 |
+
|
| 890 |
+
super().simplify(graph, unbind_node, split_sections)
|
| 891 |
+
|
| 892 |
+
def get_simplified_split_ranges(
|
| 893 |
+
self,
|
| 894 |
+
split_sections: List[int],
|
| 895 |
+
next_users: List[torch.fx.Node],
|
| 896 |
+
user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
|
| 897 |
+
) -> Optional[List[_Range]]:
|
| 898 |
+
simplified_split_ranges = super().get_simplified_split_ranges(
|
| 899 |
+
split_sections, next_users, user_inputs_list
|
| 900 |
+
)
|
| 901 |
+
if not simplified_split_ranges or len(simplified_split_ranges) != 1:
|
| 902 |
+
return None
|
| 903 |
+
return simplified_split_ranges
|
| 904 |
+
|
| 905 |
+
def get_transform_params(
|
| 906 |
+
self,
|
| 907 |
+
unbind_node: torch.fx.Node,
|
| 908 |
+
next_users: List[torch.fx.Node],
|
| 909 |
+
user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
|
| 910 |
+
) -> Optional[List[List[_TransformParam]]]:
|
| 911 |
+
"""
|
| 912 |
+
Figure out what transforms are needed for each input to each cat node.
|
| 913 |
+
|
| 914 |
+
Here is the rough transforms we apply:
|
| 915 |
+
|
| 916 |
+
x -> unbind -> stack => x -> movedim
|
| 917 |
+
|
| 918 |
+
x -> unbind -> cat => x -> movedim -> flatten
|
| 919 |
+
|
| 920 |
+
When cat/stack nodes have additional args:
|
| 921 |
+
|
| 922 |
+
addn ---| addn -> unsqueeze ---|
|
| 923 |
+
x -> unbind -> stack => x -> movedim -> cat
|
| 924 |
+
|
| 925 |
+
addn ---| addn ---|
|
| 926 |
+
x -> unbind -> cat => x -> movedim -> flatten -> cat
|
| 927 |
+
|
| 928 |
+
(Note application of these depends on the dims as well)
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
"""
|
| 932 |
+
split_dim = unbind_node.kwargs["dim"]
|
| 933 |
+
transform_params_list: List[List[_TransformParam]] = []
|
| 934 |
+
for user_node, user_inputs in zip(next_users, user_inputs_list):
|
| 935 |
+
cat_dim = get_arg_value(user_node, 1, "dim") or 0
|
| 936 |
+
transform_params: List[_TransformParam] = []
|
| 937 |
+
for user_input in user_inputs:
|
| 938 |
+
if isinstance(user_input, tuple):
|
| 939 |
+
# User input is coming from unbind
|
| 940 |
+
movedim_params = (
|
| 941 |
+
(split_dim, cat_dim) if split_dim != cat_dim else None
|
| 942 |
+
)
|
| 943 |
+
flatten_params = None
|
| 944 |
+
if user_node.target == torch.cat:
|
| 945 |
+
flatten_params = (cat_dim, cat_dim + 1)
|
| 946 |
+
transform_params.append(
|
| 947 |
+
(None, movedim_params, None, flatten_params)
|
| 948 |
+
)
|
| 949 |
+
elif (
|
| 950 |
+
user_node.target == torch.stack
|
| 951 |
+
): # We need to unsqueeze inputs not coming through unbind into cat
|
| 952 |
+
transform_params.append((None, None, (cat_dim,), None))
|
| 953 |
+
else: # Non-unbind inputs
|
| 954 |
+
transform_params.append((None, None, None, None))
|
| 955 |
+
transform_params_list.append(transform_params)
|
| 956 |
+
return transform_params_list
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
class GetItem(CallFunction):
|
| 960 |
+
def __init__(self, arg, index, _users=1):
|
| 961 |
+
super().__init__(operator.getitem, arg, index, _users=_users)
|
| 962 |
+
|
| 963 |
+
def find_anchor_nodes(self, ctx: MatchContext, searched: Set[torch.fx.Node]):
|
| 964 |
+
# We generally match GetItem with arg being an Arg(). So, we never return the anchor
|
| 965 |
+
# nodes as the stored node in ctx.pattern_to_node is returned. Here we override find_anchor_nodes
|
| 966 |
+
# to not use ctx.pattern_to_node
|
| 967 |
+
for pattern in self.flat_args_kwargs[0]:
|
| 968 |
+
if isinstance(pattern, PatternExpr):
|
| 969 |
+
for other_node in pattern.find_anchor_nodes(ctx, searched):
|
| 970 |
+
if not isinstance(other_node, torch.fx.Node):
|
| 971 |
+
continue
|
| 972 |
+
for node in other_node.users:
|
| 973 |
+
if node not in searched:
|
| 974 |
+
if self._match_fns(node):
|
| 975 |
+
yield node
|
| 976 |
+
searched.add(node)
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
@register_graph_pattern(
|
| 980 |
+
RepeatedExpr(
|
| 981 |
+
CallFunction(
|
| 982 |
+
torch.squeeze,
|
| 983 |
+
GetItem(
|
| 984 |
+
TorchSplit(
|
| 985 |
+
KeywordArg("split_input"),
|
| 986 |
+
KeywordArg("split_sizes"),
|
| 987 |
+
),
|
| 988 |
+
Ignored(),
|
| 989 |
+
),
|
| 990 |
+
KeywordArg("dim"),
|
| 991 |
+
_users=MULTIPLE,
|
| 992 |
+
),
|
| 993 |
+
),
|
| 994 |
+
pass_dict=split_cat_pass,
|
| 995 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 996 |
+
)
|
| 997 |
+
@register_graph_pattern(
|
| 998 |
+
RepeatedExpr(
|
| 999 |
+
CallFunction(
|
| 1000 |
+
torch.squeeze,
|
| 1001 |
+
GetItem(
|
| 1002 |
+
TorchSplit(
|
| 1003 |
+
KeywordArg("split_input"),
|
| 1004 |
+
KeywordArg("split_sizes"),
|
| 1005 |
+
),
|
| 1006 |
+
Ignored(),
|
| 1007 |
+
),
|
| 1008 |
+
dim=KeywordArg("dim"),
|
| 1009 |
+
_users=MULTIPLE,
|
| 1010 |
+
)
|
| 1011 |
+
),
|
| 1012 |
+
pass_dict=split_cat_pass,
|
| 1013 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 1014 |
+
)
|
| 1015 |
+
def merge_split_squeeze(
|
| 1016 |
+
match: Match, split_input: torch.fx.Node, split_sizes: List[int], dim: int
|
| 1017 |
+
):
|
| 1018 |
+
graph = match.graph
|
| 1019 |
+
split = next(node for node in match.nodes if node.target == torch.split)
|
| 1020 |
+
if not all(s == 1 for s in split_sizes):
|
| 1021 |
+
return
|
| 1022 |
+
if isinstance(dim, Sequence):
|
| 1023 |
+
return
|
| 1024 |
+
next_users = find_next_users(split)
|
| 1025 |
+
if not all(node.target == torch.squeeze for node in next_users):
|
| 1026 |
+
return
|
| 1027 |
+
with graph.inserting_before(match.output_node()):
|
| 1028 |
+
unbind = graph.call_function(
|
| 1029 |
+
torch.unbind, args=(split_input,), kwargs={"dim": dim}
|
| 1030 |
+
)
|
| 1031 |
+
for item_index, getitem_node in sorted(
|
| 1032 |
+
[
|
| 1033 |
+
(getitem_node.args[1], getitem_node)
|
| 1034 |
+
for getitem_node in split.users.keys()
|
| 1035 |
+
]
|
| 1036 |
+
):
|
| 1037 |
+
squeeze = next(iter(getitem_node.users.keys()))
|
| 1038 |
+
new_get_item = graph.call_function(
|
| 1039 |
+
operator.getitem, args=(unbind, item_index)
|
| 1040 |
+
)
|
| 1041 |
+
squeeze.replace_all_uses_with(new_get_item)
|
| 1042 |
+
new_get_item.meta.update(squeeze.meta)
|
| 1043 |
+
graph.erase_node(squeeze)
|
| 1044 |
+
graph.erase_node(getitem_node)
|
| 1045 |
+
graph.erase_node(split)
|
| 1046 |
+
counters["inductor"]["split_squeeze_replaced"] += 1
|
| 1047 |
+
|
| 1048 |
+
|
| 1049 |
+
getitem_unbind = ListOf(
|
| 1050 |
+
GetItem(
|
| 1051 |
+
CallFunction(
|
| 1052 |
+
torch.unbind,
|
| 1053 |
+
KeywordArg("unbind_input"),
|
| 1054 |
+
dim=KeywordArg("dim"),
|
| 1055 |
+
_users=MULTIPLE,
|
| 1056 |
+
),
|
| 1057 |
+
Ignored(),
|
| 1058 |
+
_users=MULTIPLE,
|
| 1059 |
+
),
|
| 1060 |
+
partial=True,
|
| 1061 |
+
)
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
+
@register_graph_pattern(
|
| 1065 |
+
CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE),
|
| 1066 |
+
pass_dict=unbind_stack_pass,
|
| 1067 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 1068 |
+
)
|
| 1069 |
+
@register_graph_pattern(
|
| 1070 |
+
CallFunction(
|
| 1071 |
+
[torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE
|
| 1072 |
+
),
|
| 1073 |
+
pass_dict=unbind_stack_pass,
|
| 1074 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 1075 |
+
)
|
| 1076 |
+
@register_graph_pattern(
|
| 1077 |
+
CallFunction(
|
| 1078 |
+
[torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE
|
| 1079 |
+
),
|
| 1080 |
+
pass_dict=unbind_stack_pass,
|
| 1081 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 1082 |
+
)
|
| 1083 |
+
def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int):
|
| 1084 |
+
unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
|
| 1085 |
+
UnbindCatRemover().remove_unbind(match.graph, unbind_node)
|
| 1086 |
+
|
| 1087 |
+
|
| 1088 |
+
getitem_split = ListOf(
|
| 1089 |
+
CallFunction(
|
| 1090 |
+
operator.getitem,
|
| 1091 |
+
TorchSplit(
|
| 1092 |
+
Ignored(),
|
| 1093 |
+
KeywordArg("split_sections"),
|
| 1094 |
+
),
|
| 1095 |
+
Ignored(),
|
| 1096 |
+
_users=MULTIPLE,
|
| 1097 |
+
),
|
| 1098 |
+
partial=True,
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
|
| 1102 |
+
@register_graph_pattern(
|
| 1103 |
+
CallFunction(
|
| 1104 |
+
[torch.stack, torch.cat],
|
| 1105 |
+
tensors=getitem_split,
|
| 1106 |
+
dim=Ignored(),
|
| 1107 |
+
_users=MULTIPLE,
|
| 1108 |
+
),
|
| 1109 |
+
pass_dict=split_cat_pass,
|
| 1110 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 1111 |
+
)
|
| 1112 |
+
@register_graph_pattern(
|
| 1113 |
+
CallFunction(
|
| 1114 |
+
[torch.stack, torch.cat],
|
| 1115 |
+
getitem_split,
|
| 1116 |
+
dim=Ignored(),
|
| 1117 |
+
_users=MULTIPLE,
|
| 1118 |
+
),
|
| 1119 |
+
pass_dict=split_cat_pass,
|
| 1120 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 1121 |
+
)
|
| 1122 |
+
@register_graph_pattern(
|
| 1123 |
+
CallFunction(
|
| 1124 |
+
[torch.stack, torch.cat],
|
| 1125 |
+
getitem_split,
|
| 1126 |
+
Ignored(),
|
| 1127 |
+
_users=MULTIPLE,
|
| 1128 |
+
),
|
| 1129 |
+
pass_dict=split_cat_pass,
|
| 1130 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 1131 |
+
)
|
| 1132 |
+
def simplify_split_cat(match: Match, split_sections: List[int], dim: int):
|
| 1133 |
+
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
|
| 1134 |
+
return
|
| 1135 |
+
split_node = next(node for node in match.nodes if node.target == torch.split)
|
| 1136 |
+
SplitCatSimplifier().simplify(match.graph, split_node, split_sections)
|
| 1137 |
+
|
| 1138 |
+
|
| 1139 |
+
# noqa: W605
|
| 1140 |
+
# ############pattern to be optimized is#########
|
| 1141 |
+
|
| 1142 |
+
# split_node(dim=1)
|
| 1143 |
+
# / \ ... / \
|
| 1144 |
+
# getitem getitem getitem getitem -> user=1
|
| 1145 |
+
# \ / \ /
|
| 1146 |
+
# cat (user=mul, dim=1) cat(user=mul, dim=1)
|
| 1147 |
+
# | \ | \
|
| 1148 |
+
|
| 1149 |
+
# ################after transformation#############
|
| 1150 |
+
|
| 1151 |
+
# split_node(dim=1)
|
| 1152 |
+
# / ... \
|
| 1153 |
+
# getitem getitem
|
| 1154 |
+
# | \ | \
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
def has_same_parent_node(node: torch.fx.Node):
|
| 1158 |
+
# the input nodes of the node should come from the same parent
|
| 1159 |
+
prev_node = None
|
| 1160 |
+
for getitem in node.args[0]: # type: ignore[union-attr]
|
| 1161 |
+
if getitem.target != operator.getitem: # type: ignore[union-attr]
|
| 1162 |
+
return False
|
| 1163 |
+
if prev_node is None:
|
| 1164 |
+
prev_node = getitem.args[0] # type: ignore[union-attr]
|
| 1165 |
+
else:
|
| 1166 |
+
if getitem.args[0] != prev_node:
|
| 1167 |
+
return False
|
| 1168 |
+
return True
|
| 1169 |
+
|
| 1170 |
+
|
| 1171 |
+
def remove_zeros(split_sections: List[int]):
|
| 1172 |
+
"""
|
| 1173 |
+
Remove zeros from the list and get the index mapping dict from getitem
|
| 1174 |
+
in split node to getitem in new split node
|
| 1175 |
+
"""
|
| 1176 |
+
new_split_sections, index_mapping = [], {}
|
| 1177 |
+
idx = 0
|
| 1178 |
+
for i in range(len(split_sections)):
|
| 1179 |
+
if split_sections[i] > 0:
|
| 1180 |
+
new_split_sections.append(split_sections[i])
|
| 1181 |
+
index_mapping[i] = idx
|
| 1182 |
+
idx += 1
|
| 1183 |
+
|
| 1184 |
+
return new_split_sections, index_mapping
|
| 1185 |
+
|
| 1186 |
+
|
| 1187 |
+
def is_sorted_and_consecutive(arr: List[int]) -> bool:
|
| 1188 |
+
# check if the array is sorted
|
| 1189 |
+
if arr == sorted(arr):
|
| 1190 |
+
# check if the differences between adjacent elements are all 1
|
| 1191 |
+
return all(x[1] - x[0] == 1 for x in zip(arr, arr[1:]))
|
| 1192 |
+
else:
|
| 1193 |
+
return False
|
| 1194 |
+
|
| 1195 |
+
|
| 1196 |
+
def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: List[int]) -> int:
|
| 1197 |
+
"""
|
| 1198 |
+
Calculate the fused tensor size in the indices
|
| 1199 |
+
"""
|
| 1200 |
+
fused_tensor_size = 0
|
| 1201 |
+
for i in range(len(split_node.args[1])): # type: ignore[arg-type]
|
| 1202 |
+
if i in indices:
|
| 1203 |
+
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
|
| 1204 |
+
return fused_tensor_size
|
| 1205 |
+
|
| 1206 |
+
|
| 1207 |
+
@register_graph_pattern(
|
| 1208 |
+
CallFunction(
|
| 1209 |
+
torch.cat,
|
| 1210 |
+
getitem_split,
|
| 1211 |
+
dim=Ignored(),
|
| 1212 |
+
_users=MULTIPLE,
|
| 1213 |
+
),
|
| 1214 |
+
pass_dict=merge_getitem_cat_pass,
|
| 1215 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 1216 |
+
)
|
| 1217 |
+
def merge_getitem_cat(match: Match, split_sections: List[int], dim: int):
|
| 1218 |
+
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
|
| 1219 |
+
return
|
| 1220 |
+
graph = match.graph
|
| 1221 |
+
split_node = next(node for node in match.nodes if node.target == torch.split)
|
| 1222 |
+
split_input, split_size, split_dim = _get_split_args_default(split_node)
|
| 1223 |
+
# if the cat and split have different dims, return
|
| 1224 |
+
# Find the next users (i.e. users after the getitem)
|
| 1225 |
+
next_users = find_next_users(split_node)
|
| 1226 |
+
# 'immutable_list' object does not support mutation. Create a new copy of it
|
| 1227 |
+
split_sections = list(split_sections)
|
| 1228 |
+
for cat_user in next_users:
|
| 1229 |
+
if cat_user.target == torch.cat:
|
| 1230 |
+
cat_dim = get_arg_value(cat_user, 1, "dim")
|
| 1231 |
+
# check the all getitems in the cat_user from the same node
|
| 1232 |
+
# check the input of the cat has all getitem from the split
|
| 1233 |
+
# check all getitem only has one single user
|
| 1234 |
+
if (
|
| 1235 |
+
split_dim != cat_dim
|
| 1236 |
+
or not has_same_parent_node(cat_user)
|
| 1237 |
+
or not all(len(arg.users) == 1 for arg in cat_user.args[0]) # type: ignore[union-attr]
|
| 1238 |
+
):
|
| 1239 |
+
continue
|
| 1240 |
+
# find the index of getitems to be cated/stacked
|
| 1241 |
+
indices = []
|
| 1242 |
+
for arg in cat_user.args[0]: # type: ignore[union-attr]
|
| 1243 |
+
indices.append(arg.args[1]) # type: ignore[union-attr]
|
| 1244 |
+
# the gettitems to be merged must be consecutive, otherwise
|
| 1245 |
+
# returned sliced tensor could be wrong
|
| 1246 |
+
if not is_sorted_and_consecutive(indices):
|
| 1247 |
+
continue
|
| 1248 |
+
# update the arg of cat user, only keep the first getitem
|
| 1249 |
+
cat_user.update_arg(0, cat_user.args[0][0]) # type: ignore[index]
|
| 1250 |
+
# calculate the fused tensor sizes in the indices
|
| 1251 |
+
fused_tensor_size = 0
|
| 1252 |
+
for i in range(len(split_node.args[1])): # type: ignore[arg-type]
|
| 1253 |
+
if i in indices:
|
| 1254 |
+
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
|
| 1255 |
+
# update the split sections
|
| 1256 |
+
split_sections[indices[0]] = calculate_fused_tensor_size(
|
| 1257 |
+
split_node, indices
|
| 1258 |
+
)
|
| 1259 |
+
# padding others with zeros to keep the same dict size
|
| 1260 |
+
for i in indices[1:]:
|
| 1261 |
+
split_sections[i] = 0
|
| 1262 |
+
# remove all unused indexes in the split_node
|
| 1263 |
+
new_split_sections, index_mapping = remove_zeros(split_sections)
|
| 1264 |
+
with graph.inserting_after(split_node):
|
| 1265 |
+
new_split_node = graph.call_function(
|
| 1266 |
+
torch.split,
|
| 1267 |
+
args=(split_input, split_sections),
|
| 1268 |
+
kwargs={"dim": split_dim},
|
| 1269 |
+
)
|
| 1270 |
+
split_node.replace_all_uses_with(new_split_node)
|
| 1271 |
+
new_split_node.meta.update(split_node.meta)
|
| 1272 |
+
# remove all unused getitem nodes
|
| 1273 |
+
to_remove = [cat_user]
|
| 1274 |
+
# dictionary keys changed during iteration
|
| 1275 |
+
new_split_getitem_nodes = list(new_split_node.users.keys())
|
| 1276 |
+
for getitem_node in new_split_getitem_nodes:
|
| 1277 |
+
if getitem_node.args[1] in indices[1:]:
|
| 1278 |
+
to_remove.append(getitem_node)
|
| 1279 |
+
# update meta data of getitem
|
| 1280 |
+
elif getitem_node.args[1] == indices[0]:
|
| 1281 |
+
cat_user.replace_all_uses_with(getitem_node)
|
| 1282 |
+
getitem_node.meta.update(cat_user.meta)
|
| 1283 |
+
else:
|
| 1284 |
+
# update getitem index for new split node
|
| 1285 |
+
getitem_node.update_arg(1, index_mapping[getitem_node.args[1]])
|
| 1286 |
+
graph.erase_node(split_node)
|
| 1287 |
+
for getitem_node in to_remove:
|
| 1288 |
+
graph.erase_node(getitem_node)
|
| 1289 |
+
# update the split sections of new split node
|
| 1290 |
+
new_split_node.update_arg(1, new_split_sections)
|
| 1291 |
+
split_node = new_split_node
|
| 1292 |
+
split_sections = new_split_sections
|
| 1293 |
+
|
| 1294 |
+
counters["inductor"]["getitem_cat_merged"] += 1
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
# ############pattern to be optimized is#########
|
| 1298 |
+
|
| 1299 |
+
# split_node(dim=1) -> user=multiple
|
| 1300 |
+
# / \ ... / \
|
| 1301 |
+
# getitem getitem getitem getitem -> user=multiple
|
| 1302 |
+
# \ \ / \
|
| 1303 |
+
# other_op /cat(user=mul, dim=1) other_op
|
| 1304 |
+
# |
|
| 1305 |
+
|
| 1306 |
+
# ################after transformation#############
|
| 1307 |
+
|
| 1308 |
+
# split_node(dim=1) -> -> user=multiple
|
| 1309 |
+
# / \ ... / \
|
| 1310 |
+
# getitem getitem getitem getitem -> user=multiple
|
| 1311 |
+
# \ \ / \
|
| 1312 |
+
# other_op
|
| 1313 |
+
|
| 1314 |
+
|
| 1315 |
+
@register_graph_pattern(
|
| 1316 |
+
CallFunction(
|
| 1317 |
+
torch.cat,
|
| 1318 |
+
getitem_split,
|
| 1319 |
+
dim=Ignored(),
|
| 1320 |
+
_users=MULTIPLE,
|
| 1321 |
+
),
|
| 1322 |
+
pass_dict=split_cat_pass,
|
| 1323 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 1324 |
+
)
|
| 1325 |
+
def mutate_cat_node(match: Match, split_sections: List[int], dim: int):
|
| 1326 |
+
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
|
| 1327 |
+
return
|
| 1328 |
+
graph = match.graph
|
| 1329 |
+
split_node = next(node for node in match.nodes if node.target == torch.split)
|
| 1330 |
+
split_input, split_size, split_dim = _get_split_args_default(split_node)
|
| 1331 |
+
# if the cat and split have different dims, return
|
| 1332 |
+
# Find the next users (i.e. users after the getitem)
|
| 1333 |
+
next_users = find_next_users(split_node)
|
| 1334 |
+
for cat_user in next_users:
|
| 1335 |
+
if cat_user.target == torch.cat:
|
| 1336 |
+
cat_dim = get_arg_value(cat_user, 1, "dim") or 0
|
| 1337 |
+
# check that all getitems in the cat_user from the same node
|
| 1338 |
+
# check the input of the cat has all getitem from the split
|
| 1339 |
+
if split_dim != cat_dim or not has_same_parent_node(cat_user):
|
| 1340 |
+
continue
|
| 1341 |
+
# find the index of getitems to be cat
|
| 1342 |
+
indices, idx_to_getitem = [], {}
|
| 1343 |
+
for getitem in cat_user.args[0]: # type: ignore[union-attr]
|
| 1344 |
+
indices.append(getitem.args[1]) # type: ignore[union-attr]
|
| 1345 |
+
idx_to_getitem[getitem.args[1]] = getitem # type: ignore[union-attr]
|
| 1346 |
+
# the gettitems to be merged must be consecutive, otherwise
|
| 1347 |
+
# returned sliced tensor could be wrong
|
| 1348 |
+
if not is_sorted_and_consecutive(indices):
|
| 1349 |
+
continue
|
| 1350 |
+
# case 1: the cat uses all getitems from the split
|
| 1351 |
+
if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type]
|
| 1352 |
+
# replace the users of the cat node to be the input of the split node
|
| 1353 |
+
cat_user.replace_all_uses_with(split_node.args[0])
|
| 1354 |
+
# remove the cat node
|
| 1355 |
+
graph.erase_node(cat_user)
|
| 1356 |
+
counters["inductor"]["cat_mutated"] += 1
|
| 1357 |
+
# case 2: the cat uses some getitems from the split
|
| 1358 |
+
elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type]
|
| 1359 |
+
# check the split dim, and construct the slice tuple
|
| 1360 |
+
start_fused_size = calculate_fused_tensor_size(
|
| 1361 |
+
split_node, list(range(indices[0]))
|
| 1362 |
+
)
|
| 1363 |
+
end_fused_size = start_fused_size + calculate_fused_tensor_size(
|
| 1364 |
+
split_node, indices
|
| 1365 |
+
)
|
| 1366 |
+
slice_list = []
|
| 1367 |
+
for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr]
|
| 1368 |
+
if i != split_dim:
|
| 1369 |
+
slice_list.append(slice(None, None, None))
|
| 1370 |
+
else:
|
| 1371 |
+
slice_list.append(slice(start_fused_size, end_fused_size, None))
|
| 1372 |
+
with graph.inserting_after(split_node):
|
| 1373 |
+
slice_node = graph.call_function(
|
| 1374 |
+
operator.getitem,
|
| 1375 |
+
args=(split_node.args[0], tuple(slice_list)),
|
| 1376 |
+
)
|
| 1377 |
+
cat_user.replace_all_uses_with(slice_node)
|
| 1378 |
+
slice_node.meta.update(cat_user.meta)
|
| 1379 |
+
|
| 1380 |
+
# remove the cat node
|
| 1381 |
+
graph.erase_node(cat_user)
|
| 1382 |
+
counters["inductor"]["cat_mutated"] += 1
|
| 1383 |
+
|
| 1384 |
+
|
| 1385 |
+
# noqa: W605
|
| 1386 |
+
# ############The pattern to be optimized is#########
|
| 1387 |
+
# split_node (dim=1)
|
| 1388 |
+
# / ... \ ... / \
|
| 1389 |
+
# getitem getitem getitem getitem -> user=1
|
| 1390 |
+
# \ /
|
| 1391 |
+
# stack (dim=0) -> user=1, getitems to be consecutive
|
| 1392 |
+
# |
|
| 1393 |
+
# tahn -> user=1
|
| 1394 |
+
# |
|
| 1395 |
+
# unbind (dim=0)
|
| 1396 |
+
# |
|
| 1397 |
+
|
| 1398 |
+
# ################After transformation#############
|
| 1399 |
+
# split_node (dim=1)
|
| 1400 |
+
# / ... / \
|
| 1401 |
+
# getitem getitem getitem -> user=1
|
| 1402 |
+
# |
|
| 1403 |
+
# tahn
|
| 1404 |
+
# |
|
| 1405 |
+
# split
|
| 1406 |
+
# |
|
| 1407 |
+
|
| 1408 |
+
|
| 1409 |
+
@register_graph_pattern(
|
| 1410 |
+
CallFunction(
|
| 1411 |
+
torch.tanh,
|
| 1412 |
+
CallFunction(
|
| 1413 |
+
torch.stack,
|
| 1414 |
+
getitem_split,
|
| 1415 |
+
dim=Ignored(),
|
| 1416 |
+
),
|
| 1417 |
+
),
|
| 1418 |
+
pass_dict=merge_getitem_cat_pass,
|
| 1419 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 1420 |
+
)
|
| 1421 |
+
@register_graph_pattern(
|
| 1422 |
+
CallFunction(
|
| 1423 |
+
torch.tanh,
|
| 1424 |
+
CallFunction(
|
| 1425 |
+
torch.stack,
|
| 1426 |
+
tensors=getitem_split,
|
| 1427 |
+
dim=Ignored(),
|
| 1428 |
+
),
|
| 1429 |
+
),
|
| 1430 |
+
pass_dict=merge_getitem_cat_pass,
|
| 1431 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 1432 |
+
)
|
| 1433 |
+
@register_graph_pattern(
|
| 1434 |
+
CallFunction(
|
| 1435 |
+
torch.tanh,
|
| 1436 |
+
CallFunction(
|
| 1437 |
+
torch.stack,
|
| 1438 |
+
getitem_split,
|
| 1439 |
+
Ignored(),
|
| 1440 |
+
),
|
| 1441 |
+
),
|
| 1442 |
+
pass_dict=merge_getitem_cat_pass,
|
| 1443 |
+
extra_check=config_flag("split_cat_fx_passes"),
|
| 1444 |
+
)
|
| 1445 |
+
def merge_stack_tahn_unbind(match: Match, split_sections: List[int], dim: int):
|
| 1446 |
+
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
|
| 1447 |
+
return
|
| 1448 |
+
graph = match.graph
|
| 1449 |
+
split_node = next(node for node in match.nodes if node.target == torch.split)
|
| 1450 |
+
split_input, split_size, split_dim = _get_split_args_default(split_node)
|
| 1451 |
+
# Find the next users (i.e. users after the getitem)
|
| 1452 |
+
next_users = find_next_users(split_node)
|
| 1453 |
+
# 'immutable_list' object does not support mutation. Create a new copy of it
|
| 1454 |
+
split_sections = list(split_sections)
|
| 1455 |
+
for user in next_users:
|
| 1456 |
+
# stack user only has one user
|
| 1457 |
+
if user.target == torch.stack:
|
| 1458 |
+
stack_dim = get_arg_value(user, 1, "dim") or 0
|
| 1459 |
+
unbind_user = find_next_users(user)[0]
|
| 1460 |
+
if unbind_user.target != torch.unbind:
|
| 1461 |
+
continue
|
| 1462 |
+
unbind_dim = get_arg_value(unbind_user, 1, "dim") or 0
|
| 1463 |
+
# stack and unbind should have the same dim
|
| 1464 |
+
# check the all getitems in the user from the same node
|
| 1465 |
+
# check all the getitems only has single user
|
| 1466 |
+
if (
|
| 1467 |
+
stack_dim != unbind_dim
|
| 1468 |
+
or not has_same_parent_node(user)
|
| 1469 |
+
or not all(len(arg.users) == 1 for arg in user.args[0]) # type: ignore[union-attr]
|
| 1470 |
+
):
|
| 1471 |
+
continue
|
| 1472 |
+
# find the index of getitems to be stacked
|
| 1473 |
+
indices = []
|
| 1474 |
+
split_sections_for_unbind = []
|
| 1475 |
+
for arg in user.args[0]: # type: ignore[union-attr]
|
| 1476 |
+
indices.append(arg.args[1]) # type: ignore[union-attr]
|
| 1477 |
+
split_sections_for_unbind.append(split_sections[arg.args[1]]) # type: ignore[union-attr]
|
| 1478 |
+
# the gettitems to be merged must be consecutive, otherwise
|
| 1479 |
+
# returned sliced tensor could be wrong
|
| 1480 |
+
if not is_sorted_and_consecutive(indices):
|
| 1481 |
+
continue
|
| 1482 |
+
# update the arg of stack user, only keep the first getitem
|
| 1483 |
+
user.update_arg(0, user.args[0][0]) # type: ignore[index]
|
| 1484 |
+
# calculate the fused tensor sizes in the indices
|
| 1485 |
+
fused_tensor_size = 0
|
| 1486 |
+
for i in range(len(split_node.args[1])): # type: ignore[arg-type]
|
| 1487 |
+
if i in indices:
|
| 1488 |
+
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, index, assignment]
|
| 1489 |
+
# update the split sections
|
| 1490 |
+
split_sections[indices[0]] = calculate_fused_tensor_size(
|
| 1491 |
+
split_node, indices
|
| 1492 |
+
)
|
| 1493 |
+
# padding others with zeros to keep the same dict size
|
| 1494 |
+
for i in indices[1:]:
|
| 1495 |
+
split_sections[i] = 0
|
| 1496 |
+
# remove all unused indexes in the split_node
|
| 1497 |
+
new_split_sections, index_mapping = remove_zeros(split_sections)
|
| 1498 |
+
with graph.inserting_after(split_node):
|
| 1499 |
+
new_split_node = graph.call_function(
|
| 1500 |
+
torch.split,
|
| 1501 |
+
args=(split_input, split_sections),
|
| 1502 |
+
kwargs={"dim": split_dim},
|
| 1503 |
+
)
|
| 1504 |
+
replace_unbind_with_split = graph.call_function(
|
| 1505 |
+
torch.split,
|
| 1506 |
+
args=(unbind_user.args[0], split_sections_for_unbind),
|
| 1507 |
+
kwargs={"dim": split_dim},
|
| 1508 |
+
)
|
| 1509 |
+
unbind_user.replace_all_uses_with(replace_unbind_with_split)
|
| 1510 |
+
replace_unbind_with_split.meta.update(unbind_user.meta)
|
| 1511 |
+
# remove getitem and split, stack
|
| 1512 |
+
split_node.replace_all_uses_with(new_split_node)
|
| 1513 |
+
new_split_node.meta.update(split_node.meta)
|
| 1514 |
+
# remove all unused getitem nodes
|
| 1515 |
+
to_remove = [unbind_user]
|
| 1516 |
+
# dictionary keys changed during iteration
|
| 1517 |
+
new_split_getitem_nodes = list(new_split_node.users.keys())
|
| 1518 |
+
for getitem_node in new_split_getitem_nodes:
|
| 1519 |
+
if getitem_node.args[1] in indices[1:]:
|
| 1520 |
+
to_remove.append(getitem_node)
|
| 1521 |
+
# update meta data of getitem
|
| 1522 |
+
elif getitem_node.args[1] == indices[0]:
|
| 1523 |
+
user.replace_all_uses_with(getitem_node)
|
| 1524 |
+
getitem_node.meta.update(user.meta)
|
| 1525 |
+
else:
|
| 1526 |
+
# update getitem index for new split node
|
| 1527 |
+
getitem_node.update_arg(1, index_mapping[getitem_node.args[1]])
|
| 1528 |
+
graph.erase_node(split_node)
|
| 1529 |
+
graph.erase_node(user)
|
| 1530 |
+
for getitem_node in to_remove:
|
| 1531 |
+
graph.erase_node(getitem_node)
|
| 1532 |
+
# update the split sections of new split node
|
| 1533 |
+
new_split_node.update_arg(1, new_split_sections)
|
| 1534 |
+
split_node = new_split_node
|
| 1535 |
+
split_sections = new_split_sections
|
| 1536 |
+
|
| 1537 |
+
counters["inductor"]["stack_tahn_unbind_merged"] += 1
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Optional, Sequence
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import _prims, Tensor
|
| 8 |
+
|
| 9 |
+
log = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def make_prim(
|
| 13 |
+
schema: str,
|
| 14 |
+
impl_aten,
|
| 15 |
+
return_type=_prims.RETURN_TYPE.NEW,
|
| 16 |
+
doc: str = "",
|
| 17 |
+
tags: Optional[Sequence[torch.Tag]] = None,
|
| 18 |
+
):
|
| 19 |
+
def meta(*args, **kwargs):
|
| 20 |
+
return _prims.TensorMeta(impl_aten(*args, **kwargs))
|
| 21 |
+
|
| 22 |
+
return _prims._make_prim(
|
| 23 |
+
schema=schema,
|
| 24 |
+
return_type=return_type,
|
| 25 |
+
meta=meta,
|
| 26 |
+
impl_aten=impl_aten,
|
| 27 |
+
doc=doc,
|
| 28 |
+
tags=tags,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def eager_force_stride(input_tensor: Tensor, stride) -> Tensor:
|
| 33 |
+
if input_tensor.stride() == stride:
|
| 34 |
+
return input_tensor
|
| 35 |
+
new_tensor = input_tensor.clone().as_strided(
|
| 36 |
+
input_tensor.shape,
|
| 37 |
+
stride,
|
| 38 |
+
)
|
| 39 |
+
new_tensor.copy_(input_tensor)
|
| 40 |
+
return new_tensor
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Custom prims used for handling randomness
|
| 44 |
+
seed = make_prim(
|
| 45 |
+
"inductor_seed(Device device) -> Tensor",
|
| 46 |
+
lambda device: torch.randint(2**63 - 1, [], device=device),
|
| 47 |
+
doc="create a fresh seed (one per call) for use with inductor_rand",
|
| 48 |
+
tags=(torch.Tag.nondeterministic_seeded,),
|
| 49 |
+
)
|
| 50 |
+
seeds = make_prim(
|
| 51 |
+
"inductor_seeds(int count, Device device) -> Tensor",
|
| 52 |
+
lambda count, device: torch.randint(2**63 - 1, [count], device=device),
|
| 53 |
+
doc="Horizontal fusion of many inductor_seed() calls",
|
| 54 |
+
tags=(torch.Tag.nondeterministic_seeded,),
|
| 55 |
+
)
|
| 56 |
+
lookup_seed = make_prim(
|
| 57 |
+
# if inductor_lookup_seed changes, update partitioners.py
|
| 58 |
+
"inductor_lookup_seed(Tensor seeds, int index) -> Tensor",
|
| 59 |
+
lambda seeds, index: seeds[index],
|
| 60 |
+
doc="Extract a single seed from the result of inductor_seeds()",
|
| 61 |
+
)
|
| 62 |
+
random = make_prim(
|
| 63 |
+
"inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor",
|
| 64 |
+
lambda size, seed, mode: getattr(torch, mode)(size, device=seed.device),
|
| 65 |
+
doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused",
|
| 66 |
+
)
|
| 67 |
+
randint = make_prim(
|
| 68 |
+
"inductor_randint(SymInt low, SymInt high, SymInt[] size, Tensor seed) -> Tensor",
|
| 69 |
+
lambda low, high, size, seed: torch.randint(low, high, size, device=seed.device),
|
| 70 |
+
doc="torch.randint() using backend-specific RNG that can be fused",
|
| 71 |
+
)
|
| 72 |
+
force_stride_order = make_prim(
|
| 73 |
+
"inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor",
|
| 74 |
+
eager_force_stride,
|
| 75 |
+
doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise",
|
| 76 |
+
)
|
| 77 |
+
masked_scatter_with_index = make_prim(
|
| 78 |
+
"inductor_masked_scatter_with_index(Tensor input, Tensor mask, Tensor source_idx, Tensor source) -> Tensor",
|
| 79 |
+
lambda input_tensor, mask, index, source: torch.masked_scatter(
|
| 80 |
+
input_tensor, mask, source
|
| 81 |
+
),
|
| 82 |
+
doc="masked_scatter with precomputed indices",
|
| 83 |
+
)
|
| 84 |
+
_unsafe_index_put_ = make_prim(
|
| 85 |
+
"_unsafe_index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)",
|
| 86 |
+
lambda self, indices, values, accumulate=False: torch.ops.aten.index_put_(
|
| 87 |
+
self, indices, values, accumulate
|
| 88 |
+
),
|
| 89 |
+
doc="Unsafe index_put_ (doesn't issue device asserts)",
|
| 90 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/test_case.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import tempfile
|
| 3 |
+
import unittest
|
| 4 |
+
|
| 5 |
+
from torch._dynamo.test_case import (
|
| 6 |
+
run_tests as dynamo_run_tests,
|
| 7 |
+
TestCase as DynamoTestCase,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
from torch._inductor import config
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def run_tests(needs=()):
|
| 14 |
+
dynamo_run_tests(needs)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestCase(DynamoTestCase):
|
| 18 |
+
"""
|
| 19 |
+
A base TestCase for inductor tests. Enables FX graph caching and isolates
|
| 20 |
+
the cache directory for each test.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
_stack: contextlib.ExitStack
|
| 24 |
+
|
| 25 |
+
@classmethod
|
| 26 |
+
def setUpClass(cls):
|
| 27 |
+
super().setUpClass()
|
| 28 |
+
cls._stack = contextlib.ExitStack()
|
| 29 |
+
cls._stack.enter_context(config.patch({"fx_graph_cache": True}))
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def tearDownClass(cls):
|
| 33 |
+
super().tearDownClass()
|
| 34 |
+
cls._stack.close()
|
| 35 |
+
|
| 36 |
+
def setUp(self):
|
| 37 |
+
super().setUp()
|
| 38 |
+
|
| 39 |
+
# For all tests, mock the tmp directory populated by the inductor
|
| 40 |
+
# FxGraphCache, both for test isolation and to avoid filling disk.
|
| 41 |
+
self._inductor_cache_tmp_dir = tempfile.TemporaryDirectory()
|
| 42 |
+
self._inductor_cache_get_tmp_dir_patch = unittest.mock.patch(
|
| 43 |
+
"torch._inductor.codecache.FxGraphCache._get_tmp_dir"
|
| 44 |
+
)
|
| 45 |
+
mock_get_dir = self._inductor_cache_get_tmp_dir_patch.start()
|
| 46 |
+
mock_get_dir.return_value = self._inductor_cache_tmp_dir.name
|
| 47 |
+
|
| 48 |
+
def tearDown(self):
|
| 49 |
+
super().tearDown()
|
| 50 |
+
|
| 51 |
+
# Clean up the FxGraphCache tmp dir.
|
| 52 |
+
self._inductor_cache_get_tmp_dir_patch.stop()
|
| 53 |
+
self._inductor_cache_tmp_dir.cleanup()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (80.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/jiterator.cpython-311.pyc
ADDED
|
Binary file (7.99 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/nccl.cpython-311.pyc
ADDED
|
Binary file (6.46 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/random.cpython-311.pyc
ADDED
|
Binary file (8.62 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/streams.cpython-311.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py
ADDED
|
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import io
|
| 5 |
+
import subprocess
|
| 6 |
+
import json
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
from typing import Any
|
| 9 |
+
from itertools import groupby
|
| 10 |
+
import base64
|
| 11 |
+
import warnings
|
| 12 |
+
|
| 13 |
+
cache = lru_cache(None)
|
| 14 |
+
|
| 15 |
+
__all__ = ["format_flamegraph", "segments", "memory", "compare"]
|
| 16 |
+
|
| 17 |
+
def _frame_fmt(f, full_filename=False):
|
| 18 |
+
i = f['line']
|
| 19 |
+
fname = f['filename']
|
| 20 |
+
if not full_filename:
|
| 21 |
+
fname = fname.split('/')[-1]
|
| 22 |
+
func = f['name']
|
| 23 |
+
return f'{fname}:{i}:{func}'
|
| 24 |
+
|
| 25 |
+
@cache
|
| 26 |
+
def _frame_filter(name, filename):
|
| 27 |
+
omit_functions = [
|
| 28 |
+
"unwind::unwind",
|
| 29 |
+
"CapturedTraceback::gather",
|
| 30 |
+
"gather_with_cpp",
|
| 31 |
+
"_start",
|
| 32 |
+
"__libc_start_main",
|
| 33 |
+
"PyEval_",
|
| 34 |
+
"PyObject_",
|
| 35 |
+
"PyFunction_",
|
| 36 |
+
]
|
| 37 |
+
omit_filenames = [
|
| 38 |
+
"core/boxing",
|
| 39 |
+
"/Register",
|
| 40 |
+
"/Redispatch",
|
| 41 |
+
"pythonrun.c",
|
| 42 |
+
"Modules/main.c",
|
| 43 |
+
"Objects/call.c",
|
| 44 |
+
"Objects/methodobject.c",
|
| 45 |
+
"pycore_ceval.h",
|
| 46 |
+
"ceval.c",
|
| 47 |
+
"cpython/abstract.h",
|
| 48 |
+
]
|
| 49 |
+
for of in omit_functions:
|
| 50 |
+
if of in name:
|
| 51 |
+
return False
|
| 52 |
+
for of in omit_filenames:
|
| 53 |
+
if of in filename:
|
| 54 |
+
return False
|
| 55 |
+
return True
|
| 56 |
+
|
| 57 |
+
def _frames_fmt(frames, full_filename=False, reverse=False):
|
| 58 |
+
if reverse:
|
| 59 |
+
frames = reversed(frames)
|
| 60 |
+
return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f['name'], f['filename'])]
|
| 61 |
+
|
| 62 |
+
def _block_extra_legacy(b):
|
| 63 |
+
if 'history' in b:
|
| 64 |
+
frames = b['history'][0].get('frames', [])
|
| 65 |
+
real_size = b['history'][0]['real_size']
|
| 66 |
+
else:
|
| 67 |
+
real_size = b.get('requested_size', b['size'])
|
| 68 |
+
frames = []
|
| 69 |
+
return frames, real_size
|
| 70 |
+
|
| 71 |
+
def _block_extra(b):
|
| 72 |
+
if 'frames' not in b:
|
| 73 |
+
# old snapshot format made it more complicated to get frames/allocated size
|
| 74 |
+
return _block_extra_legacy(b)
|
| 75 |
+
return b['frames'], b['requested_size']
|
| 76 |
+
|
| 77 |
+
def format_flamegraph(flamegraph_lines, flamegraph_script=None):
|
| 78 |
+
if flamegraph_script is None:
|
| 79 |
+
flamegraph_script = f'/tmp/{os.getuid()}_flamegraph.pl'
|
| 80 |
+
if not os.path.exists(flamegraph_script):
|
| 81 |
+
import urllib.request
|
| 82 |
+
print(f"Downloading flamegraph.pl to: {flamegraph_script}")
|
| 83 |
+
urllib.request.urlretrieve(
|
| 84 |
+
'https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl', flamegraph_script)
|
| 85 |
+
subprocess.check_call(['chmod', '+x', flamegraph_script])
|
| 86 |
+
args = [flamegraph_script, '--countname', 'bytes']
|
| 87 |
+
p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8')
|
| 88 |
+
assert p.stdin is not None
|
| 89 |
+
assert p.stdout is not None
|
| 90 |
+
p.stdin.write(flamegraph_lines)
|
| 91 |
+
p.stdin.close()
|
| 92 |
+
result = p.stdout.read()
|
| 93 |
+
p.stdout.close()
|
| 94 |
+
p.wait()
|
| 95 |
+
assert p.wait() == 0
|
| 96 |
+
return result
|
| 97 |
+
|
| 98 |
+
def _write_blocks(f, prefix, blocks):
|
| 99 |
+
def frames_fragment(frames):
|
| 100 |
+
if not frames:
|
| 101 |
+
return "<non-python>"
|
| 102 |
+
return ';'.join(_frames_fmt(frames, reverse=True))
|
| 103 |
+
for b in blocks:
|
| 104 |
+
if 'history' not in b:
|
| 105 |
+
frames, accounted_for_size = _block_extra(b)
|
| 106 |
+
f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n')
|
| 107 |
+
else:
|
| 108 |
+
accounted_for_size = 0
|
| 109 |
+
for h in b['history']:
|
| 110 |
+
sz = h['real_size']
|
| 111 |
+
accounted_for_size += sz
|
| 112 |
+
if 'frames' in h:
|
| 113 |
+
frames = h['frames']
|
| 114 |
+
f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n')
|
| 115 |
+
else:
|
| 116 |
+
f.write(f'{prefix};{b["state"]};<no-context> {sz}\n')
|
| 117 |
+
gaps = b['size'] - accounted_for_size
|
| 118 |
+
if gaps:
|
| 119 |
+
f.write(f'{prefix};{b["state"]};<gaps> {gaps}\n')
|
| 120 |
+
|
| 121 |
+
def segments(snapshot, format_flamegraph=format_flamegraph):
|
| 122 |
+
f = io.StringIO()
|
| 123 |
+
for seg in snapshot['segments']:
|
| 124 |
+
prefix = f'stream_{seg["stream"]};seg_{seg["address"]}'
|
| 125 |
+
_write_blocks(f, prefix, seg['blocks'])
|
| 126 |
+
return format_flamegraph(f.getvalue())
|
| 127 |
+
|
| 128 |
+
def memory(snapshot, format_flamegraph=format_flamegraph):
|
| 129 |
+
f = io.StringIO()
|
| 130 |
+
for seg in snapshot['segments']:
|
| 131 |
+
prefix = f'stream_{seg["stream"]}'
|
| 132 |
+
_write_blocks(f, prefix, seg['blocks'])
|
| 133 |
+
return format_flamegraph(f.getvalue())
|
| 134 |
+
|
| 135 |
+
def compare(before, after, format_flamegraph=format_flamegraph):
|
| 136 |
+
def _seg_key(seg):
|
| 137 |
+
return (seg['address'], seg['total_size'])
|
| 138 |
+
|
| 139 |
+
def _seg_info(seg):
|
| 140 |
+
return f'stream_{seg["stream"]};seg_{seg["address"]}'
|
| 141 |
+
|
| 142 |
+
f = io.StringIO()
|
| 143 |
+
|
| 144 |
+
before_segs = {_seg_key(seg) for seg in before}
|
| 145 |
+
after_segs = {_seg_key(seg) for seg in after}
|
| 146 |
+
|
| 147 |
+
print(f'only_before = {[a for a,_ in (before_segs - after_segs)]}')
|
| 148 |
+
print(f'only_after = {[a for a,_ in (after_segs - before_segs)]}')
|
| 149 |
+
|
| 150 |
+
for seg in before:
|
| 151 |
+
if _seg_key(seg) not in after_segs:
|
| 152 |
+
_write_blocks(f, f'only_before;{_seg_info(seg)}', seg['blocks'])
|
| 153 |
+
|
| 154 |
+
for seg in after:
|
| 155 |
+
if _seg_key(seg) not in before_segs:
|
| 156 |
+
_write_blocks(f, f'only_after;{_seg_info(seg)}', seg['blocks'])
|
| 157 |
+
|
| 158 |
+
return format_flamegraph(f.getvalue())
|
| 159 |
+
|
| 160 |
+
def _format_size(num):
|
| 161 |
+
# https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size
|
| 162 |
+
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
|
| 163 |
+
if abs(num) < 1024.0:
|
| 164 |
+
return f"{num:3.1f}{unit}B"
|
| 165 |
+
num /= 1024.0
|
| 166 |
+
return f"{num:.1f}YiB"
|
| 167 |
+
|
| 168 |
+
class Bytes:
|
| 169 |
+
def __init__(self, value):
|
| 170 |
+
self.value = value
|
| 171 |
+
|
| 172 |
+
def __add__(self, rhs):
|
| 173 |
+
return Bytes(self.value + rhs)
|
| 174 |
+
|
| 175 |
+
def __repr__(self):
|
| 176 |
+
return _format_size(self.value)
|
| 177 |
+
|
| 178 |
+
def calc_active(seg):
|
| 179 |
+
return sum(b['size'] for b in seg['blocks'] if b['state'] == 'active_allocated')
|
| 180 |
+
|
| 181 |
+
def _report_free(free_external, free_internal):
|
| 182 |
+
total = free_external + free_internal
|
| 183 |
+
suffix = ''
|
| 184 |
+
if total != 0:
|
| 185 |
+
pct = (free_internal / total) * 100
|
| 186 |
+
suffix = f' ({pct:.1f}% internal)'
|
| 187 |
+
return f'{Bytes(total)}{suffix}'
|
| 188 |
+
|
| 189 |
+
PAGE_SIZE = 1024 * 1024 * 20
|
| 190 |
+
legend = f"""\
|
| 191 |
+
|
| 192 |
+
Legend:
|
| 193 |
+
[a ] - a segment in the allocator
|
| 194 |
+
^-- a page {Bytes(PAGE_SIZE)} of memory in the segment
|
| 195 |
+
a-z: pages filled with a single block's content
|
| 196 |
+
' ': page is completely free
|
| 197 |
+
*: page if completely full with multiple blocks
|
| 198 |
+
0-9: page is partially full with tensors of multiple blocks (9 == 90% full)
|
| 199 |
+
(X% internal) - of the free memory, X% is free because we rounded the size of the allocation.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def segsum(data):
|
| 203 |
+
r"""Visually reports how the allocator has filled its segments.
|
| 204 |
+
|
| 205 |
+
This printout can help debug fragmentation issues since free fragments
|
| 206 |
+
will appear as gaps in this printout. The amount of free space is reported
|
| 207 |
+
for each segment.
|
| 208 |
+
We distinguish between internal free memory which occurs because the
|
| 209 |
+
allocator rounds the allocation size, and external free memory, which are
|
| 210 |
+
the gaps between allocations in a segment.
|
| 211 |
+
Args:
|
| 212 |
+
data: snapshot dictionary created from _snapshot()
|
| 213 |
+
"""
|
| 214 |
+
segments = []
|
| 215 |
+
out = io.StringIO()
|
| 216 |
+
out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n")
|
| 217 |
+
total_reserved = 0
|
| 218 |
+
total_allocated = 0
|
| 219 |
+
free_external = 0
|
| 220 |
+
free_internal = 0
|
| 221 |
+
for seg in sorted(data['segments'], key=lambda x: (x['total_size'], calc_active(x))):
|
| 222 |
+
total_reserved += seg['total_size']
|
| 223 |
+
|
| 224 |
+
seg_free_external = 0
|
| 225 |
+
seg_free_internal = 0
|
| 226 |
+
seg_allocated = 0
|
| 227 |
+
all_ranges = []
|
| 228 |
+
boffset = 0
|
| 229 |
+
for b in seg['blocks']:
|
| 230 |
+
active = b['state'] == 'active_allocated'
|
| 231 |
+
if active:
|
| 232 |
+
_, allocated_size = _block_extra(b)
|
| 233 |
+
all_ranges.append((boffset, allocated_size, True))
|
| 234 |
+
seg_allocated += allocated_size
|
| 235 |
+
seg_free_internal += b['size'] - allocated_size
|
| 236 |
+
else:
|
| 237 |
+
seg_free_external += b['size']
|
| 238 |
+
|
| 239 |
+
boffset += b['size']
|
| 240 |
+
|
| 241 |
+
total_allocated += seg_allocated
|
| 242 |
+
free_external += seg_free_external
|
| 243 |
+
free_internal += seg_free_internal
|
| 244 |
+
|
| 245 |
+
nseg = (seg['total_size'] - 1) // PAGE_SIZE + 1
|
| 246 |
+
occupied = [' ' for _ in range(nseg)]
|
| 247 |
+
frac = [0.0 for _ in range(nseg)]
|
| 248 |
+
active_size = 0
|
| 249 |
+
for i, (start_, size, active) in enumerate(all_ranges):
|
| 250 |
+
active_size += size
|
| 251 |
+
finish_ = (start_ + size)
|
| 252 |
+
start = start_ // PAGE_SIZE
|
| 253 |
+
finish = (finish_ - 1) // PAGE_SIZE + 1
|
| 254 |
+
m = chr(ord('a' if active else 'A') + (i % 26))
|
| 255 |
+
for j in range(start, finish):
|
| 256 |
+
s = max(start_, j * PAGE_SIZE)
|
| 257 |
+
e = min(finish_, (j + 1) * PAGE_SIZE)
|
| 258 |
+
frac[j] += (e - s) / PAGE_SIZE
|
| 259 |
+
if occupied[j] != ' ':
|
| 260 |
+
occupied[j] = '0123456789*'[int(frac[j] * 10)]
|
| 261 |
+
else:
|
| 262 |
+
occupied[j] = m
|
| 263 |
+
stream = '' if seg['stream'] == 0 else f', stream_{seg["stream"]}'
|
| 264 |
+
body = ''.join(occupied)
|
| 265 |
+
assert seg_free_external + seg_free_internal + seg_allocated == seg['total_size']
|
| 266 |
+
stream = f' stream_{seg["stream"]}' if seg['stream'] != 0 else ''
|
| 267 |
+
if seg['total_size'] >= PAGE_SIZE:
|
| 268 |
+
out.write(f'[{body}] {Bytes(seg["total_size"])} allocated, '
|
| 269 |
+
f'{_report_free(seg_free_external, seg_free_internal)} free{stream}\n')
|
| 270 |
+
out.write(f'segments: {len(data["segments"])}\n')
|
| 271 |
+
out.write(f'total_reserved: {Bytes(total_reserved)}\n')
|
| 272 |
+
out.write(f'total_allocated: {Bytes(total_allocated)}\n')
|
| 273 |
+
internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else ''
|
| 274 |
+
out.write(f'total_free: {_report_free(free_external, free_internal)}\n')
|
| 275 |
+
out.write(legend)
|
| 276 |
+
assert free_internal + free_external + total_allocated == total_reserved
|
| 277 |
+
return out.getvalue()
|
| 278 |
+
|
| 279 |
+
def trace(data):
|
| 280 |
+
out = io.StringIO()
|
| 281 |
+
|
| 282 |
+
def format(entries):
|
| 283 |
+
segment_intervals : list = []
|
| 284 |
+
segment_addr_to_name = {}
|
| 285 |
+
allocation_addr_to_name = {}
|
| 286 |
+
|
| 287 |
+
free_names : list = []
|
| 288 |
+
next_name = 0
|
| 289 |
+
|
| 290 |
+
def _name():
|
| 291 |
+
nonlocal next_name
|
| 292 |
+
if free_names:
|
| 293 |
+
return free_names.pop()
|
| 294 |
+
r, m = next_name // 26, next_name % 26
|
| 295 |
+
next_name += 1
|
| 296 |
+
return f'{chr(ord("a") + m)}{"" if r == 0 else r}'
|
| 297 |
+
|
| 298 |
+
def find_segment(addr):
|
| 299 |
+
for name, saddr, size in segment_intervals:
|
| 300 |
+
if addr >= saddr and addr < saddr + size:
|
| 301 |
+
return name, saddr
|
| 302 |
+
for i, seg in enumerate(data['segments']):
|
| 303 |
+
saddr = seg['address']
|
| 304 |
+
size = seg['allocated_size']
|
| 305 |
+
if addr >= saddr and addr < saddr + size:
|
| 306 |
+
return f'seg_{i}', saddr
|
| 307 |
+
return None, None
|
| 308 |
+
count = 0
|
| 309 |
+
out.write(f'{len(entries)} entries\n')
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
total_reserved = 0
|
| 313 |
+
for seg in data['segments']:
|
| 314 |
+
total_reserved += seg['total_size']
|
| 315 |
+
|
| 316 |
+
for count, e in enumerate(entries):
|
| 317 |
+
if e['action'] == 'alloc':
|
| 318 |
+
addr, size = e['addr'], e['size']
|
| 319 |
+
n = _name()
|
| 320 |
+
seg_name, seg_addr = find_segment(addr)
|
| 321 |
+
if seg_name is None:
|
| 322 |
+
seg_name = "MEM"
|
| 323 |
+
offset = addr
|
| 324 |
+
else:
|
| 325 |
+
offset = addr - seg_addr
|
| 326 |
+
out.write(f'{n} = {seg_name}[{offset}:{Bytes(size)}]\n')
|
| 327 |
+
allocation_addr_to_name[addr] = (n, size, count)
|
| 328 |
+
count += size
|
| 329 |
+
elif e['action'] == 'free_requested':
|
| 330 |
+
addr, size = e['addr'], e['size']
|
| 331 |
+
name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
|
| 332 |
+
out.write(f'del {name} # {Bytes(size)}\n')
|
| 333 |
+
elif e['action'] == 'free_completed':
|
| 334 |
+
addr, size = e['addr'], e['size']
|
| 335 |
+
count -= size
|
| 336 |
+
name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
|
| 337 |
+
out.write(f'# free completed for {name} {Bytes(size)}\n')
|
| 338 |
+
if name in allocation_addr_to_name:
|
| 339 |
+
free_names.append(name)
|
| 340 |
+
del allocation_addr_to_name[name]
|
| 341 |
+
elif e['action'] == 'segment_alloc':
|
| 342 |
+
addr, size = e['addr'], e['size']
|
| 343 |
+
name = _name()
|
| 344 |
+
out.write(f'{name} = cudaMalloc({addr}, {Bytes(size)})\n')
|
| 345 |
+
segment_intervals.append((name, addr, size))
|
| 346 |
+
segment_addr_to_name[addr] = name
|
| 347 |
+
elif e['action'] == 'segment_free':
|
| 348 |
+
addr, size = e['addr'], e['size']
|
| 349 |
+
name = segment_addr_to_name.get(addr, addr)
|
| 350 |
+
out.write(f'cudaFree({name}) # {Bytes(size)}\n')
|
| 351 |
+
if name in segment_addr_to_name:
|
| 352 |
+
free_names.append(name)
|
| 353 |
+
del segment_addr_to_name[name]
|
| 354 |
+
elif e['action'] == 'oom':
|
| 355 |
+
size = e['size']
|
| 356 |
+
free = e['device_free']
|
| 357 |
+
out.write(f'raise OutOfMemoryError() # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n')
|
| 358 |
+
else:
|
| 359 |
+
out.write(f'{e}\n')
|
| 360 |
+
out.write(f"TOTAL MEM: {Bytes(count)}")
|
| 361 |
+
for i, d in enumerate(data['device_traces']):
|
| 362 |
+
if d:
|
| 363 |
+
out.write(f'Device {i} ----------------\n')
|
| 364 |
+
format(d)
|
| 365 |
+
return out.getvalue()
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
_memory_viz_template = r"""
|
| 369 |
+
<!DOCTYPE html>
|
| 370 |
+
<html>
|
| 371 |
+
<head>
|
| 372 |
+
</head>
|
| 373 |
+
<body>
|
| 374 |
+
<script type="module">
|
| 375 |
+
import {add_local_files} from "https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/utils/viz/MemoryViz.js"
|
| 376 |
+
const local_files = $SNAPSHOT
|
| 377 |
+
add_local_files(local_files, $VIZ_KIND)
|
| 378 |
+
</script>
|
| 379 |
+
</body>
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
def _format_viz(data, viz_kind, device):
|
| 383 |
+
if device is not None:
|
| 384 |
+
warnings.warn('device argument is deprecated, plots now contain all device')
|
| 385 |
+
buffer = pickle.dumps(data)
|
| 386 |
+
buffer += b'\x00' * (3 - len(buffer) % 3)
|
| 387 |
+
# Encode the buffer with base64
|
| 388 |
+
encoded_buffer = base64.b64encode(buffer).decode('utf-8')
|
| 389 |
+
|
| 390 |
+
json_format = json.dumps([{"name": 'snapshot.pickle', "base64": encoded_buffer}])
|
| 391 |
+
return _memory_viz_template.replace('$VIZ_KIND', repr(viz_kind)) \
|
| 392 |
+
.replace('$SNAPSHOT', json_format)
|
| 393 |
+
|
| 394 |
+
def trace_plot(data, device=None, plot_segments=False):
|
| 395 |
+
"""Generate a visualization over time of the memory usage recorded by the trace as an html file.
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
data: Memory snapshot as generated from torch.cuda.memory._snapshot()
|
| 399 |
+
device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
|
| 400 |
+
plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
|
| 401 |
+
Defaults to False.
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
str: HTML of visualization
|
| 405 |
+
"""
|
| 406 |
+
return _format_viz(data, 'Active Memory Timeline' if not plot_segments else 'Active Cached Memory Timeline', device)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def _profile_to_snapshot(profile):
|
| 410 |
+
import torch
|
| 411 |
+
from torch.profiler._memory_profiler import Action, TensorKey
|
| 412 |
+
from torch._C._profiler import _EventType
|
| 413 |
+
memory_profile = profile._memory_profile()
|
| 414 |
+
|
| 415 |
+
allocation_stacks = {}
|
| 416 |
+
for event in memory_profile._op_tree.sorted_nodes:
|
| 417 |
+
if event.tag == _EventType.Allocation:
|
| 418 |
+
parent = event.parent
|
| 419 |
+
python_parents = []
|
| 420 |
+
while parent:
|
| 421 |
+
if parent.tag in (_EventType.PyCall, _EventType.PyCCall):
|
| 422 |
+
python_parents.append(parent)
|
| 423 |
+
parent = parent.parent
|
| 424 |
+
key = TensorKey.from_allocation(event.extra_fields)
|
| 425 |
+
|
| 426 |
+
# Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor)
|
| 427 |
+
# key will be None. I should add some way to identify these, I just haven't yet.
|
| 428 |
+
if key and event.extra_fields.alloc_size > 0:
|
| 429 |
+
allocation_stacks[key] = python_parents
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
device_count = torch.cuda.device_count()
|
| 433 |
+
snapshot = {
|
| 434 |
+
'device_traces': [[] for _ in range(device_count + 1)],
|
| 435 |
+
'segments': [{'device': device,
|
| 436 |
+
'address': None,
|
| 437 |
+
'total_size': 0,
|
| 438 |
+
'stream': 0,
|
| 439 |
+
'blocks': []} for device in range(device_count + 1)]
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
def to_device(device):
|
| 443 |
+
if device.type == 'cuda':
|
| 444 |
+
return device.index
|
| 445 |
+
else:
|
| 446 |
+
return device_count
|
| 447 |
+
|
| 448 |
+
def allocate(size, tensor_key, version, during_trace=True):
|
| 449 |
+
device = to_device(tensor_key.device)
|
| 450 |
+
addr = tensor_key.storage.ptr
|
| 451 |
+
|
| 452 |
+
seg = snapshot['segments'][device] # type: ignore[index]
|
| 453 |
+
if seg['address'] is None or seg['address'] > addr:
|
| 454 |
+
seg['address'] = addr
|
| 455 |
+
seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later
|
| 456 |
+
category = memory_profile._categories.get(tensor_key, version)
|
| 457 |
+
category = category.name.lower() if category is not None else "unknown"
|
| 458 |
+
stack = allocation_stacks.get(tensor_key, ())
|
| 459 |
+
stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack]
|
| 460 |
+
r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category}
|
| 461 |
+
if during_trace:
|
| 462 |
+
snapshot['device_traces'][device].append(r) # type: ignore[index]
|
| 463 |
+
return r
|
| 464 |
+
|
| 465 |
+
def free(alloc, device):
|
| 466 |
+
for e in ('free_requested', 'free_completed'):
|
| 467 |
+
snapshot['device_traces'][device].append({'action': e, # type: ignore[index]
|
| 468 |
+
'addr': alloc['addr'],
|
| 469 |
+
'size': alloc['size'],
|
| 470 |
+
'stream': 0,
|
| 471 |
+
'frames': alloc['frames']})
|
| 472 |
+
|
| 473 |
+
kv_to_elem = {}
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
# create the device trace
|
| 478 |
+
for time, action, (tensor_key, version), size in memory_profile.timeline:
|
| 479 |
+
if not isinstance(tensor_key, TensorKey):
|
| 480 |
+
continue
|
| 481 |
+
if action == Action.CREATE:
|
| 482 |
+
kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version)
|
| 483 |
+
elif action == Action.DESTROY:
|
| 484 |
+
free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
|
| 485 |
+
elif action == Action.INCREMENT_VERSION:
|
| 486 |
+
free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
|
| 487 |
+
kv_to_elem[(tensor_key, version + 1)] = allocate(size, tensor_key, version + 1)
|
| 488 |
+
elif action == Action.PREEXISTING:
|
| 489 |
+
kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version, during_trace=False)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
# create the final snapshot state
|
| 493 |
+
blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames'])
|
| 494 |
+
for (tensor_key, version), event in kv_to_elem.items()]
|
| 495 |
+
for device, blocks in groupby(sorted(blocks_at_end), key=lambda x: x[0]):
|
| 496 |
+
seg = snapshot['segments'][device] # type: ignore[index]
|
| 497 |
+
last_addr = seg['address']
|
| 498 |
+
for _, addr, size, frames in blocks:
|
| 499 |
+
if last_addr < addr:
|
| 500 |
+
seg['blocks'].append({'size': addr - last_addr, 'state': 'inactive'})
|
| 501 |
+
seg['blocks'].append({'size': size, 'state': 'active_allocated', 'requested_size': size, 'frames': frames})
|
| 502 |
+
last_addr = addr + size
|
| 503 |
+
if last_addr < seg['total_size']:
|
| 504 |
+
seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'})
|
| 505 |
+
|
| 506 |
+
snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined]
|
| 507 |
+
for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef]
|
| 508 |
+
seg['total_size'] -= seg['address']
|
| 509 |
+
if not seg['blocks']:
|
| 510 |
+
seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'})
|
| 511 |
+
|
| 512 |
+
return snapshot
|
| 513 |
+
|
| 514 |
+
def profile_plot(profile, device=None):
|
| 515 |
+
"""Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file.
|
| 516 |
+
|
| 517 |
+
Args:
|
| 518 |
+
profile: profile as generated by `torch.profiler.profile(profile_memory=True)`
|
| 519 |
+
device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
|
| 520 |
+
|
| 521 |
+
Returns:
|
| 522 |
+
str: HTML of visualization
|
| 523 |
+
"""
|
| 524 |
+
snapshot = _profile_to_snapshot(profile)
|
| 525 |
+
return _format_viz(snapshot, 'Active Memory Timeline', device)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def segment_plot(data: Any, device=None):
|
| 529 |
+
return _format_viz(data, 'Allocator State History', device)
|
| 530 |
+
|
| 531 |
+
if __name__ == "__main__":
|
| 532 |
+
import os.path
|
| 533 |
+
thedir = os.path.realpath(os.path.dirname(__file__))
|
| 534 |
+
if thedir in sys.path:
|
| 535 |
+
# otherwise we find cuda/random.py as random...
|
| 536 |
+
sys.path.remove(thedir)
|
| 537 |
+
import argparse
|
| 538 |
+
|
| 539 |
+
fn_name = 'torch.cuda.memory._snapshot()'
|
| 540 |
+
pickled = f'pickled memory statistics from {fn_name}'
|
| 541 |
+
parser = argparse.ArgumentParser(description=f'Visualize memory dumps produced by {fn_name}')
|
| 542 |
+
|
| 543 |
+
subparsers = parser.add_subparsers(dest='action')
|
| 544 |
+
|
| 545 |
+
def _output(p):
|
| 546 |
+
p.add_argument('-o', '--output', default='output.svg', help='flamegraph svg (default: output.svg)')
|
| 547 |
+
|
| 548 |
+
description = 'Prints overall allocation statistics and a visualization of how the allocators segments are currently filled.'
|
| 549 |
+
stats_a = subparsers.add_parser('stats', description=description)
|
| 550 |
+
stats_a.add_argument('input', help=pickled)
|
| 551 |
+
|
| 552 |
+
description = 'Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style.'
|
| 553 |
+
trace_a = subparsers.add_parser('trace', description=description)
|
| 554 |
+
trace_a.add_argument('input', help=pickled)
|
| 555 |
+
|
| 556 |
+
description = 'Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)'
|
| 557 |
+
segments_a = subparsers.add_parser('segments', description=description)
|
| 558 |
+
segments_a.add_argument('input', help=pickled)
|
| 559 |
+
_output(segments_a)
|
| 560 |
+
|
| 561 |
+
description = "Generate a flamegraph the program locations contributing to CUDA memory usage."
|
| 562 |
+
memory_a = subparsers.add_parser('memory', description=description)
|
| 563 |
+
memory_a.add_argument('input', help=pickled)
|
| 564 |
+
_output(memory_a)
|
| 565 |
+
|
| 566 |
+
description = 'Generate a flamegraph that shows segments (aka blocks) that have been added ' \
|
| 567 |
+
'or removed between two different memorys snapshots.'
|
| 568 |
+
compare_a = subparsers.add_parser('compare', description=description)
|
| 569 |
+
compare_a.add_argument('before', help=pickled)
|
| 570 |
+
compare_a.add_argument('after', help=pickled)
|
| 571 |
+
_output(compare_a)
|
| 572 |
+
|
| 573 |
+
plots = (
|
| 574 |
+
("trace_plot", "Generate a visualization over time of the memory usage recorded by the trace as an html file."),
|
| 575 |
+
("segment_plot", "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.")
|
| 576 |
+
)
|
| 577 |
+
for cmd, description in plots:
|
| 578 |
+
trace_plot_a = subparsers.add_parser(cmd, description=description)
|
| 579 |
+
trace_plot_a.add_argument('input', help=pickled)
|
| 580 |
+
help = 'visualize trace from this device (default: chooses the only device with trace info or errors)'
|
| 581 |
+
trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help)
|
| 582 |
+
help = 'path to save the visualization(default: output.html)'
|
| 583 |
+
trace_plot_a.add_argument('-o', '--output', default='output.html', help=help)
|
| 584 |
+
if cmd == "trace_plot":
|
| 585 |
+
help = 'visualize change to segments rather than individual allocations'
|
| 586 |
+
trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help)
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
args = parser.parse_args()
|
| 590 |
+
|
| 591 |
+
def _read(name):
|
| 592 |
+
if name == '-':
|
| 593 |
+
f = sys.stdin.buffer
|
| 594 |
+
else:
|
| 595 |
+
f = open(name, 'rb')
|
| 596 |
+
data = pickle.load(f)
|
| 597 |
+
if isinstance(data, list): # segments only...
|
| 598 |
+
data = {'segments': data, 'traces': []}
|
| 599 |
+
return data
|
| 600 |
+
|
| 601 |
+
def _write(name, data):
|
| 602 |
+
with open(name, 'w') as f:
|
| 603 |
+
f.write(data)
|
| 604 |
+
|
| 605 |
+
if args.action == 'segments':
|
| 606 |
+
data = _read(args.input)
|
| 607 |
+
_write(args.output, segments(data))
|
| 608 |
+
elif args.action == 'memory':
|
| 609 |
+
data = _read(args.input)
|
| 610 |
+
_write(args.output, memory(data))
|
| 611 |
+
elif args.action == 'stats':
|
| 612 |
+
data = _read(args.input)
|
| 613 |
+
print(segsum(data))
|
| 614 |
+
elif args.action == 'trace':
|
| 615 |
+
data = _read(args.input)
|
| 616 |
+
print(trace(data))
|
| 617 |
+
elif args.action == 'compare':
|
| 618 |
+
before = _read(args.before)
|
| 619 |
+
after = _read(args.after)
|
| 620 |
+
_write(args.output, compare(before, after))
|
| 621 |
+
elif args.action == 'trace_plot':
|
| 622 |
+
data = _read(args.input)
|
| 623 |
+
_write(args.output, trace_plot(data, device=args.device, plot_segments=args.segments))
|
| 624 |
+
elif args.action == 'segment_plot':
|
| 625 |
+
data = _read(args.input)
|
| 626 |
+
_write(args.output, segment_plot(data, device=args.device))
|