|
|
"""Convert DFT operations in an ONNX model to equivalent MatMul operations.""" |
|
|
|
|
|
import onnxscript |
|
|
import onnx_ir as ir |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase): |
|
|
def pattern(self, op, x, dft_length): |
|
|
x = op.Reshape(x, _allow_other_inputs=True) |
|
|
dft = op.DFT(x, dft_length, _outputs=["dft_output"]) |
|
|
return dft |
|
|
|
|
|
def rewrite(self, op, x: ir.Value, dft_length: ir.Value, dft_output: ir.Value): |
|
|
|
|
|
dft_node = dft_output.producer() |
|
|
assert dft_node is not None |
|
|
|
|
|
dft_size = ir.convenience.get_const_tensor(dft_length).numpy().item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_freqs = dft_size // 2 + 1 |
|
|
|
|
|
|
|
|
n = np.arange(dft_size, dtype=np.float32)[:, np.newaxis] |
|
|
k = np.arange(num_freqs, dtype=np.float32)[ |
|
|
np.newaxis, : |
|
|
] |
|
|
|
|
|
|
|
|
dft_matrix_real = np.cos( |
|
|
2 * np.pi * k * n / dft_size |
|
|
) |
|
|
|
|
|
|
|
|
dft_matrix_imag = -np.sin( |
|
|
2 * np.pi * k * n / dft_size |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
dft_matrix = np.stack([dft_matrix_real, dft_matrix_imag], axis=-1).reshape( |
|
|
dft_size, num_freqs * 2 |
|
|
) |
|
|
|
|
|
|
|
|
dft_matrix = op.initializer(ir.tensor(dft_matrix), name=f"{x.name}_dft_matrix") |
|
|
|
|
|
|
|
|
matmul_result = op.MatMul(x, dft_matrix) |
|
|
new_shape = op.initializer( |
|
|
ir.tensor([-1, 500, 513, 2], name=f"{x.name}_dft_reshaped_shape") |
|
|
) |
|
|
|
|
|
result = op.Reshape(matmul_result, new_shape) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
model = ir.load("perch_v2.onnx") |
|
|
|
|
|
onnxscript.rewriter.rewrite( |
|
|
model, |
|
|
[ReplaceDftWithMatMulRule().rule()], |
|
|
) |
|
|
|
|
|
onnxscript.optimizer.optimize(model) |
|
|
|
|
|
ir.save(model, "perch_v2_no_dft.onnx") |
|
|
|