File size: 17,750 Bytes
6a22ec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Basic types for the pattern matching and rewriter API."""

from __future__ import annotations

import dataclasses
import enum
from collections import defaultdict
from typing import TYPE_CHECKING, Any, MutableSequence, Sequence, Union

from onnxscript import ir

if TYPE_CHECKING:
    import onnxscript.rewriter._pattern_ir as _pattern_ir
    import onnxscript.rewriter._rewrite_rule as _rewrite_rule


class MatchFailureInfo:
    """Encapsulates information about a pattern match failure."""

    def __init__(
        self,
        reason: str = "",
        *failure_source: ir.Node | ir.Value,
    ):
        self.reason = reason
        self.failure_sources: tuple[ir.Node | ir.Value, ...] = failure_source
        assert all(isinstance(item, (ir.Node, ir.Value)) for item in failure_source), (
            f"All items in failure_source must be ir.Node or ir.Value, got {[type(item) for item in failure_source]}"
        )

    def __str__(self):
        return f"MatchFailureInfo(reason={self.reason!r}, failure_sources={self.failure_sources!r})"


class MatchFailureError(MatchFailureInfo, Exception):
    """Exception raised when a pattern match fails.

    This makes it easier to handle match failures in a compositional way,
    for example, during the condition-checking phase of a pattern match.
    It allows us to define utility functions without having to check for
    and propagate match failures explicitly.
    """

    def __init__(
        self,
        reason: str = "",
        *failure_source: ir.Node | ir.Value,
    ):
        MatchFailureInfo.__init__(self, reason, *failure_source)
        Exception.__init__(self, reason)


class MatchResult:
    """The state object used by the pattern-matching algorithm.

    A match can either succeed or fail.
    If it succeeds, it returns a list of nodes that matched the pattern
    and a set of bindings for the variables in the pattern.

    Example:
    ::
        def pattern(x, shape1, shape2):
            t1 = op.Reshape(x, shape1)
            t2 = op.Reshape(t1, shape2)
            return t2
    The above pattern matches a sequence of two Reshape ops.
    The matched_nodes will contain the two Reshape ops, and the bindings will
    contain the values that are bound to the variables `x`, `shape1`, and `shape2`.
    """

    def __init__(self) -> None:
        # We use a stack of partial matches to handle OR patterns that require backtracking.
        self._partial_matches: list[PartialMatchResult] = [PartialMatchResult()]

    def __repr__(self) -> str:
        """Returns a string representation of the match result."""
        if not self._partial_matches:
            return "MatchResult()"
        return (
            f"MatchResult(success={bool(self)}, reason={self.reason!r}, nodes={self.nodes!r})"
        )

    @property
    def _current_match(self) -> PartialMatchResult:
        """Returns the current match result."""
        return self._partial_matches[-1]

    def enter_new_match(self) -> None:
        """Starts a new sub-match to try out one of multiple alternatives."""
        match = PartialMatchResult()
        self._partial_matches.append(match)

    def abandon_current_match(self) -> PartialMatchResult:
        """Abandons the current alternative due to failure."""
        if len(self._partial_matches) < 2:
            raise ValueError("No match to abandon.")
        return self._partial_matches.pop()

    def merge_current_match(self) -> None:
        """Merges a successful sub-match for an alternative with the parent one."""
        if len(self._partial_matches) < 2:
            raise ValueError("No match to merge.")
        current_match = self._partial_matches.pop()
        previous_match = self._partial_matches[-1]
        if not current_match:
            raise ValueError("Current match is not successful.")
        # Merge the two matches.
        previous_match.merge(current_match)

    def __bool__(self) -> bool:
        """Returns True if the current match is successful."""
        return bool(self._current_match)

    def fail(
        self,
        reason: str = "",
        failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None,
    ) -> MatchResult:
        self._current_match.fail(reason, failure_source)
        return self

    @property
    def reason(self) -> str:
        """Returns the reason for the failure."""
        return self._current_match.reason

    @property
    def nodes(self) -> Sequence[ir.Node]:
        """Returns the list of nodes that matched the pattern."""
        return self._current_match.nodes

    def bind_node(self, pattern_node: _pattern_ir.NodePattern, node: ir.Node):
        """Binds a pattern node to a matched node."""
        self.add_node(node)
        self._current_match.node_bindings[pattern_node] = node

    def add_node(self, node: ir.Node) -> None:
        """Adds a node to the list of matched nodes."""
        self._current_match.add_node(node)

    def bind_value(self, pattern_value: _pattern_ir.ValuePattern, value: Any) -> bool:
        var_name = pattern_value.name
        # TODO(rama): Simplify the following. We currently bind values to
        # pattern variables in two different ways: via their name, or via the
        # pattern-value itself.
        if var_name is None:
            for match in self._partial_matches:
                if pattern_value in match.value_bindings:
                    # TODO(rama): Use appropriate equality-check here.
                    if match.value_bindings[pattern_value] == value:
                        return True
                    self._current_match.fail(
                        f"Binding failure: {pattern_value} bound to two different values.",
                        [match.value_bindings[pattern_value], value],
                    )
                    return False
            self._current_match.value_bindings[pattern_value] = value
            return True
        return self.bind(var_name, value)

    def bind(self, var: str, value: Any) -> bool:
        for match in self._partial_matches:
            if var in match.bindings:
                # TODO(rama): Use appropriate equality-check here.
                if match.bindings[var] == value:
                    return True
                self._current_match.fail(
                    f"Binding failure: {var} bound to two different values.",
                    [match.bindings[var], value],
                )
                return False
        self._current_match.bindings[var] = value
        return True

    @property
    def bindings(self) -> dict[str, Any]:
        """Returns the bindings for the pattern variables."""
        if len(self._partial_matches) > 1:
            raise ValueError("Bindings can be accessed only at the top-level match.")
        return self._current_match.bindings

    @property
    def value_bindings(self) -> dict[_pattern_ir.ValuePattern, ir.Value]:
        """Returns the bindings for the value variables."""
        if len(self._partial_matches) > 1:
            raise ValueError("Value bindings can be accessed only at the top-level match.")
        return self._current_match.value_bindings

    @property
    def node_bindings(self) -> dict[_pattern_ir.NodePattern, ir.Node]:
        """Returns the bindings for the node variables."""
        if len(self._partial_matches) > 1:
            raise ValueError("Node bindings can be accessed only at the top-level match.")
        return self._current_match.node_bindings

    @property
    def outputs(self) -> MutableSequence[ir.Value]:
        """Returns the list of output values that matched the pattern."""
        if len(self._partial_matches) > 1:
            raise ValueError("Outputs can be accessed only at the top-level match.")
        return self._current_match.outputs

    @property
    def failure_nodes_and_values(self) -> list[Union[ir.Node, ir.Value]]:
        """Returns the nodes and values that caused the failure."""
        return self._current_match._failure_nodes_and_values

    def lookup_node(self, pattern_node: _pattern_ir.NodePattern) -> ir.Node | None:
        """Looks up the node that matched the given pattern node."""
        for match in self._partial_matches:
            if pattern_node in match.node_bindings:
                return match.node_bindings[pattern_node]
        return None

    def num_matched_nodes(self) -> int:
        """Returns the number of nodes matched so far."""
        return sum(len(match.node_bindings) for match in self._partial_matches)


class PartialMatchResult:
    """The state object used by the pattern-matching algorithm for a sub-match."""

    def __init__(self) -> None:
        self._success: bool = True
        # For a successful match, _matched_nodes is a list of values that matched the pattern.
        # These include the internal nodes of the pattern that were matched, but not
        # the leaves (sub-trees) that match against the variables in the pattern.
        # These represent the values that will be replaced by the replacement pattern.
        self._matched_nodes: MutableSequence[ir.Node] = []
        # For a successful match, bindings is a dictionary of mapping pattern-variable-names
        # to values.
        self._bindings: dict[str, Any] = {}
        self._value_bindings: dict[_pattern_ir.ValuePattern, ir.Value] = {}
        self._node_bindings: dict[_pattern_ir.NodePattern, ir.Node] = {}

        self._outputs: list[ir.Value] = []
        # For a failed match, _reason is a string that describes the reason for the failure.
        self._reason: str = ""
        # Track the node(s) or value(s) that caused the failure.
        self._failure_nodes_and_values: list[Union[ir.Node, ir.Value]] = []

    def __bool__(self):
        return self._success

    def fail(
        self,
        reason: str = "",
        failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None,
    ) -> None:
        self._success = False
        self._reason = reason
        if failure_source is not None:
            if isinstance(failure_source, list):
                self._failure_nodes_and_values.extend(failure_source)
            else:
                self._failure_nodes_and_values.append(failure_source)

    @property
    def reason(self) -> str:
        return self._reason

    @property
    def nodes(self) -> Sequence[ir.Node]:
        return tuple(self._matched_nodes)

    def add_node(self, node: ir.Node) -> None:
        """Adds a node to the list of matched nodes."""
        self._matched_nodes.append(node)

    @property
    def bindings(self) -> dict[str, Any]:
        return self._bindings

    @property
    def value_bindings(self) -> dict[_pattern_ir.ValuePattern, ir.Value]:
        return self._value_bindings

    @property
    def outputs(self) -> MutableSequence[ir.Value]:
        return self._outputs

    @property
    def node_bindings(self) -> dict[_pattern_ir.NodePattern, ir.Node]:
        return self._node_bindings

    def merge(self, other: PartialMatchResult) -> None:
        """Merges a successful sub-match for an alternative with the parent one."""
        if self._success and other._success:
            # Merge the two successful matches. Matching algorithm responsible for ensuring
            # that the two matches are compatible. No need to check for conflicts here.
            self._bindings.update(other._bindings)
            self._matched_nodes.extend(other.nodes)
            # Note: outputs should be set only at end of the (top-level) match. There
            # should be no outputs in the sub-match.
            assert not other._outputs
        else:
            # This should not happen currently.
            raise NotImplementedError("Merging failed matches is not yet supported.")


class MatchStatus(enum.IntEnum):
    """The status of a pattern-matching operation."""

    NO_MATCH = 0  # No successful match found for entire pattern graph
    CONDITION_FAILED = 1  # Subsequent validation check failed
    REPLACEMENT_FAILED = 2  # Replacement subgraph could not be created
    SUCCESS = 3  # A successful match was found


@dataclasses.dataclass
class MatchInfo:
    """The status of a pattern-matching operation. An extension of MatchResult."""

    match_result: MatchResult
    root_node: ir.Node
    container: ir.Graph | ir.Function
    status: MatchStatus

    def score(self) -> int:
        """Return a score for the match."""
        return len(self.match_result.nodes) + int(self.status.value) * 100

    def print(self):
        separator = "-" * 80
        print(separator)
        print(f"Status: {self.status.name}")
        if self.status != MatchStatus.SUCCESS:
            reason = self.match_result.reason
            if reason:
                if self.status == MatchStatus.CONDITION_FAILED:
                    print(f"Graph matching failed due to failing check condition : {reason}")
                else:
                    print(f"Graph matching failed: {reason}")
            else:
                print("Graph matching failed.")
            failure_nodes_and_values = self.match_result.failure_nodes_and_values
            print("Failure at or around nodes/values:")
            if failure_nodes_and_values:
                for failure_cause in failure_nodes_and_values:
                    failure_cause.display()
        print("Matched nodes:")
        import onnxscript.rewriter._ir_utils as ir_utils

        ir_utils.display_nodes(self.match_result.nodes)
        print(separator)


class MatchContext:
    """A read-only context containing information about a pattern match.

    This class captures information about the context describing a match to a given pattern,
    providing access to the model, graph/function, root node, output values, and all
    nodes of the matching subgraph.
    """

    def __init__(
        self,
        model: ir.Model,
        graph_or_function: ir.Graph | ir.Function,
        root: ir.Node,
        match_result: MatchResult,
    ) -> None:
        """Initialize the pattern match context.

        Args:
            model: The model being matched.
            graph_or_function: The graph or function being matched.
            root: The root node of the matching subgraph.
            match_result: The match result containing matched nodes and outputs.
        """
        self._model = model
        self._graph_or_function = graph_or_function
        self._root = root
        self._match_result = match_result

    @property
    def model(self) -> ir.Model:
        """The model being matched."""
        return self._model

    @property
    def graph_or_function(self) -> ir.Graph | ir.Function:
        """The graph or function being matched."""
        return self._graph_or_function

    @property
    def root(self) -> ir.Node:
        """The root node of the matching subgraph."""
        return self._root

    @property
    def output_values(self) -> Sequence[ir.Value]:
        """The output values of the matching subgraph."""
        return self._match_result.outputs

    @property
    def nodes(self) -> Sequence[ir.Node]:
        """All the nodes of the matching subgraph."""
        return self._match_result.nodes

    def display(self, *, in_graph_order: bool = True) -> None:
        """Display the nodes in the pattern match context.

        Args:
            in_graph_order: If True, display nodes in the order they appear in the
                graph/function. If False, display nodes in the order they appear
                in the match result.
        """
        nodes = self.nodes
        if not nodes:
            return

        if in_graph_order:
            # Display nodes in same order as in graph/function
            for node in self._graph_or_function:
                if node in nodes:
                    node.display()
        else:
            # Display nodes in match order
            for node in nodes:
                node.display()


class MatchingTracer:
    """A debugging helper class to trace the matching of a pattern against a graph.

    This is used to track the best matches found for each rule, and to report the
    results at the end of the matching.
    """

    def __init__(self) -> None:
        self._best_matches_map: dict[_rewrite_rule.RewriteRule, list[MatchInfo]] = defaultdict(
            list
        )

    @property
    def best_matches_map(self) -> dict[_rewrite_rule.RewriteRule, list[MatchInfo]]:
        return self._best_matches_map

    def log(
        self,
        rule: _rewrite_rule.RewriteRule,
        container: ir.Graph | ir.Function,
        node: ir.Node,
        match_result: MatchResult,
        status: MatchStatus,
    ) -> None:
        this_match = MatchInfo(match_result, node, container, status)
        this_score = this_match.score()
        if this_score == 0:
            return
        best_matches = self._best_matches_map[rule]
        if best_matches:
            if this_score < best_matches[0].score():
                return
            if this_score > best_matches[0].score():
                best_matches.clear()
        best_matches.append(this_match)

    def report(self) -> None:
        best_score = 0
        for rule, matches in self._best_matches_map.items():
            if not matches:
                continue
            if matches[0].score() > best_score:
                best_score = matches[0].score()
                best_match = matches[0]
                best_rule = rule

        if best_score > 0:
            print(f"Rule: {best_rule}")
            best_match.print()
        else:
            print("No matches found.")