CA_NSBS / src /models /neural_sbs.py
1NEYRON1's picture
Update src/models/neural_sbs.py
47edef8 verified
from src.models.base_model import BaseModel, BaseMultimodalModel
import torch
import numpy as np
import torchaudio
from transformers import HubertModel, HubertConfig, BertModel, BertTokenizer
import torch.nn as nn
class SkewSimilarity(nn.Module):
def __init__(self, embedding_dim):
super(SkewSimilarity, self).__init__()
std = np.sqrt(2.0 / embedding_dim)
self.J = nn.Parameter(std * torch.randn(embedding_dim, embedding_dim))
def forward(self, x, y):
J_ = 0.5 * (self.J - self.J.transpose(1, 0))
return torch.sum(torch.matmul(x, J_) * y, dim=-1, keepdim=True)
class L2Normalize(nn.Module):
def __init__(self, eps=1e-5):
super(L2Normalize, self).__init__()
self.eps = eps
def forward(self, x):
return x / (torch.norm(x, dim=-1, keepdim=True) + self.eps)
class ResampleAudio:
def __init__(self, new_sample_rate=16000):
self.new_sample_rate = new_sample_rate
def __call__(self, audio_tensor, original_sample_rate):
return torchaudio.functional.resample(audio_tensor, original_sample_rate, self.new_sample_rate)
class NeuralSBS(nn.Module, BaseModel):
def __init__(self,
backbone_name="facebook/hubert-base-ls960",
embedding_dim=768,
device='cpu',
weights = 'src/weights/nsbs_small_v6.pth'):
nn.Module.__init__(self)
self.device = device
self.resampler = ResampleAudio(new_sample_rate=16000)
self.config = HubertConfig.from_pretrained(backbone_name)
self.features = HubertModel.from_pretrained(backbone_name, config=self.config)
self.features.config.output_hidden_states = True
self.hubert_output_dim = self.config.hidden_size
if self.hubert_output_dim != embedding_dim:
self.projection = nn.Linear(self.hubert_output_dim, embedding_dim)
else:
self.projection = nn.Identity()
self.embedding_dim = embedding_dim
self.norm = L2Normalize()
self.similarity = SkewSimilarity(embedding_dim=embedding_dim)
for layer in self.features.encoder.layers:
for param in layer.parameters():
param.requires_grad = True
BaseModel.__init__(self, weights)
def _load_weights(self, weights: str = 'src/weights/nsbs_small_v6.pth') -> torch.nn.Module:
self.load_state_dict(torch.load(weights, weights_only=True, map_location=self.device))
self.eval()
return self
def preprocess_audios(self, audio) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(audio)
audio = self.resampler(waveform, sample_rate)
desired_length = 48000
if audio.shape[1] < desired_length:
padding = desired_length - audio.shape[1]
audio = torch.nn.functional.pad(audio, (0, padding))
elif audio.shape[1] > desired_length:
audio = audio[:, :desired_length]
return audio
def forward(self, x):
B, N, S = x.shape # Batch size, Number of audio clips (2), Samples
inp = x.view(-1, S) # [batch_size * 2, num_samples]
outputs = self.features(inp).last_hidden_state
f = torch.mean(outputs, dim=1)
f = f.view(B, 2, -1) # [batch_size, 2, hubert_output_dim]
f = self.projection(f)
f1, f2 = f[:, 0, :], f[:, 1, :]
f1 = self.norm(f1)
f2 = self.norm(f2)
d = self.similarity(f1, f2)
return d
def predict(self, audios):
first_audio = self.preprocess_audios(audios[0]).unsqueeze(0)
second_audio = self.preprocess_audios(audios[1]).unsqueeze(0)
audio_pair = torch.cat([first_audio, second_audio], dim=1)
# audio_pair = torch.stack([first_audio, second_audio])
return nn.Sigmoid()(self.forward(audio_pair))
class CrossAttention(nn.Module):
def __init__(self, audio_dim, text_dim, attention_dim, num_heads=2):
super().__init__()
self.audio_proj = nn.Linear(audio_dim, attention_dim)
self.text_proj = nn.Linear(text_dim, attention_dim)
self.mha = nn.MultiheadAttention(embed_dim=attention_dim, num_heads=num_heads, batch_first=True)
self.pre_attention_norm = nn.LayerNorm(attention_dim)
self.post_attention_norm = nn.LayerNorm(attention_dim)
self.ffn = nn.Sequential(
nn.Linear(attention_dim, attention_dim * 4),
nn.GELU(),
nn.Linear(attention_dim * 4, attention_dim)
)
self.ffn_norm = nn.LayerNorm(attention_dim)
def forward(self, audio_features, text_features):
# audio_features: [B, audio_seq_len, audio_dim]
# text_features: [B, text_seq_len, text_dim]
query = self.audio_proj(audio_features)
key = self.text_proj(text_features)
value = key
query = self.pre_attention_norm(query)
attn_output, attn_weights = self.mha(query, key, value, need_weights=True)
attn_output = query + attn_output
attn_output = self.post_attention_norm(attn_output)
ffn_output = self.ffn(attn_output)
output = attn_output + ffn_output
output = self.ffn_norm(output)
return output, attn_weights
class MultiModalNeuralSBS(nn.Module, BaseMultimodalModel):
def __init__(self,
backbone_name="facebook/hubert-base-ls960",
bert_model_name="google-bert/bert-base-uncased",
embedding_dim=768,
attention_dim=512,
max_text_seq_length=512,
device='cpu',
weights = 'src/weights/ca_nsbs_eng_v0.pth'):
nn.Module.__init__(self)
self.device = device
self.config = HubertConfig.from_pretrained(backbone_name)
self.features = HubertModel.from_pretrained(backbone_name, config=self.config)
self.features.config.output_hidden_states = True
self.hubert_output_dim = self.config.hidden_size
self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
self.bert_model = BertModel.from_pretrained(bert_model_name)
self.bert_output_dim = self.bert_model.config.hidden_size
self.cross_attention = CrossAttention(self.hubert_output_dim, self.bert_output_dim, attention_dim)
self.max_text_seq_length = max_text_seq_length
if attention_dim != embedding_dim:
self.projection = nn.Sequential(
nn.Linear(attention_dim, embedding_dim),
nn.LayerNorm(embedding_dim)
)
else:
self.projection = nn.Identity()
self.embedding_dim = embedding_dim
self.norm = L2Normalize()
self.similarity = SkewSimilarity(embedding_dim=embedding_dim)
self.resampler = ResampleAudio(new_sample_rate=16000)
for layer in self.features.encoder.layers:
for param in layer.parameters():
param.requires_grad = True
BaseMultimodalModel.__init__(self, weights)
def _load_weights(self, weights: str = 'src/weights/ca_nsbs_eng_v0.pth') -> torch.nn.Module:
self.load_state_dict(torch.load(weights, weights_only=True, map_location=self.device))
self.eval()
return self
def preprocess_audio(self, audio) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(audio)
audio = self.resampler(waveform, sample_rate)
desired_length = 48000
if audio.shape[1] < desired_length:
padding = desired_length - audio.shape[1]
audio = torch.nn.functional.pad(audio, (0, padding))
elif audio.shape[1] > desired_length:
audio = audio[:, :desired_length]
return audio
def preprocess_text(self, text) -> torch.Tensor:
text_inputs = self.bert_tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_text_seq_length
)
return text_inputs
def pad_attention_weights(self, attention_weights, max_text_seq_len):
B, audio_seq_len, text_seq_len = attention_weights.size()
if text_seq_len < max_text_seq_len:
padding = max_text_seq_len - text_seq_len
attention_weights = torch.nn.functional.pad(attention_weights, (0, padding))
elif text_seq_len > max_text_seq_len:
attention_weights = attention_weights[:, :, :max_text_seq_len]
return attention_weights
def forward(self, x, text_inputs):
B, N, S = x.shape
inp = x.view(-1, S)
audio_outputs = self.features(inp).last_hidden_state
audio_f = audio_outputs.view(B, N, -1, self.hubert_output_dim) # [B, N, audio_seq_len, hubert_output_dim]
text_outputs = self.bert_model(**text_inputs.to(x.device)).last_hidden_state # [B, text_seq_len, bert_output_dim]
cross_attended_features = []
attention_weights_list = []
for i in range(N):
audio_f_pair = audio_f[:, i, :, :] # [B, audio_seq_len, hubert_output_dim]
weighted_text_f, attention_weights = self.cross_attention(audio_f_pair, text_outputs)
cross_attended_features.append(weighted_text_f)
attention_weights = self.pad_attention_weights(attention_weights, self.max_text_seq_length)
attention_weights_list.append(attention_weights)
cross_attended_features = torch.stack(cross_attended_features, dim=1) # [B, N, audio_seq_len, attention_dim]
attention_weights = torch.stack(attention_weights_list, dim=1) # [B, N, audio_seq_len, max_text_seq_len]
f = torch.mean(cross_attended_features, dim=2) # [B, N, attention_dim]
f = self.projection(f) # [B, N, embedding_dim]
f1, f2 = f[:, 0, :], f[:, 1, :]
f1 = self.norm(f1)
f2 = self.norm(f2)
d = self.similarity(f1, f2)
return d, attention_weights
def predict(self, audios, text):
first_audio = self.preprocess_audio(audios[0]).unsqueeze(0)
second_audio = self.preprocess_audio(audios[1]).unsqueeze(0)
audio_pair = torch.cat([first_audio, second_audio], dim=1)
text = self.preprocess_text(text)
return nn.Sigmoid()(self.forward(audio_pair, text)[0])