| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModel |
| from transformers.modeling_utils import PreTrainedModel ,PretrainedConfig |
|
|
|
|
|
|
|
|
| class Pooling(nn.Module): |
| def __init__(self): |
| super().__init__() |
| def compute_length_from_mask(self, mask): |
| """ |
| mask: (batch_size, T) |
| Assuming that the sampling rate is 16kHz, the frame shift is 20ms |
| """ |
| wav_lens = torch.sum(mask, dim=1) |
| feat_lens = torch.div(wav_lens-1, 16000*0.02, rounding_mode="floor") + 1 |
| feat_lens = feat_lens.int().tolist() |
| return feat_lens |
| |
| def forward(self, x, mask): |
| raise NotImplementedError |
| |
| class MeanPooling(Pooling): |
| def __init__(self): |
| super().__init__() |
| def forward(self, xs, mask): |
| """ |
| xs: (batch_size, T, feat_dim) |
| mask: (batch_size, T) |
| |
| => output: (batch_size, feat_dim) |
| """ |
| feat_lens = self.compute_length_from_mask(mask) |
| pooled_list = [] |
| for x, feat_len in zip(xs, feat_lens): |
| pooled = torch.mean(x[:feat_len], dim=0) |
| pooled_list.append(pooled) |
| pooled = torch.stack(pooled_list, dim=0) |
| return pooled |
| |
|
|
| class AttentiveStatisticsPooling(Pooling): |
| """ |
| AttentiveStatisticsPooling |
| Paper: Attentive Statistics Pooling for Deep Speaker Embedding |
| Link: https://arxiv.org/pdf/1803.10963.pdf |
| """ |
| def __init__(self, input_size): |
| super().__init__() |
| self._indim = input_size |
| self.sap_linear = nn.Linear(input_size, input_size) |
| self.attention = nn.Parameter(torch.FloatTensor(input_size, 1)) |
| torch.nn.init.normal_(self.attention, mean=0, std=1) |
|
|
| def forward(self, xs, mask): |
| """ |
| xs: (batch_size, T, feat_dim) |
| mask: (batch_size, T) |
| |
| => output: (batch_size, feat_dim*2) |
| """ |
| feat_lens = self.compute_length_from_mask(mask) |
| pooled_list = [] |
| for x, feat_len in zip(xs, feat_lens): |
| x = x[:feat_len].unsqueeze(0) |
| h = torch.tanh(self.sap_linear(x)) |
| w = torch.matmul(h, self.attention).squeeze(dim=2) |
| w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1) |
| mu = torch.sum(x * w, dim=1) |
| rh = torch.sqrt((torch.sum((x**2) * w, dim=1) - mu**2).clamp(min=1e-5)) |
| x = torch.cat((mu, rh), 1).squeeze(0) |
| pooled_list.append(x) |
| return torch.stack(pooled_list) |
|
|
|
|
|
|
| |
| class EmotionRegression(nn.Module): |
| def __init__(self, *args, **kwargs): |
| super(EmotionRegression, self).__init__() |
| input_dim = args[0] |
| hidden_dim = args[1] |
| num_layers = args[2] |
| output_dim = args[3] |
| p = kwargs.get("dropout", 0.5) |
|
|
| self.fc=nn.ModuleList([ |
| nn.Sequential( |
| nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(p) |
| ) |
| ]) |
| for lidx in range(num_layers-1): |
| self.fc.append( |
| nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(p) |
| ) |
| ) |
| self.out = nn.Sequential( |
| nn.Linear(hidden_dim, output_dim) |
| ) |
| |
| self.inp_drop = nn.Dropout(p) |
| def get_repr(self, x): |
| h = self.inp_drop(x) |
| for lidx, fc in enumerate(self.fc): |
| h=fc(h) |
| return h |
| |
| def forward(self, x): |
| h=self.get_repr(x) |
| result = self.out(h) |
| return result |
| |
| class SERConfig(PretrainedConfig): |
| model_type = "ser" |
|
|
| def __init__( |
| self, |
| num_classes: int = 3, |
| num_attention_heads = 16, |
| num_hidden_layers = 24, |
| hidden_size = 1024, |
| classifier_hidden_layers = 1, |
| classifier_dropout_prob = 0.5, |
| ssl_type= "microsoft/wavlm-large", |
| torch_dtype= "float32", |
| **kwargs, |
| ): |
| self.num_classes = num_classes |
| self.num_attention_heads = num_attention_heads |
| self.num_hidden_layers = num_hidden_layers |
| self.hidden_size = hidden_size |
| self.classifier_hidden_layers = classifier_hidden_layers |
| self.classifier_dropout_prob = classifier_dropout_prob |
| self.ssl_type = ssl_type |
| self.torch_dtype = torch_dtype |
| super().__init__(**kwargs) |
| |
| class SERModel(PreTrainedModel): |
| config_class = SERConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.ssl_model = AutoModel.from_pretrained(config.ssl_type) |
| self.ssl_model.freeze_feature_encoder() |
| |
| self.pool_model = AttentiveStatisticsPooling(config.hidden_size) |
| |
| self.ser_model = EmotionRegression(config.hidden_size*2, |
| config.hidden_size, |
| config.classifier_hidden_layers, |
| config.num_classes, |
| dropout=config.classifier_dropout_prob) |
| |
| |
| def forward(self, x, mask): |
| ssl = self.ssl_model(x, attention_mask=mask).last_hidden_state |
|
|
| ssl = self.pool_model(ssl, mask) |
| |
| pred = self.ser_model(ssl) |
| |
| return pred |
| |
|
|