|
|
"""Optimize the data model ONNX file. |
|
|
|
|
|
The birdnet_data_model_slim.onnx file is obtained with |
|
|
|
|
|
python -m tf2onnx.convert --opset 18 --tflite 'BirdNET_GLOBAL_6K_V2.4_MData_Model_FP16.tflite' --output birdnet_data_model.onnx |
|
|
onnxslim birdnet_data_model.onnx birdnet_data_model_slim.onnx |
|
|
""" |
|
|
|
|
|
import onnxscript.optimizer |
|
|
import onnx_ir as ir |
|
|
|
|
|
model = ir.load("birdnet_data_model_slim.onnx") |
|
|
|
|
|
|
|
|
|
|
|
cast_1 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2_1__50") |
|
|
add_1 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2_1") |
|
|
add_1.outputs[0].replace_all_uses_with(cast_1.outputs[0]) |
|
|
|
|
|
cast_2 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2__60") |
|
|
add_2 = model.graph.node("BirdNET/MNET_CONVERT/SelectV2") |
|
|
add_2.outputs[0].replace_all_uses_with(cast_2.outputs[0]) |
|
|
|
|
|
onnxscript.optimizer.optimize(model) |
|
|
|
|
|
model.ir_version = 10 |
|
|
model.graph.name = "BirdNET-v2.4-Data_Model" |
|
|
model.producer_name = "onnx-ir" |
|
|
model.producer_version = None |
|
|
model.graph.inputs[0].name = "input" |
|
|
model.graph.outputs[0].name = "output" |
|
|
model.graph.outputs[0].shape = ir.Shape(["batch", model.graph.outputs[0].shape[1]]) |
|
|
|
|
|
ir.save(model, "birdnet_data_model.onnx") |
|
|
|