|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import TypeVar |
|
|
|
|
|
__all__ = [ |
|
|
"basic_constant_propagation", |
|
|
"fold_constants_ir", |
|
|
"fold_constants", |
|
|
"inline", |
|
|
"optimize_ir", |
|
|
"optimize", |
|
|
"remove_unused_nodes", |
|
|
] |
|
|
|
|
|
import onnx |
|
|
import onnx_ir.passes.common as common_passes |
|
|
|
|
|
import onnxscript.optimizer._constant_folding as constant_folding |
|
|
from onnxscript import ir |
|
|
from onnxscript.optimizer._constant_folding import ( |
|
|
basic_constant_propagation, |
|
|
) |
|
|
from onnxscript.optimizer._constant_folding import ( |
|
|
fold_constants as fold_constants_ir, |
|
|
) |
|
|
from onnxscript.optimizer._optimizer import optimize_ir |
|
|
|
|
|
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) |
|
|
|
|
|
|
|
|
def optimize( |
|
|
model: _ModelProtoOrIr, |
|
|
num_iterations: int = 2, |
|
|
*, |
|
|
onnx_shape_inference: bool = True, |
|
|
stop_if_no_change: bool = True, |
|
|
input_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, |
|
|
output_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, |
|
|
inline: bool = True, |
|
|
) -> _ModelProtoOrIr: |
|
|
"""Optimizes a model. |
|
|
|
|
|
Args: |
|
|
model: The model to be optimized. |
|
|
num_iterations: Number of times the optimization loop is repeated. |
|
|
onnx_shape_inference: Applies node-level shape-inference as part of optimization |
|
|
input_size_limit: Will not apply constant folding to ops with any input of size |
|
|
greater than this. Does not apply to special ops like Shape() and Size(). |
|
|
output_size_limit: Will not rewrite any foldable-op into a Constant op if the size |
|
|
of the output tensor is greater than this. |
|
|
stop_if_no_change: Stop the optimization loop if no change is detected in an iteration. |
|
|
inline: If True, inlines all functions in the model. |
|
|
|
|
|
Returns: |
|
|
The optimized model. If the input was a ModelProto, the output will also be a |
|
|
ModelProto. If the input was an ir.Model, the output will also be an ir.Model. |
|
|
""" |
|
|
if isinstance(model, ir.Model): |
|
|
|
|
|
|
|
|
optimize_ir( |
|
|
model, |
|
|
num_iterations=num_iterations, |
|
|
onnx_shape_inference=onnx_shape_inference, |
|
|
stop_if_no_change=stop_if_no_change, |
|
|
input_size_limit=input_size_limit, |
|
|
output_size_limit=output_size_limit, |
|
|
inline=inline, |
|
|
) |
|
|
return model |
|
|
else: |
|
|
assert isinstance(model, onnx.ModelProto) |
|
|
model_ir = ir.serde.deserialize_model(model) |
|
|
optimize_ir( |
|
|
model_ir, |
|
|
num_iterations=num_iterations, |
|
|
onnx_shape_inference=onnx_shape_inference, |
|
|
stop_if_no_change=stop_if_no_change, |
|
|
input_size_limit=input_size_limit, |
|
|
output_size_limit=output_size_limit, |
|
|
inline=inline, |
|
|
) |
|
|
|
|
|
new_proto = ir.serde.serialize_model(model_ir) |
|
|
return new_proto |
|
|
|
|
|
|
|
|
def inline(model: ir.Model) -> None: |
|
|
"""Inline all function calls (recursively) in the model.""" |
|
|
if model.functions: |
|
|
common_passes.InlinePass()(model) |
|
|
|
|
|
|
|
|
def fold_constants( |
|
|
model: ir.Model | onnx.ModelProto, *args, **kwargs |
|
|
) -> constant_folding.FoldConstantsResult: |
|
|
"""Fold constants in a model in place.""" |
|
|
if isinstance(model, ir.Model): |
|
|
return constant_folding.fold_constants(model, *args, **kwargs) |
|
|
else: |
|
|
assert isinstance(model, onnx.ModelProto) |
|
|
model_proto = model |
|
|
model = ir.serde.deserialize_model(model_proto) |
|
|
result = constant_folding.fold_constants(model, *args, **kwargs) |
|
|
|
|
|
new_proto = ir.serde.serialize_model(model) |
|
|
model_proto.Clear() |
|
|
model_proto.CopyFrom(new_proto) |
|
|
return result |
|
|
|
|
|
|
|
|
def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None: |
|
|
"""Removes unused nodes from a model inplace.""" |
|
|
if isinstance(model, ir.Model): |
|
|
common_passes.RemoveUnusedNodesPass()(model) |
|
|
else: |
|
|
model_ir = ir.serde.deserialize_model(model) |
|
|
model_ir = common_passes.RemoveUnusedNodesPass()(model_ir).model |
|
|
new_proto = ir.serde.serialize_model(model_ir) |
|
|
model.Clear() |
|
|
model.CopyFrom(new_proto) |
|
|
|
|
|
|
|
|
def remove_unused_functions(model: ir.Model | onnx.ModelProto) -> None: |
|
|
"""Removes unused functions from a model inplace.""" |
|
|
if isinstance(model, ir.Model): |
|
|
common_passes.RemoveUnusedFunctionsPass()(model) |
|
|
else: |
|
|
model_ir = ir.serde.deserialize_model(model) |
|
|
model_ir = common_passes.RemoveUnusedFunctionsPass()(model_ir).model |
|
|
new_proto = ir.serde.serialize_model(model_ir) |
|
|
model.Clear() |
|
|
model.CopyFrom(new_proto) |
|
|
|