niobures's picture
Deep-Music-Enhancer (code, models, paper)
7b6980a verified
import time
import sys
import argparse
import numpy as np
import ailia # noqa: E402
# import original modules
sys.path.append('../../util')
from arg_utils import get_base_parser, update_parser, get_savepath # noqa: E402
from model_utils import check_and_download_models # noqa: E402
# logger
from logging import getLogger # noqa: E402
logger = getLogger(__name__)
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.io import wavfile
from deep_music_enhancer_utils import (
read_audio,
SingleSong
)
# ======================
# PARAMETERS
# ======================
WAV_PATH = 'input.wav'
SAVE_WAV_PATH = 'output.wav'
WEIGHT_PATH_RESNET = 'resnet.onnx'
MODEL_PATH_RESNET = 'resnet.onnx.prototxt'
WEIGHT_PATH_RESNET_BN = 'resnetbn.onnx'
MODEL_PATH_RESNET_BN = 'resnetbn.onnx.prototxt'
WEIGHT_PATH_RESNET_DA = 'resnetda.onnx'
MODEL_PATH_RESNET_DA = 'resnetda.onnx.prototxt'
WEIGHT_PATH_RESNET_DO = 'resnetdo.onnx'
MODEL_PATH_RESNET_DO = 'resnetdo.onnx.prototxt'
WEIGHT_PATH_UNET = 'unet.onnx'
MODEL_PATH_UNET = 'unet.onnx.prototxt'
WEIGHT_PATH_UNET_BN = 'unetbn.onnx'
MODEL_PATH_UNET_BN = 'unetbn.onnx.prototxt'
WEIGHT_PATH_UNET_DA = 'unetda.onnx'
MODEL_PATH_UNET_DA = 'unetda.onnx.prototxt'
WEIGHT_PATH_UNET_DO = 'unetdo.onnx'
MODEL_PATH_UNET_DO = 'unetdo.onnx.prototxt'
REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/deep-music-enhancer/'
# ======================
# Arguemnt Parser Config
# ======================
parser = get_base_parser(
'On Filter Generalization for Music Bandwidth Extension Using Deep Neural Networks',
WAV_PATH,
SAVE_WAV_PATH
)
# overwrite
parser.add_argument(
'--input', '-i', metavar='WAV', default=WAV_PATH,
help='input audio'
)
parser.add_argument(
'--ailia_audio', action='store_true',
help='use ailia audio library'
)
parser.add_argument(
'--vis', action='store_true',
help='save visualized spectrogram'
)
parser.add_argument(
'--model', type=str, default='unet',
choices=[
'resnet', 'resnet_bn', 'resnet_da', 'resnet_do',
'unet', 'unet_bn', 'unet_da', 'unet_do'
],
)
args = update_parser(parser, check_input_type=False)
# ======================
# Main function
# ======================
def audio_bandwidth_extension(net):
FILTERS_TEST = [('cheby1', 6), ('butter', 6)]
c_SAMPLE_RATE = 44100
c_WAV_SAMPLE_LEN = 8192
cutoff = 11025
duration = None
start = 0
for filter_ in FILTERS_TEST:
input_name = args.input[0]
input_name_without_ext = os.path.splitext(os.path.basename(input_name))[0]
hq_path = input_name
logger.info('filter: {}, input_name: {}'.format(filter_, input_name))
# create dataset
song_data = SingleSong(
c_WAV_SAMPLE_LEN,
filter_,
hq_path,
cutoff=cutoff,
duration=duration,
start=start
)
y_full = song_data.preallocate() # preallocation to keep individual output chunks
idx_start_chunk = 0 # model works on chunks of audio, these are concatenated later
for i in tqdm(range(len(song_data))):
x, t = song_data[i]
x = x[np.newaxis, :, :]
y = net.predict(x)
idx_end_chunk = idx_start_chunk + y.shape[0]
y_full[idx_start_chunk:idx_end_chunk] = y
idx_start_chunk = idx_end_chunk
y_full = np.concatenate(y_full, axis=-1) # create full song out of chunks
x_full, t_full = song_data.get_full_signals()
y_full = np.clip(y_full, -1, 1 - np.finfo(np.float32).eps)
# save audio
wavfile.write(args.savepath, c_SAMPLE_RATE, y_full.T)
# save spec
if args.vis:
_, _, _, _ = plt.specgram(x_full.T[:c_SAMPLE_RATE*5, 0], Fs=c_SAMPLE_RATE)
plt.savefig('{}_{}_input_spec.png'.format(input_name_without_ext, filter_[0]))
_, _, _, _ = plt.specgram(y_full.T[:c_SAMPLE_RATE*5, 0], Fs=c_SAMPLE_RATE)
plt.savefig('{}_{}_output_spec.png'.format(input_name_without_ext, filter_[0]))
logger.info('Script finished successfully.')
def main():
# model files check and download
if args.model == 'resnet':
weight_path, model_path = WEIGHT_PATH_RESNET, MODEL_PATH_RESNET
elif args.model == 'resnet_bn':
weight_path, model_path = WEIGHT_PATH_RESNET_BN, MODEL_PATH_RESNET_BN
elif args.model == 'resnet_da':
weight_path, model_path = WEIGHT_PATH_RESNET_DA, MODEL_PATH_RESNET_DA
elif args.model == 'resnet_do':
weight_path, model_path = WEIGHT_PATH_RESNET_DO, MODEL_PATH_RESNET_DO
elif args.model == 'unet':
weight_path, model_path = WEIGHT_PATH_UNET, MODEL_PATH_UNET
elif args.model == 'unet_bn':
weight_path, model_path = WEIGHT_PATH_UNET_BN, MODEL_PATH_UNET_BN
elif args.model == 'unet_da':
weight_path, model_path = WEIGHT_PATH_UNET_DA, MODEL_PATH_UNET_DA
elif args.model == 'unet_do':
weight_path, model_path = WEIGHT_PATH_UNET_DO, MODEL_PATH_UNET_DO
env_id = args.env_id
check_and_download_models(weight_path, model_path, REMOTE_PATH)
net = ailia.Net(model_path, weight_path, env_id=env_id)
audio_bandwidth_extension(net)
if __name__ == "__main__":
main()