# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from __future__ import annotations from typing import TypeVar __all__ = [ "basic_constant_propagation", "fold_constants_ir", "fold_constants", "inline", "optimize_ir", "optimize", "remove_unused_nodes", ] import onnx import onnx_ir.passes.common as common_passes import onnxscript.optimizer._constant_folding as constant_folding from onnxscript import ir from onnxscript.optimizer._constant_folding import ( basic_constant_propagation, ) from onnxscript.optimizer._constant_folding import ( fold_constants as fold_constants_ir, ) from onnxscript.optimizer._optimizer import optimize_ir _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) def optimize( model: _ModelProtoOrIr, num_iterations: int = 2, *, onnx_shape_inference: bool = True, stop_if_no_change: bool = True, input_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, output_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, inline: bool = True, ) -> _ModelProtoOrIr: """Optimizes a model. Args: model: The model to be optimized. num_iterations: Number of times the optimization loop is repeated. onnx_shape_inference: Applies node-level shape-inference as part of optimization input_size_limit: Will not apply constant folding to ops with any input of size greater than this. Does not apply to special ops like Shape() and Size(). output_size_limit: Will not rewrite any foldable-op into a Constant op if the size of the output tensor is greater than this. stop_if_no_change: Stop the optimization loop if no change is detected in an iteration. inline: If True, inlines all functions in the model. Returns: The optimized model. If the input was a ModelProto, the output will also be a ModelProto. If the input was an ir.Model, the output will also be an ir.Model. """ if isinstance(model, ir.Model): # In this case, optimize is done inplace. # TODO(justinchuby): Maybe make functional optimize_ir( model, num_iterations=num_iterations, onnx_shape_inference=onnx_shape_inference, stop_if_no_change=stop_if_no_change, input_size_limit=input_size_limit, output_size_limit=output_size_limit, inline=inline, ) return model else: assert isinstance(model, onnx.ModelProto) model_ir = ir.serde.deserialize_model(model) optimize_ir( model_ir, num_iterations=num_iterations, onnx_shape_inference=onnx_shape_inference, stop_if_no_change=stop_if_no_change, input_size_limit=input_size_limit, output_size_limit=output_size_limit, inline=inline, ) # Move the model back to the proto new_proto = ir.serde.serialize_model(model_ir) return new_proto def inline(model: ir.Model) -> None: """Inline all function calls (recursively) in the model.""" if model.functions: common_passes.InlinePass()(model) def fold_constants( model: ir.Model | onnx.ModelProto, *args, **kwargs ) -> constant_folding.FoldConstantsResult: """Fold constants in a model in place.""" if isinstance(model, ir.Model): return constant_folding.fold_constants(model, *args, **kwargs) else: assert isinstance(model, onnx.ModelProto) model_proto = model model = ir.serde.deserialize_model(model_proto) result = constant_folding.fold_constants(model, *args, **kwargs) # Move the model back to the proto new_proto = ir.serde.serialize_model(model) model_proto.Clear() model_proto.CopyFrom(new_proto) return result def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None: """Removes unused nodes from a model inplace.""" if isinstance(model, ir.Model): common_passes.RemoveUnusedNodesPass()(model) else: model_ir = ir.serde.deserialize_model(model) model_ir = common_passes.RemoveUnusedNodesPass()(model_ir).model new_proto = ir.serde.serialize_model(model_ir) model.Clear() model.CopyFrom(new_proto) def remove_unused_functions(model: ir.Model | onnx.ModelProto) -> None: """Removes unused functions from a model inplace.""" if isinstance(model, ir.Model): common_passes.RemoveUnusedFunctionsPass()(model) else: model_ir = ir.serde.deserialize_model(model) model_ir = common_passes.RemoveUnusedFunctionsPass()(model_ir).model new_proto = ir.serde.serialize_model(model_ir) model.Clear() model.CopyFrom(new_proto)