|
|
import os |
|
|
import sys |
|
|
import time |
|
|
|
|
|
import numpy as np |
|
|
import librosa |
|
|
import soundfile as sf |
|
|
|
|
|
import ailia |
|
|
from audio_utils import Audio |
|
|
|
|
|
|
|
|
sys.path.append('../../util') |
|
|
from arg_utils import get_base_parser, update_parser, get_savepath |
|
|
from model_utils import check_and_download_models |
|
|
|
|
|
from logging import getLogger |
|
|
|
|
|
logger = getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WEIGHT_PATH = 'model.onnx' |
|
|
MODEL_PATH = 'model.onnx.prototxt' |
|
|
WEIGHT_EMB_PATH = 'embedder.onnx' |
|
|
MODEL_EMB_PATH = 'embedder.onnx.prototxt' |
|
|
REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/voicefilter/' |
|
|
|
|
|
WAVE_PATH = "mixed.wav" |
|
|
SAVE_PATH = 'output.wav' |
|
|
|
|
|
|
|
|
SAMPLING_RATE = 16000 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = get_base_parser( |
|
|
'VoiceFilter', WAVE_PATH, SAVE_PATH, input_ftype='audio' |
|
|
) |
|
|
parser.add_argument( |
|
|
'-r', '--reference_file', |
|
|
default="ref-voice.wav", type=str, |
|
|
help='path of reference wav file' |
|
|
) |
|
|
args = update_parser(parser) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_wave(path): |
|
|
|
|
|
wav, source_sr = librosa.load(path, sr=None) |
|
|
|
|
|
if source_sr is not None and source_sr != SAMPLING_RATE: |
|
|
wav = librosa.resample(wav, orig_sr=source_sr, target_sr=SAMPLING_RATE) |
|
|
|
|
|
return wav |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def audio_recognition(net, embedder): |
|
|
reference_file = args.reference_file |
|
|
if not reference_file or not os.path.exists(reference_file): |
|
|
logger.error('reference_file:%s is NG.' % reference_file) |
|
|
sys.exit(-1) |
|
|
|
|
|
audio = Audio() |
|
|
|
|
|
|
|
|
dvec_wav = read_wave(reference_file) |
|
|
dvec_mel = audio.get_mel(dvec_wav) |
|
|
output = embedder.predict([dvec_mel]) |
|
|
dvec = output[0] |
|
|
dvec = np.expand_dims(dvec, axis=0) |
|
|
|
|
|
for soundf_path in args.input: |
|
|
logger.info(soundf_path) |
|
|
|
|
|
|
|
|
mixed_wav = read_wave(soundf_path) |
|
|
mag, phase = audio.wav2spec(mixed_wav) |
|
|
mag = np.expand_dims(mag, axis=0) |
|
|
|
|
|
|
|
|
logger.info('Start inference...') |
|
|
if args.benchmark: |
|
|
logger.info('BENCHMARK mode') |
|
|
total_time_estimation = 0 |
|
|
for i in range(args.benchmark_count): |
|
|
start = int(round(time.time() * 1000)) |
|
|
output = net.predict([mag, dvec]) |
|
|
end = int(round(time.time() * 1000)) |
|
|
estimation_time = (end - start) |
|
|
|
|
|
|
|
|
logger.info(f'\tailia processing estimation time {estimation_time} ms') |
|
|
if i != 0: |
|
|
total_time_estimation = total_time_estimation + estimation_time |
|
|
|
|
|
logger.info(f'\taverage time estimation {total_time_estimation / (args.benchmark_count - 1)} ms') |
|
|
else: |
|
|
output = net.predict([mag, dvec]) |
|
|
|
|
|
mask = output[0] |
|
|
|
|
|
est_mag = mag * mask |
|
|
est_wav = audio.spec2wav(est_mag[0], phase) |
|
|
|
|
|
savepath = get_savepath(args.savepath, soundf_path, ext='.wav') |
|
|
logger.info(f'saved at : {savepath}') |
|
|
sf.write(savepath, est_wav, SAMPLING_RATE, 'PCM_24') |
|
|
|
|
|
logger.info('Script finished successfully.') |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
logger.info('Checking voicefilter model...') |
|
|
check_and_download_models(WEIGHT_PATH, MODEL_PATH, REMOTE_PATH) |
|
|
logger.info('Checking embedder model...') |
|
|
check_and_download_models(WEIGHT_EMB_PATH, MODEL_EMB_PATH, REMOTE_PATH) |
|
|
|
|
|
env_id = args.env_id |
|
|
|
|
|
net = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=env_id) |
|
|
embedder = ailia.Net(MODEL_EMB_PATH, WEIGHT_EMB_PATH, env_id=env_id) |
|
|
|
|
|
audio_recognition(net, embedder) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|