Justin Chu commited on
Commit
1fc03e4
·
1 Parent(s): 1f2b9f9

Create optimized models

Browse files
BirdNET_GLOBAL_6K_V2.4_Model_FP32.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55f3e4055b1a13bfa9a2452731d0d34f6a02d6b775a334362665892794165e4c
3
+ size 51726412
birdnet.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:920e2bd05aa6265e68e21711bc3bab5900cf7c4117f7d958723632ef75956295
3
+ size 66935346
scripts/optimize.py CHANGED
@@ -43,12 +43,16 @@ class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
43
 
44
  class ReplaceSplit(onnxscript.rewriter.RewriteRuleClassBase):
45
  def pattern(self, op, x):
46
- return op.Split(x, _allow_other_inputs=True, _outputs=["split_out_1", "split_out_2"])
47
-
 
 
48
  def rewrite(self, op, x: ir.Value, **kwargs):
49
  zero = op.initializer(ir.tensor(np.array([0], dtype=np.int64)), "zero")
50
  batch_size = op.Gather(x, zero)
51
- sample_size = op.initializer(ir.tensor(np.array([144000], dtype=np.int32)), "sample_size")
 
 
52
  return batch_size, sample_size
53
 
54
 
@@ -59,13 +63,17 @@ class RemoveCast(onnxscript.rewriter.RewriteRuleClassBase):
59
  def rewrite(self, op, x: ir.Value, **kwargs):
60
  return op.Identity(x)
61
 
 
62
  model = ir.load("model.onnx")
63
 
64
  # Set dynamic axes
65
  model.graph.inputs[0].shape = ir.Shape(["batch", 144000])
66
  model.graph.outputs[0].shape = ir.Shape(["batch", 6522])
67
 
68
- onnxscript.rewriter.rewrite(model, [ReplaceDftWithMatMulRule().rule(), ReplaceSplit().rule(), RemoveCast().rule()])
 
 
 
69
 
70
  # Change all int32 initializers to int64
71
  initializers = list(model.graph.initializers.values())
@@ -82,22 +90,39 @@ onnxscript.optimizer.optimize(
82
  model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
83
  )
84
 
 
85
  # Remove Slice-Reshape
86
  def remove_slice_reshape(model: ir.Model):
87
  mul_node = model.graph.node("model/MEL_SPEC1/Mul")
88
  first_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_1")
89
- first_shape = ir.val("first_shape", const_value=ir.tensor([-1, 72000, 2], dtype=ir.DataType.INT64))
 
 
90
  model.graph.initializers.add(first_shape)
91
  second_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_1")
92
- second_shape = ir.val("second_shape", const_value=ir.tensor([-1, 18000, 8], dtype=ir.DataType.INT64))
 
 
93
  model.graph.initializers.add(second_shape)
94
 
 
 
 
 
 
 
 
 
 
 
 
95
  # Replace with Mul-Reshape-Gather
96
  first_reshape.replace_input_with(0, mul_node.outputs[0])
97
  first_reshape.replace_input_with(1, first_shape)
98
- second_reshape.replace_input_with(0, first_reshape.outputs[0])
99
  second_reshape.replace_input_with(1, second_shape)
100
-
 
101
 
102
 
103
  remove_slice_reshape(model)
@@ -108,6 +133,8 @@ onnxscript.optimizer.optimize(
108
 
109
 
110
  onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
 
 
111
  model.ir_version = 10
112
  model.producer_name = "onnx-ir"
113
  model.graph.name = "BirdNET-v2.4"
 
43
 
44
  class ReplaceSplit(onnxscript.rewriter.RewriteRuleClassBase):
45
  def pattern(self, op, x):
46
+ return op.Split(
47
+ x, _allow_other_inputs=True, _outputs=["split_out_1", "split_out_2"]
48
+ )
49
+
50
  def rewrite(self, op, x: ir.Value, **kwargs):
51
  zero = op.initializer(ir.tensor(np.array([0], dtype=np.int64)), "zero")
52
  batch_size = op.Gather(x, zero)
53
+ sample_size = op.initializer(
54
+ ir.tensor(np.array([144000], dtype=np.int32)), "sample_size"
55
+ )
56
  return batch_size, sample_size
57
 
58
 
 
63
  def rewrite(self, op, x: ir.Value, **kwargs):
64
  return op.Identity(x)
65
 
66
+
67
  model = ir.load("model.onnx")
68
 
69
  # Set dynamic axes
70
  model.graph.inputs[0].shape = ir.Shape(["batch", 144000])
71
  model.graph.outputs[0].shape = ir.Shape(["batch", 6522])
72
 
73
+ onnxscript.rewriter.rewrite(
74
+ model,
75
+ [ReplaceDftWithMatMulRule().rule(), ReplaceSplit().rule(), RemoveCast().rule()],
76
+ )
77
 
78
  # Change all int32 initializers to int64
79
  initializers = list(model.graph.initializers.values())
 
90
  model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
91
  )
92
 
93
+
94
  # Remove Slice-Reshape
95
  def remove_slice_reshape(model: ir.Model):
96
  mul_node = model.graph.node("model/MEL_SPEC1/Mul")
97
  first_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_1")
98
+ first_shape = ir.val(
99
+ "first_shape", const_value=ir.tensor([-1, 72000, 2], dtype=ir.DataType.INT64)
100
+ )
101
  model.graph.initializers.add(first_shape)
102
  second_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_1")
103
+ second_shape = ir.val(
104
+ "second_shape", const_value=ir.tensor([-1, 18000, 8], dtype=ir.DataType.INT64)
105
+ )
106
  model.graph.initializers.add(second_shape)
107
 
108
+ third_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_4")
109
+ third_shape = ir.val(
110
+ "third_shape", const_value=ir.tensor([-1, 511, 2048], dtype=ir.DataType.INT64)
111
+ )
112
+ model.graph.initializers.add(third_shape)
113
+ fourth_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_4")
114
+ fourth_shape = ir.val(
115
+ "fourth_shape", const_value=ir.tensor([-1, 511, 1024], dtype=ir.DataType.INT64)
116
+ )
117
+ model.graph.initializers.add(fourth_shape)
118
+
119
  # Replace with Mul-Reshape-Gather
120
  first_reshape.replace_input_with(0, mul_node.outputs[0])
121
  first_reshape.replace_input_with(1, first_shape)
122
+ second_reshape.replace_input_with(0, mul_node.outputs[0])
123
  second_reshape.replace_input_with(1, second_shape)
124
+ third_reshape.replace_input_with(1, third_shape)
125
+ fourth_reshape.replace_input_with(1, fourth_shape)
126
 
127
 
128
  remove_slice_reshape(model)
 
133
 
134
 
135
  onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
136
+ model.graph.inputs[0].name = "input"
137
+ model.graph.outputs[0].name = "output"
138
  model.ir_version = 10
139
  model.producer_name = "onnx-ir"
140
  model.graph.name = "BirdNET-v2.4"