import torch import torch.nn as nn from .duration_predictor import TTSDurationModel class DPNetwork(TTSDurationModel): """Inheritance keeps state_dict keys (`sentence_encoder.*` / `predictor.*`) aligned with ONNX/checkpoints.""" def __init__( self, vocab_size: int = 37, latent_channels: int = 144, style_dp: int = 8, style_dim: int = 16, sentence_encoder_cfg: dict = None, style_encoder_cfg: dict = None, predictor_cfg: dict = None, ): super().__init__( vocab_size=vocab_size, style_dp=style_dp, style_dim=style_dim, ref_in_channels=latent_channels, sentence_encoder_cfg=sentence_encoder_cfg, style_encoder_cfg=style_encoder_cfg, predictor_cfg=predictor_cfg, ) def forward( self, text_ids: torch.Tensor, z_ref: torch.Tensor | None = None, text_mask: torch.Tensor | None = None, ref_mask: torch.Tensor | None = None, style_dp: torch.Tensor | None = None, return_log: bool = False, ) -> torch.Tensor: if text_mask is not None and text_mask.dtype != torch.float32: text_mask = text_mask.float() if ref_mask is not None and ref_mask.dtype != torch.float32: ref_mask = ref_mask.float() elif ref_mask is None and z_ref is not None: B, C, T_ref = z_ref.shape ref_mask = torch.ones(B, 1, T_ref, device=z_ref.device, dtype=torch.float32) return super().forward( text_ids, z_ref=z_ref, text_mask=text_mask, ref_mask=ref_mask, style_dp=style_dp, return_log=return_log, )