Luigi's picture
Restore complete zipvoice package with all source files
2967cdb
"""
UTMOS strong model.
Implementation from https://github.com/tarepan/SpeechMOS
"""
import math
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
import torchaudio # pyright: ignore [reportMissingTypeStubs]
from torch import Tensor, nn
class UTMOS22Strong(nn.Module):
"""Saeki_2022 paper's `UTMOS strong learner` inference model
(w/o Phoneme encoder)."""
def __init__(self):
"""Init."""
super().__init__() # pyright: ignore [reportUnknownMemberType]
feat_ssl, feat_domain_emb, feat_judge_emb, feat_rnn_h, feat_proj_h = (
768,
128,
128,
512,
2048,
)
feat_cat = feat_ssl + feat_domain_emb + feat_judge_emb
# SSL/DataDomainEmb/JudgeIdEmb/BLSTM/Projection
self.wav2vec2 = Wav2Vec2Model()
self.domain_emb = nn.Parameter(
data=torch.empty(1, feat_domain_emb), requires_grad=False
)
self.judge_emb = nn.Parameter(
data=torch.empty(1, feat_judge_emb), requires_grad=False
)
self.blstm = nn.LSTM(
input_size=feat_cat,
hidden_size=feat_rnn_h,
batch_first=True,
bidirectional=True,
)
self.projection = nn.Sequential(
nn.Linear(feat_rnn_h * 2, feat_proj_h), nn.ReLU(), nn.Linear(feat_proj_h, 1)
)
def forward(self, wave: Tensor, sr: int) -> Tensor: # pylint: disable=invalid-name
"""wave-to-score :: (B, T) -> (B,)"""
# Feature extraction :: (B, T) -> (B, Frame, Feat)
unit_series = self.wav2vec2(wave)
bsz, frm, _ = unit_series.size()
# DataDomain/JudgeId Embedding's Batch/Time expansion ::
# (B=1, Feat) -> (B=bsz, Frame=frm, Feat)
domain_series = self.domain_emb.unsqueeze(1).expand(bsz, frm, -1)
judge_series = self.judge_emb.unsqueeze(1).expand(bsz, frm, -1)
# Feature concatenation :: (B, Frame, Feat=f1) + (B, Frame, Feat=f2) +
# (B, Frame, Feat=f3) -> (B, Frame, Feat=f1+f2+f3)
cat_series = torch.cat([unit_series, domain_series, judge_series], dim=2)
# Frame-scale score estimation :: (B, Frame, Feat) -> (B, Frame, Feat)
# -> (B, Frame, Feat=1) - BLSTM/Projection
feat_series = self.blstm(cat_series)[0]
score_series = self.projection(feat_series)
# Utterance-scale score :: (B, Frame, Feat=1) -> (B, Feat=1)
# -> (B,) - Time averaging
utter_score = score_series.mean(dim=1).squeeze(1) * 2 + 3
return utter_score
class Wav2Vec2Model(nn.Module):
"""Wav2Vev2."""
def __init__(self):
super().__init__() # pyright: ignore [reportUnknownMemberType]
feat_h1, feat_h2 = 512, 768
feature_enc_layers = (
[(feat_h1, 10, 5)] + [(feat_h1, 3, 2)] * 4 + [(feat_h1, 2, 2)] * 2
)
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers
) # pyright: ignore [reportGeneralTypeIssues]
self.layer_norm = nn.LayerNorm(feat_h1)
self.post_extract_proj = nn.Linear(feat_h1, feat_h2)
self.dropout_input = nn.Dropout(0.1)
self.encoder = TransformerEncoder(feat_h2)
# Remnants
self.mask_emb = nn.Parameter(torch.FloatTensor(feat_h2))
def forward(self, source: Tensor):
"""FeatureEncoder + ContextTransformer"""
# Feature encoding
features = self.feature_extractor(source)
features = features.transpose(1, 2)
features = self.layer_norm(features)
features = self.post_extract_proj(features)
# Context transformer
x = self.encoder(features)
return x
class ConvFeatureExtractionModel(nn.Module):
"""Feature Encoder."""
def __init__(self, conv_layers: List[Tuple[int, int, int]]):
super().__init__() # pyright: ignore [reportUnknownMemberType]
def block(
n_in: int, n_out: int, k: int, stride: int, is_group_norm: bool = False
):
if is_group_norm:
return nn.Sequential(
nn.Conv1d(n_in, n_out, k, stride=stride, bias=False),
nn.Dropout(p=0.0),
nn.GroupNorm(dim, dim, affine=True),
nn.GELU(),
)
else:
return nn.Sequential(
nn.Conv1d(n_in, n_out, k, stride=stride, bias=False),
nn.Dropout(p=0.0),
nn.GELU(),
)
in_d = 1
self.conv_layers = nn.ModuleList()
for i, params in enumerate(conv_layers):
(dim, k, stride) = params
self.conv_layers.append(block(in_d, dim, k, stride, is_group_norm=i == 0))
in_d = dim
def forward(self, series: Tensor) -> Tensor:
""":: (B, T) -> (B, Feat, Frame)"""
series = series.unsqueeze(1)
for conv in self.conv_layers:
series = conv(series)
return series
class TransformerEncoder(nn.Module):
"""Transformer."""
def build_encoder_layer(self, feat: int):
"""Layer builder."""
return TransformerSentenceEncoderLayer(
embedding_dim=feat,
ffn_embedding_dim=3072,
num_attention_heads=12,
activation_fn="gelu",
dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.0,
layer_norm_first=False,
)
def __init__(self, feat: int):
super().__init__() # pyright: ignore [reportUnknownMemberType]
self.required_seq_len_multiple = 2
self.pos_conv = nn.Sequential(
*[
nn.utils.weight_norm(
nn.Conv1d(feat, feat, kernel_size=128, padding=128 // 2, groups=16),
name="weight",
dim=2,
),
SamePad(128),
nn.GELU(),
]
)
self.layer_norm = nn.LayerNorm(feat)
self.layers = nn.ModuleList([self.build_encoder_layer(feat) for _ in range(12)])
def forward(self, x: Tensor) -> Tensor:
x_conv = self.pos_conv(x.transpose(1, 2)).transpose(1, 2)
x = x + x_conv
x = self.layer_norm(x)
# pad to the sequence length dimension
x, pad_length = pad_to_multiple(
x, self.required_seq_len_multiple, dim=-2, value=0
)
if pad_length > 0:
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
padding_mask[:, -pad_length:] = True
else:
padding_mask, _ = pad_to_multiple(
None, self.required_seq_len_multiple, dim=-1, value=True
)
# :: (B, T, Feat) -> (T, B, Feat)
x = x.transpose(0, 1)
for layer in self.layers:
x = layer(x, padding_mask)
# :: (T, B, Feat) -> (B, T, Feat)
x = x.transpose(0, 1)
# undo paddding
if pad_length > 0:
x = x[:, :-pad_length]
return x
class SamePad(nn.Module):
"""Tail inverse padding."""
def __init__(self, kernel_size: int):
super().__init__() # pyright: ignore [reportUnknownMemberType]
assert kernel_size % 2 == 0, "`SamePad` now support only even kernel."
def forward(self, x: Tensor) -> Tensor:
return x[:, :, :-1]
def pad_to_multiple(
x: Optional[Tensor], multiple: int, dim: int = -1, value: float = 0
) -> Tuple[Optional[Tensor], int]:
"""Tail padding."""
if x is None:
return None, 0
tsz = x.size(dim)
m = tsz / multiple
remainder = math.ceil(m) * multiple - tsz
if m.is_integer():
return x, 0
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
class TransformerSentenceEncoderLayer(nn.Module):
"""Transformer Encoder Layer used in BERT/XLM style pre-trained models."""
def __init__(
self,
embedding_dim: int,
ffn_embedding_dim: int,
num_attention_heads: int,
activation_fn: str,
dropout: float,
attention_dropout: float,
activation_dropout: float,
layer_norm_first: bool,
) -> None:
super().__init__() # pyright: ignore [reportUnknownMemberType]
assert layer_norm_first is False, "`layer_norm_first` is fixed to `False`"
assert activation_fn == "gelu", "`activation_fn` is fixed to `gelu`"
feat = embedding_dim
self.self_attn = MultiheadAttention(
feat, num_attention_heads, attention_dropout
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.fc1 = nn.Linear(feat, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, feat)
self.self_attn_layer_norm = nn.LayerNorm(feat)
self.final_layer_norm = nn.LayerNorm(feat)
def forward(self, x: Tensor, self_attn_padding_mask: Optional[Tensor]):
# Res[Attn-Do]-LN
residual = x
x = self.self_attn(x, x, x, self_attn_padding_mask)
x = self.dropout1(x)
x = residual + x
x = self.self_attn_layer_norm(x)
# Res[SegFC-GELU-Do-SegFC-Do]-LN
residual = x
x = F.gelu(self.fc1(x)) # pyright: ignore [reportUnknownMemberType]
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
x = self.final_layer_norm(x)
return x
class MultiheadAttention(nn.Module):
"""Multi-headed attention."""
def __init__(self, embed_dim: int, num_heads: int, dropout: float):
super().__init__() # pyright: ignore [reportUnknownMemberType]
self.embed_dim, self.num_heads, self.p_dropout = embed_dim, num_heads, dropout
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor],
) -> Tensor:
"""
Args:
query :: (T, B, Feat)
key_padding_mask :: (B, src_len) - mask to exclude keys that are pads
, where padding elements are indicated by 1s.
"""
return F.multi_head_attention_forward(
query=query,
key=key,
value=value,
embed_dim_to_check=self.embed_dim,
num_heads=self.num_heads,
in_proj_weight=torch.empty([0]),
in_proj_bias=torch.cat(
(self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)
),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=self.p_dropout,
out_proj_weight=self.out_proj.weight,
out_proj_bias=self.out_proj.bias,
training=False,
key_padding_mask=key_padding_mask.bool()
if key_padding_mask is not None
else None,
need_weights=False,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)[0]