""" UTMOS strong model. Implementation from https://github.com/tarepan/SpeechMOS """ import math from typing import List, Optional, Tuple import torch import torch.nn.functional as F import torchaudio # pyright: ignore [reportMissingTypeStubs] from torch import Tensor, nn class UTMOS22Strong(nn.Module): """Saeki_2022 paper's `UTMOS strong learner` inference model (w/o Phoneme encoder).""" def __init__(self): """Init.""" super().__init__() # pyright: ignore [reportUnknownMemberType] feat_ssl, feat_domain_emb, feat_judge_emb, feat_rnn_h, feat_proj_h = ( 768, 128, 128, 512, 2048, ) feat_cat = feat_ssl + feat_domain_emb + feat_judge_emb # SSL/DataDomainEmb/JudgeIdEmb/BLSTM/Projection self.wav2vec2 = Wav2Vec2Model() self.domain_emb = nn.Parameter( data=torch.empty(1, feat_domain_emb), requires_grad=False ) self.judge_emb = nn.Parameter( data=torch.empty(1, feat_judge_emb), requires_grad=False ) self.blstm = nn.LSTM( input_size=feat_cat, hidden_size=feat_rnn_h, batch_first=True, bidirectional=True, ) self.projection = nn.Sequential( nn.Linear(feat_rnn_h * 2, feat_proj_h), nn.ReLU(), nn.Linear(feat_proj_h, 1) ) def forward(self, wave: Tensor, sr: int) -> Tensor: # pylint: disable=invalid-name """wave-to-score :: (B, T) -> (B,)""" # Feature extraction :: (B, T) -> (B, Frame, Feat) unit_series = self.wav2vec2(wave) bsz, frm, _ = unit_series.size() # DataDomain/JudgeId Embedding's Batch/Time expansion :: # (B=1, Feat) -> (B=bsz, Frame=frm, Feat) domain_series = self.domain_emb.unsqueeze(1).expand(bsz, frm, -1) judge_series = self.judge_emb.unsqueeze(1).expand(bsz, frm, -1) # Feature concatenation :: (B, Frame, Feat=f1) + (B, Frame, Feat=f2) + # (B, Frame, Feat=f3) -> (B, Frame, Feat=f1+f2+f3) cat_series = torch.cat([unit_series, domain_series, judge_series], dim=2) # Frame-scale score estimation :: (B, Frame, Feat) -> (B, Frame, Feat) # -> (B, Frame, Feat=1) - BLSTM/Projection feat_series = self.blstm(cat_series)[0] score_series = self.projection(feat_series) # Utterance-scale score :: (B, Frame, Feat=1) -> (B, Feat=1) # -> (B,) - Time averaging utter_score = score_series.mean(dim=1).squeeze(1) * 2 + 3 return utter_score class Wav2Vec2Model(nn.Module): """Wav2Vev2.""" def __init__(self): super().__init__() # pyright: ignore [reportUnknownMemberType] feat_h1, feat_h2 = 512, 768 feature_enc_layers = ( [(feat_h1, 10, 5)] + [(feat_h1, 3, 2)] * 4 + [(feat_h1, 2, 2)] * 2 ) self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers ) # pyright: ignore [reportGeneralTypeIssues] self.layer_norm = nn.LayerNorm(feat_h1) self.post_extract_proj = nn.Linear(feat_h1, feat_h2) self.dropout_input = nn.Dropout(0.1) self.encoder = TransformerEncoder(feat_h2) # Remnants self.mask_emb = nn.Parameter(torch.FloatTensor(feat_h2)) def forward(self, source: Tensor): """FeatureEncoder + ContextTransformer""" # Feature encoding features = self.feature_extractor(source) features = features.transpose(1, 2) features = self.layer_norm(features) features = self.post_extract_proj(features) # Context transformer x = self.encoder(features) return x class ConvFeatureExtractionModel(nn.Module): """Feature Encoder.""" def __init__(self, conv_layers: List[Tuple[int, int, int]]): super().__init__() # pyright: ignore [reportUnknownMemberType] def block( n_in: int, n_out: int, k: int, stride: int, is_group_norm: bool = False ): if is_group_norm: return nn.Sequential( nn.Conv1d(n_in, n_out, k, stride=stride, bias=False), nn.Dropout(p=0.0), nn.GroupNorm(dim, dim, affine=True), nn.GELU(), ) else: return nn.Sequential( nn.Conv1d(n_in, n_out, k, stride=stride, bias=False), nn.Dropout(p=0.0), nn.GELU(), ) in_d = 1 self.conv_layers = nn.ModuleList() for i, params in enumerate(conv_layers): (dim, k, stride) = params self.conv_layers.append(block(in_d, dim, k, stride, is_group_norm=i == 0)) in_d = dim def forward(self, series: Tensor) -> Tensor: """:: (B, T) -> (B, Feat, Frame)""" series = series.unsqueeze(1) for conv in self.conv_layers: series = conv(series) return series class TransformerEncoder(nn.Module): """Transformer.""" def build_encoder_layer(self, feat: int): """Layer builder.""" return TransformerSentenceEncoderLayer( embedding_dim=feat, ffn_embedding_dim=3072, num_attention_heads=12, activation_fn="gelu", dropout=0.1, attention_dropout=0.1, activation_dropout=0.0, layer_norm_first=False, ) def __init__(self, feat: int): super().__init__() # pyright: ignore [reportUnknownMemberType] self.required_seq_len_multiple = 2 self.pos_conv = nn.Sequential( *[ nn.utils.weight_norm( nn.Conv1d(feat, feat, kernel_size=128, padding=128 // 2, groups=16), name="weight", dim=2, ), SamePad(128), nn.GELU(), ] ) self.layer_norm = nn.LayerNorm(feat) self.layers = nn.ModuleList([self.build_encoder_layer(feat) for _ in range(12)]) def forward(self, x: Tensor) -> Tensor: x_conv = self.pos_conv(x.transpose(1, 2)).transpose(1, 2) x = x + x_conv x = self.layer_norm(x) # pad to the sequence length dimension x, pad_length = pad_to_multiple( x, self.required_seq_len_multiple, dim=-2, value=0 ) if pad_length > 0: padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) padding_mask[:, -pad_length:] = True else: padding_mask, _ = pad_to_multiple( None, self.required_seq_len_multiple, dim=-1, value=True ) # :: (B, T, Feat) -> (T, B, Feat) x = x.transpose(0, 1) for layer in self.layers: x = layer(x, padding_mask) # :: (T, B, Feat) -> (B, T, Feat) x = x.transpose(0, 1) # undo paddding if pad_length > 0: x = x[:, :-pad_length] return x class SamePad(nn.Module): """Tail inverse padding.""" def __init__(self, kernel_size: int): super().__init__() # pyright: ignore [reportUnknownMemberType] assert kernel_size % 2 == 0, "`SamePad` now support only even kernel." def forward(self, x: Tensor) -> Tensor: return x[:, :, :-1] def pad_to_multiple( x: Optional[Tensor], multiple: int, dim: int = -1, value: float = 0 ) -> Tuple[Optional[Tensor], int]: """Tail padding.""" if x is None: return None, 0 tsz = x.size(dim) m = tsz / multiple remainder = math.ceil(m) * multiple - tsz if m.is_integer(): return x, 0 pad_offset = (0,) * (-1 - dim) * 2 return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder class TransformerSentenceEncoderLayer(nn.Module): """Transformer Encoder Layer used in BERT/XLM style pre-trained models.""" def __init__( self, embedding_dim: int, ffn_embedding_dim: int, num_attention_heads: int, activation_fn: str, dropout: float, attention_dropout: float, activation_dropout: float, layer_norm_first: bool, ) -> None: super().__init__() # pyright: ignore [reportUnknownMemberType] assert layer_norm_first is False, "`layer_norm_first` is fixed to `False`" assert activation_fn == "gelu", "`activation_fn` is fixed to `gelu`" feat = embedding_dim self.self_attn = MultiheadAttention( feat, num_attention_heads, attention_dropout ) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(activation_dropout) self.dropout3 = nn.Dropout(dropout) self.fc1 = nn.Linear(feat, ffn_embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, feat) self.self_attn_layer_norm = nn.LayerNorm(feat) self.final_layer_norm = nn.LayerNorm(feat) def forward(self, x: Tensor, self_attn_padding_mask: Optional[Tensor]): # Res[Attn-Do]-LN residual = x x = self.self_attn(x, x, x, self_attn_padding_mask) x = self.dropout1(x) x = residual + x x = self.self_attn_layer_norm(x) # Res[SegFC-GELU-Do-SegFC-Do]-LN residual = x x = F.gelu(self.fc1(x)) # pyright: ignore [reportUnknownMemberType] x = self.dropout2(x) x = self.fc2(x) x = self.dropout3(x) x = residual + x x = self.final_layer_norm(x) return x class MultiheadAttention(nn.Module): """Multi-headed attention.""" def __init__(self, embed_dim: int, num_heads: int, dropout: float): super().__init__() # pyright: ignore [reportUnknownMemberType] self.embed_dim, self.num_heads, self.p_dropout = embed_dim, num_heads, dropout self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) def forward( self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor], ) -> Tensor: """ Args: query :: (T, B, Feat) key_padding_mask :: (B, src_len) - mask to exclude keys that are pads , where padding elements are indicated by 1s. """ return F.multi_head_attention_forward( query=query, key=key, value=value, embed_dim_to_check=self.embed_dim, num_heads=self.num_heads, in_proj_weight=torch.empty([0]), in_proj_bias=torch.cat( (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias) ), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=self.p_dropout, out_proj_weight=self.out_proj.weight, out_proj_bias=self.out_proj.bias, training=False, key_padding_mask=key_padding_mask.bool() if key_padding_mask is not None else None, need_weights=False, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, )[0]