File size: 1,768 Bytes
7d51a93 8876cf1 7d51a93 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import torch
import torch.nn as nn
from .duration_predictor import TTSDurationModel
class DPNetwork(TTSDurationModel):
"""Inheritance keeps state_dict keys (`sentence_encoder.*` / `predictor.*`) aligned with ONNX/checkpoints."""
def __init__(
self,
vocab_size: int = 37,
latent_channels: int = 144,
style_dp: int = 8,
style_dim: int = 16,
sentence_encoder_cfg: dict = None,
style_encoder_cfg: dict = None,
predictor_cfg: dict = None,
):
super().__init__(
vocab_size=vocab_size,
style_dp=style_dp,
style_dim=style_dim,
ref_in_channels=latent_channels,
sentence_encoder_cfg=sentence_encoder_cfg,
style_encoder_cfg=style_encoder_cfg,
predictor_cfg=predictor_cfg,
)
def forward(
self,
text_ids: torch.Tensor,
z_ref: torch.Tensor | None = None,
text_mask: torch.Tensor | None = None,
ref_mask: torch.Tensor | None = None,
style_dp: torch.Tensor | None = None,
return_log: bool = False,
) -> torch.Tensor:
if text_mask is not None and text_mask.dtype != torch.float32:
text_mask = text_mask.float()
if ref_mask is not None and ref_mask.dtype != torch.float32:
ref_mask = ref_mask.float()
elif ref_mask is None and z_ref is not None:
B, C, T_ref = z_ref.shape
ref_mask = torch.ones(B, 1, T_ref, device=z_ref.device, dtype=torch.float32)
return super().forward(
text_ids,
z_ref=z_ref,
text_mask=text_mask,
ref_mask=ref_mask,
style_dp=style_dp,
return_log=return_log,
)
|