|
|
import time |
|
|
import sys |
|
|
import argparse |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import ailia |
|
|
|
|
|
|
|
|
sys.path.append('../../util') |
|
|
from arg_utils import get_base_parser, update_parser, get_savepath |
|
|
from model_utils import check_and_download_models |
|
|
|
|
|
|
|
|
from logging import getLogger |
|
|
logger = getLogger(__name__) |
|
|
|
|
|
import os |
|
|
from tqdm import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
from scipy.io import wavfile |
|
|
from deep_music_enhancer_utils import ( |
|
|
read_audio, |
|
|
SingleSong |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WAV_PATH = 'input.wav' |
|
|
SAVE_WAV_PATH = 'output.wav' |
|
|
|
|
|
WEIGHT_PATH_RESNET = 'resnet.onnx' |
|
|
MODEL_PATH_RESNET = 'resnet.onnx.prototxt' |
|
|
WEIGHT_PATH_RESNET_BN = 'resnetbn.onnx' |
|
|
MODEL_PATH_RESNET_BN = 'resnetbn.onnx.prototxt' |
|
|
WEIGHT_PATH_RESNET_DA = 'resnetda.onnx' |
|
|
MODEL_PATH_RESNET_DA = 'resnetda.onnx.prototxt' |
|
|
WEIGHT_PATH_RESNET_DO = 'resnetdo.onnx' |
|
|
MODEL_PATH_RESNET_DO = 'resnetdo.onnx.prototxt' |
|
|
|
|
|
WEIGHT_PATH_UNET = 'unet.onnx' |
|
|
MODEL_PATH_UNET = 'unet.onnx.prototxt' |
|
|
WEIGHT_PATH_UNET_BN = 'unetbn.onnx' |
|
|
MODEL_PATH_UNET_BN = 'unetbn.onnx.prototxt' |
|
|
WEIGHT_PATH_UNET_DA = 'unetda.onnx' |
|
|
MODEL_PATH_UNET_DA = 'unetda.onnx.prototxt' |
|
|
WEIGHT_PATH_UNET_DO = 'unetdo.onnx' |
|
|
MODEL_PATH_UNET_DO = 'unetdo.onnx.prototxt' |
|
|
|
|
|
REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/deep-music-enhancer/' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = get_base_parser( |
|
|
'On Filter Generalization for Music Bandwidth Extension Using Deep Neural Networks', |
|
|
WAV_PATH, |
|
|
SAVE_WAV_PATH |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--input', '-i', metavar='WAV', default=WAV_PATH, |
|
|
help='input audio' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--ailia_audio', action='store_true', |
|
|
help='use ailia audio library' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--vis', action='store_true', |
|
|
help='save visualized spectrogram' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--model', type=str, default='unet', |
|
|
choices=[ |
|
|
'resnet', 'resnet_bn', 'resnet_da', 'resnet_do', |
|
|
'unet', 'unet_bn', 'unet_da', 'unet_do' |
|
|
], |
|
|
) |
|
|
args = update_parser(parser, check_input_type=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def audio_bandwidth_extension(net): |
|
|
FILTERS_TEST = [('cheby1', 6), ('butter', 6)] |
|
|
c_SAMPLE_RATE = 44100 |
|
|
c_WAV_SAMPLE_LEN = 8192 |
|
|
cutoff = 11025 |
|
|
duration = None |
|
|
start = 0 |
|
|
|
|
|
for filter_ in FILTERS_TEST: |
|
|
input_name = args.input[0] |
|
|
input_name_without_ext = os.path.splitext(os.path.basename(input_name))[0] |
|
|
hq_path = input_name |
|
|
|
|
|
logger.info('filter: {}, input_name: {}'.format(filter_, input_name)) |
|
|
|
|
|
|
|
|
song_data = SingleSong( |
|
|
c_WAV_SAMPLE_LEN, |
|
|
filter_, |
|
|
hq_path, |
|
|
cutoff=cutoff, |
|
|
duration=duration, |
|
|
start=start |
|
|
) |
|
|
|
|
|
y_full = song_data.preallocate() |
|
|
|
|
|
idx_start_chunk = 0 |
|
|
|
|
|
for i in tqdm(range(len(song_data))): |
|
|
x, t = song_data[i] |
|
|
x = x[np.newaxis, :, :] |
|
|
|
|
|
y = net.predict(x) |
|
|
|
|
|
idx_end_chunk = idx_start_chunk + y.shape[0] |
|
|
y_full[idx_start_chunk:idx_end_chunk] = y |
|
|
idx_start_chunk = idx_end_chunk |
|
|
|
|
|
y_full = np.concatenate(y_full, axis=-1) |
|
|
|
|
|
x_full, t_full = song_data.get_full_signals() |
|
|
y_full = np.clip(y_full, -1, 1 - np.finfo(np.float32).eps) |
|
|
|
|
|
|
|
|
wavfile.write(args.savepath, c_SAMPLE_RATE, y_full.T) |
|
|
|
|
|
|
|
|
if args.vis: |
|
|
_, _, _, _ = plt.specgram(x_full.T[:c_SAMPLE_RATE*5, 0], Fs=c_SAMPLE_RATE) |
|
|
plt.savefig('{}_{}_input_spec.png'.format(input_name_without_ext, filter_[0])) |
|
|
_, _, _, _ = plt.specgram(y_full.T[:c_SAMPLE_RATE*5, 0], Fs=c_SAMPLE_RATE) |
|
|
plt.savefig('{}_{}_output_spec.png'.format(input_name_without_ext, filter_[0])) |
|
|
|
|
|
logger.info('Script finished successfully.') |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
if args.model == 'resnet': |
|
|
weight_path, model_path = WEIGHT_PATH_RESNET, MODEL_PATH_RESNET |
|
|
elif args.model == 'resnet_bn': |
|
|
weight_path, model_path = WEIGHT_PATH_RESNET_BN, MODEL_PATH_RESNET_BN |
|
|
elif args.model == 'resnet_da': |
|
|
weight_path, model_path = WEIGHT_PATH_RESNET_DA, MODEL_PATH_RESNET_DA |
|
|
elif args.model == 'resnet_do': |
|
|
weight_path, model_path = WEIGHT_PATH_RESNET_DO, MODEL_PATH_RESNET_DO |
|
|
elif args.model == 'unet': |
|
|
weight_path, model_path = WEIGHT_PATH_UNET, MODEL_PATH_UNET |
|
|
elif args.model == 'unet_bn': |
|
|
weight_path, model_path = WEIGHT_PATH_UNET_BN, MODEL_PATH_UNET_BN |
|
|
elif args.model == 'unet_da': |
|
|
weight_path, model_path = WEIGHT_PATH_UNET_DA, MODEL_PATH_UNET_DA |
|
|
elif args.model == 'unet_do': |
|
|
weight_path, model_path = WEIGHT_PATH_UNET_DO, MODEL_PATH_UNET_DO |
|
|
|
|
|
env_id = args.env_id |
|
|
|
|
|
check_and_download_models(weight_path, model_path, REMOTE_PATH) |
|
|
net = ailia.Net(model_path, weight_path, env_id=env_id) |
|
|
|
|
|
audio_bandwidth_extension(net) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|