File size: 5,190 Bytes
7149468 884d307 7149468 884d307 1f2b9f9 884d307 1f2b9f9 884d307 1f2b9f9 884d307 7149468 1f2b9f9 1fc03e4 1f2b9f9 1fc03e4 1f2b9f9 1fc03e4 7149468 1f2b9f9 1fc03e4 1f2b9f9 7149468 1f2b9f9 1fc03e4 1f2b9f9 1fc03e4 1f2b9f9 1fc03e4 1f2b9f9 1fc03e4 1f2b9f9 1fc03e4 1f2b9f9 1fc03e4 1f2b9f9 7149468 1fc03e4 7149468 1f2b9f9 7149468 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import onnxscript
import onnx_ir as ir
import onnx_ir.passes.common
import numpy as np
class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
def pattern(self, op, x, dft_length):
x = op.Reshape(x, _allow_other_inputs=True)
dft = op.DFT(x, dft_length, _outputs=["dft_output"])
real_part = op.Slice(dft, [0], [1], [-1])
return op.Squeeze(real_part, [-1])
def rewrite(self, op, x: ir.Value, dft_length: ir.Value, dft_output: ir.Value):
# Get the DFT node attributes
dft_node = dft_output.producer()
assert dft_node is not None
dft_size = ir.convenience.get_const_tensor(dft_length).numpy().item()
# Create one-sided DFT matrix (only real part, DC to Nyquist)
# The real part of DFT is: Re(DFT[k]) = sum(x[n] * cos(2*pi*k*n/N))
# For one-sided DFT, we only need frequencies from 0 to Nyquist (dft_size//2 + 1)
num_freqs = dft_size // 2 + 1
# Vectorized creation of DFT matrix
n = np.arange(dft_size, dtype=np.float32)[:, np.newaxis] # Shape: (dft_size, 1)
k = np.arange(num_freqs, dtype=np.float32)[
np.newaxis, :
] # Shape: (1, num_freqs)
dft_matrix = np.cos(
2 * np.pi * k * n / dft_size
) # Shape: (dft_size, num_freqs)
# Create constant node for the DFT matrix
dft_matrix = op.initializer(ir.tensor(dft_matrix), name=f"{x.name}_dft_matrix")
# DFT axis is already at the end, direct matrix multiplication
result = op.MatMul(x, dft_matrix)
return result
class ReplaceSplit(onnxscript.rewriter.RewriteRuleClassBase):
def pattern(self, op, x):
return op.Split(
x, _allow_other_inputs=True, _outputs=["split_out_1", "split_out_2"]
)
def rewrite(self, op, x: ir.Value, **kwargs):
zero = op.initializer(ir.tensor(np.array([0], dtype=np.int64)), "zero")
batch_size = op.Gather(x, zero)
sample_size = op.initializer(
ir.tensor(np.array([144000], dtype=np.int32)), "sample_size"
)
return batch_size, sample_size
class RemoveCast(onnxscript.rewriter.RewriteRuleClassBase):
def pattern(self, op, x):
return op.Cast(x)
def rewrite(self, op, x: ir.Value, **kwargs):
return op.Identity(x)
model = ir.load("model.onnx")
# Set dynamic axes
model.graph.inputs[0].shape = ir.Shape(["batch", 144000])
model.graph.outputs[0].shape = ir.Shape(["batch", 6522])
onnxscript.rewriter.rewrite(
model,
[ReplaceDftWithMatMulRule().rule(), ReplaceSplit().rule(), RemoveCast().rule()],
)
# Change all int32 initializers to int64
initializers = list(model.graph.initializers.values())
for initializer in initializers:
if initializer.dtype == ir.DataType.INT32:
int32_array = initializer.const_value.numpy()
int64_array = int32_array.astype(np.int64)
new_initializer = ir.val(initializer.name, const_value=ir.tensor(int64_array))
model.graph.initializers.pop(initializer.name)
model.graph.initializers.add(new_initializer)
initializer.replace_all_uses_with(new_initializer)
onnxscript.optimizer.optimize(
model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)
# Remove Slice-Reshape
def remove_slice_reshape(model: ir.Model):
mul_node = model.graph.node("model/MEL_SPEC1/Mul")
first_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_1")
first_shape = ir.val(
"first_shape", const_value=ir.tensor([-1, 72000, 2], dtype=ir.DataType.INT64)
)
model.graph.initializers.add(first_shape)
second_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_1")
second_shape = ir.val(
"second_shape", const_value=ir.tensor([-1, 18000, 8], dtype=ir.DataType.INT64)
)
model.graph.initializers.add(second_shape)
third_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_4")
third_shape = ir.val(
"third_shape", const_value=ir.tensor([-1, 511, 2048], dtype=ir.DataType.INT64)
)
model.graph.initializers.add(third_shape)
fourth_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_4")
fourth_shape = ir.val(
"fourth_shape", const_value=ir.tensor([-1, 511, 1024], dtype=ir.DataType.INT64)
)
model.graph.initializers.add(fourth_shape)
# Replace with Mul-Reshape-Gather
first_reshape.replace_input_with(0, mul_node.outputs[0])
first_reshape.replace_input_with(1, first_shape)
second_reshape.replace_input_with(0, mul_node.outputs[0])
second_reshape.replace_input_with(1, second_shape)
third_reshape.replace_input_with(1, third_shape)
fourth_reshape.replace_input_with(1, fourth_shape)
remove_slice_reshape(model)
# Run DCE again
onnxscript.optimizer.optimize(
model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)
onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
model.graph.inputs[0].name = "input"
model.graph.outputs[0].name = "output"
model.ir_version = 10
model.producer_name = "onnx-ir"
model.graph.name = "BirdNET-v2.4"
ir.save(model, "birdnet.onnx")
|