Spaces:
Sleeping
Sleeping
| """ | |
| Contrastive Language-Audio Pretraining Model from LAION | |
| -------------------------------------------------------- | |
| Paper: https://arxiv.org/abs/2211.06687 | |
| Authors (equal contributions): Ke Chen, Yusong Wu, Tianyu Zhang, Yuchen Hui | |
| Support: LAION | |
| """ | |
| import os | |
| import json | |
| import torch | |
| import librosa | |
| import torchaudio | |
| import transformers | |
| import numpy as np | |
| from pathlib import Path | |
| from packaging import version | |
| from .data import get_audio_features | |
| from .data import int16_to_float32, float32_to_int16 | |
| from .clap_model import CLAP | |
| from transformers import RobertaTokenizer | |
| import wget | |
| BASE_DIR = Path(__file__).resolve().parent | |
| class CLAP_Module(torch.nn.Module): | |
| def __init__(self, amodel='HTSAT-tiny', tmodel='roberta') -> None: | |
| super(CLAP_Module, self).__init__() | |
| config_path = os.path.join(BASE_DIR, 'model_configs', f'{amodel}.json') | |
| with open(config_path, "r") as f: | |
| model_cfg = json.load(f) | |
| self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") | |
| model_cfg["text_cfg"]["model_type"] = tmodel | |
| model = CLAP(**model_cfg) | |
| self.model = model | |
| self.model_cfg = model_cfg | |
| def tokenizer(self, text): | |
| result = self.tokenize( | |
| text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=77, | |
| return_tensors="pt", | |
| ) | |
| return result | |
| def load_ckpt(self, ckpt_folder_path, ckpt_name): | |
| ckpt_path = os.path.join(ckpt_folder_path, ckpt_name) | |
| if os.path.exists(ckpt_path): | |
| print(f'Load checkpoint from {ckpt_path}') | |
| else: | |
| download_link = 'https://huggingface.co/lukewys/laion_clap/resolve/main/' | |
| print(f'Download checkpoint from {download_link + ckpt_name}.') | |
| ckpt_path = wget.download(download_link + ckpt_name, ckpt_folder_path) | |
| print('Download completed!') | |
| print() | |
| checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) | |
| if isinstance(checkpoint, dict) and "state_dict" in checkpoint: | |
| state_dict = checkpoint["state_dict"] | |
| else: | |
| state_dict = checkpoint | |
| if next(iter(state_dict.items()))[0].startswith("module"): | |
| state_dict = {k[7:]: v for k, v in state_dict.items()} | |
| if version.parse(transformers.__version__) >= version.parse("4.31.0"): | |
| del state_dict["text_branch.embeddings.position_ids"] | |
| self.model.load_state_dict(state_dict) | |
| def get_audio_embedding(self, x, sr=16000, normalize=False, use_tensor=True): | |
| self.model.eval() | |
| if isinstance(x, str): | |
| x = [x] | |
| audio_input = [] | |
| for audio_waveform in x: | |
| if isinstance(audio_waveform, str): | |
| # load the waveform of the shape (T,), should resample to 48000 | |
| audio_waveform, _ = librosa.load(audio_waveform, sr=48000) | |
| elif sr != 48000: | |
| audio_waveform = torchaudio.functional.resample(audio_waveform, orig_freq=sr, new_freq=48000) | |
| if isinstance(audio_waveform, torch.Tensor): | |
| audio_waveform = audio_waveform.numpy() | |
| # quantize | |
| audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) | |
| audio_waveform = torch.from_numpy(audio_waveform).float() | |
| temp_dict = {} | |
| temp_dict = get_audio_features( | |
| temp_dict, audio_waveform, 480000, | |
| data_truncating='rand_trunc', | |
| data_filling='repeatpad', | |
| audio_cfg=self.model_cfg['audio_cfg'], | |
| require_grad=audio_waveform.requires_grad | |
| ) | |
| audio_input.append(temp_dict) | |
| audio_embed = self.model.get_audio_embedding(audio_input, normalize) | |
| if not use_tensor: | |
| audio_embed = audio_embed.detach().cpu().numpy() | |
| return audio_embed | |
| def get_text_embedding(self, x, normalize=False, use_tensor=True): | |
| self.model.eval() | |
| if isinstance(x, str): | |
| x = [x] | |
| token_data = self.tokenizer(x) | |
| sequence_lengths = (torch.ne(token_data['attention_mask'], 0).sum(-1) - 1) | |
| setence_embeds = self.model.get_text_embedding(token_data, normalize) | |
| word_embeds = self.model.get_word_embedding(token_data) | |
| if not use_tensor: | |
| setence_embeds = setence_embeds.detach().cpu().numpy() | |
| word_embeds = word_embeds.detach().cpu().numpy() | |
| return setence_embeds, word_embeds, sequence_lengths | |
| def get_clap_score(self, text, audio, sr=16000): | |
| setence_embeds, word_embeds, sequence_lengths = self.get_text_embedding(text, normalize=True) | |
| audio_embeds = self.get_audio_embedding(audio, sr=16000, normalize=True) | |
| clap_score = torch.nn.functional.cosine_similarity(setence_embeds, audio_embeds, dim=-1) | |
| return clap_score |