File size: 8,239 Bytes
f5fcbcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""VoiceCLAP-Small: dual-tower CLAP using BUD-E-Whisper-Small + MiniLM.

Standalone single-file implementation. Only depends on PyTorch and
HuggingFace `transformers` (for `BertModel`, `PreTrainedModel`, and
`PretrainedConfig`).
"""
import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertConfig, BertModel, PreTrainedModel

try:
    from .configuration_voiceclap import VoiceCLAPSmallConfig
except ImportError:
    from configuration_voiceclap import VoiceCLAPSmallConfig


class _LayerNorm(nn.LayerNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)


def _sinusoids(length: int, channels: int, max_timescale: float = 10000.0) -> torch.Tensor:
    assert channels % 2 == 0
    log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, None] * inv_timescales[None, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)


class _MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        n_batch, n_ctx, n_state = q.shape
        head_dim = n_state // self.n_head
        q = q.view(n_batch, n_ctx, self.n_head, head_dim).transpose(1, 2)
        k = k.view(n_batch, n_ctx, self.n_head, head_dim).transpose(1, 2)
        v = v.view(n_batch, n_ctx, self.n_head, head_dim).transpose(1, 2)
        out = F.scaled_dot_product_attention(q, k, v)
        out = out.transpose(1, 2).reshape(n_batch, n_ctx, n_state)
        return self.out(out)


class _ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.attn = _MultiHeadAttention(n_state, n_head)
        self.attn_ln = _LayerNorm(n_state)
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state))
        self.mlp_ln = _LayerNorm(n_state)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.attn_ln(x))
        x = x + self.mlp(self.mlp_ln(x))
        return x


class _WhisperAudioEncoder(nn.Module):
    """Whisper-style audio encoder. Takes a precomputed log-mel spectrogram."""

    def __init__(
        self,
        n_mels: int = 80,
        n_ctx: int = 1500,
        n_state: int = 768,
        n_head: int = 12,
        n_layer: int = 12,
        output_dim: int = 768,
    ):
        super().__init__()
        self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", _sinusoids(n_ctx, n_state))
        self.blocks = nn.ModuleList(
            [_ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = _LayerNorm(n_state)
        self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
        self.proj = nn.Linear(n_state, output_dim)

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        # mel: (B, n_mels, T_mel)
        x = F.gelu(self.conv1(mel))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)  # (B, T', D)
        T = x.size(1)
        x = x + self.positional_embedding[:T].to(dtype=x.dtype, device=x.device)
        for block in self.blocks:
            x = block(x)
        x = x.permute(0, 2, 1)
        x = self.avg_pooler(x)
        x = x.permute(0, 2, 1)
        x = self.ln_post(x)
        x = self.proj(x)
        return x


class VoiceCLAPSmall(PreTrainedModel):
    config_class = VoiceCLAPSmallConfig

    def __init__(self, config: VoiceCLAPSmallConfig):
        super().__init__(config)
        self.audio_encoder = _WhisperAudioEncoder(
            n_mels=config.n_mels,
            n_ctx=config.n_ctx,
            n_state=config.n_state,
            n_head=config.n_head,
            n_layer=config.n_layer,
            output_dim=config.embed_dim,
        )
        self.audio_proj = nn.Sequential(
            nn.Linear(config.embed_dim, config.embed_dim),
            nn.GELU(),
            nn.Linear(config.embed_dim, config.embed_dim),
        )
        bert_config = BertConfig(
            vocab_size=config.text_vocab_size,
            hidden_size=config.text_hidden_dim,
            num_hidden_layers=config.text_num_layers,
            num_attention_heads=config.text_num_heads,
            intermediate_size=config.text_intermediate_size,
            max_position_embeddings=config.text_max_position_embeddings,
            layer_norm_eps=config.text_layer_norm_eps,
            pad_token_id=config.text_pad_token_id,
        )
        self.text_encoder = BertModel(bert_config, add_pooling_layer=False)
        self.text_proj = nn.Sequential(
            nn.Linear(config.text_hidden_dim, config.text_proj_hidden, bias=False),
            nn.GELU(),
            nn.Linear(config.text_proj_hidden, config.embed_dim, bias=False),
        )
        self.logit_scale = nn.Parameter(torch.zeros(()))
        self.logit_bias = nn.Parameter(torch.zeros(()))

        # Mel filterbank used by encode_waveform / compute_log_mel.
        # 80 mel bins x 201 freq bins for n_fft=400, sr=16000 (Whisper-style).
        self.register_buffer(
            "mel_filters",
            torch.zeros(config.n_mels, 201),
            persistent=True,
        )
        self.post_init()

    @torch.no_grad()
    def compute_log_mel(
        self, waveform: torch.Tensor, sample_rate: int = 16000
    ) -> torch.Tensor:
        """Whisper-style log-mel spectrogram. waveform: (B, T) or (T,) at 16 kHz.

        Returns (B, n_mels, T_mel). Matches the training-time preprocessing
        bit-exactly so embeddings reproduce the published results.
        """
        if sample_rate != 16000:
            raise ValueError(f"sample_rate must be 16000, got {sample_rate}")
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        device = self.mel_filters.device
        waveform = waveform.to(device=device, dtype=torch.float32)
        window = torch.hann_window(400, device=device)
        stft = torch.stft(waveform, n_fft=400, hop_length=160, window=window, return_complex=True)
        magnitudes = stft[..., :-1].abs() ** 2
        mel = self.mel_filters.to(magnitudes.dtype) @ magnitudes
        log_spec = torch.clamp(mel, min=1e-10).log10()
        log_spec = torch.maximum(log_spec, log_spec.amax(dim=(-2, -1), keepdim=True) - 8.0)
        log_spec = (log_spec + 4.0) / 4.0
        return log_spec

    def encode_waveform(self, waveform: torch.Tensor, sample_rate: int = 16000) -> torch.Tensor:
        """Encode raw 16 kHz waveform; calls ``compute_log_mel`` then ``encode_audio``."""
        mel = self.compute_log_mel(waveform, sample_rate=sample_rate)
        return self.encode_audio(mel)

    def encode_audio(self, mel: torch.Tensor) -> torch.Tensor:
        feats = self.audio_encoder(mel)            # (B, T', D)
        feats = feats.mean(dim=1)                  # clip-level mean
        feats = self.audio_proj(feats)
        return F.normalize(feats, dim=-1)

    def encode_text(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if attention_mask is None:
            attention_mask = (input_ids != self.config.text_pad_token_id).long()
        out = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden = out.last_hidden_state                                  # (B, T, H)
        mask = attention_mask.unsqueeze(-1).to(hidden.dtype)
        pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
        feats = self.text_proj(pooled)
        return F.normalize(feats, dim=-1)