| from src.models.base_model import BaseModel, BaseMultimodalModel |
| import torch |
| import numpy as np |
| import torchaudio |
| from transformers import HubertModel, HubertConfig, BertModel, BertTokenizer |
| import torch.nn as nn |
|
|
|
|
| class SkewSimilarity(nn.Module): |
| def __init__(self, embedding_dim): |
| super(SkewSimilarity, self).__init__() |
| std = np.sqrt(2.0 / embedding_dim) |
| self.J = nn.Parameter(std * torch.randn(embedding_dim, embedding_dim)) |
|
|
| def forward(self, x, y): |
| J_ = 0.5 * (self.J - self.J.transpose(1, 0)) |
| return torch.sum(torch.matmul(x, J_) * y, dim=-1, keepdim=True) |
|
|
|
|
| class L2Normalize(nn.Module): |
| def __init__(self, eps=1e-5): |
| super(L2Normalize, self).__init__() |
| self.eps = eps |
|
|
| def forward(self, x): |
| return x / (torch.norm(x, dim=-1, keepdim=True) + self.eps) |
|
|
|
|
| class ResampleAudio: |
| def __init__(self, new_sample_rate=16000): |
| self.new_sample_rate = new_sample_rate |
|
|
| def __call__(self, audio_tensor, original_sample_rate): |
| return torchaudio.functional.resample(audio_tensor, original_sample_rate, self.new_sample_rate) |
|
|
|
|
| class NeuralSBS(nn.Module, BaseModel): |
| def __init__(self, |
| backbone_name="facebook/hubert-base-ls960", |
| embedding_dim=768, |
| device='cpu', |
| weights = 'src/weights/nsbs_small_v6.pth'): |
|
|
| nn.Module.__init__(self) |
| self.device = device |
| |
| self.resampler = ResampleAudio(new_sample_rate=16000) |
|
|
| self.config = HubertConfig.from_pretrained(backbone_name) |
| self.features = HubertModel.from_pretrained(backbone_name, config=self.config) |
| self.features.config.output_hidden_states = True |
| self.hubert_output_dim = self.config.hidden_size |
|
|
| if self.hubert_output_dim != embedding_dim: |
| self.projection = nn.Linear(self.hubert_output_dim, embedding_dim) |
| else: |
| self.projection = nn.Identity() |
|
|
| self.embedding_dim = embedding_dim |
| self.norm = L2Normalize() |
| self.similarity = SkewSimilarity(embedding_dim=embedding_dim) |
|
|
| for layer in self.features.encoder.layers: |
| for param in layer.parameters(): |
| param.requires_grad = True |
|
|
| BaseModel.__init__(self, weights) |
|
|
| def _load_weights(self, weights: str = 'src/weights/nsbs_small_v6.pth') -> torch.nn.Module: |
| self.load_state_dict(torch.load(weights, weights_only=True, map_location=self.device)) |
| self.eval() |
| return self |
|
|
| def preprocess_audios(self, audio) -> torch.Tensor: |
| waveform, sample_rate = torchaudio.load(audio) |
| audio = self.resampler(waveform, sample_rate) |
| desired_length = 48000 |
| if audio.shape[1] < desired_length: |
| padding = desired_length - audio.shape[1] |
| audio = torch.nn.functional.pad(audio, (0, padding)) |
| elif audio.shape[1] > desired_length: |
| audio = audio[:, :desired_length] |
| return audio |
|
|
| def forward(self, x): |
| B, N, S = x.shape |
| inp = x.view(-1, S) |
|
|
| outputs = self.features(inp).last_hidden_state |
| f = torch.mean(outputs, dim=1) |
| f = f.view(B, 2, -1) |
| f = self.projection(f) |
|
|
| f1, f2 = f[:, 0, :], f[:, 1, :] |
|
|
| f1 = self.norm(f1) |
| f2 = self.norm(f2) |
|
|
| d = self.similarity(f1, f2) |
|
|
| return d |
|
|
| def predict(self, audios): |
| first_audio = self.preprocess_audios(audios[0]).unsqueeze(0) |
| second_audio = self.preprocess_audios(audios[1]).unsqueeze(0) |
| audio_pair = torch.cat([first_audio, second_audio], dim=1) |
| |
| return nn.Sigmoid()(self.forward(audio_pair)) |
|
|
|
|
| class CrossAttention(nn.Module): |
| def __init__(self, audio_dim, text_dim, attention_dim, num_heads=2): |
| super().__init__() |
| self.audio_proj = nn.Linear(audio_dim, attention_dim) |
| self.text_proj = nn.Linear(text_dim, attention_dim) |
| self.mha = nn.MultiheadAttention(embed_dim=attention_dim, num_heads=num_heads, batch_first=True) |
|
|
| self.pre_attention_norm = nn.LayerNorm(attention_dim) |
| self.post_attention_norm = nn.LayerNorm(attention_dim) |
| self.ffn = nn.Sequential( |
| nn.Linear(attention_dim, attention_dim * 4), |
| nn.GELU(), |
| nn.Linear(attention_dim * 4, attention_dim) |
| ) |
| self.ffn_norm = nn.LayerNorm(attention_dim) |
|
|
| def forward(self, audio_features, text_features): |
| |
| |
| query = self.audio_proj(audio_features) |
| key = self.text_proj(text_features) |
|
|
| value = key |
| query = self.pre_attention_norm(query) |
|
|
| attn_output, attn_weights = self.mha(query, key, value, need_weights=True) |
| attn_output = query + attn_output |
|
|
| attn_output = self.post_attention_norm(attn_output) |
|
|
| ffn_output = self.ffn(attn_output) |
| output = attn_output + ffn_output |
| output = self.ffn_norm(output) |
|
|
| return output, attn_weights |
|
|
|
|
| class MultiModalNeuralSBS(nn.Module, BaseMultimodalModel): |
| def __init__(self, |
| backbone_name="facebook/hubert-base-ls960", |
| bert_model_name="google-bert/bert-base-uncased", |
| embedding_dim=768, |
| attention_dim=512, |
| max_text_seq_length=512, |
| device='cpu', |
| weights = 'src/weights/ca_nsbs_eng_v0.pth'): |
| nn.Module.__init__(self) |
|
|
| self.device = device |
|
|
| self.config = HubertConfig.from_pretrained(backbone_name) |
| self.features = HubertModel.from_pretrained(backbone_name, config=self.config) |
| self.features.config.output_hidden_states = True |
| self.hubert_output_dim = self.config.hidden_size |
|
|
| self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name) |
| self.bert_model = BertModel.from_pretrained(bert_model_name) |
| self.bert_output_dim = self.bert_model.config.hidden_size |
|
|
| self.cross_attention = CrossAttention(self.hubert_output_dim, self.bert_output_dim, attention_dim) |
| self.max_text_seq_length = max_text_seq_length |
|
|
| if attention_dim != embedding_dim: |
| self.projection = nn.Sequential( |
| nn.Linear(attention_dim, embedding_dim), |
| nn.LayerNorm(embedding_dim) |
| ) |
| else: |
| self.projection = nn.Identity() |
|
|
| self.embedding_dim = embedding_dim |
| self.norm = L2Normalize() |
| self.similarity = SkewSimilarity(embedding_dim=embedding_dim) |
|
|
| self.resampler = ResampleAudio(new_sample_rate=16000) |
|
|
| for layer in self.features.encoder.layers: |
| for param in layer.parameters(): |
| param.requires_grad = True |
|
|
| BaseMultimodalModel.__init__(self, weights) |
|
|
| def _load_weights(self, weights: str = 'src/weights/ca_nsbs_eng_v0.pth') -> torch.nn.Module: |
| self.load_state_dict(torch.load(weights, weights_only=True, map_location=self.device)) |
| self.eval() |
| return self |
|
|
| def preprocess_audio(self, audio) -> torch.Tensor: |
| waveform, sample_rate = torchaudio.load(audio) |
| audio = self.resampler(waveform, sample_rate) |
| desired_length = 48000 |
| if audio.shape[1] < desired_length: |
| padding = desired_length - audio.shape[1] |
| audio = torch.nn.functional.pad(audio, (0, padding)) |
| elif audio.shape[1] > desired_length: |
| audio = audio[:, :desired_length] |
| return audio |
|
|
| def preprocess_text(self, text) -> torch.Tensor: |
| text_inputs = self.bert_tokenizer( |
| text, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=self.max_text_seq_length |
| ) |
| return text_inputs |
|
|
| def pad_attention_weights(self, attention_weights, max_text_seq_len): |
| B, audio_seq_len, text_seq_len = attention_weights.size() |
| if text_seq_len < max_text_seq_len: |
| padding = max_text_seq_len - text_seq_len |
| attention_weights = torch.nn.functional.pad(attention_weights, (0, padding)) |
| elif text_seq_len > max_text_seq_len: |
| attention_weights = attention_weights[:, :, :max_text_seq_len] |
| return attention_weights |
|
|
| def forward(self, x, text_inputs): |
| B, N, S = x.shape |
|
|
| inp = x.view(-1, S) |
| audio_outputs = self.features(inp).last_hidden_state |
| audio_f = audio_outputs.view(B, N, -1, self.hubert_output_dim) |
|
|
| text_outputs = self.bert_model(**text_inputs.to(x.device)).last_hidden_state |
|
|
| cross_attended_features = [] |
| attention_weights_list = [] |
| for i in range(N): |
| audio_f_pair = audio_f[:, i, :, :] |
| weighted_text_f, attention_weights = self.cross_attention(audio_f_pair, text_outputs) |
| cross_attended_features.append(weighted_text_f) |
|
|
| attention_weights = self.pad_attention_weights(attention_weights, self.max_text_seq_length) |
| attention_weights_list.append(attention_weights) |
|
|
| cross_attended_features = torch.stack(cross_attended_features, dim=1) |
| attention_weights = torch.stack(attention_weights_list, dim=1) |
|
|
| f = torch.mean(cross_attended_features, dim=2) |
| f = self.projection(f) |
|
|
| f1, f2 = f[:, 0, :], f[:, 1, :] |
| f1 = self.norm(f1) |
| f2 = self.norm(f2) |
|
|
| d = self.similarity(f1, f2) |
| return d, attention_weights |
|
|
| def predict(self, audios, text): |
| first_audio = self.preprocess_audio(audios[0]).unsqueeze(0) |
| second_audio = self.preprocess_audio(audios[1]).unsqueeze(0) |
| audio_pair = torch.cat([first_audio, second_audio], dim=1) |
| text = self.preprocess_text(text) |
| return nn.Sigmoid()(self.forward(audio_pair, text)[0]) |
|
|