|
|
|
|
|
|
|
|
"""Basic rewrite rules for general optimization patterns. |
|
|
|
|
|
This module contains fundamental optimization rules that are generally applicable |
|
|
to most ONNX models, including cast elimination, transpose simplification, |
|
|
shape operation fusion, and other common patterns. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import ClassVar, Sequence |
|
|
|
|
|
from onnxscript import ir |
|
|
from onnxscript.rewriter import _ir_utils as ir_utils |
|
|
from onnxscript.rewriter._basics import MatchResult |
|
|
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet |
|
|
|
|
|
|
|
|
class SqueezeReshape(RewriteRuleClassBase): |
|
|
"""Replaces ``Reshape(Squeeze(x), [-1]])`` with ``Identity(x)`` for 1D x. |
|
|
|
|
|
This pattern arises from the translation of pytorch symints. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__("SqueezeReshape1d", remove_nodes=False) |
|
|
|
|
|
def pattern(self, op, x): |
|
|
return op.Reshape(op.Squeeze(x), [-1]) |
|
|
|
|
|
def rewrite(self, op, x: ir.Value): |
|
|
return op.Identity(x) |
|
|
|
|
|
def check(self, context, x) -> MatchResult: |
|
|
del context |
|
|
check_result = MatchResult() |
|
|
if not ir_utils.has_rank(x, 1): |
|
|
return check_result.fail("Input is not 1D") |
|
|
return check_result |
|
|
|
|
|
|
|
|
class CastIdentity(RewriteRuleClassBase): |
|
|
"""Replaces ``Cast(., to=to)`` by ``Identity`` if possible.""" |
|
|
|
|
|
def pattern(self, op, x, to): |
|
|
return op.Cast(x, to=to) |
|
|
|
|
|
def rewrite(self, op, x: ir.Value, to: ir.Attr): |
|
|
return op.Identity(x) |
|
|
|
|
|
def check(self, context, x, to) -> MatchResult: |
|
|
check_result = MatchResult() |
|
|
if x.dtype != to.as_int(): |
|
|
return check_result.fail("Input and output types are not the same") |
|
|
return check_result |
|
|
|
|
|
|
|
|
class CastCast(RewriteRuleClassBase): |
|
|
"""Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_allowed_type2_type3: ClassVar = frozenset( |
|
|
{ |
|
|
(ir.DataType.FLOAT, ir.DataType.FLOAT16), |
|
|
(ir.DataType.FLOAT, ir.DataType.BFLOAT16), |
|
|
} |
|
|
) |
|
|
|
|
|
def pattern(self, op, x, to, to_ignored): |
|
|
return op.Cast(op.Cast(x, to=to_ignored), to=to) |
|
|
|
|
|
def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> MatchResult: |
|
|
check_result = MatchResult() |
|
|
type2 = to_ignored.as_int() |
|
|
type3 = to.as_int() |
|
|
if (type2, type3) not in self._allowed_type2_type3: |
|
|
return check_result.fail( |
|
|
f"Intermediate cast elimination not recognized as valid from {type2} to {type3}. " |
|
|
f"Cast-Cast rule may be incomplete for this combination." |
|
|
) |
|
|
return check_result |
|
|
|
|
|
def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr): |
|
|
return op.Cast(x, to=to) |
|
|
|
|
|
|
|
|
class ExpandIdentity(RewriteRuleClassBase): |
|
|
"""Replaces ``Expand(..., shape)`` by ``Identity`` if possible.""" |
|
|
|
|
|
def pattern(self, op, x, shape): |
|
|
return op.Expand(x, shape) |
|
|
|
|
|
def rewrite(self, op, x: ir.Value, shape: ir.Value): |
|
|
return op.Identity(x) |
|
|
|
|
|
def check(self, context, x, shape) -> MatchResult: |
|
|
check_result = MatchResult() |
|
|
if shape.const_value is None: |
|
|
|
|
|
return check_result.fail("Shape is not a constant and cannot be guessed.") |
|
|
if (x_shape := x.shape) is None: |
|
|
|
|
|
return check_result.fail("Input shape is not known.") |
|
|
if x_shape.dims != tuple(shape.const_value.numpy().tolist()): |
|
|
return check_result.fail( |
|
|
f"Input shape {x_shape.dims} does not match the shape {shape.const_value.numpy().tolist()}." |
|
|
) |
|
|
return check_result |
|
|
|
|
|
|
|
|
class ReshapeReshape(RewriteRuleClassBase): |
|
|
"""Replaces ``Reshape(Reshape(X, ...), shape)`` by ``Reshape(X, shape)``. |
|
|
The pattern matches only if second reshape reshapes into a shape |
|
|
with positive values. |
|
|
""" |
|
|
|
|
|
def pattern(self, op, x, shape_ignored, shape): |
|
|
return op.Reshape(op.Reshape(x, shape_ignored), shape) |
|
|
|
|
|
def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value): |
|
|
return op.Reshape(x, shape) |
|
|
|
|
|
def check(self, context, x, shape_ignored, shape) -> MatchResult: |
|
|
check_result = MatchResult() |
|
|
if shape_ignored.const_value is None: |
|
|
return check_result.fail("Shape ignored is not a constant.") |
|
|
if shape.const_value is None: |
|
|
return check_result.fail("Shape is not a constant.") |
|
|
if shape.const_value.numpy().min() <= 0: |
|
|
return check_result.fail("Shape has non-positive values.") |
|
|
return check_result |
|
|
|
|
|
|
|
|
class SlicesSplit(RewriteRuleClassBase): |
|
|
"""Replaces ``Slice(x, ...), Slice(x, ...)`` |
|
|
by ``Split(x, ...)`` if possible. |
|
|
""" |
|
|
|
|
|
def pattern(self, op, x, begin0, end0, axes0, begin1, end1, axes1): |
|
|
return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1) |
|
|
|
|
|
def check(self, context, x, begin0, end0, axes0, begin1, end1, axes1) -> MatchResult: |
|
|
check_result = MatchResult() |
|
|
if ( |
|
|
axes0.const_value is None |
|
|
or axes1.const_value is None |
|
|
or axes0.const_value.numpy().tolist() != axes1.const_value.numpy().tolist() |
|
|
): |
|
|
return check_result.fail("Axes are not equal or not constant.") |
|
|
axes = axes0.const_value.numpy().tolist() |
|
|
if len(axes) != 1: |
|
|
return check_result.fail("Axes has more than one dimension.") |
|
|
if x.shape: |
|
|
rk = len(x.shape) |
|
|
else: |
|
|
rk = x.rank |
|
|
if axes[0] != -1 and axes[0] != rk - 1: |
|
|
return check_result.fail("Axes is not -1 or last dimension.") |
|
|
if ( |
|
|
begin0.const_value is None |
|
|
or end0.const_value is None |
|
|
or begin1.const_value is None |
|
|
or end1.const_value is None |
|
|
): |
|
|
return check_result.fail("Begin or end are not constant values.") |
|
|
if begin0.const_value.numpy().tolist() != [0]: |
|
|
return check_result.fail("First begin value is not 0.") |
|
|
e0, b1, e1 = ( |
|
|
end0.const_value.numpy().tolist(), |
|
|
begin1.const_value.numpy().tolist(), |
|
|
end1.const_value.numpy().tolist(), |
|
|
) |
|
|
if e0[0] != b1[0]: |
|
|
return check_result.fail("End0 is not equal to Begin1.") |
|
|
shape = x.shape |
|
|
if shape is None: |
|
|
return check_result.fail("Shape is not known.") |
|
|
last_dim = shape[-1] |
|
|
if not isinstance(last_dim, int): |
|
|
return check_result.fail("Last dimension is not known.") |
|
|
if last_dim != e1[0]: |
|
|
return check_result.fail("Last dimension is not equal to End1.") |
|
|
if last_dim // 2 != b1[0]: |
|
|
return check_result.fail("Last dimension is not equal to Begin1.") |
|
|
return check_result |
|
|
|
|
|
def rewrite(self, op, x, begin0, end0, axes0, begin1, end1, axes1): |
|
|
return op.Split(x, num_outputs=2, axis=-1, _outputs=2) |
|
|
|
|
|
|
|
|
class TransposeIdentity(RewriteRuleClassBase): |
|
|
"""Replaces ``Transpose(. perm=perm)`` |
|
|
when the permutation is identity. |
|
|
""" |
|
|
|
|
|
def pattern(self, op, x, perm): |
|
|
return op.Transpose(x, perm=perm) |
|
|
|
|
|
def check(self, context, x: ir.Value, perm: ir.Attr) -> MatchResult: |
|
|
check_result = MatchResult() |
|
|
if perm.is_ref(): |
|
|
return check_result.fail("Permutation is a reference attribute.") |
|
|
if perm.type == ir.AttributeType.INTS: |
|
|
perm_ints = tuple(perm.as_ints()) |
|
|
if perm_ints == tuple(range(len(perm_ints))): |
|
|
return check_result |
|
|
return check_result.fail("Permutation is not identity.") |
|
|
|
|
|
def rewrite(self, op, x: ir.Value, perm: ir.Attr): |
|
|
return op.Identity(x) |
|
|
|
|
|
|
|
|
class TransposeTranspose(RewriteRuleClassBase): |
|
|
"""Replaces ``Transpose(Transpose(., perm=perm1), perm=perm2)`` |
|
|
when both permutations are inverse. |
|
|
""" |
|
|
|
|
|
def pattern(self, op, x, perm1, perm2): |
|
|
return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2) |
|
|
|
|
|
def check(self, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> MatchResult: |
|
|
check_result = MatchResult() |
|
|
if perm1.is_ref() or perm2.is_ref(): |
|
|
return check_result.fail("Permutation is a reference attribute.") |
|
|
return check_result |
|
|
|
|
|
def _apply_transpose(self, perm: Sequence[int], on: list[int]) -> list[int]: |
|
|
assert len(perm) == len(on), "length mismatch" |
|
|
res = [-1 for i in on] |
|
|
for i, p in enumerate(perm): |
|
|
res[i] = on[p] |
|
|
return res |
|
|
|
|
|
def _apply_transposes( |
|
|
self, perms: list[Sequence[int]], on: list[int] | None = None |
|
|
) -> list[int]: |
|
|
if on is None: |
|
|
on = list(range(len(perms[0]))) |
|
|
for p in perms: |
|
|
on = self._apply_transpose(p, on) |
|
|
return on |
|
|
|
|
|
def rewrite(self, op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr): |
|
|
first = list(range(len(perm1.as_ints()))) |
|
|
last = self._apply_transposes([perm1.as_ints(), perm2.as_ints()]) |
|
|
if first == last: |
|
|
return op.Identity(x) |
|
|
return op.Transpose(x, perm=last) |
|
|
|
|
|
|
|
|
class UnsqueezeUnsqueeze(RewriteRuleClassBase): |
|
|
"""Replaces ``Unsqueeze(Unsqueeze(., axes1), axes2)`` with one Unsqueeze.""" |
|
|
|
|
|
def pattern(self, op, x, axes1, axes2): |
|
|
return op.Unsqueeze(op.Unsqueeze(x, axes1), axes2) |
|
|
|
|
|
def rewrite(self, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value): |
|
|
v1 = ir_utils.get_singleton_value(axes1) |
|
|
v2 = ir_utils.get_singleton_value(axes2) |
|
|
axes = [v1, v2] if v1 < v2 else [v2, v1 + 1] |
|
|
return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64))) |
|
|
|
|
|
def check(self, context, x, axes1, axes2) -> MatchResult: |
|
|
check_result = MatchResult() |
|
|
del context |
|
|
del x |
|
|
|
|
|
v1 = ir_utils.get_singleton_value(axes1) |
|
|
v2 = ir_utils.get_singleton_value(axes2) |
|
|
if v1 is None or v2 is None: |
|
|
return check_result.fail("Axes are not constant.") |
|
|
if (v1 < 0) or (v2 < 0): |
|
|
return check_result.fail("Axes are negative.") |
|
|
return check_result |
|
|
|
|
|
|
|
|
|
|
|
cast_cast_rule = CastCast.rule() |
|
|
cast_identity_rule = CastIdentity.rule() |
|
|
expand_identity_rule = ExpandIdentity.rule() |
|
|
reshape_reshape_rule = ReshapeReshape.rule() |
|
|
slice_split_rule = SlicesSplit.rule() |
|
|
transpose_identity_rule = TransposeIdentity.rule() |
|
|
transpose_transpose_rule = TransposeTranspose.rule() |
|
|
unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule() |
|
|
squeeze_reshape_1d_rule = SqueezeReshape.rule() |
|
|
|
|
|
|
|
|
def basic_optimization_rules() -> RewriteRuleSet: |
|
|
"""Returns a set of basic optimization rules. |
|
|
|
|
|
These rules perform fundamental optimizations such as: |
|
|
- Eliminating redundant cast operations |
|
|
- Simplifying consecutive operations of the same type |
|
|
- Removing identity operations |
|
|
- Optimizing shape manipulation operations |
|
|
|
|
|
These rules are generally safe to apply as a first optimization pass |
|
|
before other more specialized optimizations. |
|
|
|
|
|
Returns: |
|
|
RewriteRuleSet: A collection of basic optimization rules |
|
|
""" |
|
|
return RewriteRuleSet( |
|
|
[ |
|
|
cast_cast_rule, |
|
|
cast_identity_rule, |
|
|
expand_identity_rule, |
|
|
reshape_reshape_rule, |
|
|
slice_split_rule, |
|
|
transpose_identity_rule, |
|
|
transpose_transpose_rule, |
|
|
unsqueeze_unsqueeze_rule, |
|
|
squeeze_reshape_1d_rule, |
|
|
] |
|
|
) |
|
|
|