File size: 5,180 Bytes
8e60cc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
Contrastive Language-Audio Pretraining Model from LAION
--------------------------------------------------------
Paper: https://arxiv.org/abs/2211.06687
Authors (equal contributions): Ke Chen, Yusong Wu, Tianyu Zhang, Yuchen Hui
Support: LAION
"""
import os
import json
import torch
import librosa
import torchaudio
import transformers
import numpy as np
from pathlib import Path
from packaging import version

from .data import get_audio_features
from .data import int16_to_float32, float32_to_int16
from .clap_model import CLAP

from transformers import RobertaTokenizer
import wget

BASE_DIR = Path(__file__).resolve().parent

class CLAP_Module(torch.nn.Module):
    def __init__(self, amodel='HTSAT-tiny', tmodel='roberta') -> None:
        super(CLAP_Module, self).__init__()
        
        config_path = os.path.join(BASE_DIR, 'model_configs', f'{amodel}.json')
        with open(config_path, "r") as f:
            model_cfg = json.load(f)
        
        self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
                
        model_cfg["text_cfg"]["model_type"] = tmodel
        model = CLAP(**model_cfg)
        
        self.model = model
        self.model_cfg = model_cfg

    def tokenizer(self, text):
        result = self.tokenize(
            text,
            padding="max_length",
            truncation=True,
            max_length=77,
            return_tensors="pt",
        )
        return result

    def load_ckpt(self, ckpt_folder_path, ckpt_name):
        ckpt_path = os.path.join(ckpt_folder_path, ckpt_name)
        
        if os.path.exists(ckpt_path):
            print(f'Load checkpoint from {ckpt_path}')
        else:
            download_link = 'https://huggingface.co/lukewys/laion_clap/resolve/main/'
            print(f'Download checkpoint from {download_link + ckpt_name}.')
            ckpt_path = wget.download(download_link + ckpt_name, ckpt_folder_path)
            print('Download completed!')
        print()
        
        checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False)
        
        if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
            state_dict = checkpoint["state_dict"]
        else:
            state_dict = checkpoint
            
        if next(iter(state_dict.items()))[0].startswith("module"):
            state_dict = {k[7:]: v for k, v in state_dict.items()}
             
        if version.parse(transformers.__version__) >= version.parse("4.31.0"): 
            del state_dict["text_branch.embeddings.position_ids"]
    
        self.model.load_state_dict(state_dict)
    
    def get_audio_embedding(self, x, sr=16000, normalize=False, use_tensor=True):
        self.model.eval()
        if isinstance(x, str):
            x = [x]

        audio_input = []
        for audio_waveform in x:
            
            if isinstance(audio_waveform, str):
                # load the waveform of the shape (T,), should resample to 48000
                audio_waveform, _ = librosa.load(audio_waveform, sr=48000)
            elif sr != 48000:
                audio_waveform = torchaudio.functional.resample(audio_waveform, orig_freq=sr, new_freq=48000)                
                
            if isinstance(audio_waveform, torch.Tensor):
                audio_waveform = audio_waveform.numpy()
                
            # quantize
            audio_waveform = int16_to_float32(float32_to_int16(audio_waveform))
            audio_waveform = torch.from_numpy(audio_waveform).float()

            temp_dict = {}
            temp_dict = get_audio_features(
                temp_dict, audio_waveform, 480000, 
                data_truncating='rand_trunc', 
                data_filling='repeatpad',
                audio_cfg=self.model_cfg['audio_cfg'],
                require_grad=audio_waveform.requires_grad
            )
            
            audio_input.append(temp_dict)
            
        audio_embed = self.model.get_audio_embedding(audio_input, normalize)
        
        if not use_tensor:
            audio_embed = audio_embed.detach().cpu().numpy()

        return audio_embed

    def get_text_embedding(self, x, normalize=False, use_tensor=True):
        self.model.eval()
        if isinstance(x, str):
            x = [x]
            
        token_data = self.tokenizer(x)
        sequence_lengths = (torch.ne(token_data['attention_mask'], 0).sum(-1) - 1)
        setence_embeds = self.model.get_text_embedding(token_data, normalize)
        word_embeds = self.model.get_word_embedding(token_data)
        
        if not use_tensor:
            setence_embeds = setence_embeds.detach().cpu().numpy()
            word_embeds = word_embeds.detach().cpu().numpy()
            
        return setence_embeds, word_embeds, sequence_lengths
        
    def get_clap_score(self, text, audio, sr=16000):
        setence_embeds, word_embeds, sequence_lengths = self.get_text_embedding(text, normalize=True)
        audio_embeds = self.get_audio_embedding(audio, sr=16000, normalize=True)
        
        clap_score = torch.nn.functional.cosine_similarity(setence_embeds, audio_embeds, dim=-1)

        return clap_score