File size: 7,630 Bytes
67e9774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import itertools
import logging
from typing import Any, Callable

import torch
import torch._inductor.config as config
from torch._inductor import ir
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.ir import (
    Buffer,
    get_free_symbols,
    get_symbolic_inputs,
    gm_original_output_strides,
    ir_node_to_tensor,
    Layout,
)
from torch._inductor.runtime.benchmarking import benchmarker
from torch._inductor.utils import do_bench_using_profiling
from torch._inductor.virtualized import V


log = logging.getLogger(__name__)


class SubgraphChoiceCaller(ir.ChoiceCaller):
    """

    Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary

    GraphModule. Compiles the Subgraph down to a module for benchmarking.

    """

    def __init__(

        self,

        name: str,

        input_nodes: list[Buffer],

        layout: Layout,

        description: str,

        make_fx_graph: Callable[..., Any],

    ) -> None:
        super().__init__(name, input_nodes, layout, description)

        self.example_inputs = []
        with V.fake_mode:
            for inp in self.input_nodes:
                # Here there will be no unbacked symbols, as SubgraphBuffer does not support them
                assert len(get_free_symbols(inp.get_size(), unbacked_only=True)) == 0
                assert len(get_free_symbols(inp.get_stride(), unbacked_only=True)) == 0

                inp.data.freeze_layout()  # type: ignore[attr-defined]
                self.example_inputs.append(ir_node_to_tensor(inp))

        self.gm = make_fx_graph(*self.example_inputs)
        gm_original_output_strides(self.gm)

        self.sym_inputs = get_symbolic_inputs(self.input_nodes)

    def __str__(self) -> str:
        return f"SubgraphCaller({self.name})"

    def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
        # Codegen Subgraph for benchmarking
        # Need GraphLowering instead of SubgraphLowering to generate
        # fully callable module
        import torch._inductor.config as inductor_config
        from torch._inductor.graph import GraphLowering

        bm_graph_lowering = GraphLowering(
            gm=self.gm,
            example_inputs=self.example_inputs,
            shape_env=V.graph._shape_env,
            cpp_wrapper=V.graph.cpp_wrapper,
            aot_mode=V.graph.aot_mode,
            extern_node_serializer=V.graph.extern_node_serializer,
            is_inference=V.graph.is_inference,
            is_backward=V.graph.is_backward,
            name=f"benchmark_{self.name}",
        )

        for sym_inp in self.sym_inputs:
            bm_graph_lowering.graph_inputs[sym_inp.name] = sym_inp
            bm_graph_lowering.graph_input_names.append(sym_inp.name)

        sym_inputs = [
            int(V.graph.sizevars.shape_env.size_hint(sym_var))
            for sym_var in self.sym_inputs
        ]

        if len(sym_inputs) == 0:
            # Sanity check that args are same layout as example inputs
            # Only do it if there are no symbolic inputs, otherwise
            # the dynamic dim will be realized to the same size as args
            for ar, example_inp in zip(args, self.example_inputs):
                # Sanity check that args are same layout as example inputs
                if isinstance(ar, torch.Tensor):
                    assert isinstance(example_inp, torch.Tensor)
                    assert ar.shape == example_inp.shape
                    assert ar.stride() == example_inp.stride()

        if len(sym_inputs) == 0:
            # Sanity check that args are same layout as example inputs
            # Only do it if there are no symbolic inputs, otherwise
            # the dynamic dim will be realized to the same size as args
            for ar, example_inp in zip(args, self.example_inputs):
                # Sanity check that args are same layout as example inputs
                if isinstance(ar, torch.Tensor):
                    assert isinstance(example_inp, torch.Tensor)
                    assert ar.shape == example_inp.shape
                    assert ar.stride() == example_inp.stride()

        with V.set_graph_handler(bm_graph_lowering):
            # Don't bother autotuning on Triton here
            with inductor_config.patch(
                max_autotune=False,
                max_autotune_gemm=False,
                max_autotune_gemm_backends="ATEN",
            ):
                bm_graph_lowering.run(*self.example_inputs)
                mod = bm_graph_lowering.compile_to_module()
                bm_func = mod.call

                bm_func([*sym_inputs, *args])
        if config.profile_bandwidth_with_do_bench_using_profiling:
            return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
        return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))

    def hash_key(self) -> str:
        return "-".join(
            [
                self.name.rsplit("_", 1)[0],
                *[str(inp.get_size()) for inp in self.input_nodes],
                *[str(inp.get_stride()) for inp in self.input_nodes],
                str(self.gm.graph),
            ]
        )

    def output_node(self) -> ir.TensorBox:
        return ir.TensorBox.create(
            ir.SubgraphBuffer(
                layout=self.layout,
                input_nodes=self.input_nodes,
                gm=self.gm,
                example_inputs=self.example_inputs,
                subgraph_name=self.name,
            )
        )

    def info_dict(self) -> dict[str, Any]:
        """Information returned here is logged to the autotune log file when that is enabled."""
        return {
            "backend": "subgraph",
            "kernel_name": self.name,
        }

    def autoheuristic_id(self) -> str:
        return f"subgraph_{self.name}"


class SubgraphTemplate(KernelTemplate):
    """

    A template for subgraph evaluation to be used in autotuning.



    This class allows creating customized subgraphs that can be appended

    as choices during the autotuning process, enabling the selection of

    optimal implementations for complex operations.

    """

    index_counter = itertools.count()

    def __init__(

        self,

        name: str,

        make_fx_graph: Callable[..., Any],

    ):
        """

        Initialize a subgraph template.



        Args:

            name: The name of this template

            graph: The FX graph

        """
        self.name = f"{name}_{next(SubgraphTemplate.index_counter)}"
        self.make_fx_graph = make_fx_graph

    def generate(  # type: ignore[override]

        self,

        input_nodes: list[Buffer],

        layout: Layout,

        **kwargs: Any,

    ) -> SubgraphChoiceCaller:
        """

        Generate a SubgraphChoiceCaller instance for autotuning.



        Args:

            input_nodes: List of input nodes to the subgraph

            layout: Memory layout information for the output

            example_inputs: Example tensor inputs used to trace and benchmark the subgraph

            **kwargs: Additional keyword arguments



        Returns:

            SubgraphChoiceCaller: A callable object that can be used for autotuning

        """

        return SubgraphChoiceCaller(
            name=self.name,
            input_nodes=input_nodes,
            layout=layout,
            description="",
            make_fx_graph=self.make_fx_graph,
        )