|
|
""" |
|
|
This is an example how to implement real time processing of the DTLN ONNX |
|
|
model in python. |
|
|
|
|
|
Please change the name of the .wav file at line 49 before running the sript. |
|
|
For the ONNX runtime call: $ pip install onnxruntime |
|
|
|
|
|
|
|
|
|
|
|
Author: Nils L. Westhausen (nils.westhausen@uol.de) |
|
|
Version: 03.07.2020 |
|
|
|
|
|
This code is licensed under the terms of the MIT-license. |
|
|
""" |
|
|
|
|
|
import soundfile as sf |
|
|
import numpy as np |
|
|
import time |
|
|
import onnxruntime |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block_len = 512 |
|
|
block_shift = 128 |
|
|
|
|
|
interpreter_1 = onnxruntime.InferenceSession('./model_1.onnx') |
|
|
model_input_names_1 = [inp.name for inp in interpreter_1.get_inputs()] |
|
|
|
|
|
model_inputs_1 = { |
|
|
inp.name: np.zeros( |
|
|
[dim if isinstance(dim, int) else 1 for dim in inp.shape], |
|
|
dtype=np.float32) |
|
|
for inp in interpreter_1.get_inputs()} |
|
|
|
|
|
interpreter_2 = onnxruntime.InferenceSession('./model_2.onnx') |
|
|
model_input_names_2 = [inp.name for inp in interpreter_2.get_inputs()] |
|
|
|
|
|
model_inputs_2 = { |
|
|
inp.name: np.zeros( |
|
|
[dim if isinstance(dim, int) else 1 for dim in inp.shape], |
|
|
dtype=np.float32) |
|
|
for inp in interpreter_2.get_inputs()} |
|
|
|
|
|
|
|
|
audio,fs = sf.read('path/to/your/favorite.wav') |
|
|
|
|
|
if fs != 16000: |
|
|
raise ValueError('This model only supports 16k sampling rate.') |
|
|
|
|
|
out_file = np.zeros((len(audio))) |
|
|
|
|
|
in_buffer = np.zeros((block_len)).astype('float32') |
|
|
out_buffer = np.zeros((block_len)).astype('float32') |
|
|
|
|
|
num_blocks = (audio.shape[0] - (block_len-block_shift)) // block_shift |
|
|
|
|
|
time_array = [] |
|
|
for idx in range(num_blocks): |
|
|
start_time = time.time() |
|
|
|
|
|
in_buffer[:-block_shift] = in_buffer[block_shift:] |
|
|
in_buffer[-block_shift:] = audio[idx*block_shift:(idx*block_shift)+block_shift] |
|
|
|
|
|
in_block_fft = np.fft.rfft(in_buffer) |
|
|
in_mag = np.abs(in_block_fft) |
|
|
in_phase = np.angle(in_block_fft) |
|
|
|
|
|
in_mag = np.reshape(in_mag, (1,1,-1)).astype('float32') |
|
|
|
|
|
model_inputs_1[model_input_names_1[0]] = in_mag |
|
|
|
|
|
model_outputs_1 = interpreter_1.run(None, model_inputs_1) |
|
|
|
|
|
out_mask = model_outputs_1[0] |
|
|
|
|
|
model_inputs_1[model_input_names_1[1]] = model_outputs_1[1] |
|
|
|
|
|
estimated_complex = in_mag * out_mask * np.exp(1j * in_phase) |
|
|
estimated_block = np.fft.irfft(estimated_complex) |
|
|
|
|
|
estimated_block = np.reshape(estimated_block, (1,1,-1)).astype('float32') |
|
|
|
|
|
|
|
|
model_inputs_2[model_input_names_2[0]] = estimated_block |
|
|
|
|
|
model_outputs_2 = interpreter_2.run(None, model_inputs_2) |
|
|
|
|
|
out_block = model_outputs_2[0] |
|
|
|
|
|
model_inputs_2[model_input_names_2[1]] = model_outputs_2[1] |
|
|
|
|
|
out_buffer[:-block_shift] = out_buffer[block_shift:] |
|
|
out_buffer[-block_shift:] = np.zeros((block_shift)) |
|
|
out_buffer += np.squeeze(out_block) |
|
|
|
|
|
out_file[idx*block_shift:(idx*block_shift)+block_shift] = out_buffer[:block_shift] |
|
|
time_array.append(time.time()-start_time) |
|
|
|
|
|
|
|
|
sf.write('out.wav', out_file, fs) |
|
|
print('Processing Time [ms]:') |
|
|
print(np.mean(np.stack(time_array))*1000) |
|
|
print('Processing finished.') |