justinchuby commited on
Commit
52e82ec
·
verified ·
1 Parent(s): 3217423

Create scripts/cleanup.py

Browse files
Files changed (1) hide show
  1. scripts/cleanup.py +157 -0
scripts/cleanup.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cleanup and optimize perch_v2_slim.onnx model.
2
+
3
+ This script can be applied after completing these steps:
4
+
5
+ 1. Use `tf2onnx` to convert the tflite model to onnx
6
+ 2. Apply onnxslim and onnxscript.optimize.optimizer on the model
7
+ 3. Manually edit the model to remove the first DFT node (no-op) and fuse
8
+ the nodes that effectively takes the magnitude of the DFT output with ReduceL2.
9
+ """
10
+
11
+ import onnx_ir as ir
12
+ import onnx_ir.passes.common
13
+ import onnxscript
14
+ import numpy as np
15
+
16
+ m = ir.load("perch_v2_slim.onnx")
17
+
18
+ for node in m.graph:
19
+ if node.op_type == "MatMul":
20
+ print("Simplify MatMul + Reshape:", node.name)
21
+ if node.inputs[0].producer().op_type == "Reshape":
22
+ # Skip the reshape
23
+ input = node.inputs[0].producer().inputs[0]
24
+ node.replace_input_with(0, input)
25
+
26
+ for usage in node.outputs[0].uses():
27
+ if usage.node.op_type == "Reshape":
28
+ reshape_usages = list(usage.node.outputs[0].uses())
29
+ # Keep the last Reshape
30
+ if reshape_usages[0].node.op_type == "ReduceMax":
31
+ shape = ir.val(
32
+ "reshape_shape", const_value=ir.tensor([-1, 16, 4, 14795, 4])
33
+ )
34
+ m.graph.initializers.add(shape)
35
+ usage.node.replace_input_with(1, shape)
36
+ continue
37
+ reshape_node = usage.node
38
+ output = reshape_node.outputs[0]
39
+ output.replace_all_uses_with(node.outputs[0])
40
+
41
+ # Remove Expand
42
+ if node.op_type == "Expand":
43
+ print("Remove Expand:", node.name)
44
+ input = node.inputs[0]
45
+ output = node.outputs[0]
46
+ output.replace_all_uses_with(input)
47
+
48
+ # Clean up any unused nodes
49
+ onnx_ir.passes.common.RemoveUnusedNodesPass()(m)
50
+ # Do some const folding
51
+ onnxscript.optimizer.optimize(
52
+ m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
53
+ )
54
+ one_1d = ir.val("1d_one", const_value=ir.tensor([1], dtype=ir.DataType.INT64))
55
+ m.graph.initializers.add(one_1d)
56
+
57
+ # Simplify Unsqueeze + Reshape
58
+ for node in m.graph:
59
+ if node.op_type == "Reshape":
60
+ print("Simplify Unsqueeze + Reshape:", node.name)
61
+ if (
62
+ node.inputs[0].producer()
63
+ and node.inputs[0].producer().op_type == "Unsqueeze"
64
+ ):
65
+ unsqueeze_node = node.inputs[0].producer()
66
+ unsqueeze_node.replace_input_with(1, one_1d)
67
+ node.outputs[0].replace_all_uses_with(unsqueeze_node.outputs[0])
68
+ unsqueeze_node.outputs[0].shape = ir.Shape(["batch", 160000, 1])
69
+
70
+ first_reshape_shape = ir.val(
71
+ "first_reshape_shape", const_value=ir.tensor([-1, 1, 160000, 1])
72
+ )
73
+ m.graph.initializers.add(first_reshape_shape)
74
+
75
+ # Simplify first Reshape + Unsqueeze
76
+ for node in m.graph:
77
+ if node.op_type == "Unsqueeze":
78
+ print("Simplify Reshape + Unsqueeze:", node.name)
79
+ if node.inputs[0].producer() and node.inputs[0].producer().op_type == "Reshape":
80
+ reshape_node = node.inputs[0].producer()
81
+ reshape_node.replace_input_with(1, first_reshape_shape)
82
+ node.outputs[0].replace_all_uses_with(reshape_node.outputs[0])
83
+ reshape_node.outputs[0].shape = ir.Shape(["batch", 1, 160000, 1])
84
+ break
85
+
86
+ # Fuse Conv + Sub into Conv
87
+ for node in m.graph:
88
+ if node.op_type == "Conv":
89
+ print("Check Conv for fusion:", node.name)
90
+ conv_node = node
91
+ assert len(conv_node.outputs[0].uses()) == 1
92
+ for usage in conv_node.outputs[0].uses():
93
+ if usage.node.op_type == "Sub":
94
+ sub_node = usage.node
95
+ print(" Fuse Sub into Conv:", sub_node.name)
96
+ sub_value = sub_node.inputs[1]
97
+ new_bias = (np.negative(sub_value.const_value.numpy())).reshape((-1,))
98
+ new_bias_val = ir.val(
99
+ f"{sub_value.name}_neg",
100
+ const_value=ir.tensor(new_bias),
101
+ )
102
+ m.graph.initializers.add(new_bias_val)
103
+ if len(conv_node.inputs) == 2:
104
+ # Bad access of private field
105
+ conv_node._inputs = conv_node._inputs + (None,)
106
+ conv_node.replace_input_with(2, new_bias_val)
107
+ sub_node.outputs[0].replace_all_uses_with(conv_node.outputs[0])
108
+
109
+ # Clean up any unused nodes
110
+ onnx_ir.passes.common.RemoveUnusedNodesPass()(m)
111
+
112
+ # Clear all intermediate shapes and re-infer shapes
113
+ for node in m.graph:
114
+ for output in node.outputs:
115
+ if output.is_graph_output():
116
+ continue
117
+ output.shape = None
118
+
119
+ m.graph.inputs[0].shape = ir.Shape(["batch", *m.graph.inputs[0].shape[1:]])
120
+ for output in m.graph.outputs:
121
+ output.shape = ir.Shape(["batch", *output.shape[1:]])
122
+
123
+ onnxscript.optimizer.optimize(
124
+ m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
125
+ )
126
+
127
+ onnx_ir.passes.common.ClearMetadataAndDocStringPass()(m)
128
+
129
+ # Replace None dim with "batch"
130
+ for node in m.graph:
131
+ for output in node.outputs:
132
+ if output.shape is None:
133
+ continue
134
+ shape = ir.Shape(output.shape)
135
+ for i in range(len(shape)):
136
+ dim = shape[i]
137
+ if isinstance(dim, ir.SymbolicDim) and dim.value is None:
138
+ shape[i] = ir.SymbolicDim("batch")
139
+ output.shape = shape
140
+
141
+ # Rename IO and match the tflite model
142
+ m.graph.inputs[0].name = "inputs"
143
+ m.graph.outputs[0].name = "spatial_embedding"
144
+ m.graph.outputs[1].name = "embedding"
145
+ m.graph.outputs[2].name = "spectrogram"
146
+ m.graph.outputs[3].name = "label"
147
+
148
+ out_0 = m.graph.outputs[0]
149
+ out_1 = m.graph.outputs[1]
150
+ m.graph.outputs[1] = out_0
151
+ m.graph.outputs[0] = out_1
152
+
153
+ m.producer_name = "onnx-ir"
154
+ m.producer_version = None
155
+ m.ir_version = 10
156
+
157
+ ir.save(m, "perch_v2.onnx")