File size: 4,876 Bytes
6a22ec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from typing import TypeVar

__all__ = [
    "basic_constant_propagation",
    "fold_constants_ir",
    "fold_constants",
    "inline",
    "optimize_ir",
    "optimize",
    "remove_unused_nodes",
]

import onnx
import onnx_ir.passes.common as common_passes

import onnxscript.optimizer._constant_folding as constant_folding
from onnxscript import ir
from onnxscript.optimizer._constant_folding import (
    basic_constant_propagation,
)
from onnxscript.optimizer._constant_folding import (
    fold_constants as fold_constants_ir,
)
from onnxscript.optimizer._optimizer import optimize_ir

_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)


def optimize(
    model: _ModelProtoOrIr,
    num_iterations: int = 2,
    *,
    onnx_shape_inference: bool = True,
    stop_if_no_change: bool = True,
    input_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
    output_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
    inline: bool = True,
) -> _ModelProtoOrIr:
    """Optimizes a model.

    Args:
        model: The model to be optimized.
        num_iterations: Number of times the optimization loop is repeated.
        onnx_shape_inference: Applies node-level shape-inference as part of optimization
        input_size_limit: Will not apply constant folding to ops with any input of size
            greater than this. Does not apply to special ops like Shape() and Size().
        output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
            of the output tensor is greater than this.
        stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
        inline: If True, inlines all functions in the model.

    Returns:
        The optimized model. If the input was a ModelProto, the output will also be a
        ModelProto. If the input was an ir.Model, the output will also be an ir.Model.
    """
    if isinstance(model, ir.Model):
        # In this case, optimize is done inplace.
        # TODO(justinchuby): Maybe make functional
        optimize_ir(
            model,
            num_iterations=num_iterations,
            onnx_shape_inference=onnx_shape_inference,
            stop_if_no_change=stop_if_no_change,
            input_size_limit=input_size_limit,
            output_size_limit=output_size_limit,
            inline=inline,
        )
        return model
    else:
        assert isinstance(model, onnx.ModelProto)
        model_ir = ir.serde.deserialize_model(model)
        optimize_ir(
            model_ir,
            num_iterations=num_iterations,
            onnx_shape_inference=onnx_shape_inference,
            stop_if_no_change=stop_if_no_change,
            input_size_limit=input_size_limit,
            output_size_limit=output_size_limit,
            inline=inline,
        )
        # Move the model back to the proto
        new_proto = ir.serde.serialize_model(model_ir)
        return new_proto


def inline(model: ir.Model) -> None:
    """Inline all function calls (recursively) in the model."""
    if model.functions:
        common_passes.InlinePass()(model)


def fold_constants(
    model: ir.Model | onnx.ModelProto, *args, **kwargs
) -> constant_folding.FoldConstantsResult:
    """Fold constants in a model in place."""
    if isinstance(model, ir.Model):
        return constant_folding.fold_constants(model, *args, **kwargs)
    else:
        assert isinstance(model, onnx.ModelProto)
        model_proto = model
        model = ir.serde.deserialize_model(model_proto)
        result = constant_folding.fold_constants(model, *args, **kwargs)
        # Move the model back to the proto
        new_proto = ir.serde.serialize_model(model)
        model_proto.Clear()
        model_proto.CopyFrom(new_proto)
        return result


def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
    """Removes unused nodes from a model inplace."""
    if isinstance(model, ir.Model):
        common_passes.RemoveUnusedNodesPass()(model)
    else:
        model_ir = ir.serde.deserialize_model(model)
        model_ir = common_passes.RemoveUnusedNodesPass()(model_ir).model
        new_proto = ir.serde.serialize_model(model_ir)
        model.Clear()
        model.CopyFrom(new_proto)


def remove_unused_functions(model: ir.Model | onnx.ModelProto) -> None:
    """Removes unused functions from a model inplace."""
    if isinstance(model, ir.Model):
        common_passes.RemoveUnusedFunctionsPass()(model)
    else:
        model_ir = ir.serde.deserialize_model(model)
        model_ir = common_passes.RemoveUnusedFunctionsPass()(model_ir).model
        new_proto = ir.serde.serialize_model(model_ir)
        model.Clear()
        model.CopyFrom(new_proto)