Upload optimize_data_model.py
Browse files
scripts/optimize_data_model.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Optimize the data model ONNX file.
|
| 2 |
+
|
| 3 |
+
The birdnet_data_model_slim.onnx file is obtained with
|
| 4 |
+
|
| 5 |
+
python -m tf2onnx.convert --opset 18 --tflite 'BirdNET_GLOBAL_6K_V2.4_MData_Model_FP16.tflite' --output birdnet_data_model.onnx
|
| 6 |
+
onnxslim birdnet_data_model.onnx birdnet_data_model_slim.onnx
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import onnxscript.optimizer
|
| 10 |
+
import onnx_ir as ir
|
| 11 |
+
|
| 12 |
+
model = ir.load("birdnet_data_model_slim.onnx")
|
| 13 |
+
|
| 14 |
+
# Remove add-mul-0
|
| 15 |
+
|
| 16 |
+
cast_1 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2_1__50")
|
| 17 |
+
add_1 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2_1")
|
| 18 |
+
add_1.outputs[0].replace_all_uses_with(cast_1.outputs[0])
|
| 19 |
+
|
| 20 |
+
cast_2 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2__60")
|
| 21 |
+
add_2 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2")
|
| 22 |
+
add_2.outputs[0].replace_all_uses_with(cast_2.outputs[0])
|
| 23 |
+
|
| 24 |
+
onnxscript.optimizer.optimize(model)
|
| 25 |
+
|
| 26 |
+
model.ir_version = 10
|
| 27 |
+
model.graph.name = "BirdNET-v2.4-Data_Model"
|
| 28 |
+
model.producer_name = "onnx-ir"
|
| 29 |
+
model.producer_version = None
|
| 30 |
+
model.graph.inputs[0].name = "input"
|
| 31 |
+
model.graph.outputs[0].name = "output"
|
| 32 |
+
model.graph.outputs[0].shape = ir.Shape(["batch", model.graph.outputs[0].shape[1]])
|
| 33 |
+
|
| 34 |
+
ir.save(model, "birdnet_data_model.onnx")
|