| | |
| | |
| |
|
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import dataclasses |
| | import logging |
| | import math |
| | import typing |
| | from typing import Any, Callable, Collection, Iterable, Sequence, Union |
| |
|
| | import numpy as np |
| | import onnx |
| | import onnx.reference.ops |
| | import onnx_ir as ir |
| |
|
| | import onnxscript.utils.utils as utils |
| | from onnxscript.ir import _tape |
| |
|
| | DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 512 |
| |
|
| | DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 512 * 512 |
| |
|
| |
|
| | _NON_DETERMINISTIC_OPS = frozenset( |
| | { |
| | "RandomUniform", |
| | "RandomNormal", |
| | "RandomUniformLike", |
| | "RandomNormalLike", |
| | "Multinomial", |
| | } |
| | ) |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def _is_control_flow_op(node: ir.Node) -> bool: |
| | graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS} |
| | return any(attr.type in graph_types for attr in node.attributes.values()) |
| |
|
| |
|
| | def _is_non_deterministic_op(node: ir.Node) -> bool: |
| | return node.op_type in _NON_DETERMINISTIC_OPS and utils.is_onnx_domain(node.domain) |
| |
|
| |
|
| | def _is_onnx_op(node: ir.Node, op_type: str) -> bool: |
| | return node.op_type == op_type and utils.is_onnx_domain(node.domain) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def _process_constant_node(node: ir.Node) -> None: |
| | """Sets const_value of output value of a Constant op node.""" |
| | if node.op_type != "Constant" or node.domain != "": |
| | return |
| | if len(node.attributes) != 1: |
| | return |
| | attr_name, attr_value = next(iter(node.attributes.items())) |
| | if len(node.outputs) != 1: |
| | return |
| | ir_value = node.outputs[0] |
| |
|
| | if attr_value is None or not isinstance(attr_value, ir.Attr): |
| | return |
| |
|
| | const_value: ir.TensorProtocol |
| | if attr_name in {"value_float", "value_floats"}: |
| | const_value = ir.Tensor( |
| | np.array(attr_value.value, dtype=np.float32), name=ir_value.name |
| | ) |
| | elif attr_name in {"value_int", "value_ints"}: |
| | const_value = ir.Tensor(np.array(attr_value.value, dtype=np.int64), name=ir_value.name) |
| | elif attr_name in {"value_string", "value_strings"}: |
| | const_value = ir.StringTensor( |
| | np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name |
| | ) |
| | elif attr_name == "value": |
| | const_value = typing.cast(ir.TensorProtocol, attr_value.value) |
| | else: |
| | return |
| |
|
| | ir_value.const_value = const_value |
| | ir_value.shape = const_value.shape |
| | ir_value.dtype = const_value.dtype |
| |
|
| |
|
| | def basic_constant_propagation(nodes: Iterable[ir.Node]) -> None: |
| | """Performs basic constant propagation for a sequence of nodes. |
| | |
| | Just marks the output values of Constant op nodes with their const_value. |
| | """ |
| | for node in nodes: |
| | _process_constant_node(node) |
| |
|
| |
|
| | class ReferenceEvaluator: |
| | def get_evaluator(self, domain: str, op: str, version: int) -> Callable | None: |
| | try: |
| | op_impl_class = onnx.reference.ops.load_op(domain, op, version) |
| | return op_impl_class.eval |
| | except Exception: |
| | return None |
| |
|
| | def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: |
| | logger.debug("Evaluating %s::%s", domain, op) |
| | evaluator = self.get_evaluator(domain, op, version) |
| | if evaluator is None: |
| | return None |
| | try: |
| | return evaluator(*args, **kwargs) |
| | except Exception as e: |
| | logger.warning("Evaluation failed: %s", e) |
| | return None |
| |
|
| |
|
| | _reference_evaluator = ReferenceEvaluator() |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class Replacement: |
| | """A replacement for a node in the graph.""" |
| |
|
| | new_outputs: Sequence[ir.Value] |
| | new_nodes: Sequence[ir.Node] |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | SymbolicValue = Union[ir.Value, list[ir.Value], ir.Shape] |
| |
|
| |
|
| | class OptimizerState: |
| | def __init__(self): |
| | self._sym_value_map: dict[ir.Value, SymbolicValue] = {} |
| | self._initializer_inputs: list[set[ir.Value]] = [] |
| |
|
| | @property |
| | def symbolic_value_map(self) -> dict[ir.Value, SymbolicValue]: |
| | return self._sym_value_map |
| |
|
| | def get_sym_value(self, value: ir.Value | None) -> SymbolicValue | None: |
| | if value is None: |
| | return None |
| | return self._sym_value_map.get(value) |
| |
|
| | def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None: |
| | self._sym_value_map[value] = sym_value |
| |
|
| | def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None: |
| | const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10) |
| | if const_value is not None: |
| | if const_value.ndim == 1: |
| | return ir.Shape(const_value.tolist()) |
| | return None |
| | sym_value = self.get_sym_value(value) |
| | if isinstance(sym_value, ir.Shape): |
| | return sym_value |
| | |
| | return None |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | RewriterContext = _tape.Builder |
| | ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None] |
| | PartialEvaluatorFunction = Callable[[ir.Node, RewriterContext, OptimizerState], ReturnValue] |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class PartialEvaluator: |
| | """A class that represents a partial-evaluator for a particular op. |
| | |
| | It is applicable for a specific version range (min_version, max_version) of the op. |
| | The min_version and max_version can be None, indicating that there is no version |
| | constraint in that direction. |
| | """ |
| |
|
| | min_version: int | None |
| | max_version: int | None |
| | function: PartialEvaluatorFunction |
| |
|
| | def valid_for(self, version: int) -> bool: |
| | """Returns True if this evaluator is applicable for the given version.""" |
| | return (self.min_version is None or version >= self.min_version) and ( |
| | self.max_version is None or version <= self.max_version |
| | ) |
| |
|
| |
|
| | class PartialEvaluatorRegistry: |
| | """A class that maintains a registry of evaluators for ops.""" |
| |
|
| | def __init__(self): |
| | self.op_evaluators: dict[tuple[str, str], list[PartialEvaluator]] = {} |
| |
|
| | def lookup_evaluators(self, domain: str, opname: str, version: int): |
| | evaluator_list = self.op_evaluators.get((domain, opname), []) |
| | return [ |
| | evaluator.function for evaluator in evaluator_list if evaluator.valid_for(version) |
| | ] |
| |
|
| | def register( |
| | self, opname: str, domain: str = "", version=None |
| | ) -> Callable[[PartialEvaluatorFunction], PartialEvaluatorFunction]: |
| | if (domain, opname) in self.op_evaluators: |
| | evaluator_list = self.op_evaluators[(domain, opname)] |
| | else: |
| | evaluator_list = [] |
| | self.op_evaluators[(domain, opname)] = evaluator_list |
| | if version is None: |
| | min_version = None |
| | max_version = None |
| | elif isinstance(version, int): |
| | min_version = version |
| | max_version = version |
| | elif isinstance(version, tuple): |
| | min_version, max_version = version |
| |
|
| | def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: |
| | evaluator_list.append(PartialEvaluator(min_version, max_version, function)) |
| | return function |
| |
|
| | return decorator |
| |
|
| |
|
| | registry: PartialEvaluatorRegistry = PartialEvaluatorRegistry() |
| |
|
| | register = registry.register |
| |
|
| |
|
| | def _same_shape(shape1: ir.Shape, shape2: ir.Shape) -> bool: |
| | |
| | |
| | |
| | |
| | if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in shape1): |
| | return False |
| | return shape1.dims == shape2.dims |
| |
|
| |
|
| | def _get_numpy_value( |
| | val: ir.Value | None, dtype: ir.DataType | None = None, size_limit: int | None = None |
| | ) -> np.ndarray | None: |
| | """Returns the numpy value of a constant value, if available. |
| | |
| | It returns None if the value is not a constant value, or if the value is not of |
| | the specified element dtype, or if the size of the value exceeds the specified |
| | size_limit. |
| | """ |
| | if val is None: |
| | return None |
| | const_value = val.const_value |
| | if const_value is not None: |
| | if dtype is not None and const_value.dtype != dtype: |
| | return None |
| | if size_limit is not None and const_value.size > size_limit: |
| | return None |
| | try: |
| | |
| | |
| | array = const_value.numpy().view(const_value.dtype.numpy()) |
| | except FileNotFoundError: |
| | |
| | logger.warning( |
| | "External data for value '%s' is not available. " |
| | "This may lead to incorrect constant folding.", |
| | val.name, |
| | ) |
| | return None |
| | assert isinstance(array, np.ndarray) |
| | return array |
| | return None |
| |
|
| |
|
| | def _get_bool_value(val: ir.Value | None) -> bool | None: |
| | if val is None: |
| | return None |
| | value = _get_numpy_value(val) |
| | if value is None: |
| | return None |
| | if value.size == 1 and value.dtype == bool: |
| | return value.item(0) |
| | return None |
| |
|
| |
|
| | def _get_input(node: ir.Node, index: int) -> ir.Value | None: |
| | if index < len(node.inputs): |
| | return node.inputs[index] |
| | return None |
| |
|
| |
|
| | def _get_output(node: ir.Node, index: int) -> ir.Value | None: |
| | if index < len(node.outputs): |
| | return node.outputs[index] |
| | return None |
| |
|
| |
|
| | def _update_type(value: ir.Value, type: ir.TypeProtocol | None) -> None: |
| | if type is not None: |
| | |
| | value.type = type |
| |
|
| |
|
| | def _get_input_element_type(node: ir.Node, index: int) -> int: |
| | input = _get_input(node, index) |
| | if input is not None and input.type is not None: |
| | return input.type.dtype.value |
| | return ir.DataType.UNDEFINED.value |
| |
|
| |
|
| | def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None: |
| | if name in node.attributes: |
| | attr = node.attributes[name] |
| | if not isinstance(attr, ir.Attr): |
| | return None |
| | attr_val = attr.value |
| | if isinstance(attr_val, int): |
| | return attr_val |
| | |
| | |
| | return None |
| | return default |
| |
|
| |
|
| | @register("Abs") |
| | def abs(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | """Replace an Abs node by Identity when applicable. |
| | |
| | Currently, addresses Abs applied to symbolic shapes. |
| | """ |
| | input = _get_input(node, 0) |
| | input_sym_value = state.get_shape_value(input) |
| | if input_sym_value is None: |
| | return None |
| | if any(isinstance(d, int) and d < 0 for d in input_sym_value): |
| | return None |
| | |
| | |
| | |
| | return op.Identity(input) |
| |
|
| |
|
| | @register("Gather") |
| | def gather(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | """Replace a Gather node by a constant when applicable. |
| | |
| | Currently, handles the case of Gathering from a shape tensor. |
| | """ |
| | input = _get_input(node, 0) |
| | indices = _get_input(node, 1) |
| | if input is None or indices is None: |
| | return None |
| | input_sym_value = state.get_shape_value(input) |
| | if input_sym_value is None: |
| | return None |
| | axis = _get_int_attribute(node, "axis", None) |
| | if axis != 0: |
| | return None |
| | indices_numpy_value = _get_numpy_value(indices) |
| | if indices_numpy_value is None: |
| | return None |
| | if indices_numpy_value.ndim != 1: |
| | return None |
| | gathered = [input_sym_value[i] for i in indices_numpy_value] |
| | output = _get_output(node, 0) |
| | if output is not None: |
| | state.set_sym_value(output, ir.Shape(gathered)) |
| | if all(isinstance(d, int) for d in gathered): |
| | return op.Constant(value_ints=ir.AttrInt64s("value_ints", gathered)) |
| | return None |
| |
|
| |
|
| | @register("Reshape") |
| | def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | """Replace a Reshape node by Identity when applicable.""" |
| | input = _get_input(node, 0) |
| | shape = _get_input(node, 1) |
| | if input is None or shape is None: |
| | return None |
| |
|
| | input_shape = input.shape |
| | shape_value = state.get_shape_value(shape) |
| |
|
| | if shape_value is None or input_shape is None: |
| | return None |
| |
|
| | |
| | if _same_shape(input_shape, shape_value): |
| | return op.Identity(input) |
| | return None |
| |
|
| |
|
| | @register("Cast") |
| | def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | input = _get_input(node, 0) |
| | output = _get_output(node, 0) |
| |
|
| | if input is None or output is None: |
| | return None |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | input_shape = input.shape |
| | if input_shape is not None: |
| | output.shape = input_shape.copy() |
| |
|
| | input_dtype = _get_input_element_type(node, 0) |
| | output_dtype = _get_int_attribute(node, "to", None) |
| | if output_dtype is not None: |
| | if input_dtype == output_dtype: |
| | return op.Identity(input) |
| | output.type = ir.TensorType(ir.DataType(output_dtype)) |
| | return None |
| |
|
| |
|
| | @register("CastLike") |
| | def cast_like(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | input0 = node.inputs[0] |
| | source_element_type = _get_input_element_type(node, 0) |
| | target_element_type = _get_input_element_type(node, 1) |
| |
|
| | if target_element_type == ir.DataType.UNDEFINED: |
| | return None |
| | if source_element_type == target_element_type: |
| | return op.Identity(input0) |
| | return op.Cast(input0, to=target_element_type) |
| |
|
| |
|
| | @register("Shape") |
| | def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | input = node.inputs[0] |
| | if input is None: |
| | return None |
| | shape = input.shape |
| | if shape is None: |
| | return None |
| | start = _get_int_attribute(node, "start", 0) |
| | end = _get_int_attribute(node, "end", None) |
| | shape_slice = shape[start:end] |
| | output = _get_output(node, 0) |
| | if output is not None: |
| | state.set_sym_value(output, ir.Shape(shape_slice)) |
| | if all(isinstance(d, int) for d in shape_slice): |
| | return op.Constant(value_ints=ir.AttrInt64s("value_ints", list(shape_slice))) |
| | return None |
| |
|
| |
|
| | @register("Size") |
| | def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | input = _get_input(node, 0) |
| | if input is None: |
| | return None |
| | shape = input.shape |
| | if shape is None: |
| | return None |
| | size = 1 |
| | for d in shape: |
| | if not isinstance(d, int): |
| | return None |
| | size *= d |
| | return op.Constant(value_int=size) |
| |
|
| |
|
| | @register("If") |
| | def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | cond_input = _get_input(node, 0) |
| | cond = _get_bool_value(cond_input) |
| | if cond is not None: |
| | |
| | branch = "then_branch" if cond else "else_branch" |
| | graph_attr = node.attributes.get(branch) |
| | if graph_attr is None: |
| | return None |
| | if graph_attr.type != ir.AttributeType.GRAPH: |
| | return None |
| | assert isinstance(graph_attr, ir.Attr) |
| | graph = graph_attr.as_graph() |
| | |
| | formal_outs = list(graph.outputs) |
| | graph.outputs.clear() |
| | actual_outs = node.outputs |
| | renamings = { |
| | formal.name: actual.name |
| | for formal, actual in zip(formal_outs, actual_outs) |
| | if actual is not None |
| | } |
| | |
| |
|
| | def rename(name): |
| | return renamings.get(name, name) |
| |
|
| | graph_nodes = list(graph) |
| | graph.remove(graph_nodes) |
| | for sub_node in graph_nodes: |
| | |
| | for v in sub_node.outputs: |
| | v.name = rename(v.name) |
| | |
| | sub_node.name = f"{node.name}_{sub_node.name}" |
| |
|
| | |
| | return Replacement(formal_outs, graph_nodes) |
| | return None |
| |
|
| |
|
| | @register("Identity") |
| | def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | del op |
| | input = node.inputs[0] |
| | output = node.outputs[0] |
| | if input is not None and output is not None: |
| | state.set_sym_value(output, input) |
| | return None |
| |
|
| |
|
| | @register("SequenceConstruct") |
| | def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | del op |
| | output = node.outputs[0] |
| | if output is not None: |
| | state.set_sym_value(output, list(node.inputs)) |
| | return None |
| |
|
| |
|
| | @register("Concat") |
| | def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | """Replace a Concat node with a single input by Identity""" |
| |
|
| | |
| | inputs = node.inputs |
| | if len(inputs) == 1: |
| | return op.Identity(inputs[0]) |
| |
|
| | axis = _get_int_attribute(node, "axis", None) |
| | if axis is None: |
| | return None |
| |
|
| | |
| | def has_zero_size(operand: ir.Value | None) -> bool: |
| | if operand is None: |
| | return False |
| | if (shape := operand.shape) is None: |
| | return False |
| | try: |
| | |
| | dim_size = shape[axis] |
| | except IndexError: |
| | return False |
| | return dim_size == 0 |
| |
|
| | new_inputs = [x for x in inputs if not has_zero_size(x)] |
| | if len(new_inputs) != len(inputs): |
| | if new_inputs: |
| | |
| | logger.debug( |
| | "Concat: removing zero-length operand(s) %s => %s", inputs, new_inputs |
| | ) |
| | return op.Concat(*new_inputs, axis=axis) |
| | elif inputs: |
| | |
| | |
| | logger.debug("Concat: removing all zero-length operands %s", inputs) |
| | return op.Identity(inputs[0]) |
| | else: |
| | |
| | return None |
| |
|
| | |
| |
|
| | |
| |
|
| | if axis != 0: |
| | return None |
| | shapes = [state.get_shape_value(input) for input in inputs] |
| | if any(shape is None for shape in shapes): |
| | return None |
| | concatenated = ir.Shape(dim for shape in shapes for dim in shape.dims) |
| | output = node.outputs[0] |
| | if output is None: |
| | return None |
| | state.set_sym_value(output, concatenated) |
| | return None |
| |
|
| |
|
| | @register("Dropout", version=(12, None)) |
| | def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | """Replace a Dropout by Identity when applicable.""" |
| |
|
| | def optimized_dropout(): |
| | input = node.inputs[0] |
| | output = op.Identity(input) |
| | if len(node.outputs) == 1: |
| | return output |
| | else: |
| | true_tensor = ir.tensor([True]) |
| | input_shape = op.Shape(input) |
| | mask = op.ConstantOfShape(input_shape, value=true_tensor) |
| | return output, mask |
| |
|
| | inputs = node.inputs |
| | if (len(inputs) <= 2) or inputs[2] is None: |
| | |
| | return optimized_dropout() |
| | if _get_bool_value(inputs[2]) is False: |
| | |
| | return optimized_dropout() |
| | ratio = _get_numpy_value(inputs[1]) |
| | if ratio is None: |
| | return None |
| | if ratio.size != 1: |
| | return None |
| | if ratio.item() == 0: |
| | |
| | return optimized_dropout() |
| | return None |
| |
|
| |
|
| | @register("Expand") |
| | def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | """Replace an Expand node by Identity when applicable.""" |
| | if len(node.inputs) != 2: |
| | return None |
| | if (input := node.inputs[0]) is None: |
| | return None |
| | if (input_shape := input.shape) is None: |
| | |
| | return None |
| | if (expanded_shape := _get_numpy_value(node.inputs[1])) is None: |
| | |
| | expanded_sym_shape = state.get_shape_value(node.inputs[1]) |
| | if expanded_sym_shape is None or not _same_shape(input_shape, expanded_sym_shape): |
| | return None |
| | return op.Identity(input) |
| | if expanded_shape.ndim != 1: |
| | |
| | return None |
| | if input_shape.dims == tuple(expanded_shape.tolist()): |
| | return op.Identity(input) |
| | return None |
| |
|
| |
|
| | @register("ConcatFromSequence") |
| | def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | input = node.inputs[0] |
| | inputs = state.get_sym_value(input) |
| | if inputs is None or any(x is None for x in inputs): |
| | return None |
| | new_axis = _get_int_attribute(node, "new_axis", 0) |
| | axis = _get_int_attribute(node, "axis", None) |
| | if axis is None: |
| | return None |
| | if input is not None and isinstance(inputs, list): |
| | if new_axis == 0: |
| | logger.debug("ConcatFromSequence => Concat: %s", [x.name for x in inputs]) |
| | return op.Concat(*inputs, axis=axis) |
| | if new_axis == 1: |
| | |
| | axis_value = op.Constant(value_int=axis) |
| | unsqueezed_inputs = [] |
| | for node_input in inputs: |
| | unsqueezed_input = op.Unsqueeze( |
| | node_input, axis_value, _outputs=[f"{node_input.name}_unsqueeze"] |
| | ) |
| | unsqueezed_inputs.append(unsqueezed_input) |
| | |
| | logger.debug( |
| | "ConcatFromSequence => Concat %s", [x.name for x in unsqueezed_inputs] |
| | ) |
| | return op.Concat(*unsqueezed_inputs, axis=axis) |
| | return None |
| |
|
| |
|
| | @register("SplitToSequence") |
| | def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | """Rewriting pattern. |
| | |
| | From |
| | |
| | splits = onnx::SplitToSequence(input, split, axis=axis) |
| | |
| | to |
| | |
| | split_0, split_1, ..., split_n = onnx::Split(input, split, axis=axis) |
| | splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) |
| | |
| | or |
| | |
| | split_0, split_1, ..., split_n = onnx::Split(input, axis=axis, num_outputs=n+1) |
| | splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) |
| | |
| | where number of output tensors in `splits` is statically known. |
| | onnx::SequenceConstruct will be further optimized away if possible, by its own designated evaluator. |
| | This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly. |
| | """ |
| | input = node.inputs[0] |
| | split = node.inputs[1] |
| | output = node.outputs[0] |
| |
|
| | if input is None or split is None or output is None: |
| | return None |
| |
|
| | axis = _get_int_attribute(node, "axis", 0) |
| | if axis is None: |
| | return None |
| | shape = input.shape |
| | if shape is None: |
| | return None |
| | rank = len(shape) |
| | if axis < 0: |
| | axis = axis + rank |
| | if axis < 0 or axis >= rank: |
| | return None |
| | split_dimension_size = shape[axis] |
| | if not isinstance(split_dimension_size, int): |
| | return None |
| |
|
| | split_value = _get_numpy_value(split) |
| | if split_value is None: |
| | return None |
| | assert isinstance(split_value, np.ndarray) |
| |
|
| | if split_value.ndim == 0: |
| | |
| | num_outputs = math.ceil(split_dimension_size / split_value.item()) |
| | split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] |
| | split_values = op.Split( |
| | input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs |
| | ) |
| | elif split_value.ndim == 1: |
| | |
| | num_outputs = split_value.size |
| | split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] |
| | split_values = op.Split(input, split, axis=axis, _outputs=split_outputs) |
| | else: |
| | return None |
| |
|
| | |
| | if isinstance(split_values, ir.Value): |
| | split_values = [split_values] |
| |
|
| | keepdims = _get_int_attribute(node, "keepdims", 1) |
| | if keepdims is None: |
| | return None |
| | if keepdims == 0: |
| | |
| | axis_val = op.Constant(value_ints=[axis], _outputs=[f"{output.name}_axis"]) |
| | squeezed_values = [] |
| | for i in range(num_outputs): |
| | squeezed = op.Squeeze( |
| | split_values[i], axis_val, _outputs=[f"{split_outputs[i]}_squeeze"] |
| | ) |
| | squeezed_values.append(squeezed) |
| | split_values = squeezed_values |
| |
|
| | logger.debug("SplitToSequence => Split + SequenceConstruct") |
| |
|
| | if isinstance(split_values, ir.Value): |
| | split_values = [split_values] |
| | return op.SequenceConstruct(*split_values) |
| |
|
| |
|
| | @register("SequenceAt") |
| | def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue: |
| | input = node.inputs[0] |
| | position = node.inputs[1] |
| | output = node.outputs[0] |
| | if input is not None and position is not None: |
| | input_vals = state.get_sym_value(input) |
| | position_val = _get_numpy_value(position) |
| | if isinstance(input_vals, list) and position_val is not None: |
| | if position_val.size != 1: |
| | return None |
| | position_val = position_val.item() |
| | try: |
| | result = input_vals[position_val] |
| | except IndexError: |
| | return None |
| | state.set_sym_value(output, result) |
| | logger.debug("SequenceAt %s => %s", input.name, result.name) |
| | return op.Identity(result) |
| | return None |
| |
|
| |
|
| | def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None: |
| | def merge_dims(dim1, dim2): |
| | if dim1 == dim2: |
| | return dim1 |
| | if not isinstance(dim1, ir.SymbolicDim): |
| | return dim1 |
| | if not isinstance(dim2, ir.SymbolicDim): |
| | return dim2 |
| | if dim1.value is None: |
| | return dim2 |
| | return dim1 |
| |
|
| | if shape1 is None: |
| | return shape2 |
| | if shape2 is None: |
| | return shape1 |
| | if len(shape1) != len(shape2): |
| | raise ValueError("Shapes must have the same rank.") |
| | return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)]) |
| |
|
| |
|
| | class FoldConstantsPass(ir.passes.InPlacePass): |
| | """A pass that folds constant expressions in the model. |
| | |
| | Attributes: |
| | shape_inference: Whether to perform shape inference. |
| | input_size_limit: Maximum size of input tensors to fold. |
| | output_size_limit: Maximum size of output tensors to fold. |
| | always_fold_ops: Collection of op types that should always be folded. |
| | For ops from the default opset, only op_type is neede (e.g. "Transpose"), |
| | otherwise specify the domain with ``{domain}::{op_type}``. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | shape_inference: bool, |
| | input_size_limit: int, |
| | output_size_limit: int, |
| | always_fold_ops: Collection[str] = frozenset(["Transpose"]), |
| | ) -> None: |
| | self.shape_inference = shape_inference |
| | self.input_size_limit = input_size_limit |
| | self.output_size_limit = output_size_limit |
| | ops = [] |
| | for name in always_fold_ops: |
| | domain, op_type = name.split("::", 1) if "::" in name else ("", name) |
| | if domain == "ai.onnx": |
| | domain = "" |
| | ops.append((domain, op_type)) |
| | self.always_fold_ops: frozenset[tuple[str, str]] = frozenset(ops) |
| |
|
| | self._opset_imports: dict[str, int] = {} |
| | self._counts: dict[str, int] = {} |
| | self._sizes: dict[str, int] = {} |
| | self._modified: bool = False |
| | self._state = OptimizerState() |
| | self._reset() |
| |
|
| | def _reset(self) -> None: |
| | """Reset internal states for a new run.""" |
| | self._counts = {} |
| | self._sizes = {} |
| | self._modified = False |
| | self._state = OptimizerState() |
| |
|
| | def _do_inference(self, node: ir.Node) -> None: |
| | output_types = {} |
| |
|
| | |
| | def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: |
| | value = _get_numpy_value(x, size_limit=20) |
| | if value is not None: |
| | assert x.const_value is not None |
| | return ir.serde.serialize_tensor(x.const_value) |
| | return None |
| |
|
| | def get_type(value: ir.Value) -> onnx.TypeProto | None: |
| | if value.type is not None: |
| | type_proto = ir.serde.serialize_type(value.type) |
| | if value.shape is not None: |
| | ir.serde.serialize_shape_into(type_proto, value.shape) |
| | return type_proto |
| | return None |
| |
|
| | input_types = {x.name: get_type(x) for x in node.inputs if x is not None} |
| | input_data = {x.name: get_constant_value(x) for x in node.inputs if x is not None} |
| | input_data = {k: v for k, v in input_data.items() if v is not None} |
| | if any(t is None for t in input_types.values()): |
| | logger.debug( |
| | "Skipping shape inference for node %s due to missing input type.", |
| | node.name, |
| | ) |
| | else: |
| | |
| | try: |
| | schema = onnx.defs.get_schema( |
| | node.op_type, self._opset_imports[node.domain], node.domain |
| | ) |
| | output_types = onnx.shape_inference.infer_node_outputs( |
| | schema, |
| | ir.serde.serialize_node(node), |
| | input_types, |
| | input_data, |
| | ) |
| | for output in node.outputs: |
| | if output.name in output_types: |
| | inferred_type = output_types[output.name] |
| | |
| | inferred_shape = ir.serde.deserialize_type_proto_for_shape( |
| | inferred_type |
| | ) |
| | output.shape = _merge_shapes(output.shape, inferred_shape) |
| | output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) |
| | except Exception as e: |
| | logger.debug( |
| | "Skipping shape inference for node %s due to exception: %s", |
| | node.name, |
| | e, |
| | ) |
| |
|
| | def new_constant(self, node: ir.Node, value) -> ir.Node | None: |
| | irvalue = node.outputs[0] |
| | if not isinstance(value, np.ndarray): |
| | |
| | |
| | |
| | logger.info( |
| | "Skip storing constant folded value %s due to unsupported type %s.", |
| | irvalue.name, |
| | type(value), |
| | ) |
| | return None |
| |
|
| | tensor = ir.tensor(value) |
| | tensor.name = irvalue.name |
| | irvalue.const_value = tensor |
| |
|
| | if value.size > self.output_size_limit: |
| | |
| | |
| | removed_input_size = 0 |
| | for input in node.inputs: |
| | if (input is not None) and (len(input.uses()) == 1): |
| | array = _get_numpy_value(input) |
| | if array is not None: |
| | removed_input_size += array.size |
| | increased_size = value.size - removed_input_size |
| | if increased_size > 0: |
| | logger.info( |
| | "Skip storing constant folded nvalue %s due to large size %s.", |
| | irvalue.name, |
| | value.size, |
| | ) |
| | return None |
| |
|
| | logger.debug( |
| | "New constant for value %s dtype: %s shape: %s", |
| | irvalue.name, |
| | value.dtype, |
| | value.shape, |
| | ) |
| |
|
| | attributes = ir.convenience.convert_attributes({"value": tensor}) |
| | node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) |
| | return node |
| |
|
| | def process_node(self, node: ir.Node) -> Replacement | None: |
| | """Process a node and return a Replacement if the node can be replaced.""" |
| | for i, value in enumerate(node.inputs): |
| | sym_value = self._state.get_sym_value(value) |
| | if isinstance(sym_value, ir.Value): |
| | logger.debug( |
| | "Node [%s]: Replacing input %s with %s", |
| | node.name, |
| | value.name, |
| | sym_value.name, |
| | ) |
| | node.replace_input_with(i, sym_value) |
| | self._modified = True |
| | |
| |
|
| | |
| | if self.shape_inference and not _is_control_flow_op(node): |
| | self._do_inference(node) |
| |
|
| | if node.domain not in self._opset_imports: |
| | return None |
| | version = self._opset_imports[node.domain] |
| | op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) |
| | for optimizer in op_optimizers: |
| | assert optimizer |
| | context = RewriterContext() |
| | output = optimizer(node, context, self._state) |
| | if output is not None: |
| | if isinstance(output, Replacement): |
| | return output |
| | if isinstance(output, ir.Value): |
| | output = [output] |
| | return Replacement(output, context.nodes) |
| |
|
| | if _is_control_flow_op(node) or _is_non_deterministic_op(node): |
| | return None |
| |
|
| | if _is_onnx_op(node, "Constant"): |
| | _process_constant_node(node) |
| | return None |
| |
|
| | if any(x.is_graph_input() for x in node.inputs if x is not None): |
| | |
| | return None |
| |
|
| | |
| | if any(x.const_value is None for x in node.inputs if x is not None): |
| | if logger.isEnabledFor(logging.DEBUG): |
| | logger.debug( |
| | "Skipping constant folding for node %s because it has non-constant inputs", |
| | node, |
| | [x.name for x in node.inputs if x is not None], |
| | ) |
| | return None |
| |
|
| | input_tensors = [x.const_value if x is not None else None for x in node.inputs] |
| | if any( |
| | tensor.size > self.input_size_limit |
| | for tensor in input_tensors |
| | if tensor is not None |
| | ): |
| | if (node.domain, node.op_type) in self.always_fold_ops and all( |
| | len(input.consumers()) == 1 for input in node.inputs if input is not None |
| | ): |
| | |
| | |
| | logger.debug( |
| | "Folding large constant for node %s because it is in the always_fold_ops list", |
| | node, |
| | ) |
| | else: |
| | |
| | if logger.isEnabledFor(logging.DEBUG): |
| | input_sizes = [ |
| | tensor.size for tensor in input_tensors if tensor is not None |
| | ] |
| | logger.debug( |
| | "Skipping constant folding for node %s due to large input size: %s", |
| | node, |
| | input_sizes, |
| | ) |
| | return None |
| |
|
| | input_values = [_get_numpy_value(x) for x in node.inputs] |
| |
|
| | def convert(av): |
| | if av.type == ir.AttributeType.TENSOR: |
| | return ir.serde.serialize_tensor(av.value) |
| | return av.value |
| |
|
| | attr_values = {name: convert(attr) for name, attr in node.attributes.items()} |
| | outputs = _reference_evaluator.evaluate( |
| | node.domain, node.op_type, version, *input_values, **attr_values |
| | ) |
| |
|
| | if outputs is None: |
| | return None |
| | if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): |
| | replacement = self.new_constant(node, outputs) |
| | if _is_onnx_op(node, "ConstantOfShape") or replacement is None: |
| | return None |
| | return Replacement(replacement.outputs, [replacement]) |
| | else: |
| | logger.warning( |
| | "Skipping constant folding for op %s with multiple outputs.", node.op_type |
| | ) |
| | return None |
| |
|
| | def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None: |
| | logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name) |
| |
|
| | ir.convenience.replace_nodes_and_values( |
| | root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs |
| | ) |
| |
|
| | self._modified = True |
| |
|
| | |
| | |
| |
|
| | def visit_attribute(self, attr: ir.Attr) -> None: |
| | if attr.is_ref(): |
| | return |
| | if attr.type == ir.AttributeType.GRAPH: |
| | self.visit_graph(attr.as_graph()) |
| | elif attr.type == ir.AttributeType.GRAPHS: |
| | for graph in attr.as_graphs(): |
| | self.visit_graph(graph) |
| |
|
| | def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None: |
| | replacement = self.process_node(node) |
| | if replacement is None: |
| | |
| | for attr in node.attributes.values(): |
| | self.visit_attribute(attr) |
| | return |
| | else: |
| | self.replace_node(node, replacement, root) |
| |
|
| | def visit_graph(self, graph: ir.Graph) -> None: |
| | for node in graph: |
| | self.visit_node(node, graph) |
| |
|
| | |
| | |
| | for i, output in enumerate(graph.outputs): |
| | if output is None: |
| | continue |
| | sym_value = self._state.get_sym_value(output) |
| | if not isinstance(sym_value, ir.Value): |
| | |
| | continue |
| | if not _sym_value_can_replace_graph_output(graph, sym_value, output): |
| | continue |
| | |
| | sym_value.name = output.name |
| | graph.outputs[i] = sym_value |
| | self._modified = True |
| |
|
| | def visit_function(self, function: ir.Function) -> None: |
| | for node in function: |
| | self.visit_node(node, function) |
| |
|
| | def call(self, model: ir.Model) -> FoldConstantsResult: |
| | self._reset() |
| | self._opset_imports = model.opset_imports |
| | self.visit_graph(model.graph) |
| | for function in model.functions.values(): |
| | |
| | self.visit_function(function) |
| | return FoldConstantsResult(model, self._modified, self._state.symbolic_value_map) |
| |
|
| |
|
| | def _sym_value_can_replace_graph_output( |
| | graph: ir.Graph, sym_value: ir.Value, output: ir.Value |
| | ) -> bool: |
| | if (producer := sym_value.producer()) is None: |
| | |
| | |
| | return False |
| | if producer.graph is not graph: |
| | |
| | return False |
| | if sym_value.is_graph_output(): |
| | |
| | |
| | |
| | return False |
| | return True |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class FoldConstantsResult(ir.passes.PassResult): |
| | symbolic_value_map: dict[ir.Value, SymbolicValue] |
| |
|
| | |
| | |
| | def __bool__(self) -> bool: |
| | return self.modified |
| |
|
| |
|
| | def fold_constants( |
| | model: ir.Model, |
| | *, |
| | onnx_shape_inference: bool = False, |
| | input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, |
| | output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, |
| | always_fold_ops: Collection[str] = frozenset(["Transpose"]), |
| | ) -> FoldConstantsResult: |
| | """ |
| | Applies constant folding optimization to the model. |
| | |
| | Args: |
| | model: The ONNX model to optimize. |
| | onnx_shape_inference: Whether to enable ONNX shape inference during |
| | constant folding. Defaults to False. |
| | input_size_limit: The maximum size of input tensors |
| | that can be considered for constant folding. Defaults to |
| | `DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT`. |
| | output_size_limit: The maximum size of output tensors |
| | that can be stored after constant folding. Defaults to |
| | `DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`. |
| | always_fold_ops: A collection of op types that should always be folded, |
| | regardless of their input or output sizes. For ops from the default opset, |
| | only op_type is neede (e.g. "Transpose"), otherwise specify the domain |
| | with ``{domain}::{op_type}``. |
| | |
| | Returns: |
| | An instance of `FoldConstantsResult`. |
| | |
| | """ |
| | folder_pass = FoldConstantsPass( |
| | shape_inference=onnx_shape_inference, |
| | input_size_limit=input_size_limit, |
| | output_size_limit=output_size_limit, |
| | always_fold_ops=always_fold_ops, |
| | ) |
| | return folder_pass(model) |
| |
|