File size: 7,630 Bytes
67e9774 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import itertools
import logging
from typing import Any, Callable
import torch
import torch._inductor.config as config
from torch._inductor import ir
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.ir import (
Buffer,
get_free_symbols,
get_symbolic_inputs,
gm_original_output_strides,
ir_node_to_tensor,
Layout,
)
from torch._inductor.runtime.benchmarking import benchmarker
from torch._inductor.utils import do_bench_using_profiling
from torch._inductor.virtualized import V
log = logging.getLogger(__name__)
class SubgraphChoiceCaller(ir.ChoiceCaller):
"""
Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary
GraphModule. Compiles the Subgraph down to a module for benchmarking.
"""
def __init__(
self,
name: str,
input_nodes: list[Buffer],
layout: Layout,
description: str,
make_fx_graph: Callable[..., Any],
) -> None:
super().__init__(name, input_nodes, layout, description)
self.example_inputs = []
with V.fake_mode:
for inp in self.input_nodes:
# Here there will be no unbacked symbols, as SubgraphBuffer does not support them
assert len(get_free_symbols(inp.get_size(), unbacked_only=True)) == 0
assert len(get_free_symbols(inp.get_stride(), unbacked_only=True)) == 0
inp.data.freeze_layout() # type: ignore[attr-defined]
self.example_inputs.append(ir_node_to_tensor(inp))
self.gm = make_fx_graph(*self.example_inputs)
gm_original_output_strides(self.gm)
self.sym_inputs = get_symbolic_inputs(self.input_nodes)
def __str__(self) -> str:
return f"SubgraphCaller({self.name})"
def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
# Codegen Subgraph for benchmarking
# Need GraphLowering instead of SubgraphLowering to generate
# fully callable module
import torch._inductor.config as inductor_config
from torch._inductor.graph import GraphLowering
bm_graph_lowering = GraphLowering(
gm=self.gm,
example_inputs=self.example_inputs,
shape_env=V.graph._shape_env,
cpp_wrapper=V.graph.cpp_wrapper,
aot_mode=V.graph.aot_mode,
extern_node_serializer=V.graph.extern_node_serializer,
is_inference=V.graph.is_inference,
is_backward=V.graph.is_backward,
name=f"benchmark_{self.name}",
)
for sym_inp in self.sym_inputs:
bm_graph_lowering.graph_inputs[sym_inp.name] = sym_inp
bm_graph_lowering.graph_input_names.append(sym_inp.name)
sym_inputs = [
int(V.graph.sizevars.shape_env.size_hint(sym_var))
for sym_var in self.sym_inputs
]
if len(sym_inputs) == 0:
# Sanity check that args are same layout as example inputs
# Only do it if there are no symbolic inputs, otherwise
# the dynamic dim will be realized to the same size as args
for ar, example_inp in zip(args, self.example_inputs):
# Sanity check that args are same layout as example inputs
if isinstance(ar, torch.Tensor):
assert isinstance(example_inp, torch.Tensor)
assert ar.shape == example_inp.shape
assert ar.stride() == example_inp.stride()
if len(sym_inputs) == 0:
# Sanity check that args are same layout as example inputs
# Only do it if there are no symbolic inputs, otherwise
# the dynamic dim will be realized to the same size as args
for ar, example_inp in zip(args, self.example_inputs):
# Sanity check that args are same layout as example inputs
if isinstance(ar, torch.Tensor):
assert isinstance(example_inp, torch.Tensor)
assert ar.shape == example_inp.shape
assert ar.stride() == example_inp.stride()
with V.set_graph_handler(bm_graph_lowering):
# Don't bother autotuning on Triton here
with inductor_config.patch(
max_autotune=False,
max_autotune_gemm=False,
max_autotune_gemm_backends="ATEN",
):
bm_graph_lowering.run(*self.example_inputs)
mod = bm_graph_lowering.compile_to_module()
bm_func = mod.call
bm_func([*sym_inputs, *args])
if config.profile_bandwidth_with_do_bench_using_profiling:
return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
def hash_key(self) -> str:
return "-".join(
[
self.name.rsplit("_", 1)[0],
*[str(inp.get_size()) for inp in self.input_nodes],
*[str(inp.get_stride()) for inp in self.input_nodes],
str(self.gm.graph),
]
)
def output_node(self) -> ir.TensorBox:
return ir.TensorBox.create(
ir.SubgraphBuffer(
layout=self.layout,
input_nodes=self.input_nodes,
gm=self.gm,
example_inputs=self.example_inputs,
subgraph_name=self.name,
)
)
def info_dict(self) -> dict[str, Any]:
"""Information returned here is logged to the autotune log file when that is enabled."""
return {
"backend": "subgraph",
"kernel_name": self.name,
}
def autoheuristic_id(self) -> str:
return f"subgraph_{self.name}"
class SubgraphTemplate(KernelTemplate):
"""
A template for subgraph evaluation to be used in autotuning.
This class allows creating customized subgraphs that can be appended
as choices during the autotuning process, enabling the selection of
optimal implementations for complex operations.
"""
index_counter = itertools.count()
def __init__(
self,
name: str,
make_fx_graph: Callable[..., Any],
):
"""
Initialize a subgraph template.
Args:
name: The name of this template
graph: The FX graph
"""
self.name = f"{name}_{next(SubgraphTemplate.index_counter)}"
self.make_fx_graph = make_fx_graph
def generate( # type: ignore[override]
self,
input_nodes: list[Buffer],
layout: Layout,
**kwargs: Any,
) -> SubgraphChoiceCaller:
"""
Generate a SubgraphChoiceCaller instance for autotuning.
Args:
input_nodes: List of input nodes to the subgraph
layout: Memory layout information for the output
example_inputs: Example tensor inputs used to trace and benchmark the subgraph
**kwargs: Additional keyword arguments
Returns:
SubgraphChoiceCaller: A callable object that can be used for autotuning
"""
return SubgraphChoiceCaller(
name=self.name,
input_nodes=input_nodes,
layout=layout,
description="",
make_fx_graph=self.make_fx_graph,
)
|