|
|
import sys |
|
|
import time |
|
|
import math |
|
|
from logging import getLogger |
|
|
|
|
|
import scipy |
|
|
import librosa |
|
|
import numpy as np |
|
|
|
|
|
import ailia |
|
|
|
|
|
|
|
|
sys.path.append('../../util') |
|
|
from arg_utils import get_base_parser, update_parser, get_savepath |
|
|
from model_utils import check_and_download_models |
|
|
|
|
|
logger = getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
QUERY_WEIGHT_PATH = 'audiosep_text.onnx' |
|
|
SEPNET_WEIGHT_PATH = 'audiosep_resunet.onnx' |
|
|
|
|
|
QUERY_MODEL_PATH = 'audiosep_text.onnx.prototxt' |
|
|
SEPNET_MODEL_PATH = 'audiosep_resunet.onnx.prototxt' |
|
|
|
|
|
REMOTE_PATH = "https://storage.googleapis.com/ailia-models/audiosep/" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = get_base_parser( |
|
|
'audiosep', "input.wav", None |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"-p", "--prompt", metavar="TEXT", type=str, |
|
|
default="water drops", |
|
|
help="Text query." |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--disable_ailia_tokenizer', |
|
|
action='store_true', |
|
|
help='disable ailia tokenizer.' |
|
|
) |
|
|
|
|
|
args = update_parser(parser, check_input_type=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Functions below are taken from https://github.com/Audio-AGI/AudioSep, which was released under MIT license. |
|
|
Modified to be run with numpy arrays instead of torch tensors |
|
|
""" |
|
|
def preprocess_mag(mag): |
|
|
|
|
|
mag = np.transpose(mag, (0,3,2,1)) |
|
|
mag = (mag - np.mean(mag, axis=(2,3), keepdims=True)) / (np.std(mag, axis=(2,3), keepdims=True) + 1e-5) |
|
|
mag = np.transpose(mag, (0,3,2,1)) |
|
|
p = math.ceil(mag.shape[2] / 2**5) * 2**5 - mag.shape[2] |
|
|
mag = np.pad(mag, ((0,0),(0,0),(0,p),(0,0))) |
|
|
mag = mag[:,:,:,0:mag.shape[-1]-1] |
|
|
return mag |
|
|
|
|
|
def spectrogram_phase(input, eps=0.): |
|
|
D = librosa.stft( |
|
|
input, |
|
|
n_fft=2048, |
|
|
hop_length=320, |
|
|
win_length=2048, |
|
|
window='hann', |
|
|
center=True, |
|
|
pad_mode='reflect' |
|
|
) |
|
|
real = np.real(D) |
|
|
imag = np.imag(D) |
|
|
mag = np.clip(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 |
|
|
cos = real / mag |
|
|
sin = imag / mag |
|
|
return mag, cos, sin |
|
|
|
|
|
def wav_to_spectrogram(input, eps=1e-10): |
|
|
"""Waveform to spectrogram. |
|
|
|
|
|
Args: |
|
|
input: (batch_size, segment_samples, channels_num) |
|
|
|
|
|
Outputs: |
|
|
output: (batch_size, channels_num, time_steps, freq_bins) |
|
|
""" |
|
|
sp_list = [] |
|
|
cos_list = [] |
|
|
sin_list = [] |
|
|
channels_num = input.shape[1] |
|
|
for channel in range(channels_num): |
|
|
mag, cos, sin = spectrogram_phase(input[:, channel, :], eps=eps) |
|
|
sp_list.append(mag) |
|
|
cos_list.append(cos) |
|
|
sin_list.append(sin) |
|
|
|
|
|
sps = np.concatenate(sp_list, axis=1) |
|
|
coss = np.concatenate(cos_list, axis=1) |
|
|
sins = np.concatenate(sin_list, axis=1) |
|
|
return sps, coss, sins |
|
|
|
|
|
def sigmoid(x): |
|
|
return 1 / (1 + np.exp(-x)) |
|
|
|
|
|
def feature_maps_to_wav( |
|
|
input_tensor, |
|
|
sp, |
|
|
sin_in, |
|
|
cos_in, |
|
|
audio_length, |
|
|
): |
|
|
batch_size, _, time_steps, freq_bins = input_tensor.shape |
|
|
|
|
|
x = input_tensor.reshape( |
|
|
batch_size, |
|
|
1, |
|
|
1, |
|
|
3, |
|
|
time_steps, |
|
|
freq_bins, |
|
|
) |
|
|
|
|
|
|
|
|
mask_mag = sigmoid(x[:, :, :, 0, :, :]) |
|
|
_mask_real = np.tanh(x[:, :, :, 1, :, :]) |
|
|
_mask_imag = np.tanh(x[:, :, :, 2, :, :]) |
|
|
|
|
|
_, phase = librosa.magphase(_mask_real + 1j*_mask_imag) |
|
|
|
|
|
mask_cos = np.real(phase) |
|
|
mask_sin = np.imag(phase) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_cos = ( |
|
|
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin |
|
|
) |
|
|
out_sin = ( |
|
|
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_mag = np.max(sp[:, None, :, :, :] * mask_mag, 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_real = out_mag * out_cos |
|
|
out_imag = out_mag * out_sin |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shape = ( |
|
|
batch_size, |
|
|
1, |
|
|
time_steps, |
|
|
freq_bins, |
|
|
) |
|
|
out_real = out_real.reshape(shape) |
|
|
out_imag = out_imag.reshape(shape) |
|
|
|
|
|
x = librosa.istft( |
|
|
(out_real + 1j * out_imag)[0,0].astype('complex64').transpose((1,0)), |
|
|
n_fft = 2048, |
|
|
hop_length = 320, |
|
|
win_length = 2048, |
|
|
window = 'hann', |
|
|
center = True, |
|
|
length = audio_length, |
|
|
) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(model, input_text, input_wav): |
|
|
|
|
|
tokenizer = model['tokenizer'] |
|
|
text_prompt_tkn = dict(tokenizer(input_text, return_tensors = 'np', padding = True)) |
|
|
text_prompt_tkn = (text_prompt_tkn['input_ids'], text_prompt_tkn['attention_mask']) |
|
|
|
|
|
|
|
|
mag, cosin, sinin = wav_to_spectrogram(input_wav) |
|
|
orig_len = mag.shape[-1] |
|
|
mag = mag.transpose((0,2,1))[None] |
|
|
cosin = cosin.transpose((0,2,1))[None] |
|
|
sinin = sinin.transpose((0,2,1))[None] |
|
|
|
|
|
|
|
|
mag_in = preprocess_mag(mag) |
|
|
|
|
|
|
|
|
query = model['querynet'].predict(text_prompt_tkn)[0] |
|
|
|
|
|
output = model['sepnet'].predict((query, mag_in))[0] |
|
|
|
|
|
|
|
|
output = output[:,:,:orig_len,:] |
|
|
|
|
|
output_wav = feature_maps_to_wav(output, mag, sinin, cosin, input_wav.shape[-1]) |
|
|
|
|
|
return output_wav |
|
|
|
|
|
def split_audio(model): |
|
|
input_text = args.prompt |
|
|
input_wav = librosa.load(args.input[0], sr=32000, mono=True)[0][None,None,:] |
|
|
|
|
|
logger.info("input_text: %s" % input_text) |
|
|
|
|
|
|
|
|
logger.info('inference has started...') |
|
|
if args.benchmark: |
|
|
logger.info('BENCHMARK mode') |
|
|
total_time_estimation = 0 |
|
|
for i in range(args.benchmark_count): |
|
|
start = int(round(time.time() * 1000)) |
|
|
output = inference(model, input_text, input_wav) |
|
|
end = int(round(time.time() * 1000)) |
|
|
estimation_time = (end - start) |
|
|
|
|
|
|
|
|
logger.info(f'\tailia processing estimation time {estimation_time} ms') |
|
|
if i != 0: |
|
|
total_time_estimation = total_time_estimation + estimation_time |
|
|
|
|
|
logger.info(f'\taverage time estimation {total_time_estimation / (args.benchmark_count - 1)} ms') |
|
|
else: |
|
|
output = inference(model, input_text, input_wav) |
|
|
|
|
|
|
|
|
if args.savepath is None: |
|
|
sp = 'output.wav' |
|
|
else: |
|
|
sp = args.savepath |
|
|
scipy.io.wavfile.write(sp, 32000, np.round(output * 32767).astype(np.int16)) |
|
|
|
|
|
logger.info(f"Separated audio has been saved to {sp}") |
|
|
|
|
|
logger.info('Script finished successfully.') |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
check_and_download_models(QUERY_WEIGHT_PATH, QUERY_MODEL_PATH, REMOTE_PATH) |
|
|
check_and_download_models(SEPNET_WEIGHT_PATH, SEPNET_MODEL_PATH, REMOTE_PATH) |
|
|
|
|
|
env_id = args.env_id |
|
|
|
|
|
|
|
|
querynet = ailia.Net(None, QUERY_WEIGHT_PATH, env_id=env_id) |
|
|
sepnet = ailia.Net(None, SEPNET_WEIGHT_PATH) |
|
|
if args.disable_ailia_tokenizer: |
|
|
from transformers import RobertaTokenizer |
|
|
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') |
|
|
else: |
|
|
import ailia_tokenizer |
|
|
tokenizer = ailia_tokenizer.RobertaTokenizer.from_pretrained('./tokenizer/') |
|
|
model = { |
|
|
'querynet': querynet, |
|
|
'sepnet':sepnet, |
|
|
'tokenizer':tokenizer |
|
|
} |
|
|
|
|
|
split_audio(model) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|