|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import random |
|
|
|
|
|
import numpy as np |
|
|
import librosa |
|
|
|
|
|
import ailia |
|
|
|
|
|
|
|
|
sys.path.append('../../util') |
|
|
from arg_utils import get_base_parser, update_parser |
|
|
from model_utils import check_and_download_models |
|
|
|
|
|
|
|
|
from logging import getLogger |
|
|
|
|
|
logger = getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WEIGHT_IDENT_PATH = './proposed_iden.onnx' |
|
|
MODEL_IDENT_PATH = './proposed_iden.onnx.prototxt' |
|
|
WEIGHT_CLASSIFIER_PATH = './proposed_classifier.onnx' |
|
|
MODEL_CLASSIFIER_PATH = './proposed_classifier.onnx.prototxt' |
|
|
WEIGHT_VERI_PATH = './proposed_veri.onnx' |
|
|
MODEL_VERI_PATH = './proposed_veri.onnx.prototxt' |
|
|
REMOTE_PATH = \ |
|
|
'https://storage.googleapis.com/ailia-models/auto_speech/' |
|
|
|
|
|
WAVE_PATH = "wav/id10283/oGZsanLiXsY/00004.wav" |
|
|
|
|
|
|
|
|
SAMPLING_RATE = 16000 |
|
|
|
|
|
|
|
|
WINDOW_LENGTH = 25 |
|
|
WINDOW_STEP = 10 |
|
|
N_FFT = 512 |
|
|
|
|
|
|
|
|
AUDIO_NORM_TARGET_dBFS = -30 |
|
|
|
|
|
THRESHOLD = 0.26 |
|
|
|
|
|
INT16_MAX = (2 ** 15) - 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = get_base_parser( |
|
|
'AutoSpeech', WAVE_PATH, None, input_ftype='audio' |
|
|
) |
|
|
parser.add_argument( |
|
|
'-i1', '--input1', metavar='WAV', default=None, |
|
|
help='Specify an wav file to compare with the input2 wav. (verification mode)' |
|
|
) |
|
|
parser.add_argument( |
|
|
'-i2', '--input2', metavar='WAV', default=None, |
|
|
help='Specify an wav file to compare with the input1 wav. (verification mode)' |
|
|
) |
|
|
parser.add_argument( |
|
|
'-th', '--threshold', |
|
|
default=THRESHOLD, type=float, |
|
|
help='The similar threshold for verification.' |
|
|
) |
|
|
args = update_parser(parser) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_wave(path): |
|
|
|
|
|
wav, source_sr = librosa.load(path, sr=None) |
|
|
|
|
|
if source_sr is not None and source_sr != SAMPLING_RATE: |
|
|
wav = librosa.resample(wav, source_sr, SAMPLING_RATE) |
|
|
|
|
|
return wav |
|
|
|
|
|
|
|
|
def voxceleb1_ids(): |
|
|
with open("VoxCeleb1_ids.txt") as f: |
|
|
ids = [x.strip() for x in f] |
|
|
|
|
|
return ids |
|
|
|
|
|
|
|
|
def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): |
|
|
if increase_only and decrease_only: |
|
|
raise ValueError("Both increase only and decrease only are set") |
|
|
rms = np.sqrt(np.mean((wav * INT16_MAX) ** 2)) |
|
|
wave_dBFS = 20 * np.log10(rms / INT16_MAX) |
|
|
dBFS_change = target_dBFS - wave_dBFS |
|
|
if dBFS_change < 0 and increase_only or dBFS_change > 0 and decrease_only: |
|
|
return wav |
|
|
return wav * (10 ** (dBFS_change / 20)) |
|
|
|
|
|
|
|
|
def wav_to_spectrogram(wav): |
|
|
frames = np.abs(librosa.core.stft( |
|
|
wav, |
|
|
n_fft=N_FFT, |
|
|
hop_length=int(SAMPLING_RATE * WINDOW_STEP / 1000), |
|
|
win_length=int(SAMPLING_RATE * WINDOW_LENGTH / 1000), |
|
|
)) |
|
|
return frames.astype(np.float32).T |
|
|
|
|
|
|
|
|
def generate_sequence(feature, partial_n_frames, shift=None): |
|
|
while feature.shape[0] <= partial_n_frames: |
|
|
feature = np.repeat(feature, 2, axis=0) |
|
|
if shift is None: |
|
|
shift = partial_n_frames // 2 |
|
|
test_sequence = [] |
|
|
start = 0 |
|
|
while start + partial_n_frames <= feature.shape[0]: |
|
|
test_sequence.append(feature[start: start + partial_n_frames]) |
|
|
start += shift |
|
|
test_sequence = np.stack(test_sequence, axis=0) |
|
|
return test_sequence |
|
|
|
|
|
|
|
|
def cosine_similar(a, b, data_is_normalized=False): |
|
|
if not data_is_normalized: |
|
|
a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True) |
|
|
b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True) |
|
|
return np.dot(a, b.T) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess(wav): |
|
|
wav = normalize_volume(wav, AUDIO_NORM_TARGET_dBFS, increase_only=True) |
|
|
feature = wav_to_spectrogram(wav) |
|
|
|
|
|
sequence = generate_sequence(feature, partial_n_frames=300) |
|
|
|
|
|
mean = np.load('mean.npy') |
|
|
std = np.load('std.npy') |
|
|
sequence = (sequence - mean) / std |
|
|
|
|
|
|
|
|
|
|
|
return sequence |
|
|
|
|
|
|
|
|
def predict(wav, net, net_classifier=None): |
|
|
|
|
|
sequence = preprocess(wav) |
|
|
|
|
|
|
|
|
output = net.predict([sequence]) |
|
|
output = output[0] |
|
|
|
|
|
output = np.mean(output, axis=0, keepdims=True) |
|
|
|
|
|
if not net_classifier: |
|
|
return output |
|
|
|
|
|
output = net_classifier.predict([output]) |
|
|
output = output[0] |
|
|
|
|
|
idx = np.argsort(output[0])[::-1] |
|
|
|
|
|
return idx |
|
|
|
|
|
|
|
|
def eval_identification(net, net_classifier): |
|
|
ids = voxceleb1_ids() |
|
|
|
|
|
for input_path in args.input: |
|
|
logger.info(f'input: {input_path}') |
|
|
|
|
|
|
|
|
wav = read_wave(input_path) |
|
|
|
|
|
|
|
|
logger.info('Start inference...') |
|
|
if args.benchmark: |
|
|
logger.info('BENCHMARK mode') |
|
|
for i in range(5): |
|
|
start = int(round(time.time() * 1000)) |
|
|
idx = predict(wav, net, net_classifier) |
|
|
end = int(round(time.time() * 1000)) |
|
|
logger.info(f'\tailia processing time {end - start} ms') |
|
|
else: |
|
|
idx = predict(wav, net, net_classifier) |
|
|
|
|
|
logger.info(' Top5: %s' % ', '.join([ids[i] for i in idx[:5]])) |
|
|
|
|
|
logger.info('Script finished successfully.') |
|
|
|
|
|
|
|
|
def eval_verification(net): |
|
|
threshold = args.threshold |
|
|
input1 = args.input1 |
|
|
input2 = args.input2 |
|
|
|
|
|
if input1 is None: |
|
|
logger.error('input1 is not specified') |
|
|
sys.exit(-1) |
|
|
elif not os.path.isfile(input1): |
|
|
logger.error('specified input1 is not file path nor directory path') |
|
|
sys.exit(-1) |
|
|
if input2 is None: |
|
|
logger.error('input2 is not specified') |
|
|
sys.exit(-1) |
|
|
elif not os.path.isfile(input2): |
|
|
logger.error('specified input2 is not file path nor directory path') |
|
|
sys.exit(-1) |
|
|
|
|
|
logger.info(f'input1: {input1}') |
|
|
logger.info(f'input2: {input2}') |
|
|
|
|
|
|
|
|
wav1 = read_wave(args.input1) |
|
|
wav2 = read_wave(args.input2) |
|
|
|
|
|
|
|
|
logger.info('Start inference...') |
|
|
if args.benchmark: |
|
|
logger.info('BENCHMARK mode') |
|
|
for i in range(5): |
|
|
start = int(round(time.time() * 1000)) |
|
|
output = predict(wav1, net) |
|
|
output2 = predict(wav2, net) |
|
|
end = int(round(time.time() * 1000)) |
|
|
logger.info(f'\tailia processing time {end - start} ms') |
|
|
else: |
|
|
output = predict(wav1, net) |
|
|
output2 = predict(wav2, net) |
|
|
|
|
|
similar = cosine_similar(output, output2) |
|
|
logger.info(' similar: %.8f' % similar[0]) |
|
|
logger.info(' verification: %s (threshold: %.3f)' % |
|
|
('match' if similar[0] >= threshold else 'unmatch', threshold)) |
|
|
|
|
|
logger.info('Script finished successfully.') |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
if args.input1 or args.input2: |
|
|
check_and_download_models(WEIGHT_VERI_PATH, MODEL_VERI_PATH, REMOTE_PATH) |
|
|
else: |
|
|
logger.info('Checking identification model...') |
|
|
check_and_download_models(WEIGHT_IDENT_PATH, MODEL_IDENT_PATH, REMOTE_PATH) |
|
|
logger.info('Checking classification model...') |
|
|
check_and_download_models(WEIGHT_CLASSIFIER_PATH, MODEL_CLASSIFIER_PATH, REMOTE_PATH) |
|
|
|
|
|
env_id = args.env_id |
|
|
|
|
|
if args.input1 or args.input2: |
|
|
net = ailia.Net(MODEL_VERI_PATH, WEIGHT_VERI_PATH, env_id=env_id) |
|
|
|
|
|
eval_verification(net) |
|
|
else: |
|
|
|
|
|
net = ailia.Net(MODEL_IDENT_PATH, WEIGHT_IDENT_PATH, env_id=env_id) |
|
|
net_classifier = ailia.Net(MODEL_CLASSIFIER_PATH, WEIGHT_CLASSIFIER_PATH, env_id=env_id) |
|
|
|
|
|
eval_identification(net, net_classifier) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|