| | |
| | |
| | """Basic types for the pattern matching and rewriter API.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import dataclasses |
| | import enum |
| | from collections import defaultdict |
| | from typing import TYPE_CHECKING, Any, MutableSequence, Sequence, Union |
| |
|
| | from onnxscript import ir |
| |
|
| | if TYPE_CHECKING: |
| | import onnxscript.rewriter._pattern_ir as _pattern_ir |
| | import onnxscript.rewriter._rewrite_rule as _rewrite_rule |
| |
|
| |
|
| | class MatchFailureInfo: |
| | """Encapsulates information about a pattern match failure.""" |
| |
|
| | def __init__( |
| | self, |
| | reason: str = "", |
| | *failure_source: ir.Node | ir.Value, |
| | ): |
| | self.reason = reason |
| | self.failure_sources: tuple[ir.Node | ir.Value, ...] = failure_source |
| | assert all(isinstance(item, (ir.Node, ir.Value)) for item in failure_source), ( |
| | f"All items in failure_source must be ir.Node or ir.Value, got {[type(item) for item in failure_source]}" |
| | ) |
| |
|
| | def __str__(self): |
| | return f"MatchFailureInfo(reason={self.reason!r}, failure_sources={self.failure_sources!r})" |
| |
|
| |
|
| | class MatchFailureError(MatchFailureInfo, Exception): |
| | """Exception raised when a pattern match fails. |
| | |
| | This makes it easier to handle match failures in a compositional way, |
| | for example, during the condition-checking phase of a pattern match. |
| | It allows us to define utility functions without having to check for |
| | and propagate match failures explicitly. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | reason: str = "", |
| | *failure_source: ir.Node | ir.Value, |
| | ): |
| | MatchFailureInfo.__init__(self, reason, *failure_source) |
| | Exception.__init__(self, reason) |
| |
|
| |
|
| | class MatchResult: |
| | """The state object used by the pattern-matching algorithm. |
| | |
| | A match can either succeed or fail. |
| | If it succeeds, it returns a list of nodes that matched the pattern |
| | and a set of bindings for the variables in the pattern. |
| | |
| | Example: |
| | :: |
| | def pattern(x, shape1, shape2): |
| | t1 = op.Reshape(x, shape1) |
| | t2 = op.Reshape(t1, shape2) |
| | return t2 |
| | The above pattern matches a sequence of two Reshape ops. |
| | The matched_nodes will contain the two Reshape ops, and the bindings will |
| | contain the values that are bound to the variables `x`, `shape1`, and `shape2`. |
| | """ |
| |
|
| | def __init__(self) -> None: |
| | |
| | self._partial_matches: list[PartialMatchResult] = [PartialMatchResult()] |
| |
|
| | def __repr__(self) -> str: |
| | """Returns a string representation of the match result.""" |
| | if not self._partial_matches: |
| | return "MatchResult()" |
| | return ( |
| | f"MatchResult(success={bool(self)}, reason={self.reason!r}, nodes={self.nodes!r})" |
| | ) |
| |
|
| | @property |
| | def _current_match(self) -> PartialMatchResult: |
| | """Returns the current match result.""" |
| | return self._partial_matches[-1] |
| |
|
| | def enter_new_match(self) -> None: |
| | """Starts a new sub-match to try out one of multiple alternatives.""" |
| | match = PartialMatchResult() |
| | self._partial_matches.append(match) |
| |
|
| | def abandon_current_match(self) -> PartialMatchResult: |
| | """Abandons the current alternative due to failure.""" |
| | if len(self._partial_matches) < 2: |
| | raise ValueError("No match to abandon.") |
| | return self._partial_matches.pop() |
| |
|
| | def merge_current_match(self) -> None: |
| | """Merges a successful sub-match for an alternative with the parent one.""" |
| | if len(self._partial_matches) < 2: |
| | raise ValueError("No match to merge.") |
| | current_match = self._partial_matches.pop() |
| | previous_match = self._partial_matches[-1] |
| | if not current_match: |
| | raise ValueError("Current match is not successful.") |
| | |
| | previous_match.merge(current_match) |
| |
|
| | def __bool__(self) -> bool: |
| | """Returns True if the current match is successful.""" |
| | return bool(self._current_match) |
| |
|
| | def fail( |
| | self, |
| | reason: str = "", |
| | failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None, |
| | ) -> MatchResult: |
| | self._current_match.fail(reason, failure_source) |
| | return self |
| |
|
| | @property |
| | def reason(self) -> str: |
| | """Returns the reason for the failure.""" |
| | return self._current_match.reason |
| |
|
| | @property |
| | def nodes(self) -> Sequence[ir.Node]: |
| | """Returns the list of nodes that matched the pattern.""" |
| | return self._current_match.nodes |
| |
|
| | def bind_node(self, pattern_node: _pattern_ir.NodePattern, node: ir.Node): |
| | """Binds a pattern node to a matched node.""" |
| | self.add_node(node) |
| | self._current_match.node_bindings[pattern_node] = node |
| |
|
| | def add_node(self, node: ir.Node) -> None: |
| | """Adds a node to the list of matched nodes.""" |
| | self._current_match.add_node(node) |
| |
|
| | def bind_value(self, pattern_value: _pattern_ir.ValuePattern, value: Any) -> bool: |
| | var_name = pattern_value.name |
| | |
| | |
| | |
| | if var_name is None: |
| | for match in self._partial_matches: |
| | if pattern_value in match.value_bindings: |
| | |
| | if match.value_bindings[pattern_value] == value: |
| | return True |
| | self._current_match.fail( |
| | f"Binding failure: {pattern_value} bound to two different values.", |
| | [match.value_bindings[pattern_value], value], |
| | ) |
| | return False |
| | self._current_match.value_bindings[pattern_value] = value |
| | return True |
| | return self.bind(var_name, value) |
| |
|
| | def bind(self, var: str, value: Any) -> bool: |
| | for match in self._partial_matches: |
| | if var in match.bindings: |
| | |
| | if match.bindings[var] == value: |
| | return True |
| | self._current_match.fail( |
| | f"Binding failure: {var} bound to two different values.", |
| | [match.bindings[var], value], |
| | ) |
| | return False |
| | self._current_match.bindings[var] = value |
| | return True |
| |
|
| | @property |
| | def bindings(self) -> dict[str, Any]: |
| | """Returns the bindings for the pattern variables.""" |
| | if len(self._partial_matches) > 1: |
| | raise ValueError("Bindings can be accessed only at the top-level match.") |
| | return self._current_match.bindings |
| |
|
| | @property |
| | def value_bindings(self) -> dict[_pattern_ir.ValuePattern, ir.Value]: |
| | """Returns the bindings for the value variables.""" |
| | if len(self._partial_matches) > 1: |
| | raise ValueError("Value bindings can be accessed only at the top-level match.") |
| | return self._current_match.value_bindings |
| |
|
| | @property |
| | def node_bindings(self) -> dict[_pattern_ir.NodePattern, ir.Node]: |
| | """Returns the bindings for the node variables.""" |
| | if len(self._partial_matches) > 1: |
| | raise ValueError("Node bindings can be accessed only at the top-level match.") |
| | return self._current_match.node_bindings |
| |
|
| | @property |
| | def outputs(self) -> MutableSequence[ir.Value]: |
| | """Returns the list of output values that matched the pattern.""" |
| | if len(self._partial_matches) > 1: |
| | raise ValueError("Outputs can be accessed only at the top-level match.") |
| | return self._current_match.outputs |
| |
|
| | @property |
| | def failure_nodes_and_values(self) -> list[Union[ir.Node, ir.Value]]: |
| | """Returns the nodes and values that caused the failure.""" |
| | return self._current_match._failure_nodes_and_values |
| |
|
| | def lookup_node(self, pattern_node: _pattern_ir.NodePattern) -> ir.Node | None: |
| | """Looks up the node that matched the given pattern node.""" |
| | for match in self._partial_matches: |
| | if pattern_node in match.node_bindings: |
| | return match.node_bindings[pattern_node] |
| | return None |
| |
|
| | def num_matched_nodes(self) -> int: |
| | """Returns the number of nodes matched so far.""" |
| | return sum(len(match.node_bindings) for match in self._partial_matches) |
| |
|
| |
|
| | class PartialMatchResult: |
| | """The state object used by the pattern-matching algorithm for a sub-match.""" |
| |
|
| | def __init__(self) -> None: |
| | self._success: bool = True |
| | |
| | |
| | |
| | |
| | self._matched_nodes: MutableSequence[ir.Node] = [] |
| | |
| | |
| | self._bindings: dict[str, Any] = {} |
| | self._value_bindings: dict[_pattern_ir.ValuePattern, ir.Value] = {} |
| | self._node_bindings: dict[_pattern_ir.NodePattern, ir.Node] = {} |
| |
|
| | self._outputs: list[ir.Value] = [] |
| | |
| | self._reason: str = "" |
| | |
| | self._failure_nodes_and_values: list[Union[ir.Node, ir.Value]] = [] |
| |
|
| | def __bool__(self): |
| | return self._success |
| |
|
| | def fail( |
| | self, |
| | reason: str = "", |
| | failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None, |
| | ) -> None: |
| | self._success = False |
| | self._reason = reason |
| | if failure_source is not None: |
| | if isinstance(failure_source, list): |
| | self._failure_nodes_and_values.extend(failure_source) |
| | else: |
| | self._failure_nodes_and_values.append(failure_source) |
| |
|
| | @property |
| | def reason(self) -> str: |
| | return self._reason |
| |
|
| | @property |
| | def nodes(self) -> Sequence[ir.Node]: |
| | return tuple(self._matched_nodes) |
| |
|
| | def add_node(self, node: ir.Node) -> None: |
| | """Adds a node to the list of matched nodes.""" |
| | self._matched_nodes.append(node) |
| |
|
| | @property |
| | def bindings(self) -> dict[str, Any]: |
| | return self._bindings |
| |
|
| | @property |
| | def value_bindings(self) -> dict[_pattern_ir.ValuePattern, ir.Value]: |
| | return self._value_bindings |
| |
|
| | @property |
| | def outputs(self) -> MutableSequence[ir.Value]: |
| | return self._outputs |
| |
|
| | @property |
| | def node_bindings(self) -> dict[_pattern_ir.NodePattern, ir.Node]: |
| | return self._node_bindings |
| |
|
| | def merge(self, other: PartialMatchResult) -> None: |
| | """Merges a successful sub-match for an alternative with the parent one.""" |
| | if self._success and other._success: |
| | |
| | |
| | self._bindings.update(other._bindings) |
| | self._matched_nodes.extend(other.nodes) |
| | |
| | |
| | assert not other._outputs |
| | else: |
| | |
| | raise NotImplementedError("Merging failed matches is not yet supported.") |
| |
|
| |
|
| | class MatchStatus(enum.IntEnum): |
| | """The status of a pattern-matching operation.""" |
| |
|
| | NO_MATCH = 0 |
| | CONDITION_FAILED = 1 |
| | REPLACEMENT_FAILED = 2 |
| | SUCCESS = 3 |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class MatchInfo: |
| | """The status of a pattern-matching operation. An extension of MatchResult.""" |
| |
|
| | match_result: MatchResult |
| | root_node: ir.Node |
| | container: ir.Graph | ir.Function |
| | status: MatchStatus |
| |
|
| | def score(self) -> int: |
| | """Return a score for the match.""" |
| | return len(self.match_result.nodes) + int(self.status.value) * 100 |
| |
|
| | def print(self): |
| | separator = "-" * 80 |
| | print(separator) |
| | print(f"Status: {self.status.name}") |
| | if self.status != MatchStatus.SUCCESS: |
| | reason = self.match_result.reason |
| | if reason: |
| | if self.status == MatchStatus.CONDITION_FAILED: |
| | print(f"Graph matching failed due to failing check condition : {reason}") |
| | else: |
| | print(f"Graph matching failed: {reason}") |
| | else: |
| | print("Graph matching failed.") |
| | failure_nodes_and_values = self.match_result.failure_nodes_and_values |
| | print("Failure at or around nodes/values:") |
| | if failure_nodes_and_values: |
| | for failure_cause in failure_nodes_and_values: |
| | failure_cause.display() |
| | print("Matched nodes:") |
| | import onnxscript.rewriter._ir_utils as ir_utils |
| |
|
| | ir_utils.display_nodes(self.match_result.nodes) |
| | print(separator) |
| |
|
| |
|
| | class MatchContext: |
| | """A read-only context containing information about a pattern match. |
| | |
| | This class captures information about the context describing a match to a given pattern, |
| | providing access to the model, graph/function, root node, output values, and all |
| | nodes of the matching subgraph. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model: ir.Model, |
| | graph_or_function: ir.Graph | ir.Function, |
| | root: ir.Node, |
| | match_result: MatchResult, |
| | ) -> None: |
| | """Initialize the pattern match context. |
| | |
| | Args: |
| | model: The model being matched. |
| | graph_or_function: The graph or function being matched. |
| | root: The root node of the matching subgraph. |
| | match_result: The match result containing matched nodes and outputs. |
| | """ |
| | self._model = model |
| | self._graph_or_function = graph_or_function |
| | self._root = root |
| | self._match_result = match_result |
| |
|
| | @property |
| | def model(self) -> ir.Model: |
| | """The model being matched.""" |
| | return self._model |
| |
|
| | @property |
| | def graph_or_function(self) -> ir.Graph | ir.Function: |
| | """The graph or function being matched.""" |
| | return self._graph_or_function |
| |
|
| | @property |
| | def root(self) -> ir.Node: |
| | """The root node of the matching subgraph.""" |
| | return self._root |
| |
|
| | @property |
| | def output_values(self) -> Sequence[ir.Value]: |
| | """The output values of the matching subgraph.""" |
| | return self._match_result.outputs |
| |
|
| | @property |
| | def nodes(self) -> Sequence[ir.Node]: |
| | """All the nodes of the matching subgraph.""" |
| | return self._match_result.nodes |
| |
|
| | def display(self, *, in_graph_order: bool = True) -> None: |
| | """Display the nodes in the pattern match context. |
| | |
| | Args: |
| | in_graph_order: If True, display nodes in the order they appear in the |
| | graph/function. If False, display nodes in the order they appear |
| | in the match result. |
| | """ |
| | nodes = self.nodes |
| | if not nodes: |
| | return |
| |
|
| | if in_graph_order: |
| | |
| | for node in self._graph_or_function: |
| | if node in nodes: |
| | node.display() |
| | else: |
| | |
| | for node in nodes: |
| | node.display() |
| |
|
| |
|
| | class MatchingTracer: |
| | """A debugging helper class to trace the matching of a pattern against a graph. |
| | |
| | This is used to track the best matches found for each rule, and to report the |
| | results at the end of the matching. |
| | """ |
| |
|
| | def __init__(self) -> None: |
| | self._best_matches_map: dict[_rewrite_rule.RewriteRule, list[MatchInfo]] = defaultdict( |
| | list |
| | ) |
| |
|
| | @property |
| | def best_matches_map(self) -> dict[_rewrite_rule.RewriteRule, list[MatchInfo]]: |
| | return self._best_matches_map |
| |
|
| | def log( |
| | self, |
| | rule: _rewrite_rule.RewriteRule, |
| | container: ir.Graph | ir.Function, |
| | node: ir.Node, |
| | match_result: MatchResult, |
| | status: MatchStatus, |
| | ) -> None: |
| | this_match = MatchInfo(match_result, node, container, status) |
| | this_score = this_match.score() |
| | if this_score == 0: |
| | return |
| | best_matches = self._best_matches_map[rule] |
| | if best_matches: |
| | if this_score < best_matches[0].score(): |
| | return |
| | if this_score > best_matches[0].score(): |
| | best_matches.clear() |
| | best_matches.append(this_match) |
| |
|
| | def report(self) -> None: |
| | best_score = 0 |
| | for rule, matches in self._best_matches_map.items(): |
| | if not matches: |
| | continue |
| | if matches[0].score() > best_score: |
| | best_score = matches[0].score() |
| | best_match = matches[0] |
| | best_rule = rule |
| |
|
| | if best_score > 0: |
| | print(f"Rule: {best_rule}") |
| | best_match.print() |
| | else: |
| | print("No matches found.") |
| |
|