Spaces:
Sleeping
Sleeping
File size: 5,180 Bytes
8e60cc8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """
Contrastive Language-Audio Pretraining Model from LAION
--------------------------------------------------------
Paper: https://arxiv.org/abs/2211.06687
Authors (equal contributions): Ke Chen, Yusong Wu, Tianyu Zhang, Yuchen Hui
Support: LAION
"""
import os
import json
import torch
import librosa
import torchaudio
import transformers
import numpy as np
from pathlib import Path
from packaging import version
from .data import get_audio_features
from .data import int16_to_float32, float32_to_int16
from .clap_model import CLAP
from transformers import RobertaTokenizer
import wget
BASE_DIR = Path(__file__).resolve().parent
class CLAP_Module(torch.nn.Module):
def __init__(self, amodel='HTSAT-tiny', tmodel='roberta') -> None:
super(CLAP_Module, self).__init__()
config_path = os.path.join(BASE_DIR, 'model_configs', f'{amodel}.json')
with open(config_path, "r") as f:
model_cfg = json.load(f)
self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
model_cfg["text_cfg"]["model_type"] = tmodel
model = CLAP(**model_cfg)
self.model = model
self.model_cfg = model_cfg
def tokenizer(self, text):
result = self.tokenize(
text,
padding="max_length",
truncation=True,
max_length=77,
return_tensors="pt",
)
return result
def load_ckpt(self, ckpt_folder_path, ckpt_name):
ckpt_path = os.path.join(ckpt_folder_path, ckpt_name)
if os.path.exists(ckpt_path):
print(f'Load checkpoint from {ckpt_path}')
else:
download_link = 'https://huggingface.co/lukewys/laion_clap/resolve/main/'
print(f'Download checkpoint from {download_link + ckpt_name}.')
ckpt_path = wget.download(download_link + ckpt_name, ckpt_folder_path)
print('Download completed!')
print()
checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False)
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith("module"):
state_dict = {k[7:]: v for k, v in state_dict.items()}
if version.parse(transformers.__version__) >= version.parse("4.31.0"):
del state_dict["text_branch.embeddings.position_ids"]
self.model.load_state_dict(state_dict)
def get_audio_embedding(self, x, sr=16000, normalize=False, use_tensor=True):
self.model.eval()
if isinstance(x, str):
x = [x]
audio_input = []
for audio_waveform in x:
if isinstance(audio_waveform, str):
# load the waveform of the shape (T,), should resample to 48000
audio_waveform, _ = librosa.load(audio_waveform, sr=48000)
elif sr != 48000:
audio_waveform = torchaudio.functional.resample(audio_waveform, orig_freq=sr, new_freq=48000)
if isinstance(audio_waveform, torch.Tensor):
audio_waveform = audio_waveform.numpy()
# quantize
audio_waveform = int16_to_float32(float32_to_int16(audio_waveform))
audio_waveform = torch.from_numpy(audio_waveform).float()
temp_dict = {}
temp_dict = get_audio_features(
temp_dict, audio_waveform, 480000,
data_truncating='rand_trunc',
data_filling='repeatpad',
audio_cfg=self.model_cfg['audio_cfg'],
require_grad=audio_waveform.requires_grad
)
audio_input.append(temp_dict)
audio_embed = self.model.get_audio_embedding(audio_input, normalize)
if not use_tensor:
audio_embed = audio_embed.detach().cpu().numpy()
return audio_embed
def get_text_embedding(self, x, normalize=False, use_tensor=True):
self.model.eval()
if isinstance(x, str):
x = [x]
token_data = self.tokenizer(x)
sequence_lengths = (torch.ne(token_data['attention_mask'], 0).sum(-1) - 1)
setence_embeds = self.model.get_text_embedding(token_data, normalize)
word_embeds = self.model.get_word_embedding(token_data)
if not use_tensor:
setence_embeds = setence_embeds.detach().cpu().numpy()
word_embeds = word_embeds.detach().cpu().numpy()
return setence_embeds, word_embeds, sequence_lengths
def get_clap_score(self, text, audio, sr=16000):
setence_embeds, word_embeds, sequence_lengths = self.get_text_embedding(text, normalize=True)
audio_embeds = self.get_audio_embedding(audio, sr=16000, normalize=True)
clap_score = torch.nn.functional.cosine_similarity(setence_embeds, audio_embeds, dim=-1)
return clap_score |