niobures's picture
CLAP (code, models, paper)
f8ffad7 verified
import sys
import time
from logging import getLogger
import json
import random
import librosa
import numpy as np
import ailia
# import original modules
sys.path.append('../../util')
from arg_utils import get_base_parser, update_parser, get_savepath # noqa
from model_utils import check_and_download_models # noqa
logger = getLogger(__name__)
# ======================
# Parameters
# ======================
CAPTION_WEIGHT_PATH_2023 = 'msclap_2023_caption.onnx'
AUDIO_WEIGHT_PATH_2023 = 'msclap_2023_audio.onnx'
CAPTION_MODEL_PATH_2023 = 'msclap_2023_caption.onnx.prototxt'
AUDIO_MODEL_PATH_2023 = 'msclap_2023_audio.onnx.prototxt'
CAPTION_WEIGHT_PATH_2022 = 'msclap_2022_caption.onnx'
AUDIO_WEIGHT_PATH_2022 = 'msclap_2022_audio.onnx'
CAPTION_MODEL_PATH_2022 = 'msclap_2022_caption.onnx.prototxt'
AUDIO_MODEL_PATH_2022 = 'msclap_2022_audio.onnx.prototxt'
REMOTE_PATH = "https://storage.googleapis.com/ailia-models/msclap/"
# ======================
# Arguemnt Parser Config
# ======================
parser = get_base_parser(
'msclap', None, None
)
parser.add_argument(
"-a", "--audio", type=str,
default="input.wav",
help="Input audio file path."
)
parser.add_argument(
"-t", "--text", type=str,
default="captions.txt",
help="Input text caption file path"
)
parser.add_argument(
"-v", "--version", type=str,
default="2023",
help="Version of the CLAP model (2022 or 2023)."
)
parser.add_argument(
'-w', '--write_json',
action='store_true',
help='Flag to output results to json file.'
)
parser.add_argument(
'--disable_ailia_tokenizer',
action='store_true',
help='disable ailia tokenizer.'
)
args = update_parser(parser, check_input_type=False)
# ======================
# Helper functions
# ======================
def read_audio(audio_path):
r"""Loads audio file or array and returns a numpy tensor"""
# Randomly sample a segment of audio_duration from the clip or pad to match duration
audio_time_series, sample_rate = librosa.load(audio_path, sr=None)
return audio_time_series, sample_rate
def resample_audio(audio_time_series, sample_rate, resample_rate):
resample_rate = 44100
if resample_rate != sample_rate:
audio_time_series = librosa.resample(
audio_time_series,
orig_sr=sample_rate,
target_sr=resample_rate,
res_type = 'sinc_best'
)
return audio_time_series, resample_rate
def resize_audio(audio_time_series, sample_rate, audio_duration, resample=False):
r"""Loads audio file and returns raw audio."""
# Randomly sample a segment of audio_duration from the clip or pad to match duration
audio_time_series = audio_time_series.reshape(-1)
# audio_time_series is shorter than predefined audio duration,
# so audio_time_series is extended
if audio_duration*sample_rate >= audio_time_series.shape[0]:
repeat_factor = int(np.ceil((audio_duration*sample_rate) /
audio_time_series.shape[0]))
# Repeat audio_time_series by repeat_factor to match audio_duration
audio_time_series = np.tile(audio_time_series,repeat_factor)
# remove excess part of audio_time_series
audio_time_series = audio_time_series[0:audio_duration*sample_rate]
else:
# audio_time_series is longer than predefined audio duration,
# so audio_time_series is trimmed
start_index = random.randrange(
audio_time_series.shape[0] - audio_duration*sample_rate)
audio_time_series = audio_time_series[start_index:start_index +
audio_duration*sample_rate]
return audio_time_series
def get_audio_embeddings(wav_input, sample_rate, model, version="2023"):
if version in ('2023', '2022'):
wav_input = resample_audio(wav_input, sample_rate, 44100)[0]
wav_input = resize_audio(wav_input, 44100, 7)[None]
return model['audio_model'].predict(wav_input)
def get_caption_embeddings(text_input, model, version="2023"):
# preprocesing
if version == '2023':
text_input = [t + ' <|endoftext|>' for t in text_input]
tokenized = dict(model['tokenizer'](text_input, padding = True, return_tensors = 'np'))
# inference
model_input = (tokenized['input_ids'], tokenized['attention_mask'])
return model['caption_model'].predict(model_input)[0]
def cossim(v1, v2):
return np.sum(v1 * v2, axis = -1) / (np.sum(v1 ** 2, axis = -1) ** 0.5 * np.sum(v2 ** 2, axis = -1) ** 0.5)
def print_sorted_dict(d):
m_len = max([len(k) for k in d.keys()])
for k, v in sorted(d.items(), key=lambda x: x[1], reverse=True):
pad = ' ' * (m_len - len(k) + 4)
print(f'{pad + k}: {v}')
def save_sorted_dict_as_json(d):
m_len = max([len(k) for k in d.keys()])
result = []
for k, v in sorted(d.items(), key=lambda x: x[1], reverse=True):
result.append({"caption": k, "similarity": float(v)})
with open('output.json', 'w', encoding='utf-8') as f:
json.dump(result, f, indent=2)
# ======================
# Main functions
# ======================
def inference(model, input_text, input_wav, sample_rate, version):
# get embeddings
audio_embeddings = get_audio_embeddings(input_wav, sample_rate, model, version)
caption_embeddings = get_caption_embeddings(input_text, model, version)
return cossim(audio_embeddings, caption_embeddings)
def estimate_best_caption(model):
# load inputs
#input_text = CAPTIONS
with open(args.text, 'r') as f:
input_text = f.read().splitlines()
#input_text = args.input.split('.')
input_wav, sample_rate = read_audio(args.audio)
input_wav = input_wav[None]
logger.info("input_text: %s" % input_text)
# inference
logger.info('inference has started...')
if args.benchmark:
logger.info('BENCHMARK mode')
total_time_estimation = 0
for i in range(args.benchmark_count):
start = int(round(time.time() * 1000))
output = inference(model, input_text, input_wav, sample_rate, args.version)
end = int(round(time.time() * 1000))
estimation_time = (end - start)
# Logging
logger.info(f'\tailia processing estimation time {estimation_time} ms')
if i != 0:
total_time_estimation = total_time_estimation + estimation_time
logger.info(f'\taverage time estimation {total_time_estimation / (args.benchmark_count - 1)} ms')
else:
output = inference(model, input_text, input_wav, sample_rate, args.version)
print(f"Similarity: ")
print_sorted_dict(dict(zip(input_text, output)))
if args.write_json:
save_sorted_dict_as_json(dict(zip(input_text, output)))
logger.info('Script finished successfully.')
def main():
# model files check and download
if args.version == '2023':
check_and_download_models(
CAPTION_WEIGHT_PATH_2023,
CAPTION_MODEL_PATH_2023,
REMOTE_PATH
)
check_and_download_models(
AUDIO_WEIGHT_PATH_2023,
AUDIO_MODEL_PATH_2023,
REMOTE_PATH
)
elif args.version == '2022':
check_and_download_models(
CAPTION_WEIGHT_PATH_2022,
CAPTION_MODEL_PATH_2022,
REMOTE_PATH
)
check_and_download_models(
AUDIO_WEIGHT_PATH_2022,
AUDIO_MODEL_PATH_2022,
REMOTE_PATH
)
env_id = args.env_id
# disable FP16
if "FP16" in ailia.get_environment(args.env_id).props or sys.platform == 'Darwin':
logger.warning('This model do not work on FP16. So use CPU mode.')
env_id = 0
# initialize
if args.version == '2023':
caption_model = ailia.Net(CAPTION_MODEL_PATH_2023, CAPTION_WEIGHT_PATH_2023, env_id=env_id)
audio_model = ailia.Net(AUDIO_MODEL_PATH_2023, AUDIO_WEIGHT_PATH_2023, env_id=env_id)
if args.disable_ailia_tokenizer:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token': '!'})
else:
from ailia_tokenizer import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('./tokenizer_gpt2/')
tokenizer.add_special_tokens({'pad_token': '!'})
#tokenizer._pad_token_id = 0
elif args.version == '2022':
caption_model = ailia.Net(CAPTION_MODEL_PATH_2022, CAPTION_WEIGHT_PATH_2022, env_id=env_id)
audio_model = ailia.Net(AUDIO_MODEL_PATH_2022, AUDIO_WEIGHT_PATH_2022, env_id=env_id)
if args.disable_ailia_tokenizer:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
else:
from ailia_tokenizer import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('./tokenizer_bert/')
model = {
'caption_model':caption_model,
'audio_model':audio_model,
'tokenizer':tokenizer
}
estimate_best_caption(model)
if __name__ == '__main__':
main()