justinchuby commited on
Commit
485d186
·
verified ·
1 Parent(s): 77cd739

Upload optimize_data_model.py

Browse files
Files changed (1) hide show
  1. scripts/optimize_data_model.py +34 -0
scripts/optimize_data_model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Optimize the data model ONNX file.
2
+
3
+ The birdnet_data_model_slim.onnx file is obtained with
4
+
5
+ python -m tf2onnx.convert --opset 18 --tflite 'BirdNET_GLOBAL_6K_V2.4_MData_Model_FP16.tflite' --output birdnet_data_model.onnx
6
+ onnxslim birdnet_data_model.onnx birdnet_data_model_slim.onnx
7
+ """
8
+
9
+ import onnxscript.optimizer
10
+ import onnx_ir as ir
11
+
12
+ model = ir.load("birdnet_data_model_slim.onnx")
13
+
14
+ # Remove add-mul-0
15
+
16
+ cast_1 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2_1__50")
17
+ add_1 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2_1")
18
+ add_1.outputs[0].replace_all_uses_with(cast_1.outputs[0])
19
+
20
+ cast_2 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2__60")
21
+ add_2 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2")
22
+ add_2.outputs[0].replace_all_uses_with(cast_2.outputs[0])
23
+
24
+ onnxscript.optimizer.optimize(model)
25
+
26
+ model.ir_version = 10
27
+ model.graph.name = "BirdNET-v2.4-Data_Model"
28
+ model.producer_name = "onnx-ir"
29
+ model.producer_version = None
30
+ model.graph.inputs[0].name = "input"
31
+ model.graph.outputs[0].name = "output"
32
+ model.graph.outputs[0].shape = ir.Shape(["batch", model.graph.outputs[0].shape[1]])
33
+
34
+ ir.save(model, "birdnet_data_model.onnx")