File size: 5,771 Bytes
52e82ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
"""Cleanup and optimize perch_v2_slim.onnx model.
This script can be applied after completing these steps:
1. Use `tf2onnx` to convert the tflite model to onnx
2. Apply onnxslim and onnxscript.optimize.optimizer on the model
3. Manually edit the model to remove the first DFT node (no-op) and fuse
the nodes that effectively takes the magnitude of the DFT output with ReduceL2.
"""
import onnx_ir as ir
import onnx_ir.passes.common
import onnxscript
import numpy as np
m = ir.load("perch_v2_slim.onnx")
for node in m.graph:
if node.op_type == "MatMul":
print("Simplify MatMul + Reshape:", node.name)
if node.inputs[0].producer().op_type == "Reshape":
# Skip the reshape
input = node.inputs[0].producer().inputs[0]
node.replace_input_with(0, input)
for usage in node.outputs[0].uses():
if usage.node.op_type == "Reshape":
reshape_usages = list(usage.node.outputs[0].uses())
# Keep the last Reshape
if reshape_usages[0].node.op_type == "ReduceMax":
shape = ir.val(
"reshape_shape", const_value=ir.tensor([-1, 16, 4, 14795, 4])
)
m.graph.initializers.add(shape)
usage.node.replace_input_with(1, shape)
continue
reshape_node = usage.node
output = reshape_node.outputs[0]
output.replace_all_uses_with(node.outputs[0])
# Remove Expand
if node.op_type == "Expand":
print("Remove Expand:", node.name)
input = node.inputs[0]
output = node.outputs[0]
output.replace_all_uses_with(input)
# Clean up any unused nodes
onnx_ir.passes.common.RemoveUnusedNodesPass()(m)
# Do some const folding
onnxscript.optimizer.optimize(
m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)
one_1d = ir.val("1d_one", const_value=ir.tensor([1], dtype=ir.DataType.INT64))
m.graph.initializers.add(one_1d)
# Simplify Unsqueeze + Reshape
for node in m.graph:
if node.op_type == "Reshape":
print("Simplify Unsqueeze + Reshape:", node.name)
if (
node.inputs[0].producer()
and node.inputs[0].producer().op_type == "Unsqueeze"
):
unsqueeze_node = node.inputs[0].producer()
unsqueeze_node.replace_input_with(1, one_1d)
node.outputs[0].replace_all_uses_with(unsqueeze_node.outputs[0])
unsqueeze_node.outputs[0].shape = ir.Shape(["batch", 160000, 1])
first_reshape_shape = ir.val(
"first_reshape_shape", const_value=ir.tensor([-1, 1, 160000, 1])
)
m.graph.initializers.add(first_reshape_shape)
# Simplify first Reshape + Unsqueeze
for node in m.graph:
if node.op_type == "Unsqueeze":
print("Simplify Reshape + Unsqueeze:", node.name)
if node.inputs[0].producer() and node.inputs[0].producer().op_type == "Reshape":
reshape_node = node.inputs[0].producer()
reshape_node.replace_input_with(1, first_reshape_shape)
node.outputs[0].replace_all_uses_with(reshape_node.outputs[0])
reshape_node.outputs[0].shape = ir.Shape(["batch", 1, 160000, 1])
break
# Fuse Conv + Sub into Conv
for node in m.graph:
if node.op_type == "Conv":
print("Check Conv for fusion:", node.name)
conv_node = node
assert len(conv_node.outputs[0].uses()) == 1
for usage in conv_node.outputs[0].uses():
if usage.node.op_type == "Sub":
sub_node = usage.node
print(" Fuse Sub into Conv:", sub_node.name)
sub_value = sub_node.inputs[1]
new_bias = (np.negative(sub_value.const_value.numpy())).reshape((-1,))
new_bias_val = ir.val(
f"{sub_value.name}_neg",
const_value=ir.tensor(new_bias),
)
m.graph.initializers.add(new_bias_val)
if len(conv_node.inputs) == 2:
# Bad access of private field
conv_node._inputs = conv_node._inputs + (None,)
conv_node.replace_input_with(2, new_bias_val)
sub_node.outputs[0].replace_all_uses_with(conv_node.outputs[0])
# Clean up any unused nodes
onnx_ir.passes.common.RemoveUnusedNodesPass()(m)
# Clear all intermediate shapes and re-infer shapes
for node in m.graph:
for output in node.outputs:
if output.is_graph_output():
continue
output.shape = None
m.graph.inputs[0].shape = ir.Shape(["batch", *m.graph.inputs[0].shape[1:]])
for output in m.graph.outputs:
output.shape = ir.Shape(["batch", *output.shape[1:]])
onnxscript.optimizer.optimize(
m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)
onnx_ir.passes.common.ClearMetadataAndDocStringPass()(m)
# Replace None dim with "batch"
for node in m.graph:
for output in node.outputs:
if output.shape is None:
continue
shape = ir.Shape(output.shape)
for i in range(len(shape)):
dim = shape[i]
if isinstance(dim, ir.SymbolicDim) and dim.value is None:
shape[i] = ir.SymbolicDim("batch")
output.shape = shape
# Rename IO and match the tflite model
m.graph.inputs[0].name = "inputs"
m.graph.outputs[0].name = "spatial_embedding"
m.graph.outputs[1].name = "embedding"
m.graph.outputs[2].name = "spectrogram"
m.graph.outputs[3].name = "label"
out_0 = m.graph.outputs[0]
out_1 = m.graph.outputs[1]
m.graph.outputs[1] = out_0
m.graph.outputs[0] = out_1
m.producer_name = "onnx-ir"
m.producer_version = None
m.ir_version = 10
ir.save(m, "perch_v2.onnx")
|