xiaoanyu123's picture
Add files using upload-large-folder tool
6a22ec9 verified
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# NOTE: This will eventually replace the existing constant_folding.py and evaluator.py files.
from __future__ import annotations
import dataclasses
import logging
import math
import typing
from typing import Any, Callable, Collection, Iterable, Sequence, Union
import numpy as np
import onnx
import onnx.reference.ops
import onnx_ir as ir
import onnxscript.utils.utils as utils
from onnxscript.ir import _tape
DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 512
DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 512 * 512
_NON_DETERMINISTIC_OPS = frozenset(
{
"RandomUniform",
"RandomNormal",
"RandomUniformLike",
"RandomNormalLike",
"Multinomial",
}
)
logger = logging.getLogger(__name__)
def _is_control_flow_op(node: ir.Node) -> bool:
graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
return any(attr.type in graph_types for attr in node.attributes.values())
def _is_non_deterministic_op(node: ir.Node) -> bool:
return node.op_type in _NON_DETERMINISTIC_OPS and utils.is_onnx_domain(node.domain)
def _is_onnx_op(node: ir.Node, op_type: str) -> bool:
return node.op_type == op_type and utils.is_onnx_domain(node.domain)
# "Standard" evaluators are used to perform constant-folding.
# The API below works only for non-control-flow ops (ops without any graph-attributes).
# This currently used ONNX's reference implementation. But we could also
# use ORT's implementation if we want to.
def _process_constant_node(node: ir.Node) -> None:
"""Sets const_value of output value of a Constant op node."""
if node.op_type != "Constant" or node.domain != "":
return
if len(node.attributes) != 1:
return
attr_name, attr_value = next(iter(node.attributes.items()))
if len(node.outputs) != 1:
return
ir_value = node.outputs[0]
if attr_value is None or not isinstance(attr_value, ir.Attr):
return
const_value: ir.TensorProtocol
if attr_name in {"value_float", "value_floats"}:
const_value = ir.Tensor(
np.array(attr_value.value, dtype=np.float32), name=ir_value.name
)
elif attr_name in {"value_int", "value_ints"}:
const_value = ir.Tensor(np.array(attr_value.value, dtype=np.int64), name=ir_value.name)
elif attr_name in {"value_string", "value_strings"}:
const_value = ir.StringTensor(
np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name
)
elif attr_name == "value":
const_value = typing.cast(ir.TensorProtocol, attr_value.value)
else:
return
ir_value.const_value = const_value
ir_value.shape = const_value.shape # type: ignore
ir_value.dtype = const_value.dtype
def basic_constant_propagation(nodes: Iterable[ir.Node]) -> None:
"""Performs basic constant propagation for a sequence of nodes.
Just marks the output values of Constant op nodes with their const_value.
"""
for node in nodes:
_process_constant_node(node)
class ReferenceEvaluator:
def get_evaluator(self, domain: str, op: str, version: int) -> Callable | None:
try:
op_impl_class = onnx.reference.ops.load_op(domain, op, version)
return op_impl_class.eval # noqa: TRY300
except Exception:
return None
def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any:
logger.debug("Evaluating %s::%s", domain, op)
evaluator = self.get_evaluator(domain, op, version)
if evaluator is None:
return None
try:
return evaluator(*args, **kwargs)
except Exception as e:
logger.warning("Evaluation failed: %s", e)
return None
_reference_evaluator = ReferenceEvaluator()
@dataclasses.dataclass
class Replacement:
"""A replacement for a node in the graph."""
new_outputs: Sequence[ir.Value]
new_nodes: Sequence[ir.Node]
# The optimizer tracks an optional symbolic value for each value in the model.
# The symbolic value attached to a value X can be:
# - another IR value Y (indicating that X is equal to Y)
# - a list of IR values [Y1, Y2, ...] (indicating that X is a sequence of values Y1, Y2, ...)
# - a Shape object (indicating that X is a shape value)
# A Shape object as a symbolic value indicates that the corresponding value is
# 1-D (or 0-D) tensor of INT64 values. The values in this object may be constants
# or symbolic dimension values (like "batch_size", "sequence_length", etc.).
# Currently, we assume that symbolic dimensions are also guaranteed to be non-negative.
# TODO: Add support for negative symbolic dimensions.
SymbolicValue = Union[ir.Value, list[ir.Value], ir.Shape]
class OptimizerState:
def __init__(self):
self._sym_value_map: dict[ir.Value, SymbolicValue] = {}
self._initializer_inputs: list[set[ir.Value]] = []
@property
def symbolic_value_map(self) -> dict[ir.Value, SymbolicValue]:
return self._sym_value_map
def get_sym_value(self, value: ir.Value | None) -> SymbolicValue | None:
if value is None:
return None
return self._sym_value_map.get(value)
def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None:
self._sym_value_map[value] = sym_value
def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10)
if const_value is not None:
if const_value.ndim == 1:
return ir.Shape(const_value.tolist())
return None
sym_value = self.get_sym_value(value)
if isinstance(sym_value, ir.Shape):
return sym_value
# TODO use shape of value if available
return None
# The "partial evaluators" below are non-standard evaluators. They are used to perform
# partial evaluation and/or static program analysis (abstract interpretation).
# A partial-evaluator function takes a node, a RewriterContext, OptimizerState and returns
# a Replacement for the node or None (if no replacement is needed). It may also return just
# the ir.Value or ir.Values to replace the output values of the node, when the new nodes
# can be inferred from the RewriterContext used to build the new nodes.
RewriterContext = _tape.Builder
ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None]
PartialEvaluatorFunction = Callable[[ir.Node, RewriterContext, OptimizerState], ReturnValue]
@dataclasses.dataclass
class PartialEvaluator:
"""A class that represents a partial-evaluator for a particular op.
It is applicable for a specific version range (min_version, max_version) of the op.
The min_version and max_version can be None, indicating that there is no version
constraint in that direction.
"""
min_version: int | None
max_version: int | None
function: PartialEvaluatorFunction
def valid_for(self, version: int) -> bool:
"""Returns True if this evaluator is applicable for the given version."""
return (self.min_version is None or version >= self.min_version) and (
self.max_version is None or version <= self.max_version
)
class PartialEvaluatorRegistry:
"""A class that maintains a registry of evaluators for ops."""
def __init__(self):
self.op_evaluators: dict[tuple[str, str], list[PartialEvaluator]] = {}
def lookup_evaluators(self, domain: str, opname: str, version: int):
evaluator_list = self.op_evaluators.get((domain, opname), [])
return [
evaluator.function for evaluator in evaluator_list if evaluator.valid_for(version)
]
def register(
self, opname: str, domain: str = "", version=None
) -> Callable[[PartialEvaluatorFunction], PartialEvaluatorFunction]:
if (domain, opname) in self.op_evaluators:
evaluator_list = self.op_evaluators[(domain, opname)]
else:
evaluator_list = []
self.op_evaluators[(domain, opname)] = evaluator_list
if version is None:
min_version = None
max_version = None
elif isinstance(version, int):
min_version = version
max_version = version
elif isinstance(version, tuple):
min_version, max_version = version
def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction:
evaluator_list.append(PartialEvaluator(min_version, max_version, function))
return function
return decorator
registry: PartialEvaluatorRegistry = PartialEvaluatorRegistry()
register = registry.register
def _same_shape(shape1: ir.Shape, shape2: ir.Shape) -> bool:
# Comparison of shapes as tuples works except if any dimension is None
# (which represents an unknown dimension value). Thus, two shapes such
# as (Batch, 1024) and (Batch, 1024) are considered equal, but (None, 1024)
# and (None, 1024) are not considered equal.
if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in shape1):
return False
return shape1.dims == shape2.dims
def _get_numpy_value(
val: ir.Value | None, dtype: ir.DataType | None = None, size_limit: int | None = None
) -> np.ndarray | None:
"""Returns the numpy value of a constant value, if available.
It returns None if the value is not a constant value, or if the value is not of
the specified element dtype, or if the size of the value exceeds the specified
size_limit.
"""
if val is None:
return None
const_value = val.const_value
if const_value is not None:
if dtype is not None and const_value.dtype != dtype:
return None
if size_limit is not None and const_value.size > size_limit:
return None
try:
# Reinterpret the array with `.view()` because some implementations of
# ir.TensorProtocol (e.g. PyTorch<=2.7) do not use ml_dtypes for bfloat16 etc.
array = const_value.numpy().view(const_value.dtype.numpy())
except FileNotFoundError:
# External data is not available.
logger.warning(
"External data for value '%s' is not available. "
"This may lead to incorrect constant folding.",
val.name,
)
return None
assert isinstance(array, np.ndarray)
return array
return None
def _get_bool_value(val: ir.Value | None) -> bool | None:
if val is None:
return None
value = _get_numpy_value(val)
if value is None:
return None
if value.size == 1 and value.dtype == bool:
return value.item(0)
return None
def _get_input(node: ir.Node, index: int) -> ir.Value | None:
if index < len(node.inputs):
return node.inputs[index]
return None
def _get_output(node: ir.Node, index: int) -> ir.Value | None:
if index < len(node.outputs):
return node.outputs[index]
return None
def _update_type(value: ir.Value, type: ir.TypeProtocol | None) -> None:
if type is not None:
# TODO: merge types
value.type = type
def _get_input_element_type(node: ir.Node, index: int) -> int:
input = _get_input(node, index)
if input is not None and input.type is not None:
return input.type.dtype.value
return ir.DataType.UNDEFINED.value
def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None:
if name in node.attributes:
attr = node.attributes[name]
if not isinstance(attr, ir.Attr):
return None
attr_val = attr.value
if isinstance(attr_val, int):
return attr_val
# This is an invalid model: attribute has invalid/unexpected type.
# For now, we just return None. We could raise an error too.
return None
return default
@register("Abs")
def abs(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace an Abs node by Identity when applicable.
Currently, addresses Abs applied to symbolic shapes.
"""
input = _get_input(node, 0)
input_sym_value = state.get_shape_value(input)
if input_sym_value is None:
return None
if any(isinstance(d, int) and d < 0 for d in input_sym_value):
return None
# Abs applied to a symbolic shape of the form [1, 1, SequenceLength].
# We assume that SequenceLength is a non-negative integer.
# The Abs op is redundant in this case.
return op.Identity(input)
@register("Gather")
def gather(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Gather node by a constant when applicable.
Currently, handles the case of Gathering from a shape tensor.
"""
input = _get_input(node, 0)
indices = _get_input(node, 1)
if input is None or indices is None:
return None
input_sym_value = state.get_shape_value(input)
if input_sym_value is None:
return None
axis = _get_int_attribute(node, "axis", None)
if axis != 0:
return None
indices_numpy_value = _get_numpy_value(indices)
if indices_numpy_value is None:
return None
if indices_numpy_value.ndim != 1:
return None
gathered = [input_sym_value[i] for i in indices_numpy_value]
output = _get_output(node, 0)
if output is not None:
state.set_sym_value(output, ir.Shape(gathered))
if all(isinstance(d, int) for d in gathered):
return op.Constant(value_ints=ir.AttrInt64s("value_ints", gathered))
return None
@register("Reshape")
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Reshape node by Identity when applicable."""
input = _get_input(node, 0)
shape = _get_input(node, 1)
if input is None or shape is None:
return None
input_shape = input.shape
shape_value = state.get_shape_value(shape)
if shape_value is None or input_shape is None:
return None
# No need to check for special values like -1, 0, etc. here
if _same_shape(input_shape, shape_value):
return op.Identity(input)
return None
@register("Cast")
def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
output = _get_output(node, 0)
if input is None or output is None:
return None
# TODO(rama): Parts of the following logic (implementing type/shape inference
# for Cast op) should be unnecessary. Generic incremental shape-inference
# should handle this. Only the optimization to eliminate redundant Cast ops
# should be needed here.
input_shape = input.shape
if input_shape is not None:
output.shape = input_shape.copy()
input_dtype = _get_input_element_type(node, 0)
output_dtype = _get_int_attribute(node, "to", None)
if output_dtype is not None:
if input_dtype == output_dtype:
return op.Identity(input)
output.type = ir.TensorType(ir.DataType(output_dtype))
return None
@register("CastLike")
def cast_like(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input0 = node.inputs[0]
source_element_type = _get_input_element_type(node, 0)
target_element_type = _get_input_element_type(node, 1)
if target_element_type == ir.DataType.UNDEFINED:
return None
if source_element_type == target_element_type:
return op.Identity(input0)
return op.Cast(input0, to=target_element_type)
@register("Shape")
def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = node.inputs[0]
if input is None:
return None
shape = input.shape
if shape is None:
return None
start = _get_int_attribute(node, "start", 0)
end = _get_int_attribute(node, "end", None)
shape_slice = shape[start:end]
output = _get_output(node, 0)
if output is not None:
state.set_sym_value(output, ir.Shape(shape_slice))
if all(isinstance(d, int) for d in shape_slice):
return op.Constant(value_ints=ir.AttrInt64s("value_ints", list(shape_slice)))
return None
@register("Size")
def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
if input is None:
return None
shape = input.shape
if shape is None:
return None
size = 1
for d in shape:
if not isinstance(d, int):
return None
size *= d
return op.Constant(value_int=size)
@register("If")
def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
cond_input = _get_input(node, 0)
cond = _get_bool_value(cond_input)
if cond is not None:
# cond is a constant-value: inline the branch
branch = "then_branch" if cond else "else_branch"
graph_attr = node.attributes.get(branch)
if graph_attr is None:
return None
if graph_attr.type != ir.AttributeType.GRAPH:
return None
assert isinstance(graph_attr, ir.Attr)
graph = graph_attr.as_graph()
# Copy the graph outputs and clear the graph outputs so that the values are free to move
formal_outs = list(graph.outputs)
graph.outputs.clear()
actual_outs = node.outputs
renamings = {
formal.name: actual.name
for formal, actual in zip(formal_outs, actual_outs)
if actual is not None
}
# TODO: Extend renaming to intermediate values.
def rename(name):
return renamings.get(name, name)
graph_nodes = list(graph)
graph.remove(graph_nodes)
for sub_node in graph_nodes:
# TODO: handle renaming inside subgraphs in nodes
for v in sub_node.outputs:
v.name = rename(v.name)
# Avoid name collision.
sub_node.name = f"{node.name}_{sub_node.name}"
# TODO: we should handle initializers as well!
return Replacement(formal_outs, graph_nodes)
return None
@register("Identity")
def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
del op
input = node.inputs[0]
output = node.outputs[0]
if input is not None and output is not None:
state.set_sym_value(output, input)
return None
@register("SequenceConstruct")
def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
del op
output = node.outputs[0]
if output is not None:
state.set_sym_value(output, list(node.inputs))
return None
@register("Concat")
def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Concat node with a single input by Identity"""
# Replace Concat(x) by Identity(x)
inputs = node.inputs
if len(inputs) == 1:
return op.Identity(inputs[0])
axis = _get_int_attribute(node, "axis", None)
if axis is None:
return None
# Eliminate zero-length operands from Concat
def has_zero_size(operand: ir.Value | None) -> bool:
if operand is None:
return False # Invalid model
if (shape := operand.shape) is None:
return False
try:
# We have already checked that axis is an int value (!= None)
dim_size = shape[axis] # type: ignore[index]
except IndexError:
return False
return dim_size == 0 # return False if symbolic or None or non-zero int value
new_inputs = [x for x in inputs if not has_zero_size(x)]
if len(new_inputs) != len(inputs):
if new_inputs:
# Remove zero-length operands from Concat
logger.debug(
"Concat: removing zero-length operand(s) %s => %s", inputs, new_inputs
)
return op.Concat(*new_inputs, axis=axis)
elif inputs:
# All operands are zero-length. Concat is a no-op, but we need to use one of the
# inputs to get the other dimensions correct:
logger.debug("Concat: removing all zero-length operands %s", inputs)
return op.Identity(inputs[0])
else:
# No inputs: invalid model.
return None
# Track value of tensors that carry a shape value:
# Check axis attribute is 0
if axis != 0:
return None
shapes = [state.get_shape_value(input) for input in inputs]
if any(shape is None for shape in shapes):
return None
concatenated = ir.Shape(dim for shape in shapes for dim in shape.dims) # type: ignore[union-attr]
output = node.outputs[0]
if output is None:
return None
state.set_sym_value(output, concatenated)
return None
@register("Dropout", version=(12, None))
def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Dropout by Identity when applicable."""
def optimized_dropout():
input = node.inputs[0]
output = op.Identity(input)
if len(node.outputs) == 1:
return output
else:
true_tensor = ir.tensor([True])
input_shape = op.Shape(input)
mask = op.ConstantOfShape(input_shape, value=true_tensor)
return output, mask
inputs = node.inputs
if (len(inputs) <= 2) or inputs[2] is None:
# No training_mode specified:
return optimized_dropout()
if _get_bool_value(inputs[2]) is False:
# training_mode is False: dropout is not applied.
return optimized_dropout()
ratio = _get_numpy_value(inputs[1])
if ratio is None:
return None
if ratio.size != 1: # Only scalar dropout ratio is supported.
return None
if ratio.item() == 0:
# dropout ratio is 0: dropout is not applied.
return optimized_dropout()
return None
@register("Expand")
def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace an Expand node by Identity when applicable."""
if len(node.inputs) != 2:
return None
if (input := node.inputs[0]) is None:
return None
if (input_shape := input.shape) is None:
# Input shape is not known.
return None
if (expanded_shape := _get_numpy_value(node.inputs[1])) is None:
# Target shape is not known.
expanded_sym_shape = state.get_shape_value(node.inputs[1])
if expanded_sym_shape is None or not _same_shape(input_shape, expanded_sym_shape):
return None
return op.Identity(input)
if expanded_shape.ndim != 1:
# Target shape must be a 1D tensor. Erroneous model.
return None
if input_shape.dims == tuple(expanded_shape.tolist()):
return op.Identity(input)
return None
@register("ConcatFromSequence")
def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = node.inputs[0]
inputs = state.get_sym_value(input)
if inputs is None or any(x is None for x in inputs):
return None
new_axis = _get_int_attribute(node, "new_axis", 0)
axis = _get_int_attribute(node, "axis", None)
if axis is None:
return None
if input is not None and isinstance(inputs, list):
if new_axis == 0:
logger.debug("ConcatFromSequence => Concat: %s", [x.name for x in inputs])
return op.Concat(*inputs, axis=axis)
if new_axis == 1:
# Unsqueeze the inputs with concat axis if new_axis is 1
axis_value = op.Constant(value_int=axis)
unsqueezed_inputs = []
for node_input in inputs:
unsqueezed_input = op.Unsqueeze(
node_input, axis_value, _outputs=[f"{node_input.name}_unsqueeze"]
)
unsqueezed_inputs.append(unsqueezed_input)
# Send unsqueezed outputs to Concat
logger.debug(
"ConcatFromSequence => Concat %s", [x.name for x in unsqueezed_inputs]
)
return op.Concat(*unsqueezed_inputs, axis=axis)
return None
@register("SplitToSequence")
def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Rewriting pattern.
From
splits = onnx::SplitToSequence(input, split, axis=axis)
to
split_0, split_1, ..., split_n = onnx::Split(input, split, axis=axis)
splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n)
or
split_0, split_1, ..., split_n = onnx::Split(input, axis=axis, num_outputs=n+1)
splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n)
where number of output tensors in `splits` is statically known.
onnx::SequenceConstruct will be further optimized away if possible, by its own designated evaluator.
This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly.
"""
input = node.inputs[0]
split = node.inputs[1]
output = node.outputs[0]
if input is None or split is None or output is None:
return None
axis = _get_int_attribute(node, "axis", 0)
if axis is None:
return None
shape = input.shape
if shape is None:
return None
rank = len(shape)
if axis < 0:
axis = axis + rank
if axis < 0 or axis >= rank:
return None
split_dimension_size = shape[axis]
if not isinstance(split_dimension_size, int):
return None
split_value = _get_numpy_value(split)
if split_value is None:
return None
assert isinstance(split_value, np.ndarray)
if split_value.ndim == 0:
# split into chunks all of size 'split' if possible.
num_outputs = math.ceil(split_dimension_size / split_value.item())
split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)]
split_values = op.Split(
input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs
)
elif split_value.ndim == 1:
# split into 'size(split)' chunks
num_outputs = split_value.size
split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)]
split_values = op.Split(input, split, axis=axis, _outputs=split_outputs)
else:
return None
# If Split returns a single value, we need to wrap it into a list.
if isinstance(split_values, ir.Value):
split_values = [split_values]
keepdims = _get_int_attribute(node, "keepdims", 1)
if keepdims is None:
return None
if keepdims == 0:
# squeeze the split dimension if keepdims is 0
axis_val = op.Constant(value_ints=[axis], _outputs=[f"{output.name}_axis"])
squeezed_values = []
for i in range(num_outputs):
squeezed = op.Squeeze(
split_values[i], axis_val, _outputs=[f"{split_outputs[i]}_squeeze"]
)
squeezed_values.append(squeezed)
split_values = squeezed_values
logger.debug("SplitToSequence => Split + SequenceConstruct")
if isinstance(split_values, ir.Value):
split_values = [split_values]
return op.SequenceConstruct(*split_values)
@register("SequenceAt")
def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = node.inputs[0]
position = node.inputs[1]
output = node.outputs[0]
if input is not None and position is not None:
input_vals = state.get_sym_value(input)
position_val = _get_numpy_value(position)
if isinstance(input_vals, list) and position_val is not None:
if position_val.size != 1:
return None
position_val = position_val.item()
try:
result = input_vals[position_val] # type: ignore[index]
except IndexError:
return None
state.set_sym_value(output, result)
logger.debug("SequenceAt %s => %s", input.name, result.name)
return op.Identity(result)
return None
def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
def merge_dims(dim1, dim2):
if dim1 == dim2:
return dim1
if not isinstance(dim1, ir.SymbolicDim):
return dim1 # Prefer int value over symbolic dim
if not isinstance(dim2, ir.SymbolicDim):
return dim2
if dim1.value is None:
return dim2
return dim1
if shape1 is None:
return shape2
if shape2 is None:
return shape1
if len(shape1) != len(shape2):
raise ValueError("Shapes must have the same rank.")
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
class FoldConstantsPass(ir.passes.InPlacePass):
"""A pass that folds constant expressions in the model.
Attributes:
shape_inference: Whether to perform shape inference.
input_size_limit: Maximum size of input tensors to fold.
output_size_limit: Maximum size of output tensors to fold.
always_fold_ops: Collection of op types that should always be folded.
For ops from the default opset, only op_type is neede (e.g. "Transpose"),
otherwise specify the domain with ``{domain}::{op_type}``.
"""
def __init__(
self,
*,
shape_inference: bool,
input_size_limit: int,
output_size_limit: int,
always_fold_ops: Collection[str] = frozenset(["Transpose"]),
) -> None:
self.shape_inference = shape_inference
self.input_size_limit = input_size_limit
self.output_size_limit = output_size_limit
ops = []
for name in always_fold_ops:
domain, op_type = name.split("::", 1) if "::" in name else ("", name)
if domain == "ai.onnx":
domain = ""
ops.append((domain, op_type))
self.always_fold_ops: frozenset[tuple[str, str]] = frozenset(ops)
self._opset_imports: dict[str, int] = {}
self._counts: dict[str, int] = {}
self._sizes: dict[str, int] = {}
self._modified: bool = False
self._state = OptimizerState()
self._reset()
def _reset(self) -> None:
"""Reset internal states for a new run."""
self._counts = {}
self._sizes = {}
self._modified = False
self._state = OptimizerState()
def _do_inference(self, node: ir.Node) -> None:
output_types = {}
# TODO: handle optional inputs
def get_constant_value(x: ir.Value) -> onnx.TensorProto | None:
value = _get_numpy_value(x, size_limit=20)
if value is not None:
assert x.const_value is not None
return ir.serde.serialize_tensor(x.const_value)
return None
def get_type(value: ir.Value) -> onnx.TypeProto | None:
if value.type is not None:
type_proto = ir.serde.serialize_type(value.type)
if value.shape is not None:
ir.serde.serialize_shape_into(type_proto, value.shape)
return type_proto
return None
input_types = {x.name: get_type(x) for x in node.inputs if x is not None}
input_data = {x.name: get_constant_value(x) for x in node.inputs if x is not None}
input_data = {k: v for k, v in input_data.items() if v is not None}
if any(t is None for t in input_types.values()):
logger.debug(
"Skipping shape inference for node %s due to missing input type.",
node.name,
)
else:
# TODO: pass in constant values, ir_version
try:
schema = onnx.defs.get_schema(
node.op_type, self._opset_imports[node.domain], node.domain
)
output_types = onnx.shape_inference.infer_node_outputs(
schema,
ir.serde.serialize_node(node),
input_types, # type: ignore[arg-type]
input_data, # type: ignore[arg-type]
)
for output in node.outputs:
if output.name in output_types:
inferred_type = output_types[output.name]
# TODO: merge types, check for conflicts
inferred_shape = ir.serde.deserialize_type_proto_for_shape(
inferred_type
)
output.shape = _merge_shapes(output.shape, inferred_shape)
output.type = ir.serde.deserialize_type_proto_for_type(inferred_type)
except Exception as e:
logger.debug(
"Skipping shape inference for node %s due to exception: %s",
node.name,
e,
)
def new_constant(self, node: ir.Node, value) -> ir.Node | None:
irvalue = node.outputs[0]
if not isinstance(value, np.ndarray):
# ONNX does not have a way to represent non-tensor constants, eg. a sequence.
# So, a constant-value of type sequence is not folded, but it can be used
# to optimize subsequent operations when possible.
logger.info(
"Skip storing constant folded value %s due to unsupported type %s.",
irvalue.name,
type(value),
)
return None
tensor = ir.tensor(value)
tensor.name = irvalue.name
irvalue.const_value = tensor
if value.size > self.output_size_limit:
# Handle examples like Transpose(weight) to be folded even if the size is large,
# as long as weight has no other uses. This won't increase model size.
removed_input_size = 0
for input in node.inputs:
if (input is not None) and (len(input.uses()) == 1):
array = _get_numpy_value(input)
if array is not None:
removed_input_size += array.size
increased_size = value.size - removed_input_size
if increased_size > 0:
logger.info(
"Skip storing constant folded nvalue %s due to large size %s.",
irvalue.name,
value.size,
)
return None
logger.debug(
"New constant for value %s dtype: %s shape: %s",
irvalue.name,
value.dtype,
value.shape,
)
attributes = ir.convenience.convert_attributes({"value": tensor})
node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1)
return node
def process_node(self, node: ir.Node) -> Replacement | None:
"""Process a node and return a Replacement if the node can be replaced."""
for i, value in enumerate(node.inputs):
sym_value = self._state.get_sym_value(value)
if isinstance(sym_value, ir.Value):
logger.debug(
"Node [%s]: Replacing input %s with %s",
node.name,
value.name, # type: ignore[union-attr]
sym_value.name,
)
node.replace_input_with(i, sym_value)
self._modified = True
# TODO(rama): consider merging type/other info from both values
# Do incremental shape inference
if self.shape_inference and not _is_control_flow_op(node):
self._do_inference(node)
if node.domain not in self._opset_imports:
return None
version = self._opset_imports[node.domain]
op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version)
for optimizer in op_optimizers:
assert optimizer
context = RewriterContext()
output = optimizer(node, context, self._state)
if output is not None:
if isinstance(output, Replacement):
return output
if isinstance(output, ir.Value):
output = [output]
return Replacement(output, context.nodes)
if _is_control_flow_op(node) or _is_non_deterministic_op(node):
return None
if _is_onnx_op(node, "Constant"):
_process_constant_node(node)
return None
if any(x.is_graph_input() for x in node.inputs if x is not None):
# Do not fold any graph inputs to preserve graph signature
return None
# Ensure all node inputs are constants
if any(x.const_value is None for x in node.inputs if x is not None):
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Skipping constant folding for node %s because it has non-constant inputs",
node,
[x.name for x in node.inputs if x is not None],
)
return None
input_tensors = [x.const_value if x is not None else None for x in node.inputs]
if any(
tensor.size > self.input_size_limit
for tensor in input_tensors
if tensor is not None
):
if (node.domain, node.op_type) in self.always_fold_ops and all(
len(input.consumers()) == 1 for input in node.inputs if input is not None
):
# If the op is in always_fold_ops and all inputs are used only by this node,
# we can still fold it even if the input size exceeds the limit.
logger.debug(
"Folding large constant for node %s because it is in the always_fold_ops list",
node,
)
else:
# Skip folding large tensors
if logger.isEnabledFor(logging.DEBUG):
input_sizes = [
tensor.size for tensor in input_tensors if tensor is not None
]
logger.debug(
"Skipping constant folding for node %s due to large input size: %s",
node,
input_sizes,
)
return None
input_values = [_get_numpy_value(x) for x in node.inputs]
def convert(av):
if av.type == ir.AttributeType.TENSOR:
return ir.serde.serialize_tensor(av.value)
return av.value
attr_values = {name: convert(attr) for name, attr in node.attributes.items()}
outputs = _reference_evaluator.evaluate(
node.domain, node.op_type, version, *input_values, **attr_values
)
if outputs is None:
return None
if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
replacement = self.new_constant(node, outputs)
if _is_onnx_op(node, "ConstantOfShape") or replacement is None:
return None
return Replacement(replacement.outputs, [replacement])
else:
logger.warning(
"Skipping constant folding for op %s with multiple outputs.", node.op_type
)
return None
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None:
logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name)
ir.convenience.replace_nodes_and_values(
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
)
self._modified = True
# TODO: what about new opset_imports?
# TODO: track statistics about replaced nodes and sizes of new constants
def visit_attribute(self, attr: ir.Attr) -> None:
if attr.is_ref():
return
if attr.type == ir.AttributeType.GRAPH:
self.visit_graph(attr.as_graph())
elif attr.type == ir.AttributeType.GRAPHS:
for graph in attr.as_graphs():
self.visit_graph(graph)
def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None:
replacement = self.process_node(node)
if replacement is None:
# No change. Process attributes.
for attr in node.attributes.values():
self.visit_attribute(attr)
return
else:
self.replace_node(node, replacement, root)
def visit_graph(self, graph: ir.Graph) -> None:
for node in graph:
self.visit_node(node, graph)
# Replace outputs if output nodes can be folded. This are typically outputs from
# Identity nodes
for i, output in enumerate(graph.outputs):
if output is None:
continue
sym_value = self._state.get_sym_value(output)
if not isinstance(sym_value, ir.Value):
# An output must be a Value
continue
if not _sym_value_can_replace_graph_output(graph, sym_value, output):
continue
# Rename sym_value to match the output name
sym_value.name = output.name
graph.outputs[i] = sym_value
self._modified = True
def visit_function(self, function: ir.Function) -> None:
for node in function:
self.visit_node(node, function)
def call(self, model: ir.Model) -> FoldConstantsResult:
self._reset()
self._opset_imports = model.opset_imports
self.visit_graph(model.graph)
for function in model.functions.values():
# TODO(rama): Should we specialize functions?
self.visit_function(function)
return FoldConstantsResult(model, self._modified, self._state.symbolic_value_map)
def _sym_value_can_replace_graph_output(
graph: ir.Graph, sym_value: ir.Value, output: ir.Value
) -> bool:
if (producer := sym_value.producer()) is None:
# If the sym_value has no producer, it is some graph's input
# ONNX does not allow a graph input to be a graph output
return False
if producer.graph is not graph:
# The sym_value must be produced by a node in the graph to be an output of this graph
return False
if sym_value.is_graph_output():
# If the sym_value is already an output of a graph, we cannot rename it
# to this output name. Otherwise the graph output represented by sym_value
# will lose its name.
return False
return True
@dataclasses.dataclass
class FoldConstantsResult(ir.passes.PassResult):
symbolic_value_map: dict[ir.Value, SymbolicValue]
# Add conversion to bool for backward compatibility. The previously returned value
# for the fold_constants method was a boolean indicating whether the model was modified.
def __bool__(self) -> bool:
return self.modified
def fold_constants(
model: ir.Model,
*,
onnx_shape_inference: bool = False,
input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
always_fold_ops: Collection[str] = frozenset(["Transpose"]),
) -> FoldConstantsResult:
"""
Applies constant folding optimization to the model.
Args:
model: The ONNX model to optimize.
onnx_shape_inference: Whether to enable ONNX shape inference during
constant folding. Defaults to False.
input_size_limit: The maximum size of input tensors
that can be considered for constant folding. Defaults to
`DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT`.
output_size_limit: The maximum size of output tensors
that can be stored after constant folding. Defaults to
`DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`.
always_fold_ops: A collection of op types that should always be folded,
regardless of their input or output sizes. For ops from the default opset,
only op_type is neede (e.g. "Transpose"), otherwise specify the domain
with ``{domain}::{op_type}``.
Returns:
An instance of `FoldConstantsResult`.
"""
folder_pass = FoldConstantsPass(
shape_inference=onnx_shape_inference,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
always_fold_ops=always_fold_ops,
)
return folder_pass(model) # type: ignore[return-value]