# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """The Pattern IR: used to describe (source) patterns of rewrite rules.""" from __future__ import annotations import abc import contextlib import inspect import itertools from collections.abc import Mapping from typing import ( Any, Callable, Iterable, Iterator, Protocol, Sequence, TypeVar, Union, ) import onnxscript.rewriter._basics as _basics from onnxscript import ir T = TypeVar("T") class Pattern(Protocol[T]): # type: ignore[misc] """This is essentially a Predicate[T], that is, a Callable[[T], bool] bound to the name "matches".""" def matches(self, item: T) -> bool: ... class StringPattern(abc.ABC, Pattern[str]): """Abstract base class for string patterns.""" @abc.abstractmethod def matches(self, item: str) -> bool: pass @abc.abstractmethod def __str__(self) -> str: pass class StringConstantPattern(StringPattern): """Matches strings with given value.""" def __init__(self, value: str): self._value = value def matches(self, item: str) -> bool: return item == self._value def __str__(self) -> str: return self._value def value(self) -> str: return self._value class PrefixPattern(StringPattern): """Matches strings with a given prefix.""" def __init__(self, value: str) -> None: self._value = value def matches(self, value: str) -> bool: return value.startswith(self._value) def __str__(self) -> str: return f"{self._value}*" class AttrPattern(Pattern[ir.Attr]): """Base class for an attribute pattern. Matches any attribute value by default.""" def __init__(self, name: str | None, *, can_match_none: bool = False): self._name = name self._can_match_none = can_match_none @property def name(self) -> str | None: return self._name @property def can_match_none(self) -> bool: """Indicates whether this pattern can match a None attribute.""" return self._can_match_none def matches(self, attr: ir.Attr) -> bool: return True def __str__(self) -> str: return self._name if self._name is not None else "anonymous:" + str(id(self)) class AttrVar(AttrPattern): """Represents a pattern variable used to match against attribute values.""" def __init__(self, name: str | None, *, can_match_none: bool = False): super().__init__(name, can_match_none=can_match_none) # TODO: Support tensors. Align with usage elsewhere. SupportedAttrTypes = Union[ int, float, str, Sequence[int], Sequence[float], Sequence[str], ] class AttrConstantPattern(AttrPattern): """Matches attributes with given value. Uses standard equality for matching. For list-valued attributes, the order of elements matters. If order is immaterial, we need to define a separate pattern for that. """ def __init__(self, value: SupportedAttrTypes): super().__init__(None) self._value = value def matches(self, attr: ir.Attr) -> bool: return isinstance(attr, ir.Attr) and attr.value == self._value def __str__(self) -> str: return str(self._value) def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) -> AttrPattern: """Represents promotion of values allowed as keyword-arguments in a pattern-builder call to an AttrPattern.""" if isinstance(value, AttrPattern): return value if isinstance(value, Var): # This is a hack. Currently, when we create pattern-variables, we create them as Var, # and change them to AttrPattern if/when used in an attribute context. We could use type # annotations to distinguish between ValuePattern and AttrPattern, but forces users to # use these type annotations. # TODO: check for misuse at rule-creation time. (Currently will be caught by matcher at match-time.) if value.check_method is not None: raise ValueError( "Pattern variables used in attributes must not have check_method set." ) return AttrVar(value.name, can_match_none=value.can_match_none) if isinstance(value, (int, float, str)): return AttrConstantPattern(value) if isinstance(value, Sequence): if all(isinstance(i, (int, float)) for i in value): return AttrConstantPattern(value) if all(isinstance(i, str) for i in value): return AttrConstantPattern(value) raise ValueError("Only lists of int/float/str can be used as an AttrPattern") raise TypeError(f"Cannot convert {type(value)} to AttrPattern") class OpsetPatternBuilder: """Represents an opset pattern and a pattern builder. (i) It is used to create a NodePattern (via OpPatternBuilder). Example usage: :: z = op.Matmul(x, y) Here, `op` is an instance of OpsetPatternBuilder and `op.Matmul` is an instance of OpPatternBuilder, and `op.Matmul(x, y)` is an instance of NodePattern. (ii) It contains a domain pattern matched against the actual opset domain used in the input model. """ def __init__(self, domain: StringPattern | str, record: bool = False) -> None: if isinstance(domain, str): domain = StringConstantPattern(domain) self._domain_pattern = domain if record: self._nodes: list[NodePattern] | None = [] else: self._nodes = None def domain_pattern(self) -> StringPattern: return self._domain_pattern def __getattr__(self, op_name: str) -> OpPatternBuilder: return OpPatternBuilder(self, op_name) def submodule(self, name: str) -> OpPatternBuilder: """This method is used to match against submodule ops with prefix.""" return OpPatternBuilder(self, PrefixPattern(name)) def __str__(self) -> str: return str(self._domain_pattern) def add_node(self, node: NodePattern) -> None: if self._nodes is not None: self._nodes.append(node) def nodes(self) -> Sequence[NodePattern]: if self._nodes is None: raise ValueError("Nodes were not recorded.") return self._nodes onnxop = OpsetPatternBuilder("") torch_module_op = OpsetPatternBuilder(PrefixPattern("pkg.torch")) class OpPatternBuilder: """A utility class to build a NodePattern. It is used primarily to create a NodePattern. Example usage: :: z = op.Matmul(x, y) Here, `op` is an instance of OpsetPatternBuilder and `op.Matmul` is an instance of OpPatternBuilder, and `op.Matmul(x, y)` is an instance of NodePattern. """ def __init__( self, pattern_builder: OpsetPatternBuilder, op_name: str | Pattern[str], ) -> None: self.pattern_builder = pattern_builder self.op_name = op_name def __call__( self, *args, _domain: str | None = None, _version: int | None = None, _outputs: int | list[str | None] = 1, _allow_other_attributes: bool | None = None, _allow_other_inputs: bool | None = None, _check: Callable | None = None, **kwargs, ): if _version is not None: raise ValueError( "The pattern builder does not support '_version' keyword argument. " "Version restrictions should be handled by rewrite rules." ) if _domain is None: opset_pattern = self.pattern_builder.domain_pattern() elif isinstance(_domain, str): opset_pattern = StringConstantPattern(_domain) else: # TODO(rama): allow OpsetPatternBuilder as _domain. raise TypeError("_domain must be a string.") if isinstance(_outputs, int): _outputs = [None for _ in range(_outputs)] elif not isinstance(_outputs, Sequence) or not all( isinstance(x, (str, type(None))) for x in _outputs ): raise ValueError("_outputs must be an int or a list[str|None].") inputs = [_to_value_pattern(x) for x in args] attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()} node_pattern = NodePattern( opset_pattern, self.op_name, inputs, attributes, _outputs, allow_other_attributes=_allow_other_attributes, allow_other_inputs=_allow_other_inputs, check=_check, ) self.pattern_builder.add_node(node_pattern) output_values = node_pattern.outputs # Unpack outputs if there is only one output, the common case. if len(output_values) == 1: return output_values[0] else: return output_values def _to_value_pattern( x: ValuePattern | int | float | Callable | None, ) -> ValuePattern | None: """Promotes an input-value used to construct a NodePattern to a ValuePattern. Example usage: :: x = op.MatMul(a, b) z = op.Add(x, 0) In this example, `a, `b`, and `x` are ValuePatterns used to construct a NodePattern. `0` is a constant (int) value, and is automatically promoted to a ValuePattern. Note that this is a shorthand for creating a Constant pattern. The user can more explicitly write this as: :: z = op.Add(x, op.Constant(0)) If a callable is provided, it will be converted to a ValuePattern with the callable as the check attribute. """ if x is None or isinstance(x, ValuePattern): return x if isinstance(x, (int, float)): return Constant(x) if isinstance(x, Sequence): if all(isinstance(i, (int, float)) for i in x): return Constant(x) raise ValueError("Only lists of int/float can be used as a ValuePattern") if callable(x): return ValuePattern(None, check=x) raise TypeError(f"Cannot convert {type(x)} to ValuePattern") _pattern_builder: OpsetPatternBuilder = onnxop @contextlib.contextmanager def pattern_builder(builder: OpsetPatternBuilder): global _pattern_builder prev_builder = _pattern_builder _pattern_builder = builder yield _pattern_builder = prev_builder class ValuePattern: """Base class for all patterns that match against IR values. This is used primarily to provide operator overloadings for arithmetic operations, so that we can write patterns like `x + 1` and `1 + x`. """ def __init__( self, name: str | None, *, check: Callable | None = None, can_match_none: bool = False ) -> None: self._name = name self._check = check self._can_match_none = can_match_none # Note: uses will be computed only when the full graph-pattern is constructed. self._uses: list[tuple[NodePattern, int]] = [] def clone(self, node_map: dict[NodePattern, NodePattern]) -> ValuePattern: del node_map return ValuePattern(self._name, check=self._check) @property def name(self) -> str | None: return self._name @property def check_method(self) -> Callable | None: return self._check @property def can_match_none(self) -> bool: """Indicates whether this variable can match a None input.""" return self._can_match_none def producer(self) -> NodePattern | None: return None def uses(self) -> Sequence[tuple[NodePattern, int]]: return self._uses def append_use(self, node: NodePattern, index: int): self._uses.append((node, index)) def __repr__(self) -> str: return f"ValuePattern({self._name!r})" def __add__(self, other): return _pattern_builder.Add(self, other) def __radd__(self, other): return _pattern_builder.Add(other, self) def __sub__(self, other): return _pattern_builder.Sub(self, other) def __rsub__(self, other): return _pattern_builder.Sub(other, self) def __mul__(self, other): return _pattern_builder.Mul(self, other) def __rmul__(self, other): return _pattern_builder.Mul(other, self) def __truediv__(self, other): return _pattern_builder.Div(self, other) def __rtruediv__(self, other): return _pattern_builder.Div(other, self) def __pow__(self, other): return _pattern_builder.Pow(self, other) def __str__(self) -> str: return self._name if self._name is not None else "anonymous:" + str(id(self)) class NodePattern: """Represents a pattern that matches against a Node. This differs from a NodeOutputPattern in that it matches against a node (which may produce 1 or more outputs), whereas a NodeOutputPattern matches against a specific output of a node. Args: domain: pattern to match against the domain of the node. op: pattern or string constant to match against the op_type of the node. inputs: sequence of ValuePatterns (or constants) to match against the inputs of the node. attributes: dictionary of attribute patterns to match against the attributes of the node. outputs: specifies pattern-variable-name for outputs (or None) allow_other_attributes: specifies whether other attributes (not mentioned in `attributes`) are allowed in the node. """ def __init__( self, domain: StringPattern, op: str | Pattern[str], inputs: Sequence[int | float | ValuePattern | None], attributes: dict[str, AttrPattern], outputs: Sequence[str | None], *, allow_other_attributes: bool | None, allow_other_inputs: bool | None, check: Callable | None = None, ): if allow_other_attributes is None: # Default behavior: allow other unmatched attributes in the node. allow_other_attributes = True if allow_other_inputs is None: # TODO(rama): Should we default to True? For now, we preserve the current behavior. allow_other_inputs = False self.domain = domain self.op = StringConstantPattern(op) if isinstance(op, str) else op self.inputs = [_to_value_pattern(x) for x in inputs] self.attributes = attributes self.allow_other_attributes = allow_other_attributes self.allow_other_inputs = allow_other_inputs self._check = check # In the common case, domain and op are constants, which can be used to optimize matching. if isinstance(op, str) and isinstance(domain, StringConstantPattern): # TODO(rama): support overloaded operators. overload = "" self._op_identifier: ir.OperatorIdentifier | None = ( domain.value(), op, overload, ) else: self._op_identifier = None self.outputs = [NodeOutputPattern(self, i, name) for i, name in enumerate(outputs)] # Update uses for inputs. for index, value in enumerate(self.inputs): if value is not None: value.append_use(self, index) def __str__(self) -> str: inputs = ", ".join(str(v) for v in self.inputs) outputs = ", ".join(str(v) for v in self.outputs) attributes = ", ".join(f"{k}={v}" for k, v in self.attributes.items()) op = str(self.op) domain = str(self.domain) qualified_op = f"{domain}.{op}" if domain else op inputs_and_attributes = f"{inputs}, {attributes}" if attributes else inputs return f"{outputs} = {qualified_op} ({inputs_and_attributes})" def op_identifier(self) -> ir.OperatorIdentifier | None: return self._op_identifier @property def op_type(self) -> str: return str(self.op) @property def check_method(self) -> Callable | None: return self._check def matches(self, node: ir.Node, match: _basics.MatchResult) -> _basics.MatchResult: """Matches the pattern represented by self against a node. This is purely a local node-level match, and does not consider the subgraph rooted at the node. We check the domain, op_type, and attributes of the node, but not the inputs. """ # TODO(rama): Ensure we handle "" and "onnx.ai" correctly. if not self.op.matches(node.op_type): return match.fail( f"OpType mismatch: expected {self.op}, got {node.op_type}.", node ) if not self.domain.matches(node.domain): return match.fail( f"Domain mismatch: expected {self.domain}, got {node.domain}.", node ) for name, attr_pattern in self.attributes.items(): attr_value = node.attributes.get(name) if attr_value is None: if not attr_pattern.can_match_none: return match.fail(f"Attribute {name} not found in node.", node) elif not attr_pattern.matches(attr_value): return match.fail( f"Attribute {name} mismatch: expected {attr_pattern}, got {attr_value}.", node, ) if attr_pattern.name is not None: if not match.bind(attr_pattern.name, attr_value): return match if not self.allow_other_attributes: for name in node.attributes: # TODO: Support matching default nodes for attributes. if name not in self.attributes: return match.fail(f"Attribute {name} not expected in node.", node) return match def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePattern: inputs = [(v.clone(node_map) if v is not None else None) for v in self.inputs] if swap: assert len(inputs) == 2, ( "Internal error: commutative swap applies only to binary ops." ) inputs = [inputs[1], inputs[0]] outputs = [value.name for value in self.outputs] copied = NodePattern( self.domain, self.op, inputs, self.attributes, outputs, allow_other_attributes=self.allow_other_attributes, allow_other_inputs=self.allow_other_inputs, check=self._check, ) node_map[self] = copied return copied class NodeOutputPattern(ValuePattern): """Represents a pattern that matches against a specific output of a Node. This is the primary pattern used to match against computed values, that is values computed using a specific op. """ def __init__( self, producer: NodePattern, output_index: int, name: str | None = None ) -> None: super().__init__(name) self._producer = producer self._output_index = output_index def clone(self, node_map: dict[NodePattern, NodePattern]) -> NodeOutputPattern: return node_map[self._producer].outputs[self._output_index] # return NodeOutputPattern(node_map[self._producer], self._output_index, self._name) @property def output_index(self) -> int: return self._output_index def producer(self) -> NodePattern: return self._producer class Var(ValuePattern): """Represents a pattern-variable.""" def __init__( self, name: str | None, *, check: Callable | None = None, can_match_none: bool = False ) -> None: super().__init__(name, check=check, can_match_none=can_match_none) def clone(self, node_map: dict[NodePattern, NodePattern]) -> Var: """Clones the pattern-variable, preserving its name and check method.""" return Var(self.name, check=self.check_method, can_match_none=self.can_match_none) class AnyValue(ValuePattern): """Represents a pattern that matches against any value.""" def __init__(self) -> None: super().__init__(None) def clone(self, node_map: dict[NodePattern, NodePattern]) -> AnyValue: # A single instance of AnyValue suffices. return self ANY_VALUE = AnyValue() class Constant(ValuePattern): """Represents a pattern that matches against a scalar constant value.""" def __init__( self, value: int | float | Sequence[int] | Sequence[float], rel_tol: float = 1e-5, abs_tol: float = 1e-8, ) -> None: super().__init__(None) self._value = list(value) if isinstance(value, Sequence) else value self._rel_tol = rel_tol self._abs_tol = abs_tol def clone(self, node_map: dict[NodePattern, NodePattern]) -> Constant: del node_map return Constant(self._value, self._rel_tol, self._abs_tol) @property def value(self) -> int | float | list[int] | list[float]: return self._value def __str__(self) -> str: return str(self._value) class OpIdDispatchOr(ValuePattern): """Represents a (restricted) form of value pattern disjunction that enables deterministic matching.""" def __init__( self, op_to_pattern: Mapping[ir.OperatorIdentifier, tuple[Any, ValuePattern]], name: str | None = None, tag_var: str | None = None, ) -> None: """ Initialize an OpIdDispatchOr pattern. Args: op_to_pattern: A dictionary mapping operator identifiers to tuples of tag values and patterns. The keys are operator identifiers, and the values are tuples containing a tag value and a pattern to match against. name: An optional variable name for the pattern. Defaults to None. If present, this name will be bound to the value matched by the pattern. tag_var: An optional variable name for the tag. Defaults to None. If present, it will be bound to a value indicating which alternative was matched. """ super().__init__(name) self._op_to_pattern = op_to_pattern self._tag_var = tag_var @property def tag_var(self) -> str | None: """Returns the tag variable associated with the OrValue pattern.""" return self._tag_var def clone(self, node_map: dict[NodePattern, NodePattern]) -> OpIdDispatchOr: return OpIdDispatchOr( {k: (v[0], v[1].clone(node_map)) for k, v in self._op_to_pattern.items()}, self.name, self._tag_var, ) def get_pattern(self, value: ir.Value) -> tuple[Any, ValuePattern] | None: """Returns the pattern that should be tried for the given value.""" producer = value.producer() if producer is not None: id = producer.op_identifier() if id is not None and id in self._op_to_pattern: return self._op_to_pattern[id] return None class BacktrackingOr(ValuePattern): """Represents an unrestricted form of OR pattern implemented using backtracking.""" def __init__( self, values: Sequence[ValuePattern], name: str | None = None, tag_var: str | None = None, tag_values: Sequence[Any] | None = None, ) -> None: """ Initialize a BacktrackingOr pattern. Args: values: A sequence of value patterns to match against. name: An optional variable name for the pattern. Defaults to None. If present, this name will be bound to the value matched by the pattern. tag_var: An optional variable name for the tag. Defaults to None. If present, it will be bound to a value (from tag_values) indicating which alternative was matched. tag_values: An optional sequence of values to bind to the tag_var. Defaults to None. If present, the length of tag_values must match the number of alternatives in values. In a successful match, tag-var will be bound to the i-th value in tag_values if the i-th alternative pattern matched. If omitted, the default value of (0, 1, 2, ...) will be used. """ super().__init__(name) if tag_values is not None: if tag_var is None: raise ValueError("tag_var must be specified if tag_values is provided.") if len(tag_values) != len(values): raise ValueError( "tag_values must have the same length as the number of alternatives." ) else: tag_values = tuple(range(len(values))) self._tag_var = tag_var self._tag_values = tag_values self._values = values @property def tag_var(self) -> str | None: """Returns the tag variable associated with the OrValue pattern.""" return self._tag_var def clone(self, node_map: dict[NodePattern, NodePattern]) -> BacktrackingOr: return BacktrackingOr( [v.clone(node_map) for v in self._values], self.name, self._tag_var, self._tag_values, ) def OrValue( values: Sequence[ValuePattern], name: str | None = None, tag_var: str | None = None, tag_values: Sequence[Any] | None = None, ) -> ValuePattern: """ Creates an OR pattern. Args: values: A sequence of value patterns to match against. name: An optional variable name for the pattern. Defaults to None. If present, this name will be bound to the value matched by the pattern. tag_var: An optional variable name for the tag. Defaults to None. If present, it will be bound to a value (from tag_values) indicating which alternative was matched. tag_values: An optional sequence of values to bind to the tag_var. Defaults to None. If present, the length of tag_values must match the number of alternatives in values. In a successful match, tag-var will be bound to the i-th value in tag_values if the i-th alternative pattern matched. If omitted, the default value of (0, 1, 2, ...) will be used. """ if tag_values is not None: if tag_var is None: raise ValueError("tag_var must be specified if tag_values is provided.") if len(tag_values) != len(values): raise ValueError( "tag_values must have the same length as the number of alternatives." ) else: tag_values = tuple(range(len(values))) def make_op_id_or_pattern() -> OpIdDispatchOr | None: mapping: dict[ir.OperatorIdentifier, tuple[Any, NodeOutputPattern]] = {} for i, alternative in enumerate(values): if not isinstance(alternative, NodeOutputPattern): return None producer = alternative.producer() id = producer.op_identifier() if id is None or id in mapping: return None mapping[id] = (tag_values[i], alternative) return OpIdDispatchOr(mapping, name, tag_var) optimized_pattern = make_op_id_or_pattern() return optimized_pattern or BacktrackingOr( values, name, tag_var, tag_values if tag_var else None ) def _nodes_in_pattern(outputs: Sequence[ValuePattern]) -> list[NodePattern]: """Returns all nodes used in a pattern, given the outputs of the pattern.""" node_patterns: list[NodePattern] = [] def visit(value_patterns: Sequence[ValuePattern | None]) -> None: for value_pattern in value_patterns: if isinstance(value_pattern, NodeOutputPattern): node_pattern = value_pattern.producer() if node_pattern not in node_patterns: node_patterns.append(node_pattern) visit(node_pattern.inputs) visit(outputs) node_patterns.reverse() return node_patterns def _add_backward_slice( node: NodePattern, backward_slice: set[NodePattern], backward_slice_values: set[ValuePattern], ) -> None: """Adds all nodes in the backward slice of given node to the set `backward_slice`. The backward slice of a node is the set of all nodes that are reachable from the node in a backward traversal from the given node. """ if node in backward_slice: return backward_slice.add(node) for value_pattern in node.inputs: if isinstance(value_pattern, NodeOutputPattern): _add_backward_slice( value_pattern.producer(), backward_slice, backward_slice_values ) elif isinstance(value_pattern, (OpIdDispatchOr, BacktrackingOr)): backward_slice_values.add(value_pattern) class GraphPattern: """Represents a pattern that can be matched against a subgraph.""" def __init__( self, inputs: Sequence[ValuePattern], outputs: Sequence[ValuePattern], nodes: Sequence[NodePattern], ) -> None: self._inputs = inputs self._outputs = outputs if len(outputs) == 0: raise ValueError("GraphPattern must have at least one output") self._nodes = nodes # _nodes_in_pattern(outputs) # Determine the output nodes of the pattern. These are a minimal set of nodes # whose backward-slices cover the entire pattern. output_nodes: set[NodePattern] = set() covered: set[NodePattern] = set() choice_values_returned: set[ValuePattern] = set() covered_choice_values: set[ValuePattern] = set() for value_pattern in outputs: if not isinstance(value_pattern, ValuePattern): raise TypeError( f"Invalid type {type(value_pattern)} for graph pattern output." ) if isinstance(value_pattern, NodeOutputPattern): candidate = value_pattern.producer() if candidate not in covered: output_nodes.add(candidate) _add_backward_slice(candidate, covered, covered_choice_values) elif isinstance(value_pattern, (OpIdDispatchOr, BacktrackingOr)): choice_values_returned.add(value_pattern) # check if all choice_values_returned are contained in covered_choice_values: # We don't yet support the use of a choice-value as a "root" of the search. # This is a limitation of the current implementation, and will be fixed in the future. if not (choice_values_returned <= covered_choice_values): raise NotImplementedError("Returning uncovered choice-values is not supported.") self.output_nodes: list[NodePattern] = list(output_nodes) @property def output_node(self) -> NodePattern: if len(self.output_nodes) != 1: raise ValueError("GraphPattern does not have unique output node.") return self.output_nodes[0] def node(self, index: int) -> NodePattern: return self._nodes[index] def num_nodes(self) -> int: return len(self._nodes) def __len__(self) -> int: return self.num_nodes() @property def inputs(self) -> Sequence[ValuePattern]: return self._inputs @property def outputs(self) -> Sequence[ValuePattern]: return self._outputs def __iter__(self) -> Iterator[NodePattern]: return iter(self._nodes) def __reversed__(self) -> Iterator[NodePattern]: return reversed(self._nodes) @property def has_single_output_node(self) -> bool: return len(self.output_nodes) == 1 @property def num_outputs(self) -> int: return len(self._outputs) def commute(self) -> Sequence[GraphPattern]: def commute_node(node: NodePattern) -> Iterable[bool]: if node.op_identifier() == ("", "Add", "") or node.op_identifier() == ( "", "Mul", "", ): # Try with and without swapping inputs. return [False, True] # No swapping of inputs return [False] iteration_space = [commute_node(node) for node in self._nodes] def copy_graph(swap_list: Iterable[bool]) -> GraphPattern: if not any(swap_list): # No need to swap inputs of any node return self # Create a copy of the graph, with swapped inputs for the nodes that need it. node_map: dict[NodePattern, NodePattern] = {} new_inputs = [v.clone(node_map) for v in self._inputs] new_nodes = [ node.clone(node_map, swap) for node, swap in zip(self._nodes, swap_list) ] new_outputs = [v.clone(node_map) for v in self._outputs] return GraphPattern(new_inputs, new_outputs, new_nodes) return [copy_graph(swap_list) for swap_list in itertools.product(*iteration_space)] def __str__(self) -> str: inputs = ", ".join(str(v) for v in self._inputs) outputs = ", ".join(str(v) for v in self._outputs) nodes = "\n ".join(str(n) for n in self._nodes) return f"pattern ({inputs}) {{\n {nodes}\n return {outputs}\n}}" def _to_graph_pattern(pattern_constructor: Callable) -> GraphPattern: """Convert a pattern-construction function to a GraphPattern. A pattern-construction function will return values as below: :: def pattern(op, x: Var, shape1: Var, shape2: Var): ... return outputs We create a pattern graph by creating pattern-variables for each parameter of the function, and calling the function. The returned values are normalized to a list of ValuePatterns, which represent the outputs of the pattern graph. Args: pattern_constructor: Callable Returns: GraphPattern: A representation of the pattern that can be matched against a subgraph. """ _pattern_vars = inspect.signature(pattern_constructor).parameters pattern_inputs = [Var(v) for v in _pattern_vars][1:] # Skip the first parameter builder = OpsetPatternBuilder("", record=True) with pattern_builder(builder): pattern_outputs = pattern_constructor(builder, *pattern_inputs) # TODO(rama): classify inputs as value/attribute vars # Returned value could be a single ValuePattern or a list of ValuePatterns. # Normalize representation to a list of ValuePatterns. if isinstance(pattern_outputs, ValuePattern): pattern_outputs = [pattern_outputs] return GraphPattern(pattern_inputs, pattern_outputs, builder.nodes())