File size: 6,832 Bytes
1f1b5fd 650fcdf 1f1b5fd 650fcdf 1f1b5fd 650fcdf 1f1b5fd 650fcdf 1f1b5fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import onnxscript
import onnx_ir as ir
import onnx_ir.passes.common
import numpy as np
import onnxslim
class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
def pattern(self, op, x, dft_length):
x = op.Reshape(x, _allow_other_inputs=True)
dft = op.DFT(x, dft_length, _outputs=["dft_output"])
real_part = op.Slice(dft, [0], [1], [-1])
return op.Squeeze(real_part, [-1])
def rewrite(self, op, x: ir.Value, dft_length: ir.Value, dft_output: ir.Value):
# Get the DFT node attributes
dft_node = dft_output.producer()
assert dft_node is not None
dft_size = ir.convenience.get_const_tensor(dft_length).numpy().item()
# Create one-sided DFT matrix (only real part, DC to Nyquist)
# The real part of DFT is: Re(DFT[k]) = sum(x[n] * cos(2*pi*k*n/N))
# For one-sided DFT, we only need frequencies from 0 to Nyquist (dft_size//2 + 1)
num_freqs = dft_size // 2 + 1
# Vectorized creation of DFT matrix
n = np.arange(dft_size, dtype=np.float32)[:, np.newaxis] # Shape: (dft_size, 1)
k = np.arange(num_freqs, dtype=np.float32)[
np.newaxis, :
] # Shape: (1, num_freqs)
dft_matrix = np.cos(
2 * np.pi * k * n / dft_size
) # Shape: (dft_size, num_freqs)
# Create constant node for the DFT matrix
dft_matrix = op.initializer(ir.tensor(dft_matrix), name=f"{x.name}_dft_matrix")
# DFT axis is already at the end, direct matrix multiplication
result = op.MatMul(x, dft_matrix)
return result
class ReplaceSplit(onnxscript.rewriter.RewriteRuleClassBase):
def pattern(self, op, x):
return op.Split(
x, _allow_other_inputs=True, _outputs=["split_out_1", "split_out_2"]
)
def rewrite(self, op, x: ir.Value, **kwargs):
zero = op.initializer(ir.tensor(np.array([0], dtype=np.int64)), "zero")
batch_size = op.Gather(x, zero)
sample_size = op.initializer(
ir.tensor(np.array([144000], dtype=np.int32)), "sample_size"
)
return batch_size, sample_size
class RemoveCast(onnxscript.rewriter.RewriteRuleClassBase):
def pattern(self, op, x):
return op.Cast(x)
def rewrite(self, op, x: ir.Value, **kwargs):
return op.Identity(x)
class RemoveReversedSequenceFork(onnxscript.rewriter.RewriteRuleClassBase):
def pattern(self, op, x, y, scale, bias):
x = op.Transpose(x)
y = op.Transpose(y)
x = op.ReverseSequence(x, _allow_other_inputs=True)
y = op.ReverseSequence(y, _allow_other_inputs=True)
x = op.Unsqueeze(x, _allow_other_inputs=True)
y = op.Unsqueeze(y, _allow_other_inputs=True)
concat = op.Concat(x, y)
mul = op.Mul(concat, scale)
add = op.Add(mul, bias)
return op.Transpose(add)
def rewrite(self, op, x, y, scale, bias, **kwargs):
# x: batch, 511, 96
neg_one = op.initializer(ir.tensor(np.array([-1], dtype=np.int64)), "neg_one")
int_64_min = op.initializer(
ir.tensor(np.array([-9223372036854775808], dtype=np.int64)), "int_64_min"
)
# slice
x = op.Slice(x, neg_one, int_64_min, neg_one, neg_one)
y = op.Slice(y, neg_one, int_64_min, neg_one, neg_one)
x = op.Unsqueeze(x, neg_one)
y = op.Unsqueeze(y, neg_one)
concat = op.Concat(x, y, axis=3)
# batch, 511, 96, 2
mul = op.Mul(concat, scale)
add = op.Add(mul, bias)
return op.Transpose(add, perm=[0, 3, 2, 1]) # batch, 2, 96, 511
model = ir.load("model.onnx")
# Set dynamic axes
model.graph.inputs[0].shape = ir.Shape(["batch", 144000])
model.graph.outputs[0].shape = ir.Shape(["batch", 6522])
onnxscript.rewriter.rewrite(
model,
[
ReplaceDftWithMatMulRule().rule(),
ReplaceSplit().rule(),
RemoveCast().rule(),
],
)
# Change all int32 initializers to int64
initializers = list(model.graph.initializers.values())
for initializer in initializers:
if initializer.dtype == ir.DataType.INT32:
int32_array = initializer.const_value.numpy()
int64_array = int32_array.astype(np.int64)
new_initializer = ir.val(initializer.name, const_value=ir.tensor(int64_array))
model.graph.initializers.pop(initializer.name)
model.graph.initializers.add(new_initializer)
initializer.replace_all_uses_with(new_initializer)
onnxscript.optimizer.optimize(
model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)
# Remove Slice-Reshape
def remove_slice_reshape(model: ir.Model):
mul_node = model.graph.node("model/MEL_SPEC1/Mul")
first_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_1")
first_shape = ir.val(
"first_shape", const_value=ir.tensor([-1, 72000, 2], dtype=ir.DataType.INT64)
)
model.graph.initializers.add(first_shape)
second_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_1")
second_shape = ir.val(
"second_shape", const_value=ir.tensor([-1, 18000, 8], dtype=ir.DataType.INT64)
)
model.graph.initializers.add(second_shape)
third_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_4")
third_shape = ir.val(
"third_shape", const_value=ir.tensor([-1, 511, 2048], dtype=ir.DataType.INT64)
)
model.graph.initializers.add(third_shape)
fourth_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_4")
fourth_shape = ir.val(
"fourth_shape", const_value=ir.tensor([-1, 511, 1024], dtype=ir.DataType.INT64)
)
model.graph.initializers.add(fourth_shape)
# Replace with Mul-Reshape-Gather
first_reshape.replace_input_with(0, mul_node.outputs[0])
first_reshape.replace_input_with(1, first_shape)
second_reshape.replace_input_with(0, mul_node.outputs[0])
second_reshape.replace_input_with(1, second_shape)
third_reshape.replace_input_with(1, third_shape)
fourth_reshape.replace_input_with(1, fourth_shape)
remove_slice_reshape(model)
# Run DCE again
onnxscript.optimizer.optimize(
model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)
print("Slimming model...")
model = ir.from_proto(onnxslim.slim(ir.to_proto(model)))
print("Removing reversed sequence fork...")
onnxscript.rewriter.rewrite(
model,
[
RemoveReversedSequenceFork.rule(),
],
)
# Use onnxslim to do shape inference
model = ir.from_proto(onnxslim.slim(ir.to_proto(model)))
onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
model.graph.inputs[0].name = "input"
model.graph.outputs[0].name = "output"
model.ir_version = 10
model.producer_name = "onnx-ir"
model.graph.name = "BirdNET-v2.4"
ir.save(model, "birdnet.onnx")
|