File size: 5,771 Bytes
52e82ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Cleanup and optimize perch_v2_slim.onnx model.

This script can be applied after completing these steps:

1. Use `tf2onnx` to convert the tflite model to onnx
2. Apply onnxslim and onnxscript.optimize.optimizer on the model
3. Manually edit the model to remove the first DFT node (no-op) and fuse
    the nodes that effectively takes the magnitude of the DFT output with ReduceL2.
"""

import onnx_ir as ir
import onnx_ir.passes.common
import onnxscript
import numpy as np

m = ir.load("perch_v2_slim.onnx")

for node in m.graph:
    if node.op_type == "MatMul":
        print("Simplify MatMul + Reshape:", node.name)
        if node.inputs[0].producer().op_type == "Reshape":
            # Skip the reshape
            input = node.inputs[0].producer().inputs[0]
            node.replace_input_with(0, input)

        for usage in node.outputs[0].uses():
            if usage.node.op_type == "Reshape":
                reshape_usages = list(usage.node.outputs[0].uses())
                # Keep the last Reshape
                if reshape_usages[0].node.op_type == "ReduceMax":
                    shape = ir.val(
                        "reshape_shape", const_value=ir.tensor([-1, 16, 4, 14795, 4])
                    )
                    m.graph.initializers.add(shape)
                    usage.node.replace_input_with(1, shape)
                    continue
                reshape_node = usage.node
                output = reshape_node.outputs[0]
                output.replace_all_uses_with(node.outputs[0])

    # Remove Expand
    if node.op_type == "Expand":
        print("Remove Expand:", node.name)
        input = node.inputs[0]
        output = node.outputs[0]
        output.replace_all_uses_with(input)

# Clean up any unused nodes
onnx_ir.passes.common.RemoveUnusedNodesPass()(m)
# Do some const folding
onnxscript.optimizer.optimize(
    m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)
one_1d = ir.val("1d_one", const_value=ir.tensor([1], dtype=ir.DataType.INT64))
m.graph.initializers.add(one_1d)

# Simplify Unsqueeze + Reshape
for node in m.graph:
    if node.op_type == "Reshape":
        print("Simplify Unsqueeze + Reshape:", node.name)
        if (
            node.inputs[0].producer()
            and node.inputs[0].producer().op_type == "Unsqueeze"
        ):
            unsqueeze_node = node.inputs[0].producer()
            unsqueeze_node.replace_input_with(1, one_1d)
            node.outputs[0].replace_all_uses_with(unsqueeze_node.outputs[0])
            unsqueeze_node.outputs[0].shape = ir.Shape(["batch", 160000, 1])

first_reshape_shape = ir.val(
    "first_reshape_shape", const_value=ir.tensor([-1, 1, 160000, 1])
)
m.graph.initializers.add(first_reshape_shape)

# Simplify first Reshape + Unsqueeze
for node in m.graph:
    if node.op_type == "Unsqueeze":
        print("Simplify Reshape + Unsqueeze:", node.name)
        if node.inputs[0].producer() and node.inputs[0].producer().op_type == "Reshape":
            reshape_node = node.inputs[0].producer()
            reshape_node.replace_input_with(1, first_reshape_shape)
            node.outputs[0].replace_all_uses_with(reshape_node.outputs[0])
            reshape_node.outputs[0].shape = ir.Shape(["batch", 1, 160000, 1])
            break

# Fuse Conv + Sub into Conv
for node in m.graph:
    if node.op_type == "Conv":
        print("Check Conv for fusion:", node.name)
        conv_node = node
        assert len(conv_node.outputs[0].uses()) == 1
        for usage in conv_node.outputs[0].uses():
            if usage.node.op_type == "Sub":
                sub_node = usage.node
                print("  Fuse Sub into Conv:", sub_node.name)
                sub_value = sub_node.inputs[1]
                new_bias = (np.negative(sub_value.const_value.numpy())).reshape((-1,))
                new_bias_val = ir.val(
                    f"{sub_value.name}_neg",
                    const_value=ir.tensor(new_bias),
                )
                m.graph.initializers.add(new_bias_val)
                if len(conv_node.inputs) == 2:
                    # Bad access of private field
                    conv_node._inputs = conv_node._inputs + (None,)
                conv_node.replace_input_with(2, new_bias_val)
                sub_node.outputs[0].replace_all_uses_with(conv_node.outputs[0])

# Clean up any unused nodes
onnx_ir.passes.common.RemoveUnusedNodesPass()(m)

# Clear all intermediate shapes and re-infer shapes
for node in m.graph:
    for output in node.outputs:
        if output.is_graph_output():
            continue
        output.shape = None

m.graph.inputs[0].shape = ir.Shape(["batch", *m.graph.inputs[0].shape[1:]])
for output in m.graph.outputs:
    output.shape = ir.Shape(["batch", *output.shape[1:]])

onnxscript.optimizer.optimize(
    m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)

onnx_ir.passes.common.ClearMetadataAndDocStringPass()(m)

# Replace None dim with "batch"
for node in m.graph:
    for output in node.outputs:
        if output.shape is None:
            continue
        shape = ir.Shape(output.shape)
        for i in range(len(shape)):
            dim = shape[i]
            if isinstance(dim, ir.SymbolicDim) and dim.value is None:
                shape[i] = ir.SymbolicDim("batch")
        output.shape = shape

# Rename IO and match the tflite model
m.graph.inputs[0].name = "inputs"
m.graph.outputs[0].name = "spatial_embedding"
m.graph.outputs[1].name = "embedding"
m.graph.outputs[2].name = "spectrogram"
m.graph.outputs[3].name = "label"

out_0 = m.graph.outputs[0]
out_1 = m.graph.outputs[1]
m.graph.outputs[1] = out_0
m.graph.outputs[0] = out_1

m.producer_name = "onnx-ir"
m.producer_version = None
m.ir_version = 10

ir.save(m, "perch_v2.onnx")