|
|
import time |
|
|
import sys |
|
|
import argparse |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import ailia |
|
|
|
|
|
import soundfile as sf |
|
|
|
|
|
|
|
|
sys.path.append('../../util') |
|
|
from arg_utils import get_base_parser, update_parser, get_savepath |
|
|
from model_utils import check_and_download_models |
|
|
|
|
|
|
|
|
from logging import getLogger |
|
|
logger = getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WAV_PATH = 'doublenoble_k7rain_part.wav' |
|
|
|
|
|
SAVE_WAV_PATH = 'separated_voice.wav' |
|
|
MODEL_LISTS = ['base', 'large'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = get_base_parser( |
|
|
'RSource separation.', |
|
|
WAV_PATH, |
|
|
SAVE_WAV_PATH, |
|
|
) |
|
|
parser.add_argument( |
|
|
'-n', '--onnx', |
|
|
action='store_true', |
|
|
default=False, |
|
|
help='Use onnxruntime' |
|
|
) |
|
|
parser.add_argument( |
|
|
'-st', '--stereo', |
|
|
action='store_true', |
|
|
default=False, |
|
|
help='Use stereo mode' |
|
|
) |
|
|
parser.add_argument( |
|
|
'-a', '--arch', |
|
|
default='base', choices=MODEL_LISTS, |
|
|
help='model lists: ' + ' | '.join(MODEL_LISTS) |
|
|
) |
|
|
parser.add_argument( |
|
|
'--ailia_audio', action='store_true', |
|
|
help='use ailia audio library' |
|
|
) |
|
|
args = update_parser(parser) |
|
|
|
|
|
if args.ailia_audio: |
|
|
import ailia.audio as ailia_audio |
|
|
from unet_source_separation_utils_ailia import preemphasis, inv_preemphasis, lowpass, tfconvert, zero_pad, calc_time |
|
|
else: |
|
|
from scipy import signal |
|
|
from unet_source_separation_utils import preemphasis, inv_preemphasis, lowpass, tfconvert, zero_pad, calc_time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.arch == 'base' : |
|
|
WEIGHT_PATH = "second_voice_bank.best.opt2.onnx" |
|
|
else : |
|
|
WEIGHT_PATH = "RefineSpectrogramUnet.best.opt.onnx" |
|
|
MODEL_PATH = WEIGHT_PATH + ".prototxt" |
|
|
REMOTE_PATH = "https://storage.googleapis.com/ailia-models/unet_source_separation/" |
|
|
|
|
|
|
|
|
if args.arch == 'base' : |
|
|
DESIRED_SR = 22050 |
|
|
MULT = 2 ** 5 |
|
|
WINDOW_LEN = 512 |
|
|
HOP_LEN = 64 |
|
|
else : |
|
|
DESIRED_SR = 44100 |
|
|
MULT = 2 ** 6 |
|
|
WINDOW_LEN = 1024 |
|
|
HOP_LEN = 128 |
|
|
|
|
|
|
|
|
if args.arch == 'base' : |
|
|
LPF_CUTOFF = 10000 |
|
|
else : |
|
|
LPF_CUTOFF = 20000 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def src_sep(data, session) : |
|
|
|
|
|
if not args.onnx : |
|
|
sep = session.run(data)[0] |
|
|
|
|
|
else : |
|
|
first_input_name = session.get_inputs()[0].name |
|
|
second_input_name = session.get_inputs()[1].name |
|
|
first_output_name = session.get_outputs()[0].name |
|
|
sep = session.run( |
|
|
[first_output_name], |
|
|
{first_input_name: data[0], second_input_name: data[1]})[0] |
|
|
|
|
|
return sep |
|
|
|
|
|
|
|
|
def recognize_one_audio(input_path): |
|
|
|
|
|
logger.info('Loading wavfile...') |
|
|
wav, sr = sf.read(input_path) |
|
|
|
|
|
if wav.dtype != np.float32: |
|
|
wav = wav.astype(np.float32) |
|
|
|
|
|
if wav.ndim == 2 : |
|
|
if args.stereo: |
|
|
wav = np.transpose(wav,(1,0)) |
|
|
else: |
|
|
wav = (wav[:,0][np.newaxis,:] + wav[:,1][np.newaxis,:])/2 |
|
|
else: |
|
|
wav = wav[np.newaxis,:] |
|
|
|
|
|
calc_time(wav.shape[1], sr) |
|
|
|
|
|
|
|
|
logger.info('Converting sample rate...') |
|
|
if not sr == DESIRED_SR : |
|
|
if args.ailia_audio: |
|
|
wav = ailia.audio.resample(wav,sr,DESIRED_SR) |
|
|
else: |
|
|
wav = signal.resample_poly(wav, DESIRED_SR, sr, axis=1) |
|
|
|
|
|
|
|
|
logger.info('Generating input feature...') |
|
|
wav = preemphasis(wav) |
|
|
|
|
|
input_feature = tfconvert(wav, WINDOW_LEN, HOP_LEN, MULT) |
|
|
|
|
|
|
|
|
if not args.onnx : |
|
|
logger.info('Use ailia') |
|
|
env_id = args.env_id |
|
|
logger.info(f'env_id: {env_id}') |
|
|
memory_mode = ailia.get_memory_mode(reuse_interstage=True) |
|
|
session = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=env_id, memory_mode=memory_mode) |
|
|
else : |
|
|
logger.info('Use onnxruntime') |
|
|
import onnxruntime |
|
|
session = onnxruntime.InferenceSession(WEIGHT_PATH) |
|
|
|
|
|
|
|
|
logger.info('Start inference...') |
|
|
if args.benchmark: |
|
|
logger.info('BENCHMARK mode') |
|
|
for c in range(5) : |
|
|
start = int(round(time.time() * 1000)) |
|
|
sep = src_sep(input_feature, session) |
|
|
end = int(round(time.time() * 1000)) |
|
|
logger.info("\tprocessing time {} ms".format(end-start)) |
|
|
else: |
|
|
sep = src_sep(input_feature, session) |
|
|
|
|
|
|
|
|
logger.info('Start postprocessing...') |
|
|
if LPF_CUTOFF > 0 : |
|
|
sep = lowpass(sep, LPF_CUTOFF, DESIRED_SR) |
|
|
|
|
|
out_wav = inv_preemphasis(sep).clip(-1.,1.) |
|
|
out_wav = out_wav.swapaxes(0,1) |
|
|
|
|
|
|
|
|
savepath = get_savepath(args.savepath, input_path) |
|
|
logger.info(f'saved at : {savepath}') |
|
|
|
|
|
sf.write(savepath, out_wav, DESIRED_SR) |
|
|
|
|
|
logger.info('Saved separated signal. ') |
|
|
logger.info('Script finished successfully.') |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
check_and_download_models(WEIGHT_PATH, MODEL_PATH, REMOTE_PATH) |
|
|
|
|
|
for input_file in args.input: |
|
|
recognize_one_audio(input_file) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|