File size: 4,876 Bytes
6a22ec9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import TypeVar
__all__ = [
"basic_constant_propagation",
"fold_constants_ir",
"fold_constants",
"inline",
"optimize_ir",
"optimize",
"remove_unused_nodes",
]
import onnx
import onnx_ir.passes.common as common_passes
import onnxscript.optimizer._constant_folding as constant_folding
from onnxscript import ir
from onnxscript.optimizer._constant_folding import (
basic_constant_propagation,
)
from onnxscript.optimizer._constant_folding import (
fold_constants as fold_constants_ir,
)
from onnxscript.optimizer._optimizer import optimize_ir
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
def optimize(
model: _ModelProtoOrIr,
num_iterations: int = 2,
*,
onnx_shape_inference: bool = True,
stop_if_no_change: bool = True,
input_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
inline: bool = True,
) -> _ModelProtoOrIr:
"""Optimizes a model.
Args:
model: The model to be optimized.
num_iterations: Number of times the optimization loop is repeated.
onnx_shape_inference: Applies node-level shape-inference as part of optimization
input_size_limit: Will not apply constant folding to ops with any input of size
greater than this. Does not apply to special ops like Shape() and Size().
output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
of the output tensor is greater than this.
stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
inline: If True, inlines all functions in the model.
Returns:
The optimized model. If the input was a ModelProto, the output will also be a
ModelProto. If the input was an ir.Model, the output will also be an ir.Model.
"""
if isinstance(model, ir.Model):
# In this case, optimize is done inplace.
# TODO(justinchuby): Maybe make functional
optimize_ir(
model,
num_iterations=num_iterations,
onnx_shape_inference=onnx_shape_inference,
stop_if_no_change=stop_if_no_change,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
inline=inline,
)
return model
else:
assert isinstance(model, onnx.ModelProto)
model_ir = ir.serde.deserialize_model(model)
optimize_ir(
model_ir,
num_iterations=num_iterations,
onnx_shape_inference=onnx_shape_inference,
stop_if_no_change=stop_if_no_change,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
inline=inline,
)
# Move the model back to the proto
new_proto = ir.serde.serialize_model(model_ir)
return new_proto
def inline(model: ir.Model) -> None:
"""Inline all function calls (recursively) in the model."""
if model.functions:
common_passes.InlinePass()(model)
def fold_constants(
model: ir.Model | onnx.ModelProto, *args, **kwargs
) -> constant_folding.FoldConstantsResult:
"""Fold constants in a model in place."""
if isinstance(model, ir.Model):
return constant_folding.fold_constants(model, *args, **kwargs)
else:
assert isinstance(model, onnx.ModelProto)
model_proto = model
model = ir.serde.deserialize_model(model_proto)
result = constant_folding.fold_constants(model, *args, **kwargs)
# Move the model back to the proto
new_proto = ir.serde.serialize_model(model)
model_proto.Clear()
model_proto.CopyFrom(new_proto)
return result
def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
"""Removes unused nodes from a model inplace."""
if isinstance(model, ir.Model):
common_passes.RemoveUnusedNodesPass()(model)
else:
model_ir = ir.serde.deserialize_model(model)
model_ir = common_passes.RemoveUnusedNodesPass()(model_ir).model
new_proto = ir.serde.serialize_model(model_ir)
model.Clear()
model.CopyFrom(new_proto)
def remove_unused_functions(model: ir.Model | onnx.ModelProto) -> None:
"""Removes unused functions from a model inplace."""
if isinstance(model, ir.Model):
common_passes.RemoveUnusedFunctionsPass()(model)
else:
model_ir = ir.serde.deserialize_model(model)
model_ir = common_passes.RemoveUnusedFunctionsPass()(model_ir).model
new_proto = ir.serde.serialize_model(model_ir)
model.Clear()
model.CopyFrom(new_proto)
|