voiceclap-small / modeling_voiceclap.py
gijs's picture
Initial release: VoiceCLAP-Small (BUD-E-Whisper + MiniLM, dual-tower CLAP, 1 epoch on voiceclap_10)
f5fcbcd verified
"""VoiceCLAP-Small: dual-tower CLAP using BUD-E-Whisper-Small + MiniLM.
Standalone single-file implementation. Only depends on PyTorch and
HuggingFace `transformers` (for `BertModel`, `PreTrainedModel`, and
`PretrainedConfig`).
"""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertConfig, BertModel, PreTrainedModel
try:
from .configuration_voiceclap import VoiceCLAPSmallConfig
except ImportError:
from configuration_voiceclap import VoiceCLAPSmallConfig
class _LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def _sinusoids(length: int, channels: int, max_timescale: float = 10000.0) -> torch.Tensor:
assert channels % 2 == 0
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, None] * inv_timescales[None, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class _MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
def forward(self, x: torch.Tensor) -> torch.Tensor:
q = self.query(x)
k = self.key(x)
v = self.value(x)
n_batch, n_ctx, n_state = q.shape
head_dim = n_state // self.n_head
q = q.view(n_batch, n_ctx, self.n_head, head_dim).transpose(1, 2)
k = k.view(n_batch, n_ctx, self.n_head, head_dim).transpose(1, 2)
v = v.view(n_batch, n_ctx, self.n_head, head_dim).transpose(1, 2)
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(n_batch, n_ctx, n_state)
return self.out(out)
class _ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.attn = _MultiHeadAttention(n_state, n_head)
self.attn_ln = _LayerNorm(n_state)
n_mlp = n_state * 4
self.mlp = nn.Sequential(nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state))
self.mlp_ln = _LayerNorm(n_state)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.attn_ln(x))
x = x + self.mlp(self.mlp_ln(x))
return x
class _WhisperAudioEncoder(nn.Module):
"""Whisper-style audio encoder. Takes a precomputed log-mel spectrogram."""
def __init__(
self,
n_mels: int = 80,
n_ctx: int = 1500,
n_state: int = 768,
n_head: int = 12,
n_layer: int = 12,
output_dim: int = 768,
):
super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", _sinusoids(n_ctx, n_state))
self.blocks = nn.ModuleList(
[_ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = _LayerNorm(n_state)
self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
self.proj = nn.Linear(n_state, output_dim)
def forward(self, mel: torch.Tensor) -> torch.Tensor:
# mel: (B, n_mels, T_mel)
x = F.gelu(self.conv1(mel))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1) # (B, T', D)
T = x.size(1)
x = x + self.positional_embedding[:T].to(dtype=x.dtype, device=x.device)
for block in self.blocks:
x = block(x)
x = x.permute(0, 2, 1)
x = self.avg_pooler(x)
x = x.permute(0, 2, 1)
x = self.ln_post(x)
x = self.proj(x)
return x
class VoiceCLAPSmall(PreTrainedModel):
config_class = VoiceCLAPSmallConfig
def __init__(self, config: VoiceCLAPSmallConfig):
super().__init__(config)
self.audio_encoder = _WhisperAudioEncoder(
n_mels=config.n_mels,
n_ctx=config.n_ctx,
n_state=config.n_state,
n_head=config.n_head,
n_layer=config.n_layer,
output_dim=config.embed_dim,
)
self.audio_proj = nn.Sequential(
nn.Linear(config.embed_dim, config.embed_dim),
nn.GELU(),
nn.Linear(config.embed_dim, config.embed_dim),
)
bert_config = BertConfig(
vocab_size=config.text_vocab_size,
hidden_size=config.text_hidden_dim,
num_hidden_layers=config.text_num_layers,
num_attention_heads=config.text_num_heads,
intermediate_size=config.text_intermediate_size,
max_position_embeddings=config.text_max_position_embeddings,
layer_norm_eps=config.text_layer_norm_eps,
pad_token_id=config.text_pad_token_id,
)
self.text_encoder = BertModel(bert_config, add_pooling_layer=False)
self.text_proj = nn.Sequential(
nn.Linear(config.text_hidden_dim, config.text_proj_hidden, bias=False),
nn.GELU(),
nn.Linear(config.text_proj_hidden, config.embed_dim, bias=False),
)
self.logit_scale = nn.Parameter(torch.zeros(()))
self.logit_bias = nn.Parameter(torch.zeros(()))
# Mel filterbank used by encode_waveform / compute_log_mel.
# 80 mel bins x 201 freq bins for n_fft=400, sr=16000 (Whisper-style).
self.register_buffer(
"mel_filters",
torch.zeros(config.n_mels, 201),
persistent=True,
)
self.post_init()
@torch.no_grad()
def compute_log_mel(
self, waveform: torch.Tensor, sample_rate: int = 16000
) -> torch.Tensor:
"""Whisper-style log-mel spectrogram. waveform: (B, T) or (T,) at 16 kHz.
Returns (B, n_mels, T_mel). Matches the training-time preprocessing
bit-exactly so embeddings reproduce the published results.
"""
if sample_rate != 16000:
raise ValueError(f"sample_rate must be 16000, got {sample_rate}")
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
device = self.mel_filters.device
waveform = waveform.to(device=device, dtype=torch.float32)
window = torch.hann_window(400, device=device)
stft = torch.stft(waveform, n_fft=400, hop_length=160, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
mel = self.mel_filters.to(magnitudes.dtype) @ magnitudes
log_spec = torch.clamp(mel, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.amax(dim=(-2, -1), keepdim=True) - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
def encode_waveform(self, waveform: torch.Tensor, sample_rate: int = 16000) -> torch.Tensor:
"""Encode raw 16 kHz waveform; calls ``compute_log_mel`` then ``encode_audio``."""
mel = self.compute_log_mel(waveform, sample_rate=sample_rate)
return self.encode_audio(mel)
def encode_audio(self, mel: torch.Tensor) -> torch.Tensor:
feats = self.audio_encoder(mel) # (B, T', D)
feats = feats.mean(dim=1) # clip-level mean
feats = self.audio_proj(feats)
return F.normalize(feats, dim=-1)
def encode_text(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if attention_mask is None:
attention_mask = (input_ids != self.config.text_pad_token_id).long()
out = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
hidden = out.last_hidden_state # (B, T, H)
mask = attention_mask.unsqueeze(-1).to(hidden.dtype)
pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
feats = self.text_proj(pooled)
return F.normalize(feats, dim=-1)