|
|
import sys |
|
|
import time |
|
|
from logging import getLogger |
|
|
import json |
|
|
|
|
|
import random |
|
|
|
|
|
import librosa |
|
|
import numpy as np |
|
|
|
|
|
import ailia |
|
|
|
|
|
|
|
|
sys.path.append('../../util') |
|
|
from arg_utils import get_base_parser, update_parser, get_savepath |
|
|
from model_utils import check_and_download_models |
|
|
|
|
|
logger = getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CAPTION_WEIGHT_PATH_2023 = 'msclap_2023_caption.onnx' |
|
|
AUDIO_WEIGHT_PATH_2023 = 'msclap_2023_audio.onnx' |
|
|
|
|
|
CAPTION_MODEL_PATH_2023 = 'msclap_2023_caption.onnx.prototxt' |
|
|
AUDIO_MODEL_PATH_2023 = 'msclap_2023_audio.onnx.prototxt' |
|
|
|
|
|
CAPTION_WEIGHT_PATH_2022 = 'msclap_2022_caption.onnx' |
|
|
AUDIO_WEIGHT_PATH_2022 = 'msclap_2022_audio.onnx' |
|
|
|
|
|
CAPTION_MODEL_PATH_2022 = 'msclap_2022_caption.onnx.prototxt' |
|
|
AUDIO_MODEL_PATH_2022 = 'msclap_2022_audio.onnx.prototxt' |
|
|
|
|
|
REMOTE_PATH = "https://storage.googleapis.com/ailia-models/msclap/" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = get_base_parser( |
|
|
'msclap', None, None |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"-a", "--audio", type=str, |
|
|
default="input.wav", |
|
|
help="Input audio file path." |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"-t", "--text", type=str, |
|
|
default="captions.txt", |
|
|
help="Input text caption file path" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"-v", "--version", type=str, |
|
|
default="2023", |
|
|
help="Version of the CLAP model (2022 or 2023)." |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'-w', '--write_json', |
|
|
action='store_true', |
|
|
help='Flag to output results to json file.' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--disable_ailia_tokenizer', |
|
|
action='store_true', |
|
|
help='disable ailia tokenizer.' |
|
|
) |
|
|
args = update_parser(parser, check_input_type=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_audio(audio_path): |
|
|
r"""Loads audio file or array and returns a numpy tensor""" |
|
|
|
|
|
audio_time_series, sample_rate = librosa.load(audio_path, sr=None) |
|
|
return audio_time_series, sample_rate |
|
|
|
|
|
def resample_audio(audio_time_series, sample_rate, resample_rate): |
|
|
resample_rate = 44100 |
|
|
if resample_rate != sample_rate: |
|
|
audio_time_series = librosa.resample( |
|
|
audio_time_series, |
|
|
orig_sr=sample_rate, |
|
|
target_sr=resample_rate, |
|
|
res_type = 'sinc_best' |
|
|
) |
|
|
return audio_time_series, resample_rate |
|
|
|
|
|
|
|
|
def resize_audio(audio_time_series, sample_rate, audio_duration, resample=False): |
|
|
r"""Loads audio file and returns raw audio.""" |
|
|
|
|
|
audio_time_series = audio_time_series.reshape(-1) |
|
|
|
|
|
|
|
|
if audio_duration*sample_rate >= audio_time_series.shape[0]: |
|
|
repeat_factor = int(np.ceil((audio_duration*sample_rate) / |
|
|
audio_time_series.shape[0])) |
|
|
|
|
|
audio_time_series = np.tile(audio_time_series,repeat_factor) |
|
|
|
|
|
audio_time_series = audio_time_series[0:audio_duration*sample_rate] |
|
|
else: |
|
|
|
|
|
|
|
|
start_index = random.randrange( |
|
|
audio_time_series.shape[0] - audio_duration*sample_rate) |
|
|
audio_time_series = audio_time_series[start_index:start_index + |
|
|
audio_duration*sample_rate] |
|
|
return audio_time_series |
|
|
|
|
|
def get_audio_embeddings(wav_input, sample_rate, model, version="2023"): |
|
|
if version in ('2023', '2022'): |
|
|
wav_input = resample_audio(wav_input, sample_rate, 44100)[0] |
|
|
wav_input = resize_audio(wav_input, 44100, 7)[None] |
|
|
return model['audio_model'].predict(wav_input) |
|
|
|
|
|
def get_caption_embeddings(text_input, model, version="2023"): |
|
|
|
|
|
|
|
|
if version == '2023': |
|
|
text_input = [t + ' <|endoftext|>' for t in text_input] |
|
|
tokenized = dict(model['tokenizer'](text_input, padding = True, return_tensors = 'np')) |
|
|
|
|
|
|
|
|
model_input = (tokenized['input_ids'], tokenized['attention_mask']) |
|
|
return model['caption_model'].predict(model_input)[0] |
|
|
|
|
|
def cossim(v1, v2): |
|
|
return np.sum(v1 * v2, axis = -1) / (np.sum(v1 ** 2, axis = -1) ** 0.5 * np.sum(v2 ** 2, axis = -1) ** 0.5) |
|
|
|
|
|
def print_sorted_dict(d): |
|
|
m_len = max([len(k) for k in d.keys()]) |
|
|
for k, v in sorted(d.items(), key=lambda x: x[1], reverse=True): |
|
|
pad = ' ' * (m_len - len(k) + 4) |
|
|
print(f'{pad + k}: {v}') |
|
|
|
|
|
def save_sorted_dict_as_json(d): |
|
|
m_len = max([len(k) for k in d.keys()]) |
|
|
result = [] |
|
|
for k, v in sorted(d.items(), key=lambda x: x[1], reverse=True): |
|
|
result.append({"caption": k, "similarity": float(v)}) |
|
|
with open('output.json', 'w', encoding='utf-8') as f: |
|
|
json.dump(result, f, indent=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(model, input_text, input_wav, sample_rate, version): |
|
|
|
|
|
audio_embeddings = get_audio_embeddings(input_wav, sample_rate, model, version) |
|
|
caption_embeddings = get_caption_embeddings(input_text, model, version) |
|
|
|
|
|
return cossim(audio_embeddings, caption_embeddings) |
|
|
|
|
|
def estimate_best_caption(model): |
|
|
|
|
|
|
|
|
with open(args.text, 'r') as f: |
|
|
input_text = f.read().splitlines() |
|
|
|
|
|
|
|
|
input_wav, sample_rate = read_audio(args.audio) |
|
|
input_wav = input_wav[None] |
|
|
|
|
|
logger.info("input_text: %s" % input_text) |
|
|
|
|
|
|
|
|
logger.info('inference has started...') |
|
|
if args.benchmark: |
|
|
logger.info('BENCHMARK mode') |
|
|
total_time_estimation = 0 |
|
|
for i in range(args.benchmark_count): |
|
|
start = int(round(time.time() * 1000)) |
|
|
output = inference(model, input_text, input_wav, sample_rate, args.version) |
|
|
end = int(round(time.time() * 1000)) |
|
|
estimation_time = (end - start) |
|
|
|
|
|
|
|
|
logger.info(f'\tailia processing estimation time {estimation_time} ms') |
|
|
if i != 0: |
|
|
total_time_estimation = total_time_estimation + estimation_time |
|
|
|
|
|
logger.info(f'\taverage time estimation {total_time_estimation / (args.benchmark_count - 1)} ms') |
|
|
else: |
|
|
output = inference(model, input_text, input_wav, sample_rate, args.version) |
|
|
|
|
|
print(f"Similarity: ") |
|
|
print_sorted_dict(dict(zip(input_text, output))) |
|
|
|
|
|
if args.write_json: |
|
|
save_sorted_dict_as_json(dict(zip(input_text, output))) |
|
|
|
|
|
logger.info('Script finished successfully.') |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
if args.version == '2023': |
|
|
check_and_download_models( |
|
|
CAPTION_WEIGHT_PATH_2023, |
|
|
CAPTION_MODEL_PATH_2023, |
|
|
REMOTE_PATH |
|
|
) |
|
|
check_and_download_models( |
|
|
AUDIO_WEIGHT_PATH_2023, |
|
|
AUDIO_MODEL_PATH_2023, |
|
|
REMOTE_PATH |
|
|
) |
|
|
elif args.version == '2022': |
|
|
check_and_download_models( |
|
|
CAPTION_WEIGHT_PATH_2022, |
|
|
CAPTION_MODEL_PATH_2022, |
|
|
REMOTE_PATH |
|
|
) |
|
|
check_and_download_models( |
|
|
AUDIO_WEIGHT_PATH_2022, |
|
|
AUDIO_MODEL_PATH_2022, |
|
|
REMOTE_PATH |
|
|
) |
|
|
|
|
|
env_id = args.env_id |
|
|
|
|
|
|
|
|
if "FP16" in ailia.get_environment(args.env_id).props or sys.platform == 'Darwin': |
|
|
logger.warning('This model do not work on FP16. So use CPU mode.') |
|
|
env_id = 0 |
|
|
|
|
|
|
|
|
if args.version == '2023': |
|
|
caption_model = ailia.Net(CAPTION_MODEL_PATH_2023, CAPTION_WEIGHT_PATH_2023, env_id=env_id) |
|
|
audio_model = ailia.Net(AUDIO_MODEL_PATH_2023, AUDIO_WEIGHT_PATH_2023, env_id=env_id) |
|
|
if args.disable_ailia_tokenizer: |
|
|
from transformers import AutoTokenizer |
|
|
tokenizer = AutoTokenizer.from_pretrained('gpt2') |
|
|
tokenizer.add_special_tokens({'pad_token': '!'}) |
|
|
else: |
|
|
from ailia_tokenizer import GPT2Tokenizer |
|
|
tokenizer = GPT2Tokenizer.from_pretrained('./tokenizer_gpt2/') |
|
|
tokenizer.add_special_tokens({'pad_token': '!'}) |
|
|
|
|
|
elif args.version == '2022': |
|
|
caption_model = ailia.Net(CAPTION_MODEL_PATH_2022, CAPTION_WEIGHT_PATH_2022, env_id=env_id) |
|
|
audio_model = ailia.Net(AUDIO_MODEL_PATH_2022, AUDIO_WEIGHT_PATH_2022, env_id=env_id) |
|
|
if args.disable_ailia_tokenizer: |
|
|
from transformers import AutoTokenizer |
|
|
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') |
|
|
else: |
|
|
from ailia_tokenizer import BertTokenizer |
|
|
tokenizer = BertTokenizer.from_pretrained('./tokenizer_bert/') |
|
|
|
|
|
model = { |
|
|
'caption_model':caption_model, |
|
|
'audio_model':audio_model, |
|
|
'tokenizer':tokenizer |
|
|
} |
|
|
|
|
|
estimate_best_caption(model) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |