BlueV2 / models /duration_predictor.py
notmax123's picture
DP ref encoder uses compressed latent channels from config; README v2 onnx path; pytest smoke tests
8876cf1
import torch
import torch.nn as nn
from .text_encoder import (
AttnEncoder,
TextEmbedderWrapper,
ConvNeXtWrapper,
)
class DPReferenceEncoder(nn.Module):
def __init__(
self,
in_channels: int = 144,
d_model: int = 64,
hidden_dim: int = 256,
num_blocks: int = 4,
num_queries: int = 8,
query_dim: int = 16,
num_heads: int = 2,
kernel_size: int = 5,
dilation_lst: list = None,
):
super().__init__()
self.d_model = d_model
self.num_queries = num_queries
self.query_dim = query_dim
mlp_ratio = hidden_dim // d_model
self.input_proj = nn.Conv1d(in_channels, d_model, kernel_size=1)
self.convnext = ConvNeXtWrapper(
d_model,
n_layers=num_blocks,
expansion_factor=mlp_ratio,
kernel_size=kernel_size,
dilation_lst=dilation_lst,
)
self.ref_keys = nn.Parameter(torch.randn(num_queries, query_dim) * 0.02)
self.attn1 = nn.MultiheadAttention(
embed_dim=query_dim, num_heads=num_heads, kdim=d_model, vdim=d_model, batch_first=True
)
self.attn2 = nn.MultiheadAttention(
embed_dim=query_dim, num_heads=num_heads, kdim=d_model, vdim=d_model, batch_first=True
)
def forward(self, z_ref: torch.Tensor, mask: torch.Tensor = None):
B = z_ref.shape[0]
x = self.input_proj(z_ref)
x = self.convnext(x, mask=mask)
kv = x.transpose(1, 2)
key_padding_mask = None
if mask is not None:
key_padding_mask = (mask.squeeze(1) == 0)
q0 = self.ref_keys.unsqueeze(0).expand(B, -1, -1)
q1, _ = self.attn1(query=q0, key=kv, value=kv, key_padding_mask=key_padding_mask, need_weights=False)
q2 = q0 + q1
out, _ = self.attn2(query=q2, key=kv, value=kv, key_padding_mask=key_padding_mask, need_weights=False)
return out.reshape(B, -1)
class DPTextEncoder(nn.Module):
def __init__(self, vocab_size=37, d_model=64):
super().__init__()
self.d_model = d_model
self.text_embedder = TextEmbedderWrapper(vocab_size, d_model)
self.convnext = ConvNeXtWrapper(d_model, n_layers=6, expansion_factor=4)
self.sentence_token = nn.Parameter(torch.randn(1, d_model, 1) * 0.02)
self.attn_encoder = AttnEncoder(
channels=d_model,
n_heads=2,
filter_channels=d_model * 4,
n_layers=2,
)
self.proj_out = nn.Sequential()
self.proj_out.add_module("net", nn.Conv1d(d_model, d_model, 1, bias=False))
def forward(self, text_ids, mask=None):
B, T = text_ids.shape
x = self.text_embedder(text_ids)
x = x.transpose(1, 2)
if mask is not None:
x = x * mask
u_token = self.sentence_token.expand(B, -1, -1)
x = torch.cat([u_token, x], dim=2)
if mask is not None:
mask_u = torch.ones(B, 1, 1, device=mask.device)
mask = torch.cat([mask_u, mask], dim=2)
x = self.convnext(x, mask=mask)
conv_out = x
x = self.attn_encoder(x, mask=mask)
x = x + conv_out
first_token = x[:, :, :1]
out = self.proj_out(first_token)
if mask is not None:
out = out * mask[:, :, :1]
return out.squeeze(2)
class DurationEstimator(nn.Module):
def __init__(self, text_dim=64, style_dim=128):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(text_dim + style_dim, 128),
nn.Linear(128, 1),
])
self.activation = nn.PReLU()
def forward(self, text_emb, style_emb, text_mask=None, return_log=False):
if style_emb.dim() > 2:
style_emb = style_emb.reshape(style_emb.shape[0], -1)
x = torch.cat([text_emb, style_emb], dim=1)
x = self.layers[0](x)
x = self.activation(x)
x = self.layers[1](x)
if return_log:
return x.squeeze(1)
return torch.exp(x).squeeze(1)
class TTSDurationModel(nn.Module):
def __init__(
self,
vocab_size=37,
style_dp=8,
style_dim=16,
ref_in_channels=144,
sentence_encoder_cfg=None,
style_encoder_cfg=None,
predictor_cfg=None,
):
super().__init__()
self.vocab_size = vocab_size
se_cfg = sentence_encoder_cfg or {}
st_cfg = style_encoder_cfg or {}
pr_cfg = predictor_cfg or {}
se_d_model = se_cfg.get("char_emb_dim", 64)
st_proj = st_cfg.get("proj_in", {})
st_d_model = st_proj.get("odim", 64)
st_convnext = st_cfg.get("convnext", {})
st_hidden_dim = st_convnext.get("intermediate_dim", 256)
st_num_blocks = st_convnext.get("num_layers", 4)
st_dilation = st_convnext.get("dilation_lst", None)
st_token_layer = st_cfg.get("style_token_layer", {})
st_num_queries = st_token_layer.get("n_style", style_dp)
st_query_dim = st_token_layer.get("style_value_dim", style_dim)
st_num_heads = st_token_layer.get("n_heads", 2)
pr_text_dim = pr_cfg.get("sentence_dim", 64)
pr_style_dim = pr_cfg.get("n_style", st_num_queries) * pr_cfg.get("style_dim", st_query_dim)
self.sentence_encoder = DPTextEncoder(vocab_size=vocab_size, d_model=se_d_model)
self.ref_encoder = DPReferenceEncoder(
in_channels=ref_in_channels,
d_model=st_d_model,
hidden_dim=st_hidden_dim,
num_blocks=st_num_blocks,
num_queries=st_num_queries,
query_dim=st_query_dim,
num_heads=st_num_heads,
dilation_lst=st_dilation,
)
self.predictor = DurationEstimator(text_dim=pr_text_dim, style_dim=pr_style_dim)
def forward(self, text_ids, z_ref=None, text_mask=None, ref_mask=None, style_dp=None, return_log=False):
text_emb = self.sentence_encoder(text_ids, mask=text_mask)
if style_dp is not None:
style_emb = style_dp
elif z_ref is not None:
style_emb = self.ref_encoder(z_ref, mask=ref_mask)
else:
raise ValueError("Either z_ref or style_dp must be provided")
return self.predictor(text_emb, style_emb, text_mask=text_mask, return_log=return_log)