File size: 1,143 Bytes
485d186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""Optimize the data model ONNX file.

The birdnet_data_model_slim.onnx file is obtained with

python -m tf2onnx.convert --opset 18 --tflite 'BirdNET_GLOBAL_6K_V2.4_MData_Model_FP16.tflite' --output birdnet_data_model.onnx
onnxslim birdnet_data_model.onnx birdnet_data_model_slim.onnx
"""

import onnxscript.optimizer
import onnx_ir as ir

model = ir.load("birdnet_data_model_slim.onnx")

# Remove add-mul-0

cast_1 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2_1__50")
add_1 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2_1")
add_1.outputs[0].replace_all_uses_with(cast_1.outputs[0])

cast_2 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2__60")
add_2 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2")
add_2.outputs[0].replace_all_uses_with(cast_2.outputs[0])

onnxscript.optimizer.optimize(model)

model.ir_version = 10
model.graph.name = "BirdNET-v2.4-Data_Model"
model.producer_name = "onnx-ir"
model.producer_version = None
model.graph.inputs[0].name = "input"
model.graph.outputs[0].name = "output"
model.graph.outputs[0].shape = ir.Shape(["batch", model.graph.outputs[0].shape[1]])

ir.save(model, "birdnet_data_model.onnx")