|
|
import sys
|
|
|
import time
|
|
|
from logging import getLogger
|
|
|
import onnxruntime
|
|
|
|
|
|
import numpy as np
|
|
|
import soundfile as sf
|
|
|
|
|
|
import ailia
|
|
|
|
|
|
|
|
|
sys.path.append('../../util')
|
|
|
from model_utils import check_and_download_models
|
|
|
from arg_utils import get_base_parser, get_savepath, update_parser
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WEIGHT1_PATH = "dtln1.onnx"
|
|
|
MODEL1_PATH = "dtln1.onnx.prototxt"
|
|
|
WEIGHT2_PATH = "dtln2.onnx"
|
|
|
MODEL2_PATH = "dtln2.onnx.prototxt"
|
|
|
REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/dtln/'
|
|
|
|
|
|
SAMPLE_RATE = 16000
|
|
|
|
|
|
WAV_PATH = '1221-135766-0000.wav'
|
|
|
SAVE_WAV_PATH = 'output.wav'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = get_base_parser(
|
|
|
'Dual-signal Transformation LSTM Network', WAV_PATH, SAVE_WAV_PATH, input_ftype='audio'
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
'--onnx',
|
|
|
action='store_true',
|
|
|
help='By default, the ailia SDK is used, but with this option, you can switch to using ONNX Runtime'
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
'--shift',
|
|
|
default=128, type=int,
|
|
|
)
|
|
|
args = update_parser(parser)
|
|
|
|
|
|
block_shift = args.shift
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(audio,models):
|
|
|
block_len = 512
|
|
|
out_file = np.zeros((len(audio)))
|
|
|
|
|
|
in_buffer = np.zeros((block_len)).astype('float32')
|
|
|
out_buffer = np.zeros((block_len)).astype('float32')
|
|
|
|
|
|
num_blocks = (audio.shape[0] - (block_len-block_shift)) // block_shift
|
|
|
|
|
|
time_array = []
|
|
|
|
|
|
inp_shape = [(1, 1, 257) ,(1, 2, 128, 2)]
|
|
|
model_input_names_1 = ['input_2', 'input_3']
|
|
|
interpreter_1 = models[0]
|
|
|
interpreter_2 = models[1]
|
|
|
model_inputs_1 = {}
|
|
|
model_inputs_1[model_input_names_1[0]] = np.zeros(
|
|
|
[dim if isinstance(dim, int) else 1 for dim in inp_shape[0]],
|
|
|
dtype=np.float32)
|
|
|
model_inputs_1[model_input_names_1[1]] = np.zeros(
|
|
|
[dim if isinstance(dim, int) else 1 for dim in inp_shape[1]],
|
|
|
dtype=np.float32)
|
|
|
|
|
|
model_input_names_2 = ['input_4','input_5']
|
|
|
model_inputs_2 = {}
|
|
|
inp_shape = [(1, 1, 512) ,(1, 2, 128, 2)]
|
|
|
model_inputs_2[model_input_names_2[0]] = np.zeros(
|
|
|
[dim if isinstance(dim, int) else 1 for dim in inp_shape[0]],
|
|
|
dtype=np.float32)
|
|
|
model_inputs_2[model_input_names_2[1]] = np.zeros(
|
|
|
[dim if isinstance(dim, int) else 1 for dim in inp_shape[1]],
|
|
|
dtype=np.float32)
|
|
|
|
|
|
|
|
|
for idx in range(num_blocks):
|
|
|
start_time = time.time()
|
|
|
|
|
|
in_buffer[:-block_shift] = in_buffer[block_shift:]
|
|
|
in_buffer[-block_shift:] = audio[idx*block_shift:(idx*block_shift)+block_shift]
|
|
|
|
|
|
in_block_fft = np.fft.rfft(in_buffer)
|
|
|
in_mag = np.abs(in_block_fft)
|
|
|
in_phase = np.angle(in_block_fft)
|
|
|
|
|
|
in_mag = np.reshape(in_mag, (1,1,-1)).astype('float32')
|
|
|
|
|
|
model_inputs_1[model_input_names_1[0]] = in_mag
|
|
|
|
|
|
if args.onnx:
|
|
|
model_outputs_1 = interpreter_1.run([],{'input_2':model_inputs_1['input_2'],'input_3':model_inputs_1['input_3']})
|
|
|
else:
|
|
|
inputs = [model_inputs_1['input_2'],model_inputs_1['input_3']]
|
|
|
model_outputs_1 = interpreter_1.run(inputs)
|
|
|
|
|
|
out_mask = model_outputs_1[0]
|
|
|
|
|
|
model_inputs_1[model_input_names_1[1]] = model_outputs_1[1]
|
|
|
|
|
|
estimated_complex = in_mag * out_mask * np.exp(1j * in_phase)
|
|
|
estimated_block = np.fft.irfft(estimated_complex)
|
|
|
|
|
|
estimated_block = np.reshape(estimated_block, (1,1,-1)).astype('float32')
|
|
|
|
|
|
model_inputs_2[model_input_names_2[0]] = estimated_block
|
|
|
|
|
|
if args.onnx:
|
|
|
model_outputs_2 = interpreter_2.run([],{'input_4':model_inputs_2['input_4'],'input_5':model_inputs_2['input_5']})
|
|
|
else:
|
|
|
inputs = [model_inputs_2['input_4'],model_inputs_2['input_5']]
|
|
|
model_outputs_2 = interpreter_2.run(inputs)
|
|
|
|
|
|
out_block = model_outputs_2[0]
|
|
|
|
|
|
model_inputs_2[model_input_names_2[1]] = model_outputs_2[1]
|
|
|
|
|
|
out_buffer[:-block_shift] = out_buffer[block_shift:]
|
|
|
out_buffer[-block_shift:] = np.zeros((block_shift))
|
|
|
out_buffer += np.squeeze(out_block)
|
|
|
|
|
|
out_file[idx*block_shift:(idx*block_shift)+block_shift] = out_buffer[:block_shift]
|
|
|
time_array.append(time.time()-start_time)
|
|
|
return out_file
|
|
|
|
|
|
def recognize_from_audio(models):
|
|
|
|
|
|
inp_shape = [(1, 1, 257) ,(1, 2, 128, 2)]
|
|
|
model_input_names_1 = ['input_2', 'input_3']
|
|
|
|
|
|
model_inputs_1 = {}
|
|
|
model_inputs_1[model_input_names_1[0]] = np.zeros(
|
|
|
[dim if isinstance(dim, int) else 1 for dim in inp_shape[0]],
|
|
|
dtype=np.float32)
|
|
|
model_inputs_1[model_input_names_1[1]] = np.zeros(
|
|
|
[dim if isinstance(dim, int) else 1 for dim in inp_shape[1]],
|
|
|
dtype=np.float32)
|
|
|
|
|
|
model_input_names_2 = ['input_4','input_5']
|
|
|
model_inputs_2 = {}
|
|
|
inp_shape = [(1, 1, 512) ,(1, 2, 128, 2)]
|
|
|
model_inputs_2[model_input_names_2[0]] = np.zeros(
|
|
|
[dim if isinstance(dim, int) else 1 for dim in inp_shape[0]],
|
|
|
dtype=np.float32)
|
|
|
model_inputs_2[model_input_names_2[1]] = np.zeros(
|
|
|
[dim if isinstance(dim, int) else 1 for dim in inp_shape[1]],
|
|
|
dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
for audio_path in args.input:
|
|
|
logger.info(audio_path)
|
|
|
|
|
|
|
|
|
audio,fs = sf.read(audio_path)
|
|
|
|
|
|
if fs != 16000:
|
|
|
print('This model only supports 16k sampling rate.')
|
|
|
continue
|
|
|
|
|
|
|
|
|
logger.info('Start inference...')
|
|
|
if args.benchmark:
|
|
|
logger.info('BENCHMARK mode')
|
|
|
start = int(round(time.time() * 1000))
|
|
|
output,sr = predict(audio, models)
|
|
|
end = int(round(time.time() * 1000))
|
|
|
estimation_time = (end - start)
|
|
|
logger.info(f'\ttotal processing time {estimation_time} ms')
|
|
|
else:
|
|
|
output = predict(audio, models)
|
|
|
|
|
|
|
|
|
savepath = get_savepath(args.savepath, audio_path, ext='.wav')
|
|
|
logger.info(f'saved at : {savepath}')
|
|
|
sf.write(savepath, output, fs)
|
|
|
|
|
|
logger.info('Script finished successfully.')
|
|
|
|
|
|
|
|
|
def main():
|
|
|
check_and_download_models(WEIGHT1_PATH, MODEL1_PATH, REMOTE_PATH)
|
|
|
check_and_download_models(WEIGHT2_PATH, MODEL2_PATH, REMOTE_PATH)
|
|
|
|
|
|
env_id = args.env_id
|
|
|
|
|
|
if args.onnx:
|
|
|
models = [onnxruntime.InferenceSession(WEIGHT1_PATH),
|
|
|
onnxruntime.InferenceSession(WEIGHT2_PATH)]
|
|
|
else:
|
|
|
models = [ailia.Net(MODEL1_PATH,WEIGHT1_PATH, env_id = env_id),
|
|
|
ailia.Net(MODEL2_PATH,WEIGHT2_PATH, env_id = env_id)]
|
|
|
|
|
|
|
|
|
recognize_from_audio(models)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|
|
|
|