koichi12 commited on
Commit
ee1d2ef
·
verified ·
1 Parent(s): 466ab75

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 +3 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/bounds.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/decomposition.cpython-311.pyc +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-311.pyc +0 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_helpers.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_heuristics.cpython-311.pyc +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/bounds.py +124 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__init__.py +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc +0 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py +1755 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc +0 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc +0 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc +0 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +212 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py +45 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py +242 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py +360 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py +0 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py +186 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py +18 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py +75 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py +130 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py +1204 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py +1100 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py +182 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py +202 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py +186 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/central_index.py +114 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/split_cat.py +1537 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py +90 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py +0 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/test_case.py +53 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/__init__.cpython-311.pyc +0 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/jiterator.cpython-311.pyc +0 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/nccl.cpython-311.pyc +0 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/random.cpython-311.pyc +0 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/streams.cpython-311.pyc +0 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py +626 -0
.gitattributes CHANGED
@@ -74,3 +74,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/V
74
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
75
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 filter=lfs diff=lfs merge=lfs -text
76
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
74
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
75
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 filter=lfs diff=lfs merge=lfs -text
76
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
77
+ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:647373d0020a53c70bd44d2950f81f6c5edec206899855800a76aabe1ae27e02
3
+ size 745240
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-311.pyc ADDED
Binary file (29.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/bounds.cpython-311.pyc ADDED
Binary file (7.75 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/decomposition.cpython-311.pyc ADDED
Binary file (34.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-311.pyc ADDED
Binary file (12.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-311.pyc ADDED
Binary file (35.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-311.pyc ADDED
Binary file (86.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_helpers.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/triton_heuristics.cpython-311.pyc ADDED
Binary file (64.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/utils.cpython-311.pyc ADDED
Binary file (79.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/bounds.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+ from functools import partial
3
+ from typing import Any, Callable, Dict
4
+
5
+ from sympy import Expr
6
+
7
+ import torch
8
+ from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
9
+ from .ir import InterpreterShim, LoopBody, LoopBodyBlock
10
+ from .utils import cache_on_self, dominated_nodes
11
+ from .virtualized import V
12
+
13
+
14
+ class BoundVars:
15
+ """
16
+ Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run()
17
+ It exposes the ranges of the nodes in the `bounds` variable
18
+
19
+ Note. A current limitation of this analysis is that it just works on a per-loop basis.
20
+ We should be able to propagate the bounds between across the whole graph. This may benefit
21
+ the case a bounded variable is returned by a kernel and fed into another.
22
+ """
23
+
24
+ def __init__(self, loop_body: LoopBody) -> None:
25
+ self.loop_body = loop_body
26
+ self.replacement_vals = {
27
+ k: ValueRanges[Expr](0, v - 1)
28
+ if (isinstance(v, int) or v.is_number)
29
+ else bound_sympy(v)
30
+ for k, v in loop_body.var_ranges.items()
31
+ }
32
+ # avoid computing these values, pessimistically assume that they are unbounded
33
+ self.unbounded_vars = dominated_nodes(
34
+ node
35
+ for node in self.loop_body.get_nodes()
36
+ if node.target in ["load", "reduction", operator.getitem]
37
+ or "masked_subblock" in node.target
38
+ )
39
+ # To access this variable call `get_bounds()`
40
+ self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {}
41
+
42
+ @cache_on_self
43
+ def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]:
44
+ submodules = self.swap_submodules(self.loop_body.submodules)
45
+
46
+ # Initialize the environment with the unbounded variables
47
+ for node in self.unbounded_vars:
48
+ # we need to evaluate masked_subblock to recurse, and we need to set indirect values
49
+ if not isinstance(node.target, str) or (
50
+ "masked_subblock" not in node.target
51
+ and "set_indirect" not in node.target
52
+ ):
53
+ self._bounds[node] = ValueRanges[Expr].unknown()
54
+
55
+ with V.set_ops_handler(ValueRangeAnalysis()):
56
+ interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules)
57
+ interpreter.run(V.get_ops_handler(), initial_env=self._bounds)
58
+ return self._bounds
59
+
60
+ def swap_submodules(
61
+ self, submodules: Dict[str, Callable[..., Any]]
62
+ ) -> Dict[str, Callable[..., ValueRanges[Expr]]]:
63
+ result: Dict[str, Callable[..., ValueRanges[Expr]]] = {}
64
+ for key in submodules.keys():
65
+ if key == "get_index":
66
+ result[key] = self.get_index
67
+ elif "masked_subblock" in key:
68
+ subblock = self.loop_body.subblocks[key]
69
+ # The result within the lambda will reference to the final
70
+ # set of modules at the end of the for-loop as it stores a reference to it
71
+
72
+ # bind subblock in a function because python lambdas close over by reference
73
+ # moving the lambda out of make_fn would close over the reference to subblock,
74
+ # so all lambdas would have the same subblock reference that is the final
75
+ # subblock in the loop
76
+ def make_fn(subblock):
77
+ return lambda mask, value: self.masked_subblock(
78
+ subblock, self._bounds, mask, value, result
79
+ )
80
+
81
+ result[key] = make_fn(subblock)
82
+
83
+ elif "set_indirect" in key:
84
+ idx = int(key[len("set_indirect") :])
85
+ var = self.loop_body.indirect_vars[idx]
86
+ indirect = partial(self.set_indirect, var)
87
+ result[key] = indirect
88
+ else:
89
+ assert "scan" in key
90
+ result[key] = submodules[key]
91
+
92
+ return result
93
+
94
+ def masked_subblock(
95
+ self,
96
+ subblock: LoopBodyBlock,
97
+ env: Dict[torch.fx.Node, ValueRanges[Expr]],
98
+ mask: Any,
99
+ value: Any,
100
+ submodules: Dict[str, Callable[..., Any]],
101
+ ) -> ValueRanges[Expr]:
102
+ interp = InterpreterShim(subblock.graph, submodules)
103
+ interp.run(V.get_ops_handler(), initial_env=env)
104
+ output = [node for node in subblock.graph.nodes if node.target == "output"]
105
+ assert len(output) == 1
106
+ # dont bother unioning with value since the load from buffer will be
107
+ # pessimistically assumed to be inf anyway
108
+ return interp.env[output[0]]
109
+
110
+ def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]:
111
+ assert isinstance(new, ValueRanges)
112
+ self.replacement_vals[old] = new
113
+ return new
114
+
115
+ def get_index(self, name: Expr) -> ValueRanges[Expr]:
116
+ expr = self.loop_body.indexing_exprs[name]
117
+ bound = self.replacement_vals.get(expr)
118
+ if bound is None:
119
+ bound = bound_sympy(expr, self.replacement_vals)
120
+ # The following assertion is true at the time of this writing
121
+ # We don't assert is as to not execute bound_sympy when bound is not None
122
+ # assert bound is None or bound == bound_sympy(expr, self.replacement_vals)
123
+ self.replacement_vals[name] = bound
124
+ return bound
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (224 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-311.pyc ADDED
Binary file (46.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-311.pyc ADDED
Binary file (22.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-311.pyc ADDED
Binary file (9.21 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-311.pyc ADDED
Binary file (94.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/common.py ADDED
@@ -0,0 +1,1755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import dataclasses
3
+ import functools
4
+ import itertools
5
+ import logging
6
+ import operator
7
+ import re
8
+ from itertools import chain
9
+ from typing import (
10
+ Any,
11
+ Callable,
12
+ ClassVar,
13
+ Dict,
14
+ List,
15
+ NamedTuple,
16
+ Optional,
17
+ Set,
18
+ Tuple,
19
+ TYPE_CHECKING,
20
+ Union,
21
+ )
22
+
23
+ import sympy
24
+ from sympy.printing.printer import Printer
25
+
26
+ import torch
27
+ import torch.fx
28
+ from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
29
+ from torch.utils import _pytree as pytree
30
+ from torch.utils._sympy.value_ranges import ValueRanges
31
+
32
+ from .. import config, metrics
33
+ from ..utils import (
34
+ DeferredLineBase,
35
+ do_bench,
36
+ free_symbol_startswith,
37
+ IndentedBuffer,
38
+ sympy_dot,
39
+ sympy_index_symbol,
40
+ sympy_subs,
41
+ unique,
42
+ )
43
+ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
44
+
45
+ if TYPE_CHECKING:
46
+ from ..ir import TensorBox
47
+
48
+ schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
49
+
50
+
51
+ def data_type_logger(msg):
52
+ if schedule_log.isEnabledFor(logging.DEBUG):
53
+ schedule_log.debug("Data type propagation: %s", msg)
54
+
55
+
56
+ @dataclasses.dataclass
57
+ class WorkspaceArg:
58
+ """A temporary buffer used for a single kernel, then discarded.
59
+
60
+ Not registered as a traditional buffer since there are no users,
61
+ so it would be dead code eliminated.
62
+ """
63
+
64
+ nbytes: sympy.Expr
65
+ zero_fill: bool
66
+
67
+
68
+ @dataclasses.dataclass
69
+ class TensorArg:
70
+ name: str
71
+ buffer: str
72
+ dtype: torch.dtype
73
+ offset: sympy.Expr = sympy.Integer(0)
74
+
75
+
76
+ @dataclasses.dataclass
77
+ class SizeArg:
78
+ name: str
79
+ expr: sympy.Expr
80
+
81
+
82
+ @dataclasses.dataclass
83
+ class DeviceCodegen:
84
+ scheduling: type
85
+ wrapper_codegen: type
86
+
87
+
88
+ KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
89
+
90
+ device_codegens: Dict[str, DeviceCodegen] = {}
91
+
92
+
93
+ class DeviceOpOverrides:
94
+ def import_get_raw_stream_as(self, name):
95
+ raise NotImplementedError()
96
+
97
+ def set_device(self, device_idx):
98
+ raise NotImplementedError()
99
+
100
+ def synchronize(self):
101
+ raise NotImplementedError()
102
+
103
+ def device_guard(self, device_idx):
104
+ raise NotImplementedError()
105
+
106
+
107
+ device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
108
+
109
+
110
+ # The code generated by Inductor consists of two main parts: kernel code and wrapper code.
111
+ # For any new backend looking to integrate with Inductor, customization of these two main
112
+ # parts are necessary to generate its specific code.
113
+ #
114
+ # Kernel code generation is determined by different Scheduling. Consequently, a new
115
+ # backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
116
+ # CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
117
+ #
118
+ # For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
119
+ # that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
120
+ # and override specific member functions to create backend-specific Python wrapper code.
121
+ #
122
+ # Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
123
+ # of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
124
+ # provide flexibility to the backend. A backend can choose to implement these classes from scratch,
125
+ # or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
126
+ # register_backend_for_device, to equip a new backend at runtime.
127
+ #
128
+ # Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
129
+ # This backend can be used as a reference:
130
+ # https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
131
+ def register_backend_for_device(
132
+ device: str, device_scheduling: type, device_wrapper_codegen: type
133
+ ):
134
+ device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen)
135
+
136
+
137
+ def get_scheduling_for_device(device: str):
138
+ return device_codegens[device].scheduling if device in device_codegens else None
139
+
140
+
141
+ def get_wrapper_codegen_for_device(device: str):
142
+ return (
143
+ device_codegens[device].wrapper_codegen if device in device_codegens else None
144
+ )
145
+
146
+
147
+ def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
148
+ from ..ir import FlexibleLayout
149
+
150
+ # added contiguous index prevents reordering
151
+ return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
152
+
153
+
154
+ def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
155
+ device_op_overrides_dict[device] = device_op_overrides
156
+
157
+
158
+ def get_device_op_overrides(device: str):
159
+ assert isinstance(device, str)
160
+
161
+ if not device_op_overrides_dict.keys():
162
+ from .cuda import device_op_overrides # noqa: F401
163
+
164
+ if device in device_op_overrides_dict.keys():
165
+ return device_op_overrides_dict[device]
166
+
167
+ return DeviceOpOverrides()
168
+
169
+
170
+ @functools.lru_cache(None)
171
+ def boolean_ops():
172
+ return (
173
+ "is_inf",
174
+ "is_nan",
175
+ "bitwise_xor",
176
+ "logical_not",
177
+ "signbit",
178
+ "le",
179
+ "lt",
180
+ "ge",
181
+ "gt",
182
+ "eq",
183
+ "ne",
184
+ )
185
+
186
+
187
+ DTYPE_TO_COMPUTATION_DTYPE = {
188
+ torch.bfloat16: torch.float,
189
+ torch.float16: torch.float,
190
+ **{
191
+ dtype: dtype
192
+ for dtype in [
193
+ torch.bool,
194
+ torch.float32,
195
+ torch.float64,
196
+ torch.int8,
197
+ torch.int16,
198
+ torch.int32,
199
+ torch.int64,
200
+ torch.uint8,
201
+ torch.uint16,
202
+ torch.uint32,
203
+ torch.uint64,
204
+ ]
205
+ },
206
+ }
207
+
208
+
209
+ class DataTypePropagation:
210
+ def __init__(self, body) -> None:
211
+ self.body = body
212
+ self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
213
+ "root": body.root_block.graph
214
+ }
215
+ for k, v in body.subblocks.items():
216
+ self.graphs[k] = v.graph
217
+
218
+ def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
219
+ inputs = node.all_input_nodes
220
+ input_nodes = [
221
+ n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
222
+ ]
223
+ if len(input_nodes) == 0:
224
+ return None
225
+
226
+ all_input_nodes_propogated = all(
227
+ OptimizationContext.key in n.meta
228
+ and n.meta[OptimizationContext.key].dtype is not None
229
+ for n in input_nodes
230
+ )
231
+ if not all_input_nodes_propogated:
232
+ return None
233
+
234
+ return functools.reduce(
235
+ torch.promote_types,
236
+ [n.meta[OptimizationContext.key].dtype for n in input_nodes],
237
+ )
238
+
239
+ def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
240
+ sub_graph = self.graphs[node.target]
241
+ dtype = self.propagate_graph(sub_graph)
242
+ assert dtype
243
+ return dtype
244
+
245
+ def deduce_node_dtype(self, node: torch.fx.Node):
246
+ if node.target in boolean_ops():
247
+ return torch.bool
248
+
249
+ if node.op == "placeholder":
250
+ return None
251
+
252
+ if node.target == "output":
253
+ # we can infer output node if it only have 1 arg
254
+ if len(node.args) != 1:
255
+ return None
256
+
257
+ if node.target in (
258
+ "to_dtype",
259
+ "index_expr",
260
+ ):
261
+ return node.args[-1]
262
+
263
+ if node.target in (
264
+ "rand",
265
+ "randn",
266
+ ):
267
+ return torch.float
268
+
269
+ if node.target in (
270
+ "get_index",
271
+ "index_expr",
272
+ ):
273
+ return torch.int64
274
+
275
+ if node.target in (
276
+ "load",
277
+ "store",
278
+ "store_reduction",
279
+ ):
280
+ buf_name = node.args[1]
281
+ return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
282
+
283
+ if node.target == operator.getitem:
284
+ return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
285
+
286
+ assert isinstance(node.target, str)
287
+
288
+ if node.target == "reduction":
289
+ return node.args[1]
290
+
291
+ if node.target == "constant":
292
+ return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index]
293
+
294
+ if node.target.startswith("masked_subblock"):
295
+ return self.deduce_node_dtype_by_subgraph(node)
296
+
297
+ return self.deduce_node_dtype_by_inputs(node)
298
+
299
+ def propagate_graph(self, graph: torch.fx.Graph):
300
+ assert graph.nodes
301
+ graph_dtype = None
302
+ # For masked_subblock, we use output's dtype to represent
303
+ # the dtype of this subgraph. For other cases, graph_dtype
304
+ # might be None
305
+ for node in graph.nodes:
306
+ if OptimizationContext.key in node.meta:
307
+ opt_ctx = node.meta[OptimizationContext.key]
308
+ else:
309
+ opt_ctx = OptimizationContext()
310
+
311
+ opt_ctx.dtype = self.deduce_node_dtype(node)
312
+ node.meta[OptimizationContext.key] = opt_ctx
313
+ if node.target == "output":
314
+ graph_dtype = opt_ctx.dtype
315
+ return graph_dtype
316
+
317
+ def propagate(self):
318
+ self.propagate_graph(self.graphs["root"])
319
+
320
+ @classmethod
321
+ def propagate_loopbody(cls, body):
322
+ return cls(body).propagate()
323
+
324
+ @classmethod
325
+ def propagate_scheduler_node(cls, node):
326
+ from ..ir import LoopBody
327
+ from ..scheduler import SchedulerNode
328
+
329
+ assert isinstance(node, SchedulerNode)
330
+ assert isinstance(node._body, LoopBody)
331
+ DataTypePropagation.propagate_loopbody(node._body)
332
+
333
+
334
+ class ExprPrinter(Printer):
335
+ @staticmethod
336
+ def paren(string):
337
+ def all_in_parens(string):
338
+ if string[0] != "(" or len(string) < 2:
339
+ return False
340
+ count = 1
341
+ for i, char in enumerate(string[1:]):
342
+ if char == "(":
343
+ count += 1
344
+ elif char == ")":
345
+ count -= 1
346
+ if count == 0 and i != len(string) - 2:
347
+ return False
348
+ assert count == 0
349
+ return True
350
+
351
+ if (
352
+ isinstance(string, CSEVariable)
353
+ or re.match(r"^[a-z0-9_.]+$", string, re.I)
354
+ or re.match(r"^\([^)]*\)$", string, re.I)
355
+ or string == ""
356
+ ):
357
+ return string
358
+ # don't put extra parens for strings that are already wrapped in parens
359
+ if all_in_parens(string):
360
+ return string
361
+ return f"({string})"
362
+
363
+ def _print_Infinity(self, expr):
364
+ return "math.inf"
365
+
366
+ def _print_NegativeInfinity(self, expr):
367
+ return "-math.inf"
368
+
369
+ def _print_Relational(self, expr):
370
+ return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
371
+
372
+ def _print_Mul(self, expr):
373
+ return "*".join(map(self.paren, map(self._print, expr.args)))
374
+
375
+ def _print_Add(self, expr):
376
+ return " + ".join(map(self.paren, map(self._print, expr.args)))
377
+
378
+ def _print_Mod(self, expr):
379
+ return " % ".join(map(self.paren, map(self._print, expr.args)))
380
+
381
+ def _print_FloorDiv(self, expr):
382
+ raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
383
+
384
+ def _print_CleanDiv(self, expr):
385
+ return self._print_FloorDiv(expr)
386
+
387
+ def _print_GreaterThan(self, expr):
388
+ # GreaterThan: >=
389
+ # StrictlyGreaterThan: >
390
+ # Go figure...
391
+ return " >= ".join(map(self.paren, map(self._print, expr.args)))
392
+
393
+ def _print_align(self, expr):
394
+ assert len(expr.args) == 1
395
+ return f"align({self._print(expr.args[0])})"
396
+
397
+
398
+ class PythonPrinter(ExprPrinter):
399
+ def _print_ModularIndexing(self, expr):
400
+ x, div, mod = expr.args
401
+ x = self.paren(self.doprint(x))
402
+ div = self.paren(self.doprint(div))
403
+ mod = self.paren(self.doprint(mod))
404
+ if div != "1":
405
+ x = f"({x} // {div})"
406
+ return f"{x} % {mod}"
407
+
408
+ def _print_FloorDiv(self, expr):
409
+ x, div = expr.args
410
+ x = self.paren(self.doprint(x))
411
+ div = self.paren(self.doprint(div))
412
+ return f"({x} // {div})"
413
+
414
+ def _helper_sqrt(self, expr):
415
+ return f"math.sqrt({self._print(expr)})"
416
+
417
+ def _print_Pow(self, expr):
418
+ # Pow() confuses triton
419
+ base, exp = expr.args
420
+ # NB: Remember this is sizevar computation! You don't typically
421
+ # expect to have to do floating point computation including exponents
422
+ # in sizevar compute. Instead of adding support for floating
423
+ # point pow, you should make upstream retranslate the Sympy expression
424
+ # into Tensor expressions earlier and do that instead.
425
+ if exp == 0.5:
426
+ return self._helper_sqrt(base)
427
+ elif exp == -0.5:
428
+ return "1/" + self._helper_sqrt(base)
429
+ base = self._print(base)
430
+ assert exp == int(exp), exp
431
+ exp = int(exp)
432
+ if exp > 0:
433
+ return "*".join([self.paren(base)] * exp)
434
+ elif exp < 0:
435
+ return "1/" + self.paren("*".join([self.paren(base)] * abs(exp)))
436
+ else: # exp == 0
437
+ return "1"
438
+
439
+ def _print_floor(self, expr):
440
+ assert len(expr.args) == 1
441
+ return f"math.floor({self._print(expr.args[0])})"
442
+
443
+ def _print_ceiling(self, expr):
444
+ assert len(expr.args) == 1
445
+ return f"math.ceil({self._print(expr.args[0])})"
446
+
447
+ def _print_Abs(self, expr):
448
+ assert len(expr.args) == 1
449
+ return f"abs({self._print(expr.args[0])})"
450
+
451
+ def _print_Max(self, expr):
452
+ assert len(expr.args) >= 2
453
+ return f"max({', '.join(map(self._print, expr.args))})"
454
+
455
+ def _print_Min(self, expr):
456
+ assert len(expr.args) >= 2
457
+ return f"min({', '.join(map(self._print, expr.args))})"
458
+
459
+ def _print_cos(self, expr):
460
+ assert len(expr.args) == 1
461
+ return f"math.cos({self._print(expr.args[0])})"
462
+
463
+ def _print_cosh(self, expr):
464
+ assert len(expr.args) == 1
465
+ return f"math.cosh({self._print(expr.args[0])})"
466
+
467
+ def _print_acos(self, expr):
468
+ assert len(expr.args) == 1
469
+ return f"math.acos({self._print(expr.args[0])})"
470
+
471
+ def _print_sin(self, expr):
472
+ assert len(expr.args) == 1
473
+ return f"math.sin({self._print(expr.args[0])})"
474
+
475
+ def _print_sinh(self, expr):
476
+ assert len(expr.args) == 1
477
+ return f"math.sinh({self._print(expr.args[0])})"
478
+
479
+ def _print_asin(self, expr):
480
+ assert len(expr.args) == 1
481
+ return f"math.asin({self._print(expr.args[0])})"
482
+
483
+ def _print_tan(self, expr):
484
+ assert len(expr.args) == 1
485
+ return f"math.tan({self._print(expr.args[0])})"
486
+
487
+ def _print_tanh(self, expr):
488
+ assert len(expr.args) == 1
489
+ return f"math.tanh({self._print(expr.args[0])})"
490
+
491
+ def _print_atan(self, expr):
492
+ assert len(expr.args) == 1
493
+ return f"math.atan({self._print(expr.args[0])})"
494
+
495
+ def _print_Round(self, expr):
496
+ assert len(expr.args) == 1
497
+ return f"round({self._print(expr.args[0])})"
498
+
499
+ def _print_RoundDecimal(self, expr):
500
+ assert len(expr.args) == 2
501
+ number, ndigits = expr.args
502
+ assert isinstance(ndigits, sympy.Integer)
503
+ return f"round({self._print(number)}, {ndigits})"
504
+
505
+
506
+ class OpOverrides:
507
+ def __init__(self, parent):
508
+ super().__init__()
509
+ self._parent = parent
510
+
511
+ def __getattr__(self, item):
512
+ return getattr(self._parent, item)
513
+
514
+ @staticmethod
515
+ def identity(value):
516
+ # used to trigger cse
517
+ return value
518
+
519
+ @staticmethod
520
+ def constant(value, dtype):
521
+ return repr(value)
522
+
523
+ @staticmethod
524
+ def reciprocal(x):
525
+ return ops.truediv("1", x)
526
+
527
+ @staticmethod
528
+ def square(x):
529
+ return ops.mul(x, x)
530
+
531
+ @staticmethod
532
+ def bitwise_not(x):
533
+ return f"~{ExprPrinter.paren(x)}"
534
+
535
+ @staticmethod
536
+ def logical_not(a):
537
+ return f"{ExprPrinter.paren(a)} == 0"
538
+
539
+ @staticmethod
540
+ def bitwise_and(x, y):
541
+ return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
542
+
543
+ @staticmethod
544
+ def bitwise_or(x, y):
545
+ return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
546
+
547
+ @staticmethod
548
+ def bitwise_xor(x, y):
549
+ return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
550
+
551
+ @staticmethod
552
+ def bitwise_left_shift(x, y):
553
+ return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
554
+
555
+ @staticmethod
556
+ def bitwise_right_shift(x, y):
557
+ return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
558
+
559
+ @staticmethod
560
+ def remainder(a, b):
561
+ r = ops.mod(a, b)
562
+ return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)
563
+
564
+ @staticmethod
565
+ def load_seed(name, offset):
566
+ return ops.load(name, sympy.Integer(offset))
567
+
568
+ @classmethod
569
+ def _initialize_pointwise_overrides(cls, target):
570
+ assert target in {"triton", "cpp", "cppvec"}, target
571
+
572
+ def pointwise_factory_1(impl):
573
+ def func(x):
574
+ return impl.format(x=x)
575
+
576
+ return func
577
+
578
+ def pointwise_factory_2(impl):
579
+ def func(x, y):
580
+ return impl.format(x=x, y=y)
581
+
582
+ return func
583
+
584
+ for funcname, data in pointwise_overrides_data.items():
585
+ impl = getattr(data, target)
586
+ if isinstance(impl, str):
587
+ nof_args = 2 if "{y}" in impl else 1
588
+ # extend the following dictionary with factory
589
+ # functions for a specific number of arguments as
590
+ # needed:
591
+ factory = {1: pointwise_factory_1, 2: pointwise_factory_2}[nof_args]
592
+ setattr(cls, funcname, staticmethod(factory(impl)))
593
+
594
+
595
+ @dataclasses.dataclass
596
+ class OverridesData:
597
+ name: str
598
+ cpp: str
599
+ triton: Optional[str] = None # None when not impl in libdevice/triton
600
+ cppvec: Optional[str] = None # None when not impl in aten/.../vec
601
+ type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
602
+ ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
603
+ )
604
+
605
+
606
+ pointwise_overrides_data: Dict[str, OverridesData] = dict(
607
+ airy_ai=OverridesData(
608
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
609
+ cpp="airy_ai_forward({x})",
610
+ name="special_airy_ai",
611
+ ),
612
+ bessel_j0=OverridesData(
613
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
614
+ cpp="bessel_j0_forward({x})",
615
+ triton="libdevice.j0({x})",
616
+ name="special_bessel_j0",
617
+ ),
618
+ bessel_j1=OverridesData(
619
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
620
+ cpp="bessel_j1_forward({x})",
621
+ triton="libdevice.j1({x})",
622
+ name="special_bessel_j1",
623
+ ),
624
+ bessel_y0=OverridesData(
625
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
626
+ cpp="bessel_y0_forward({x})",
627
+ triton="libdevice.y0({x})",
628
+ name="special_bessel_y0",
629
+ ),
630
+ bessel_y1=OverridesData(
631
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
632
+ cpp="bessel_y1_forward({x})",
633
+ triton="libdevice.y1({x})",
634
+ name="special_bessel_y1",
635
+ ),
636
+ digamma=OverridesData(
637
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
638
+ cpp="calc_digamma({x})",
639
+ cppvec="{x}.digamma()",
640
+ name="digamma",
641
+ ),
642
+ # no cpp nor triton implementation for entr, it is defined as decomposition
643
+ # erf, erfc
644
+ erfcx=OverridesData(
645
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
646
+ cpp="calc_erfcx({x})",
647
+ triton="libdevice.erfcx({x})",
648
+ name="special_erfcx",
649
+ ),
650
+ # erfinv, exp2, expit, gammaln
651
+ igamma=OverridesData(
652
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
653
+ cpp="calc_igamma({x}, {y})",
654
+ name="igamma",
655
+ ),
656
+ igammac=OverridesData(
657
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
658
+ cpp="calc_igammac({x}, {y})",
659
+ name="igammac",
660
+ ),
661
+ gammainc=OverridesData(
662
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
663
+ cpp="calc_igamma({x}, {y})",
664
+ name="special_gammainc",
665
+ ),
666
+ gammaincc=OverridesData(
667
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
668
+ cpp="calc_igammac({x}, {y})",
669
+ name="special_gammaincc",
670
+ ),
671
+ i0=OverridesData(
672
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
673
+ cpp="calc_i0({x})",
674
+ triton="libdevice.cyl_bessel_i0({x})",
675
+ cppvec="{x}.i0()",
676
+ name="i0",
677
+ ),
678
+ i0e=OverridesData(
679
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
680
+ cpp="calc_i0e({x})",
681
+ cppvec="{x}.i0e()",
682
+ name="special_i0e",
683
+ ),
684
+ i1=OverridesData(
685
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
686
+ cpp="calc_i1({x})",
687
+ triton="libdevice.cyl_bessel_i1({x})",
688
+ name="special_i1",
689
+ ),
690
+ i1e=OverridesData(
691
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
692
+ cpp="calc_i1e({x})",
693
+ name="special_i1e",
694
+ ),
695
+ log_ndtr=OverridesData(
696
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
697
+ cpp="calc_log_ndtr({x})",
698
+ name="special_log_ndtr",
699
+ ),
700
+ # logit
701
+ modified_bessel_i0=OverridesData(
702
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
703
+ cpp="modified_bessel_i0_forward({x})",
704
+ triton="libdevice.cyl_bessel_i0({x})",
705
+ name="special_modified_bessel_i0",
706
+ ),
707
+ modified_bessel_i1=OverridesData(
708
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
709
+ cpp="modified_bessel_i1_forward({x})",
710
+ triton="libdevice.cyl_bessel_i1({x})",
711
+ name="special_modified_bessel_i1",
712
+ ),
713
+ modified_bessel_k0=OverridesData(
714
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
715
+ cpp="modified_bessel_k0_forward({x})",
716
+ name="special_modified_bessel_k0",
717
+ ),
718
+ modified_bessel_k1=OverridesData(
719
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
720
+ cpp="modified_bessel_k1_forward({x})",
721
+ name="special_modified_bessel_k1",
722
+ ),
723
+ # multigamma
724
+ ndtr=OverridesData(
725
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
726
+ cpp="calc_ndtr({x})",
727
+ name="special_ndtr",
728
+ ),
729
+ ndtri=OverridesData(
730
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
731
+ cpp="calc_ndtri({x})",
732
+ name="special_ndtri",
733
+ ),
734
+ polygamma=OverridesData(
735
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
736
+ cpp="calc_polygamma({y}, {x})",
737
+ name="polygamma",
738
+ ),
739
+ # psi - alias to digamma
740
+ # round
741
+ scaled_modified_bessel_k0=OverridesData(
742
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
743
+ cpp="scaled_modified_bessel_k0_forward({x})",
744
+ name="special_scaled_modified_bessel_k0",
745
+ ),
746
+ scaled_modified_bessel_k1=OverridesData(
747
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
748
+ cpp="scaled_modified_bessel_k1_forward({x})",
749
+ name="special_scaled_modified_bessel_k1",
750
+ ),
751
+ # sinc
752
+ spherical_bessel_j0=OverridesData(
753
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
754
+ cpp="spherical_bessel_j0_forward({x})",
755
+ name="special_spherical_bessel_j0",
756
+ ),
757
+ zeta=OverridesData(
758
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
759
+ cpp="zeta({x}, {y})",
760
+ name="special_zeta",
761
+ ),
762
+ chebyshev_polynomial_t=OverridesData(
763
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
764
+ cpp="chebyshev_polynomial_t_forward({x}, {y})",
765
+ name="special_chebyshev_polynomial_t",
766
+ ),
767
+ chebyshev_polynomial_u=OverridesData(
768
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
769
+ cpp="chebyshev_polynomial_u_forward({x}, {y})",
770
+ name="special_chebyshev_polynomial_u",
771
+ ),
772
+ chebyshev_polynomial_v=OverridesData(
773
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
774
+ cpp="chebyshev_polynomial_v_forward({x}, {y})",
775
+ name="special_chebyshev_polynomial_v",
776
+ ),
777
+ chebyshev_polynomial_w=OverridesData(
778
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
779
+ cpp="chebyshev_polynomial_w_forward({x}, {y})",
780
+ name="special_chebyshev_polynomial_w",
781
+ ),
782
+ legendre_polynomial_p=OverridesData(
783
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
784
+ cpp="legendre_polynomial_p_forward({x}, {y})",
785
+ name="special_legendre_polynomial_p",
786
+ ),
787
+ shifted_chebyshev_polynomial_t=OverridesData(
788
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
789
+ cpp="shifted_chebyshev_polynomial_t_forward({x}, {y})",
790
+ name="special_shifted_chebyshev_polynomial_t",
791
+ ),
792
+ shifted_chebyshev_polynomial_u=OverridesData(
793
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
794
+ cpp="shifted_chebyshev_polynomial_u_forward({x}, {y})",
795
+ name="special_shifted_chebyshev_polynomial_u",
796
+ ),
797
+ shifted_chebyshev_polynomial_v=OverridesData(
798
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
799
+ cpp="shifted_chebyshev_polynomial_v_forward({x}, {y})",
800
+ name="special_shifted_chebyshev_polynomial_v",
801
+ ),
802
+ shifted_chebyshev_polynomial_w=OverridesData(
803
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
804
+ cpp="shifted_chebyshev_polynomial_w_forward({x}, {y})",
805
+ name="special_shifted_chebyshev_polynomial_w",
806
+ ),
807
+ hermite_polynomial_h=OverridesData(
808
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
809
+ cpp="hermite_polynomial_h_forward({x}, {y})",
810
+ name="special_hermite_polynomial_h",
811
+ ),
812
+ hermite_polynomial_he=OverridesData(
813
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
814
+ cpp="hermite_polynomial_he_forward({x}, {y})",
815
+ name="special_hermite_polynomial_he",
816
+ ),
817
+ laguerre_polynomial_l=OverridesData(
818
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
819
+ cpp="laguerre_polynomial_l_forward({x}, {y})",
820
+ name="special_laguerre_polynomial_l",
821
+ ),
822
+ )
823
+
824
+
825
+ # Use mypy to check protocol implemented correctly
826
+ def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
827
+ return h
828
+
829
+
830
+ class DeferredLine(DeferredLineBase):
831
+ """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
832
+
833
+ def __init__(self, name, line):
834
+ super().__init__(line)
835
+ self.name = name
836
+ assert not isinstance(line, DeferredLineBase)
837
+
838
+ def __call__(self):
839
+ if all(
840
+ self.name not in x
841
+ for x in (
842
+ V.graph.removed_buffers,
843
+ V.kernel.removed_buffers,
844
+ V.graph.inplaced_to_remove,
845
+ V.kernel.inplaced_to_remove,
846
+ )
847
+ ):
848
+ return self.line
849
+ return None
850
+
851
+ def _new_line(self, line):
852
+ return DeferredLine(self.name, line)
853
+
854
+
855
+ class BracesBuffer(IndentedBuffer):
856
+ def indent(self, offset=1):
857
+ @contextlib.contextmanager
858
+ def ctx():
859
+ for _ in range(offset):
860
+ self.writeline("{")
861
+ self._indent += 1
862
+ for _ in range(-offset):
863
+ self._indent -= 1
864
+ self.writeline("}")
865
+ yield
866
+ for _ in range(-offset):
867
+ self.writeline("{")
868
+ self._indent += 1
869
+ for _ in range(offset):
870
+ self._indent -= 1
871
+ self.writeline("}")
872
+
873
+ return ctx()
874
+
875
+
876
+ class InplacedBuffer(NamedTuple):
877
+ inner_name: str
878
+ other_names: List[str]
879
+
880
+
881
+ class KernelArgs:
882
+ @staticmethod
883
+ def _lookup(prefix, odict, name):
884
+ assert isinstance(name, (str, sympy.Symbol))
885
+ if name not in odict:
886
+ odict[name] = f"{prefix}{len(odict)}"
887
+ return odict[name]
888
+
889
+ def __init__(self, sizevars=None):
890
+ self.input_buffers = dict()
891
+ self.output_buffers = dict()
892
+ self.inplace_buffers = dict()
893
+ self.sizevars = sizevars or dict()
894
+ self.workspace_arg = None
895
+
896
+ def __repr__(self):
897
+ return "KernelArgs({})".format(
898
+ ", ".join(
899
+ map(
900
+ repr,
901
+ [
902
+ self.input_buffers,
903
+ self.output_buffers,
904
+ self.inplace_buffers,
905
+ self.sizevars,
906
+ ],
907
+ )
908
+ )
909
+ )
910
+
911
+ def _buffer_is_marked_removed(self, name):
912
+ return isinstance(name, str) and name.startswith("REMOVED")
913
+
914
+ def input(self, name):
915
+ if V.graph.scheduler:
916
+ name = V.graph.scheduler.mutation_real_name.get(name, name)
917
+ assert name not in V.graph.removed_buffers, name
918
+ if name in self.output_buffers:
919
+ return self.output_buffers[name]
920
+ if name in self.inplace_buffers:
921
+ return self.inplace_buffers[name].inner_name
922
+ if name.startswith("seed"):
923
+ return self._lookup("seed", self.input_buffers, name)
924
+ return self._lookup("in_ptr", self.input_buffers, name)
925
+
926
+ def output(self, name):
927
+ if V.graph.scheduler:
928
+ name = V.graph.scheduler.mutation_real_name.get(name, name)
929
+ assert name not in V.graph.removed_buffers, name
930
+ if name in self.inplace_buffers:
931
+ return self.inplace_buffers[name].inner_name
932
+ return self._lookup("out_ptr", self.output_buffers, name)
933
+
934
+ def make_inplace(self, input_name, output_name):
935
+ assert output_name not in self.inplace_buffers
936
+ if input_name in self.inplace_buffers:
937
+ buf = self.inplace_buffers[input_name]
938
+ buf.other_names.append(output_name)
939
+ self.inplace_buffers[output_name] = buf
940
+ else:
941
+ buf = InplacedBuffer(
942
+ f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
943
+ [input_name, output_name],
944
+ )
945
+ self.inplace_buffers[input_name] = buf
946
+ self.inplace_buffers[output_name] = buf
947
+
948
+ def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
949
+ if self.workspace_arg is None:
950
+ self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
951
+ return "ws_ptr", 0
952
+
953
+ offset = self.workspace_arg.nbytes
954
+ zero_fill = zero_fill or self.workspace_arg.zero_fill
955
+ self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
956
+ return "ws_ptr", offset
957
+
958
+ def seed_offset(self, name, value):
959
+ if value in self.sizevars:
960
+ return self.sizevars[value]
961
+ if name in self.sizevars.values():
962
+ name = (
963
+ f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
964
+ )
965
+ self.sizevars[value] = name
966
+ return name
967
+
968
+ def size(self, name):
969
+ if str(name) == "seed":
970
+ self.sizevars["seed"] = "seed"
971
+ return "seed"
972
+ return self._lookup("ks", self.sizevars, name)
973
+
974
+ def call_names(self):
975
+ return chain(
976
+ self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
977
+ )
978
+
979
+ def wrap_ptr_arg(self, buf, dtype):
980
+ return buf
981
+
982
+ def wrap_size_arg(self, size):
983
+ return str(size)
984
+
985
+ def cpp_argdefs(self):
986
+ from .cpp import DTYPE_TO_CPP, INDEX_TYPE
987
+
988
+ call_args = []
989
+ arg_defs = []
990
+ arg_types = []
991
+ for inplaced in unique(self.inplace_buffers.values()):
992
+ if self._buffer_is_marked_removed(inplaced):
993
+ continue
994
+ outer = inplaced.other_names[-1]
995
+ inner = inplaced.inner_name
996
+ dtype = V.graph.get_dtype(outer)
997
+ cpp_dtype = DTYPE_TO_CPP[dtype]
998
+ arg_defs.append(f"{cpp_dtype}* {inner}")
999
+ call_args.append(self.wrap_ptr_arg(outer, dtype))
1000
+ arg_types.append(f"{cpp_dtype}*")
1001
+ for outer, inner in self.input_buffers.items():
1002
+ if outer in self.inplace_buffers:
1003
+ continue
1004
+ dtype = V.graph.get_dtype(outer)
1005
+ cpp_dtype = DTYPE_TO_CPP[dtype]
1006
+ arg_defs.append(f"const {cpp_dtype}* {inner}")
1007
+ call_args.append(self.wrap_ptr_arg(outer, dtype))
1008
+ arg_types.append(f"const {cpp_dtype}*")
1009
+ for outer, inner in self.output_buffers.items():
1010
+ if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
1011
+ continue
1012
+ dtype = V.graph.get_dtype(outer)
1013
+ cpp_dtype = DTYPE_TO_CPP[dtype]
1014
+ arg_defs.append(f"{cpp_dtype}* {inner}")
1015
+ call_args.append(self.wrap_ptr_arg(outer, dtype))
1016
+ arg_types.append(f"{cpp_dtype}*")
1017
+ for outer, inner in self.sizevars.items():
1018
+ arg_defs.append(f"const {INDEX_TYPE} {inner}")
1019
+ call_args.append(self.wrap_size_arg(outer))
1020
+ arg_types.append(f"const {INDEX_TYPE}")
1021
+ if V.graph.wrapper_code:
1022
+ V.graph.wrapper_code.ensure_size_computed(outer)
1023
+ assert self.workspace_arg is None, "Workspace not supported on CPU "
1024
+ return arg_defs, call_args, arg_types
1025
+
1026
+ def python_argdefs(self):
1027
+ arg_defs = []
1028
+ call_args = []
1029
+ precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
1030
+ for inplaced in unique(self.inplace_buffers.values()):
1031
+ if self._buffer_is_marked_removed(inplaced):
1032
+ continue
1033
+ arg_defs.append(inplaced.inner_name)
1034
+ call_args.append(inplaced.other_names[-1])
1035
+ precompile_args.append(
1036
+ TensorArg(
1037
+ name=inplaced.inner_name,
1038
+ buffer=inplaced.other_names[-1],
1039
+ dtype=V.graph.get_dtype(inplaced.other_names[-1]),
1040
+ )
1041
+ )
1042
+ for outer, inner in chain(
1043
+ self.input_buffers.items(), self.output_buffers.items()
1044
+ ):
1045
+ if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
1046
+ continue
1047
+ arg_defs.append(inner)
1048
+ call_args.append(outer)
1049
+ precompile_args.append(
1050
+ TensorArg(
1051
+ name=inner,
1052
+ buffer=outer,
1053
+ dtype=V.graph.get_dtype(outer),
1054
+ )
1055
+ )
1056
+ for outer, inner in self.sizevars.items():
1057
+ arg_defs.append(inner)
1058
+ call_args.append(outer)
1059
+ precompile_args.append(SizeArg(inner, outer))
1060
+ if V.graph.wrapper_code:
1061
+ V.graph.wrapper_code.ensure_size_computed(outer)
1062
+ if self.workspace_arg is not None:
1063
+ arg_defs.append("ws_ptr")
1064
+ call_args.append("workspace")
1065
+ precompile_args.append(self.workspace_arg)
1066
+
1067
+ return arg_defs, call_args, precompile_args
1068
+
1069
+ def aliases(self):
1070
+ for inplaced in unique(self.inplace_buffers.values()):
1071
+ if self._buffer_is_marked_removed(inplaced):
1072
+ continue
1073
+ for other in inplaced.other_names:
1074
+ if (
1075
+ other in V.graph.inplaced_to_remove
1076
+ or other in V.kernel.inplaced_to_remove
1077
+ ):
1078
+ continue
1079
+ if other in self.input_buffers:
1080
+ yield self.input_buffers[other], inplaced.inner_name
1081
+ if other in self.output_buffers:
1082
+ yield self.output_buffers[other], inplaced.inner_name
1083
+
1084
+ def is_removed(self, name):
1085
+ def _is_removed(name, buffers):
1086
+ return name not in buffers or self._buffer_is_marked_removed(buffers[name])
1087
+
1088
+ return _is_removed(name, self.output_buffers) and _is_removed(
1089
+ name, self.inplace_buffers
1090
+ )
1091
+
1092
+ # Includes inplace buffers, excludes removed buffers. Essentially,
1093
+ # after you do a call into this kernel, which buffers actually contain
1094
+ # updated data? Modeled off of python_argdefs.
1095
+ def live_output_buffers(self):
1096
+ live_outs = set()
1097
+ for inplaced in unique(self.inplace_buffers.values()):
1098
+ if self._buffer_is_marked_removed(inplaced):
1099
+ continue
1100
+ live_outs.add(inplaced.other_names[-1])
1101
+ for outer, inner in self.output_buffers.items():
1102
+ if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
1103
+ continue
1104
+ live_outs.add(outer)
1105
+ return live_outs
1106
+
1107
+
1108
+ class CSEVariable:
1109
+ """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
1110
+ To do so, the backends can simply overload `Kernel.create_cse_var`
1111
+ The "CSEVariable.update_on_args" method gives you a hook for annotations
1112
+ See example of TritonCSEVariable in triton.py
1113
+ """
1114
+
1115
+ def __init__(self, name, bounds: ValueRanges[Any]):
1116
+ assert isinstance(bounds, ValueRanges)
1117
+ self.name = name
1118
+ self.bounds = bounds
1119
+
1120
+ def __str__(self):
1121
+ return self.name
1122
+
1123
+ def __hash__(self) -> int:
1124
+ return hash(self.name)
1125
+
1126
+ def __eq__(self, other) -> bool:
1127
+ return type(other) == type(self) and other.name == self.name
1128
+
1129
+ def update_on_args(self, name, args, kwargs):
1130
+ pass
1131
+
1132
+
1133
+ class CppWrapperKernelArgs(KernelArgs):
1134
+ def wrap_ptr_arg(self, buf, dtype):
1135
+ from .cpp import DTYPE_TO_CPP
1136
+
1137
+ if config.abi_compatible:
1138
+ # In the abi_compatible model, we just return the buf here.
1139
+ # We will form correct call args later in wrapper.generate_kernel_all.
1140
+ return buf
1141
+ else:
1142
+ return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
1143
+
1144
+ def wrap_size_arg(self, size):
1145
+ return f"{size}"
1146
+
1147
+
1148
+ class CSE:
1149
+ """Common subexpression elimination"""
1150
+
1151
+ def __init__(
1152
+ self,
1153
+ prefix="",
1154
+ suffix="",
1155
+ name_prefix="tmp",
1156
+ iter_buffers=None,
1157
+ store_cache=None,
1158
+ reduction_cache=None,
1159
+ varname_map=None,
1160
+ ):
1161
+ self.prefix = prefix
1162
+ self.suffix = suffix
1163
+ self.cache = {}
1164
+ self.name_prefix = name_prefix
1165
+ self.store_cache = store_cache or {}
1166
+ self.reduction_cache = reduction_cache or {}
1167
+ self.iter_buffer_ids = iter_buffers or itertools.count()
1168
+ self.invalidated_stores = set()
1169
+ self.varname_map = varname_map or {}
1170
+
1171
+ def invalidate(self, keep_vars: Set[str]):
1172
+ for name, tmp in list(self.store_cache.items()):
1173
+ if tmp not in keep_vars:
1174
+ del self.store_cache[name]
1175
+ self.invalidated_stores.add(name)
1176
+ self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
1177
+
1178
+ def clone(self):
1179
+ # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
1180
+ return CSE(
1181
+ prefix=self.prefix,
1182
+ suffix=self.suffix,
1183
+ name_prefix=self.name_prefix,
1184
+ iter_buffers=self.iter_buffer_ids,
1185
+ store_cache=self.store_cache,
1186
+ varname_map=self.varname_map,
1187
+ )
1188
+
1189
+ def generate(
1190
+ self,
1191
+ buffer: IndentedBuffer,
1192
+ expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
1193
+ *,
1194
+ bounds: ValueRanges[Any] = ValueRanges.unknown(),
1195
+ write=True,
1196
+ assignment=True,
1197
+ ) -> CSEVariable:
1198
+ if isinstance(expr, OpsValue):
1199
+ expr = expr.value
1200
+
1201
+ assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
1202
+ assert write or assignment
1203
+ if isinstance(expr, CSEVariable):
1204
+ # If the expressions were always created with all the information, we could
1205
+ # assert expr.bounds == bounds, but sometimes the expression is created
1206
+ # with the loose ValueRanges.unknown(), so we need to tighten the bounds
1207
+ expr.bounds = expr.bounds.tighten(bounds)
1208
+ return expr
1209
+ cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
1210
+ var = self.cache.get(cache_key, None)
1211
+ if not var:
1212
+ var = self.newvar(bounds) if assignment else None
1213
+ self.cache[cache_key] = var
1214
+ if write:
1215
+ if V.kernel.current_node:
1216
+ V.kernel.current_node.codegen_originating_info(
1217
+ buffer, only_once=True
1218
+ )
1219
+ if isinstance(expr, IndentedBuffer):
1220
+ if assignment:
1221
+ buffer.writeline(f"{self.prefix}{var} =")
1222
+ buffer.splice(expr)
1223
+ buffer.writeline(self.suffix)
1224
+ else:
1225
+ if assignment:
1226
+ line = f"{self.prefix}{var} = {expr}{self.suffix}"
1227
+ else:
1228
+ line = f"{expr}{self.suffix}"
1229
+ buffer.writeline(line)
1230
+ else:
1231
+ var.bounds = var.bounds.tighten(bounds)
1232
+
1233
+ return var
1234
+
1235
+ def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
1236
+ var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
1237
+ var = V.kernel.create_cse_var(var_name, bounds)
1238
+ self.varname_map[var_name] = var
1239
+ return var
1240
+
1241
+
1242
+ class IndirectAssertLine(DeferredLineBase):
1243
+ def __init__(self, line, assert_fn, var, mask, size_map):
1244
+ self.var = var
1245
+ self.mask = mask
1246
+ self.line = line
1247
+ self.assert_fn = assert_fn
1248
+ self.size_map = size_map
1249
+
1250
+ def __call__(self):
1251
+ size, size_str = self.size_map[(self.var, self.mask)]
1252
+
1253
+ # We assert if we've not been able to prove the bound
1254
+ assert_min = (self.var.bounds.lower >= 0) != sympy.true
1255
+ assert_max = (self.var.bounds.upper < size) != sympy.true
1256
+
1257
+ # FooBar interview question
1258
+ if not (assert_min or assert_max):
1259
+ return None
1260
+ elif assert_min and assert_max:
1261
+ # The conditions need to be in parens because of Python's operator precedence.
1262
+ # It'd be less error-prone to use and/or/not, which is suported by triton
1263
+ cond = f"(0 <= {self.var}) & ({self.var} < {size_str})"
1264
+ cond_print = f"0 <= {self.var} < {size_str}"
1265
+ elif assert_min:
1266
+ cond = f"0 <= {self.var}"
1267
+ cond_print = cond
1268
+ else:
1269
+ assert assert_max
1270
+ cond = f"{self.var} < {size_str}"
1271
+ cond_print = cond
1272
+
1273
+ if self.mask:
1274
+ cond = f"({cond}) | ~{self.mask}"
1275
+ return self.line.format(
1276
+ assert_fn=self.assert_fn, cond=cond, cond_print=cond_print
1277
+ )
1278
+
1279
+ def _new_line(self, line):
1280
+ return IndirectAssertLine(
1281
+ line, self.assert_fn, self.var, self.mask, self.size_map
1282
+ )
1283
+
1284
+
1285
+ class CodeGen:
1286
+ def __init__(self):
1287
+ super().__init__()
1288
+ self.exit_stack = contextlib.ExitStack()
1289
+
1290
+ def __enter__(self):
1291
+ self.exit_stack.__enter__()
1292
+ return self
1293
+
1294
+ def __exit__(self, exc_type, exc_val, exc_tb):
1295
+ self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
1296
+
1297
+
1298
+ class Kernel(CodeGen):
1299
+ newvar_prefix = ""
1300
+ suffix = ""
1301
+ overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
1302
+ # TODO: these look dead, but with all the getattr it's hard to tell...
1303
+ load_format: None = None
1304
+ store_format: None = None
1305
+
1306
+ def __init__(self, args=None, increase_kernel_count=True):
1307
+ super().__init__()
1308
+ if increase_kernel_count:
1309
+ metrics.generated_kernel_count += 1
1310
+ self.args = args or KernelArgs()
1311
+ self.loads = IndentedBuffer()
1312
+ self.compute = IndentedBuffer()
1313
+ self.stores = IndentedBuffer()
1314
+ self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
1315
+ self.must_keep_buffers = set()
1316
+ self.store_buffer_names = set()
1317
+ self._load_mask = None
1318
+ # set in set_current_node
1319
+ self.current_node = None
1320
+ self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
1321
+ # Upper bounds for indirect_indexing and their str representation
1322
+ # NB: None, None is never stored in map, but it is the assumed
1323
+ # "not set" value for the dict
1324
+ self.indirect_max_sizes: Dict[
1325
+ Tuple[CSEVariable, str], Union[Tuple[sympy.Expr, str], Tuple[None, None]]
1326
+ ] = {}
1327
+
1328
+ self.removed_buffers = set()
1329
+ self.inplaced_to_remove = set()
1330
+
1331
+ # key: the buffer to write
1332
+ # value: the buffer to read and whose memory can be reused for
1333
+ # the buffer specified by key
1334
+ self.inplace_update_buffers = dict()
1335
+ # Set minimum number of elements processed per thread.
1336
+ self.min_elem_per_thread = 1
1337
+ self.kernel_name = None
1338
+
1339
+ @contextlib.contextmanager
1340
+ def set_current_node(self, node):
1341
+ prior = self.current_node
1342
+ self.current_node = node
1343
+ self.node_to_bounds = node._body.bounds().get_bounds()
1344
+ try:
1345
+ yield
1346
+ finally:
1347
+ self.current_node = prior
1348
+
1349
+ @contextlib.contextmanager
1350
+ def swap_buffers(self, lb, cb=None, sb=None):
1351
+ if cb is None:
1352
+ cb = lb
1353
+ loads = self.loads
1354
+ compute = self.compute
1355
+ stores = self.stores
1356
+ cse = self.cse
1357
+ self.loads = lb
1358
+ self.compute = cb
1359
+ self.stores = sb
1360
+ self.cse = cse.clone()
1361
+ try:
1362
+ yield
1363
+ finally:
1364
+ self.loads = loads
1365
+ self.compute = compute
1366
+ self.stores = stores
1367
+ self.cse = cse
1368
+
1369
+ def load(self, name: str, index: sympy.Expr) -> CSEVariable:
1370
+ raise NotImplementedError()
1371
+
1372
+ def indirect_load(self, name: str, index: sympy.Expr):
1373
+ """A load the depends on an index we have read"""
1374
+ prior = self.loads
1375
+ try:
1376
+ # put the load in the compute section as it might have deps
1377
+ self.loads = self.compute
1378
+ return self.load(name, index)
1379
+ finally:
1380
+ self.loads = prior
1381
+
1382
+ def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
1383
+ raise NotImplementedError()
1384
+
1385
+ def store(
1386
+ self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
1387
+ ) -> None:
1388
+ raise NotImplementedError()
1389
+
1390
+ def reduction(
1391
+ self,
1392
+ dtype: torch.dtype,
1393
+ src_dtype: torch.dtype,
1394
+ reduction_type: ReductionType,
1395
+ value: Union[CSEVariable, Tuple[CSEVariable, ...]],
1396
+ ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
1397
+ raise NotImplementedError()
1398
+
1399
+ def scan(
1400
+ self,
1401
+ dtype: torch.dtype,
1402
+ combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable],
1403
+ value: CSEVariable,
1404
+ init: int,
1405
+ ) -> CSEVariable:
1406
+ raise NotImplementedError()
1407
+
1408
+ def bucketize(
1409
+ self,
1410
+ values: CSEVariable,
1411
+ offsets_name: str,
1412
+ offsets_size: sympy.Expr,
1413
+ indexing_dtype: torch.dtype,
1414
+ right: bool,
1415
+ ) -> CSEVariable:
1416
+ """
1417
+ See [Note: Inductor bucketize op]
1418
+ """
1419
+ raise NotImplementedError()
1420
+
1421
+ @property
1422
+ def assert_function(self) -> str:
1423
+ raise NotImplementedError()
1424
+
1425
+ def index_to_str(self, index: sympy.Expr) -> str:
1426
+ raise NotImplementedError()
1427
+
1428
+ def __enter__(self):
1429
+ # TODO: hoist this to top level
1430
+ class CSEProxy:
1431
+ self.name = "CSEProxy"
1432
+
1433
+ @staticmethod
1434
+ def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
1435
+ def inner(*args, **kwargs):
1436
+ # TritonTemplateKernel has no current_node
1437
+ buf_bounds = ValueRanges.unknown()
1438
+ if hasattr(V.interpreter, "current_node"):
1439
+ fx_node = V.interpreter.current_node
1440
+ assert isinstance(self.node_to_bounds, dict)
1441
+ buf_bounds = self.node_to_bounds.get(
1442
+ fx_node, ValueRanges.unknown()
1443
+ )
1444
+
1445
+ value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
1446
+
1447
+ def do_cse(v):
1448
+ csevar = self.cse.generate(self.compute, v, bounds=buf_bounds)
1449
+ csevar.update_on_args(name, args, kwargs)
1450
+ return csevar
1451
+
1452
+ return pytree.tree_map(do_cse, value)
1453
+
1454
+ return inner
1455
+
1456
+ @staticmethod
1457
+ def indirect_indexing(
1458
+ var: CSEVariable, size: sympy.Expr, check: bool = True
1459
+ ):
1460
+ # Skip CSE since this doesn't return an expression
1461
+
1462
+ if var.bounds.lower < 0: # type: ignore[operator]
1463
+ new_bounds = ValueRanges.unknown()
1464
+ if var.bounds != ValueRanges.unknown() and isinstance(
1465
+ size, sympy.Number
1466
+ ):
1467
+ # Take the negative part of the bound and add size to it
1468
+ # Then take union of that and the positive part
1469
+ # This is a tighter bound than that of a generic ops.where, as we have info on the cond
1470
+ neg = var.bounds & ValueRanges(-sympy.oo, -1)
1471
+ new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
1472
+ # We don't have a good way of representing the empty range
1473
+ if var.bounds.upper >= 0: # type: ignore[operator]
1474
+ pos = var.bounds & ValueRanges(0, sympy.oo)
1475
+ new_bounds = new_bounds | pos
1476
+
1477
+ stm = ops.add(var, self.rename_indexing(size))
1478
+ # Mixed negative and non-negative
1479
+ if var.bounds.upper >= 0: # type: ignore[operator]
1480
+ lt = ops.lt(var, "0")
1481
+ stm = ops.where(lt, stm, var)
1482
+ new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
1483
+
1484
+ new_var.update_on_args("index_wrap", (var,), {})
1485
+ var = new_var
1486
+
1487
+ if self.generate_assert(check):
1488
+ mask = self.load_mask(var)
1489
+
1490
+ # An assertion line may have been written already, if so just
1491
+ # update the max size.
1492
+ map_key = (var, mask)
1493
+ existing_size, _ = self.indirect_max_sizes.get(
1494
+ map_key, (None, None)
1495
+ )
1496
+ if existing_size is not None:
1497
+ size = sympy.Min(size, existing_size)
1498
+ else:
1499
+ line = (
1500
+ '{assert_fn}({cond}, "index out of bounds: {cond_print}")'
1501
+ )
1502
+ self.compute.writeline(
1503
+ IndirectAssertLine(
1504
+ line,
1505
+ self.assert_function,
1506
+ var,
1507
+ mask,
1508
+ self.indirect_max_sizes,
1509
+ )
1510
+ )
1511
+
1512
+ self.indirect_max_sizes[map_key] = (size, self.index_to_str(size))
1513
+ return sympy_index_symbol(str(var))
1514
+
1515
+ @staticmethod
1516
+ def load(name: str, index: sympy.Expr) -> CSEVariable:
1517
+ if name in self.cse.invalidated_stores:
1518
+ # A load from an invalidated store requires us to
1519
+ # keep the actual buffer around
1520
+ V.kernel.must_keep_buffers.add(name)
1521
+ if free_symbol_startswith(index, "tmp"):
1522
+ return self.indirect_load(name, index)
1523
+ store_cache = self.cse.store_cache
1524
+ if name in store_cache:
1525
+ return store_cache[name]
1526
+ return self.load(name, index)
1527
+
1528
+ @staticmethod
1529
+ def store(
1530
+ name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
1531
+ ) -> None:
1532
+ self.store_buffer_names.add(name)
1533
+ if mode is None:
1534
+ self.cse.store_cache[name] = value
1535
+ if self.current_node:
1536
+ for other_name in self.current_node.get_mutations():
1537
+ self.cse.store_cache[other_name] = value
1538
+ if name not in V.graph.removed_buffers:
1539
+ return self.store(name, index, value, mode=mode)
1540
+ else:
1541
+ return None # type: ignore[return-value]
1542
+
1543
+ @staticmethod
1544
+ def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
1545
+ self.store_buffer_names.add(name)
1546
+ self.cse.store_cache[name] = value
1547
+ if self.current_node:
1548
+ for other_name in self.current_node.get_mutations():
1549
+ self.cse.store_cache[other_name] = value
1550
+
1551
+ if name not in V.graph.removed_buffers:
1552
+ return self.store_reduction(name, index, value)
1553
+
1554
+ @staticmethod
1555
+ def reduction(
1556
+ dtype: torch.dtype,
1557
+ src_dtype: torch.dtype,
1558
+ reduction_type: ReductionType,
1559
+ value: Union[CSEVariable, Tuple[CSEVariable, ...]],
1560
+ ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
1561
+ return self.reduction(dtype, src_dtype, reduction_type, value)
1562
+
1563
+ @staticmethod
1564
+ def scan(
1565
+ dtype: torch.dtype,
1566
+ combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable],
1567
+ value: CSEVariable,
1568
+ init: int,
1569
+ ) -> CSEVariable:
1570
+ return self.scan(dtype, combine_fn, value, init)
1571
+
1572
+ @staticmethod
1573
+ def bucketize(
1574
+ values: CSEVariable,
1575
+ offsets_name: str,
1576
+ offsets_size: sympy.Expr,
1577
+ indexing_dtype: torch.dtype,
1578
+ right: bool,
1579
+ ) -> CSEVariable:
1580
+ """
1581
+ [Note: Inductor bucketize op]
1582
+
1583
+ Given values (tensor) and offsets_name (reference to the name of a 1D
1584
+ tensor), calculate the bucket that each value belongs to.
1585
+
1586
+ e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
1587
+ return = [ 0, 1, 1, 1, 1, 3, 3, 4].
1588
+
1589
+ When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
1590
+ When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
1591
+
1592
+ Offsets must be non-decreasing or the result is undefined.
1593
+ """
1594
+ return self.bucketize(
1595
+ values, offsets_name, offsets_size, indexing_dtype, right
1596
+ )
1597
+
1598
+ # Use mypy to check protocol implemented correctly
1599
+ def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
1600
+ return h
1601
+
1602
+ super().__enter__()
1603
+ assert self.overrides
1604
+ parent_handler = self.overrides(V.get_ops_handler())
1605
+ self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
1606
+ self.exit_stack.enter_context(V.set_kernel_handler(self))
1607
+ return self
1608
+
1609
+ def __exit__(self, exc_type, exc_val, exc_tb):
1610
+ """
1611
+ Note that V.graph.scheduler can be None when codegening triton template
1612
+ kernels.
1613
+ """
1614
+ if V.graph.scheduler:
1615
+ V.graph.scheduler.remove_kernel_local_buffers()
1616
+ super().__exit__(exc_type, exc_val, exc_tb)
1617
+
1618
+ def generate_assert(self, check):
1619
+ return (check or config.debug_index_asserts) and config.assert_indirect_indexing
1620
+
1621
+ def load_mask(self, var) -> str:
1622
+ # only the triton kernel requires mask
1623
+ return ""
1624
+
1625
+ def rename_indexing(self, index) -> sympy.Expr:
1626
+ # adds the necessary kernel args for index expressions
1627
+ # and renames variables in index expressions to kernel arg names
1628
+ if isinstance(index, (list, tuple)):
1629
+ return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
1630
+ index = V.graph.sizevars.simplify(index)
1631
+ sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
1632
+ replacements = {
1633
+ x: self.args.size(x)
1634
+ for x in sorted_symbols
1635
+ if x.name.startswith(("s", "u", "ps"))
1636
+ or (x.name.startswith("i") and not x.name.startswith("idx"))
1637
+ }
1638
+ return sympy_subs(index, replacements)
1639
+
1640
+ def create_cse_var(self, *args, **kwargs):
1641
+ return CSEVariable(*args, **kwargs)
1642
+
1643
+
1644
+ @dataclasses.dataclass
1645
+ class OptimizationContext:
1646
+ key: ClassVar[str] = "opt_ctx"
1647
+
1648
+ # Load value as mask
1649
+ is_load_as_mask: bool = False
1650
+
1651
+ dtype: Optional[torch.dtype] = None
1652
+ ops_name: str = ""
1653
+
1654
+ # Load uint8/int8 value as float32
1655
+ is_load_int8_as_float: bool = False
1656
+
1657
+
1658
+ @functools.lru_cache(None)
1659
+ def jinja2_env():
1660
+ try:
1661
+ import jinja2
1662
+
1663
+ return jinja2.Environment(
1664
+ undefined=jinja2.StrictUndefined,
1665
+ )
1666
+ except ImportError:
1667
+ return None
1668
+
1669
+
1670
+ PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]]
1671
+
1672
+
1673
+ class ChoiceCaller:
1674
+ """
1675
+ Represents a possible choice used in autotune_process.py.
1676
+ During autotuning, self.benchmark() is first called to get benchmark result,
1677
+ and if this choice is selected, self.output_node() is called to get the output_node.
1678
+
1679
+ Children classes: TritonTemplateCaller, CUDATemplateCaller.
1680
+ """
1681
+
1682
+ def __init__(self, name, input_nodes, layout):
1683
+ super().__init__()
1684
+ self.name = name
1685
+ self.layout = layout
1686
+ self.input_nodes = input_nodes
1687
+
1688
+ def benchmark(self, *args, out) -> float:
1689
+ algo = self.to_callable()
1690
+ return do_bench(lambda: algo(*args, out=out))
1691
+
1692
+ def call_name(self) -> str:
1693
+ raise NotImplementedError()
1694
+
1695
+ def to_callable(self):
1696
+ raise NotImplementedError()
1697
+
1698
+ def hash_key(self) -> str:
1699
+ raise NotImplementedError()
1700
+
1701
+ def output_node(self) -> "TensorBox":
1702
+ raise NotImplementedError()
1703
+
1704
+ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
1705
+ """Information returned here is logged to the autotune log file when that is enabled."""
1706
+ return {}
1707
+
1708
+
1709
+ class KernelTemplate:
1710
+ """
1711
+ Base class for defining kernel templates.
1712
+
1713
+ Children classes: TritonTemplate, CUDATemplate
1714
+ """
1715
+
1716
+ @staticmethod
1717
+ def _template_from_string(source):
1718
+ env = jinja2_env()
1719
+ if env is not None:
1720
+ return env.from_string(source)
1721
+ return None
1722
+
1723
+ @staticmethod
1724
+ def _fake_get_dtype(fake_out):
1725
+ _get_dtype_real = V.graph.get_dtype
1726
+
1727
+ def get_dtype(name):
1728
+ if name == fake_out.get_name():
1729
+ return fake_out.get_dtype()
1730
+ return _get_dtype_real(name)
1731
+
1732
+ return get_dtype
1733
+
1734
+ def __init__(self, name: str):
1735
+ self.name = name
1736
+
1737
+ def maybe_append_choice(self, choices, **kwargs):
1738
+ """
1739
+ Maybe generates a new ChoiceCaller and appends it into existing choices.
1740
+
1741
+ choices: A list of ChoiceCallers.
1742
+ kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
1743
+ """
1744
+
1745
+ try:
1746
+ choices.append(self.generate(**kwargs))
1747
+ except NotImplementedError:
1748
+ pass
1749
+
1750
+ def generate(self, **kwargs) -> ChoiceCaller:
1751
+ """
1752
+ Generates a ChoiceCaller instance from the given arguments.
1753
+ """
1754
+
1755
+ raise NotImplementedError()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp.py ADDED
The diff for this file is too large to render. See raw diff
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-311.pyc ADDED
Binary file (2.29 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-311.pyc ADDED
Binary file (19.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-311.pyc ADDED
Binary file (20.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-311.pyc ADDED
Binary file (30.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import cast, List
3
+
4
+ from ...._dynamo.utils import counters
5
+
6
+ from ... import config, ir
7
+ from ...codecache import code_hash, get_path
8
+ from ...ir import ComputedBuffer, CUDATemplateBuffer, Pointwise
9
+ from ...scheduler import (
10
+ BaseSchedulerNode,
11
+ BaseScheduling,
12
+ FusedSchedulerNode,
13
+ Scheduler,
14
+ SchedulerNode,
15
+ )
16
+ from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product
17
+ from ...virtualized import V
18
+ from ..common import IndentedBuffer
19
+
20
+ from .cutlass_epilogue_gen import CUTLASSEVTOpNotImplementedError
21
+
22
+ log = logging.getLogger(__name__)
23
+
24
+
25
+ class CUDACPPScheduling(BaseScheduling):
26
+ """
27
+ Partial Scheduling implementation for CUDA C++ Kernels.
28
+ This class is intended to be used in combination with TritonScheduling,
29
+ and delegated to by CUDACombinedScheduling.
30
+
31
+ It handles fusion decisions and CUDA C++ specific template code generation.
32
+ """
33
+
34
+ def __init__(self, scheduler: Scheduler):
35
+ super().__init__()
36
+ self.scheduler = scheduler
37
+
38
+ def group_fn(self, sizes):
39
+ return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
40
+
41
+ def is_cuda_cpp_template(self, node: BaseSchedulerNode) -> bool:
42
+ return isinstance(node, SchedulerNode) and isinstance(
43
+ node.node, CUDATemplateBuffer
44
+ )
45
+
46
+ def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool:
47
+ return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(
48
+ node.get_template_node()
49
+ )
50
+
51
+ def _can_fuse_epilogue_impl(
52
+ self,
53
+ cuda_template_buffer: CUDATemplateBuffer,
54
+ epilogue_nodes: List[ir.IRNode],
55
+ additional_node: ir.IRNode,
56
+ ) -> bool:
57
+ """
58
+ Check if the given node can be fused with the epilogue. At the moment, Kernels
59
+ support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes.
60
+
61
+ Args:
62
+ cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer
63
+ epilogue_nodes : List[ir.Buffer]: The list of already fused epilogue nodes.
64
+ additional_node: The ir.Buffer node to be checked if it can be fused with the epilogue.
65
+ Returns:
66
+ - bool: True if the given node can be fused with the epilogue, False otherwise.
67
+
68
+ """
69
+ if not isinstance(cuda_template_buffer, CUDATemplateBuffer):
70
+ return False
71
+ if not cuda_template_buffer.template.can_fuse_epilogue:
72
+ # The used GEMM op does not support fusing epilogues
73
+ return False
74
+ if not isinstance(additional_node, ComputedBuffer):
75
+ return False
76
+ if not isinstance(additional_node.data, Pointwise):
77
+ return False
78
+ # We can fuse a Pointwise op that depends on the last fused epilogue node
79
+ # if any. If there is no epilogue node yet, it needs to depend on the template
80
+ # node
81
+ node_name = additional_node.get_computed_buffer_name()
82
+ if node_name is None:
83
+ return False
84
+
85
+ if len(epilogue_nodes) == 0:
86
+ if cuda_template_buffer.name not in additional_node.get_read_names():
87
+ return False
88
+ else:
89
+ last_epilogue_node = epilogue_nodes[-1]
90
+ assert isinstance(last_epilogue_node, ir.ComputedBuffer) # for mypy
91
+ last_epilogue_name = (
92
+ last_epilogue_node.name
93
+ if last_epilogue_node.name is not None
94
+ else last_epilogue_node.data.name # type: ignore[attr-defined]
95
+ )
96
+ if last_epilogue_name not in additional_node.get_read_names():
97
+ return False
98
+ if additional_node.layout != cuda_template_buffer.layout:
99
+ return False
100
+ try:
101
+ from torch._inductor.codegen.cuda.cutlass_epilogue_gen import (
102
+ CutlassEVTEpilogueArgumentFormatter,
103
+ CutlassEVTEpilogueTypeFormatter,
104
+ )
105
+
106
+ CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(
107
+ cast(str, cuda_template_buffer.name), "anything", [additional_node]
108
+ )
109
+ CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(
110
+ cast(str, cuda_template_buffer.name), [additional_node]
111
+ )
112
+ except CUTLASSEVTOpNotImplementedError as e:
113
+ not_implemented_op = str(e)
114
+ if not_implemented_op.startswith("_op_"):
115
+ not_implemented_op = not_implemented_op[4:]
116
+ log.warning(
117
+ f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}, likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950
118
+ )
119
+ return False
120
+ else:
121
+ # Likely due to unsupported dtype.
122
+ log.warning(
123
+ f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}. Reason: {not_implemented_op}" # noqa: G004, B950
124
+ )
125
+ return False
126
+ return True
127
+
128
+ @staticmethod
129
+ def _unwrap_epilogue_nodes(fused_node: FusedSchedulerNode) -> List[ir.IRNode]:
130
+ nodes = fused_node.get_nodes()
131
+ template_node = fused_node.get_template_node()
132
+ nodes.remove(template_node)
133
+ return [n.node for n in nodes]
134
+
135
+ def can_fuse_vertical(
136
+ self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
137
+ ) -> bool:
138
+ if self.is_cuda_cpp_template(node1) and isinstance(node2, SchedulerNode):
139
+ return self._can_fuse_epilogue_impl(
140
+ cast(CUDATemplateBuffer, node1.node), [], node2.node
141
+ )
142
+ elif self.is_cuda_cpp_fused_template(node1) and isinstance(
143
+ node2, SchedulerNode
144
+ ):
145
+ fnode1 = cast(FusedSchedulerNode, node1)
146
+ return self._can_fuse_epilogue_impl(
147
+ fnode1.get_template_node().node,
148
+ self._unwrap_epilogue_nodes(fnode1),
149
+ node2.node,
150
+ )
151
+ return False
152
+
153
+ def define_kernel(self, src_code: str, node_schedule) -> str:
154
+ wrapper = V.graph.wrapper_code
155
+ if src_code in wrapper.src_to_kernel:
156
+ kernel_name = wrapper.src_to_kernel[src_code]
157
+ else:
158
+ fused_name = (
159
+ get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
160
+ if config.triton.descriptive_names
161
+ else ""
162
+ )
163
+ kernel_name = "_".join(["cuda", fused_name, wrapper.next_kernel_suffix()])
164
+ # use the original src_code as the key
165
+ wrapper.src_to_kernel[src_code] = kernel_name
166
+ src_code = src_code.replace("KERNEL_NAME", kernel_name)
167
+
168
+ _, _, kernel_path = get_path(code_hash(src_code), "py")
169
+
170
+ compile_wrapper = IndentedBuffer()
171
+ compile_wrapper.writeline("async_compile.cuda(r'''")
172
+ compile_wrapper.splice(src_code, strip=True)
173
+ compile_wrapper.writeline("''', 'so')")
174
+
175
+ metadata_comment = f"# kernel path: {kernel_path}"
176
+ origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
177
+ metadata_comment += "\n" + origins + "\n" + detailed_origins
178
+ wrapper.define_kernel(
179
+ kernel_name, compile_wrapper.getvalue(), metadata_comment
180
+ )
181
+ return kernel_name
182
+
183
+ def codegen_template(
184
+ self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode]
185
+ ):
186
+ """
187
+ Codegen a CUDA template, possibly with fused epilogues
188
+ """
189
+ counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
190
+ assert self.is_cuda_cpp_template(
191
+ template_node
192
+ ), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
193
+ template_node = cast(SchedulerNode, template_node)
194
+ _, (numel, rnumel) = template_node.group
195
+ assert rnumel == 1
196
+ ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node)
197
+ epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes]
198
+ assert all(
199
+ isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes
200
+ ), "Epilogue nodes must all be instances of ir.ComputedBuffer"
201
+ kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes)
202
+ with kernel:
203
+ for node in [template_node, *epilogue_nodes]:
204
+ node.mark_run()
205
+ src_code = render()
206
+
207
+ with V.set_kernel_handler(kernel):
208
+ node_schedule = [template_node, *epilogue_nodes]
209
+ kernel_name = self.define_kernel(src_code, node_schedule)
210
+ kernel.call_kernel(kernel_name, ctb, epilogue_ir_nodes)
211
+ V.graph.removed_buffers |= kernel.removed_buffers
212
+ self.scheduler.free_buffers()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_env.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ from ... import config
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ def get_cuda_arch() -> Optional[str]:
13
+ try:
14
+ cuda_arch = config.cuda.arch
15
+ if cuda_arch is None:
16
+ # Get Compute Capability of the first Visible device
17
+ major, minor = torch.cuda.get_device_capability(0)
18
+ return str(major * 10 + minor)
19
+ return str(cuda_arch)
20
+ except Exception as e:
21
+ log.error("Error getting cuda arch: %s", e)
22
+ return None
23
+
24
+
25
+ def get_cuda_version() -> Optional[str]:
26
+ try:
27
+ cuda_version = config.cuda.version
28
+ if cuda_version is None:
29
+ cuda_version = torch.version.cuda
30
+ return cuda_version
31
+ except Exception as e:
32
+ log.error("Error getting cuda version: %s", e)
33
+ return None
34
+
35
+
36
+ @functools.lru_cache(None)
37
+ def nvcc_exist(nvcc_path: str = "nvcc") -> bool:
38
+ if nvcc_path is None:
39
+ return False
40
+ import subprocess
41
+
42
+ res = subprocess.call(
43
+ ["which", nvcc_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
44
+ )
45
+ return res == 0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_template.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import itertools
3
+ import logging
4
+ from typing import List, Optional
5
+ from unittest.mock import patch
6
+
7
+ import sympy
8
+
9
+ import torch
10
+ from ...autotune_process import CUDABenchmarkRequest, TensorMeta
11
+ from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout
12
+
13
+ from ...utils import IndentedBuffer, unique
14
+ from ...virtualized import V
15
+ from ..common import KernelTemplate
16
+ from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+
21
+ class CUDATemplate(KernelTemplate):
22
+ index_counter = itertools.count()
23
+
24
+ def __init__(
25
+ self,
26
+ name: str,
27
+ input_nodes: List[Buffer],
28
+ layout: Layout,
29
+ input_reorder: Optional[List[int]] = None,
30
+ ):
31
+ """
32
+
33
+ Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly.
34
+
35
+ Args:
36
+ name (str): The name of the CUDATemplate object.
37
+ input_nodes (List[IRNode]): A list of input IRNodes.
38
+ layout (Layout): The layout of the output buffer / tensor.
39
+ input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes.
40
+
41
+ """
42
+ super().__init__(name)
43
+ self.input_nodes = input_nodes
44
+ self.output_node: Buffer = Buffer("buf_out", layout)
45
+ self.input_reorder = input_reorder
46
+ self.layout = layout
47
+
48
+ def generate( # type: ignore[override]
49
+ self,
50
+ **kwargs,
51
+ ) -> CUDATemplateCaller:
52
+ """
53
+ Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller
54
+ may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning.
55
+
56
+ Args:
57
+ kwargs: Additional keyword arguments.
58
+
59
+ Returns:
60
+ A CUDATemplateCaller object representing the generated CUDA template caller.
61
+ """
62
+ kernel_name = f"cuda_{self.name}"
63
+ with patch.object(
64
+ V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
65
+ ), CUDATemplateKernel(
66
+ kernel_name=kernel_name,
67
+ ) as kernel:
68
+ code = self.render(kernel=kernel, **kwargs)
69
+ _, call_args, _ = kernel.args.python_argdefs()
70
+ log.debug("Generated Code:\n%s", code)
71
+ log.debug(
72
+ "Args: cpp_argdefs: %s, python_argdefs: %s",
73
+ kernel.args.cpp_argdefs(),
74
+ kernel.args.python_argdefs(),
75
+ )
76
+
77
+ input_reorder = (
78
+ self.input_reorder
79
+ if self.input_reorder is not None
80
+ else list(range(len(self.input_nodes)))
81
+ )
82
+ expected_args = list(
83
+ unique(self.input_nodes[idx].get_name() for idx in input_reorder)
84
+ )
85
+ expected_args.extend([self.output_node.get_name()])
86
+ assert list(call_args)[: len(expected_args)] == expected_args, (
87
+ call_args,
88
+ expected_args,
89
+ )
90
+ extra_args = V.graph.sizevars.size_hints(
91
+ map(sympy.expand, call_args[len(expected_args) :])
92
+ )
93
+
94
+ kernel_hash_name = f"cuda_{self.name}_{next(self.index_counter)}"
95
+
96
+ # create the BenchmarkRequest
97
+ bmreq = CUDABenchmarkRequest(
98
+ kernel_name=kernel_name,
99
+ input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
100
+ output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
101
+ extra_args=extra_args,
102
+ source_code=code,
103
+ )
104
+
105
+ def make_kernel_render(
106
+ template_node: CUDATemplateBuffer,
107
+ epilogue_nodes: Optional[List[IRNode]] = None,
108
+ ):
109
+ kernel = CUDATemplateKernel(
110
+ kernel_name="KERNEL_NAME",
111
+ )
112
+ render = functools.partial(
113
+ self.render,
114
+ kernel=kernel,
115
+ template_buffer_node=template_node,
116
+ epilogue_nodes=epilogue_nodes,
117
+ **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate
118
+ )
119
+ return kernel, render
120
+
121
+ return CUDATemplateCaller(
122
+ kernel_hash_name,
123
+ self.name,
124
+ self.input_nodes,
125
+ self.output_node.get_layout(),
126
+ make_kernel_render,
127
+ bmreq,
128
+ self,
129
+ kwargs,
130
+ )
131
+
132
+ def header(self) -> IndentedBuffer:
133
+ res = IndentedBuffer()
134
+ res.splice(
135
+ """
136
+ #include <exception>
137
+ #include <iostream>
138
+ #include <memory>
139
+ #include <random>
140
+ #include <vector>
141
+ """
142
+ )
143
+ return res
144
+
145
+ def globals(self) -> IndentedBuffer:
146
+ res = IndentedBuffer()
147
+ res.splice(
148
+ """
149
+ // We compile all models with -fvisibility=hidden. Any symbols that need to be
150
+ // exposed in the final shared library must be declared with PT_EXPORT to make
151
+ // them visible.
152
+ #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
153
+ #define PT_EXPORT __attribute__((__visibility__("default")))
154
+ #else
155
+ #ifdef _WIN32
156
+ #define PT_EXPORT __declspec(dllexport)
157
+ #else
158
+ #define PT_EXPORT
159
+ #endif
160
+ #endif
161
+ using bfloat16 = nv_bfloat16;
162
+ """
163
+ )
164
+ return res
165
+
166
+ def render(self, **kwargs) -> str:
167
+ raise NotImplementedError
168
+
169
+
170
+ class CUTLASSTemplate(CUDATemplate):
171
+ """
172
+ CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the
173
+ CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels.
174
+ """
175
+
176
+ def header(self) -> IndentedBuffer:
177
+ res = super().header()
178
+ res.splice(
179
+ """
180
+ #include "cute/tensor.hpp"
181
+ #include "cutlass/cutlass.h"
182
+ #include "cutlass/numeric_types.h"
183
+ #include "cutlass/tensor_ref.h"
184
+ #include "cutlass/util/host_tensor.h"
185
+ #include "cutlass/util/reference/host/tensor_fill.h"
186
+ #include "cutlass/util/reference/device/tensor_fill.h"
187
+ #include "cutlass/util/device_memory.h"
188
+ """
189
+ )
190
+ return res
191
+
192
+ def globals(self) -> IndentedBuffer:
193
+ res = super().globals()
194
+ res.splice(
195
+ """
196
+ using namespace cute;
197
+ #define CUTLASS_CHECK(status) \\
198
+ { \\
199
+ cutlass::Status error = status; \\
200
+ if (error != cutlass::Status::kSuccess) { \\
201
+ auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\
202
+ cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\
203
+ throw std::runtime_error(msg); \\
204
+ } \\
205
+ }
206
+
207
+ // Used as pass-through functor in EVT just for type casting / rounding
208
+ template <typename T>
209
+ struct identity_op {
210
+ CUTLASS_HOST_DEVICE
211
+ T operator()(T val) const { return val; }
212
+ };
213
+
214
+ """
215
+ )
216
+ return res
217
+
218
+ def cute_int(self, int_str: str, var_name: str) -> str:
219
+ res = ""
220
+ if int_str in {"1", "1L"}:
221
+ res = "cute::Int<1>{}"
222
+ else:
223
+ res = int_str
224
+
225
+ return f"{res} /* {var_name} */"
226
+
227
+ _DTYPE_TO_CUTLASS = {
228
+ torch.float32: "float",
229
+ torch.float64: "double",
230
+ torch.float16: "cutlass::half_t",
231
+ torch.int32: "int",
232
+ torch.int8: "int8_t",
233
+ torch.uint8: "uint8_t",
234
+ torch.bool: "bool",
235
+ torch.bfloat16: "cutlass::bfloat16_t",
236
+ }
237
+
238
+ def cutlass_type_cast(self, node: IRNode, ptr: str) -> str:
239
+ if node is None:
240
+ return ptr
241
+ else:
242
+ return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})"
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ from unittest.mock import patch
3
+
4
+ import sympy
5
+
6
+ import torch._inductor.virtualized as virtualized
7
+ from torch._inductor.ir import ComputedBuffer, FlexibleLayout, IRNode, Pointwise
8
+ from torch._inductor.utils import IndentedBuffer, sympy_str
9
+
10
+
11
+ # Used as a magic string to indicate an unsupported sympy expression
12
+ # became part of generated C++ code.
13
+ _MAGIC_SYMPY_ERROR_STRING = "[!sympy: unsupported expr!]"
14
+
15
+
16
+ def _arg_str(a):
17
+ if isinstance(a, sympy.Expr):
18
+ # If this return value containting the _MAGIC_SYMPY_ERROR_STRING
19
+ # is used as part of the final generated C++ code,
20
+ # a CUTLASSEVTOpNotImplementedError is raised to indicate that
21
+ # the op could not be converted to a valid EVT expression.
22
+ return f"{_MAGIC_SYMPY_ERROR_STRING}('{sympy_str(a)}')"
23
+ return str(a)
24
+
25
+
26
+ class CUTLASSEVTOpNotImplementedError(NotImplementedError):
27
+ pass
28
+
29
+
30
+ class CutlassEVTEpilogueTypeFormatter:
31
+ """
32
+ Codegen class, which provides an entry point to generate
33
+ Cutlass "Epilogue Visitor Tree" (EVT) functor declarations.
34
+
35
+ See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder
36
+ for more about EVTs and how they are declared and used to generate.
37
+
38
+ Notes:
39
+ * Used by CUTLASSGemmTemplate.
40
+ * This class should not be instantiated by users, it is intended to be used
41
+ by calling CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(...)
42
+ which instantiates this class as an ops handler for virtualized.V.ops.[op-name]
43
+ * Extend this with more _op_<whatever> nodes to add support for new pointwise operations.
44
+
45
+
46
+ """
47
+
48
+ def __init__(self, accumulator_node_name, evt_type_name):
49
+ """
50
+
51
+ Initialize an instance of CutlassEVTEpilogueTypeFormatter.
52
+
53
+ Parameters:
54
+ - accumulator_node_name (str): The name of the output Buffer for the GEMM operation in the original (unfused)
55
+ IR graph.
56
+ - evt_type_name (str): The output name of the EVT type we are generating.
57
+
58
+ """
59
+ self.accumulator_node_name = accumulator_node_name
60
+ self.output = IndentedBuffer(0)
61
+ self.var_counter = 0
62
+ self.evt_type_name = evt_type_name
63
+ self.aliases = dict()
64
+
65
+ @staticmethod
66
+ def ir_to_evt_string(
67
+ template_output_node_name: str,
68
+ evt_type_name: str,
69
+ epilogue_nodes: List[IRNode],
70
+ ):
71
+ """
72
+ Formats IR nodes into a string representation compatible with Cutlass EVT format.
73
+
74
+ Args:
75
+ template_output_node_name (str): The name of the template output node.
76
+ evt_type_name (str): The name of the EVT type.
77
+ epilogue_nodes (List[IRNode]): A list of IR nodes representing the epilogue nodes. As of now, these must be
78
+ ComputedBuffer nodes wrapping Pointwise nodes.
79
+
80
+ Returns:
81
+ A string representation of the IR nodes formatted according to the Cutlass EVT format.
82
+ """
83
+ formatter = CutlassEVTEpilogueTypeFormatter(
84
+ template_output_node_name, evt_type_name
85
+ )
86
+
87
+ with virtualized.V.set_ops_handler(formatter), patch.object(
88
+ FlexibleLayout, "allow_indexing", True
89
+ ):
90
+ for node in epilogue_nodes:
91
+ if isinstance(node, ComputedBuffer):
92
+ pnode = node.data
93
+ else:
94
+ raise RuntimeError(
95
+ "Epilogue nodes must be Pointwise nodes, wrapped in a named ComputedBuffer"
96
+ )
97
+ assert isinstance(pnode, Pointwise)
98
+ index = pnode._index(pnode.ranges)
99
+ result = pnode.inner_fn(index)
100
+ # each epilogue node results in a single "using" statement and may refer to the previous steps by name
101
+ formatter.aliases[node.name] = result
102
+ res = formatter.getvalue(result) # type: ignore[possibly-undefined]
103
+ if _MAGIC_SYMPY_ERROR_STRING in res:
104
+ raise CUTLASSEVTOpNotImplementedError(
105
+ "sympy / indexing expressions not yet supported in EVT fusion"
106
+ )
107
+ else:
108
+ return res
109
+
110
+ def __getattr__(self, name):
111
+ """
112
+ Resolve V.ops.<whatever> calls, after this instance has been installed as V.ops handler.
113
+ """
114
+
115
+ def inner(*args, **kwargs):
116
+ fargs = [_arg_str(a) for a in args]
117
+ fkwargs = {key: _arg_str(a) for key, a in kwargs.items()}
118
+ fn = getattr(self, f"_op_{name}")
119
+ line = fn(*fargs, **fkwargs)
120
+ self.var_counter += 1
121
+ varname = f"EVT_expr_{self.var_counter}"
122
+ # replace line with a new variable name
123
+ self.output.writeline(f"using {varname} = {line};")
124
+ return varname
125
+
126
+ if name.startswith("_"):
127
+ raise CUTLASSEVTOpNotImplementedError(name)
128
+ if hasattr(self, f"_op_{name}"):
129
+ return inner
130
+ else:
131
+ raise CUTLASSEVTOpNotImplementedError(name)
132
+
133
+ def _op_load(self, name, index_expr):
134
+ # Load an input to an operation. Might be the output of the matmul, the result
135
+ # of a previous epilogue node, a constant or (TODO) an auxiliary input.
136
+ if name == self.accumulator_node_name:
137
+ return f"cutlass::epilogue::fusion::Sm90AccFetch /* :={name} (matmul output in accumulator) */"
138
+ elif name in self.aliases:
139
+ return self.aliases[name]
140
+ else:
141
+ # return f"cutlass::epilogue::fusion::Sm90SrcFetch /* :={name} */"
142
+ raise CUTLASSEVTOpNotImplementedError(
143
+ f"Operand {name} not found. Auxiliary inputs not supported yet."
144
+ )
145
+
146
+ def _op_constant(self, value, dtype):
147
+ # Load a constant
148
+ if str(dtype) in ("torch.float16", "torch.float32"):
149
+ return f"cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc> /* value={value}, dtype={dtype} */"
150
+ else:
151
+ raise CUTLASSEVTOpNotImplementedError(
152
+ f"Unsupported dtype for constant: {dtype}"
153
+ )
154
+
155
+ def _cutlass_binary_functional_op(self, op, a, b):
156
+ # Perform a named operation on two inputs
157
+ # see https://github.com/NVIDIA/cutlass/blob/6407bcdf0a24097b7b016ee105937693c62f9923/include/cutlass/functional.h for ops
158
+ return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::{op}, ElementAcc, ElementAcc, RoundStyle>,{a},{b}>" # noqa: B950
159
+
160
+ def _convert_to_output_dtype(self, a):
161
+ # Convert the final output to the dtype of the output buffer
162
+ return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,{a}>" # noqa: B950
163
+
164
+ def _op_to_dtype(self, a, *args, **kwargs):
165
+ # no-op in our case, since we convert to the output dtype at the end and convert everything to the accumulator
166
+ # dtype.
167
+ # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible
168
+ # throughout the fusion chain.
169
+ return a # noqa: B950
170
+
171
+ def _op_mul(self, a, b):
172
+ return self._cutlass_binary_functional_op("multiplies", a, b)
173
+
174
+ def _op_div(self, a, b):
175
+ return self._cutlass_binary_functional_op("divides", a, b)
176
+
177
+ def _op_truediv(self, a, b):
178
+ return self._cutlass_binary_functional_op("divides", a, b)
179
+
180
+ def _op_ge(self, a, b):
181
+ return self._cutlass_binary_functional_op("greater_equal", a, b)
182
+
183
+ def _op_add(self, a, b):
184
+ return self._cutlass_binary_functional_op("plus", a, b)
185
+
186
+ def _op_sub(self, a, b):
187
+ return self._cutlass_binary_functional_op("minus", a, b)
188
+
189
+ def _op_minimum(self, a, b):
190
+ return self._cutlass_binary_functional_op("minimum", a, b)
191
+
192
+ def _op_maximum(self, a, b):
193
+ return self._cutlass_binary_functional_op("maximum", a, b)
194
+
195
+ def _op_relu(self, a):
196
+ const_zero = self._op_constant(0.0, "torch.float32")
197
+ return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, ElementAcc, ElementAcc, RoundStyle>,{a}, {const_zero}>" # noqa: B950
198
+
199
+ def reduction(self, dtype, src_dtype, reduction_type, value):
200
+ raise CUTLASSEVTOpNotImplementedError()
201
+
202
+ # Add more ops here...
203
+ def getvalue(self, result) -> str:
204
+ # Return final result
205
+ dtype_converted_expr = self._convert_to_output_dtype(
206
+ f"EVT_expr_{self.var_counter}"
207
+ )
208
+ self.output.writeline(f"using {self.evt_type_name} = {dtype_converted_expr};")
209
+ return self.output.getvalue()
210
+
211
+
212
+ class CutlassEVTEpilogueArgumentFormatter:
213
+ """
214
+ Codegen class, which provides an entry point to generate
215
+ Cutlass "Epilogue Visitor Tree" (EVT) Argument initializers
216
+
217
+ See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder
218
+ for more about EVTs and how they are declared and used to generate.
219
+
220
+ Notes:
221
+ * Used by CUTLASSGemmTemplate.
222
+ * This class should not be instantiated by users, it is intended to be used
223
+ by calling CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(...)
224
+ which instantiates this class as an ops handler for virtualized.V.ops.[op-name]
225
+ * Extend this with more _op_<whatever> nodes to add support for new pointwise operations.
226
+
227
+
228
+ """
229
+
230
+ def __init__(self, accumulator_node_name: str):
231
+ """
232
+
233
+ Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly.
234
+ Use the CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string static method.
235
+
236
+ Args:
237
+ accumulator_node_name (str): The name of the accumulator node which should contain
238
+ the Matmul result before fusion according to the IR graph.
239
+ """
240
+ self.accumulator_node_name: str = accumulator_node_name #
241
+ self.output: IndentedBuffer = IndentedBuffer(0) # The output buffer for codegen
242
+ self.var_counter: int = (
243
+ 0 # used to generate variable names, incremented for each new variable
244
+ )
245
+ self.aliases: Dict[str, str] = dict() # Aliases for subexpression functors
246
+
247
+ @staticmethod
248
+ def ir_to_evt_argument_string(
249
+ template_output_node_name: str,
250
+ epilogue_nodes: List[IRNode],
251
+ ) -> str:
252
+ formatter = CutlassEVTEpilogueArgumentFormatter(
253
+ template_output_node_name,
254
+ )
255
+
256
+ with virtualized.V.set_ops_handler(formatter), patch.object(
257
+ FlexibleLayout, "allow_indexing", True
258
+ ):
259
+ for node in epilogue_nodes:
260
+ assert isinstance(node, ComputedBuffer)
261
+ pnode = node.data
262
+ assert isinstance(pnode, Pointwise)
263
+ index = pnode._index(pnode.ranges)
264
+ result = pnode.inner_fn(index)
265
+ # each epilogue node results in a single "using" statement and may refer to the previous steps by name
266
+ if node.name is not None:
267
+ formatter.aliases[node.name] = result
268
+
269
+ res: str = formatter.getvalue(result) # type: ignore[possibly-undefined]
270
+ if _MAGIC_SYMPY_ERROR_STRING in res:
271
+ raise CUTLASSEVTOpNotImplementedError(
272
+ "sympy / indexing expressions not yet supported in EVT fusion"
273
+ )
274
+ else:
275
+ return res
276
+
277
+ def __getattr__(self, name):
278
+ def inner(*args, **kwargs):
279
+ fargs = [_arg_str(a) for a in args]
280
+ fkwargs = {key: _arg_str(a) for key, a in kwargs.items()}
281
+ fn = getattr(self, f"_op_{name}")
282
+ line = fn(*fargs, **fkwargs)
283
+ return line
284
+
285
+ if name.startswith("_"):
286
+ raise CUTLASSEVTOpNotImplementedError(name)
287
+
288
+ if hasattr(self, f"_op_{name}"):
289
+ return inner
290
+ else:
291
+ raise CUTLASSEVTOpNotImplementedError(name)
292
+
293
+ def _op_load(self, name, index_expr):
294
+ if name == self.accumulator_node_name:
295
+ return "{}"
296
+ elif name in self.aliases:
297
+ return self.aliases[name]
298
+ else:
299
+ raise CUTLASSEVTOpNotImplementedError(
300
+ f"Operand {name} not found. Auxiliary inputs not supported yet."
301
+ )
302
+
303
+ def _op_constant(self, value, dtype):
304
+ if str(dtype) in ("torch.float16", "torch.float32"):
305
+ return "{ static_cast<ElementAcc>(" + str(value) + ") }"
306
+ else:
307
+ raise CUTLASSEVTOpNotImplementedError(
308
+ f"Unsupported dtype for constant: {dtype}"
309
+ )
310
+
311
+ def _cutlass_binary_functional_op(self, op, a, b):
312
+ return f"{{ /*{op}: */ {a}, {b} }}"
313
+
314
+ def _op_mul(self, a, b):
315
+ return self._cutlass_binary_functional_op("multiplies", a, b)
316
+
317
+ def _op_div(self, a, b):
318
+ return self._cutlass_binary_functional_op("divides", a, b)
319
+
320
+ def _op_truediv(self, a, b):
321
+ return self._cutlass_binary_functional_op("divides", a, b)
322
+
323
+ def _op_ge(self, a, b):
324
+ return self._cutlass_binary_functional_op("greater_equal", a, b)
325
+
326
+ def _op_add(self, a, b):
327
+ return self._cutlass_binary_functional_op("plus", a, b)
328
+
329
+ def _op_sub(self, a, b):
330
+ return self._cutlass_binary_functional_op("minus", a, b)
331
+
332
+ def _op_minimum(self, a, b):
333
+ return self._cutlass_binary_functional_op("minimum", a, b)
334
+
335
+ def _op_maximum(self, a, b):
336
+ return self._cutlass_binary_functional_op("maximum", a, b)
337
+
338
+ def _op_relu(self, a):
339
+ const_zero = self._op_constant(0.0, "torch.float32")
340
+ return "{" + str(a) + ", " + const_zero + "}"
341
+
342
+ def _op_to_dtype(self, a, dtype, src_dtype=None):
343
+ # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible
344
+ # throughout the fusion chain.
345
+ assert dtype in (
346
+ "torch.float32",
347
+ "torch.float16",
348
+ ), f"Unsupported dtype: {dtype}"
349
+ assert src_dtype in (
350
+ None,
351
+ "torch.float32",
352
+ "torch.float16",
353
+ ), f"Unsupported source dtype: {src_dtype}"
354
+ return a
355
+
356
+ def reduction(self, dtype, src_dtype, reduction_type, value):
357
+ raise CUTLASSEVTOpNotImplementedError()
358
+
359
+ def getvalue(self, result) -> str:
360
+ return "{" + str(result) + "}"
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..cutlass_utils import try_import_cutlass
2
+
3
+ if try_import_cutlass():
4
+ import enum
5
+
6
+ from cutlass_library.library import * # noqa: F401, F403
7
+ from cutlass_library.gemm_operation import * # noqa: F401, F403
8
+
9
+ # copied / modified from original at
10
+ # https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658
11
+ # to support EVT similar to
12
+ # https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L315C69-L315C69 # noqa: B950
13
+ class EmitGemmUniversal3xInstanceWithEVT:
14
+ """Responsible for emitting a CUTLASS 3.x template definition"""
15
+
16
+ def __init__(self, operation_suffix=""):
17
+ self.operation_suffix = operation_suffix
18
+ self.includes = [
19
+ "cutlass/cutlass.h",
20
+ "cutlass/gemm/gemm.h",
21
+ "cutlass/numeric_types.h",
22
+ "cutlass/gemm/kernel/gemm_universal.hpp",
23
+ "cutlass/gemm/collective/collective_builder.hpp",
24
+ "cutlass/epilogue/collective/collective_builder.hpp",
25
+ ]
26
+ self.builtin_epilogue_functor_template = """
27
+ ${epilogue_functor}<
28
+ ${element_c},
29
+ ${epilogue_vector_length},
30
+ ${element_accumulator},
31
+ ${element_epilogue}
32
+ >
33
+ """
34
+ self.gemm_template = """
35
+ using EpilogueScheduleType = ${epilogue_schedule};
36
+ static_assert(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> ||
37
+ cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>,
38
+ "Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue");
39
+ static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
40
+ using ElementAcc = ${element_accumulator};
41
+ using ElementD = ${element_d};
42
+ ${epilogue_functor};
43
+ using ${operation_name}_epilogue =
44
+ typename cutlass::epilogue::collective::CollectiveBuilder<
45
+ ${arch}, ${opcode_class},
46
+ cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
47
+ cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
48
+ cutlass::epilogue::collective::EpilogueTileAuto,
49
+ ${element_accumulator}, ${element_epilogue},
50
+ ${element_c}, ${layout_c}, ${align_c},
51
+ ${element_d}, ${layout_d}, ${align_d},
52
+ EpilogueScheduleType,
53
+ ${operation_name}_epilogue_functor
54
+ >::CollectiveOp;
55
+
56
+ using ${operation_name}_mainloop =
57
+ typename cutlass::gemm::collective::CollectiveBuilder<
58
+ ${arch}, ${opcode_class},
59
+ ${element_a}, ${layout_a}, ${align_a},
60
+ ${element_b}, ${layout_b}, ${align_b},
61
+ ${element_accumulator},
62
+ cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
63
+ cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
64
+ ${stages},
65
+ ${kernel_schedule}
66
+ >::CollectiveOp;
67
+
68
+ // Gemm operator ${operation_name}
69
+ using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
70
+ cute::Shape<int,int,int,int>,
71
+ ${operation_name}_mainloop,
72
+ ${operation_name}_epilogue,
73
+ ${tile_scheduler}>;
74
+
75
+ // Define named type
76
+ struct ${operation_name} :
77
+ public ${operation_name}_base { };
78
+
79
+ """
80
+
81
+ #
82
+ def instance_template(self):
83
+ return """
84
+ ${compile_guard_start}
85
+ using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
86
+ manifest.append(
87
+ new ${gemm_kind}<GemmKernel>("${operation_name}"));
88
+ ${compile_guard_end}
89
+ """
90
+
91
+ #
92
+ def emit(self, operation):
93
+ tile_shape = operation.tile_description.tile_shape
94
+ warp_count = operation.tile_description.warp_count
95
+ # stage count set to zero indicates builder automatic stage selection
96
+ if operation.tile_description.stages > 0:
97
+ stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>"
98
+ else:
99
+ stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage)>" # noqa: B950
100
+ warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)]
101
+
102
+ (
103
+ instance_layout_A,
104
+ instance_layout_B,
105
+ instance_layout_C,
106
+ instance_layout_D,
107
+ ) = (
108
+ operation.A.layout,
109
+ operation.B.layout,
110
+ operation.C.layout,
111
+ operation.D.layout,
112
+ )
113
+
114
+ # 3.0 profiler integration only supports trivial epilogues for now
115
+ epilogue_vector_length = 1
116
+
117
+ # Support built-in epilogue functors or user-defined functions
118
+ if isinstance(operation.epilogue_functor, enum.Enum):
119
+ values = {
120
+ "epilogue_vector_length": str(epilogue_vector_length),
121
+ "element_epilogue": str(DataTypeTag[operation.element_epilogue]), # type: ignore[name-defined]
122
+ "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], # type: ignore[name-defined]
123
+ }
124
+ epilogue_functor = SubstituteTemplate( # type: ignore[name-defined]
125
+ self.builtin_epilogue_functor_template, values
126
+ )
127
+
128
+ elif callable(operation.epilogue_functor):
129
+ epilogue_functor = operation.epilogue_functor(
130
+ operation.procedural_name() + "_epilogue_functor"
131
+ )
132
+ else:
133
+ epilogue_functor = str(operation.epilogue_functor)
134
+ #
135
+
136
+ values = {
137
+ "operation_name": operation.procedural_name(),
138
+ "operation_suffix": self.operation_suffix,
139
+ "element_a": DataTypeTag[operation.A.element], # type: ignore[name-defined]
140
+ "layout_a": LayoutTag[instance_layout_A], # type: ignore[name-defined]
141
+ "element_b": DataTypeTag[operation.B.element], # type: ignore[name-defined]
142
+ "layout_b": LayoutTag[instance_layout_B], # type: ignore[name-defined]
143
+ "element_c": DataTypeTag[operation.C.element], # type: ignore[name-defined]
144
+ "layout_c": LayoutTag[instance_layout_C], # type: ignore[name-defined]
145
+ "element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined]
146
+ "layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined]
147
+ "element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined]
148
+ "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], # type: ignore[name-defined] # noqa: B950
149
+ "arch": "cutlass::arch::Sm%d" % operation.arch,
150
+ "tile_shape_m": str(operation.tile_description.tile_shape[0]),
151
+ "tile_shape_n": str(operation.tile_description.tile_shape[1]),
152
+ "tile_shape_k": str(operation.tile_description.tile_shape[2]),
153
+ "cluster_m": str(operation.tile_description.cluster_shape[0]),
154
+ "cluster_n": str(operation.tile_description.cluster_shape[1]),
155
+ "cluster_k": str(operation.tile_description.cluster_shape[2]),
156
+ "warp_shape_m": str(warp_shape[0]),
157
+ "warp_shape_n": str(warp_shape[1]),
158
+ "warp_shape_k": str(warp_shape[2]),
159
+ "instruction_shape_m": str(
160
+ operation.tile_description.math_instruction.instruction_shape[0]
161
+ ),
162
+ "instruction_shape_n": str(
163
+ operation.tile_description.math_instruction.instruction_shape[1]
164
+ ),
165
+ "instruction_shape_k": str(
166
+ operation.tile_description.math_instruction.instruction_shape[2]
167
+ ),
168
+ "kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined]
169
+ "epilogue_schedule": str(EpilogueScheduleTag[operation.epilogue_schedule]), # type: ignore[name-defined]
170
+ "epilogue_functor": epilogue_functor,
171
+ "stages": stage_count_string,
172
+ "align_a": str(operation.A.alignment),
173
+ "align_b": str(operation.B.alignment),
174
+ "align_c": str(operation.C.alignment),
175
+ "align_d": str(operation.C.alignment),
176
+ "transform_a": ComplexTransformTag[operation.A.complex_transform], # type: ignore[name-defined]
177
+ "transform_b": ComplexTransformTag[operation.B.complex_transform], # type: ignore[name-defined]
178
+ "math_operation": MathOperationTag[ # type: ignore[name-defined]
179
+ operation.tile_description.math_instruction.math_operation
180
+ ],
181
+ "epilogue_vector_length": str(epilogue_vector_length),
182
+ "element_epilogue": str(DataTypeTag[operation.element_epilogue]), # type: ignore[name-defined]
183
+ "tile_scheduler": str(TileSchedulerTag[operation.tile_scheduler]), # type: ignore[name-defined]
184
+ }
185
+
186
+ return SubstituteTemplate(self.gemm_template, values) # type: ignore[name-defined]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..common import DeviceOpOverrides, register_device_op_overrides
2
+
3
+
4
+ class CUDADeviceOpOverrides(DeviceOpOverrides):
5
+ def import_get_raw_stream_as(self, name):
6
+ return f"from torch._C import _cuda_getCurrentRawStream as {name}"
7
+
8
+ def set_device(self, device_idx):
9
+ return f"torch.cuda.set_device({device_idx})"
10
+
11
+ def synchronize(self):
12
+ return "torch.cuda.synchronize()"
13
+
14
+ def device_guard(self, device_idx):
15
+ return f"torch.cuda._DeviceGuard({device_idx})"
16
+
17
+
18
+ register_device_op_overrides("cuda", CUDADeviceOpOverrides())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from ..scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode
4
+ from .cuda.cuda_cpp_scheduling import CUDACPPScheduling
5
+
6
+ from .triton import TritonScheduling
7
+
8
+
9
+ class CUDACombinedScheduling(BaseScheduling):
10
+ """
11
+ Scheduler for CUDA Kernels, which delegates calls as appropriate
12
+ to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices
13
+ and use a unified-wrapper for codegen.
14
+
15
+ If Scheduling code needs to be specialized for the case of mixed Triton / CUDA C++ code,
16
+ this would also be the place to do it.
17
+ """
18
+
19
+ def __init__(self, scheduler: Scheduler):
20
+ super().__init__()
21
+ self._scheduler = scheduler
22
+ self._triton_scheduling = TritonScheduling(scheduler)
23
+ self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler)
24
+
25
+ def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling:
26
+ if self._cuda_cpp_scheduling.is_cuda_cpp_template(
27
+ node
28
+ ) or self._cuda_cpp_scheduling.is_cuda_cpp_fused_template(node):
29
+ return self._cuda_cpp_scheduling
30
+ return self._triton_scheduling
31
+
32
+ def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
33
+ if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2):
34
+ return True
35
+ return self._triton_scheduling.can_fuse_vertical(node1, node2)
36
+
37
+ def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
38
+ for node in (node1, node2):
39
+ if self._cuda_cpp_scheduling.is_cuda_cpp_template(
40
+ node
41
+ ) or self._cuda_cpp_scheduling.is_cuda_cpp_fused_template(node):
42
+ return self._cuda_cpp_scheduling.can_fuse_horizontal(
43
+ node1, node2
44
+ ) # always False at the moment
45
+ return self._triton_scheduling.can_fuse_horizontal(node1, node2)
46
+
47
+ def group_fn(self, sizes):
48
+ return self._triton_scheduling.group_fn(sizes)
49
+
50
+ def codegen_template(
51
+ self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode]
52
+ ):
53
+ if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
54
+ return self._cuda_cpp_scheduling.codegen_template(
55
+ template_node, epilogue_nodes
56
+ )
57
+ else:
58
+ return self._triton_scheduling.codegen_template(
59
+ template_node, epilogue_nodes
60
+ )
61
+
62
+ def codegen_nodes(self, nodes: List[SchedulerNode]):
63
+ return self._triton_scheduling.codegen_nodes(nodes)
64
+
65
+ def codegen_sync(self):
66
+ return self._triton_scheduling.codegen_sync()
67
+
68
+ def flush(self):
69
+ return self._triton_scheduling.flush()
70
+
71
+ def codegen_foreach(self, *args, **kwargs):
72
+ return self._triton_scheduling.codegen_foreach(*args, **kwargs)
73
+
74
+ def benchmark_fused_nodes(self, nodes):
75
+ return self._triton_scheduling.benchmark_fused_nodes(nodes)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/misc_patterns.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ from typing import Dict, Set, Tuple
4
+
5
+ import torch
6
+ from torch._dynamo.utils import counters
7
+
8
+ from torch._ops import OpOverload, OpOverloadPacket
9
+ from ..pattern_matcher import fwd_only, register_replacement
10
+
11
+ aten = torch.ops.aten
12
+
13
+
14
+ @functools.lru_cache(None)
15
+ def _misc_patterns_init():
16
+ from .joint_graph import patterns as joint_graph_patterns
17
+ from .post_grad import pass_patterns as post_grad_patterns_all
18
+
19
+ post_grad_patterns = post_grad_patterns_all[1] # medium priority
20
+
21
+ if torch.cuda.is_available():
22
+ # workaround https://github.com/pytorch/pytorch/issues/97894
23
+ device = "cuda"
24
+ else:
25
+ device = "cpu"
26
+
27
+ # These patterns do 2 things
28
+ # 1. Since we know that index is completely unique, we can codegen it using
29
+ # stores instead of atomic adds, which is quite a bit faster.
30
+ # 2. Also, since we are guaranteed that they are completely within bounds,
31
+ # we can use unsafe indexing and skip debug asserts
32
+ def randperm_index_add_pattern(x, y):
33
+ index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
34
+ return torch.index_add(x, dim=0, source=y, index=index), index
35
+
36
+ def randperm_index_add_replacement(x, y):
37
+ index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
38
+ return (
39
+ torch.ops.aten._unsafe_index_put(
40
+ x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False
41
+ ),
42
+ index,
43
+ )
44
+
45
+ register_replacement(
46
+ randperm_index_add_pattern,
47
+ randperm_index_add_replacement,
48
+ [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)],
49
+ fwd_only,
50
+ [post_grad_patterns, joint_graph_patterns],
51
+ )
52
+
53
+ def randperm_index_pattern(x, slice_shape):
54
+ index = torch.randperm(x.shape[0], device=x.device)[:slice_shape]
55
+ return torch.ops.aten.index(x, (index,)), index
56
+
57
+ def randperm_index_replacement(x, slice_shape):
58
+ index = torch.randperm(x.shape[0], device=x.device)[:slice_shape]
59
+ return torch.ops.aten._unsafe_index(x, (index,)), index
60
+
61
+ pattern = register_replacement(
62
+ randperm_index_pattern,
63
+ randperm_index_replacement,
64
+ [torch.empty(4, 8, device=device)],
65
+ fwd_only,
66
+ [post_grad_patterns, joint_graph_patterns],
67
+ scalar_workaround={"slice_shape": 42},
68
+ )
69
+
70
+
71
+ class NumpyCompatNormalization:
72
+ numpy_compat: Dict[str, Tuple[str, ...]] = {
73
+ "dim": ("axis",),
74
+ "keepdim": ("keepdims",),
75
+ "input": ("x", "a", "x1"),
76
+ "other": ("x2",),
77
+ }
78
+ inverse_mapping: Dict[str, str]
79
+ cache: Dict["torch.fx.graph.Target", Set[str]]
80
+
81
+ def __init__(self):
82
+ self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"]
83
+ self.inverse_mapping = {}
84
+ for actual_kwarg, numpy_kwargs in self.numpy_compat.items():
85
+ for numpy_kwarg in numpy_kwargs:
86
+ assert numpy_kwarg not in self.inverse_mapping
87
+ self.inverse_mapping[numpy_kwarg] = actual_kwarg
88
+
89
+ def __call__(self, graph: torch.fx.Graph):
90
+ for node in graph.nodes:
91
+ if node.op != "call_function":
92
+ continue
93
+ if isinstance(node.target, (OpOverload, OpOverloadPacket)):
94
+ # only applies to torch ops; e.g. torch.stack(axis=1) works, torch.ops.aten.stack(axis=1) doesn't.
95
+ continue
96
+ kwargs = node.kwargs
97
+
98
+ if node.target in self.cache:
99
+ replaceable_kwargs = self.cache[node.target]
100
+ else:
101
+ signatures = torch.fx.operator_schemas.get_signature_for_torch_op(
102
+ node.target
103
+ )
104
+ signatures = () if signatures is None else signatures
105
+ replaceable_kwargs = set()
106
+ for sig in signatures:
107
+ for param_name in sig.parameters.keys():
108
+ if param_name in self.numpy_compat:
109
+ replaceable_kwargs.update(self.numpy_compat[param_name])
110
+
111
+ self.cache[node.target] = replaceable_kwargs
112
+
113
+ if not replaceable_kwargs:
114
+ continue
115
+
116
+ new_kwargs = {}
117
+ kwargs_changed = False
118
+ for k, v in kwargs.items():
119
+ if k in replaceable_kwargs:
120
+ kwargs_changed = True
121
+ new_kwargs[self.inverse_mapping[k]] = v
122
+ else:
123
+ new_kwargs[k] = v
124
+
125
+ if kwargs_changed:
126
+ node.kwargs = torch.fx.immutable_collections.immutable_dict(new_kwargs)
127
+ counters["inductor"]["numpy_compat_normalization"] += 1
128
+
129
+
130
+ numpy_compat_normalization = NumpyCompatNormalization()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py ADDED
@@ -0,0 +1,1204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import operator
3
+ from functools import reduce
4
+ from typing import Any, Tuple
5
+
6
+ import torch
7
+
8
+ from torch.fx.experimental.symbolic_shapes import has_free_symbols
9
+
10
+ from .. import ir
11
+
12
+ from ..lowering import lowerings as L
13
+ from ..pattern_matcher import (
14
+ Arg,
15
+ CallFunction,
16
+ filter_nodes,
17
+ get_arg_value,
18
+ KeywordArg,
19
+ MULTIPLE,
20
+ )
21
+ from ..virtualized import ops
22
+ from .freezing_patterns import register_freezing_graph_pattern
23
+ from .post_grad import register_lowering_pattern
24
+ from .quantization import (
25
+ _register_quantization_lowerings,
26
+ _register_quantization_weight_pack_pass,
27
+ )
28
+
29
+ if torch._C._has_mkldnn:
30
+ aten = torch.ops.aten
31
+ mkldnn = torch.ops.mkldnn
32
+ prims = torch.ops.prims
33
+
34
+ _conv_args = [Arg() for _ in range(10)]
35
+ _linear_args = [Arg() for _ in range(6)]
36
+ _conv_transpose_args = [Arg() for _ in range(11)]
37
+
38
+ def _conv_call(users=1):
39
+ return CallFunction(
40
+ mkldnn._convolution_pointwise.default, *_conv_args, _users=users
41
+ )
42
+
43
+ def _linear_call(users=1):
44
+ return CallFunction(
45
+ mkldnn._linear_pointwise.default, *_linear_args, _users=users
46
+ )
47
+
48
+ def _conv_transpose_call(users=1):
49
+ return CallFunction(
50
+ mkldnn._convolution_transpose_pointwise.default,
51
+ *_conv_transpose_args,
52
+ _users=users,
53
+ )
54
+
55
+ def _to_float(input_call, users=1):
56
+ return CallFunction(
57
+ prims.convert_element_type.default,
58
+ input_call,
59
+ KeywordArg("to_float"),
60
+ _users=users,
61
+ )
62
+
63
+ def _to_bf16(input_call):
64
+ return CallFunction(
65
+ prims.convert_element_type.default,
66
+ input_call,
67
+ KeywordArg("to_bf16"),
68
+ _users=1,
69
+ )
70
+
71
+ def _to_fp16(input_call):
72
+ return CallFunction(
73
+ prims.convert_element_type.default,
74
+ input_call,
75
+ KeywordArg("to_fp16"),
76
+ _users=1,
77
+ )
78
+
79
+ def _unary_fusion_pattern(unary_fusion, call_fn, users, lowp_dtype):
80
+ # only insert to_dtype if lowp_dtype is True
81
+ computation_call = (
82
+ _to_float(call_fn(), users=users) if lowp_dtype else call_fn(users=users)
83
+ )
84
+ out = unary_fusion(computation_call)
85
+ if lowp_dtype == torch.bfloat16:
86
+ return _to_bf16(out)
87
+ elif lowp_dtype == torch.float16:
88
+ return _to_fp16(out)
89
+ else:
90
+ return out
91
+
92
+ def _gelu_fusion_1(computation_call):
93
+ return CallFunction(
94
+ aten.mul,
95
+ CallFunction(aten.mul, computation_call, 0.5),
96
+ CallFunction(
97
+ aten.add,
98
+ CallFunction(
99
+ aten.erf,
100
+ CallFunction(aten.mul, computation_call, 0.7071067811865476),
101
+ ),
102
+ 1,
103
+ ),
104
+ )
105
+
106
+ def _gelu_fusion_2(computation_call):
107
+ return CallFunction(
108
+ aten.mul,
109
+ CallFunction(aten.mul, computation_call, 0.5),
110
+ CallFunction(
111
+ aten.add,
112
+ CallFunction(
113
+ aten.tanh,
114
+ CallFunction(
115
+ aten.mul,
116
+ CallFunction(
117
+ aten.add,
118
+ computation_call,
119
+ CallFunction(
120
+ aten.mul,
121
+ CallFunction(
122
+ aten.mul,
123
+ CallFunction(
124
+ aten.mul, computation_call, computation_call
125
+ ),
126
+ computation_call,
127
+ ),
128
+ 0.044715,
129
+ ),
130
+ ),
131
+ 0.7978845608028654,
132
+ ),
133
+ ),
134
+ 1,
135
+ ),
136
+ )
137
+
138
+ def _hardswish_fusion(computation_call):
139
+ return CallFunction(
140
+ aten.div,
141
+ CallFunction(
142
+ aten.mul,
143
+ computation_call,
144
+ CallFunction(
145
+ aten.clamp_max,
146
+ CallFunction(
147
+ aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0
148
+ ),
149
+ 6,
150
+ ),
151
+ ),
152
+ 6,
153
+ )
154
+
155
+ def _silu_fusion(computation_call):
156
+ return CallFunction(
157
+ aten.mul, computation_call, CallFunction(aten.sigmoid, computation_call)
158
+ )
159
+
160
+ def _hardsigmoid_fusion(computation_call):
161
+ return CallFunction(
162
+ aten.div,
163
+ CallFunction(
164
+ aten.clamp_max,
165
+ CallFunction(
166
+ aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0
167
+ ),
168
+ 6,
169
+ ),
170
+ 6,
171
+ )
172
+
173
+ def _leaky_relu_fusion(computation_call):
174
+ return CallFunction(
175
+ aten.where,
176
+ CallFunction(aten.gt, computation_call, 0),
177
+ computation_call,
178
+ CallFunction(aten.mul, computation_call, KeywordArg("negative_slope")),
179
+ )
180
+
181
+ def _hardtanh_fusion(computation_call):
182
+ return CallFunction(
183
+ aten.clamp_max,
184
+ CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")),
185
+ KeywordArg("max_value"),
186
+ )
187
+
188
+ def _combined_fusion(computation_call, elementwise_op):
189
+ return CallFunction(elementwise_op, computation_call)
190
+
191
+ # binary_op(other, computation_op)
192
+ def _binary_fusion_v1(computation_call, binary_fn):
193
+ return CallFunction(binary_fn, KeywordArg("other"), computation_call)
194
+
195
+ # binary_op(computation_op, other)
196
+ def _binary_fusion_v2(computation_call, binary_fn):
197
+ return CallFunction(binary_fn, computation_call, KeywordArg("other"))
198
+
199
+ def _is_single_computation_op(computation_op):
200
+ def fn(match):
201
+ computation_nodes = filter_nodes(match.nodes, computation_op)
202
+ if len(computation_nodes) < 1:
203
+ return False
204
+ if any(n.args[-3] != "none" for n in computation_nodes):
205
+ return False
206
+ return True
207
+
208
+ return fn
209
+
210
+ def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None):
211
+ def fn(match):
212
+ matched = _is_single_computation_op(computation_op)(match)
213
+ computation_node = filter_nodes(match.nodes, computation_op)[0]
214
+ if lowp_dtype:
215
+ conversion_dtype_nodes = filter_nodes(
216
+ match.nodes, prims.convert_element_type.default
217
+ )
218
+ if len(conversion_dtype_nodes) != 2:
219
+ return False
220
+ # fusion pattern is always in the form of computation_op + to_float32 + unary_op + to_bfloat16
221
+ if computation_node == conversion_dtype_nodes[0].args[0]:
222
+ to_float = conversion_dtype_nodes[0].args[1]
223
+ to_lp = conversion_dtype_nodes[1].args[1]
224
+ else:
225
+ to_float = conversion_dtype_nodes[1].args[1]
226
+ to_lp = conversion_dtype_nodes[0].args[1]
227
+ matched = matched and to_float == torch.float and to_lp == lowp_dtype
228
+ return matched
229
+
230
+ return fn
231
+
232
+ def _register_unary_fusion_lowering(
233
+ pattern, unary_attr, computation_op, lowp_dtype=None
234
+ ):
235
+ @register_lowering_pattern(
236
+ pattern,
237
+ extra_check=_is_valid_computation_unary_fusion(computation_op, lowp_dtype),
238
+ )
239
+ def fn(match, *args, **kwargs):
240
+ computation_args = list(args)[:-3] + [
241
+ unary_attr.op_name,
242
+ unary_attr.scalars_attr,
243
+ unary_attr.algorithm_attr,
244
+ ]
245
+ return L[computation_op](*computation_args)
246
+
247
+ return fn
248
+
249
+ def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None):
250
+ @register_lowering_pattern(
251
+ pattern, extra_check=_is_single_computation_op(computation_op)
252
+ )
253
+ def fn(match, *args, **kwargs):
254
+ negative_slope = kwargs.get("negative_slope")
255
+ if isinstance(negative_slope, ir.TensorBox):
256
+ matched = False
257
+ else: # inp is a Number
258
+ matched = True
259
+ if lowp_dtype:
260
+ dtype1 = kwargs.get("to_float")
261
+ dtype2 = (
262
+ kwargs.get("to_bf16")
263
+ if lowp_dtype == torch.bfloat16
264
+ else kwargs.get("to_fp16")
265
+ )
266
+ matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
267
+ computation_args = list(args)
268
+ if matched:
269
+ computation_args = computation_args[:-3] + [
270
+ "leaky_relu",
271
+ [negative_slope],
272
+ "",
273
+ ]
274
+ return L[computation_op](*computation_args)
275
+ else:
276
+ # computation_args += ["none", [], ""]
277
+ out = L[computation_op](*computation_args)
278
+ if lowp_dtype:
279
+ out = L[prims.convert_element_type.default](out, dtype=torch.float)
280
+ out = L[aten.where](
281
+ L[aten.gt](out, 0),
282
+ out,
283
+ L[aten.mul](out, negative_slope),
284
+ )
285
+ if lowp_dtype:
286
+ out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined]
287
+ return out
288
+
289
+ return fn
290
+
291
+ def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None):
292
+ @register_lowering_pattern(
293
+ pattern, extra_check=_is_single_computation_op(computation_op)
294
+ )
295
+ def fn(match, *args, **kwargs):
296
+ min_value = kwargs.get("min_value")
297
+ max_value = kwargs.get("max_value")
298
+ if isinstance(min_value, ir.TensorBox) or isinstance(
299
+ max_value, ir.TensorBox
300
+ ):
301
+ matched = False
302
+ else: # inp is a Number
303
+ assert max_value is not None
304
+ matched = min_value <= max_value
305
+ if lowp_dtype:
306
+ dtype1 = kwargs.get("to_float")
307
+ dtype2 = (
308
+ kwargs.get("to_bf16")
309
+ if lowp_dtype == torch.bfloat16
310
+ else kwargs.get("to_fp16")
311
+ )
312
+ matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
313
+ computation_args = list(args)
314
+ if matched:
315
+ computation_args = computation_args[:-3] + [
316
+ "hardtanh",
317
+ [min_value, max_value],
318
+ "",
319
+ ]
320
+ return L[computation_op](*computation_args)
321
+ else:
322
+ out = L[computation_op](*computation_args)
323
+ if lowp_dtype:
324
+ out = L[prims.convert_element_type.default](out, dtype=torch.float)
325
+ out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value)
326
+ if lowp_dtype:
327
+ out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined]
328
+ return out
329
+
330
+ return fn
331
+
332
+ _binary_attr = {
333
+ aten.add: "add",
334
+ ops.add: "add",
335
+ aten.sub: "sub",
336
+ ops.sub: "sub",
337
+ }
338
+
339
+ def _is_valid_binary(match, fn):
340
+ binary_nodes = filter_nodes(match.nodes, fn)
341
+ if len(binary_nodes) < 1:
342
+ return False
343
+
344
+ def get_meta_value(argument: torch.fx.node.Argument):
345
+ # Only torch.fx.Node is expected to have meta.
346
+ if isinstance(argument, torch.fx.Node):
347
+ return argument.meta.get("val", None)
348
+ return None
349
+
350
+ if any(
351
+ not isinstance(get_meta_value(n.args[0]), torch.Tensor)
352
+ or not isinstance(get_meta_value(n.args[1]), torch.Tensor)
353
+ for n in binary_nodes
354
+ ):
355
+ return False
356
+ # check alpha is one.
357
+ if any(
358
+ get_arg_value(n, 2, kwarg_name="alpha") != 1.0
359
+ and get_arg_value(n, 2, kwarg_name="alpha") is not None
360
+ for n in binary_nodes
361
+ ):
362
+ return False
363
+ if any(
364
+ get_meta_value(n.args[0]).size() != get_meta_value(n.args[1]).size()
365
+ or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device
366
+ or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype
367
+ for n in binary_nodes
368
+ ):
369
+ return False
370
+ # check args[0] and args[1] is not same
371
+ if any(n.args[0] == n.args[1] for n in binary_nodes):
372
+ return False
373
+ return True
374
+
375
+ def _is_valid_computation_binary(computation_op, binary_op, other_index=None):
376
+ def fn(match):
377
+ if not _is_single_computation_op(computation_op)(match):
378
+ return False
379
+ if not _is_valid_binary(match, binary_op):
380
+ return False
381
+ return True
382
+
383
+ return fn
384
+
385
+ def _get_remaining_users(extra_input_node, compute_node):
386
+ # Think about this pattern:
387
+ # ReLU
388
+ # / \
389
+ # Conv1
390
+ # / \
391
+ # Conv2
392
+ # \ /
393
+ # Add
394
+ # Although, the extra input node (ReLU) has more than 1 users: Conv1 and Add.
395
+ # The Conv1 is the ancestor node of the current compute node (Conv2).
396
+ # This indicates that the buffer of ReLU has completed all its usage,
397
+ # So we can safely make changes to it now by doing Conv2->Add inplace fusion.
398
+ # Take above case as example:
399
+ # * extra_input_node: ReLU
400
+ # * compute_node: Conv2
401
+ # _get_remaining_users will return the users of extra_input_node which are not
402
+ # ancestor node of compute_node.
403
+ def _is_ancestor_node(_current_node, _ancestor_node):
404
+ # Check whether _ancestor_node is the ancestor node of _current_node
405
+ _node_list = [_current_node]
406
+ _visited_nodes = set()
407
+ while len(_node_list) != 0:
408
+ _current_node = _node_list.pop(0)
409
+ if _current_node not in _visited_nodes:
410
+ _visited_nodes.add(_current_node)
411
+ if _current_node == _ancestor_node:
412
+ return True
413
+ elif isinstance(
414
+ _current_node, torch.fx.Node
415
+ ) and _current_node.op not in ["placeholder", "output", "get_attr"]:
416
+ for input in _current_node.all_input_nodes:
417
+ _node_list.append(input) # noqa: PERF402
418
+ return False
419
+
420
+ return [
421
+ user
422
+ for user in list(extra_input_node.users)
423
+ if not _is_ancestor_node(compute_node, user)
424
+ ]
425
+
426
+ def _is_valid_computation_binary_inplace(computation_op, binary_op, other_index):
427
+ def fn(match):
428
+ if not _is_valid_computation_binary(computation_op, binary_op)(match):
429
+ return False
430
+ binary_nodes = filter_nodes(match.nodes, binary_op)
431
+
432
+ def _get_compute_node(_binary_node, _other_index):
433
+ assert (
434
+ len(_binary_node.all_input_nodes) == 2
435
+ ), "Binary node should have 2 input nodes."
436
+ _compute_index = 1 if (_other_index == 0) else 0
437
+ return _binary_node.args[_compute_index]
438
+
439
+ def _other_input_not_inplaceable(_binary_node, _other_index):
440
+ _compute_node = _get_compute_node(_binary_node, _other_index)
441
+ return (
442
+ len(
443
+ _get_remaining_users(
444
+ _binary_node.args[_other_index], _compute_node
445
+ )
446
+ )
447
+ > 1
448
+ or _binary_node.args[_other_index] == _compute_node.args[0]
449
+ )
450
+
451
+ if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes):
452
+ return False
453
+ if any(
454
+ n.args[other_index].op in ["placeholder", "output"]
455
+ for n in binary_nodes
456
+ ):
457
+ return False
458
+ return True
459
+
460
+ return fn
461
+
462
+ def _register_binary_unary_fusion_lowering(
463
+ pattern,
464
+ computation_op,
465
+ binary_op,
466
+ fusion_op,
467
+ unary_attr=None,
468
+ ):
469
+ @register_lowering_pattern(
470
+ pattern, extra_check=_is_valid_computation_binary(computation_op, binary_op)
471
+ )
472
+ def fn(match, *args, **kwargs):
473
+ other = kwargs.get("other")
474
+ assert isinstance(other, ir.TensorBox)
475
+ binary_attr = _binary_attr[binary_op]
476
+ args_list = list(args)
477
+ computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr]
478
+ if len(args_list) > 6:
479
+ if unary_attr is not None:
480
+ computation_args += [
481
+ 1.0,
482
+ unary_attr.op_name,
483
+ unary_attr.scalars_attr,
484
+ unary_attr.algorithm_attr,
485
+ ]
486
+ else:
487
+ computation_args += [1.0, None, [], None]
488
+ return L[fusion_op](*computation_args)
489
+
490
+ return fn
491
+
492
+ def _can_be_inplace(_other):
493
+ if isinstance(_other.data, ir.View):
494
+ return _can_be_inplace(_other.data)
495
+ else:
496
+ return not (
497
+ isinstance(_other.data, ir.ReinterpretView)
498
+ or isinstance(
499
+ _other.get_layout(), (ir.MutationLayout, ir.AliasedLayout)
500
+ )
501
+ )
502
+
503
+ def _register_binary_unary_maybe_inplace_fusion_lowering(
504
+ pattern,
505
+ computation_op,
506
+ binary_op,
507
+ inplace_fusion_op,
508
+ outplace_fusion_op,
509
+ unary_attr=None,
510
+ other_index=None,
511
+ ):
512
+ @register_lowering_pattern(
513
+ pattern,
514
+ extra_check=_is_valid_computation_binary_inplace(
515
+ computation_op, binary_op, other_index
516
+ ),
517
+ )
518
+ def fn(match, *args, **kwargs):
519
+ other = kwargs.get("other")
520
+ assert isinstance(other, ir.TensorBox)
521
+ binary_attr = _binary_attr[binary_op]
522
+ args_list = list(args)
523
+ computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr]
524
+ if len(args_list) > 6:
525
+ if unary_attr is not None:
526
+ computation_args += [
527
+ 1.0,
528
+ unary_attr.op_name,
529
+ unary_attr.scalars_attr,
530
+ unary_attr.algorithm_attr,
531
+ ]
532
+ else:
533
+ computation_args += [1.0, None, [], None]
534
+ # Make sure the other is not an alias or mutation(fx side doesn't has such info).
535
+ other.realize()
536
+ if not _can_be_inplace(other):
537
+ return L[outplace_fusion_op](*computation_args)
538
+ return L[inplace_fusion_op](*computation_args)
539
+
540
+ return fn
541
+
542
+ computation_ops = [
543
+ mkldnn._convolution_pointwise.default,
544
+ mkldnn._linear_pointwise.default,
545
+ mkldnn._convolution_transpose_pointwise.default,
546
+ ]
547
+
548
+ class UnaryAttr:
549
+ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
550
+ self.op_name = op_name
551
+ self.scalars_attr = scalars_attr if scalars_attr else []
552
+ self.algorithm_attr = algorithm_attr if algorithm_attr else ""
553
+
554
+ def _register_unary_fusion():
555
+ computation_call_fns = [_conv_call, _linear_call, _conv_transpose_call]
556
+
557
+ def _unary_fusion_patterns(lowp_dtype):
558
+ replacement_unary_fusion_patterns = {
559
+ UnaryAttr("gelu", algorithm_attr="tanh"): [
560
+ _unary_fusion_pattern(_gelu_fusion_2, call_fn, 4, lowp_dtype)
561
+ for call_fn in computation_call_fns
562
+ ],
563
+ UnaryAttr("gelu", algorithm_attr="none"): [
564
+ _unary_fusion_pattern(_gelu_fusion_1, call_fn, 2, lowp_dtype)
565
+ for call_fn in computation_call_fns
566
+ ],
567
+ UnaryAttr("hardswish"): [
568
+ _unary_fusion_pattern(_hardswish_fusion, call_fn, 2, lowp_dtype)
569
+ for call_fn in computation_call_fns
570
+ ],
571
+ UnaryAttr("hardsigmoid"): [
572
+ _unary_fusion_pattern(_hardsigmoid_fusion, call_fn, 1, lowp_dtype)
573
+ for call_fn in computation_call_fns
574
+ ],
575
+ UnaryAttr("swish"): [
576
+ _unary_fusion_pattern(_silu_fusion, call_fn, 2, lowp_dtype)
577
+ for call_fn in computation_call_fns
578
+ ],
579
+ }
580
+ if not lowp_dtype:
581
+ call_user1 = [call_fn(users=1) for call_fn in computation_call_fns]
582
+ replacement_unary_fusion_patterns.update(
583
+ {
584
+ UnaryAttr("relu"): [
585
+ _combined_fusion(u, aten.relu) for u in call_user1
586
+ ],
587
+ UnaryAttr("sigmoid"): [
588
+ _combined_fusion(u, aten.sigmoid) for u in call_user1
589
+ ],
590
+ UnaryAttr("tanh"): [
591
+ _combined_fusion(u, aten.tanh) for u in call_user1
592
+ ],
593
+ }
594
+ )
595
+
596
+ return replacement_unary_fusion_patterns
597
+
598
+ for lowp_dtype in [torch.bfloat16, torch.float16, None]:
599
+ replace_patterns = _unary_fusion_patterns(lowp_dtype)
600
+ for unary_attr, patterns in replace_patterns.items():
601
+ _register_unary_fusion_lowering(
602
+ patterns[0], unary_attr, computation_ops[0], lowp_dtype
603
+ )
604
+ _register_unary_fusion_lowering(
605
+ patterns[1], unary_attr, computation_ops[1], lowp_dtype
606
+ )
607
+ _register_unary_fusion_lowering(
608
+ patterns[2], unary_attr, computation_ops[2], lowp_dtype
609
+ )
610
+ _leaky_relu_patterns = [
611
+ _unary_fusion_pattern(_leaky_relu_fusion, call_fn, 3, lowp_dtype)
612
+ for call_fn in computation_call_fns
613
+ ]
614
+ for pattern, computation_op in zip(_leaky_relu_patterns, computation_ops):
615
+ _register_leaky_relu_fusion_lowering(
616
+ pattern, computation_op, lowp_dtype
617
+ )
618
+ hardtanh_patterns = [
619
+ _unary_fusion_pattern(_hardtanh_fusion, call_fn, 1, lowp_dtype)
620
+ for call_fn in computation_call_fns
621
+ ]
622
+ for pattern, computation_op in zip(hardtanh_patterns, computation_ops):
623
+ _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype)
624
+
625
+ def _register_inplace_fusion():
626
+ binary_ops = [aten.add, ops.add]
627
+ inplace_fusion_op = mkldnn._convolution_pointwise_.binary
628
+ outplace_fusion_op = mkldnn._convolution_pointwise.binary
629
+ conv_call = _conv_call(users=1)
630
+ conv_op = computation_ops[0]
631
+ for binary_op in binary_ops:
632
+ binary_v1 = _binary_fusion_v1(conv_call, binary_op)
633
+ binary_unary_v1 = _combined_fusion(binary_v1, aten.relu)
634
+ _register_binary_unary_maybe_inplace_fusion_lowering(
635
+ binary_unary_v1,
636
+ conv_op,
637
+ binary_op,
638
+ inplace_fusion_op,
639
+ outplace_fusion_op,
640
+ other_index=0,
641
+ unary_attr=UnaryAttr("relu"),
642
+ )
643
+ _register_binary_unary_maybe_inplace_fusion_lowering(
644
+ binary_v1,
645
+ conv_op,
646
+ binary_op,
647
+ inplace_fusion_op,
648
+ outplace_fusion_op,
649
+ other_index=0,
650
+ )
651
+ binary_v2 = _binary_fusion_v2(conv_call, binary_op)
652
+ binary_unary_v2 = _combined_fusion(binary_v2, aten.relu)
653
+ _register_binary_unary_maybe_inplace_fusion_lowering(
654
+ binary_unary_v2,
655
+ conv_op,
656
+ binary_op,
657
+ inplace_fusion_op,
658
+ outplace_fusion_op,
659
+ other_index=1,
660
+ unary_attr=UnaryAttr("relu"),
661
+ )
662
+ _register_binary_unary_maybe_inplace_fusion_lowering(
663
+ binary_v2,
664
+ conv_op,
665
+ binary_op,
666
+ inplace_fusion_op,
667
+ outplace_fusion_op,
668
+ other_index=1,
669
+ )
670
+
671
+ def _register_binary_fusion():
672
+ binary_ops = [aten.add, ops.add, aten.sub, ops.sub]
673
+ fusion_ops = [
674
+ mkldnn._convolution_pointwise.binary,
675
+ mkldnn._linear_pointwise.binary,
676
+ ]
677
+ _computation_user_1 = [_conv_call(users=1), _linear_call(users=1)]
678
+ for computation_call, computation_op, fusion_op in zip(
679
+ _computation_user_1, computation_ops[:-1], fusion_ops
680
+ ):
681
+ for binary_op in binary_ops:
682
+ pattern = _binary_fusion_v2(computation_call, binary_op)
683
+ _register_binary_unary_fusion_lowering(
684
+ pattern, computation_op, binary_op, fusion_op
685
+ )
686
+
687
+ for binary_op in [aten.add, ops.add]:
688
+ pattern = _binary_fusion_v1(computation_call, binary_op)
689
+ _register_binary_unary_fusion_lowering(
690
+ pattern, computation_op, binary_op, fusion_op
691
+ )
692
+
693
+ def _register_binary_unary_fusion():
694
+ binary_ops = [aten.add, ops.add, aten.sub, ops.sub]
695
+ fusion_ops = [mkldnn._convolution_pointwise.binary]
696
+ _computation_user_1 = [_conv_call(users=1)]
697
+ for computation_call, computation_op, fusion_op in zip(
698
+ _computation_user_1, computation_ops[:-1], fusion_ops
699
+ ):
700
+ for binary_op in binary_ops:
701
+ pattern_v1 = _combined_fusion(
702
+ _binary_fusion_v2(computation_call, binary_op), aten.relu
703
+ )
704
+ _register_binary_unary_fusion_lowering(
705
+ pattern_v1,
706
+ computation_op,
707
+ binary_op,
708
+ fusion_op,
709
+ unary_attr=UnaryAttr("relu"),
710
+ )
711
+ for binary_op in [aten.add, ops.add]:
712
+ pattern_v2 = _combined_fusion(
713
+ _binary_fusion_v1(computation_call, binary_op), aten.relu
714
+ )
715
+ _register_binary_unary_fusion_lowering(
716
+ pattern_v2,
717
+ computation_op,
718
+ binary_op,
719
+ fusion_op,
720
+ unary_attr=UnaryAttr("relu"),
721
+ )
722
+
723
+ def _recover_linear():
724
+ # convert reshape+linear+reshape to a single linear for applying fusion path.
725
+ @register_freezing_graph_pattern(
726
+ CallFunction(
727
+ aten.reshape.default,
728
+ CallFunction(
729
+ mkldnn._linear_pointwise.default,
730
+ CallFunction(
731
+ aten.reshape.default,
732
+ Arg(),
733
+ KeywordArg("reshape_1"),
734
+ _users=MULTIPLE,
735
+ ),
736
+ Arg(),
737
+ Arg(),
738
+ Arg(),
739
+ Arg(),
740
+ Arg(),
741
+ ),
742
+ KeywordArg("reshape_2"),
743
+ ),
744
+ pass_number=1,
745
+ )
746
+ def reshape_linear_reshape_pattern(match, *args, **kwargs):
747
+ reshape_1 = kwargs.get("reshape_1")
748
+ reshape_2 = kwargs.get("reshape_2")
749
+ assert isinstance(reshape_1, list)
750
+ assert isinstance(reshape_2, list)
751
+ assert len(reshape_1) == 2
752
+ dynamic_shapes = not all(
753
+ isinstance(x, int) for x in ([reshape_1[0]] + reshape_2[:-1])
754
+ )
755
+
756
+ graph = match.graph
757
+ reshape_2_node = match.output_node()
758
+ linear_input_node = reshape_2_node.args[0].args[0].args[0]
759
+ # check linear's input's shape[:-1] == reshape_2[:-1]
760
+ # and check product(reshape_2[:-1]) == reshape_1[0]
761
+ if dynamic_shapes:
762
+ # TODO: Haozhe investigate how add guard here
763
+ return
764
+ else:
765
+ can_remove_reshape = linear_input_node.meta.get("val").shape[
766
+ :-1
767
+ ] == torch.Size(reshape_2[:-1])
768
+ can_remove_reshape = can_remove_reshape and (
769
+ reduce(operator.mul, reshape_2[:-1]) == reshape_1[0]
770
+ )
771
+
772
+ if can_remove_reshape:
773
+ repl = graph.call_function(mkldnn._linear_pointwise.default, args)
774
+ repl.meta.update(reshape_2_node.meta)
775
+ reshape_2_node.replace_all_uses_with(repl)
776
+ old_linear_node = reshape_2_node.args[0]
777
+ reshape_1_node = old_linear_node.args[0]
778
+ graph.erase_node(reshape_2_node)
779
+ graph.erase_node(old_linear_node)
780
+ if len(reshape_1_node.users) == 0:
781
+ graph.erase_node(reshape_1_node)
782
+
783
+ def is_linear_add_bias(match):
784
+ add_node = match.output_node()
785
+ linear_node = add_node.args[0]
786
+ weight_meta = linear_node.args[1].meta.get("val")
787
+ bias_meta = add_node.args[1].meta.get("val")
788
+ if weight_meta is None or bias_meta is None:
789
+ return False
790
+ return (
791
+ linear_node.args[2] is None
792
+ and bias_meta.dim() == 1
793
+ and bias_meta.size(0) == weight_meta.size(0)
794
+ )
795
+
796
+ # convert linear+bias to a single linear for applying fusion path.
797
+ @register_freezing_graph_pattern(
798
+ CallFunction(
799
+ aten.add.Tensor,
800
+ CallFunction(mkldnn._linear_pointwise.default, *_linear_args),
801
+ Arg(),
802
+ ),
803
+ pass_number=1,
804
+ extra_check=is_linear_add_bias,
805
+ )
806
+ def linear_bias_pattern(match, *args):
807
+ graph = match.graph
808
+ add_node = match.output_node()
809
+ linear_node = add_node.args[0]
810
+ new_args = list(linear_node.args)
811
+ new_args[2] = add_node.args[1]
812
+ repl = graph.call_function(
813
+ mkldnn._linear_pointwise.default, tuple(new_args)
814
+ )
815
+ repl.meta.update(add_node.meta)
816
+ add_node.replace_all_uses_with(repl)
817
+ match.erase_nodes(graph)
818
+
819
+ def _is_packable_mkldnn_rnn_layer(match):
820
+ lstm_node = match.output_node()
821
+ POS_WEIGHTS = [1, 2]
822
+ POS_INPUTS = [0, 5, 6]
823
+ POS_ARGS = POS_WEIGHTS + POS_INPUTS
824
+ # Weights should be Constant
825
+ if any(
826
+ lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS
827
+ ):
828
+ return False
829
+
830
+ # Meta info for weights and inputs should be available
831
+ if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS):
832
+ return False
833
+
834
+ # Check device
835
+ if any(
836
+ lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu"
837
+ for POS_ARG in POS_ARGS
838
+ ):
839
+ return False
840
+
841
+ # Check dtype
842
+ if any(
843
+ lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16
844
+ and not mkldnn._is_mkldnn_bf16_supported()
845
+ for POS_ARG in POS_ARGS
846
+ ):
847
+ return False
848
+ if any(
849
+ lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16
850
+ and not mkldnn._is_mkldnn_fp16_supported()
851
+ for POS_ARG in POS_ARGS
852
+ ):
853
+ return False
854
+
855
+ return True
856
+
857
+ def _is_packable_convolution(match):
858
+ """
859
+ Check if the node is supported for MKLDNN convolution.
860
+ """
861
+ conv_node = match.output_node()
862
+ input_meta_value = conv_node.args[0].meta.get("val")
863
+ weight_meta_value = conv_node.args[1].meta.get("val")
864
+ if input_meta_value is None or weight_meta_value is None:
865
+ return False
866
+ input_size = input_meta_value.shape
867
+ if conv_node.args[1].op != "get_attr":
868
+ return False
869
+ for meta_value in [input_meta_value, weight_meta_value]:
870
+ if (
871
+ meta_value is None
872
+ or meta_value.device.type != "cpu"
873
+ or meta_value.dim() != 4
874
+ ):
875
+ return False
876
+ if (
877
+ input_meta_value.dtype == torch.bfloat16
878
+ or weight_meta_value.dtype == torch.bfloat16
879
+ ):
880
+ if not mkldnn._is_mkldnn_bf16_supported():
881
+ return False
882
+ if (
883
+ input_meta_value.dtype == torch.float16
884
+ or weight_meta_value.dtype == torch.float16
885
+ ):
886
+ if not mkldnn._is_mkldnn_fp16_supported():
887
+ return False
888
+ is_transposed = conv_node.args[-3]
889
+ if is_transposed:
890
+ # TODO: Support dynamic shape case for MKLDNN conv transpose.
891
+ if has_free_symbols(input_size):
892
+ return False
893
+ groups = conv_node.args[-1]
894
+ in_channels = weight_meta_value.size(0)
895
+ # doesn't support group_depthwise_conv_transpose.
896
+ if groups > 1 and groups == in_channels:
897
+ return False
898
+ # Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big
899
+ output_paddings = conv_node.args[-2]
900
+ strides = conv_node.args[3]
901
+ if any(
902
+ output_padding >= stride
903
+ for output_padding, stride in zip(output_paddings, strides)
904
+ ):
905
+ return False
906
+ return True
907
+
908
+ def _is_packable_linear(match):
909
+ """
910
+ Check if the node is supported for MKLDNN linear.
911
+ """
912
+ linear_node = match.output_node()
913
+ # weight_idx is 1 for aten.mm and is 2 for aten.addmm
914
+ weight_idx = 2 if linear_node.target == aten.addmm.default else 1
915
+ if linear_node.args[weight_idx].op != "get_attr":
916
+ return False
917
+ input_meta_value = linear_node.args[weight_idx - 1].meta.get("val")
918
+ weight_meta_value = linear_node.args[weight_idx].meta.get("val")
919
+ if input_meta_value is None or weight_meta_value is None:
920
+ return False
921
+ batch_size = input_meta_value.shape[0]
922
+ is_lp_weight = weight_meta_value.dtype in (
923
+ torch.bfloat16,
924
+ torch.float16,
925
+ )
926
+ # on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol.
927
+ # on aarch64, use mkldnn op for fp32 as well if acl is enabled
928
+ if (
929
+ not is_lp_weight
930
+ and not mkldnn._is_mkldnn_acl_supported()
931
+ and ((not torch._C.has_mkl) or has_free_symbols(batch_size))
932
+ ):
933
+ return False
934
+ for meta_value in [input_meta_value, weight_meta_value]:
935
+ if (
936
+ meta_value is None
937
+ or meta_value.device.type != "cpu"
938
+ or meta_value.dim() != 2
939
+ ):
940
+ return False
941
+ if weight_idx == 2:
942
+ bias_meta_value = linear_node.args[0].meta.get("val")
943
+ if (
944
+ bias_meta_value is None
945
+ or meta_value.device.type != "cpu"
946
+ or bias_meta_value.dim() != 1
947
+ or bias_meta_value.size(0) != weight_meta_value.size(1)
948
+ ):
949
+ return False
950
+
951
+ if (
952
+ input_meta_value.dtype == torch.bfloat16
953
+ or weight_meta_value.dtype == torch.bfloat16
954
+ ):
955
+ if not mkldnn._is_mkldnn_bf16_supported():
956
+ return False
957
+ if (
958
+ input_meta_value.dtype == torch.float16
959
+ or weight_meta_value.dtype == torch.float16
960
+ ):
961
+ if not mkldnn._is_mkldnn_fp16_supported():
962
+ return False
963
+ return True
964
+
965
+ _aten_conv_args = (
966
+ Arg(),
967
+ Arg(),
968
+ Arg(),
969
+ Arg(),
970
+ Arg(),
971
+ Arg(),
972
+ KeywordArg("is_transposed"),
973
+ Arg(),
974
+ Arg(),
975
+ )
976
+
977
+ _aten_mkldnn_rnn_layer_args = (
978
+ Arg(), # input
979
+ Arg(), # weight0
980
+ Arg(), # weight1
981
+ Arg(), # weight2
982
+ Arg(), # weight3
983
+ Arg(), # hx_
984
+ Arg(), # cx_
985
+ KeywordArg("reverse"), # reverse
986
+ Arg(), # batch_sizes
987
+ Arg(), # mode
988
+ Arg(), # hidden_size
989
+ Arg(), # num_layers
990
+ Arg(), # has_biases
991
+ Arg(), # bidirectional
992
+ Arg(), # batch_first
993
+ Arg(), # train
994
+ )
995
+
996
+ def _register_weight_pack_pass():
997
+ @register_freezing_graph_pattern(
998
+ CallFunction(aten.convolution.default, *_aten_conv_args),
999
+ extra_check=_is_packable_convolution,
1000
+ )
1001
+ def convolution(match, *args, **kwargs):
1002
+ is_transposed = kwargs.get("is_transposed")
1003
+ assert isinstance(is_transposed, bool)
1004
+ graph = match.graph
1005
+ conv_node = match.output_node()
1006
+ input_size = conv_node.args[0].meta.get("val").shape
1007
+ with graph.inserting_before(conv_node):
1008
+ constant_args = [args[4], args[3], args[5], args[-1]]
1009
+ packed_weight_op = mkldnn._reorder_convolution_weight
1010
+ packed_conv_op = mkldnn._convolution_pointwise.default
1011
+ if is_transposed:
1012
+ constant_args.insert(1, args[-2]) # output_padding
1013
+ packed_weight_op = mkldnn._reorder_convolution_transpose_weight
1014
+ packed_conv_op = mkldnn._convolution_transpose_pointwise.default
1015
+ if not has_free_symbols(input_size):
1016
+ packed_weight_inputs = (
1017
+ (args[1],) + tuple(constant_args) + (input_size,)
1018
+ )
1019
+ packed_weight_node = graph.create_node(
1020
+ "call_function", packed_weight_op, args=packed_weight_inputs
1021
+ )
1022
+ else:
1023
+ assert not is_transposed
1024
+ # For dynamic shape case, we need to pack weight in runtime.
1025
+ packed_weight_node = args[1]
1026
+ packed_conv_inputs = (
1027
+ (args[0], packed_weight_node, args[2])
1028
+ + tuple(constant_args)
1029
+ + ("none", [], "")
1030
+ )
1031
+ packed_conv_node = graph.create_node(
1032
+ "call_function", packed_conv_op, tuple(packed_conv_inputs)
1033
+ )
1034
+ conv_node.replace_all_uses_with(packed_conv_node)
1035
+ packed_conv_node.meta.update(conv_node.meta)
1036
+ graph.erase_node(conv_node)
1037
+
1038
+ @register_freezing_graph_pattern(
1039
+ CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args),
1040
+ extra_check=_is_packable_mkldnn_rnn_layer,
1041
+ )
1042
+ def mkldnn_rnn_layer(match, *args, **kwargs):
1043
+ def get_item(graph, node, index):
1044
+ return graph.call_function(operator.getitem, (node, index))
1045
+
1046
+ graph = match.graph
1047
+ lstm_node = match.output_node()
1048
+ input = args[0]
1049
+ weight0, weight1 = args[1:3]
1050
+ reverse = kwargs.get("reverse")
1051
+ packed_lstm_op = aten.mkldnn_rnn_layer.default
1052
+ hidden_size = args[9]
1053
+ has_biases = args[11]
1054
+ batch_first = args[13]
1055
+ with graph.inserting_before(lstm_node):
1056
+ packed_weight_op = mkldnn._reorder_mkldnn_rnn_layer_weight.default
1057
+ packed_weight_inputs = (
1058
+ weight0,
1059
+ weight1,
1060
+ hidden_size,
1061
+ reverse,
1062
+ has_biases,
1063
+ batch_first,
1064
+ )
1065
+ packed_weight_node = graph.create_node(
1066
+ "call_function", packed_weight_op, packed_weight_inputs, {}, "name"
1067
+ )
1068
+ packed_weight_items = [
1069
+ get_item(graph, packed_weight_node, i) for i in range(2)
1070
+ ]
1071
+ pack_lstm_inputs = (
1072
+ args[0],
1073
+ *packed_weight_items,
1074
+ args[3],
1075
+ args[4],
1076
+ args[5],
1077
+ args[6],
1078
+ reverse,
1079
+ *args[7:],
1080
+ )
1081
+
1082
+ packed_lstm_node = graph.create_node(
1083
+ "call_function", packed_lstm_op, args=pack_lstm_inputs
1084
+ )
1085
+ lstm_node.replace_all_uses_with(packed_lstm_node)
1086
+ packed_lstm_node.meta.update(lstm_node.meta)
1087
+ graph.erase_node(lstm_node)
1088
+
1089
+ @register_freezing_graph_pattern(
1090
+ CallFunction(aten.addmm.default, Arg(), Arg(), Arg()),
1091
+ extra_check=_is_packable_linear,
1092
+ )
1093
+ @register_freezing_graph_pattern(
1094
+ CallFunction(aten.mm.default, Arg(), Arg()),
1095
+ extra_check=_is_packable_linear,
1096
+ )
1097
+ def linear(match, *args, **kwargs):
1098
+ graph = match.graph
1099
+ linear_node = match.output_node()
1100
+ input = args[0] if linear_node.target == aten.mm.default else args[1]
1101
+ bias = None if linear_node.target == aten.mm.default else args[0]
1102
+ weight = args[1] if linear_node.target == aten.mm.default else args[2]
1103
+ with graph.inserting_before(linear_node):
1104
+ transpose_weight_node = graph.create_node(
1105
+ "call_function", aten.permute.default, (weight, (1, 0))
1106
+ )
1107
+ weight_dtype = weight.meta.get("val").dtype
1108
+ is_lp_weight = weight_dtype in (
1109
+ torch.bfloat16,
1110
+ torch.float16,
1111
+ )
1112
+ batch_size = input.meta.get("val").shape[0]
1113
+ if has_free_symbols(batch_size):
1114
+ assert (
1115
+ is_lp_weight or mkldnn._is_mkldnn_acl_supported()
1116
+ ), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
1117
+ # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
1118
+ packed_weight_inputs = (
1119
+ transpose_weight_node,
1120
+ batch_size.node.shape_env.size_hint(batch_size.node.expr)
1121
+ if has_free_symbols(batch_size)
1122
+ else batch_size,
1123
+ )
1124
+ packed_weight_op = (
1125
+ mkldnn._reorder_linear_weight
1126
+ if (is_lp_weight or mkldnn._is_mkldnn_acl_supported())
1127
+ else torch.ops.mkl._mkl_reorder_linear_weight
1128
+ )
1129
+ packed_weight_node = graph.create_node(
1130
+ "call_function", packed_weight_op, args=packed_weight_inputs
1131
+ )
1132
+
1133
+ packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node)
1134
+ if is_lp_weight or mkldnn._is_mkldnn_acl_supported():
1135
+ packed_linear_inputs += (bias, "none", [], "")
1136
+ packed_linear_op = mkldnn._linear_pointwise.default
1137
+ else:
1138
+ packed_linear_inputs += (transpose_weight_node, bias, batch_size)
1139
+ packed_linear_op = torch.ops.mkl._mkl_linear
1140
+ packed_linear_node = graph.create_node(
1141
+ "call_function", packed_linear_op, packed_linear_inputs
1142
+ )
1143
+ linear_node.replace_all_uses_with(packed_linear_node)
1144
+ packed_linear_node.meta.update(linear_node.meta)
1145
+ graph.erase_node(linear_node)
1146
+
1147
+ def _eliminate_duplicate_packed_nodes(gm):
1148
+ """
1149
+ Combine packed weight nodes with the same inputs to reduce memory usage.
1150
+ for example:
1151
+ class Model(nn.Module):
1152
+ def __init__(self):
1153
+ super().__init__()
1154
+ self.linear = nn.Linear(32, 32, bias=True)
1155
+
1156
+ def forward(self, x):
1157
+ return self.linear(self.linear(x))
1158
+
1159
+ the above's packed weight nodes are duplicate if two linear calls have same input size.
1160
+ """
1161
+ if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
1162
+ return gm
1163
+
1164
+ packed_weight_ops = [
1165
+ torch._C._nn.mkldnn_reorder_conv2d_weight,
1166
+ mkldnn._reorder_convolution_transpose_weight,
1167
+ mkldnn._reorder_linear_weight,
1168
+ mkldnn._reorder_mkldnn_rnn_layer_weight,
1169
+ ]
1170
+ if torch._C.has_mkl:
1171
+ packed_weight_ops.append(torch.ops.mkl._mkl_reorder_linear_weight)
1172
+
1173
+ for node in gm.graph.nodes:
1174
+ if node.target in packed_weight_ops and len(node.args[0].users) > 1:
1175
+ for user_node in list(node.args[0].users.keys()):
1176
+ if (
1177
+ user_node.target == node.target
1178
+ and user_node != node
1179
+ and user_node.args == node.args
1180
+ ):
1181
+ user_node.replace_all_uses_with(node)
1182
+ gm.graph.erase_node(user_node)
1183
+
1184
+ @functools.lru_cache(None)
1185
+ def _mkldnn_fusion_init():
1186
+ # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now.
1187
+ # Otherwise even the matmul or innerproduct can not be accelerated with acl
1188
+ if (
1189
+ torch.backends.mkldnn.enabled
1190
+ and torch.backends.mkldnn.is_available()
1191
+ and not torch.ops.mkldnn._is_mkldnn_acl_supported()
1192
+ ):
1193
+ _register_unary_fusion()
1194
+ _register_inplace_fusion()
1195
+ _register_binary_unary_fusion()
1196
+ _register_binary_fusion()
1197
+ _register_quantization_lowerings()
1198
+
1199
+ @functools.lru_cache(None)
1200
+ def _mkldnn_weight_pack_init():
1201
+ if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available():
1202
+ _register_weight_pack_pass()
1203
+ _recover_linear()
1204
+ _register_quantization_weight_pack_pass()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/post_grad.py ADDED
@@ -0,0 +1,1100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import functools
3
+ import itertools
4
+ import logging
5
+ import operator
6
+ from collections import Counter, defaultdict
7
+ from typing import Any, Dict, List, Optional, Set, Union
8
+
9
+ from sympy import Expr
10
+
11
+ import torch
12
+ import torch._inductor as inductor
13
+ import torch.utils._pytree as pytree
14
+ from torch import fx
15
+ from torch._decomp import register_decomposition
16
+ from torch._dynamo.utils import counters, optimus_scuba_log
17
+
18
+ from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
19
+
20
+ from torch._utils_internal import upload_graph
21
+ from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
22
+
23
+ from .. import config, ir, pattern_matcher
24
+ from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
25
+
26
+ from ..lowering import lowerings as L
27
+ from ..pattern_matcher import (
28
+ _return_true,
29
+ Arg,
30
+ CallFunction,
31
+ CallFunctionVarArgs,
32
+ filter_nodes,
33
+ get_arg_value,
34
+ get_mutation_region_id,
35
+ Ignored,
36
+ init_once_fakemode,
37
+ KeywordArg,
38
+ ListOf,
39
+ Match,
40
+ MULTIPLE,
41
+ PatternMatcherPass,
42
+ register_graph_pattern,
43
+ stable_topological_sort,
44
+ )
45
+ from ..utils import decode_device, is_pointwise_use
46
+ from ..virtualized import V
47
+ from .group_batch_fusion import group_batch_fusion_passes
48
+ from .reinplace import reinplace_inplaceable_ops
49
+
50
+ log = logging.getLogger(__name__)
51
+ aten = torch.ops.aten
52
+ prims = torch.ops.prims
53
+
54
+ # First pass_patterns[0] are applied, then [1], then [2]
55
+ pass_patterns = [
56
+ PatternMatcherPass(),
57
+ PatternMatcherPass(),
58
+ PatternMatcherPass(),
59
+ ]
60
+ # patterns applied only in inference
61
+ inference_patterns = PatternMatcherPass()
62
+ decompose_mm_pass = PatternMatcherPass()
63
+
64
+
65
+ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
66
+ """
67
+ Passes that run on after grad. This is called once on the forwards
68
+ graph and once on the backwards graph.
69
+
70
+ The IR here has been normalized and functionalized.
71
+ """
72
+ if config.dce:
73
+ # has some issues with mutation in inference mode
74
+ gm.graph.eliminate_dead_code()
75
+
76
+ if is_inference and config.reorder_for_locality:
77
+ reorder_for_locality(gm.graph)
78
+
79
+ fake_tensor_updater = FakeTensorUpdater(gm.graph)
80
+
81
+ if config.post_grad_custom_pre_pass is not None:
82
+ config.post_grad_custom_pre_pass(gm.graph)
83
+
84
+ if config.pattern_matcher:
85
+ lazy_init()
86
+ inductor_before_change = copy.deepcopy(counters["inductor"])
87
+ group_batch_fusion_passes(gm.graph, pre_grad=False)
88
+ if counters["inductor"] != inductor_before_change:
89
+ optimus_scuba_log["group_batch_fusion_post_grad"] = upload_graph(gm.graph)
90
+ remove_noop_ops(gm.graph)
91
+ for patterns in pass_patterns:
92
+ patterns.apply(gm.graph) # type: ignore[arg-type]
93
+ if is_inference:
94
+ inference_patterns.apply(gm.graph) # type: ignore[arg-type]
95
+ decompose_mm_pass.apply(gm.graph) # type: ignore[arg-type]
96
+
97
+ if config.post_grad_custom_post_pass is not None:
98
+ config.post_grad_custom_post_pass(gm.graph)
99
+
100
+ stable_topological_sort(gm.graph)
101
+
102
+ move_constructors_to_cuda(gm.graph)
103
+
104
+ fake_tensor_updater.incremental_update()
105
+
106
+ # Keep these last, since they introduces mutation. Look at
107
+ # ./fx_passes/README.md for a discussion of mutation invariants.
108
+ reinplace_inplaceable_ops(gm.graph)
109
+ decompose_auto_functionalized(gm.graph)
110
+
111
+ gm.recompile()
112
+ gm.graph.lint()
113
+
114
+
115
+ @init_once_fakemode
116
+ def lazy_init():
117
+ if torch._C._has_mkldnn:
118
+ from . import decompose_mem_bound_mm # noqa: F401
119
+ from .mkldnn_fusion import _mkldnn_fusion_init
120
+
121
+ _mkldnn_fusion_init()
122
+
123
+
124
+ def reorder_for_locality(graph: torch.fx.Graph):
125
+ def visit(other_node):
126
+ if (
127
+ other_node.op == "call_function"
128
+ and other_node.target != operator.getitem
129
+ and all((n in seen_nodes) for n in other_node.users)
130
+ and get_mutation_region_id(graph, node)
131
+ == get_mutation_region_id(graph, other_node)
132
+ ):
133
+ # move node's producers right before it
134
+ node.prepend(other_node)
135
+
136
+ seen_nodes = set()
137
+
138
+ # only reorder nodes before the first copy_ in the graph.
139
+ # copy_ will appear at the end of functionalized graphs when there is mutation on inputs,
140
+ # and this reordering doesnt work well with mutation
141
+ first_copy = next(
142
+ (
143
+ node
144
+ for node in graph.nodes
145
+ if node.op == "call_function"
146
+ and node.target == torch.ops.aten.copy_.default
147
+ ),
148
+ None,
149
+ )
150
+ past_mutating_epilogue = True if first_copy is None else False
151
+
152
+ for node in reversed(graph.nodes):
153
+ seen_nodes.add(node)
154
+ if not past_mutating_epilogue:
155
+ past_mutating_epilogue = node is first_copy
156
+ continue
157
+
158
+ torch.fx.map_arg((node.args, node.kwargs), visit)
159
+
160
+
161
+ def register_lowering_pattern(pattern, extra_check=_return_true, pass_number=1):
162
+ """
163
+ Register an aten to inductor IR replacement pattern
164
+ """
165
+ return pattern_matcher.register_lowering_pattern(
166
+ pattern, extra_check, pass_dict=pass_patterns[pass_number]
167
+ )
168
+
169
+
170
+ ################################################################################
171
+ # Actual patterns below this point.
172
+ # Priority of patterns is:
173
+ # - later output nodes first
174
+ # - order patterns are defined in
175
+ ################################################################################
176
+
177
+
178
+ def is_valid_mm_plus_mm(match: Match):
179
+ *b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape
180
+ *b2, k2, n1 = match.kwargs["mat2"].meta.get("tensor_meta").shape
181
+ if k1 != k2:
182
+ return False
183
+
184
+ *b1, m2, k3 = match.kwargs["mat3"].meta.get("tensor_meta").shape
185
+ *b2, k4, n2 = match.kwargs["mat4"].meta.get("tensor_meta").shape
186
+ if k3 != k4:
187
+ return False
188
+
189
+ if m1 != m2 or n1 != n2:
190
+ return False
191
+
192
+ return True
193
+
194
+
195
+ @register_lowering_pattern(
196
+ CallFunction(
197
+ aten.add,
198
+ CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")),
199
+ CallFunction(aten.mm, KeywordArg("mat3"), KeywordArg("mat4")),
200
+ ),
201
+ extra_check=is_valid_mm_plus_mm,
202
+ )
203
+ def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4):
204
+ return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4)
205
+
206
+
207
+ def cuda_and_enabled_mixed_mm(match):
208
+ return (config.use_mixed_mm or config.force_mixed_mm) and getattr(
209
+ match.kwargs["mat1"].meta.get("val"), "is_cuda", False
210
+ )
211
+
212
+
213
+ def cuda_and_enabled_mixed_mm_and_not_int8(match):
214
+ return (
215
+ cuda_and_enabled_mixed_mm(match)
216
+ and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False)
217
+ and getattr(match.kwargs["mat2"].meta.get("val"), "dtype", torch.int8)
218
+ != torch.int8
219
+ ) # bitshift numerics in triton and pytorch don't match for torch.int8
220
+
221
+
222
+ """
223
+ this is intended to be used to unpack a [K,N] int4 tensor from a [K/2, N] uint4x2 tensor
224
+ (where the int4 and uint4x2 are represented with int8 and uint8 respectively)
225
+ where every other row of the int4 is packed with the row above it as:
226
+ uint4x2[k,n] = (8+int4[2*k,n])+(8+int4[2*k+1,n])<<4
227
+
228
+ unpack formulas:
229
+ int4[2*k,n]=(uint4x2[k,n] & 0xF) - 8
230
+ int4[2*k+1,n]=(uint4x2[k,n] >> 4) - 8
231
+
232
+ thus matching on unpack formula:
233
+ torch.mm(mat1, torch.cat((mat2 & 0xF, mat2>>4),1).reshape(mat2_mm_shape).to(mat2_dtype).sub(8))
234
+
235
+ note: although the unpack formula in pytorch and the triton kernel is designed for a uint8 mat2, the behavior
236
+ of the kernel matches the pytorch formula for all dtypes except torch.int8
237
+ where the bitwise numerics in triton do not match those in pytorch.
238
+ """
239
+
240
+
241
+ @register_lowering_pattern(
242
+ CallFunction(
243
+ aten.mm.default,
244
+ KeywordArg("mat1"),
245
+ CallFunction(
246
+ aten.sub.Tensor,
247
+ CallFunction(
248
+ prims.convert_element_type.default,
249
+ CallFunction(
250
+ aten.reshape.default,
251
+ CallFunction(
252
+ aten.cat.default,
253
+ ListOf(
254
+ CallFunction(
255
+ aten.bitwise_and.Scalar,
256
+ KeywordArg("mat2"),
257
+ 0xF,
258
+ ),
259
+ CallFunction(
260
+ aten.__rshift__.Scalar,
261
+ KeywordArg("mat2"),
262
+ 4,
263
+ ),
264
+ ),
265
+ 1,
266
+ ),
267
+ KeywordArg("mat2_mm_shape"),
268
+ ),
269
+ KeywordArg("mat2_dtype"),
270
+ ),
271
+ 8,
272
+ ),
273
+ ),
274
+ extra_check=cuda_and_enabled_mixed_mm_and_not_int8,
275
+ )
276
+ def uint4x2_mixed_mm(match: Match, mat1, mat2, mat2_mm_shape, mat2_dtype):
277
+ return inductor.kernel.unpack_mixed_mm.tuned_uint4x2_mixed_mm(
278
+ mat1, mat2, mat2_mm_shape, mat2_dtype
279
+ )
280
+
281
+
282
+ """
283
+ torch.mm(mat1, mat2.to(mat2_dtype))
284
+ """
285
+
286
+
287
+ @register_lowering_pattern(
288
+ CallFunction(
289
+ aten.mm,
290
+ KeywordArg("mat1"),
291
+ CallFunction(
292
+ prims.convert_element_type.default,
293
+ KeywordArg("mat2"),
294
+ KeywordArg("mat2_dtype"),
295
+ ),
296
+ ),
297
+ extra_check=cuda_and_enabled_mixed_mm,
298
+ )
299
+ def mixed_mm(match: Match, mat1, mat2, mat2_dtype):
300
+ return inductor.kernel.mm.tuned_mixed_mm(mat1, mat2, mat2_dtype)
301
+
302
+
303
+ @register_graph_pattern(
304
+ CallFunction(
305
+ aten.cumsum.default,
306
+ CallFunction(
307
+ torch.ops.aten.full.default,
308
+ KeywordArg("shape"),
309
+ KeywordArg("fill_value"),
310
+ dtype=KeywordArg("dtype"),
311
+ layout=Ignored(),
312
+ device=KeywordArg("device"),
313
+ pin_memory=False,
314
+ _users=MULTIPLE,
315
+ ),
316
+ KeywordArg("dim"),
317
+ _users=MULTIPLE,
318
+ ),
319
+ pass_dict=pass_patterns[1],
320
+ )
321
+ def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim):
322
+ """Based on a pattern in OPTForCausalLM"""
323
+
324
+ if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
325
+ # cumsum promotes all integral types to int64
326
+ dtype = torch.int64
327
+
328
+ def repl(*shape):
329
+ dim_size = shape[dim]
330
+ idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype)
331
+
332
+ inter_shape = [1] * len(shape)
333
+ inter_shape[dim] = dim_size
334
+ return (idx * fill_value).view(inter_shape).expand(shape)
335
+
336
+ # only replace the output node, not all nodes
337
+ match.nodes = [match.output_node()]
338
+ with V.fake_mode:
339
+ match.replace_by_example(repl, list(shape))
340
+
341
+
342
+ def shape_of_mm(a, b):
343
+ m, _ = a.get_size()
344
+ _, n = b.get_size()
345
+ return [m, n]
346
+
347
+
348
+ @register_lowering_pattern(
349
+ CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()),
350
+ )
351
+ def cat_mm(match, inputs, dim):
352
+ return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of_mm)
353
+
354
+
355
+ @register_lowering_pattern(
356
+ CallFunction(
357
+ aten.cat, ListOf(CallFunction(aten.addmm, Arg(), Arg(), Arg())), Arg()
358
+ ),
359
+ )
360
+ def cat_addmm(match, inputs, dim):
361
+ def shape_of(bias, a, b):
362
+ m, _ = a.get_size()
363
+ _, n = b.get_size()
364
+ return [m, n]
365
+
366
+ return cat_tuned_op(match, inputs, dim, op=L[aten.addmm], shape_of=shape_of)
367
+
368
+
369
+ def cat_tuned_op(match, inputs, dim, *, op, shape_of):
370
+ """
371
+ Memory planning to remove cat. We can't use the stock memory
372
+ planner since autotuning matmuls needs to know the output layout.
373
+ """
374
+ if len(inputs) == 1:
375
+ return op(*inputs[0])
376
+
377
+ # TODO(jansel): rewrite this as a bmm?
378
+ if dim < 0:
379
+ dim += len(shape_of(*inputs[0]))
380
+ assert dim in (0, 1)
381
+ notdim = 1 - dim
382
+
383
+ new_size: Optional[Union[List[Expr], List[int]]] = None
384
+ offsets_start = []
385
+ offsets_end = []
386
+
387
+ # compute output sizes
388
+ for i in range(len(inputs)):
389
+ shape = shape_of(*inputs[i])
390
+ if new_size is None:
391
+ new_size = shape
392
+ else:
393
+ new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload]
394
+ shape[notdim], new_size[notdim]
395
+ )
396
+ new_size[dim] += shape[dim]
397
+ offsets_start.append(new_size[dim] - shape[dim])
398
+ offsets_end.append(new_size[dim])
399
+
400
+ assert new_size is not None
401
+ dtype = functools.reduce(
402
+ torch.promote_types,
403
+ [x.get_dtype() for x in itertools.chain.from_iterable(inputs)],
404
+ )
405
+ device = inputs[0][0].get_device()
406
+ kernel = ir.ConcatKernel(
407
+ name=None,
408
+ layout=ir.FixedLayout(device, dtype, new_size),
409
+ inputs=[],
410
+ )
411
+ kernel_tensor = ir.TensorBox.create(kernel)
412
+
413
+ for i in range(len(inputs)):
414
+ dst = ir.SliceView.create(kernel_tensor, dim, offsets_start[i], offsets_end[i])
415
+ src = op(*inputs[i], layout=dst.get_layout()).data.data
416
+ assert isinstance(src, (ir.ExternKernelOut, ir.TemplateBuffer))
417
+ src.layout = ir.AliasedLayout(dst)
418
+ kernel.inputs.append(src)
419
+
420
+ kernel.name = V.graph.register_buffer(kernel)
421
+ kernel.inputs = ir.ConcatKernel.unwrap_storage(kernel.inputs)
422
+ return kernel_tensor
423
+
424
+
425
+ _cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2)
426
+
427
+
428
+ @register_lowering_pattern(
429
+ CallFunction(
430
+ aten.cat,
431
+ [
432
+ _cat_1,
433
+ CallFunction(
434
+ aten.slice,
435
+ _cat_1,
436
+ 1,
437
+ 0,
438
+ KeywordArg("size"),
439
+ ),
440
+ ],
441
+ 1,
442
+ )
443
+ )
444
+ def cat_slice_cat(match, cat_input, size, dim=1):
445
+ """
446
+ This is an example of a more complex pattern where cat_1 is used
447
+ multiple times inside the pattern. We fold 2 calls to cat into one.
448
+
449
+ Matches:
450
+ cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1)
451
+ slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
452
+ slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
453
+ cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1)
454
+
455
+
456
+ Rewrite to:
457
+ slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19)
458
+ cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1)
459
+ """
460
+ first, *rest = cat_input
461
+ # Optimization is optional, because we can just not fold the cat
462
+ # size should be within first.get_size()[dim] such that the optimization is valid.
463
+ # For negative `end`, we currently fallback to not optimizing.
464
+ if size >= 0 and V.graph.sizevars.statically_known_leq(size, first.get_size()[dim]):
465
+ # fold 2 cats into 1 cat
466
+ return L[aten.cat](
467
+ [
468
+ first,
469
+ *rest,
470
+ L[aten.slice](first, dim, 0, size),
471
+ ],
472
+ dim,
473
+ )
474
+ else:
475
+ # don't expect to hit this case, just fall back
476
+ tmp = L[aten.cat](cat_input, dim)
477
+ return L[aten.cat](
478
+ [
479
+ tmp,
480
+ L[aten.slice](tmp, dim, 0, size),
481
+ ],
482
+ dim,
483
+ )
484
+
485
+
486
+ def is_valid_splitwithsizes_cat(match):
487
+ split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
488
+ cat_nodes = filter_nodes(match.nodes, aten.cat)
489
+ get_item_nodes = filter_nodes(match.nodes, operator.getitem)
490
+ if len(split_nodes) != 1 or len(cat_nodes) != 1:
491
+ return False
492
+ split_node, cat_node = split_nodes[0], cat_nodes[0]
493
+ # The dim of split and cat should match for passthrough
494
+ if get_arg_value(split_node, 2, "dim") != get_arg_value(cat_node, 1, "dim"):
495
+ return False
496
+ get_item_args = {
497
+ get_arg_value(get_item_node, 1) for get_item_node in get_item_nodes
498
+ }
499
+ assert None not in get_item_args
500
+ split_sizes = get_arg_value(split_node, 1, "split_sizes")
501
+ # All parts of split should be included in the cat
502
+ if get_item_args != set(range(len(split_sizes))):
503
+ return False
504
+ # The order of get_item_args should same with cat_node used.
505
+ # For example, if the split_node like split_with_sizes(input, [2, 2, 3], 1),
506
+ # the cat node should be like cat([get_item(0), get_item(1), get_item(2)], 1).
507
+ cat_items_args_order = [
508
+ get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0)
509
+ ]
510
+ if cat_items_args_order != list(range(len(split_sizes))):
511
+ return False
512
+
513
+ return True
514
+
515
+
516
+ def same_meta(node1: torch.fx.Node, node2: torch.fx.Node):
517
+ """True if two nodes have the same metadata"""
518
+ val1 = node1.meta.get("val")
519
+ val2 = node2.meta.get("val")
520
+ return (
521
+ val1 is not None
522
+ and val2 is not None
523
+ and statically_known_true(sym_eq(val1.size(), val2.size()))
524
+ and val1.layout == val2.layout
525
+ and val1.dtype == val2.dtype
526
+ and val1.device == val2.device
527
+ and (
528
+ val1.layout != torch.strided
529
+ or statically_known_true(sym_eq(val1.stride(), val2.stride()))
530
+ )
531
+ )
532
+
533
+
534
+ noop_registry: Dict[Any, Any] = {}
535
+
536
+
537
+ def register_noop_decomp(targets, nop_arg=0):
538
+ def register_fun(cond):
539
+ register_decomposition(targets, registry=noop_registry, unsafe=True)(
540
+ (cond, nop_arg)
541
+ )
542
+ return cond
543
+
544
+ return register_fun
545
+
546
+
547
+ @register_noop_decomp(aten.slice)
548
+ def slice_noop(self, dim=0, start=None, end=None, step=1):
549
+ if start is None or end is None:
550
+ return False
551
+ if start == 0 and end >= 2**63 - 1 and step == 1:
552
+ return True
553
+ return False
554
+
555
+
556
+ @register_noop_decomp(aten.slice_scatter, 1)
557
+ def slice_scatter_noop(self, src, dim=0, start=None, end=None, step=1):
558
+ if start is None:
559
+ start = 0
560
+ if end is None:
561
+ end = 2**63 - 1
562
+ if start == 0 and end >= 2**63 - 1 and step == 1:
563
+ return True
564
+ return False
565
+
566
+
567
+ @register_noop_decomp(aten.repeat)
568
+ def repeat_noop(self, repeats):
569
+ return all(r == 1 for r in repeats)
570
+
571
+
572
+ @register_noop_decomp(aten.constant_pad_nd)
573
+ def constant_pad_nd(x, padding, fill_value=0):
574
+ return all(p == 0 for p in padding)
575
+
576
+
577
+ @register_noop_decomp(torch.ops.prims.convert_element_type)
578
+ def convert_element_type_noop(x, dtype: torch.dtype):
579
+ return x.dtype == dtype
580
+
581
+
582
+ @register_noop_decomp(torch.ops.prims.device_put)
583
+ def device_put_noop(x, device):
584
+ return x.device == decode_device(device)
585
+
586
+
587
+ @register_noop_decomp([aten.ceil, aten.floor, aten.round, aten.trunc])
588
+ def int_noop(x):
589
+ return is_integer_dtype(x.dtype)
590
+
591
+
592
+ @register_noop_decomp([aten.pow])
593
+ def pow_noop(a, b):
594
+ return isinstance(b, int) and b == 1
595
+
596
+
597
+ @register_noop_decomp([aten.cat], lambda args: args[0][0])
598
+ def cat_noop(inputs, dim=0):
599
+ return len(inputs) == 1
600
+
601
+
602
+ @register_noop_decomp(aten.view)
603
+ def view_noop(arg, size):
604
+ return arg.shape == size
605
+
606
+
607
+ # Note, we also always have a check for identical metadata, which is why these
608
+ # are safe
609
+ @register_noop_decomp([aten.copy], nop_arg=1)
610
+ @register_noop_decomp([aten.alias, aten.clone])
611
+ def true_noop(*args, **kwargs):
612
+ return True
613
+
614
+
615
+ def remove_noop_ops(graph: torch.fx.Graph):
616
+ """
617
+ Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph.
618
+ """
619
+ inputs = set()
620
+ input_storages = set()
621
+ output_storages = set()
622
+
623
+ for node in graph.nodes:
624
+ if node.op == "placeholder":
625
+ inputs.add(node)
626
+ input_storages.add(get_node_storage(node))
627
+ else:
628
+ break
629
+
630
+ output_node = next(iter(reversed(graph.nodes)))
631
+ assert output_node.op == "output"
632
+ for out in output_node.args[0]:
633
+ if isinstance(out, torch.fx.Node):
634
+ output_storages.add(get_node_storage(out))
635
+
636
+ for node in graph.nodes:
637
+ if node.target in noop_registry:
638
+ cond, src_index = noop_registry[node.target]
639
+ if isinstance(src_index, int):
640
+ src = node.args[src_index]
641
+ else:
642
+ src = src_index(node.args)
643
+ if not isinstance(src, torch.fx.Node):
644
+ continue
645
+ # Don't introduce new aliasing between inputs and outputs.
646
+ # See fx_passes/README.md for a discussion of why this is
647
+ # necessary.
648
+ node_storage = get_node_storage(node)
649
+ src_storage = get_node_storage(src)
650
+ node_is_view = node_storage == src_storage
651
+ if (
652
+ not node_is_view
653
+ and node_storage in output_storages
654
+ and (src_storage in input_storages or src_storage in output_storages)
655
+ ):
656
+ continue
657
+
658
+ # Even if input and outputs are expected to alias,
659
+ # don't make "node is src" True
660
+ if (
661
+ node_is_view
662
+ and node in output_node.args
663
+ and (src in inputs or src in output_node.args)
664
+ ):
665
+ continue
666
+
667
+ is_valid, args, kwargs = get_fake_args_kwargs(node)
668
+ if not is_valid:
669
+ continue
670
+ if same_meta(node, src) and cond(*args, **kwargs):
671
+ node.replace_all_uses_with(src)
672
+ graph.erase_node(node)
673
+
674
+
675
+ def decompose_auto_functionalized(graph):
676
+ graph_pass = PatternMatcherPass()
677
+
678
+ @register_graph_pattern(
679
+ CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized),
680
+ pass_dict=graph_pass,
681
+ )
682
+ def replacement(match: Match, *args, **kwargs):
683
+ from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense
684
+
685
+ only_clone_these_tensors = tuple(
686
+ match.nodes[0].meta.get("only_clone_these_tensors", [])
687
+ )
688
+
689
+ flat_args, spec = pytree.tree_flatten((args, kwargs))
690
+
691
+ # NB: we combine (args, kwargs) into flat args for replacing.
692
+ # This is replace_by_example uses make_fx which does not support
693
+ # tracing a function with kwargs.
694
+ def decomp(*flat_args):
695
+ args, kwargs = pytree.tree_unflatten(flat_args, spec)
696
+ return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs)
697
+
698
+ with V.fake_mode:
699
+ match.replace_by_example(decomp, flat_args, run_dce=False)
700
+
701
+ graph_pass.apply(graph)
702
+ for node in graph.nodes:
703
+ if node.target is torch.ops.higher_order.auto_functionalized:
704
+ raise AssertionError("auto_functionalized was not removed")
705
+
706
+
707
+ @register_lowering_pattern(
708
+ CallFunction(
709
+ aten.cat,
710
+ ListOf(
711
+ CallFunction(
712
+ operator.getitem,
713
+ CallFunction(
714
+ aten.split_with_sizes,
715
+ KeywordArg("input_"),
716
+ Ignored(),
717
+ Ignored(),
718
+ _users=MULTIPLE,
719
+ ),
720
+ Ignored(),
721
+ ),
722
+ ),
723
+ Ignored(),
724
+ ),
725
+ pass_number=2,
726
+ extra_check=is_valid_splitwithsizes_cat,
727
+ )
728
+ def splitwithsizes_cat_replace(match, input_):
729
+ return input_
730
+
731
+
732
+ def is_valid_cat_splitwithsizes(match):
733
+ cat_nodes = filter_nodes(match.nodes, aten.cat)
734
+ split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
735
+ if len(split_nodes) != 1 or len(cat_nodes) != 1:
736
+ return False
737
+ split_node, cat_node = split_nodes[0], cat_nodes[0]
738
+
739
+ # the cat node has other users: can't eliminate
740
+ if len(cat_node.users) > 1:
741
+ return False
742
+
743
+ # the dim of the cat and split should match
744
+ dim = get_arg_value(split_node, 2, "dim")
745
+ if dim != get_arg_value(cat_node, 1, "dim"):
746
+ return False
747
+
748
+ cat_inputs = list(get_arg_value(cat_node, 0))
749
+ split_sizes = get_arg_value(split_node, 1, "split_sizes")
750
+ # the number of input tensors in cat and the
751
+ # length of the split sizes should match
752
+ if len(cat_inputs) != len(split_sizes):
753
+ return False
754
+
755
+ for cat_input, split_size in zip(cat_inputs, split_sizes):
756
+ # each cat input tensor's size along dim
757
+ # should match the corresponding split size
758
+ if "val" not in cat_input.meta:
759
+ return False
760
+ cat_input_size = cat_input.meta["val"].size(dim)
761
+ if cat_input_size != split_size:
762
+ return False
763
+
764
+ return True
765
+
766
+
767
+ @register_lowering_pattern(
768
+ CallFunction(
769
+ aten.split_with_sizes,
770
+ CallFunction(
771
+ aten.cat,
772
+ KeywordArg("input_"),
773
+ Ignored(),
774
+ _users=MULTIPLE,
775
+ ),
776
+ Ignored(),
777
+ Ignored(),
778
+ ),
779
+ pass_number=2,
780
+ extra_check=is_valid_cat_splitwithsizes,
781
+ )
782
+ def cat_splitwithsizes_replace(match, input_):
783
+ return input_
784
+
785
+
786
+ def view_to_reshape(gm):
787
+ """
788
+ Replace view ops in the GraphModule to reshape ops.
789
+ """
790
+ for nd in gm.graph.nodes:
791
+ if nd.target == torch.ops.aten.view.default:
792
+ nd.target = torch.ops.aten.reshape.default
793
+
794
+
795
+ def should_prefer_unfused_addmm(match):
796
+ inp = match.kwargs["inp"]
797
+ if not inp.meta["val"].is_cuda:
798
+ return False
799
+
800
+ output = match.output_node()
801
+ return all(is_pointwise_use(use) for use in output.users)
802
+
803
+
804
+ @register_graph_pattern(
805
+ CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
806
+ pass_dict=pass_patterns[2],
807
+ extra_check=should_prefer_unfused_addmm,
808
+ )
809
+ def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
810
+ def repl(inp, x1, x2):
811
+ return x1 @ x2 + inp
812
+
813
+ with V.fake_mode:
814
+ match.replace_by_example(repl, [inp, mat1, mat2])
815
+
816
+
817
+ def is_valid_addmm_fusion(match):
818
+ mat1, mat2 = match.args
819
+ inp = match.kwargs["inp"]
820
+
821
+ if not (
822
+ isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor)
823
+ ):
824
+ return False # Input is a number
825
+
826
+ in_shape = inp.meta["val"].shape
827
+ mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1]
828
+ matched = is_expandable_to(in_shape, mm_shape)
829
+ if not matched:
830
+ return False # Shape mismatch
831
+
832
+ return not should_prefer_unfused_addmm(match)
833
+
834
+
835
+ @register_graph_pattern(
836
+ CallFunction(
837
+ aten.add,
838
+ CallFunction(aten.mm, Arg(), Arg()),
839
+ KeywordArg("inp"),
840
+ ),
841
+ pass_dict=pass_patterns[2],
842
+ extra_check=is_valid_addmm_fusion,
843
+ )
844
+ @register_graph_pattern(
845
+ CallFunction(
846
+ aten.add,
847
+ KeywordArg("inp"),
848
+ CallFunction(aten.mm, Arg(), Arg()),
849
+ ),
850
+ pass_dict=pass_patterns[2],
851
+ extra_check=is_valid_addmm_fusion,
852
+ )
853
+ def addmm(match, mat1, mat2, *, inp):
854
+ def repl(inp, mat1, mat2):
855
+ return aten.addmm(inp, mat1, mat2)
856
+
857
+ with V.fake_mode:
858
+ match.replace_by_example(repl, [inp, mat1, mat2])
859
+
860
+
861
+ def check_shape_cuda_and_fused_int_mm_mul_enabled(match):
862
+ return (
863
+ config.force_fuse_int_mm_with_mul
864
+ and len(getattr(match.args[2].meta.get("val"), "shape", [])) == 2
865
+ and getattr(match.args[2].meta.get("val"), "is_cuda", False)
866
+ )
867
+
868
+
869
+ @register_lowering_pattern(
870
+ CallFunction(
871
+ prims.convert_element_type.default,
872
+ CallFunction(
873
+ aten.mul,
874
+ CallFunction(
875
+ aten._int_mm,
876
+ Arg(),
877
+ Arg(),
878
+ ),
879
+ Arg(),
880
+ ),
881
+ Arg(),
882
+ ),
883
+ check_shape_cuda_and_fused_int_mm_mul_enabled,
884
+ )
885
+ @register_lowering_pattern(
886
+ CallFunction(
887
+ aten.mul,
888
+ CallFunction(
889
+ aten._int_mm,
890
+ Arg(),
891
+ Arg(),
892
+ ),
893
+ Arg(),
894
+ ),
895
+ check_shape_cuda_and_fused_int_mm_mul_enabled,
896
+ )
897
+ def fused_int_mm_mul(match: Match, mat1, mat2, mat3, out_dtype=None):
898
+ return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype)
899
+
900
+
901
+ class ConstructorMoverPass:
902
+ def __init__(self, target: str, allow_outputs: bool = False) -> None:
903
+ """
904
+ Move constructors from cpu to the target_device.
905
+
906
+ Sweeps through the module, looking for constructor nodes that can be moved
907
+ to the target_device.
908
+
909
+ A constructor node can be moved to the target_device iff all of its users
910
+ can also be moved (tested by cannot_be_moved). Otherwise, all dependent
911
+ constructor nodes won't be moved.
912
+
913
+ - target: target device type
914
+ - allow_outputs: allow outputs to be moved
915
+ """
916
+
917
+ self.target = target
918
+ self.allow_outputs = allow_outputs
919
+
920
+ assert isinstance(target, str), (
921
+ "target should be a string representing the device type. "
922
+ f"Got: {type(target).__name__}"
923
+ )
924
+
925
+ def allow_cpu_device(self, node: fx.Node) -> bool:
926
+ """
927
+ Returns whether a node that returns a tensor on the target device may have
928
+ cpu tensors as input.
929
+ """
930
+ return node.target in (
931
+ torch.ops.aten.index.Tensor,
932
+ torch.ops.aten.index_put.default,
933
+ torch.ops.aten.index_put_.default,
934
+ torch.ops.aten.copy.default,
935
+ torch.ops.aten.copy_.default,
936
+ torch.ops.aten.slice_scatter.default,
937
+ )
938
+
939
+ def cannot_be_moved(self, node: fx.Node) -> bool:
940
+ """
941
+ Returns whether a node can be moved to the target device.
942
+
943
+ If this function returns False, it means that this node and all of its users
944
+ won't be moved into the target device.
945
+ """
946
+ if node.target == "output":
947
+ return not self.allow_outputs
948
+
949
+ if not (
950
+ isinstance(node.target, torch._ops.OpOverload)
951
+ and node.target.namespace in ("prims", "aten")
952
+ ):
953
+ return True
954
+
955
+ return False
956
+
957
+ def get_node_device(self, node: fx.Node) -> Optional[torch.device]:
958
+ """
959
+ Get the device of a node.
960
+ """
961
+ ten = node.meta.get("val")
962
+ return None if not isinstance(ten, torch.Tensor) else ten.device
963
+
964
+ def get_cpu_indeg_count(self, graph: fx.Graph) -> Dict[fx.Node, int]:
965
+ """
966
+ Get the number of cpu inputs to a node
967
+ """
968
+ cpu_indeg: Dict[fx.Node, int] = Counter()
969
+
970
+ for node in graph.nodes:
971
+ cpu_count = 0
972
+
973
+ def add_cpu_inp(node):
974
+ nonlocal cpu_count
975
+ device = self.get_node_device(node)
976
+ cpu_count += device is not None and device.type == "cpu"
977
+
978
+ pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs))
979
+
980
+ if cpu_count:
981
+ cpu_indeg[node] = cpu_count
982
+
983
+ return cpu_indeg
984
+
985
+ def __call__(self, graph: fx.Graph) -> None:
986
+ target_devices = set()
987
+ constructors = []
988
+
989
+ for node in graph.nodes:
990
+ device = self.get_node_device(node)
991
+ if device and device.type == self.target:
992
+ target_devices.add(device)
993
+
994
+ if not (
995
+ isinstance(node.target, torch._ops.OpOverload)
996
+ and node.target.namespace in ("prims", "aten")
997
+ ):
998
+ continue
999
+
1000
+ if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target):
1001
+ continue
1002
+
1003
+ if not node.kwargs.get("device") == torch.device("cpu"):
1004
+ continue
1005
+
1006
+ constructors.append(node)
1007
+
1008
+ # not handling multiple target devices initially
1009
+ if not constructors or len(target_devices) != 1:
1010
+ return
1011
+
1012
+ movable_constructors = self.find_movable_constructors(graph, constructors)
1013
+
1014
+ for node in movable_constructors:
1015
+ kwargs = node.kwargs.copy()
1016
+ kwargs["device"] = next(iter(target_devices))
1017
+ node.kwargs = kwargs
1018
+
1019
+ def find_movable_constructors(
1020
+ self, graph: fx.Graph, constructors: List[fx.Node]
1021
+ ) -> Set[fx.Node]:
1022
+ """
1023
+ Starting from the cpu constructors, iterate through the graph and test that all of their
1024
+ downstream uses can safely be moved to cpu.
1025
+ """
1026
+ cpu_indeg: Dict[fx.Node, int] = self.get_cpu_indeg_count(graph)
1027
+
1028
+ # which constructors cannot be moved to cuda
1029
+ cannot_move_to_cuda: Set[fx.Node] = set()
1030
+
1031
+ # For any node in the graph, which constructors does it have a dependency on
1032
+ constructor_dependencies: Dict[fx.Node, Set[fx.Node]] = defaultdict(set)
1033
+
1034
+ # if a cpu node has a dependency on two different cpu constructors,
1035
+ # then if either constructor cannot be moved to cuda, the other cannot as well.
1036
+ # In this case any node with a dependency on one will have a dependency on the other
1037
+ equal_constructor_sets: Dict[fx.Node, Set[fx.Node]] = {
1038
+ c: {c} for c in constructors
1039
+ }
1040
+
1041
+ def make_dependencies_equivalent(
1042
+ set1: Set[fx.Node], set2: Set[fx.Node]
1043
+ ) -> Set[fx.Node]:
1044
+ # could use union find but not worth complexity here
1045
+ set1.update(set2)
1046
+ for obj in set1:
1047
+ equal_constructor_sets[obj] = set1
1048
+ return set1
1049
+
1050
+ queue: List[fx.Node] = list(constructors)
1051
+
1052
+ for c in queue:
1053
+ constructor_dependencies[c].add(c)
1054
+
1055
+ while queue:
1056
+ node = queue.pop()
1057
+ dependencies = constructor_dependencies[node]
1058
+
1059
+ for user in node.users:
1060
+ if self.cannot_be_moved(user):
1061
+ cannot_move_to_cuda.update(dependencies)
1062
+ break
1063
+
1064
+ # this node was used on a op which takes in multiple devices and output a cuda
1065
+ # tensor. we can convert its cpu input to cuda without making further changes
1066
+ node_device = self.get_node_device(user)
1067
+ if (
1068
+ self.allow_cpu_device(user)
1069
+ and node_device
1070
+ and node_device.type == self.target
1071
+ ):
1072
+ del cpu_indeg[user]
1073
+ else:
1074
+ # otherwise, we should continue look at its downstream uses
1075
+ cpu_indeg[user] -= 1
1076
+ if cpu_indeg[user] == 0:
1077
+ del cpu_indeg[user]
1078
+ queue.append(user)
1079
+
1080
+ unioned_set = make_dependencies_equivalent(
1081
+ dependencies, constructor_dependencies[user]
1082
+ )
1083
+ constructor_dependencies[user] = unioned_set
1084
+
1085
+ for node in cpu_indeg:
1086
+ if constructor_dependencies[node]:
1087
+ cannot_move_to_cuda.update(constructor_dependencies[node])
1088
+
1089
+ all_cannot_move_to_cuda = cannot_move_to_cuda.copy()
1090
+ for constructor in cannot_move_to_cuda:
1091
+ all_cannot_move_to_cuda.update(equal_constructor_sets[constructor])
1092
+
1093
+ return set(constructors) - all_cannot_move_to_cuda
1094
+
1095
+
1096
+ def move_constructors_to_cuda(graph: fx.Graph) -> None:
1097
+ """
1098
+ Moves intermediary tensors which are constructed on the cpu to cuda when safe
1099
+ """
1100
+ ConstructorMoverPass("cuda")(graph)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
35
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
36
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
37
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
38
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
39
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
40
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
41
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
42
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
43
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
44
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
45
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
46
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
47
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
48
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
49
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
50
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
51
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
52
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
53
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
54
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
55
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
56
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
57
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
58
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
59
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
60
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
61
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
62
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
63
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
64
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
65
+ div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale'))
66
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
67
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
68
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
69
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
70
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
71
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
72
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
73
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
74
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
75
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
76
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
77
+ _sfdp_pattern_1_training = MultiOutputPattern([view_default_5,
78
+ view_default_9,
79
+ permute_default_4,
80
+ view_default_11,
81
+ None
82
+ ])
83
+
84
+
85
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
86
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
87
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
88
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
89
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
90
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
91
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
92
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
93
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
94
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
95
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
96
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
97
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
98
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
99
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
100
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
101
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
102
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
103
+ _sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
104
+
105
+
106
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
107
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
108
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
109
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
110
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
111
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
112
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
113
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
114
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
115
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
116
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
117
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
118
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
119
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
120
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
121
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
122
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
123
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
124
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
125
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
126
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
127
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
128
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
129
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
130
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
131
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
132
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
133
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
134
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
135
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
136
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
137
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
138
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
139
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
140
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
141
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
142
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
143
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
144
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
145
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
146
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
147
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
148
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
149
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
150
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
151
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
152
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
153
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
154
+ _sfdp_pattern_1_half_training = MultiOutputPattern([view_default_5,
155
+ view_default_9,
156
+ permute_default_4,
157
+ view_default_11,
158
+ None
159
+ ])
160
+
161
+
162
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
163
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
164
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
165
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
166
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
167
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
168
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
169
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
170
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
171
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
172
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
173
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
174
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
175
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
176
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
177
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
178
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
179
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
180
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
181
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
182
+ _sfdp_pattern_1_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
35
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
36
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
37
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
38
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
39
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
40
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
41
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
42
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
43
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
44
+ amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
45
+ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
46
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
47
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
48
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
49
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor)
50
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored())
51
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored())
52
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
53
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
54
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
55
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
56
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
57
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
58
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
59
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
60
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
61
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
62
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
63
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
64
+ clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format)
65
+ alias_default = CallFunction(aten.alias.default, div_Tensor)
66
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
67
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
68
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
69
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
70
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
71
+ mul_Tensor_6 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
72
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_5, mul_Tensor_6)
73
+ mul_Tensor_7 = CallFunction(aten.mul.Tensor, sub_Tensor_1, KeywordArg('scale_factor'))
74
+ view_default_8 = CallFunction(aten.view.default, mul_Tensor_7, Ignored(), _users=2)
75
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
76
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
77
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
78
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
79
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
80
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
81
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
82
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
83
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
84
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
85
+ _sfdp_pattern_4_training = MultiOutputPattern([view_default_5,
86
+ view_default_9,
87
+ permute_default_4,
88
+ view_default_11,
89
+ None,
90
+ None
91
+ ])
92
+
93
+
94
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
95
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
96
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
97
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
98
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
99
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
100
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
101
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
102
+ amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
103
+ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
104
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
105
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
106
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
107
+ clone_default = CallFunction(aten.clone.default, div_Tensor)
108
+ expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
109
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
110
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
111
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
112
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
113
+ _sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
114
+
115
+
116
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
117
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
118
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
119
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
120
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
121
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
122
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
123
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
124
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
125
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
126
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
127
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
128
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
129
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
130
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
131
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
132
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
133
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
134
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored())
135
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored())
136
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
137
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
138
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
139
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
140
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
141
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
142
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
143
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
144
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
145
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
146
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
147
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
148
+ clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format)
149
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
150
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
151
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
152
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
153
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
154
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
155
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
156
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
157
+ mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
158
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_5, mul_Tensor_6)
159
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
160
+ mul_Tensor_7 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor'))
161
+ view_default_8 = CallFunction(aten.view.default, mul_Tensor_7, Ignored(), _users=2)
162
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
163
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
164
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
165
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
166
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
167
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
168
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
169
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
170
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
171
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
172
+ _sfdp_pattern_4_half_training = MultiOutputPattern([view_default_5,
173
+ view_default_9,
174
+ permute_default_4,
175
+ view_default_11,
176
+ None,
177
+ None
178
+ ])
179
+
180
+
181
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
182
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
183
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
184
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
185
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
186
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
187
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
188
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
189
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
190
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
191
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
192
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
193
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
194
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
195
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
196
+ clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
197
+ expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
198
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
199
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
200
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
201
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
202
+ _sfdp_pattern_4_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
35
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
36
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
37
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
38
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
39
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
40
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
41
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
42
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
43
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
44
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
45
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
46
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
47
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
48
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
49
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
50
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
51
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
52
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
53
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
54
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
55
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
56
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
57
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
58
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
59
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
60
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
61
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
62
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
63
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
64
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
65
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
66
+ div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored())
67
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
68
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
69
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
70
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
71
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
72
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
73
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
74
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
75
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
76
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
77
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
78
+ _sfdp_pattern_5_training = MultiOutputPattern([view_default_5,
79
+ view_default_9,
80
+ permute_default_4,
81
+ view_default_11,
82
+ None
83
+ ])
84
+
85
+
86
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
87
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
88
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
89
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
90
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
91
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
92
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
93
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
94
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
95
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
96
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
97
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
98
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
99
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
100
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
101
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
102
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
103
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
104
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
105
+ _sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
106
+
107
+
108
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
109
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
110
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
111
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
112
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
113
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
114
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
115
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
116
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
117
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
118
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
119
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
120
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
121
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
122
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
123
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
124
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
125
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
126
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
127
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
128
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
129
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
130
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
131
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
132
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
133
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
134
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
135
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
136
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
137
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
138
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
139
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
140
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
141
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
142
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
143
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
144
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
145
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored())
146
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
147
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
148
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
149
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
150
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
151
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
152
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
153
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
154
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
155
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
156
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
157
+ _sfdp_pattern_5_half_training = MultiOutputPattern([view_default_5,
158
+ view_default_9,
159
+ permute_default_4,
160
+ view_default_11,
161
+ None
162
+ ])
163
+
164
+
165
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
166
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
167
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
168
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
169
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
170
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
171
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
172
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
173
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
174
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
175
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
176
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
177
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
178
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
179
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
180
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
181
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
182
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
183
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
184
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
185
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
186
+ _sfdp_pattern_5_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/central_index.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # This is an auto-generated file. Please do not modify it by hand.
4
+ # To re-generate, run:
5
+ # cd ~/pytorch && python
6
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
7
+ from ._sfdp_pattern_1 import (_sfdp_pattern_1_training, _sfdp_pattern_1_inference, _sfdp_pattern_1_half_training, _sfdp_pattern_1_half_inference)
8
+ from ._sfdp_pattern_2 import (_sfdp_pattern_2_training, _sfdp_pattern_2_inference, _sfdp_pattern_2_half_training, _sfdp_pattern_2_half_inference)
9
+ from ._sfdp_pattern_3 import (_sfdp_pattern_3_training, _sfdp_pattern_3_inference, _sfdp_pattern_3_half_training, _sfdp_pattern_3_half_inference)
10
+ from ._sfdp_pattern_4 import (_sfdp_pattern_4_training, _sfdp_pattern_4_inference, _sfdp_pattern_4_half_training, _sfdp_pattern_4_half_inference)
11
+ from ._sfdp_pattern_5 import (_sfdp_pattern_5_training, _sfdp_pattern_5_inference, _sfdp_pattern_5_half_training, _sfdp_pattern_5_half_inference)
12
+ from ._sfdp_pattern_6 import (_sfdp_pattern_6_training, _sfdp_pattern_6_inference, _sfdp_pattern_6_half_training, _sfdp_pattern_6_half_inference)
13
+ from ._sfdp_pattern_7 import (_sfdp_pattern_7_training, _sfdp_pattern_7_inference, _sfdp_pattern_7_half_training, _sfdp_pattern_7_half_inference)
14
+ from ._sfdp_pattern_8 import (_sfdp_pattern_8_training, _sfdp_pattern_8_inference, _sfdp_pattern_8_half_training, _sfdp_pattern_8_half_inference)
15
+ from ._sfdp_pattern_9 import (_sfdp_pattern_9_training, _sfdp_pattern_9_inference, _sfdp_pattern_9_half_training, _sfdp_pattern_9_half_inference)
16
+ from ._sfdp_pattern_10 import (_sfdp_pattern_10_training, _sfdp_pattern_10_inference, _sfdp_pattern_10_half_training, _sfdp_pattern_10_half_inference)
17
+ from ._sfdp_pattern_11 import (_sfdp_pattern_11_training, _sfdp_pattern_11_inference, _sfdp_pattern_11_half_training, _sfdp_pattern_11_half_inference)
18
+ from ._sfdp_pattern_12 import (_sfdp_pattern_12_training, _sfdp_pattern_12_inference, _sfdp_pattern_12_half_training, _sfdp_pattern_12_half_inference)
19
+ from ._sfdp_pattern_13 import (_sfdp_pattern_13_training, _sfdp_pattern_13_inference, _sfdp_pattern_13_half_training, _sfdp_pattern_13_half_inference)
20
+ from ._sfdp_pattern_14 import (_sfdp_pattern_14_training, _sfdp_pattern_14_inference, _sfdp_pattern_14_half_training, _sfdp_pattern_14_half_inference)
21
+ from ._sfdp_pattern_15 import (_sfdp_pattern_15_training, _sfdp_pattern_15_inference, _sfdp_pattern_15_half_training, _sfdp_pattern_15_half_inference)
22
+ from ._sfdp_pattern_16 import (_sfdp_pattern_16_training, _sfdp_pattern_16_inference, _sfdp_pattern_16_bs1_training, _sfdp_pattern_16_bs1_inference, _sfdp_pattern_16_half_training, _sfdp_pattern_16_half_inference, _sfdp_pattern_16_half_bs1_training, _sfdp_pattern_16_half_bs1_inference, _sfdp_pattern_16_half_mask_fp32_training, _sfdp_pattern_16_half_mask_fp32_inference, _sfdp_pattern_16_half_mask_fp32_bs1_training, _sfdp_pattern_16_half_mask_fp32_bs1_inference)
23
+ from ._sfdp_pattern_17 import (_sfdp_pattern_17_training, _sfdp_pattern_17_inference, _sfdp_pattern_17_half_training, _sfdp_pattern_17_half_inference)
24
+
25
+ central_index = {
26
+ '_sfdp_pattern_1_training': _sfdp_pattern_1_training,
27
+ '_sfdp_pattern_1_inference': _sfdp_pattern_1_inference,
28
+ '_sfdp_pattern_2_training': _sfdp_pattern_2_training,
29
+ '_sfdp_pattern_2_inference': _sfdp_pattern_2_inference,
30
+ '_sfdp_pattern_3_training': _sfdp_pattern_3_training,
31
+ '_sfdp_pattern_3_inference': _sfdp_pattern_3_inference,
32
+ '_sfdp_pattern_4_training': _sfdp_pattern_4_training,
33
+ '_sfdp_pattern_4_inference': _sfdp_pattern_4_inference,
34
+ '_sfdp_pattern_5_training': _sfdp_pattern_5_training,
35
+ '_sfdp_pattern_5_inference': _sfdp_pattern_5_inference,
36
+ '_sfdp_pattern_6_training': _sfdp_pattern_6_training,
37
+ '_sfdp_pattern_6_inference': _sfdp_pattern_6_inference,
38
+ '_sfdp_pattern_7_training': _sfdp_pattern_7_training,
39
+ '_sfdp_pattern_7_inference': _sfdp_pattern_7_inference,
40
+ '_sfdp_pattern_8_training': _sfdp_pattern_8_training,
41
+ '_sfdp_pattern_8_inference': _sfdp_pattern_8_inference,
42
+ '_sfdp_pattern_9_training': _sfdp_pattern_9_training,
43
+ '_sfdp_pattern_9_inference': _sfdp_pattern_9_inference,
44
+ '_sfdp_pattern_10_training': _sfdp_pattern_10_training,
45
+ '_sfdp_pattern_10_inference': _sfdp_pattern_10_inference,
46
+ '_sfdp_pattern_11_training': _sfdp_pattern_11_training,
47
+ '_sfdp_pattern_11_inference': _sfdp_pattern_11_inference,
48
+ '_sfdp_pattern_12_training': _sfdp_pattern_12_training,
49
+ '_sfdp_pattern_12_inference': _sfdp_pattern_12_inference,
50
+ '_sfdp_pattern_13_training': _sfdp_pattern_13_training,
51
+ '_sfdp_pattern_13_inference': _sfdp_pattern_13_inference,
52
+ '_sfdp_pattern_14_training': _sfdp_pattern_14_training,
53
+ '_sfdp_pattern_14_inference': _sfdp_pattern_14_inference,
54
+ '_sfdp_pattern_15_training': _sfdp_pattern_15_training,
55
+ '_sfdp_pattern_15_inference': _sfdp_pattern_15_inference,
56
+ '_sfdp_pattern_16_training': _sfdp_pattern_16_training,
57
+ '_sfdp_pattern_16_inference': _sfdp_pattern_16_inference,
58
+ '_sfdp_pattern_16_bs1_training': _sfdp_pattern_16_bs1_training,
59
+ '_sfdp_pattern_16_bs1_inference': _sfdp_pattern_16_bs1_inference,
60
+ '_sfdp_pattern_17_training': _sfdp_pattern_17_training,
61
+ '_sfdp_pattern_17_inference': _sfdp_pattern_17_inference,
62
+ '_sfdp_pattern_1_half_training': _sfdp_pattern_1_half_training,
63
+ '_sfdp_pattern_1_half_inference': _sfdp_pattern_1_half_inference,
64
+ '_sfdp_pattern_2_half_training': _sfdp_pattern_2_half_training,
65
+ '_sfdp_pattern_2_half_inference': _sfdp_pattern_2_half_inference,
66
+ '_sfdp_pattern_3_half_training': _sfdp_pattern_3_half_training,
67
+ '_sfdp_pattern_3_half_inference': _sfdp_pattern_3_half_inference,
68
+ '_sfdp_pattern_4_half_training': _sfdp_pattern_4_half_training,
69
+ '_sfdp_pattern_4_half_inference': _sfdp_pattern_4_half_inference,
70
+ '_sfdp_pattern_5_half_training': _sfdp_pattern_5_half_training,
71
+ '_sfdp_pattern_5_half_inference': _sfdp_pattern_5_half_inference,
72
+ '_sfdp_pattern_6_half_training': _sfdp_pattern_6_half_training,
73
+ '_sfdp_pattern_6_half_inference': _sfdp_pattern_6_half_inference,
74
+ '_sfdp_pattern_7_half_training': _sfdp_pattern_7_half_training,
75
+ '_sfdp_pattern_7_half_inference': _sfdp_pattern_7_half_inference,
76
+ '_sfdp_pattern_8_half_training': _sfdp_pattern_8_half_training,
77
+ '_sfdp_pattern_8_half_inference': _sfdp_pattern_8_half_inference,
78
+ '_sfdp_pattern_9_half_training': _sfdp_pattern_9_half_training,
79
+ '_sfdp_pattern_9_half_inference': _sfdp_pattern_9_half_inference,
80
+ '_sfdp_pattern_10_half_training': _sfdp_pattern_10_half_training,
81
+ '_sfdp_pattern_10_half_inference': _sfdp_pattern_10_half_inference,
82
+ '_sfdp_pattern_11_half_training': _sfdp_pattern_11_half_training,
83
+ '_sfdp_pattern_11_half_inference': _sfdp_pattern_11_half_inference,
84
+ '_sfdp_pattern_12_half_training': _sfdp_pattern_12_half_training,
85
+ '_sfdp_pattern_12_half_inference': _sfdp_pattern_12_half_inference,
86
+ '_sfdp_pattern_13_half_training': _sfdp_pattern_13_half_training,
87
+ '_sfdp_pattern_13_half_inference': _sfdp_pattern_13_half_inference,
88
+ '_sfdp_pattern_14_half_training': _sfdp_pattern_14_half_training,
89
+ '_sfdp_pattern_14_half_inference': _sfdp_pattern_14_half_inference,
90
+ '_sfdp_pattern_15_half_training': _sfdp_pattern_15_half_training,
91
+ '_sfdp_pattern_15_half_inference': _sfdp_pattern_15_half_inference,
92
+ '_sfdp_pattern_16_half_training': _sfdp_pattern_16_half_training,
93
+ '_sfdp_pattern_16_half_inference': _sfdp_pattern_16_half_inference,
94
+ '_sfdp_pattern_16_half_bs1_training': _sfdp_pattern_16_half_bs1_training,
95
+ '_sfdp_pattern_16_half_bs1_inference': _sfdp_pattern_16_half_bs1_inference,
96
+ '_sfdp_pattern_17_half_training': _sfdp_pattern_17_half_training,
97
+ '_sfdp_pattern_17_half_inference': _sfdp_pattern_17_half_inference,
98
+ '_sfdp_pattern_16_half_mask_fp32_training': _sfdp_pattern_16_half_mask_fp32_training,
99
+ '_sfdp_pattern_16_half_mask_fp32_inference': _sfdp_pattern_16_half_mask_fp32_inference,
100
+ '_sfdp_pattern_16_half_mask_fp32_bs1_training': _sfdp_pattern_16_half_mask_fp32_bs1_training,
101
+ '_sfdp_pattern_16_half_mask_fp32_bs1_inference': _sfdp_pattern_16_half_mask_fp32_bs1_inference,
102
+ }
103
+
104
+
105
+ def get_serialized_pattern(key):
106
+ import torch._inductor # noqa: F401
107
+ from torch._inductor import config
108
+ if config.fallback_random:
109
+ return None
110
+
111
+ # TODO - could add more validation that the same set of decomps used when
112
+ # tracing SDPA are also used in current context. softmax, dropout, etc
113
+ # decomp use is stable so not an issue in practice.
114
+ return central_index.get(key)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/split_cat.py ADDED
@@ -0,0 +1,1537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import logging
3
+ import operator
4
+ from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, Union
5
+
6
+ from typing_extensions import TypeAlias
7
+
8
+ import torch
9
+ from torch._dynamo.utils import counters
10
+
11
+ from ..pattern_matcher import (
12
+ Arg,
13
+ CallFunction,
14
+ CallFunctionVarArgs,
15
+ CallMethodVarArgs,
16
+ config_flag,
17
+ FailedMatch,
18
+ get_arg_value,
19
+ Ignored,
20
+ KeywordArg,
21
+ ListOf,
22
+ Match,
23
+ MatchContext,
24
+ MULTIPLE,
25
+ PatternExpr,
26
+ register_graph_pattern,
27
+ RepeatedExpr,
28
+ )
29
+ from .group_batch_fusion import is_node_meta_valid
30
+ from .pre_grad import (
31
+ merge_getitem_cat_pass,
32
+ merge_splits_pass,
33
+ normalization_pass,
34
+ split_cat_pass,
35
+ unbind_stack_pass,
36
+ )
37
+
38
+ log = logging.getLogger(__name__)
39
+
40
+ _Arguments: TypeAlias = Tuple[torch.fx.node.Argument, ...]
41
+ _TransformParam: TypeAlias = Tuple[
42
+ Optional[_Arguments],
43
+ Optional[_Arguments],
44
+ Optional[_Arguments],
45
+ Optional[_Arguments],
46
+ ]
47
+ _Range: TypeAlias = Tuple[int, int]
48
+
49
+
50
+ def _get_split_args_default(split_node):
51
+ input_kwarg = "tensor"
52
+ split_size_kwarg = "split_size_or_sections"
53
+ dim_kwarg = "dim"
54
+ default_dim_value = 0
55
+ if split_node.op == "call_method":
56
+ split_size_kwarg = "split_size"
57
+ return (
58
+ get_arg_value(split_node, 0, input_kwarg),
59
+ get_arg_value(split_node, 1, split_size_kwarg),
60
+ get_arg_value(split_node, 2, dim_kwarg) or default_dim_value,
61
+ )
62
+
63
+
64
+ # noqa: W605
65
+ # ############The pattern to be optimized is#########
66
+ # unbind (dim=0)
67
+ # / ... \
68
+ # getitem getitem -> user=1
69
+ # | |
70
+ # split split -> dim=1, user=1, split_section_size=1
71
+ # | |
72
+ # getitem getitem -> user=1
73
+ # \ /
74
+ # cat (dim=1) -> user=1
75
+ # |
76
+
77
+ # ################After transformation#############
78
+ # unbind (dim=0)
79
+ # / ... \
80
+ # getitem getitem -> user=1
81
+ # \ /
82
+ # cat (dim=1) -> user=1
83
+ # |
84
+
85
+
86
+ def remove_split_with_size_one(
87
+ graph: torch.fx.Graph,
88
+ node: torch.fx.Node,
89
+ input: torch.fx.Node,
90
+ ):
91
+ # find the grand children of the split_node
92
+ next_users = find_next_users(node)
93
+ user = next(iter(node.users.keys()))
94
+ # replace the users of grand child node with the input node
95
+ for next_user in next_users:
96
+ next_user.replace_input_with(user, input)
97
+ # erase the split node and its child
98
+ graph.erase_node(user)
99
+ graph.erase_node(node)
100
+
101
+ counters["inductor"]["remove_split_with_size_one"] += 1
102
+
103
+
104
+ def normalize_split_base(
105
+ match: Match,
106
+ _get_split_args: Callable[
107
+ [torch.fx.Node], Tuple[Optional[torch.fx.Node], Optional[Any], Optional[int]]
108
+ ],
109
+ ):
110
+ """
111
+ Normalize split with split_size into split_with_sizes, so that we only deal with one type of split in
112
+ subsequent optimizations
113
+ """
114
+ split_node = match.nodes[0]
115
+ graph = match.graph
116
+ split_input, split_size, split_dim = _get_split_args(split_node)
117
+ if split_input is None or split_dim is None or split_size is None:
118
+ log.debug("couldn't find split args")
119
+ return
120
+ if "example_value" not in split_node.meta:
121
+ log.debug("example value absent for node: %s", split_node)
122
+ return
123
+ assert isinstance(split_node.meta["example_value"], (list, tuple))
124
+ split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]]
125
+
126
+ if any(isinstance(section, torch.SymInt) for section in split_sections):
127
+ # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing.
128
+ return
129
+ # remove the dummy split whose split sections size is one
130
+ if len(split_sections) == 1:
131
+ remove_split_with_size_one(graph, split_node, split_input)
132
+ return
133
+ if split_dim < 0: # Normalize split dim
134
+ split_dim += split_input.meta["example_value"].dim()
135
+ with graph.inserting_after(split_node):
136
+ new_split_node = graph.call_function(
137
+ torch.split,
138
+ args=(split_input, split_sections),
139
+ kwargs={"dim": split_dim},
140
+ )
141
+ split_node.replace_all_uses_with(new_split_node)
142
+ new_split_node.meta.update(split_node.meta)
143
+ graph.erase_node(split_node)
144
+ counters["inductor"]["split_cat_norm"] += 1
145
+
146
+
147
+ @register_graph_pattern(
148
+ CallFunctionVarArgs(torch.split, users=MULTIPLE),
149
+ pass_dict=normalization_pass,
150
+ extra_check=config_flag("split_cat_fx_passes"),
151
+ )
152
+ @register_graph_pattern(
153
+ CallMethodVarArgs("split", users=MULTIPLE),
154
+ pass_dict=normalization_pass,
155
+ extra_check=config_flag("split_cat_fx_passes"),
156
+ )
157
+ def normalize_split_default(match: Match, *args, **kwargs):
158
+ return normalize_split_base(match, _get_split_args_default)
159
+
160
+
161
+ @register_graph_pattern(
162
+ CallFunctionVarArgs(torch.unbind, users=MULTIPLE),
163
+ pass_dict=normalization_pass,
164
+ extra_check=config_flag("split_cat_fx_passes"),
165
+ )
166
+ @register_graph_pattern(
167
+ CallMethodVarArgs("unbind", users=MULTIPLE),
168
+ pass_dict=normalization_pass,
169
+ extra_check=config_flag("split_cat_fx_passes"),
170
+ )
171
+ def normalize_unbind_default(match: Match, *args, **kwargs):
172
+ node = match.nodes[0]
173
+ graph = match.graph
174
+ input = get_arg_value(node, 0, "input")
175
+ dim = get_arg_value(node, 1, "dim")
176
+ if dim is None:
177
+ axis = node.kwargs.get("axis")
178
+ if axis is not None:
179
+ dim = axis
180
+ else:
181
+ dim = 0
182
+ if input is None:
183
+ log.debug("couldn't find unbind args")
184
+ return
185
+ if "example_value" not in input.meta:
186
+ log.debug("example value absent for node: %s", input)
187
+ return
188
+ ndim = input.meta["example_value"].ndim
189
+ if dim < 0: # Normalize unbind dim
190
+ dim += ndim
191
+ with graph.inserting_after(node):
192
+ new_node = graph.call_function(
193
+ torch.unbind,
194
+ args=(input,),
195
+ kwargs={"dim": dim},
196
+ )
197
+ node.replace_all_uses_with(new_node)
198
+ new_node.meta.update(node.meta)
199
+ graph.erase_node(node)
200
+ counters["inductor"]["split_cat_norm"] += 1
201
+
202
+
203
+ @register_graph_pattern(
204
+ CallFunctionVarArgs(torch.cat, users=MULTIPLE),
205
+ pass_dict=normalization_pass,
206
+ extra_check=config_flag("split_cat_fx_passes"),
207
+ )
208
+ def normalize_cat_default(match: Match, *args, **kwargs):
209
+ from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
210
+
211
+ cat_node = match.nodes[0]
212
+ graph = match.graph
213
+ tensors = get_arg_value(cat_node, 0, "tensors")
214
+ cat_dim = get_arg_value(cat_node, 1, "dim")
215
+ if cat_dim is None:
216
+ cat_axis = cat_node.kwargs.get("axis")
217
+ if cat_axis is not None:
218
+ cat_dim = cat_axis
219
+ else:
220
+ cat_dim = 0
221
+ if tensors is None or cat_dim is None:
222
+ log.debug("couldn't find cat args")
223
+ return
224
+ assert isinstance(tensors, (list, tuple))
225
+ for tensor in itertools.chain([cat_node], tensors):
226
+ if "example_value" not in tensor.meta:
227
+ log.debug("example value absent for node: %s", tensor)
228
+ return
229
+
230
+ ndim = cat_node.meta["example_value"].dim()
231
+
232
+ def is_empty_tensor(x):
233
+ # special case where torch.cat supports cat'ing with an empty tensor
234
+ x_shape = x.meta["example_value"].shape
235
+ return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0)
236
+
237
+ assert all(
238
+ ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors
239
+ )
240
+
241
+ if cat_dim < 0: # Normalize cat dim
242
+ cat_dim += ndim
243
+
244
+ with graph.inserting_after(cat_node):
245
+ new_cat_node = graph.call_function(
246
+ torch.cat,
247
+ args=(tensors,),
248
+ kwargs={"dim": cat_dim},
249
+ )
250
+ cat_node.replace_all_uses_with(new_cat_node)
251
+ new_cat_node.meta.update(cat_node.meta)
252
+ graph.erase_node(cat_node)
253
+ counters["inductor"]["split_cat_norm"] += 1
254
+
255
+
256
+ @register_graph_pattern(
257
+ CallFunctionVarArgs(torch.stack, users=MULTIPLE),
258
+ pass_dict=normalization_pass,
259
+ extra_check=config_flag("split_cat_fx_passes"),
260
+ )
261
+ def normalize_stack_default(match: Match, *args, **kwargs):
262
+ node = match.nodes[0]
263
+ graph = match.graph
264
+ tensors = get_arg_value(node, 0, "tensors")
265
+ dim = get_arg_value(node, 1, "dim") or 0
266
+ if tensors is None or dim is None:
267
+ log.debug("couldn't find stack args")
268
+ return
269
+ assert isinstance(tensors, (list, tuple))
270
+
271
+ # A bug in pytorch, some nodes miss the example_value metadata
272
+ for tensor in itertools.chain([node], tensors):
273
+ if "example_value" not in tensor.meta:
274
+ log.debug("example value absent for node: %s", tensor)
275
+ return
276
+
277
+ ndim = node.meta["example_value"].dim()
278
+ if dim < 0: # Normalize dim
279
+ dim += ndim
280
+
281
+ with graph.inserting_after(node):
282
+ new_node = graph.call_function(
283
+ node.target,
284
+ args=(tensors,),
285
+ kwargs={"dim": dim},
286
+ )
287
+ node.replace_all_uses_with(new_node)
288
+ new_node.meta.update(node.meta)
289
+ graph.erase_node(node)
290
+ counters["inductor"]["split_cat_norm"] += 1
291
+
292
+
293
+ def find_next_users(split_node: torch.fx.Node) -> List[torch.fx.Node]:
294
+ next_users = []
295
+ for getitem_node in split_node.users.keys():
296
+ for getitem_user in getitem_node.users.keys():
297
+ if getitem_user not in next_users:
298
+ next_users.append(getitem_user)
299
+ return next_users
300
+
301
+
302
+ @register_graph_pattern(
303
+ CallMethodVarArgs("squeeze", users=MULTIPLE),
304
+ pass_dict=normalization_pass,
305
+ extra_check=config_flag("split_cat_fx_passes"),
306
+ )
307
+ def normalize_squeeze_default(match: Match, *args, **kwargs):
308
+ squeeze_node = match.nodes[0]
309
+ squeeze_input = get_arg_value(squeeze_node, 0)
310
+
311
+ if "dim" in squeeze_node.kwargs:
312
+ assert len(squeeze_node.args) == 1
313
+ dim = squeeze_node.kwargs["dim"]
314
+ elif len(squeeze_node.args) == 1:
315
+ # squeeze(Tensor)
316
+ dim = None
317
+ elif len(squeeze_node.args) == 2:
318
+ # squeeze(Tensor self, int dim)
319
+ # squeeze(Tensor self, int[] dim)
320
+ dim = squeeze_node.args[1]
321
+ else:
322
+ # squeeze(Tensor self, int[] dim) (called with varargs)
323
+ dim = squeeze_node.args[1:]
324
+
325
+ if isinstance(dim, Sequence) and len(dim) == 1:
326
+ dim = dim[0]
327
+
328
+ with match.graph.inserting_after(squeeze_node):
329
+ if dim is None:
330
+ new_squeeze_node = match.graph.call_function(
331
+ torch.squeeze, args=(squeeze_input,)
332
+ )
333
+ else:
334
+ new_squeeze_node = match.graph.call_function(
335
+ torch.squeeze, args=(squeeze_input,), kwargs={"dim": dim}
336
+ )
337
+ squeeze_node.replace_all_uses_with(new_squeeze_node)
338
+ match.graph.erase_node(squeeze_node)
339
+
340
+
341
+ class TorchSplit(CallFunction):
342
+ """
343
+ Matches a call to torch.split if it is in a normalized form. Ensures that all users of
344
+ splits are unique getitems.
345
+ """
346
+
347
+ def __init__(self, arg, sizes, func=torch.split):
348
+ # using KeywordArg("dim") for `dim` checks they all match
349
+ super().__init__(func, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim"))
350
+
351
+ def _match(self, node: torch.fx.Node, ctx: MatchContext):
352
+ m = super()._match(node, ctx)
353
+ if not m:
354
+ return m
355
+ split_sections = node.args[1]
356
+ if not isinstance(split_sections, (list, tuple)):
357
+ return FailedMatch("split not normalized")
358
+ # check users are all unique getitems
359
+ seen_idxs = set()
360
+ for user in node.users:
361
+ if not CallFunction(operator.getitem, Arg(), Arg()).match(user):
362
+ # This should ideally never happen. Split user should always be a getitem
363
+ return FailedMatch(f"user of split not a getitem: {user}")
364
+ if not isinstance(user.args[1], int):
365
+ return FailedMatch("only integer getitems are handled")
366
+ if user.args[1] in seen_idxs:
367
+ return FailedMatch(f"duplicate getitem {user.args[1]}")
368
+ if user.args[-1] < 0: # type: ignore[operator]
369
+ # This shouldn't ideally happen as dynamo normalizes indexes to positive
370
+ return FailedMatch("negative index")
371
+ seen_idxs.add(user.args[1])
372
+ return m
373
+
374
+
375
+ @register_graph_pattern(
376
+ TorchSplit(
377
+ CallFunction(
378
+ operator.getitem,
379
+ TorchSplit(
380
+ KeywordArg("first_split_input"),
381
+ KeywordArg("first_split_sections"),
382
+ ),
383
+ Ignored(),
384
+ ),
385
+ KeywordArg("next_split_sections"),
386
+ ),
387
+ pass_dict=merge_splits_pass,
388
+ extra_check=config_flag("split_cat_fx_passes"),
389
+ )
390
+ def merge_splits(
391
+ match: Match,
392
+ first_split_input: torch.fx.Node,
393
+ first_split_sections: List[int],
394
+ next_split_sections: List[int],
395
+ # Note: dim is implicitly passed by TorchSplit, as it internally uses a pattern with dim
396
+ dim: int,
397
+ ):
398
+ node = match.output_node()
399
+ # it is possible that the split has no users,
400
+ # we check the corner case and skip the pattern
401
+ if len(node.users.keys()) == 0:
402
+ return
403
+ graph = match.graph
404
+ first_split = node.args[0].args[0] # type: ignore[union-attr]
405
+ next_split_index = node.args[0].args[1] # type: ignore[union-attr]
406
+
407
+ new_split_sections = list(first_split_sections)
408
+ new_split_sections[next_split_index : next_split_index + 1] = next_split_sections # type: ignore[operator, misc]
409
+
410
+ first_split_dim = first_split.kwargs["dim"] # type: ignore[union-attr]
411
+
412
+ to_remove = []
413
+
414
+ with graph.inserting_before(first_split):
415
+ # Add the new split node
416
+ new_split = graph.call_function(
417
+ torch.split,
418
+ args=(first_split_input, new_split_sections),
419
+ kwargs={"dim": first_split_dim},
420
+ )
421
+ first_split_num_to_user = {
422
+ user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr]
423
+ }
424
+
425
+ new_split_num = 0
426
+ for split_num in range(len(first_split_sections)):
427
+ if split_num not in first_split_num_to_user:
428
+ new_split_num += 1
429
+ continue
430
+ old_getitem = first_split_num_to_user[split_num]
431
+ if split_num != next_split_index:
432
+ old_getitem.update_arg(0, new_split)
433
+ old_getitem.update_arg(1, new_split_num)
434
+ new_split_num += 1
435
+ else:
436
+ next_split_num_to_user = {
437
+ user.args[1]: user for user in node.users.keys()
438
+ }
439
+ # It is not necessary all getitems from the split node are used.
440
+ # We use the num of users to check the getitems to be merged.
441
+ for next_split_num in range(len(node.users.keys())):
442
+ with graph.inserting_after(new_split):
443
+ new_getitem = graph.call_function(
444
+ operator.getitem, args=(new_split, new_split_num)
445
+ )
446
+ new_split_num += 1
447
+ next_getitem = next_split_num_to_user[next_split_num]
448
+ new_getitem.meta.update(next_getitem.meta)
449
+ next_getitem.replace_all_uses_with(new_getitem)
450
+ to_remove.append(next_getitem)
451
+ to_remove.append(node)
452
+ to_remove.append(old_getitem)
453
+
454
+ to_remove.append(first_split) # type: ignore[arg-type]
455
+ for node in to_remove:
456
+ graph.erase_node(node)
457
+
458
+ counters["inductor"]["consecutive_split_merged"] += 1
459
+
460
+
461
+ class SplitCatSimplifier:
462
+ """
463
+ Helper class to simplify split-cat pattern. In simple cases, both split and cat node can be removed in a "split->cat"
464
+ pattern. However, there are various cases where they can't and we need to simplify split/ add transforms before cat.
465
+ Some such cases are:
466
+ 1. Final node has additional args (not coming from the initial split)
467
+ 2. Shuffling of args between split/cat
468
+ 3. Some final nodes are non-(cat/stack)
469
+ 4. Split-dim != cat-dim (but equal split)
470
+
471
+ Note that any combination of the above cases can happen.
472
+
473
+ To deal with 1, 2, & 3 - we iterate over all users of split. And figure out common "ranges" that can be merged.
474
+ Then, we simplify the split accordingly. In the best case, split can be entirely removed.
475
+
476
+ To deal with 4, we add some transformations (unflatten + movedim) (See `get_transform_params`).
477
+
478
+ Finally, depending on final node being cat or stack, unsqueeze/flatten needs to be added.
479
+
480
+ """
481
+
482
+ def simplify(
483
+ self,
484
+ graph: torch.fx.Graph,
485
+ split_node: torch.fx.Node,
486
+ split_sections: List[int],
487
+ ):
488
+ # Find the next users (i.e. users after the getitem)
489
+ next_users = find_next_users(split_node)
490
+ # Gather inputs of the next users. When inputs come from `split_node`, they are instead represented by
491
+ # a tuple indicating the split ranges. See `get_user_input_list` for more details
492
+ user_inputs_list = self.get_user_input_list(split_node, next_users)
493
+ # Simplify the split_sections based on user_inputs_list. In simpler cases, len(simplified_split_ranges) == 1 and
494
+ # we can simply replace the split node. Otherwise, we simplify it.
495
+ simplified_split_ranges = self.get_simplified_split_ranges(
496
+ split_sections, next_users, user_inputs_list
497
+ )
498
+ if not simplified_split_ranges: # Simplification not possible
499
+ return
500
+ transform_params_list = self.get_transform_params(
501
+ split_node, next_users, user_inputs_list
502
+ )
503
+ if not transform_params_list:
504
+ return
505
+
506
+ # Start actual replacement
507
+ user_inputs_list_new = self.replace_split(
508
+ graph, split_node, split_sections, user_inputs_list, simplified_split_ranges
509
+ )
510
+ self.replace_cat(
511
+ graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type]
512
+ )
513
+ self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type]
514
+
515
+ def get_user_input_list(
516
+ self, split_node: torch.fx.Node, next_users: List[torch.fx.Node]
517
+ ) -> List[List[Union[torch.fx.Node, _Range]]]:
518
+ """
519
+ Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner
520
+ list represents the inputs to that particular node. This list can either contain
521
+ - a tuple representing the ranges of get_items that should go into the cat (closed interval)
522
+ - torch.fx.Node representing "other" inputs (which are not coming from our split)
523
+ """
524
+ user_inputs_list: List[List[Union[torch.fx.Node, _Range]]] = []
525
+ for user in next_users:
526
+ if user.target in {torch.cat, torch.stack}:
527
+ user_inputs_list.append(self.get_merged_user_inputs(split_node, user))
528
+ else:
529
+ user_inputs_list.append(self.get_non_cat_node_input(split_node, user)) # type: ignore[arg-type]
530
+ return user_inputs_list
531
+
532
+ def get_merged_user_inputs(
533
+ self, split_node: torch.fx.Node, cat_node: torch.fx.Node
534
+ ) -> List[Union[torch.fx.Node, _Range]]:
535
+ user_inputs = get_arg_value(cat_node, 0, "tensors")
536
+ simplified_user_inputs = []
537
+ split_users = set(split_node.users.keys())
538
+ for user_input in user_inputs:
539
+ if user_input not in split_users:
540
+ simplified_user_inputs.append(user_input)
541
+ else:
542
+ # Add which "getitem" cat depends on
543
+ simplified_user_inputs.append(user_input.args[1])
544
+ return self.merge_consecutive_inputs(simplified_user_inputs)
545
+
546
+ def get_non_cat_node_input(
547
+ self, split_node: torch.fx.Node, node: torch.fx.Node
548
+ ) -> List[_Range]:
549
+ """
550
+ Get input for a non cat node in the same format as `get_merged_user_inputs`
551
+ """
552
+ node_input = []
553
+ split_users = set(split_node.users.keys())
554
+ for node_arg in node.all_input_nodes:
555
+ if node_arg in split_users:
556
+ getitem_num = get_arg_value(node_arg, 1)
557
+ node_input.append((getitem_num, getitem_num))
558
+ return node_input
559
+
560
+ def merge_consecutive_inputs(
561
+ self, inputs: List[Union[torch.fx.Node, int]]
562
+ ) -> List[Union[torch.fx.Node, _Range]]:
563
+ """
564
+ Merge consecutive inputs going into a user node.
565
+
566
+ For e.g.
567
+ [arg0, 0, 1, 2, arg1] -> [arg0, (0, 2), arg1]
568
+ """
569
+ merged_ranges = []
570
+ cur_range = None
571
+ for input_ in inputs:
572
+ if isinstance(input_, int):
573
+ if not cur_range:
574
+ cur_range = [input_, input_]
575
+ elif input_ == cur_range[1] + 1:
576
+ cur_range[1] += 1
577
+ else:
578
+ merged_ranges.append(tuple(cur_range))
579
+ cur_range = [input_, input_]
580
+ else:
581
+ if cur_range:
582
+ merged_ranges.append(tuple(cur_range))
583
+ cur_range = None
584
+ merged_ranges.append(input_) # type: ignore[arg-type]
585
+ if cur_range:
586
+ merged_ranges.append(tuple(cur_range))
587
+ return merged_ranges # type: ignore[return-value]
588
+
589
+ def get_simplified_split_ranges(
590
+ self,
591
+ split_sections,
592
+ next_users,
593
+ user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
594
+ ) -> Optional[List[_Range]]:
595
+ ranges = set()
596
+ for user_node, user_inputs in zip(next_users, user_inputs_list):
597
+ ranges |= {
598
+ user_input
599
+ for user_input in user_inputs
600
+ if isinstance(user_input, tuple)
601
+ }
602
+ cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist()
603
+ split_ranges = sorted(
604
+ [(cumulative_sizes[r[0]], cumulative_sizes[r[1] + 1]) for r in ranges]
605
+ )
606
+
607
+ if not self.has_non_overlapping_ranges(
608
+ split_ranges,
609
+ ): # This need not be a strict condition
610
+ # However, we keep it now for simplicity.
611
+ return None
612
+ split_ranges = self.fill_gaps(split_ranges, 0, cumulative_sizes[-1])
613
+ if len(split_sections) == len(split_ranges): # Simplification not possible
614
+ return None
615
+ counters["inductor"]["scmerge_split_sections_removed"] = len(
616
+ split_sections
617
+ ) - len(split_ranges)
618
+ return split_ranges
619
+
620
+ def has_non_overlapping_ranges(self, ranges: List[_Range]) -> bool:
621
+ for range_, next_range in zip(ranges, ranges[1:]):
622
+ if range_[1] > next_range[0]:
623
+ return False
624
+ return True
625
+
626
+ def fill_gaps(self, ranges: List[_Range], min_: int, max_: int) -> List[_Range]:
627
+ cur = min_
628
+ filled_ranges = []
629
+ for a, b in ranges:
630
+ if cur < a:
631
+ filled_ranges.append((cur, a))
632
+ filled_ranges.append((a, b))
633
+ cur = b
634
+ if filled_ranges[-1][1] < max_:
635
+ filled_ranges.append((filled_ranges[-1][1], max_))
636
+ return filled_ranges
637
+
638
+ def get_transform_params(
639
+ self,
640
+ split_node: torch.fx.Node,
641
+ next_users: List[torch.fx.Node],
642
+ user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
643
+ ) -> Optional[List[List[_TransformParam]]]:
644
+ """
645
+ Figure out what transforms are needed for each input to each cat node.
646
+
647
+ We replace a split node with an unflatten followed by a movedim
648
+ """
649
+ split_dim = split_node.kwargs["dim"]
650
+ split_sections = split_node.args[1]
651
+ transform_params_list: List[List[_TransformParam]] = []
652
+
653
+ for user_node, user_inputs in zip(next_users, user_inputs_list):
654
+ if user_node.target not in {torch.cat, torch.stack}:
655
+ transform_params_list.append([])
656
+ continue
657
+
658
+ cat_dim = get_arg_value(user_node, 1, "dim")
659
+ transform_params: List[_TransformParam] = []
660
+ for user_input in user_inputs:
661
+ if split_dim == cat_dim and user_node.target == torch.cat:
662
+ # No transform needed
663
+ transform_params.append((None, None, None, None))
664
+ elif isinstance(user_input, tuple): # Split being simplified
665
+ # Verify equal split
666
+ subset_split_sections = split_sections[ # type: ignore[index]
667
+ user_input[0] : user_input[1] + 1
668
+ ]
669
+ # All sections should be equal
670
+ if len(set(subset_split_sections)) != 1:
671
+ return None
672
+
673
+ num_splits = len(subset_split_sections)
674
+ unflatten_params = (split_dim, (num_splits, -1))
675
+ movedim_params = (
676
+ (split_dim, cat_dim) if split_dim != cat_dim else None
677
+ )
678
+ transform_params.append(
679
+ (unflatten_params, movedim_params, None, None)
680
+ )
681
+ elif (
682
+ user_node.target == torch.stack or split_dim != cat_dim
683
+ ): # We need to unsqueeze inputs not coming through split
684
+ transform_params.append((None, None, (cat_dim,), None))
685
+ else: # Non-split inputs
686
+ transform_params.append((None, None, None, None))
687
+ transform_params_list.append(transform_params)
688
+ return transform_params_list
689
+
690
+ def replace_split(
691
+ self,
692
+ graph: torch.fx.Graph,
693
+ split_node: torch.fx.Node,
694
+ split_sections: List[int],
695
+ user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
696
+ split_ranges: List[_Range],
697
+ ) -> List[List[torch.fx.Node]]:
698
+ """
699
+ Replace the split node. It can either remove the split node if len(split_ranges) == 1, or simplify it
700
+ into a split with lesser sections if len(split_ranges) > 1.
701
+
702
+ Returns the new `user_inputs_list`, with tuples replaced with new getitems from the newer split node.
703
+ """
704
+ split_input = split_node.args[0]
705
+ split_dim = split_node.kwargs["dim"]
706
+ if len(split_ranges) == 1: # We can completely eliminate the split node
707
+ split_items = [split_input]
708
+ else:
709
+ with graph.inserting_after(split_node):
710
+ new_split = graph.call_function(
711
+ torch.split,
712
+ args=(
713
+ split_input,
714
+ [r[1] - r[0] for r in split_ranges],
715
+ ),
716
+ kwargs={"dim": split_dim},
717
+ )
718
+ new_split.meta.update(split_node.meta)
719
+ counters["inductor"]["scmerge_split_added"] += 1
720
+ with graph.inserting_after(new_split):
721
+ split_items = [
722
+ graph.call_function(operator.getitem, args=(new_split, i))
723
+ for i in range(len(split_ranges))
724
+ ]
725
+ # Now assign the right getitem to the right input
726
+ cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist()
727
+ new_user_inputs_list = []
728
+ for user_inputs in user_inputs_list:
729
+ new_user_inputs = []
730
+ for user_input in user_inputs:
731
+ if isinstance(user_input, tuple):
732
+ # Find the correct new getitem (present in split_items)
733
+ new_user_inputs.append(
734
+ split_items[
735
+ split_ranges.index(
736
+ (
737
+ cumulative_sizes[user_input[0]],
738
+ cumulative_sizes[user_input[1] + 1],
739
+ )
740
+ )
741
+ ]
742
+ )
743
+ else:
744
+ new_user_inputs.append(user_input)
745
+ new_user_inputs_list.append(new_user_inputs)
746
+ return new_user_inputs_list # type: ignore[return-value]
747
+
748
+ def replace_cat(
749
+ self,
750
+ graph: torch.fx.GraphModule,
751
+ split_node: torch.fx.Node,
752
+ next_users: List[torch.fx.Node],
753
+ user_inputs_list_new,
754
+ transform_params_list: List[List[_TransformParam]],
755
+ ):
756
+ split_dim = split_node.kwargs["dim"]
757
+
758
+ split_users = split_node.users.keys()
759
+ new_cats = []
760
+ for user_node, user_inputs_new, transform_params in zip(
761
+ next_users, user_inputs_list_new, transform_params_list
762
+ ):
763
+ if user_node.target not in {torch.cat, torch.stack}:
764
+ # Change the args and kwargs of non-cat/stack nodes. Replace old getitems (belonging to
765
+ # the original split node) with the newer getitems
766
+ next_cat_input = 0
767
+ for input_node in user_node.all_input_nodes:
768
+ if input_node in split_users:
769
+ user_node.replace_input_with(
770
+ input_node, user_inputs_new[next_cat_input]
771
+ )
772
+ next_cat_input += 1
773
+ continue
774
+
775
+ # Handle cat/stack user nodes
776
+ cat_dim = get_arg_value(user_node, 1, "dim")
777
+ user_inputs_new_transformed = []
778
+ # For `unsqueeze` transform, we will combine consecutive inputs with the same unsqueeze params, and stack them
779
+ to_stack = []
780
+ stack_dim = None
781
+ with graph.inserting_before(user_node):
782
+ for user_input_new, transform_param in zip(
783
+ user_inputs_new, transform_params
784
+ ):
785
+ # Apply transforms
786
+ (
787
+ unflatten_params,
788
+ movedim_params,
789
+ unsqueeze_params,
790
+ flatten_params,
791
+ ) = transform_param
792
+ if unsqueeze_params and (
793
+ stack_dim is None or stack_dim == unsqueeze_params[0]
794
+ ):
795
+ to_stack.append(user_input_new)
796
+ stack_dim = unsqueeze_params[0]
797
+ continue
798
+ elif to_stack:
799
+ stacked_input = graph.call_function(
800
+ torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
801
+ )
802
+ to_stack = []
803
+ stack_dim = None
804
+ user_inputs_new_transformed.append(stacked_input)
805
+ if unsqueeze_params:
806
+ to_stack.append(user_input_new)
807
+ stack_dim = unsqueeze_params[0]
808
+ continue
809
+
810
+ if unflatten_params:
811
+ user_input_new = graph.call_function(
812
+ torch.unflatten, args=(user_input_new, *unflatten_params)
813
+ )
814
+ if movedim_params:
815
+ user_input_new = graph.call_function(
816
+ torch.movedim, args=(user_input_new, *movedim_params)
817
+ )
818
+ if flatten_params:
819
+ user_input_new = graph.call_function(
820
+ torch.flatten, args=(user_input_new, *flatten_params)
821
+ )
822
+ user_inputs_new_transformed.append(user_input_new)
823
+ if to_stack:
824
+ stacked_input = graph.call_function(
825
+ torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
826
+ )
827
+ user_inputs_new_transformed.append(stacked_input)
828
+
829
+ with graph.inserting_after(user_node):
830
+ if len(user_inputs_new_transformed) > 1:
831
+ new_cat_node = graph.call_function(
832
+ torch.cat,
833
+ args=(user_inputs_new_transformed,),
834
+ kwargs={"dim": cat_dim},
835
+ )
836
+ new_cat_node.meta.update(user_node.meta)
837
+ counters["inductor"]["scmerge_cat_added"] += 1
838
+ else:
839
+ new_cat_node = user_inputs_new_transformed[-1]
840
+
841
+ if (
842
+ user_node.target == torch.cat
843
+ and split_dim != cat_dim
844
+ and split_node.target == torch.split
845
+ ):
846
+ with graph.inserting_after(new_cat_node):
847
+ new_cat_node = graph.call_function(
848
+ torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1)
849
+ )
850
+ user_node.replace_all_uses_with(new_cat_node)
851
+ new_cats.append(new_cat_node)
852
+
853
+ def erase_old_nodes(
854
+ self,
855
+ graph: torch.fx.GraphModule,
856
+ split_node: torch.fx.Node,
857
+ next_users: List[torch.fx.Node],
858
+ ):
859
+ to_remove = [split_node]
860
+ counters["inductor"]["scmerge_split_removed"] += 1
861
+ to_remove.extend(split_node.users.keys())
862
+ for next_user in next_users:
863
+ if next_user.target not in {torch.cat, torch.stack}:
864
+ continue
865
+ counters["inductor"]["scmerge_cat_removed"] += 1
866
+ to_remove.append(next_user)
867
+ for node in reversed(to_remove):
868
+ graph.erase_node(node)
869
+
870
+
871
+ class UnbindCatRemover(SplitCatSimplifier):
872
+ """
873
+ Helper class to merge Unbind->Cat/Stack. Many of the cases are similar to SplitCatSimplifier.
874
+
875
+ Unbind can't be simplified like splits. So, we can only remove the unbind node. Other than this,
876
+ other cases like multiple users, additional args, dim mismatch are similar to `SplitCatSimplifier`,
877
+ hence we extend that class.
878
+ """
879
+
880
+ def remove_unbind(
881
+ self,
882
+ graph: torch.fx.Graph,
883
+ unbind_node: torch.fx.Node,
884
+ ):
885
+ num_unbind = ( # type: ignore[operator]
886
+ max(getitem_node.args[1] for getitem_node in unbind_node.users.keys()) + 1 # type: ignore[operator, union-attr, type-var]
887
+ )
888
+ split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type]
889
+
890
+ super().simplify(graph, unbind_node, split_sections)
891
+
892
+ def get_simplified_split_ranges(
893
+ self,
894
+ split_sections: List[int],
895
+ next_users: List[torch.fx.Node],
896
+ user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
897
+ ) -> Optional[List[_Range]]:
898
+ simplified_split_ranges = super().get_simplified_split_ranges(
899
+ split_sections, next_users, user_inputs_list
900
+ )
901
+ if not simplified_split_ranges or len(simplified_split_ranges) != 1:
902
+ return None
903
+ return simplified_split_ranges
904
+
905
+ def get_transform_params(
906
+ self,
907
+ unbind_node: torch.fx.Node,
908
+ next_users: List[torch.fx.Node],
909
+ user_inputs_list: List[List[Union[torch.fx.Node, _Range]]],
910
+ ) -> Optional[List[List[_TransformParam]]]:
911
+ """
912
+ Figure out what transforms are needed for each input to each cat node.
913
+
914
+ Here is the rough transforms we apply:
915
+
916
+ x -> unbind -> stack => x -> movedim
917
+
918
+ x -> unbind -> cat => x -> movedim -> flatten
919
+
920
+ When cat/stack nodes have additional args:
921
+
922
+ addn ---| addn -> unsqueeze ---|
923
+ x -> unbind -> stack => x -> movedim -> cat
924
+
925
+ addn ---| addn ---|
926
+ x -> unbind -> cat => x -> movedim -> flatten -> cat
927
+
928
+ (Note application of these depends on the dims as well)
929
+
930
+
931
+ """
932
+ split_dim = unbind_node.kwargs["dim"]
933
+ transform_params_list: List[List[_TransformParam]] = []
934
+ for user_node, user_inputs in zip(next_users, user_inputs_list):
935
+ cat_dim = get_arg_value(user_node, 1, "dim") or 0
936
+ transform_params: List[_TransformParam] = []
937
+ for user_input in user_inputs:
938
+ if isinstance(user_input, tuple):
939
+ # User input is coming from unbind
940
+ movedim_params = (
941
+ (split_dim, cat_dim) if split_dim != cat_dim else None
942
+ )
943
+ flatten_params = None
944
+ if user_node.target == torch.cat:
945
+ flatten_params = (cat_dim, cat_dim + 1)
946
+ transform_params.append(
947
+ (None, movedim_params, None, flatten_params)
948
+ )
949
+ elif (
950
+ user_node.target == torch.stack
951
+ ): # We need to unsqueeze inputs not coming through unbind into cat
952
+ transform_params.append((None, None, (cat_dim,), None))
953
+ else: # Non-unbind inputs
954
+ transform_params.append((None, None, None, None))
955
+ transform_params_list.append(transform_params)
956
+ return transform_params_list
957
+
958
+
959
+ class GetItem(CallFunction):
960
+ def __init__(self, arg, index, _users=1):
961
+ super().__init__(operator.getitem, arg, index, _users=_users)
962
+
963
+ def find_anchor_nodes(self, ctx: MatchContext, searched: Set[torch.fx.Node]):
964
+ # We generally match GetItem with arg being an Arg(). So, we never return the anchor
965
+ # nodes as the stored node in ctx.pattern_to_node is returned. Here we override find_anchor_nodes
966
+ # to not use ctx.pattern_to_node
967
+ for pattern in self.flat_args_kwargs[0]:
968
+ if isinstance(pattern, PatternExpr):
969
+ for other_node in pattern.find_anchor_nodes(ctx, searched):
970
+ if not isinstance(other_node, torch.fx.Node):
971
+ continue
972
+ for node in other_node.users:
973
+ if node not in searched:
974
+ if self._match_fns(node):
975
+ yield node
976
+ searched.add(node)
977
+
978
+
979
+ @register_graph_pattern(
980
+ RepeatedExpr(
981
+ CallFunction(
982
+ torch.squeeze,
983
+ GetItem(
984
+ TorchSplit(
985
+ KeywordArg("split_input"),
986
+ KeywordArg("split_sizes"),
987
+ ),
988
+ Ignored(),
989
+ ),
990
+ KeywordArg("dim"),
991
+ _users=MULTIPLE,
992
+ ),
993
+ ),
994
+ pass_dict=split_cat_pass,
995
+ extra_check=config_flag("split_cat_fx_passes"),
996
+ )
997
+ @register_graph_pattern(
998
+ RepeatedExpr(
999
+ CallFunction(
1000
+ torch.squeeze,
1001
+ GetItem(
1002
+ TorchSplit(
1003
+ KeywordArg("split_input"),
1004
+ KeywordArg("split_sizes"),
1005
+ ),
1006
+ Ignored(),
1007
+ ),
1008
+ dim=KeywordArg("dim"),
1009
+ _users=MULTIPLE,
1010
+ )
1011
+ ),
1012
+ pass_dict=split_cat_pass,
1013
+ extra_check=config_flag("split_cat_fx_passes"),
1014
+ )
1015
+ def merge_split_squeeze(
1016
+ match: Match, split_input: torch.fx.Node, split_sizes: List[int], dim: int
1017
+ ):
1018
+ graph = match.graph
1019
+ split = next(node for node in match.nodes if node.target == torch.split)
1020
+ if not all(s == 1 for s in split_sizes):
1021
+ return
1022
+ if isinstance(dim, Sequence):
1023
+ return
1024
+ next_users = find_next_users(split)
1025
+ if not all(node.target == torch.squeeze for node in next_users):
1026
+ return
1027
+ with graph.inserting_before(match.output_node()):
1028
+ unbind = graph.call_function(
1029
+ torch.unbind, args=(split_input,), kwargs={"dim": dim}
1030
+ )
1031
+ for item_index, getitem_node in sorted(
1032
+ [
1033
+ (getitem_node.args[1], getitem_node)
1034
+ for getitem_node in split.users.keys()
1035
+ ]
1036
+ ):
1037
+ squeeze = next(iter(getitem_node.users.keys()))
1038
+ new_get_item = graph.call_function(
1039
+ operator.getitem, args=(unbind, item_index)
1040
+ )
1041
+ squeeze.replace_all_uses_with(new_get_item)
1042
+ new_get_item.meta.update(squeeze.meta)
1043
+ graph.erase_node(squeeze)
1044
+ graph.erase_node(getitem_node)
1045
+ graph.erase_node(split)
1046
+ counters["inductor"]["split_squeeze_replaced"] += 1
1047
+
1048
+
1049
+ getitem_unbind = ListOf(
1050
+ GetItem(
1051
+ CallFunction(
1052
+ torch.unbind,
1053
+ KeywordArg("unbind_input"),
1054
+ dim=KeywordArg("dim"),
1055
+ _users=MULTIPLE,
1056
+ ),
1057
+ Ignored(),
1058
+ _users=MULTIPLE,
1059
+ ),
1060
+ partial=True,
1061
+ )
1062
+
1063
+
1064
+ @register_graph_pattern(
1065
+ CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE),
1066
+ pass_dict=unbind_stack_pass,
1067
+ extra_check=config_flag("split_cat_fx_passes"),
1068
+ )
1069
+ @register_graph_pattern(
1070
+ CallFunction(
1071
+ [torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE
1072
+ ),
1073
+ pass_dict=unbind_stack_pass,
1074
+ extra_check=config_flag("split_cat_fx_passes"),
1075
+ )
1076
+ @register_graph_pattern(
1077
+ CallFunction(
1078
+ [torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE
1079
+ ),
1080
+ pass_dict=unbind_stack_pass,
1081
+ extra_check=config_flag("split_cat_fx_passes"),
1082
+ )
1083
+ def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int):
1084
+ unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
1085
+ UnbindCatRemover().remove_unbind(match.graph, unbind_node)
1086
+
1087
+
1088
+ getitem_split = ListOf(
1089
+ CallFunction(
1090
+ operator.getitem,
1091
+ TorchSplit(
1092
+ Ignored(),
1093
+ KeywordArg("split_sections"),
1094
+ ),
1095
+ Ignored(),
1096
+ _users=MULTIPLE,
1097
+ ),
1098
+ partial=True,
1099
+ )
1100
+
1101
+
1102
+ @register_graph_pattern(
1103
+ CallFunction(
1104
+ [torch.stack, torch.cat],
1105
+ tensors=getitem_split,
1106
+ dim=Ignored(),
1107
+ _users=MULTIPLE,
1108
+ ),
1109
+ pass_dict=split_cat_pass,
1110
+ extra_check=config_flag("split_cat_fx_passes"),
1111
+ )
1112
+ @register_graph_pattern(
1113
+ CallFunction(
1114
+ [torch.stack, torch.cat],
1115
+ getitem_split,
1116
+ dim=Ignored(),
1117
+ _users=MULTIPLE,
1118
+ ),
1119
+ pass_dict=split_cat_pass,
1120
+ extra_check=config_flag("split_cat_fx_passes"),
1121
+ )
1122
+ @register_graph_pattern(
1123
+ CallFunction(
1124
+ [torch.stack, torch.cat],
1125
+ getitem_split,
1126
+ Ignored(),
1127
+ _users=MULTIPLE,
1128
+ ),
1129
+ pass_dict=split_cat_pass,
1130
+ extra_check=config_flag("split_cat_fx_passes"),
1131
+ )
1132
+ def simplify_split_cat(match: Match, split_sections: List[int], dim: int):
1133
+ if not isinstance(split_sections, (list, tuple)): # Unnormalized split
1134
+ return
1135
+ split_node = next(node for node in match.nodes if node.target == torch.split)
1136
+ SplitCatSimplifier().simplify(match.graph, split_node, split_sections)
1137
+
1138
+
1139
+ # noqa: W605
1140
+ # ############pattern to be optimized is#########
1141
+
1142
+ # split_node(dim=1)
1143
+ # / \ ... / \
1144
+ # getitem getitem getitem getitem -> user=1
1145
+ # \ / \ /
1146
+ # cat (user=mul, dim=1) cat(user=mul, dim=1)
1147
+ # | \ | \
1148
+
1149
+ # ################after transformation#############
1150
+
1151
+ # split_node(dim=1)
1152
+ # / ... \
1153
+ # getitem getitem
1154
+ # | \ | \
1155
+
1156
+
1157
+ def has_same_parent_node(node: torch.fx.Node):
1158
+ # the input nodes of the node should come from the same parent
1159
+ prev_node = None
1160
+ for getitem in node.args[0]: # type: ignore[union-attr]
1161
+ if getitem.target != operator.getitem: # type: ignore[union-attr]
1162
+ return False
1163
+ if prev_node is None:
1164
+ prev_node = getitem.args[0] # type: ignore[union-attr]
1165
+ else:
1166
+ if getitem.args[0] != prev_node:
1167
+ return False
1168
+ return True
1169
+
1170
+
1171
+ def remove_zeros(split_sections: List[int]):
1172
+ """
1173
+ Remove zeros from the list and get the index mapping dict from getitem
1174
+ in split node to getitem in new split node
1175
+ """
1176
+ new_split_sections, index_mapping = [], {}
1177
+ idx = 0
1178
+ for i in range(len(split_sections)):
1179
+ if split_sections[i] > 0:
1180
+ new_split_sections.append(split_sections[i])
1181
+ index_mapping[i] = idx
1182
+ idx += 1
1183
+
1184
+ return new_split_sections, index_mapping
1185
+
1186
+
1187
+ def is_sorted_and_consecutive(arr: List[int]) -> bool:
1188
+ # check if the array is sorted
1189
+ if arr == sorted(arr):
1190
+ # check if the differences between adjacent elements are all 1
1191
+ return all(x[1] - x[0] == 1 for x in zip(arr, arr[1:]))
1192
+ else:
1193
+ return False
1194
+
1195
+
1196
+ def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: List[int]) -> int:
1197
+ """
1198
+ Calculate the fused tensor size in the indices
1199
+ """
1200
+ fused_tensor_size = 0
1201
+ for i in range(len(split_node.args[1])): # type: ignore[arg-type]
1202
+ if i in indices:
1203
+ fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
1204
+ return fused_tensor_size
1205
+
1206
+
1207
+ @register_graph_pattern(
1208
+ CallFunction(
1209
+ torch.cat,
1210
+ getitem_split,
1211
+ dim=Ignored(),
1212
+ _users=MULTIPLE,
1213
+ ),
1214
+ pass_dict=merge_getitem_cat_pass,
1215
+ extra_check=config_flag("split_cat_fx_passes"),
1216
+ )
1217
+ def merge_getitem_cat(match: Match, split_sections: List[int], dim: int):
1218
+ if not isinstance(split_sections, (list, tuple)): # Unnormalized split
1219
+ return
1220
+ graph = match.graph
1221
+ split_node = next(node for node in match.nodes if node.target == torch.split)
1222
+ split_input, split_size, split_dim = _get_split_args_default(split_node)
1223
+ # if the cat and split have different dims, return
1224
+ # Find the next users (i.e. users after the getitem)
1225
+ next_users = find_next_users(split_node)
1226
+ # 'immutable_list' object does not support mutation. Create a new copy of it
1227
+ split_sections = list(split_sections)
1228
+ for cat_user in next_users:
1229
+ if cat_user.target == torch.cat:
1230
+ cat_dim = get_arg_value(cat_user, 1, "dim")
1231
+ # check the all getitems in the cat_user from the same node
1232
+ # check the input of the cat has all getitem from the split
1233
+ # check all getitem only has one single user
1234
+ if (
1235
+ split_dim != cat_dim
1236
+ or not has_same_parent_node(cat_user)
1237
+ or not all(len(arg.users) == 1 for arg in cat_user.args[0]) # type: ignore[union-attr]
1238
+ ):
1239
+ continue
1240
+ # find the index of getitems to be cated/stacked
1241
+ indices = []
1242
+ for arg in cat_user.args[0]: # type: ignore[union-attr]
1243
+ indices.append(arg.args[1]) # type: ignore[union-attr]
1244
+ # the gettitems to be merged must be consecutive, otherwise
1245
+ # returned sliced tensor could be wrong
1246
+ if not is_sorted_and_consecutive(indices):
1247
+ continue
1248
+ # update the arg of cat user, only keep the first getitem
1249
+ cat_user.update_arg(0, cat_user.args[0][0]) # type: ignore[index]
1250
+ # calculate the fused tensor sizes in the indices
1251
+ fused_tensor_size = 0
1252
+ for i in range(len(split_node.args[1])): # type: ignore[arg-type]
1253
+ if i in indices:
1254
+ fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
1255
+ # update the split sections
1256
+ split_sections[indices[0]] = calculate_fused_tensor_size(
1257
+ split_node, indices
1258
+ )
1259
+ # padding others with zeros to keep the same dict size
1260
+ for i in indices[1:]:
1261
+ split_sections[i] = 0
1262
+ # remove all unused indexes in the split_node
1263
+ new_split_sections, index_mapping = remove_zeros(split_sections)
1264
+ with graph.inserting_after(split_node):
1265
+ new_split_node = graph.call_function(
1266
+ torch.split,
1267
+ args=(split_input, split_sections),
1268
+ kwargs={"dim": split_dim},
1269
+ )
1270
+ split_node.replace_all_uses_with(new_split_node)
1271
+ new_split_node.meta.update(split_node.meta)
1272
+ # remove all unused getitem nodes
1273
+ to_remove = [cat_user]
1274
+ # dictionary keys changed during iteration
1275
+ new_split_getitem_nodes = list(new_split_node.users.keys())
1276
+ for getitem_node in new_split_getitem_nodes:
1277
+ if getitem_node.args[1] in indices[1:]:
1278
+ to_remove.append(getitem_node)
1279
+ # update meta data of getitem
1280
+ elif getitem_node.args[1] == indices[0]:
1281
+ cat_user.replace_all_uses_with(getitem_node)
1282
+ getitem_node.meta.update(cat_user.meta)
1283
+ else:
1284
+ # update getitem index for new split node
1285
+ getitem_node.update_arg(1, index_mapping[getitem_node.args[1]])
1286
+ graph.erase_node(split_node)
1287
+ for getitem_node in to_remove:
1288
+ graph.erase_node(getitem_node)
1289
+ # update the split sections of new split node
1290
+ new_split_node.update_arg(1, new_split_sections)
1291
+ split_node = new_split_node
1292
+ split_sections = new_split_sections
1293
+
1294
+ counters["inductor"]["getitem_cat_merged"] += 1
1295
+
1296
+
1297
+ # ############pattern to be optimized is#########
1298
+
1299
+ # split_node(dim=1) -> user=multiple
1300
+ # / \ ... / \
1301
+ # getitem getitem getitem getitem -> user=multiple
1302
+ # \ \ / \
1303
+ # other_op /cat(user=mul, dim=1) other_op
1304
+ # |
1305
+
1306
+ # ################after transformation#############
1307
+
1308
+ # split_node(dim=1) -> -> user=multiple
1309
+ # / \ ... / \
1310
+ # getitem getitem getitem getitem -> user=multiple
1311
+ # \ \ / \
1312
+ # other_op
1313
+
1314
+
1315
+ @register_graph_pattern(
1316
+ CallFunction(
1317
+ torch.cat,
1318
+ getitem_split,
1319
+ dim=Ignored(),
1320
+ _users=MULTIPLE,
1321
+ ),
1322
+ pass_dict=split_cat_pass,
1323
+ extra_check=config_flag("split_cat_fx_passes"),
1324
+ )
1325
+ def mutate_cat_node(match: Match, split_sections: List[int], dim: int):
1326
+ if not isinstance(split_sections, (list, tuple)): # Unnormalized split
1327
+ return
1328
+ graph = match.graph
1329
+ split_node = next(node for node in match.nodes if node.target == torch.split)
1330
+ split_input, split_size, split_dim = _get_split_args_default(split_node)
1331
+ # if the cat and split have different dims, return
1332
+ # Find the next users (i.e. users after the getitem)
1333
+ next_users = find_next_users(split_node)
1334
+ for cat_user in next_users:
1335
+ if cat_user.target == torch.cat:
1336
+ cat_dim = get_arg_value(cat_user, 1, "dim") or 0
1337
+ # check that all getitems in the cat_user from the same node
1338
+ # check the input of the cat has all getitem from the split
1339
+ if split_dim != cat_dim or not has_same_parent_node(cat_user):
1340
+ continue
1341
+ # find the index of getitems to be cat
1342
+ indices, idx_to_getitem = [], {}
1343
+ for getitem in cat_user.args[0]: # type: ignore[union-attr]
1344
+ indices.append(getitem.args[1]) # type: ignore[union-attr]
1345
+ idx_to_getitem[getitem.args[1]] = getitem # type: ignore[union-attr]
1346
+ # the gettitems to be merged must be consecutive, otherwise
1347
+ # returned sliced tensor could be wrong
1348
+ if not is_sorted_and_consecutive(indices):
1349
+ continue
1350
+ # case 1: the cat uses all getitems from the split
1351
+ if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type]
1352
+ # replace the users of the cat node to be the input of the split node
1353
+ cat_user.replace_all_uses_with(split_node.args[0])
1354
+ # remove the cat node
1355
+ graph.erase_node(cat_user)
1356
+ counters["inductor"]["cat_mutated"] += 1
1357
+ # case 2: the cat uses some getitems from the split
1358
+ elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type]
1359
+ # check the split dim, and construct the slice tuple
1360
+ start_fused_size = calculate_fused_tensor_size(
1361
+ split_node, list(range(indices[0]))
1362
+ )
1363
+ end_fused_size = start_fused_size + calculate_fused_tensor_size(
1364
+ split_node, indices
1365
+ )
1366
+ slice_list = []
1367
+ for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr]
1368
+ if i != split_dim:
1369
+ slice_list.append(slice(None, None, None))
1370
+ else:
1371
+ slice_list.append(slice(start_fused_size, end_fused_size, None))
1372
+ with graph.inserting_after(split_node):
1373
+ slice_node = graph.call_function(
1374
+ operator.getitem,
1375
+ args=(split_node.args[0], tuple(slice_list)),
1376
+ )
1377
+ cat_user.replace_all_uses_with(slice_node)
1378
+ slice_node.meta.update(cat_user.meta)
1379
+
1380
+ # remove the cat node
1381
+ graph.erase_node(cat_user)
1382
+ counters["inductor"]["cat_mutated"] += 1
1383
+
1384
+
1385
+ # noqa: W605
1386
+ # ############The pattern to be optimized is#########
1387
+ # split_node (dim=1)
1388
+ # / ... \ ... / \
1389
+ # getitem getitem getitem getitem -> user=1
1390
+ # \ /
1391
+ # stack (dim=0) -> user=1, getitems to be consecutive
1392
+ # |
1393
+ # tahn -> user=1
1394
+ # |
1395
+ # unbind (dim=0)
1396
+ # |
1397
+
1398
+ # ################After transformation#############
1399
+ # split_node (dim=1)
1400
+ # / ... / \
1401
+ # getitem getitem getitem -> user=1
1402
+ # |
1403
+ # tahn
1404
+ # |
1405
+ # split
1406
+ # |
1407
+
1408
+
1409
+ @register_graph_pattern(
1410
+ CallFunction(
1411
+ torch.tanh,
1412
+ CallFunction(
1413
+ torch.stack,
1414
+ getitem_split,
1415
+ dim=Ignored(),
1416
+ ),
1417
+ ),
1418
+ pass_dict=merge_getitem_cat_pass,
1419
+ extra_check=config_flag("split_cat_fx_passes"),
1420
+ )
1421
+ @register_graph_pattern(
1422
+ CallFunction(
1423
+ torch.tanh,
1424
+ CallFunction(
1425
+ torch.stack,
1426
+ tensors=getitem_split,
1427
+ dim=Ignored(),
1428
+ ),
1429
+ ),
1430
+ pass_dict=merge_getitem_cat_pass,
1431
+ extra_check=config_flag("split_cat_fx_passes"),
1432
+ )
1433
+ @register_graph_pattern(
1434
+ CallFunction(
1435
+ torch.tanh,
1436
+ CallFunction(
1437
+ torch.stack,
1438
+ getitem_split,
1439
+ Ignored(),
1440
+ ),
1441
+ ),
1442
+ pass_dict=merge_getitem_cat_pass,
1443
+ extra_check=config_flag("split_cat_fx_passes"),
1444
+ )
1445
+ def merge_stack_tahn_unbind(match: Match, split_sections: List[int], dim: int):
1446
+ if not isinstance(split_sections, (list, tuple)): # Unnormalized split
1447
+ return
1448
+ graph = match.graph
1449
+ split_node = next(node for node in match.nodes if node.target == torch.split)
1450
+ split_input, split_size, split_dim = _get_split_args_default(split_node)
1451
+ # Find the next users (i.e. users after the getitem)
1452
+ next_users = find_next_users(split_node)
1453
+ # 'immutable_list' object does not support mutation. Create a new copy of it
1454
+ split_sections = list(split_sections)
1455
+ for user in next_users:
1456
+ # stack user only has one user
1457
+ if user.target == torch.stack:
1458
+ stack_dim = get_arg_value(user, 1, "dim") or 0
1459
+ unbind_user = find_next_users(user)[0]
1460
+ if unbind_user.target != torch.unbind:
1461
+ continue
1462
+ unbind_dim = get_arg_value(unbind_user, 1, "dim") or 0
1463
+ # stack and unbind should have the same dim
1464
+ # check the all getitems in the user from the same node
1465
+ # check all the getitems only has single user
1466
+ if (
1467
+ stack_dim != unbind_dim
1468
+ or not has_same_parent_node(user)
1469
+ or not all(len(arg.users) == 1 for arg in user.args[0]) # type: ignore[union-attr]
1470
+ ):
1471
+ continue
1472
+ # find the index of getitems to be stacked
1473
+ indices = []
1474
+ split_sections_for_unbind = []
1475
+ for arg in user.args[0]: # type: ignore[union-attr]
1476
+ indices.append(arg.args[1]) # type: ignore[union-attr]
1477
+ split_sections_for_unbind.append(split_sections[arg.args[1]]) # type: ignore[union-attr]
1478
+ # the gettitems to be merged must be consecutive, otherwise
1479
+ # returned sliced tensor could be wrong
1480
+ if not is_sorted_and_consecutive(indices):
1481
+ continue
1482
+ # update the arg of stack user, only keep the first getitem
1483
+ user.update_arg(0, user.args[0][0]) # type: ignore[index]
1484
+ # calculate the fused tensor sizes in the indices
1485
+ fused_tensor_size = 0
1486
+ for i in range(len(split_node.args[1])): # type: ignore[arg-type]
1487
+ if i in indices:
1488
+ fused_tensor_size += split_node.args[1][i] # type: ignore[operator, index, assignment]
1489
+ # update the split sections
1490
+ split_sections[indices[0]] = calculate_fused_tensor_size(
1491
+ split_node, indices
1492
+ )
1493
+ # padding others with zeros to keep the same dict size
1494
+ for i in indices[1:]:
1495
+ split_sections[i] = 0
1496
+ # remove all unused indexes in the split_node
1497
+ new_split_sections, index_mapping = remove_zeros(split_sections)
1498
+ with graph.inserting_after(split_node):
1499
+ new_split_node = graph.call_function(
1500
+ torch.split,
1501
+ args=(split_input, split_sections),
1502
+ kwargs={"dim": split_dim},
1503
+ )
1504
+ replace_unbind_with_split = graph.call_function(
1505
+ torch.split,
1506
+ args=(unbind_user.args[0], split_sections_for_unbind),
1507
+ kwargs={"dim": split_dim},
1508
+ )
1509
+ unbind_user.replace_all_uses_with(replace_unbind_with_split)
1510
+ replace_unbind_with_split.meta.update(unbind_user.meta)
1511
+ # remove getitem and split, stack
1512
+ split_node.replace_all_uses_with(new_split_node)
1513
+ new_split_node.meta.update(split_node.meta)
1514
+ # remove all unused getitem nodes
1515
+ to_remove = [unbind_user]
1516
+ # dictionary keys changed during iteration
1517
+ new_split_getitem_nodes = list(new_split_node.users.keys())
1518
+ for getitem_node in new_split_getitem_nodes:
1519
+ if getitem_node.args[1] in indices[1:]:
1520
+ to_remove.append(getitem_node)
1521
+ # update meta data of getitem
1522
+ elif getitem_node.args[1] == indices[0]:
1523
+ user.replace_all_uses_with(getitem_node)
1524
+ getitem_node.meta.update(user.meta)
1525
+ else:
1526
+ # update getitem index for new split node
1527
+ getitem_node.update_arg(1, index_mapping[getitem_node.args[1]])
1528
+ graph.erase_node(split_node)
1529
+ graph.erase_node(user)
1530
+ for getitem_node in to_remove:
1531
+ graph.erase_node(getitem_node)
1532
+ # update the split sections of new split node
1533
+ new_split_node.update_arg(1, new_split_sections)
1534
+ split_node = new_split_node
1535
+ split_sections = new_split_sections
1536
+
1537
+ counters["inductor"]["stack_tahn_unbind_merged"] += 1
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/inductor_prims.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Optional, Sequence
5
+
6
+ import torch
7
+ from torch import _prims, Tensor
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ def make_prim(
13
+ schema: str,
14
+ impl_aten,
15
+ return_type=_prims.RETURN_TYPE.NEW,
16
+ doc: str = "",
17
+ tags: Optional[Sequence[torch.Tag]] = None,
18
+ ):
19
+ def meta(*args, **kwargs):
20
+ return _prims.TensorMeta(impl_aten(*args, **kwargs))
21
+
22
+ return _prims._make_prim(
23
+ schema=schema,
24
+ return_type=return_type,
25
+ meta=meta,
26
+ impl_aten=impl_aten,
27
+ doc=doc,
28
+ tags=tags,
29
+ )
30
+
31
+
32
+ def eager_force_stride(input_tensor: Tensor, stride) -> Tensor:
33
+ if input_tensor.stride() == stride:
34
+ return input_tensor
35
+ new_tensor = input_tensor.clone().as_strided(
36
+ input_tensor.shape,
37
+ stride,
38
+ )
39
+ new_tensor.copy_(input_tensor)
40
+ return new_tensor
41
+
42
+
43
+ # Custom prims used for handling randomness
44
+ seed = make_prim(
45
+ "inductor_seed(Device device) -> Tensor",
46
+ lambda device: torch.randint(2**63 - 1, [], device=device),
47
+ doc="create a fresh seed (one per call) for use with inductor_rand",
48
+ tags=(torch.Tag.nondeterministic_seeded,),
49
+ )
50
+ seeds = make_prim(
51
+ "inductor_seeds(int count, Device device) -> Tensor",
52
+ lambda count, device: torch.randint(2**63 - 1, [count], device=device),
53
+ doc="Horizontal fusion of many inductor_seed() calls",
54
+ tags=(torch.Tag.nondeterministic_seeded,),
55
+ )
56
+ lookup_seed = make_prim(
57
+ # if inductor_lookup_seed changes, update partitioners.py
58
+ "inductor_lookup_seed(Tensor seeds, int index) -> Tensor",
59
+ lambda seeds, index: seeds[index],
60
+ doc="Extract a single seed from the result of inductor_seeds()",
61
+ )
62
+ random = make_prim(
63
+ "inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor",
64
+ lambda size, seed, mode: getattr(torch, mode)(size, device=seed.device),
65
+ doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused",
66
+ )
67
+ randint = make_prim(
68
+ "inductor_randint(SymInt low, SymInt high, SymInt[] size, Tensor seed) -> Tensor",
69
+ lambda low, high, size, seed: torch.randint(low, high, size, device=seed.device),
70
+ doc="torch.randint() using backend-specific RNG that can be fused",
71
+ )
72
+ force_stride_order = make_prim(
73
+ "inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor",
74
+ eager_force_stride,
75
+ doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise",
76
+ )
77
+ masked_scatter_with_index = make_prim(
78
+ "inductor_masked_scatter_with_index(Tensor input, Tensor mask, Tensor source_idx, Tensor source) -> Tensor",
79
+ lambda input_tensor, mask, index, source: torch.masked_scatter(
80
+ input_tensor, mask, source
81
+ ),
82
+ doc="masked_scatter with precomputed indices",
83
+ )
84
+ _unsafe_index_put_ = make_prim(
85
+ "_unsafe_index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)",
86
+ lambda self, indices, values, accumulate=False: torch.ops.aten.index_put_(
87
+ self, indices, values, accumulate
88
+ ),
89
+ doc="Unsafe index_put_ (doesn't issue device asserts)",
90
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py ADDED
The diff for this file is too large to render. See raw diff
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/test_case.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import tempfile
3
+ import unittest
4
+
5
+ from torch._dynamo.test_case import (
6
+ run_tests as dynamo_run_tests,
7
+ TestCase as DynamoTestCase,
8
+ )
9
+
10
+ from torch._inductor import config
11
+
12
+
13
+ def run_tests(needs=()):
14
+ dynamo_run_tests(needs)
15
+
16
+
17
+ class TestCase(DynamoTestCase):
18
+ """
19
+ A base TestCase for inductor tests. Enables FX graph caching and isolates
20
+ the cache directory for each test.
21
+ """
22
+
23
+ _stack: contextlib.ExitStack
24
+
25
+ @classmethod
26
+ def setUpClass(cls):
27
+ super().setUpClass()
28
+ cls._stack = contextlib.ExitStack()
29
+ cls._stack.enter_context(config.patch({"fx_graph_cache": True}))
30
+
31
+ @classmethod
32
+ def tearDownClass(cls):
33
+ super().tearDownClass()
34
+ cls._stack.close()
35
+
36
+ def setUp(self):
37
+ super().setUp()
38
+
39
+ # For all tests, mock the tmp directory populated by the inductor
40
+ # FxGraphCache, both for test isolation and to avoid filling disk.
41
+ self._inductor_cache_tmp_dir = tempfile.TemporaryDirectory()
42
+ self._inductor_cache_get_tmp_dir_patch = unittest.mock.patch(
43
+ "torch._inductor.codecache.FxGraphCache._get_tmp_dir"
44
+ )
45
+ mock_get_dir = self._inductor_cache_get_tmp_dir_patch.start()
46
+ mock_get_dir.return_value = self._inductor_cache_tmp_dir.name
47
+
48
+ def tearDown(self):
49
+ super().tearDown()
50
+
51
+ # Clean up the FxGraphCache tmp dir.
52
+ self._inductor_cache_get_tmp_dir_patch.stop()
53
+ self._inductor_cache_tmp_dir.cleanup()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (80.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/jiterator.cpython-311.pyc ADDED
Binary file (7.99 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/nccl.cpython-311.pyc ADDED
Binary file (6.46 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/random.cpython-311.pyc ADDED
Binary file (8.62 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/streams.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import sys
3
+ import os
4
+ import io
5
+ import subprocess
6
+ import json
7
+ from functools import lru_cache
8
+ from typing import Any
9
+ from itertools import groupby
10
+ import base64
11
+ import warnings
12
+
13
+ cache = lru_cache(None)
14
+
15
+ __all__ = ["format_flamegraph", "segments", "memory", "compare"]
16
+
17
+ def _frame_fmt(f, full_filename=False):
18
+ i = f['line']
19
+ fname = f['filename']
20
+ if not full_filename:
21
+ fname = fname.split('/')[-1]
22
+ func = f['name']
23
+ return f'{fname}:{i}:{func}'
24
+
25
+ @cache
26
+ def _frame_filter(name, filename):
27
+ omit_functions = [
28
+ "unwind::unwind",
29
+ "CapturedTraceback::gather",
30
+ "gather_with_cpp",
31
+ "_start",
32
+ "__libc_start_main",
33
+ "PyEval_",
34
+ "PyObject_",
35
+ "PyFunction_",
36
+ ]
37
+ omit_filenames = [
38
+ "core/boxing",
39
+ "/Register",
40
+ "/Redispatch",
41
+ "pythonrun.c",
42
+ "Modules/main.c",
43
+ "Objects/call.c",
44
+ "Objects/methodobject.c",
45
+ "pycore_ceval.h",
46
+ "ceval.c",
47
+ "cpython/abstract.h",
48
+ ]
49
+ for of in omit_functions:
50
+ if of in name:
51
+ return False
52
+ for of in omit_filenames:
53
+ if of in filename:
54
+ return False
55
+ return True
56
+
57
+ def _frames_fmt(frames, full_filename=False, reverse=False):
58
+ if reverse:
59
+ frames = reversed(frames)
60
+ return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f['name'], f['filename'])]
61
+
62
+ def _block_extra_legacy(b):
63
+ if 'history' in b:
64
+ frames = b['history'][0].get('frames', [])
65
+ real_size = b['history'][0]['real_size']
66
+ else:
67
+ real_size = b.get('requested_size', b['size'])
68
+ frames = []
69
+ return frames, real_size
70
+
71
+ def _block_extra(b):
72
+ if 'frames' not in b:
73
+ # old snapshot format made it more complicated to get frames/allocated size
74
+ return _block_extra_legacy(b)
75
+ return b['frames'], b['requested_size']
76
+
77
+ def format_flamegraph(flamegraph_lines, flamegraph_script=None):
78
+ if flamegraph_script is None:
79
+ flamegraph_script = f'/tmp/{os.getuid()}_flamegraph.pl'
80
+ if not os.path.exists(flamegraph_script):
81
+ import urllib.request
82
+ print(f"Downloading flamegraph.pl to: {flamegraph_script}")
83
+ urllib.request.urlretrieve(
84
+ 'https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl', flamegraph_script)
85
+ subprocess.check_call(['chmod', '+x', flamegraph_script])
86
+ args = [flamegraph_script, '--countname', 'bytes']
87
+ p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8')
88
+ assert p.stdin is not None
89
+ assert p.stdout is not None
90
+ p.stdin.write(flamegraph_lines)
91
+ p.stdin.close()
92
+ result = p.stdout.read()
93
+ p.stdout.close()
94
+ p.wait()
95
+ assert p.wait() == 0
96
+ return result
97
+
98
+ def _write_blocks(f, prefix, blocks):
99
+ def frames_fragment(frames):
100
+ if not frames:
101
+ return "<non-python>"
102
+ return ';'.join(_frames_fmt(frames, reverse=True))
103
+ for b in blocks:
104
+ if 'history' not in b:
105
+ frames, accounted_for_size = _block_extra(b)
106
+ f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n')
107
+ else:
108
+ accounted_for_size = 0
109
+ for h in b['history']:
110
+ sz = h['real_size']
111
+ accounted_for_size += sz
112
+ if 'frames' in h:
113
+ frames = h['frames']
114
+ f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n')
115
+ else:
116
+ f.write(f'{prefix};{b["state"]};<no-context> {sz}\n')
117
+ gaps = b['size'] - accounted_for_size
118
+ if gaps:
119
+ f.write(f'{prefix};{b["state"]};<gaps> {gaps}\n')
120
+
121
+ def segments(snapshot, format_flamegraph=format_flamegraph):
122
+ f = io.StringIO()
123
+ for seg in snapshot['segments']:
124
+ prefix = f'stream_{seg["stream"]};seg_{seg["address"]}'
125
+ _write_blocks(f, prefix, seg['blocks'])
126
+ return format_flamegraph(f.getvalue())
127
+
128
+ def memory(snapshot, format_flamegraph=format_flamegraph):
129
+ f = io.StringIO()
130
+ for seg in snapshot['segments']:
131
+ prefix = f'stream_{seg["stream"]}'
132
+ _write_blocks(f, prefix, seg['blocks'])
133
+ return format_flamegraph(f.getvalue())
134
+
135
+ def compare(before, after, format_flamegraph=format_flamegraph):
136
+ def _seg_key(seg):
137
+ return (seg['address'], seg['total_size'])
138
+
139
+ def _seg_info(seg):
140
+ return f'stream_{seg["stream"]};seg_{seg["address"]}'
141
+
142
+ f = io.StringIO()
143
+
144
+ before_segs = {_seg_key(seg) for seg in before}
145
+ after_segs = {_seg_key(seg) for seg in after}
146
+
147
+ print(f'only_before = {[a for a,_ in (before_segs - after_segs)]}')
148
+ print(f'only_after = {[a for a,_ in (after_segs - before_segs)]}')
149
+
150
+ for seg in before:
151
+ if _seg_key(seg) not in after_segs:
152
+ _write_blocks(f, f'only_before;{_seg_info(seg)}', seg['blocks'])
153
+
154
+ for seg in after:
155
+ if _seg_key(seg) not in before_segs:
156
+ _write_blocks(f, f'only_after;{_seg_info(seg)}', seg['blocks'])
157
+
158
+ return format_flamegraph(f.getvalue())
159
+
160
+ def _format_size(num):
161
+ # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size
162
+ for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
163
+ if abs(num) < 1024.0:
164
+ return f"{num:3.1f}{unit}B"
165
+ num /= 1024.0
166
+ return f"{num:.1f}YiB"
167
+
168
+ class Bytes:
169
+ def __init__(self, value):
170
+ self.value = value
171
+
172
+ def __add__(self, rhs):
173
+ return Bytes(self.value + rhs)
174
+
175
+ def __repr__(self):
176
+ return _format_size(self.value)
177
+
178
+ def calc_active(seg):
179
+ return sum(b['size'] for b in seg['blocks'] if b['state'] == 'active_allocated')
180
+
181
+ def _report_free(free_external, free_internal):
182
+ total = free_external + free_internal
183
+ suffix = ''
184
+ if total != 0:
185
+ pct = (free_internal / total) * 100
186
+ suffix = f' ({pct:.1f}% internal)'
187
+ return f'{Bytes(total)}{suffix}'
188
+
189
+ PAGE_SIZE = 1024 * 1024 * 20
190
+ legend = f"""\
191
+
192
+ Legend:
193
+ [a ] - a segment in the allocator
194
+ ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment
195
+ a-z: pages filled with a single block's content
196
+ ' ': page is completely free
197
+ *: page if completely full with multiple blocks
198
+ 0-9: page is partially full with tensors of multiple blocks (9 == 90% full)
199
+ (X% internal) - of the free memory, X% is free because we rounded the size of the allocation.
200
+ """
201
+
202
+ def segsum(data):
203
+ r"""Visually reports how the allocator has filled its segments.
204
+
205
+ This printout can help debug fragmentation issues since free fragments
206
+ will appear as gaps in this printout. The amount of free space is reported
207
+ for each segment.
208
+ We distinguish between internal free memory which occurs because the
209
+ allocator rounds the allocation size, and external free memory, which are
210
+ the gaps between allocations in a segment.
211
+ Args:
212
+ data: snapshot dictionary created from _snapshot()
213
+ """
214
+ segments = []
215
+ out = io.StringIO()
216
+ out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n")
217
+ total_reserved = 0
218
+ total_allocated = 0
219
+ free_external = 0
220
+ free_internal = 0
221
+ for seg in sorted(data['segments'], key=lambda x: (x['total_size'], calc_active(x))):
222
+ total_reserved += seg['total_size']
223
+
224
+ seg_free_external = 0
225
+ seg_free_internal = 0
226
+ seg_allocated = 0
227
+ all_ranges = []
228
+ boffset = 0
229
+ for b in seg['blocks']:
230
+ active = b['state'] == 'active_allocated'
231
+ if active:
232
+ _, allocated_size = _block_extra(b)
233
+ all_ranges.append((boffset, allocated_size, True))
234
+ seg_allocated += allocated_size
235
+ seg_free_internal += b['size'] - allocated_size
236
+ else:
237
+ seg_free_external += b['size']
238
+
239
+ boffset += b['size']
240
+
241
+ total_allocated += seg_allocated
242
+ free_external += seg_free_external
243
+ free_internal += seg_free_internal
244
+
245
+ nseg = (seg['total_size'] - 1) // PAGE_SIZE + 1
246
+ occupied = [' ' for _ in range(nseg)]
247
+ frac = [0.0 for _ in range(nseg)]
248
+ active_size = 0
249
+ for i, (start_, size, active) in enumerate(all_ranges):
250
+ active_size += size
251
+ finish_ = (start_ + size)
252
+ start = start_ // PAGE_SIZE
253
+ finish = (finish_ - 1) // PAGE_SIZE + 1
254
+ m = chr(ord('a' if active else 'A') + (i % 26))
255
+ for j in range(start, finish):
256
+ s = max(start_, j * PAGE_SIZE)
257
+ e = min(finish_, (j + 1) * PAGE_SIZE)
258
+ frac[j] += (e - s) / PAGE_SIZE
259
+ if occupied[j] != ' ':
260
+ occupied[j] = '0123456789*'[int(frac[j] * 10)]
261
+ else:
262
+ occupied[j] = m
263
+ stream = '' if seg['stream'] == 0 else f', stream_{seg["stream"]}'
264
+ body = ''.join(occupied)
265
+ assert seg_free_external + seg_free_internal + seg_allocated == seg['total_size']
266
+ stream = f' stream_{seg["stream"]}' if seg['stream'] != 0 else ''
267
+ if seg['total_size'] >= PAGE_SIZE:
268
+ out.write(f'[{body}] {Bytes(seg["total_size"])} allocated, '
269
+ f'{_report_free(seg_free_external, seg_free_internal)} free{stream}\n')
270
+ out.write(f'segments: {len(data["segments"])}\n')
271
+ out.write(f'total_reserved: {Bytes(total_reserved)}\n')
272
+ out.write(f'total_allocated: {Bytes(total_allocated)}\n')
273
+ internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else ''
274
+ out.write(f'total_free: {_report_free(free_external, free_internal)}\n')
275
+ out.write(legend)
276
+ assert free_internal + free_external + total_allocated == total_reserved
277
+ return out.getvalue()
278
+
279
+ def trace(data):
280
+ out = io.StringIO()
281
+
282
+ def format(entries):
283
+ segment_intervals : list = []
284
+ segment_addr_to_name = {}
285
+ allocation_addr_to_name = {}
286
+
287
+ free_names : list = []
288
+ next_name = 0
289
+
290
+ def _name():
291
+ nonlocal next_name
292
+ if free_names:
293
+ return free_names.pop()
294
+ r, m = next_name // 26, next_name % 26
295
+ next_name += 1
296
+ return f'{chr(ord("a") + m)}{"" if r == 0 else r}'
297
+
298
+ def find_segment(addr):
299
+ for name, saddr, size in segment_intervals:
300
+ if addr >= saddr and addr < saddr + size:
301
+ return name, saddr
302
+ for i, seg in enumerate(data['segments']):
303
+ saddr = seg['address']
304
+ size = seg['allocated_size']
305
+ if addr >= saddr and addr < saddr + size:
306
+ return f'seg_{i}', saddr
307
+ return None, None
308
+ count = 0
309
+ out.write(f'{len(entries)} entries\n')
310
+
311
+
312
+ total_reserved = 0
313
+ for seg in data['segments']:
314
+ total_reserved += seg['total_size']
315
+
316
+ for count, e in enumerate(entries):
317
+ if e['action'] == 'alloc':
318
+ addr, size = e['addr'], e['size']
319
+ n = _name()
320
+ seg_name, seg_addr = find_segment(addr)
321
+ if seg_name is None:
322
+ seg_name = "MEM"
323
+ offset = addr
324
+ else:
325
+ offset = addr - seg_addr
326
+ out.write(f'{n} = {seg_name}[{offset}:{Bytes(size)}]\n')
327
+ allocation_addr_to_name[addr] = (n, size, count)
328
+ count += size
329
+ elif e['action'] == 'free_requested':
330
+ addr, size = e['addr'], e['size']
331
+ name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
332
+ out.write(f'del {name} # {Bytes(size)}\n')
333
+ elif e['action'] == 'free_completed':
334
+ addr, size = e['addr'], e['size']
335
+ count -= size
336
+ name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
337
+ out.write(f'# free completed for {name} {Bytes(size)}\n')
338
+ if name in allocation_addr_to_name:
339
+ free_names.append(name)
340
+ del allocation_addr_to_name[name]
341
+ elif e['action'] == 'segment_alloc':
342
+ addr, size = e['addr'], e['size']
343
+ name = _name()
344
+ out.write(f'{name} = cudaMalloc({addr}, {Bytes(size)})\n')
345
+ segment_intervals.append((name, addr, size))
346
+ segment_addr_to_name[addr] = name
347
+ elif e['action'] == 'segment_free':
348
+ addr, size = e['addr'], e['size']
349
+ name = segment_addr_to_name.get(addr, addr)
350
+ out.write(f'cudaFree({name}) # {Bytes(size)}\n')
351
+ if name in segment_addr_to_name:
352
+ free_names.append(name)
353
+ del segment_addr_to_name[name]
354
+ elif e['action'] == 'oom':
355
+ size = e['size']
356
+ free = e['device_free']
357
+ out.write(f'raise OutOfMemoryError() # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n')
358
+ else:
359
+ out.write(f'{e}\n')
360
+ out.write(f"TOTAL MEM: {Bytes(count)}")
361
+ for i, d in enumerate(data['device_traces']):
362
+ if d:
363
+ out.write(f'Device {i} ----------------\n')
364
+ format(d)
365
+ return out.getvalue()
366
+
367
+
368
+ _memory_viz_template = r"""
369
+ <!DOCTYPE html>
370
+ <html>
371
+ <head>
372
+ </head>
373
+ <body>
374
+ <script type="module">
375
+ import {add_local_files} from "https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/utils/viz/MemoryViz.js"
376
+ const local_files = $SNAPSHOT
377
+ add_local_files(local_files, $VIZ_KIND)
378
+ </script>
379
+ </body>
380
+ """
381
+
382
+ def _format_viz(data, viz_kind, device):
383
+ if device is not None:
384
+ warnings.warn('device argument is deprecated, plots now contain all device')
385
+ buffer = pickle.dumps(data)
386
+ buffer += b'\x00' * (3 - len(buffer) % 3)
387
+ # Encode the buffer with base64
388
+ encoded_buffer = base64.b64encode(buffer).decode('utf-8')
389
+
390
+ json_format = json.dumps([{"name": 'snapshot.pickle', "base64": encoded_buffer}])
391
+ return _memory_viz_template.replace('$VIZ_KIND', repr(viz_kind)) \
392
+ .replace('$SNAPSHOT', json_format)
393
+
394
+ def trace_plot(data, device=None, plot_segments=False):
395
+ """Generate a visualization over time of the memory usage recorded by the trace as an html file.
396
+
397
+ Args:
398
+ data: Memory snapshot as generated from torch.cuda.memory._snapshot()
399
+ device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
400
+ plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
401
+ Defaults to False.
402
+
403
+ Returns:
404
+ str: HTML of visualization
405
+ """
406
+ return _format_viz(data, 'Active Memory Timeline' if not plot_segments else 'Active Cached Memory Timeline', device)
407
+
408
+
409
+ def _profile_to_snapshot(profile):
410
+ import torch
411
+ from torch.profiler._memory_profiler import Action, TensorKey
412
+ from torch._C._profiler import _EventType
413
+ memory_profile = profile._memory_profile()
414
+
415
+ allocation_stacks = {}
416
+ for event in memory_profile._op_tree.sorted_nodes:
417
+ if event.tag == _EventType.Allocation:
418
+ parent = event.parent
419
+ python_parents = []
420
+ while parent:
421
+ if parent.tag in (_EventType.PyCall, _EventType.PyCCall):
422
+ python_parents.append(parent)
423
+ parent = parent.parent
424
+ key = TensorKey.from_allocation(event.extra_fields)
425
+
426
+ # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor)
427
+ # key will be None. I should add some way to identify these, I just haven't yet.
428
+ if key and event.extra_fields.alloc_size > 0:
429
+ allocation_stacks[key] = python_parents
430
+
431
+
432
+ device_count = torch.cuda.device_count()
433
+ snapshot = {
434
+ 'device_traces': [[] for _ in range(device_count + 1)],
435
+ 'segments': [{'device': device,
436
+ 'address': None,
437
+ 'total_size': 0,
438
+ 'stream': 0,
439
+ 'blocks': []} for device in range(device_count + 1)]
440
+ }
441
+
442
+ def to_device(device):
443
+ if device.type == 'cuda':
444
+ return device.index
445
+ else:
446
+ return device_count
447
+
448
+ def allocate(size, tensor_key, version, during_trace=True):
449
+ device = to_device(tensor_key.device)
450
+ addr = tensor_key.storage.ptr
451
+
452
+ seg = snapshot['segments'][device] # type: ignore[index]
453
+ if seg['address'] is None or seg['address'] > addr:
454
+ seg['address'] = addr
455
+ seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later
456
+ category = memory_profile._categories.get(tensor_key, version)
457
+ category = category.name.lower() if category is not None else "unknown"
458
+ stack = allocation_stacks.get(tensor_key, ())
459
+ stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack]
460
+ r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category}
461
+ if during_trace:
462
+ snapshot['device_traces'][device].append(r) # type: ignore[index]
463
+ return r
464
+
465
+ def free(alloc, device):
466
+ for e in ('free_requested', 'free_completed'):
467
+ snapshot['device_traces'][device].append({'action': e, # type: ignore[index]
468
+ 'addr': alloc['addr'],
469
+ 'size': alloc['size'],
470
+ 'stream': 0,
471
+ 'frames': alloc['frames']})
472
+
473
+ kv_to_elem = {}
474
+
475
+
476
+
477
+ # create the device trace
478
+ for time, action, (tensor_key, version), size in memory_profile.timeline:
479
+ if not isinstance(tensor_key, TensorKey):
480
+ continue
481
+ if action == Action.CREATE:
482
+ kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version)
483
+ elif action == Action.DESTROY:
484
+ free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
485
+ elif action == Action.INCREMENT_VERSION:
486
+ free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
487
+ kv_to_elem[(tensor_key, version + 1)] = allocate(size, tensor_key, version + 1)
488
+ elif action == Action.PREEXISTING:
489
+ kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version, during_trace=False)
490
+
491
+
492
+ # create the final snapshot state
493
+ blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames'])
494
+ for (tensor_key, version), event in kv_to_elem.items()]
495
+ for device, blocks in groupby(sorted(blocks_at_end), key=lambda x: x[0]):
496
+ seg = snapshot['segments'][device] # type: ignore[index]
497
+ last_addr = seg['address']
498
+ for _, addr, size, frames in blocks:
499
+ if last_addr < addr:
500
+ seg['blocks'].append({'size': addr - last_addr, 'state': 'inactive'})
501
+ seg['blocks'].append({'size': size, 'state': 'active_allocated', 'requested_size': size, 'frames': frames})
502
+ last_addr = addr + size
503
+ if last_addr < seg['total_size']:
504
+ seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'})
505
+
506
+ snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined]
507
+ for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef]
508
+ seg['total_size'] -= seg['address']
509
+ if not seg['blocks']:
510
+ seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'})
511
+
512
+ return snapshot
513
+
514
+ def profile_plot(profile, device=None):
515
+ """Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file.
516
+
517
+ Args:
518
+ profile: profile as generated by `torch.profiler.profile(profile_memory=True)`
519
+ device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
520
+
521
+ Returns:
522
+ str: HTML of visualization
523
+ """
524
+ snapshot = _profile_to_snapshot(profile)
525
+ return _format_viz(snapshot, 'Active Memory Timeline', device)
526
+
527
+
528
+ def segment_plot(data: Any, device=None):
529
+ return _format_viz(data, 'Allocator State History', device)
530
+
531
+ if __name__ == "__main__":
532
+ import os.path
533
+ thedir = os.path.realpath(os.path.dirname(__file__))
534
+ if thedir in sys.path:
535
+ # otherwise we find cuda/random.py as random...
536
+ sys.path.remove(thedir)
537
+ import argparse
538
+
539
+ fn_name = 'torch.cuda.memory._snapshot()'
540
+ pickled = f'pickled memory statistics from {fn_name}'
541
+ parser = argparse.ArgumentParser(description=f'Visualize memory dumps produced by {fn_name}')
542
+
543
+ subparsers = parser.add_subparsers(dest='action')
544
+
545
+ def _output(p):
546
+ p.add_argument('-o', '--output', default='output.svg', help='flamegraph svg (default: output.svg)')
547
+
548
+ description = 'Prints overall allocation statistics and a visualization of how the allocators segments are currently filled.'
549
+ stats_a = subparsers.add_parser('stats', description=description)
550
+ stats_a.add_argument('input', help=pickled)
551
+
552
+ description = 'Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style.'
553
+ trace_a = subparsers.add_parser('trace', description=description)
554
+ trace_a.add_argument('input', help=pickled)
555
+
556
+ description = 'Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)'
557
+ segments_a = subparsers.add_parser('segments', description=description)
558
+ segments_a.add_argument('input', help=pickled)
559
+ _output(segments_a)
560
+
561
+ description = "Generate a flamegraph the program locations contributing to CUDA memory usage."
562
+ memory_a = subparsers.add_parser('memory', description=description)
563
+ memory_a.add_argument('input', help=pickled)
564
+ _output(memory_a)
565
+
566
+ description = 'Generate a flamegraph that shows segments (aka blocks) that have been added ' \
567
+ 'or removed between two different memorys snapshots.'
568
+ compare_a = subparsers.add_parser('compare', description=description)
569
+ compare_a.add_argument('before', help=pickled)
570
+ compare_a.add_argument('after', help=pickled)
571
+ _output(compare_a)
572
+
573
+ plots = (
574
+ ("trace_plot", "Generate a visualization over time of the memory usage recorded by the trace as an html file."),
575
+ ("segment_plot", "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.")
576
+ )
577
+ for cmd, description in plots:
578
+ trace_plot_a = subparsers.add_parser(cmd, description=description)
579
+ trace_plot_a.add_argument('input', help=pickled)
580
+ help = 'visualize trace from this device (default: chooses the only device with trace info or errors)'
581
+ trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help)
582
+ help = 'path to save the visualization(default: output.html)'
583
+ trace_plot_a.add_argument('-o', '--output', default='output.html', help=help)
584
+ if cmd == "trace_plot":
585
+ help = 'visualize change to segments rather than individual allocations'
586
+ trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help)
587
+
588
+
589
+ args = parser.parse_args()
590
+
591
+ def _read(name):
592
+ if name == '-':
593
+ f = sys.stdin.buffer
594
+ else:
595
+ f = open(name, 'rb')
596
+ data = pickle.load(f)
597
+ if isinstance(data, list): # segments only...
598
+ data = {'segments': data, 'traces': []}
599
+ return data
600
+
601
+ def _write(name, data):
602
+ with open(name, 'w') as f:
603
+ f.write(data)
604
+
605
+ if args.action == 'segments':
606
+ data = _read(args.input)
607
+ _write(args.output, segments(data))
608
+ elif args.action == 'memory':
609
+ data = _read(args.input)
610
+ _write(args.output, memory(data))
611
+ elif args.action == 'stats':
612
+ data = _read(args.input)
613
+ print(segsum(data))
614
+ elif args.action == 'trace':
615
+ data = _read(args.input)
616
+ print(trace(data))
617
+ elif args.action == 'compare':
618
+ before = _read(args.before)
619
+ after = _read(args.after)
620
+ _write(args.output, compare(before, after))
621
+ elif args.action == 'trace_plot':
622
+ data = _read(args.input)
623
+ _write(args.output, trace_plot(data, device=args.device, plot_segments=args.segments))
624
+ elif args.action == 'segment_plot':
625
+ data = _read(args.input)
626
+ _write(args.output, segment_plot(data, device=args.device))