Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comms.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/exc.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/graph.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comms.py +363 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/config.py +752 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py +264 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/graph.py +1324 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ir.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py +1524 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/scheduler.py +2445 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py +1156 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/sizevars.py +643 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/utils.py +1428 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/wrapper_benchmark.py +299 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DynamicLibrary.h +34 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h +324 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h +86 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h +9 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContext.h +9 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h +115 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Exceptions.h +174 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSDevice.h +85 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AmpKernels.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUBlas.h +189 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h +263 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvolutionMM3d.h +14 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Copy.h +20 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h +229 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ForeachUtils.h +371 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebra.h +18 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SortingUtils.h +88 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UnaryOps.h +130 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h +37 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CatKernel.h +12 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h +14 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h +33 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Loops.h +394 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h +14 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h +238 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h +22 -0
.gitattributes
CHANGED
|
@@ -76,3 +76,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/
|
|
| 76 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 77 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text
|
| 78 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 76 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 77 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text
|
| 78 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
|
| 79 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 filter=lfs diff=lfs merge=lfs -text
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:26a7288b7315d658acab1073f02c4f18cd1d27eeadde102958f0317dad6656e0
|
| 3 |
+
size 150200
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-311.pyc
ADDED
|
Binary file (8.74 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comms.cpython-311.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-311.pyc
ADDED
|
Binary file (6.13 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/exc.cpython-311.pyc
ADDED
|
Binary file (7.37 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/graph.cpython-311.pyc
ADDED
|
Binary file (67.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-311.pyc
ADDED
|
Binary file (18.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-311.pyc
ADDED
|
Binary file (4.85 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comms.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pyre-strict
|
| 2 |
+
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from . import config, ir, scheduler
|
| 8 |
+
from .dependencies import WeakDep
|
| 9 |
+
from .utils import tuple_sorted
|
| 10 |
+
|
| 11 |
+
overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def sink_waits(
|
| 15 |
+
snodes: List["scheduler.BaseSchedulerNode"],
|
| 16 |
+
) -> List["scheduler.BaseSchedulerNode"]:
|
| 17 |
+
"""
|
| 18 |
+
Greedily moves waits as late as possible (i.e. until we reach a use). Optimal in terms of
|
| 19 |
+
communication overlap.
|
| 20 |
+
"""
|
| 21 |
+
new_order = []
|
| 22 |
+
cur_waits = set()
|
| 23 |
+
for snode in snodes:
|
| 24 |
+
if isinstance(snode.node, ir.Wait):
|
| 25 |
+
cur_waits.add(snode)
|
| 26 |
+
else:
|
| 27 |
+
for wait in tuple_sorted(cur_waits):
|
| 28 |
+
if snode in wait.node_users:
|
| 29 |
+
new_order.append(wait)
|
| 30 |
+
cur_waits.remove(wait)
|
| 31 |
+
new_order.append(snode)
|
| 32 |
+
new_order.extend(tuple_sorted(cur_waits))
|
| 33 |
+
return new_order
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def raise_comms(
|
| 37 |
+
snodes: List["scheduler.BaseSchedulerNode"],
|
| 38 |
+
) -> List["scheduler.BaseSchedulerNode"]:
|
| 39 |
+
"""
|
| 40 |
+
Greedily moves comms as early as possible (i.e. until we reach an input).
|
| 41 |
+
Optimal in terms of communication overlap.
|
| 42 |
+
|
| 43 |
+
TODO: We might want to adjust this in the future to account for memory limitations.
|
| 44 |
+
e.g. when we are compiling FSDP, this heuristics will cause the all-gathers to be prefetched as soon as possible,
|
| 45 |
+
which is the beginning of the forwards pass. We'll have to either do a special pass for FSDP,
|
| 46 |
+
or we'll want to redo this pass with memory considerations so we handle the FSDP case in a general way.
|
| 47 |
+
"""
|
| 48 |
+
new_order_reversed: List["scheduler.BaseSchedulerNode"] = []
|
| 49 |
+
cur_comms: List["scheduler.BaseSchedulerNode"] = []
|
| 50 |
+
for snode in reversed(snodes):
|
| 51 |
+
if isinstance(snode.node, ir.CollectiveKernel):
|
| 52 |
+
cur_comms.append(snode)
|
| 53 |
+
else:
|
| 54 |
+
for comm in cur_comms:
|
| 55 |
+
assert len(comm.inverse_users) > 0
|
| 56 |
+
while len(cur_comms) > 0 and any(
|
| 57 |
+
snode in comm.inverse_users for comm in cur_comms
|
| 58 |
+
):
|
| 59 |
+
comm = cur_comms.pop(0)
|
| 60 |
+
new_order_reversed.append(comm)
|
| 61 |
+
new_order_reversed.append(snode)
|
| 62 |
+
assert len(cur_comms) <= 1
|
| 63 |
+
new_order_reversed.extend(tuple_sorted(cur_comms))
|
| 64 |
+
return new_order_reversed[::-1]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_ancestors(node):
|
| 68 |
+
ancestors = set()
|
| 69 |
+
cur_nodes = [node]
|
| 70 |
+
while len(cur_nodes) > 0:
|
| 71 |
+
new_nodes = []
|
| 72 |
+
for node in cur_nodes:
|
| 73 |
+
for inp in node.inverse_users:
|
| 74 |
+
if inp not in ancestors:
|
| 75 |
+
ancestors.add(inp)
|
| 76 |
+
new_nodes.append(inp)
|
| 77 |
+
cur_nodes = new_nodes
|
| 78 |
+
return ancestors
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_descendants(node):
|
| 82 |
+
descendants = set()
|
| 83 |
+
cur_nodes = [node]
|
| 84 |
+
while len(cur_nodes) > 0:
|
| 85 |
+
new_nodes = []
|
| 86 |
+
for node in cur_nodes:
|
| 87 |
+
for inp in node.node_users:
|
| 88 |
+
if inp not in descendants:
|
| 89 |
+
descendants.add(inp)
|
| 90 |
+
new_nodes.append(inp)
|
| 91 |
+
cur_nodes = new_nodes
|
| 92 |
+
return descendants
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def decide_global_ordering_of_comms(nodes: List["scheduler.BaseSchedulerNode"]):
|
| 96 |
+
"""
|
| 97 |
+
Decide global ordering of comms, by just enforcing the ordering that's in the input graph
|
| 98 |
+
(might not be the same ordering as the eager mode program).
|
| 99 |
+
TODO: Come up with a better approach
|
| 100 |
+
"""
|
| 101 |
+
comm_nodes = [n for n in nodes if isinstance(n.node, ir.CollectiveKernel)]
|
| 102 |
+
for i in range(1, len(comm_nodes)):
|
| 103 |
+
# Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
|
| 104 |
+
comm_nodes[i].add_fake_dep(WeakDep(comm_nodes[i - 1].get_name()))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def assert_no_comm_nodes(snodes: List["scheduler.BaseSchedulerNode"]) -> None:
|
| 108 |
+
assert not any(isinstance(snode.node, ir.CollectiveKernel) for snode in snodes)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def estimate_op_runtime(snode: "scheduler.BaseSchedulerNode") -> float:
|
| 112 |
+
"""
|
| 113 |
+
Returns estimated op runtime in nanoseconds (ns)
|
| 114 |
+
"""
|
| 115 |
+
if config.estimate_op_runtime == "default":
|
| 116 |
+
runtime = snode.get_estimated_runtime()
|
| 117 |
+
else:
|
| 118 |
+
assert callable(config.estimate_op_runtime)
|
| 119 |
+
runtime = config.estimate_op_runtime(snode)
|
| 120 |
+
return runtime
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def reorder_compute_for_overlap(
|
| 124 |
+
snodes: List["scheduler.BaseSchedulerNode"],
|
| 125 |
+
) -> List["scheduler.BaseSchedulerNode"]:
|
| 126 |
+
"""
|
| 127 |
+
Decides a global ordering of all compute and communication nodes,
|
| 128 |
+
assuming that we already have a global ordering of communication nodes.
|
| 129 |
+
|
| 130 |
+
Overall scheduling procedure is:
|
| 131 |
+
Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
|
| 132 |
+
that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N.
|
| 133 |
+
Step 2: If all those compute nodes are sufficient to overlap comm N, we're done.
|
| 134 |
+
Otherwise, we now need to look elsewhere to find compute that overlaps with comm N.
|
| 135 |
+
We prioritize compute nodes that are needed sooner.
|
| 136 |
+
Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1.
|
| 137 |
+
Step 4: We schedule comm N + 1.
|
| 138 |
+
Repeat this for subsequent comm nodes.
|
| 139 |
+
"""
|
| 140 |
+
final_order = []
|
| 141 |
+
|
| 142 |
+
comm_nodes = []
|
| 143 |
+
for snode in snodes:
|
| 144 |
+
if isinstance(snode.node, ir.CollectiveKernel):
|
| 145 |
+
comm_nodes.append(snode)
|
| 146 |
+
if len(comm_nodes) == 0:
|
| 147 |
+
# if there is no comm nodes, return the current order
|
| 148 |
+
return snodes
|
| 149 |
+
|
| 150 |
+
comm_ancestors = {node: get_ancestors(node) for node in comm_nodes}
|
| 151 |
+
comm_descendants = {node: get_descendants(node) for node in comm_nodes}
|
| 152 |
+
|
| 153 |
+
indeg = dict.fromkeys(snodes, 0)
|
| 154 |
+
for snode in snodes:
|
| 155 |
+
for user in snode.node_users:
|
| 156 |
+
if user in indeg:
|
| 157 |
+
indeg[user] += 1
|
| 158 |
+
ready_to_schedule_nodes = {node for node in snodes if indeg[node] == 0}
|
| 159 |
+
|
| 160 |
+
unscheduled_nodes = set()
|
| 161 |
+
unscheduled_nodes = set(snodes)
|
| 162 |
+
|
| 163 |
+
def schedule_node(snode):
|
| 164 |
+
"""
|
| 165 |
+
Schedule a single node.
|
| 166 |
+
"""
|
| 167 |
+
assert snode in unscheduled_nodes
|
| 168 |
+
assert snode in ready_to_schedule_nodes
|
| 169 |
+
ready_to_schedule_nodes.remove(snode)
|
| 170 |
+
unscheduled_nodes.remove(snode)
|
| 171 |
+
final_order.append(snode)
|
| 172 |
+
for user in tuple_sorted(snode.node_users):
|
| 173 |
+
if user in indeg:
|
| 174 |
+
indeg[user] -= 1
|
| 175 |
+
if indeg[user] == 0:
|
| 176 |
+
ready_to_schedule_nodes.add(user)
|
| 177 |
+
|
| 178 |
+
def schedule_nodes(snodes):
|
| 179 |
+
"""
|
| 180 |
+
Schedules all nodes in `snodes` in an arbitrary topologically valid order.
|
| 181 |
+
"""
|
| 182 |
+
all_nodes = set(snodes)
|
| 183 |
+
assert all(node in unscheduled_nodes for node in all_nodes)
|
| 184 |
+
while len(all_nodes) > 0:
|
| 185 |
+
# NOTE: since model graph is always a DAG and does not have circular dependency inside,
|
| 186 |
+
# there should be at least one node that is a "free node" (i.e. indeg == 0),
|
| 187 |
+
# hence infinite loop is not possible. But we check here just to be safe.
|
| 188 |
+
progress = False
|
| 189 |
+
for node in tuple_sorted(all_nodes):
|
| 190 |
+
if node in ready_to_schedule_nodes:
|
| 191 |
+
schedule_node(node)
|
| 192 |
+
all_nodes.remove(node)
|
| 193 |
+
progress = True
|
| 194 |
+
if not progress:
|
| 195 |
+
raise Exception(
|
| 196 |
+
"Unable to find a free node (indeg == 0). This is an impossible state to reach. "
|
| 197 |
+
"Please report a bug to PyTorch."
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# First, schedule all compute nodes that are required by first comm node,
|
| 201 |
+
# as well as the first comm node itself.
|
| 202 |
+
assert len(comm_nodes) > 0
|
| 203 |
+
schedule_nodes(
|
| 204 |
+
list(comm_ancestors[comm_nodes[0]]) + [comm_nodes[0]],
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
rolled_over_compute_cost = 0
|
| 208 |
+
for idx in range(1, len(comm_ancestors)):
|
| 209 |
+
# Step 1: Given that we've currently scheduled comm `idx-1`, we now schedule
|
| 210 |
+
# all compute nodes that are required for comm `idx` but do not depend on comm `idx-1`,
|
| 211 |
+
# to run at the same time with comm `idx-1`.
|
| 212 |
+
needed_by_next_comm_and_ready_compute_nodes = unscheduled_nodes & (
|
| 213 |
+
comm_ancestors[comm_nodes[idx]] - comm_descendants[comm_nodes[idx - 1]]
|
| 214 |
+
)
|
| 215 |
+
assert_no_comm_nodes(needed_by_next_comm_and_ready_compute_nodes)
|
| 216 |
+
|
| 217 |
+
total_compute_runtime_cost = rolled_over_compute_cost + sum(
|
| 218 |
+
[
|
| 219 |
+
estimate_op_runtime(node)
|
| 220 |
+
for node in needed_by_next_comm_and_ready_compute_nodes
|
| 221 |
+
]
|
| 222 |
+
)
|
| 223 |
+
prev_comm_runtime_cost = estimate_op_runtime(comm_nodes[idx - 1])
|
| 224 |
+
schedule_nodes(tuple_sorted(needed_by_next_comm_and_ready_compute_nodes))
|
| 225 |
+
|
| 226 |
+
# Step 2: If all those compute nodes are sufficient to overlap comm `idx-1`, we're done.
|
| 227 |
+
# Otherwise, we now need to look elsewhere to find compute that overlaps with comm `idx`.
|
| 228 |
+
# We prioritize compute nodes that are needed sooner.
|
| 229 |
+
step1_runtime_cost = total_compute_runtime_cost
|
| 230 |
+
if step1_runtime_cost >= prev_comm_runtime_cost:
|
| 231 |
+
pass
|
| 232 |
+
else:
|
| 233 |
+
# Find all ready to schedule compute nodes that do not depend on comm `idx-1`.
|
| 234 |
+
ready_to_schedule_compute_nodes = tuple_sorted(
|
| 235 |
+
ready_to_schedule_nodes - comm_descendants[comm_nodes[idx - 1]]
|
| 236 |
+
)
|
| 237 |
+
assert_no_comm_nodes(ready_to_schedule_compute_nodes)
|
| 238 |
+
|
| 239 |
+
def earliest_comm_descendant(node):
|
| 240 |
+
for idx in range(len(comm_nodes)):
|
| 241 |
+
if node in comm_ancestors[comm_nodes[idx]]:
|
| 242 |
+
return idx
|
| 243 |
+
return len(comm_nodes)
|
| 244 |
+
|
| 245 |
+
# Prioritize compute nodes that are needed sooner.
|
| 246 |
+
ready_to_schedule_compute_nodes = sorted(
|
| 247 |
+
ready_to_schedule_compute_nodes, key=earliest_comm_descendant
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
for snode in ready_to_schedule_compute_nodes:
|
| 251 |
+
if total_compute_runtime_cost >= prev_comm_runtime_cost:
|
| 252 |
+
# If accumulated compute runtime cost is greater than comm `idx-1` runtime cost,
|
| 253 |
+
# it means we have maximized overlap for comm `idx-1`, and hence we stop looking
|
| 254 |
+
# for more compute to schedule.
|
| 255 |
+
break
|
| 256 |
+
compute_runtime_cost = estimate_op_runtime(snode)
|
| 257 |
+
# If we're not able to leverage more than half of this
|
| 258 |
+
# node's compute to overlap, we skip it.
|
| 259 |
+
# TODO: Smarter heuristics here
|
| 260 |
+
if (
|
| 261 |
+
prev_comm_runtime_cost - total_compute_runtime_cost
|
| 262 |
+
) <= compute_runtime_cost / 2:
|
| 263 |
+
continue
|
| 264 |
+
schedule_node(snode)
|
| 265 |
+
total_compute_runtime_cost += compute_runtime_cost
|
| 266 |
+
rollable_compute_cost = total_compute_runtime_cost - step1_runtime_cost
|
| 267 |
+
|
| 268 |
+
# Step 3: We schedule the compute nodes dependent on comm `idx-1` and required for comm `idx`.
|
| 269 |
+
needed_by_next_comm_nodes = unscheduled_nodes & comm_ancestors[comm_nodes[idx]]
|
| 270 |
+
schedule_nodes(list(needed_by_next_comm_nodes))
|
| 271 |
+
|
| 272 |
+
# Step 4: We schedule comm `idx`.
|
| 273 |
+
schedule_nodes([comm_nodes[idx]])
|
| 274 |
+
|
| 275 |
+
is_prev_comm_blocking_next_comm = len(needed_by_next_comm_nodes) > 0
|
| 276 |
+
# The idea here is that if there are no compute nodes from Step 3
|
| 277 |
+
# (i.e. if prev comm is not blocking next comm), we can roll over the compute nodes
|
| 278 |
+
# in Step 2 to overlap with the next comm, since they're not required to finish
|
| 279 |
+
# before the next comm starts.
|
| 280 |
+
if is_prev_comm_blocking_next_comm:
|
| 281 |
+
rolled_over_compute_cost = 0
|
| 282 |
+
else:
|
| 283 |
+
rolled_over_compute_cost = rollable_compute_cost # type: ignore[assignment]
|
| 284 |
+
|
| 285 |
+
schedule_nodes(unscheduled_nodes)
|
| 286 |
+
return final_order
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def node_summary(snode):
|
| 290 |
+
detail = ""
|
| 291 |
+
if isinstance(snode.node, ir.ExternKernelOut):
|
| 292 |
+
detail = f" ({snode.node.python_kernel_name})"
|
| 293 |
+
out_tensor_info = ""
|
| 294 |
+
if (
|
| 295 |
+
hasattr(snode.node, "layout")
|
| 296 |
+
and hasattr(snode.node.layout, "size")
|
| 297 |
+
and hasattr(snode.node.layout, "stride")
|
| 298 |
+
):
|
| 299 |
+
out_tensor_info = (
|
| 300 |
+
f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})"
|
| 301 |
+
)
|
| 302 |
+
node_name = ""
|
| 303 |
+
if hasattr(snode.node, "name"):
|
| 304 |
+
node_name = snode.node.name
|
| 305 |
+
return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})"
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def visualize_overlap(order):
|
| 309 |
+
total_est_runtime: float = 0.0
|
| 310 |
+
cur_comm_node = None
|
| 311 |
+
for snode in order:
|
| 312 |
+
if cur_comm_node is None:
|
| 313 |
+
if isinstance(snode.node, ir.CollectiveKernel):
|
| 314 |
+
total_est_runtime += estimate_op_runtime(snode)
|
| 315 |
+
cur_comm_node = snode.node
|
| 316 |
+
elif isinstance(snode.node, ir.Wait):
|
| 317 |
+
raise Exception(
|
| 318 |
+
"Wait is not expected when there is no collective running"
|
| 319 |
+
)
|
| 320 |
+
else: # exposed compute op
|
| 321 |
+
total_est_runtime += estimate_op_runtime(snode)
|
| 322 |
+
overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
|
| 323 |
+
else: # cur_comm_node is not None
|
| 324 |
+
if isinstance(snode.node, ir.CollectiveKernel):
|
| 325 |
+
raise Exception(
|
| 326 |
+
"Found two collectives running at the same time. "
|
| 327 |
+
"`visualize_overlap` needs to be updated to handle this case"
|
| 328 |
+
)
|
| 329 |
+
elif isinstance(snode.node, ir.Wait): # end of this comm op
|
| 330 |
+
overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
|
| 331 |
+
cur_comm_node = None
|
| 332 |
+
else: # overlapped compute op
|
| 333 |
+
overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004
|
| 334 |
+
overlap_log.debug(
|
| 335 |
+
f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def reorder_compute_and_comm_for_overlap(
|
| 340 |
+
snodes: List["scheduler.BaseSchedulerNode"],
|
| 341 |
+
) -> List["scheduler.BaseSchedulerNode"]:
|
| 342 |
+
order = snodes
|
| 343 |
+
for p in config.reorder_for_compute_comm_overlap_passes:
|
| 344 |
+
if isinstance(p, str) and p in globals():
|
| 345 |
+
p = globals()[p] # it is a builtin pass
|
| 346 |
+
if torch.distributed.get_rank() == 0:
|
| 347 |
+
overlap_log.debug(
|
| 348 |
+
f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004
|
| 349 |
+
)
|
| 350 |
+
try:
|
| 351 |
+
visualize_overlap(order)
|
| 352 |
+
except Exception as e:
|
| 353 |
+
overlap_log.debug(str(e))
|
| 354 |
+
order = p(order) # type: ignore[operator]
|
| 355 |
+
if torch.distributed.get_rank() == 0:
|
| 356 |
+
overlap_log.debug(
|
| 357 |
+
f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004
|
| 358 |
+
)
|
| 359 |
+
try:
|
| 360 |
+
visualize_overlap(order)
|
| 361 |
+
except Exception as e:
|
| 362 |
+
overlap_log.debug(str(e))
|
| 363 |
+
return order
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/config.py
ADDED
|
@@ -0,0 +1,752 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os # noqa: C101
|
| 2 |
+
import sys
|
| 3 |
+
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def is_fbcode():
|
| 9 |
+
return not hasattr(torch.version, "git_version")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# add some debug printouts
|
| 13 |
+
debug = False
|
| 14 |
+
|
| 15 |
+
# add inf and NaN checkers
|
| 16 |
+
debug_check_inf_and_nan = False
|
| 17 |
+
|
| 18 |
+
# Whether to disable a progress bar for autotuning
|
| 19 |
+
disable_progress = True
|
| 20 |
+
|
| 21 |
+
# Whether to enable printing the source code for each future
|
| 22 |
+
verbose_progress = False
|
| 23 |
+
|
| 24 |
+
# use fx aot graph codegen cache
|
| 25 |
+
fx_graph_cache = os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE") == "1"
|
| 26 |
+
|
| 27 |
+
# use cpp wrapper instead of python wrapper
|
| 28 |
+
cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1"
|
| 29 |
+
|
| 30 |
+
# codegen cpp wrapper code in an ABI compatible mode
|
| 31 |
+
abi_compatible = (
|
| 32 |
+
os.environ.get("TORCHINDUCTOR_ABI_COMPATIBLE", "1" if is_fbcode() else "0") == "1"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
c_shim_version = os.environ.get(
|
| 36 |
+
"TORCHINDUCTOR_C_SHIM_VERSION", "1" if is_fbcode() else "2"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# dead code elimination
|
| 40 |
+
dce = False
|
| 41 |
+
|
| 42 |
+
# assume weight tensors are fixed size
|
| 43 |
+
static_weight_shapes = True
|
| 44 |
+
|
| 45 |
+
# put correctness assertions in generated code
|
| 46 |
+
size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
|
| 47 |
+
nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
|
| 48 |
+
|
| 49 |
+
# enable loop reordering based on input orders
|
| 50 |
+
pick_loop_orders = True
|
| 51 |
+
|
| 52 |
+
# reuse a kernel input as the output
|
| 53 |
+
inplace_buffers = True
|
| 54 |
+
|
| 55 |
+
# reuse a buffer for an unrelated purpose
|
| 56 |
+
allow_buffer_reuse = True
|
| 57 |
+
|
| 58 |
+
# Enable pooled allocations for non-output tensors
|
| 59 |
+
memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1"
|
| 60 |
+
|
| 61 |
+
# How to organize memory under memory_planning=True:
|
| 62 |
+
# - "none": do not try to pool storage, just reuse
|
| 63 |
+
# - "intermediates": all non-outputs share storage, outputs each get unique storage
|
| 64 |
+
# - "outputs": two pools, one for intermediates (freed on return) and one for outputs
|
| 65 |
+
# - "combined": a single pool for both intermediates and outputs
|
| 66 |
+
memory_pool = os.environ.get("TORCHINDUCTOR_MEMORY_POOL", "intermediates")
|
| 67 |
+
|
| 68 |
+
# codegen benchmark harness
|
| 69 |
+
benchmark_harness = True
|
| 70 |
+
|
| 71 |
+
# fuse pointwise into templates
|
| 72 |
+
epilogue_fusion = True
|
| 73 |
+
|
| 74 |
+
# do epilogue fusions before other fusions
|
| 75 |
+
epilogue_fusion_first = False
|
| 76 |
+
|
| 77 |
+
# enable pattern match+replace optimizations
|
| 78 |
+
pattern_matcher = True
|
| 79 |
+
|
| 80 |
+
# register custom graph optimization pass hook. so far, pre/post passes are
|
| 81 |
+
# only applied before/after pattern_matcher in post_grad_passes.
|
| 82 |
+
#
|
| 83 |
+
# def my_custom_pre_pass(graph: torch.fx.graph.Graph):
|
| 84 |
+
# # my custom graph optimization pass
|
| 85 |
+
# ...
|
| 86 |
+
#
|
| 87 |
+
# def my_custom_post_pass(graph: torch.fx.graph.Graph):
|
| 88 |
+
# # my custom graph optimization pass
|
| 89 |
+
# ...
|
| 90 |
+
#
|
| 91 |
+
# torch._inductor.config.post_grad_custom_pre_pass = my_custom_pre_pass
|
| 92 |
+
# torch._inductor.config.post_grad_custom_post_pass = my_custom_post_pass
|
| 93 |
+
post_grad_custom_pre_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
|
| 94 |
+
post_grad_custom_post_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
|
| 95 |
+
|
| 96 |
+
# Registers a custom pregrad pass. Note that the pre-grad IR is 1.
|
| 97 |
+
# non-functional, 2. non-normalized, and 3. prone to change. Ideally we should
|
| 98 |
+
# use post-grad passes.
|
| 99 |
+
pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
|
| 100 |
+
|
| 101 |
+
# Optimize away split cat patterns (Experimental)
|
| 102 |
+
split_cat_fx_passes = True
|
| 103 |
+
|
| 104 |
+
# Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability.
|
| 105 |
+
efficient_conv_bn_eval_fx_passes = False
|
| 106 |
+
|
| 107 |
+
# Enable predispatch aten IR for export
|
| 108 |
+
is_predispatch = False
|
| 109 |
+
|
| 110 |
+
# Deprecated
|
| 111 |
+
group_fusion = False
|
| 112 |
+
|
| 113 |
+
# Deprecated
|
| 114 |
+
batch_fusion = True
|
| 115 |
+
|
| 116 |
+
# Pre grad group/batch fusion and options in order, set to empty dict to disable fusion.
|
| 117 |
+
# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions.
|
| 118 |
+
pre_grad_fusion_options: Dict[str, Dict[str, Any]] = {
|
| 119 |
+
"batch_linear": {},
|
| 120 |
+
"batch_linear_lhs": {},
|
| 121 |
+
"batch_layernorm": {},
|
| 122 |
+
"batch_tanh": {},
|
| 123 |
+
"batch_relu": {},
|
| 124 |
+
"batch_sigmoid": {},
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
# Post grad group/batch fusion and options, set to empty dict to disable fusion.
|
| 128 |
+
# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
|
| 129 |
+
post_grad_fusion_options: Dict[str, Dict[str, Any]] = {}
|
| 130 |
+
|
| 131 |
+
# enable reordering pass for improving memory locality
|
| 132 |
+
reorder_for_locality = True
|
| 133 |
+
|
| 134 |
+
# Scale down RBLOCK for better occupancy
|
| 135 |
+
dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1"
|
| 136 |
+
|
| 137 |
+
# this forces fusion for int_mm with mul. Needed when you want to avoid realizing the int32
|
| 138 |
+
# but the mul gets fused with other pointwise ops instead.
|
| 139 |
+
force_fuse_int_mm_with_mul = False
|
| 140 |
+
|
| 141 |
+
# for pattern torch.mm(a, b.to(dtype)) with cuda tensors,
|
| 142 |
+
# enable torch._inductor.kernel.mm.tuned_mixed_mm fused kernel.
|
| 143 |
+
# Autotune will compare perf with normal cast->then->mm option
|
| 144 |
+
use_mixed_mm = False
|
| 145 |
+
|
| 146 |
+
# enable runtime numeric check for pre/post grad fx passes
|
| 147 |
+
# floating point provides limited accuracy (about 7 decimal digits for single precision
|
| 148 |
+
# floating point numbers,about 16 decimal digits for double precision floating point numbers)
|
| 149 |
+
# according to PyTorch documentation.
|
| 150 |
+
# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations
|
| 151 |
+
fx_passes_numeric_check: Dict[str, Any] = {
|
| 152 |
+
"pre_grad": False,
|
| 153 |
+
"precision": 1e-4,
|
| 154 |
+
"num_iterations": 1,
|
| 155 |
+
"requires_optimizer": True,
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
# for pattern torch.mm(a, b.to(dtype)) with cuda tensors, always use
|
| 159 |
+
# torch._inductor.kernel.mm.tuned_mixed_mm's fused kernel.
|
| 160 |
+
# Autotune will not compare with normal cast->then->mm option.
|
| 161 |
+
# (if force_mixed_mm is true, the use_mixed_mm flag will be ignored)
|
| 162 |
+
force_mixed_mm = False
|
| 163 |
+
|
| 164 |
+
# enable reordering pass for increasing overlap between compute and communication
|
| 165 |
+
reorder_for_compute_comm_overlap = False
|
| 166 |
+
|
| 167 |
+
# passes (in execution order) for increasing overlap between compute and communication
|
| 168 |
+
# for built-in passes, use string name; for user-defined passes, pass in the function handle
|
| 169 |
+
reorder_for_compute_comm_overlap_passes = [
|
| 170 |
+
"reorder_compute_for_overlap",
|
| 171 |
+
"sink_waits",
|
| 172 |
+
"raise_comms",
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
# runtime estimation function for ops
|
| 176 |
+
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
|
| 177 |
+
estimate_op_runtime = "default"
|
| 178 |
+
|
| 179 |
+
# unit: GB/s, uni-directional P2P bandwidth per card
|
| 180 |
+
# default value is NVLink
|
| 181 |
+
intra_node_bw = 300
|
| 182 |
+
|
| 183 |
+
# unit: GB/s, uni-directional P2P bandwidth per node
|
| 184 |
+
# default value is InfiniBand
|
| 185 |
+
inter_node_bw = 25
|
| 186 |
+
|
| 187 |
+
# enable slow autotuning passes to select algorithms
|
| 188 |
+
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
|
| 189 |
+
|
| 190 |
+
# enable slow autotuning passes to select pointwise/reductions algorithms
|
| 191 |
+
max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1"
|
| 192 |
+
|
| 193 |
+
# enable slow autotuning passes to select gemm algorithms
|
| 194 |
+
max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1"
|
| 195 |
+
|
| 196 |
+
# enable autotune local cache
|
| 197 |
+
use_autotune_local_cache = True
|
| 198 |
+
|
| 199 |
+
# enable autotune remote cache
|
| 200 |
+
use_autotune_remote_cache = (
|
| 201 |
+
os.environ.get("TORCH_INDUCTOR_AUTOTUNE_REMOTE_CACHE") == "1"
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# force cublas and triton to use the same precision; cublas supports TF32 for matmul operations
|
| 205 |
+
# when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations
|
| 206 |
+
# for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure
|
| 207 |
+
# that triton does not use TF32 wherever cublas would not use TF32
|
| 208 |
+
force_same_precision = (
|
| 209 |
+
True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1"
|
| 210 |
+
)
|
| 211 |
+
# Specify candidate backends for gemm autotune.
|
| 212 |
+
# Possible choices are combinations of: ATen, Triton, CUTLASS.
|
| 213 |
+
# ATen: default Pytorch ATen kernels.
|
| 214 |
+
# Triton: Triton templates defined in torch inductor.
|
| 215 |
+
# CUTLASS: Cutlass templates and kernels.
|
| 216 |
+
max_autotune_gemm_backends = os.environ.get(
|
| 217 |
+
"TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON"
|
| 218 |
+
).upper()
|
| 219 |
+
|
| 220 |
+
# the value used as a fallback for the unbacked SymInts
|
| 221 |
+
# that can appear in the input shapes (e.g., in autotuning)
|
| 222 |
+
unbacked_symint_fallback = 8192
|
| 223 |
+
|
| 224 |
+
# enable searching global and local cache regardless of `max_autotune`
|
| 225 |
+
search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1"
|
| 226 |
+
|
| 227 |
+
save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1"
|
| 228 |
+
|
| 229 |
+
# We will disable creating subprocess for autotuning if this is False
|
| 230 |
+
autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
|
| 231 |
+
|
| 232 |
+
# If autotuning in subprocess, whether to use multiple devices
|
| 233 |
+
autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1"
|
| 234 |
+
|
| 235 |
+
coordinate_descent_tuning = (
|
| 236 |
+
os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1"
|
| 237 |
+
)
|
| 238 |
+
coordinate_descent_check_all_directions = (
|
| 239 |
+
os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1"
|
| 240 |
+
)
|
| 241 |
+
coordinate_descent_search_radius = int(
|
| 242 |
+
os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1")
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Disabled by default on ROCm, opt-in if model utilises NHWC convolutions
|
| 246 |
+
layout_opt_default = "1" if not torch.version.hip else "0"
|
| 247 |
+
layout_optimization = (
|
| 248 |
+
os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", layout_opt_default) == "1"
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1"
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# Whether to keep the output strides the same as eager after layout optimization.
|
| 255 |
+
keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1"
|
| 256 |
+
|
| 257 |
+
# Enabling this will let compiler print warning messages if a generated triton
|
| 258 |
+
# kernel has inputs with mixed layouts. This is helpful for perf debugging
|
| 259 |
+
# since kernel with mixed layout inputs may run much slower then one whose inputs
|
| 260 |
+
# have uniform layouts.
|
| 261 |
+
warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1"
|
| 262 |
+
|
| 263 |
+
# control store vs recompute heuristic
|
| 264 |
+
# For fanouts, rematerialization can lead to exponential blowup. So, have
|
| 265 |
+
# smaller threshold
|
| 266 |
+
realize_reads_threshold = 4
|
| 267 |
+
realize_opcount_threshold = 30
|
| 268 |
+
|
| 269 |
+
# Threshold to prevent excessive accumulation of ops in one buffer during lowering
|
| 270 |
+
realize_acc_reads_threshold = 8
|
| 271 |
+
|
| 272 |
+
# fallback to eager for random/dropout, this is slow but useful for debugging
|
| 273 |
+
fallback_random = False
|
| 274 |
+
|
| 275 |
+
# automatically create fallbacks when encountering an unhandled op
|
| 276 |
+
implicit_fallbacks = True
|
| 277 |
+
|
| 278 |
+
# fuse even in cases without common reads
|
| 279 |
+
aggressive_fusion = False
|
| 280 |
+
|
| 281 |
+
# For each fused kernel in the wrapper, comment with the nodes that get fused.
|
| 282 |
+
# Useful for debugging fusion.
|
| 283 |
+
debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1"
|
| 284 |
+
benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
|
| 285 |
+
enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "")
|
| 286 |
+
|
| 287 |
+
# how many nodes to allow into a single fusion
|
| 288 |
+
max_fusion_size = 64
|
| 289 |
+
|
| 290 |
+
# max number of inputs to generate cat as a pointwise op with masked laods
|
| 291 |
+
max_pointwise_cat_inputs = 8
|
| 292 |
+
|
| 293 |
+
# replace small reductions with pointwise, disable with `= 1`
|
| 294 |
+
unroll_reductions_threshold = 8
|
| 295 |
+
|
| 296 |
+
# Add extra comments to output code (causes compile cache misses)
|
| 297 |
+
comment_origin = False
|
| 298 |
+
|
| 299 |
+
# Convert 1x1 convs into matmuls
|
| 300 |
+
conv_1x1_as_mm = False
|
| 301 |
+
|
| 302 |
+
# Enable split reductions for better utilization when the dimension
|
| 303 |
+
# being reduced over is large (by splitting it)
|
| 304 |
+
split_reductions = True
|
| 305 |
+
|
| 306 |
+
benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1"
|
| 307 |
+
|
| 308 |
+
# Enable constant and index_expr folding
|
| 309 |
+
constant_and_index_propagation = True
|
| 310 |
+
|
| 311 |
+
# we always add constants into graph.constants without
|
| 312 |
+
# performing any constant-inlining optimization
|
| 313 |
+
always_keep_tensor_constants = False
|
| 314 |
+
|
| 315 |
+
# assert that indirect indexing does not read / write out of bounds
|
| 316 |
+
assert_indirect_indexing = True
|
| 317 |
+
|
| 318 |
+
# constant folding on the joint graph
|
| 319 |
+
joint_graph_constant_folding = True
|
| 320 |
+
|
| 321 |
+
# Enable indirect_indexing asserts for decompositions and lowerings
|
| 322 |
+
debug_index_asserts = False
|
| 323 |
+
|
| 324 |
+
# warnings intended for PyTorch developers, disable for point releases
|
| 325 |
+
is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__
|
| 326 |
+
developer_warnings = is_fbcode() or is_nightly_or_source
|
| 327 |
+
|
| 328 |
+
# The multiprocessing start method to use for inductor workers in the codecache.
|
| 329 |
+
# TODO: fork is not safe in a multithreaded environment, we should evaluate changing
|
| 330 |
+
# the default to spawn.
|
| 331 |
+
worker_start_method = "fork"
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def decide_compile_threads():
|
| 335 |
+
"""
|
| 336 |
+
Here are the precedence to decide compile_threads
|
| 337 |
+
1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by
|
| 338 |
+
setting this to 1 to make pdb happy.
|
| 339 |
+
2. Set to 1 if it's win32 platform or it's a fbcode build
|
| 340 |
+
3. decide by the number of CPU cores
|
| 341 |
+
"""
|
| 342 |
+
if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
|
| 343 |
+
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
|
| 344 |
+
elif sys.platform == "win32" or is_fbcode():
|
| 345 |
+
return 1
|
| 346 |
+
else:
|
| 347 |
+
cpu_count = (
|
| 348 |
+
len(os.sched_getaffinity(0))
|
| 349 |
+
if hasattr(os, "sched_getaffinity")
|
| 350 |
+
else os.cpu_count()
|
| 351 |
+
)
|
| 352 |
+
assert cpu_count
|
| 353 |
+
return min(32, cpu_count)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
compile_threads = decide_compile_threads()
|
| 357 |
+
|
| 358 |
+
# gemm autotuning global cache dir
|
| 359 |
+
if is_fbcode():
|
| 360 |
+
from libfb.py import parutil
|
| 361 |
+
|
| 362 |
+
try:
|
| 363 |
+
if __package__:
|
| 364 |
+
global_cache_dir = parutil.get_dir_path(
|
| 365 |
+
os.path.join(__package__.replace(".", os.sep), "fb/cache")
|
| 366 |
+
)
|
| 367 |
+
else:
|
| 368 |
+
global_cache_dir = parutil.get_dir_path("fb/cache")
|
| 369 |
+
except ValueError:
|
| 370 |
+
global_cache_dir = None
|
| 371 |
+
else:
|
| 372 |
+
global_cache_dir = None
|
| 373 |
+
|
| 374 |
+
# If kernel is fused, the name is generated from the origin node op names
|
| 375 |
+
# for larger kernels limit this
|
| 376 |
+
kernel_name_max_ops = 10
|
| 377 |
+
|
| 378 |
+
# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs
|
| 379 |
+
shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1"
|
| 380 |
+
|
| 381 |
+
# Fx-based linear/matmul/bmm + permute/transpose vertical fusion
|
| 382 |
+
permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"
|
| 383 |
+
|
| 384 |
+
# Mark the wrapper call in PyTorch profiler
|
| 385 |
+
profiler_mark_wrapper_call = False
|
| 386 |
+
|
| 387 |
+
# Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for
|
| 388 |
+
# every intermediate for which we can correlate it with an intermediate
|
| 389 |
+
# from the original FX graph
|
| 390 |
+
generate_intermediate_hooks = False
|
| 391 |
+
|
| 392 |
+
# Populate traceback field on IRNode; good for debugging why origin_node is
|
| 393 |
+
# not populated, or finding out where an IRNode was constructed
|
| 394 |
+
debug_ir_traceback = False
|
| 395 |
+
|
| 396 |
+
# used for debugging to make sure config is properly set
|
| 397 |
+
_raise_error_for_testing = False
|
| 398 |
+
|
| 399 |
+
_profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "")
|
| 400 |
+
profile_bandwidth = _profile_var != ""
|
| 401 |
+
profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var
|
| 402 |
+
# Specify a file where we print out the profiling results.
|
| 403 |
+
# None means we do not dump results to a file.
|
| 404 |
+
profile_bandwidth_output = os.environ.get("TORCHINDUCTOR_PROFILE_OUTPUT", None)
|
| 405 |
+
|
| 406 |
+
# TODO: remove later
|
| 407 |
+
disable_cpp_codegen = False
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
# Freezing will attempt to inline weights as constants in optimization
|
| 411 |
+
# and run constant folding and other optimizations on them. After freezing, weights
|
| 412 |
+
# can no longer be updated.
|
| 413 |
+
freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1"
|
| 414 |
+
|
| 415 |
+
# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead
|
| 416 |
+
# of potentially keeping multiple copies of weights.
|
| 417 |
+
freezing_discard_parameters: bool = False
|
| 418 |
+
|
| 419 |
+
# Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests
|
| 420 |
+
# should be run with this flag both on and off to make sure we have coverage.
|
| 421 |
+
allow_stack_allocation: bool = (
|
| 422 |
+
os.environ.get("TORCHINDUCTOR_STACK_ALLOCATION", "1") == "1"
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# Enables an alternate DSO interface (the "minimal ArrayRef interface") intended
|
| 426 |
+
# to maximize performance for use cases that it can accommodate at the expense of
|
| 427 |
+
# generality. In brief:
|
| 428 |
+
# - inputs and outputs are ArrayRefTensor<T> (note that strides are required, but the
|
| 429 |
+
# tensor must be contiguous)
|
| 430 |
+
# - constant handling is unchanged because it is not a per-inference-iteration bottleneck
|
| 431 |
+
#
|
| 432 |
+
# When the DSO is generated in this mode, the usual interface will also be supported,
|
| 433 |
+
# but performance for that interface may be degraded.
|
| 434 |
+
use_minimal_arrayref_interface: bool = False
|
| 435 |
+
|
| 436 |
+
# decompose some memory bound matmul/bmm to mul
|
| 437 |
+
decompose_mem_bound_mm: bool = False
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# config specific to codegen/cpp.py
|
| 441 |
+
class cpp:
|
| 442 |
+
# set to torch.get_num_threads()
|
| 443 |
+
threads = -1
|
| 444 |
+
|
| 445 |
+
# Do not generate loops when the condition doesn't hold, like:
|
| 446 |
+
# for(long i0=4096; i0<4096; i0+=1)
|
| 447 |
+
no_redundant_loops = True
|
| 448 |
+
|
| 449 |
+
# Assume number of threads is dynamic, don't specialize thread number.
|
| 450 |
+
# Kernels don't recompile on thread number changes with this flag on.
|
| 451 |
+
# For single-threaded workload, turning it on would incur a slight
|
| 452 |
+
# performance degradation.
|
| 453 |
+
dynamic_threads = False
|
| 454 |
+
|
| 455 |
+
simdlen: Optional[int] = None
|
| 456 |
+
min_chunk_size = 4096
|
| 457 |
+
cxx = (
|
| 458 |
+
None, # download gcc12 from conda-forge if conda is installed
|
| 459 |
+
# "g++-12",
|
| 460 |
+
# "g++-11",
|
| 461 |
+
# "g++-10",
|
| 462 |
+
# "clang++",
|
| 463 |
+
os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"),
|
| 464 |
+
# "g++.par",
|
| 465 |
+
)
|
| 466 |
+
# Allow kernel performance profiling via PyTorch profiler
|
| 467 |
+
enable_kernel_profile = False
|
| 468 |
+
|
| 469 |
+
# enable weight prepacking to get a better performance; may lead to large memory footprint
|
| 470 |
+
weight_prepack = True
|
| 471 |
+
|
| 472 |
+
# Inject a bug into our relu implementation; useful for testing our repro
|
| 473 |
+
# extraction and minification functionality.
|
| 474 |
+
# Valid values: "compile_error", "runtime_error", "accuracy"
|
| 475 |
+
inject_relu_bug_TESTING_ONLY: Optional[str] = None
|
| 476 |
+
inject_log1p_bug_TESTING_ONLY: Optional[str] = None
|
| 477 |
+
|
| 478 |
+
# If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise,
|
| 479 |
+
# force usage as specified, without testing.
|
| 480 |
+
vec_isa_ok: Optional[bool] = None
|
| 481 |
+
|
| 482 |
+
# similar to config.triton.descriptive_names
|
| 483 |
+
descriptive_names = "original_aten"
|
| 484 |
+
|
| 485 |
+
# how many nodes to allow into a single horizontal fusion
|
| 486 |
+
max_horizontal_fusion_size = 16
|
| 487 |
+
|
| 488 |
+
# Make scatter_reduce fallback when reduce is sum to avoid performance regression
|
| 489 |
+
# using atomic_add.
|
| 490 |
+
fallback_scatter_reduce_sum = True
|
| 491 |
+
|
| 492 |
+
# Use funsafe-math-optimizations when compiling
|
| 493 |
+
enable_unsafe_math_opt_flag = False
|
| 494 |
+
|
| 495 |
+
# Use ffp-contract when compiling
|
| 496 |
+
enable_floating_point_contract_flag = False
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
# config specific to codegen/triton.py
|
| 500 |
+
class triton:
|
| 501 |
+
# Use cudagraphs on output code
|
| 502 |
+
cudagraphs = False
|
| 503 |
+
|
| 504 |
+
# Use cudagraph trees for memory pooling if `cudagraphs` is True
|
| 505 |
+
cudagraph_trees = True
|
| 506 |
+
|
| 507 |
+
# assertions not on the fast path, steady state
|
| 508 |
+
slow_path_cudagraph_asserts = True
|
| 509 |
+
|
| 510 |
+
# TODO - need to debug why this prevents cleanup
|
| 511 |
+
cudagraph_trees_history_recording = False
|
| 512 |
+
|
| 513 |
+
# assertions on the fast path
|
| 514 |
+
fast_path_cudagraph_asserts = False
|
| 515 |
+
|
| 516 |
+
# skip warmup for cudagraph trees
|
| 517 |
+
skip_cudagraph_warmup = False
|
| 518 |
+
|
| 519 |
+
# Synchronize before and after every compiled graph.
|
| 520 |
+
debug_sync_graph = False
|
| 521 |
+
|
| 522 |
+
# Synchronize after every kernel launch, to help pinpoint bugs
|
| 523 |
+
debug_sync_kernel = False
|
| 524 |
+
|
| 525 |
+
# Always load full blocks (rather than broadcasting inside the block)
|
| 526 |
+
dense_indexing = False
|
| 527 |
+
|
| 528 |
+
# limit tiling dimensions
|
| 529 |
+
max_tiles = 2
|
| 530 |
+
|
| 531 |
+
# use triton.autotune for pointwise ops with complex layouts
|
| 532 |
+
# this should only be disabled for debugging/testing
|
| 533 |
+
autotune_pointwise = True
|
| 534 |
+
|
| 535 |
+
# max autotune gemm with cublasLt
|
| 536 |
+
autotune_cublasLt = True
|
| 537 |
+
|
| 538 |
+
# should we stop a fusion to allow better tiling?
|
| 539 |
+
tiling_prevents_pointwise_fusion = True
|
| 540 |
+
tiling_prevents_reduction_fusion = True
|
| 541 |
+
|
| 542 |
+
# should we give different names to kernels
|
| 543 |
+
# Note: This is orthogonal to descriptive_names - this is deciding whether
|
| 544 |
+
# our triton kernel names should all be `triton_` (to maximize caching) or
|
| 545 |
+
# whether they should be unique.
|
| 546 |
+
unique_kernel_names = os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1"
|
| 547 |
+
|
| 548 |
+
# should we put op names in kernel names
|
| 549 |
+
# False: No special names (just triton__1, triton__2, etc.)
|
| 550 |
+
# "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.)
|
| 551 |
+
# "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions)
|
| 552 |
+
# "inductor_node": Maps to the node name in the FX graph passed to Inductor
|
| 553 |
+
descriptive_names = "original_aten"
|
| 554 |
+
|
| 555 |
+
# use alternate codegen for smaller reductions
|
| 556 |
+
persistent_reductions = (
|
| 557 |
+
os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1"
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
# 0/False: disable
|
| 561 |
+
# 1/True: enable, use tuning to pick between different subkernels
|
| 562 |
+
# 2: enable, force using persistent reduction (for debugging)
|
| 563 |
+
# 3: enable, force using non-persistent reduction (for debugging)
|
| 564 |
+
multi_kernel = int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0"))
|
| 565 |
+
|
| 566 |
+
# hint to Triton when arguments are divisible by 16
|
| 567 |
+
divisible_by_16 = True
|
| 568 |
+
|
| 569 |
+
# theses are not enforced, but they are used by asserts in triton_heuristics.py
|
| 570 |
+
# NOTE: mobilevit_s in timm_models required X to be set to the higher value 2048
|
| 571 |
+
|
| 572 |
+
# Max RBLOCK will be large for multi-kernel since we do more aggressive
|
| 573 |
+
# persistent reduction.
|
| 574 |
+
max_block = {
|
| 575 |
+
"X": 2048,
|
| 576 |
+
"Y": 1024,
|
| 577 |
+
"Z": 1024,
|
| 578 |
+
"R": 4096 * (16 if multi_kernel else 1),
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
# Minimum RBLOCK to be used for a TritonSplitScanKernel
|
| 582 |
+
# NOTE: This also indirectly controls the size of workspace buffer required
|
| 583 |
+
min_split_scan_rblock = 256
|
| 584 |
+
|
| 585 |
+
# Store the generated cubin files for cpp wrapper code to load
|
| 586 |
+
store_cubin = False
|
| 587 |
+
|
| 588 |
+
# the max number of spills we allow for the configs we benchmark.
|
| 589 |
+
# Setting this to 0 means we skip a config if it spills even a single
|
| 590 |
+
# register.
|
| 591 |
+
# Setting it to a larger value allows a config spilling a small amount
|
| 592 |
+
# of registers being benchmarked.
|
| 593 |
+
#
|
| 594 |
+
# NOTE: triton will always report >0 register spills for kernels using sin/cos.
|
| 595 |
+
# (check this issue https://github.com/openai/triton/issues/1756 )
|
| 596 |
+
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
|
| 597 |
+
# Raise the threshold to 16 to be safe.
|
| 598 |
+
# We should revisit this once we understand more of the source of register spills.
|
| 599 |
+
spill_threshold: int = 16
|
| 600 |
+
|
| 601 |
+
# Generate code containing the newer tl.make_block_ptr() API for loads/store
|
| 602 |
+
use_block_ptr = False
|
| 603 |
+
|
| 604 |
+
# Inject a bug into our relu implementation; useful for testing our repro
|
| 605 |
+
# extraction and minification functionality.
|
| 606 |
+
# Valid values: "compile_error", "runtime_error", "accuracy"
|
| 607 |
+
inject_relu_bug_TESTING_ONLY: Optional[str] = None
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
class aot_inductor:
|
| 611 |
+
# AOTInductor output path
|
| 612 |
+
# If an absolute path is specified, the generated lib files will be stored under the directory;
|
| 613 |
+
# If a relative path is specified, it will be used as a subdirectory under the default caching path;
|
| 614 |
+
# If not specified, a temp directory will be created under the default caching path.
|
| 615 |
+
# If the specified path contains something like "model.so", the sub-string will be used
|
| 616 |
+
# to name the generated library.
|
| 617 |
+
output_path = ""
|
| 618 |
+
|
| 619 |
+
debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1"
|
| 620 |
+
|
| 621 |
+
# Serialized tree spec for flattening inputs
|
| 622 |
+
serialized_in_spec = ""
|
| 623 |
+
|
| 624 |
+
# Serialized tree spec for flattening outputs
|
| 625 |
+
serialized_out_spec = ""
|
| 626 |
+
|
| 627 |
+
# flag to decide whether to create a submodule for constant graph.
|
| 628 |
+
use_runtime_constant_folding: bool = False
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
class cuda:
|
| 632 |
+
# CUDA arch to use for CUDA template kernel compilation.
|
| 633 |
+
# e.g. "70", "75", "80", "90", etc.
|
| 634 |
+
# When arch is None, Inductor uses torch.cuda.get_device_capability(0).
|
| 635 |
+
arch: Optional[str] = None
|
| 636 |
+
|
| 637 |
+
# CUDA version to use for CUDA template kernel compilation.
|
| 638 |
+
# e.g. "11.4", "12.1", etc.
|
| 639 |
+
# When version is None, Inductor uses torch.version.cuda.
|
| 640 |
+
version: Optional[str] = None
|
| 641 |
+
|
| 642 |
+
# Optimization level for the host compiler.
|
| 643 |
+
compile_opt_level = "-O1"
|
| 644 |
+
|
| 645 |
+
# Whether to enable device LTO (link-time-optimization).
|
| 646 |
+
enable_cuda_lto = False
|
| 647 |
+
|
| 648 |
+
# Whether to keep intermediate files dring compilation.
|
| 649 |
+
enable_ptxas_info = False
|
| 650 |
+
|
| 651 |
+
# Whether to enable debug info, e.g. line number, cutlass debug info.
|
| 652 |
+
enable_debug_info = False
|
| 653 |
+
|
| 654 |
+
# Whether to use fast math.
|
| 655 |
+
use_fast_math = False
|
| 656 |
+
|
| 657 |
+
# Path to the CUTLASS repo root directory.
|
| 658 |
+
# The default path only works under PyTorch local development environment.
|
| 659 |
+
cutlass_dir = os.environ.get(
|
| 660 |
+
"TORCHINDUCTOR_CUTLASS_DIR",
|
| 661 |
+
os.path.abspath(
|
| 662 |
+
os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/")
|
| 663 |
+
),
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
# Configures the maximum number of CUTLASS configs to profile in max_autotune.
|
| 667 |
+
# By default it's None, so that all CUTLASS configs are tuned.
|
| 668 |
+
# This is mainly used to reduce test time in CI.
|
| 669 |
+
cutlass_max_profiling_configs: Optional[int] = None
|
| 670 |
+
|
| 671 |
+
# Path to CUDA NVCC.
|
| 672 |
+
# NVCC search order:
|
| 673 |
+
# 1) cuda_cxx set in this config
|
| 674 |
+
# 2)CUDACXX environment variable
|
| 675 |
+
# 3)CUDA_HOME environment variable
|
| 676 |
+
# 4) default system search PATH.
|
| 677 |
+
cuda_cxx: Optional[str] = None
|
| 678 |
+
|
| 679 |
+
# If set to True, it will ensure that only GEMM ops capable of
|
| 680 |
+
# epilogue fusion via CUTLASS Epilogue Visitor Trees ( EVT )
|
| 681 |
+
# are enabled for the CUTLASS backend.
|
| 682 |
+
cutlass_only_evt_capable_ops: bool = False
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
# create a directory containing lots of debug information
|
| 686 |
+
class trace:
|
| 687 |
+
# master switch for all debugging flags below
|
| 688 |
+
enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
|
| 689 |
+
|
| 690 |
+
# Save debug information to a temporary directory
|
| 691 |
+
# If not specified, a temp directory will be created by system
|
| 692 |
+
debug_dir: Optional[str] = None
|
| 693 |
+
|
| 694 |
+
# Save python logger call >=logging.DEBUG
|
| 695 |
+
debug_log = False
|
| 696 |
+
|
| 697 |
+
# Save python logger call >=logging.INFO
|
| 698 |
+
info_log = False
|
| 699 |
+
|
| 700 |
+
# Save input FX graph (post decomps, pre optimization)
|
| 701 |
+
fx_graph = True
|
| 702 |
+
|
| 703 |
+
# Save FX graph after transformations
|
| 704 |
+
fx_graph_transformed = True
|
| 705 |
+
|
| 706 |
+
# Save TorchInductor IR before fusion pass
|
| 707 |
+
ir_pre_fusion = True
|
| 708 |
+
|
| 709 |
+
# Save TorchInductor IR after fusion pass
|
| 710 |
+
ir_post_fusion = True
|
| 711 |
+
|
| 712 |
+
# Copy generated code to trace dir
|
| 713 |
+
output_code = True
|
| 714 |
+
|
| 715 |
+
# SVG figure showing post-fusion graph
|
| 716 |
+
graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1"
|
| 717 |
+
|
| 718 |
+
# SVG figure showing fx with fusion
|
| 719 |
+
draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1"
|
| 720 |
+
|
| 721 |
+
# We draw our fx graphs with the "record" shape attribute by default.
|
| 722 |
+
# Sometimes, when the graph is very complex, we may hit dot errors like below:
|
| 723 |
+
# "flat edge between adjacent nodes one of which has a record shape -
|
| 724 |
+
# replace records with HTML-like labels"
|
| 725 |
+
# and thus fail to generate a graph. So, let's give the user an option
|
| 726 |
+
# to specify the shape attribute for the dot graph. For example, passing
|
| 727 |
+
# INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like lables
|
| 728 |
+
# to workaround the above failure.
|
| 729 |
+
dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None)
|
| 730 |
+
|
| 731 |
+
# Store cProfile (see snakeviz to view)
|
| 732 |
+
compile_profile = False
|
| 733 |
+
|
| 734 |
+
# Upload the .tar.gz file
|
| 735 |
+
# Needs to be overriden based on specific environment needs
|
| 736 |
+
upload_tar: Optional[Callable[[str], None]] = None
|
| 737 |
+
|
| 738 |
+
log_autotuning_results: bool = False
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
_save_config_ignore = {
|
| 742 |
+
# workaround: "Can't pickle <function ...>"
|
| 743 |
+
"trace.upload_tar",
|
| 744 |
+
}
|
| 745 |
+
|
| 746 |
+
if TYPE_CHECKING:
|
| 747 |
+
from torch.utils._config_typing import * # noqa: F401, F403
|
| 748 |
+
|
| 749 |
+
from torch.utils._config_module import install_config_module
|
| 750 |
+
|
| 751 |
+
# adds patch, save_config, etc
|
| 752 |
+
install_config_module(sys.modules[__name__])
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
from typing import Any, Callable, Dict, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils._pytree as pytree
|
| 6 |
+
|
| 7 |
+
aten = torch.ops.aten
|
| 8 |
+
|
| 9 |
+
# We would like to split modules into two subgraphs for runtime weight updates to work correctly.
|
| 10 |
+
# The use case and more information could be found at:
|
| 11 |
+
# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing
|
| 12 |
+
META_TAG = "MODULE_TYPE"
|
| 13 |
+
MODULE_TAG = "_MAIN_MODULE"
|
| 14 |
+
CONST_MODULE_TAG = "_CONST_MODULE"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def replace_node_with_constant(gm, node, constant, name=None):
|
| 18 |
+
g = gm.graph
|
| 19 |
+
|
| 20 |
+
if name:
|
| 21 |
+
qualname = name
|
| 22 |
+
else:
|
| 23 |
+
if not hasattr(gm, "_frozen_param_count"):
|
| 24 |
+
gm._frozen_param_count = 0
|
| 25 |
+
i = gm._frozen_param_count
|
| 26 |
+
|
| 27 |
+
while True:
|
| 28 |
+
qualname = f"_frozen_param{i}"
|
| 29 |
+
if not hasattr(gm, qualname):
|
| 30 |
+
break
|
| 31 |
+
i += 1
|
| 32 |
+
|
| 33 |
+
gm._frozen_param_count = i + 1
|
| 34 |
+
|
| 35 |
+
with g.inserting_before(node):
|
| 36 |
+
new_input_node = g.create_node("get_attr", qualname, (), {})
|
| 37 |
+
node.replace_all_uses_with(new_input_node)
|
| 38 |
+
new_input_node.meta.update(node.meta)
|
| 39 |
+
g.erase_node(node)
|
| 40 |
+
|
| 41 |
+
# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
|
| 42 |
+
gm.register_buffer(qualname, constant)
|
| 43 |
+
setattr(gm, qualname, constant)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ConstantFolder(torch.fx.Interpreter):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
gm,
|
| 50 |
+
skip_constructors=False,
|
| 51 |
+
):
|
| 52 |
+
super().__init__(gm)
|
| 53 |
+
self.node_replacements: Dict[torch.fx.Node, Any] = {}
|
| 54 |
+
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
|
| 55 |
+
self.unknown_value = object()
|
| 56 |
+
self.skip_constructors: bool = skip_constructors
|
| 57 |
+
|
| 58 |
+
# overwrite this to deallocate env values if their only remaining use
|
| 59 |
+
# is the output
|
| 60 |
+
self.user_to_last_uses = self.node_to_last_non_output_use()
|
| 61 |
+
|
| 62 |
+
def is_impure(self, node: torch.fx.node.Node):
|
| 63 |
+
if node.target in [
|
| 64 |
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
| 65 |
+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
| 66 |
+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
| 67 |
+
]:
|
| 68 |
+
# For the pattern fp32_weight -> q -> dq
|
| 69 |
+
# We only folding fp32_weight -> q
|
| 70 |
+
# int8_weight and leave dq in graph to be fused
|
| 71 |
+
return True
|
| 72 |
+
return False
|
| 73 |
+
|
| 74 |
+
def node_to_last_non_output_use(self):
|
| 75 |
+
last_non_output_use = collections.defaultdict(list)
|
| 76 |
+
seen_uses = set()
|
| 77 |
+
output_node = next(iter(reversed(self.module.graph.nodes)))
|
| 78 |
+
|
| 79 |
+
for node in reversed(self.module.graph.nodes):
|
| 80 |
+
if node.target == "output":
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
def add_use(inp):
|
| 84 |
+
if inp in seen_uses:
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
+
seen_uses.add(inp)
|
| 88 |
+
last_non_output_use[node].append(inp)
|
| 89 |
+
|
| 90 |
+
pytree.tree_map_only(torch.fx.Node, add_use, (node.args, node.kwargs))
|
| 91 |
+
|
| 92 |
+
# if this node is only used in output, we want to gc it right away
|
| 93 |
+
if len(node.users) == 1 and output_node in node.users:
|
| 94 |
+
last_non_output_use[node].append(node)
|
| 95 |
+
|
| 96 |
+
return last_non_output_use
|
| 97 |
+
|
| 98 |
+
def run_node(self, node):
|
| 99 |
+
if node.target == "output":
|
| 100 |
+
# because we remove nodes from env on last non output use,
|
| 101 |
+
# re-define them now or we'll get error in interpreter
|
| 102 |
+
def set_env(arg):
|
| 103 |
+
self.env[arg] = self.unknown_value
|
| 104 |
+
|
| 105 |
+
pytree.tree_map_only(torch.fx.Node, set_env, node.args)
|
| 106 |
+
return super().run_node(node)
|
| 107 |
+
|
| 108 |
+
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
| 109 |
+
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
|
| 110 |
+
|
| 111 |
+
if self.unknown_value in flattened_inputs:
|
| 112 |
+
return self.unknown_value
|
| 113 |
+
|
| 114 |
+
# TODO - fix errors with this
|
| 115 |
+
if (
|
| 116 |
+
node.op == "call_function"
|
| 117 |
+
and node.target == aten._efficientzerotensor.default
|
| 118 |
+
):
|
| 119 |
+
return self.unknown_value
|
| 120 |
+
|
| 121 |
+
# TODO - constant folding triton kernel returns the inputs -- fix this
|
| 122 |
+
if (
|
| 123 |
+
node.op == "call_function"
|
| 124 |
+
and node.name == "triton_kernel_wrapper_functional_proxy"
|
| 125 |
+
):
|
| 126 |
+
return self.unknown_value
|
| 127 |
+
|
| 128 |
+
# skip constructors, since inductor generates optimal code for them already
|
| 129 |
+
# and turning into tensor would result in an additional global memory read
|
| 130 |
+
# TODO - more complicated strategy
|
| 131 |
+
if (
|
| 132 |
+
self.skip_constructors
|
| 133 |
+
and node.op != "get_attr"
|
| 134 |
+
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
|
| 135 |
+
):
|
| 136 |
+
return self.unknown_value
|
| 137 |
+
|
| 138 |
+
# All mutations should either be removed or on inputs which we did not make constant
|
| 139 |
+
if (
|
| 140 |
+
isinstance(node.target, torch._ops.OpOverload)
|
| 141 |
+
and torch.Tag.nondeterministic_seeded in node.target.tags
|
| 142 |
+
):
|
| 143 |
+
return self.unknown_value
|
| 144 |
+
|
| 145 |
+
out = super().run_node(node)
|
| 146 |
+
|
| 147 |
+
if node.op != "get_attr" and isinstance(out, torch.Tensor):
|
| 148 |
+
if not self.insertable_tensor_check(out):
|
| 149 |
+
return out
|
| 150 |
+
|
| 151 |
+
if self.is_impure(node):
|
| 152 |
+
return self.unknown_value
|
| 153 |
+
|
| 154 |
+
self.add_node_replacement(node, out)
|
| 155 |
+
|
| 156 |
+
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
|
| 157 |
+
|
| 158 |
+
for n in flattened_node_inps:
|
| 159 |
+
if not isinstance(n, torch.fx.Node):
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
self.replaced_uses[n] += 1
|
| 163 |
+
|
| 164 |
+
for to_delete in self.user_to_last_uses.get(node, []):
|
| 165 |
+
if self.replaced_uses[to_delete] == len(to_delete.users):
|
| 166 |
+
self.node_replacements.pop(to_delete, None)
|
| 167 |
+
|
| 168 |
+
return out
|
| 169 |
+
|
| 170 |
+
def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
|
| 171 |
+
return True
|
| 172 |
+
|
| 173 |
+
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
|
| 174 |
+
self.node_replacements[node] = tensor
|
| 175 |
+
|
| 176 |
+
def run(self):
|
| 177 |
+
env = {}
|
| 178 |
+
for n in self.module.graph.nodes:
|
| 179 |
+
if n.op == "placeholder":
|
| 180 |
+
env[n] = self.unknown_value
|
| 181 |
+
return super().run(initial_env=env)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@torch.utils._python_dispatch._disable_current_modes()
|
| 185 |
+
def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None):
|
| 186 |
+
cf = ConstantFolder(gm, skip_constructors=True)
|
| 187 |
+
cf.run()
|
| 188 |
+
|
| 189 |
+
for node, constant in cf.node_replacements.items():
|
| 190 |
+
if constraint_fn is not None and not constraint_fn(node):
|
| 191 |
+
continue
|
| 192 |
+
replace_node_with_constant(gm, node, constant)
|
| 193 |
+
|
| 194 |
+
erased_params = []
|
| 195 |
+
for node in gm.graph.nodes:
|
| 196 |
+
if node.op == "get_attr" and len(node.users) == 0:
|
| 197 |
+
if hasattr(gm, node.target):
|
| 198 |
+
delattr(gm, node.target)
|
| 199 |
+
erased_params.append(node)
|
| 200 |
+
|
| 201 |
+
for node in erased_params:
|
| 202 |
+
gm.graph.erase_node(node)
|
| 203 |
+
|
| 204 |
+
gm.graph.eliminate_dead_code()
|
| 205 |
+
gm.graph.lint()
|
| 206 |
+
gm.recompile()
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
@torch.utils._python_dispatch._disable_current_modes()
|
| 210 |
+
def constant_graph_tag(gm: torch.fx.GraphModule):
|
| 211 |
+
cf = ConstantFolder(gm, skip_constructors=True)
|
| 212 |
+
cf.run()
|
| 213 |
+
|
| 214 |
+
for node in gm.graph.nodes:
|
| 215 |
+
if (
|
| 216 |
+
node.op == "get_attr"
|
| 217 |
+
or node in cf.node_replacements
|
| 218 |
+
or node in cf.replaced_uses
|
| 219 |
+
):
|
| 220 |
+
node.meta[META_TAG] = CONST_MODULE_TAG
|
| 221 |
+
else:
|
| 222 |
+
node.meta[META_TAG] = MODULE_TAG
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 226 |
+
"""
|
| 227 |
+
Construct a GraphModule which corresponds to the part which could be
|
| 228 |
+
constant folded in provided gm.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
constant_graph_tag(gm)
|
| 232 |
+
# We rewrite the tags, if it's a constant being directly consumed, without
|
| 233 |
+
# any folding opportunity, we keep it in main gm.
|
| 234 |
+
for node in gm.graph.nodes:
|
| 235 |
+
if node.op == "get_attr":
|
| 236 |
+
used_to_fold = False
|
| 237 |
+
for u in node.users:
|
| 238 |
+
if u.meta[META_TAG] == CONST_MODULE_TAG:
|
| 239 |
+
used_to_fold = True
|
| 240 |
+
break
|
| 241 |
+
if not used_to_fold:
|
| 242 |
+
node.meta[META_TAG] = MODULE_TAG
|
| 243 |
+
|
| 244 |
+
new_graph = torch.fx.Graph()
|
| 245 |
+
|
| 246 |
+
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
| 247 |
+
output_nodes = []
|
| 248 |
+
for node in gm.graph.nodes:
|
| 249 |
+
if node.meta[META_TAG] == MODULE_TAG:
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
|
| 253 |
+
node_remapping[node] = new_node
|
| 254 |
+
|
| 255 |
+
for user in node.users:
|
| 256 |
+
if user.meta[META_TAG] == MODULE_TAG:
|
| 257 |
+
output_nodes.append(new_node)
|
| 258 |
+
break
|
| 259 |
+
|
| 260 |
+
new_graph.output(tuple(output_nodes))
|
| 261 |
+
new_graph.lint()
|
| 262 |
+
new_gm = torch.fx.GraphModule(gm, new_graph)
|
| 263 |
+
|
| 264 |
+
return new_gm
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/graph.py
ADDED
|
@@ -0,0 +1,1324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import logging
|
| 3 |
+
import operator
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple
|
| 11 |
+
|
| 12 |
+
import sympy
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch._logging
|
| 16 |
+
import torch.fx
|
| 17 |
+
from torch._decomp import get_decompositions
|
| 18 |
+
from torch._dynamo.utils import defake, dynamo_timed
|
| 19 |
+
from torch._logging import LazyString, trace_structured
|
| 20 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 21 |
+
from torch.fx.experimental._backward_state import BackwardState
|
| 22 |
+
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
|
| 23 |
+
from torch.fx.experimental.symbolic_shapes import has_free_symbols, ShapeEnv, SymTypes
|
| 24 |
+
from torch.utils._mode_utils import no_dispatch
|
| 25 |
+
|
| 26 |
+
from . import config, ir
|
| 27 |
+
from .codegen.common import (
|
| 28 |
+
DeviceOpOverrides,
|
| 29 |
+
get_device_op_overrides,
|
| 30 |
+
get_scheduling_for_device,
|
| 31 |
+
get_wrapper_codegen_for_device,
|
| 32 |
+
register_backend_for_device,
|
| 33 |
+
)
|
| 34 |
+
from .codegen.cpp_wrapper_cpu import CppWrapperCpu
|
| 35 |
+
from .codegen.cpp_wrapper_cuda import CppWrapperCuda
|
| 36 |
+
from .codegen.wrapper import WrapperCodeGen
|
| 37 |
+
from .exc import (
|
| 38 |
+
CppWrapperCodeGenError,
|
| 39 |
+
LoweringException,
|
| 40 |
+
MissingOperatorWithDecomp,
|
| 41 |
+
MissingOperatorWithoutDecomp,
|
| 42 |
+
)
|
| 43 |
+
from .ir import (
|
| 44 |
+
Constant,
|
| 45 |
+
FixedLayout,
|
| 46 |
+
InputBuffer,
|
| 47 |
+
Pointwise,
|
| 48 |
+
Reduction,
|
| 49 |
+
StorageBox,
|
| 50 |
+
TensorBox,
|
| 51 |
+
)
|
| 52 |
+
from .lowering import (
|
| 53 |
+
constrain_to_fx_strides,
|
| 54 |
+
FALLBACK_ALLOW_LIST,
|
| 55 |
+
fallback_handler,
|
| 56 |
+
fallback_node_due_to_unsupported_type,
|
| 57 |
+
layout_constraints,
|
| 58 |
+
lowerings,
|
| 59 |
+
make_fallback,
|
| 60 |
+
needs_realized_inputs,
|
| 61 |
+
unsupported_output_tensor,
|
| 62 |
+
)
|
| 63 |
+
from .sizevars import SizeVarAllocator
|
| 64 |
+
from .utils import convert_shape_to_inductor, gather_origins, get_sympy_Expr_dtype
|
| 65 |
+
from .virtualized import V
|
| 66 |
+
|
| 67 |
+
log = logging.getLogger(__name__)
|
| 68 |
+
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
| 69 |
+
output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if config.is_fbcode():
|
| 73 |
+
from torch._inductor.fb.utils import log_module_code
|
| 74 |
+
else:
|
| 75 |
+
|
| 76 |
+
def log_module_code(*args, **kwargs):
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def supported_dtype_of_cpp_wrapper(dtype, cuda):
|
| 81 |
+
supported_dtype = {
|
| 82 |
+
torch.float32,
|
| 83 |
+
torch.float64,
|
| 84 |
+
torch.int64,
|
| 85 |
+
torch.int32,
|
| 86 |
+
torch.int16,
|
| 87 |
+
torch.int8,
|
| 88 |
+
torch.uint8,
|
| 89 |
+
torch.bool,
|
| 90 |
+
torch.bfloat16,
|
| 91 |
+
torch.complex32,
|
| 92 |
+
torch.complex64,
|
| 93 |
+
torch.complex128,
|
| 94 |
+
torch.float16,
|
| 95 |
+
}
|
| 96 |
+
if cuda:
|
| 97 |
+
supported_dtype.add(torch.float8_e4m3fn)
|
| 98 |
+
supported_dtype.add(torch.float8_e5m2)
|
| 99 |
+
supported_dtype.add(torch.float8_e4m3fnuz)
|
| 100 |
+
supported_dtype.add(torch.float8_e5m2fnuz)
|
| 101 |
+
|
| 102 |
+
return dtype in supported_dtype
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def may_get_constant_buffer_dtype(constant_buffer):
|
| 106 |
+
assert isinstance(
|
| 107 |
+
constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
|
| 108 |
+
), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
|
| 109 |
+
if isinstance(constant_buffer, sympy.core.numbers.Integer):
|
| 110 |
+
return torch.int64
|
| 111 |
+
|
| 112 |
+
if isinstance(constant_buffer, sympy.Expr):
|
| 113 |
+
return get_sympy_Expr_dtype(constant_buffer)
|
| 114 |
+
|
| 115 |
+
if constant_buffer.is_integer:
|
| 116 |
+
return torch.int64
|
| 117 |
+
elif constant_buffer.is_float:
|
| 118 |
+
return torch.float32
|
| 119 |
+
else:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def is_magic_method(op):
|
| 124 |
+
magic_ops = {method_to_operator(m) for m in magic_methods}
|
| 125 |
+
return op in magic_ops
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def getattr_recursive(obj, target):
|
| 129 |
+
target_atoms = target.split(".")
|
| 130 |
+
attr_itr = obj
|
| 131 |
+
for i, atom in enumerate(target_atoms):
|
| 132 |
+
if not hasattr(attr_itr, atom):
|
| 133 |
+
raise RuntimeError(
|
| 134 |
+
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
|
| 135 |
+
)
|
| 136 |
+
attr_itr = getattr(attr_itr, atom)
|
| 137 |
+
return attr_itr
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class GraphLowering(torch.fx.Interpreter):
|
| 141 |
+
graph_outputs: List[ir.IRNode]
|
| 142 |
+
|
| 143 |
+
def symbolic_sizes_strides(self, ex: torch.Tensor):
|
| 144 |
+
"""
|
| 145 |
+
Support dynamic shapes and dynamic strides by assigning variables
|
| 146 |
+
to each dimension. We duck-shape tensors, so if two tensors
|
| 147 |
+
have the same size they get assigned the same symbolic variable.
|
| 148 |
+
"""
|
| 149 |
+
if self.reuse_shape_env:
|
| 150 |
+
return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor(
|
| 151 |
+
ex.stride()
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
from torch._dynamo.source import ConstantSource
|
| 155 |
+
|
| 156 |
+
# TODO: this should not be needed once #93059 lands
|
| 157 |
+
# https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816
|
| 158 |
+
# TODO: make a dedicated UnknownSource for this?
|
| 159 |
+
# NB: This is using the legacy default behavior from
|
| 160 |
+
# create_symbolic_sizes_strides_storage_offset but we hope we can
|
| 161 |
+
# just delete this entirely
|
| 162 |
+
source = ConstantSource(
|
| 163 |
+
f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}"
|
| 164 |
+
)
|
| 165 |
+
(
|
| 166 |
+
size,
|
| 167 |
+
stride,
|
| 168 |
+
_,
|
| 169 |
+
) = self._shape_env.create_symbolic_sizes_strides_storage_offset(
|
| 170 |
+
ex,
|
| 171 |
+
source,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size]
|
| 175 |
+
stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride]
|
| 176 |
+
return size, stride
|
| 177 |
+
|
| 178 |
+
def static_sizes_strides(self, ex: torch.Tensor):
|
| 179 |
+
"""
|
| 180 |
+
Primarily used to weights
|
| 181 |
+
"""
|
| 182 |
+
size = [sympy.Integer(i) for i in ex.size()]
|
| 183 |
+
stride = [sympy.Integer(i) for i in ex.stride()]
|
| 184 |
+
return size, stride
|
| 185 |
+
|
| 186 |
+
def init_backend_registration(self):
|
| 187 |
+
if get_scheduling_for_device("cpu") is None:
|
| 188 |
+
from .codegen.cpp import CppScheduling
|
| 189 |
+
|
| 190 |
+
register_backend_for_device("cpu", CppScheduling, WrapperCodeGen)
|
| 191 |
+
|
| 192 |
+
if get_scheduling_for_device("cuda") is None:
|
| 193 |
+
from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
|
| 194 |
+
|
| 195 |
+
# CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
|
| 196 |
+
register_backend_for_device("cuda", CUDACombinedScheduling, WrapperCodeGen)
|
| 197 |
+
|
| 198 |
+
def __init__(
|
| 199 |
+
self,
|
| 200 |
+
gm: torch.fx.GraphModule,
|
| 201 |
+
example_inputs: Optional[List[torch.Tensor]] = None,
|
| 202 |
+
shape_env=None,
|
| 203 |
+
num_static_inputs=None,
|
| 204 |
+
graph_id=None,
|
| 205 |
+
cpp_wrapper=False,
|
| 206 |
+
aot_mode=False,
|
| 207 |
+
user_visible_outputs=frozenset(),
|
| 208 |
+
layout_opt=None,
|
| 209 |
+
extern_node_serializer=None,
|
| 210 |
+
is_inference=False,
|
| 211 |
+
is_const_graph=False,
|
| 212 |
+
const_output_index=None,
|
| 213 |
+
const_code=None,
|
| 214 |
+
const_module=None,
|
| 215 |
+
name=None,
|
| 216 |
+
):
|
| 217 |
+
super().__init__(gm)
|
| 218 |
+
|
| 219 |
+
self.example_inputs = example_inputs
|
| 220 |
+
self.layout_opt = (
|
| 221 |
+
layout_opt
|
| 222 |
+
if layout_opt is not None
|
| 223 |
+
else self.decide_layout_opt(gm, is_inference=is_inference)
|
| 224 |
+
)
|
| 225 |
+
self.num_channels_last_conv = 0
|
| 226 |
+
self.is_inference = is_inference
|
| 227 |
+
self.is_const_graph = is_const_graph
|
| 228 |
+
self.const_code = const_code
|
| 229 |
+
self.const_module = const_module
|
| 230 |
+
|
| 231 |
+
self.extra_traceback = False # we do our own error wrapping
|
| 232 |
+
if shape_env is None:
|
| 233 |
+
shape_env = ShapeEnv()
|
| 234 |
+
self.reuse_shape_env = False
|
| 235 |
+
else:
|
| 236 |
+
self._shape_env = shape_env
|
| 237 |
+
self.reuse_shape_env = True
|
| 238 |
+
self._shape_env = shape_env
|
| 239 |
+
self.sizevars = SizeVarAllocator(shape_env)
|
| 240 |
+
self.graph_input_names: List[str] = []
|
| 241 |
+
self.graph_inputs: Dict[str, TensorBox] = {}
|
| 242 |
+
self.graph_inputs_original: Dict[str, InputBuffer] = {}
|
| 243 |
+
self.device_types: Set[str] = (
|
| 244 |
+
const_module.device_types if const_module else set()
|
| 245 |
+
)
|
| 246 |
+
self.device_idxs: Set[int] = const_module.device_idxs if const_module else set()
|
| 247 |
+
self.cuda = False
|
| 248 |
+
self.buffers: List[ir.Buffer] = []
|
| 249 |
+
self.const_output_index: Dict[str, int] = (
|
| 250 |
+
const_output_index if const_output_index else {}
|
| 251 |
+
)
|
| 252 |
+
self.folded_constants: Set[str] = (
|
| 253 |
+
set(const_output_index.keys()) if const_output_index else set()
|
| 254 |
+
)
|
| 255 |
+
self.constants: Dict[str, torch.Tensor] = (
|
| 256 |
+
const_module.constants if const_module else {}
|
| 257 |
+
)
|
| 258 |
+
self.constant_reprs: Dict[str, str] = {}
|
| 259 |
+
self.removed_buffers: Set[str] = set()
|
| 260 |
+
self.removed_inplace_buffers: Set[str] = set()
|
| 261 |
+
self.mutated_buffers: Set[str] = set()
|
| 262 |
+
self.never_reuse_buffers: Set[str] = set()
|
| 263 |
+
self.inplaced_to_remove: Set[str] = set()
|
| 264 |
+
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
|
| 265 |
+
self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment]
|
| 266 |
+
# See `ProxyExecutor Design Note` in ir.py for more details
|
| 267 |
+
self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
|
| 268 |
+
self.extern_node_serializer: Optional[
|
| 269 |
+
Callable[[List[ir.ExternKernelNode]], Any]
|
| 270 |
+
] = extern_node_serializer
|
| 271 |
+
self.current_node: torch.fx.Node = None # type: ignore[assignment]
|
| 272 |
+
self.num_static_inputs = num_static_inputs
|
| 273 |
+
self.lists: Dict[str, List[str]] = {}
|
| 274 |
+
self.mutated_inputs: Set[str] = set()
|
| 275 |
+
self.mutated_input_idxs: List[int] = []
|
| 276 |
+
self.name_to_buffer: Dict[str, ir.Buffer] = {}
|
| 277 |
+
self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
|
| 278 |
+
self.creation_time = time.time()
|
| 279 |
+
self.name = name
|
| 280 |
+
self.cpp_wrapper = cpp_wrapper
|
| 281 |
+
|
| 282 |
+
# record multi_kernel choice for cpp_wrapper so the second pass knows
|
| 283 |
+
# which sub-kernel is picked. Copy cpp_wrapper to another variable
|
| 284 |
+
# since cpp_wrapper flag is set to false for the first pass of codegen.
|
| 285 |
+
self.record_multi_kernel_choice = cpp_wrapper
|
| 286 |
+
self.multi_kernel_to_choice: Dict[str, int] = {}
|
| 287 |
+
|
| 288 |
+
self.aot_mode = aot_mode
|
| 289 |
+
self.graph_id = graph_id
|
| 290 |
+
self.scheduler: "torch._inductor.scheduler.Scheduler" = None # type: ignore[assignment]
|
| 291 |
+
self.nodes_prefer_channels_last = (
|
| 292 |
+
self.find_nodes_prefer_channels_last() if self.layout_opt else set()
|
| 293 |
+
)
|
| 294 |
+
self._warned_fallback = {"aten.convolution_backward"}
|
| 295 |
+
self.user_visible_outputs = user_visible_outputs
|
| 296 |
+
self.cache_key: str = "" # This is the cache key for the compiled artifact
|
| 297 |
+
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
|
| 298 |
+
self.cache_linemap: List[
|
| 299 |
+
Tuple[int, str]
|
| 300 |
+
] = (
|
| 301 |
+
[]
|
| 302 |
+
) # This is the linemap used by the profiler to mark custom compiled kernels getting run
|
| 303 |
+
# Used if lowering encounters cases where cudagraphs are not supported
|
| 304 |
+
self.disable_cudagraphs_reason: Optional[str] = None
|
| 305 |
+
|
| 306 |
+
# only keeping one node per device for stack trace purposes
|
| 307 |
+
self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
|
| 308 |
+
self.orig_gm: torch.fx.GraphModule = gm.__copy__()
|
| 309 |
+
self.dynamo_flat_name_to_original_fqn = self.module.meta.get(
|
| 310 |
+
"dynamo_flat_name_to_original_fqn", {}
|
| 311 |
+
)
|
| 312 |
+
self.allocated_constant_name = (
|
| 313 |
+
const_module.allocated_constant_name if const_module is not None else {}
|
| 314 |
+
)
|
| 315 |
+
self.init_backend_registration()
|
| 316 |
+
|
| 317 |
+
@staticmethod
|
| 318 |
+
def decide_layout_opt(gm, *, is_inference) -> bool:
|
| 319 |
+
"""
|
| 320 |
+
Decide if we should enable layout optimization for this graph based on
|
| 321 |
+
heuristics.
|
| 322 |
+
"""
|
| 323 |
+
if not config.layout_optimization:
|
| 324 |
+
return False
|
| 325 |
+
|
| 326 |
+
if config.force_layout_optimization:
|
| 327 |
+
return True
|
| 328 |
+
|
| 329 |
+
conv_nodes = [
|
| 330 |
+
n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default
|
| 331 |
+
]
|
| 332 |
+
nconv = len(conv_nodes)
|
| 333 |
+
|
| 334 |
+
if nconv == 0:
|
| 335 |
+
return False
|
| 336 |
+
|
| 337 |
+
# For cpu backend and mkldnn enabled, we always use channels_last for better performance.
|
| 338 |
+
if (
|
| 339 |
+
torch.backends.mkldnn.enabled
|
| 340 |
+
and torch.backends.mkldnn.is_available()
|
| 341 |
+
and all(
|
| 342 |
+
n.args[idx].meta["val"].device == torch.device("cpu")
|
| 343 |
+
for n in conv_nodes
|
| 344 |
+
for idx in [0, 1]
|
| 345 |
+
)
|
| 346 |
+
):
|
| 347 |
+
return True
|
| 348 |
+
|
| 349 |
+
# Following models are skipped due to this:
|
| 350 |
+
# jx_nest_base
|
| 351 |
+
# volo_d1_224
|
| 352 |
+
if len(list(gm.graph.nodes)) >= 300 * nconv:
|
| 353 |
+
log.debug("Skipped layout opt because only a few conv")
|
| 354 |
+
return False
|
| 355 |
+
|
| 356 |
+
if any(
|
| 357 |
+
has_free_symbols(n.args[idx].meta["val"])
|
| 358 |
+
for n in conv_nodes
|
| 359 |
+
for idx in [0, 1]
|
| 360 |
+
):
|
| 361 |
+
log.debug(
|
| 362 |
+
"See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670"
|
| 363 |
+
)
|
| 364 |
+
return False
|
| 365 |
+
|
| 366 |
+
def is_grouped(n):
|
| 367 |
+
return n.args[-1] > 1 and n.args[1].meta["val"].size(1) > 1
|
| 368 |
+
|
| 369 |
+
def is_in_out_channel(n):
|
| 370 |
+
return (
|
| 371 |
+
n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1)
|
| 372 |
+
and n.args[1].meta["val"].size(2) > 1
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
def is_small_channel(n):
|
| 376 |
+
return (
|
| 377 |
+
n.args[1].meta["val"].size(0) <= 64
|
| 378 |
+
and n.args[1].meta["val"].size(1) <= 64
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# only grouped convolutions benchmarked as slower in conv samples for inference only
|
| 382 |
+
if is_inference:
|
| 383 |
+
from torch.utils.flop_counter import FlopCounterMode
|
| 384 |
+
|
| 385 |
+
flop_counts: Dict[str, float] = defaultdict(float)
|
| 386 |
+
for node in conv_nodes:
|
| 387 |
+
success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(
|
| 388 |
+
node
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
if success:
|
| 392 |
+
with FlopCounterMode(display=False) as flop_counter_mode:
|
| 393 |
+
with V.fake_mode:
|
| 394 |
+
node.target(*args, **kwargs)
|
| 395 |
+
|
| 396 |
+
counted_flops = flop_counter_mode.get_total_flops()
|
| 397 |
+
if is_grouped(node):
|
| 398 |
+
node_type = "grouped"
|
| 399 |
+
elif is_small_channel(node):
|
| 400 |
+
node_type = "small"
|
| 401 |
+
elif is_in_out_channel(node):
|
| 402 |
+
node_type = "in_out"
|
| 403 |
+
else:
|
| 404 |
+
node_type = "default"
|
| 405 |
+
|
| 406 |
+
flop_counts[node_type] += counted_flops
|
| 407 |
+
else:
|
| 408 |
+
log.debug("Conv inputs meta not found")
|
| 409 |
+
|
| 410 |
+
# average benchmarked channels last speedup / slowdown, < 1 is speedup.
|
| 411 |
+
# taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/
|
| 412 |
+
# To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb
|
| 413 |
+
GROUPED_MULTIPLIER = 1.358
|
| 414 |
+
DEFAULT_MULTIPLIER = 0.823
|
| 415 |
+
IN_OUT_MULTIPLIER = 0.725
|
| 416 |
+
SMALL_MULTIPLIER = 0.783
|
| 417 |
+
|
| 418 |
+
total_flops = sum(flop_counts.values())
|
| 419 |
+
# TODO - get different values per hardware
|
| 420 |
+
weighted_flops = (
|
| 421 |
+
flop_counts["grouped"] * GROUPED_MULTIPLIER
|
| 422 |
+
+ flop_counts["small"] * SMALL_MULTIPLIER
|
| 423 |
+
+ flop_counts["in_out"] * IN_OUT_MULTIPLIER
|
| 424 |
+
+ flop_counts["default"] * DEFAULT_MULTIPLIER
|
| 425 |
+
)
|
| 426 |
+
do_layout_opt = weighted_flops <= total_flops
|
| 427 |
+
if not do_layout_opt:
|
| 428 |
+
log.debug(
|
| 429 |
+
"Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d",
|
| 430 |
+
total_flops,
|
| 431 |
+
weighted_flops,
|
| 432 |
+
)
|
| 433 |
+
return do_layout_opt
|
| 434 |
+
|
| 435 |
+
# Channels last layout can dramatically hurt grouped conv perf. E.g.
|
| 436 |
+
# Conv with arguments like
|
| 437 |
+
# {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3],
|
| 438 |
+
# "stride": [2, 2], "padding": [1, 1], "groups": 2}
|
| 439 |
+
# slows down 31x using channels last..
|
| 440 |
+
|
| 441 |
+
# But a lot of timm models use depthwise separable convolution which will
|
| 442 |
+
# result in grouped convolution with in-channel size == 1.
|
| 443 |
+
# For those grouped convolution, channels last still helps a lot.
|
| 444 |
+
# E.g.
|
| 445 |
+
# Conv with arguments
|
| 446 |
+
# {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3],
|
| 447 |
+
# "stride": [2, 2], "padding": [1, 1], "groups": 58}
|
| 448 |
+
# get 1.86x speedup with channels last layout.
|
| 449 |
+
#
|
| 450 |
+
# The following heuristics skip using channels-last if the model contains
|
| 451 |
+
# grouped convolution with in-channels > 1.
|
| 452 |
+
if any(map(is_grouped, conv_nodes)):
|
| 453 |
+
log.debug(
|
| 454 |
+
"Skip layout opt because found grouped convolution with >1 in_channels!"
|
| 455 |
+
)
|
| 456 |
+
return False
|
| 457 |
+
|
| 458 |
+
# For some models that contain convolution with larger in-channel than out-channel, applying
|
| 459 |
+
# channels last hurts performance.
|
| 460 |
+
# Following models are skipped due to this:
|
| 461 |
+
# - pytorch_unet
|
| 462 |
+
# - phlippe_densenet (slightly worse)
|
| 463 |
+
# - Background_Matting (1.22x -> 0.821x)
|
| 464 |
+
# - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x)
|
| 465 |
+
if any(map(is_in_out_channel, conv_nodes)):
|
| 466 |
+
log.debug(
|
| 467 |
+
"Skip layout opt because some convolutions have smaller out_channel"
|
| 468 |
+
)
|
| 469 |
+
return False
|
| 470 |
+
|
| 471 |
+
# Following models are skipped due to this:
|
| 472 |
+
# - functorch_maml_omniglot
|
| 473 |
+
if all(map(is_small_channel, conv_nodes)):
|
| 474 |
+
log.debug("Skip layout opt because all convolution channels are too small")
|
| 475 |
+
return False
|
| 476 |
+
|
| 477 |
+
return True
|
| 478 |
+
|
| 479 |
+
def qualify_name(self, name: str) -> str:
|
| 480 |
+
"""Prepend the given name with the graph name if any."""
|
| 481 |
+
if self.name is not None:
|
| 482 |
+
return f"{self.name}_{name}"
|
| 483 |
+
return name
|
| 484 |
+
|
| 485 |
+
def make_subgraph(
|
| 486 |
+
self,
|
| 487 |
+
gm: torch.fx.GraphModule,
|
| 488 |
+
example_inputs: List[torch.Tensor],
|
| 489 |
+
subgraph_name: str,
|
| 490 |
+
) -> "GraphLowering":
|
| 491 |
+
"""
|
| 492 |
+
Make a subgraph of the current graph with all inherited
|
| 493 |
+
parts, except the graph module (`gm`) and `example_inputs`.
|
| 494 |
+
The subgraphs are lowered separately, but intended to be
|
| 495 |
+
inlined in the parent graph's codegening. Hence the need
|
| 496 |
+
for maintaining the same `shape_env` and other properties.
|
| 497 |
+
The subgraph name is qualified by the parent graph's name.
|
| 498 |
+
"""
|
| 499 |
+
return GraphLowering(
|
| 500 |
+
gm=gm,
|
| 501 |
+
example_inputs=example_inputs,
|
| 502 |
+
shape_env=self._shape_env,
|
| 503 |
+
cpp_wrapper=self.cpp_wrapper,
|
| 504 |
+
aot_mode=self.aot_mode,
|
| 505 |
+
extern_node_serializer=self.extern_node_serializer,
|
| 506 |
+
is_inference=self.is_inference,
|
| 507 |
+
name=self.qualify_name(subgraph_name),
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
def find_nodes_prefer_channels_last(self):
|
| 511 |
+
"""
|
| 512 |
+
The rule to decide if an node prefer channels last is simple.
|
| 513 |
+
1. if it's input/output of a convolution
|
| 514 |
+
2. if one of its user prefers channels last
|
| 515 |
+
|
| 516 |
+
We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs;
|
| 517 |
+
Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers
|
| 518 |
+
channels last.
|
| 519 |
+
|
| 520 |
+
Consider the scenario: conv -> batch-norm -> relu -> conv
|
| 521 |
+
Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies:
|
| 522 |
+
1. the output of batch-norm should be channels last initially since its input is a conv's output.
|
| 523 |
+
Forcing the batch-norm's output to be contiguous results in the first copy
|
| 524 |
+
2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output.
|
| 525 |
+
We need convert it to channels last layout which results in the second copy.
|
| 526 |
+
With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies
|
| 527 |
+
can be saved.
|
| 528 |
+
"""
|
| 529 |
+
output_set = set()
|
| 530 |
+
for n in reversed(self.module.graph.nodes):
|
| 531 |
+
if n.target == torch.ops.aten.convolution.default:
|
| 532 |
+
output_set.add(n)
|
| 533 |
+
continue
|
| 534 |
+
|
| 535 |
+
for user in n.users:
|
| 536 |
+
if user in output_set:
|
| 537 |
+
output_set.add(n)
|
| 538 |
+
break
|
| 539 |
+
|
| 540 |
+
# need a second pass to add downstream nodes of those channel last nodes to the sets.
|
| 541 |
+
# This pass is especially needed to avoid mix-layout kernel inputs in backward pass.
|
| 542 |
+
#
|
| 543 |
+
# Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned
|
| 544 |
+
# from the fwd graph. Without this second pass, we will force relu's output to be contiguous.
|
| 545 |
+
# Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last
|
| 546 |
+
# tensors and passed to a kernel.
|
| 547 |
+
#
|
| 548 |
+
# This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x.
|
| 549 |
+
# It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x .
|
| 550 |
+
# This also helps the following models:
|
| 551 |
+
# - res2net101_26w_4s
|
| 552 |
+
# - res2net50_14w_8s
|
| 553 |
+
# - sebotnet33ts_256
|
| 554 |
+
for n in self.module.graph.nodes:
|
| 555 |
+
if n in output_set:
|
| 556 |
+
for child in n.users:
|
| 557 |
+
output_set.add(child)
|
| 558 |
+
|
| 559 |
+
return output_set
|
| 560 |
+
|
| 561 |
+
def warn_fallback(self, name):
|
| 562 |
+
if name not in self._warned_fallback:
|
| 563 |
+
self._warned_fallback.add(name)
|
| 564 |
+
perf_hint_log.info("Using FallbackKernel: %s", name)
|
| 565 |
+
|
| 566 |
+
def add_device_info(self, device: torch.device):
|
| 567 |
+
self.device_types.add(device.type)
|
| 568 |
+
if device.index is not None:
|
| 569 |
+
self.device_idxs.add(device.index)
|
| 570 |
+
if V.graph.current_node and device not in self.device_node_mapping:
|
| 571 |
+
self.device_node_mapping[device] = V.graph.current_node
|
| 572 |
+
|
| 573 |
+
@property
|
| 574 |
+
def fake_mode(self):
|
| 575 |
+
return V.fake_mode
|
| 576 |
+
|
| 577 |
+
def get_buffer(self, buffer_name: str):
|
| 578 |
+
if buffer_name in self.name_to_buffer:
|
| 579 |
+
return self.name_to_buffer[buffer_name]
|
| 580 |
+
if buffer_name in self.graph_inputs:
|
| 581 |
+
return self.graph_inputs[buffer_name]
|
| 582 |
+
return None
|
| 583 |
+
|
| 584 |
+
def get_dtype(self, buffer_name: str):
|
| 585 |
+
if buffer_name in self.constants:
|
| 586 |
+
return self.constants[buffer_name].dtype
|
| 587 |
+
if buffer_name in self.name_to_buffer:
|
| 588 |
+
return self.name_to_buffer[buffer_name].get_dtype()
|
| 589 |
+
if buffer_name in self.graph_inputs:
|
| 590 |
+
return self.graph_inputs[buffer_name].get_dtype()
|
| 591 |
+
m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name)
|
| 592 |
+
if m:
|
| 593 |
+
return self.get_dtype(m.group(1))
|
| 594 |
+
raise KeyError(f"could not find {buffer_name}")
|
| 595 |
+
|
| 596 |
+
def get_numel(self, buffer_name: str):
|
| 597 |
+
from .ir import MultiOutputLayout
|
| 598 |
+
|
| 599 |
+
if buffer_name in self.constants:
|
| 600 |
+
return self.constants[buffer_name].numel()
|
| 601 |
+
if buffer_name in self.name_to_buffer:
|
| 602 |
+
buf = self.name_to_buffer[buffer_name]
|
| 603 |
+
if isinstance(getattr(buf, "layout", None), MultiOutputLayout):
|
| 604 |
+
return 1
|
| 605 |
+
return buf.get_numel()
|
| 606 |
+
if buffer_name in self.graph_inputs:
|
| 607 |
+
return self.graph_inputs[buffer_name].get_numel()
|
| 608 |
+
raise KeyError(f"could not find {buffer_name}")
|
| 609 |
+
|
| 610 |
+
@dynamo_timed
|
| 611 |
+
def run(self, *args):
|
| 612 |
+
return super().run(*args)
|
| 613 |
+
|
| 614 |
+
def register_buffer(self, buffer: ir.Buffer):
|
| 615 |
+
name = self.qualify_name(f"buf{len(self.buffers)}")
|
| 616 |
+
self.buffers.append(buffer)
|
| 617 |
+
self.name_to_buffer[name] = buffer
|
| 618 |
+
# Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
|
| 619 |
+
if not isinstance(buffer, ir.ComputedBuffer) or not buffer.is_zero_elements():
|
| 620 |
+
self.add_device_info(buffer.get_device())
|
| 621 |
+
return name
|
| 622 |
+
|
| 623 |
+
def register_list(self, buffer_names: List[str]):
|
| 624 |
+
name = self.qualify_name("list_" + "_".join(buffer_names))
|
| 625 |
+
self.lists[name] = buffer_names
|
| 626 |
+
return name
|
| 627 |
+
|
| 628 |
+
def register_users_of(self, node_output):
|
| 629 |
+
def register(value):
|
| 630 |
+
if isinstance(value, (list, tuple)):
|
| 631 |
+
for x in value:
|
| 632 |
+
register(x)
|
| 633 |
+
if isinstance(value, ir.IRNode):
|
| 634 |
+
if (
|
| 635 |
+
not hasattr(value, "data")
|
| 636 |
+
or not isinstance(value.data, ir.IRNode)
|
| 637 |
+
or not (
|
| 638 |
+
hasattr(value.data, "data")
|
| 639 |
+
and isinstance(value.data.data, ir.IRNode)
|
| 640 |
+
)
|
| 641 |
+
):
|
| 642 |
+
return
|
| 643 |
+
|
| 644 |
+
for read_name in value.get_read_names():
|
| 645 |
+
self.name_to_users[read_name].append(value)
|
| 646 |
+
|
| 647 |
+
register(node_output)
|
| 648 |
+
|
| 649 |
+
def mark_buffer_mutated(self, name: str):
|
| 650 |
+
"""
|
| 651 |
+
When a buffer is mutated we need to make sure all the reads to
|
| 652 |
+
the old version are realized before the mutation happens.
|
| 653 |
+
"""
|
| 654 |
+
assert isinstance(name, str)
|
| 655 |
+
self.mutated_buffers.add(name)
|
| 656 |
+
|
| 657 |
+
if name not in self.name_to_users:
|
| 658 |
+
return
|
| 659 |
+
|
| 660 |
+
for user in self.name_to_users[name]:
|
| 661 |
+
user.realize()
|
| 662 |
+
|
| 663 |
+
def add_tensor_constant(self, data, name=None):
|
| 664 |
+
def allocate(name):
|
| 665 |
+
if not config.aot_inductor.use_runtime_constant_folding:
|
| 666 |
+
for constant_name, value in self.constants.items():
|
| 667 |
+
if (
|
| 668 |
+
not data.is_mkldnn
|
| 669 |
+
and data.size() == value.size()
|
| 670 |
+
and data.stride() == value.stride()
|
| 671 |
+
and data.dtype == value.dtype
|
| 672 |
+
and data.device == value.device
|
| 673 |
+
and torch.eq(data, value).all()
|
| 674 |
+
):
|
| 675 |
+
return constant_name
|
| 676 |
+
|
| 677 |
+
if name is None:
|
| 678 |
+
name = f"constant{len(self.constants)}"
|
| 679 |
+
if name[0].isdigit():
|
| 680 |
+
name = f"constant_{name}"
|
| 681 |
+
name = self.qualify_name(name)
|
| 682 |
+
# We may generate a var name for each constant in the codegen.
|
| 683 |
+
# Let's only keep sane characters.
|
| 684 |
+
prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
| 685 |
+
name = prefix
|
| 686 |
+
cnt = 0
|
| 687 |
+
while name in self.constants:
|
| 688 |
+
name = f"{prefix}_{cnt}"
|
| 689 |
+
cnt += 1
|
| 690 |
+
self.constants[name] = data
|
| 691 |
+
self.constant_reprs[name] = (
|
| 692 |
+
f"{data.device!r} {data.dtype!r} "
|
| 693 |
+
f"{tuple(data.size())!r} {tuple(data.stride())!r} "
|
| 694 |
+
f"{hash(data):x}"
|
| 695 |
+
)
|
| 696 |
+
return name
|
| 697 |
+
|
| 698 |
+
new_name = allocate(name)
|
| 699 |
+
self.allocated_constant_name[new_name] = name
|
| 700 |
+
|
| 701 |
+
return TensorBox.create(
|
| 702 |
+
ir.ConstantBuffer(
|
| 703 |
+
new_name,
|
| 704 |
+
FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)),
|
| 705 |
+
)
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
def constant_name(self, name: str, device_override: Optional[torch.device]):
|
| 709 |
+
"""
|
| 710 |
+
We AOT copy constants to the devices they are needed on.
|
| 711 |
+
If device_override doesn't match the constant's device, then
|
| 712 |
+
copy it and return a different name.
|
| 713 |
+
"""
|
| 714 |
+
if self.constants[name].device == device_override or device_override is None:
|
| 715 |
+
return name
|
| 716 |
+
alt_name = f"{name}_{device_override.type}{device_override.index or 0}"
|
| 717 |
+
if alt_name not in self.constants:
|
| 718 |
+
self.constants[alt_name] = self.constants[name].to(device_override)
|
| 719 |
+
return alt_name
|
| 720 |
+
|
| 721 |
+
def placeholder(self, target: str, args, kwargs):
|
| 722 |
+
example = super().placeholder(target, args, kwargs)
|
| 723 |
+
self.graph_input_names.append(target)
|
| 724 |
+
if isinstance(example, SymTypes):
|
| 725 |
+
expr = example.node.expr
|
| 726 |
+
self.graph_inputs[target] = expr
|
| 727 |
+
return expr
|
| 728 |
+
elif isinstance(example, (int, bool, float)):
|
| 729 |
+
expr = sympy.sympify(example)
|
| 730 |
+
self.graph_inputs[target] = expr
|
| 731 |
+
return expr
|
| 732 |
+
if isinstance(example, BackwardState):
|
| 733 |
+
# Ignored arg, must be unused
|
| 734 |
+
# Alternately we could filter this out in AotAutograd
|
| 735 |
+
return None
|
| 736 |
+
assert isinstance(example, torch.Tensor), example
|
| 737 |
+
# todo(chilli): We can remove the last check once we turn buffers into
|
| 738 |
+
# static shape tensors. That's a hack to workaround Inductor believing
|
| 739 |
+
# the buffer should be static but us passing in a fake tensor with
|
| 740 |
+
# symbolic shapes.
|
| 741 |
+
if not example._has_symbolic_sizes_strides:
|
| 742 |
+
# the first N inputs are weights
|
| 743 |
+
sizes, strides = self.static_sizes_strides(example)
|
| 744 |
+
else:
|
| 745 |
+
sizes, strides = self.symbolic_sizes_strides(example)
|
| 746 |
+
# TODO(jansel): handle input aliasing
|
| 747 |
+
target = self.qualify_name(target)
|
| 748 |
+
tensor = TensorBox.create(
|
| 749 |
+
InputBuffer(
|
| 750 |
+
target,
|
| 751 |
+
FixedLayout(example.device, example.dtype, sizes, strides),
|
| 752 |
+
)
|
| 753 |
+
)
|
| 754 |
+
self.graph_inputs[target] = tensor
|
| 755 |
+
self.graph_inputs_original[target] = tensor.data.data
|
| 756 |
+
self.add_device_info(example.device)
|
| 757 |
+
return tensor
|
| 758 |
+
|
| 759 |
+
def call_function(self, target, args, kwargs):
|
| 760 |
+
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
|
| 761 |
+
return super().call_function(target, args, kwargs)
|
| 762 |
+
|
| 763 |
+
if hasattr(target, "_inductor_lowering_function"):
|
| 764 |
+
# passthrough lowerings from .pattern_matcher
|
| 765 |
+
return target(*args, **kwargs)
|
| 766 |
+
|
| 767 |
+
def get_custom_op_layout_constraints(target, args, kwargs):
|
| 768 |
+
# Custom operations that require preserving stride order
|
| 769 |
+
# which run through implicit fallback must constrain their
|
| 770 |
+
# arguments' fx strides
|
| 771 |
+
layout_constraint = None
|
| 772 |
+
if torch._C.Tag.needs_fixed_stride_order in target.tags:
|
| 773 |
+
# We have to set the current args because call_function will immediately
|
| 774 |
+
# evaluate this lowering after creating the fallback, without evaluating
|
| 775 |
+
# the layout constraint
|
| 776 |
+
args, kwargs = constrain_to_fx_strides(
|
| 777 |
+
self.current_node, *args, **kwargs
|
| 778 |
+
)
|
| 779 |
+
# Also register the layout constraint so when the fallback
|
| 780 |
+
# is used again, we can constrain the args to the same layout
|
| 781 |
+
layout_constraint = constrain_to_fx_strides
|
| 782 |
+
return layout_constraint, args, kwargs
|
| 783 |
+
|
| 784 |
+
if target not in lowerings:
|
| 785 |
+
assert isinstance(
|
| 786 |
+
target, torch._ops.OpOverload
|
| 787 |
+
), f"{target} is not an OpOverload"
|
| 788 |
+
base_name = target.name().split(".")[0]
|
| 789 |
+
if base_name in FALLBACK_ALLOW_LIST:
|
| 790 |
+
make_fallback(target)
|
| 791 |
+
elif config.implicit_fallbacks:
|
| 792 |
+
layout_constraint, args, kwargs = get_custom_op_layout_constraints(
|
| 793 |
+
target, args, kwargs
|
| 794 |
+
)
|
| 795 |
+
error = (
|
| 796 |
+
MissingOperatorWithDecomp
|
| 797 |
+
if get_decompositions([target])
|
| 798 |
+
else MissingOperatorWithoutDecomp
|
| 799 |
+
)
|
| 800 |
+
log.info(
|
| 801 |
+
"Creating implicit fallback for:\n%s",
|
| 802 |
+
error.operator_str(target, args, kwargs),
|
| 803 |
+
)
|
| 804 |
+
make_fallback(target, layout_constraint)
|
| 805 |
+
|
| 806 |
+
elif get_decompositions([target]):
|
| 807 |
+
# There isn't a good way to dynamically patch this in
|
| 808 |
+
# since AOT Autograd already ran. The error message tells
|
| 809 |
+
# the user how to fix it.
|
| 810 |
+
raise MissingOperatorWithDecomp(target, args, kwargs)
|
| 811 |
+
else:
|
| 812 |
+
raise MissingOperatorWithoutDecomp(target, args, kwargs)
|
| 813 |
+
|
| 814 |
+
try:
|
| 815 |
+
log.debug(" via %s", lowerings[target])
|
| 816 |
+
out = lowerings[target](*args, **kwargs)
|
| 817 |
+
return out
|
| 818 |
+
except Exception as e:
|
| 819 |
+
raise LoweringException(e, target, args, kwargs).with_traceback(
|
| 820 |
+
e.__traceback__
|
| 821 |
+
) from None
|
| 822 |
+
|
| 823 |
+
@staticmethod
|
| 824 |
+
def can_inline_constant(t: torch.Tensor) -> bool:
|
| 825 |
+
"""
|
| 826 |
+
True if this is a small constant attr that will be inlined.
|
| 827 |
+
"""
|
| 828 |
+
return len(t.shape) == 1 and t.shape[0] <= 8
|
| 829 |
+
|
| 830 |
+
def get_attr(self, target, args, kwargs):
|
| 831 |
+
# this is a constant
|
| 832 |
+
value = getattr_recursive(self.module, target)
|
| 833 |
+
|
| 834 |
+
if isinstance(value, torch.fx.GraphModule):
|
| 835 |
+
return ir.Subgraph(name=target, graph_module=value)
|
| 836 |
+
|
| 837 |
+
if (
|
| 838 |
+
config.aot_inductor.use_runtime_constant_folding
|
| 839 |
+
or config.always_keep_tensor_constants
|
| 840 |
+
or unsupported_output_tensor(value)
|
| 841 |
+
):
|
| 842 |
+
return self.add_tensor_constant(value, target)
|
| 843 |
+
|
| 844 |
+
with no_dispatch():
|
| 845 |
+
if value.shape == ():
|
| 846 |
+
return Constant(value.item(), value.dtype, value.device)
|
| 847 |
+
if self.can_inline_constant(value):
|
| 848 |
+
# tensor lowering has constant inlining logic
|
| 849 |
+
from .lowering import tensor
|
| 850 |
+
|
| 851 |
+
return tensor(value.tolist(), dtype=value.dtype, device=value.device)
|
| 852 |
+
|
| 853 |
+
return self.add_tensor_constant(value, target)
|
| 854 |
+
|
| 855 |
+
def call_module(self, target, args, kwargs):
|
| 856 |
+
raise AssertionError()
|
| 857 |
+
|
| 858 |
+
def call_method(self, target, args, kwargs):
|
| 859 |
+
raise AssertionError()
|
| 860 |
+
|
| 861 |
+
def output(self, target, args, kwargs):
|
| 862 |
+
result = super().output(target, args, kwargs)
|
| 863 |
+
assert isinstance(result, (tuple, list)), type(result)
|
| 864 |
+
assert all(
|
| 865 |
+
isinstance(
|
| 866 |
+
x,
|
| 867 |
+
(
|
| 868 |
+
TensorBox,
|
| 869 |
+
ir.Constant,
|
| 870 |
+
type(None),
|
| 871 |
+
ir.ConstantBuffer,
|
| 872 |
+
sympy.Expr,
|
| 873 |
+
sympy.logic.boolalg.Boolean,
|
| 874 |
+
int,
|
| 875 |
+
),
|
| 876 |
+
)
|
| 877 |
+
for x in result
|
| 878 |
+
), result
|
| 879 |
+
self.graph_outputs = [ir.ExternKernel.realize_input(x) for x in result]
|
| 880 |
+
value: ir.IRNode
|
| 881 |
+
for name, value in self.graph_inputs.items():
|
| 882 |
+
assert isinstance(
|
| 883 |
+
value, (TensorBox, sympy.Expr)
|
| 884 |
+
), f"Unsupported inductor graph input type: {type(value)}"
|
| 885 |
+
if not isinstance(value, TensorBox):
|
| 886 |
+
continue
|
| 887 |
+
value.realize()
|
| 888 |
+
assert isinstance(value, TensorBox)
|
| 889 |
+
value = value.data
|
| 890 |
+
assert isinstance(value, ir.StorageBox)
|
| 891 |
+
value_storage_box = value
|
| 892 |
+
value = value.data
|
| 893 |
+
if not isinstance(value, InputBuffer) or value.get_name() != name:
|
| 894 |
+
# one of our inputs was mutated, need to turn that into a copy
|
| 895 |
+
ir.MutationLayout.realize_into(value, self.graph_inputs_original[name])
|
| 896 |
+
# replace output with mutated input
|
| 897 |
+
try:
|
| 898 |
+
ind = self.graph_outputs.index(value_storage_box)
|
| 899 |
+
self.graph_outputs[ind] = self.graph_inputs_original[name]
|
| 900 |
+
except ValueError:
|
| 901 |
+
pass
|
| 902 |
+
|
| 903 |
+
self.finalize()
|
| 904 |
+
log.debug(
|
| 905 |
+
"Force channels last inputs for %d conv for the current graph with id %d",
|
| 906 |
+
self.num_channels_last_conv,
|
| 907 |
+
self.graph_id if self.graph_id is not None else -1,
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
def finalize(self):
|
| 911 |
+
for buf in self.buffers:
|
| 912 |
+
buf.decide_layout()
|
| 913 |
+
|
| 914 |
+
@contextmanager
|
| 915 |
+
def set_current_node(self, node: torch.fx.Node):
|
| 916 |
+
old = self.current_node
|
| 917 |
+
try:
|
| 918 |
+
self.current_node = node
|
| 919 |
+
yield
|
| 920 |
+
finally:
|
| 921 |
+
self.current_node = old
|
| 922 |
+
|
| 923 |
+
def run_node(self, n: torch.fx.Node):
|
| 924 |
+
def debug(msg):
|
| 925 |
+
log.debug("lowering %s %s", LazyString(n.format_node), msg)
|
| 926 |
+
|
| 927 |
+
origins = {n}
|
| 928 |
+
if n.op == "call_function":
|
| 929 |
+
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
| 930 |
+
origins |= gather_origins(args, kwargs)
|
| 931 |
+
with ir.IRNode.current_origins(origins), self.set_current_node(
|
| 932 |
+
n
|
| 933 |
+
), V.set_current_node(n):
|
| 934 |
+
if (
|
| 935 |
+
n.op == "call_function"
|
| 936 |
+
and n.target is not operator.getitem
|
| 937 |
+
and fallback_node_due_to_unsupported_type(n)
|
| 938 |
+
):
|
| 939 |
+
debug("fallback_handler")
|
| 940 |
+
result = fallback_handler(n.target, add_to_fallback_set=False)(
|
| 941 |
+
*args, **kwargs # type: ignore[possibly-undefined]
|
| 942 |
+
)
|
| 943 |
+
elif n.op == "call_function" and n.target in layout_constraints:
|
| 944 |
+
debug("layout_constraints")
|
| 945 |
+
args, kwargs = layout_constraints[n.target](n, *args, **kwargs) # type: ignore[index]
|
| 946 |
+
result = self.call_function(n.target, args, kwargs)
|
| 947 |
+
elif is_magic_method(n.target):
|
| 948 |
+
# TODO: this is sus, it probably should be handled in the
|
| 949 |
+
# lowerings themselves similarly to sym_size/sym-stride
|
| 950 |
+
debug("is_magic_method")
|
| 951 |
+
if isinstance(n.meta["val"], torch.SymInt):
|
| 952 |
+
result = n.meta["val"].node.expr
|
| 953 |
+
else:
|
| 954 |
+
result = super().run_node(n)
|
| 955 |
+
else:
|
| 956 |
+
debug("")
|
| 957 |
+
result = super().run_node(n)
|
| 958 |
+
|
| 959 |
+
# require the same stride order for dense outputs,
|
| 960 |
+
# 1. user-land view() will not throw because inductor
|
| 961 |
+
# output different strides than eager
|
| 962 |
+
# long term the solution is to make view() always succeed
|
| 963 |
+
# with infallible strides.
|
| 964 |
+
# 2: as_strided ops, we need make sure its input has same size/stride with
|
| 965 |
+
# eager model to align with eager behavior.
|
| 966 |
+
as_strided_ops = [
|
| 967 |
+
torch.ops.aten.as_strided.default,
|
| 968 |
+
torch.ops.aten.as_strided_.default,
|
| 969 |
+
torch.ops.aten.as_strided_scatter.default,
|
| 970 |
+
]
|
| 971 |
+
is_output = any(user.op == "output" for user in n.users)
|
| 972 |
+
is_input_for_as_strided = any(
|
| 973 |
+
user.target in as_strided_ops for user in n.users
|
| 974 |
+
)
|
| 975 |
+
if (
|
| 976 |
+
is_output
|
| 977 |
+
and isinstance(result, TensorBox)
|
| 978 |
+
and isinstance(result.data, ir.BaseView)
|
| 979 |
+
):
|
| 980 |
+
# Realize so that outputs are correctly aliased
|
| 981 |
+
result.realize()
|
| 982 |
+
|
| 983 |
+
if (is_output or is_input_for_as_strided) and isinstance(
|
| 984 |
+
n.meta["val"], torch.Tensor
|
| 985 |
+
):
|
| 986 |
+
strides = n.meta["val"].stride()
|
| 987 |
+
dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"])
|
| 988 |
+
# requiring a stride order for a non-dense output wouldn't
|
| 989 |
+
# recreate the same strides, and would fail with view, defer for now.
|
| 990 |
+
if dense and len(strides):
|
| 991 |
+
stride_order = ir.get_stride_order(strides)
|
| 992 |
+
if (
|
| 993 |
+
len(result.get_size()) == 4
|
| 994 |
+
and n in self.nodes_prefer_channels_last
|
| 995 |
+
and n.name not in self.user_visible_outputs
|
| 996 |
+
and not is_input_for_as_strided
|
| 997 |
+
):
|
| 998 |
+
stride_order = ir.NHWC_STRIDE_ORDER
|
| 999 |
+
result = ir.ExternKernel.require_stride_order(result, stride_order)
|
| 1000 |
+
|
| 1001 |
+
# Realize if (1) any user need inputs realized, or (2) there is
|
| 1002 |
+
# already too many reads and rematerializing can be bad.
|
| 1003 |
+
num_users = len(set(n.users))
|
| 1004 |
+
if num_users > 1 and isinstance(result, TensorBox):
|
| 1005 |
+
for user in n.users:
|
| 1006 |
+
if user.target in needs_realized_inputs:
|
| 1007 |
+
result.realize_hint()
|
| 1008 |
+
# This inclusion is somewhat controversial (from
|
| 1009 |
+
# discussion between Horace, Natalia, and Elias).
|
| 1010 |
+
# Currently, it's not very clear why this is helpful.
|
| 1011 |
+
# The general idea here is that even though a node may
|
| 1012 |
+
# have FlexibleLayout, we still often *treat* it as if
|
| 1013 |
+
# it was contiguous. This appears to sometimes result in
|
| 1014 |
+
# suboptimal behavior.
|
| 1015 |
+
#
|
| 1016 |
+
# When we do a better job selecting layout, we should
|
| 1017 |
+
# revisit this.
|
| 1018 |
+
need_fixed_layout = [
|
| 1019 |
+
torch.ops.aten.convolution_backward.default,
|
| 1020 |
+
torch.ops.aten.mm.default,
|
| 1021 |
+
torch.ops.aten._int_mm.default,
|
| 1022 |
+
]
|
| 1023 |
+
if not self.layout_opt:
|
| 1024 |
+
need_fixed_layout.append(torch.ops.aten.convolution.default)
|
| 1025 |
+
if torch._C._has_mkldnn:
|
| 1026 |
+
need_fixed_layout += [
|
| 1027 |
+
torch.ops.mkldnn._convolution_pointwise.default,
|
| 1028 |
+
torch.ops.mkldnn._convolution_pointwise.binary,
|
| 1029 |
+
torch.ops.mkldnn._convolution_pointwise_.binary,
|
| 1030 |
+
torch.ops.mkldnn._convolution_transpose_pointwise.default,
|
| 1031 |
+
torch.ops.mkldnn._linear_pointwise.default,
|
| 1032 |
+
torch.ops.mkldnn._linear_pointwise.binary,
|
| 1033 |
+
torch.ops.aten.mkldnn_rnn_layer.default,
|
| 1034 |
+
torch.ops.onednn.qconv2d_pointwise.default,
|
| 1035 |
+
torch.ops.onednn.qconv2d_pointwise.binary,
|
| 1036 |
+
torch.ops.onednn.qlinear_pointwise.default,
|
| 1037 |
+
torch.ops.onednn.qlinear_pointwise.tensor,
|
| 1038 |
+
]
|
| 1039 |
+
if torch._C.has_mkl:
|
| 1040 |
+
need_fixed_layout += [torch.ops.mkl._mkl_linear.default]
|
| 1041 |
+
if user.target in need_fixed_layout:
|
| 1042 |
+
result = ir.ExternKernel.require_stride_order(
|
| 1043 |
+
result, ir.get_stride_order(n.meta["val"].stride())
|
| 1044 |
+
)
|
| 1045 |
+
if user.op == "output":
|
| 1046 |
+
if isinstance(result.data.data, (Pointwise, Reduction)):
|
| 1047 |
+
result.realize()
|
| 1048 |
+
|
| 1049 |
+
# TODO(jansel): introduce a store vs inline choice
|
| 1050 |
+
result.mark_reuse(len(n.users))
|
| 1051 |
+
|
| 1052 |
+
# Realize if the IRNode already has accumulated lots of reads
|
| 1053 |
+
if isinstance(result, TensorBox) and result.has_exceeded_max_reads():
|
| 1054 |
+
# Prevent excessive accumulation in a computed buffer, when
|
| 1055 |
+
# there are multiple branches each with small number of memory
|
| 1056 |
+
# reads, but they converge to a user.
|
| 1057 |
+
result.realize_hint()
|
| 1058 |
+
|
| 1059 |
+
# Realize if a Pointwise has too much stuff to be inlined.
|
| 1060 |
+
# As this may cause RecursionError during Inductor's evaluation.
|
| 1061 |
+
if isinstance(result, TensorBox) and isinstance(result.data, StorageBox):
|
| 1062 |
+
curr = result.data.data
|
| 1063 |
+
if isinstance(curr, Pointwise):
|
| 1064 |
+
# Use inner fn as a rough proxy. Good enough.
|
| 1065 |
+
if curr.has_large_inner_fn():
|
| 1066 |
+
result.realize()
|
| 1067 |
+
|
| 1068 |
+
# This is not complete, but it doesn't have to be: origin_node
|
| 1069 |
+
# tracking is best effort. The logic here critically relies on direct
|
| 1070 |
+
# TensorBox -> StorageBox denoting a non-view; we don't bother trying
|
| 1071 |
+
# to get views to work. Feel free to add any extra cases as needed.
|
| 1072 |
+
#
|
| 1073 |
+
# Note: we can't YOLO tree_map over this result, because if there are
|
| 1074 |
+
# buffers or a view involved, we might not be able to validly assign
|
| 1075 |
+
# the origin_node here.
|
| 1076 |
+
if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox):
|
| 1077 |
+
if isinstance(result.data.data, ir.Loops):
|
| 1078 |
+
result.data.data.origin_node = n
|
| 1079 |
+
elif isinstance(result.data.data, ir.Buffer):
|
| 1080 |
+
result.data.data.origin_node = n
|
| 1081 |
+
if isinstance(result.data.data, ir.ComputedBuffer) and isinstance(
|
| 1082 |
+
result.data.data.data, ir.Loops
|
| 1083 |
+
):
|
| 1084 |
+
result.data.data.data.origin_node = n
|
| 1085 |
+
# Not really multi-output, can straightforwardly recurse in
|
| 1086 |
+
elif (
|
| 1087 |
+
isinstance(result.data.data, ir.MultiOutput)
|
| 1088 |
+
and not result.data.data.indices
|
| 1089 |
+
):
|
| 1090 |
+
if isinstance(result.data.data.inputs[0], ir.Buffer):
|
| 1091 |
+
result.data.data.inputs[0].origin_node = n
|
| 1092 |
+
|
| 1093 |
+
self.register_users_of(result)
|
| 1094 |
+
|
| 1095 |
+
return result
|
| 1096 |
+
|
| 1097 |
+
def validate_can_generate_cpp_wrapper(self):
|
| 1098 |
+
if config.disable_cpp_codegen:
|
| 1099 |
+
raise CppWrapperCodeGenError("C++ codegen is disabled")
|
| 1100 |
+
|
| 1101 |
+
if sys.platform not in ["linux", "darwin"]:
|
| 1102 |
+
raise CppWrapperCodeGenError(f"Unsupported platform {sys.platform}")
|
| 1103 |
+
|
| 1104 |
+
for value in self.graph_inputs.values():
|
| 1105 |
+
dtype = None
|
| 1106 |
+
if isinstance(value, TensorBox):
|
| 1107 |
+
dtype = value.get_dtype()
|
| 1108 |
+
elif isinstance(
|
| 1109 |
+
value, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
|
| 1110 |
+
):
|
| 1111 |
+
dtype = may_get_constant_buffer_dtype(value)
|
| 1112 |
+
|
| 1113 |
+
if not supported_dtype_of_cpp_wrapper(dtype, self.cuda):
|
| 1114 |
+
raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}")
|
| 1115 |
+
|
| 1116 |
+
def init_wrapper_code(self):
|
| 1117 |
+
self.cuda = "cuda" in self.device_types
|
| 1118 |
+
if self.cpp_wrapper:
|
| 1119 |
+
self.validate_can_generate_cpp_wrapper()
|
| 1120 |
+
self.wrapper_code = CppWrapperCuda() if self.cuda else CppWrapperCpu()
|
| 1121 |
+
else:
|
| 1122 |
+
device_types = self.device_types.copy()
|
| 1123 |
+
device_types.discard("cpu")
|
| 1124 |
+
# TODO(Eikan): Only support mixing cpu and other device now.
|
| 1125 |
+
assert len(device_types) <= 1, "Does not support mixing {}".format(
|
| 1126 |
+
"+".join(device_types)
|
| 1127 |
+
)
|
| 1128 |
+
only_cpu = len(device_types) == 0
|
| 1129 |
+
device_type = "cpu" if only_cpu else device_types.pop()
|
| 1130 |
+
|
| 1131 |
+
self.device_ops = get_device_op_overrides(device_type)
|
| 1132 |
+
wrapper_code_gen_cls = get_wrapper_codegen_for_device(device_type)
|
| 1133 |
+
assert (
|
| 1134 |
+
wrapper_code_gen_cls is not None
|
| 1135 |
+
), f"Device {device_type} not supported"
|
| 1136 |
+
self.wrapper_code = wrapper_code_gen_cls()
|
| 1137 |
+
|
| 1138 |
+
if self.const_module:
|
| 1139 |
+
# If we have const module, we could reuse the kernels
|
| 1140 |
+
# This could avoid duplication and save time on doing recompilation (if Triton.)
|
| 1141 |
+
self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter
|
| 1142 |
+
self.wrapper_code.src_to_kernel = (
|
| 1143 |
+
self.const_module.wrapper_code.src_to_kernel
|
| 1144 |
+
)
|
| 1145 |
+
|
| 1146 |
+
def codegen_with_cpp_wrapper(self):
|
| 1147 |
+
"""
|
| 1148 |
+
For CPU, the cpp wrapper codegen is done in one pass.
|
| 1149 |
+
For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python
|
| 1150 |
+
wrapper code and run it to generate autotuned kernel binaries in the first pass; and then
|
| 1151 |
+
generate cpp wrapper code and compile it to a dynamic library in the second pass.
|
| 1152 |
+
"""
|
| 1153 |
+
if "cuda" in self.device_types:
|
| 1154 |
+
# first pass
|
| 1155 |
+
self.cpp_wrapper = False
|
| 1156 |
+
compiled = self.compile_to_module().call
|
| 1157 |
+
|
| 1158 |
+
def materialize(x):
|
| 1159 |
+
if isinstance(x, (torch.SymInt, torch.SymFloat)):
|
| 1160 |
+
# Need concrete value to run dynamic shapes and tune the result
|
| 1161 |
+
return x.node.hint
|
| 1162 |
+
elif isinstance(x, FakeTensor):
|
| 1163 |
+
return defake(x)
|
| 1164 |
+
else:
|
| 1165 |
+
assert isinstance(
|
| 1166 |
+
x, torch.Tensor
|
| 1167 |
+
), "Unknown type when creating real inputs" + str(type(x))
|
| 1168 |
+
return x
|
| 1169 |
+
|
| 1170 |
+
if tracing_context := torch._guards.TracingContext.try_get():
|
| 1171 |
+
if tracing_context.output_strides:
|
| 1172 |
+
tracing_context.output_strides.clear()
|
| 1173 |
+
|
| 1174 |
+
params_flat = [
|
| 1175 |
+
param
|
| 1176 |
+
for param in tracing_context.params_flat # type: ignore[union-attr]
|
| 1177 |
+
if param is not None
|
| 1178 |
+
]
|
| 1179 |
+
real_inputs = [
|
| 1180 |
+
materialize(x) for x in itertools.chain(params_flat, V.real_inputs)
|
| 1181 |
+
]
|
| 1182 |
+
else:
|
| 1183 |
+
real_inputs = [materialize(x) for x in V.real_inputs]
|
| 1184 |
+
|
| 1185 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 1186 |
+
assert self.example_inputs is not None
|
| 1187 |
+
compiled(real_inputs)
|
| 1188 |
+
del real_inputs
|
| 1189 |
+
|
| 1190 |
+
# second pass
|
| 1191 |
+
# TODO: reuse self.scheduler from the first pass to speed up the second pass
|
| 1192 |
+
self.cpp_wrapper = True
|
| 1193 |
+
self.removed_buffers.clear()
|
| 1194 |
+
self.inplaced_to_remove.clear()
|
| 1195 |
+
return self.codegen()
|
| 1196 |
+
else:
|
| 1197 |
+
# cpu
|
| 1198 |
+
return self.codegen()
|
| 1199 |
+
|
| 1200 |
+
def codegen(self):
|
| 1201 |
+
from .scheduler import Scheduler
|
| 1202 |
+
|
| 1203 |
+
self.init_wrapper_code()
|
| 1204 |
+
|
| 1205 |
+
self.scheduler = Scheduler(self.buffers)
|
| 1206 |
+
V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
|
| 1207 |
+
|
| 1208 |
+
self.scheduler.codegen()
|
| 1209 |
+
return self.wrapper_code.generate(self.is_inference)
|
| 1210 |
+
|
| 1211 |
+
def codegen_subgraph(self, parent_graph):
|
| 1212 |
+
"""
|
| 1213 |
+
This is a more compact version of the `codegen()` above
|
| 1214 |
+
where we codegen this graph as a subgraph of some parent
|
| 1215 |
+
graph. The parent graph is passed as an argument: the
|
| 1216 |
+
intention is to inline codegening of the subgraph in
|
| 1217 |
+
the parent graph's wrapper code (including the generated
|
| 1218 |
+
kerenls). The wrapper code is not finalized (via `.generate()`
|
| 1219 |
+
call), as this will be done in the parent graph's `codegen()`.
|
| 1220 |
+
"""
|
| 1221 |
+
from .scheduler import Scheduler
|
| 1222 |
+
|
| 1223 |
+
self.wrapper_code = parent_graph.wrapper_code
|
| 1224 |
+
self.device_ops = parent_graph.device_ops
|
| 1225 |
+
self.cpp_wrapper = parent_graph.cpp_wrapper
|
| 1226 |
+
|
| 1227 |
+
self.scheduler = Scheduler(self.buffers)
|
| 1228 |
+
self.scheduler.codegen()
|
| 1229 |
+
|
| 1230 |
+
def count_bytes(self):
|
| 1231 |
+
from .scheduler import Scheduler
|
| 1232 |
+
|
| 1233 |
+
scheduler = Scheduler(self.buffers)
|
| 1234 |
+
|
| 1235 |
+
total_bytes = 0
|
| 1236 |
+
node_counts = []
|
| 1237 |
+
node_runtimes = []
|
| 1238 |
+
for node in scheduler.nodes:
|
| 1239 |
+
num_bytes = node.get_read_write_buffers_sizes()
|
| 1240 |
+
total_bytes += num_bytes
|
| 1241 |
+
node_counts.append((node, num_bytes // 4))
|
| 1242 |
+
node_runtimes.append((node, node.get_estimated_runtime()))
|
| 1243 |
+
return total_bytes, node_counts, node_runtimes
|
| 1244 |
+
|
| 1245 |
+
@dynamo_timed(phase_name="code_gen")
|
| 1246 |
+
def compile_to_module(self):
|
| 1247 |
+
from .codecache import PyCodeCache
|
| 1248 |
+
|
| 1249 |
+
code, linemap = (
|
| 1250 |
+
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
|
| 1251 |
+
)
|
| 1252 |
+
linemap = [(line_no, node.stack_trace) for line_no, node in linemap]
|
| 1253 |
+
key, path = PyCodeCache.write(code)
|
| 1254 |
+
mod = PyCodeCache.load_by_key_path(
|
| 1255 |
+
key, path, linemap=linemap, attrs=self.constants
|
| 1256 |
+
)
|
| 1257 |
+
self.cache_key = key
|
| 1258 |
+
self.cache_path = path
|
| 1259 |
+
self.cache_linemap = linemap
|
| 1260 |
+
|
| 1261 |
+
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
|
| 1262 |
+
# TODO. Revisit this once the logging API is more mature
|
| 1263 |
+
assert mod.__file__ is not None
|
| 1264 |
+
|
| 1265 |
+
log_module_code(mod.__file__)
|
| 1266 |
+
log.debug("Output code written to: %s", mod.__file__)
|
| 1267 |
+
output_code_log.debug("Output code: \n%s", code)
|
| 1268 |
+
trace_structured(
|
| 1269 |
+
"inductor_output_code",
|
| 1270 |
+
lambda: {"filename": mod.__file__},
|
| 1271 |
+
payload_fn=lambda: code,
|
| 1272 |
+
)
|
| 1273 |
+
output_code_log.info("Output code written to: %s", mod.__file__)
|
| 1274 |
+
if config.benchmark_kernel:
|
| 1275 |
+
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
|
| 1276 |
+
V.debug.output_code(mod.__file__)
|
| 1277 |
+
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
|
| 1278 |
+
return mod
|
| 1279 |
+
|
| 1280 |
+
def compile_to_fn(self):
|
| 1281 |
+
if self.aot_mode:
|
| 1282 |
+
from .codecache import AotCodeCompiler
|
| 1283 |
+
|
| 1284 |
+
assert self.cpp_wrapper, "AOT mode only supports C++ wrapper"
|
| 1285 |
+
code, linemap = self.codegen_with_cpp_wrapper()
|
| 1286 |
+
output_code_log.debug("Output code: \n%s", code)
|
| 1287 |
+
|
| 1288 |
+
serialized_extern_kernel_nodes = None
|
| 1289 |
+
if (
|
| 1290 |
+
config.is_fbcode()
|
| 1291 |
+
and self.extern_kernel_nodes
|
| 1292 |
+
and self.extern_node_serializer
|
| 1293 |
+
):
|
| 1294 |
+
serialized_extern_kernel_nodes = self.extern_node_serializer(
|
| 1295 |
+
self.extern_kernel_nodes
|
| 1296 |
+
)
|
| 1297 |
+
output_code_log.debug(
|
| 1298 |
+
"Serialized Extern Kernel Nodes: \n%s",
|
| 1299 |
+
serialized_extern_kernel_nodes,
|
| 1300 |
+
)
|
| 1301 |
+
|
| 1302 |
+
# Directly return the file path with the compiled code
|
| 1303 |
+
return AotCodeCompiler.compile(
|
| 1304 |
+
self, code, serialized_extern_kernel_nodes, cuda=self.cuda
|
| 1305 |
+
)
|
| 1306 |
+
else:
|
| 1307 |
+
return self.compile_to_module().call
|
| 1308 |
+
|
| 1309 |
+
def get_output_names(self):
|
| 1310 |
+
return [
|
| 1311 |
+
node.get_name()
|
| 1312 |
+
for node in self.graph_outputs
|
| 1313 |
+
if not isinstance(node, ir.NoneAsConstantBuffer)
|
| 1314 |
+
and not isinstance(node, ir.ShapeAsConstantBuffer)
|
| 1315 |
+
]
|
| 1316 |
+
|
| 1317 |
+
def is_unspec_arg(self, name: str):
|
| 1318 |
+
# dynamo wraps unspec variable as 0d CPU tensor,
|
| 1319 |
+
# need to convert to scalar during codegen (triton only)
|
| 1320 |
+
return (
|
| 1321 |
+
name in self.graph_inputs.keys()
|
| 1322 |
+
and self.graph_inputs[name].get_numel() == 1
|
| 1323 |
+
and self.graph_inputs[name].get_device().type == "cpu"
|
| 1324 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ir.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py
ADDED
|
@@ -0,0 +1,1524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
import functools
|
| 5 |
+
import inspect
|
| 6 |
+
import itertools
|
| 7 |
+
import logging
|
| 8 |
+
import operator
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
from typing import (
|
| 13 |
+
Any,
|
| 14 |
+
Callable,
|
| 15 |
+
DefaultDict,
|
| 16 |
+
Dict,
|
| 17 |
+
Iterable,
|
| 18 |
+
List,
|
| 19 |
+
NoReturn,
|
| 20 |
+
Optional,
|
| 21 |
+
Set,
|
| 22 |
+
Union,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from typing_extensions import TypeGuard
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch._guards
|
| 29 |
+
import torch.fx
|
| 30 |
+
import torch.utils._pytree as pytree
|
| 31 |
+
from torch._dispatch.python import enable_python_dispatcher
|
| 32 |
+
from torch._dynamo.utils import counters
|
| 33 |
+
from torch._prims_common import is_integer_dtype
|
| 34 |
+
from torch.fx import Node
|
| 35 |
+
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
|
| 36 |
+
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
| 37 |
+
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
| 38 |
+
|
| 39 |
+
from .._functorch import config as functorch_config
|
| 40 |
+
from .._functorch.aot_autograd import aot_function, make_boxed_func
|
| 41 |
+
from .._functorch.partitioners import default_partition
|
| 42 |
+
from .._subclasses import FakeTensorMode
|
| 43 |
+
from ..fx import Transformer
|
| 44 |
+
from . import config
|
| 45 |
+
from .decomposition import select_decomp_table
|
| 46 |
+
from .lowering import fallback_node_due_to_unsupported_type
|
| 47 |
+
|
| 48 |
+
log = logging.getLogger(__name__)
|
| 49 |
+
aten = torch.ops.aten
|
| 50 |
+
prims = torch.ops.prims
|
| 51 |
+
|
| 52 |
+
Constant = Any
|
| 53 |
+
NodeOrConstant = Union[Constant, torch.fx.Node]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Multiple:
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Sentinel indicating multiple quantities can be matched
|
| 61 |
+
MULTIPLE = Multiple()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Match:
|
| 65 |
+
"""
|
| 66 |
+
Represents a successfully matched pattern.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, pattern: PatternExpr, args=None, kwargs=None):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.pattern = pattern
|
| 72 |
+
# The input nodes that must be passed in to the result
|
| 73 |
+
self.args = args or []
|
| 74 |
+
self.kwargs = kwargs or {}
|
| 75 |
+
# The nodes matched in this expression
|
| 76 |
+
self.nodes: List[torch.fx.Node] = []
|
| 77 |
+
# Mapping CallFunction to the node.target
|
| 78 |
+
self.targets: Dict[_TargetExpr, torch.fx.node.Target] = {}
|
| 79 |
+
self.ctx: Optional[MatchContext] = None
|
| 80 |
+
self.replacement_graph: Optional[torch.fx.Graph] = None
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def graph(self) -> torch.fx.Graph:
|
| 84 |
+
assert self.ctx
|
| 85 |
+
return self.ctx.graph
|
| 86 |
+
|
| 87 |
+
def extend(self, other: Match):
|
| 88 |
+
if self.kwargs:
|
| 89 |
+
for key in set(self.kwargs.keys()) & set(other.kwargs.keys()):
|
| 90 |
+
if self.kwargs[key] != other.kwargs[key]:
|
| 91 |
+
raise FailedMatch("kwarg mismatch: {}", key)
|
| 92 |
+
self.args.extend(other.args)
|
| 93 |
+
self.nodes.extend(other.nodes)
|
| 94 |
+
self.kwargs.update(other.kwargs)
|
| 95 |
+
self.targets.update(other.targets)
|
| 96 |
+
|
| 97 |
+
def bundle(self) -> Match:
|
| 98 |
+
# Wrap args in an extra list
|
| 99 |
+
self.args = [tuple(self.args)] if self.args else []
|
| 100 |
+
return self
|
| 101 |
+
|
| 102 |
+
def __repr__(self):
|
| 103 |
+
return f"Match(..., {self.args}, {self.kwargs})"
|
| 104 |
+
|
| 105 |
+
def erase_nodes(self, graph: torch.fx.Graph):
|
| 106 |
+
for n in reversed(self.nodes):
|
| 107 |
+
if not n._erased:
|
| 108 |
+
graph.erase_node(n)
|
| 109 |
+
|
| 110 |
+
def output_nodes(self) -> List[Optional[torch.fx.Node]]:
|
| 111 |
+
assert self.ctx
|
| 112 |
+
return [
|
| 113 |
+
(self.ctx.pattern_to_node[p] if p is not None else None)
|
| 114 |
+
for p in self.ctx.outputs
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
def output_node(self) -> torch.fx.Node:
|
| 118 |
+
return next(p for p in self.output_nodes() if p)
|
| 119 |
+
|
| 120 |
+
def replace_with_graph(self, replacement_graph, args):
|
| 121 |
+
assert self.ctx
|
| 122 |
+
ReplacementPatternEntry.replace_with_graph(
|
| 123 |
+
self, self.ctx.graph, replacement_graph, args
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def replace_by_example(self, replacement_fn, args, trace_fn=None, run_dce=True):
|
| 127 |
+
assert self.ctx
|
| 128 |
+
if trace_fn is None:
|
| 129 |
+
trace_fn = functools.partial(fwd_only, run_dce=run_dce)
|
| 130 |
+
replacement = trace_fn(
|
| 131 |
+
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
|
| 132 |
+
)
|
| 133 |
+
ReplacementPatternEntry.replace_with_graph(
|
| 134 |
+
self,
|
| 135 |
+
self.ctx.graph,
|
| 136 |
+
replacement,
|
| 137 |
+
args,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class FailedMatch(RuntimeError):
|
| 142 |
+
def __init__(self, format_string, *args, **kwargs):
|
| 143 |
+
self.format_string = format_string
|
| 144 |
+
# We want to construct error messages lazily instead of eagerly, as
|
| 145 |
+
# constructing them eagerly can significantly worsen compile times.
|
| 146 |
+
if len(format_string) > 200:
|
| 147 |
+
raise RuntimeError(
|
| 148 |
+
f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}"
|
| 149 |
+
)
|
| 150 |
+
self.args = args
|
| 151 |
+
self.kwargs = kwargs
|
| 152 |
+
|
| 153 |
+
def __str__(self):
|
| 154 |
+
return self.format_string.format(*self.args, **self.kwargs)
|
| 155 |
+
|
| 156 |
+
def __bool__(self):
|
| 157 |
+
return False
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def is_match(m: Union[Match, FailedMatch]) -> TypeGuard[Match]:
|
| 161 |
+
"""
|
| 162 |
+
TypeGuards cannot act on `self`. Thus this function exists to let mypy
|
| 163 |
+
recognize FailedMatch.__bool__ as a TypeGuard.
|
| 164 |
+
"""
|
| 165 |
+
return bool(m)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class MatchContext:
|
| 169 |
+
"""
|
| 170 |
+
State needed while running PatternExpr._match().
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def __init__(
|
| 174 |
+
self,
|
| 175 |
+
outputs: List[Optional[PatternExpr]],
|
| 176 |
+
pattern_to_node: Optional[Dict[PatternExpr, Node]] = None,
|
| 177 |
+
*,
|
| 178 |
+
graph: torch.fx.Graph,
|
| 179 |
+
):
|
| 180 |
+
self.outputs = outputs
|
| 181 |
+
self.pattern_to_node = {} if pattern_to_node is None else pattern_to_node
|
| 182 |
+
self.graph = graph
|
| 183 |
+
self.exclusive_node_set: List[NodeOrConstant] = []
|
| 184 |
+
|
| 185 |
+
def match(self, pattern, node):
|
| 186 |
+
"""wrapper to check reused nodes in patterns"""
|
| 187 |
+
if pattern in self.pattern_to_node:
|
| 188 |
+
if self.pattern_to_node[pattern] == node:
|
| 189 |
+
return Match(pattern) # already checked this node
|
| 190 |
+
else:
|
| 191 |
+
return FailedMatch("repeated pattern differs")
|
| 192 |
+
m = pattern._match(node, self)
|
| 193 |
+
assert pattern not in self.pattern_to_node
|
| 194 |
+
self.pattern_to_node[pattern] = node if m else None
|
| 195 |
+
m.ctx = self
|
| 196 |
+
return m
|
| 197 |
+
|
| 198 |
+
def filter_multi_user_patterns(self):
|
| 199 |
+
return {
|
| 200 |
+
pattern: node
|
| 201 |
+
for pattern, node in self.pattern_to_node.items()
|
| 202 |
+
if pattern.has_multiple_users() and node is not None
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class PatternExpr:
|
| 207 |
+
"""
|
| 208 |
+
Base class for types of patterns
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
def _match(
|
| 212 |
+
self, node: torch.fx.Node, ctx: MatchContext
|
| 213 |
+
) -> Union[Match, FailedMatch]:
|
| 214 |
+
raise NotImplementedError()
|
| 215 |
+
|
| 216 |
+
def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]:
|
| 217 |
+
try:
|
| 218 |
+
return MatchContext([self], graph=node.graph).match(self, node)
|
| 219 |
+
except FailedMatch as e:
|
| 220 |
+
return e
|
| 221 |
+
|
| 222 |
+
def has_multiple_users(self) -> bool:
|
| 223 |
+
return False
|
| 224 |
+
|
| 225 |
+
def __repr__(self):
|
| 226 |
+
return self.__class__.__name__ + "()"
|
| 227 |
+
|
| 228 |
+
def find_anchor_nodes(self, ctx: MatchContext, searched):
|
| 229 |
+
if self in ctx.pattern_to_node:
|
| 230 |
+
yield ctx.pattern_to_node[self]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class Arg(PatternExpr):
|
| 234 |
+
"""
|
| 235 |
+
Capture an arg which will become an input to the handler. Args are
|
| 236 |
+
passed in depth first order.
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
def _match(self, node: NodeOrConstant, ctx: MatchContext):
|
| 240 |
+
return Match(self, args=[node]) # matches anything
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class Ignored(PatternExpr):
|
| 244 |
+
"""
|
| 245 |
+
Match an arg, but don't pass it to handler
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def _match(self, node: NodeOrConstant, ctx: MatchContext):
|
| 249 |
+
return Match(self) # matches anything
|
| 250 |
+
|
| 251 |
+
def __repr__(self):
|
| 252 |
+
return "*"
|
| 253 |
+
|
| 254 |
+
def pretty_print(self, pp: PatternPrettyPrinter):
|
| 255 |
+
return "Ignored()"
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class KeywordArg(PatternExpr):
|
| 259 |
+
"""
|
| 260 |
+
Capture a kwarg which will become an input to the handler.
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
def __init__(self, name: str):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.name = name
|
| 266 |
+
|
| 267 |
+
def __repr__(self):
|
| 268 |
+
return f"KeywordArg({self.name!r})"
|
| 269 |
+
|
| 270 |
+
def _match(self, node: NodeOrConstant, ctx: MatchContext):
|
| 271 |
+
return Match(self, kwargs={self.name: node}) # matches anything
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class ExclusiveKeywordArg(PatternExpr):
|
| 275 |
+
"""
|
| 276 |
+
Capture a kwarg which will become an input to the handler.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
def __init__(self, name):
|
| 280 |
+
super().__init__()
|
| 281 |
+
self.name = name
|
| 282 |
+
|
| 283 |
+
def __repr__(self):
|
| 284 |
+
return f"ExclusiveKeywordArg({self.name!r})"
|
| 285 |
+
|
| 286 |
+
def _match(self, node: NodeOrConstant, ctx: MatchContext):
|
| 287 |
+
if node in ctx.exclusive_node_set:
|
| 288 |
+
return FailedMatch("exclusive arg appears twice")
|
| 289 |
+
|
| 290 |
+
ctx.exclusive_node_set.append(node)
|
| 291 |
+
return Match(self, kwargs={self.name: node}) # matches anything
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class _TargetExpr(PatternExpr):
|
| 295 |
+
"""
|
| 296 |
+
Base class for filtering match by node.target
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
op: Optional[str] = None
|
| 300 |
+
|
| 301 |
+
def __init__(self, fns, users=1):
|
| 302 |
+
if not self.op:
|
| 303 |
+
raise NotImplementedError("Shouldn't directly use _BaseNodeMatch")
|
| 304 |
+
super().__init__()
|
| 305 |
+
fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns)
|
| 306 |
+
for fn in list(fns):
|
| 307 |
+
if isinstance(fn, torch._ops.OpOverloadPacket):
|
| 308 |
+
fns.extend([getattr(fn, overload) for overload in fn.overloads()])
|
| 309 |
+
|
| 310 |
+
self.fns: List[Union[Callable[..., Any], str]] = fns
|
| 311 |
+
self.fns_set: Set[Union[Callable[..., Any], str]] = set(fns)
|
| 312 |
+
self.users: Union[int, Multiple] = users
|
| 313 |
+
|
| 314 |
+
def fns_repr(self) -> str:
|
| 315 |
+
first_repr = self.fns[0]
|
| 316 |
+
if not isinstance(first_repr, str):
|
| 317 |
+
first_repr = first_repr.__name__
|
| 318 |
+
|
| 319 |
+
if len(self.fns) > 1:
|
| 320 |
+
return f"[{first_repr}, ...]"
|
| 321 |
+
elif self.fns[0] is getattr(torch, first_repr, None):
|
| 322 |
+
return f"torch.{first_repr}"
|
| 323 |
+
elif isinstance(self.fns[0], torch._ops.OpOverload):
|
| 324 |
+
return str(self.fns[0])
|
| 325 |
+
else:
|
| 326 |
+
return first_repr
|
| 327 |
+
|
| 328 |
+
def __repr__(self):
|
| 329 |
+
return f"{self.__class__.__name__}({self.fns_repr()})"
|
| 330 |
+
|
| 331 |
+
def has_multiple_users(self) -> bool:
|
| 332 |
+
return isinstance(self.users, Multiple) or self.users > 1
|
| 333 |
+
|
| 334 |
+
def find_anchor_nodes(self, ctx: MatchContext, searched):
|
| 335 |
+
raise NotImplementedError()
|
| 336 |
+
|
| 337 |
+
def _match_fns(self, node: torch.fx.Node):
|
| 338 |
+
return (
|
| 339 |
+
isinstance(node, torch.fx.Node)
|
| 340 |
+
and node.op == self.op
|
| 341 |
+
and extract_target(node) in self.fns_set
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
def _match_users(self, node: torch.fx.Node, ctx: MatchContext):
|
| 345 |
+
return (
|
| 346 |
+
self in ctx.outputs
|
| 347 |
+
or self.users is MULTIPLE
|
| 348 |
+
or len(node.users) == self.users
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class _TargetArgsExpr(_TargetExpr):
|
| 353 |
+
"""
|
| 354 |
+
Base class for filtering match by node.{target,args,kwargs}
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
def __init__(self, fns, *args, _users=1, **kwargs):
|
| 358 |
+
super().__init__(fns, _users)
|
| 359 |
+
self.args = tuple(args)
|
| 360 |
+
self.kwargs = dict(kwargs)
|
| 361 |
+
if any(
|
| 362 |
+
isinstance(x, (dict, list, tuple))
|
| 363 |
+
for x in itertools.chain(args, kwargs.values())
|
| 364 |
+
):
|
| 365 |
+
self.flatten = self.pytree_flatten
|
| 366 |
+
else:
|
| 367 |
+
self.flatten = self.simple_flatten
|
| 368 |
+
self.flat_args_kwargs = self.flatten(self.args, self.kwargs)
|
| 369 |
+
|
| 370 |
+
@staticmethod
|
| 371 |
+
def simple_flatten(args, kwargs: Dict[Any, Any]):
|
| 372 |
+
return (*args, *kwargs.values()), (len(args), *kwargs.keys())
|
| 373 |
+
|
| 374 |
+
@staticmethod
|
| 375 |
+
def pytree_flatten(args, kwargs: Dict[Any, Any]):
|
| 376 |
+
def norm_spec(s: pytree.TreeSpec):
|
| 377 |
+
if s.type is None:
|
| 378 |
+
return s
|
| 379 |
+
mapping = {immutable_list: list, tuple: list, immutable_dict: dict}
|
| 380 |
+
return pytree.TreeSpec(
|
| 381 |
+
mapping.get(s.type, s.type),
|
| 382 |
+
s.context,
|
| 383 |
+
list(map(norm_spec, s.children_specs)),
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
flat, spec = pytree.tree_flatten([args, kwargs])
|
| 387 |
+
spec = norm_spec(spec)
|
| 388 |
+
return flat, spec
|
| 389 |
+
|
| 390 |
+
def __repr__(self):
|
| 391 |
+
args = [
|
| 392 |
+
self.fns_repr(),
|
| 393 |
+
*map(repr, self.args),
|
| 394 |
+
*[f"{k}={v}" for k, v in self.kwargs.items()],
|
| 395 |
+
]
|
| 396 |
+
return f"{self.__class__.__name__}({', '.join(args)})"
|
| 397 |
+
|
| 398 |
+
def pretty_print(self, pp: PatternPrettyPrinter):
|
| 399 |
+
args = [
|
| 400 |
+
self.fns_repr(),
|
| 401 |
+
*(pp.pretty_print(x) for x in self.args),
|
| 402 |
+
*[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()],
|
| 403 |
+
]
|
| 404 |
+
if isinstance(self.users, Multiple):
|
| 405 |
+
args.append("_users=MULTIPLE")
|
| 406 |
+
elif self.users > 1:
|
| 407 |
+
args.append(f"_users={self.users}")
|
| 408 |
+
|
| 409 |
+
joiner_str = ", "
|
| 410 |
+
return f"{self.__class__.__name__}({joiner_str.join(args)})"
|
| 411 |
+
|
| 412 |
+
def _match(self, node: torch.fx.Node, ctx: MatchContext):
|
| 413 |
+
if not self._match_fns(node) or len(node.args) != len(self.args):
|
| 414 |
+
return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
|
| 415 |
+
|
| 416 |
+
if not self._match_users(node, ctx):
|
| 417 |
+
return FailedMatch("multiple_users {}", self)
|
| 418 |
+
|
| 419 |
+
_args = node.args
|
| 420 |
+
_kwargs = node.kwargs
|
| 421 |
+
if len(_kwargs) < len(self.kwargs):
|
| 422 |
+
from torch.fx.operator_schemas import normalize_function
|
| 423 |
+
|
| 424 |
+
normalized_args_and_kwargs = normalize_function(
|
| 425 |
+
node.target, node.args, node.kwargs
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
if normalized_args_and_kwargs is None:
|
| 429 |
+
return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
|
| 430 |
+
else:
|
| 431 |
+
_args, _kwargs = normalized_args_and_kwargs
|
| 432 |
+
if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs):
|
| 433 |
+
_kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
|
| 434 |
+
else:
|
| 435 |
+
return FailedMatch(
|
| 436 |
+
"function_mismatch: node={}, pattern={}", node, self
|
| 437 |
+
)
|
| 438 |
+
else:
|
| 439 |
+
_kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
|
| 440 |
+
|
| 441 |
+
node_items, node_spec = self.flatten(_args, _kwargs)
|
| 442 |
+
self_items, self_spec = self.flat_args_kwargs
|
| 443 |
+
if node_spec != self_spec:
|
| 444 |
+
return FailedMatch("args_structure {} {}", node_spec, self_spec)
|
| 445 |
+
assert len(node_items) == len(self_items)
|
| 446 |
+
|
| 447 |
+
m = Match(self)
|
| 448 |
+
for i, pattern, child_node in zip(itertools.count(), self_items, node_items):
|
| 449 |
+
if isinstance(pattern, PatternExpr):
|
| 450 |
+
child_match = ctx.match(pattern, child_node)
|
| 451 |
+
if not child_match:
|
| 452 |
+
return child_match
|
| 453 |
+
m.extend(child_match)
|
| 454 |
+
elif isinstance(child_node, torch.fx.Node) or child_node != pattern:
|
| 455 |
+
return FailedMatch(
|
| 456 |
+
"constant_args: {} {!r}!={pattern!r}", node, child_node
|
| 457 |
+
)
|
| 458 |
+
m.nodes.append(node)
|
| 459 |
+
m.targets[self] = node.target
|
| 460 |
+
return m
|
| 461 |
+
|
| 462 |
+
def find_anchor_nodes(self, ctx: MatchContext, searched):
|
| 463 |
+
"""
|
| 464 |
+
This is used when we are matching a pattern with multiple outputs.
|
| 465 |
+
There is a partial match (stored in ctx) and we want to walk
|
| 466 |
+
this pattern to find a connection to an already-matched node.
|
| 467 |
+
|
| 468 |
+
Yields candidate nodes that `self._match` might like.
|
| 469 |
+
"""
|
| 470 |
+
if self in ctx.pattern_to_node:
|
| 471 |
+
yield ctx.pattern_to_node[self]
|
| 472 |
+
return
|
| 473 |
+
|
| 474 |
+
for pattern in self.flat_args_kwargs[0]:
|
| 475 |
+
if isinstance(pattern, PatternExpr):
|
| 476 |
+
for other_node in pattern.find_anchor_nodes(ctx, searched):
|
| 477 |
+
if not isinstance(other_node, torch.fx.Node):
|
| 478 |
+
continue
|
| 479 |
+
for node in other_node.users:
|
| 480 |
+
if node not in searched:
|
| 481 |
+
if self._match_fns(node):
|
| 482 |
+
yield node
|
| 483 |
+
searched.add(node)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class CallFunction(_TargetArgsExpr):
|
| 487 |
+
"""
|
| 488 |
+
Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)`
|
| 489 |
+
"""
|
| 490 |
+
|
| 491 |
+
op = "call_function"
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class CallMethod(_TargetArgsExpr):
|
| 495 |
+
"""
|
| 496 |
+
Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)`
|
| 497 |
+
"""
|
| 498 |
+
|
| 499 |
+
op = "call_method"
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
class CallModule(_TargetArgsExpr):
|
| 503 |
+
"""
|
| 504 |
+
Matches a call_module node in the FX graphs: `module(*args, **kwargs)`
|
| 505 |
+
"""
|
| 506 |
+
|
| 507 |
+
op = "call_module"
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class _TargetExprVarArgs(_TargetExpr):
|
| 511 |
+
"""
|
| 512 |
+
Matches a call_function node with any arguments which are passed into the pattern
|
| 513 |
+
"""
|
| 514 |
+
|
| 515 |
+
def _match(self, node: torch.fx.Node, ctx: MatchContext):
|
| 516 |
+
if not self._match_fns(node):
|
| 517 |
+
return FailedMatch("function_mismatch")
|
| 518 |
+
|
| 519 |
+
if not self._match_users(node, ctx):
|
| 520 |
+
return FailedMatch("multiple_users")
|
| 521 |
+
|
| 522 |
+
m = Match(self)
|
| 523 |
+
m.nodes.append(node)
|
| 524 |
+
m.targets[self] = node.target
|
| 525 |
+
m.args.extend(node.args)
|
| 526 |
+
m.kwargs.update(node.kwargs)
|
| 527 |
+
return m
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class CallFunctionVarArgs(_TargetExprVarArgs):
|
| 531 |
+
op = "call_function"
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class CallMethodVarArgs(_TargetExprVarArgs):
|
| 535 |
+
op = "call_method"
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class CallModuleVarArgs(_TargetExprVarArgs):
|
| 539 |
+
op = "call_module"
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class ListOf(PatternExpr):
|
| 543 |
+
"""
|
| 544 |
+
Matches a repeated pattern
|
| 545 |
+
"""
|
| 546 |
+
|
| 547 |
+
def __init__(self, pattern: PatternExpr, partial=False):
|
| 548 |
+
super().__init__()
|
| 549 |
+
assert isinstance(pattern, PatternExpr)
|
| 550 |
+
self.pattern = pattern
|
| 551 |
+
self.partial = partial
|
| 552 |
+
|
| 553 |
+
def __repr__(self):
|
| 554 |
+
return f"{self.__class__.__name__}({self.pattern})"
|
| 555 |
+
|
| 556 |
+
def _match(self, node: List[torch.fx.Node], ctx: MatchContext): # type: ignore[override]
|
| 557 |
+
if not isinstance(node, (list, tuple)) or len(node) == 0:
|
| 558 |
+
return FailedMatch("non_list")
|
| 559 |
+
m = Match(self)
|
| 560 |
+
# Propagating patterns with multiple users will ensure we don't revisit
|
| 561 |
+
# the same nodes
|
| 562 |
+
pattern_to_node = ctx.filter_multi_user_patterns()
|
| 563 |
+
matched = False
|
| 564 |
+
for i, child_node in enumerate(node):
|
| 565 |
+
child_ctx = MatchContext(
|
| 566 |
+
ctx.outputs, pattern_to_node, graph=child_node.graph
|
| 567 |
+
)
|
| 568 |
+
child_match = child_ctx.match(self.pattern, child_node)
|
| 569 |
+
pattern_to_node = child_ctx.filter_multi_user_patterns()
|
| 570 |
+
if not child_match:
|
| 571 |
+
if not self.partial:
|
| 572 |
+
return FailedMatch("list[{}]: {}", i, child_match)
|
| 573 |
+
continue
|
| 574 |
+
matched = True
|
| 575 |
+
m.extend(child_match.bundle())
|
| 576 |
+
if not matched:
|
| 577 |
+
return FailedMatch("list: no_match")
|
| 578 |
+
return m.bundle()
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
class MultiOutputPattern(PatternExpr):
|
| 582 |
+
def __init__(self, outputs):
|
| 583 |
+
super().__init__()
|
| 584 |
+
assert all(isinstance(x, (PatternExpr, type(None))) for x in outputs), outputs
|
| 585 |
+
self.outputs: List[Optional[PatternExpr]] = outputs
|
| 586 |
+
|
| 587 |
+
@property
|
| 588 |
+
def fns(self):
|
| 589 |
+
assert self.outputs[0] and hasattr(self.outputs[0], "fns")
|
| 590 |
+
return self.outputs[0].fns
|
| 591 |
+
|
| 592 |
+
def __repr__(self):
|
| 593 |
+
return f"{self.__class__.__name__}({self.outputs})"
|
| 594 |
+
|
| 595 |
+
def pretty_print(self, pp: PatternPrettyPrinter):
|
| 596 |
+
args = [pp.pretty_print(x) for x in self.outputs]
|
| 597 |
+
joiner_str = f",\n{' '}"
|
| 598 |
+
str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}"
|
| 599 |
+
str_out = f"{str_out}\n])"
|
| 600 |
+
return str_out
|
| 601 |
+
|
| 602 |
+
def _match(self, node: torch.fx.Node, ctx: MatchContext):
|
| 603 |
+
m = ctx.match(self.outputs[0], node)
|
| 604 |
+
if not m:
|
| 605 |
+
return m
|
| 606 |
+
|
| 607 |
+
for pattern in self.outputs[1:]:
|
| 608 |
+
if pattern is None:
|
| 609 |
+
continue
|
| 610 |
+
child_match = self._match_from_anchors(pattern, ctx)
|
| 611 |
+
if not child_match:
|
| 612 |
+
return child_match
|
| 613 |
+
m.extend(child_match)
|
| 614 |
+
|
| 615 |
+
return m
|
| 616 |
+
|
| 617 |
+
def _match_from_anchors(self, pattern, ctx):
|
| 618 |
+
prior = dict(ctx.pattern_to_node)
|
| 619 |
+
m = FailedMatch("no anchor found")
|
| 620 |
+
for node in pattern.find_anchor_nodes(ctx, set()):
|
| 621 |
+
m = ctx.match(pattern, node)
|
| 622 |
+
if m:
|
| 623 |
+
return m
|
| 624 |
+
# revert any partial matches
|
| 625 |
+
ctx.pattern_to_node = dict(prior)
|
| 626 |
+
return m
|
| 627 |
+
|
| 628 |
+
def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]:
|
| 629 |
+
try:
|
| 630 |
+
return MatchContext(self.outputs, graph=node.graph).match(self, node)
|
| 631 |
+
except FailedMatch as e:
|
| 632 |
+
return e
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class RepeatedExpr(PatternExpr):
|
| 636 |
+
"""
|
| 637 |
+
Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind`
|
| 638 |
+
"""
|
| 639 |
+
|
| 640 |
+
def __init__(self, inner_pattern: PatternExpr):
|
| 641 |
+
super().__init__()
|
| 642 |
+
assert hasattr(inner_pattern, "fns")
|
| 643 |
+
self.inner_pattern = inner_pattern
|
| 644 |
+
|
| 645 |
+
@property
|
| 646 |
+
def fns(self):
|
| 647 |
+
return self.inner_pattern.fns
|
| 648 |
+
|
| 649 |
+
def _match(self, node: torch.fx.Node, ctx: MatchContext):
|
| 650 |
+
m = ctx.match(self.inner_pattern, node)
|
| 651 |
+
if not m:
|
| 652 |
+
return m
|
| 653 |
+
ctx.pattern_to_node.pop(
|
| 654 |
+
self.inner_pattern,
|
| 655 |
+
)
|
| 656 |
+
# Check all anchor nodes match the pattern
|
| 657 |
+
for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, set()):
|
| 658 |
+
anchor_m = MatchContext([self], graph=node.graph).match(
|
| 659 |
+
self.inner_pattern, anchor_node
|
| 660 |
+
)
|
| 661 |
+
if not anchor_m:
|
| 662 |
+
return anchor_m
|
| 663 |
+
m.extend(anchor_m)
|
| 664 |
+
return m
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
class PatternPrettyPrinter:
|
| 668 |
+
"""
|
| 669 |
+
Serializes Patterns to executable python.
|
| 670 |
+
XXX: currently only used and tested for fuse attention patterns. May not cover
|
| 671 |
+
all patterns.
|
| 672 |
+
"""
|
| 673 |
+
|
| 674 |
+
def __init__(self):
|
| 675 |
+
self.namespace = torch.fx.graph._Namespace()
|
| 676 |
+
self.memoized_objs_names: Dict[PatternExpr, str] = {}
|
| 677 |
+
self.memoized_objs_pp: Dict[PatternExpr, str] = {}
|
| 678 |
+
|
| 679 |
+
@staticmethod
|
| 680 |
+
def run(obj: PatternExpr, output_name="output"):
|
| 681 |
+
"""
|
| 682 |
+
Serializes obj to python code with obj written out to `output_name`
|
| 683 |
+
"""
|
| 684 |
+
|
| 685 |
+
pp = PatternPrettyPrinter()
|
| 686 |
+
assert hasattr(obj, "pretty_print")
|
| 687 |
+
out_str = obj.pretty_print(pp=pp)
|
| 688 |
+
|
| 689 |
+
output = []
|
| 690 |
+
for key in pp.memoized_objs_names:
|
| 691 |
+
output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}")
|
| 692 |
+
|
| 693 |
+
output.append(f"{output_name} = {out_str}")
|
| 694 |
+
|
| 695 |
+
return "\n".join(output)
|
| 696 |
+
|
| 697 |
+
def pretty_print(self, obj):
|
| 698 |
+
if isinstance(obj, _TargetArgsExpr):
|
| 699 |
+
if memoized_name := self.memoized_objs_names.get(obj):
|
| 700 |
+
return memoized_name
|
| 701 |
+
else:
|
| 702 |
+
return self.memoize(obj)
|
| 703 |
+
if hasattr(obj, "pretty_print"):
|
| 704 |
+
return obj.pretty_print(self)
|
| 705 |
+
|
| 706 |
+
return repr(obj)
|
| 707 |
+
|
| 708 |
+
def memoize(self, obj):
|
| 709 |
+
obj_str = obj.pretty_print(self)
|
| 710 |
+
obj_name = obj.fns_repr()
|
| 711 |
+
for prefix in ("aten.", "torch.", "prims."):
|
| 712 |
+
obj_name = obj_name.replace(prefix, "")
|
| 713 |
+
|
| 714 |
+
tmp_name = self.namespace.create_name(obj_name, None)
|
| 715 |
+
self.memoized_objs_names[obj] = tmp_name
|
| 716 |
+
self.memoized_objs_pp[obj] = obj_str
|
| 717 |
+
return tmp_name
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
@dataclasses.dataclass
|
| 721 |
+
class PatternEntry:
|
| 722 |
+
pattern: PatternExpr
|
| 723 |
+
extra_check: Callable[[Match], bool]
|
| 724 |
+
|
| 725 |
+
def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node):
|
| 726 |
+
raise NotImplementedError()
|
| 727 |
+
|
| 728 |
+
def register(self, pass_dicts, target=None, prepend=False):
|
| 729 |
+
if target is None:
|
| 730 |
+
assert hasattr(self.pattern, "fns")
|
| 731 |
+
for fn in self.pattern.fns:
|
| 732 |
+
self.register(pass_dicts, fn, prepend=prepend)
|
| 733 |
+
elif isinstance(pass_dicts, (dict, PatternMatcherPass)):
|
| 734 |
+
if prepend:
|
| 735 |
+
pass_dicts[target].insert(0, self)
|
| 736 |
+
else:
|
| 737 |
+
pass_dicts[target].append(self)
|
| 738 |
+
else:
|
| 739 |
+
for x in pass_dicts:
|
| 740 |
+
self.register(x, target, prepend=prepend)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
@dataclasses.dataclass
|
| 744 |
+
class LoweringPatternEntry(PatternEntry):
|
| 745 |
+
handler: Callable[..., Any]
|
| 746 |
+
|
| 747 |
+
def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node):
|
| 748 |
+
handler = functools.wraps(self.handler)(functools.partial(self.handler, match))
|
| 749 |
+
with graph.inserting_before(node):
|
| 750 |
+
replacement = graph.call_function(handler, tuple(match.args), match.kwargs)
|
| 751 |
+
replacement.meta.update(node.meta)
|
| 752 |
+
node.replace_all_uses_with(replacement)
|
| 753 |
+
assert match.nodes[-1] is node
|
| 754 |
+
match.erase_nodes(graph)
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
@dataclasses.dataclass
|
| 758 |
+
class GraphPatternEntry(PatternEntry):
|
| 759 |
+
"""
|
| 760 |
+
A pattern that runs a function on the FX graph
|
| 761 |
+
"""
|
| 762 |
+
|
| 763 |
+
handler: Callable[..., Any]
|
| 764 |
+
|
| 765 |
+
def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node):
|
| 766 |
+
with graph.inserting_before(node):
|
| 767 |
+
self.handler(match, *match.args, **match.kwargs)
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
@dataclasses.dataclass
|
| 771 |
+
class ReplacementPatternEntry(PatternEntry):
|
| 772 |
+
normalize_args: Callable[..., List[Any]]
|
| 773 |
+
|
| 774 |
+
@staticmethod
|
| 775 |
+
def replace_with_graph(
|
| 776 |
+
match: Match,
|
| 777 |
+
graph: torch.fx.Graph,
|
| 778 |
+
replacement_graph: torch.fx.Graph,
|
| 779 |
+
args: List[Any],
|
| 780 |
+
):
|
| 781 |
+
output_nodes = match.output_nodes()
|
| 782 |
+
first_node = output_nodes[0]
|
| 783 |
+
|
| 784 |
+
class Replacer(torch.fx.Interpreter):
|
| 785 |
+
call_method = None # type: ignore[assignment]
|
| 786 |
+
call_module = None # type: ignore[assignment]
|
| 787 |
+
get_attr = None # type: ignore[assignment]
|
| 788 |
+
|
| 789 |
+
def run_node(self, node) -> Any:
|
| 790 |
+
if node.op in ("placeholder", "output"):
|
| 791 |
+
return super().run_node(node)
|
| 792 |
+
if node.op == "call_function":
|
| 793 |
+
target = node.target
|
| 794 |
+
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
| 795 |
+
result = graph.call_function(target, args, kwargs)
|
| 796 |
+
if "val" in node.meta and "val" not in result.meta:
|
| 797 |
+
result.meta["val"] = node.meta["val"]
|
| 798 |
+
if isinstance(node.meta["val"], torch.Tensor):
|
| 799 |
+
assert "tensor_meta" in node.meta
|
| 800 |
+
result.meta["tensor_meta"] = node.meta["tensor_meta"]
|
| 801 |
+
return result
|
| 802 |
+
raise NotImplementedError(f"unhandled {node}")
|
| 803 |
+
|
| 804 |
+
output_nodes = match.output_nodes()
|
| 805 |
+
|
| 806 |
+
if len(output_nodes) == 1:
|
| 807 |
+
last_node = output_nodes[0]
|
| 808 |
+
else:
|
| 809 |
+
assert output_nodes[0]
|
| 810 |
+
nodes = list(output_nodes[0].graph.nodes)
|
| 811 |
+
indices = [
|
| 812 |
+
(nodes.index(n), n)
|
| 813 |
+
for n in output_nodes
|
| 814 |
+
if isinstance(n, torch.fx.Node)
|
| 815 |
+
]
|
| 816 |
+
last_node = min(indices, key=lambda tup: tup[0])[1]
|
| 817 |
+
|
| 818 |
+
def percolate_tags(node, recompute_tag, input_stops):
|
| 819 |
+
queue = [node]
|
| 820 |
+
visited = set()
|
| 821 |
+
|
| 822 |
+
while queue:
|
| 823 |
+
arg = queue.pop()
|
| 824 |
+
if (
|
| 825 |
+
arg not in visited
|
| 826 |
+
and arg not in input_stops
|
| 827 |
+
and hasattr(arg, "meta")
|
| 828 |
+
):
|
| 829 |
+
visited.add(arg)
|
| 830 |
+
arg.meta["recompute"] = recompute_tag
|
| 831 |
+
queue.extend(arg.all_input_nodes)
|
| 832 |
+
|
| 833 |
+
with graph.inserting_before(last_node):
|
| 834 |
+
replacement = Replacer(replacement_graph).run(*args)
|
| 835 |
+
if isinstance(replacement, torch.fx.Node):
|
| 836 |
+
replacement = [replacement]
|
| 837 |
+
|
| 838 |
+
def maybe_getitem(node):
|
| 839 |
+
if node.op != "call_function":
|
| 840 |
+
return None
|
| 841 |
+
if node.target != operator.getitem:
|
| 842 |
+
return None
|
| 843 |
+
assert len(node.args) == 2
|
| 844 |
+
return node.args[1]
|
| 845 |
+
|
| 846 |
+
def replace(old, new):
|
| 847 |
+
if old is None:
|
| 848 |
+
assert new is None
|
| 849 |
+
return
|
| 850 |
+
assert isinstance(old, torch.fx.Node)
|
| 851 |
+
if new is None:
|
| 852 |
+
old.replace_all_uses_with(None)
|
| 853 |
+
graph.erase_node(old)
|
| 854 |
+
return
|
| 855 |
+
if isinstance(new, torch.fx.Node):
|
| 856 |
+
if "val" not in new.meta:
|
| 857 |
+
new.meta.update(old.meta)
|
| 858 |
+
|
| 859 |
+
# Preserve the recompute tags in the replacement graph. We
|
| 860 |
+
# look at the recompute tags of the original output node to
|
| 861 |
+
# propagate the tag from the output all the way to the input
|
| 862 |
+
# args (named as args in the replace_with_graph).
|
| 863 |
+
# Note that this is best effort. Since patterns are from
|
| 864 |
+
# many to many, there is no easy way to correctly map the
|
| 865 |
+
# recomputable tags. It is possible in some scenarios that we
|
| 866 |
+
# incorrectly tag some nodes as recomputables.
|
| 867 |
+
if "recompute" in old.meta:
|
| 868 |
+
percolate_tags(new, old.meta["recompute"], args)
|
| 869 |
+
|
| 870 |
+
old.replace_all_uses_with(new)
|
| 871 |
+
graph.erase_node(old)
|
| 872 |
+
return
|
| 873 |
+
|
| 874 |
+
# `new` is not a node: it's a list of nodes.
|
| 875 |
+
#
|
| 876 |
+
# This happens when we want to replace a node that has a single
|
| 877 |
+
# packed return with multiple unpacked returns. We need to do
|
| 878 |
+
# some graph surgery here.
|
| 879 |
+
#
|
| 880 |
+
# Example:
|
| 881 |
+
# def original_graph(x):
|
| 882 |
+
# a = op(x)
|
| 883 |
+
# b = a[0]
|
| 884 |
+
# c = a[1]
|
| 885 |
+
# ...
|
| 886 |
+
#
|
| 887 |
+
# Assume that we want to replace op(x) with the graph
|
| 888 |
+
# def new_op(x):
|
| 889 |
+
# w = x + 1
|
| 890 |
+
# z = x + 2
|
| 891 |
+
# return (w, z)
|
| 892 |
+
#
|
| 893 |
+
# We need to replace `op` with the contents of `new_op`,
|
| 894 |
+
# and then rewrite a[0] to be w and a[1] to be z, as so:
|
| 895 |
+
# def new_graph(x):
|
| 896 |
+
# w = x + 1
|
| 897 |
+
# z = x + 2
|
| 898 |
+
# b = w
|
| 899 |
+
# c = z
|
| 900 |
+
# ...
|
| 901 |
+
old_uses = list(old.users.keys())
|
| 902 |
+
for user in old_uses:
|
| 903 |
+
idx = maybe_getitem(user)
|
| 904 |
+
if idx is None:
|
| 905 |
+
raise AssertionError("can't handle")
|
| 906 |
+
replace(user, new[idx])
|
| 907 |
+
graph.erase_node(old)
|
| 908 |
+
|
| 909 |
+
if len(output_nodes) == len(replacement):
|
| 910 |
+
for old, new in zip(output_nodes, replacement):
|
| 911 |
+
replace(old, new)
|
| 912 |
+
else:
|
| 913 |
+
assert len(output_nodes) == 1
|
| 914 |
+
replace(output_nodes[0], replacement)
|
| 915 |
+
|
| 916 |
+
match.erase_nodes(graph)
|
| 917 |
+
|
| 918 |
+
def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node):
|
| 919 |
+
self.replace_with_graph(
|
| 920 |
+
match,
|
| 921 |
+
graph,
|
| 922 |
+
match.replacement_graph, # type: ignore[arg-type]
|
| 923 |
+
self.normalize_args(*match.args, **match.kwargs),
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
def _return_true(match):
|
| 928 |
+
return True
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
def log_trace_failure(search_fn, e):
|
| 932 |
+
log.info(
|
| 933 |
+
"Replacement pattern %s failed to apply due to shape mismatch: %s",
|
| 934 |
+
search_fn.__name__,
|
| 935 |
+
e,
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
def register_replacement(
|
| 940 |
+
search_fn,
|
| 941 |
+
replace_fn,
|
| 942 |
+
example_inputs: Iterable[Any],
|
| 943 |
+
trace_fn: Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule],
|
| 944 |
+
pass_dicts,
|
| 945 |
+
extra_check=_return_true,
|
| 946 |
+
scalar_workaround=(),
|
| 947 |
+
exclusive_arg_names=(),
|
| 948 |
+
search_fn_pattern=None,
|
| 949 |
+
):
|
| 950 |
+
"""
|
| 951 |
+
Create a replacement rule based on example functions that get traced
|
| 952 |
+
to create patterns. This supports both training and inference when
|
| 953 |
+
run on a joint forward+backward graph.
|
| 954 |
+
|
| 955 |
+
Args:
|
| 956 |
+
search_fn: traced to give original pattern
|
| 957 |
+
replace_fn: traced to give replacement graph
|
| 958 |
+
example_inputs: example inputs for initial trace
|
| 959 |
+
trace_fn: fwd_only or joint_fwd_bwd
|
| 960 |
+
pass_dict: dict of passes to register to
|
| 961 |
+
extra_check: additional check to run on match(using real shapes)
|
| 962 |
+
"""
|
| 963 |
+
argnames_static = [*inspect.signature(search_fn).parameters.keys()]
|
| 964 |
+
|
| 965 |
+
def check_fn(match: Match):
|
| 966 |
+
"""
|
| 967 |
+
Often shapes get burned into the pattern, so our initial match ran with
|
| 968 |
+
`ignore_types=(int, ...)`.
|
| 969 |
+
|
| 970 |
+
Recheck the match with the correct shapes.
|
| 971 |
+
"""
|
| 972 |
+
argnames = list(argnames_static)
|
| 973 |
+
for name in argnames:
|
| 974 |
+
if name not in match.kwargs:
|
| 975 |
+
raise RuntimeError(
|
| 976 |
+
f"Not all inputs to pattern found in match.kwargs. Perhaps one "
|
| 977 |
+
f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}"
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
+
args = list(
|
| 981 |
+
torch.fx.map_arg(
|
| 982 |
+
[match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
|
| 983 |
+
)
|
| 984 |
+
)
|
| 985 |
+
sym_args: List[torch.SymInt] = []
|
| 986 |
+
with torch._dynamo.utils.detect_fake_mode(args):
|
| 987 |
+
for i, grad in enumerate(requires_grad):
|
| 988 |
+
if isinstance(args[i], torch.Tensor):
|
| 989 |
+
if grad and is_integer_dtype(args[i].dtype):
|
| 990 |
+
return False
|
| 991 |
+
|
| 992 |
+
args[i] = torch.empty_strided(
|
| 993 |
+
args[i].size(),
|
| 994 |
+
args[i].stride(),
|
| 995 |
+
dtype=args[i].dtype,
|
| 996 |
+
device=args[i].device,
|
| 997 |
+
requires_grad=grad,
|
| 998 |
+
)
|
| 999 |
+
for v in itertools.chain(args[i].shape, args[i].stride()):
|
| 1000 |
+
if isinstance(v, torch.SymInt) and all(
|
| 1001 |
+
guard_size_oblivious(v != a) for a in sym_args
|
| 1002 |
+
):
|
| 1003 |
+
sym_args.append(v)
|
| 1004 |
+
|
| 1005 |
+
if sym_args:
|
| 1006 |
+
# AOT Autograd and make fx will dedupe symbolic shape size
|
| 1007 |
+
# accesses of sym ints that appear as inputs
|
| 1008 |
+
# We don't want the sym_size uses to interfere with pattern matching
|
| 1009 |
+
# so we provide them as inputs.
|
| 1010 |
+
# Later, when we actually do the replacement, the symbolic shape
|
| 1011 |
+
# sizes will get re-traced and added to the graph.
|
| 1012 |
+
|
| 1013 |
+
def search_fn_new(*args_new):
|
| 1014 |
+
return search_fn(*args_new[len(args_new) - len(args) :])
|
| 1015 |
+
|
| 1016 |
+
try:
|
| 1017 |
+
specific_graph = trace_fn(search_fn_new, sym_args + args)
|
| 1018 |
+
except RuntimeError as e:
|
| 1019 |
+
log_trace_failure(search_fn, e)
|
| 1020 |
+
return False
|
| 1021 |
+
|
| 1022 |
+
# correct argnames in the graph
|
| 1023 |
+
sym_arg_names = []
|
| 1024 |
+
for i, placeholder in zip(
|
| 1025 |
+
range(len(sym_args) + len(args)),
|
| 1026 |
+
specific_graph.graph.nodes,
|
| 1027 |
+
):
|
| 1028 |
+
if i < len(sym_args):
|
| 1029 |
+
sym_arg_names.append(placeholder.target)
|
| 1030 |
+
continue
|
| 1031 |
+
|
| 1032 |
+
with specific_graph.graph.inserting_after(placeholder):
|
| 1033 |
+
new_node = specific_graph.graph.placeholder(
|
| 1034 |
+
argnames[i - len(sym_args)]
|
| 1035 |
+
)
|
| 1036 |
+
new_node.target = new_node.name
|
| 1037 |
+
placeholder.replace_all_uses_with(new_node)
|
| 1038 |
+
specific_graph.graph.erase_node(placeholder)
|
| 1039 |
+
|
| 1040 |
+
argnames = sym_arg_names + argnames
|
| 1041 |
+
else:
|
| 1042 |
+
try:
|
| 1043 |
+
specific_graph = trace_fn(search_fn, args)
|
| 1044 |
+
except RuntimeError as e:
|
| 1045 |
+
log_trace_failure(search_fn, e)
|
| 1046 |
+
return False
|
| 1047 |
+
|
| 1048 |
+
specific_pattern = fx_to_pattern(
|
| 1049 |
+
specific_graph,
|
| 1050 |
+
argnames=argnames,
|
| 1051 |
+
exclusive_arg_names=exclusive_arg_names,
|
| 1052 |
+
scalar_workaround=scalar_workaround,
|
| 1053 |
+
)
|
| 1054 |
+
specific_pattern_match = specific_pattern.match(match.output_nodes()[0]) # type: ignore[arg-type]
|
| 1055 |
+
if specific_pattern_match and extra_check(specific_pattern_match):
|
| 1056 |
+
# trace the pattern using the shapes from the user program
|
| 1057 |
+
match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment]
|
| 1058 |
+
return True
|
| 1059 |
+
return False
|
| 1060 |
+
|
| 1061 |
+
def normalize_args(**kwargs):
|
| 1062 |
+
args = []
|
| 1063 |
+
for name in argnames_static:
|
| 1064 |
+
args.append(kwargs.pop(name))
|
| 1065 |
+
for i in range(1, len(kwargs) + 1):
|
| 1066 |
+
if f"tangents_{i}" not in kwargs:
|
| 1067 |
+
break
|
| 1068 |
+
args.append(kwargs.pop(f"tangents_{i}"))
|
| 1069 |
+
assert not kwargs, f"leftover kwargs: {kwargs!r}"
|
| 1070 |
+
return args
|
| 1071 |
+
|
| 1072 |
+
if trace_fn is joint_fwd_bwd:
|
| 1073 |
+
# If inference mode is enabled during compilation, assume that we don't
|
| 1074 |
+
# want to match on any training graph patterns
|
| 1075 |
+
if torch.is_inference_mode_enabled():
|
| 1076 |
+
return False
|
| 1077 |
+
|
| 1078 |
+
# TODO: Revisit the functionalize_rng_ops for lowmem dropout
|
| 1079 |
+
with functorch_config.patch(functionalize_rng_ops=False):
|
| 1080 |
+
requires_grad: List[bool] = [
|
| 1081 |
+
isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs
|
| 1082 |
+
]
|
| 1083 |
+
if search_fn_pattern is None:
|
| 1084 |
+
pattern = gen_pattern(
|
| 1085 |
+
search_fn,
|
| 1086 |
+
example_inputs,
|
| 1087 |
+
trace_fn,
|
| 1088 |
+
scalar_workaround,
|
| 1089 |
+
exclusive_arg_names,
|
| 1090 |
+
)
|
| 1091 |
+
else:
|
| 1092 |
+
pattern = search_fn_pattern
|
| 1093 |
+
|
| 1094 |
+
pattern_repr = PatternPrettyPrinter.run(pattern)
|
| 1095 |
+
assert pattern_repr not in _seen_patterns
|
| 1096 |
+
_seen_patterns.add(pattern_repr)
|
| 1097 |
+
pattern = ReplacementPatternEntry(
|
| 1098 |
+
pattern=pattern,
|
| 1099 |
+
extra_check=check_fn,
|
| 1100 |
+
normalize_args=normalize_args,
|
| 1101 |
+
)
|
| 1102 |
+
pattern.register(pass_dicts)
|
| 1103 |
+
return pattern.pattern
|
| 1104 |
+
|
| 1105 |
+
|
| 1106 |
+
@functorch_config.patch(functionalize_rng_ops=False)
|
| 1107 |
+
def gen_pattern(
|
| 1108 |
+
search_fn, example_inputs, trace_fn, scalar_workaround=(), exclusive_arg_names=()
|
| 1109 |
+
) -> PatternExpr:
|
| 1110 |
+
argnames = [*inspect.signature(search_fn).parameters.keys()]
|
| 1111 |
+
|
| 1112 |
+
if scalar_workaround == ():
|
| 1113 |
+
scalar_workaround = {}
|
| 1114 |
+
flat_inputs = []
|
| 1115 |
+
input_idx = 0 # Positional arguments index
|
| 1116 |
+
|
| 1117 |
+
for argname in argnames:
|
| 1118 |
+
if argname in scalar_workaround:
|
| 1119 |
+
flat_inputs.append(scalar_workaround[argname])
|
| 1120 |
+
else:
|
| 1121 |
+
flat_inputs.append(example_inputs[input_idx])
|
| 1122 |
+
input_idx += 1
|
| 1123 |
+
|
| 1124 |
+
search_gm = trace_fn(search_fn, flat_inputs)
|
| 1125 |
+
return fx_to_pattern(
|
| 1126 |
+
search_gm,
|
| 1127 |
+
ignore_types=(int, float, list, torch.device, torch.dtype),
|
| 1128 |
+
argnames=argnames,
|
| 1129 |
+
scalar_workaround=scalar_workaround,
|
| 1130 |
+
exclusive_arg_names=exclusive_arg_names,
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
|
| 1134 |
+
def register_lowering_pattern(
|
| 1135 |
+
pattern: PatternExpr, extra_check=_return_true, *, pass_dict, prepend=False
|
| 1136 |
+
):
|
| 1137 |
+
"""
|
| 1138 |
+
Register an aten to inductor IR replacement pattern. The decorated
|
| 1139 |
+
function is saved and then called a lowering time allowing direct
|
| 1140 |
+
pattern to inductor IR conversion.
|
| 1141 |
+
"""
|
| 1142 |
+
|
| 1143 |
+
def decorator(handler):
|
| 1144 |
+
assert callable(handler)
|
| 1145 |
+
LoweringPatternEntry(
|
| 1146 |
+
pattern=pattern, extra_check=extra_check, handler=handler
|
| 1147 |
+
).register(pass_dict, prepend=prepend)
|
| 1148 |
+
handler._inductor_lowering_function = True
|
| 1149 |
+
return handler
|
| 1150 |
+
|
| 1151 |
+
return decorator
|
| 1152 |
+
|
| 1153 |
+
|
| 1154 |
+
def register_graph_pattern(
|
| 1155 |
+
pattern: PatternExpr, extra_check=_return_true, *, pass_dict, prepend=False
|
| 1156 |
+
):
|
| 1157 |
+
"""
|
| 1158 |
+
Register a pattern that runs a function on the FX graph, allowing
|
| 1159 |
+
custom transformation code.
|
| 1160 |
+
"""
|
| 1161 |
+
|
| 1162 |
+
def decorator(handler):
|
| 1163 |
+
assert callable(handler)
|
| 1164 |
+
GraphPatternEntry(
|
| 1165 |
+
pattern=pattern, extra_check=extra_check, handler=handler
|
| 1166 |
+
).register(pass_dict, prepend=prepend)
|
| 1167 |
+
return handler
|
| 1168 |
+
|
| 1169 |
+
return decorator
|
| 1170 |
+
|
| 1171 |
+
|
| 1172 |
+
def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
|
| 1173 |
+
# first node in the graph
|
| 1174 |
+
return node is next(iter(graph.nodes))
|
| 1175 |
+
|
| 1176 |
+
|
| 1177 |
+
# match: copy_, relu_, _set_grad_enabled, manual_seed, enter_functional_autocast, etc
|
| 1178 |
+
_mutation_op_re = re.compile(r"_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_)")
|
| 1179 |
+
|
| 1180 |
+
|
| 1181 |
+
def is_mutation_op(node: torch.fx.Node) -> bool:
|
| 1182 |
+
if node.op == "call_function":
|
| 1183 |
+
if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
|
| 1184 |
+
return True
|
| 1185 |
+
elif node.op == "call_method":
|
| 1186 |
+
if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type]
|
| 1187 |
+
return True
|
| 1188 |
+
return node.kwargs.get("out") is not None
|
| 1189 |
+
|
| 1190 |
+
|
| 1191 |
+
def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
|
| 1192 |
+
n = node
|
| 1193 |
+
while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):
|
| 1194 |
+
n = n.prev
|
| 1195 |
+
mutation_region_id = n.meta.get("mutation_region_id", 0)
|
| 1196 |
+
while n is not node:
|
| 1197 |
+
n = n.next
|
| 1198 |
+
if is_mutation_op(n):
|
| 1199 |
+
mutation_region_id += 1
|
| 1200 |
+
n.meta["mutation_region_id"] = mutation_region_id
|
| 1201 |
+
return mutation_region_id
|
| 1202 |
+
|
| 1203 |
+
|
| 1204 |
+
def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool:
|
| 1205 |
+
return "mutation_region_id" not in next(iter(graph.nodes)).meta
|
| 1206 |
+
|
| 1207 |
+
|
| 1208 |
+
def compute_mutation_region_ids(graph: torch.fx.GraphModule):
|
| 1209 |
+
mutation_region_id = 0
|
| 1210 |
+
for nd in graph.nodes:
|
| 1211 |
+
if is_mutation_op(nd):
|
| 1212 |
+
mutation_region_id += 1
|
| 1213 |
+
nd.meta["mutation_region_id"] = mutation_region_id
|
| 1214 |
+
|
| 1215 |
+
|
| 1216 |
+
class PatternMatcherPass:
|
| 1217 |
+
def __init__(
|
| 1218 |
+
self, prevent_match_across_mutations=False, pass_name: Optional[str] = None
|
| 1219 |
+
):
|
| 1220 |
+
super().__init__()
|
| 1221 |
+
self.patterns: DefaultDict[
|
| 1222 |
+
torch.fx.node.Target, List[PatternEntry]
|
| 1223 |
+
] = defaultdict(list)
|
| 1224 |
+
self.prevent_match_across_mutations = prevent_match_across_mutations
|
| 1225 |
+
self.pass_name = pass_name
|
| 1226 |
+
|
| 1227 |
+
def __getitem__(self, item: torch.fx.node.Target) -> List[PatternEntry]:
|
| 1228 |
+
return self.patterns[item]
|
| 1229 |
+
|
| 1230 |
+
def apply(self, graph: torch.fx.GraphModule) -> int:
|
| 1231 |
+
if not self.patterns:
|
| 1232 |
+
return 0
|
| 1233 |
+
if isinstance(graph, torch.fx.GraphModule):
|
| 1234 |
+
graph = graph.graph
|
| 1235 |
+
if self.prevent_match_across_mutations:
|
| 1236 |
+
if should_compute_mutation_region_ids(graph):
|
| 1237 |
+
compute_mutation_region_ids(graph)
|
| 1238 |
+
get_mutation_region_id_partial = functools.partial(
|
| 1239 |
+
get_mutation_region_id, graph
|
| 1240 |
+
)
|
| 1241 |
+
count = 0
|
| 1242 |
+
for node in reversed(graph.nodes):
|
| 1243 |
+
target = extract_target(node)
|
| 1244 |
+
if (
|
| 1245 |
+
node.op in ["call_function", "call_method", "call_module"]
|
| 1246 |
+
and target in self.patterns
|
| 1247 |
+
):
|
| 1248 |
+
# conservatively not applying pattern for cpu input,
|
| 1249 |
+
# since some of the patterns induce codegen and split nodes.
|
| 1250 |
+
# Note: we will only skip cpu compute if disable_cpp_codegen=True
|
| 1251 |
+
if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False):
|
| 1252 |
+
continue
|
| 1253 |
+
|
| 1254 |
+
for entry in self.patterns[target]:
|
| 1255 |
+
if node._erased:
|
| 1256 |
+
break
|
| 1257 |
+
m = entry.pattern.match(node)
|
| 1258 |
+
# pattern match crosses mutation barrier - discard
|
| 1259 |
+
if (
|
| 1260 |
+
self.prevent_match_across_mutations
|
| 1261 |
+
and is_match(m)
|
| 1262 |
+
and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined]
|
| 1263 |
+
):
|
| 1264 |
+
continue
|
| 1265 |
+
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
|
| 1266 |
+
log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
|
| 1267 |
+
if is_match(m) and entry.extra_check(m):
|
| 1268 |
+
count += 1
|
| 1269 |
+
entry.apply(m, graph, node) # type: ignore[arg-type]
|
| 1270 |
+
counters["inductor"]["pattern_matcher_count"] += 1
|
| 1271 |
+
counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes)
|
| 1272 |
+
return count
|
| 1273 |
+
|
| 1274 |
+
def clear(self):
|
| 1275 |
+
self.patterns.clear()
|
| 1276 |
+
|
| 1277 |
+
|
| 1278 |
+
def _not_implemented(*args, **kwargs) -> NoReturn:
|
| 1279 |
+
raise NotImplementedError()
|
| 1280 |
+
|
| 1281 |
+
|
| 1282 |
+
def fx_to_pattern(
|
| 1283 |
+
gm,
|
| 1284 |
+
ignore_types=(),
|
| 1285 |
+
argnames=(),
|
| 1286 |
+
scalar_workaround=(),
|
| 1287 |
+
exclusive_arg_names=(),
|
| 1288 |
+
) -> PatternExpr:
|
| 1289 |
+
"""
|
| 1290 |
+
Convert an FX graph into a PatternExpr. This is useful for simple
|
| 1291 |
+
patterns that can only match single functions and fixed-length lists.
|
| 1292 |
+
"""
|
| 1293 |
+
# scalar_workaround is a hack to capture dropout_p
|
| 1294 |
+
# see https://github.com/pytorch/pytorch/issues/97894
|
| 1295 |
+
scalar_workaround = scalar_workaround or {}
|
| 1296 |
+
inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()}
|
| 1297 |
+
assert len(inv_scalar_workaround) == len(scalar_workaround)
|
| 1298 |
+
|
| 1299 |
+
def process_arg(x):
|
| 1300 |
+
if isinstance(x, (float, int)) and x in inv_scalar_workaround:
|
| 1301 |
+
return KeywordArg(inv_scalar_workaround[x])
|
| 1302 |
+
if type(x) in ignore_types:
|
| 1303 |
+
return Ignored()
|
| 1304 |
+
if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x:
|
| 1305 |
+
return Ignored()
|
| 1306 |
+
return x
|
| 1307 |
+
|
| 1308 |
+
argnum = itertools.count()
|
| 1309 |
+
|
| 1310 |
+
class Converter(torch.fx.Interpreter):
|
| 1311 |
+
call_method = _not_implemented
|
| 1312 |
+
call_module = _not_implemented
|
| 1313 |
+
get_attr = _not_implemented
|
| 1314 |
+
|
| 1315 |
+
def placeholder(self, target, args, kwargs):
|
| 1316 |
+
n = next(argnum)
|
| 1317 |
+
if n < len(argnames):
|
| 1318 |
+
name = argnames[n]
|
| 1319 |
+
elif argnames:
|
| 1320 |
+
assert target.startswith("tangent")
|
| 1321 |
+
name = target
|
| 1322 |
+
else:
|
| 1323 |
+
target = re.sub(r"_\d+$", "", target) # de-mangle arg name
|
| 1324 |
+
name = target
|
| 1325 |
+
if name in exclusive_arg_names:
|
| 1326 |
+
return ExclusiveKeywordArg(name)
|
| 1327 |
+
else:
|
| 1328 |
+
return KeywordArg(name)
|
| 1329 |
+
|
| 1330 |
+
def call_function(self, target, args, kwargs):
|
| 1331 |
+
args, kwargs = pytree.tree_map(process_arg, (args, kwargs))
|
| 1332 |
+
if list in ignore_types:
|
| 1333 |
+
# Handle a burned in tensor size which are now [Ignored(), Ignored(), ...]
|
| 1334 |
+
args = [process_arg(a) for a in args]
|
| 1335 |
+
kwargs = {k: process_arg(a) for k, a in kwargs.items()}
|
| 1336 |
+
return CallFunction(target, *args, **kwargs)
|
| 1337 |
+
|
| 1338 |
+
def run_node(self, n):
|
| 1339 |
+
rv = super().run_node(n)
|
| 1340 |
+
if n.op == "output" and isinstance(rv, tuple):
|
| 1341 |
+
assert len(rv) == len(n.args[0])
|
| 1342 |
+
for r, arg in zip(rv, n.args[0]):
|
| 1343 |
+
r.users = len(arg.users)
|
| 1344 |
+
else:
|
| 1345 |
+
rv.users = len(n.users)
|
| 1346 |
+
return rv
|
| 1347 |
+
|
| 1348 |
+
pattern = Converter(gm).run()
|
| 1349 |
+
if not isinstance(pattern, PatternExpr):
|
| 1350 |
+
return MultiOutputPattern(pytree.tree_leaves(pattern))
|
| 1351 |
+
return pattern
|
| 1352 |
+
|
| 1353 |
+
|
| 1354 |
+
@torch.no_grad()
|
| 1355 |
+
def fwd_only(fn, args, *, run_dce=True) -> torch.fx.GraphModule:
|
| 1356 |
+
"""Build a normalized inference graph, for use with fx_to_pattern"""
|
| 1357 |
+
# TODO - look into using aot autograd, asserting no mutating ops here
|
| 1358 |
+
with enable_python_dispatcher():
|
| 1359 |
+
mode = (
|
| 1360 |
+
"real" if not torch._inductor.utils.any_is_symbolic(*args) else "symbolic"
|
| 1361 |
+
)
|
| 1362 |
+
gm = make_fx(fn, select_decomp_table(), tracing_mode=mode)(*args)
|
| 1363 |
+
if run_dce:
|
| 1364 |
+
gm.graph.eliminate_dead_code()
|
| 1365 |
+
gm.recompile()
|
| 1366 |
+
return gm
|
| 1367 |
+
|
| 1368 |
+
|
| 1369 |
+
@torch.enable_grad()
|
| 1370 |
+
def joint_fwd_bwd(fn, args) -> torch.fx.GraphModule:
|
| 1371 |
+
"""Build a normalized training graph, for use with fx_to_pattern"""
|
| 1372 |
+
gm: Optional[torch.fx.GraphModule] = None
|
| 1373 |
+
|
| 1374 |
+
def record_joint_graph(joint_graph, inputs, **kwargs):
|
| 1375 |
+
nonlocal gm
|
| 1376 |
+
assert not gm
|
| 1377 |
+
gm = clone_graph(joint_graph)
|
| 1378 |
+
return default_partition(joint_graph, inputs, **kwargs)
|
| 1379 |
+
|
| 1380 |
+
with torch._guards.tracing(None):
|
| 1381 |
+
aot_function(
|
| 1382 |
+
fn,
|
| 1383 |
+
lambda g, i: make_boxed_func(g),
|
| 1384 |
+
partition_fn=record_joint_graph,
|
| 1385 |
+
decompositions=select_decomp_table(),
|
| 1386 |
+
keep_inference_input_mutations=True,
|
| 1387 |
+
enable_log=False,
|
| 1388 |
+
)(*args)
|
| 1389 |
+
assert gm
|
| 1390 |
+
|
| 1391 |
+
from .fx_passes.joint_graph import pointless_view
|
| 1392 |
+
|
| 1393 |
+
matcher_pass = PatternMatcherPass()
|
| 1394 |
+
|
| 1395 |
+
pattern = CallFunction(
|
| 1396 |
+
torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")
|
| 1397 |
+
)
|
| 1398 |
+
GraphPatternEntry(
|
| 1399 |
+
pattern=pattern, handler=pointless_view, extra_check=_return_true
|
| 1400 |
+
).register(matcher_pass.patterns)
|
| 1401 |
+
matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
| 1402 |
+
|
| 1403 |
+
# remove in/out specs
|
| 1404 |
+
gm.graph._codegen = torch.fx.graph.CodeGen()
|
| 1405 |
+
gm.graph.eliminate_dead_code()
|
| 1406 |
+
gm.recompile()
|
| 1407 |
+
return gm
|
| 1408 |
+
|
| 1409 |
+
|
| 1410 |
+
def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]:
|
| 1411 |
+
args: List[torch.fx.node.Argument] = list()
|
| 1412 |
+
torch.fx.map_arg((n.args, n.kwargs), args.append)
|
| 1413 |
+
return args
|
| 1414 |
+
|
| 1415 |
+
|
| 1416 |
+
def stable_topological_sort(graph: torch.fx.Graph):
|
| 1417 |
+
# Nodes are in exactly one of these three collections:
|
| 1418 |
+
|
| 1419 |
+
# - Nodes in `pending` are waiting to be processed (in reverse order):
|
| 1420 |
+
pending = list(reversed(graph.nodes))
|
| 1421 |
+
|
| 1422 |
+
# - Nodes in `ready` have been processed and are already in the correct
|
| 1423 |
+
# order.
|
| 1424 |
+
ready = set()
|
| 1425 |
+
|
| 1426 |
+
# - `waiting` is a mapping from a dependency to nodes which depend on that
|
| 1427 |
+
# dependency.
|
| 1428 |
+
waiting = defaultdict(list)
|
| 1429 |
+
|
| 1430 |
+
# The cursor indicates the last processed node so we can add new nodes
|
| 1431 |
+
# after it.
|
| 1432 |
+
cursor = None
|
| 1433 |
+
while pending:
|
| 1434 |
+
node = pending.pop()
|
| 1435 |
+
waiting_for = [x for x in _args(node) if x not in ready]
|
| 1436 |
+
if waiting_for:
|
| 1437 |
+
# We have unprocessed input nodes. Might as well wait for the last
|
| 1438 |
+
# arg so an already sorted list will only recheck this node once.
|
| 1439 |
+
waiting[waiting_for[-1]].append(node)
|
| 1440 |
+
else:
|
| 1441 |
+
ready.add(node)
|
| 1442 |
+
if cursor and cursor.next is not node:
|
| 1443 |
+
cursor.append(node)
|
| 1444 |
+
cursor = node
|
| 1445 |
+
# Mark the nodes that have been waiting for this node to finish as
|
| 1446 |
+
# ready to check again.
|
| 1447 |
+
pending.extend(reversed(waiting.pop(node, ())))
|
| 1448 |
+
|
| 1449 |
+
assert not waiting and len(ready) == len(graph.nodes)
|
| 1450 |
+
|
| 1451 |
+
|
| 1452 |
+
def init_once_fakemode(fn: Callable[..., Any]):
|
| 1453 |
+
"""Wrapper around lazy init functions in fx_passes/"""
|
| 1454 |
+
|
| 1455 |
+
@functools.lru_cache(None)
|
| 1456 |
+
@functools.wraps(fn)
|
| 1457 |
+
def lazy_init():
|
| 1458 |
+
counters_ref = counters["inductor"].copy()
|
| 1459 |
+
|
| 1460 |
+
with torch._guards.tracing(
|
| 1461 |
+
None
|
| 1462 |
+
), maybe_disable_fake_tensor_mode(), FakeTensorMode():
|
| 1463 |
+
result = fn()
|
| 1464 |
+
|
| 1465 |
+
# clear view matches encountered during tracing
|
| 1466 |
+
counters["inductor"] = counters_ref
|
| 1467 |
+
|
| 1468 |
+
return result
|
| 1469 |
+
|
| 1470 |
+
return lazy_init
|
| 1471 |
+
|
| 1472 |
+
|
| 1473 |
+
def config_flag(name):
|
| 1474 |
+
"""Function for extra_check to put pass behind a flag"""
|
| 1475 |
+
|
| 1476 |
+
def flag_check(match):
|
| 1477 |
+
return getattr(config, name)
|
| 1478 |
+
|
| 1479 |
+
return flag_check
|
| 1480 |
+
|
| 1481 |
+
|
| 1482 |
+
def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 1483 |
+
class CopyGraph(Transformer):
|
| 1484 |
+
def run_node(self, old_node):
|
| 1485 |
+
new_node = super().run_node(old_node)
|
| 1486 |
+
if isinstance(new_node, torch.fx.Proxy):
|
| 1487 |
+
new_node.node.meta.update(old_node.meta)
|
| 1488 |
+
new_node.node.name = self.new_graph._graph_namespace.create_name(
|
| 1489 |
+
old_node.name, None
|
| 1490 |
+
)
|
| 1491 |
+
return new_node
|
| 1492 |
+
|
| 1493 |
+
return CopyGraph(input_graph).transform()
|
| 1494 |
+
|
| 1495 |
+
|
| 1496 |
+
_seen_patterns: Set[str] = set()
|
| 1497 |
+
|
| 1498 |
+
|
| 1499 |
+
def get_arg_value(
|
| 1500 |
+
node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None
|
| 1501 |
+
):
|
| 1502 |
+
return (
|
| 1503 |
+
node.args[arg_number]
|
| 1504 |
+
if len(node.args) > arg_number
|
| 1505 |
+
else node.kwargs.get(kwarg_name) # type: ignore[arg-type]
|
| 1506 |
+
)
|
| 1507 |
+
|
| 1508 |
+
|
| 1509 |
+
def filter_nodes(nodes: Iterable[torch.fx.Node], fn) -> List[torch.fx.Node]:
|
| 1510 |
+
fns = [fn]
|
| 1511 |
+
if isinstance(fn, torch._ops.OpOverloadPacket):
|
| 1512 |
+
fns.extend([getattr(fn, overload) for overload in fn.overloads()])
|
| 1513 |
+
|
| 1514 |
+
return [node for node in nodes if node.target in fns]
|
| 1515 |
+
|
| 1516 |
+
|
| 1517 |
+
def extract_target(node: Node):
|
| 1518 |
+
"""For call_function and call_method, we directly use the target function;
|
| 1519 |
+
For call_module, the target is string, and we treat the module class
|
| 1520 |
+
as a function.
|
| 1521 |
+
"""
|
| 1522 |
+
if node.op == "call_module":
|
| 1523 |
+
return getattr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type]
|
| 1524 |
+
return node.target
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/scheduler.py
ADDED
|
@@ -0,0 +1,2445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import dataclasses
|
| 3 |
+
import functools
|
| 4 |
+
import itertools
|
| 5 |
+
import logging
|
| 6 |
+
import math
|
| 7 |
+
import operator
|
| 8 |
+
import os
|
| 9 |
+
import pprint
|
| 10 |
+
import textwrap
|
| 11 |
+
from typing import (
|
| 12 |
+
Any,
|
| 13 |
+
Counter,
|
| 14 |
+
DefaultDict,
|
| 15 |
+
Dict,
|
| 16 |
+
Generic,
|
| 17 |
+
List,
|
| 18 |
+
Optional,
|
| 19 |
+
Sequence,
|
| 20 |
+
Set,
|
| 21 |
+
Tuple,
|
| 22 |
+
TypeVar,
|
| 23 |
+
Union,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
import sympy
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
from torch._dynamo.utils import dynamo_timed
|
| 30 |
+
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
| 31 |
+
from torch.utils._triton import has_triton
|
| 32 |
+
|
| 33 |
+
from . import comms, config, dependencies, ir, metrics
|
| 34 |
+
from .codegen.common import get_scheduling_for_device, Kernel
|
| 35 |
+
from .comm_analysis import estimate_nccl_collective_runtime
|
| 36 |
+
from .dependencies import Dep, MemoryDep, StarDep, WeakDep
|
| 37 |
+
from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout
|
| 38 |
+
from .sizevars import SimplifyIndexing
|
| 39 |
+
from .utils import (
|
| 40 |
+
cache_on_self,
|
| 41 |
+
cmp,
|
| 42 |
+
free_symbol_has,
|
| 43 |
+
get_device_tflops,
|
| 44 |
+
get_dtype_size,
|
| 45 |
+
get_gpu_dram_gbps,
|
| 46 |
+
green_text,
|
| 47 |
+
is_collective,
|
| 48 |
+
is_wait,
|
| 49 |
+
red_text,
|
| 50 |
+
sympy_product,
|
| 51 |
+
)
|
| 52 |
+
from .virtualized import V
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
log = logging.getLogger(__name__)
|
| 56 |
+
fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class WhyNoFuse:
|
| 60 |
+
# TODO when we drop support for Python < 3.10, we can use
|
| 61 |
+
# @dataclass(slots=True) instead of manually specifying __slots__.
|
| 62 |
+
__slots__ = ["node1", "node2", "reason", "args"]
|
| 63 |
+
reason: str
|
| 64 |
+
args: Tuple[Any, ...]
|
| 65 |
+
|
| 66 |
+
def __init__(self, node1: "BaseSchedulerNode", node2: "BaseSchedulerNode"):
|
| 67 |
+
self.node1 = node1
|
| 68 |
+
self.node2 = node2
|
| 69 |
+
|
| 70 |
+
def __call__(self, reason, *args):
|
| 71 |
+
self.reason = reason
|
| 72 |
+
self.args = args
|
| 73 |
+
fusion_log.debug(self)
|
| 74 |
+
|
| 75 |
+
def __str__(self):
|
| 76 |
+
return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + (
|
| 77 |
+
self.reason % self.args
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def pformat(obj):
|
| 82 |
+
if isinstance(obj, set):
|
| 83 |
+
# pformat has trouble with sets of sympy exprs
|
| 84 |
+
obj = sorted(obj, key=str)
|
| 85 |
+
result = pprint.pformat(obj, indent=4)
|
| 86 |
+
if "\n" in result:
|
| 87 |
+
return f"\n{textwrap.indent(result, ' '*4)}"
|
| 88 |
+
return result
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class OutputNode:
|
| 92 |
+
def __init__(self, dep):
|
| 93 |
+
self.unmet_dependencies = {dep}
|
| 94 |
+
self.inverse_users = []
|
| 95 |
+
|
| 96 |
+
def is_reduction(self):
|
| 97 |
+
return False
|
| 98 |
+
|
| 99 |
+
def get_alias_names(self):
|
| 100 |
+
return ()
|
| 101 |
+
|
| 102 |
+
def get_name(self):
|
| 103 |
+
return "OUTPUT"
|
| 104 |
+
|
| 105 |
+
__repr__ = get_name
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _prune_redundant_deps(node, name_to_fused_node):
|
| 109 |
+
"""
|
| 110 |
+
Prunes weakdeps intended for mutation ordering
|
| 111 |
+
on an upstream fused node if after fusion there is another dependency
|
| 112 |
+
on the fused upstream node, making the weakdep redundant
|
| 113 |
+
|
| 114 |
+
In essence this enforces an ordering on fusions. As fusions occur, weakdeps will
|
| 115 |
+
be incrementally removed, enabling other fusions, ensuring they are fused in order.
|
| 116 |
+
"""
|
| 117 |
+
name_to_dep_count: Counter[str] = collections.Counter()
|
| 118 |
+
|
| 119 |
+
for dep in node.unmet_dependencies:
|
| 120 |
+
if not isinstance(dep, WeakDep):
|
| 121 |
+
name_to_dep_count[name_to_fused_node[dep.name].get_name()] += 1
|
| 122 |
+
|
| 123 |
+
def should_prune(dep):
|
| 124 |
+
if isinstance(dep, WeakDep):
|
| 125 |
+
is_redundant = (
|
| 126 |
+
name_to_dep_count[name_to_fused_node[dep.name].get_name()] > 0
|
| 127 |
+
)
|
| 128 |
+
# These can occur because fused nodes always gather deps from their snodes
|
| 129 |
+
# If B has a weakdep on A
|
| 130 |
+
# B gets fused with C, then any time BC is fused, the weakdep will reappear
|
| 131 |
+
is_self_dep = name_to_fused_node[dep.name] == node
|
| 132 |
+
return is_redundant or is_self_dep
|
| 133 |
+
else:
|
| 134 |
+
return False
|
| 135 |
+
|
| 136 |
+
deps_to_prune = {dep for dep in node.unmet_dependencies if should_prune(dep)}
|
| 137 |
+
|
| 138 |
+
if deps_to_prune:
|
| 139 |
+
node.unmet_dependencies = node.unmet_dependencies - deps_to_prune
|
| 140 |
+
node.set_read_writes(node.read_writes.remove_reads(deps_to_prune))
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# TODO(xmfan): reuse an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel
|
| 144 |
+
kernel_name_to_op = {
|
| 145 |
+
"extern_kernels.convolution": torch.ops.aten.convolution,
|
| 146 |
+
"extern_kernels.mm": torch.ops.aten.mm,
|
| 147 |
+
"extern_kernels.bmm": torch.ops.aten.bmm,
|
| 148 |
+
"extern_kernels.addmm": torch.ops.aten.addmm,
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class BaseSchedulerNode:
|
| 153 |
+
def __init__(self, scheduler: "Scheduler", node: ir.Buffer):
|
| 154 |
+
self.scheduler: Scheduler = scheduler
|
| 155 |
+
self.node: ir.Buffer = node
|
| 156 |
+
self.users: List[NodeUser] = []
|
| 157 |
+
self.inverse_users: List[BaseSchedulerNode] = []
|
| 158 |
+
self.node_users: List[BaseSchedulerNode] = []
|
| 159 |
+
self.set_read_writes(node.get_read_writes())
|
| 160 |
+
self.ancestors: Set[str] = set()
|
| 161 |
+
self.min_order: int
|
| 162 |
+
self.max_order: int
|
| 163 |
+
self.last_usage: Set[
|
| 164 |
+
str
|
| 165 |
+
] = set() # buffers that won't be used after this kernel
|
| 166 |
+
self.written = False
|
| 167 |
+
|
| 168 |
+
def __repr__(self):
|
| 169 |
+
return f"{type(self).__name__}(name={self.get_name()!r})"
|
| 170 |
+
|
| 171 |
+
def debug_str(self) -> str:
|
| 172 |
+
"""Longer form printout for trace logs"""
|
| 173 |
+
name = self.get_name()
|
| 174 |
+
lines = [
|
| 175 |
+
f"{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__})",
|
| 176 |
+
f"{name}.writes = {pformat(self.read_writes.writes)}",
|
| 177 |
+
f"{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}",
|
| 178 |
+
f"{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}",
|
| 179 |
+
f"{name}.users = {self.users}",
|
| 180 |
+
]
|
| 181 |
+
try:
|
| 182 |
+
lines += [
|
| 183 |
+
self.debug_str_extra(),
|
| 184 |
+
]
|
| 185 |
+
except Exception:
|
| 186 |
+
log.warning("Ignoring error in debug_str()", exc_info=True)
|
| 187 |
+
|
| 188 |
+
return "\n".join(lines).rstrip()
|
| 189 |
+
|
| 190 |
+
def debug_str_extra(self) -> str:
|
| 191 |
+
return ""
|
| 192 |
+
|
| 193 |
+
def log_details(self):
|
| 194 |
+
log.info(
|
| 195 |
+
"%s: unmet_dependencies = %s, writes = %s",
|
| 196 |
+
self,
|
| 197 |
+
self.unmet_dependencies,
|
| 198 |
+
self.read_writes.writes,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def update_mutated_names(self, renames: Dict[str, str]):
|
| 202 |
+
self.set_read_writes(self.read_writes.rename(renames))
|
| 203 |
+
|
| 204 |
+
def add_mutation_dep(self, dep):
|
| 205 |
+
self.set_read_writes(self.read_writes.with_read(dep))
|
| 206 |
+
|
| 207 |
+
def add_fake_dep(self, dep):
|
| 208 |
+
self.set_read_writes(self.read_writes.with_read(dep))
|
| 209 |
+
|
| 210 |
+
def set_users(self, users: List["NodeUser"]):
|
| 211 |
+
# deduplicate
|
| 212 |
+
result: Dict[int, NodeUser] = {}
|
| 213 |
+
for use in users:
|
| 214 |
+
if id(use.node) in result:
|
| 215 |
+
result[id(use.node)] = use.merge(result[id(use.node)])
|
| 216 |
+
else:
|
| 217 |
+
result[id(use.node)] = use
|
| 218 |
+
self.users = list(result.values())
|
| 219 |
+
|
| 220 |
+
def set_last_usage(
|
| 221 |
+
self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str]
|
| 222 |
+
):
|
| 223 |
+
used_buffers = self.used_or_aliased_buffer_names()
|
| 224 |
+
used_buffers = {mutation_real_name.get(k, k) for k in used_buffers}
|
| 225 |
+
self.last_usage = used_buffers - future_used_buffers
|
| 226 |
+
|
| 227 |
+
def get_aliases(self):
|
| 228 |
+
return self.node.get_alias_names()
|
| 229 |
+
|
| 230 |
+
def get_mutations(self):
|
| 231 |
+
return self.node.get_mutation_names()
|
| 232 |
+
|
| 233 |
+
def has_aliasing_or_mutation(self):
|
| 234 |
+
return bool(self.get_aliases() or self.get_mutations())
|
| 235 |
+
|
| 236 |
+
def set_read_writes(self, rw: dependencies.ReadWrites):
|
| 237 |
+
self.read_writes: dependencies.ReadWrites = rw
|
| 238 |
+
self.unmet_dependencies = self.read_writes.reads
|
| 239 |
+
self.prune_deps()
|
| 240 |
+
|
| 241 |
+
def op_counts(self):
|
| 242 |
+
return self.read_writes.op_counts
|
| 243 |
+
|
| 244 |
+
def used_buffer_names(self) -> Set[str]:
|
| 245 |
+
return {
|
| 246 |
+
dep.name
|
| 247 |
+
for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes)
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
def used_or_aliased_buffer_names(self) -> Set[str]:
|
| 251 |
+
used_names = set()
|
| 252 |
+
|
| 253 |
+
for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes):
|
| 254 |
+
used_names.add(dep.name)
|
| 255 |
+
if V.graph.name_to_buffer.get(dep.name):
|
| 256 |
+
layout = V.graph.name_to_buffer[dep.name].get_layout()
|
| 257 |
+
# needed to avoid deallocating aliased buffer
|
| 258 |
+
# if there are still uses of aliases ahead
|
| 259 |
+
if isinstance(layout, ir.AliasedLayout):
|
| 260 |
+
used_names.add(layout.view.data.get_name())
|
| 261 |
+
return used_names
|
| 262 |
+
|
| 263 |
+
def prune_deps(self):
|
| 264 |
+
self.unmet_dependencies = {
|
| 265 |
+
dep
|
| 266 |
+
for dep in self.unmet_dependencies
|
| 267 |
+
if dep.name not in self.scheduler.available_buffer_names
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
def prune_weak_deps(self):
|
| 271 |
+
# Prune weak dependencies on buffers that have been removed
|
| 272 |
+
def should_prune(dep):
|
| 273 |
+
return isinstance(dep, WeakDep) and dep.name in V.graph.removed_buffers
|
| 274 |
+
|
| 275 |
+
to_remove = {dep for dep in self.read_writes.reads if should_prune(dep)}
|
| 276 |
+
self.set_read_writes(self.read_writes.remove_reads(to_remove))
|
| 277 |
+
|
| 278 |
+
def prune_redundant_deps(self, name_to_fused_node):
|
| 279 |
+
_prune_redundant_deps(self, name_to_fused_node)
|
| 280 |
+
|
| 281 |
+
def get_name(self) -> str:
|
| 282 |
+
return self.node.get_name()
|
| 283 |
+
|
| 284 |
+
def get_first_name(self) -> str:
|
| 285 |
+
return self.get_name()
|
| 286 |
+
|
| 287 |
+
def get_names(self) -> Set[str]:
|
| 288 |
+
return {self.get_name()}
|
| 289 |
+
|
| 290 |
+
def get_nodes(self) -> Sequence["BaseSchedulerNode"]:
|
| 291 |
+
return [self]
|
| 292 |
+
|
| 293 |
+
def get_device(self):
|
| 294 |
+
return self.node.get_device()
|
| 295 |
+
|
| 296 |
+
def is_reduction(self):
|
| 297 |
+
return False
|
| 298 |
+
|
| 299 |
+
def is_split_scan(self):
|
| 300 |
+
return False
|
| 301 |
+
|
| 302 |
+
def is_template(self):
|
| 303 |
+
return False
|
| 304 |
+
|
| 305 |
+
def is_extern(self):
|
| 306 |
+
return False
|
| 307 |
+
|
| 308 |
+
def is_foreach(self):
|
| 309 |
+
return False
|
| 310 |
+
|
| 311 |
+
def can_inplace(self, read_dep: dependencies.MemoryDep):
|
| 312 |
+
return False
|
| 313 |
+
|
| 314 |
+
def has_side_effects(self):
|
| 315 |
+
return False
|
| 316 |
+
|
| 317 |
+
def decide_inplace_update(self):
|
| 318 |
+
"""
|
| 319 |
+
Decide if there should be inplace updates for the node
|
| 320 |
+
and record the decision in the active kernel.
|
| 321 |
+
"""
|
| 322 |
+
if not self.node.should_allocate():
|
| 323 |
+
return
|
| 324 |
+
|
| 325 |
+
if isinstance(self, (SchedulerNode,)) and (
|
| 326 |
+
self.node.get_alias_names() or self.node.get_mutation_names()
|
| 327 |
+
):
|
| 328 |
+
return
|
| 329 |
+
|
| 330 |
+
if (
|
| 331 |
+
(
|
| 332 |
+
isinstance(self, (SchedulerNode,))
|
| 333 |
+
# o what have i done. lets make this an api
|
| 334 |
+
or (
|
| 335 |
+
isinstance(self, ExternKernelSchedulerNode)
|
| 336 |
+
and isinstance(self.node, (ir.AllReduce, ir.InPlaceHint))
|
| 337 |
+
)
|
| 338 |
+
)
|
| 339 |
+
and config.inplace_buffers
|
| 340 |
+
and (
|
| 341 |
+
not isinstance(V.kernel, torch._inductor.codegen.triton.TritonKernel)
|
| 342 |
+
or getattr(V.kernel, "mutations", None) is not None
|
| 343 |
+
)
|
| 344 |
+
):
|
| 345 |
+
from .codegen.wrapper import buffer_reuse_key
|
| 346 |
+
|
| 347 |
+
ordered_reads = sorted(self.read_writes.reads, key=lambda x: x.name)
|
| 348 |
+
|
| 349 |
+
for read in ordered_reads:
|
| 350 |
+
input_node: Optional[
|
| 351 |
+
BaseSchedulerNode
|
| 352 |
+
] = self.scheduler.name_to_node.get(read.name)
|
| 353 |
+
if input_node and V.graph.wrapper_code.can_reuse(input_node, self):
|
| 354 |
+
assert input_node.users is not None
|
| 355 |
+
remaining_uses = [
|
| 356 |
+
x
|
| 357 |
+
for x in input_node.users
|
| 358 |
+
if x.node.get_name()
|
| 359 |
+
not in self.scheduler.available_buffer_names
|
| 360 |
+
]
|
| 361 |
+
if (
|
| 362 |
+
len(remaining_uses) == 1
|
| 363 |
+
and remaining_uses[0].can_inplace
|
| 364 |
+
and remaining_uses[0].node is self
|
| 365 |
+
and not isinstance(
|
| 366 |
+
input_node.node.get_layout(),
|
| 367 |
+
(
|
| 368 |
+
ir.MultiOutputLayout,
|
| 369 |
+
ir.MutationLayout,
|
| 370 |
+
ir.AliasedLayout,
|
| 371 |
+
),
|
| 372 |
+
)
|
| 373 |
+
and not (
|
| 374 |
+
isinstance(
|
| 375 |
+
input_node.node, (ir.FallbackKernel, ir.MultiOutput)
|
| 376 |
+
)
|
| 377 |
+
and len(input_node.node.get_alias_names()) > 0
|
| 378 |
+
)
|
| 379 |
+
and buffer_reuse_key(input_node.node)
|
| 380 |
+
== buffer_reuse_key(self.node)
|
| 381 |
+
):
|
| 382 |
+
# hacky check for if V.kernel is a real kernel or NullHandler
|
| 383 |
+
if hasattr(V.kernel, "args"):
|
| 384 |
+
# if there isn't a triton kernel, then we don't need to call triton-specific things.
|
| 385 |
+
# but TODO this might be a convenient place to signal to the Collective kernels to inplace
|
| 386 |
+
# (and, can we make "kernel" less generic of a name?)
|
| 387 |
+
V.kernel.args.make_inplace(
|
| 388 |
+
input_node.get_name(), self.get_name()
|
| 389 |
+
)
|
| 390 |
+
# mutations not tracked in cpp kernels
|
| 391 |
+
if isinstance(
|
| 392 |
+
V.kernel, torch._inductor.codegen.triton.TritonKernel
|
| 393 |
+
):
|
| 394 |
+
V.kernel.mutations.add(input_node.get_name())
|
| 395 |
+
V.kernel.mutations.add(self.get_name())
|
| 396 |
+
|
| 397 |
+
# update last usage of reused node
|
| 398 |
+
self.last_usage.discard(input_node.get_name())
|
| 399 |
+
|
| 400 |
+
V.kernel.inplace_update_buffers[
|
| 401 |
+
self.get_name()
|
| 402 |
+
] = input_node.get_name()
|
| 403 |
+
break
|
| 404 |
+
|
| 405 |
+
def allocate(self):
|
| 406 |
+
if not self.node.should_allocate():
|
| 407 |
+
return
|
| 408 |
+
|
| 409 |
+
if isinstance(self, (SchedulerNode,)) and (
|
| 410 |
+
self.node.get_alias_names() or self.node.get_mutation_names()
|
| 411 |
+
):
|
| 412 |
+
V.graph.wrapper_code.codegen_allocation(self.node)
|
| 413 |
+
return
|
| 414 |
+
|
| 415 |
+
# hacky check for if V.kernel is a real kernel or NullHandler
|
| 416 |
+
if (
|
| 417 |
+
hasattr(V.kernel, "args")
|
| 418 |
+
and self.get_name() in V.kernel.inplace_update_buffers
|
| 419 |
+
):
|
| 420 |
+
V.graph.wrapper_code.codegen_inplace_reuse(
|
| 421 |
+
self.scheduler.name_to_node[
|
| 422 |
+
V.kernel.inplace_update_buffers[self.get_name()]
|
| 423 |
+
].node,
|
| 424 |
+
self.node,
|
| 425 |
+
)
|
| 426 |
+
else:
|
| 427 |
+
V.graph.wrapper_code.codegen_allocation(self.node)
|
| 428 |
+
|
| 429 |
+
def can_free(self):
|
| 430 |
+
# There's no real allocated buffer, no need to free it
|
| 431 |
+
if isinstance(self.node.layout, ir.NoneLayout):
|
| 432 |
+
return False
|
| 433 |
+
for use in self.users:
|
| 434 |
+
if isinstance(use.node, OutputNode):
|
| 435 |
+
return False
|
| 436 |
+
return True
|
| 437 |
+
|
| 438 |
+
def codegen_originating_info(self, buffer, only_once=True):
|
| 439 |
+
if not config.comment_origin:
|
| 440 |
+
return
|
| 441 |
+
|
| 442 |
+
if only_once and self.written:
|
| 443 |
+
return
|
| 444 |
+
origins = self.node.origins
|
| 445 |
+
out_lines = []
|
| 446 |
+
|
| 447 |
+
for o in origins:
|
| 448 |
+
if o.op == "output":
|
| 449 |
+
# These are boring and samey
|
| 450 |
+
continue
|
| 451 |
+
|
| 452 |
+
out_lines.append("")
|
| 453 |
+
# TODO(voz): Should the pragma be constant somewhere?
|
| 454 |
+
out_lines.append("#pragma CMT ORIGIN:")
|
| 455 |
+
op_info_str = f"#pragma CMT {o.op} {o.target}"
|
| 456 |
+
if "seq_nr" in o.meta:
|
| 457 |
+
op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}"
|
| 458 |
+
out_lines.append(op_info_str)
|
| 459 |
+
if "stack_trace" in o.meta:
|
| 460 |
+
stack_trace = f"{o.meta['stack_trace']}"
|
| 461 |
+
stack_trace_last_line = stack_trace.split("|")[-1]
|
| 462 |
+
out_lines.append(
|
| 463 |
+
"#pragma CMT "
|
| 464 |
+
+ stack_trace_last_line.replace("{", "{{")
|
| 465 |
+
.replace("}", "}}")
|
| 466 |
+
.replace("\n", "\\")
|
| 467 |
+
)
|
| 468 |
+
out_lines.append("#pragma CMT END ORIGIN")
|
| 469 |
+
out_lines.append("")
|
| 470 |
+
|
| 471 |
+
if len(out_lines) == 0:
|
| 472 |
+
return
|
| 473 |
+
|
| 474 |
+
# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
|
| 475 |
+
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
|
| 476 |
+
buffer.writelines(out_lines)
|
| 477 |
+
self.written = True
|
| 478 |
+
|
| 479 |
+
def get_read_write_buffers_sizes(self) -> int:
|
| 480 |
+
"""
|
| 481 |
+
Counting the number of bytes accessed for a kernel is
|
| 482 |
+
surprisingly tricky. In particular, there is a differentiation
|
| 483 |
+
between 'theoretical' memory accesses and practical memory
|
| 484 |
+
accesses. For example, a layernorm kernel may actually access an
|
| 485 |
+
input 3 times, but in theory, it only needs to access its input
|
| 486 |
+
once (and may be optimized to do so through say, persistent
|
| 487 |
+
reductions)
|
| 488 |
+
|
| 489 |
+
Another example is that even though a buffer is passed in, we may
|
| 490 |
+
not access the entire buffer. This may occur if we are accessing
|
| 491 |
+
a slice of the buffer. Another tricky case is for indirect
|
| 492 |
+
indexing, where the amount of bytes accessed depends on the
|
| 493 |
+
values of the input.
|
| 494 |
+
|
| 495 |
+
What this function aims to compute is the memory accesses for
|
| 496 |
+
worst-case inputs, best-case optimization. What this means is
|
| 497 |
+
that for each buffer we compute the amount of potential accesses in two ways and take the minimum.
|
| 498 |
+
|
| 499 |
+
1. Numel in ranges multiplied by number of deps the buffer has
|
| 500 |
+
2. The buffer size
|
| 501 |
+
"""
|
| 502 |
+
if isinstance(self, NopKernelSchedulerNode):
|
| 503 |
+
return 0
|
| 504 |
+
if isinstance(self, ExternKernelSchedulerNode) and isinstance(
|
| 505 |
+
self.node, MultiOutput
|
| 506 |
+
):
|
| 507 |
+
return 0
|
| 508 |
+
|
| 509 |
+
if isinstance(self, SchedulerNode):
|
| 510 |
+
node_numel = V.graph.sizevars.size_hint(
|
| 511 |
+
sympy_product(self.get_ranges()[0])
|
| 512 |
+
* sympy_product(self.get_ranges()[1])
|
| 513 |
+
)
|
| 514 |
+
else:
|
| 515 |
+
node_numel = int(1e9)
|
| 516 |
+
buf_accesses = collections.defaultdict(list)
|
| 517 |
+
for dep in self.read_writes.reads | self.read_writes.writes:
|
| 518 |
+
buf_accesses[dep.name].append(dep)
|
| 519 |
+
|
| 520 |
+
reads = {dep.name for dep in self.read_writes.reads}
|
| 521 |
+
writes = {dep.name for dep in self.read_writes.writes}
|
| 522 |
+
|
| 523 |
+
def is_materialized(buf, snodes):
|
| 524 |
+
users = self.scheduler.name_to_node[buf].users
|
| 525 |
+
buf_uses = {user.node for user in users}
|
| 526 |
+
return len(buf_uses - set(snodes)) > 0
|
| 527 |
+
|
| 528 |
+
if isinstance(self, FusedSchedulerNode):
|
| 529 |
+
removed_buffers = {
|
| 530 |
+
dep for dep in writes if not is_materialized(dep, self.snodes)
|
| 531 |
+
}
|
| 532 |
+
writes = writes - removed_buffers
|
| 533 |
+
reads = reads - removed_buffers
|
| 534 |
+
node_bytes = 0
|
| 535 |
+
|
| 536 |
+
for buf_name in reads | writes:
|
| 537 |
+
buf_accessed_elems = sum([node_numel for dep in buf_accesses[buf_name]])
|
| 538 |
+
buf: Union[ir.Buffer, ir.TensorBox]
|
| 539 |
+
if buf_name in V.graph.name_to_buffer:
|
| 540 |
+
buf = V.graph.name_to_buffer[buf_name]
|
| 541 |
+
elif buf_name in V.graph.graph_inputs:
|
| 542 |
+
buf = V.graph.graph_inputs[buf_name]
|
| 543 |
+
else:
|
| 544 |
+
continue
|
| 545 |
+
|
| 546 |
+
def get_buf_elems(buf):
|
| 547 |
+
return V.graph.sizevars.size_hint(sympy_product(buf.get_size()))
|
| 548 |
+
|
| 549 |
+
# Kind of a lazy way to get the MultiOutput nodes corresponding to
|
| 550 |
+
# a MultiOutputLayout
|
| 551 |
+
if isinstance(buf.layout, MultiOutputLayout):
|
| 552 |
+
users = self.scheduler.name_to_node[buf.get_name()].users
|
| 553 |
+
buf_elems = sum(get_buf_elems(user.node.node) for user in users)
|
| 554 |
+
else:
|
| 555 |
+
buf_elems = get_buf_elems(buf)
|
| 556 |
+
|
| 557 |
+
node_bytes += min(buf_elems, buf_accessed_elems) * get_dtype_size(
|
| 558 |
+
buf.get_dtype()
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
return node_bytes
|
| 562 |
+
|
| 563 |
+
def get_estimated_runtime(self) -> float:
|
| 564 |
+
"""
|
| 565 |
+
Returns estimated op runtime in nanoseconds (ns)
|
| 566 |
+
"""
|
| 567 |
+
layout = None
|
| 568 |
+
dtype = None
|
| 569 |
+
if not hasattr(self, "node") or not self.node:
|
| 570 |
+
assert isinstance(
|
| 571 |
+
self, (FusedSchedulerNode, ForeachKernelSchedulerNode)
|
| 572 |
+
), f"{type(self)=}"
|
| 573 |
+
assert self.snodes
|
| 574 |
+
if not self.snodes[0].node:
|
| 575 |
+
return 0
|
| 576 |
+
layout = self.snodes[0].node.get_layout()
|
| 577 |
+
dtype = self.snodes[0].node.get_dtype()
|
| 578 |
+
else:
|
| 579 |
+
layout = self.node.get_layout()
|
| 580 |
+
dtype = self.node.get_dtype()
|
| 581 |
+
|
| 582 |
+
if "cuda" != layout.device.type:
|
| 583 |
+
# default to no reordering based on runtime
|
| 584 |
+
return 0
|
| 585 |
+
|
| 586 |
+
# Collective kernels
|
| 587 |
+
if is_collective(self.node):
|
| 588 |
+
return estimate_nccl_collective_runtime(self.node)
|
| 589 |
+
elif is_wait(self.node):
|
| 590 |
+
# ir.Wait is only used for collective ops.
|
| 591 |
+
# The time needed for the collective op is already estimated and considered
|
| 592 |
+
# when we are processing the collective op IR node, so ir.Wait takes 0 time
|
| 593 |
+
# since it doesn't take extra time to get the result after the collective is completed.
|
| 594 |
+
return 0
|
| 595 |
+
|
| 596 |
+
try:
|
| 597 |
+
gpu_memory_bandwidth = get_gpu_dram_gbps()
|
| 598 |
+
gpu_flops = get_device_tflops(dtype) * 10**12
|
| 599 |
+
except Exception:
|
| 600 |
+
return 0
|
| 601 |
+
|
| 602 |
+
if isinstance(self, ExternKernelSchedulerNode):
|
| 603 |
+
assert isinstance(self.node, ir.ExternKernel), f"{type(self.node)=}"
|
| 604 |
+
op = kernel_name_to_op.get(
|
| 605 |
+
getattr(self.node, "python_kernel_name", ""), None
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
# if there is a resolved op, dry-run using fake mode and record flop count
|
| 609 |
+
if op is not None:
|
| 610 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 611 |
+
from torch.utils.flop_counter import FlopCounterMode
|
| 612 |
+
|
| 613 |
+
with FakeTensorMode(), FlopCounterMode(
|
| 614 |
+
display=False
|
| 615 |
+
) as flop_counter_mode:
|
| 616 |
+
from .ir import ir_node_to_tensor
|
| 617 |
+
|
| 618 |
+
fake_inputs = [
|
| 619 |
+
ir_node_to_tensor(input, guard_shape=False)
|
| 620 |
+
for input in self.node.inputs
|
| 621 |
+
]
|
| 622 |
+
cls = self.node.__class__
|
| 623 |
+
cls.process_kernel(op, *fake_inputs, **self.node.kwargs)
|
| 624 |
+
|
| 625 |
+
# TODO(xmfan): find a better heuristic to model FLOPS/latency relationship
|
| 626 |
+
factor = 1.0
|
| 627 |
+
counted_flops = flop_counter_mode.get_total_flops()
|
| 628 |
+
counted_bytes = self.get_read_write_buffers_sizes()
|
| 629 |
+
compute_time = (factor * counted_flops / gpu_flops) * 1e9
|
| 630 |
+
transfer_time = counted_bytes / gpu_memory_bandwidth
|
| 631 |
+
|
| 632 |
+
# Return estimated runtime in nanoseconds
|
| 633 |
+
return max(compute_time, transfer_time)
|
| 634 |
+
|
| 635 |
+
elif isinstance(self, FusedSchedulerNode) or isinstance(
|
| 636 |
+
self.node, ComputedBuffer
|
| 637 |
+
):
|
| 638 |
+
# Return estimated runtime in nanoseconds (bytes / gbps)
|
| 639 |
+
return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
|
| 640 |
+
|
| 641 |
+
return 0
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
class ExternKernelSchedulerNode(BaseSchedulerNode):
|
| 645 |
+
def debug_str_extra(self) -> str:
|
| 646 |
+
return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}"
|
| 647 |
+
|
| 648 |
+
def is_extern(self):
|
| 649 |
+
return True
|
| 650 |
+
|
| 651 |
+
def has_side_effects(self):
|
| 652 |
+
return hasattr(self.node, "has_side_effects") and self.node.has_side_effects()
|
| 653 |
+
|
| 654 |
+
def can_inplace(self, read_dep: dependencies.MemoryDep):
|
| 655 |
+
if self.get_aliases() or self.is_template():
|
| 656 |
+
return False
|
| 657 |
+
|
| 658 |
+
if read_dep.name not in self.scheduler.name_to_node:
|
| 659 |
+
# don't allow reuse of an 'input' buffer, we don't own it
|
| 660 |
+
# (would this have been fixed if I tracked mutations properly above?)
|
| 661 |
+
return False
|
| 662 |
+
if not isinstance(
|
| 663 |
+
self.node, (torch._inductor.ir.AllReduce, torch._inductor.ir.InPlaceHint)
|
| 664 |
+
):
|
| 665 |
+
# TODO make this a property of the IR
|
| 666 |
+
return False
|
| 667 |
+
|
| 668 |
+
if len(self.read_writes.writes) == 1:
|
| 669 |
+
write_dep = next(iter(self.read_writes.writes))
|
| 670 |
+
numel_diff = read_dep.get_numel() - write_dep.get_numel()
|
| 671 |
+
return V.graph.sizevars.simplify(numel_diff) == 0
|
| 672 |
+
|
| 673 |
+
return False
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
class NopKernelSchedulerNode(BaseSchedulerNode):
|
| 677 |
+
pass
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
class SchedulerNode(BaseSchedulerNode):
|
| 681 |
+
def __init__(
|
| 682 |
+
self,
|
| 683 |
+
scheduler: "Scheduler",
|
| 684 |
+
node: Union[ir.ComputedBuffer, ir.TemplateBuffer],
|
| 685 |
+
):
|
| 686 |
+
super().__init__(scheduler, node)
|
| 687 |
+
self._compute_attrs()
|
| 688 |
+
|
| 689 |
+
def _compute_attrs(
|
| 690 |
+
self,
|
| 691 |
+
extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None,
|
| 692 |
+
):
|
| 693 |
+
assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer))
|
| 694 |
+
self._sizes, self._body = self.node.simplify_and_reorder(
|
| 695 |
+
extra_indexing_constraints=extra_indexing_constraints
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn
|
| 699 |
+
self.group = (self.node.get_device(), group_fn(self._sizes))
|
| 700 |
+
|
| 701 |
+
if isinstance(self.node, ir.TemplateBuffer):
|
| 702 |
+
self.set_read_writes(self.node.normalized_read_writes())
|
| 703 |
+
else:
|
| 704 |
+
self.set_read_writes(
|
| 705 |
+
dependencies.extract_read_writes(
|
| 706 |
+
self._body, *self._sizes, normalize=True
|
| 707 |
+
)
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
def recompute_size_and_body(
|
| 711 |
+
self, extra_indexing_constraints: Tuple[Dict[Any, Any], List[Any]]
|
| 712 |
+
):
|
| 713 |
+
self._compute_attrs(extra_indexing_constraints=extra_indexing_constraints)
|
| 714 |
+
|
| 715 |
+
def debug_str_extra(self) -> str:
|
| 716 |
+
name = self.get_name()
|
| 717 |
+
lines = [
|
| 718 |
+
f"{name}.group.device = {self.group[0]}",
|
| 719 |
+
f"{name}.group.iteration = {self.group[1]}",
|
| 720 |
+
f"{name}.sizes = {self._sizes}",
|
| 721 |
+
]
|
| 722 |
+
if self.get_aliases():
|
| 723 |
+
lines.append(f"{name}.aliases = {pformat(self.get_aliases())}")
|
| 724 |
+
if self.get_mutations():
|
| 725 |
+
lines.append(f"{name}.mutations = {pformat(self.get_mutations())}")
|
| 726 |
+
if isinstance(self._body, ir.LoopBody):
|
| 727 |
+
lines.append(f"class {name}_loop_body:")
|
| 728 |
+
lines.append(textwrap.indent(self._body.debug_str(), " "))
|
| 729 |
+
return "\n".join(lines)
|
| 730 |
+
|
| 731 |
+
def get_ranges(self):
|
| 732 |
+
return self._sizes
|
| 733 |
+
|
| 734 |
+
def is_reduction(self):
|
| 735 |
+
assert isinstance(
|
| 736 |
+
self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
|
| 737 |
+
), f"{type(self.node)=}"
|
| 738 |
+
return bool(self.node.get_reduction_type())
|
| 739 |
+
|
| 740 |
+
def is_split_scan(self):
|
| 741 |
+
assert isinstance(
|
| 742 |
+
self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
|
| 743 |
+
), f"{type(self.node)=}"
|
| 744 |
+
return isinstance(self.node, ir.ComputedBuffer) and isinstance(
|
| 745 |
+
self.node.data, ir.SplitScan
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
def is_template(self):
|
| 749 |
+
return isinstance(self.node, ir.TemplateBuffer)
|
| 750 |
+
|
| 751 |
+
def get_template_node(self):
|
| 752 |
+
return self.node if self.is_template() else None
|
| 753 |
+
|
| 754 |
+
def run(self, *index_vars):
|
| 755 |
+
self.decide_inplace_update()
|
| 756 |
+
self.mark_run()
|
| 757 |
+
self.codegen(index_vars)
|
| 758 |
+
|
| 759 |
+
def mark_run(self):
|
| 760 |
+
self.allocate()
|
| 761 |
+
|
| 762 |
+
def ranges_from_index_vars(self, index_vars):
|
| 763 |
+
sizes = self._sizes
|
| 764 |
+
assert sum(map(len, sizes)) == sum(map(len, index_vars))
|
| 765 |
+
var_ranges = dict(
|
| 766 |
+
zip(
|
| 767 |
+
itertools.chain.from_iterable(index_vars),
|
| 768 |
+
itertools.chain.from_iterable(sizes),
|
| 769 |
+
)
|
| 770 |
+
)
|
| 771 |
+
return var_ranges
|
| 772 |
+
|
| 773 |
+
def codegen(self, index_vars):
|
| 774 |
+
var_ranges = self.ranges_from_index_vars(index_vars)
|
| 775 |
+
try:
|
| 776 |
+
with V.set_ops_handler(
|
| 777 |
+
SimplifyIndexing(V.get_ops_handler(), var_ranges)
|
| 778 |
+
), V.kernel.set_current_node(self):
|
| 779 |
+
self._body(*index_vars)
|
| 780 |
+
except Exception:
|
| 781 |
+
log.fatal("Error in codegen for %s", self.node)
|
| 782 |
+
raise
|
| 783 |
+
|
| 784 |
+
def pointwise_read_writes(self):
|
| 785 |
+
"""
|
| 786 |
+
Get the memory dependencies in the non-reduction axis.
|
| 787 |
+
"""
|
| 788 |
+
sizes, reduction_sizes = self._sizes
|
| 789 |
+
|
| 790 |
+
def fn(index):
|
| 791 |
+
return self._body(index, [sympy.Integer(0) for _ in reduction_sizes])
|
| 792 |
+
|
| 793 |
+
return dependencies.extract_read_writes(fn, sizes)
|
| 794 |
+
|
| 795 |
+
def can_inplace(self, read_dep: dependencies.MemoryDep):
|
| 796 |
+
if self.get_aliases() or self.is_template():
|
| 797 |
+
return False
|
| 798 |
+
if len(self.read_writes.writes) == 1 and isinstance(
|
| 799 |
+
read_dep, dependencies.MemoryDep
|
| 800 |
+
):
|
| 801 |
+
write_dep = next(iter(self.read_writes.writes))
|
| 802 |
+
assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}"
|
| 803 |
+
return read_dep.index == write_dep.index and read_dep.size == write_dep.size
|
| 804 |
+
return False
|
| 805 |
+
|
| 806 |
+
@cache_on_self
|
| 807 |
+
def _get_atomic_add_buffers(self) -> Set[str]:
|
| 808 |
+
buffers_store_as_atomic_add = set()
|
| 809 |
+
if isinstance(self._body, ir.LoopBody):
|
| 810 |
+
for node in self._body.get_nodes():
|
| 811 |
+
if (
|
| 812 |
+
node.op == "call_method"
|
| 813 |
+
and node.target == "store"
|
| 814 |
+
and (
|
| 815 |
+
("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add")
|
| 816 |
+
or (len(node.args) == 5 and node.args[4] == "atomic_add")
|
| 817 |
+
)
|
| 818 |
+
):
|
| 819 |
+
buffers_store_as_atomic_add.add(
|
| 820 |
+
node.kwargs["name"]
|
| 821 |
+
if "name" in node.kwargs
|
| 822 |
+
else (node.args[1] if len(node.args) >= 2 else "")
|
| 823 |
+
)
|
| 824 |
+
return buffers_store_as_atomic_add
|
| 825 |
+
|
| 826 |
+
def has_atomic_add(self, check_buf):
|
| 827 |
+
return check_buf in self._get_atomic_add_buffers()
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
class FusedSchedulerNode(BaseSchedulerNode):
|
| 831 |
+
"""
|
| 832 |
+
This is a "fake" scheduler node that represents a group of scheduler nodes
|
| 833 |
+
that are meant to be fused together. The way it does this is by maintaining
|
| 834 |
+
its unmet dependencies as the union of its constituent nodes.
|
| 835 |
+
"""
|
| 836 |
+
|
| 837 |
+
@classmethod
|
| 838 |
+
def fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
|
| 839 |
+
assert node1.scheduler is node2.scheduler
|
| 840 |
+
assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance(
|
| 841 |
+
node2, (SchedulerNode, FusedSchedulerNode)
|
| 842 |
+
)
|
| 843 |
+
return cls(node1.scheduler, list(node1.get_nodes()) + list(node2.get_nodes())) # type: ignore[arg-type]
|
| 844 |
+
|
| 845 |
+
def __init__(self, scheduler: "Scheduler", snodes: List[SchedulerNode]):
|
| 846 |
+
# NB: No need to call super().__init__() because we don't need to re-use any of its logic.
|
| 847 |
+
self.snodes = snodes
|
| 848 |
+
self.scheduler = scheduler
|
| 849 |
+
self.node: ir.Buffer = None # type: ignore[assignment]
|
| 850 |
+
self.users: List[NodeUser] = []
|
| 851 |
+
self.inverse_users = []
|
| 852 |
+
self.node_users = []
|
| 853 |
+
self.group = max(snodes, key=lambda x: int(x.is_reduction())).group
|
| 854 |
+
self.ancestors = set.union(
|
| 855 |
+
*[x.ancestors for x in snodes if x.ancestors is not None]
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
self.set_read_writes(
|
| 859 |
+
dependencies.ReadWrites.merge_list([x.read_writes for x in snodes])
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
self.unmet_dependencies = {
|
| 863 |
+
dep
|
| 864 |
+
for dep in set.union(*[x.unmet_dependencies for x in snodes])
|
| 865 |
+
if dep.name not in self.get_names()
|
| 866 |
+
} - self.read_writes.writes
|
| 867 |
+
self.min_order = min([x.min_order for x in self.snodes])
|
| 868 |
+
self.max_order = max([x.max_order for x in self.snodes])
|
| 869 |
+
|
| 870 |
+
@cache_on_self
|
| 871 |
+
def get_name(self) -> str:
|
| 872 |
+
return "_".join([x.get_name() for x in self.snodes])
|
| 873 |
+
|
| 874 |
+
def get_first_name(self) -> str:
|
| 875 |
+
return self.snodes[0].get_name()
|
| 876 |
+
|
| 877 |
+
@cache_on_self
|
| 878 |
+
def get_names(self) -> Set[str]:
|
| 879 |
+
return set.union(*[x.get_names() for x in self.snodes])
|
| 880 |
+
|
| 881 |
+
def debug_str_extra(self) -> str:
|
| 882 |
+
lines = [
|
| 883 |
+
f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}"
|
| 884 |
+
for i, node in enumerate(self.snodes)
|
| 885 |
+
]
|
| 886 |
+
return textwrap.indent("\n".join(lines).rstrip(), " ")
|
| 887 |
+
|
| 888 |
+
def set_last_usage(
|
| 889 |
+
self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str]
|
| 890 |
+
):
|
| 891 |
+
# Set self.last_usage using the global information
|
| 892 |
+
# This will be used for inter-kernel optimisations
|
| 893 |
+
super().set_last_usage(future_used_buffers, mutation_real_name)
|
| 894 |
+
# Set self.last_usage on the snodes
|
| 895 |
+
# This will be used for optimisations within the kernel
|
| 896 |
+
future_used_buffers: Set[str] = set()
|
| 897 |
+
for node in reversed(self.snodes):
|
| 898 |
+
node.set_last_usage(future_used_buffers, mutation_real_name)
|
| 899 |
+
future_used_buffers.update(node.last_usage) # type: ignore[arg-type]
|
| 900 |
+
|
| 901 |
+
@cache_on_self
|
| 902 |
+
def used_buffer_names(self) -> Set[str]:
|
| 903 |
+
return set.union(*[x.used_buffer_names() for x in self.snodes])
|
| 904 |
+
|
| 905 |
+
@cache_on_self
|
| 906 |
+
def used_or_aliased_buffer_names(self) -> Set[str]:
|
| 907 |
+
return set.union(*[x.used_or_aliased_buffer_names() for x in self.snodes])
|
| 908 |
+
|
| 909 |
+
def get_nodes(self) -> List[SchedulerNode]:
|
| 910 |
+
return self.snodes
|
| 911 |
+
|
| 912 |
+
def __repr__(self):
|
| 913 |
+
return f"{type(self).__name__}(nodes={self.get_name()})"
|
| 914 |
+
|
| 915 |
+
@cache_on_self
|
| 916 |
+
def is_reduction(self):
|
| 917 |
+
return any(x.is_reduction() for x in self.snodes)
|
| 918 |
+
|
| 919 |
+
@cache_on_self
|
| 920 |
+
def is_split_scan(self):
|
| 921 |
+
return any(x.is_split_scan() for x in self.snodes)
|
| 922 |
+
|
| 923 |
+
@cache_on_self
|
| 924 |
+
def is_template(self):
|
| 925 |
+
return any(x.is_template() for x in self.snodes)
|
| 926 |
+
|
| 927 |
+
@cache_on_self
|
| 928 |
+
def get_template_node(self):
|
| 929 |
+
for node in self.snodes:
|
| 930 |
+
if node.is_template():
|
| 931 |
+
return node
|
| 932 |
+
return None
|
| 933 |
+
|
| 934 |
+
def get_device(self):
|
| 935 |
+
return self.group[0]
|
| 936 |
+
|
| 937 |
+
@cache_on_self
|
| 938 |
+
def has_aliasing_or_mutation(self):
|
| 939 |
+
return any(x.has_aliasing_or_mutation() for x in self.snodes)
|
| 940 |
+
|
| 941 |
+
@cache_on_self
|
| 942 |
+
def op_counts(self):
|
| 943 |
+
op_counts: Counter[str] = collections.Counter()
|
| 944 |
+
for node in self.snodes:
|
| 945 |
+
op_counts.update(node.op_counts())
|
| 946 |
+
return op_counts
|
| 947 |
+
|
| 948 |
+
def has_atomic_add(self, check_buf):
|
| 949 |
+
return any(
|
| 950 |
+
(
|
| 951 |
+
isinstance(sub_schedule_node1, SchedulerNode)
|
| 952 |
+
and sub_schedule_node1.has_atomic_add(check_buf)
|
| 953 |
+
)
|
| 954 |
+
for sub_schedule_node1 in self.get_nodes()
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
# None of these need to be implemented, as a FusedSchedulerNode is just an
|
| 958 |
+
# abstraction for scheduling purposes
|
| 959 |
+
def update_mutated_names(self, renames: Dict[str, str]):
|
| 960 |
+
raise NotImplementedError
|
| 961 |
+
|
| 962 |
+
def add_mutation_dep(self, name):
|
| 963 |
+
raise NotImplementedError
|
| 964 |
+
|
| 965 |
+
def set_users(self, users: List["NodeUser"]):
|
| 966 |
+
raise NotImplementedError
|
| 967 |
+
|
| 968 |
+
def get_aliases(self):
|
| 969 |
+
raise NotImplementedError
|
| 970 |
+
|
| 971 |
+
def get_mutations(self):
|
| 972 |
+
raise NotImplementedError
|
| 973 |
+
|
| 974 |
+
def can_inplace(self, read_dep: dependencies.MemoryDep):
|
| 975 |
+
raise NotImplementedError
|
| 976 |
+
|
| 977 |
+
def allocate(self):
|
| 978 |
+
raise NotImplementedError
|
| 979 |
+
|
| 980 |
+
def can_free(self):
|
| 981 |
+
raise NotImplementedError
|
| 982 |
+
|
| 983 |
+
def debug_str(self) -> str:
|
| 984 |
+
"""Longer form printout for trace logs"""
|
| 985 |
+
name = self.get_name()
|
| 986 |
+
node_typestr = ",".join(type(n).__name__ for n in self.snodes)
|
| 987 |
+
lines = [
|
| 988 |
+
f"{name}: {type(self).__name__}({node_typestr})",
|
| 989 |
+
f"{name}.writes = {pformat(self.read_writes.writes)}",
|
| 990 |
+
f"{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}",
|
| 991 |
+
f"{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}",
|
| 992 |
+
f"{name}.users = {self.users}",
|
| 993 |
+
]
|
| 994 |
+
try:
|
| 995 |
+
lines += [
|
| 996 |
+
self.debug_str_extra(),
|
| 997 |
+
]
|
| 998 |
+
except Exception:
|
| 999 |
+
log.warning("Ignoring error in debug_str()", exc_info=True)
|
| 1000 |
+
|
| 1001 |
+
return "\n".join(lines).rstrip()
|
| 1002 |
+
|
| 1003 |
+
|
| 1004 |
+
class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
| 1005 |
+
"""Scheduler node which consists of a list of scheduler nodes that each operate on a
|
| 1006 |
+
distinct tensor in a list of tensors."""
|
| 1007 |
+
|
| 1008 |
+
def get_consumer_subnode_for(self, producer):
|
| 1009 |
+
if producer.get_name() in self.read_to_node:
|
| 1010 |
+
return self.read_to_node[producer.get_name()]
|
| 1011 |
+
|
| 1012 |
+
return None
|
| 1013 |
+
|
| 1014 |
+
def get_producer_subnode_for(self, consumer):
|
| 1015 |
+
for rd in consumer.read_writes.reads:
|
| 1016 |
+
if rd.name in self.name_to_node:
|
| 1017 |
+
return self.name_to_node[rd.name]
|
| 1018 |
+
|
| 1019 |
+
return None
|
| 1020 |
+
|
| 1021 |
+
@classmethod
|
| 1022 |
+
def can_fuse(cls, producer, consumer):
|
| 1023 |
+
why = WhyNoFuse(producer, consumer)
|
| 1024 |
+
if producer.is_foreach() and consumer.is_foreach():
|
| 1025 |
+
foreach_match = len(producer.snodes) == len(consumer.snodes)
|
| 1026 |
+
if not foreach_match:
|
| 1027 |
+
why("foreach do not have same length")
|
| 1028 |
+
return foreach_match and all(
|
| 1029 |
+
producer.scheduler.can_fuse(l, r)
|
| 1030 |
+
for l, r in zip(producer.snodes, consumer.snodes)
|
| 1031 |
+
)
|
| 1032 |
+
elif consumer.is_foreach():
|
| 1033 |
+
consumer_subnode = consumer.get_consumer_subnode_for(producer)
|
| 1034 |
+
if consumer_subnode is not None:
|
| 1035 |
+
return consumer.scheduler.can_fuse(producer, consumer_subnode)
|
| 1036 |
+
|
| 1037 |
+
why("candidate producer is not dep of any foreach consumer")
|
| 1038 |
+
return False
|
| 1039 |
+
|
| 1040 |
+
elif producer.is_foreach():
|
| 1041 |
+
producer_subnode = producer.get_producer_subnode_for(consumer)
|
| 1042 |
+
if producer_subnode is not None:
|
| 1043 |
+
return producer.scheduler.can_fuse(producer_subnode, consumer)
|
| 1044 |
+
|
| 1045 |
+
why("candidate consumer has no dep in any foreach producer")
|
| 1046 |
+
return False
|
| 1047 |
+
|
| 1048 |
+
raise AssertionError(
|
| 1049 |
+
"At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node"
|
| 1050 |
+
)
|
| 1051 |
+
|
| 1052 |
+
@classmethod
|
| 1053 |
+
def fuse(cls, producer, consumer):
|
| 1054 |
+
assert producer.is_foreach() or consumer.is_foreach()
|
| 1055 |
+
prev_node_1 = None
|
| 1056 |
+
prev_node_2 = None
|
| 1057 |
+
if producer.is_foreach() and consumer.is_foreach():
|
| 1058 |
+
fused_nodes = [
|
| 1059 |
+
FusedSchedulerNode.fuse(l, r)
|
| 1060 |
+
for l, r in zip(producer.snodes, consumer.snodes)
|
| 1061 |
+
]
|
| 1062 |
+
elif producer.is_foreach():
|
| 1063 |
+
producer_subnode = producer.get_producer_subnode_for(consumer)
|
| 1064 |
+
fused_nodes = []
|
| 1065 |
+
prev_node_1 = producer
|
| 1066 |
+
prev_node_2 = None
|
| 1067 |
+
for node in producer.snodes:
|
| 1068 |
+
if node is producer_subnode:
|
| 1069 |
+
new_node = FusedSchedulerNode.fuse(node, consumer)
|
| 1070 |
+
prev_node_2 = new_node
|
| 1071 |
+
fused_nodes.append(new_node)
|
| 1072 |
+
else:
|
| 1073 |
+
fused_nodes.append(node)
|
| 1074 |
+
|
| 1075 |
+
elif consumer.is_foreach():
|
| 1076 |
+
consumer_subnode = consumer.get_consumer_subnode_for(producer)
|
| 1077 |
+
fused_nodes = []
|
| 1078 |
+
prev_node_1 = consumer
|
| 1079 |
+
prev_node_2 = None
|
| 1080 |
+
|
| 1081 |
+
for node in consumer.snodes:
|
| 1082 |
+
if node is consumer_subnode:
|
| 1083 |
+
new_node = FusedSchedulerNode.fuse(producer, node)
|
| 1084 |
+
prev_node_2 = new_node
|
| 1085 |
+
fused_nodes.append(new_node)
|
| 1086 |
+
else:
|
| 1087 |
+
fused_nodes.append(node)
|
| 1088 |
+
|
| 1089 |
+
return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2) # type: ignore[possibly-undefined]
|
| 1090 |
+
|
| 1091 |
+
def __init__(
|
| 1092 |
+
self,
|
| 1093 |
+
scheduler: "Scheduler",
|
| 1094 |
+
nodes: List[SchedulerNode],
|
| 1095 |
+
prev_node_1=None,
|
| 1096 |
+
prev_node_2=None,
|
| 1097 |
+
):
|
| 1098 |
+
self.read_to_node = {}
|
| 1099 |
+
self.name_to_node = {}
|
| 1100 |
+
|
| 1101 |
+
if prev_node_1 is None or prev_node_2 is None:
|
| 1102 |
+
super().__init__(scheduler, nodes)
|
| 1103 |
+
|
| 1104 |
+
for node in nodes:
|
| 1105 |
+
for read in node.read_writes.reads:
|
| 1106 |
+
self.read_to_node[read.name] = node
|
| 1107 |
+
|
| 1108 |
+
for name in node.get_names():
|
| 1109 |
+
self.name_to_node[name] = node
|
| 1110 |
+
else:
|
| 1111 |
+
self.scheduler = scheduler
|
| 1112 |
+
self.snodes = nodes
|
| 1113 |
+
self.node: ir.Buffer = None # type: ignore[assignment]
|
| 1114 |
+
self.users: List[NodeUser] = []
|
| 1115 |
+
|
| 1116 |
+
self.set_read_writes(
|
| 1117 |
+
dependencies.ReadWrites.merge_list(
|
| 1118 |
+
[prev_node_1.read_writes, prev_node_2.read_writes]
|
| 1119 |
+
)
|
| 1120 |
+
)
|
| 1121 |
+
|
| 1122 |
+
self.unmet_dependencies = {
|
| 1123 |
+
dep
|
| 1124 |
+
for dep in set.union(
|
| 1125 |
+
prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies
|
| 1126 |
+
)
|
| 1127 |
+
if dep.name not in self.get_names()
|
| 1128 |
+
} - self.read_writes.writes
|
| 1129 |
+
|
| 1130 |
+
self.min_order = min([prev_node_1.min_order, prev_node_2.min_order])
|
| 1131 |
+
self.max_order = max([prev_node_1.max_order, prev_node_2.max_order])
|
| 1132 |
+
|
| 1133 |
+
foreach_node = prev_node_1 if prev_node_1.is_foreach() else prev_node_2
|
| 1134 |
+
other_node = prev_node_2 if prev_node_1.is_foreach() else prev_node_1
|
| 1135 |
+
|
| 1136 |
+
self.ancestors = foreach_node.ancestors
|
| 1137 |
+
self.ancestors.update(other_node.ancestors)
|
| 1138 |
+
|
| 1139 |
+
self.name_to_node = foreach_node.name_to_node
|
| 1140 |
+
for name in other_node.get_names():
|
| 1141 |
+
self.name_to_node[name] = other_node
|
| 1142 |
+
|
| 1143 |
+
self.group = (nodes[0].get_device(), "foreach")
|
| 1144 |
+
|
| 1145 |
+
self.origins: Set[torch.fx.Node] = set()
|
| 1146 |
+
|
| 1147 |
+
def mark_run(self):
|
| 1148 |
+
raise NotImplementedError
|
| 1149 |
+
|
| 1150 |
+
def codegen(self):
|
| 1151 |
+
assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}"
|
| 1152 |
+
self.node.get_store_function()(self.node.make_loader()())
|
| 1153 |
+
|
| 1154 |
+
def can_free(self):
|
| 1155 |
+
return NotImplementedError
|
| 1156 |
+
|
| 1157 |
+
def is_foreach(self):
|
| 1158 |
+
return True
|
| 1159 |
+
|
| 1160 |
+
def get_subkernel_nodes(self):
|
| 1161 |
+
"""Returns a list of nodes which comprise the foreach kernel, operating on corresponding elements of our input lists.
|
| 1162 |
+
These nodes may be vertically fused."""
|
| 1163 |
+
return list(self.snodes)
|
| 1164 |
+
|
| 1165 |
+
def get_nodes(self):
|
| 1166 |
+
"""Returns all nodes contained in this kernel, unpacking fused nodes into their constituent scheduler nodes."""
|
| 1167 |
+
return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes))
|
| 1168 |
+
|
| 1169 |
+
def get_first_name(self):
|
| 1170 |
+
return self.snodes[0].get_first_name()
|
| 1171 |
+
|
| 1172 |
+
def prune_redundant_deps(self, name_to_fused_node):
|
| 1173 |
+
_prune_redundant_deps(self, name_to_fused_node)
|
| 1174 |
+
|
| 1175 |
+
for node in self.snodes:
|
| 1176 |
+
node.prune_redundant_deps(name_to_fused_node)
|
| 1177 |
+
|
| 1178 |
+
|
| 1179 |
+
def pick_loop_order(stride_lengths, sizes, priority_idx=()):
|
| 1180 |
+
"""
|
| 1181 |
+
A heuristic to decide loop iteration orders. This has not been well
|
| 1182 |
+
tuned and may be something we should autotune.
|
| 1183 |
+
"""
|
| 1184 |
+
|
| 1185 |
+
@functools.cmp_to_key
|
| 1186 |
+
def index_cmp(a, b):
|
| 1187 |
+
if sizes[a] == 1 or sizes[b] == 1:
|
| 1188 |
+
# 1-sizes don't matter, just move them to the end
|
| 1189 |
+
return cmp(sizes[a] == 1, sizes[b] == 1)
|
| 1190 |
+
|
| 1191 |
+
stride_len_a = [sl[a] for sl in stride_lengths]
|
| 1192 |
+
stride_len_b = [sl[b] for sl in stride_lengths]
|
| 1193 |
+
|
| 1194 |
+
# equivalent to
|
| 1195 |
+
# np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all()
|
| 1196 |
+
a_first = sum(
|
| 1197 |
+
sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b)
|
| 1198 |
+
)
|
| 1199 |
+
b_first = sum(
|
| 1200 |
+
sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b)
|
| 1201 |
+
)
|
| 1202 |
+
if a_first > b_first:
|
| 1203 |
+
return -1
|
| 1204 |
+
if b_first > a_first:
|
| 1205 |
+
return 1
|
| 1206 |
+
|
| 1207 |
+
# otherwise contiguous
|
| 1208 |
+
return cmp(b, a)
|
| 1209 |
+
|
| 1210 |
+
order = list(reversed(range(len(stride_lengths[0]))))
|
| 1211 |
+
if len(priority_idx) > 0:
|
| 1212 |
+
# if we have priority node, only use that node's order
|
| 1213 |
+
stride_lengths = [stride_lengths[pi] for pi in priority_idx]
|
| 1214 |
+
if config.pick_loop_orders:
|
| 1215 |
+
order.sort(key=index_cmp)
|
| 1216 |
+
return order
|
| 1217 |
+
|
| 1218 |
+
|
| 1219 |
+
@dataclasses.dataclass
|
| 1220 |
+
class NodeUser:
|
| 1221 |
+
node: BaseSchedulerNode
|
| 1222 |
+
can_inplace: bool = False
|
| 1223 |
+
|
| 1224 |
+
# A weak user must be scheduled after a given node, but doesn't actually
|
| 1225 |
+
# use the result
|
| 1226 |
+
is_weak: bool = False
|
| 1227 |
+
|
| 1228 |
+
def __hash__(self):
|
| 1229 |
+
return hash((self.node.get_name(), self.can_inplace, self.is_weak))
|
| 1230 |
+
|
| 1231 |
+
def __eq__(self, other):
|
| 1232 |
+
return (
|
| 1233 |
+
self.get_name() == other.get_name()
|
| 1234 |
+
and self.can_inplace == other.can_inplace
|
| 1235 |
+
and self.is_weak == other.is_weak
|
| 1236 |
+
)
|
| 1237 |
+
|
| 1238 |
+
def get_name(self):
|
| 1239 |
+
return self.node.get_name()
|
| 1240 |
+
|
| 1241 |
+
def merge(self, other: "NodeUser") -> "NodeUser":
|
| 1242 |
+
assert self.node is other.node
|
| 1243 |
+
return NodeUser(
|
| 1244 |
+
self.node,
|
| 1245 |
+
self.can_inplace and other.can_inplace,
|
| 1246 |
+
self.is_weak and other.is_weak,
|
| 1247 |
+
)
|
| 1248 |
+
|
| 1249 |
+
|
| 1250 |
+
_post_grad_graph_counter = itertools.count()
|
| 1251 |
+
|
| 1252 |
+
|
| 1253 |
+
class Scheduler:
|
| 1254 |
+
@dynamo_timed
|
| 1255 |
+
def __init__(self, nodes):
|
| 1256 |
+
super().__init__()
|
| 1257 |
+
self.backends = {}
|
| 1258 |
+
self.fuse_cache = {}
|
| 1259 |
+
self.post_grad_graph_id = next(_post_grad_graph_counter)
|
| 1260 |
+
|
| 1261 |
+
self.nodes = []
|
| 1262 |
+
self.available_buffer_names = {
|
| 1263 |
+
*V.graph.graph_inputs.keys(),
|
| 1264 |
+
*V.graph.constants.keys(),
|
| 1265 |
+
}
|
| 1266 |
+
|
| 1267 |
+
self.nodes = [self.create_scheduler_node(n) for n in nodes]
|
| 1268 |
+
|
| 1269 |
+
# some new constants could have been created above
|
| 1270 |
+
self.available_buffer_names.update(V.graph.constants.keys())
|
| 1271 |
+
for node in self.nodes:
|
| 1272 |
+
node.prune_deps()
|
| 1273 |
+
|
| 1274 |
+
self.name_to_node: Dict[str, BaseSchedulerNode] = {
|
| 1275 |
+
n.get_name(): n for n in self.nodes
|
| 1276 |
+
}
|
| 1277 |
+
self.name_to_fused_node: Dict[
|
| 1278 |
+
str, BaseSchedulerNode
|
| 1279 |
+
] = dict() # set in fuse_nodes()
|
| 1280 |
+
|
| 1281 |
+
# mutation_real_name: Maps back to the original name for codegen
|
| 1282 |
+
# Example:
|
| 1283 |
+
# If you mutate buf0 inside of buf1's kernel, then:
|
| 1284 |
+
# mutation_real_name = {"buf0" : "buf1"}
|
| 1285 |
+
# all subsequent uses of buf0 become buf1's usage in dependency graph
|
| 1286 |
+
self.mutation_real_name = {}
|
| 1287 |
+
|
| 1288 |
+
# We handle mutation by renaming modified versions of the same
|
| 1289 |
+
# buffer in the dependency graph to prevent cycles.
|
| 1290 |
+
# mutation_renames: tracks the current name for a given buffer
|
| 1291 |
+
# (changed once per mutation)
|
| 1292 |
+
# Example:
|
| 1293 |
+
# If you mutate buf0 inside of buf1's kernel, then:
|
| 1294 |
+
# mutation_renames = {"buf1" : "buf0"}
|
| 1295 |
+
# in codegen we only use buf0, never buf1
|
| 1296 |
+
self.mutation_renames = {}
|
| 1297 |
+
|
| 1298 |
+
self.compute_dependencies()
|
| 1299 |
+
self.topological_sort_schedule()
|
| 1300 |
+
self.dead_node_elimination()
|
| 1301 |
+
if config.reorder_for_compute_comm_overlap:
|
| 1302 |
+
comms.decide_global_ordering_of_comms(self.nodes)
|
| 1303 |
+
self.compute_ancestors()
|
| 1304 |
+
|
| 1305 |
+
metrics.ir_nodes_pre_fusion += len(self.nodes)
|
| 1306 |
+
V.debug.ir_pre_fusion(self.nodes)
|
| 1307 |
+
self.num_orig_nodes = len(self.nodes)
|
| 1308 |
+
self.name_to_fused_node = {n.get_name(): n for n in self.nodes}
|
| 1309 |
+
self.create_foreach_nodes()
|
| 1310 |
+
self.topological_sort_schedule()
|
| 1311 |
+
self.logged_slow_fusion = set()
|
| 1312 |
+
self.fuse_nodes()
|
| 1313 |
+
if config.reorder_for_compute_comm_overlap:
|
| 1314 |
+
# Refresh node_users and inverse_users to reflect fused nodes
|
| 1315 |
+
self.compute_node_users()
|
| 1316 |
+
self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
|
| 1317 |
+
self.compute_last_usage()
|
| 1318 |
+
V.debug.ir_post_fusion(self.nodes)
|
| 1319 |
+
V.debug.graph_diagram(self.nodes)
|
| 1320 |
+
self.debug_draw_graph()
|
| 1321 |
+
|
| 1322 |
+
# used during codegen:
|
| 1323 |
+
self.current_device: torch.device = None # type: ignore[assignment]
|
| 1324 |
+
self.buffer_names_to_free = set()
|
| 1325 |
+
|
| 1326 |
+
# fx graph node to the position it appears in the graph
|
| 1327 |
+
# for debug attribution
|
| 1328 |
+
self.origin_to_index = {}
|
| 1329 |
+
|
| 1330 |
+
get_metric_table("graph_stats").add_row(
|
| 1331 |
+
lambda: {
|
| 1332 |
+
"graph_id": self.post_grad_graph_id,
|
| 1333 |
+
"num_nodes_before_fusion": self.num_orig_nodes,
|
| 1334 |
+
"num_nodes_after_fusion": len(self.nodes),
|
| 1335 |
+
}
|
| 1336 |
+
)
|
| 1337 |
+
|
| 1338 |
+
def debug_draw_graph(self):
|
| 1339 |
+
"""Generate an image of the graph for debugging"""
|
| 1340 |
+
if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1":
|
| 1341 |
+
from .debug import draw_buffers
|
| 1342 |
+
|
| 1343 |
+
draw_buffers(self.nodes, print_graph=True)
|
| 1344 |
+
|
| 1345 |
+
def debug_print_nodes(self, label):
|
| 1346 |
+
if log.isEnabledFor(logging.INFO):
|
| 1347 |
+
log.info("%s:", label)
|
| 1348 |
+
for node in self.nodes:
|
| 1349 |
+
node.log_details()
|
| 1350 |
+
|
| 1351 |
+
def create_scheduler_node(self, node):
|
| 1352 |
+
assert (
|
| 1353 |
+
node.origins is not None
|
| 1354 |
+
), "All nodes passed to scheduling must have an origin"
|
| 1355 |
+
if node.is_no_op():
|
| 1356 |
+
return NopKernelSchedulerNode(self, node)
|
| 1357 |
+
elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
|
| 1358 |
+
return SchedulerNode(self, node)
|
| 1359 |
+
elif isinstance(node, ir.ExternKernel):
|
| 1360 |
+
return ExternKernelSchedulerNode(self, node)
|
| 1361 |
+
else:
|
| 1362 |
+
raise NotImplementedError(node)
|
| 1363 |
+
|
| 1364 |
+
def create_foreach_nodes(self):
|
| 1365 |
+
removed_node_names = set()
|
| 1366 |
+
fe_nodes = []
|
| 1367 |
+
kept_node_names = self.name_to_fused_node.keys()
|
| 1368 |
+
|
| 1369 |
+
for names in V.graph.lists.values():
|
| 1370 |
+
names = [
|
| 1371 |
+
name
|
| 1372 |
+
for name in names
|
| 1373 |
+
if name in kept_node_names
|
| 1374 |
+
and not isinstance(self.name_to_node[name], NopKernelSchedulerNode)
|
| 1375 |
+
]
|
| 1376 |
+
if not names:
|
| 1377 |
+
# All nodes eliminated
|
| 1378 |
+
continue
|
| 1379 |
+
|
| 1380 |
+
removed_node_names.update(names)
|
| 1381 |
+
snodes = [self.name_to_node[name] for name in names]
|
| 1382 |
+
|
| 1383 |
+
fe_node = ForeachKernelSchedulerNode(self, snodes) # type: ignore[arg-type]
|
| 1384 |
+
|
| 1385 |
+
fe_nodes.append(fe_node)
|
| 1386 |
+
|
| 1387 |
+
for name in names:
|
| 1388 |
+
self.name_to_fused_node[name] = fe_node
|
| 1389 |
+
|
| 1390 |
+
self.nodes = [
|
| 1391 |
+
node for node in self.nodes if node.get_name() not in removed_node_names
|
| 1392 |
+
] + fe_nodes
|
| 1393 |
+
|
| 1394 |
+
def compute_dependencies(self):
|
| 1395 |
+
"""
|
| 1396 |
+
Create dependency edges between nodes, handling aliasing and
|
| 1397 |
+
mutation properly.
|
| 1398 |
+
"""
|
| 1399 |
+
|
| 1400 |
+
T = TypeVar("T")
|
| 1401 |
+
|
| 1402 |
+
class DedupList(Generic[T]):
|
| 1403 |
+
"""
|
| 1404 |
+
This data structure behaves like a list except it makes sure the
|
| 1405 |
+
elements remain unique.
|
| 1406 |
+
Normally one could use a set/dict for this purpose however
|
| 1407 |
+
the list in question gets elements appended as it is being
|
| 1408 |
+
iterated over which means that we need to keep the list
|
| 1409 |
+
semantics.
|
| 1410 |
+
"""
|
| 1411 |
+
|
| 1412 |
+
def __init__(self, items=None, membership=None):
|
| 1413 |
+
self.items = items or list()
|
| 1414 |
+
self.membership = membership or set()
|
| 1415 |
+
|
| 1416 |
+
def append(self, node_user: T) -> None:
|
| 1417 |
+
if node_user in self.membership:
|
| 1418 |
+
return
|
| 1419 |
+
self.items.append(node_user)
|
| 1420 |
+
self.membership.add(node_user)
|
| 1421 |
+
|
| 1422 |
+
def __add__(self, other: "DedupList[T]") -> "DedupList[T]":
|
| 1423 |
+
new_membership = set.union(self.membership, other.membership)
|
| 1424 |
+
new_items = self.items + [
|
| 1425 |
+
x for x in other.items if x not in self.membership
|
| 1426 |
+
]
|
| 1427 |
+
return DedupList(new_items, new_membership)
|
| 1428 |
+
|
| 1429 |
+
name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict(
|
| 1430 |
+
DedupList
|
| 1431 |
+
)
|
| 1432 |
+
|
| 1433 |
+
# handle aliasing by using python aliasing in name_to_users
|
| 1434 |
+
# if foo aliases bar then we will make name_to_users["foo"] point
|
| 1435 |
+
# to the same python list as name_to_users["bar"]
|
| 1436 |
+
for node1 in self.nodes:
|
| 1437 |
+
node1_name = node1.get_name()
|
| 1438 |
+
for node2_name in node1.get_aliases():
|
| 1439 |
+
if node1_name in name_to_users and node2_name in name_to_users:
|
| 1440 |
+
# merge the two
|
| 1441 |
+
list1 = name_to_users[node1_name]
|
| 1442 |
+
list2 = name_to_users[node2_name]
|
| 1443 |
+
combined = list1 + list2
|
| 1444 |
+
for key in name_to_users.keys():
|
| 1445 |
+
if name_to_users[key] is list1 or name_to_users[key] is list2:
|
| 1446 |
+
name_to_users[key] = combined
|
| 1447 |
+
elif node1_name in name_to_users:
|
| 1448 |
+
name_to_users[node2_name] = name_to_users[node1_name]
|
| 1449 |
+
else:
|
| 1450 |
+
name_to_users[node1_name] = name_to_users[node2_name]
|
| 1451 |
+
|
| 1452 |
+
def rename(n):
|
| 1453 |
+
if n in self.mutation_renames:
|
| 1454 |
+
return rename(self.mutation_renames[n])
|
| 1455 |
+
return n
|
| 1456 |
+
|
| 1457 |
+
def dep_closure(node_name):
|
| 1458 |
+
reachable_names = {node_name}
|
| 1459 |
+
node = self.name_to_node[node_name]
|
| 1460 |
+
write_dep = next(iter(node.read_writes.writes))
|
| 1461 |
+
for read_dep in node.read_writes.reads:
|
| 1462 |
+
if (
|
| 1463 |
+
read_dep.name in self.name_to_node
|
| 1464 |
+
and isinstance(read_dep, dependencies.MemoryDep)
|
| 1465 |
+
and isinstance(write_dep, dependencies.MemoryDep)
|
| 1466 |
+
and read_dep.index == write_dep.index
|
| 1467 |
+
and read_dep.size == write_dep.size
|
| 1468 |
+
):
|
| 1469 |
+
reachable_names.update(dep_closure(read_dep.name))
|
| 1470 |
+
return reachable_names
|
| 1471 |
+
|
| 1472 |
+
def add_user(used_by_name, user_node, can_inplace=False, is_weak=False):
|
| 1473 |
+
name_to_users[rename(used_by_name)].append(
|
| 1474 |
+
NodeUser(user_node, can_inplace, is_weak)
|
| 1475 |
+
)
|
| 1476 |
+
|
| 1477 |
+
unbacked_symbol_to_origin_node = {}
|
| 1478 |
+
|
| 1479 |
+
for node in self.nodes:
|
| 1480 |
+
log.debug("scheduling %s", node.node)
|
| 1481 |
+
|
| 1482 |
+
# unbacked symbols don't follow ordinary buffer dependencies, so
|
| 1483 |
+
# we track their def/uses separately
|
| 1484 |
+
unbacked_symbol_defs = sorted(
|
| 1485 |
+
node.node.get_unbacked_symbol_defs(), key=lambda x: x.name
|
| 1486 |
+
)
|
| 1487 |
+
for s in unbacked_symbol_defs:
|
| 1488 |
+
assert isinstance(s, sympy.Symbol)
|
| 1489 |
+
# Pick the first definer as canonical. There may be multiple
|
| 1490 |
+
# because if a MultiOutputLayout buffer propagates an unbacked
|
| 1491 |
+
# symint to multiple outputs, they will all claim to def it.
|
| 1492 |
+
if s not in unbacked_symbol_to_origin_node:
|
| 1493 |
+
unbacked_symbol_to_origin_node[s] = node
|
| 1494 |
+
|
| 1495 |
+
unbacked_symbol_uses = sorted(
|
| 1496 |
+
node.node.get_unbacked_symbol_uses(), key=lambda x: x.name
|
| 1497 |
+
)
|
| 1498 |
+
# if a kernel takes unbacked symints, register dependencies
|
| 1499 |
+
for s in unbacked_symbol_uses:
|
| 1500 |
+
assert (
|
| 1501 |
+
s in unbacked_symbol_to_origin_node
|
| 1502 |
+
), f"{s} not in {unbacked_symbol_to_origin_node}"
|
| 1503 |
+
node.add_fake_dep(StarDep(unbacked_symbol_to_origin_node[s].get_name()))
|
| 1504 |
+
|
| 1505 |
+
# a node will mutate either 0 or 1 buffers
|
| 1506 |
+
assert len(node.get_mutations()) <= 1
|
| 1507 |
+
for alt_name in node.get_mutations():
|
| 1508 |
+
alt_name = rename(alt_name)
|
| 1509 |
+
# this node must run after the prior writer
|
| 1510 |
+
add_user(alt_name, node)
|
| 1511 |
+
node.add_mutation_dep(StarDep(alt_name))
|
| 1512 |
+
for other_node in name_to_users[alt_name].items:
|
| 1513 |
+
# this node must run after all prior readers
|
| 1514 |
+
other_name = rename(other_node.get_name())
|
| 1515 |
+
known_dep_node_names = dep_closure(node.get_name())
|
| 1516 |
+
if other_name not in known_dep_node_names:
|
| 1517 |
+
# If this node already directly or indirectly depends on other_node,
|
| 1518 |
+
# we don't need to insert an extra dep.
|
| 1519 |
+
node.add_mutation_dep(WeakDep(other_name))
|
| 1520 |
+
add_user(other_name, node, is_weak=True)
|
| 1521 |
+
|
| 1522 |
+
# add normal non-mutation dependencies
|
| 1523 |
+
for read in node.read_writes.reads:
|
| 1524 |
+
is_weak = isinstance(read, WeakDep)
|
| 1525 |
+
add_user(read.name, node, node.can_inplace(read), is_weak)
|
| 1526 |
+
|
| 1527 |
+
node.update_mutated_names(self.mutation_renames)
|
| 1528 |
+
|
| 1529 |
+
# update our renaming scheme for the next iteration
|
| 1530 |
+
for alt_name in node.get_mutations():
|
| 1531 |
+
self.mutation_renames[rename(alt_name)] = node.get_name()
|
| 1532 |
+
self.mutation_renames[alt_name] = node.get_name()
|
| 1533 |
+
self.mutation_real_name[node.get_name()] = self.mutation_real_name.get(
|
| 1534 |
+
alt_name, alt_name
|
| 1535 |
+
)
|
| 1536 |
+
|
| 1537 |
+
# make sure outputs aren't dead-code-eliminated
|
| 1538 |
+
for node_name in V.graph.get_output_names():
|
| 1539 |
+
log.debug("scheduling output %s", node_name)
|
| 1540 |
+
add_user(node_name, OutputNode(StarDep(node_name)))
|
| 1541 |
+
|
| 1542 |
+
# make sure unbacked symints aren't dead-code-eliminated
|
| 1543 |
+
for node in V.graph.graph_outputs:
|
| 1544 |
+
for s in node.get_unbacked_symbol_uses():
|
| 1545 |
+
assert (
|
| 1546 |
+
s in unbacked_symbol_to_origin_node
|
| 1547 |
+
), f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
|
| 1548 |
+
node_name = unbacked_symbol_to_origin_node[s].node.name
|
| 1549 |
+
log.debug("scheduling output %s for unbacked symint %s", node_name, s)
|
| 1550 |
+
add_user(node_name, OutputNode(StarDep(node_name)))
|
| 1551 |
+
|
| 1552 |
+
# make sure input mutation isn't dead-code-eliminated
|
| 1553 |
+
for name in self.mutation_renames:
|
| 1554 |
+
if name in V.graph.graph_inputs:
|
| 1555 |
+
add_user(name, OutputNode(StarDep(name)))
|
| 1556 |
+
V.graph.mutated_inputs.add(name)
|
| 1557 |
+
|
| 1558 |
+
inp_names = {
|
| 1559 |
+
name: index for index, name in enumerate(V.graph.graph_inputs.keys())
|
| 1560 |
+
}
|
| 1561 |
+
V.graph.mutated_input_idxs = [
|
| 1562 |
+
inp_names[name] for name in V.graph.mutated_inputs
|
| 1563 |
+
]
|
| 1564 |
+
|
| 1565 |
+
# copy users information onto the nodes
|
| 1566 |
+
for node in self.nodes:
|
| 1567 |
+
node.set_users(name_to_users[node.get_name()].items)
|
| 1568 |
+
|
| 1569 |
+
# populate inverse_users
|
| 1570 |
+
for node in self.nodes:
|
| 1571 |
+
for user in node.users:
|
| 1572 |
+
user.node.inverse_users.append(node)
|
| 1573 |
+
|
| 1574 |
+
def compute_node_users(self):
|
| 1575 |
+
# set up buffer name to (fused)snode mapping
|
| 1576 |
+
buf_to_snode = {}
|
| 1577 |
+
for node in self.nodes:
|
| 1578 |
+
if isinstance(node, FusedSchedulerNode):
|
| 1579 |
+
for x in node.snodes:
|
| 1580 |
+
buf_to_snode[x.get_name()] = node
|
| 1581 |
+
buf_to_snode[node.get_name()] = node
|
| 1582 |
+
|
| 1583 |
+
for node in self.nodes:
|
| 1584 |
+
node.node_users = []
|
| 1585 |
+
node.inverse_users = []
|
| 1586 |
+
|
| 1587 |
+
# compute inverse_users
|
| 1588 |
+
for node in self.nodes:
|
| 1589 |
+
inverse_users = []
|
| 1590 |
+
for dep in node.unmet_dependencies:
|
| 1591 |
+
assert dep.name in buf_to_snode
|
| 1592 |
+
dep_node = buf_to_snode[dep.name]
|
| 1593 |
+
inverse_users.append(dep_node)
|
| 1594 |
+
node.inverse_users = inverse_users
|
| 1595 |
+
|
| 1596 |
+
# compute node_users
|
| 1597 |
+
# TODO: ideally, we should deduplicate .users and .node_users,
|
| 1598 |
+
# but currently .users contains extra information that's difficult to
|
| 1599 |
+
# extract into a standalone container.
|
| 1600 |
+
node_to_users: Dict[BaseSchedulerNode, List[BaseSchedulerNode]] = {}
|
| 1601 |
+
for node in self.nodes:
|
| 1602 |
+
for inverse_user in node.inverse_users:
|
| 1603 |
+
node_to_users.setdefault(inverse_user, []).append(node)
|
| 1604 |
+
for node, users in node_to_users.items():
|
| 1605 |
+
node.node_users = users
|
| 1606 |
+
|
| 1607 |
+
def dead_node_elimination(self):
|
| 1608 |
+
"""
|
| 1609 |
+
Remove any nodes without users
|
| 1610 |
+
"""
|
| 1611 |
+
again = True # repeat until a fixed point
|
| 1612 |
+
while again:
|
| 1613 |
+
updated_nodes = []
|
| 1614 |
+
for node in self.nodes:
|
| 1615 |
+
|
| 1616 |
+
def can_eliminate_user(user: NodeUser):
|
| 1617 |
+
return user.is_weak or user.get_name() in V.graph.removed_buffers
|
| 1618 |
+
|
| 1619 |
+
can_eliminate = not node.has_side_effects() and all(
|
| 1620 |
+
can_eliminate_user(u) for u in node.users
|
| 1621 |
+
)
|
| 1622 |
+
|
| 1623 |
+
if not can_eliminate:
|
| 1624 |
+
updated_nodes.append(node)
|
| 1625 |
+
else:
|
| 1626 |
+
# dead code
|
| 1627 |
+
log.debug("removed dead node: %s", node.get_name())
|
| 1628 |
+
V.graph.removed_buffers.add(node.get_name())
|
| 1629 |
+
|
| 1630 |
+
again = len(self.nodes) > len(updated_nodes)
|
| 1631 |
+
self.nodes = updated_nodes
|
| 1632 |
+
|
| 1633 |
+
# Prune any WeakDeps no longer needed
|
| 1634 |
+
for node in self.nodes:
|
| 1635 |
+
node.prune_weak_deps()
|
| 1636 |
+
|
| 1637 |
+
def topological_sort_schedule(self):
|
| 1638 |
+
"""
|
| 1639 |
+
Ensure self.nodes is in topologically sorted order
|
| 1640 |
+
"""
|
| 1641 |
+
seen: Set[ir.Buffer] = set()
|
| 1642 |
+
name_to_node: Dict[str, ir.Buffer] = dict()
|
| 1643 |
+
result: List[ir.Buffer] = []
|
| 1644 |
+
|
| 1645 |
+
def visit(n):
|
| 1646 |
+
if n not in seen:
|
| 1647 |
+
seen.add(n)
|
| 1648 |
+
for dep in sorted(n.unmet_dependencies, key=lambda d: d.name):
|
| 1649 |
+
visit(name_to_node[dep.name])
|
| 1650 |
+
result.append(n)
|
| 1651 |
+
|
| 1652 |
+
for node in self.nodes:
|
| 1653 |
+
for name in node.get_names():
|
| 1654 |
+
name_to_node[name] = node
|
| 1655 |
+
for node in self.nodes:
|
| 1656 |
+
visit(node)
|
| 1657 |
+
self.nodes = result
|
| 1658 |
+
|
| 1659 |
+
def compute_ancestors(self):
|
| 1660 |
+
"""
|
| 1661 |
+
Populate each node.ancestors
|
| 1662 |
+
"""
|
| 1663 |
+
# note self.nodes is topologically sorted
|
| 1664 |
+
name_to_ancestors: Dict[str, Set[str]] = {}
|
| 1665 |
+
for node in self.nodes:
|
| 1666 |
+
ancestors = set()
|
| 1667 |
+
for dep in node.unmet_dependencies:
|
| 1668 |
+
ancestors.add(dep.name)
|
| 1669 |
+
ancestors |= name_to_ancestors[dep.name]
|
| 1670 |
+
name_to_ancestors[node.get_name()] = ancestors
|
| 1671 |
+
node.ancestors = ancestors
|
| 1672 |
+
|
| 1673 |
+
for order, node in enumerate(self.nodes):
|
| 1674 |
+
node.min_order = order
|
| 1675 |
+
node.max_order = order
|
| 1676 |
+
|
| 1677 |
+
def fuse_nodes(self):
|
| 1678 |
+
"""
|
| 1679 |
+
Mutates self.nodes to combine nodes into FusedSchedulerNodes.
|
| 1680 |
+
"""
|
| 1681 |
+
for i in range(10):
|
| 1682 |
+
old_len = len(self.nodes)
|
| 1683 |
+
fusion_log.debug(
|
| 1684 |
+
"===== attempting fusion (%d/10): %d nodes =====", i + 1, old_len
|
| 1685 |
+
)
|
| 1686 |
+
self.fuse_nodes_once()
|
| 1687 |
+
new_len = len(self.nodes)
|
| 1688 |
+
fusion_log.debug(
|
| 1689 |
+
"completed fusion round (%d/10): fused %d nodes into %d nodes\n",
|
| 1690 |
+
i + 1,
|
| 1691 |
+
old_len,
|
| 1692 |
+
new_len,
|
| 1693 |
+
)
|
| 1694 |
+
if new_len == old_len or new_len == 1:
|
| 1695 |
+
fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1)
|
| 1696 |
+
break
|
| 1697 |
+
|
| 1698 |
+
def benchmark_fused_nodes(self, nodes):
|
| 1699 |
+
"""
|
| 1700 |
+
Benchmark fused list of nodes and return the execution time
|
| 1701 |
+
in milliseconds on randomly generated inputs.
|
| 1702 |
+
"""
|
| 1703 |
+
assert len(nodes) > 0
|
| 1704 |
+
device = nodes[0].get_device()
|
| 1705 |
+
V.graph.scheduler = self
|
| 1706 |
+
self.current_device = device
|
| 1707 |
+
backend = self.get_backend(device)
|
| 1708 |
+
return backend.benchmark_fused_nodes(nodes)
|
| 1709 |
+
|
| 1710 |
+
def speedup_by_fusion(self, node1, node2):
|
| 1711 |
+
"""
|
| 1712 |
+
If config.benchmark_fusion is False, always return True.
|
| 1713 |
+
Otherwise, return True if fusion can brings speedup.
|
| 1714 |
+
"""
|
| 1715 |
+
if not config.benchmark_fusion:
|
| 1716 |
+
return True
|
| 1717 |
+
|
| 1718 |
+
if (
|
| 1719 |
+
node1.is_template()
|
| 1720 |
+
and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer)
|
| 1721 |
+
or node1.is_foreach()
|
| 1722 |
+
or node2.is_foreach()
|
| 1723 |
+
):
|
| 1724 |
+
# TODO support benchmarking epilogue fusion
|
| 1725 |
+
return True
|
| 1726 |
+
|
| 1727 |
+
node_list_1 = node1.get_nodes()
|
| 1728 |
+
device = node_list_1[0].get_device()
|
| 1729 |
+
|
| 1730 |
+
# don't support benchmark fusion for CPU right now.
|
| 1731 |
+
if device.type == "cpu":
|
| 1732 |
+
return True
|
| 1733 |
+
|
| 1734 |
+
node_list_2 = node2.get_nodes()
|
| 1735 |
+
node_list_fused = node_list_1 + node_list_2
|
| 1736 |
+
|
| 1737 |
+
# We can not accurately benchmark kernel using atomic_add
|
| 1738 |
+
# due to how we generate random integer inputs.
|
| 1739 |
+
# Skip benchmarking them by allowing fusion.
|
| 1740 |
+
if any(
|
| 1741 |
+
hasattr(n.node, "data")
|
| 1742 |
+
and hasattr(n.node.data, "scatter_mode")
|
| 1743 |
+
and n.node.data.scatter_mode == "atomic_add"
|
| 1744 |
+
for n in node_list_fused
|
| 1745 |
+
):
|
| 1746 |
+
return True
|
| 1747 |
+
|
| 1748 |
+
from triton.compiler.errors import CompilationError
|
| 1749 |
+
|
| 1750 |
+
why = WhyNoFuse(node1, node2)
|
| 1751 |
+
|
| 1752 |
+
try:
|
| 1753 |
+
ms1, path1 = self.benchmark_fused_nodes(node_list_1)
|
| 1754 |
+
if math.isinf(ms1):
|
| 1755 |
+
why("register spilling of the first kernel")
|
| 1756 |
+
return False
|
| 1757 |
+
ms2, path2 = self.benchmark_fused_nodes(node_list_2)
|
| 1758 |
+
if math.isinf(ms2):
|
| 1759 |
+
why("register spilling of the second kernel")
|
| 1760 |
+
return False
|
| 1761 |
+
ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused)
|
| 1762 |
+
if math.isinf(ms_fused):
|
| 1763 |
+
why("register spilling of the fused kernel")
|
| 1764 |
+
return False
|
| 1765 |
+
except CompilationError as e:
|
| 1766 |
+
# workaround triton issue: https://github.com/openai/triton/issues/2151
|
| 1767 |
+
if "Loop-carried variable" in str(e):
|
| 1768 |
+
return True # allow fusion
|
| 1769 |
+
else:
|
| 1770 |
+
raise
|
| 1771 |
+
|
| 1772 |
+
if fusion_log.isEnabledFor(logging.DEBUG):
|
| 1773 |
+
if ms_fused < ms1 + ms2:
|
| 1774 |
+
fusion_log.debug(
|
| 1775 |
+
"can fuse (benchmark): fusing %s with %s cause %sx speedup",
|
| 1776 |
+
node1.get_names(),
|
| 1777 |
+
node2.get_names(),
|
| 1778 |
+
green_text(f"{(ms1 + ms2) / ms_fused:.3f}"),
|
| 1779 |
+
)
|
| 1780 |
+
else:
|
| 1781 |
+
fusion_log.debug(
|
| 1782 |
+
"cannot fuse (benchmark): fusing %s with %s cause %sx slowdown",
|
| 1783 |
+
node1.get_names(),
|
| 1784 |
+
node2.get_names(),
|
| 1785 |
+
red_text(f"{ms_fused / (ms1 + ms2):.3f}"),
|
| 1786 |
+
)
|
| 1787 |
+
|
| 1788 |
+
if (
|
| 1789 |
+
is_metric_table_enabled("slow_fusion")
|
| 1790 |
+
and ms_fused >= ms1 + ms2
|
| 1791 |
+
and (path1, path2) not in self.logged_slow_fusion
|
| 1792 |
+
):
|
| 1793 |
+
self.logged_slow_fusion.add((path1, path2))
|
| 1794 |
+
get_metric_table("slow_fusion").add_row(
|
| 1795 |
+
lambda: {
|
| 1796 |
+
"kernel1_path": path1,
|
| 1797 |
+
"kernel1_latency": ms1,
|
| 1798 |
+
"kernel2_path": path2,
|
| 1799 |
+
"kernel2_latency": ms2,
|
| 1800 |
+
"fused_kernel_path": path_fused,
|
| 1801 |
+
"fused_kernel_latency": ms_fused,
|
| 1802 |
+
"slow_down_ratio": ms_fused / (ms1 + ms2),
|
| 1803 |
+
}
|
| 1804 |
+
)
|
| 1805 |
+
return ms_fused < ms1 + ms2
|
| 1806 |
+
|
| 1807 |
+
def fuse_nodes_once(self):
|
| 1808 |
+
"""
|
| 1809 |
+
Mutates self.nodes to combine nodes into FusedSchedulerNodes.
|
| 1810 |
+
|
| 1811 |
+
This relies on two key functions to control the logic:
|
| 1812 |
+
- self.can_fuse(): checks if a fusion is legal
|
| 1813 |
+
- self.score_fusion(): assigns priority to a given fusion
|
| 1814 |
+
"""
|
| 1815 |
+
fused_nodes = set(self.nodes)
|
| 1816 |
+
for node1, node2 in self.get_possible_fusions():
|
| 1817 |
+
node1 = self.name_to_fused_node[node1.get_first_name()]
|
| 1818 |
+
node2 = self.name_to_fused_node[node2.get_first_name()]
|
| 1819 |
+
if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle(
|
| 1820 |
+
node1, node2
|
| 1821 |
+
):
|
| 1822 |
+
if not self.speedup_by_fusion(node1, node2):
|
| 1823 |
+
continue
|
| 1824 |
+
fusion_log.debug(
|
| 1825 |
+
"fusing %s with %s", node1.get_name(), node2.get_name()
|
| 1826 |
+
)
|
| 1827 |
+
|
| 1828 |
+
# above can_fuse asserts that node2 has the same device
|
| 1829 |
+
device = node1.get_device()
|
| 1830 |
+
node3 = self.get_backend(device).fuse(node1, node2)
|
| 1831 |
+
fused_nodes.remove(node1)
|
| 1832 |
+
fused_nodes.remove(node2)
|
| 1833 |
+
fused_nodes.add(node3)
|
| 1834 |
+
self.name_to_fused_node.update(
|
| 1835 |
+
{n.get_name(): node3 for n in node3.get_nodes()}
|
| 1836 |
+
)
|
| 1837 |
+
self.nodes = sorted(fused_nodes, key=lambda x: x.min_order)
|
| 1838 |
+
self.topological_sort_schedule()
|
| 1839 |
+
self.prune_redundant_deps()
|
| 1840 |
+
|
| 1841 |
+
def prune_redundant_deps(self):
|
| 1842 |
+
for node in self.nodes:
|
| 1843 |
+
node.prune_redundant_deps(self.name_to_fused_node)
|
| 1844 |
+
|
| 1845 |
+
def get_possible_fusions(self):
|
| 1846 |
+
"""
|
| 1847 |
+
Helper to find all legal fusion opportunities, sorted by self.score_fusion()
|
| 1848 |
+
"""
|
| 1849 |
+
possible_fusions = []
|
| 1850 |
+
seen = set()
|
| 1851 |
+
|
| 1852 |
+
def check_all_pairs(nodes):
|
| 1853 |
+
for node1_index, node1 in enumerate(nodes):
|
| 1854 |
+
for node2 in nodes[node1_index + 1 :]:
|
| 1855 |
+
key = (node1, node2)
|
| 1856 |
+
if key in seen:
|
| 1857 |
+
continue
|
| 1858 |
+
seen.add(key)
|
| 1859 |
+
|
| 1860 |
+
if self.can_fuse(node1, node2):
|
| 1861 |
+
possible_fusions.append(key)
|
| 1862 |
+
elif (node2.is_template() or node2.is_foreach()) and self.can_fuse(
|
| 1863 |
+
node2, node1
|
| 1864 |
+
):
|
| 1865 |
+
# foreach fusions and epilogue fusions are order dependent
|
| 1866 |
+
possible_fusions.append((node2, node1))
|
| 1867 |
+
|
| 1868 |
+
buffer_names_grouping = collections.defaultdict(list)
|
| 1869 |
+
for node in self.nodes:
|
| 1870 |
+
for buf in node.used_buffer_names():
|
| 1871 |
+
buffer_names_grouping[buf].append(node)
|
| 1872 |
+
for node_grouping in buffer_names_grouping.values():
|
| 1873 |
+
check_all_pairs(node_grouping)
|
| 1874 |
+
|
| 1875 |
+
if config.aggressive_fusion:
|
| 1876 |
+
group_grouping = collections.defaultdict(list)
|
| 1877 |
+
for node in self.nodes:
|
| 1878 |
+
group = getattr(node, "group", None)
|
| 1879 |
+
if group:
|
| 1880 |
+
group_grouping[group].append(node)
|
| 1881 |
+
for node_grouping in group_grouping.values():
|
| 1882 |
+
check_all_pairs(node_grouping)
|
| 1883 |
+
|
| 1884 |
+
possible_fusions.sort(key=self.score_fusion_key, reverse=True)
|
| 1885 |
+
fusion_log.debug("found %d possible fusions", len(possible_fusions))
|
| 1886 |
+
return possible_fusions
|
| 1887 |
+
|
| 1888 |
+
def will_fusion_create_cycle(self, node1, node2):
|
| 1889 |
+
"""
|
| 1890 |
+
Finds whether there's a path from node1 to node2 (or vice-versa)
|
| 1891 |
+
caused indirectly by other fusions.
|
| 1892 |
+
"""
|
| 1893 |
+
|
| 1894 |
+
def found_path(node):
|
| 1895 |
+
# only fused nodes can introduce new ancestors.
|
| 1896 |
+
if isinstance(node, FusedSchedulerNode) and node not in visited:
|
| 1897 |
+
visited.add(node)
|
| 1898 |
+
if node.get_names().issubset(combined_ancestors):
|
| 1899 |
+
# All fusion outputs are in ancestors of node1 and node2, thus
|
| 1900 |
+
# cannot introduce new path:
|
| 1901 |
+
#
|
| 1902 |
+
# 1. if output is neither descendent of node1 or node2, the
|
| 1903 |
+
# output cannot introduce a path
|
| 1904 |
+
# 2. due to [can_fuse]: if WLOG output is descendent of node1, it cannot be
|
| 1905 |
+
# on path(node1->node2), hence it cannot be ancestor of node2
|
| 1906 |
+
# 3. due to [acyclic]: if WLOG output is descendent of node1, it cannot be
|
| 1907 |
+
# ancestor of node1
|
| 1908 |
+
return False
|
| 1909 |
+
else:
|
| 1910 |
+
# continue DFS of new ancestors introduced by the fusion
|
| 1911 |
+
return bool(combined_names & node.ancestors) or any(
|
| 1912 |
+
found_path(self.name_to_fused_node[n])
|
| 1913 |
+
for n in node.ancestors - combined_ancestors
|
| 1914 |
+
)
|
| 1915 |
+
return False
|
| 1916 |
+
|
| 1917 |
+
visited = set()
|
| 1918 |
+
combined_names = node1.get_names() | node2.get_names()
|
| 1919 |
+
combined_ancestors = (node1.ancestors | node2.ancestors) - combined_names
|
| 1920 |
+
cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors)
|
| 1921 |
+
if cycle:
|
| 1922 |
+
WhyNoFuse(node1, node2)("will create cycle")
|
| 1923 |
+
return cycle
|
| 1924 |
+
|
| 1925 |
+
def can_fusion_increase_peak_memory(
|
| 1926 |
+
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
| 1927 |
+
):
|
| 1928 |
+
"""
|
| 1929 |
+
This function prevents fusion for nodes that can increase memory
|
| 1930 |
+
footprint. This problem is more common in horizontal fusion, where nodes
|
| 1931 |
+
that are far apart in the original order get fused, lengthening the live
|
| 1932 |
+
intervals of tensors. This is very evident in models with activation
|
| 1933 |
+
checkpointing, where the recomputed nodes from different checkpointed
|
| 1934 |
+
regions get fused and significantly increase the memory footprint.
|
| 1935 |
+
|
| 1936 |
+
The current attempt is a quick, possibly hacky, heuristic to prevent the
|
| 1937 |
+
fusion of nodes that are far away in the original order.
|
| 1938 |
+
|
| 1939 |
+
A better but difficult to implement heurisitic would be to use live
|
| 1940 |
+
intervals of the buffers, find region of peak pressure in the original
|
| 1941 |
+
program and prevent fusion that crosses that peak region. We might need
|
| 1942 |
+
special care or good approximation in this implementation, as fusion of
|
| 1943 |
+
node changes live intervals, and re-computing live intervals and peak
|
| 1944 |
+
memory after each fusion can introduce large compilation overhead.
|
| 1945 |
+
"""
|
| 1946 |
+
proximity_score = max(
|
| 1947 |
+
abs(node1.min_order - node2.max_order),
|
| 1948 |
+
abs(node2.min_order - node1.max_order),
|
| 1949 |
+
)
|
| 1950 |
+
return proximity_score > 64
|
| 1951 |
+
|
| 1952 |
+
def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
|
| 1953 |
+
"""
|
| 1954 |
+
Determine if it is possible to combine node1 and node2 into a
|
| 1955 |
+
single fused node.
|
| 1956 |
+
"""
|
| 1957 |
+
|
| 1958 |
+
if node1 is node2:
|
| 1959 |
+
return False
|
| 1960 |
+
|
| 1961 |
+
why = WhyNoFuse(node1, node2)
|
| 1962 |
+
|
| 1963 |
+
if (
|
| 1964 |
+
isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
|
| 1965 |
+
and not node1.is_template()
|
| 1966 |
+
):
|
| 1967 |
+
why("node1 is extern or nop")
|
| 1968 |
+
return False
|
| 1969 |
+
if (
|
| 1970 |
+
isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
|
| 1971 |
+
and not node2.is_template()
|
| 1972 |
+
):
|
| 1973 |
+
why("node2 is extern or nop")
|
| 1974 |
+
return False
|
| 1975 |
+
|
| 1976 |
+
if node2.get_names() & node1.ancestors:
|
| 1977 |
+
why("node1 must go before node2")
|
| 1978 |
+
return False
|
| 1979 |
+
|
| 1980 |
+
if (
|
| 1981 |
+
isinstance(node1, (FusedSchedulerNode, SchedulerNode))
|
| 1982 |
+
and isinstance(node2, SchedulerNode)
|
| 1983 |
+
and isinstance(node2._body, ir.LoopBody)
|
| 1984 |
+
):
|
| 1985 |
+
# Fix issue: https://github.com/pytorch/pytorch/issues/108963
|
| 1986 |
+
# Check:
|
| 1987 |
+
# If node2 reads a buf which is a mutation buf of node1(SchedulerNode) or among nodes in node1(FusedSchedulerNode),
|
| 1988 |
+
# we will get the corresponding mutation buf and check if this mutation buf is stored by atomic_add mode.
|
| 1989 |
+
# If True, we will disable the fusion of node1 and node2.
|
| 1990 |
+
if any(
|
| 1991 |
+
(
|
| 1992 |
+
node2_used_buf in self.mutation_renames
|
| 1993 |
+
and node1.has_atomic_add(self.mutation_renames[node2_used_buf])
|
| 1994 |
+
)
|
| 1995 |
+
for node2_used_buf in node2._body.reads_name2expr.keys()
|
| 1996 |
+
):
|
| 1997 |
+
return False
|
| 1998 |
+
|
| 1999 |
+
if node2.is_template():
|
| 2000 |
+
why("templates can only fuse epilogues")
|
| 2001 |
+
return False
|
| 2002 |
+
if node1.is_template() and (
|
| 2003 |
+
node2.has_aliasing_or_mutation()
|
| 2004 |
+
or node2.is_reduction()
|
| 2005 |
+
or not config.epilogue_fusion
|
| 2006 |
+
):
|
| 2007 |
+
why("template epilogue not satisfied")
|
| 2008 |
+
return False
|
| 2009 |
+
|
| 2010 |
+
device = node1.get_device()
|
| 2011 |
+
device2 = node2.get_device()
|
| 2012 |
+
if device != device2:
|
| 2013 |
+
why("device mismatch (%s vs %s)", device, device2)
|
| 2014 |
+
return False
|
| 2015 |
+
del device2
|
| 2016 |
+
|
| 2017 |
+
no_shared_data = self.score_fusion_memory(node1, node2) == 0
|
| 2018 |
+
if no_shared_data and (
|
| 2019 |
+
not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction()
|
| 2020 |
+
):
|
| 2021 |
+
why("no shared data")
|
| 2022 |
+
return False # heuristic not needed for correctness
|
| 2023 |
+
|
| 2024 |
+
if (
|
| 2025 |
+
not node1.is_foreach()
|
| 2026 |
+
and not node2.is_foreach()
|
| 2027 |
+
and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size
|
| 2028 |
+
):
|
| 2029 |
+
why("exceeds max fusion")
|
| 2030 |
+
return False # heuristic not needed for correctness
|
| 2031 |
+
|
| 2032 |
+
if node1.get_names() & node2.ancestors:
|
| 2033 |
+
# node2 depends on node1 outputs
|
| 2034 |
+
if not self.can_fuse_vertical(node1, node2):
|
| 2035 |
+
return False
|
| 2036 |
+
return self.get_backend(device).can_fuse_vertical(node1, node2)
|
| 2037 |
+
else: # nodes don't depend on each other, but may have common reads
|
| 2038 |
+
if self.can_fusion_increase_peak_memory(node1, node2):
|
| 2039 |
+
why("will increase peak memory")
|
| 2040 |
+
return False
|
| 2041 |
+
return self.get_backend(device).can_fuse_horizontal(node1, node2)
|
| 2042 |
+
|
| 2043 |
+
def can_fuse_vertical(self, node1, node2):
|
| 2044 |
+
"""
|
| 2045 |
+
Check if it is legal to fuse a consumer (node2) into a producer (node1).
|
| 2046 |
+
|
| 2047 |
+
We can fuse them if all the reads of node2 either match
|
| 2048 |
+
corresponding writes in node1, or are written by nodes that can
|
| 2049 |
+
be scheduled before the fusion of node1 and node2.
|
| 2050 |
+
|
| 2051 |
+
We also disable fusion of a write subsequent to a read if the reads
|
| 2052 |
+
and writes do not align.
|
| 2053 |
+
"""
|
| 2054 |
+
node1_names = node1.get_names()
|
| 2055 |
+
computed_deps = set()
|
| 2056 |
+
why = WhyNoFuse(node1, node2)
|
| 2057 |
+
|
| 2058 |
+
# StarDep doesn't match MemoryDep, different indices don't match
|
| 2059 |
+
# However, broadcasting sometimes strips dimensions, and if that's the case
|
| 2060 |
+
# we still can match unmet dep
|
| 2061 |
+
# if there's indirect indexing, don't match it
|
| 2062 |
+
def fusable_read_and_write(read: Dep, write: Dep):
|
| 2063 |
+
return (
|
| 2064 |
+
self.mutation_renames.get(read.name, read.name) == write.name
|
| 2065 |
+
and (isinstance(read, MemoryDep) and isinstance(write, MemoryDep))
|
| 2066 |
+
and not free_symbol_has(read.index, "tmp")
|
| 2067 |
+
and not free_symbol_has(write.index, "tmp")
|
| 2068 |
+
and read.index == write.index
|
| 2069 |
+
and len(read.size) >= len(write.size)
|
| 2070 |
+
and read.size[: len(write.size)] == write.size
|
| 2071 |
+
)
|
| 2072 |
+
|
| 2073 |
+
for rd in node2.unmet_dependencies:
|
| 2074 |
+
for cd in node1.read_writes.writes:
|
| 2075 |
+
if fusable_read_and_write(rd, cd):
|
| 2076 |
+
computed_deps.add(rd)
|
| 2077 |
+
|
| 2078 |
+
remaining_deps = {dep.name for dep in node2.unmet_dependencies - computed_deps}
|
| 2079 |
+
if remaining_deps & node1_names:
|
| 2080 |
+
# MemoryDeps didn't match and read different locations of the same buffer.
|
| 2081 |
+
# Examples here include:
|
| 2082 |
+
# - MemoryDep("foo", x) != MemoryDep("foo", x + 1)
|
| 2083 |
+
# - MemoryDep("foo", x) != StarDep("foo")
|
| 2084 |
+
why("memory deps did not match")
|
| 2085 |
+
return False
|
| 2086 |
+
for name in remaining_deps:
|
| 2087 |
+
if node1_names & self.name_to_fused_node[name].ancestors:
|
| 2088 |
+
why("intermediate nodes between node1 & node2")
|
| 2089 |
+
return False
|
| 2090 |
+
|
| 2091 |
+
# similar to can_inplace, if we are going to fuse a write subsequent to a read
|
| 2092 |
+
# require that the indexing and size is the same
|
| 2093 |
+
for write in node2.read_writes.writes:
|
| 2094 |
+
for read in node1.read_writes.reads:
|
| 2095 |
+
if write.name != self.mutation_renames.get(read.name, read.name):
|
| 2096 |
+
continue
|
| 2097 |
+
|
| 2098 |
+
# bail on StarDep
|
| 2099 |
+
if not fusable_read_and_write(read=read, write=write):
|
| 2100 |
+
why("fusing a write into a read with different indexing formula")
|
| 2101 |
+
return False
|
| 2102 |
+
|
| 2103 |
+
return True
|
| 2104 |
+
|
| 2105 |
+
def score_fusion(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
|
| 2106 |
+
"""
|
| 2107 |
+
Assign a score (higher comes first) to the fusion of node1
|
| 2108 |
+
and node2. When different fusions conflict with each other,
|
| 2109 |
+
this is the way we decide what order to run them in.
|
| 2110 |
+
|
| 2111 |
+
Our current score is based on:
|
| 2112 |
+
- Estimate of the saved memory operations
|
| 2113 |
+
- Fusions closer together in original order
|
| 2114 |
+
"""
|
| 2115 |
+
memory_score = self.score_fusion_memory(node1, node2)
|
| 2116 |
+
proximity_score = -max(
|
| 2117 |
+
abs(node1.min_order - node2.max_order),
|
| 2118 |
+
abs(node2.min_order - node1.max_order),
|
| 2119 |
+
)
|
| 2120 |
+
return (
|
| 2121 |
+
node1.is_template() == config.epilogue_fusion_first and memory_score > 0,
|
| 2122 |
+
node1.is_reduction() == node2.is_reduction() and memory_score > 0,
|
| 2123 |
+
memory_score,
|
| 2124 |
+
proximity_score,
|
| 2125 |
+
)
|
| 2126 |
+
|
| 2127 |
+
def score_fusion_memory(self, node1, node2):
|
| 2128 |
+
"""
|
| 2129 |
+
The first term in our fusion score that estimates number of saved memory operations.
|
| 2130 |
+
"""
|
| 2131 |
+
common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & (
|
| 2132 |
+
node2.read_writes.reads | node2.read_writes.writes
|
| 2133 |
+
)
|
| 2134 |
+
common_memory_deps = {
|
| 2135 |
+
dep for dep in common_memory_deps if not dep.has_unbacked_symbols()
|
| 2136 |
+
}
|
| 2137 |
+
return sum(dep.numbytes_hint() for dep in common_memory_deps)
|
| 2138 |
+
|
| 2139 |
+
def score_fusion_key(self, nodes):
|
| 2140 |
+
"""
|
| 2141 |
+
Shim for list.sort(key=...)
|
| 2142 |
+
"""
|
| 2143 |
+
node1, node2 = nodes
|
| 2144 |
+
return self.score_fusion(node1, node2)
|
| 2145 |
+
|
| 2146 |
+
def compute_last_usage(self):
|
| 2147 |
+
"""
|
| 2148 |
+
Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode)
|
| 2149 |
+
"""
|
| 2150 |
+
|
| 2151 |
+
future_used_buffers = set()
|
| 2152 |
+
for node_name in V.graph.get_output_names():
|
| 2153 |
+
future_used_buffers.add(node_name)
|
| 2154 |
+
|
| 2155 |
+
for node in reversed(self.nodes):
|
| 2156 |
+
node.set_last_usage(future_used_buffers, self.mutation_real_name)
|
| 2157 |
+
future_used_buffers.update(node.last_usage)
|
| 2158 |
+
|
| 2159 |
+
def free_buffers(self):
|
| 2160 |
+
"""Free any buffers that are no longer needed"""
|
| 2161 |
+
for name in sorted(
|
| 2162 |
+
self.buffer_names_to_free
|
| 2163 |
+
- V.graph.removed_buffers
|
| 2164 |
+
- V.graph.wrapper_code.freed
|
| 2165 |
+
):
|
| 2166 |
+
if name in self.name_to_node:
|
| 2167 |
+
node = self.name_to_node[name]
|
| 2168 |
+
if node.can_free():
|
| 2169 |
+
V.graph.wrapper_code.codegen_free(node.node)
|
| 2170 |
+
elif name in V.graph.graph_inputs:
|
| 2171 |
+
storage = V.graph.graph_inputs[name].data
|
| 2172 |
+
assert isinstance(storage, ir.StorageBox) and storage.is_input_buffer()
|
| 2173 |
+
V.graph.wrapper_code.codegen_free(storage.data)
|
| 2174 |
+
|
| 2175 |
+
self.buffer_names_to_free.clear()
|
| 2176 |
+
|
| 2177 |
+
def remove_kernel_local_buffers(self):
|
| 2178 |
+
"""
|
| 2179 |
+
Any buffers that are both created and have a last use in the
|
| 2180 |
+
same kernel can be removed.
|
| 2181 |
+
"""
|
| 2182 |
+
|
| 2183 |
+
# V.kernel.store_buffer_names should represent the set of nodes
|
| 2184 |
+
# get fused
|
| 2185 |
+
fused_node_names = V.kernel.store_buffer_names
|
| 2186 |
+
names_to_remove = []
|
| 2187 |
+
for out_buf in V.kernel.store_buffer_names:
|
| 2188 |
+
users = self.name_to_node[out_buf].users
|
| 2189 |
+
assert users is not None
|
| 2190 |
+
users = {user.get_name() for user in users if not user.is_weak}
|
| 2191 |
+
if users.issubset(fused_node_names):
|
| 2192 |
+
names_to_remove.append(out_buf)
|
| 2193 |
+
|
| 2194 |
+
def remove_filter(n):
|
| 2195 |
+
return (
|
| 2196 |
+
n not in V.kernel.must_keep_buffers
|
| 2197 |
+
and n not in V.kernel.args.input_buffers
|
| 2198 |
+
and n not in self.mutation_renames
|
| 2199 |
+
and n not in self.mutation_real_name
|
| 2200 |
+
)
|
| 2201 |
+
|
| 2202 |
+
names_to_remove = list(filter(remove_filter, names_to_remove))
|
| 2203 |
+
|
| 2204 |
+
for name in names_to_remove:
|
| 2205 |
+
if name in V.kernel.args.inplace_buffers:
|
| 2206 |
+
buf = V.kernel.args.inplace_buffers[name]
|
| 2207 |
+
if isinstance(buf, str) and buf.startswith("REMOVED"):
|
| 2208 |
+
continue
|
| 2209 |
+
remove = all(n in names_to_remove for n in buf.other_names)
|
| 2210 |
+
if remove:
|
| 2211 |
+
self.remove_inplace_buffer(name)
|
| 2212 |
+
V.kernel.inplaced_to_remove.add(name)
|
| 2213 |
+
else:
|
| 2214 |
+
self.remove_buffer(name)
|
| 2215 |
+
|
| 2216 |
+
def remove_buffer(self, name):
|
| 2217 |
+
# Assign a special value instead of deleting the entry
|
| 2218 |
+
# because we still rely on output_buffers's length to
|
| 2219 |
+
# generate unique arg name.
|
| 2220 |
+
log.debug("remove_buffer(%r)", name)
|
| 2221 |
+
V.kernel.args.output_buffers[name] = "REMOVED"
|
| 2222 |
+
V.kernel.removed_buffers.add(name)
|
| 2223 |
+
|
| 2224 |
+
def remove_inplace_buffer(self, name):
|
| 2225 |
+
log.debug("removing_inplace_buffer(%r)", name)
|
| 2226 |
+
inner_name = V.kernel.args.inplace_buffers[name].inner_name
|
| 2227 |
+
V.kernel.args.inplace_buffers[name] = inner_name.replace(
|
| 2228 |
+
"in_out_ptr", "REMOVED"
|
| 2229 |
+
)
|
| 2230 |
+
V.kernel.removed_buffers.add(name)
|
| 2231 |
+
|
| 2232 |
+
def flush(self):
|
| 2233 |
+
for backend in self.backends.values():
|
| 2234 |
+
backend.flush()
|
| 2235 |
+
self.free_buffers()
|
| 2236 |
+
|
| 2237 |
+
def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode):
|
| 2238 |
+
assert isinstance(scheduler_node, ExternKernelSchedulerNode)
|
| 2239 |
+
# 'decide_inplace_update' stores the inplace update decisions in
|
| 2240 |
+
# the current kernel from where 'allocate' retrieve those decisions.
|
| 2241 |
+
# We have to make sure there is a non-NULL kernel handler to store
|
| 2242 |
+
# those inplace update decisions.
|
| 2243 |
+
with V.set_kernel_handler(Kernel(increase_kernel_count=False)):
|
| 2244 |
+
scheduler_node.decide_inplace_update()
|
| 2245 |
+
scheduler_node.allocate()
|
| 2246 |
+
node = scheduler_node.node
|
| 2247 |
+
assert isinstance(node, ir.ExternKernel), f"{type(node)=}"
|
| 2248 |
+
node.codegen(V.graph.wrapper_code)
|
| 2249 |
+
self.free_buffers()
|
| 2250 |
+
|
| 2251 |
+
def create_backend(self, device: torch.device):
|
| 2252 |
+
assert (
|
| 2253 |
+
device.type != "cuda" or device.index is not None
|
| 2254 |
+
), f"{device} should have been normalized in lowering"
|
| 2255 |
+
V.graph.add_device_info(device)
|
| 2256 |
+
|
| 2257 |
+
device_scheduling = get_scheduling_for_device(device.type)
|
| 2258 |
+
if device_scheduling is None:
|
| 2259 |
+
raise RuntimeError(f"Unsupported device type: {device.type}")
|
| 2260 |
+
|
| 2261 |
+
if device.type == "cuda" and not has_triton():
|
| 2262 |
+
device_props = torch.cuda.get_device_properties(device)
|
| 2263 |
+
if device_props.major < 7:
|
| 2264 |
+
raise RuntimeError(
|
| 2265 |
+
f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" # noqa: B950
|
| 2266 |
+
)
|
| 2267 |
+
else:
|
| 2268 |
+
raise RuntimeError(
|
| 2269 |
+
"Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950
|
| 2270 |
+
)
|
| 2271 |
+
|
| 2272 |
+
return device_scheduling(self)
|
| 2273 |
+
|
| 2274 |
+
def get_backend(self, device: torch.device):
|
| 2275 |
+
if device not in self.backends:
|
| 2276 |
+
self.backends[device] = self.create_backend(device)
|
| 2277 |
+
return self.backends[device]
|
| 2278 |
+
|
| 2279 |
+
def enter_context(self, node):
|
| 2280 |
+
def get_order(n):
|
| 2281 |
+
if n not in self.origin_to_index:
|
| 2282 |
+
self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)})
|
| 2283 |
+
return self.origin_to_index[n]
|
| 2284 |
+
|
| 2285 |
+
# Use a dict to have ordering
|
| 2286 |
+
origins = {
|
| 2287 |
+
(get_order(e), e): None for n in node.get_nodes() for e in n.node.origins
|
| 2288 |
+
}
|
| 2289 |
+
origins = list(origins.keys())
|
| 2290 |
+
if origins:
|
| 2291 |
+
_, last = max(origins, key=operator.itemgetter(0))
|
| 2292 |
+
V.graph.wrapper_code.enter_context(last)
|
| 2293 |
+
|
| 2294 |
+
@dynamo_timed
|
| 2295 |
+
def codegen(self):
|
| 2296 |
+
for node in self.nodes:
|
| 2297 |
+
try:
|
| 2298 |
+
log.debug(
|
| 2299 |
+
"Generating code for node %s with estimated runtime %f",
|
| 2300 |
+
node.get_name(),
|
| 2301 |
+
node.get_estimated_runtime(),
|
| 2302 |
+
)
|
| 2303 |
+
except Exception as e:
|
| 2304 |
+
log.debug(
|
| 2305 |
+
"Generating code for node %s with estimated runtime 0.0",
|
| 2306 |
+
node.get_name(),
|
| 2307 |
+
)
|
| 2308 |
+
|
| 2309 |
+
self.enter_context(node)
|
| 2310 |
+
|
| 2311 |
+
if not isinstance(node, NopKernelSchedulerNode):
|
| 2312 |
+
device = node.get_device()
|
| 2313 |
+
if (
|
| 2314 |
+
device != self.current_device
|
| 2315 |
+
or node.is_extern()
|
| 2316 |
+
or node.is_template()
|
| 2317 |
+
):
|
| 2318 |
+
self.flush()
|
| 2319 |
+
if device != self.current_device:
|
| 2320 |
+
if device.type == "cuda":
|
| 2321 |
+
if self.current_device and self.current_device.type == "cuda":
|
| 2322 |
+
V.graph.wrapper_code.codegen_device_guard_exit()
|
| 2323 |
+
assert device.index is not None, "device should have an index"
|
| 2324 |
+
V.graph.wrapper_code.codegen_device_guard_enter(device.index)
|
| 2325 |
+
elif self.current_device and self.current_device.type == "cuda":
|
| 2326 |
+
V.graph.wrapper_code.codegen_device_guard_exit()
|
| 2327 |
+
self.current_device = device
|
| 2328 |
+
|
| 2329 |
+
self.buffer_names_to_free.update(node.last_usage)
|
| 2330 |
+
|
| 2331 |
+
if node.is_template():
|
| 2332 |
+
node, *epilogue = node.get_nodes()
|
| 2333 |
+
self.get_backend(device).codegen_template(node, epilogue) # type: ignore[possibly-undefined]
|
| 2334 |
+
elif node.is_extern():
|
| 2335 |
+
self.codegen_extern_call(node)
|
| 2336 |
+
elif node.is_foreach():
|
| 2337 |
+
self.get_backend(device).codegen_foreach(node) # type: ignore[possibly-undefined]
|
| 2338 |
+
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
|
| 2339 |
+
self.get_backend(device).codegen_nodes(node.get_nodes()) # type: ignore[possibly-undefined]
|
| 2340 |
+
else:
|
| 2341 |
+
assert isinstance(node, NopKernelSchedulerNode)
|
| 2342 |
+
node.allocate()
|
| 2343 |
+
|
| 2344 |
+
if config.debug_check_inf_and_nan:
|
| 2345 |
+
V.graph.wrapper_code.generate_inf_and_nan_checker(node)
|
| 2346 |
+
|
| 2347 |
+
if config.triton.debug_sync_kernel:
|
| 2348 |
+
self.get_backend(device).codegen_sync() # type: ignore[possibly-undefined]
|
| 2349 |
+
|
| 2350 |
+
self.available_buffer_names.update(node.get_names())
|
| 2351 |
+
|
| 2352 |
+
if not isinstance(node, NopKernelSchedulerNode):
|
| 2353 |
+
device = node.get_device()
|
| 2354 |
+
if self.get_backend(device).ready_to_flush():
|
| 2355 |
+
self.flush()
|
| 2356 |
+
|
| 2357 |
+
if self.current_device and self.current_device.type == "cuda":
|
| 2358 |
+
# exit the outermost CUDA device guard. this is
|
| 2359 |
+
# important for nested indentation codegen-ing.
|
| 2360 |
+
V.graph.wrapper_code.codegen_device_guard_exit()
|
| 2361 |
+
|
| 2362 |
+
self.flush()
|
| 2363 |
+
|
| 2364 |
+
def is_unaligned_buffer(self, buf_name):
|
| 2365 |
+
if buf_name in V.graph.graph_inputs or buf_name in V.graph.constants:
|
| 2366 |
+
# all graph inputs or constants are assumed to be aligned
|
| 2367 |
+
return False
|
| 2368 |
+
node = self.name_to_node[buf_name]
|
| 2369 |
+
layout = node.node.get_layout()
|
| 2370 |
+
if isinstance(layout, ir.AliasedLayout):
|
| 2371 |
+
return not layout.maybe_guard_aligned()
|
| 2372 |
+
else:
|
| 2373 |
+
return False
|
| 2374 |
+
|
| 2375 |
+
|
| 2376 |
+
class BaseScheduling:
|
| 2377 |
+
def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
|
| 2378 |
+
"""
|
| 2379 |
+
Check whether node1 and node2 can be vertically fused or not.
|
| 2380 |
+
"""
|
| 2381 |
+
raise NotImplementedError()
|
| 2382 |
+
|
| 2383 |
+
def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
|
| 2384 |
+
"""
|
| 2385 |
+
Check whether node1 and node2 can be horizontally fused or not.
|
| 2386 |
+
"""
|
| 2387 |
+
raise NotImplementedError()
|
| 2388 |
+
|
| 2389 |
+
def fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
|
| 2390 |
+
"""
|
| 2391 |
+
Fuse two nodes
|
| 2392 |
+
"""
|
| 2393 |
+
if node1.is_foreach() or node2.is_foreach():
|
| 2394 |
+
return ForeachKernelSchedulerNode.fuse(node1, node2)
|
| 2395 |
+
else:
|
| 2396 |
+
return FusedSchedulerNode.fuse(node1, node2)
|
| 2397 |
+
|
| 2398 |
+
def group_fn(self, sizes):
|
| 2399 |
+
"""
|
| 2400 |
+
Process the iteration sizes in case a transformation needs to be applied.
|
| 2401 |
+
"""
|
| 2402 |
+
raise NotImplementedError()
|
| 2403 |
+
|
| 2404 |
+
def codegen_template(
|
| 2405 |
+
self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode]
|
| 2406 |
+
):
|
| 2407 |
+
"""
|
| 2408 |
+
Given a template node, generate a kernel.
|
| 2409 |
+
|
| 2410 |
+
This function is only available for triton now. If the third-party backend behaves as a sub-class
|
| 2411 |
+
of TritonScheduling, it can override it or reuse it.
|
| 2412 |
+
"""
|
| 2413 |
+
raise NotImplementedError()
|
| 2414 |
+
|
| 2415 |
+
def codegen_nodes(self, nodes: List[SchedulerNode]):
|
| 2416 |
+
"""
|
| 2417 |
+
Generate a kernel given a list of pre-fused nodes.
|
| 2418 |
+
"""
|
| 2419 |
+
raise NotImplementedError()
|
| 2420 |
+
|
| 2421 |
+
def codegen_sync(self):
|
| 2422 |
+
"""
|
| 2423 |
+
Generate synchronization code for the kernel. This method depends on the hardware characteristics.
|
| 2424 |
+
"""
|
| 2425 |
+
raise NotImplementedError()
|
| 2426 |
+
|
| 2427 |
+
def ready_to_flush(self) -> bool:
|
| 2428 |
+
"""
|
| 2429 |
+
Check whether the backend is requesting the scheduler to flush the generated kernel.
|
| 2430 |
+
If not supported, please return False.
|
| 2431 |
+
"""
|
| 2432 |
+
return False
|
| 2433 |
+
|
| 2434 |
+
def flush(self):
|
| 2435 |
+
"""
|
| 2436 |
+
Flush the generated kernel and python wrapper code to the source code file.
|
| 2437 |
+
"""
|
| 2438 |
+
raise NotImplementedError()
|
| 2439 |
+
|
| 2440 |
+
def benchmark_fused_nodes(self, nodes):
|
| 2441 |
+
"""
|
| 2442 |
+
Benchmark fused list of nodes and return the execution time
|
| 2443 |
+
in milliseconds on randomly generated inputs.
|
| 2444 |
+
"""
|
| 2445 |
+
raise NotImplementedError()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py
ADDED
|
@@ -0,0 +1,1156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import builtins
|
| 2 |
+
import functools
|
| 3 |
+
import inspect
|
| 4 |
+
import itertools
|
| 5 |
+
import logging
|
| 6 |
+
import operator
|
| 7 |
+
import sys
|
| 8 |
+
import textwrap
|
| 9 |
+
import time
|
| 10 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 11 |
+
from io import StringIO
|
| 12 |
+
|
| 13 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 14 |
+
from unittest.mock import patch
|
| 15 |
+
|
| 16 |
+
import sympy
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch._dynamo.testing import rand_strided
|
| 20 |
+
from torch._dynamo.utils import counters, identity, preserve_rng_state
|
| 21 |
+
|
| 22 |
+
from . import config, ir
|
| 23 |
+
from .autotune_process import TensorMeta, TritonBenchmarkRequest
|
| 24 |
+
from .codecache import code_hash, PersistentCache, PyCodeCache
|
| 25 |
+
from .codegen.common import (
|
| 26 |
+
ChoiceCaller,
|
| 27 |
+
IndentedBuffer,
|
| 28 |
+
KernelTemplate,
|
| 29 |
+
PrimitiveInfoType,
|
| 30 |
+
)
|
| 31 |
+
from .codegen.triton import (
|
| 32 |
+
gen_common_triton_imports,
|
| 33 |
+
texpr,
|
| 34 |
+
TritonKernel,
|
| 35 |
+
TritonPrinter,
|
| 36 |
+
TritonScheduling,
|
| 37 |
+
)
|
| 38 |
+
from .codegen.triton_utils import config_of, signature_to_meta
|
| 39 |
+
from .exc import CUDACompileError
|
| 40 |
+
from .utils import (
|
| 41 |
+
do_bench,
|
| 42 |
+
get_dtype_size,
|
| 43 |
+
Placeholder,
|
| 44 |
+
sympy_dot,
|
| 45 |
+
sympy_product,
|
| 46 |
+
unique,
|
| 47 |
+
)
|
| 48 |
+
from .virtualized import V
|
| 49 |
+
|
| 50 |
+
log = logging.getLogger(__name__)
|
| 51 |
+
|
| 52 |
+
# correctness checks struggle with fp16/tf32
|
| 53 |
+
VERIFY: Dict[str, Any] = dict()
|
| 54 |
+
PRINT_AUTOTUNE = True
|
| 55 |
+
DEBUG = False
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class KernelNamespace:
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# these objects are imported from the generated wrapper code
|
| 63 |
+
extern_kernels = KernelNamespace()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class PartialRender:
|
| 67 |
+
"""
|
| 68 |
+
Some parts of a template need to be generated at the end, but
|
| 69 |
+
inserted into the template at the start. This allows doing a bunch
|
| 70 |
+
of replacements after the initial render.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, code, replacement_hooks):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.code = code
|
| 76 |
+
self.replacement_hooks = replacement_hooks
|
| 77 |
+
|
| 78 |
+
def finalize(self):
|
| 79 |
+
code = self.code
|
| 80 |
+
assert code is not None, "can only be called once"
|
| 81 |
+
self.code = None
|
| 82 |
+
for key, fn in self.replacement_hooks.items():
|
| 83 |
+
code = code.replace(key, fn())
|
| 84 |
+
return code
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class TritonTemplateKernel(TritonKernel):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
kernel_name,
|
| 91 |
+
input_nodes,
|
| 92 |
+
output_node,
|
| 93 |
+
defines,
|
| 94 |
+
num_stages,
|
| 95 |
+
num_warps,
|
| 96 |
+
grid_fn,
|
| 97 |
+
meta,
|
| 98 |
+
call_sizes,
|
| 99 |
+
use_jit=True,
|
| 100 |
+
prefix_args=0,
|
| 101 |
+
suffix_args=0,
|
| 102 |
+
epilogue_fn=identity,
|
| 103 |
+
*,
|
| 104 |
+
index_dtype,
|
| 105 |
+
):
|
| 106 |
+
super().__init__(
|
| 107 |
+
sympy_product(output_node.get_size()),
|
| 108 |
+
sympy.Integer(1),
|
| 109 |
+
index_dtype=index_dtype,
|
| 110 |
+
)
|
| 111 |
+
self.input_nodes = input_nodes
|
| 112 |
+
self.output_node = output_node
|
| 113 |
+
self.named_input_nodes = {}
|
| 114 |
+
self.defines = defines
|
| 115 |
+
self.kernel_name = kernel_name
|
| 116 |
+
self.template_mask = None
|
| 117 |
+
self.use_jit = use_jit
|
| 118 |
+
self.num_stages = num_stages
|
| 119 |
+
self.num_warps = num_warps
|
| 120 |
+
self.grid_fn = grid_fn
|
| 121 |
+
self.meta = meta
|
| 122 |
+
self.call_sizes = call_sizes
|
| 123 |
+
# for templates with fixed epilogues
|
| 124 |
+
self.prefix_args = prefix_args
|
| 125 |
+
self.suffix_args = suffix_args
|
| 126 |
+
self.epilogue_fn = epilogue_fn
|
| 127 |
+
self.render_hooks = dict()
|
| 128 |
+
self.triton_meta: Optional[Dict[str, object]] = None
|
| 129 |
+
|
| 130 |
+
def need_numel_args(self):
|
| 131 |
+
return False
|
| 132 |
+
|
| 133 |
+
def estimate_kernel_num_bytes(self):
|
| 134 |
+
"""
|
| 135 |
+
Estimate the total number of bytes this kernel takes.
|
| 136 |
+
For in/out nodes, sizes are counted twice: once for reading and
|
| 137 |
+
once for writing.
|
| 138 |
+
"""
|
| 139 |
+
ninplace_args = len(unique(self.args.inplace_buffers.values()))
|
| 140 |
+
num_bytes = []
|
| 141 |
+
for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))):
|
| 142 |
+
size = V.graph.sizevars.size_hints(inp.get_size())
|
| 143 |
+
numel = functools.reduce(operator.mul, size)
|
| 144 |
+
dtype_size = get_dtype_size(inp.get_dtype())
|
| 145 |
+
num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
|
| 146 |
+
return sum(num_bytes)
|
| 147 |
+
|
| 148 |
+
def jit_lines(self):
|
| 149 |
+
if self.use_jit:
|
| 150 |
+
return "@triton.jit"
|
| 151 |
+
|
| 152 |
+
argdefs, _, signature = self.args.python_argdefs()
|
| 153 |
+
triton_meta = {
|
| 154 |
+
"signature": signature_to_meta(signature, size_dtype=self.index_dtype),
|
| 155 |
+
"device": V.graph.scheduler.current_device.index,
|
| 156 |
+
"device_type": V.graph.scheduler.current_device.type,
|
| 157 |
+
"constants": {},
|
| 158 |
+
}
|
| 159 |
+
triton_meta["configs"] = [config_of(signature)]
|
| 160 |
+
for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index]
|
| 161 |
+
triton_meta["constants"][arg_num] = 1 # type: ignore[index]
|
| 162 |
+
self.triton_meta = triton_meta
|
| 163 |
+
|
| 164 |
+
inductor_meta = {
|
| 165 |
+
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
|
| 166 |
+
"backend_hash": torch.utils._triton.triton_hash_with_backend(),
|
| 167 |
+
}
|
| 168 |
+
if config.profile_bandwidth or config.benchmark_kernel:
|
| 169 |
+
num_gb = self.estimate_kernel_num_bytes() / 1e9
|
| 170 |
+
inductor_meta["kernel_num_gb"] = num_gb
|
| 171 |
+
return f"""
|
| 172 |
+
@triton_heuristics.template(
|
| 173 |
+
num_stages={self.num_stages},
|
| 174 |
+
num_warps={self.num_warps},
|
| 175 |
+
triton_meta={triton_meta!r},
|
| 176 |
+
inductor_meta={inductor_meta!r},
|
| 177 |
+
)
|
| 178 |
+
@triton.jit
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
def def_kernel(self, *argnames):
|
| 182 |
+
"""
|
| 183 |
+
Hook called from template code to generate function def and
|
| 184 |
+
needed args.
|
| 185 |
+
"""
|
| 186 |
+
assert all(isinstance(x, str) for x in argnames)
|
| 187 |
+
renames = IndentedBuffer(initial_indent=1)
|
| 188 |
+
|
| 189 |
+
named_args = self.input_nodes[
|
| 190 |
+
self.prefix_args : len(self.input_nodes) - self.suffix_args
|
| 191 |
+
]
|
| 192 |
+
|
| 193 |
+
assert len(argnames) == len(named_args), (
|
| 194 |
+
len(argnames),
|
| 195 |
+
len(named_args),
|
| 196 |
+
self.prefix_args,
|
| 197 |
+
len(self.input_nodes),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
for input_node in self.input_nodes[: self.prefix_args]:
|
| 201 |
+
# get args in correct order
|
| 202 |
+
self.args.input(input_node.get_name())
|
| 203 |
+
|
| 204 |
+
for name, input_node in zip(argnames, named_args):
|
| 205 |
+
arg_name = f"arg_{name}"
|
| 206 |
+
self.named_input_nodes[name] = input_node
|
| 207 |
+
self.args.input_buffers[input_node.get_name()] = arg_name
|
| 208 |
+
|
| 209 |
+
# The args may be duplicated, so renaming must be after args are de-duplicated.
|
| 210 |
+
for name in argnames:
|
| 211 |
+
input_node = self.named_input_nodes[name]
|
| 212 |
+
arg_name = self.args.input_buffers[input_node.get_name()]
|
| 213 |
+
if input_node.get_layout().offset == 0:
|
| 214 |
+
renames.writeline(f"{name} = {arg_name}")
|
| 215 |
+
else:
|
| 216 |
+
offset = texpr(self.rename_indexing(input_node.get_layout().offset))
|
| 217 |
+
renames.writeline(f"{name} = {arg_name} + {offset}")
|
| 218 |
+
|
| 219 |
+
for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]:
|
| 220 |
+
# get args in correct order
|
| 221 |
+
self.args.input(input_node.get_name())
|
| 222 |
+
|
| 223 |
+
def hook():
|
| 224 |
+
# python_argdefs() cannot be run until after the rest of the template lazily adds more args
|
| 225 |
+
arg_defs, *_ = self.args.python_argdefs()
|
| 226 |
+
code = IndentedBuffer()
|
| 227 |
+
code.splice(gen_common_triton_imports())
|
| 228 |
+
code.splice(self.jit_lines())
|
| 229 |
+
code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):")
|
| 230 |
+
with code.indent():
|
| 231 |
+
code.splice(self.defines)
|
| 232 |
+
code.splice(renames.getvalue())
|
| 233 |
+
return code.getvalue()
|
| 234 |
+
|
| 235 |
+
assert "<DEF_KERNEL>" not in self.render_hooks
|
| 236 |
+
self.render_hooks["<DEF_KERNEL>"] = hook
|
| 237 |
+
return "<DEF_KERNEL>"
|
| 238 |
+
|
| 239 |
+
def size(self, name: str, index: int):
|
| 240 |
+
"""
|
| 241 |
+
Hook called from template code to get the size of an arg.
|
| 242 |
+
Will add needed args to pass it in if it is dynamic.
|
| 243 |
+
"""
|
| 244 |
+
assert isinstance(index, int)
|
| 245 |
+
if name is None:
|
| 246 |
+
val = self.output_node.get_size()[index]
|
| 247 |
+
else:
|
| 248 |
+
assert isinstance(name, str)
|
| 249 |
+
val = self.named_input_nodes[name].get_size()[index]
|
| 250 |
+
return texpr(self.rename_indexing(val))
|
| 251 |
+
|
| 252 |
+
def stride(self, name, index):
|
| 253 |
+
"""
|
| 254 |
+
Hook called from template code to get the stride of an arg.
|
| 255 |
+
Will add needed args to pass it in if it is dynamic.
|
| 256 |
+
"""
|
| 257 |
+
assert isinstance(index, int)
|
| 258 |
+
if name is None:
|
| 259 |
+
val = self.output_node.get_stride()[index]
|
| 260 |
+
else:
|
| 261 |
+
assert isinstance(name, str)
|
| 262 |
+
val = self.named_input_nodes[name].get_stride()[index]
|
| 263 |
+
return texpr(self.rename_indexing(val))
|
| 264 |
+
|
| 265 |
+
def store_output(self, indices, val, mask):
|
| 266 |
+
"""
|
| 267 |
+
Hook called from template code to store the final output
|
| 268 |
+
(if the buffer hasn't been optimized away), then append any
|
| 269 |
+
epilogue fusions.
|
| 270 |
+
"""
|
| 271 |
+
assert isinstance(indices, (list, tuple))
|
| 272 |
+
assert isinstance(val, str)
|
| 273 |
+
assert isinstance(mask, str)
|
| 274 |
+
assert self.template_mask is None
|
| 275 |
+
indices = list(map(TritonPrinter.paren, indices))
|
| 276 |
+
index_symbols = [sympy.Symbol(x) for x in indices]
|
| 277 |
+
lengths = [V.graph.sizevars.simplify(s) for s in self.output_node.get_size()]
|
| 278 |
+
assert len(indices) == len(lengths)
|
| 279 |
+
|
| 280 |
+
# glue to make generated code use same indexing from template
|
| 281 |
+
for name, range_tree_entry in zip(
|
| 282 |
+
indices, self.range_trees[0].construct_entries(lengths)
|
| 283 |
+
):
|
| 284 |
+
range_tree_entry.set_name(name)
|
| 285 |
+
contiguous_index = sympy_dot(
|
| 286 |
+
ir.FlexibleLayout.contiguous_strides(lengths), index_symbols
|
| 287 |
+
)
|
| 288 |
+
contiguous_index = self.rename_indexing(contiguous_index)
|
| 289 |
+
self.body.writeline("xindex = " + texpr(contiguous_index))
|
| 290 |
+
self.range_trees[0].lookup(sympy.Integer(1), sympy_product(lengths)).set_name(
|
| 291 |
+
"xindex"
|
| 292 |
+
)
|
| 293 |
+
self.template_mask = mask
|
| 294 |
+
self.template_indices = indices
|
| 295 |
+
output_index = self.output_node.get_layout().make_indexer()(index_symbols)
|
| 296 |
+
output_index = self.rename_indexing(output_index)
|
| 297 |
+
if output_index == contiguous_index:
|
| 298 |
+
output_index = sympy.Symbol("xindex")
|
| 299 |
+
|
| 300 |
+
epilogue_args = [val]
|
| 301 |
+
for input_node in itertools.chain(
|
| 302 |
+
self.input_nodes[: self.prefix_args],
|
| 303 |
+
self.input_nodes[len(self.input_nodes) - self.suffix_args :],
|
| 304 |
+
):
|
| 305 |
+
input_node.freeze_layout()
|
| 306 |
+
epilogue_args.append(input_node.make_loader()(index_symbols))
|
| 307 |
+
|
| 308 |
+
V.ops.store(
|
| 309 |
+
self.output_node.get_name(),
|
| 310 |
+
output_index,
|
| 311 |
+
self.epilogue_fn(*epilogue_args),
|
| 312 |
+
)
|
| 313 |
+
self.codegen_body()
|
| 314 |
+
|
| 315 |
+
def hook():
|
| 316 |
+
# more stuff might have been added since the codegen_body above
|
| 317 |
+
self.codegen_body()
|
| 318 |
+
return textwrap.indent(self.body.getvalue(), " ").strip()
|
| 319 |
+
|
| 320 |
+
assert "<STORE_OUTPUT>" not in self.render_hooks
|
| 321 |
+
self.render_hooks["<STORE_OUTPUT>"] = hook
|
| 322 |
+
return "<STORE_OUTPUT>"
|
| 323 |
+
|
| 324 |
+
def render(self, template, kwargs):
|
| 325 |
+
return PartialRender(
|
| 326 |
+
template.render(**self.template_env(), **kwargs),
|
| 327 |
+
self.render_hooks,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
def make_load(self, name, indices, mask):
|
| 331 |
+
"""
|
| 332 |
+
Optional helper called from template code to generate the code
|
| 333 |
+
needed to load from an tensor.
|
| 334 |
+
"""
|
| 335 |
+
assert isinstance(indices, (list, tuple))
|
| 336 |
+
assert isinstance(name, str)
|
| 337 |
+
assert isinstance(mask, str)
|
| 338 |
+
stride = self.named_input_nodes[name].get_stride()
|
| 339 |
+
indices = list(map(TritonPrinter.paren, indices))
|
| 340 |
+
assert len(indices) == len(stride)
|
| 341 |
+
index = " + ".join(
|
| 342 |
+
f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)
|
| 343 |
+
)
|
| 344 |
+
return f"tl.load({name} + ({index}), {mask})"
|
| 345 |
+
|
| 346 |
+
def template_env(self):
|
| 347 |
+
"""
|
| 348 |
+
Generate the namespace visible in the template.
|
| 349 |
+
"""
|
| 350 |
+
return {
|
| 351 |
+
fn.__name__: fn
|
| 352 |
+
for fn in [
|
| 353 |
+
self.def_kernel,
|
| 354 |
+
self.size,
|
| 355 |
+
self.stride,
|
| 356 |
+
self.store_output,
|
| 357 |
+
self.make_load,
|
| 358 |
+
]
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
def indexing(
|
| 362 |
+
self,
|
| 363 |
+
index: sympy.Expr,
|
| 364 |
+
*,
|
| 365 |
+
dense_indexing=False,
|
| 366 |
+
copy_shape=None,
|
| 367 |
+
override_mask=None,
|
| 368 |
+
block_ptr=False,
|
| 369 |
+
):
|
| 370 |
+
"""
|
| 371 |
+
Override the default indexing to use our custom mask and force
|
| 372 |
+
dense indexing.
|
| 373 |
+
"""
|
| 374 |
+
return super().indexing(
|
| 375 |
+
index,
|
| 376 |
+
dense_indexing=False,
|
| 377 |
+
copy_shape=self.template_mask,
|
| 378 |
+
override_mask=self.template_mask,
|
| 379 |
+
block_ptr=block_ptr,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
def initialize_range_tree(self, pid_cache):
|
| 383 |
+
super().initialize_range_tree(pid_cache)
|
| 384 |
+
# ignore default codegen
|
| 385 |
+
self.body.clear()
|
| 386 |
+
self.indexing_code.clear()
|
| 387 |
+
|
| 388 |
+
def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
|
| 389 |
+
wrapper = V.graph.wrapper_code
|
| 390 |
+
_, call_args, _ = self.args.python_argdefs()
|
| 391 |
+
call_args = [str(a) for a in call_args]
|
| 392 |
+
|
| 393 |
+
for i in range(len(call_args)):
|
| 394 |
+
if V.graph.is_unspec_arg(call_args[i]):
|
| 395 |
+
call_args[i] = call_args[i] + ".item()"
|
| 396 |
+
if isinstance(call_args[i], sympy.Symbol):
|
| 397 |
+
call_args[i] = texpr(call_args[i])
|
| 398 |
+
|
| 399 |
+
if V.graph.cpp_wrapper:
|
| 400 |
+
# In the cpp_wrapper case, we have to compute CUDA launch grid at runtime
|
| 401 |
+
# if any dynamic dimension is involved. We rely on the Python version
|
| 402 |
+
# of the grid function to generate those grid configs, which may contain
|
| 403 |
+
# symbolic values. The wrapper will use cexpr to print out C++ code
|
| 404 |
+
# appropriately for the grid configs.
|
| 405 |
+
grid_args = [V.graph.sizevars.simplify(s) for s in self.call_sizes] + [
|
| 406 |
+
self.meta
|
| 407 |
+
]
|
| 408 |
+
grid = self.grid_fn(*grid_args)
|
| 409 |
+
|
| 410 |
+
wrapper.generate_kernel_call(
|
| 411 |
+
name,
|
| 412 |
+
call_args,
|
| 413 |
+
device_index=V.graph.scheduler.current_device.index,
|
| 414 |
+
grid=grid,
|
| 415 |
+
triton_meta=self.triton_meta,
|
| 416 |
+
)
|
| 417 |
+
else:
|
| 418 |
+
stream_name = wrapper.write_get_raw_stream(
|
| 419 |
+
V.graph.scheduler.current_device.index
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
wrapper.add_import_once(f"import {self.grid_fn.__module__}")
|
| 423 |
+
meta = wrapper.add_meta_once(self.meta)
|
| 424 |
+
|
| 425 |
+
grid_call = [
|
| 426 |
+
texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes
|
| 427 |
+
] + [meta]
|
| 428 |
+
grid_call = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})"
|
| 429 |
+
wrapper.writeline(
|
| 430 |
+
f"{name}.run({', '.join(call_args)}, grid={grid_call}, stream={stream_name})"
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
@functools.lru_cache(None)
|
| 435 |
+
def _jinja2_env():
|
| 436 |
+
try:
|
| 437 |
+
import jinja2
|
| 438 |
+
|
| 439 |
+
return jinja2.Environment(
|
| 440 |
+
undefined=jinja2.StrictUndefined,
|
| 441 |
+
)
|
| 442 |
+
except ImportError:
|
| 443 |
+
return None
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class TritonTemplate(KernelTemplate):
|
| 447 |
+
index_counter = itertools.count()
|
| 448 |
+
all_templates: Dict[str, "TritonTemplate"] = dict()
|
| 449 |
+
|
| 450 |
+
def __init__(self, name: str, grid: Any, source: str, debug=False):
|
| 451 |
+
super().__init__(name)
|
| 452 |
+
self.grid = grid
|
| 453 |
+
self.template = self._template_from_string(source)
|
| 454 |
+
assert name not in self.all_templates, "duplicate template name"
|
| 455 |
+
self.all_templates[name] = self
|
| 456 |
+
self.debug = debug
|
| 457 |
+
|
| 458 |
+
def generate(
|
| 459 |
+
self,
|
| 460 |
+
input_nodes,
|
| 461 |
+
layout,
|
| 462 |
+
num_stages,
|
| 463 |
+
num_warps,
|
| 464 |
+
prefix_args=0,
|
| 465 |
+
suffix_args=0,
|
| 466 |
+
epilogue_fn=identity,
|
| 467 |
+
**kwargs,
|
| 468 |
+
):
|
| 469 |
+
assert self.template, "requires jinja2"
|
| 470 |
+
defines = StringIO()
|
| 471 |
+
for name, val in kwargs.items():
|
| 472 |
+
defines.write(f" {name} : tl.constexpr = {val}\n")
|
| 473 |
+
defines = defines.getvalue()
|
| 474 |
+
|
| 475 |
+
fake_out = ir.Buffer("buf_out", layout)
|
| 476 |
+
kernel_name = f"triton_{self.name}"
|
| 477 |
+
|
| 478 |
+
numel = sympy_product(layout.size)
|
| 479 |
+
buffers = itertools.chain(input_nodes, (fake_out,))
|
| 480 |
+
if not TritonScheduling.can_use_32bit_indexing(numel, buffers):
|
| 481 |
+
raise NotImplementedError(
|
| 482 |
+
"64-bit indexing is not yet implemented for triton templates"
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
kernel_options = dict(
|
| 486 |
+
input_nodes=input_nodes,
|
| 487 |
+
defines=defines,
|
| 488 |
+
num_stages=num_stages,
|
| 489 |
+
num_warps=num_warps,
|
| 490 |
+
grid_fn=self.grid,
|
| 491 |
+
meta=kwargs,
|
| 492 |
+
call_sizes=layout.size,
|
| 493 |
+
prefix_args=prefix_args,
|
| 494 |
+
suffix_args=suffix_args,
|
| 495 |
+
epilogue_fn=epilogue_fn,
|
| 496 |
+
index_dtype="tl.int32",
|
| 497 |
+
)
|
| 498 |
+
with patch.object(
|
| 499 |
+
V.graph, "get_dtype", self._fake_get_dtype(fake_out)
|
| 500 |
+
), TritonTemplateKernel(
|
| 501 |
+
kernel_name=kernel_name,
|
| 502 |
+
output_node=fake_out,
|
| 503 |
+
use_jit=True,
|
| 504 |
+
**kernel_options,
|
| 505 |
+
) as kernel:
|
| 506 |
+
try:
|
| 507 |
+
code = kernel.render(self.template, kwargs).finalize()
|
| 508 |
+
except ZeroDivisionError:
|
| 509 |
+
# TODO(nmacchioni): fix sympy division by zero
|
| 510 |
+
return None
|
| 511 |
+
if self.debug:
|
| 512 |
+
print("Generated Code:\n", code)
|
| 513 |
+
extra = (
|
| 514 |
+
"-".join(
|
| 515 |
+
[
|
| 516 |
+
*[
|
| 517 |
+
f"{kwarg}={repr(kwargs[kwarg])}"
|
| 518 |
+
for kwarg in sorted(kwargs.keys())
|
| 519 |
+
],
|
| 520 |
+
f"num_stages={num_stages}",
|
| 521 |
+
f"num_warps={num_warps}",
|
| 522 |
+
]
|
| 523 |
+
)
|
| 524 |
+
+ "-"
|
| 525 |
+
)
|
| 526 |
+
mod = PyCodeCache.load(code, extra)
|
| 527 |
+
_, call_args, _ = kernel.args.python_argdefs()
|
| 528 |
+
|
| 529 |
+
expected_args = list(unique(x.get_name() for x in input_nodes))
|
| 530 |
+
expected_args.extend([fake_out.get_name()])
|
| 531 |
+
assert list(call_args)[: len(expected_args)] == expected_args, (
|
| 532 |
+
call_args,
|
| 533 |
+
expected_args,
|
| 534 |
+
)
|
| 535 |
+
extra_args = V.graph.sizevars.size_hints(
|
| 536 |
+
map(sympy.expand, call_args[len(expected_args) :]),
|
| 537 |
+
fallback=config.unbacked_symint_fallback,
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
|
| 541 |
+
|
| 542 |
+
def make_kernel_render(out_node):
|
| 543 |
+
kernel = TritonTemplateKernel(
|
| 544 |
+
kernel_name=str(Placeholder.KERNEL_NAME),
|
| 545 |
+
output_node=out_node,
|
| 546 |
+
use_jit=False,
|
| 547 |
+
**kernel_options,
|
| 548 |
+
)
|
| 549 |
+
render = functools.partial(
|
| 550 |
+
kernel.render,
|
| 551 |
+
self.template,
|
| 552 |
+
kwargs,
|
| 553 |
+
)
|
| 554 |
+
return kernel, render
|
| 555 |
+
|
| 556 |
+
# create the BenchmarkRequest
|
| 557 |
+
assert mod.__file__ is not None
|
| 558 |
+
grid = self.grid(
|
| 559 |
+
*V.graph.sizevars.size_hints(
|
| 560 |
+
layout.size,
|
| 561 |
+
fallback=config.unbacked_symint_fallback,
|
| 562 |
+
),
|
| 563 |
+
kwargs,
|
| 564 |
+
)
|
| 565 |
+
bmreq = TritonBenchmarkRequest(
|
| 566 |
+
module_path=mod.__file__,
|
| 567 |
+
module_cache_key=mod.key,
|
| 568 |
+
kernel_name=kernel_name,
|
| 569 |
+
grid=grid,
|
| 570 |
+
extra_args=extra_args,
|
| 571 |
+
num_stages=num_stages,
|
| 572 |
+
num_warps=num_warps,
|
| 573 |
+
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
|
| 574 |
+
input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
|
| 575 |
+
output_tensor_meta=TensorMeta.from_irnodes(layout),
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
return TritonTemplateCaller(
|
| 579 |
+
kernel_hash_name,
|
| 580 |
+
input_nodes,
|
| 581 |
+
layout,
|
| 582 |
+
make_kernel_render,
|
| 583 |
+
extra.strip("-").replace("-", ", "),
|
| 584 |
+
bmreq,
|
| 585 |
+
log_info={
|
| 586 |
+
"tile_shape": str(
|
| 587 |
+
(
|
| 588 |
+
kwargs.get("BLOCK_M", -1),
|
| 589 |
+
kwargs.get("BLOCK_K", -1),
|
| 590 |
+
kwargs.get("BLOCK_N", -1),
|
| 591 |
+
)
|
| 592 |
+
),
|
| 593 |
+
"num_stages": num_stages,
|
| 594 |
+
"num_warps": num_warps,
|
| 595 |
+
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
|
| 596 |
+
"acc_type": str(kwargs.get("ACC_TYPE", None)),
|
| 597 |
+
},
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
class ExternKernelChoice:
|
| 602 |
+
def __init__(
|
| 603 |
+
self,
|
| 604 |
+
kernel,
|
| 605 |
+
cpp_kernel=None,
|
| 606 |
+
*,
|
| 607 |
+
name=None,
|
| 608 |
+
has_out_variant=True,
|
| 609 |
+
op_overload=None,
|
| 610 |
+
use_fallback_kernel=False,
|
| 611 |
+
):
|
| 612 |
+
super().__init__()
|
| 613 |
+
name = name or kernel.__name__
|
| 614 |
+
assert callable(kernel)
|
| 615 |
+
assert not hasattr(extern_kernels, name), "duplicate extern kernel"
|
| 616 |
+
self.name = name
|
| 617 |
+
self.cpp_kernel_name = cpp_kernel
|
| 618 |
+
self.has_out_variant = has_out_variant
|
| 619 |
+
setattr(extern_kernels, name, kernel)
|
| 620 |
+
self.op_overload = op_overload
|
| 621 |
+
self.use_fallback_kernel = use_fallback_kernel
|
| 622 |
+
|
| 623 |
+
def to_callable(self):
|
| 624 |
+
return getattr(extern_kernels, self.name)
|
| 625 |
+
|
| 626 |
+
def call_name(self):
|
| 627 |
+
return f"extern_kernels.{self.name}"
|
| 628 |
+
|
| 629 |
+
@functools.lru_cache(None)
|
| 630 |
+
def hash_key(self):
|
| 631 |
+
fn = self.to_callable()
|
| 632 |
+
parts = [
|
| 633 |
+
self.name,
|
| 634 |
+
getattr(fn, "__name__", ""),
|
| 635 |
+
getattr(fn, "__module__", ""),
|
| 636 |
+
]
|
| 637 |
+
try:
|
| 638 |
+
parts.append(inspect.getsource(fn))
|
| 639 |
+
except Exception:
|
| 640 |
+
pass
|
| 641 |
+
return code_hash("-".join(parts))
|
| 642 |
+
|
| 643 |
+
def bind(
|
| 644 |
+
self,
|
| 645 |
+
input_nodes,
|
| 646 |
+
layout,
|
| 647 |
+
ordered_kwargs_for_cpp_kernel=(),
|
| 648 |
+
**kwargs,
|
| 649 |
+
):
|
| 650 |
+
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
|
| 651 |
+
return ExternKernelCaller(
|
| 652 |
+
self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
class TritonTemplateCaller(ChoiceCaller):
|
| 657 |
+
def __init__(
|
| 658 |
+
self,
|
| 659 |
+
name,
|
| 660 |
+
input_nodes,
|
| 661 |
+
layout,
|
| 662 |
+
make_kernel_render,
|
| 663 |
+
debug_extra,
|
| 664 |
+
bmreq,
|
| 665 |
+
log_info: Optional[
|
| 666 |
+
Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]
|
| 667 |
+
] = None,
|
| 668 |
+
):
|
| 669 |
+
super().__init__(name, input_nodes, layout)
|
| 670 |
+
self.make_kernel_render = make_kernel_render
|
| 671 |
+
self.debug_extra = debug_extra
|
| 672 |
+
self.bmreq: TritonBenchmarkRequest = bmreq
|
| 673 |
+
if log_info is None:
|
| 674 |
+
log_info = {}
|
| 675 |
+
self.log_info: Dict[str, Any] = log_info
|
| 676 |
+
self.log_info.update(
|
| 677 |
+
{
|
| 678 |
+
"backend": "Triton",
|
| 679 |
+
"grid": str(self.bmreq.grid),
|
| 680 |
+
"num_stages": self.bmreq.num_stages,
|
| 681 |
+
"num_warps": self.bmreq.num_warps,
|
| 682 |
+
}
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
def benchmark(self, *args, out):
|
| 686 |
+
assert self.bmreq is not None
|
| 687 |
+
return self.bmreq.benchmark(*args, output_tensor=out)
|
| 688 |
+
|
| 689 |
+
def __str__(self):
|
| 690 |
+
return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})"
|
| 691 |
+
|
| 692 |
+
def call_name(self):
|
| 693 |
+
return f"template_kernels.{self.name}"
|
| 694 |
+
|
| 695 |
+
def hash_key(self):
|
| 696 |
+
return "-".join(
|
| 697 |
+
[
|
| 698 |
+
self.name.rsplit("_", 1)[0],
|
| 699 |
+
self.bmreq.module_cache_key,
|
| 700 |
+
]
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
def output_node(self):
|
| 704 |
+
return ir.TensorBox.create(
|
| 705 |
+
ir.TritonTemplateBuffer(
|
| 706 |
+
layout=self.layout,
|
| 707 |
+
inputs=self.input_nodes,
|
| 708 |
+
make_kernel_render=self.make_kernel_render,
|
| 709 |
+
)
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
| 713 |
+
"""Information returned here is logged to the autotune log file when that is enabled."""
|
| 714 |
+
return self.log_info
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
class ExternKernelCaller(ChoiceCaller):
|
| 718 |
+
def __init__(
|
| 719 |
+
self,
|
| 720 |
+
choice: ExternKernelChoice,
|
| 721 |
+
input_nodes,
|
| 722 |
+
layout,
|
| 723 |
+
kwargs=None,
|
| 724 |
+
*,
|
| 725 |
+
has_out_variant=True,
|
| 726 |
+
):
|
| 727 |
+
super().__init__(choice.name, input_nodes, layout)
|
| 728 |
+
self.choice = choice
|
| 729 |
+
self.kwargs = kwargs or {}
|
| 730 |
+
self.has_out_variant = has_out_variant
|
| 731 |
+
|
| 732 |
+
def __str__(self):
|
| 733 |
+
return f"ExternKernelCaller({self.choice.call_name()})"
|
| 734 |
+
|
| 735 |
+
def benchmark(self, *args, out):
|
| 736 |
+
if self.has_out_variant:
|
| 737 |
+
return super().benchmark(*args, out=out)
|
| 738 |
+
else:
|
| 739 |
+
algo = self.to_callable()
|
| 740 |
+
out_new = algo(*args)
|
| 741 |
+
torch._C._dynamo.guards.assert_size_stride(
|
| 742 |
+
out_new, tuple(out.size()), tuple(out.stride())
|
| 743 |
+
)
|
| 744 |
+
out.copy_(out_new) # for correctness checking
|
| 745 |
+
return do_bench(lambda: algo(*args))
|
| 746 |
+
|
| 747 |
+
def to_callable(self):
|
| 748 |
+
fn = self.choice.to_callable()
|
| 749 |
+
if self.kwargs:
|
| 750 |
+
return functools.partial(fn, **self.kwargs)
|
| 751 |
+
else:
|
| 752 |
+
return fn
|
| 753 |
+
|
| 754 |
+
def hash_key(self):
|
| 755 |
+
return "-".join(
|
| 756 |
+
[
|
| 757 |
+
self.choice.name,
|
| 758 |
+
*[
|
| 759 |
+
f"{kwarg}={repr(self.kwargs[kwarg])}"
|
| 760 |
+
for kwarg in sorted(self.kwargs.keys())
|
| 761 |
+
],
|
| 762 |
+
self.choice.hash_key(),
|
| 763 |
+
]
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
def output_node(self):
|
| 767 |
+
if config.abi_compatible and self.choice.use_fallback_kernel:
|
| 768 |
+
assert (
|
| 769 |
+
self.choice.op_overload is not None
|
| 770 |
+
), "Please provide an op_overload to use ir.FallbackKernel"
|
| 771 |
+
inner = ir.FallbackKernel.create(
|
| 772 |
+
self.choice.op_overload, *self.input_nodes, **self.kwargs
|
| 773 |
+
)
|
| 774 |
+
else:
|
| 775 |
+
cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc
|
| 776 |
+
inner = cls(
|
| 777 |
+
layout=self.layout,
|
| 778 |
+
inputs=self.input_nodes,
|
| 779 |
+
python_kernel_name=self.choice.call_name(),
|
| 780 |
+
cpp_kernel_name=self.choice.cpp_kernel_name,
|
| 781 |
+
ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel,
|
| 782 |
+
op_overload=self.choice.op_overload,
|
| 783 |
+
kwargs=self.kwargs,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
return ir.TensorBox.create(inner)
|
| 787 |
+
|
| 788 |
+
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
| 789 |
+
"""Information returned here is logged to the autotune log file when that is enabled."""
|
| 790 |
+
return {
|
| 791 |
+
"backend": "extern",
|
| 792 |
+
"kernel_call_name": self.choice.call_name(),
|
| 793 |
+
}
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
class ErrorFromChoice(RuntimeError):
|
| 797 |
+
def __init__(self, msg, choice: ChoiceCaller, inputs_str):
|
| 798 |
+
msg += f"\nFrom choice {choice}\n{inputs_str}"
|
| 799 |
+
super().__init__(msg)
|
| 800 |
+
self.choice = choice
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
class AlgorithmSelectorCache(PersistentCache):
|
| 804 |
+
def __call__(
|
| 805 |
+
self,
|
| 806 |
+
name,
|
| 807 |
+
choices: List[ChoiceCaller],
|
| 808 |
+
input_nodes,
|
| 809 |
+
layout,
|
| 810 |
+
# optional dict mapping arg indices to the functions
|
| 811 |
+
# generating a torch.Tensor for that input from the
|
| 812 |
+
# corresponding ir.Buffer. if passed for a given
|
| 813 |
+
# arg, the function will be called instead of
|
| 814 |
+
# generating a random torch.Tensor for benchmarking.
|
| 815 |
+
input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
|
| 816 |
+
precompilation_timeout_seconds: int = 60 * 60,
|
| 817 |
+
):
|
| 818 |
+
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
|
| 819 |
+
|
| 820 |
+
# TODO(nmacchioni): remove once CI tests are fixed
|
| 821 |
+
choices = [choice for choice in choices if choice is not None]
|
| 822 |
+
if len(choices) == 0:
|
| 823 |
+
raise RuntimeError(
|
| 824 |
+
"No choices to select, please consider adding ATEN into max_autotune_gemm_backends "
|
| 825 |
+
"config (defined in torch/_inductor/config.py) to allow at least one choice. "
|
| 826 |
+
)
|
| 827 |
+
log.debug("Max autotune selects from %s choices.", str(len(choices)))
|
| 828 |
+
|
| 829 |
+
if len(choices) == 1:
|
| 830 |
+
if not isinstance(choices[0], CUDATemplateCaller):
|
| 831 |
+
# CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size.
|
| 832 |
+
return choices[0].output_node()
|
| 833 |
+
|
| 834 |
+
@functools.lru_cache(None)
|
| 835 |
+
def make_benchmark_fn():
|
| 836 |
+
return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns)
|
| 837 |
+
|
| 838 |
+
def precompile(choices):
|
| 839 |
+
if (
|
| 840 |
+
precompilation_timeout_seconds is None
|
| 841 |
+
or precompilation_timeout_seconds <= 0
|
| 842 |
+
):
|
| 843 |
+
return
|
| 844 |
+
num_workers = min(
|
| 845 |
+
config.compile_threads,
|
| 846 |
+
torch.get_num_threads(),
|
| 847 |
+
len(choices),
|
| 848 |
+
)
|
| 849 |
+
if num_workers <= 0:
|
| 850 |
+
return
|
| 851 |
+
log.info(
|
| 852 |
+
"Multithreaded precompilation for %d choices using %d worker threads",
|
| 853 |
+
len(choices),
|
| 854 |
+
num_workers,
|
| 855 |
+
)
|
| 856 |
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
| 857 |
+
futures = executor.map(
|
| 858 |
+
lambda c: c.precompile(),
|
| 859 |
+
[c for c in choices if hasattr(c, "precompile")],
|
| 860 |
+
timeout=precompilation_timeout_seconds,
|
| 861 |
+
)
|
| 862 |
+
try:
|
| 863 |
+
iterator = iter(futures)
|
| 864 |
+
while True:
|
| 865 |
+
try:
|
| 866 |
+
next(iterator)
|
| 867 |
+
except CUDACompileError:
|
| 868 |
+
log.error( # noqa: G201
|
| 869 |
+
"CUDA Compilation error", exc_info=True
|
| 870 |
+
)
|
| 871 |
+
except TimeoutError:
|
| 872 |
+
log.warning(
|
| 873 |
+
f"Precompilation timed out after {precompilation_timeout_seconds} seconds." # noqa: G004
|
| 874 |
+
)
|
| 875 |
+
except StopIteration:
|
| 876 |
+
pass
|
| 877 |
+
executor.shutdown(wait=True)
|
| 878 |
+
|
| 879 |
+
def autotune(choices):
|
| 880 |
+
try:
|
| 881 |
+
precompile(choices)
|
| 882 |
+
except TimeoutError:
|
| 883 |
+
log.warning(
|
| 884 |
+
"Precompilation phase took longer than timeout allowed. Continuing"
|
| 885 |
+
)
|
| 886 |
+
pass
|
| 887 |
+
return make_benchmark_fn()(choices)
|
| 888 |
+
|
| 889 |
+
if config.autotune_in_subproc:
|
| 890 |
+
from .autotune_process import tuning_pool
|
| 891 |
+
|
| 892 |
+
# do the optional warmup
|
| 893 |
+
tuning_pool.initialize()
|
| 894 |
+
|
| 895 |
+
autotune_start_ts = time.time()
|
| 896 |
+
timings = self.lookup(
|
| 897 |
+
choices,
|
| 898 |
+
name,
|
| 899 |
+
repr([self.key_of(x) for x in input_nodes]),
|
| 900 |
+
autotune,
|
| 901 |
+
)
|
| 902 |
+
autotune_elapse = time.time() - autotune_start_ts
|
| 903 |
+
if timings == {} or choices[0] not in timings:
|
| 904 |
+
return choices[0].output_node()
|
| 905 |
+
|
| 906 |
+
if make_benchmark_fn.cache_info().currsize:
|
| 907 |
+
counters["inductor"]["select_algorithm_autotune"] += 1
|
| 908 |
+
if (
|
| 909 |
+
make_benchmark_fn.cache_info().currsize
|
| 910 |
+
or log.getEffectiveLevel() == logging.DEBUG
|
| 911 |
+
or config.trace.log_autotuning_results
|
| 912 |
+
):
|
| 913 |
+
self.log_results(name, input_nodes, timings, autotune_elapse)
|
| 914 |
+
selected_choice = builtins.min(timings, key=timings.__getitem__).output_node()
|
| 915 |
+
log.debug("selected choice: %s", str(selected_choice))
|
| 916 |
+
return selected_choice
|
| 917 |
+
|
| 918 |
+
@classmethod
|
| 919 |
+
def make_benchmark_fn(
|
| 920 |
+
cls,
|
| 921 |
+
choices,
|
| 922 |
+
input_nodes,
|
| 923 |
+
layout,
|
| 924 |
+
input_gen_fns=None,
|
| 925 |
+
):
|
| 926 |
+
if input_gen_fns is None:
|
| 927 |
+
input_gen_fns = {}
|
| 928 |
+
|
| 929 |
+
# de-duplicate args
|
| 930 |
+
unique_example_inputs = {
|
| 931 |
+
x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x)
|
| 932 |
+
for i, x in enumerate(input_nodes)
|
| 933 |
+
}
|
| 934 |
+
example_inputs = list(unique_example_inputs.values())
|
| 935 |
+
example_inputs_extern = [
|
| 936 |
+
torch.as_strided(
|
| 937 |
+
unique_example_inputs[input_node.get_name()],
|
| 938 |
+
V.graph.sizevars.size_hints(
|
| 939 |
+
input_node.get_size(),
|
| 940 |
+
fallback=config.unbacked_symint_fallback,
|
| 941 |
+
),
|
| 942 |
+
V.graph.sizevars.size_hints(
|
| 943 |
+
input_node.get_stride(),
|
| 944 |
+
fallback=config.unbacked_symint_fallback,
|
| 945 |
+
),
|
| 946 |
+
V.graph.sizevars.size_hint(
|
| 947 |
+
input_node.get_layout().offset,
|
| 948 |
+
fallback=config.unbacked_symint_fallback,
|
| 949 |
+
),
|
| 950 |
+
)
|
| 951 |
+
for input_node in input_nodes
|
| 952 |
+
]
|
| 953 |
+
|
| 954 |
+
out = cls.benchmark_example_value(layout)
|
| 955 |
+
out_extern = torch.as_strided(
|
| 956 |
+
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
|
| 957 |
+
)
|
| 958 |
+
if VERIFY:
|
| 959 |
+
choices[0].benchmark(*example_inputs_extern, out=out_extern)
|
| 960 |
+
expected = out_extern.clone()
|
| 961 |
+
|
| 962 |
+
if DEBUG:
|
| 963 |
+
print(f"{len(choices)} tuning requests:")
|
| 964 |
+
|
| 965 |
+
def debug_str():
|
| 966 |
+
def tensor_repr(x):
|
| 967 |
+
return (
|
| 968 |
+
f"torch.empty_strided({tuple(x.size())!r}, {tuple(x.stride())!r}, "
|
| 969 |
+
f"dtype={x.dtype!r}, device={x.device.type!r})"
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
lines = [
|
| 973 |
+
"inputs = [",
|
| 974 |
+
]
|
| 975 |
+
for x in example_inputs:
|
| 976 |
+
lines.append(f" {tensor_repr(x)},")
|
| 977 |
+
lines += ["]", f"out = {tensor_repr(out)}", ""]
|
| 978 |
+
return "\n".join(lines)
|
| 979 |
+
|
| 980 |
+
def benchmark_choice_in_current_process(choice):
|
| 981 |
+
out.zero_()
|
| 982 |
+
if isinstance(choice, ExternKernelCaller):
|
| 983 |
+
# aten kernels want the offset baked in for sliced tensors
|
| 984 |
+
result = choice.benchmark(*example_inputs_extern, out=out_extern)
|
| 985 |
+
else:
|
| 986 |
+
# triton templates want the base pointer for sliced tensors
|
| 987 |
+
result = choice.benchmark(*example_inputs, out=out)
|
| 988 |
+
if VERIFY:
|
| 989 |
+
torch.testing.assert_close(out_extern, expected, **VERIFY)
|
| 990 |
+
torch.cuda.synchronize() # shake out any CUDA errors
|
| 991 |
+
return result
|
| 992 |
+
|
| 993 |
+
def benchmark_in_current_process(choices):
|
| 994 |
+
timings = {}
|
| 995 |
+
for choice in choices:
|
| 996 |
+
try:
|
| 997 |
+
timing = benchmark_choice_in_current_process(choice)
|
| 998 |
+
except CUDACompileError as e:
|
| 999 |
+
log.warning(
|
| 1000 |
+
"CUDA compilation error: \n%s. \nIgnore this choice.", str(e)
|
| 1001 |
+
)
|
| 1002 |
+
timing = float("inf")
|
| 1003 |
+
except RuntimeError as e:
|
| 1004 |
+
msg = str(e)
|
| 1005 |
+
if "invalid argument" in msg:
|
| 1006 |
+
msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n"
|
| 1007 |
+
log.warning(msg)
|
| 1008 |
+
timing = float("inf")
|
| 1009 |
+
else:
|
| 1010 |
+
if "illegal memory access" in msg:
|
| 1011 |
+
msg += "\n\nEither error in template or triton bug.\n"
|
| 1012 |
+
raise ErrorFromChoice(msg, choice, debug_str()) # noqa: TRY200
|
| 1013 |
+
except AssertionError as e:
|
| 1014 |
+
raise AssertionError( # noqa: TRY200
|
| 1015 |
+
f"Incorrect result from choice {choice}\n\n{e}"
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
timings[choice] = timing
|
| 1019 |
+
|
| 1020 |
+
return timings
|
| 1021 |
+
|
| 1022 |
+
def benchmark_in_sub_process(choices):
|
| 1023 |
+
from . import autotune_process
|
| 1024 |
+
|
| 1025 |
+
# only benchmark triton kernel in sub process for now.
|
| 1026 |
+
# ATen/Extern kernel are still benchmarked in the current process.
|
| 1027 |
+
extern = [c for c in choices if isinstance(c, ExternKernelCaller)]
|
| 1028 |
+
triton = [c for c in choices if not isinstance(c, ExternKernelCaller)]
|
| 1029 |
+
|
| 1030 |
+
timings = benchmark_in_current_process(extern)
|
| 1031 |
+
timings.update(autotune_process.benchmark_in_sub_process(triton))
|
| 1032 |
+
return timings
|
| 1033 |
+
|
| 1034 |
+
benchmark = (
|
| 1035 |
+
benchmark_in_sub_process
|
| 1036 |
+
if config.autotune_in_subproc
|
| 1037 |
+
else benchmark_in_current_process
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
return benchmark
|
| 1041 |
+
|
| 1042 |
+
@staticmethod
|
| 1043 |
+
def log_results(
|
| 1044 |
+
name: str,
|
| 1045 |
+
input_nodes: List[ir.IRNode],
|
| 1046 |
+
timings: Dict[ChoiceCaller, float],
|
| 1047 |
+
elapse: float,
|
| 1048 |
+
):
|
| 1049 |
+
V.debug.log_autotuning_results(name, input_nodes, timings, elapse)
|
| 1050 |
+
if not (config.max_autotune or config.max_autotune_gemm) or not PRINT_AUTOTUNE:
|
| 1051 |
+
return
|
| 1052 |
+
sizes = ", ".join(
|
| 1053 |
+
[
|
| 1054 |
+
"x".join(
|
| 1055 |
+
map(
|
| 1056 |
+
str,
|
| 1057 |
+
V.graph.sizevars.size_hints(
|
| 1058 |
+
n.get_size(), fallback=config.unbacked_symint_fallback
|
| 1059 |
+
),
|
| 1060 |
+
)
|
| 1061 |
+
)
|
| 1062 |
+
for n in input_nodes
|
| 1063 |
+
]
|
| 1064 |
+
)
|
| 1065 |
+
n = None if log.getEffectiveLevel() == logging.DEBUG else 10
|
| 1066 |
+
top_k = sorted(timings, key=timings.__getitem__)[:n]
|
| 1067 |
+
best = top_k[0]
|
| 1068 |
+
best_time = timings[best]
|
| 1069 |
+
sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
|
| 1070 |
+
for choice in top_k:
|
| 1071 |
+
result = timings[choice]
|
| 1072 |
+
if result:
|
| 1073 |
+
sys.stderr.write(
|
| 1074 |
+
f" {choice.name} {result:.4f} ms {best_time/result:.1%}\n"
|
| 1075 |
+
)
|
| 1076 |
+
else:
|
| 1077 |
+
sys.stderr.write(
|
| 1078 |
+
f" {choice.name} {result:.4f} ms <DIVIDED BY ZERO ERROR>\n"
|
| 1079 |
+
)
|
| 1080 |
+
|
| 1081 |
+
autotune_type_str = (
|
| 1082 |
+
"SubProcess" if config.autotune_in_subproc else "SingleProcess"
|
| 1083 |
+
)
|
| 1084 |
+
sys.stderr.write(f"{autotune_type_str} AUTOTUNE takes {elapse:.4f} seconds\n")
|
| 1085 |
+
|
| 1086 |
+
@staticmethod
|
| 1087 |
+
def benchmark_example_value(node):
|
| 1088 |
+
"""
|
| 1089 |
+
Convert an ir.Buffer into a concrete torch.Tensor we can use for
|
| 1090 |
+
benchmarking.
|
| 1091 |
+
"""
|
| 1092 |
+
if isinstance(node, ir.Layout):
|
| 1093 |
+
node = ir.Buffer("fake", node)
|
| 1094 |
+
# triton templates want the base tensor.
|
| 1095 |
+
if isinstance(node, ir.BaseView):
|
| 1096 |
+
node = node.unwrap_view()
|
| 1097 |
+
# preserve rng states to avoid the rand_strided call below changes
|
| 1098 |
+
# the rng states for the real model code.
|
| 1099 |
+
with preserve_rng_state():
|
| 1100 |
+
return rand_strided(
|
| 1101 |
+
V.graph.sizevars.size_hints(
|
| 1102 |
+
node.get_size(),
|
| 1103 |
+
fallback=config.unbacked_symint_fallback,
|
| 1104 |
+
),
|
| 1105 |
+
V.graph.sizevars.size_hints(
|
| 1106 |
+
node.get_stride(),
|
| 1107 |
+
fallback=config.unbacked_symint_fallback,
|
| 1108 |
+
),
|
| 1109 |
+
device=node.get_device(),
|
| 1110 |
+
dtype=node.get_dtype(),
|
| 1111 |
+
extra_size=node.layout.offset,
|
| 1112 |
+
)
|
| 1113 |
+
|
| 1114 |
+
@staticmethod
|
| 1115 |
+
def key_of(node):
|
| 1116 |
+
"""
|
| 1117 |
+
Extract the pieces of an ir.Buffer that we should invalidate cached
|
| 1118 |
+
autotuning results on.
|
| 1119 |
+
"""
|
| 1120 |
+
sizevars = V.graph.sizevars
|
| 1121 |
+
return (
|
| 1122 |
+
node.get_device().type,
|
| 1123 |
+
str(node.get_dtype()),
|
| 1124 |
+
*sizevars.size_hints(
|
| 1125 |
+
node.get_size(),
|
| 1126 |
+
fallback=config.unbacked_symint_fallback,
|
| 1127 |
+
),
|
| 1128 |
+
*sizevars.size_hints(
|
| 1129 |
+
node.get_stride(),
|
| 1130 |
+
fallback=config.unbacked_symint_fallback,
|
| 1131 |
+
),
|
| 1132 |
+
sizevars.size_hint(
|
| 1133 |
+
node.get_layout().offset,
|
| 1134 |
+
fallback=config.unbacked_symint_fallback,
|
| 1135 |
+
),
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
|
| 1139 |
+
_ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None
|
| 1140 |
+
|
| 1141 |
+
|
| 1142 |
+
def autotune_select_algorithm(*args, **kwargs):
|
| 1143 |
+
global _ALGORITHM_SELECTOR_CACHE
|
| 1144 |
+
if _ALGORITHM_SELECTOR_CACHE is None:
|
| 1145 |
+
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
|
| 1146 |
+
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
|
| 1147 |
+
|
| 1148 |
+
|
| 1149 |
+
def realize_inputs(*args):
|
| 1150 |
+
if len(args) == 1:
|
| 1151 |
+
return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0]))
|
| 1152 |
+
return [realize_inputs(x) for x in args]
|
| 1153 |
+
|
| 1154 |
+
|
| 1155 |
+
# ensure lowering is imported so that `extern_kernels.*` is populated
|
| 1156 |
+
from . import lowering # noqa: F401
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/sizevars.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import itertools
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import sympy
|
| 7 |
+
from sympy import Expr
|
| 8 |
+
|
| 9 |
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
| 10 |
+
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
| 11 |
+
from torch.utils._sympy.value_ranges import bound_sympy
|
| 12 |
+
|
| 13 |
+
from .utils import sympy_index_symbol, sympy_subs, VarRanges
|
| 14 |
+
from .virtualized import V
|
| 15 |
+
|
| 16 |
+
log = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# This class is a little awkward, because ShapeEnv is doing most of the heavy
|
| 20 |
+
# lifting and in some cases we should be directly passing through to ShapeEnv,
|
| 21 |
+
# but there is some extra inductor logic that needs to be handled here
|
| 22 |
+
class SizeVarAllocator:
|
| 23 |
+
def __init__(self, shape_env=None):
|
| 24 |
+
super().__init__()
|
| 25 |
+
if shape_env is None:
|
| 26 |
+
shape_env = ShapeEnv()
|
| 27 |
+
self.shape_env = shape_env
|
| 28 |
+
self.var_to_val = self.shape_env.var_to_val
|
| 29 |
+
self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements
|
| 30 |
+
# Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
|
| 31 |
+
# The basic idea is if we have some complicated sympy expression
|
| 32 |
+
# f(s0), we may choose to precompute it on the host and then replace
|
| 33 |
+
# all occurrences of that sympy expression with ps0, so that when we
|
| 34 |
+
# codegen we simply reference ps0 directly without repeating
|
| 35 |
+
# f(s0). Unlike regular size variables, ps variables cannot be
|
| 36 |
+
# guarded upon; so if we are asked to guard on a Sympy expression
|
| 37 |
+
# which potentially could have already had a precomputed replacement
|
| 38 |
+
# on it, we are obligated to invert the precomputed replacements
|
| 39 |
+
# (inv_precomputed_replacements).
|
| 40 |
+
self.precomputed_replacements: Dict[Expr, sympy.Symbol] = dict()
|
| 41 |
+
self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = dict()
|
| 42 |
+
self.stride_vars = self.make_stride_vars_cache()
|
| 43 |
+
self.simplify_with_ranges = self.make_simplify_with_ranges_cache()
|
| 44 |
+
self._simplify_loops = self.make_simplify_loops_cache()
|
| 45 |
+
|
| 46 |
+
def simplify(self, expr: Expr):
|
| 47 |
+
return sympy.expand(expr).xreplace(self.replacements)
|
| 48 |
+
|
| 49 |
+
def make_simplify_with_ranges_cache(self) -> Callable[[Expr, VarRanges], Expr]:
|
| 50 |
+
"""
|
| 51 |
+
self._simplify_with_ranges() can be expensive, cache its results
|
| 52 |
+
"""
|
| 53 |
+
cache: Dict[Tuple[Any, ...], Expr] = dict()
|
| 54 |
+
replacement_count = len(self.replacements)
|
| 55 |
+
|
| 56 |
+
def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr:
|
| 57 |
+
nonlocal replacement_count
|
| 58 |
+
if replacement_count != len(self.replacements):
|
| 59 |
+
# new replacements invalidates cached results
|
| 60 |
+
cache.clear()
|
| 61 |
+
replacement_count = len(self.replacements)
|
| 62 |
+
key = (expr, *var_ranges.items())
|
| 63 |
+
result = cache.get(key, None)
|
| 64 |
+
if result is None:
|
| 65 |
+
result = self._simplify_with_ranges(expr, var_ranges)
|
| 66 |
+
cache[key] = result
|
| 67 |
+
return result
|
| 68 |
+
|
| 69 |
+
return simplify_with_ranges
|
| 70 |
+
|
| 71 |
+
def make_simplify_loops_cache(self):
|
| 72 |
+
"""
|
| 73 |
+
self._simplify_with_ranges() can be expensive, cache its results
|
| 74 |
+
"""
|
| 75 |
+
cache: Dict[Tuple[Any, ...], Any] = dict()
|
| 76 |
+
replacement_count = len(self.replacements)
|
| 77 |
+
|
| 78 |
+
def simplify_loops(index_vars, sizes, index_formulas):
|
| 79 |
+
nonlocal replacement_count
|
| 80 |
+
if replacement_count != len(self.replacements):
|
| 81 |
+
# new replacements invalidates cached results
|
| 82 |
+
cache.clear()
|
| 83 |
+
replacement_count = len(self.replacements)
|
| 84 |
+
key = (*index_vars, *sizes, *index_formulas)
|
| 85 |
+
result = cache.get(key, None)
|
| 86 |
+
if result is None:
|
| 87 |
+
result = self._simplify_loops_impl(index_vars, sizes, index_formulas)
|
| 88 |
+
cache[key] = result
|
| 89 |
+
return result
|
| 90 |
+
|
| 91 |
+
return simplify_loops
|
| 92 |
+
|
| 93 |
+
def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr:
|
| 94 |
+
"""
|
| 95 |
+
Simplify indexing expression with knowledge of the ranges of
|
| 96 |
+
iteration variables.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
expr = join_dimensions(self.simplify(expr))
|
| 100 |
+
original_expr = expr
|
| 101 |
+
|
| 102 |
+
def remove_zero_terms(base, divisor):
|
| 103 |
+
"""Symbols smaller than the divisor are zero"""
|
| 104 |
+
for v in base.free_symbols:
|
| 105 |
+
if v in var_ranges:
|
| 106 |
+
# var smaller than divisor can be removed
|
| 107 |
+
# if the rest is guaranteed to be multiple of divisor
|
| 108 |
+
rest = sympy.Wild("_rest", exclude=[v])
|
| 109 |
+
m = base.match(v + rest)
|
| 110 |
+
if m and v not in m[rest].free_symbols:
|
| 111 |
+
gcd = sympy.gcd(m[rest], divisor)
|
| 112 |
+
if gcd == divisor:
|
| 113 |
+
if self.statically_known_leq(var_ranges[v], divisor):
|
| 114 |
+
base = m[rest]
|
| 115 |
+
return base
|
| 116 |
+
|
| 117 |
+
def visit_indexing_div(base, divisor):
|
| 118 |
+
return FloorDiv(remove_zero_terms(base, divisor), divisor)
|
| 119 |
+
|
| 120 |
+
def visit_modular_indexing(base, divisor, modulus):
|
| 121 |
+
base = remove_zero_terms(base, divisor)
|
| 122 |
+
base_pos = True
|
| 123 |
+
if isinstance(base, ModularIndexing):
|
| 124 |
+
# for modular indexing, biggest values from the ranges don't necessarily result in
|
| 125 |
+
# the biggest result, the biggest result is modulus - 1
|
| 126 |
+
base_s = base.args[2] - 1
|
| 127 |
+
elif not base.has(ModularIndexing):
|
| 128 |
+
# actual iteration range is to size-1
|
| 129 |
+
iter_ranges_zero = {k: 0 for k, v in var_ranges.items()}
|
| 130 |
+
base_lowest = sympy_subs(base, iter_ranges_zero)
|
| 131 |
+
if self.statically_known_leq(0, base_lowest): # type: ignore[arg-type]
|
| 132 |
+
# can't replace with indexing div if base can be negative
|
| 133 |
+
base_pos = True
|
| 134 |
+
else:
|
| 135 |
+
base_pos = False
|
| 136 |
+
iter_ranges = {k: v - 1 for k, v in var_ranges.items()}
|
| 137 |
+
base_s = sympy_subs(base, iter_ranges)
|
| 138 |
+
else:
|
| 139 |
+
base_s = base
|
| 140 |
+
if self.statically_known_lt(base_s, modulus * divisor) and base_pos:
|
| 141 |
+
return FloorDiv(base, divisor)
|
| 142 |
+
return ModularIndexing(base, divisor, modulus)
|
| 143 |
+
|
| 144 |
+
if expr.has(ModularIndexing):
|
| 145 |
+
expr = expr.replace(
|
| 146 |
+
ModularIndexing(
|
| 147 |
+
sympy.Wild("base"),
|
| 148 |
+
sympy.Wild("divisor"),
|
| 149 |
+
sympy.Wild("modulus"),
|
| 150 |
+
),
|
| 151 |
+
visit_modular_indexing,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if expr.has(FloorDiv):
|
| 155 |
+
expr = expr.replace(
|
| 156 |
+
FloorDiv(
|
| 157 |
+
sympy.Wild("base"),
|
| 158 |
+
sympy.Wild("divisor"),
|
| 159 |
+
),
|
| 160 |
+
visit_indexing_div,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if expr != original_expr:
|
| 164 |
+
return self._simplify_with_ranges(expr, var_ranges)
|
| 165 |
+
return expr
|
| 166 |
+
|
| 167 |
+
def _simplify_loops_impl(
|
| 168 |
+
self, index_vars: List[sympy.Symbol], sizes, index_formulas
|
| 169 |
+
):
|
| 170 |
+
"""
|
| 171 |
+
Try to remove as many axis from loop iterations as possible, by:
|
| 172 |
+
1) removing size==1 dimensions
|
| 173 |
+
2) fuse contiguous dimensions into a single loop
|
| 174 |
+
If channel_last = True, we will prevent the last dim fused with other dims
|
| 175 |
+
"""
|
| 176 |
+
sizes = list(map(self.simplify, sizes))
|
| 177 |
+
|
| 178 |
+
strides = [self.stride_vars(x, index_vars) for x in index_formulas]
|
| 179 |
+
assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))
|
| 180 |
+
|
| 181 |
+
for i in range(len(sizes)):
|
| 182 |
+
if sizes[i] == 1:
|
| 183 |
+
# remove dim
|
| 184 |
+
sizes[i] = None
|
| 185 |
+
|
| 186 |
+
def can_merge_dims(a, b):
|
| 187 |
+
for k in range(len(strides)):
|
| 188 |
+
if self.simplify(strides[k][a] * sizes[a]) == self.simplify(
|
| 189 |
+
strides[k][b]
|
| 190 |
+
):
|
| 191 |
+
# approximate test passed, try sound version
|
| 192 |
+
va = index_vars[a]
|
| 193 |
+
vb = index_vars[b]
|
| 194 |
+
v = sympy_index_symbol("_merge_tester")
|
| 195 |
+
expr1 = sympy_subs(index_formulas[k], {va: v * sizes[a], vb: 0})
|
| 196 |
+
expr2 = sympy_subs(index_formulas[k], {va: 0, vb: v})
|
| 197 |
+
if self.simplify(expr1) == self.simplify(expr2):
|
| 198 |
+
continue
|
| 199 |
+
return False
|
| 200 |
+
return True
|
| 201 |
+
|
| 202 |
+
changed = True
|
| 203 |
+
while changed:
|
| 204 |
+
changed = False
|
| 205 |
+
for i, j in itertools.product(
|
| 206 |
+
reversed(range(len(sizes))), reversed(range(len(sizes)))
|
| 207 |
+
):
|
| 208 |
+
if i == j or sizes[i] is None or sizes[j] is None:
|
| 209 |
+
continue
|
| 210 |
+
if can_merge_dims(i, j):
|
| 211 |
+
changed = True
|
| 212 |
+
sizes[i] = sizes[i] * sizes[j]
|
| 213 |
+
sizes[j] = None
|
| 214 |
+
|
| 215 |
+
def reindex(index):
|
| 216 |
+
it = list(reversed(index))
|
| 217 |
+
new_index = []
|
| 218 |
+
for size in sizes:
|
| 219 |
+
if size is None:
|
| 220 |
+
new_index.append(sympy.Integer(0))
|
| 221 |
+
else:
|
| 222 |
+
new_index.append(it.pop())
|
| 223 |
+
assert not it
|
| 224 |
+
return new_index
|
| 225 |
+
|
| 226 |
+
def prune(index):
|
| 227 |
+
assert len(index) == len(sizes)
|
| 228 |
+
return [i for i, s in zip(index, sizes) if s is not None]
|
| 229 |
+
|
| 230 |
+
return [x for x in sizes if x is not None], reindex, prune
|
| 231 |
+
|
| 232 |
+
# Note - [On Statically Known]
|
| 233 |
+
#
|
| 234 |
+
# The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system
|
| 235 |
+
# operated by providing essentially a question, where the size hinted values were evaluated. If the condition was
|
| 236 |
+
# true, we add a guard and return True, otherwise, False.
|
| 237 |
+
#
|
| 238 |
+
# def maybe_guard_foo(args):
|
| 239 |
+
# if size_hinted_check(args):
|
| 240 |
+
# return False # No guard, no optim
|
| 241 |
+
# guard(args) # Make a guard
|
| 242 |
+
# return True # Safe to apply optimization
|
| 243 |
+
#
|
| 244 |
+
# The prior system incurred a guard, and green lit an optimization.
|
| 245 |
+
#
|
| 246 |
+
# The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the
|
| 247 |
+
# condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we
|
| 248 |
+
# return False.
|
| 249 |
+
#
|
| 250 |
+
# def maybe_guard_foo(args):
|
| 251 |
+
# if all_static(args):
|
| 252 |
+
# return True # Safe to apply optimization
|
| 253 |
+
# else:
|
| 254 |
+
# return False # No guard, no optim
|
| 255 |
+
|
| 256 |
+
# See Note - [On Statically Known]
|
| 257 |
+
|
| 258 |
+
def is_expr_static_and_true(self, expr: Union[Expr, int]) -> bool:
|
| 259 |
+
if expr in (True, False):
|
| 260 |
+
return bool(expr)
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
simplified = self.shape_env._maybe_evaluate_static(expr)
|
| 264 |
+
if simplified is not None:
|
| 265 |
+
return bool(simplified)
|
| 266 |
+
except Exception:
|
| 267 |
+
log.debug("Could not simplify %s", expr)
|
| 268 |
+
|
| 269 |
+
return False
|
| 270 |
+
|
| 271 |
+
def statically_known_equals(self, left: Expr, right: Expr) -> bool:
|
| 272 |
+
"""
|
| 273 |
+
Returns a bool indicating if it is sound to optimize as if left and right are equal.
|
| 274 |
+
"""
|
| 275 |
+
return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type]
|
| 276 |
+
|
| 277 |
+
# See Note - [On Statically Known]
|
| 278 |
+
def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
|
| 279 |
+
"""
|
| 280 |
+
Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
|
| 281 |
+
"""
|
| 282 |
+
if len(left) != len(right):
|
| 283 |
+
return False
|
| 284 |
+
if all(self.statically_known_equals(l, r) for l, r in zip(left, right)):
|
| 285 |
+
return True
|
| 286 |
+
return False
|
| 287 |
+
|
| 288 |
+
# See Note - [On Statically Known]
|
| 289 |
+
def statically_known_leq(self, left: Expr, right: Expr) -> bool:
|
| 290 |
+
"""
|
| 291 |
+
Returns a bool indicating if it is sound to optimize as if left is less than or equal to right.
|
| 292 |
+
"""
|
| 293 |
+
expr = left <= right
|
| 294 |
+
return self.is_expr_static_and_true(expr)
|
| 295 |
+
|
| 296 |
+
# See Note - [On Statically Known]
|
| 297 |
+
def statically_known_lt(self, left: Expr, right: Expr) -> bool:
|
| 298 |
+
"""
|
| 299 |
+
Returns a bool indicating if it is sound to optimize as if left is less than right.
|
| 300 |
+
"""
|
| 301 |
+
expr = left < right
|
| 302 |
+
return self.is_expr_static_and_true(expr)
|
| 303 |
+
|
| 304 |
+
# See Note - [On Statically Known]
|
| 305 |
+
def statically_known_multiple_of(self, numerator: Expr, denominator: Expr) -> bool:
|
| 306 |
+
"""
|
| 307 |
+
Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
|
| 308 |
+
"""
|
| 309 |
+
expr = sympy.Eq(numerator % denominator, 0)
|
| 310 |
+
return self.is_expr_static_and_true(expr) # type: ignore[arg-type]
|
| 311 |
+
|
| 312 |
+
# The guard functions require you to ALREADY KNOW that a particular
|
| 313 |
+
# condition holds. If you don't know (you want to guard on an expression
|
| 314 |
+
# being a particular value, and then get access to that value), use
|
| 315 |
+
# the evaluate functions.
|
| 316 |
+
|
| 317 |
+
def guard_equals(self, left: Expr, right: Expr) -> Expr:
|
| 318 |
+
if isinstance(left, Expr):
|
| 319 |
+
left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
| 320 |
+
if isinstance(right, Expr):
|
| 321 |
+
right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
| 322 |
+
assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
|
| 323 |
+
return left
|
| 324 |
+
|
| 325 |
+
def guard_leq(self, left: Expr, right: Expr) -> None:
|
| 326 |
+
return self.guard_lt(left, right + 1)
|
| 327 |
+
|
| 328 |
+
def guard_lt(self, left: Expr, right: Expr) -> None:
|
| 329 |
+
assert self.shape_env.evaluate_expr(sympy.Lt(left, right))
|
| 330 |
+
|
| 331 |
+
def expect_true(self, expr: Expr, *, msg: str) -> None:
|
| 332 |
+
expr = sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
| 333 |
+
self.shape_env.defer_runtime_assert(expr, msg, fx_node=None)
|
| 334 |
+
|
| 335 |
+
def expect_equals(self, left: Expr, right: Expr, *, msg: str) -> Expr:
|
| 336 |
+
# Prefer returning the expression without unbacked symints
|
| 337 |
+
if self.shape_env.is_unbacked_symint(left):
|
| 338 |
+
self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type]
|
| 339 |
+
return right
|
| 340 |
+
elif self.shape_env.is_unbacked_symint(right):
|
| 341 |
+
self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type]
|
| 342 |
+
return left
|
| 343 |
+
else:
|
| 344 |
+
return self.guard_equals(left, right)
|
| 345 |
+
|
| 346 |
+
def guarded_order(self, seq):
|
| 347 |
+
"""
|
| 348 |
+
Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing.
|
| 349 |
+
Used for generating block_ptrs.
|
| 350 |
+
"""
|
| 351 |
+
seq = [*map(self.remove_precomputed_replacements, seq)]
|
| 352 |
+
seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)]
|
| 353 |
+
seq.sort()
|
| 354 |
+
order = [-1] * len(seq)
|
| 355 |
+
last_var = None
|
| 356 |
+
for new_index, (_, orig_index, var) in enumerate(seq):
|
| 357 |
+
order[orig_index] = new_index
|
| 358 |
+
if last_var is not None:
|
| 359 |
+
self.guard_leq(last_var, var)
|
| 360 |
+
last_var = var
|
| 361 |
+
return order
|
| 362 |
+
|
| 363 |
+
# The evaluate functions evaluate some symbolic sympy expression
|
| 364 |
+
# (NB: not necessarily an Expr) and return what the concrete result
|
| 365 |
+
# is, guarding on the expression being that result
|
| 366 |
+
|
| 367 |
+
# NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b)
|
| 368 |
+
# as this will ensure that you actually have a sympy'ified expression,
|
| 369 |
+
# and will prevent you from incorrectly writing evaluate_expr(a == b)
|
| 370 |
+
# which does the wrong thing if a or b is a sympy expression
|
| 371 |
+
def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool:
|
| 372 |
+
assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left)
|
| 373 |
+
return self.shape_env.evaluate_expr(sympy.sympify(left))
|
| 374 |
+
|
| 375 |
+
def evaluate_min(self, left: Expr, right: Expr) -> Expr:
|
| 376 |
+
"""return the smaller of left and right, and guard on that choice"""
|
| 377 |
+
lv = self.size_hint(left)
|
| 378 |
+
rv = self.size_hint(right)
|
| 379 |
+
if lv <= rv:
|
| 380 |
+
self.guard_leq(left, right)
|
| 381 |
+
return left
|
| 382 |
+
else:
|
| 383 |
+
self.guard_leq(right, left)
|
| 384 |
+
return right
|
| 385 |
+
|
| 386 |
+
def evaluate_max(self, left: Expr, right: Expr) -> Expr:
|
| 387 |
+
"""return the larger of left and right, and guard on that choice"""
|
| 388 |
+
# Always choose the opposite of eval min for consistency
|
| 389 |
+
# This means min(a, b) and max(a, b) produce the same guards
|
| 390 |
+
min_val = self.evaluate_min(left, right)
|
| 391 |
+
return right if min_val is left else left
|
| 392 |
+
|
| 393 |
+
def evaluate_static_shape(self, left: Expr) -> int:
|
| 394 |
+
right = self.size_hint(left)
|
| 395 |
+
self.guard_equals(left, sympy.Integer(right))
|
| 396 |
+
return int(right)
|
| 397 |
+
|
| 398 |
+
def evaluate_static_shapes(self, left: List[Expr]) -> List[int]:
|
| 399 |
+
return [self.evaluate_static_shape(x) for x in left]
|
| 400 |
+
|
| 401 |
+
def remove_precomputed_replacements(self, expr: Expr) -> Expr:
|
| 402 |
+
if any(s.name.startswith("ps") for s in expr.free_symbols): # type: ignore[attr-defined]
|
| 403 |
+
return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
| 404 |
+
return expr
|
| 405 |
+
|
| 406 |
+
def symbolic_hint(self, expr: Expr) -> Expr:
|
| 407 |
+
# Substitute all hints into expr, but leave unbacked symints alone
|
| 408 |
+
if not isinstance(expr, Expr):
|
| 409 |
+
assert isinstance(expr, int)
|
| 410 |
+
return expr
|
| 411 |
+
free_symbols = expr.free_symbols
|
| 412 |
+
if not free_symbols:
|
| 413 |
+
return int(expr) # type: ignore[return-value]
|
| 414 |
+
expr = self.remove_precomputed_replacements(expr)
|
| 415 |
+
return sympy_subs(expr, self.var_to_val)
|
| 416 |
+
|
| 417 |
+
def size_hint(self, expr: Expr, *, fallback: Optional[int] = None) -> int:
|
| 418 |
+
out = self.symbolic_hint(expr)
|
| 419 |
+
if not isinstance(out, (int, sympy.Integer)) and fallback is not None:
|
| 420 |
+
# Use the provided heuristic fallback hint
|
| 421 |
+
sym_vrs = {
|
| 422 |
+
s: self.shape_env.var_to_range.get(s, None) for s in expr.free_symbols
|
| 423 |
+
}
|
| 424 |
+
if all(vr is not None for vr in sym_vrs.values()):
|
| 425 |
+
expr_vr = bound_sympy(expr, sym_vrs) # type: ignore[arg-type]
|
| 426 |
+
lower = self.size_hint(expr_vr.lower) # type: ignore[arg-type]
|
| 427 |
+
upper = self.size_hint(expr_vr.upper) # type: ignore[arg-type]
|
| 428 |
+
fallback = min(max(fallback, lower), upper)
|
| 429 |
+
return fallback
|
| 430 |
+
try:
|
| 431 |
+
return int(out)
|
| 432 |
+
except Exception:
|
| 433 |
+
log.debug("failed on: %s", out)
|
| 434 |
+
raise
|
| 435 |
+
|
| 436 |
+
def size_hints(
|
| 437 |
+
self,
|
| 438 |
+
exprs: Iterable[Expr],
|
| 439 |
+
*,
|
| 440 |
+
fallback: Optional[int] = None,
|
| 441 |
+
) -> Tuple[int, ...]:
|
| 442 |
+
return tuple(self.size_hint(x, fallback=fallback) for x in exprs)
|
| 443 |
+
|
| 444 |
+
def _lru_cache(self, fn, maxsize=None):
|
| 445 |
+
"""
|
| 446 |
+
Wrapper around functools.lru_cache that clears when replacements
|
| 447 |
+
has been invalidated.
|
| 448 |
+
"""
|
| 449 |
+
fn_cache = functools.lru_cache(maxsize)(fn)
|
| 450 |
+
prior_len = len(self.replacements)
|
| 451 |
+
|
| 452 |
+
@functools.wraps(fn)
|
| 453 |
+
def wrapper(*args, **kwargs):
|
| 454 |
+
nonlocal prior_len
|
| 455 |
+
if prior_len != len(self.replacements):
|
| 456 |
+
prior_len = len(self.replacements)
|
| 457 |
+
fn_cache.cache_clear()
|
| 458 |
+
return fn_cache(*args, **kwargs)
|
| 459 |
+
|
| 460 |
+
return wrapper
|
| 461 |
+
|
| 462 |
+
def make_stride_vars_cache(self):
|
| 463 |
+
cache = self._lru_cache(self._stride_vars)
|
| 464 |
+
|
| 465 |
+
def stride_vars(
|
| 466 |
+
index: Expr,
|
| 467 |
+
vars: List[sympy.Symbol],
|
| 468 |
+
support_vars: Optional[List[sympy.Symbol]] = None,
|
| 469 |
+
) -> List[Expr]:
|
| 470 |
+
if not support_vars:
|
| 471 |
+
support_vars = vars
|
| 472 |
+
return cache(index, tuple(vars), tuple(support_vars))
|
| 473 |
+
|
| 474 |
+
return stride_vars
|
| 475 |
+
|
| 476 |
+
def _stride_vars(
|
| 477 |
+
self, index: Expr, vars: List[sympy.Symbol], support_vars: List[sympy.Symbol]
|
| 478 |
+
) -> List[Expr]:
|
| 479 |
+
"""Convert an indexing expression back into strides
|
| 480 |
+
|
| 481 |
+
NOTE: This is only valid if the index is a standard strided offset
|
| 482 |
+
calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a
|
| 483 |
+
stride of -10 because the index wraps around after the first element
|
| 484 |
+
|
| 485 |
+
"""
|
| 486 |
+
strides = []
|
| 487 |
+
index = self.simplify(index)
|
| 488 |
+
# remove any offset
|
| 489 |
+
index = index - sympy_subs(
|
| 490 |
+
index, {v: sympy.Integer(0) for v in support_vars if v != 0}
|
| 491 |
+
)
|
| 492 |
+
for i in range(len(vars)):
|
| 493 |
+
# drop all the other dims
|
| 494 |
+
index_dim = sympy_subs(
|
| 495 |
+
index,
|
| 496 |
+
{
|
| 497 |
+
support_vars[j]: sympy.Integer(0)
|
| 498 |
+
for j in range(len(support_vars))
|
| 499 |
+
if vars[i] != support_vars[j] and support_vars[j] != 0
|
| 500 |
+
},
|
| 501 |
+
)
|
| 502 |
+
v = vars[i]
|
| 503 |
+
if v == 0:
|
| 504 |
+
strides.append(sympy.Integer(0))
|
| 505 |
+
else:
|
| 506 |
+
# TODO(jansel): should we use sympy.diff here?
|
| 507 |
+
strides.append(
|
| 508 |
+
sympy_subs(index_dim, {v: sympy.Integer(1)})
|
| 509 |
+
- sympy_subs(index_dim, {v: sympy.Integer(0)})
|
| 510 |
+
)
|
| 511 |
+
return strides
|
| 512 |
+
|
| 513 |
+
def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr:
|
| 514 |
+
"""Extract offset part of an indexing expression"""
|
| 515 |
+
index = self.simplify(index)
|
| 516 |
+
return sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0})
|
| 517 |
+
|
| 518 |
+
def stride_hints(
|
| 519 |
+
self,
|
| 520 |
+
index: Expr,
|
| 521 |
+
vars: List[sympy.Symbol],
|
| 522 |
+
support_vars: Optional[List[sympy.Symbol]] = None,
|
| 523 |
+
) -> List[int]:
|
| 524 |
+
for v in index.free_symbols:
|
| 525 |
+
if v.name.startswith("indirect"): # type: ignore[attr-defined]
|
| 526 |
+
index = sympy_subs(index, {v: 0}) # type: ignore[dict-item]
|
| 527 |
+
result = []
|
| 528 |
+
for s in self.stride_vars(index, vars, support_vars):
|
| 529 |
+
try:
|
| 530 |
+
result.append(self.size_hint(s))
|
| 531 |
+
except TypeError:
|
| 532 |
+
result.append(0)
|
| 533 |
+
return result
|
| 534 |
+
|
| 535 |
+
def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]:
|
| 536 |
+
strides = tuple(map(abs, self.stride_hints(index, vars)))
|
| 537 |
+
order = list(range(len(strides)))
|
| 538 |
+
order.sort(key=lambda x: (strides[x] == 0, strides[x]))
|
| 539 |
+
return order
|
| 540 |
+
|
| 541 |
+
def lookup_precomputed_size(self, expr: Expr) -> Expr:
|
| 542 |
+
if (
|
| 543 |
+
isinstance(expr, (int, sympy.Symbol, sympy.Number))
|
| 544 |
+
or expr.is_number
|
| 545 |
+
or expr.is_symbol
|
| 546 |
+
):
|
| 547 |
+
return expr
|
| 548 |
+
expr = self.remove_precomputed_replacements(expr)
|
| 549 |
+
if expr not in self.precomputed_replacements:
|
| 550 |
+
sym = sympy_index_symbol(f"ps{len(self.precomputed_replacements)}")
|
| 551 |
+
self.precomputed_replacements[expr] = sym
|
| 552 |
+
self.inv_precomputed_replacements[sym] = expr
|
| 553 |
+
return self.precomputed_replacements[expr]
|
| 554 |
+
|
| 555 |
+
def free_symbols(self) -> Set[sympy.Symbol]:
|
| 556 |
+
return set(self.var_to_val.keys()) - set(self.replacements.keys())
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def join_dimensions(expr: Expr) -> Expr:
|
| 560 |
+
if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing):
|
| 561 |
+
return expr # fast exit path
|
| 562 |
+
return _join_dimensions_cached(expr)
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
@functools.lru_cache(256)
|
| 566 |
+
def _join_dimensions_cached(expr: Expr) -> Expr:
|
| 567 |
+
"""
|
| 568 |
+
ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4)
|
| 569 |
+
becomes
|
| 570 |
+
ModularIndexing(i0, 1, 128)
|
| 571 |
+
ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32)
|
| 572 |
+
becomes i0
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
This type of pattern can come from view operations
|
| 576 |
+
"""
|
| 577 |
+
assert isinstance(expr, sympy.Add)
|
| 578 |
+
|
| 579 |
+
scale = sympy.Wild("scale", exclude=[0])
|
| 580 |
+
base = sympy.Wild("base")
|
| 581 |
+
divisor = sympy.Wild("divisor")
|
| 582 |
+
mod1 = sympy.Wild("modulus")
|
| 583 |
+
mod2 = sympy.Wild("modulus2")
|
| 584 |
+
for term1 in expr.args:
|
| 585 |
+
m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
|
| 586 |
+
if m1:
|
| 587 |
+
for term2 in expr.args:
|
| 588 |
+
m2 = term2.match(
|
| 589 |
+
m1[scale]
|
| 590 |
+
* m1[mod1]
|
| 591 |
+
* ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2)
|
| 592 |
+
)
|
| 593 |
+
if m2 and term1 != term2:
|
| 594 |
+
expr = join_dimensions(
|
| 595 |
+
expr
|
| 596 |
+
- term1
|
| 597 |
+
- term2
|
| 598 |
+
+ m1[scale]
|
| 599 |
+
* ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2])
|
| 600 |
+
)
|
| 601 |
+
return expr
|
| 602 |
+
for term1 in expr.args:
|
| 603 |
+
m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
|
| 604 |
+
if m1:
|
| 605 |
+
for term2 in expr.args:
|
| 606 |
+
m2 = term2.match(
|
| 607 |
+
m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1])
|
| 608 |
+
)
|
| 609 |
+
if m2 is not None: # in case of success we get an empty dict here
|
| 610 |
+
expr = join_dimensions(
|
| 611 |
+
expr
|
| 612 |
+
- term1
|
| 613 |
+
- term2
|
| 614 |
+
+ m1[scale] * FloorDiv(m1[base], m1[divisor])
|
| 615 |
+
)
|
| 616 |
+
return expr
|
| 617 |
+
return expr
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined]
|
| 621 |
+
"""
|
| 622 |
+
A wrapper around .virtualize.ops that uses var range information to
|
| 623 |
+
simplify ModularIndexing/FloorDiv.
|
| 624 |
+
"""
|
| 625 |
+
|
| 626 |
+
def __init__(self, inner, var_ranges: VarRanges):
|
| 627 |
+
super().__init__(inner)
|
| 628 |
+
self.name = "SimplifyIndexing"
|
| 629 |
+
self._simplify: Callable[
|
| 630 |
+
[Expr], Expr
|
| 631 |
+
] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
|
| 632 |
+
|
| 633 |
+
def load(self, name: str, index: sympy.Expr):
|
| 634 |
+
return self._inner.load(name, self._simplify(index))
|
| 635 |
+
|
| 636 |
+
def store(self, name, index, value, mode=None):
|
| 637 |
+
return self._inner.store(name, self._simplify(index), value, mode=mode)
|
| 638 |
+
|
| 639 |
+
def store_reduction(self, name, index, value):
|
| 640 |
+
return self._inner.store_reduction(name, self._simplify(index), value)
|
| 641 |
+
|
| 642 |
+
def index_expr(self, index, dtype):
|
| 643 |
+
return self._inner.index_expr(self._simplify(index), dtype)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/utils.py
ADDED
|
@@ -0,0 +1,1428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
import contextlib
|
| 5 |
+
import dataclasses
|
| 6 |
+
import enum
|
| 7 |
+
import functools
|
| 8 |
+
import getpass
|
| 9 |
+
import inspect
|
| 10 |
+
import io
|
| 11 |
+
import itertools
|
| 12 |
+
import logging
|
| 13 |
+
import math
|
| 14 |
+
import operator
|
| 15 |
+
import os
|
| 16 |
+
import platform
|
| 17 |
+
import re
|
| 18 |
+
import shutil
|
| 19 |
+
import sys
|
| 20 |
+
import tempfile
|
| 21 |
+
import textwrap
|
| 22 |
+
import time
|
| 23 |
+
import unittest
|
| 24 |
+
from dataclasses import fields
|
| 25 |
+
from datetime import datetime
|
| 26 |
+
from io import StringIO
|
| 27 |
+
from typing import (
|
| 28 |
+
Any,
|
| 29 |
+
Callable,
|
| 30 |
+
Dict,
|
| 31 |
+
Generic,
|
| 32 |
+
Iterable,
|
| 33 |
+
List,
|
| 34 |
+
NamedTuple,
|
| 35 |
+
Optional,
|
| 36 |
+
Protocol,
|
| 37 |
+
Set,
|
| 38 |
+
TypeVar,
|
| 39 |
+
Union,
|
| 40 |
+
ValuesView,
|
| 41 |
+
)
|
| 42 |
+
from unittest import mock
|
| 43 |
+
|
| 44 |
+
import sympy
|
| 45 |
+
from typing_extensions import Concatenate, ParamSpec
|
| 46 |
+
|
| 47 |
+
import torch
|
| 48 |
+
from torch._dynamo.device_interface import get_interface_for_device
|
| 49 |
+
from torch.autograd import DeviceType
|
| 50 |
+
from torch.autograd.profiler_util import EventList
|
| 51 |
+
from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
|
| 52 |
+
from . import config
|
| 53 |
+
|
| 54 |
+
log = logging.getLogger(__name__)
|
| 55 |
+
|
| 56 |
+
_T = TypeVar("_T")
|
| 57 |
+
VarRanges = Dict[sympy.Expr, sympy.Expr]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float:
|
| 61 |
+
"""
|
| 62 |
+
Returns benchmark results by examining torch profiler events.
|
| 63 |
+
This could be more accurate as it doesn't count CPU side overhead.
|
| 64 |
+
However, this also requires manually excluding irrelevant event, e.g.
|
| 65 |
+
vectorized_elementwise_kernel which is used to fill L2 cache,
|
| 66 |
+
various CUDA events, etc, so could also be fragile.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
fn()
|
| 70 |
+
torch.cuda.synchronize()
|
| 71 |
+
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
|
| 72 |
+
|
| 73 |
+
# Estimate the runtime of the function
|
| 74 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
| 75 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
| 76 |
+
start_event.record()
|
| 77 |
+
for _ in range(5):
|
| 78 |
+
cache.zero_()
|
| 79 |
+
fn()
|
| 80 |
+
end_event.record()
|
| 81 |
+
torch.cuda.synchronize()
|
| 82 |
+
estimate_ms = start_event.elapsed_time(end_event) / 5
|
| 83 |
+
|
| 84 |
+
# compute number of warmup and repeat
|
| 85 |
+
n_warmup = max(1, int(warmup / estimate_ms))
|
| 86 |
+
n_repeat = max(1, int(rep / estimate_ms))
|
| 87 |
+
|
| 88 |
+
# Warm-up
|
| 89 |
+
for _ in range(n_warmup):
|
| 90 |
+
fn()
|
| 91 |
+
|
| 92 |
+
with torch.profiler.profile(
|
| 93 |
+
activities=[
|
| 94 |
+
torch.profiler.ProfilerActivity.CUDA,
|
| 95 |
+
]
|
| 96 |
+
) as p:
|
| 97 |
+
# Benchmark
|
| 98 |
+
for i in range(n_repeat):
|
| 99 |
+
# we clear the L2 cache before each run
|
| 100 |
+
cache.zero_()
|
| 101 |
+
# record time of `fn`
|
| 102 |
+
fn()
|
| 103 |
+
# Record clocks
|
| 104 |
+
torch.cuda.synchronize()
|
| 105 |
+
|
| 106 |
+
log.debug("raw events")
|
| 107 |
+
log.debug(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
|
| 108 |
+
|
| 109 |
+
filtered_events = EventList(
|
| 110 |
+
[
|
| 111 |
+
event
|
| 112 |
+
for event in p.events()
|
| 113 |
+
if event.device_type == DeviceType.CUDA and event.name != "Context Sync"
|
| 114 |
+
]
|
| 115 |
+
)
|
| 116 |
+
if len(filtered_events) % n_repeat != 0:
|
| 117 |
+
raise RuntimeError(
|
| 118 |
+
"Failed to divide all profiling events into #repeat groups. "
|
| 119 |
+
"#CUDA events: %d, #repeats: %s",
|
| 120 |
+
len(filtered_events),
|
| 121 |
+
n_repeat,
|
| 122 |
+
)
|
| 123 |
+
num_event_per_group = len(filtered_events) / n_repeat
|
| 124 |
+
actual_events = EventList(
|
| 125 |
+
[
|
| 126 |
+
event
|
| 127 |
+
for i, event in enumerate(filtered_events)
|
| 128 |
+
if i % num_event_per_group != 0
|
| 129 |
+
]
|
| 130 |
+
)
|
| 131 |
+
actual_events._build_tree()
|
| 132 |
+
actual_events = actual_events.key_averages()
|
| 133 |
+
|
| 134 |
+
log.debug("profiling time breakdown")
|
| 135 |
+
log.debug(actual_events.table(row_limit=-1))
|
| 136 |
+
|
| 137 |
+
res = sum(event.cuda_time_total for event in actual_events) / 1000.0 / n_repeat
|
| 138 |
+
log.debug("profiling results: %s ms", res)
|
| 139 |
+
return res
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def do_bench(*args, **kwargs):
|
| 143 |
+
@functools.lru_cache(None)
|
| 144 |
+
def load_triton():
|
| 145 |
+
try:
|
| 146 |
+
# NB: Lazily load triton, as importing triton is slow
|
| 147 |
+
# see https://github.com/openai/triton/issues/1599
|
| 148 |
+
from triton.testing import do_bench as triton_do_bench
|
| 149 |
+
except ImportError as exc:
|
| 150 |
+
raise NotImplementedError("requires Triton") from exc
|
| 151 |
+
|
| 152 |
+
# triton PR https://github.com/openai/triton/pull/1513 change the
|
| 153 |
+
# quantile fields name from 'percentiles' to 'quantiles'
|
| 154 |
+
# and change the default value from (0.5, 0.2, 0.8) to None.
|
| 155 |
+
# This may break inductor since a caller expects a tuple may get a item.
|
| 156 |
+
#
|
| 157 |
+
# Add a wrapper to maintain the same behavior for inductor.
|
| 158 |
+
# Maybe we should have own implementation of this function?
|
| 159 |
+
return triton_do_bench, (
|
| 160 |
+
"quantiles"
|
| 161 |
+
if inspect.signature(triton_do_bench).parameters.get("quantiles")
|
| 162 |
+
is not None
|
| 163 |
+
else "percentiles"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
triton_do_bench, quantile_field_name = load_triton()
|
| 167 |
+
|
| 168 |
+
if quantile_field_name not in kwargs:
|
| 169 |
+
kwargs[quantile_field_name] = (0.5, 0.2, 0.8)
|
| 170 |
+
return triton_do_bench(*args, **kwargs)[0]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@functools.lru_cache(None)
|
| 174 |
+
def has_torchvision_roi_align() -> bool:
|
| 175 |
+
try:
|
| 176 |
+
from torchvision.ops import roi_align # noqa: F401
|
| 177 |
+
|
| 178 |
+
return roi_align is not None and hasattr(
|
| 179 |
+
getattr(torch.ops, "torchvision", None), "roi_align"
|
| 180 |
+
)
|
| 181 |
+
except ImportError:
|
| 182 |
+
return False
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def conditional_product(*args):
|
| 186 |
+
return functools.reduce(operator.mul, [x for x in args if x])
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def decode_device(device: Union[Optional[torch.device], str]) -> torch.device:
|
| 190 |
+
if device is None:
|
| 191 |
+
return torch.tensor(0.0).device # default device
|
| 192 |
+
if isinstance(device, str):
|
| 193 |
+
device = torch.device(device)
|
| 194 |
+
if device.type != "cpu" and device.index is None:
|
| 195 |
+
device_interface = get_interface_for_device(device.type)
|
| 196 |
+
return torch.device(device.type, index=device_interface.Worker.current_device())
|
| 197 |
+
return device
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def sympy_product(it):
|
| 201 |
+
return functools.reduce(operator.mul, it, sympy.Integer(1))
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def sympy_dot(seq1, seq2):
|
| 205 |
+
assert len(seq1) == len(seq2)
|
| 206 |
+
return sympy.expand(sum(a * b for a, b in zip(seq1, seq2)))
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def unique(it: Iterable[_T]) -> ValuesView[_T]:
|
| 210 |
+
return {id(x): x for x in it}.values()
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def ceildiv(
|
| 214 |
+
numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr]
|
| 215 |
+
) -> Union[int, sympy.Expr]:
|
| 216 |
+
if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr):
|
| 217 |
+
return CeilDiv(numer, denom)
|
| 218 |
+
# TODO: There is a bug in a call to this function, to repro:
|
| 219 |
+
# python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
|
| 220 |
+
# --amp --only YituTechConvBert --dynamic-shapes
|
| 221 |
+
assert isinstance(numer, int) and isinstance(
|
| 222 |
+
denom, int
|
| 223 |
+
), f"{numer}: {type(numer)}, {denom}: {type(denom)}"
|
| 224 |
+
return -(numer // -denom)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def next_power_of_2(n: int) -> int:
|
| 228 |
+
"""Return the smallest power of 2 greater than or equal to n"""
|
| 229 |
+
n -= 1
|
| 230 |
+
n |= n >> 1
|
| 231 |
+
n |= n >> 2
|
| 232 |
+
n |= n >> 4
|
| 233 |
+
n |= n >> 8
|
| 234 |
+
n |= n >> 16
|
| 235 |
+
n |= n >> 32
|
| 236 |
+
n += 1
|
| 237 |
+
return n
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def _type_of(key):
|
| 241 |
+
# Use the function here to get rid of dependencies on the Triton during the codegen.
|
| 242 |
+
# Refer to Triton implementation here:
|
| 243 |
+
# https://github.com/openai/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238
|
| 244 |
+
# `None` is nullptr. Implicitly convert to *i8.
|
| 245 |
+
if key is None:
|
| 246 |
+
return "*i8"
|
| 247 |
+
dtype_str = str(key).split(".")[-1]
|
| 248 |
+
tys = {
|
| 249 |
+
"bool": "i1",
|
| 250 |
+
"float8e4nv": "fp8e4nv",
|
| 251 |
+
"float8e5": "fp8e5",
|
| 252 |
+
"float8e4b15": "fp8e4b15",
|
| 253 |
+
"float8e4b15x4": "fp8e4b15x4",
|
| 254 |
+
"float8_e4m3fn": "fp8e4nv",
|
| 255 |
+
"float8_e5m2": "fp8e5",
|
| 256 |
+
"float16": "fp16",
|
| 257 |
+
"bfloat16": "bf16",
|
| 258 |
+
"float32": "fp32",
|
| 259 |
+
"float64": "fp64",
|
| 260 |
+
"int8": "i8",
|
| 261 |
+
"int16": "i16",
|
| 262 |
+
"int32": "i32",
|
| 263 |
+
"int64": "i64",
|
| 264 |
+
"uint8": "u8",
|
| 265 |
+
"uint16": "u16",
|
| 266 |
+
"uint32": "u32",
|
| 267 |
+
"uint64": "u64",
|
| 268 |
+
}
|
| 269 |
+
# reinterpret can create triton type
|
| 270 |
+
for v in list(tys.values()):
|
| 271 |
+
tys[v] = v
|
| 272 |
+
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def convert_shape_to_inductor(
|
| 276 |
+
lst: Iterable[Union[int, torch.SymInt]]
|
| 277 |
+
) -> List[sympy.Expr]:
|
| 278 |
+
"""
|
| 279 |
+
Gets the shape and stride of a tensor. For non-symbolic tensors, this is
|
| 280 |
+
trivial. But for symbolic tensors, we need to map from SymIntNode into
|
| 281 |
+
sympy.Expr.
|
| 282 |
+
"""
|
| 283 |
+
return [
|
| 284 |
+
i.node.expr if isinstance(i, torch.SymInt) else sympy.Integer(i) for i in lst
|
| 285 |
+
]
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def convert_shape_to_symint(
|
| 289 |
+
lst: Iterable[Union[int, sympy.Expr]]
|
| 290 |
+
) -> List[Union[int, torch.SymInt]]:
|
| 291 |
+
"""
|
| 292 |
+
Takes a list of shapes from Inductor and converts them into symints (or just
|
| 293 |
+
ints if all shapes are static).
|
| 294 |
+
"""
|
| 295 |
+
from .virtualized import V
|
| 296 |
+
|
| 297 |
+
return [
|
| 298 |
+
i
|
| 299 |
+
if isinstance(i, int)
|
| 300 |
+
else int(i)
|
| 301 |
+
if isinstance(i, sympy.Integer)
|
| 302 |
+
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
|
| 303 |
+
for i in lst
|
| 304 |
+
]
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def is_view(op: torch._ops.OpOverload):
|
| 308 |
+
"""
|
| 309 |
+
Does this op overload have aliasing
|
| 310 |
+
"""
|
| 311 |
+
assert isinstance(op, torch._ops.OpOverload)
|
| 312 |
+
return any(a.alias_info is not None for a in op._schema.arguments)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def is_pointwise_use(use):
|
| 316 |
+
if not use.op == "call_function":
|
| 317 |
+
return False
|
| 318 |
+
|
| 319 |
+
if not (
|
| 320 |
+
isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
|
| 321 |
+
):
|
| 322 |
+
return False
|
| 323 |
+
|
| 324 |
+
if use.target is operator.getitem or is_view(use.target):
|
| 325 |
+
return all(is_pointwise_use(u) for u in use.users)
|
| 326 |
+
|
| 327 |
+
return torch.Tag.pointwise in use.target.tags
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def gen_gm_and_inputs(target, args, kwargs):
|
| 331 |
+
g = torch.fx.Graph()
|
| 332 |
+
g_args = []
|
| 333 |
+
a_args = []
|
| 334 |
+
for n, arg in enumerate(args):
|
| 335 |
+
if isinstance(arg, torch.Tensor):
|
| 336 |
+
g_args.append(g.placeholder(f"arg{n}"))
|
| 337 |
+
a_args.append(arg)
|
| 338 |
+
else:
|
| 339 |
+
g_args.append(arg)
|
| 340 |
+
assert all(not isinstance(x, torch.Tensor) for x in kwargs.values())
|
| 341 |
+
node = g.call_function(target, tuple(g_args), kwargs)
|
| 342 |
+
if (
|
| 343 |
+
len(target._schema.returns) == 1
|
| 344 |
+
and str(target._schema.returns[0].type) == "Tensor"
|
| 345 |
+
):
|
| 346 |
+
node = (node,)
|
| 347 |
+
g.output(node)
|
| 348 |
+
|
| 349 |
+
gm = torch.fx.GraphModule({}, g)
|
| 350 |
+
return gm, a_args
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def synchronize(device: str = "cuda"):
|
| 354 |
+
if device == "cpu":
|
| 355 |
+
return
|
| 356 |
+
device_interface = get_interface_for_device(device)
|
| 357 |
+
if device_interface.is_available():
|
| 358 |
+
device_interface.synchronize()
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def timed(
|
| 362 |
+
model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda"
|
| 363 |
+
) -> float:
|
| 364 |
+
synchronize(device)
|
| 365 |
+
torch.manual_seed(1337)
|
| 366 |
+
t0 = time.perf_counter()
|
| 367 |
+
for _ in range(times):
|
| 368 |
+
result = model(*example_inputs)
|
| 369 |
+
synchronize(device)
|
| 370 |
+
t1 = time.perf_counter()
|
| 371 |
+
# GC the result after timing
|
| 372 |
+
assert result is not None # type: ignore[possibly-undefined]
|
| 373 |
+
return t1 - t0
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def print_performance(
|
| 377 |
+
fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda"
|
| 378 |
+
):
|
| 379 |
+
timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)])
|
| 380 |
+
took = torch.median(timings) / times
|
| 381 |
+
print(f"{took/baseline:.6f}")
|
| 382 |
+
return took
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def precompute_method(obj: Any, method: str):
|
| 386 |
+
"""Replace obj.method() with a new method that returns a precomputed constant."""
|
| 387 |
+
result = getattr(obj, method)()
|
| 388 |
+
setattr(obj, method, lambda: result)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def precompute_methods(obj: Any, methods: List[str]):
|
| 392 |
+
"""Replace methods with new methods that returns a precomputed constants."""
|
| 393 |
+
for method in methods:
|
| 394 |
+
precompute_method(obj, method)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def cmp(a, b) -> int:
|
| 398 |
+
return int(a > b) - int(a < b)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def pad_listlike(x, size):
|
| 402 |
+
if len(x) == 1:
|
| 403 |
+
return type(x)([x[0]]) * size
|
| 404 |
+
else:
|
| 405 |
+
return x
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
# Used to ensure that iterating over a set is deterministic
|
| 409 |
+
def tuple_sorted(x):
|
| 410 |
+
if len(x) == 0:
|
| 411 |
+
return []
|
| 412 |
+
|
| 413 |
+
def sort_func(elem):
|
| 414 |
+
if isinstance(elem, str):
|
| 415 |
+
return elem
|
| 416 |
+
else:
|
| 417 |
+
# We expect `elem` to be `scheduler.BaseSchedulerNode` type here,
|
| 418 |
+
# but we are not able to do isinstance assert because of circular dependency
|
| 419 |
+
return elem.get_name()
|
| 420 |
+
|
| 421 |
+
return sorted(x, key=sort_func)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
P = ParamSpec("P")
|
| 425 |
+
RV = TypeVar("RV", covariant=True)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
class CachedMethod(Generic[P, RV], Protocol):
|
| 429 |
+
@staticmethod
|
| 430 |
+
def clear_cache(self) -> None:
|
| 431 |
+
...
|
| 432 |
+
|
| 433 |
+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV:
|
| 434 |
+
...
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
|
| 438 |
+
def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
|
| 439 |
+
key = f"__{fn.__name__}_cache"
|
| 440 |
+
|
| 441 |
+
@functools.wraps(fn)
|
| 442 |
+
def wrapper(self):
|
| 443 |
+
if not hasattr(self, key):
|
| 444 |
+
setattr(self, key, fn(self))
|
| 445 |
+
return getattr(self, key)
|
| 446 |
+
|
| 447 |
+
def clear_cache(self):
|
| 448 |
+
if hasattr(self, key):
|
| 449 |
+
delattr(self, key)
|
| 450 |
+
|
| 451 |
+
wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
|
| 452 |
+
return wrapper # type: ignore[return-value]
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def aggregate_origins(node_schedule):
|
| 456 |
+
from . import ir
|
| 457 |
+
|
| 458 |
+
if isinstance(node_schedule, list):
|
| 459 |
+
return functools.reduce(
|
| 460 |
+
operator.or_,
|
| 461 |
+
[
|
| 462 |
+
node.node.origins
|
| 463 |
+
for node in node_schedule
|
| 464 |
+
if hasattr(node, "node") and node.node
|
| 465 |
+
],
|
| 466 |
+
set(),
|
| 467 |
+
)
|
| 468 |
+
elif isinstance(node_schedule, ir.ExternKernel):
|
| 469 |
+
return node_schedule.origins
|
| 470 |
+
else:
|
| 471 |
+
return set()
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def get_fused_kernel_name(node_schedule, descriptive_names):
|
| 475 |
+
all_origins = aggregate_origins(node_schedule)
|
| 476 |
+
if descriptive_names == "original_aten":
|
| 477 |
+
# Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
|
| 478 |
+
sources = [
|
| 479 |
+
origin.meta["original_aten"]._overloadpacket.__name__
|
| 480 |
+
for origin in all_origins
|
| 481 |
+
if origin.op == "call_function"
|
| 482 |
+
and "original_aten" in origin.meta
|
| 483 |
+
and origin.meta["original_aten"] is not None
|
| 484 |
+
]
|
| 485 |
+
sources = sorted(set(sources))
|
| 486 |
+
elif descriptive_names == "torch":
|
| 487 |
+
# Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
|
| 488 |
+
sources = []
|
| 489 |
+
for origin in all_origins:
|
| 490 |
+
if origin.op == "call_function" and "source_fn_stack" in origin.meta:
|
| 491 |
+
source_fn = origin.meta["source_fn_stack"][-1]
|
| 492 |
+
if isinstance(source_fn[1], str):
|
| 493 |
+
sources.append(source_fn[1])
|
| 494 |
+
else:
|
| 495 |
+
sources.append(source_fn[1].__name__)
|
| 496 |
+
sources = sorted(set(sources))
|
| 497 |
+
elif descriptive_names == "inductor_node":
|
| 498 |
+
sources = [
|
| 499 |
+
origin.name for origin in all_origins if origin.op == "call_function"
|
| 500 |
+
]
|
| 501 |
+
else:
|
| 502 |
+
raise NotImplementedError
|
| 503 |
+
sources = sources
|
| 504 |
+
return "_".join(["fused"] + sources)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def get_kernel_metadata(node_schedule, wrapper):
|
| 508 |
+
all_origins = aggregate_origins(node_schedule)
|
| 509 |
+
inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"]
|
| 510 |
+
|
| 511 |
+
from_node_dict = collections.defaultdict(list)
|
| 512 |
+
original_aten_dict = collections.defaultdict(list)
|
| 513 |
+
for node in inductor_nodes:
|
| 514 |
+
if "original_aten" in node.meta and node.meta["original_aten"] is not None:
|
| 515 |
+
key = str(node.meta["original_aten"]._overloadpacket)
|
| 516 |
+
original_aten_dict[key].append(node.name)
|
| 517 |
+
if "from_node" in node.meta:
|
| 518 |
+
key = node.meta["from_node"][0][0]
|
| 519 |
+
from_node_dict[key].append(node.name)
|
| 520 |
+
metadata = (
|
| 521 |
+
f"{wrapper.comment} Source Nodes: [{', '.join(sorted(from_node_dict.keys()))}], "
|
| 522 |
+
f"Original ATen: [{', '.join(sorted(original_aten_dict.keys()))}]"
|
| 523 |
+
)
|
| 524 |
+
# trace back to original node here
|
| 525 |
+
detailed_metadata = []
|
| 526 |
+
for original_node, nodes in sorted(from_node_dict.items()):
|
| 527 |
+
detailed_metadata.append(
|
| 528 |
+
f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}"
|
| 529 |
+
)
|
| 530 |
+
return metadata, "\n".join(detailed_metadata)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def dominated_nodes(
|
| 534 |
+
initial_queue: Iterable[torch.fx.Node], skip_filter=None
|
| 535 |
+
) -> Set[torch.fx.Node]:
|
| 536 |
+
"""Returns the set of nodes whose values depend on those within initial_queue"""
|
| 537 |
+
initial_queue = list(initial_queue)
|
| 538 |
+
dominated_set = set(initial_queue)
|
| 539 |
+
|
| 540 |
+
while initial_queue:
|
| 541 |
+
node = initial_queue.pop()
|
| 542 |
+
for user in node.users:
|
| 543 |
+
if skip_filter and skip_filter(user):
|
| 544 |
+
continue
|
| 545 |
+
if user not in dominated_set:
|
| 546 |
+
dominated_set.add(user)
|
| 547 |
+
initial_queue.append(user)
|
| 548 |
+
|
| 549 |
+
return dominated_set
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def gather_origins(args, kwargs):
|
| 553 |
+
import itertools
|
| 554 |
+
|
| 555 |
+
from . import ir
|
| 556 |
+
|
| 557 |
+
def is_unrealized_node(n):
|
| 558 |
+
if isinstance(n, ir.TensorBox):
|
| 559 |
+
return is_unrealized_node(n.data)
|
| 560 |
+
if isinstance(n, ir.StorageBox):
|
| 561 |
+
return is_unrealized_node(n.data)
|
| 562 |
+
return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise)
|
| 563 |
+
|
| 564 |
+
kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)]
|
| 565 |
+
arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)]
|
| 566 |
+
return set(itertools.chain(*arg_origins, *kwarg_origins))
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def sympy_str(expr: sympy.Expr) -> str:
|
| 570 |
+
"""
|
| 571 |
+
Normal sympy str is very slow, this is a lot faster. The result are
|
| 572 |
+
somewhat worse, as it doesn't do as much simplification. So don't
|
| 573 |
+
use this for final codegen.
|
| 574 |
+
"""
|
| 575 |
+
if isinstance(expr, sympy.Symbol):
|
| 576 |
+
return expr.name
|
| 577 |
+
if isinstance(expr, sympy.Add):
|
| 578 |
+
return " + ".join(map(sympy_str, expr.args))
|
| 579 |
+
if isinstance(expr, sympy.Mul):
|
| 580 |
+
return " * ".join(map(sympy_str, expr.args))
|
| 581 |
+
|
| 582 |
+
if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)):
|
| 583 |
+
return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
|
| 584 |
+
return str(expr)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def sympy_index_symbol(name: str) -> sympy.Symbol:
|
| 588 |
+
"""
|
| 589 |
+
Used to generate an integer-nonnegative symbol.
|
| 590 |
+
"""
|
| 591 |
+
# This should never be used for creating shape/stride symbols, as those
|
| 592 |
+
# should all be allocated before Inductor.
|
| 593 |
+
assert name[0] != "s"
|
| 594 |
+
# NOTE: shape symbols are positive (> 0), but index variables are only
|
| 595 |
+
# non-negative (>= 0).
|
| 596 |
+
return sympy.Symbol(name, integer=True, nonnegative=True)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr:
|
| 600 |
+
"""
|
| 601 |
+
When the passed replacement symbol v is a string, it is converted to a symbol with name v that
|
| 602 |
+
have the same replaced expression integer and nonnegative properties.
|
| 603 |
+
"""
|
| 604 |
+
|
| 605 |
+
def to_symbol(replaced, replacement):
|
| 606 |
+
assert isinstance(replaced, sympy.Expr)
|
| 607 |
+
if isinstance(replacement, str):
|
| 608 |
+
return sympy.Symbol(
|
| 609 |
+
replacement,
|
| 610 |
+
integer=replaced.is_integer, # type: ignore[attr-defined]
|
| 611 |
+
nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
|
| 612 |
+
)
|
| 613 |
+
else:
|
| 614 |
+
return replacement
|
| 615 |
+
|
| 616 |
+
# xreplace is faster than subs, but is way more picky
|
| 617 |
+
return sympy.sympify(expr).xreplace(
|
| 618 |
+
{k: to_symbol(k, v) for k, v in replacements.items()}
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
def free_symbol_startswith(index: sympy.Expr, prefix: str):
|
| 623 |
+
return any(v.name.startswith(prefix) for v in index.free_symbols) # type: ignore[attr-defined]
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def free_symbol_has(index: sympy.Expr, pattern: str):
|
| 627 |
+
return any(pattern in v.name for v in index.free_symbols) # type: ignore[attr-defined]
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def is_symbolic(a: Any) -> bool:
|
| 631 |
+
return isinstance(a, torch.SymInt) or (
|
| 632 |
+
isinstance(a, torch.Tensor)
|
| 633 |
+
and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride()))
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def any_is_symbolic(*args: Any) -> bool:
|
| 638 |
+
return any(is_symbolic(a) for a in args)
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def has_incompatible_cudagraph_ops(gm):
|
| 642 |
+
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
| 643 |
+
|
| 644 |
+
forbidden_set = {
|
| 645 |
+
"aten._fused_moving_avg_obs_fq_helper.default",
|
| 646 |
+
"aten._fused_moving_avg_obs_fq_helper_functional.default",
|
| 647 |
+
"aten.multinomial.default",
|
| 648 |
+
"fbgemm.dense_to_jagged.default",
|
| 649 |
+
"fbgemm.jagged_to_padded_dense.default",
|
| 650 |
+
"run_and_save_rng_state",
|
| 651 |
+
"run_with_rng_state",
|
| 652 |
+
"aten._local_scalar_dense",
|
| 653 |
+
# Technically, it's not necessary to ban this, because an
|
| 654 |
+
# assert_scalar with constant arguments can be validly run
|
| 655 |
+
# with CUDA graphs, but the operator is also pointless with
|
| 656 |
+
# constant arguments, so might as well ban
|
| 657 |
+
"aten._assert_scalar",
|
| 658 |
+
}
|
| 659 |
+
if torch.are_deterministic_algorithms_enabled():
|
| 660 |
+
forbidden_set.update(
|
| 661 |
+
{
|
| 662 |
+
"aten._unsafe_index_put.default",
|
| 663 |
+
"aten.index_put.default",
|
| 664 |
+
"aten.index_put_.default",
|
| 665 |
+
"aten.scatter.src",
|
| 666 |
+
"aten.scatter.reduce",
|
| 667 |
+
"aten.scatter.value_reduce",
|
| 668 |
+
"aten.scatter_add_",
|
| 669 |
+
"aten.scatter_add.default",
|
| 670 |
+
"aten.scatter_reduce.two",
|
| 671 |
+
"aten.scatter_reduce_.two",
|
| 672 |
+
"aten.scatter_reduce.two_out",
|
| 673 |
+
}
|
| 674 |
+
)
|
| 675 |
+
for node in gm.graph.nodes:
|
| 676 |
+
if str(node.target) in forbidden_set:
|
| 677 |
+
return True
|
| 678 |
+
if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val):
|
| 679 |
+
return True
|
| 680 |
+
return False
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def output_node(gm: torch.fx.GraphModule):
|
| 684 |
+
"""Get the output node from an FX graph"""
|
| 685 |
+
last_node = next(iter(reversed(gm.graph.nodes)))
|
| 686 |
+
assert last_node.op == "output"
|
| 687 |
+
return last_node
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
# Attempt to import AttrsDescriptor from Triton
|
| 691 |
+
try:
|
| 692 |
+
from triton.compiler.compiler import AttrsDescriptor
|
| 693 |
+
|
| 694 |
+
attrs_descriptor_available = True
|
| 695 |
+
# Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor
|
| 696 |
+
attr_desc_fields = {f.name for f in fields(AttrsDescriptor)}
|
| 697 |
+
ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields
|
| 698 |
+
divisible_by_8_available = "divisible_by_8" in attr_desc_fields
|
| 699 |
+
except ImportError:
|
| 700 |
+
attrs_descriptor_available = False
|
| 701 |
+
|
| 702 |
+
# Define `instance_descriptor` function with clear conditional handling
|
| 703 |
+
if attrs_descriptor_available:
|
| 704 |
+
|
| 705 |
+
def instance_descriptor(
|
| 706 |
+
divisible_by_16=None,
|
| 707 |
+
equal_to_1=None,
|
| 708 |
+
ids_of_folded_args=None,
|
| 709 |
+
divisible_by_8=None,
|
| 710 |
+
):
|
| 711 |
+
# Prepare the arguments for AttrsDescriptor
|
| 712 |
+
kwargs = {
|
| 713 |
+
"divisible_by_16": divisible_by_16,
|
| 714 |
+
"equal_to_1": equal_to_1,
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
# Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor
|
| 718 |
+
if ids_of_folded_args_available:
|
| 719 |
+
kwargs["ids_of_folded_args"] = ids_of_folded_args
|
| 720 |
+
if divisible_by_8_available:
|
| 721 |
+
kwargs["divisible_by_8"] = divisible_by_8
|
| 722 |
+
|
| 723 |
+
# Instantiate AttrsDescriptor with the prepared arguments
|
| 724 |
+
return AttrsDescriptor(**kwargs)
|
| 725 |
+
|
| 726 |
+
else:
|
| 727 |
+
# Define a namedtuple as a fallback when AttrsDescriptor is not available
|
| 728 |
+
instance_descriptor = collections.namedtuple( # type: ignore[no-redef]
|
| 729 |
+
"instance_descriptor",
|
| 730 |
+
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
|
| 731 |
+
defaults=[tuple(), tuple(), tuple(), tuple()],
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
@functools.lru_cache(None)
|
| 736 |
+
def cache_dir() -> str:
|
| 737 |
+
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
|
| 738 |
+
if cache_dir is None:
|
| 739 |
+
sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
|
| 740 |
+
cache_dir = os.path.join(
|
| 741 |
+
tempfile.gettempdir(),
|
| 742 |
+
"torchinductor_" + sanitized_username,
|
| 743 |
+
)
|
| 744 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 745 |
+
return cache_dir
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
@contextlib.contextmanager
|
| 749 |
+
def fresh_inductor_cache(cache_entries=None):
|
| 750 |
+
"""
|
| 751 |
+
Contextmanager that provides a clean tmp cachedir for inductor.
|
| 752 |
+
|
| 753 |
+
Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
|
| 754 |
+
generated with this cache instance.
|
| 755 |
+
"""
|
| 756 |
+
with tempfile.TemporaryDirectory() as inductor_cache_dir:
|
| 757 |
+
with mock.patch.dict(
|
| 758 |
+
os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
|
| 759 |
+
):
|
| 760 |
+
triton_cache_dir = os.path.join(inductor_cache_dir, "triton")
|
| 761 |
+
with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}):
|
| 762 |
+
yield
|
| 763 |
+
if isinstance(cache_entries, dict):
|
| 764 |
+
assert len(cache_entries) == 0, "expected empty cache_entries dict"
|
| 765 |
+
if os.path.exists(triton_cache_dir):
|
| 766 |
+
files = os.listdir(triton_cache_dir)
|
| 767 |
+
cache_entries.update(
|
| 768 |
+
{
|
| 769 |
+
f: os.path.getsize(os.path.join(triton_cache_dir, f))
|
| 770 |
+
for f in files
|
| 771 |
+
if ".lock" not in f
|
| 772 |
+
}
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def argsort(seq) -> List[int]:
|
| 777 |
+
# preserve original order for equal strides
|
| 778 |
+
getter = seq.__getitem__
|
| 779 |
+
a_r = range(len(seq))
|
| 780 |
+
return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
@functools.lru_cache(8)
|
| 784 |
+
def get_dtype_size(dtype):
|
| 785 |
+
return torch.empty((), dtype=dtype).element_size()
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
class LineContext(NamedTuple):
|
| 789 |
+
context: Any
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
class IndentedBuffer:
|
| 793 |
+
tabwidth = 4
|
| 794 |
+
|
| 795 |
+
def __init__(self, initial_indent=0):
|
| 796 |
+
self._lines = []
|
| 797 |
+
self._indent = initial_indent
|
| 798 |
+
|
| 799 |
+
def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]:
|
| 800 |
+
buf = StringIO()
|
| 801 |
+
p = 1
|
| 802 |
+
linemap = []
|
| 803 |
+
for line in self._lines:
|
| 804 |
+
if isinstance(line, DeferredLineBase):
|
| 805 |
+
line = line()
|
| 806 |
+
if line is None:
|
| 807 |
+
continue
|
| 808 |
+
elif isinstance(line, LineContext):
|
| 809 |
+
linemap.append((p, line.context))
|
| 810 |
+
continue
|
| 811 |
+
assert isinstance(line, str)
|
| 812 |
+
buf.write(line)
|
| 813 |
+
buf.write("\n")
|
| 814 |
+
p += 1 + line.count("\n")
|
| 815 |
+
return buf.getvalue(), linemap
|
| 816 |
+
|
| 817 |
+
def getvalue(self) -> str:
|
| 818 |
+
v, _ = self.getvaluewithlinemap()
|
| 819 |
+
return v
|
| 820 |
+
|
| 821 |
+
def getrawvalue(self) -> str:
|
| 822 |
+
buf = StringIO()
|
| 823 |
+
for line in self._lines:
|
| 824 |
+
if isinstance(line, DeferredLineBase):
|
| 825 |
+
line = line()
|
| 826 |
+
if line is None:
|
| 827 |
+
continue
|
| 828 |
+
elif isinstance(line, LineContext):
|
| 829 |
+
continue
|
| 830 |
+
assert isinstance(line, str)
|
| 831 |
+
# backslash implies line continuation
|
| 832 |
+
if line.endswith("\\"):
|
| 833 |
+
buf.write(line[:-1])
|
| 834 |
+
else:
|
| 835 |
+
buf.write(line)
|
| 836 |
+
buf.write("\n")
|
| 837 |
+
return buf.getvalue()
|
| 838 |
+
|
| 839 |
+
def clear(self):
|
| 840 |
+
self._lines.clear()
|
| 841 |
+
|
| 842 |
+
def __bool__(self):
|
| 843 |
+
return bool(self._lines)
|
| 844 |
+
|
| 845 |
+
def prefix(self):
|
| 846 |
+
return " " * (self._indent * self.tabwidth)
|
| 847 |
+
|
| 848 |
+
def newline(self):
|
| 849 |
+
self.writeline("\n")
|
| 850 |
+
|
| 851 |
+
def writeline(self, line):
|
| 852 |
+
if isinstance(line, LineContext):
|
| 853 |
+
self._lines.append(line)
|
| 854 |
+
elif isinstance(line, DeferredLineBase):
|
| 855 |
+
self._lines.append(line.with_prefix(self.prefix()))
|
| 856 |
+
elif line.strip():
|
| 857 |
+
self._lines.append(f"{self.prefix()}{line}")
|
| 858 |
+
else:
|
| 859 |
+
self._lines.append("")
|
| 860 |
+
|
| 861 |
+
def writelines(self, lines):
|
| 862 |
+
for line in lines:
|
| 863 |
+
self.writeline(line)
|
| 864 |
+
|
| 865 |
+
def indent(self, offset=1):
|
| 866 |
+
@contextlib.contextmanager
|
| 867 |
+
def ctx():
|
| 868 |
+
self._indent += offset
|
| 869 |
+
try:
|
| 870 |
+
yield
|
| 871 |
+
finally:
|
| 872 |
+
self._indent -= offset
|
| 873 |
+
|
| 874 |
+
return ctx()
|
| 875 |
+
|
| 876 |
+
def do_indent(self, offset=1):
|
| 877 |
+
self._indent += offset
|
| 878 |
+
|
| 879 |
+
def do_unindent(self, offset=1):
|
| 880 |
+
self._indent -= offset
|
| 881 |
+
|
| 882 |
+
def splice(self, other_code, strip=False):
|
| 883 |
+
if isinstance(other_code, IndentedBuffer):
|
| 884 |
+
dedent = float("inf")
|
| 885 |
+
for line in other_code._lines:
|
| 886 |
+
if not isinstance(line, LineContext) and line:
|
| 887 |
+
dedent = min(dedent, len(line) - len(line.lstrip()))
|
| 888 |
+
if math.isinf(dedent):
|
| 889 |
+
dedent = 0
|
| 890 |
+
for line in other_code._lines:
|
| 891 |
+
if isinstance(line, LineContext):
|
| 892 |
+
self._lines.append(line)
|
| 893 |
+
else:
|
| 894 |
+
IndentedBuffer.writeline(self, line[int(dedent) :])
|
| 895 |
+
else:
|
| 896 |
+
other_code = textwrap.dedent(other_code)
|
| 897 |
+
if strip:
|
| 898 |
+
other_code = other_code.lstrip()
|
| 899 |
+
if not other_code:
|
| 900 |
+
return
|
| 901 |
+
other_code = other_code.rstrip()
|
| 902 |
+
for line in other_code.split("\n"):
|
| 903 |
+
self.writeline(line)
|
| 904 |
+
|
| 905 |
+
def __repr__(self):
|
| 906 |
+
return f"{type(self)}({self.getvalue()})"
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
class DeferredLineBase:
|
| 910 |
+
"""A line that can be 'unwritten' at a later time"""
|
| 911 |
+
|
| 912 |
+
def __init__(self, line):
|
| 913 |
+
if not line.strip():
|
| 914 |
+
line = ""
|
| 915 |
+
self.line = line
|
| 916 |
+
|
| 917 |
+
def __call__(self) -> Optional[str]:
|
| 918 |
+
"""Returns either self.line or None to indicate the line has been 'unwritten'"""
|
| 919 |
+
raise NotImplementedError()
|
| 920 |
+
|
| 921 |
+
def _new_line(self, line: str) -> DeferredLineBase:
|
| 922 |
+
"""Returns a new deferred line with the same condition"""
|
| 923 |
+
raise NotImplementedError()
|
| 924 |
+
|
| 925 |
+
def with_prefix(self, prefix):
|
| 926 |
+
return self._new_line(f"{prefix}{self.line}")
|
| 927 |
+
|
| 928 |
+
def lstrip(self):
|
| 929 |
+
return self._new_line(self.line.lstrip())
|
| 930 |
+
|
| 931 |
+
def __getitem__(self, index):
|
| 932 |
+
return self._new_line(self.line[index])
|
| 933 |
+
|
| 934 |
+
def __bool__(self):
|
| 935 |
+
return bool(self.line)
|
| 936 |
+
|
| 937 |
+
def __len__(self):
|
| 938 |
+
return len(self.line)
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
@functools.lru_cache(None)
|
| 942 |
+
def is_big_gpu(index):
|
| 943 |
+
sms = torch.cuda.get_device_properties(index).multi_processor_count
|
| 944 |
+
if sms < 80: # V100
|
| 945 |
+
log.warning("not enough SMs to use max_autotune_gemm mode")
|
| 946 |
+
return False
|
| 947 |
+
return True
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
def use_max_autotune() -> bool:
|
| 951 |
+
return (
|
| 952 |
+
config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache
|
| 953 |
+
)
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool:
|
| 957 |
+
return (
|
| 958 |
+
use_max_autotune()
|
| 959 |
+
and layout.device.type == "cuda"
|
| 960 |
+
and layout.dtype in allowed_layout_dtypes
|
| 961 |
+
and is_big_gpu(layout.device.index or 0)
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
def _use_autotune_backend(backend: str) -> bool:
|
| 966 |
+
return backend.upper() in [
|
| 967 |
+
x.strip() for x in config.max_autotune_gemm_backends.upper().split(",")
|
| 968 |
+
]
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
def use_triton_template(layout, *, enable_int32=False):
|
| 972 |
+
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
| 973 |
+
if enable_int32:
|
| 974 |
+
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
|
| 975 |
+
return _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
|
| 976 |
+
"TRITON"
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
def use_cutlass_template(layout):
|
| 981 |
+
from .codegen.cuda.cutlass_utils import try_import_cutlass
|
| 982 |
+
|
| 983 |
+
# Do not use cutlass template on ROCm
|
| 984 |
+
if torch.version.hip:
|
| 985 |
+
return False
|
| 986 |
+
|
| 987 |
+
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
| 988 |
+
res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
|
| 989 |
+
"CUTLASS"
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
if res:
|
| 993 |
+
if not try_import_cutlass():
|
| 994 |
+
log.warning(
|
| 995 |
+
"Failed to import CUTLASS lib. Please check whether "
|
| 996 |
+
"_inductor.config.cuda.cutlass_dir is set correctly. "
|
| 997 |
+
"Skipping CUTLASS backend for now."
|
| 998 |
+
)
|
| 999 |
+
return False
|
| 1000 |
+
return res
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
def use_aten_gemm_kernels():
|
| 1004 |
+
return not use_max_autotune() or _use_autotune_backend("ATEN")
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
class DebugDirManager:
|
| 1008 |
+
counter = itertools.count(0)
|
| 1009 |
+
prev_debug_name: str
|
| 1010 |
+
|
| 1011 |
+
def __init__(self):
|
| 1012 |
+
self.id = next(DebugDirManager.counter)
|
| 1013 |
+
|
| 1014 |
+
def __enter__(self):
|
| 1015 |
+
self.prev_debug_name = torch._dynamo.config.debug_dir_root
|
| 1016 |
+
self.new_name = f"{self.prev_debug_name}_tmp_{self.id}"
|
| 1017 |
+
torch._dynamo.config.debug_dir_root = self.new_name
|
| 1018 |
+
|
| 1019 |
+
def __exit__(self, *args):
|
| 1020 |
+
shutil.rmtree(self.new_name)
|
| 1021 |
+
torch._dynamo.config.debug_dir_root = self.prev_debug_name
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
def run_and_get_code(fn, *args, **kwargs):
|
| 1025 |
+
from .graph import GraphLowering
|
| 1026 |
+
|
| 1027 |
+
compile_to_module = GraphLowering.compile_to_module
|
| 1028 |
+
source_codes = []
|
| 1029 |
+
|
| 1030 |
+
def patched_compile_to_module(self):
|
| 1031 |
+
mod = compile_to_module(self)
|
| 1032 |
+
with open(mod.__file__) as f:
|
| 1033 |
+
source_codes.append(f.read())
|
| 1034 |
+
return mod
|
| 1035 |
+
|
| 1036 |
+
# If FX code caching is enabled, a hit prevents getting the code.
|
| 1037 |
+
with config.patch({"fx_graph_cache": False}):
|
| 1038 |
+
with mock.patch.object(
|
| 1039 |
+
GraphLowering, "compile_to_module", patched_compile_to_module
|
| 1040 |
+
):
|
| 1041 |
+
torch._dynamo.reset()
|
| 1042 |
+
result = fn(*args, **kwargs)
|
| 1043 |
+
return result, source_codes
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
def run_and_get_triton_code(fn, *args, **kwargs):
|
| 1047 |
+
_, source_codes = run_and_get_code(fn, *args, **kwargs)
|
| 1048 |
+
# Can have two outputs if backwards was eagerly compiled
|
| 1049 |
+
assert (
|
| 1050 |
+
1 <= len(source_codes) <= 2
|
| 1051 |
+
), f"expected one or two code outputs got {len(source_codes)}"
|
| 1052 |
+
return source_codes[0]
|
| 1053 |
+
|
| 1054 |
+
|
| 1055 |
+
@contextlib.contextmanager
|
| 1056 |
+
def override_lowering(aten_op, override_fn):
|
| 1057 |
+
"""
|
| 1058 |
+
Override the lowering of aten_op with override_fn.
|
| 1059 |
+
The first argument of override_fn is the original lowering fn.
|
| 1060 |
+
"""
|
| 1061 |
+
from torch._inductor import lowering
|
| 1062 |
+
|
| 1063 |
+
orig_fn = lowering.lowerings[aten_op]
|
| 1064 |
+
try:
|
| 1065 |
+
lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn)
|
| 1066 |
+
yield
|
| 1067 |
+
finally:
|
| 1068 |
+
lowering.lowerings[aten_op] = orig_fn
|
| 1069 |
+
|
| 1070 |
+
|
| 1071 |
+
def add_scheduler_init_hook(pre_fn, post_fn=None):
|
| 1072 |
+
"""
|
| 1073 |
+
Add hook functions to be called at the beginning and end of Scheduler.__init__.
|
| 1074 |
+
Used for unit tests.
|
| 1075 |
+
"""
|
| 1076 |
+
from torch._inductor.scheduler import Scheduler
|
| 1077 |
+
|
| 1078 |
+
orig_fn = Scheduler.__init__
|
| 1079 |
+
|
| 1080 |
+
def wrapper(scheduler, nodes):
|
| 1081 |
+
pre_fn(scheduler, nodes)
|
| 1082 |
+
out = orig_fn(scheduler, nodes)
|
| 1083 |
+
if post_fn:
|
| 1084 |
+
post_fn(scheduler, nodes)
|
| 1085 |
+
return out
|
| 1086 |
+
|
| 1087 |
+
return unittest.mock.patch.object(Scheduler, "__init__", wrapper)
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
def developer_warning(msg):
|
| 1091 |
+
"""
|
| 1092 |
+
Warnings that will be actionable for PyTorch developers, but not
|
| 1093 |
+
end users. Allows us to easily disable them in stable releases but
|
| 1094 |
+
keep them on for nightly builds.
|
| 1095 |
+
"""
|
| 1096 |
+
if config.developer_warnings:
|
| 1097 |
+
log.warning(msg)
|
| 1098 |
+
else:
|
| 1099 |
+
log.info(msg)
|
| 1100 |
+
|
| 1101 |
+
|
| 1102 |
+
def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
|
| 1103 |
+
"""
|
| 1104 |
+
Return the total number of bytes the arguments of tensor type takes.
|
| 1105 |
+
|
| 1106 |
+
For in/out args, tensor sizes are counted twice: once for reading and
|
| 1107 |
+
once for writing.
|
| 1108 |
+
|
| 1109 |
+
The first num_in_out_args arguments are in out tensors.
|
| 1110 |
+
"""
|
| 1111 |
+
return sum(
|
| 1112 |
+
arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args))
|
| 1113 |
+
for i, arg in enumerate(args)
|
| 1114 |
+
if isinstance(arg, torch.Tensor)
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
|
| 1118 |
+
def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True):
|
| 1119 |
+
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
|
| 1120 |
+
slow = ms > 0.012 and gb_per_s < 650
|
| 1121 |
+
return red_text(info_str) if color and slow else info_str
|
| 1122 |
+
|
| 1123 |
+
|
| 1124 |
+
def get_benchmark_name():
|
| 1125 |
+
"""
|
| 1126 |
+
An experimental API used only when config.benchmark_kernel is true.
|
| 1127 |
+
|
| 1128 |
+
The benchmark name is only available at codegen time. So we can not
|
| 1129 |
+
directly call it in benchmark_all_kernels which is run after codegen.
|
| 1130 |
+
|
| 1131 |
+
The function assumes the argument after --only is the benchmark name.
|
| 1132 |
+
It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc
|
| 1133 |
+
scripts, this function may return None.
|
| 1134 |
+
|
| 1135 |
+
There are 2 flavors of --only argument we need handle:
|
| 1136 |
+
1. --only model_name
|
| 1137 |
+
2. --only=model_name
|
| 1138 |
+
"""
|
| 1139 |
+
try:
|
| 1140 |
+
idx = sys.argv.index("--only")
|
| 1141 |
+
if (
|
| 1142 |
+
idx + 1 < len(sys.argv)
|
| 1143 |
+
and len(sys.argv[idx + 1]) > 0
|
| 1144 |
+
and sys.argv[idx + 1][0] != "-"
|
| 1145 |
+
):
|
| 1146 |
+
return sys.argv[idx + 1]
|
| 1147 |
+
except ValueError:
|
| 1148 |
+
pass
|
| 1149 |
+
|
| 1150 |
+
for arg in sys.argv:
|
| 1151 |
+
if arg.startswith("--only="):
|
| 1152 |
+
return arg[len("--only=") :]
|
| 1153 |
+
|
| 1154 |
+
|
| 1155 |
+
def is_ones(items):
|
| 1156 |
+
return all(x == 1 for x in items)
|
| 1157 |
+
|
| 1158 |
+
|
| 1159 |
+
def is_zeros(items):
|
| 1160 |
+
return all(x == 0 for x in items)
|
| 1161 |
+
|
| 1162 |
+
|
| 1163 |
+
def is_cpu_device(inputs):
|
| 1164 |
+
return all(
|
| 1165 |
+
item.device == torch.device("cpu")
|
| 1166 |
+
for item in inputs
|
| 1167 |
+
if isinstance(item, torch.Tensor)
|
| 1168 |
+
)
|
| 1169 |
+
|
| 1170 |
+
|
| 1171 |
+
def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
|
| 1172 |
+
assert isinstance(
|
| 1173 |
+
val, sympy.Expr
|
| 1174 |
+
), "only support sympy.Expr as input to get_sympy_Expr_dtype"
|
| 1175 |
+
if val.is_integer: # type: ignore[attr-defined]
|
| 1176 |
+
return torch.int64
|
| 1177 |
+
else:
|
| 1178 |
+
return torch.float64
|
| 1179 |
+
|
| 1180 |
+
|
| 1181 |
+
@contextlib.contextmanager
|
| 1182 |
+
def maybe_profile(should_profile, *args, **kwargs):
|
| 1183 |
+
if should_profile:
|
| 1184 |
+
with torch.profiler.profile(*args, **kwargs) as p:
|
| 1185 |
+
yield p
|
| 1186 |
+
else:
|
| 1187 |
+
yield
|
| 1188 |
+
|
| 1189 |
+
|
| 1190 |
+
def triton_config_to_hashable(cfg):
|
| 1191 |
+
"""
|
| 1192 |
+
Convert triton config to a tuple that can uniquely identify it. We can use
|
| 1193 |
+
the return value as a dictionary key.
|
| 1194 |
+
"""
|
| 1195 |
+
items = sorted(cfg.kwargs.items())
|
| 1196 |
+
items.append(("num_warps", cfg.num_warps))
|
| 1197 |
+
items.append(("num_stages", cfg.num_stages))
|
| 1198 |
+
return tuple(items)
|
| 1199 |
+
|
| 1200 |
+
|
| 1201 |
+
def parallel_num_threads():
|
| 1202 |
+
threads = config.cpp.threads
|
| 1203 |
+
if threads < 1:
|
| 1204 |
+
threads = torch.get_num_threads()
|
| 1205 |
+
return threads
|
| 1206 |
+
|
| 1207 |
+
|
| 1208 |
+
HAS_COLORAMA = True
|
| 1209 |
+
try:
|
| 1210 |
+
import colorama
|
| 1211 |
+
except ImportError:
|
| 1212 |
+
HAS_COLORAMA = False
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
def _color_text(msg, color):
|
| 1216 |
+
if not HAS_COLORAMA:
|
| 1217 |
+
return msg
|
| 1218 |
+
|
| 1219 |
+
return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
|
| 1220 |
+
|
| 1221 |
+
|
| 1222 |
+
def green_text(msg):
|
| 1223 |
+
return _color_text(msg, "green")
|
| 1224 |
+
|
| 1225 |
+
|
| 1226 |
+
def yellow_text(msg):
|
| 1227 |
+
return _color_text(msg, "yellow")
|
| 1228 |
+
|
| 1229 |
+
|
| 1230 |
+
def red_text(msg):
|
| 1231 |
+
return _color_text(msg, "red")
|
| 1232 |
+
|
| 1233 |
+
|
| 1234 |
+
def blue_text(msg):
|
| 1235 |
+
return _color_text(msg, "blue")
|
| 1236 |
+
|
| 1237 |
+
|
| 1238 |
+
@functools.lru_cache(None)
|
| 1239 |
+
def get_device_tflops(dtype):
|
| 1240 |
+
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
|
| 1241 |
+
|
| 1242 |
+
assert dtype in (torch.float16, torch.bfloat16, torch.float32)
|
| 1243 |
+
|
| 1244 |
+
if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"):
|
| 1245 |
+
# Triton API change in https://github.com/openai/triton/pull/2293
|
| 1246 |
+
from torch._utils_internal import max_clock_rate
|
| 1247 |
+
|
| 1248 |
+
sm_clock = max_clock_rate()
|
| 1249 |
+
if dtype in (torch.float16, torch.bfloat16):
|
| 1250 |
+
return get_max_tensorcore_tflops(dtype, sm_clock)
|
| 1251 |
+
|
| 1252 |
+
if torch.backends.cuda.matmul.allow_tf32:
|
| 1253 |
+
return get_max_tensorcore_tflops(torch.float32, sm_clock)
|
| 1254 |
+
else:
|
| 1255 |
+
return get_max_simd_tflops(torch.float32, sm_clock)
|
| 1256 |
+
else:
|
| 1257 |
+
if dtype in (torch.float16, torch.bfloat16):
|
| 1258 |
+
return get_max_tensorcore_tflops(dtype)
|
| 1259 |
+
|
| 1260 |
+
if torch.backends.cuda.matmul.allow_tf32:
|
| 1261 |
+
return get_max_tensorcore_tflops(torch.float32)
|
| 1262 |
+
else:
|
| 1263 |
+
return get_max_simd_tflops(torch.float32)
|
| 1264 |
+
|
| 1265 |
+
|
| 1266 |
+
@functools.lru_cache(None)
|
| 1267 |
+
def get_gpu_dram_gbps():
|
| 1268 |
+
from triton.testing import get_dram_gbps
|
| 1269 |
+
|
| 1270 |
+
return get_dram_gbps()
|
| 1271 |
+
|
| 1272 |
+
|
| 1273 |
+
def is_welford_reduction(reduction_type):
|
| 1274 |
+
return reduction_type.startswith("welford")
|
| 1275 |
+
|
| 1276 |
+
|
| 1277 |
+
def reduction_num_outputs(reduction_type):
|
| 1278 |
+
return 3 if is_welford_reduction(reduction_type) else 1
|
| 1279 |
+
|
| 1280 |
+
|
| 1281 |
+
def get_max_y_grid():
|
| 1282 |
+
return 65535
|
| 1283 |
+
|
| 1284 |
+
|
| 1285 |
+
def is_linux() -> bool:
|
| 1286 |
+
return platform.system() == "Linux"
|
| 1287 |
+
|
| 1288 |
+
|
| 1289 |
+
def has_free_symbols(itr: Iterable[Any]):
|
| 1290 |
+
return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr)
|
| 1291 |
+
|
| 1292 |
+
|
| 1293 |
+
def is_dynamic(*args):
|
| 1294 |
+
from . import ir
|
| 1295 |
+
|
| 1296 |
+
for t in args:
|
| 1297 |
+
if isinstance(t, ir.TensorBox):
|
| 1298 |
+
if has_free_symbols(t.data.get_size()) or (
|
| 1299 |
+
hasattr(t.data, "get_stride") and has_free_symbols(t.data.get_stride())
|
| 1300 |
+
):
|
| 1301 |
+
return True
|
| 1302 |
+
elif isinstance(t, (ir.StorageBox, ir.BaseView, ir.ComputedBuffer)):
|
| 1303 |
+
assert hasattr(t, "get_size") and hasattr(t, "get_stride")
|
| 1304 |
+
if has_free_symbols(t.get_size()) or has_free_symbols(t.get_stride()):
|
| 1305 |
+
return True
|
| 1306 |
+
elif not isinstance(t, ir.IRNode):
|
| 1307 |
+
continue
|
| 1308 |
+
else:
|
| 1309 |
+
raise TypeError(f"unexpected type for is_dynamic {type(t)}")
|
| 1310 |
+
|
| 1311 |
+
return False
|
| 1312 |
+
|
| 1313 |
+
|
| 1314 |
+
# Placeholder strings used in triton codegen.
|
| 1315 |
+
class Placeholder(enum.Enum):
|
| 1316 |
+
# The placeholder for the actual name of a triton kernel.
|
| 1317 |
+
# e.g. for "def triton_" it would be "triton_"
|
| 1318 |
+
KERNEL_NAME = "KERNEL_NAME"
|
| 1319 |
+
|
| 1320 |
+
# The descriptive name of the triton kernel; when unique_kernel_names = False, this
|
| 1321 |
+
# placeholder will be replaced with a string with more information.
|
| 1322 |
+
DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME"
|
| 1323 |
+
|
| 1324 |
+
|
| 1325 |
+
def pass_execution_and_save(func, gm, msg):
|
| 1326 |
+
from .pattern_matcher import stable_topological_sort
|
| 1327 |
+
|
| 1328 |
+
with tempfile.NamedTemporaryFile(
|
| 1329 |
+
mode="w",
|
| 1330 |
+
encoding="utf-8",
|
| 1331 |
+
delete=False,
|
| 1332 |
+
) as f:
|
| 1333 |
+
before_io = io.StringIO()
|
| 1334 |
+
after_io = io.StringIO()
|
| 1335 |
+
print(f"Before:\n{gm.graph}", file=f)
|
| 1336 |
+
print(gm.graph, file=before_io)
|
| 1337 |
+
start_time = datetime.now()
|
| 1338 |
+
func(gm.graph)
|
| 1339 |
+
time_elapsed = datetime.now() - start_time
|
| 1340 |
+
# recompile graph
|
| 1341 |
+
stable_topological_sort(gm.graph)
|
| 1342 |
+
gm.graph.lint()
|
| 1343 |
+
gm.recompile()
|
| 1344 |
+
|
| 1345 |
+
print(f"After:\n{gm.graph}", file=f)
|
| 1346 |
+
print(gm.graph, file=after_io)
|
| 1347 |
+
t = before_io.getvalue() == after_io.getvalue()
|
| 1348 |
+
log.info(
|
| 1349 |
+
"%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s",
|
| 1350 |
+
msg,
|
| 1351 |
+
f.name,
|
| 1352 |
+
t,
|
| 1353 |
+
time_elapsed,
|
| 1354 |
+
)
|
| 1355 |
+
|
| 1356 |
+
|
| 1357 |
+
def is_collective(node):
|
| 1358 |
+
from . import ir
|
| 1359 |
+
|
| 1360 |
+
return isinstance(node, ir.CollectiveKernel) or type(node) == ir._CollectiveKernel
|
| 1361 |
+
|
| 1362 |
+
|
| 1363 |
+
def is_wait(node):
|
| 1364 |
+
from . import ir
|
| 1365 |
+
|
| 1366 |
+
return isinstance(node, ir.Wait) or type(node) == ir._WaitKernel
|
| 1367 |
+
|
| 1368 |
+
|
| 1369 |
+
def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int):
|
| 1370 |
+
"Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)"
|
| 1371 |
+
num_rng_seed_offset_inputs = (
|
| 1372 |
+
2 if torch._functorch.config.functionalize_rng_ops else 0
|
| 1373 |
+
)
|
| 1374 |
+
return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs
|
| 1375 |
+
|
| 1376 |
+
|
| 1377 |
+
def count_tangents(fx_g: torch.fx.GraphModule):
|
| 1378 |
+
"""
|
| 1379 |
+
Infers which inputs are static for a backwards graph
|
| 1380 |
+
"""
|
| 1381 |
+
|
| 1382 |
+
def is_saved_tensor(x):
|
| 1383 |
+
return (
|
| 1384 |
+
"tangents" not in x.name
|
| 1385 |
+
and "bwd_seed" not in x.name
|
| 1386 |
+
and "bwd_base_offset" not in x.name
|
| 1387 |
+
)
|
| 1388 |
+
|
| 1389 |
+
arg_count = 0
|
| 1390 |
+
static_arg_idxs = []
|
| 1391 |
+
for n in fx_g.graph.nodes:
|
| 1392 |
+
if n.op == "placeholder":
|
| 1393 |
+
if is_saved_tensor(n):
|
| 1394 |
+
static_arg_idxs.append(arg_count)
|
| 1395 |
+
arg_count += 1
|
| 1396 |
+
|
| 1397 |
+
assert static_arg_idxs == list(range(len(static_arg_idxs)))
|
| 1398 |
+
return len(static_arg_idxs)
|
| 1399 |
+
|
| 1400 |
+
|
| 1401 |
+
@dataclasses.dataclass
|
| 1402 |
+
class BoxedBool:
|
| 1403 |
+
value: bool
|
| 1404 |
+
|
| 1405 |
+
def __bool__(self):
|
| 1406 |
+
return self.value
|
| 1407 |
+
|
| 1408 |
+
@staticmethod
|
| 1409 |
+
def disable(obj):
|
| 1410 |
+
if isinstance(obj, BoxedBool):
|
| 1411 |
+
obj.value = False
|
| 1412 |
+
return obj
|
| 1413 |
+
return False
|
| 1414 |
+
|
| 1415 |
+
|
| 1416 |
+
@contextlib.contextmanager
|
| 1417 |
+
def collect_defined_kernels(kernel_list):
|
| 1418 |
+
from .codegen.wrapper import WrapperCodeGen
|
| 1419 |
+
|
| 1420 |
+
orig_define_kernel = WrapperCodeGen.define_kernel
|
| 1421 |
+
|
| 1422 |
+
def new_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs):
|
| 1423 |
+
nonlocal kernel_list
|
| 1424 |
+
kernel_list.append(kernel_code)
|
| 1425 |
+
return orig_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs)
|
| 1426 |
+
|
| 1427 |
+
with unittest.mock.patch.object(WrapperCodeGen, "define_kernel", new_define_kernel):
|
| 1428 |
+
yield
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/wrapper_benchmark.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import tempfile
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.autograd import DeviceType
|
| 7 |
+
from .utils import create_bandwidth_info_str, do_bench, get_num_bytes
|
| 8 |
+
|
| 9 |
+
_kernel_category_choices = [
|
| 10 |
+
"foreach",
|
| 11 |
+
"persistent_reduction",
|
| 12 |
+
"pointwise",
|
| 13 |
+
"reduction",
|
| 14 |
+
"split_scan",
|
| 15 |
+
"template",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_kernel_category_by_source_code(src_code):
|
| 20 |
+
"""
|
| 21 |
+
Similar to get_kernel_category but use the source code. Call this API
|
| 22 |
+
if we have not compile the src_code to module yet.
|
| 23 |
+
"""
|
| 24 |
+
choices = [
|
| 25 |
+
ch for ch in _kernel_category_choices if f"@triton_heuristics.{ch}" in src_code
|
| 26 |
+
]
|
| 27 |
+
if len(choices) == 1:
|
| 28 |
+
return choices[0]
|
| 29 |
+
else:
|
| 30 |
+
return "unknown"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_kernel_category(kernel_mod):
|
| 34 |
+
"""
|
| 35 |
+
Given the module defining a triton kernel, return the category of the kernel.
|
| 36 |
+
Category can be one of:
|
| 37 |
+
- pointwise
|
| 38 |
+
- reduction
|
| 39 |
+
- persistent_reduction
|
| 40 |
+
|
| 41 |
+
Currently we simply decide the category depending on what decorator is imported
|
| 42 |
+
by the kernel.
|
| 43 |
+
"""
|
| 44 |
+
choices = [ch for ch in _kernel_category_choices if ch in kernel_mod.__dict__]
|
| 45 |
+
if len(choices) == 1:
|
| 46 |
+
return choices[0]
|
| 47 |
+
else:
|
| 48 |
+
return "unknown"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_triton_kernel(mod):
|
| 52 |
+
from torch._inductor.triton_heuristics import CachingAutotuner
|
| 53 |
+
|
| 54 |
+
cand_list = [
|
| 55 |
+
v
|
| 56 |
+
for k, v in mod.__dict__.items()
|
| 57 |
+
if k.startswith("triton_") and isinstance(v, CachingAutotuner)
|
| 58 |
+
]
|
| 59 |
+
assert len(cand_list) == 1
|
| 60 |
+
return cand_list[0]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
|
| 64 |
+
"""
|
| 65 |
+
An experimental API used only when config.benchmark_kernel is true.
|
| 66 |
+
|
| 67 |
+
Run the kernel benchmarks for all the kernels cached in PyCodeCache.
|
| 68 |
+
Used in the compiled modules.
|
| 69 |
+
|
| 70 |
+
Put this method here rather than codegen it for convenience since its implementation
|
| 71 |
+
does not change based on different graph modules being compiled.
|
| 72 |
+
"""
|
| 73 |
+
from torch._inductor.codecache import PyCodeCache
|
| 74 |
+
|
| 75 |
+
nfound = 0
|
| 76 |
+
for kernel_key, kernel_mod in PyCodeCache.cache.items():
|
| 77 |
+
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
triton_kernel = get_triton_kernel(kernel_mod)
|
| 81 |
+
kernel_category = get_kernel_category(kernel_mod)
|
| 82 |
+
args = kernel_mod.get_args()
|
| 83 |
+
num_in_out_ptrs = len(
|
| 84 |
+
[
|
| 85 |
+
arg_name
|
| 86 |
+
for arg_name in triton_kernel.fn.arg_names
|
| 87 |
+
if arg_name.startswith("in_out_ptr")
|
| 88 |
+
]
|
| 89 |
+
)
|
| 90 |
+
num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None)
|
| 91 |
+
if num_gb is None:
|
| 92 |
+
num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
|
| 93 |
+
|
| 94 |
+
def get_info_str(ms, n_regs, n_spills, shared, prefix=""):
|
| 95 |
+
if not any(x is None for x in [n_regs, n_spills, shared]):
|
| 96 |
+
kernel_detail_str = (
|
| 97 |
+
f" {n_regs:3} regs {n_spills:3} spills {shared:8} shared mem"
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
kernel_detail_str = ""
|
| 101 |
+
|
| 102 |
+
gb_per_s = num_gb / (ms / 1e3)
|
| 103 |
+
return create_bandwidth_info_str(
|
| 104 |
+
ms, num_gb, gb_per_s, prefix=prefix, suffix=kernel_detail_str
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
kernel_desc = (
|
| 108 |
+
f"{benchmark_name:20} {kernel_category[:3].upper()} {kernel_key[:10]}"
|
| 109 |
+
)
|
| 110 |
+
if benchmark_all_configs:
|
| 111 |
+
assert hasattr(kernel_mod, "benchmark_all_configs")
|
| 112 |
+
bench_result = kernel_mod.benchmark_all_configs(args)
|
| 113 |
+
print(kernel_desc)
|
| 114 |
+
for launcher, ms in bench_result.items():
|
| 115 |
+
print(
|
| 116 |
+
f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}"
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)
|
| 120 |
+
assert (
|
| 121 |
+
len(triton_kernel.launchers) == 1
|
| 122 |
+
), "Autotuner should have selected the best config"
|
| 123 |
+
launcher = triton_kernel.launchers[0]
|
| 124 |
+
print(
|
| 125 |
+
get_info_str(
|
| 126 |
+
ms,
|
| 127 |
+
launcher.n_regs,
|
| 128 |
+
launcher.n_spills,
|
| 129 |
+
launcher.shared,
|
| 130 |
+
prefix=f"{kernel_desc} ",
|
| 131 |
+
)
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
nfound += 1
|
| 135 |
+
if nfound == 0:
|
| 136 |
+
print(
|
| 137 |
+
"No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@dataclasses.dataclass
|
| 142 |
+
class ProfileEvent:
|
| 143 |
+
category: str
|
| 144 |
+
key: str
|
| 145 |
+
self_cuda_time_ms: float
|
| 146 |
+
# the benchmark is run multiple times and we average the count across all the
|
| 147 |
+
# runs. It should be an integer but define a float just in case.
|
| 148 |
+
count: float
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def parse_profile_event_list(benchmark_name, event_list, wall_time_ms, nruns):
|
| 152 |
+
def get_self_cuda_time(ev):
|
| 153 |
+
"""
|
| 154 |
+
ev.self_cuda_time_total is in microsecond. Convert to millisecond.
|
| 155 |
+
"""
|
| 156 |
+
return ev.self_cuda_time_total / 1000 / nruns
|
| 157 |
+
|
| 158 |
+
all_events = defaultdict(list)
|
| 159 |
+
|
| 160 |
+
def add_event(ev, category):
|
| 161 |
+
profile_ev = ProfileEvent(
|
| 162 |
+
category=category,
|
| 163 |
+
key=ev.key,
|
| 164 |
+
self_cuda_time_ms=get_self_cuda_time(ev),
|
| 165 |
+
count=ev.count / nruns, # average across all runs
|
| 166 |
+
)
|
| 167 |
+
all_events[category].append(profile_ev)
|
| 168 |
+
|
| 169 |
+
for ev in event_list:
|
| 170 |
+
assert not ev.is_legacy, "Don't support the legacy profiler"
|
| 171 |
+
if ev.device_type == DeviceType.CPU:
|
| 172 |
+
# ignore the event on CPU side
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
category = "unknown"
|
| 176 |
+
if ev.key.startswith("triton_"):
|
| 177 |
+
if ev.key.startswith("triton_poi"):
|
| 178 |
+
category = "triton_pointwise"
|
| 179 |
+
elif ev.key.startswith("triton_red"):
|
| 180 |
+
category = "triton_reduction"
|
| 181 |
+
elif ev.key.startswith("triton_per"):
|
| 182 |
+
category = "triton_persistent_reduction"
|
| 183 |
+
else:
|
| 184 |
+
category = "triton_unknown"
|
| 185 |
+
|
| 186 |
+
add_event(ev, category)
|
| 187 |
+
|
| 188 |
+
def report_category(category, profile_events):
|
| 189 |
+
from tabulate import tabulate
|
| 190 |
+
|
| 191 |
+
profile_events.sort(key=lambda ev: ev.self_cuda_time_ms, reverse=True)
|
| 192 |
+
|
| 193 |
+
rows = []
|
| 194 |
+
total_time = 0.0
|
| 195 |
+
print(f"\n == {category} category kernels == ")
|
| 196 |
+
for ev in profile_events:
|
| 197 |
+
total_time += ev.self_cuda_time_ms
|
| 198 |
+
percent = f"{ev.self_cuda_time_ms / wall_time_ms * 100:.2f}%"
|
| 199 |
+
rows.append([ev.key[:120], ev.self_cuda_time_ms, ev.count, percent])
|
| 200 |
+
rows.append(
|
| 201 |
+
["Total", total_time, "", f"{total_time / wall_time_ms * 100:.2f}%"]
|
| 202 |
+
)
|
| 203 |
+
print(
|
| 204 |
+
tabulate(
|
| 205 |
+
rows, headers=["Kernel", "Self CUDA TIME (ms)", "Count", "Percent"]
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
return total_time
|
| 209 |
+
|
| 210 |
+
def report():
|
| 211 |
+
category_list = [
|
| 212 |
+
"triton_pointwise",
|
| 213 |
+
"triton_reduction",
|
| 214 |
+
"triton_persistent_reduction",
|
| 215 |
+
"triton_unknown",
|
| 216 |
+
"unknown",
|
| 217 |
+
]
|
| 218 |
+
assert set(all_events.keys()).issubset(
|
| 219 |
+
set(category_list)
|
| 220 |
+
), f"{list(all_events.keys())}"
|
| 221 |
+
|
| 222 |
+
per_category_wall_time = {}
|
| 223 |
+
total_cuda_ms = 0.0
|
| 224 |
+
for category in category_list:
|
| 225 |
+
if category in all_events:
|
| 226 |
+
_time = report_category(category, all_events[category])
|
| 227 |
+
per_category_wall_time[category] = _time
|
| 228 |
+
total_cuda_ms += _time
|
| 229 |
+
|
| 230 |
+
gpu_busy_percent = f"{total_cuda_ms / wall_time_ms * 100:.2f}%"
|
| 231 |
+
print(f"\nPercent of time when GPU is busy: {gpu_busy_percent}")
|
| 232 |
+
print(f"Total wall time {wall_time_ms:.3f} ms")
|
| 233 |
+
|
| 234 |
+
# output such a line so we can gather such line from all compiled modules from all
|
| 235 |
+
# benchmarks and tabulate it!
|
| 236 |
+
# Columns: benchmark_name, pointwise_percent, reduction_percent, persistent_reduction_percent,
|
| 237 |
+
# unknown_category_percent, GPU_busy_percent, wall_time_ms
|
| 238 |
+
tabulate_line = f"Output for tabulate: {benchmark_name}"
|
| 239 |
+
for category in category_list:
|
| 240 |
+
percent = (
|
| 241 |
+
f"{per_category_wall_time.get(category, 0.0) / wall_time_ms * 100:.2f}%"
|
| 242 |
+
)
|
| 243 |
+
tabulate_line += f", {percent}"
|
| 244 |
+
tabulate_line += f", {gpu_busy_percent}, {wall_time_ms:.3f}ms"
|
| 245 |
+
|
| 246 |
+
print(tabulate_line)
|
| 247 |
+
|
| 248 |
+
report()
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def compiled_module_main(benchmark_name, benchmark_compiled_module_fn):
|
| 252 |
+
"""
|
| 253 |
+
This is the function called in __main__ block of a compiled module.
|
| 254 |
+
"""
|
| 255 |
+
import argparse
|
| 256 |
+
|
| 257 |
+
parser = argparse.ArgumentParser()
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
"--benchmark-kernels",
|
| 260 |
+
"-k",
|
| 261 |
+
action="store_true",
|
| 262 |
+
help="Whether to benchmark each individual kernels",
|
| 263 |
+
)
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
"--benchmark-all-configs",
|
| 266 |
+
"-c",
|
| 267 |
+
action="store_true",
|
| 268 |
+
help="Whether to benchmark each individual config for a kernel",
|
| 269 |
+
)
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--profile",
|
| 272 |
+
"-p",
|
| 273 |
+
action="store_true",
|
| 274 |
+
help="Whether to profile the compiled module",
|
| 275 |
+
)
|
| 276 |
+
args = parser.parse_args()
|
| 277 |
+
|
| 278 |
+
if args.benchmark_kernels:
|
| 279 |
+
benchmark_all_kernels(benchmark_name, args.benchmark_all_configs)
|
| 280 |
+
else:
|
| 281 |
+
times = 10
|
| 282 |
+
repeat = 10
|
| 283 |
+
wall_time_ms = benchmark_compiled_module_fn(times=times, repeat=repeat) * 1000
|
| 284 |
+
|
| 285 |
+
if not args.profile:
|
| 286 |
+
return
|
| 287 |
+
|
| 288 |
+
with torch.profiler.profile(record_shapes=True) as p:
|
| 289 |
+
benchmark_compiled_module_fn(times=times, repeat=repeat)
|
| 290 |
+
|
| 291 |
+
path = f"{tempfile.gettempdir()}/compiled_module_profile.json"
|
| 292 |
+
p.export_chrome_trace(path)
|
| 293 |
+
print(f"Profiling result for a compiled module of benchmark {benchmark_name}:")
|
| 294 |
+
print(f"Chrome trace for the profile is written to {path}")
|
| 295 |
+
event_list = p.key_averages(group_by_input_shape=True)
|
| 296 |
+
print(event_list.table(sort_by="self_cuda_time_total", row_limit=10))
|
| 297 |
+
parse_profile_event_list(
|
| 298 |
+
benchmark_name, event_list, wall_time_ms, times * repeat
|
| 299 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/DimVector.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Dimname.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DynamicLibrary.h
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Utils.h>
|
| 4 |
+
#include <c10/macros/Export.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
|
| 7 |
+
namespace c10 {
|
| 8 |
+
|
| 9 |
+
class DynamicLibraryError : public Error {
|
| 10 |
+
using Error::Error;
|
| 11 |
+
};
|
| 12 |
+
|
| 13 |
+
} // namespace c10
|
| 14 |
+
|
| 15 |
+
namespace at {
|
| 16 |
+
|
| 17 |
+
struct DynamicLibrary {
|
| 18 |
+
AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
|
| 19 |
+
|
| 20 |
+
TORCH_API DynamicLibrary(
|
| 21 |
+
const char* name,
|
| 22 |
+
const char* alt_name = nullptr,
|
| 23 |
+
bool leak_handle = false);
|
| 24 |
+
|
| 25 |
+
TORCH_API void* sym(const char* name);
|
| 26 |
+
|
| 27 |
+
TORCH_API ~DynamicLibrary();
|
| 28 |
+
|
| 29 |
+
private:
|
| 30 |
+
bool leak_handle;
|
| 31 |
+
void* handle = nullptr;
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Formatting.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_meta_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_add_relu_meta_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_addmm_activation_meta_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_amp_update_scale_meta_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_coalesced_meta_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_meta_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_meta_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_ctc_loss_meta_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_efficientzerotensor_meta_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask_meta_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_fused_sdp_choice_meta_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_index_put_impl_meta_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_linalg_det_meta_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_linalg_eigh_meta_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_linalg_slogdet_meta_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_linalg_solve_ex_meta_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_linalg_svd_meta_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_log_softmax_meta_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_log_softmax_backward_data_meta_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_mkldnn_transpose_meta_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_reshape_alias_meta_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_resize_output_meta_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_softmax_meta_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_softmax_backward_data_meta_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_meta_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_meta_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_meta_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_meta_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_meta_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_meta_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_upsample_nearest_exact1d_meta_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_meta_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_upsample_nearest_exact2d_meta_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_meta_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_upsample_nearest_exact3d_meta_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_meta_dispatch.h>
|
| 54 |
+
#include <ATen/ops/acos_meta_dispatch.h>
|
| 55 |
+
#include <ATen/ops/acosh_meta_dispatch.h>
|
| 56 |
+
#include <ATen/ops/adaptive_max_pool2d_meta_dispatch.h>
|
| 57 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_meta_dispatch.h>
|
| 58 |
+
#include <ATen/ops/adaptive_max_pool3d_meta_dispatch.h>
|
| 59 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_meta_dispatch.h>
|
| 60 |
+
#include <ATen/ops/add_meta_dispatch.h>
|
| 61 |
+
#include <ATen/ops/addbmm_meta_dispatch.h>
|
| 62 |
+
#include <ATen/ops/addcdiv_meta_dispatch.h>
|
| 63 |
+
#include <ATen/ops/addcmul_meta_dispatch.h>
|
| 64 |
+
#include <ATen/ops/addmm_meta_dispatch.h>
|
| 65 |
+
#include <ATen/ops/addmv_meta_dispatch.h>
|
| 66 |
+
#include <ATen/ops/all_meta_dispatch.h>
|
| 67 |
+
#include <ATen/ops/amax_meta_dispatch.h>
|
| 68 |
+
#include <ATen/ops/amin_meta_dispatch.h>
|
| 69 |
+
#include <ATen/ops/aminmax_meta_dispatch.h>
|
| 70 |
+
#include <ATen/ops/any_meta_dispatch.h>
|
| 71 |
+
#include <ATen/ops/arange_meta_dispatch.h>
|
| 72 |
+
#include <ATen/ops/argmax_meta_dispatch.h>
|
| 73 |
+
#include <ATen/ops/argmin_meta_dispatch.h>
|
| 74 |
+
#include <ATen/ops/as_strided_meta_dispatch.h>
|
| 75 |
+
#include <ATen/ops/asin_meta_dispatch.h>
|
| 76 |
+
#include <ATen/ops/asinh_meta_dispatch.h>
|
| 77 |
+
#include <ATen/ops/atan_meta_dispatch.h>
|
| 78 |
+
#include <ATen/ops/atan2_meta_dispatch.h>
|
| 79 |
+
#include <ATen/ops/atanh_meta_dispatch.h>
|
| 80 |
+
#include <ATen/ops/avg_pool2d_meta_dispatch.h>
|
| 81 |
+
#include <ATen/ops/avg_pool2d_backward_meta_dispatch.h>
|
| 82 |
+
#include <ATen/ops/avg_pool3d_meta_dispatch.h>
|
| 83 |
+
#include <ATen/ops/avg_pool3d_backward_meta_dispatch.h>
|
| 84 |
+
#include <ATen/ops/baddbmm_meta_dispatch.h>
|
| 85 |
+
#include <ATen/ops/bernoulli_meta_dispatch.h>
|
| 86 |
+
#include <ATen/ops/bitwise_and_meta_dispatch.h>
|
| 87 |
+
#include <ATen/ops/bitwise_left_shift_meta_dispatch.h>
|
| 88 |
+
#include <ATen/ops/bitwise_not_meta_dispatch.h>
|
| 89 |
+
#include <ATen/ops/bitwise_or_meta_dispatch.h>
|
| 90 |
+
#include <ATen/ops/bitwise_right_shift_meta_dispatch.h>
|
| 91 |
+
#include <ATen/ops/bitwise_xor_meta_dispatch.h>
|
| 92 |
+
#include <ATen/ops/bmm_meta_dispatch.h>
|
| 93 |
+
#include <ATen/ops/cat_meta_dispatch.h>
|
| 94 |
+
#include <ATen/ops/cauchy_meta_dispatch.h>
|
| 95 |
+
#include <ATen/ops/ceil_meta_dispatch.h>
|
| 96 |
+
#include <ATen/ops/clamp_meta_dispatch.h>
|
| 97 |
+
#include <ATen/ops/clamp_max_meta_dispatch.h>
|
| 98 |
+
#include <ATen/ops/clamp_min_meta_dispatch.h>
|
| 99 |
+
#include <ATen/ops/copy_sparse_to_sparse_meta_dispatch.h>
|
| 100 |
+
#include <ATen/ops/copysign_meta_dispatch.h>
|
| 101 |
+
#include <ATen/ops/cos_meta_dispatch.h>
|
| 102 |
+
#include <ATen/ops/cosh_meta_dispatch.h>
|
| 103 |
+
#include <ATen/ops/cumprod_meta_dispatch.h>
|
| 104 |
+
#include <ATen/ops/cumsum_meta_dispatch.h>
|
| 105 |
+
#include <ATen/ops/digamma_meta_dispatch.h>
|
| 106 |
+
#include <ATen/ops/div_meta_dispatch.h>
|
| 107 |
+
#include <ATen/ops/elu_meta_dispatch.h>
|
| 108 |
+
#include <ATen/ops/elu_backward_meta_dispatch.h>
|
| 109 |
+
#include <ATen/ops/embedding_renorm_meta_dispatch.h>
|
| 110 |
+
#include <ATen/ops/empty_meta_dispatch.h>
|
| 111 |
+
#include <ATen/ops/empty_strided_meta_dispatch.h>
|
| 112 |
+
#include <ATen/ops/eq_meta_dispatch.h>
|
| 113 |
+
#include <ATen/ops/erf_meta_dispatch.h>
|
| 114 |
+
#include <ATen/ops/erfc_meta_dispatch.h>
|
| 115 |
+
#include <ATen/ops/erfinv_meta_dispatch.h>
|
| 116 |
+
#include <ATen/ops/exp_meta_dispatch.h>
|
| 117 |
+
#include <ATen/ops/exp2_meta_dispatch.h>
|
| 118 |
+
#include <ATen/ops/expm1_meta_dispatch.h>
|
| 119 |
+
#include <ATen/ops/exponential_meta_dispatch.h>
|
| 120 |
+
#include <ATen/ops/eye_meta_dispatch.h>
|
| 121 |
+
#include <ATen/ops/fill_meta_dispatch.h>
|
| 122 |
+
#include <ATen/ops/floor_meta_dispatch.h>
|
| 123 |
+
#include <ATen/ops/floor_divide_meta_dispatch.h>
|
| 124 |
+
#include <ATen/ops/fmax_meta_dispatch.h>
|
| 125 |
+
#include <ATen/ops/fmin_meta_dispatch.h>
|
| 126 |
+
#include <ATen/ops/fmod_meta_dispatch.h>
|
| 127 |
+
#include <ATen/ops/frac_meta_dispatch.h>
|
| 128 |
+
#include <ATen/ops/fractional_max_pool2d_meta_dispatch.h>
|
| 129 |
+
#include <ATen/ops/fractional_max_pool2d_backward_meta_dispatch.h>
|
| 130 |
+
#include <ATen/ops/fractional_max_pool3d_meta_dispatch.h>
|
| 131 |
+
#include <ATen/ops/gather_meta_dispatch.h>
|
| 132 |
+
#include <ATen/ops/gcd_meta_dispatch.h>
|
| 133 |
+
#include <ATen/ops/ge_meta_dispatch.h>
|
| 134 |
+
#include <ATen/ops/gelu_meta_dispatch.h>
|
| 135 |
+
#include <ATen/ops/gelu_backward_meta_dispatch.h>
|
| 136 |
+
#include <ATen/ops/geometric_meta_dispatch.h>
|
| 137 |
+
#include <ATen/ops/glu_meta_dispatch.h>
|
| 138 |
+
#include <ATen/ops/gt_meta_dispatch.h>
|
| 139 |
+
#include <ATen/ops/hardshrink_meta_dispatch.h>
|
| 140 |
+
#include <ATen/ops/hardshrink_backward_meta_dispatch.h>
|
| 141 |
+
#include <ATen/ops/hardsigmoid_meta_dispatch.h>
|
| 142 |
+
#include <ATen/ops/hardsigmoid_backward_meta_dispatch.h>
|
| 143 |
+
#include <ATen/ops/hardswish_meta_dispatch.h>
|
| 144 |
+
#include <ATen/ops/hardtanh_meta_dispatch.h>
|
| 145 |
+
#include <ATen/ops/heaviside_meta_dispatch.h>
|
| 146 |
+
#include <ATen/ops/hypot_meta_dispatch.h>
|
| 147 |
+
#include <ATen/ops/i0_meta_dispatch.h>
|
| 148 |
+
#include <ATen/ops/igamma_meta_dispatch.h>
|
| 149 |
+
#include <ATen/ops/igammac_meta_dispatch.h>
|
| 150 |
+
#include <ATen/ops/index_meta_dispatch.h>
|
| 151 |
+
#include <ATen/ops/index_add_meta_dispatch.h>
|
| 152 |
+
#include <ATen/ops/index_copy_meta_dispatch.h>
|
| 153 |
+
#include <ATen/ops/index_fill_meta_dispatch.h>
|
| 154 |
+
#include <ATen/ops/index_reduce_meta_dispatch.h>
|
| 155 |
+
#include <ATen/ops/isin_meta_dispatch.h>
|
| 156 |
+
#include <ATen/ops/isneginf_meta_dispatch.h>
|
| 157 |
+
#include <ATen/ops/isposinf_meta_dispatch.h>
|
| 158 |
+
#include <ATen/ops/lcm_meta_dispatch.h>
|
| 159 |
+
#include <ATen/ops/le_meta_dispatch.h>
|
| 160 |
+
#include <ATen/ops/leaky_relu_meta_dispatch.h>
|
| 161 |
+
#include <ATen/ops/leaky_relu_backward_meta_dispatch.h>
|
| 162 |
+
#include <ATen/ops/lerp_meta_dispatch.h>
|
| 163 |
+
#include <ATen/ops/lgamma_meta_dispatch.h>
|
| 164 |
+
#include <ATen/ops/linalg_cholesky_ex_meta_dispatch.h>
|
| 165 |
+
#include <ATen/ops/linalg_cross_meta_dispatch.h>
|
| 166 |
+
#include <ATen/ops/linalg_inv_ex_meta_dispatch.h>
|
| 167 |
+
#include <ATen/ops/linalg_ldl_factor_ex_meta_dispatch.h>
|
| 168 |
+
#include <ATen/ops/linalg_ldl_solve_meta_dispatch.h>
|
| 169 |
+
#include <ATen/ops/linalg_lu_meta_dispatch.h>
|
| 170 |
+
#include <ATen/ops/linalg_lu_factor_ex_meta_dispatch.h>
|
| 171 |
+
#include <ATen/ops/linalg_lu_solve_meta_dispatch.h>
|
| 172 |
+
#include <ATen/ops/linalg_qr_meta_dispatch.h>
|
| 173 |
+
#include <ATen/ops/linalg_vector_norm_meta_dispatch.h>
|
| 174 |
+
#include <ATen/ops/linspace_meta_dispatch.h>
|
| 175 |
+
#include <ATen/ops/log_meta_dispatch.h>
|
| 176 |
+
#include <ATen/ops/log10_meta_dispatch.h>
|
| 177 |
+
#include <ATen/ops/log1p_meta_dispatch.h>
|
| 178 |
+
#include <ATen/ops/log2_meta_dispatch.h>
|
| 179 |
+
#include <ATen/ops/log_normal_meta_dispatch.h>
|
| 180 |
+
#include <ATen/ops/logaddexp_meta_dispatch.h>
|
| 181 |
+
#include <ATen/ops/logaddexp2_meta_dispatch.h>
|
| 182 |
+
#include <ATen/ops/logit_meta_dispatch.h>
|
| 183 |
+
#include <ATen/ops/logit_backward_meta_dispatch.h>
|
| 184 |
+
#include <ATen/ops/logspace_meta_dispatch.h>
|
| 185 |
+
#include <ATen/ops/lshift_meta_dispatch.h>
|
| 186 |
+
#include <ATen/ops/lt_meta_dispatch.h>
|
| 187 |
+
#include <ATen/ops/lu_unpack_meta_dispatch.h>
|
| 188 |
+
#include <ATen/ops/masked_fill_meta_dispatch.h>
|
| 189 |
+
#include <ATen/ops/masked_scatter_meta_dispatch.h>
|
| 190 |
+
#include <ATen/ops/max_meta_dispatch.h>
|
| 191 |
+
#include <ATen/ops/max_pool2d_with_indices_meta_dispatch.h>
|
| 192 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_meta_dispatch.h>
|
| 193 |
+
#include <ATen/ops/maximum_meta_dispatch.h>
|
| 194 |
+
#include <ATen/ops/mean_meta_dispatch.h>
|
| 195 |
+
#include <ATen/ops/min_meta_dispatch.h>
|
| 196 |
+
#include <ATen/ops/minimum_meta_dispatch.h>
|
| 197 |
+
#include <ATen/ops/mish_meta_dispatch.h>
|
| 198 |
+
#include <ATen/ops/mm_meta_dispatch.h>
|
| 199 |
+
#include <ATen/ops/mse_loss_meta_dispatch.h>
|
| 200 |
+
#include <ATen/ops/mul_meta_dispatch.h>
|
| 201 |
+
#include <ATen/ops/ne_meta_dispatch.h>
|
| 202 |
+
#include <ATen/ops/neg_meta_dispatch.h>
|
| 203 |
+
#include <ATen/ops/nextafter_meta_dispatch.h>
|
| 204 |
+
#include <ATen/ops/nll_loss_backward_meta_dispatch.h>
|
| 205 |
+
#include <ATen/ops/nll_loss_forward_meta_dispatch.h>
|
| 206 |
+
#include <ATen/ops/norm_meta_dispatch.h>
|
| 207 |
+
#include <ATen/ops/normal_meta_dispatch.h>
|
| 208 |
+
#include <ATen/ops/polygamma_meta_dispatch.h>
|
| 209 |
+
#include <ATen/ops/pow_meta_dispatch.h>
|
| 210 |
+
#include <ATen/ops/prod_meta_dispatch.h>
|
| 211 |
+
#include <ATen/ops/put_meta_dispatch.h>
|
| 212 |
+
#include <ATen/ops/random_meta_dispatch.h>
|
| 213 |
+
#include <ATen/ops/range_meta_dispatch.h>
|
| 214 |
+
#include <ATen/ops/reciprocal_meta_dispatch.h>
|
| 215 |
+
#include <ATen/ops/reflection_pad1d_meta_dispatch.h>
|
| 216 |
+
#include <ATen/ops/reflection_pad1d_backward_meta_dispatch.h>
|
| 217 |
+
#include <ATen/ops/reflection_pad3d_meta_dispatch.h>
|
| 218 |
+
#include <ATen/ops/reflection_pad3d_backward_meta_dispatch.h>
|
| 219 |
+
#include <ATen/ops/relu_meta_dispatch.h>
|
| 220 |
+
#include <ATen/ops/remainder_meta_dispatch.h>
|
| 221 |
+
#include <ATen/ops/renorm_meta_dispatch.h>
|
| 222 |
+
#include <ATen/ops/replication_pad1d_meta_dispatch.h>
|
| 223 |
+
#include <ATen/ops/replication_pad1d_backward_meta_dispatch.h>
|
| 224 |
+
#include <ATen/ops/replication_pad2d_meta_dispatch.h>
|
| 225 |
+
#include <ATen/ops/replication_pad3d_meta_dispatch.h>
|
| 226 |
+
#include <ATen/ops/resize_meta_dispatch.h>
|
| 227 |
+
#include <ATen/ops/resize_as_sparse_meta_dispatch.h>
|
| 228 |
+
#include <ATen/ops/round_meta_dispatch.h>
|
| 229 |
+
#include <ATen/ops/rrelu_with_noise_meta_dispatch.h>
|
| 230 |
+
#include <ATen/ops/rshift_meta_dispatch.h>
|
| 231 |
+
#include <ATen/ops/rsqrt_meta_dispatch.h>
|
| 232 |
+
#include <ATen/ops/scatter_meta_dispatch.h>
|
| 233 |
+
#include <ATen/ops/scatter_add_meta_dispatch.h>
|
| 234 |
+
#include <ATen/ops/scatter_reduce_meta_dispatch.h>
|
| 235 |
+
#include <ATen/ops/set_meta_dispatch.h>
|
| 236 |
+
#include <ATen/ops/sgn_meta_dispatch.h>
|
| 237 |
+
#include <ATen/ops/sigmoid_meta_dispatch.h>
|
| 238 |
+
#include <ATen/ops/sigmoid_backward_meta_dispatch.h>
|
| 239 |
+
#include <ATen/ops/sign_meta_dispatch.h>
|
| 240 |
+
#include <ATen/ops/signbit_meta_dispatch.h>
|
| 241 |
+
#include <ATen/ops/silu_meta_dispatch.h>
|
| 242 |
+
#include <ATen/ops/silu_backward_meta_dispatch.h>
|
| 243 |
+
#include <ATen/ops/sin_meta_dispatch.h>
|
| 244 |
+
#include <ATen/ops/sinc_meta_dispatch.h>
|
| 245 |
+
#include <ATen/ops/sinh_meta_dispatch.h>
|
| 246 |
+
#include <ATen/ops/slow_conv_transpose2d_meta_dispatch.h>
|
| 247 |
+
#include <ATen/ops/smooth_l1_loss_meta_dispatch.h>
|
| 248 |
+
#include <ATen/ops/softplus_meta_dispatch.h>
|
| 249 |
+
#include <ATen/ops/softplus_backward_meta_dispatch.h>
|
| 250 |
+
#include <ATen/ops/softshrink_meta_dispatch.h>
|
| 251 |
+
#include <ATen/ops/softshrink_backward_meta_dispatch.h>
|
| 252 |
+
#include <ATen/ops/sort_meta_dispatch.h>
|
| 253 |
+
#include <ATen/ops/sparse_resize_meta_dispatch.h>
|
| 254 |
+
#include <ATen/ops/sparse_resize_and_clear_meta_dispatch.h>
|
| 255 |
+
#include <ATen/ops/special_airy_ai_meta_dispatch.h>
|
| 256 |
+
#include <ATen/ops/special_bessel_j0_meta_dispatch.h>
|
| 257 |
+
#include <ATen/ops/special_bessel_j1_meta_dispatch.h>
|
| 258 |
+
#include <ATen/ops/special_bessel_y0_meta_dispatch.h>
|
| 259 |
+
#include <ATen/ops/special_bessel_y1_meta_dispatch.h>
|
| 260 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_meta_dispatch.h>
|
| 261 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_meta_dispatch.h>
|
| 262 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_meta_dispatch.h>
|
| 263 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_meta_dispatch.h>
|
| 264 |
+
#include <ATen/ops/special_entr_meta_dispatch.h>
|
| 265 |
+
#include <ATen/ops/special_erfcx_meta_dispatch.h>
|
| 266 |
+
#include <ATen/ops/special_hermite_polynomial_h_meta_dispatch.h>
|
| 267 |
+
#include <ATen/ops/special_hermite_polynomial_he_meta_dispatch.h>
|
| 268 |
+
#include <ATen/ops/special_i0e_meta_dispatch.h>
|
| 269 |
+
#include <ATen/ops/special_i1_meta_dispatch.h>
|
| 270 |
+
#include <ATen/ops/special_i1e_meta_dispatch.h>
|
| 271 |
+
#include <ATen/ops/special_laguerre_polynomial_l_meta_dispatch.h>
|
| 272 |
+
#include <ATen/ops/special_legendre_polynomial_p_meta_dispatch.h>
|
| 273 |
+
#include <ATen/ops/special_log_ndtr_meta_dispatch.h>
|
| 274 |
+
#include <ATen/ops/special_modified_bessel_i0_meta_dispatch.h>
|
| 275 |
+
#include <ATen/ops/special_modified_bessel_i1_meta_dispatch.h>
|
| 276 |
+
#include <ATen/ops/special_modified_bessel_k0_meta_dispatch.h>
|
| 277 |
+
#include <ATen/ops/special_modified_bessel_k1_meta_dispatch.h>
|
| 278 |
+
#include <ATen/ops/special_ndtri_meta_dispatch.h>
|
| 279 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_meta_dispatch.h>
|
| 280 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_meta_dispatch.h>
|
| 281 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_meta_dispatch.h>
|
| 282 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_meta_dispatch.h>
|
| 283 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_meta_dispatch.h>
|
| 284 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_meta_dispatch.h>
|
| 285 |
+
#include <ATen/ops/special_spherical_bessel_j0_meta_dispatch.h>
|
| 286 |
+
#include <ATen/ops/special_xlog1py_meta_dispatch.h>
|
| 287 |
+
#include <ATen/ops/special_zeta_meta_dispatch.h>
|
| 288 |
+
#include <ATen/ops/sqrt_meta_dispatch.h>
|
| 289 |
+
#include <ATen/ops/sub_meta_dispatch.h>
|
| 290 |
+
#include <ATen/ops/sum_meta_dispatch.h>
|
| 291 |
+
#include <ATen/ops/tan_meta_dispatch.h>
|
| 292 |
+
#include <ATen/ops/tanh_meta_dispatch.h>
|
| 293 |
+
#include <ATen/ops/tanh_backward_meta_dispatch.h>
|
| 294 |
+
#include <ATen/ops/threshold_meta_dispatch.h>
|
| 295 |
+
#include <ATen/ops/threshold_backward_meta_dispatch.h>
|
| 296 |
+
#include <ATen/ops/topk_meta_dispatch.h>
|
| 297 |
+
#include <ATen/ops/triangular_solve_meta_dispatch.h>
|
| 298 |
+
#include <ATen/ops/tril_meta_dispatch.h>
|
| 299 |
+
#include <ATen/ops/triu_meta_dispatch.h>
|
| 300 |
+
#include <ATen/ops/trunc_meta_dispatch.h>
|
| 301 |
+
#include <ATen/ops/unfold_meta_dispatch.h>
|
| 302 |
+
#include <ATen/ops/uniform_meta_dispatch.h>
|
| 303 |
+
#include <ATen/ops/upsample_bicubic2d_meta_dispatch.h>
|
| 304 |
+
#include <ATen/ops/upsample_bicubic2d_backward_meta_dispatch.h>
|
| 305 |
+
#include <ATen/ops/upsample_bilinear2d_meta_dispatch.h>
|
| 306 |
+
#include <ATen/ops/upsample_bilinear2d_backward_meta_dispatch.h>
|
| 307 |
+
#include <ATen/ops/upsample_linear1d_meta_dispatch.h>
|
| 308 |
+
#include <ATen/ops/upsample_linear1d_backward_meta_dispatch.h>
|
| 309 |
+
#include <ATen/ops/upsample_nearest1d_meta_dispatch.h>
|
| 310 |
+
#include <ATen/ops/upsample_nearest1d_backward_meta_dispatch.h>
|
| 311 |
+
#include <ATen/ops/upsample_nearest2d_meta_dispatch.h>
|
| 312 |
+
#include <ATen/ops/upsample_nearest2d_backward_meta_dispatch.h>
|
| 313 |
+
#include <ATen/ops/upsample_nearest3d_meta_dispatch.h>
|
| 314 |
+
#include <ATen/ops/upsample_nearest3d_backward_meta_dispatch.h>
|
| 315 |
+
#include <ATen/ops/upsample_trilinear3d_meta_dispatch.h>
|
| 316 |
+
#include <ATen/ops/upsample_trilinear3d_backward_meta_dispatch.h>
|
| 317 |
+
#include <ATen/ops/view_meta_dispatch.h>
|
| 318 |
+
#include <ATen/ops/view_as_complex_meta_dispatch.h>
|
| 319 |
+
#include <ATen/ops/view_as_real_meta_dispatch.h>
|
| 320 |
+
#include <ATen/ops/xlogy_meta_dispatch.h>
|
| 321 |
+
#include <ATen/ops/zero_meta_dispatch.h>
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/List.h>
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
| 5 |
+
|
| 6 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 7 |
+
#include <ATen/Functions.h>
|
| 8 |
+
#else
|
| 9 |
+
#include <ATen/ops/equal.h>
|
| 10 |
+
#endif
|
| 11 |
+
|
| 12 |
+
namespace at {
|
| 13 |
+
|
| 14 |
+
// Note [Tensor-subclass-like Tensors]
|
| 15 |
+
// Tensor-subclass-like is defined as:
|
| 16 |
+
// - a Tensor subclass (via __torch_dispatch__ in Python or extending
|
| 17 |
+
// TensorImpl in C++)
|
| 18 |
+
// - anything else that shares the same perils as Tensor subclasses.
|
| 19 |
+
// For example, many Tensor subclasses do not have storage and meta Tensors
|
| 20 |
+
// do not have storage either, so meta Tensors belong here.
|
| 21 |
+
//
|
| 22 |
+
// We should ensure that PyTorch internals supports Tensor-subclass-like
|
| 23 |
+
// objects. In particular, Tensor-subclass-like objects struggle with two
|
| 24 |
+
// classes of operations that are problematic for Tensor subclasses:
|
| 25 |
+
// 1. Because some Tensor subclasses do not have storage, .item() or
|
| 26 |
+
// .data_ptr() calls are not good.
|
| 27 |
+
// 2. Certain in-place operations can eliminate the typing of the Tensor
|
| 28 |
+
// subclass. For example:
|
| 29 |
+
// >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input)
|
| 30 |
+
// If input is a Tensor subclass, then the above ends up either erroring out
|
| 31 |
+
// or returning a regular non-Tensor-subclass Tensor!
|
| 32 |
+
|
| 33 |
+
constexpr auto kFunctorchWrappedTensors = DispatchKeySet(
|
| 34 |
+
{DispatchKey::FuncTorchGradWrapper,
|
| 35 |
+
DispatchKey::FuncTorchBatched,
|
| 36 |
+
DispatchKey::Functionalize});
|
| 37 |
+
|
| 38 |
+
constexpr auto kTensorSubclassLike =
|
| 39 |
+
kFunctorchWrappedTensors |
|
| 40 |
+
DispatchKeySet(
|
| 41 |
+
{// WARNING: DO NOT put combined backend component + functionality keys
|
| 42 |
+
// here, you will incorrectly always match on the functionality key
|
| 43 |
+
// no matter the backend component
|
| 44 |
+
DispatchKey::Batched,
|
| 45 |
+
DispatchKey::Sparse,
|
| 46 |
+
DispatchKey::SparseCsr,
|
| 47 |
+
DispatchKey::Python}) |
|
| 48 |
+
DispatchKeySet(BackendComponent::MetaBit);
|
| 49 |
+
|
| 50 |
+
inline bool isTensorSubclassLike(const Tensor& tensor) {
|
| 51 |
+
if (c10::impl::dispatch_mode_enabled())
|
| 52 |
+
return true;
|
| 53 |
+
auto key_set = tensor.unsafeGetTensorImpl()->key_set();
|
| 54 |
+
return !(key_set & kTensorSubclassLike).empty();
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
inline bool areAnyTensorSubclassLike(TensorList tensors) {
|
| 58 |
+
if (c10::impl::dispatch_mode_enabled())
|
| 59 |
+
return true;
|
| 60 |
+
return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
inline bool areAnyOptionalTensorSubclassLike(
|
| 64 |
+
const c10::List<c10::optional<Tensor>>& tensors) {
|
| 65 |
+
if (c10::impl::dispatch_mode_enabled())
|
| 66 |
+
return true;
|
| 67 |
+
return std::any_of(
|
| 68 |
+
tensors.begin(), tensors.end(), [](const optional<Tensor>& opt_tensor) {
|
| 69 |
+
return (
|
| 70 |
+
opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value()));
|
| 71 |
+
});
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
// Helper function to deal testing truthfulness of a scalar tensor
|
| 75 |
+
// in a Composite Compliant manner.
|
| 76 |
+
// NOTE: This function expects a scalar tensor of boolean dtype.
|
| 77 |
+
// Eg.
|
| 78 |
+
// Non-Composite Compliant Pattern : (t == 0).all().item<bool>()
|
| 79 |
+
// Composite Compliant Patter : is_salar_tensor_true((t == 0).all())
|
| 80 |
+
inline bool is_scalar_tensor_true(const Tensor& t) {
|
| 81 |
+
TORCH_INTERNAL_ASSERT(t.dim() == 0)
|
| 82 |
+
TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool)
|
| 83 |
+
return at::equal(t, t.new_ones({}, t.options()));
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda.h>
|
| 4 |
+
#include <cuda_runtime.h>
|
| 5 |
+
#include <cuda_fp16.h>
|
| 6 |
+
|
| 7 |
+
#include <c10/macros/Export.h>
|
| 8 |
+
|
| 9 |
+
// Use TORCH_CUDA_CPP_API or TORCH_CUDA_CU_API for exports from this folder
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContext.h
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/CUDAContextLight.h>
|
| 4 |
+
|
| 5 |
+
// Preserved for BC, as many files depend on these includes
|
| 6 |
+
#include <ATen/Context.h>
|
| 7 |
+
#include <c10/cuda/CUDAStream.h>
|
| 8 |
+
#include <c10/util/Logging.h>
|
| 9 |
+
#include <ATen/cuda/Exceptions.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/ScalarType.h>
|
| 4 |
+
|
| 5 |
+
#include <cuda.h>
|
| 6 |
+
#include <library_types.h>
|
| 7 |
+
|
| 8 |
+
namespace at::cuda {
|
| 9 |
+
|
| 10 |
+
template <typename scalar_t>
|
| 11 |
+
cudaDataType getCudaDataType() {
|
| 12 |
+
TORCH_INTERNAL_ASSERT(false, "Cannot convert type ", typeid(scalar_t).name(), " to cudaDataType.")
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
template<> inline cudaDataType getCudaDataType<at::Half>() {
|
| 16 |
+
return CUDA_R_16F;
|
| 17 |
+
}
|
| 18 |
+
template<> inline cudaDataType getCudaDataType<float>() {
|
| 19 |
+
return CUDA_R_32F;
|
| 20 |
+
}
|
| 21 |
+
template<> inline cudaDataType getCudaDataType<double>() {
|
| 22 |
+
return CUDA_R_64F;
|
| 23 |
+
}
|
| 24 |
+
template<> inline cudaDataType getCudaDataType<c10::complex<c10::Half>>() {
|
| 25 |
+
return CUDA_C_16F;
|
| 26 |
+
}
|
| 27 |
+
template<> inline cudaDataType getCudaDataType<c10::complex<float>>() {
|
| 28 |
+
return CUDA_C_32F;
|
| 29 |
+
}
|
| 30 |
+
template<> inline cudaDataType getCudaDataType<c10::complex<double>>() {
|
| 31 |
+
return CUDA_C_64F;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// HIP doesn't define integral types
|
| 35 |
+
#ifndef USE_ROCM
|
| 36 |
+
template<> inline cudaDataType getCudaDataType<uint8_t>() {
|
| 37 |
+
return CUDA_R_8U;
|
| 38 |
+
}
|
| 39 |
+
template<> inline cudaDataType getCudaDataType<int8_t>() {
|
| 40 |
+
return CUDA_R_8I;
|
| 41 |
+
}
|
| 42 |
+
template<> inline cudaDataType getCudaDataType<int>() {
|
| 43 |
+
return CUDA_R_32I;
|
| 44 |
+
}
|
| 45 |
+
#endif
|
| 46 |
+
|
| 47 |
+
#if !defined(USE_ROCM)
|
| 48 |
+
template<> inline cudaDataType getCudaDataType<int16_t>() {
|
| 49 |
+
return CUDA_R_16I;
|
| 50 |
+
}
|
| 51 |
+
template<> inline cudaDataType getCudaDataType<int64_t>() {
|
| 52 |
+
return CUDA_R_64I;
|
| 53 |
+
}
|
| 54 |
+
template<> inline cudaDataType getCudaDataType<at::BFloat16>() {
|
| 55 |
+
return CUDA_R_16BF;
|
| 56 |
+
}
|
| 57 |
+
#endif
|
| 58 |
+
|
| 59 |
+
inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) {
|
| 60 |
+
switch (scalar_type) {
|
| 61 |
+
// HIP doesn't define integral types
|
| 62 |
+
#ifndef USE_ROCM
|
| 63 |
+
case c10::ScalarType::Byte:
|
| 64 |
+
return CUDA_R_8U;
|
| 65 |
+
case c10::ScalarType::Char:
|
| 66 |
+
return CUDA_R_8I;
|
| 67 |
+
case c10::ScalarType::Int:
|
| 68 |
+
return CUDA_R_32I;
|
| 69 |
+
#endif
|
| 70 |
+
case c10::ScalarType::Half:
|
| 71 |
+
return CUDA_R_16F;
|
| 72 |
+
case c10::ScalarType::Float:
|
| 73 |
+
return CUDA_R_32F;
|
| 74 |
+
case c10::ScalarType::Double:
|
| 75 |
+
return CUDA_R_64F;
|
| 76 |
+
case c10::ScalarType::ComplexHalf:
|
| 77 |
+
return CUDA_C_16F;
|
| 78 |
+
case c10::ScalarType::ComplexFloat:
|
| 79 |
+
return CUDA_C_32F;
|
| 80 |
+
case c10::ScalarType::ComplexDouble:
|
| 81 |
+
return CUDA_C_64F;
|
| 82 |
+
#if !defined(USE_ROCM)
|
| 83 |
+
case c10::ScalarType::Short:
|
| 84 |
+
return CUDA_R_16I;
|
| 85 |
+
case c10::ScalarType::Long:
|
| 86 |
+
return CUDA_R_64I;
|
| 87 |
+
case c10::ScalarType::BFloat16:
|
| 88 |
+
return CUDA_R_16BF;
|
| 89 |
+
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
|
| 90 |
+
case c10::ScalarType::Float8_e4m3fn:
|
| 91 |
+
return CUDA_R_8F_E4M3;
|
| 92 |
+
case c10::ScalarType::Float8_e5m2:
|
| 93 |
+
return CUDA_R_8F_E5M2;
|
| 94 |
+
#endif
|
| 95 |
+
#else // USE_ROCM
|
| 96 |
+
case c10::ScalarType::BFloat16:
|
| 97 |
+
return CUDA_R_16BF;
|
| 98 |
+
#if defined(HIP_NEW_TYPE_ENUMS)
|
| 99 |
+
case c10::ScalarType::Float8_e4m3fnuz:
|
| 100 |
+
return HIP_R_8F_E4M3_FNUZ;
|
| 101 |
+
case c10::ScalarType::Float8_e5m2fnuz:
|
| 102 |
+
return HIP_R_8F_E5M2_FNUZ;
|
| 103 |
+
#else
|
| 104 |
+
case c10::ScalarType::Float8_e4m3fnuz:
|
| 105 |
+
return static_cast<hipDataType>(1000);
|
| 106 |
+
case c10::ScalarType::Float8_e5m2fnuz:
|
| 107 |
+
return static_cast<hipDataType>(1001);
|
| 108 |
+
#endif
|
| 109 |
+
#endif
|
| 110 |
+
default:
|
| 111 |
+
TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.")
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
} // namespace at::cuda
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Exceptions.h
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cublas_v2.h>
|
| 4 |
+
#include <cusparse.h>
|
| 5 |
+
#include <c10/macros/Export.h>
|
| 6 |
+
|
| 7 |
+
#ifdef CUDART_VERSION
|
| 8 |
+
#include <cusolver_common.h>
|
| 9 |
+
#endif
|
| 10 |
+
|
| 11 |
+
#include <ATen/Context.h>
|
| 12 |
+
#include <c10/util/Exception.h>
|
| 13 |
+
#include <c10/cuda/CUDAException.h>
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
namespace c10 {
|
| 17 |
+
|
| 18 |
+
class CuDNNError : public c10::Error {
|
| 19 |
+
using Error::Error;
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
} // namespace c10
|
| 23 |
+
|
| 24 |
+
#define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \
|
| 25 |
+
do { \
|
| 26 |
+
auto error_object = EXPR; \
|
| 27 |
+
if (!error_object.is_good()) { \
|
| 28 |
+
TORCH_CHECK_WITH(CuDNNError, false, \
|
| 29 |
+
"cuDNN Frontend error: ", error_object.get_message()); \
|
| 30 |
+
} \
|
| 31 |
+
} while (0) \
|
| 32 |
+
|
| 33 |
+
#define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__)
|
| 34 |
+
|
| 35 |
+
// See Note [CHECK macro]
|
| 36 |
+
#define AT_CUDNN_CHECK(EXPR, ...) \
|
| 37 |
+
do { \
|
| 38 |
+
cudnnStatus_t status = EXPR; \
|
| 39 |
+
if (status != CUDNN_STATUS_SUCCESS) { \
|
| 40 |
+
if (status == CUDNN_STATUS_NOT_SUPPORTED) { \
|
| 41 |
+
TORCH_CHECK_WITH(CuDNNError, false, \
|
| 42 |
+
"cuDNN error: ", \
|
| 43 |
+
cudnnGetErrorString(status), \
|
| 44 |
+
". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \
|
| 45 |
+
} else { \
|
| 46 |
+
TORCH_CHECK_WITH(CuDNNError, false, \
|
| 47 |
+
"cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \
|
| 48 |
+
} \
|
| 49 |
+
} \
|
| 50 |
+
} while (0)
|
| 51 |
+
|
| 52 |
+
namespace at::cuda::blas {
|
| 53 |
+
C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
|
| 54 |
+
} // namespace at::cuda::blas
|
| 55 |
+
|
| 56 |
+
#define TORCH_CUDABLAS_CHECK(EXPR) \
|
| 57 |
+
do { \
|
| 58 |
+
cublasStatus_t __err = EXPR; \
|
| 59 |
+
TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \
|
| 60 |
+
"CUDA error: ", \
|
| 61 |
+
at::cuda::blas::_cublasGetErrorEnum(__err), \
|
| 62 |
+
" when calling `" #EXPR "`"); \
|
| 63 |
+
} while (0)
|
| 64 |
+
|
| 65 |
+
const char *cusparseGetErrorString(cusparseStatus_t status);
|
| 66 |
+
|
| 67 |
+
#define TORCH_CUDASPARSE_CHECK(EXPR) \
|
| 68 |
+
do { \
|
| 69 |
+
cusparseStatus_t __err = EXPR; \
|
| 70 |
+
TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \
|
| 71 |
+
"CUDA error: ", \
|
| 72 |
+
cusparseGetErrorString(__err), \
|
| 73 |
+
" when calling `" #EXPR "`"); \
|
| 74 |
+
} while (0)
|
| 75 |
+
|
| 76 |
+
// cusolver related headers are only supported on cuda now
|
| 77 |
+
#ifdef CUDART_VERSION
|
| 78 |
+
|
| 79 |
+
namespace at::cuda::solver {
|
| 80 |
+
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
|
| 81 |
+
|
| 82 |
+
constexpr const char* _cusolver_backend_suggestion = \
|
| 83 |
+
"If you keep seeing this error, you may use " \
|
| 84 |
+
"`torch.backends.cuda.preferred_linalg_library()` to try " \
|
| 85 |
+
"linear algebra operators with other supported backends. " \
|
| 86 |
+
"See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";
|
| 87 |
+
|
| 88 |
+
} // namespace at::cuda::solver
|
| 89 |
+
|
| 90 |
+
// When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
|
| 91 |
+
// When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
|
| 92 |
+
#define TORCH_CUSOLVER_CHECK(EXPR) \
|
| 93 |
+
do { \
|
| 94 |
+
cusolverStatus_t __err = EXPR; \
|
| 95 |
+
if ((CUDA_VERSION < 11500 && \
|
| 96 |
+
__err == CUSOLVER_STATUS_EXECUTION_FAILED) || \
|
| 97 |
+
(CUDA_VERSION >= 11500 && \
|
| 98 |
+
__err == CUSOLVER_STATUS_INVALID_VALUE)) { \
|
| 99 |
+
TORCH_CHECK_LINALG( \
|
| 100 |
+
false, \
|
| 101 |
+
"cusolver error: ", \
|
| 102 |
+
at::cuda::solver::cusolverGetErrorMessage(__err), \
|
| 103 |
+
", when calling `" #EXPR "`", \
|
| 104 |
+
". This error may appear if the input matrix contains NaN. ", \
|
| 105 |
+
at::cuda::solver::_cusolver_backend_suggestion); \
|
| 106 |
+
} else { \
|
| 107 |
+
TORCH_CHECK( \
|
| 108 |
+
__err == CUSOLVER_STATUS_SUCCESS, \
|
| 109 |
+
"cusolver error: ", \
|
| 110 |
+
at::cuda::solver::cusolverGetErrorMessage(__err), \
|
| 111 |
+
", when calling `" #EXPR "`. ", \
|
| 112 |
+
at::cuda::solver::_cusolver_backend_suggestion); \
|
| 113 |
+
} \
|
| 114 |
+
} while (0)
|
| 115 |
+
|
| 116 |
+
#else
|
| 117 |
+
#define TORCH_CUSOLVER_CHECK(EXPR) EXPR
|
| 118 |
+
#endif
|
| 119 |
+
|
| 120 |
+
#define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR)
|
| 121 |
+
|
| 122 |
+
// For CUDA Driver API
|
| 123 |
+
//
|
| 124 |
+
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
|
| 125 |
+
// in ATen, and we need to use its nvrtcGetErrorString.
|
| 126 |
+
// See NOTE [ USE OF NVRTC AND DRIVER API ].
|
| 127 |
+
#if !defined(USE_ROCM)
|
| 128 |
+
|
| 129 |
+
#define AT_CUDA_DRIVER_CHECK(EXPR) \
|
| 130 |
+
do { \
|
| 131 |
+
CUresult __err = EXPR; \
|
| 132 |
+
if (__err != CUDA_SUCCESS) { \
|
| 133 |
+
const char* err_str; \
|
| 134 |
+
CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \
|
| 135 |
+
if (get_error_str_err != CUDA_SUCCESS) { \
|
| 136 |
+
AT_ERROR("CUDA driver error: unknown error"); \
|
| 137 |
+
} else { \
|
| 138 |
+
AT_ERROR("CUDA driver error: ", err_str); \
|
| 139 |
+
} \
|
| 140 |
+
} \
|
| 141 |
+
} while (0)
|
| 142 |
+
|
| 143 |
+
#else
|
| 144 |
+
|
| 145 |
+
#define AT_CUDA_DRIVER_CHECK(EXPR) \
|
| 146 |
+
do { \
|
| 147 |
+
CUresult __err = EXPR; \
|
| 148 |
+
if (__err != CUDA_SUCCESS) { \
|
| 149 |
+
AT_ERROR("CUDA driver error: ", static_cast<int>(__err)); \
|
| 150 |
+
} \
|
| 151 |
+
} while (0)
|
| 152 |
+
|
| 153 |
+
#endif
|
| 154 |
+
|
| 155 |
+
// For CUDA NVRTC
|
| 156 |
+
//
|
| 157 |
+
// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE,
|
| 158 |
+
// incorrectly produces the error string "NVRTC unknown error."
|
| 159 |
+
// The following maps it correctly.
|
| 160 |
+
//
|
| 161 |
+
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
|
| 162 |
+
// in ATen, and we need to use its nvrtcGetErrorString.
|
| 163 |
+
// See NOTE [ USE OF NVRTC AND DRIVER API ].
|
| 164 |
+
#define AT_CUDA_NVRTC_CHECK(EXPR) \
|
| 165 |
+
do { \
|
| 166 |
+
nvrtcResult __err = EXPR; \
|
| 167 |
+
if (__err != NVRTC_SUCCESS) { \
|
| 168 |
+
if (static_cast<int>(__err) != 7) { \
|
| 169 |
+
AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \
|
| 170 |
+
} else { \
|
| 171 |
+
AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \
|
| 172 |
+
} \
|
| 173 |
+
} \
|
| 174 |
+
} while (0)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSDevice.h
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
#include <c10/core/Allocator.h>
|
| 5 |
+
#include <c10/macros/Macros.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
#ifdef __OBJC__
|
| 10 |
+
#include <Foundation/Foundation.h>
|
| 11 |
+
#include <Metal/Metal.h>
|
| 12 |
+
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
| 13 |
+
typedef id<MTLDevice> MTLDevice_t;
|
| 14 |
+
typedef id<MTLLibrary> MTLLibrary_t;
|
| 15 |
+
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
|
| 16 |
+
typedef id<MTLLibrary> MTLLibrary_t;
|
| 17 |
+
#else
|
| 18 |
+
typedef void* MTLDevice;
|
| 19 |
+
typedef void* MTLDevice_t;
|
| 20 |
+
typedef void* MTLLibrary_t;
|
| 21 |
+
typedef void* MTLComputePipelineState_t;
|
| 22 |
+
typedef void* MTLLibrary_t;
|
| 23 |
+
#endif
|
| 24 |
+
|
| 25 |
+
using namespace std;
|
| 26 |
+
|
| 27 |
+
namespace at::mps {
|
| 28 |
+
|
| 29 |
+
// Helper enum to check if a MPSGraph op is supported in a given macOS version
|
| 30 |
+
enum class MacOSVersion : uint32_t {
|
| 31 |
+
MACOS_VER_13_0_PLUS = 0,
|
| 32 |
+
MACOS_VER_13_1_PLUS,
|
| 33 |
+
MACOS_VER_13_2_PLUS,
|
| 34 |
+
MACOS_VER_13_3_PLUS,
|
| 35 |
+
MACOS_VER_14_0_PLUS,
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
//-----------------------------------------------------------------
|
| 39 |
+
// MPSDevice
|
| 40 |
+
//
|
| 41 |
+
// MPSDevice is a singleton class that returns the default device
|
| 42 |
+
//-----------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
class TORCH_API MPSDevice {
|
| 45 |
+
public:
|
| 46 |
+
/**
|
| 47 |
+
* MPSDevice should not be cloneable.
|
| 48 |
+
*/
|
| 49 |
+
MPSDevice(MPSDevice& other) = delete;
|
| 50 |
+
/**
|
| 51 |
+
* MPSDevice should not be assignable.
|
| 52 |
+
*/
|
| 53 |
+
void operator=(const MPSDevice&) = delete;
|
| 54 |
+
/**
|
| 55 |
+
* Gets single instance of the Device.
|
| 56 |
+
*/
|
| 57 |
+
static MPSDevice* getInstance();
|
| 58 |
+
/**
|
| 59 |
+
* Returns the single device.
|
| 60 |
+
*/
|
| 61 |
+
MTLDevice_t device() {
|
| 62 |
+
return _mtl_device;
|
| 63 |
+
}
|
| 64 |
+
/**
|
| 65 |
+
* Returns whether running on Ventura or newer
|
| 66 |
+
*/
|
| 67 |
+
bool isMacOS13Plus(MacOSVersion version) const;
|
| 68 |
+
|
| 69 |
+
MTLComputePipelineState_t metalIndexingPSO(const std::string &kernel);
|
| 70 |
+
MTLLibrary_t getMetalIndexingLibrary();
|
| 71 |
+
|
| 72 |
+
~MPSDevice();
|
| 73 |
+
|
| 74 |
+
private:
|
| 75 |
+
static MPSDevice* _device;
|
| 76 |
+
MTLDevice_t _mtl_device;
|
| 77 |
+
MTLLibrary_t _mtl_indexing_library;
|
| 78 |
+
MPSDevice();
|
| 79 |
+
};
|
| 80 |
+
|
| 81 |
+
TORCH_API bool is_available();
|
| 82 |
+
TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
|
| 83 |
+
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
|
| 84 |
+
|
| 85 |
+
} // namespace at::mps
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AmpKernels.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <ATen/core/ATen_fwd.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
class Tensor;
|
| 8 |
+
|
| 9 |
+
namespace native {
|
| 10 |
+
|
| 11 |
+
using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)(
|
| 12 |
+
TensorList,
|
| 13 |
+
Tensor&,
|
| 14 |
+
const Tensor&);
|
| 15 |
+
|
| 16 |
+
using _amp_update_scale_cpu__fn = Tensor& (*)(
|
| 17 |
+
Tensor&,
|
| 18 |
+
Tensor&,
|
| 19 |
+
const Tensor&,
|
| 20 |
+
double,
|
| 21 |
+
double,
|
| 22 |
+
int64_t);
|
| 23 |
+
|
| 24 |
+
DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub);
|
| 25 |
+
DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub);
|
| 26 |
+
|
| 27 |
+
} // namespace native
|
| 28 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUBlas.h
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/OpMathType.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <ATen/native/TransposeType.h>
|
| 6 |
+
#include <c10/util/complex.h>
|
| 7 |
+
#include <c10/core/ScalarType.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
|
| 10 |
+
namespace at::native::cpublas {
|
| 11 |
+
|
| 12 |
+
namespace internal {
|
| 13 |
+
void normalize_last_dims(
|
| 14 |
+
TransposeType transa, TransposeType transb,
|
| 15 |
+
int64_t m, int64_t n, int64_t k,
|
| 16 |
+
int64_t *lda, int64_t *ldb, int64_t *ldc);
|
| 17 |
+
} // namespace internal
|
| 18 |
+
|
| 19 |
+
using gemm_fn = void(*)(
|
| 20 |
+
at::ScalarType type,
|
| 21 |
+
TransposeType transa, TransposeType transb,
|
| 22 |
+
int64_t m, int64_t n, int64_t k,
|
| 23 |
+
const Scalar& alpha,
|
| 24 |
+
const void *a, int64_t lda,
|
| 25 |
+
const void *b, int64_t ldb,
|
| 26 |
+
const Scalar& beta,
|
| 27 |
+
void *c, int64_t ldc);
|
| 28 |
+
|
| 29 |
+
DECLARE_DISPATCH(gemm_fn, gemm_stub);
|
| 30 |
+
|
| 31 |
+
template <typename scalar_t>
|
| 32 |
+
void gemm(
|
| 33 |
+
TransposeType transa, TransposeType transb,
|
| 34 |
+
int64_t m, int64_t n, int64_t k,
|
| 35 |
+
at::opmath_type<scalar_t> alpha,
|
| 36 |
+
const scalar_t *a, int64_t lda,
|
| 37 |
+
const scalar_t *b, int64_t ldb,
|
| 38 |
+
at::opmath_type<scalar_t> beta,
|
| 39 |
+
scalar_t *c, int64_t ldc) {
|
| 40 |
+
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
| 41 |
+
gemm_stub(
|
| 42 |
+
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
|
| 43 |
+
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
void gemm(
|
| 47 |
+
TransposeType transa, TransposeType transb,
|
| 48 |
+
int64_t m, int64_t n, int64_t k,
|
| 49 |
+
double alpha,
|
| 50 |
+
const double *a, int64_t lda,
|
| 51 |
+
const double *b, int64_t ldb,
|
| 52 |
+
double beta,
|
| 53 |
+
double *c, int64_t ldc);
|
| 54 |
+
|
| 55 |
+
void gemm(
|
| 56 |
+
TransposeType transa, TransposeType transb,
|
| 57 |
+
int64_t m, int64_t n, int64_t k,
|
| 58 |
+
float alpha,
|
| 59 |
+
const float *a, int64_t lda,
|
| 60 |
+
const float *b, int64_t ldb,
|
| 61 |
+
float beta,
|
| 62 |
+
float *c, int64_t ldc);
|
| 63 |
+
|
| 64 |
+
void gemm(
|
| 65 |
+
TransposeType transa, TransposeType transb,
|
| 66 |
+
int64_t m, int64_t n, int64_t k,
|
| 67 |
+
float alpha,
|
| 68 |
+
const at::BFloat16 *a, int64_t lda,
|
| 69 |
+
const at::BFloat16 *b, int64_t ldb,
|
| 70 |
+
float beta,
|
| 71 |
+
at::BFloat16 *c, int64_t ldc);
|
| 72 |
+
|
| 73 |
+
void gemm(
|
| 74 |
+
TransposeType transa, TransposeType transb,
|
| 75 |
+
int64_t m, int64_t n, int64_t k,
|
| 76 |
+
const float alpha,
|
| 77 |
+
const at::BFloat16 *a, int64_t lda,
|
| 78 |
+
const at::BFloat16 *b, int64_t ldb,
|
| 79 |
+
const float beta,
|
| 80 |
+
float *c, int64_t ldc);
|
| 81 |
+
|
| 82 |
+
void gemm(
|
| 83 |
+
TransposeType transa, TransposeType transb,
|
| 84 |
+
int64_t m, int64_t n, int64_t k,
|
| 85 |
+
float alpha,
|
| 86 |
+
const at::Half *a, int64_t lda,
|
| 87 |
+
const at::Half *b, int64_t ldb,
|
| 88 |
+
float beta,
|
| 89 |
+
at::Half *c, int64_t ldc);
|
| 90 |
+
|
| 91 |
+
void gemm(
|
| 92 |
+
TransposeType transa, TransposeType transb,
|
| 93 |
+
int64_t m, int64_t n, int64_t k,
|
| 94 |
+
const float alpha,
|
| 95 |
+
const at::Half *a, int64_t lda,
|
| 96 |
+
const at::Half *b, int64_t ldb,
|
| 97 |
+
const float beta,
|
| 98 |
+
float *c, int64_t ldc);
|
| 99 |
+
|
| 100 |
+
void gemm(
|
| 101 |
+
TransposeType transa, TransposeType transb,
|
| 102 |
+
int64_t m, int64_t n, int64_t k,
|
| 103 |
+
c10::complex<double> alpha,
|
| 104 |
+
const c10::complex<double> *a, int64_t lda,
|
| 105 |
+
const c10::complex<double> *b, int64_t ldb,
|
| 106 |
+
c10::complex<double> beta,
|
| 107 |
+
c10::complex<double> *c, int64_t ldc);
|
| 108 |
+
|
| 109 |
+
void gemm(
|
| 110 |
+
TransposeType transa, TransposeType transb,
|
| 111 |
+
int64_t m, int64_t n, int64_t k,
|
| 112 |
+
c10::complex<float> alpha,
|
| 113 |
+
const c10::complex<float> *a, int64_t lda,
|
| 114 |
+
const c10::complex<float> *b, int64_t ldb,
|
| 115 |
+
c10::complex<float> beta,
|
| 116 |
+
c10::complex<float> *c, int64_t ldc);
|
| 117 |
+
|
| 118 |
+
void gemm(
|
| 119 |
+
TransposeType transa, TransposeType transb,
|
| 120 |
+
int64_t m, int64_t n, int64_t k,
|
| 121 |
+
int64_t alpha,
|
| 122 |
+
const int64_t *a, int64_t lda,
|
| 123 |
+
const int64_t *b, int64_t ldb,
|
| 124 |
+
int64_t beta,
|
| 125 |
+
int64_t *c, int64_t ldc);
|
| 126 |
+
|
| 127 |
+
template <typename scalar_t>
|
| 128 |
+
void gemm_batched(
|
| 129 |
+
TransposeType transa, TransposeType transb,
|
| 130 |
+
int64_t batch_size, int64_t m, int64_t n, int64_t k,
|
| 131 |
+
scalar_t alpha,
|
| 132 |
+
const scalar_t * const *a, int64_t lda,
|
| 133 |
+
const scalar_t * const *b, int64_t ldb,
|
| 134 |
+
const scalar_t beta,
|
| 135 |
+
scalar_t * const *c, int64_t ldc);
|
| 136 |
+
|
| 137 |
+
template <typename scalar_t>
|
| 138 |
+
void gemm_batched_with_stride(
|
| 139 |
+
TransposeType transa, TransposeType transb,
|
| 140 |
+
int64_t batch_size, int64_t m, int64_t n, int64_t k,
|
| 141 |
+
scalar_t alpha,
|
| 142 |
+
const scalar_t *a, int64_t lda, int64_t batch_stride_a,
|
| 143 |
+
const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
|
| 144 |
+
scalar_t beta,
|
| 145 |
+
scalar_t *c, int64_t ldc, int64_t batch_stride_c);
|
| 146 |
+
|
| 147 |
+
using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
|
| 148 |
+
|
| 149 |
+
DECLARE_DISPATCH(axpy_fn, axpy_stub);
|
| 150 |
+
|
| 151 |
+
template<typename scalar_t>
|
| 152 |
+
void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
|
| 153 |
+
if(n == 1)
|
| 154 |
+
{
|
| 155 |
+
incx = 1;
|
| 156 |
+
incy = 1;
|
| 157 |
+
}
|
| 158 |
+
axpy_stub(
|
| 159 |
+
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
|
| 160 |
+
n, a, x, incx, y, incy);
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
|
| 164 |
+
void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
|
| 165 |
+
void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
|
| 166 |
+
void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
|
| 167 |
+
|
| 168 |
+
using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
|
| 169 |
+
|
| 170 |
+
DECLARE_DISPATCH(copy_fn, copy_stub);
|
| 171 |
+
|
| 172 |
+
template<typename scalar_t>
|
| 173 |
+
void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
|
| 174 |
+
if(n == 1)
|
| 175 |
+
{
|
| 176 |
+
incx = 1;
|
| 177 |
+
incy = 1;
|
| 178 |
+
}
|
| 179 |
+
copy_stub(
|
| 180 |
+
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
|
| 181 |
+
n, x, incx, y, incy);
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
|
| 185 |
+
void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
|
| 186 |
+
void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
|
| 187 |
+
void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
|
| 188 |
+
|
| 189 |
+
} // namespace at::native::cpublas
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <utility>
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
namespace {
|
| 8 |
+
|
| 9 |
+
// operator_brackets_proxy is used in
|
| 10 |
+
// CompositeRandomAccessor in place of operator[].
|
| 11 |
+
// For some iterators, references returned by operator[]
|
| 12 |
+
// could become invalid, operator_brackets_proxy tries to
|
| 13 |
+
// resolve that by making accessor[n] to be equivalent to
|
| 14 |
+
// *(accessor + n).
|
| 15 |
+
template <typename Accessor>
|
| 16 |
+
class operator_brackets_proxy {
|
| 17 |
+
using reference = typename std::iterator_traits<Accessor>::reference;
|
| 18 |
+
using value_type = typename std::iterator_traits<Accessor>::value_type;
|
| 19 |
+
|
| 20 |
+
public:
|
| 21 |
+
C10_HOST_DEVICE
|
| 22 |
+
operator_brackets_proxy(Accessor const& accessor)
|
| 23 |
+
: accessor(accessor)
|
| 24 |
+
{}
|
| 25 |
+
|
| 26 |
+
C10_HOST_DEVICE
|
| 27 |
+
operator reference() {
|
| 28 |
+
return *accessor;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
C10_HOST_DEVICE
|
| 32 |
+
reference operator*() {
|
| 33 |
+
return *accessor;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
C10_HOST_DEVICE
|
| 37 |
+
operator_brackets_proxy& operator=(value_type const& val) {
|
| 38 |
+
*accessor = val;
|
| 39 |
+
return *this;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
private:
|
| 43 |
+
Accessor accessor;
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// references_holder is used as a surrogate for the
|
| 49 |
+
// references type from std::iterator_traits in CompositeRandomAccessor.
|
| 50 |
+
// It is assumed in CompositeRandomAccessor that
|
| 51 |
+
// References = tuple<Types&...>,
|
| 52 |
+
// Values = tuple<Types...> by default,
|
| 53 |
+
// but they could be anything as long as References could be
|
| 54 |
+
// cast to Values.
|
| 55 |
+
// If you plan to use it with STL, for example, you will need to
|
| 56 |
+
// define 'swap` and `get`(aka std::get) methods.
|
| 57 |
+
template <typename Values, typename References>
|
| 58 |
+
class references_holder {
|
| 59 |
+
public:
|
| 60 |
+
using values = Values;
|
| 61 |
+
using references = References;
|
| 62 |
+
|
| 63 |
+
C10_HOST_DEVICE
|
| 64 |
+
references_holder(references refs)
|
| 65 |
+
: refs{std::move(refs)}
|
| 66 |
+
{}
|
| 67 |
+
|
| 68 |
+
C10_HOST_DEVICE
|
| 69 |
+
operator references() {
|
| 70 |
+
return refs;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
C10_HOST_DEVICE
|
| 74 |
+
operator values() {
|
| 75 |
+
return refs;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
C10_HOST_DEVICE
|
| 79 |
+
references_holder& operator=(values vals) {
|
| 80 |
+
refs = vals;
|
| 81 |
+
return *this;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
C10_HOST_DEVICE
|
| 85 |
+
references& data() {
|
| 86 |
+
return refs;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
protected:
|
| 90 |
+
references refs;
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
// CompositeRandomAccessor is essentially a simplified version of
|
| 94 |
+
// a random access iterator over two random access iterators.
|
| 95 |
+
// TupleInfo should contain a variadic type `tuple`, and a method `tie`,
|
| 96 |
+
// which constructs a tuple of references from a variadic list of arguments.
|
| 97 |
+
template <typename KeyAccessor, typename ValueAccessor, typename TupleInfo>
|
| 98 |
+
class CompositeRandomAccessor {
|
| 99 |
+
using self_type = CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfo>;
|
| 100 |
+
|
| 101 |
+
using key_accessor_value_type =
|
| 102 |
+
typename std::iterator_traits<KeyAccessor>::value_type;
|
| 103 |
+
using value_accessor_value_type =
|
| 104 |
+
typename std::iterator_traits<ValueAccessor>::value_type;
|
| 105 |
+
using key_accessor_reference_type =
|
| 106 |
+
typename std::iterator_traits<KeyAccessor>::reference;
|
| 107 |
+
using value_accessor_reference_type =
|
| 108 |
+
typename std::iterator_traits<ValueAccessor>::reference;
|
| 109 |
+
|
| 110 |
+
using composite_value_type = typename TupleInfo::template tuple<
|
| 111 |
+
key_accessor_value_type,
|
| 112 |
+
value_accessor_value_type>;
|
| 113 |
+
using composite_reference = typename TupleInfo::template tuple<
|
| 114 |
+
key_accessor_reference_type,
|
| 115 |
+
value_accessor_reference_type>;
|
| 116 |
+
|
| 117 |
+
public:
|
| 118 |
+
using value_type = composite_value_type;
|
| 119 |
+
using reference = references_holder<composite_value_type, composite_reference>;
|
| 120 |
+
// Note that CompositeRandomAccessor does not hold key and values
|
| 121 |
+
// in a specific datastructure, which means that a pointer to a (key, value)
|
| 122 |
+
// is not defined. Hence we just use a pointer type of the KeyAccessor.
|
| 123 |
+
using pointer = typename std::iterator_traits<KeyAccessor>::pointer;
|
| 124 |
+
using difference_type = typename std::iterator_traits<KeyAccessor>::difference_type;
|
| 125 |
+
using iterator_category = std::random_access_iterator_tag;
|
| 126 |
+
|
| 127 |
+
C10_HOST_DEVICE
|
| 128 |
+
CompositeRandomAccessor() = default;
|
| 129 |
+
|
| 130 |
+
C10_HOST_DEVICE
|
| 131 |
+
CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values)
|
| 132 |
+
: keys(keys), values(values)
|
| 133 |
+
{}
|
| 134 |
+
|
| 135 |
+
// Pointer-like operations {
|
| 136 |
+
C10_HOST_DEVICE
|
| 137 |
+
reference operator*() const {
|
| 138 |
+
return TupleInfo::tie(*keys, *values);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
// operator->() is supposed to return a pointer type.
|
| 142 |
+
// Since CompositeRandomAccessor does not hold pointers to pairs,
|
| 143 |
+
// we just return a pointer to a key.
|
| 144 |
+
C10_HOST_DEVICE
|
| 145 |
+
auto* operator->() const {
|
| 146 |
+
return keys.operator->();
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
C10_HOST_DEVICE
|
| 150 |
+
reference operator[](difference_type idx) {
|
| 151 |
+
return operator_brackets_proxy<self_type>(
|
| 152 |
+
CompositeRandomAccessor(keys + idx, values + idx)
|
| 153 |
+
);
|
| 154 |
+
}
|
| 155 |
+
// }
|
| 156 |
+
|
| 157 |
+
// Prefix/postfix increment/decrement {
|
| 158 |
+
C10_HOST_DEVICE
|
| 159 |
+
CompositeRandomAccessor& operator++() {
|
| 160 |
+
++keys;
|
| 161 |
+
++values;
|
| 162 |
+
return *this;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
C10_HOST_DEVICE
|
| 166 |
+
CompositeRandomAccessor operator++(int) {
|
| 167 |
+
CompositeRandomAccessor copy(*this);
|
| 168 |
+
++*this;
|
| 169 |
+
return copy;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
C10_HOST_DEVICE
|
| 173 |
+
CompositeRandomAccessor& operator--() {
|
| 174 |
+
--keys;
|
| 175 |
+
--values;
|
| 176 |
+
return *this;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
C10_HOST_DEVICE
|
| 180 |
+
CompositeRandomAccessor operator--(int) {
|
| 181 |
+
CompositeRandomAccessor copy(*this);
|
| 182 |
+
--*this;
|
| 183 |
+
return copy;
|
| 184 |
+
}
|
| 185 |
+
// }
|
| 186 |
+
|
| 187 |
+
// Arithmetic operations {
|
| 188 |
+
C10_HOST_DEVICE
|
| 189 |
+
CompositeRandomAccessor& operator+=(difference_type offset) {
|
| 190 |
+
keys += offset;
|
| 191 |
+
values += offset;
|
| 192 |
+
return *this;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
C10_HOST_DEVICE
|
| 196 |
+
CompositeRandomAccessor operator+(difference_type offset) const {
|
| 197 |
+
return CompositeRandomAccessor(keys + offset, values + offset);
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
C10_HOST_DEVICE
|
| 201 |
+
friend CompositeRandomAccessor operator+(
|
| 202 |
+
difference_type offset,
|
| 203 |
+
const CompositeRandomAccessor& accessor
|
| 204 |
+
) {
|
| 205 |
+
return accessor + offset;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
C10_HOST_DEVICE
|
| 209 |
+
CompositeRandomAccessor& operator-=(difference_type offset) {
|
| 210 |
+
keys -= offset;
|
| 211 |
+
values -= offset;
|
| 212 |
+
return *this;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
C10_HOST_DEVICE
|
| 216 |
+
CompositeRandomAccessor operator-(difference_type offset) const {
|
| 217 |
+
return CompositeRandomAccessor(keys - offset, values - offset);
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
C10_HOST_DEVICE
|
| 221 |
+
difference_type operator-(const CompositeRandomAccessor& other) const {
|
| 222 |
+
return keys - other.keys;
|
| 223 |
+
}
|
| 224 |
+
// }
|
| 225 |
+
|
| 226 |
+
// Comparison operators {
|
| 227 |
+
C10_HOST_DEVICE
|
| 228 |
+
bool operator==(const CompositeRandomAccessor& other) const {
|
| 229 |
+
return keys == other.keys;
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
C10_HOST_DEVICE
|
| 233 |
+
bool operator!=(const CompositeRandomAccessor& other) const {
|
| 234 |
+
return keys != other.keys;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
C10_HOST_DEVICE
|
| 238 |
+
bool operator<(const CompositeRandomAccessor& other) const {
|
| 239 |
+
return keys < other.keys;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
C10_HOST_DEVICE
|
| 243 |
+
bool operator<=(const CompositeRandomAccessor& other) const {
|
| 244 |
+
return keys <= other.keys;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
C10_HOST_DEVICE
|
| 248 |
+
bool operator>(const CompositeRandomAccessor& other) const {
|
| 249 |
+
return keys > other.keys;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
C10_HOST_DEVICE
|
| 253 |
+
bool operator>=(const CompositeRandomAccessor& other) const {
|
| 254 |
+
return keys >= other.keys;
|
| 255 |
+
}
|
| 256 |
+
// }
|
| 257 |
+
|
| 258 |
+
protected:
|
| 259 |
+
KeyAccessor keys;
|
| 260 |
+
ValueAccessor values;
|
| 261 |
+
};
|
| 262 |
+
|
| 263 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvolutionMM3d.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
|
| 3 |
+
namespace at::native {
|
| 4 |
+
|
| 5 |
+
std::tuple<Tensor, Tensor, Tensor> slow_conv3d_backward_cpu(
|
| 6 |
+
const Tensor& grad_output,
|
| 7 |
+
const Tensor& self,
|
| 8 |
+
const Tensor& weight,
|
| 9 |
+
IntArrayRef kernel_size,
|
| 10 |
+
IntArrayRef stride,
|
| 11 |
+
IntArrayRef padding,
|
| 12 |
+
std::array<bool, 3> output_mask);
|
| 13 |
+
|
| 14 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Copy.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
|
| 7 |
+
class Tensor;
|
| 8 |
+
struct TensorIterator;
|
| 9 |
+
class TensorBase;
|
| 10 |
+
|
| 11 |
+
namespace native {
|
| 12 |
+
|
| 13 |
+
using copy_fn = void (*)(TensorIterator&, bool non_blocking);
|
| 14 |
+
|
| 15 |
+
DECLARE_DISPATCH(copy_fn, copy_stub);
|
| 16 |
+
|
| 17 |
+
TORCH_API void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src);
|
| 18 |
+
|
| 19 |
+
} // namespace native
|
| 20 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <algorithm>
|
| 4 |
+
#include <vector>
|
| 5 |
+
|
| 6 |
+
#include <ATen/div_rtn.h>
|
| 7 |
+
#include <ATen/core/Tensor.h>
|
| 8 |
+
#include <c10/util/irange.h>
|
| 9 |
+
|
| 10 |
+
#define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
|
| 11 |
+
TORCH_CHECK( \
|
| 12 |
+
T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \
|
| 13 |
+
"Need " #T " of dimension ", \
|
| 14 |
+
DIM, \
|
| 15 |
+
" and " #T ".size[", \
|
| 16 |
+
DIM_SIZE, \
|
| 17 |
+
"] == ", \
|
| 18 |
+
SIZE, \
|
| 19 |
+
" but got input to be of shape ", \
|
| 20 |
+
T.sizes())
|
| 21 |
+
|
| 22 |
+
namespace at::native::internal {
|
| 23 |
+
namespace {
|
| 24 |
+
inline bool all_positive(IntArrayRef& arr) {
|
| 25 |
+
return std::all_of(
|
| 26 |
+
arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
inline bool all_nonnegative(std::vector<int64_t>& arr) {
|
| 30 |
+
return std::all_of(
|
| 31 |
+
arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
} // namespace
|
| 35 |
+
|
| 36 |
+
// calculate the rear part of output tensor sizes
|
| 37 |
+
template <int64_t dim>
|
| 38 |
+
std::vector<int64_t> get_output_size(
|
| 39 |
+
const Tensor& input,
|
| 40 |
+
IntArrayRef kernel_size,
|
| 41 |
+
IntArrayRef stride_size,
|
| 42 |
+
IntArrayRef pad_size,
|
| 43 |
+
IntArrayRef dilation_size) {
|
| 44 |
+
std::vector<int64_t> sizes;
|
| 45 |
+
for (const auto index : c10::irange(dim)) {
|
| 46 |
+
sizes.push_back(
|
| 47 |
+
div_rtn<int64_t>(
|
| 48 |
+
input.size(index + input.dim() - dim) + 2 * pad_size[index] -
|
| 49 |
+
(dilation_size[index] * (kernel_size[index] - 1) + 1),
|
| 50 |
+
stride_size[index]) +
|
| 51 |
+
1);
|
| 52 |
+
}
|
| 53 |
+
return sizes;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
// calculate the sizes of output tensor
|
| 57 |
+
template <int64_t dim>
|
| 58 |
+
std::vector<int64_t> get_output_size(
|
| 59 |
+
const Tensor& input,
|
| 60 |
+
const Tensor& weight,
|
| 61 |
+
IntArrayRef kernel_size,
|
| 62 |
+
IntArrayRef stride_size,
|
| 63 |
+
IntArrayRef pad_size,
|
| 64 |
+
IntArrayRef dilation_size) {
|
| 65 |
+
auto output_size = get_output_size<dim>(
|
| 66 |
+
input, kernel_size, stride_size, pad_size, dilation_size);
|
| 67 |
+
output_size.insert(output_size.begin(), weight.size(0));
|
| 68 |
+
if (input.dim() == dim + 2) {
|
| 69 |
+
output_size.insert(output_size.begin(), input.size(0));
|
| 70 |
+
}
|
| 71 |
+
return output_size;
|
| 72 |
+
}
|
| 73 |
+
/*
|
| 74 |
+
slow_conv_dilated_shape_check - check user-input to dilated convolution
|
| 75 |
+
forward and backward functions.
|
| 76 |
+
*/
|
| 77 |
+
template <int64_t dim>
|
| 78 |
+
void slow_conv_dilated_shape_check(
|
| 79 |
+
const Tensor& input,
|
| 80 |
+
const Tensor& weight,
|
| 81 |
+
const Tensor& bias,
|
| 82 |
+
const Tensor& grad_output,
|
| 83 |
+
IntArrayRef kernel_size,
|
| 84 |
+
IntArrayRef stride_size,
|
| 85 |
+
IntArrayRef pad_size,
|
| 86 |
+
IntArrayRef dilation_size) {
|
| 87 |
+
/*
|
| 88 |
+
When the following tensors are defined:
|
| 89 |
+
|
| 90 |
+
bias, grad_weight, grad_output
|
| 91 |
+
|
| 92 |
+
then these are assumed to be contiguous without checking
|
| 93 |
+
because of these tensors are made contiguous by calling
|
| 94 |
+
.contiguous() method or by resizing of zero-sized tensors in
|
| 95 |
+
forward/backward functions.
|
| 96 |
+
|
| 97 |
+
When grad_weight is defined then it is assumed without
|
| 98 |
+
checking to have the same shape as weight, see backward
|
| 99 |
+
functions.
|
| 100 |
+
*/
|
| 101 |
+
// Check size arguments
|
| 102 |
+
TORCH_CHECK(
|
| 103 |
+
kernel_size.size() == dim,
|
| 104 |
+
"kernel sizes length should be ",
|
| 105 |
+
dim,
|
| 106 |
+
", but got ",
|
| 107 |
+
kernel_size.size());
|
| 108 |
+
TORCH_CHECK(
|
| 109 |
+
stride_size.size() == dim,
|
| 110 |
+
"strides length should be ",
|
| 111 |
+
dim,
|
| 112 |
+
", but got ",
|
| 113 |
+
stride_size.size());
|
| 114 |
+
TORCH_CHECK(
|
| 115 |
+
dilation_size.size() == dim,
|
| 116 |
+
"dilations length should be ",
|
| 117 |
+
dim,
|
| 118 |
+
", but got ",
|
| 119 |
+
dilation_size.size());
|
| 120 |
+
TORCH_CHECK(
|
| 121 |
+
pad_size.size() == dim,
|
| 122 |
+
"pads length should be ",
|
| 123 |
+
dim,
|
| 124 |
+
", but got ",
|
| 125 |
+
pad_size.size());
|
| 126 |
+
|
| 127 |
+
TORCH_CHECK(
|
| 128 |
+
all_positive(kernel_size),
|
| 129 |
+
"kernel size should be greater than zero, but got ",
|
| 130 |
+
kernel_size);
|
| 131 |
+
TORCH_CHECK(
|
| 132 |
+
all_positive(stride_size),
|
| 133 |
+
"stride should be greater than zero, but got ",
|
| 134 |
+
stride_size);
|
| 135 |
+
TORCH_CHECK(
|
| 136 |
+
all_positive(dilation_size),
|
| 137 |
+
"dilation should be greater than zero, but got ",
|
| 138 |
+
dilation_size);
|
| 139 |
+
|
| 140 |
+
// check input
|
| 141 |
+
TORCH_CHECK(input.defined(), "input must be defined");
|
| 142 |
+
bool is_batch = input.dim() == dim + 2;
|
| 143 |
+
int64_t n = (is_batch ? 2 : 1);
|
| 144 |
+
int64_t ndim = n + dim;
|
| 145 |
+
if (!is_batch) {
|
| 146 |
+
// input dim has to be dim + 1 if not batched
|
| 147 |
+
TORCH_CHECK(
|
| 148 |
+
input.dim() == dim + 1,
|
| 149 |
+
"input must be 4D or 5D tensor but got ",
|
| 150 |
+
input.dim(),
|
| 151 |
+
"D tensor");
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
// check output sizes
|
| 155 |
+
auto output_size = get_output_size<dim>(
|
| 156 |
+
input, kernel_size, stride_size, pad_size, dilation_size);
|
| 157 |
+
|
| 158 |
+
TORCH_CHECK(
|
| 159 |
+
all_nonnegative(output_size),
|
| 160 |
+
"calculated output size ",
|
| 161 |
+
output_size,
|
| 162 |
+
" is too small (all sizes must be non-negative)");
|
| 163 |
+
|
| 164 |
+
// check weight
|
| 165 |
+
TORCH_CHECK(weight.defined(), "weight must be defined");
|
| 166 |
+
TORCH_CHECK(
|
| 167 |
+
weight.dim() == dim + 2,
|
| 168 |
+
"weight must be ",
|
| 169 |
+
dim + 2,
|
| 170 |
+
"D tensor but got ",
|
| 171 |
+
weight.dim(),
|
| 172 |
+
"D tensor dim=",
|
| 173 |
+
dim);
|
| 174 |
+
TORCH_CHECK(
|
| 175 |
+
weight.sizes().slice(2) == kernel_size,
|
| 176 |
+
"weight[2:] shape ",
|
| 177 |
+
weight.sizes().slice(2),
|
| 178 |
+
" must be equal to kernel_size ",
|
| 179 |
+
kernel_size);
|
| 180 |
+
|
| 181 |
+
TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));
|
| 182 |
+
|
| 183 |
+
// check bias when present
|
| 184 |
+
if (bias.defined()) {
|
| 185 |
+
TORCH_CHECK(
|
| 186 |
+
bias.dim() == 1,
|
| 187 |
+
"bias must be 1D tensor but got ",
|
| 188 |
+
bias.dim(),
|
| 189 |
+
"D tensor");
|
| 190 |
+
TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
// check grad_output when present
|
| 194 |
+
if (grad_output.defined()) {
|
| 195 |
+
TORCH_CHECK(
|
| 196 |
+
grad_output.dim() == ndim,
|
| 197 |
+
"grad_output must be ",
|
| 198 |
+
ndim,
|
| 199 |
+
"D tensor but got ",
|
| 200 |
+
grad_output.dim(),
|
| 201 |
+
"D tensor");
|
| 202 |
+
if (is_batch) {
|
| 203 |
+
TORCH_CHECK(
|
| 204 |
+
grad_output.size(0) == input.size(0),
|
| 205 |
+
"grad_output.size(0)=",
|
| 206 |
+
grad_output.size(0),
|
| 207 |
+
" must be input.size(0)=",
|
| 208 |
+
input.size(0));
|
| 209 |
+
}
|
| 210 |
+
TORCH_CHECK(
|
| 211 |
+
grad_output.size(n - 1) == weight.size(0),
|
| 212 |
+
"grad_output.size(",
|
| 213 |
+
n - 1,
|
| 214 |
+
")=",
|
| 215 |
+
grad_output.size(n - 1),
|
| 216 |
+
" must be weight.size(0)=",
|
| 217 |
+
weight.size(0));
|
| 218 |
+
TORCH_CHECK(
|
| 219 |
+
grad_output.sizes().slice(n) == output_size,
|
| 220 |
+
"grad_output[",
|
| 221 |
+
n,
|
| 222 |
+
":] shape",
|
| 223 |
+
grad_output.sizes().slice(n),
|
| 224 |
+
" must be equal to output size ",
|
| 225 |
+
output_size);
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
} // namespace at::native::internal
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ForeachUtils.h
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Device.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/ScalarType.h>
|
| 6 |
+
#include <ATen/core/Tensor.h>
|
| 7 |
+
#include <ATen/native/utils/ParamsHash.h>
|
| 8 |
+
#include <c10/util/Exception.h>
|
| 9 |
+
#include <c10/util/irange.h>
|
| 10 |
+
|
| 11 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 12 |
+
#include <ATen/NativeFunctions.h>
|
| 13 |
+
#else
|
| 14 |
+
#include <ATen/ops/result_type_native.h>
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
#include <unordered_map>
|
| 18 |
+
#include <vector>
|
| 19 |
+
|
| 20 |
+
namespace at::native {
|
| 21 |
+
namespace {
|
| 22 |
+
// Check if tensor list has either a boolean tensor or a integer tensor
|
| 23 |
+
inline bool has_integral_tensor(TensorList tensors, const bool includeBool) {
|
| 24 |
+
return std::any_of(
|
| 25 |
+
tensors.begin(), tensors.end(), [&includeBool](const auto& t) {
|
| 26 |
+
return at::isIntegralType(t.scalar_type(), includeBool);
|
| 27 |
+
});
|
| 28 |
+
}
|
| 29 |
+
// check if tensor list has bool tensors
|
| 30 |
+
inline bool has_bool_tensor(TensorList tensors) {
|
| 31 |
+
return std::any_of(tensors.begin(), tensors.end(), [](const auto& t) -> bool {
|
| 32 |
+
return t.scalar_type() == ScalarType::Bool;
|
| 33 |
+
});
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
// Check foreach API restrictions
|
| 37 |
+
// - Tensor lists must be non-empty.
|
| 38 |
+
// - All TensorLists and ScalarLists must have the same number of elements.
|
| 39 |
+
// - Corresponding tensors must have the same size.
|
| 40 |
+
inline void check_foreach_api_restrictions(TensorList tensors) {
|
| 41 |
+
TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor.");
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
inline void check_foreach_api_restrictions(
|
| 45 |
+
TensorList tensors,
|
| 46 |
+
ArrayRef<Scalar> scalars) {
|
| 47 |
+
check_foreach_api_restrictions(tensors);
|
| 48 |
+
TORCH_CHECK(
|
| 49 |
+
tensors.size() == scalars.size(),
|
| 50 |
+
"Tensor list must have same number of elements as scalar list.");
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
inline void check_foreach_api_restrictions(
|
| 54 |
+
TensorList tensors1,
|
| 55 |
+
TensorList tensors2) {
|
| 56 |
+
TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
|
| 57 |
+
TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
|
| 58 |
+
TORCH_CHECK(
|
| 59 |
+
tensors1.size() == tensors2.size(),
|
| 60 |
+
"Tensor lists must have the same number of tensors, got ",
|
| 61 |
+
tensors1.size(),
|
| 62 |
+
" and ",
|
| 63 |
+
tensors2.size());
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
inline void check_foreach_api_restrictions(
|
| 67 |
+
TensorList tensors1,
|
| 68 |
+
TensorList tensors2,
|
| 69 |
+
TensorList tensors3) {
|
| 70 |
+
TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
|
| 71 |
+
TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
|
| 72 |
+
TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor.");
|
| 73 |
+
TORCH_CHECK(
|
| 74 |
+
tensors1.size() == tensors2.size(),
|
| 75 |
+
"Tensor lists must have the same number of tensors, got ",
|
| 76 |
+
tensors1.size(),
|
| 77 |
+
" and ",
|
| 78 |
+
tensors2.size());
|
| 79 |
+
TORCH_CHECK(
|
| 80 |
+
tensors1.size() == tensors3.size(),
|
| 81 |
+
"Tensor lists must have the same number of tensors, got ",
|
| 82 |
+
tensors1.size(),
|
| 83 |
+
" and ",
|
| 84 |
+
tensors3.size());
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
inline void check_foreach_api_restrictions(
|
| 88 |
+
TensorList tensors1,
|
| 89 |
+
TensorList tensors2,
|
| 90 |
+
TensorList tensors3,
|
| 91 |
+
ArrayRef<Scalar> scalars) {
|
| 92 |
+
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
|
| 93 |
+
TORCH_CHECK(
|
| 94 |
+
tensors1.size() == scalars.size(),
|
| 95 |
+
"Tensor list must have same number of elements as scalar list, got ",
|
| 96 |
+
tensors1.size(),
|
| 97 |
+
" and ",
|
| 98 |
+
scalars.size());
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
// Helper function called in check_fast_path_restrictions to check whether all
|
| 102 |
+
// corresponding tensors (aligning in index across the tensorLists) share the
|
| 103 |
+
// same device and dtype.
|
| 104 |
+
inline bool _check_tensors_share_device_and_dtype(
|
| 105 |
+
ArrayRef<TensorList> tensorLists) {
|
| 106 |
+
const auto expected_dtype = tensorLists[0][0].dtype();
|
| 107 |
+
const auto expected_device = tensorLists[0][0].device();
|
| 108 |
+
|
| 109 |
+
auto is_tensor_okay = [&](const Tensor& tensor) {
|
| 110 |
+
return tensor.dtype() == expected_dtype &&
|
| 111 |
+
tensor.device() == expected_device && tensor.layout() == at::kStrided &&
|
| 112 |
+
tensor.is_non_overlapping_and_dense();
|
| 113 |
+
};
|
| 114 |
+
|
| 115 |
+
for (const auto& tensorList : tensorLists) {
|
| 116 |
+
for (const auto& tensor : tensorList) {
|
| 117 |
+
if (!is_tensor_okay(tensor)) {
|
| 118 |
+
return false;
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
return true;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// Helper function called in check_fast_path_restrictions to check if
|
| 127 |
+
// corresponding tensors in tensor lists have the same sizes and strides.
|
| 128 |
+
inline bool _check_tensors_share_sizes_and_strides(
|
| 129 |
+
ArrayRef<TensorList> tensorLists) {
|
| 130 |
+
for (const auto i : c10::irange(1, tensorLists.size())) {
|
| 131 |
+
for (const auto j : c10::irange(tensorLists[0].size())) {
|
| 132 |
+
if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() ||
|
| 133 |
+
tensorLists[0][j].strides() != tensorLists[i][j].strides()) {
|
| 134 |
+
return false;
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
return true;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
// Helper function called in check_fast_path_restrictions to check whether
|
| 143 |
+
// all tensors type promote properly with the scalars in scalarList. This
|
| 144 |
+
// function assumes that _check_tensors_share_device_and_dtype has already been
|
| 145 |
+
// called so that all corresponding tensors in tensorLists have the same dtype.
|
| 146 |
+
// Then, it is sufficient to check the type promotion with just one tensorList.
|
| 147 |
+
inline bool _check_tensors_do_type_promotion_with_scalars(
|
| 148 |
+
TensorList tensorList,
|
| 149 |
+
ArrayRef<Scalar> scalarList = {},
|
| 150 |
+
bool does_op_promote_integer_inputs_to_float = false) {
|
| 151 |
+
for (const auto i : c10::irange(tensorList.size())) {
|
| 152 |
+
// For division, integer inputs will result in float.
|
| 153 |
+
if (does_op_promote_integer_inputs_to_float) {
|
| 154 |
+
if (at::isIntegralType(
|
| 155 |
+
tensorList[i].scalar_type(), /*includeBool*/ true)) {
|
| 156 |
+
return false;
|
| 157 |
+
}
|
| 158 |
+
}
|
| 159 |
+
if (!scalarList.empty()) {
|
| 160 |
+
const auto& scalar =
|
| 161 |
+
scalarList.size() == 1 ? scalarList[0] : scalarList[i];
|
| 162 |
+
const auto& tensor = tensorList[i];
|
| 163 |
+
// note(mkozuki): This check might be responsible for
|
| 164 |
+
// `_foreach_add(bool_tensors, bool_tensors)` being pushed to slow path.
|
| 165 |
+
if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) {
|
| 166 |
+
return false;
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
return true;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
// To go via 'fast' path, several conditions must be satisfied
|
| 175 |
+
// - All tensors in all lists must have the same dtype.
|
| 176 |
+
// - All tensors must be on the same device
|
| 177 |
+
// - All tensors must have strided layout
|
| 178 |
+
// - All tensors must be non-overlapping and dense
|
| 179 |
+
// - Resulting tensor must have the same dtype as the input one
|
| 180 |
+
|
| 181 |
+
// Please, make sure to call check_foreach_api_restrictions before calling this
|
| 182 |
+
// method. There is a set of preconditions that have to be satisfied.
|
| 183 |
+
inline bool check_fast_path_restrictions(
|
| 184 |
+
ArrayRef<TensorList> tensorLists,
|
| 185 |
+
ArrayRef<Scalar> scalarList = {},
|
| 186 |
+
bool does_op_promote_integer_inputs_to_float = false) {
|
| 187 |
+
return _check_tensors_share_device_and_dtype(tensorLists) &&
|
| 188 |
+
_check_tensors_share_sizes_and_strides(tensorLists) &&
|
| 189 |
+
_check_tensors_do_type_promotion_with_scalars(
|
| 190 |
+
tensorLists[0],
|
| 191 |
+
scalarList,
|
| 192 |
+
does_op_promote_integer_inputs_to_float);
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
inline std::vector<c10::Scalar> convert_tensor_to_scalar_list(
|
| 196 |
+
const Tensor& scalarList_,
|
| 197 |
+
int64_t expect_length) {
|
| 198 |
+
std::vector<c10::Scalar> scalarList;
|
| 199 |
+
TORCH_CHECK(
|
| 200 |
+
scalarList_.device() == c10::kCPU,
|
| 201 |
+
"Expected scalars to be on CPU, got ",
|
| 202 |
+
scalarList_.device(),
|
| 203 |
+
" instead.");
|
| 204 |
+
TORCH_CHECK(
|
| 205 |
+
scalarList_.is_contiguous(), "Expected scalars to be contiguous.");
|
| 206 |
+
TORCH_CHECK(
|
| 207 |
+
scalarList_.dim() == 1,
|
| 208 |
+
"Expected packed scalar Tensor to be of dimension 1. Got ",
|
| 209 |
+
scalarList_.dim(),
|
| 210 |
+
" instead.");
|
| 211 |
+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
| 212 |
+
kComplexHalf,
|
| 213 |
+
kHalf,
|
| 214 |
+
kBool,
|
| 215 |
+
kBFloat16,
|
| 216 |
+
scalarList_.scalar_type(),
|
| 217 |
+
"convert_tensor_to_scalar_list",
|
| 218 |
+
[&]() {
|
| 219 |
+
const scalar_t* scalar_data = scalarList_.data_ptr<scalar_t>();
|
| 220 |
+
TORCH_CHECK(
|
| 221 |
+
(expect_length == scalarList_.size(0)),
|
| 222 |
+
"Expected length of scalars to match input of length ",
|
| 223 |
+
expect_length,
|
| 224 |
+
" but got ",
|
| 225 |
+
scalarList_.size(0),
|
| 226 |
+
" instead.");
|
| 227 |
+
for (int64_t i = 0; i < scalarList_.size(0); i++) {
|
| 228 |
+
scalarList.emplace_back(scalar_data[i]);
|
| 229 |
+
}
|
| 230 |
+
});
|
| 231 |
+
return scalarList;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
inline bool can_use_fast_route(
|
| 235 |
+
ArrayRef<TensorList> tensorLists,
|
| 236 |
+
ArrayRef<Scalar> scalarList = {},
|
| 237 |
+
bool does_op_promote_integer_inputs_to_float = false) {
|
| 238 |
+
return check_fast_path_restrictions(
|
| 239 |
+
tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
inline bool can_use_fast_route(
|
| 243 |
+
TensorList tensors1,
|
| 244 |
+
TensorList tensors2,
|
| 245 |
+
bool does_op_promote_integer_inputs_to_float = false) {
|
| 246 |
+
return can_use_fast_route(
|
| 247 |
+
{tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
using DeviceDtypeKey = std::pair<at::Device, at::ScalarType>;
|
| 251 |
+
using IndicesT = std::vector<size_t>;
|
| 252 |
+
using nested_optional_tensorvec_t =
|
| 253 |
+
std::vector<std::vector<c10::optional<at::Tensor>>>;
|
| 254 |
+
using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>;
|
| 255 |
+
using FlatMap = std::unordered_map<
|
| 256 |
+
DeviceDtypeKey,
|
| 257 |
+
TensorsAndIndicesT,
|
| 258 |
+
ParamsHash<DeviceDtypeKey>>;
|
| 259 |
+
|
| 260 |
+
inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
|
| 261 |
+
const nested_optional_tensorvec_t& nested_tensorlist,
|
| 262 |
+
const bool with_indices) {
|
| 263 |
+
FlatMap grouped_tensors_with_indices;
|
| 264 |
+
|
| 265 |
+
TORCH_CHECK(!nested_tensorlist.empty());
|
| 266 |
+
TORCH_CHECK(!nested_tensorlist[0].empty());
|
| 267 |
+
const auto num_lists = nested_tensorlist.size();
|
| 268 |
+
const auto num_tensors = nested_tensorlist[0].size();
|
| 269 |
+
|
| 270 |
+
TORCH_CHECK(std::all_of(
|
| 271 |
+
nested_tensorlist.cbegin(),
|
| 272 |
+
nested_tensorlist.cend(),
|
| 273 |
+
[&](const auto& tensorlist) -> bool {
|
| 274 |
+
// note(crcrpar): Allow empty tensorlists following
|
| 275 |
+
// ref:
|
| 276 |
+
// https://github.com/pytorch/pytorch/blob/85885301fd3c6adb8b9dc3cf7afadf6945566684/torch/utils/_foreach_utils.py#L21-L24
|
| 277 |
+
return tensorlist.size() == num_tensors || tensorlist.size() == 0;
|
| 278 |
+
}));
|
| 279 |
+
|
| 280 |
+
for (const auto& tensor_index : c10::irange(num_tensors)) {
|
| 281 |
+
const auto key = [&]() -> DeviceDtypeKey {
|
| 282 |
+
const auto t = nested_tensorlist[0][tensor_index];
|
| 283 |
+
TORCH_CHECK(
|
| 284 |
+
t.has_value(),
|
| 285 |
+
"Tensors of the first list of nested Tensor lists are supposed to be defined but ",
|
| 286 |
+
"the ",
|
| 287 |
+
tensor_index,
|
| 288 |
+
"-th Tensor is not.");
|
| 289 |
+
return {t->device(), t->scalar_type()};
|
| 290 |
+
}();
|
| 291 |
+
TORCH_CHECK(
|
| 292 |
+
std::all_of(
|
| 293 |
+
nested_tensorlist.cbegin(),
|
| 294 |
+
nested_tensorlist.cend(),
|
| 295 |
+
[&](const auto& tensorlist) -> bool {
|
| 296 |
+
if (tensorlist.size() == 0) {
|
| 297 |
+
return true;
|
| 298 |
+
}
|
| 299 |
+
const auto& tensor = tensorlist[tensor_index];
|
| 300 |
+
// note(crcrpar): Currently the scope of this function is
|
| 301 |
+
// optimizers so there could be `state_steps` and other scalars
|
| 302 |
+
// whose elements are float tensors no matter what the parameter's
|
| 303 |
+
// dtype is.
|
| 304 |
+
if (!tensor.has_value()) {
|
| 305 |
+
return true;
|
| 306 |
+
} else {
|
| 307 |
+
const auto s = tensor->scalar_type();
|
| 308 |
+
const auto d = tensor->device();
|
| 309 |
+
// Note: `step` or `state_step` is float32 by default.
|
| 310 |
+
if (key.first == d) {
|
| 311 |
+
return key.second == s || s == at::ScalarType::Float ||
|
| 312 |
+
s == at::ScalarType::Double;
|
| 313 |
+
} else if (d.is_cpu()) {
|
| 314 |
+
// note(crcrpar): There are some test cases (e.g.
|
| 315 |
+
// TestOptim::test_adam) where state_steps are on CPU and the
|
| 316 |
+
// others are on CUDA. Currently a state_step Tensor has the
|
| 317 |
+
// dtype of float.
|
| 318 |
+
return s == at::ScalarType::Float ||
|
| 319 |
+
s == at::ScalarType::Double;
|
| 320 |
+
} else {
|
| 321 |
+
return false;
|
| 322 |
+
}
|
| 323 |
+
}
|
| 324 |
+
}),
|
| 325 |
+
"Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding");
|
| 326 |
+
if (!grouped_tensors_with_indices.count(key)) {
|
| 327 |
+
grouped_tensors_with_indices.insert(
|
| 328 |
+
{key,
|
| 329 |
+
TensorsAndIndicesT{
|
| 330 |
+
[&]() -> nested_optional_tensorvec_t {
|
| 331 |
+
nested_optional_tensorvec_t nested_tensorvec;
|
| 332 |
+
nested_tensorvec.reserve(num_lists);
|
| 333 |
+
for (const auto& i : c10::irange(num_lists)) {
|
| 334 |
+
std::vector<c10::optional<at::Tensor>> tensors;
|
| 335 |
+
if (!nested_tensorlist[i].empty()) {
|
| 336 |
+
// NB: num_tensors is the max possible length for any of
|
| 337 |
+
// the inner lists of tensor references. Reserving the max
|
| 338 |
+
// trades memory for perf. This should not have significant
|
| 339 |
+
// impact.
|
| 340 |
+
tensors.reserve(num_tensors);
|
| 341 |
+
}
|
| 342 |
+
nested_tensorvec.emplace_back(tensors);
|
| 343 |
+
}
|
| 344 |
+
return nested_tensorvec;
|
| 345 |
+
}(),
|
| 346 |
+
[&]() -> IndicesT {
|
| 347 |
+
if (!with_indices) {
|
| 348 |
+
return {};
|
| 349 |
+
} else {
|
| 350 |
+
IndicesT indices;
|
| 351 |
+
indices.reserve(num_tensors);
|
| 352 |
+
return indices;
|
| 353 |
+
}
|
| 354 |
+
}()}});
|
| 355 |
+
}
|
| 356 |
+
for (const auto& list_index : c10::irange(num_lists)) {
|
| 357 |
+
if (!nested_tensorlist[list_index].empty()) {
|
| 358 |
+
grouped_tensors_with_indices[key].first[list_index].emplace_back(
|
| 359 |
+
nested_tensorlist[list_index][tensor_index]);
|
| 360 |
+
}
|
| 361 |
+
}
|
| 362 |
+
if (with_indices) {
|
| 363 |
+
grouped_tensors_with_indices[key].second.emplace_back(tensor_index);
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
return grouped_tensors_with_indices;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
} // namespace
|
| 371 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebra.h
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <c10/util/Optional.h>
|
| 5 |
+
|
| 6 |
+
namespace c10 {
|
| 7 |
+
class Scalar;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
struct TensorIterator;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
namespace at::native {
|
| 15 |
+
|
| 16 |
+
using addr_fn = void (*)(TensorIterator &, const Scalar& beta, const Scalar& alpha);
|
| 17 |
+
DECLARE_DISPATCH(addr_fn, addr_stub);
|
| 18 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SortingUtils.h
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/NumericUtils.h>
|
| 4 |
+
#include <ATen/native/Resize.h>
|
| 5 |
+
#include <c10/util/irange.h>
|
| 6 |
+
|
| 7 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 8 |
+
#include <ATen/Functions.h>
|
| 9 |
+
#else
|
| 10 |
+
#include <ATen/ops/empty.h>
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
|
| 15 |
+
// ensure we get good values and indices for kthvalue, mode
|
| 16 |
+
// this will always be with the reducing dim as 1-d
|
| 17 |
+
inline void _reduction_with_indices_allocate_or_resize_output(
|
| 18 |
+
Tensor& values,
|
| 19 |
+
Tensor& indices,
|
| 20 |
+
const Tensor& self,
|
| 21 |
+
int64_t dim_,
|
| 22 |
+
bool keepdim) {
|
| 23 |
+
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
|
| 24 |
+
auto result_sizes = self.sizes().vec();
|
| 25 |
+
if (!result_sizes.empty()) {
|
| 26 |
+
result_sizes[dim] = 1;
|
| 27 |
+
}
|
| 28 |
+
if (values.defined()) {
|
| 29 |
+
TORCH_CHECK(
|
| 30 |
+
self.options().type_equal(values.options()),
|
| 31 |
+
"output values must be of same type as input");
|
| 32 |
+
if (!keepdim && values.dim() == self.dim() - 1) {
|
| 33 |
+
// unsqueeze to preserve passed in noncontiguous tensor in resize
|
| 34 |
+
values.unsqueeze_(dim);
|
| 35 |
+
}
|
| 36 |
+
resize_output(values, result_sizes);
|
| 37 |
+
} else {
|
| 38 |
+
values = at::empty(result_sizes, self.options());
|
| 39 |
+
}
|
| 40 |
+
if (indices.defined()) {
|
| 41 |
+
TORCH_CHECK(
|
| 42 |
+
indices.dtype() == kLong, "output indices must be of scalar type Long");
|
| 43 |
+
TORCH_CHECK(
|
| 44 |
+
indices.device() == self.device(),
|
| 45 |
+
"output indices must be on same device as input");
|
| 46 |
+
if (!keepdim && indices.dim() == self.dim() - 1) {
|
| 47 |
+
// unsqueeze to preserve passed in noncontiguous tensor in resize
|
| 48 |
+
indices.unsqueeze_(dim);
|
| 49 |
+
}
|
| 50 |
+
resize_output(indices, result_sizes);
|
| 51 |
+
} else {
|
| 52 |
+
indices = at::empty(result_sizes, self.options().dtype(kLong));
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
// ensure we get good values and indices for topk
|
| 57 |
+
inline void _allocate_or_resize_output_with_indices(
|
| 58 |
+
Tensor& values,
|
| 59 |
+
Tensor& indices,
|
| 60 |
+
const Tensor& self,
|
| 61 |
+
int64_t dim_,
|
| 62 |
+
int64_t k) {
|
| 63 |
+
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
|
| 64 |
+
auto result_sizes = self.sizes().vec();
|
| 65 |
+
if (!result_sizes.empty()) {
|
| 66 |
+
result_sizes[dim] = k;
|
| 67 |
+
}
|
| 68 |
+
if (values.defined()) {
|
| 69 |
+
TORCH_CHECK(
|
| 70 |
+
self.options().type_equal(values.options()),
|
| 71 |
+
"output values must be of same type as input");
|
| 72 |
+
values.resize_(result_sizes);
|
| 73 |
+
} else {
|
| 74 |
+
values = at::empty(result_sizes, self.options());
|
| 75 |
+
}
|
| 76 |
+
if (indices.defined()) {
|
| 77 |
+
TORCH_CHECK(
|
| 78 |
+
indices.dtype() == kLong, "output indices must be of scalar type Long");
|
| 79 |
+
TORCH_CHECK(
|
| 80 |
+
indices.device() == self.device(),
|
| 81 |
+
"output indices must be on same device as input");
|
| 82 |
+
indices.resize_(result_sizes);
|
| 83 |
+
} else {
|
| 84 |
+
indices = at::empty(result_sizes, self.options().dtype(kLong));
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UnaryOps.h
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <ATen/Generator.h>
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <stdexcept>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
class Tensor;
|
| 10 |
+
class TensorBase;
|
| 11 |
+
struct TensorIteratorBase;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
namespace at::native {
|
| 15 |
+
|
| 16 |
+
using unary_fn = void(*)(TensorIteratorBase&);
|
| 17 |
+
using unary_fn_with_scalar = void(*)(TensorIteratorBase&, const Scalar& a);
|
| 18 |
+
|
| 19 |
+
inline namespace CPU_CAPABILITY {
|
| 20 |
+
void conj_kernel(TensorIteratorBase &iter);
|
| 21 |
+
void neg_kernel(TensorIteratorBase &iter);
|
| 22 |
+
void reciprocal_kernel(TensorIteratorBase &iter);
|
| 23 |
+
void rsqrt_kernel(TensorIteratorBase& iter);
|
| 24 |
+
void sqrt_kernel(TensorIteratorBase& iter);
|
| 25 |
+
} // namespace CPU_CAPABILITY
|
| 26 |
+
|
| 27 |
+
DECLARE_DISPATCH(unary_fn, abs_stub);
|
| 28 |
+
DECLARE_DISPATCH(unary_fn, angle_stub);
|
| 29 |
+
DECLARE_DISPATCH(unary_fn, conj_physical_stub);
|
| 30 |
+
DECLARE_DISPATCH(unary_fn, acos_stub);
|
| 31 |
+
DECLARE_DISPATCH(unary_fn, acosh_stub);
|
| 32 |
+
DECLARE_DISPATCH(unary_fn, asinh_stub);
|
| 33 |
+
DECLARE_DISPATCH(unary_fn, atanh_stub);
|
| 34 |
+
DECLARE_DISPATCH(unary_fn, asin_stub);
|
| 35 |
+
DECLARE_DISPATCH(unary_fn, atan_stub);
|
| 36 |
+
DECLARE_DISPATCH(unary_fn, bitwise_not_stub);
|
| 37 |
+
DECLARE_DISPATCH(unary_fn, logical_not_stub);
|
| 38 |
+
DECLARE_DISPATCH(unary_fn, ceil_stub);
|
| 39 |
+
DECLARE_DISPATCH(unary_fn, cos_stub);
|
| 40 |
+
DECLARE_DISPATCH(unary_fn, cosh_stub);
|
| 41 |
+
DECLARE_DISPATCH(unary_fn, digamma_stub);
|
| 42 |
+
DECLARE_DISPATCH(unary_fn, special_entr_stub);
|
| 43 |
+
DECLARE_DISPATCH(unary_fn, special_erfcx_stub);
|
| 44 |
+
DECLARE_DISPATCH(unary_fn, erf_stub);
|
| 45 |
+
DECLARE_DISPATCH(unary_fn, erfc_stub);
|
| 46 |
+
DECLARE_DISPATCH(unary_fn, erfinv_stub);
|
| 47 |
+
DECLARE_DISPATCH(unary_fn, exp_stub);
|
| 48 |
+
DECLARE_DISPATCH(unary_fn, exp2_stub);
|
| 49 |
+
DECLARE_DISPATCH(unary_fn, expm1_stub);
|
| 50 |
+
DECLARE_DISPATCH(unary_fn, floor_stub);
|
| 51 |
+
DECLARE_DISPATCH(unary_fn, frac_stub);
|
| 52 |
+
DECLARE_DISPATCH(unary_fn, frexp_stub);
|
| 53 |
+
DECLARE_DISPATCH(unary_fn, i0_stub);
|
| 54 |
+
DECLARE_DISPATCH(unary_fn, special_i0e_stub);
|
| 55 |
+
DECLARE_DISPATCH(unary_fn, special_i1_stub);
|
| 56 |
+
DECLARE_DISPATCH(unary_fn, special_i1e_stub);
|
| 57 |
+
DECLARE_DISPATCH(unary_fn, log_stub);
|
| 58 |
+
DECLARE_DISPATCH(unary_fn, log10_stub);
|
| 59 |
+
DECLARE_DISPATCH(unary_fn, log1p_stub);
|
| 60 |
+
DECLARE_DISPATCH(unary_fn, log2_stub);
|
| 61 |
+
DECLARE_DISPATCH(unary_fn, special_ndtri_stub);
|
| 62 |
+
DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub);
|
| 63 |
+
DECLARE_DISPATCH(unary_fn, neg_stub);
|
| 64 |
+
|
| 65 |
+
DECLARE_DISPATCH(unary_fn, reciprocal_stub);
|
| 66 |
+
DECLARE_DISPATCH(unary_fn, round_stub);
|
| 67 |
+
DECLARE_DISPATCH(unary_fn, rsqrt_stub);
|
| 68 |
+
DECLARE_DISPATCH(unary_fn, sigmoid_stub);
|
| 69 |
+
DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub);
|
| 70 |
+
DECLARE_DISPATCH(unary_fn, sign_stub);
|
| 71 |
+
DECLARE_DISPATCH(unary_fn, signbit_stub);
|
| 72 |
+
DECLARE_DISPATCH(unary_fn, sgn_stub);
|
| 73 |
+
DECLARE_DISPATCH(unary_fn, sin_stub);
|
| 74 |
+
DECLARE_DISPATCH(unary_fn, sinc_stub);
|
| 75 |
+
DECLARE_DISPATCH(unary_fn, sinh_stub);
|
| 76 |
+
DECLARE_DISPATCH(unary_fn, sqrt_stub);
|
| 77 |
+
DECLARE_DISPATCH(unary_fn, tan_stub);
|
| 78 |
+
DECLARE_DISPATCH(unary_fn, tanh_stub);
|
| 79 |
+
DECLARE_DISPATCH(unary_fn, trigamma_stub);
|
| 80 |
+
DECLARE_DISPATCH(unary_fn, trunc_stub);
|
| 81 |
+
DECLARE_DISPATCH(unary_fn, lgamma_stub);
|
| 82 |
+
DECLARE_DISPATCH(unary_fn, special_airy_ai_stub);
|
| 83 |
+
DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub);
|
| 84 |
+
DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub);
|
| 85 |
+
DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub);
|
| 86 |
+
DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub);
|
| 87 |
+
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub);
|
| 88 |
+
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub);
|
| 89 |
+
DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub);
|
| 90 |
+
DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub);
|
| 91 |
+
DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub);
|
| 92 |
+
DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub);
|
| 93 |
+
DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub);
|
| 94 |
+
|
| 95 |
+
// NB: these are actually defined in Distribution
|
| 96 |
+
DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, c10::optional<Generator>), bernoulli_tensor_stub);
|
| 97 |
+
DECLARE_DISPATCH(void(*)(const TensorBase&, const double, c10::optional<Generator>), bernoulli_scalar_stub);
|
| 98 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, c10::optional<Generator>), cauchy_stub);
|
| 99 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, c10::optional<Generator>), exponential_stub);
|
| 100 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, c10::optional<Generator>), geometric_stub);
|
| 101 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, c10::optional<Generator>), log_normal_stub);
|
| 102 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, c10::optional<Generator>), uniform_stub);
|
| 103 |
+
DECLARE_DISPATCH(void(*)(const TensorBase&, const double, const double, c10::optional<Generator>), normal_stub);
|
| 104 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const uint64_t, const int64_t, c10::optional<Generator>), random_from_to_stub);
|
| 105 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, c10::optional<Generator>), random_full_64_bits_range_stub);
|
| 106 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, c10::optional<Generator>), random_stub);
|
| 107 |
+
|
| 108 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub);
|
| 109 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub);
|
| 110 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const Scalar& a, const Scalar& b), clamp_stub);
|
| 111 |
+
DECLARE_DISPATCH(
|
| 112 |
+
void (*)(Tensor&, const Tensor&, int64_t, c10::optional<Generator>),
|
| 113 |
+
multinomial_with_replacement_stub);
|
| 114 |
+
DECLARE_DISPATCH(
|
| 115 |
+
void (*)(
|
| 116 |
+
TensorIteratorBase&,
|
| 117 |
+
c10::optional<double>,
|
| 118 |
+
c10::optional<double>,
|
| 119 |
+
c10::optional<double>),
|
| 120 |
+
nan_to_num_stub);
|
| 121 |
+
DECLARE_DISPATCH(void (*)(TensorIteratorBase&, int64_t), round_decimals_stub);
|
| 122 |
+
|
| 123 |
+
// Missing unary functions
|
| 124 |
+
// digamma
|
| 125 |
+
// lgamma
|
| 126 |
+
// erfinv
|
| 127 |
+
// clone
|
| 128 |
+
// contiguous
|
| 129 |
+
// zero
|
| 130 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef ATOMIC_ADD_FLOAT
|
| 2 |
+
#define ATOMIC_ADD_FLOAT
|
| 3 |
+
|
| 4 |
+
#if (defined(__x86_64__) || defined(__i386__) || defined(__aarch64__))
|
| 5 |
+
#include <ATen/native/cpu/Intrinsics.h>
|
| 6 |
+
#else
|
| 7 |
+
#define _mm_pause()
|
| 8 |
+
#endif
|
| 9 |
+
|
| 10 |
+
#include <atomic>
|
| 11 |
+
|
| 12 |
+
static inline void cpu_atomic_add_float(float* dst, float fvalue)
|
| 13 |
+
{
|
| 14 |
+
typedef union {
|
| 15 |
+
unsigned intV;
|
| 16 |
+
float floatV;
|
| 17 |
+
} uf32_t;
|
| 18 |
+
|
| 19 |
+
uf32_t new_value, old_value;
|
| 20 |
+
std::atomic<unsigned>* dst_intV = (std::atomic<unsigned>*)(dst);
|
| 21 |
+
|
| 22 |
+
old_value.floatV = *dst;
|
| 23 |
+
new_value.floatV = old_value.floatV + fvalue;
|
| 24 |
+
|
| 25 |
+
unsigned* old_intV = (unsigned*)(&old_value.intV);
|
| 26 |
+
while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) {
|
| 27 |
+
#ifdef __aarch64__
|
| 28 |
+
__asm__ __volatile__("yield;" : : : "memory");
|
| 29 |
+
#else
|
| 30 |
+
_mm_pause();
|
| 31 |
+
#endif
|
| 32 |
+
old_value.floatV = *dst;
|
| 33 |
+
new_value.floatV = old_value.floatV + fvalue;
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
#endif
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CatKernel.h
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <ATen/core/IListRef.h>
|
| 6 |
+
|
| 7 |
+
namespace at { namespace native {
|
| 8 |
+
|
| 9 |
+
using cat_serial_fn = void(*)(const Tensor &, const MaterializedITensorListRef&, int64_t);
|
| 10 |
+
DECLARE_DISPATCH(cat_serial_fn, cat_serial_stub);
|
| 11 |
+
|
| 12 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/DispatchStub.h>
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class TensorBase;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at { namespace native {
|
| 10 |
+
|
| 11 |
+
using channel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
|
| 12 |
+
DECLARE_DISPATCH(channel_shuffle_fn, channel_shuffle_kernel);
|
| 13 |
+
|
| 14 |
+
}} // at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <c10/util/ArrayRef.h>
|
| 5 |
+
|
| 6 |
+
/*
|
| 7 |
+
Depthwise 3x3 Winograd convolution operator
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
class Tensor;
|
| 12 |
+
|
| 13 |
+
namespace native {
|
| 14 |
+
|
| 15 |
+
using convolution_depthwise3x3_winograd_fn =
|
| 16 |
+
Tensor (*)(const Tensor &, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t);
|
| 17 |
+
|
| 18 |
+
DECLARE_DISPATCH(convolution_depthwise3x3_winograd_fn, convolution_depthwise3x3_winograd_stub);
|
| 19 |
+
|
| 20 |
+
} // namespace native
|
| 21 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#if defined(__clang__) && (defined(__x86_64__) || defined(__i386__))
|
| 4 |
+
/* Clang-compatible compiler, targeting x86/x86-64 */
|
| 5 |
+
#include <x86intrin.h>
|
| 6 |
+
#elif defined(_MSC_VER)
|
| 7 |
+
/* Microsoft C/C++-compatible compiler */
|
| 8 |
+
#include <intrin.h>
|
| 9 |
+
#if _MSC_VER <= 1900
|
| 10 |
+
#define _mm256_extract_epi64(X, Y) (((uint64_t*)&X)[Y])
|
| 11 |
+
#endif
|
| 12 |
+
#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
|
| 13 |
+
/* GCC-compatible compiler, targeting x86/x86-64 */
|
| 14 |
+
#include <x86intrin.h>
|
| 15 |
+
#elif defined(__GNUC__) && defined(__ARM_NEON__)
|
| 16 |
+
/* GCC-compatible compiler, targeting ARM with NEON */
|
| 17 |
+
#include <arm_neon.h>
|
| 18 |
+
#elif defined(__GNUC__) && defined(__IWMMXT__)
|
| 19 |
+
/* GCC-compatible compiler, targeting ARM with WMMX */
|
| 20 |
+
#include <mmintrin.h>
|
| 21 |
+
#elif (defined(__GNUC__) || defined(__xlC__)) && \
|
| 22 |
+
(defined(__VEC__) || defined(__ALTIVEC__))
|
| 23 |
+
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
|
| 24 |
+
#include <altivec.h>
|
| 25 |
+
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts
|
| 26 |
+
with the C++ types. => Can still use __bool/__vector */
|
| 27 |
+
#undef bool
|
| 28 |
+
#undef vector
|
| 29 |
+
#undef pixel
|
| 30 |
+
#elif defined(__GNUC__) && defined(__SPE__)
|
| 31 |
+
/* GCC-compatible compiler, targeting PowerPC with SPE */
|
| 32 |
+
#include <spe.h>
|
| 33 |
+
#endif
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Loops.h
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// This file provides two functions to help write elementwise kernels:
|
| 4 |
+
//
|
| 5 |
+
// cpu_kernel(TensorIterator iter, <lambda>)
|
| 6 |
+
// cpu_kernel_vec(TensorIterator iter, <lambda>, <vec_lambda>)
|
| 7 |
+
//
|
| 8 |
+
// Both functions may generate vectorized code. The cpu_kernel implementation
|
| 9 |
+
// relies on the compiler's auto-vectorization. The cpu_kernel_vec
|
| 10 |
+
// implementation uses x86 SIMD intrinsics when available. These functions
|
| 11 |
+
// are only intended to be used in the ATen/native/cpu subdirectory, since files
|
| 12 |
+
// in other directories are not compiled with AVX/AVX2 enabled. See README.md
|
| 13 |
+
// for more details.
|
| 14 |
+
//
|
| 15 |
+
// For example, to write a multiplication kernel for float:
|
| 16 |
+
//
|
| 17 |
+
// cpu_kernel(iter, [](float a, float b) { return a * b; });
|
| 18 |
+
//
|
| 19 |
+
// Or you may write:
|
| 20 |
+
//
|
| 21 |
+
// cpu_kernel_vec(iter,
|
| 22 |
+
// [](float a, float b) { return a * b; },
|
| 23 |
+
// [](Vectorized<float> a, Vectorized<float> b) { return a * b; });
|
| 24 |
+
//
|
| 25 |
+
// See BinaryOpsKernel.cpp for the complete implementation
|
| 26 |
+
//
|
| 27 |
+
//
|
| 28 |
+
|
| 29 |
+
#include <stdint.h>
|
| 30 |
+
#include <c10/util/C++17.h>
|
| 31 |
+
#include <c10/util/Load.h>
|
| 32 |
+
#include <c10/util/irange.h>
|
| 33 |
+
#include <ATen/detail/FunctionTraits.h>
|
| 34 |
+
#include <ATen/native/cpu/IsContiguous.h>
|
| 35 |
+
#include <ATen/native/TensorIterator.h>
|
| 36 |
+
#include <ATen/native/TensorIteratorDynamicCasting.h>
|
| 37 |
+
#include <ATen/cpu/vec/vec.h>
|
| 38 |
+
|
| 39 |
+
#include <utility>
|
| 40 |
+
|
| 41 |
+
namespace at { namespace native { inline namespace CPU_CAPABILITY {
|
| 42 |
+
|
| 43 |
+
using namespace vec;
|
| 44 |
+
|
| 45 |
+
template <typename traits, std::size_t... INDEX>
|
| 46 |
+
typename traits::ArgsTuple
|
| 47 |
+
dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i,
|
| 48 |
+
std::index_sequence<INDEX...>) {
|
| 49 |
+
return std::make_tuple(
|
| 50 |
+
c10::load<typename traits::template arg<INDEX>::type>(
|
| 51 |
+
data[INDEX] + i * strides[INDEX])...);
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
template <typename traits>
|
| 55 |
+
typename traits::ArgsTuple
|
| 56 |
+
dereference(char* C10_RESTRICT data[], const int64_t* strides, int64_t i) {
|
| 57 |
+
using Indices = std::make_index_sequence<traits::arity>;
|
| 58 |
+
return dereference_impl<traits>(data, strides, i, Indices{});
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
template <typename traits, std::size_t... INDEX>
|
| 62 |
+
typename traits::ArgsTuple
|
| 63 |
+
dereference_vec_impl(char* C10_RESTRICT data[],
|
| 64 |
+
const typename traits::result_type& opt_scalar,
|
| 65 |
+
size_t S,
|
| 66 |
+
int64_t i,
|
| 67 |
+
std::index_sequence<INDEX...>) {
|
| 68 |
+
using Vec = typename traits::result_type;
|
| 69 |
+
using scalar_t = typename Vec::value_type;
|
| 70 |
+
return std::make_tuple(
|
| 71 |
+
S == INDEX + 1 ?
|
| 72 |
+
opt_scalar :
|
| 73 |
+
Vec::loadu(data[INDEX] + i * sizeof(scalar_t))...);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
template <typename traits>
|
| 77 |
+
typename traits::ArgsTuple
|
| 78 |
+
dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& opt_scalar, size_t S, int64_t i) {
|
| 79 |
+
using Indices = std::make_index_sequence<traits::arity>;
|
| 80 |
+
return dereference_vec_impl<traits>(data, opt_scalar, S, i, Indices{});
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
template <typename func_t,
|
| 84 |
+
typename std::enable_if<!std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
|
| 85 |
+
static inline void
|
| 86 |
+
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
|
| 87 |
+
using traits = function_traits<func_t>;
|
| 88 |
+
using result_type = typename traits::result_type;
|
| 89 |
+
for (; i < n; i++) {
|
| 90 |
+
result_type* out_ptr = (result_type*)(data[0] + i * strides[0]);
|
| 91 |
+
*out_ptr = c10::guts::apply(std::forward<func_t>(op), dereference<traits>(
|
| 92 |
+
&data[1],
|
| 93 |
+
&strides[1],
|
| 94 |
+
i));
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
template <typename func_t,
|
| 99 |
+
typename std::enable_if<std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
|
| 100 |
+
static inline void
|
| 101 |
+
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
|
| 102 |
+
using traits = function_traits<func_t>;
|
| 103 |
+
for (; i < n; i++) {
|
| 104 |
+
c10::guts::apply(std::forward<func_t>(op), dereference<traits>(
|
| 105 |
+
&data[0],
|
| 106 |
+
&strides[0],
|
| 107 |
+
i));
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
// Basic loop operation (one output, N inputs). May be auto-vectorized
|
| 112 |
+
// by the compiler. Supports inputs and outputs of different types.
|
| 113 |
+
template <typename func_t>
|
| 114 |
+
static inline void
|
| 115 |
+
basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
|
| 116 |
+
using traits = function_traits<func_t>;
|
| 117 |
+
constexpr int ntensors = traits::arity + 1;
|
| 118 |
+
|
| 119 |
+
// Copying strides to temporary array helps auto vectorization in older GCC
|
| 120 |
+
// versions.
|
| 121 |
+
int64_t strides[ntensors];
|
| 122 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 123 |
+
strides[arg] = strides_[arg];
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
execute_op(data, strides, i, n, std::forward<func_t>(op));
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// the recursive variadic template for iterating over the returned tuple
|
| 130 |
+
template<class T, size_t N>
|
| 131 |
+
struct TupleOutput {
|
| 132 |
+
static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
|
| 133 |
+
const T &tuple) {
|
| 134 |
+
TupleOutput<T, N - 1>::handle(data, strides, i, tuple);
|
| 135 |
+
|
| 136 |
+
auto output = std::get<N - 1>(tuple);
|
| 137 |
+
using output_type = decltype(output);
|
| 138 |
+
output_type * out_ptr = (output_type *)(data[N - 1] + i * strides[N - 1]);
|
| 139 |
+
*out_ptr = output;
|
| 140 |
+
}
|
| 141 |
+
};
|
| 142 |
+
|
| 143 |
+
// Base case for the above recursive template
|
| 144 |
+
template<class T>
|
| 145 |
+
struct TupleOutput<T, 1> {
|
| 146 |
+
static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
|
| 147 |
+
const T &tuple) {
|
| 148 |
+
auto output = std::get<0>(tuple);
|
| 149 |
+
using output_type = decltype(output);
|
| 150 |
+
output_type* out_ptr = (output_type *)(data[0] + i * strides[0]);
|
| 151 |
+
*out_ptr = output;
|
| 152 |
+
}
|
| 153 |
+
};
|
| 154 |
+
|
| 155 |
+
template<class... Args>
|
| 156 |
+
void handle_tuple_outputs(char* C10_RESTRICT data[],
|
| 157 |
+
const int64_t* strides,
|
| 158 |
+
int64_t i,
|
| 159 |
+
const std::tuple<Args...> &tuple) {
|
| 160 |
+
TupleOutput<decltype(tuple), sizeof...(Args)>::handle(data, strides, i, tuple);
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
// Loop operation for `cpu_kernel_multiple_outputs`.
|
| 164 |
+
// 1. Use `c10::guts::apply` to make dynamic method invocation
|
| 165 |
+
// for the lambda passed in `cpu_kernel_multiple_outputs`.
|
| 166 |
+
// 2. Iterate over the members of the returned tuple, set the corresponding
|
| 167 |
+
// output tensor by the tuple member in `handle_tuple_outputs` function.
|
| 168 |
+
template <typename func_t>
|
| 169 |
+
static inline void
|
| 170 |
+
multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
|
| 171 |
+
using traits = function_traits<func_t>;
|
| 172 |
+
|
| 173 |
+
using result_type = typename traits::result_type;
|
| 174 |
+
constexpr int num_outputs = std::tuple_size<result_type>::value;
|
| 175 |
+
constexpr int ntensors = traits::arity + num_outputs;
|
| 176 |
+
|
| 177 |
+
// Copying strides to temporary array helps auto vectorization in older GCC
|
| 178 |
+
// versions.
|
| 179 |
+
int64_t strides[ntensors];
|
| 180 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 181 |
+
strides[arg] = strides_[arg];
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
for (; i < n; i++) {
|
| 185 |
+
auto output = c10::guts::apply(op, dereference<traits>(
|
| 186 |
+
&data[num_outputs],
|
| 187 |
+
&strides[num_outputs],
|
| 188 |
+
i));
|
| 189 |
+
handle_tuple_outputs(data, strides, i, output);
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
// Explicitly vectorized loop implementation. All inputs and outputs must be
|
| 194 |
+
// the same type and contiguous with one exception: a single input may be
|
| 195 |
+
// a scalar (stride 0). It's position is indicated by the argument `S`. If `S`
|
| 196 |
+
// is 0, then there are no scalar inputs.
|
| 197 |
+
template <typename func_t, typename vec_func_t>
|
| 198 |
+
static inline void
|
| 199 |
+
vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) {
|
| 200 |
+
using traits = function_traits<vec_func_t>;
|
| 201 |
+
using scalar_t = typename function_traits<func_t>::result_type;
|
| 202 |
+
using Vec = Vectorized<scalar_t>;
|
| 203 |
+
constexpr int ntensors = traits::arity + 1;
|
| 204 |
+
|
| 205 |
+
char* C10_RESTRICT data[ntensors];
|
| 206 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 207 |
+
data[arg] = data_[arg];
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
Vec opt_scalar = Vec(S > 0 ? *(scalar_t*)data[S] : scalar_t(0));
|
| 211 |
+
int64_t i = 0;
|
| 212 |
+
for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
|
| 213 |
+
auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);
|
| 214 |
+
auto args2 = dereference_vec<traits>(&data[1], opt_scalar, S, i + Vec::size());
|
| 215 |
+
auto out1 = c10::guts::apply(std::forward<vec_func_t>(vop), std::move(args1));
|
| 216 |
+
auto out2 = c10::guts::apply(std::forward<vec_func_t>(vop), std::move(args2));
|
| 217 |
+
out1.store(data[0] + i * sizeof(scalar_t));
|
| 218 |
+
out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t));
|
| 219 |
+
}
|
| 220 |
+
if (i < n) {
|
| 221 |
+
int64_t strides[ntensors];
|
| 222 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 223 |
+
strides[arg] = (S > 0 && arg == S) ? 0 : sizeof(scalar_t);
|
| 224 |
+
}
|
| 225 |
+
basic_loop(data, strides, i, n, std::forward<func_t>(op));
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
template <typename traits, typename cb_t>
|
| 231 |
+
static inline void unroll_contiguous_scalar_checks(
|
| 232 |
+
const int64_t* /*strides*/,
|
| 233 |
+
std::index_sequence<>,
|
| 234 |
+
cb_t&& cb) {
|
| 235 |
+
cb(0);
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
template <typename traits, typename cb_t, size_t INDEX0, size_t ...INDEX>
|
| 239 |
+
static inline void unroll_contiguous_scalar_checks(
|
| 240 |
+
const int64_t* strides,
|
| 241 |
+
std::index_sequence<INDEX0, INDEX...>,
|
| 242 |
+
cb_t&& cb) {
|
| 243 |
+
if (is_contiguous_scalar<traits, INDEX0 + 1>(strides)) {
|
| 244 |
+
cb(INDEX0 + 1);
|
| 245 |
+
} else {
|
| 246 |
+
unroll_contiguous_scalar_checks<traits>(strides, std::index_sequence<INDEX...>{}, std::forward<cb_t>(cb));
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
template <typename op_t, typename vop_t>
|
| 251 |
+
struct VectorizedLoop2d {
|
| 252 |
+
op_t op;
|
| 253 |
+
vop_t vop;
|
| 254 |
+
|
| 255 |
+
using traits = function_traits<op_t>;
|
| 256 |
+
static constexpr int ntensors = traits::arity + 1;
|
| 257 |
+
using data_t = std::array<char*, ntensors>;
|
| 258 |
+
|
| 259 |
+
VectorizedLoop2d(const op_t &op, vop_t vop):
|
| 260 |
+
op(op), vop(std::move(vop)) {}
|
| 261 |
+
|
| 262 |
+
static void advance(data_t &data, const int64_t *outer_strides) {
|
| 263 |
+
for (const auto arg : c10::irange(data.size())) {
|
| 264 |
+
data[arg] += outer_strides[arg];
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
void operator()(char** base, const int64_t *strides, int64_t size0, int64_t size1) {
|
| 269 |
+
data_t data;
|
| 270 |
+
std::copy_n(base, ntensors, data.data());
|
| 271 |
+
const int64_t *outer_strides = &strides[ntensors];
|
| 272 |
+
|
| 273 |
+
if (is_contiguous<traits>(strides)) {
|
| 274 |
+
for (const auto i C10_UNUSED : c10::irange(size1)) {
|
| 275 |
+
vectorized_loop(data.data(), size0, 0, op, vop);
|
| 276 |
+
advance(data, outer_strides);
|
| 277 |
+
}
|
| 278 |
+
} else {
|
| 279 |
+
using Indices = std::make_index_sequence<traits::arity>;
|
| 280 |
+
unroll_contiguous_scalar_checks<traits>(strides, Indices{}, [&](size_t idx) {
|
| 281 |
+
if (idx) {
|
| 282 |
+
for (const auto i C10_UNUSED : c10::irange(size1)) {
|
| 283 |
+
vectorized_loop(data.data(), size0, idx, op, vop);
|
| 284 |
+
advance(data, outer_strides);
|
| 285 |
+
}
|
| 286 |
+
} else {
|
| 287 |
+
for (const auto i C10_UNUSED : c10::irange(size1)) {
|
| 288 |
+
basic_loop(data.data(), strides, 0, size0, op);
|
| 289 |
+
advance(data, outer_strides);
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
});
|
| 293 |
+
}
|
| 294 |
+
}
|
| 295 |
+
};
|
| 296 |
+
|
| 297 |
+
template <typename op_t, typename vop_t>
|
| 298 |
+
VectorizedLoop2d<op_t, vop_t> make_vectorized_loop2d(
|
| 299 |
+
const op_t &op, const vop_t &vop) {
|
| 300 |
+
return VectorizedLoop2d<op_t, vop_t>(op, vop);
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
template <typename func_t>
|
| 304 |
+
void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
|
| 305 |
+
using traits = function_traits<func_t>;
|
| 306 |
+
// this could be extended to work with void return types
|
| 307 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 308 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 309 |
+
// dynamic casting not currently supported on CPU
|
| 310 |
+
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
| 311 |
+
|
| 312 |
+
iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
|
| 313 |
+
// basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that
|
| 314 |
+
// iter.for_each is ever sending to the loop lambda
|
| 315 |
+
basic_loop(data, strides, 0, n, std::forward<func_t>(op));
|
| 316 |
+
}, grain_size);
|
| 317 |
+
iter.cast_outputs();
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
// This function helps write elementwise kernels that requires multiple outputs.
|
| 321 |
+
// It follows the similar structure of cpu_kernel.
|
| 322 |
+
// Instead of `basic_loop` function, a new `multiple_outputs_loop` function is
|
| 323 |
+
// manipulated to handle multiple return values.
|
| 324 |
+
// For now `needs_dynamic_casting` check is not added as the passed lambda (`func_t`)
|
| 325 |
+
// of `multiple_outputs_loop` returns `std::tuple` instead of `scalar_t`.
|
| 326 |
+
// The `gpu_kernel_multiple_outputs` is also implemented without this check,
|
| 327 |
+
// We could extend `needs_dynamic_casting` to support both `std::tuple` and
|
| 328 |
+
// `thrust::tuple` in the future.
|
| 329 |
+
template <typename func_t>
|
| 330 |
+
void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
|
| 331 |
+
using traits = function_traits<func_t>;
|
| 332 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 333 |
+
|
| 334 |
+
iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
|
| 335 |
+
multiple_outputs_loop(data, strides, 0, n, std::forward<func_t>(op));
|
| 336 |
+
}, grain_size);
|
| 337 |
+
iter.cast_outputs();
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
template <bool check_dynamic_cast=true, typename func_t, typename vec_func_t>
|
| 341 |
+
void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, int64_t grain_size = at::internal::GRAIN_SIZE) {
|
| 342 |
+
using traits = function_traits<func_t>;
|
| 343 |
+
// this could be extended to work with void return types
|
| 344 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 345 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 346 |
+
// dynamic casting not currently supported on CPU, but some kernels (like Fill)
|
| 347 |
+
// explicitly dynamic_cast, so we give the opt-out of checking.
|
| 348 |
+
if constexpr (check_dynamic_cast) {
|
| 349 |
+
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
iter.for_each(make_vectorized_loop2d(op, vop), grain_size);
|
| 353 |
+
iter.cast_outputs();
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
template <typename func_t>
|
| 357 |
+
void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range) {
|
| 358 |
+
using traits = function_traits<func_t>;
|
| 359 |
+
constexpr bool result_void = std::is_void<typename traits::result_type>::value;
|
| 360 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity &&
|
| 361 |
+
((result_void && iter.noutputs() == 0) || (!result_void && iter.noutputs() == 1)));
|
| 362 |
+
// dynamic casting not currently supported on CPU
|
| 363 |
+
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
| 364 |
+
|
| 365 |
+
iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) {
|
| 366 |
+
basic_loop(data, strides, 0, n, std::forward<func_t>(op));
|
| 367 |
+
}, range);
|
| 368 |
+
iter.cast_outputs();
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
template <typename func_t>
|
| 372 |
+
void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op) {
|
| 373 |
+
cpu_serial_kernel(iter, op, {0, iter.numel()});
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
template <typename func_t, typename vec_func_t>
|
| 377 |
+
void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, const Range& range) {
|
| 378 |
+
using traits = function_traits<func_t>;
|
| 379 |
+
// this could be extended to work with void return types
|
| 380 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 381 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 382 |
+
// dynamic casting not currently supported on CPU
|
| 383 |
+
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
| 384 |
+
|
| 385 |
+
iter.serial_for_each(make_vectorized_loop2d(op, vop), range);
|
| 386 |
+
iter.cast_outputs();
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
template <typename func_t, typename vec_func_t>
|
| 390 |
+
void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) {
|
| 391 |
+
cpu_serial_kernel_vec(iter, op, vop, {0, iter.numel()});
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
}}} // namespace at::native::<anonymous>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/DispatchStub.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
class Tensor;
|
| 6 |
+
|
| 7 |
+
namespace native {
|
| 8 |
+
|
| 9 |
+
using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&);
|
| 10 |
+
|
| 11 |
+
DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel);
|
| 12 |
+
DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel);
|
| 13 |
+
|
| 14 |
+
}} // at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Parallel.h>
|
| 4 |
+
#include <ATen/NumericUtils.h>
|
| 5 |
+
#include <ATen/cpu/vec/vec.h>
|
| 6 |
+
#include <ATen/cpu/vec/functional.h>
|
| 7 |
+
#include <ATen/native/ReductionType.h>
|
| 8 |
+
#include <c10/util/irange.h>
|
| 9 |
+
#include <ATen/OpMathType.h>
|
| 10 |
+
#include <ATen/native/cpu/utils.h>
|
| 11 |
+
#include <ATen/OpMathType.h>
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
inline namespace CPU_CAPABILITY {
|
| 15 |
+
|
| 16 |
+
using namespace vec;
|
| 17 |
+
|
| 18 |
+
#define AT_DISPATCH_REDUCTION_TYPES(op, ...) \
|
| 19 |
+
[&] { \
|
| 20 |
+
switch (op) { \
|
| 21 |
+
case ReductionType::SUM: { \
|
| 22 |
+
static constexpr auto reduce = ReductionType::SUM; \
|
| 23 |
+
return __VA_ARGS__(); \
|
| 24 |
+
} \
|
| 25 |
+
case ReductionType::MEAN: { \
|
| 26 |
+
static constexpr auto reduce = ReductionType::MEAN; \
|
| 27 |
+
return __VA_ARGS__(); \
|
| 28 |
+
} \
|
| 29 |
+
case ReductionType::MIN: { \
|
| 30 |
+
static constexpr auto reduce = ReductionType::MIN; \
|
| 31 |
+
return __VA_ARGS__(); \
|
| 32 |
+
} \
|
| 33 |
+
case ReductionType::MAX: { \
|
| 34 |
+
static constexpr auto reduce = ReductionType::MAX; \
|
| 35 |
+
return __VA_ARGS__(); \
|
| 36 |
+
} \
|
| 37 |
+
case ReductionType::PROD: { \
|
| 38 |
+
static constexpr auto reduce = ReductionType::PROD; \
|
| 39 |
+
return __VA_ARGS__(); \
|
| 40 |
+
} \
|
| 41 |
+
} \
|
| 42 |
+
}()
|
| 43 |
+
|
| 44 |
+
template <typename scalar_t, ReductionType reduce>
|
| 45 |
+
inline vec_scalar_t<scalar_t> init_value() {
|
| 46 |
+
using acc_t = vec_scalar_t<scalar_t>;
|
| 47 |
+
acc_t val;
|
| 48 |
+
if (reduce == ReductionType::SUM ||
|
| 49 |
+
reduce == ReductionType::MEAN) {
|
| 50 |
+
val = static_cast<acc_t>(0);
|
| 51 |
+
} else if (reduce == ReductionType::PROD) {
|
| 52 |
+
val = static_cast<acc_t>(1);
|
| 53 |
+
} else if (reduce == ReductionType::MAX) {
|
| 54 |
+
val = -std::numeric_limits<acc_t>::infinity();
|
| 55 |
+
} else {
|
| 56 |
+
TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
|
| 57 |
+
val = std::numeric_limits<acc_t>::infinity();
|
| 58 |
+
}
|
| 59 |
+
return val;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
template <typename scalar_t, ReductionType reduce>
|
| 63 |
+
inline vec_scalar_t<scalar_t> init_value(const c10::optional<Scalar>& initial) {
|
| 64 |
+
using acc_t = vec_scalar_t<scalar_t>;
|
| 65 |
+
if (initial.has_value()) {
|
| 66 |
+
return initial.value().to<acc_t>();
|
| 67 |
+
} else {
|
| 68 |
+
return init_value<scalar_t, reduce>();
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
template <typename scalar_t>
|
| 73 |
+
inline void init(scalar_t* out, int64_t size, const vec_scalar_t<scalar_t>& val) {
|
| 74 |
+
using Vec = Vectorized<vec_scalar_t<scalar_t>>;
|
| 75 |
+
map<scalar_t>(
|
| 76 |
+
[val](Vec x) { return Vec(val); },
|
| 77 |
+
out,
|
| 78 |
+
out,
|
| 79 |
+
size);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
template <typename scalar_t, ReductionType reduce>
|
| 83 |
+
inline void init(scalar_t* out, int64_t size, const c10::optional<Scalar>& initial) {
|
| 84 |
+
using acc_t = vec_scalar_t<scalar_t>;
|
| 85 |
+
acc_t val = init_value<scalar_t, reduce>(initial);
|
| 86 |
+
init(out, size, val);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
// overload with `include_self`, used by scatter_reduce
|
| 90 |
+
template <typename scalar_t, ReductionType reduce>
|
| 91 |
+
inline void init(scalar_t* out, int64_t size, bool include_self = false) {
|
| 92 |
+
using acc_t = vec_scalar_t<scalar_t>;
|
| 93 |
+
if (!include_self) {
|
| 94 |
+
acc_t val = init_value<scalar_t, reduce>();
|
| 95 |
+
init(out, size, val);
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
template <typename scalar_t, ReductionType reduce>
|
| 100 |
+
inline void _init(scalar_t* self_ptr, at::opmath_type<scalar_t>* buffer_ptr, int64_t size, bool include_self) {
|
| 101 |
+
if (!include_self) {
|
| 102 |
+
init<at::opmath_type<scalar_t>, reduce>(buffer_ptr, size, include_self);
|
| 103 |
+
} else {
|
| 104 |
+
vec::convert(self_ptr, buffer_ptr, size);
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
template <typename scalar_t>
|
| 109 |
+
inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
|
| 110 |
+
_max(const scalar_t& x, const scalar_t& y) {
|
| 111 |
+
return at::_isnan(y) ? y : std::max(x, y);
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
template <typename scalar_t>
|
| 115 |
+
inline Vectorized<scalar_t> _max(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
|
| 116 |
+
// vec::maximum propagates NaN
|
| 117 |
+
return vec::maximum(x, y);
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
template <typename vec_t>
|
| 121 |
+
inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
|
| 122 |
+
_max(const vec_t& x, const vec_t& y) {
|
| 123 |
+
// vec::maximum propagates NaN
|
| 124 |
+
return maximum(x, y);
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
template <typename scalar_t>
|
| 128 |
+
inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
|
| 129 |
+
_min(const scalar_t& x, const scalar_t& y) {
|
| 130 |
+
return at::_isnan(y) ? y : std::min(x, y);
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
template <typename scalar_t>
|
| 134 |
+
inline Vectorized<scalar_t> _min(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
|
| 135 |
+
// vec::minimum propagates NaN
|
| 136 |
+
return vec::minimum(x, y);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
template <typename vec_t>
|
| 140 |
+
inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
|
| 141 |
+
_min(const vec_t& x, const vec_t& y) {
|
| 142 |
+
// vec::minimum propagates NaN
|
| 143 |
+
return minimum(x, y);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
template <typename scalar_t, typename accumut, typename Op,
|
| 147 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 148 |
+
inline void map_acc(
|
| 149 |
+
const Op& vec_fun,
|
| 150 |
+
accumut* output_data,
|
| 151 |
+
const accumut* input_data,
|
| 152 |
+
const scalar_t* input_data2,
|
| 153 |
+
int64_t size) {
|
| 154 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 155 |
+
using aVec = vec::Vectorized<accumut>;
|
| 156 |
+
int64_t d = 0;
|
| 157 |
+
constexpr int64_t kVecSize = Vec::size();
|
| 158 |
+
constexpr int64_t kaVecSize = aVec::size();
|
| 159 |
+
for (d = 0; d < size - (size % kVecSize); d += kVecSize) {
|
| 160 |
+
Vec data2_vec = Vec::loadu(input_data2 + d);
|
| 161 |
+
auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
|
| 162 |
+
aVec input_vec0 = aVec::loadu(input_data + d);
|
| 163 |
+
aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize);
|
| 164 |
+
vec_fun(input_vec0, data2_avec0).store(output_data + d);
|
| 165 |
+
vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize);
|
| 166 |
+
}
|
| 167 |
+
if (size - d > 0) {
|
| 168 |
+
int64_t tail_size = size - d;
|
| 169 |
+
Vec data2_vec = Vec::loadu(input_data2 + d, tail_size);
|
| 170 |
+
auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
|
| 171 |
+
if (tail_size > kaVecSize) {
|
| 172 |
+
aVec input_vec0 = aVec::loadu(input_data + d);
|
| 173 |
+
aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize);
|
| 174 |
+
vec_fun(input_vec0, data2_avec0).store(output_data + d);
|
| 175 |
+
vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize);
|
| 176 |
+
} else {
|
| 177 |
+
aVec input_vec0 = aVec::loadu(input_data + d, tail_size);
|
| 178 |
+
vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size);
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
// for Max and Min, propagate NaN:
|
| 184 |
+
template <typename T, ReductionType reduce>
|
| 185 |
+
inline T update(const T& x, const T& y) {
|
| 186 |
+
if (reduce == ReductionType::SUM ||
|
| 187 |
+
reduce == ReductionType::MEAN) {
|
| 188 |
+
return x + y;
|
| 189 |
+
} else if (reduce == ReductionType::PROD) {
|
| 190 |
+
return x * y;
|
| 191 |
+
} else if (reduce == ReductionType::MAX) {
|
| 192 |
+
return _max(x, y);
|
| 193 |
+
} else {
|
| 194 |
+
TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
|
| 195 |
+
return _min(x, y);
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
template <typename scalar_t, ReductionType reduce>
|
| 200 |
+
inline void update(scalar_t* out, const scalar_t* data, int64_t K) {
|
| 201 |
+
using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
|
| 202 |
+
map2<scalar_t>(
|
| 203 |
+
[](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
|
| 204 |
+
out,
|
| 205 |
+
out,
|
| 206 |
+
data,
|
| 207 |
+
K);
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
template <typename scalar_t, ReductionType reduce,
|
| 211 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 212 |
+
inline void update(at::opmath_type<scalar_t>* out, const scalar_t* data, int64_t K) {
|
| 213 |
+
using opmath_t = at::opmath_type<scalar_t>;
|
| 214 |
+
using Vec = vec::Vectorized<opmath_t>;
|
| 215 |
+
map_acc<scalar_t, opmath_t>(
|
| 216 |
+
[](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
|
| 217 |
+
out,
|
| 218 |
+
out,
|
| 219 |
+
data,
|
| 220 |
+
K);
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
template <typename scalar_t, ReductionType reduce>
|
| 224 |
+
inline void write(scalar_t* out, int64_t count, int64_t K) {
|
| 225 |
+
using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
|
| 226 |
+
if (reduce == ReductionType::MEAN) {
|
| 227 |
+
if (count > 0) {
|
| 228 |
+
vec::map<scalar_t>(
|
| 229 |
+
[count](Vec x) { return x / Vec(count); },
|
| 230 |
+
out,
|
| 231 |
+
out,
|
| 232 |
+
K);
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
} // namespace CPU_CAPABILITY
|
| 238 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <ATen/native/ReductionType.h>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
|
| 9 |
+
using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 10 |
+
using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 11 |
+
using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 12 |
+
using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 13 |
+
using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 14 |
+
|
| 15 |
+
DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub);
|
| 16 |
+
DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub);
|
| 17 |
+
DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub);
|
| 18 |
+
DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub);
|
| 19 |
+
DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub);
|
| 20 |
+
DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub);
|
| 21 |
+
|
| 22 |
+
} // at::native
|