| import torch |
| import torch.nn as nn |
|
|
| from .text_encoder import ( |
| AttnEncoder, |
| TextEmbedderWrapper, |
| ConvNeXtWrapper, |
| ) |
|
|
|
|
| class DPReferenceEncoder(nn.Module): |
| def __init__( |
| self, |
| in_channels: int = 144, |
| d_model: int = 64, |
| hidden_dim: int = 256, |
| num_blocks: int = 4, |
| num_queries: int = 8, |
| query_dim: int = 16, |
| num_heads: int = 2, |
| kernel_size: int = 5, |
| dilation_lst: list = None, |
| ): |
| super().__init__() |
| self.d_model = d_model |
| self.num_queries = num_queries |
| self.query_dim = query_dim |
| mlp_ratio = hidden_dim // d_model |
|
|
| self.input_proj = nn.Conv1d(in_channels, d_model, kernel_size=1) |
| self.convnext = ConvNeXtWrapper( |
| d_model, |
| n_layers=num_blocks, |
| expansion_factor=mlp_ratio, |
| kernel_size=kernel_size, |
| dilation_lst=dilation_lst, |
| ) |
| self.ref_keys = nn.Parameter(torch.randn(num_queries, query_dim) * 0.02) |
| self.attn1 = nn.MultiheadAttention( |
| embed_dim=query_dim, num_heads=num_heads, kdim=d_model, vdim=d_model, batch_first=True |
| ) |
| self.attn2 = nn.MultiheadAttention( |
| embed_dim=query_dim, num_heads=num_heads, kdim=d_model, vdim=d_model, batch_first=True |
| ) |
|
|
| def forward(self, z_ref: torch.Tensor, mask: torch.Tensor = None): |
| B = z_ref.shape[0] |
| x = self.input_proj(z_ref) |
| x = self.convnext(x, mask=mask) |
| kv = x.transpose(1, 2) |
|
|
| key_padding_mask = None |
| if mask is not None: |
| key_padding_mask = (mask.squeeze(1) == 0) |
|
|
| q0 = self.ref_keys.unsqueeze(0).expand(B, -1, -1) |
| q1, _ = self.attn1(query=q0, key=kv, value=kv, key_padding_mask=key_padding_mask, need_weights=False) |
| q2 = q0 + q1 |
| out, _ = self.attn2(query=q2, key=kv, value=kv, key_padding_mask=key_padding_mask, need_weights=False) |
| return out.reshape(B, -1) |
|
|
|
|
| class DPTextEncoder(nn.Module): |
| def __init__(self, vocab_size=37, d_model=64): |
| super().__init__() |
| self.d_model = d_model |
| self.text_embedder = TextEmbedderWrapper(vocab_size, d_model) |
| self.convnext = ConvNeXtWrapper(d_model, n_layers=6, expansion_factor=4) |
| self.sentence_token = nn.Parameter(torch.randn(1, d_model, 1) * 0.02) |
| self.attn_encoder = AttnEncoder( |
| channels=d_model, |
| n_heads=2, |
| filter_channels=d_model * 4, |
| n_layers=2, |
| ) |
| self.proj_out = nn.Sequential() |
| self.proj_out.add_module("net", nn.Conv1d(d_model, d_model, 1, bias=False)) |
|
|
| def forward(self, text_ids, mask=None): |
| B, T = text_ids.shape |
| x = self.text_embedder(text_ids) |
| x = x.transpose(1, 2) |
| if mask is not None: |
| x = x * mask |
|
|
| u_token = self.sentence_token.expand(B, -1, -1) |
| x = torch.cat([u_token, x], dim=2) |
|
|
| if mask is not None: |
| mask_u = torch.ones(B, 1, 1, device=mask.device) |
| mask = torch.cat([mask_u, mask], dim=2) |
|
|
| x = self.convnext(x, mask=mask) |
| conv_out = x |
| x = self.attn_encoder(x, mask=mask) |
| x = x + conv_out |
|
|
| first_token = x[:, :, :1] |
| out = self.proj_out(first_token) |
| if mask is not None: |
| out = out * mask[:, :, :1] |
| return out.squeeze(2) |
|
|
|
|
| class DurationEstimator(nn.Module): |
| def __init__(self, text_dim=64, style_dim=128): |
| super().__init__() |
| self.layers = nn.ModuleList([ |
| nn.Linear(text_dim + style_dim, 128), |
| nn.Linear(128, 1), |
| ]) |
| self.activation = nn.PReLU() |
|
|
| def forward(self, text_emb, style_emb, text_mask=None, return_log=False): |
| if style_emb.dim() > 2: |
| style_emb = style_emb.reshape(style_emb.shape[0], -1) |
| x = torch.cat([text_emb, style_emb], dim=1) |
| x = self.layers[0](x) |
| x = self.activation(x) |
| x = self.layers[1](x) |
| if return_log: |
| return x.squeeze(1) |
| return torch.exp(x).squeeze(1) |
|
|
|
|
| class TTSDurationModel(nn.Module): |
| def __init__( |
| self, |
| vocab_size=37, |
| style_dp=8, |
| style_dim=16, |
| ref_in_channels=144, |
| sentence_encoder_cfg=None, |
| style_encoder_cfg=None, |
| predictor_cfg=None, |
| ): |
| super().__init__() |
| self.vocab_size = vocab_size |
|
|
| se_cfg = sentence_encoder_cfg or {} |
| st_cfg = style_encoder_cfg or {} |
| pr_cfg = predictor_cfg or {} |
|
|
| se_d_model = se_cfg.get("char_emb_dim", 64) |
|
|
| st_proj = st_cfg.get("proj_in", {}) |
| st_d_model = st_proj.get("odim", 64) |
|
|
| st_convnext = st_cfg.get("convnext", {}) |
| st_hidden_dim = st_convnext.get("intermediate_dim", 256) |
| st_num_blocks = st_convnext.get("num_layers", 4) |
| st_dilation = st_convnext.get("dilation_lst", None) |
|
|
| st_token_layer = st_cfg.get("style_token_layer", {}) |
| st_num_queries = st_token_layer.get("n_style", style_dp) |
| st_query_dim = st_token_layer.get("style_value_dim", style_dim) |
| st_num_heads = st_token_layer.get("n_heads", 2) |
|
|
| pr_text_dim = pr_cfg.get("sentence_dim", 64) |
| pr_style_dim = pr_cfg.get("n_style", st_num_queries) * pr_cfg.get("style_dim", st_query_dim) |
|
|
| self.sentence_encoder = DPTextEncoder(vocab_size=vocab_size, d_model=se_d_model) |
| self.ref_encoder = DPReferenceEncoder( |
| in_channels=ref_in_channels, |
| d_model=st_d_model, |
| hidden_dim=st_hidden_dim, |
| num_blocks=st_num_blocks, |
| num_queries=st_num_queries, |
| query_dim=st_query_dim, |
| num_heads=st_num_heads, |
| dilation_lst=st_dilation, |
| ) |
| self.predictor = DurationEstimator(text_dim=pr_text_dim, style_dim=pr_style_dim) |
|
|
| def forward(self, text_ids, z_ref=None, text_mask=None, ref_mask=None, style_dp=None, return_log=False): |
| text_emb = self.sentence_encoder(text_ids, mask=text_mask) |
|
|
| if style_dp is not None: |
| style_emb = style_dp |
| elif z_ref is not None: |
| style_emb = self.ref_encoder(z_ref, mask=ref_mask) |
| else: |
| raise ValueError("Either z_ref or style_dp must be provided") |
|
|
| return self.predictor(text_emb, style_emb, text_mask=text_mask, return_log=return_log) |
|
|