"""VoiceCLAP-Small: dual-tower CLAP using BUD-E-Whisper-Small + MiniLM. Standalone single-file implementation. Only depends on PyTorch and HuggingFace `transformers` (for `BertModel`, `PreTrainedModel`, and `PretrainedConfig`). """ import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import BertConfig, BertModel, PreTrainedModel try: from .configuration_voiceclap import VoiceCLAPSmallConfig except ImportError: from configuration_voiceclap import VoiceCLAPSmallConfig class _LayerNorm(nn.LayerNorm): def forward(self, x): return super().forward(x.float()).type(x.dtype) def _sinusoids(length: int, channels: int, max_timescale: float = 10000.0) -> torch.Tensor: assert channels % 2 == 0 log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) scaled_time = torch.arange(length)[:, None] * inv_timescales[None, :] return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) class _MultiHeadAttention(nn.Module): def __init__(self, n_state: int, n_head: int): super().__init__() self.n_head = n_head self.query = nn.Linear(n_state, n_state) self.key = nn.Linear(n_state, n_state, bias=False) self.value = nn.Linear(n_state, n_state) self.out = nn.Linear(n_state, n_state) def forward(self, x: torch.Tensor) -> torch.Tensor: q = self.query(x) k = self.key(x) v = self.value(x) n_batch, n_ctx, n_state = q.shape head_dim = n_state // self.n_head q = q.view(n_batch, n_ctx, self.n_head, head_dim).transpose(1, 2) k = k.view(n_batch, n_ctx, self.n_head, head_dim).transpose(1, 2) v = v.view(n_batch, n_ctx, self.n_head, head_dim).transpose(1, 2) out = F.scaled_dot_product_attention(q, k, v) out = out.transpose(1, 2).reshape(n_batch, n_ctx, n_state) return self.out(out) class _ResidualAttentionBlock(nn.Module): def __init__(self, n_state: int, n_head: int): super().__init__() self.attn = _MultiHeadAttention(n_state, n_head) self.attn_ln = _LayerNorm(n_state) n_mlp = n_state * 4 self.mlp = nn.Sequential(nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)) self.mlp_ln = _LayerNorm(n_state) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.attn_ln(x)) x = x + self.mlp(self.mlp_ln(x)) return x class _WhisperAudioEncoder(nn.Module): """Whisper-style audio encoder. Takes a precomputed log-mel spectrogram.""" def __init__( self, n_mels: int = 80, n_ctx: int = 1500, n_state: int = 768, n_head: int = 12, n_layer: int = 12, output_dim: int = 768, ): super().__init__() self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) self.register_buffer("positional_embedding", _sinusoids(n_ctx, n_state)) self.blocks = nn.ModuleList( [_ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] ) self.ln_post = _LayerNorm(n_state) self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2) self.proj = nn.Linear(n_state, output_dim) def forward(self, mel: torch.Tensor) -> torch.Tensor: # mel: (B, n_mels, T_mel) x = F.gelu(self.conv1(mel)) x = F.gelu(self.conv2(x)) x = x.permute(0, 2, 1) # (B, T', D) T = x.size(1) x = x + self.positional_embedding[:T].to(dtype=x.dtype, device=x.device) for block in self.blocks: x = block(x) x = x.permute(0, 2, 1) x = self.avg_pooler(x) x = x.permute(0, 2, 1) x = self.ln_post(x) x = self.proj(x) return x class VoiceCLAPSmall(PreTrainedModel): config_class = VoiceCLAPSmallConfig def __init__(self, config: VoiceCLAPSmallConfig): super().__init__(config) self.audio_encoder = _WhisperAudioEncoder( n_mels=config.n_mels, n_ctx=config.n_ctx, n_state=config.n_state, n_head=config.n_head, n_layer=config.n_layer, output_dim=config.embed_dim, ) self.audio_proj = nn.Sequential( nn.Linear(config.embed_dim, config.embed_dim), nn.GELU(), nn.Linear(config.embed_dim, config.embed_dim), ) bert_config = BertConfig( vocab_size=config.text_vocab_size, hidden_size=config.text_hidden_dim, num_hidden_layers=config.text_num_layers, num_attention_heads=config.text_num_heads, intermediate_size=config.text_intermediate_size, max_position_embeddings=config.text_max_position_embeddings, layer_norm_eps=config.text_layer_norm_eps, pad_token_id=config.text_pad_token_id, ) self.text_encoder = BertModel(bert_config, add_pooling_layer=False) self.text_proj = nn.Sequential( nn.Linear(config.text_hidden_dim, config.text_proj_hidden, bias=False), nn.GELU(), nn.Linear(config.text_proj_hidden, config.embed_dim, bias=False), ) self.logit_scale = nn.Parameter(torch.zeros(())) self.logit_bias = nn.Parameter(torch.zeros(())) # Mel filterbank used by encode_waveform / compute_log_mel. # 80 mel bins x 201 freq bins for n_fft=400, sr=16000 (Whisper-style). self.register_buffer( "mel_filters", torch.zeros(config.n_mels, 201), persistent=True, ) self.post_init() @torch.no_grad() def compute_log_mel( self, waveform: torch.Tensor, sample_rate: int = 16000 ) -> torch.Tensor: """Whisper-style log-mel spectrogram. waveform: (B, T) or (T,) at 16 kHz. Returns (B, n_mels, T_mel). Matches the training-time preprocessing bit-exactly so embeddings reproduce the published results. """ if sample_rate != 16000: raise ValueError(f"sample_rate must be 16000, got {sample_rate}") if waveform.dim() == 1: waveform = waveform.unsqueeze(0) device = self.mel_filters.device waveform = waveform.to(device=device, dtype=torch.float32) window = torch.hann_window(400, device=device) stft = torch.stft(waveform, n_fft=400, hop_length=160, window=window, return_complex=True) magnitudes = stft[..., :-1].abs() ** 2 mel = self.mel_filters.to(magnitudes.dtype) @ magnitudes log_spec = torch.clamp(mel, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.amax(dim=(-2, -1), keepdim=True) - 8.0) log_spec = (log_spec + 4.0) / 4.0 return log_spec def encode_waveform(self, waveform: torch.Tensor, sample_rate: int = 16000) -> torch.Tensor: """Encode raw 16 kHz waveform; calls ``compute_log_mel`` then ``encode_audio``.""" mel = self.compute_log_mel(waveform, sample_rate=sample_rate) return self.encode_audio(mel) def encode_audio(self, mel: torch.Tensor) -> torch.Tensor: feats = self.audio_encoder(mel) # (B, T', D) feats = feats.mean(dim=1) # clip-level mean feats = self.audio_proj(feats) return F.normalize(feats, dim=-1) def encode_text( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if attention_mask is None: attention_mask = (input_ids != self.config.text_pad_token_id).long() out = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask) hidden = out.last_hidden_state # (B, T, H) mask = attention_mask.unsqueeze(-1).to(hidden.dtype) pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) feats = self.text_proj(pooled) return F.normalize(feats, dim=-1)