| import copy |
| from dataclasses import dataclass |
| from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union |
|
|
| import torch |
|
|
| from ._compatibility import compatibility |
| from ._symbolic_trace import symbolic_trace |
| from .graph import Graph |
| from .graph_module import GraphModule |
| from .node import Node |
|
|
|
|
| if TYPE_CHECKING: |
| from .passes.utils.matcher_with_name_node_map_utils import InternalMatch |
|
|
| __all__ = [ |
| "Match", |
| "replace_pattern", |
| "replace_pattern_with_filters", |
| "ReplacedPatterns", |
| ] |
|
|
|
|
| @compatibility(is_backward_compatible=True) |
| class Match(NamedTuple): |
| |
| anchor: Node |
| |
| nodes_map: dict[Node, Node] |
|
|
|
|
| @compatibility(is_backward_compatible=False) |
| @dataclass |
| class ReplacedPatterns: |
| |
| anchor: Node |
| |
| nodes_map: dict[Node, Node] |
| |
| replacements: list[Node] |
|
|
|
|
| def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: |
| gm.delete_all_unused_submodules() |
|
|
| if isinstance(replacement, GraphModule): |
| replacement.graph.lint() |
|
|
| def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]: |
| module_path, _, attr_name = target.rpartition(".") |
| try: |
| mod: torch.nn.Module = gm.get_submodule(module_path) |
| except AttributeError: |
| return None |
| attr = getattr(mod, attr_name, None) |
| return attr |
|
|
| for node in gm.graph.nodes: |
| if node.op == "call_module" or node.op == "get_attr": |
| gm_attr = try_get_attr(gm, node.target) |
| replacement_attr = try_get_attr(replacement, node.target) |
|
|
| |
| |
| |
| if gm_attr is not None: |
| continue |
|
|
| |
| |
| elif replacement_attr is not None: |
| new_attr = copy.deepcopy(replacement_attr) |
| if isinstance(replacement_attr, torch.nn.Module): |
| gm.add_submodule(node.target, new_attr) |
| else: |
| setattr(gm, node.target, new_attr) |
|
|
| |
| |
| else: |
| raise RuntimeError( |
| 'Attempted to create a "', |
| node.op, |
| '" node during subgraph rewriting ' |
| f"with target {node.target}, but " |
| "the referenced attribute does not " |
| "exist in the replacement GraphModule", |
| ) |
|
|
| gm.graph.lint() |
|
|
|
|
| @compatibility(is_backward_compatible=True) |
| def replace_pattern( |
| gm: GraphModule, |
| pattern: Union[Callable, GraphModule], |
| replacement: Union[Callable, GraphModule], |
| ) -> list[Match]: |
| """ |
| Matches all possible non-overlapping sets of operators and their |
| data dependencies (``pattern``) in the Graph of a GraphModule |
| (``gm``), then replaces each of these matched subgraphs with another |
| subgraph (``replacement``). |
| |
| Args: |
| ``gm``: The GraphModule that wraps the Graph to operate on |
| ``pattern``: The subgraph to match in ``gm`` for replacement |
| ``replacement``: The subgraph to replace ``pattern`` with |
| |
| Returns: |
| List[Match]: A list of ``Match`` objects representing the places |
| in the original graph that ``pattern`` was matched to. The list |
| is empty if there are no matches. ``Match`` is defined as: |
| |
| .. code-block:: python |
| |
| class Match(NamedTuple): |
| # Node from which the match was found |
| anchor: Node |
| # Maps nodes in the pattern subgraph to nodes in the larger graph |
| nodes_map: Dict[Node, Node] |
| |
| Examples: |
| |
| .. code-block:: python |
| |
| import torch |
| from torch.fx import symbolic_trace, subgraph_rewriter |
| |
| |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x, w1, w2): |
| m1 = torch.cat([w1, w2]).sum() |
| m2 = torch.cat([w1, w2]).sum() |
| return x + torch.max(m1) + torch.max(m2) |
| |
| |
| def pattern(w1, w2): |
| return torch.cat([w1, w2]) |
| |
| |
| def replacement(w1, w2): |
| return torch.stack([w1, w2]) |
| |
| |
| traced_module = symbolic_trace(M()) |
| |
| subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) |
| |
| The above code will first match ``pattern`` in the ``forward`` |
| method of ``traced_module``. Pattern-matching is done based on |
| use-def relationships, not node names. For example, if you had |
| ``p = torch.cat([a, b])`` in ``pattern``, you could match |
| ``m = torch.cat([a, b])`` in the original ``forward`` function, |
| despite the variable names being different (``p`` vs ``m``). |
| |
| The ``return`` statement in ``pattern`` is matched based on its |
| value only; it may or may not match to the ``return`` statement in |
| the larger graph. In other words, the pattern doesn't have to extend |
| to the end of the larger graph. |
| |
| When the pattern is matched, it will be removed from the larger |
| function and replaced by ``replacement``. If there are multiple |
| matches for ``pattern`` in the larger function, each non-overlapping |
| match will be replaced. In the case of a match overlap, the first |
| found match in the set of overlapping matches will be replaced. |
| ("First" here being defined as the first in a topological ordering |
| of the Nodes' use-def relationships. In most cases, the first Node |
| is the parameter that appears directly after ``self``, while the |
| last Node is whatever the function returns.) |
| |
| One important thing to note is that the parameters of the |
| ``pattern`` Callable must be used in the Callable itself, |
| and the parameters of the ``replacement`` Callable must match |
| the pattern. The first rule is why, in the above code block, the |
| ``forward`` function has parameters ``x, w1, w2``, but the |
| ``pattern`` function only has parameters ``w1, w2``. ``pattern`` |
| doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. |
| As an example of the second rule, consider replacing |
| |
| .. code-block:: python |
| |
| def pattern(x, y): |
| return torch.neg(x) + torch.relu(y) |
| |
| with |
| |
| .. code-block:: python |
| |
| def replacement(x, y): |
| return torch.relu(x) |
| |
| In this case, ``replacement`` needs the same number of parameters |
| as ``pattern`` (both ``x`` and ``y``), even though the parameter |
| ``y`` isn't used in ``replacement``. |
| |
| After calling ``subgraph_rewriter.replace_pattern``, the generated |
| Python code looks like this: |
| |
| .. code-block:: python |
| |
| def forward(self, x, w1, w2): |
| stack_1 = torch.stack([w1, w2]) |
| sum_1 = stack_1.sum() |
| stack_2 = torch.stack([w1, w2]) |
| sum_2 = stack_2.sum() |
| max_1 = torch.max(sum_1) |
| add_1 = x + max_1 |
| max_2 = torch.max(sum_2) |
| add_2 = add_1 + max_2 |
| return add_2 |
| """ |
| match_and_replacements = _replace_pattern(gm, pattern, replacement) |
| return [ |
| Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements |
| ] |
|
|
|
|
| |
| @compatibility(is_backward_compatible=False) |
| def replace_pattern_with_filters( |
| gm: GraphModule, |
| pattern: Union[Callable, Graph, GraphModule], |
| replacement: Union[Callable, Graph, GraphModule, None] = None, |
| match_filters: Optional[ |
| list[Callable[["InternalMatch", Graph, Graph], bool]] |
| ] = None, |
| ignore_literals: bool = False, |
| |
| replacement_callback: Optional[ |
| Callable[["InternalMatch", Graph, Graph], Graph] |
| ] = None, |
| node_name_match: str = "", |
| ) -> list[ReplacedPatterns]: |
| """ |
| See replace_pattern for documentation. This function is an overload with an additional match_filter argument. |
| |
| Args: |
| ``match_filters``: A list of functions that take in |
| (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating |
| whether the match satisfies the condition. |
| See matcher_utils.py for definition of InternalMatch. |
| ``replacement_callback``: A function that takes in a match and returns a |
| Graph to be used as the replacement. This allows you to construct a |
| replacement graph based on the match. |
| ``replacement_callback``: Node name to match. If not empty, it will try to match the node name. |
| """ |
|
|
| return _replace_pattern( |
| gm, |
| pattern, |
| replacement, |
| match_filters, |
| ignore_literals, |
| replacement_callback, |
| node_name_match, |
| ) |
|
|
|
|
| def _replace_pattern( |
| gm: GraphModule, |
| pattern: Union[Callable, Graph, GraphModule], |
| replacement: Union[Callable, Graph, GraphModule, None] = None, |
| match_filters: Optional[ |
| list[Callable[["InternalMatch", Graph, Graph], bool]] |
| ] = None, |
| ignore_literals: bool = False, |
| |
| replacement_callback: Optional[ |
| Callable[["InternalMatch", Graph, Graph], Graph] |
| ] = None, |
| node_name_match: str = "", |
| ) -> list[ReplacedPatterns]: |
| from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher |
|
|
| if match_filters is None: |
| match_filters = [] |
|
|
| |
| original_graph: Graph = gm.graph |
|
|
| if isinstance(pattern, GraphModule): |
| pattern_graph = pattern.graph |
| elif isinstance(pattern, Graph): |
| pattern_graph = pattern |
| else: |
| pattern_graph = symbolic_trace(pattern).graph |
|
|
| matcher = SubgraphMatcher( |
| pattern_graph, |
| match_output=False, |
| match_placeholder=False, |
| remove_overlapping_matches=True, |
| ignore_literals=ignore_literals, |
| ) |
| _matches: list[InternalMatch] = matcher.match( |
| original_graph, node_name_match=node_name_match |
| ) |
|
|
| |
| _matches = [ |
| m |
| for m in _matches |
| if all( |
| match_filter(m, original_graph, pattern_graph) |
| for match_filter in match_filters |
| ) |
| ] |
|
|
| if isinstance(replacement, GraphModule): |
| common_replacement_graph = replacement.graph |
| elif isinstance(replacement, Graph): |
| common_replacement_graph = replacement |
| elif callable(replacement): |
| common_replacement_graph = symbolic_trace(replacement).graph |
| else: |
| assert replacement_callback is not None, ( |
| "Must provide either a replacement GraphModule or a replacement callback" |
| ) |
| common_replacement_graph = None |
|
|
| |
| match_changed_node: dict[Node, Node] = {} |
|
|
| match_and_replacements = [] |
| for match in _matches: |
| if replacement_callback is not None: |
| replacement_graph = replacement_callback( |
| match, original_graph, pattern_graph |
| ) |
| else: |
| assert common_replacement_graph is not None, ( |
| "Must provide either a replacement GraphModule or a replacement callback" |
| ) |
| replacement_graph = common_replacement_graph |
| replacement_placeholders = [ |
| n for n in replacement_graph.nodes if n.op == "placeholder" |
| ] |
|
|
| |
|
|
| |
| |
| assert len(match.placeholder_nodes) == len(replacement_placeholders) |
| val_map: dict[Node, Node] = {} |
| for rn, gn in zip(replacement_placeholders, match.placeholder_nodes): |
| if isinstance(gn, Node): |
| val_map[rn] = match_changed_node.get(gn, gn) |
| if gn != val_map[rn]: |
| |
| gn_ind = match.placeholder_nodes.index(gn) |
| match.placeholder_nodes[gn_ind] = match_changed_node[gn] |
| map_key = list(match.nodes_map.keys())[ |
| list(match.nodes_map.values()).index(gn) |
| ] |
| match.nodes_map[map_key] = match_changed_node[gn] |
| else: |
| val_map[rn] = gn |
|
|
| |
| user_nodes: set[Node] = set() |
| for n in match.returning_nodes: |
| user_nodes.update(n.users) |
|
|
| first_user_node = None |
| if len(user_nodes) == 0: |
| first_user_node = None |
| elif len(user_nodes) == 1: |
| first_user_node = next(iter(user_nodes)) |
| else: |
| |
| |
| for n in original_graph.nodes: |
| if n in user_nodes: |
| first_user_node = n |
| break |
|
|
| first_next_node = None |
| if first_user_node is None: |
| |
| |
| next_node = None |
| for n in reversed(original_graph.nodes): |
| if n in match.returning_nodes: |
| first_next_node = next_node |
| break |
| else: |
| next_node = n |
| insert_point = ( |
| first_user_node if first_user_node is not None else first_next_node |
| ) |
| assert insert_point is not None, "The insert point can't be None" |
| with original_graph.inserting_before(insert_point): |
| copied_returning_nodes = original_graph.graph_copy( |
| replacement_graph, val_map |
| ) |
|
|
| if isinstance(copied_returning_nodes, Node): |
| copied_returning_nodes = (copied_returning_nodes,) |
|
|
| |
| replacement_nodes: list[Node] = [ |
| v for v in val_map.values() if v not in match.placeholder_nodes |
| ] |
|
|
| |
| |
| assert len(match.returning_nodes) == len(copied_returning_nodes) |
| for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): |
| gn.replace_all_uses_with(copied_node) |
| match_changed_node[gn] = copied_node |
| |
| for node in reversed(pattern_graph.nodes): |
| if node.op != "placeholder" and node.op != "output": |
| gn = match.nodes_map[node] |
| gm.graph.erase_node(gn) |
|
|
| match_and_replacements.append( |
| ReplacedPatterns( |
| anchor=match.anchors[0], |
| nodes_map=match.nodes_map, |
| replacements=replacement_nodes, |
| ) |
| ) |
|
|
| |
| |
| gm.recompile() |
|
|
| |
| |
| if isinstance(replacement, torch.nn.Module): |
| _replace_attributes(gm, replacement) |
|
|
| return match_and_replacements |
|
|