File size: 12,030 Bytes
b5fdc16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Basic rewrite rules for general optimization patterns.
This module contains fundamental optimization rules that are generally applicable
to most ONNX models, including cast elimination, transpose simplification,
shape operation fusion, and other common patterns.
"""
from __future__ import annotations
from typing import ClassVar, Sequence
from onnxscript import ir
from onnxscript.rewriter import _ir_utils as ir_utils
from onnxscript.rewriter._basics import MatchResult
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
class SqueezeReshape(RewriteRuleClassBase):
"""Replaces ``Reshape(Squeeze(x), [-1]])`` with ``Identity(x)`` for 1D x.
This pattern arises from the translation of pytorch symints.
"""
def __init__(self):
super().__init__("SqueezeReshape1d", remove_nodes=False)
def pattern(self, op, x):
return op.Reshape(op.Squeeze(x), [-1])
def rewrite(self, op, x: ir.Value):
return op.Identity(x)
def check(self, context, x) -> MatchResult:
del context # Unused
check_result = MatchResult()
if not ir_utils.has_rank(x, 1):
return check_result.fail("Input is not 1D")
return check_result
class CastIdentity(RewriteRuleClassBase):
"""Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
def pattern(self, op, x, to):
return op.Cast(x, to=to)
def rewrite(self, op, x: ir.Value, to: ir.Attr):
return op.Identity(x)
def check(self, context, x, to) -> MatchResult:
check_result = MatchResult()
if x.dtype != to.as_int():
return check_result.fail("Input and output types are not the same")
return check_result
class CastCast(RewriteRuleClassBase):
"""Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``."""
# Simplify "cast type1 => type2 => type3" to "cast type1 => type3".
# This rule is not valid for all combinations of types: e.g.,
# it is not valid for float32 => float16 => float32 or float32 => int32 => string.
# TODO: fill out the list of allowed combinations: the following is just a couple
# that shows up in practice where it is valid
_allowed_type2_type3: ClassVar = frozenset(
{
(ir.DataType.FLOAT, ir.DataType.FLOAT16),
(ir.DataType.FLOAT, ir.DataType.BFLOAT16),
}
)
def pattern(self, op, x, to, to_ignored):
return op.Cast(op.Cast(x, to=to_ignored), to=to)
def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> MatchResult:
check_result = MatchResult()
type2 = to_ignored.as_int()
type3 = to.as_int()
if (type2, type3) not in self._allowed_type2_type3:
return check_result.fail(
f"Intermediate cast elimination not recognized as valid from {type2} to {type3}. "
f"Cast-Cast rule may be incomplete for this combination."
)
return check_result
def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr):
return op.Cast(x, to=to)
class ExpandIdentity(RewriteRuleClassBase):
"""Replaces ``Expand(..., shape)`` by ``Identity`` if possible."""
def pattern(self, op, x, shape):
return op.Expand(x, shape)
def rewrite(self, op, x: ir.Value, shape: ir.Value):
return op.Identity(x)
def check(self, context, x, shape) -> MatchResult:
check_result = MatchResult()
if shape.const_value is None:
# Shape is not a constant and cannot be guessed.
return check_result.fail("Shape is not a constant and cannot be guessed.")
if (x_shape := x.shape) is None:
# We don't know the shape of the input
return check_result.fail("Input shape is not known.")
if x_shape.dims != tuple(shape.const_value.numpy().tolist()):
return check_result.fail(
f"Input shape {x_shape.dims} does not match the shape {shape.const_value.numpy().tolist()}."
)
return check_result
class ReshapeReshape(RewriteRuleClassBase):
"""Replaces ``Reshape(Reshape(X, ...), shape)`` by ``Reshape(X, shape)``.
The pattern matches only if second reshape reshapes into a shape
with positive values.
"""
def pattern(self, op, x, shape_ignored, shape):
return op.Reshape(op.Reshape(x, shape_ignored), shape)
def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
return op.Reshape(x, shape)
def check(self, context, x, shape_ignored, shape) -> MatchResult:
check_result = MatchResult()
if shape_ignored.const_value is None:
return check_result.fail("Shape ignored is not a constant.")
if shape.const_value is None:
return check_result.fail("Shape is not a constant.")
if shape.const_value.numpy().min() <= 0:
return check_result.fail("Shape has non-positive values.")
return check_result
class SlicesSplit(RewriteRuleClassBase):
"""Replaces ``Slice(x, ...), Slice(x, ...)``
by ``Split(x, ...)`` if possible.
"""
def pattern(self, op, x, begin0, end0, axes0, begin1, end1, axes1):
return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1)
def check(self, context, x, begin0, end0, axes0, begin1, end1, axes1) -> MatchResult:
check_result = MatchResult()
if (
axes0.const_value is None
or axes1.const_value is None
or axes0.const_value.numpy().tolist() != axes1.const_value.numpy().tolist()
):
return check_result.fail("Axes are not equal or not constant.")
axes = axes0.const_value.numpy().tolist()
if len(axes) != 1:
return check_result.fail("Axes has more than one dimension.")
if x.shape:
rk = len(x.shape)
else:
rk = x.rank
if axes[0] != -1 and axes[0] != rk - 1:
return check_result.fail("Axes is not -1 or last dimension.")
if (
begin0.const_value is None
or end0.const_value is None
or begin1.const_value is None
or end1.const_value is None
):
return check_result.fail("Begin or end are not constant values.")
if begin0.const_value.numpy().tolist() != [0]:
return check_result.fail("First begin value is not 0.")
e0, b1, e1 = (
end0.const_value.numpy().tolist(),
begin1.const_value.numpy().tolist(),
end1.const_value.numpy().tolist(),
)
if e0[0] != b1[0]:
return check_result.fail("End0 is not equal to Begin1.")
shape = x.shape
if shape is None:
return check_result.fail("Shape is not known.")
last_dim = shape[-1]
if not isinstance(last_dim, int):
return check_result.fail("Last dimension is not known.")
if last_dim != e1[0]:
return check_result.fail("Last dimension is not equal to End1.")
if last_dim // 2 != b1[0]:
return check_result.fail("Last dimension is not equal to Begin1.")
return check_result
def rewrite(self, op, x, begin0, end0, axes0, begin1, end1, axes1):
return op.Split(x, num_outputs=2, axis=-1, _outputs=2)
class TransposeIdentity(RewriteRuleClassBase):
"""Replaces ``Transpose(. perm=perm)``
when the permutation is identity.
"""
def pattern(self, op, x, perm):
return op.Transpose(x, perm=perm)
def check(self, context, x: ir.Value, perm: ir.Attr) -> MatchResult:
check_result = MatchResult()
if perm.is_ref():
return check_result.fail("Permutation is a reference attribute.")
if perm.type == ir.AttributeType.INTS:
perm_ints = tuple(perm.as_ints())
if perm_ints == tuple(range(len(perm_ints))):
return check_result
return check_result.fail("Permutation is not identity.")
def rewrite(self, op, x: ir.Value, perm: ir.Attr):
return op.Identity(x)
class TransposeTranspose(RewriteRuleClassBase):
"""Replaces ``Transpose(Transpose(., perm=perm1), perm=perm2)``
when both permutations are inverse.
"""
def pattern(self, op, x, perm1, perm2):
return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2)
def check(self, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> MatchResult:
check_result = MatchResult()
if perm1.is_ref() or perm2.is_ref():
return check_result.fail("Permutation is a reference attribute.")
return check_result
def _apply_transpose(self, perm: Sequence[int], on: list[int]) -> list[int]:
assert len(perm) == len(on), "length mismatch"
res = [-1 for i in on]
for i, p in enumerate(perm):
res[i] = on[p]
return res
def _apply_transposes(
self, perms: list[Sequence[int]], on: list[int] | None = None
) -> list[int]:
if on is None:
on = list(range(len(perms[0])))
for p in perms:
on = self._apply_transpose(p, on)
return on
def rewrite(self, op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
first = list(range(len(perm1.as_ints())))
last = self._apply_transposes([perm1.as_ints(), perm2.as_ints()])
if first == last:
return op.Identity(x)
return op.Transpose(x, perm=last)
class UnsqueezeUnsqueeze(RewriteRuleClassBase):
"""Replaces ``Unsqueeze(Unsqueeze(., axes1), axes2)`` with one Unsqueeze."""
def pattern(self, op, x, axes1, axes2):
return op.Unsqueeze(op.Unsqueeze(x, axes1), axes2)
def rewrite(self, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
v1 = ir_utils.get_singleton_value(axes1)
v2 = ir_utils.get_singleton_value(axes2)
axes = [v1, v2] if v1 < v2 else [v2, v1 + 1]
return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64)))
def check(self, context, x, axes1, axes2) -> MatchResult:
check_result = MatchResult()
del context # Unused
del x # Unused
# Currently restricted to single element positive axis
v1 = ir_utils.get_singleton_value(axes1)
v2 = ir_utils.get_singleton_value(axes2)
if v1 is None or v2 is None:
return check_result.fail("Axes are not constant.")
if (v1 < 0) or (v2 < 0):
return check_result.fail("Axes are negative.")
return check_result
# Create rule instances
cast_cast_rule = CastCast.rule()
cast_identity_rule = CastIdentity.rule()
expand_identity_rule = ExpandIdentity.rule()
reshape_reshape_rule = ReshapeReshape.rule()
slice_split_rule = SlicesSplit.rule()
transpose_identity_rule = TransposeIdentity.rule()
transpose_transpose_rule = TransposeTranspose.rule()
unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule()
squeeze_reshape_1d_rule = SqueezeReshape.rule()
def basic_optimization_rules() -> RewriteRuleSet:
"""Returns a set of basic optimization rules.
These rules perform fundamental optimizations such as:
- Eliminating redundant cast operations
- Simplifying consecutive operations of the same type
- Removing identity operations
- Optimizing shape manipulation operations
These rules are generally safe to apply as a first optimization pass
before other more specialized optimizations.
Returns:
RewriteRuleSet: A collection of basic optimization rules
"""
return RewriteRuleSet(
[
cast_cast_rule,
cast_identity_rule,
expand_identity_rule,
reshape_reshape_rule,
slice_split_rule,
transpose_identity_rule,
transpose_transpose_rule,
unsqueeze_unsqueeze_rule,
squeeze_reshape_1d_rule,
]
)
|