|
|
""" |
|
|
This is an example how to implement real time processing of the DTLN tf light |
|
|
model in python. |
|
|
|
|
|
Please change the name of the .wav file at line 43 before running the sript. |
|
|
For .whl files of the tf light runtime go to: |
|
|
https://www.tensorflow.org/lite/guide/python |
|
|
|
|
|
Author: Nils L. Westhausen (nils.westhausen@uol.de) |
|
|
Version: 30.06.2020 |
|
|
|
|
|
This code is licensed under the terms of the MIT-license. |
|
|
""" |
|
|
|
|
|
import soundfile as sf |
|
|
import numpy as np |
|
|
import tflite_runtime.interpreter as tflite |
|
|
import time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block_len = 512 |
|
|
block_shift = 128 |
|
|
|
|
|
interpreter_1 = tflite.Interpreter(model_path='./pretrained_model/model_1.tflite') |
|
|
interpreter_1.allocate_tensors() |
|
|
interpreter_2 = tflite.Interpreter(model_path='./pretrained_model/model_2.tflite') |
|
|
interpreter_2.allocate_tensors() |
|
|
|
|
|
|
|
|
input_details_1 = interpreter_1.get_input_details() |
|
|
output_details_1 = interpreter_1.get_output_details() |
|
|
|
|
|
input_details_2 = interpreter_2.get_input_details() |
|
|
output_details_2 = interpreter_2.get_output_details() |
|
|
|
|
|
states_1 = np.zeros(input_details_1[1]['shape']).astype('float32') |
|
|
states_2 = np.zeros(input_details_2[1]['shape']).astype('float32') |
|
|
|
|
|
audio,fs = sf.read('path/to/your/favorite/.wav') |
|
|
|
|
|
if fs != 16000: |
|
|
raise ValueError('This model only supports 16k sampling rate.') |
|
|
|
|
|
out_file = np.zeros((len(audio))) |
|
|
|
|
|
in_buffer = np.zeros((block_len)).astype('float32') |
|
|
out_buffer = np.zeros((block_len)).astype('float32') |
|
|
|
|
|
num_blocks = (audio.shape[0] - (block_len-block_shift)) // block_shift |
|
|
time_array = [] |
|
|
|
|
|
for idx in range(num_blocks): |
|
|
start_time = time.time() |
|
|
|
|
|
in_buffer[:-block_shift] = in_buffer[block_shift:] |
|
|
in_buffer[-block_shift:] = audio[idx*block_shift:(idx*block_shift)+block_shift] |
|
|
|
|
|
in_block_fft = np.fft.rfft(in_buffer) |
|
|
in_mag = np.abs(in_block_fft) |
|
|
in_phase = np.angle(in_block_fft) |
|
|
|
|
|
in_mag = np.reshape(in_mag, (1,1,-1)).astype('float32') |
|
|
|
|
|
interpreter_1.set_tensor(input_details_1[1]['index'], states_1) |
|
|
interpreter_1.set_tensor(input_details_1[0]['index'], in_mag) |
|
|
|
|
|
interpreter_1.invoke() |
|
|
|
|
|
out_mask = interpreter_1.get_tensor(output_details_1[0]['index']) |
|
|
states_1 = interpreter_1.get_tensor(output_details_1[1]['index']) |
|
|
|
|
|
estimated_complex = in_mag * out_mask * np.exp(1j * in_phase) |
|
|
estimated_block = np.fft.irfft(estimated_complex) |
|
|
|
|
|
estimated_block = np.reshape(estimated_block, (1,1,-1)).astype('float32') |
|
|
|
|
|
interpreter_2.set_tensor(input_details_2[1]['index'], states_2) |
|
|
interpreter_2.set_tensor(input_details_2[0]['index'], estimated_block) |
|
|
|
|
|
interpreter_2.invoke() |
|
|
|
|
|
out_block = interpreter_2.get_tensor(output_details_2[0]['index']) |
|
|
states_2 = interpreter_2.get_tensor(output_details_2[1]['index']) |
|
|
|
|
|
|
|
|
|
|
|
out_buffer[:-block_shift] = out_buffer[block_shift:] |
|
|
out_buffer[-block_shift:] = np.zeros((block_shift)) |
|
|
out_buffer += np.squeeze(out_block) |
|
|
|
|
|
out_file[idx*block_shift:(idx*block_shift)+block_shift] = out_buffer[:block_shift] |
|
|
time_array.append(time.time()-start_time) |
|
|
|
|
|
|
|
|
sf.write('out.wav', out_file, fs) |
|
|
print('Processing Time [ms]:') |
|
|
print(np.mean(np.stack(time_array))*1000) |
|
|
print('Processing finished.') |
|
|
|