File size: 7,894 Bytes
59f1501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# mypy: ignore-errors


import operator
from typing import Callable

import sympy

import torch
import torch.fx as fx
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils import _pytree as pytree
from torch.utils._pytree import tree_flatten


aten = torch.ops.aten


def get_aten_target(node: fx.Node) -> Callable:
    if hasattr(node.target, "overloadpacket"):
        return node.target.overloadpacket
    return node.target


rand_ops = [
    aten.dropout,
    aten._fused_dropout,
    aten._standard_gamma,
    aten.bernoulli,
    aten.multinomial,
    aten.native_dropout,
    aten.normal,
    aten.poisson,
    aten.binomial,
    aten.rrelu,
    aten.rand_like,
    aten.rand,
    aten.randint,
    aten.randn,
    aten.randperm,
]


# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph
def fx_graph_cse(fx_g: torch.fx.graph.Graph):
    new_graph = fx.Graph()
    env = {}  # map from node in the old graph to node in the new graph
    hash_env = {}  # map from hash to a node in the new graph
    token_map = {}  # map from hash to token

    from torch._inductor.pattern_matcher import (
        compute_mutation_region_ids,
        same_mutation_regions,
    )

    compute_mutation_region_ids(fx_g)  # type: ignore[arg-type]

    # Make a set of separate storages returned from the output, which will be preserved
    # when pruning.  This prevents us from deduplicating returned tensors which have
    # experienced identical operations, but are separate data structures in eager mode.
    output_node: fx.Node = list(fx_g.nodes)[-1]
    assert output_node.op == "output"

    def checkable_node(node: fx.Node) -> bool:
        """We can evaluate only nodes that represent tensors with defined storage."""
        if "val" not in node.meta or not isinstance(node.meta["val"], torch.Tensor):
            return False

        try:
            node.meta["val"].untyped_storage()
        except NotImplementedError:
            return False

        return True

    output_storages = {
        StorageWeakRef(n.meta["val"].untyped_storage())
        for n in output_node.all_input_nodes
        if checkable_node(n)
    }
    nodes_that_alias_outputs = {
        n
        for n in fx_g.nodes
        if checkable_node(n)
        and StorageWeakRef(n.meta["val"].untyped_storage()) in output_storages
    }

    for n in fx_g.nodes:
        # The placeholder, output, and get_attr nodes are copied to the new graph without change
        # do not CSE away random operations
        if (
            n.op == "placeholder"
            or n.op == "output"
            or n.op == "get_attr"
            or get_aten_target(n) in rand_ops
            # aten.empty is non-deterministic, so don't CSE it.
            # Also, aten.empty is almost always fusible into its consumer,
            # so it's not worth CSEing.
            or get_aten_target(n) is aten.empty
            or n in nodes_that_alias_outputs
            # This CSE pass currently doesn't handle re-propogation of unbacked
            # meta where it'll sometimes eliminate a _local_scalar_dense but not
            # replace the meta of downstream users. eg. one bug we've seen is:
            #
            # _local_scalar_dense_11: "Sym(u14)" = torch.ops.aten._local_scalar_dense.default(select_10);
            # sym_sum_2: "Sym(u19 + u20 + u21)" = torch.sym_sum((_local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13)) # noqa: B950
            #
            # Notice how _local_scalar_dense_11 is u14 but sym_sum_2's meta is incorrectly the old
            # pre-cse value of u19.
            or (
                "val" in n.meta
                and isinstance(n.meta["val"], sympy.Symbol)
                and free_unbacked_symbols(n.meta["val"])
            )
        ):
            new_node = new_graph.node_copy(n, lambda x: env[x])
            env[n] = new_node
        else:  # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
            # substitute args and kwargs members to their mapping in env if exists
            # specs can be used to reconstruct nested list/dictionaries
            def substitute(arg_list):
                arg_list, spec = tree_flatten(arg_list)
                for i in range(len(arg_list)):
                    v = arg_list[i]
                    if isinstance(v, torch.fx.node.Node) and v in env:
                        arg_list[i] = env[v]
                    if isinstance(v, (torch.SymBool, torch.SymInt, torch.SymFloat)):
                        arg_list[i] = v.node
                return tuple(arg_list), spec

            args, args_spec = substitute(n.args)
            kwargs, kwargs_spec = substitute(n.kwargs)

            # each token corresponds to a unique node
            # nodes with the same token can be substituted
            token = {
                "target": n.target,
                "args": args,
                "args_spec": args_spec,
                "kwargs": kwargs,
                "kwargs_spec": kwargs_spec,
            }

            # hash substituted args to a number, do not hash specs because specs are not hashable
            # We need to add type into hash to avoid situations like:
            # hash((primals_2, 1.0)) == hash((primals_2, 1))
            hash_arg = hash(
                (tuple((a, type(a)) for a in args), tuple((a, type(a)) for a in kwargs))
            )
            hash_val = (n.target, hash_arg)

            # check if a node has a substitute and can be eliminated
            hash_val_in_hash_env = hash_val in hash_env
            overwrite_due_to_mutation = False
            if hash_val_in_hash_env and token_map[hash_val] == token:
                duplicate_n_prev = hash_env[hash_val]
                if same_mutation_regions(n, duplicate_n_prev):
                    env[n] = duplicate_n_prev
                    continue
                else:
                    # any futures duplicates should replace with n, not duplicate_n_prev
                    overwrite_due_to_mutation = True

            new_node = new_graph.node_copy(n, lambda x: env[x])
            env[n] = new_node
            if overwrite_due_to_mutation or not hash_val_in_hash_env:
                hash_env[hash_val] = new_node
                token_map[hash_val] = token

    return new_graph


def raise_getitems(gm: fx.GraphModule) -> fx.GraphModule:
    # Pre-create a list of nodes to iterate over, as modifying the node order
    # during the loop can lead to infinite loops if not handled properly.
    getitem_nodes = list(
        gm.graph.find_nodes(op="call_function", target=operator.getitem)
    )

    # loop through getitem nodes in the graph and raise them to the parent node
    # in reverse order to perserve their original relative order
    for node in reversed(getitem_nodes):
        assert len(node.all_input_nodes) == 1
        parent = node.all_input_nodes[0]
        parent.append(node)

    gm.recompile()
    return gm


def strip_overloads(gm):
    """

    Modifies the target of graph nodes in :attr:`gm` to strip overloads.



    Args:

        gm(fx.GraphModule): The input Fx graph module to be modified

    """
    for node in gm.graph.nodes:
        if isinstance(node.target, torch._ops.OpOverload):
            node.target = node.target.overloadpacket
    gm.recompile()


def get_placeholders(graph):
    return graph.find_nodes(op="placeholder")


def get_outputs(graph):
    for node in graph.find_nodes(op="output"):
        return pytree.tree_leaves(node.args[0])
    raise AssertionError("No output node found")