Perch-onnx / scripts /convert_dft_to_matmul.py
justinchuby's picture
Create scripts to simplify and benchmark model
1efbb9a verified
"""Convert DFT operations in an ONNX model to equivalent MatMul operations."""
import onnxscript
import onnx_ir as ir
import numpy as np
class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
def pattern(self, op, x, dft_length):
x = op.Reshape(x, _allow_other_inputs=True)
dft = op.DFT(x, dft_length, _outputs=["dft_output"])
return dft
def rewrite(self, op, x: ir.Value, dft_length: ir.Value, dft_output: ir.Value):
# Get the DFT node attributes
dft_node = dft_output.producer()
assert dft_node is not None
dft_size = ir.convenience.get_const_tensor(dft_length).numpy().item()
# Create one-sided DFT matrices (real and imaginary parts, DC to Nyquist)
# Real part: Re(DFT[k]) = sum(x[n] * cos(2*pi*k*n/N))
# Imaginary part: Im(DFT[k]) = sum(x[n] * -sin(2*pi*k*n/N))
# For one-sided DFT, we only need frequencies from 0 to Nyquist (dft_size//2 + 1)
num_freqs = dft_size // 2 + 1
# Vectorized creation of DFT matrices
n = np.arange(dft_size, dtype=np.float32)[:, np.newaxis] # Shape: (dft_size, 1)
k = np.arange(num_freqs, dtype=np.float32)[
np.newaxis, :
] # Shape: (1, num_freqs)
# Real part (cosine)
dft_matrix_real = np.cos(
2 * np.pi * k * n / dft_size
) # Shape: (dft_size, num_freqs)
# Imaginary part (negative sine)
dft_matrix_imag = -np.sin(
2 * np.pi * k * n / dft_size
) # Shape: (dft_size, num_freqs)
# Stack real and imaginary parts: shape (dft_size, num_freqs * 2)
# Interleave real and imaginary: [real_0, imag_0, real_1, imag_1, ...]
dft_matrix = np.stack([dft_matrix_real, dft_matrix_imag], axis=-1).reshape(
dft_size, num_freqs * 2
)
# Create constant node for the combined DFT matrix
dft_matrix = op.initializer(ir.tensor(dft_matrix), name=f"{x.name}_dft_matrix")
# Single matrix multiplication
matmul_result = op.MatMul(x, dft_matrix)
new_shape = op.initializer(
ir.tensor([-1, 500, 513, 2], name=f"{x.name}_dft_reshaped_shape")
)
result = op.Reshape(matmul_result, new_shape)
return result
model = ir.load("perch_v2.onnx")
onnxscript.rewriter.rewrite(
model,
[ReplaceDftWithMatMulRule().rule()],
)
onnxscript.optimizer.optimize(model)
ir.save(model, "perch_v2_no_dft.onnx")